OTVision 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. OTVision/__init__.py +30 -0
  2. OTVision/application/__init__.py +0 -0
  3. OTVision/application/configure_logger.py +23 -0
  4. OTVision/application/detect/__init__.py +0 -0
  5. OTVision/application/detect/get_detect_cli_args.py +9 -0
  6. OTVision/application/detect/update_detect_config_with_cli_args.py +95 -0
  7. OTVision/application/get_config.py +25 -0
  8. OTVision/config.py +754 -0
  9. OTVision/convert/__init__.py +0 -0
  10. OTVision/convert/convert.py +318 -0
  11. OTVision/dataformat.py +70 -0
  12. OTVision/detect/__init__.py +0 -0
  13. OTVision/detect/builder.py +48 -0
  14. OTVision/detect/cli.py +166 -0
  15. OTVision/detect/detect.py +296 -0
  16. OTVision/detect/otdet.py +103 -0
  17. OTVision/detect/plugin_av/__init__.py +0 -0
  18. OTVision/detect/plugin_av/rotate_frame.py +37 -0
  19. OTVision/detect/yolo.py +277 -0
  20. OTVision/domain/__init__.py +0 -0
  21. OTVision/domain/cli.py +42 -0
  22. OTVision/helpers/__init__.py +0 -0
  23. OTVision/helpers/date.py +26 -0
  24. OTVision/helpers/files.py +538 -0
  25. OTVision/helpers/formats.py +139 -0
  26. OTVision/helpers/log.py +131 -0
  27. OTVision/helpers/machine.py +71 -0
  28. OTVision/helpers/video.py +54 -0
  29. OTVision/track/__init__.py +0 -0
  30. OTVision/track/iou.py +282 -0
  31. OTVision/track/iou_util.py +140 -0
  32. OTVision/track/preprocess.py +451 -0
  33. OTVision/track/track.py +422 -0
  34. OTVision/transform/__init__.py +0 -0
  35. OTVision/transform/get_homography.py +156 -0
  36. OTVision/transform/reference_points_picker.py +462 -0
  37. OTVision/transform/transform.py +352 -0
  38. OTVision/version.py +13 -0
  39. OTVision/view/__init__.py +0 -0
  40. OTVision/view/helpers/OTC.ico +0 -0
  41. OTVision/view/view.py +90 -0
  42. OTVision/view/view_convert.py +128 -0
  43. OTVision/view/view_detect.py +146 -0
  44. OTVision/view/view_helpers.py +417 -0
  45. OTVision/view/view_track.py +131 -0
  46. OTVision/view/view_transform.py +140 -0
  47. otvision-0.5.3.dist-info/METADATA +47 -0
  48. otvision-0.5.3.dist-info/RECORD +50 -0
  49. otvision-0.5.3.dist-info/WHEEL +4 -0
  50. otvision-0.5.3.dist-info/licenses/LICENSE +674 -0
@@ -0,0 +1,131 @@
1
+ """
2
+ OTVision helpers for logging
3
+ """
4
+
5
+ # Copyright (C) 2022 OpenTrafficCam Contributors
6
+ # <https://github.com/OpenTrafficCam
7
+ # <team@opentrafficcam.org>
8
+ #
9
+ # This program is free software: you can redistribute it and/or modify
10
+ # it under the terms of the GNU General Public License as published by
11
+ # the Free Software Foundation, either version 3 of the License, or
12
+ # (at your option) any later version.
13
+ #
14
+ # This program is distributed in the hope that it will be useful,
15
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
16
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17
+ # GNU General Public License for more details.
18
+ #
19
+ # You should have received a copy of the GNU General Public License
20
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
21
+
22
+
23
+ import logging
24
+ import sys
25
+ from datetime import datetime
26
+ from pathlib import Path
27
+
28
+ LOGGER_NAME = "OTVision Logger"
29
+
30
+ DEFAULT_LOG_NAME = f"{datetime.now().strftime(r'%Y-%m-%d_%H-%M-%S')}"
31
+ LOG_EXT = "log"
32
+ DEFAULT_LOG_FILE = Path(f"logs/{DEFAULT_LOG_NAME}.{LOG_EXT}")
33
+
34
+ VALID_LOG_LEVELS = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
35
+
36
+ LOG_FORMAT: str = (
37
+ "%(asctime)s %(levelname)s (%(filename)s::%(funcName)s::%(lineno)d): %(message)s"
38
+ )
39
+
40
+ LOG_LEVEL_INTEGERS = {
41
+ "DEBUG": 10,
42
+ "INFO": 20,
43
+ "WARNING": 30,
44
+ "ERROR": 40,
45
+ "CRITICAL": 50,
46
+ }
47
+
48
+
49
+ class LogFileAlreadyExists(Exception):
50
+ pass
51
+
52
+
53
+ class _OTVisionLogger:
54
+ """Class for creating a logging.Logger.
55
+ Should only be instantiated once in the same module as this class.
56
+ To access this instance, use logging.getLogger(LOGGER_NAME)
57
+ with LOGGER_NAME from the same module where this class is defined.
58
+ """
59
+
60
+ def __init__(self, name: str = LOGGER_NAME) -> None:
61
+ self.logger = logging.getLogger(name=name)
62
+ self.logger.setLevel("DEBUG")
63
+ self._set_formatter()
64
+
65
+ def _set_formatter(self) -> None:
66
+ self.formatter = logging.Formatter(LOG_FORMAT)
67
+
68
+ def _add_handler(self, handler: logging.Handler, level: str) -> None:
69
+ handler.setFormatter(self.formatter)
70
+ handler.setLevel(level=level)
71
+ self.logger.addHandler(handler)
72
+
73
+ def add_file_handler(
74
+ self,
75
+ log_file: Path = DEFAULT_LOG_FILE,
76
+ level: str = "DEBUG",
77
+ overwrite: bool = False,
78
+ ) -> None:
79
+ """Add a file handler to the already existing global instance of
80
+ _OTVisionLogger.
81
+
82
+ Should only be used once in each of OTVisions command line or
83
+ graphical user interfaces.
84
+
85
+ Args:
86
+ log_file (Path): file path to write the logs. Defaults to None.
87
+ level (str): Logging level of the file handler.
88
+ One from "DEBUG", "INFO", "WARNING", "ERROR" or "CRITICAL".
89
+ overwrite (bool): if True, overwrite existing log file. Defaults to False.
90
+
91
+ IMPORTANT:
92
+ log_file and level are not intended to be optional, they have to be provided
93
+ in every case. The default values provided are a safety net.
94
+ """
95
+ if log_file.exists() and not overwrite:
96
+ raise LogFileAlreadyExists(
97
+ f"Log file '{log_file}' already exists. "
98
+ "Please specify option to overwrite the log file when using the CLI."
99
+ )
100
+ log_file.parent.mkdir(parents=True, exist_ok=True)
101
+ log_file.touch()
102
+ file_handler = logging.FileHandler(log_file, mode="w")
103
+ self._add_handler(file_handler, level)
104
+
105
+ def add_console_handler(self, level: str = "WARNING") -> None:
106
+ """Add a console handler to the already existing global instance of
107
+ _OTVisionLogger.
108
+ Should only be used once in each of OTVisions command line or
109
+ graphical user interfaces.
110
+
111
+ Args:
112
+ level (str): Logging level of the console handler.
113
+ One from "DEBUG", "INFO", "WARNING", "ERROR" or "CRITICAL".
114
+ Defaults to "WARNING".
115
+
116
+ IMPORTANT:
117
+ level is not intended to be optional, it has to be provided
118
+ in every case. The default value provided is a safety net.
119
+ """
120
+ console_handler = logging.StreamHandler(sys.stdout)
121
+ self._add_handler(console_handler, level)
122
+
123
+ def _remove_handlers(self) -> None:
124
+ for handler in self.logger.handlers:
125
+ self.logger.removeHandler(handler)
126
+
127
+
128
+ # This here should be the only time the _OTVisionLogger is "directly" instantiated
129
+ # In all other module that should be logged from, use logging.getLogger(LOGGER_NAME)
130
+
131
+ log = _OTVisionLogger()
@@ -0,0 +1,71 @@
1
+ """
2
+ OTVision helpers to gather information about the machine and the system
3
+ """
4
+
5
+ # Copyright (C) 2022 OpenTrafficCam Contributors
6
+ # <https://github.com/OpenTrafficCam
7
+ # <team@opentrafficcam.org>
8
+ #
9
+ # This program is free software: you can redistribute it and/or modify
10
+ # it under the terms of the GNU General Public License as published by
11
+ # the Free Software Foundation, either version 3 of the License, or
12
+ # (at your option) any later version.
13
+ #
14
+ # This program is distributed in the hope that it will be useful,
15
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
16
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17
+ # GNU General Public License for more details.
18
+ #
19
+ # You should have received a copy of the GNU General Public License
20
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
21
+
22
+
23
+ import platform
24
+
25
+ OS = platform.system().replace("Darwin", "Mac")
26
+ """OS OTVision is currently running on"""
27
+
28
+ ON_WINDOWS = OS == "Windows"
29
+ """Wether OS is Windows or not"""
30
+
31
+ ON_LINUX = OS == "Linux"
32
+ """Wether OS is Linux or not"""
33
+
34
+ ON_MAC = OS == "Mac"
35
+ """Wether OS is MacOS or not"""
36
+
37
+ OS_RELEASE = platform.release()
38
+ """Release of the OS OTVision is currently running on"""
39
+
40
+ OS_VERSION = platform.version()
41
+ """Specific version of the OS OTVision is currently running on"""
42
+
43
+ PY_MAJOR_VERSION = int(platform.python_version_tuple()[0])
44
+ """Python major version digit (e.g. 3 for 3.9.5) OTVision is currently running with"""
45
+
46
+ PY_MINOR_VERSION = int(platform.python_version_tuple()[1])
47
+ """Python minor version digit (e.g. 9 for 3.9.5) OTVision is currently running with"""
48
+
49
+ PY_PATCH_VERSION = int(platform.python_version_tuple()[2])
50
+ """Python patch version digit (e.g. 5 for 3.9.5) OTVision is currently running with"""
51
+
52
+
53
+ def _has_cuda() -> bool:
54
+ """Returns True if CUDA is installed on machine
55
+
56
+ Returns:
57
+ Bool: If CUDA is installed on machine or not
58
+ """
59
+ import torch
60
+
61
+ return torch.cuda.is_available()
62
+
63
+
64
+ def print_has_cuda() -> None:
65
+ """Returns True if CUDA is installed on machine
66
+
67
+ Returns:
68
+ Bool: If CUDA is installed on machine or not
69
+ """
70
+
71
+ print(f"This system has cuda: {_has_cuda()}")
@@ -0,0 +1,54 @@
1
+ from datetime import timedelta
2
+ from pathlib import Path
3
+
4
+ from moviepy.video.io.VideoFileClip import VideoFileClip
5
+
6
+
7
+ def get_video_dimensions(video: Path) -> tuple[int, int]:
8
+ """Get video width and height.
9
+
10
+ Args:
11
+ video (Path): the video file
12
+
13
+ Returns:
14
+ tuple[int, int]: width and height of video
15
+ """
16
+ with VideoFileClip(str(video)) as clip:
17
+ video_dimensions = clip.size
18
+ return video_dimensions
19
+
20
+
21
+ def get_fps(video: Path) -> float:
22
+ """Get video's fps.
23
+
24
+ Args:
25
+ video (Path): the video file
26
+
27
+ Returns:
28
+ float: the video's fps
29
+ """
30
+ with VideoFileClip(str(video)) as clip:
31
+ fps = clip.fps
32
+ return fps
33
+
34
+
35
+ def get_duration(video_file: Path) -> timedelta:
36
+ """Get the duration of the video
37
+ Args:
38
+ video_file (Path): path to video file
39
+ Returns:
40
+ timedelta: duration of the video
41
+ """
42
+ with VideoFileClip(str(video_file.absolute())) as clip:
43
+ return timedelta(seconds=clip.duration)
44
+
45
+
46
+ def get_number_of_frames(video_file: Path) -> int:
47
+ """Get the number of frames of the video
48
+ Args:
49
+ video_file (Path): path to video file
50
+ Returns:
51
+ timedelta: number of frames of the video
52
+ """
53
+ with VideoFileClip(str(video_file.absolute())) as clip:
54
+ return clip.reader.nframes
File without changes
OTVision/track/iou.py ADDED
@@ -0,0 +1,282 @@
1
+ """
2
+ OTVision module to track road users in frames detected by OTVision
3
+ """
4
+
5
+ # based on IOU Tracker written by Erik Bochinski originally licensed under the
6
+ # MIT License, see
7
+ # https://github.com/bochinski/iou-tracker/blob/master/LICENSE.
8
+
9
+
10
+ # Copyright (C) 2022 OpenTrafficCam Contributors
11
+ # <https://github.com/OpenTrafficCam
12
+ # <team@opentrafficcam.org>
13
+ #
14
+ # This program is free software: you can redistribute it and/or modify
15
+ # it under the terms of the GNU General Public License as published by
16
+ # the Free Software Foundation, either version 3 of the License, or
17
+ # (at your option) any later version.
18
+ #
19
+ # This program is distributed in the hope that it will be useful,
20
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
21
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
22
+ # GNU General Public License for more details.
23
+ #
24
+ # You should have received a copy of the GNU General Public License
25
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
26
+ from collections import defaultdict
27
+ from dataclasses import dataclass
28
+ from typing import Iterator
29
+
30
+ from tqdm import tqdm
31
+
32
+ from OTVision.config import CONFIG
33
+ from OTVision.dataformat import (
34
+ AGE,
35
+ BBOXES,
36
+ CENTER,
37
+ CLASS,
38
+ CONFIDENCE,
39
+ DETECTIONS,
40
+ FINISHED,
41
+ FIRST,
42
+ FRAMES,
43
+ MAX_CLASS,
44
+ MAX_CONF,
45
+ START_FRAME,
46
+ TRACK_ID,
47
+ H,
48
+ W,
49
+ X,
50
+ Y,
51
+ )
52
+
53
+ from .iou_util import iou
54
+
55
+
56
+ class TrackedDetections:
57
+ def __init__(
58
+ self,
59
+ detections: dict[str, dict],
60
+ detected_ids: set[int],
61
+ active_track_ids: set[int],
62
+ ) -> None:
63
+ self._detections = detections
64
+ self._detected_ids = detected_ids
65
+ self._active_track_ids = active_track_ids
66
+
67
+ def update_active_track_ids(self, new_active_ids: set[int]) -> None:
68
+ self._active_track_ids = {
69
+ _id for _id in self._active_track_ids if _id in new_active_ids
70
+ }
71
+
72
+ def is_finished(self) -> bool:
73
+ return len(self._active_track_ids) == 0
74
+
75
+
76
+ @dataclass(frozen=True)
77
+ class TrackingResult:
78
+ tracked_detections: TrackedDetections
79
+ active_tracks: list[dict]
80
+ last_track_frame: dict[int, int]
81
+
82
+
83
+ def make_bbox(obj: dict) -> tuple[float, float, float, float]:
84
+ """Calculates xyxy coordinates from dict of xywh.
85
+
86
+ Args:
87
+ obj (dict): dict of pixel values for xcenter, ycenter, width and height
88
+
89
+ Returns:
90
+ tuple[float, float, float, float]: xmin, ymin, xmay, ymax
91
+ """
92
+ return (
93
+ obj[X] - obj[W] / 2,
94
+ obj[Y] - obj[H] / 2,
95
+ obj[X] + obj[W] / 2,
96
+ obj[Y] + obj[H] / 2,
97
+ )
98
+
99
+
100
+ def center(obj: dict) -> tuple[float, float]:
101
+ """Retrieves center coordinates from dict.
102
+
103
+ Args:
104
+ obj (dict): _description_
105
+
106
+ Returns:
107
+ tuple[float, float]: _description_
108
+ """
109
+ return obj[X], obj[Y]
110
+
111
+
112
+ def id_generator() -> Iterator[int]:
113
+ ID: int = 0
114
+ while True:
115
+ ID += 1
116
+ yield ID
117
+
118
+
119
+ def track_iou(
120
+ detections: list, # TODO: Type hint nested list during refactoring
121
+ sigma_l: float = CONFIG["TRACK"]["IOU"]["SIGMA_L"],
122
+ sigma_h: float = CONFIG["TRACK"]["IOU"]["SIGMA_H"],
123
+ sigma_iou: float = CONFIG["TRACK"]["IOU"]["SIGMA_IOU"],
124
+ t_min: int = CONFIG["TRACK"]["IOU"]["T_MIN"],
125
+ t_miss_max: int = CONFIG["TRACK"]["IOU"]["T_MISS_MAX"],
126
+ previous_active_tracks: list = [],
127
+ vehicle_id_generator: Iterator[int] = id_generator(),
128
+ ) -> TrackingResult: # sourcery skip: low-code-quality
129
+ """
130
+ Simple IOU based tracker.
131
+ See "High-Speed Tracking-by-Detection Without Using Image Information
132
+ by E. Bochinski, V. Eiselein, T. Sikora" for
133
+ more information.
134
+
135
+ Args:
136
+ detections (list): list of detections per frame, usually generated
137
+ by util.load_mot
138
+ sigma_l (float): low detection threshold.
139
+ sigma_h (float): high detection threshold.
140
+ sigma_iou (float): IOU threshold.
141
+ t_min (float): minimum track length in frames.
142
+ previous_active_tracks (list): a list of remaining active tracks
143
+ from previous iterations.
144
+ vehicle_id_generator (Iterator[int]): provides ids for new tracks
145
+
146
+ Returns:
147
+ TrackingResult: new detections, a list of active tracks
148
+ and a lookup dic for each tracks last detection frame.
149
+ """
150
+
151
+ _check_types(sigma_l, sigma_h, sigma_iou, t_min, t_miss_max)
152
+
153
+ tracks_active: list = []
154
+ tracks_active.extend(previous_active_tracks)
155
+ # tracks_finished = []
156
+
157
+ vehIDs_finished: list = []
158
+ new_detections: dict = {}
159
+
160
+ for frame_num in tqdm(detections, desc="Tracked frames", unit=" frames"):
161
+ detections_frame = detections[frame_num][DETECTIONS]
162
+ # apply low threshold to detections
163
+ dets = [det for det in detections_frame if det[CONFIDENCE] >= sigma_l]
164
+ new_detections[frame_num] = {}
165
+ updated_tracks: list = []
166
+ saved_tracks: list = []
167
+ for track in tracks_active:
168
+ if dets:
169
+ # get det with highest iou
170
+ best_match = max(
171
+ dets, key=lambda x: iou(track[BBOXES][-1], make_bbox(x))
172
+ )
173
+ if iou(track[BBOXES][-1], make_bbox(best_match)) >= sigma_iou:
174
+ track[FRAMES].append(int(frame_num))
175
+ track[BBOXES].append(make_bbox(best_match))
176
+ track[CENTER].append(center(best_match))
177
+ track[CONFIDENCE].append(best_match[CONFIDENCE])
178
+ track[CLASS].append(best_match[CLASS])
179
+ track[MAX_CONF] = max(track[MAX_CONF], best_match[CONFIDENCE])
180
+ track[AGE] = 0
181
+
182
+ updated_tracks.append(track)
183
+
184
+ # remove best matching detection from detections
185
+ del dets[dets.index(best_match)]
186
+ # best_match[TRACK_ID] = track[TRACK_ID]
187
+ best_match[FIRST] = False
188
+ new_detections[frame_num][track[TRACK_ID]] = best_match
189
+
190
+ # if track was not updated
191
+ if not updated_tracks or track is not updated_tracks[-1]:
192
+ # finish track when the conditions are met
193
+ if track[AGE] < t_miss_max:
194
+ track[AGE] += 1
195
+ saved_tracks.append(track)
196
+ elif (
197
+ track[MAX_CONF] >= sigma_h
198
+ and track[FRAMES][-1] - track[FRAMES][0] >= t_min
199
+ ):
200
+ # tracks_finished.append(track)
201
+ vehIDs_finished.append(track[TRACK_ID])
202
+ # TODO: Alter der Tracks
203
+ # create new tracks
204
+ new_tracks = []
205
+ for det in dets:
206
+ vehID = next(vehicle_id_generator)
207
+ new_tracks.append(
208
+ {
209
+ FRAMES: [int(frame_num)],
210
+ BBOXES: [make_bbox(det)],
211
+ CENTER: [center(det)],
212
+ CONFIDENCE: [det[CONFIDENCE]],
213
+ CLASS: [det[CLASS]],
214
+ MAX_CLASS: det[CLASS],
215
+ MAX_CONF: det[CONFIDENCE],
216
+ TRACK_ID: vehID,
217
+ START_FRAME: int(frame_num),
218
+ AGE: 0,
219
+ }
220
+ )
221
+ # det[TRACK_ID] = vehID
222
+ det[FIRST] = True
223
+ new_detections[frame_num][vehID] = det
224
+ tracks_active = updated_tracks + saved_tracks + new_tracks
225
+
226
+ # finish all remaining active tracks
227
+ # tracks_finished += [
228
+ # track
229
+ # for track in tracks_active
230
+ # if (
231
+ # track["max_conf"] >= sigma_h
232
+ # and track["frames"][-1] - track["frames"][0] >= t_min
233
+ # )
234
+ # ]
235
+
236
+ # for track in tracks_finished:
237
+ # track["max_class"] = pd.Series(track["class"]).mode().iat[0]
238
+
239
+ # TODO: #82 Use dict comprehensions in track_iou
240
+ # save last occurrence frame of tracks
241
+ last_track_frame: dict[int, int] = defaultdict(lambda: -1)
242
+
243
+ for frame_num, frame_det in tqdm(
244
+ new_detections.items(), desc="New detection frames", unit=" frames"
245
+ ):
246
+ for vehID, det in frame_det.items():
247
+ det[FINISHED] = False
248
+ det[TRACK_ID] = vehID
249
+ last_track_frame[vehID] = max(frame_num, last_track_frame[vehID])
250
+
251
+ # return tracks_finished
252
+ # TODO: #83 Remove unnecessary code (e.g. for tracks_finished) from track_iou
253
+
254
+ active_track_ids = {t[TRACK_ID] for t in tracks_active}
255
+ detected_ids = set(last_track_frame.keys())
256
+ return TrackingResult(
257
+ TrackedDetections(
258
+ detections=new_detections,
259
+ detected_ids=detected_ids,
260
+ active_track_ids={_id for _id in detected_ids if _id in active_track_ids},
261
+ ),
262
+ active_tracks=tracks_active,
263
+ last_track_frame=last_track_frame,
264
+ )
265
+ # return new_detections, tracks_active, last_track_frame
266
+
267
+
268
+ def _check_types(
269
+ sigma_l: float, sigma_h: float, sigma_iou: float, t_min: int, t_miss_max: int
270
+ ) -> None:
271
+ """Raise ValueErrors if wrong types"""
272
+
273
+ if not isinstance(sigma_l, (int, float)):
274
+ raise ValueError("sigma_l has to be int or float")
275
+ if not isinstance(sigma_h, (int, float)):
276
+ raise ValueError("sigma_h has to be int or float")
277
+ if not isinstance(sigma_iou, (int, float)):
278
+ raise ValueError("sigma_iou has to be int or float")
279
+ if not isinstance(t_min, int):
280
+ raise ValueError("t_min has to be int")
281
+ if not isinstance(t_miss_max, int):
282
+ raise ValueError("t_miss_max has to be int")
@@ -0,0 +1,140 @@
1
+ """
2
+ Utils for using iou tracker
3
+ """
4
+
5
+ # ---------------------------------------------------------
6
+ # IOU Tracker
7
+ # Copyright (c) 2017 TU Berlin, Communication Systems Group
8
+ # Licensed under The MIT License, see
9
+ # https://github.com/bochinski/iou-tracker/blob/master/LICENSE
10
+ # for details.
11
+ # Written by Erik Bochinski
12
+ # ---------------------------------------------------------
13
+
14
+ from typing import Union
15
+
16
+ import numpy as np
17
+
18
+
19
+ # TODO: Remove if not needed
20
+ def nms(
21
+ boxes: np.ndarray,
22
+ scores: np.ndarray,
23
+ overlapThresh: float,
24
+ classes: Union[np.ndarray, None] = None,
25
+ ) -> Union[tuple[np.ndarray, np.ndarray, np.ndarray], tuple[np.ndarray, np.ndarray]]:
26
+ """
27
+ perform non-maximum suppression. based on Malisiewicz et al.
28
+ Args:
29
+ boxes (numpy.ndarray): boxes to process
30
+ scores (numpy.ndarray): corresponding scores for each box
31
+ overlapThresh (float): overlap threshold for boxes to merge
32
+ classes (numpy.ndarray, optional): class ids for each box.
33
+
34
+ Returns:
35
+ (tuple): tuple containing:
36
+
37
+ boxes (list): nms boxes
38
+ scores (list): nms scores
39
+ classes (list, optional): nms classes if specified
40
+ """
41
+ # # if there are no boxes, return an empty list
42
+ # if len(boxes) == 0:
43
+ # return [], [], [] if classes else [], []
44
+
45
+ # if the bounding boxes integers, convert them to floats --
46
+ # this is important since we'll be doing a bunch of divisions
47
+ if boxes.dtype.kind == "i":
48
+ boxes = boxes.astype("float")
49
+
50
+ if scores.dtype.kind == "i":
51
+ scores = scores.astype("float")
52
+
53
+ # initialize the list of picked indexes
54
+ pick = []
55
+
56
+ # grab the coordinates of the bounding boxes
57
+ x1 = boxes[:, 0]
58
+ y1 = boxes[:, 1]
59
+ x2 = boxes[:, 2]
60
+ y2 = boxes[:, 3]
61
+ # score = boxes[:, 4]
62
+ # compute the area of the bounding boxes and sort the bounding
63
+ # boxes by the bottom-right y-coordinate of the bounding box
64
+ area = (x2 - x1 + 1) * (y2 - y1 + 1)
65
+ idxs = np.argsort(scores)
66
+
67
+ # keep looping while some indexes still remain in the indexes
68
+ # list
69
+ while len(idxs) > 0:
70
+ # grab the last index in the indexes list and add the
71
+ # index value to the list of picked indexes
72
+ last = len(idxs) - 1
73
+ i = idxs[last]
74
+ pick.append(i)
75
+
76
+ # find the largest (x, y) coordinates for the start of
77
+ # the bounding box and the smallest (x, y) coordinates
78
+ # for the end of the bounding box
79
+ xx1 = np.maximum(x1[i], x1[idxs[:last]])
80
+ yy1 = np.maximum(y1[i], y1[idxs[:last]])
81
+ xx2 = np.minimum(x2[i], x2[idxs[:last]])
82
+ yy2 = np.minimum(y2[i], y2[idxs[:last]])
83
+
84
+ # compute the width and height of the bounding box
85
+ w = np.maximum(0, xx2 - xx1 + 1)
86
+ h = np.maximum(0, yy2 - yy1 + 1)
87
+
88
+ # compute the ratio of overlap
89
+ overlap = (w * h) / area[idxs[:last]]
90
+
91
+ # delete all indexes from the index list that have
92
+ idxs = np.delete(
93
+ idxs,
94
+ np.concatenate((np.array([last]), np.where(overlap > overlapThresh)[0])),
95
+ )
96
+
97
+ if classes is not None:
98
+ return boxes[pick], scores[pick], classes[pick]
99
+ else:
100
+ return boxes[pick], scores[pick]
101
+
102
+
103
+ def iou(
104
+ bbox1: Union[list[float], tuple[float, float, float, float]],
105
+ bbox2: Union[list[float], tuple[float, float, float, float]],
106
+ ) -> float:
107
+ """
108
+ Calculates the intersection-over-union of two bounding boxes.
109
+
110
+ Args:
111
+ bbox1 (list of floats): bounding box in format x1,y1,x2,y2.
112
+ bbox2 (list of floats): bounding box in format x1,y1,x2,y2.
113
+
114
+ Returns:
115
+ int: intersection-over-onion of bbox1, bbox2
116
+ """
117
+
118
+ bbox1 = [float(x) for x in bbox1]
119
+ bbox2 = [float(x) for x in bbox2]
120
+
121
+ (x0_1, y0_1, x1_1, y1_1) = bbox1
122
+ (x0_2, y0_2, x1_2, y1_2) = bbox2
123
+
124
+ # get the overlap rectangle
125
+ overlap_x0 = max(x0_1, x0_2)
126
+ overlap_y0 = max(y0_1, y0_2)
127
+ overlap_x1 = min(x1_1, x1_2)
128
+ overlap_y1 = min(y1_1, y1_2)
129
+
130
+ # check if there is an overlap
131
+ if overlap_x1 - overlap_x0 <= 0 or overlap_y1 - overlap_y0 <= 0:
132
+ return 0
133
+
134
+ # if yes, calculate the ratio of the overlap to each ROI size and the unified size
135
+ size_1 = (x1_1 - x0_1) * (y1_1 - y0_1)
136
+ size_2 = (x1_2 - x0_2) * (y1_2 - y0_2)
137
+ size_intersection = (overlap_x1 - overlap_x0) * (overlap_y1 - overlap_y0)
138
+ size_union = size_1 + size_2 - size_intersection
139
+
140
+ return size_intersection / size_union