supervisely 6.73.418__py3-none-any.whl → 6.73.420__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/api/entity_annotation/figure_api.py +89 -45
- supervisely/nn/inference/inference.py +61 -45
- supervisely/nn/inference/instance_segmentation/instance_segmentation.py +1 -0
- supervisely/nn/inference/object_detection/object_detection.py +1 -0
- supervisely/nn/inference/session.py +4 -4
- supervisely/nn/model/model_api.py +31 -20
- supervisely/nn/model/prediction.py +11 -0
- supervisely/nn/model/prediction_session.py +33 -6
- supervisely/nn/tracker/__init__.py +1 -2
- supervisely/nn/tracker/base_tracker.py +44 -0
- supervisely/nn/tracker/botsort/__init__.py +1 -0
- supervisely/nn/tracker/botsort/botsort_config.yaml +31 -0
- supervisely/nn/tracker/botsort/osnet_reid/osnet.py +566 -0
- supervisely/nn/tracker/botsort/osnet_reid/osnet_reid_interface.py +88 -0
- supervisely/nn/tracker/botsort/tracker/__init__.py +0 -0
- supervisely/nn/tracker/{bot_sort → botsort/tracker}/basetrack.py +1 -2
- supervisely/nn/tracker/{utils → botsort/tracker}/gmc.py +51 -59
- supervisely/nn/tracker/{deep_sort/deep_sort → botsort/tracker}/kalman_filter.py +71 -33
- supervisely/nn/tracker/botsort/tracker/matching.py +202 -0
- supervisely/nn/tracker/{bot_sort/bot_sort.py → botsort/tracker/mc_bot_sort.py} +68 -81
- supervisely/nn/tracker/botsort_tracker.py +259 -0
- supervisely/project/project.py +212 -74
- {supervisely-6.73.418.dist-info → supervisely-6.73.420.dist-info}/METADATA +3 -1
- {supervisely-6.73.418.dist-info → supervisely-6.73.420.dist-info}/RECORD +29 -42
- supervisely/nn/tracker/bot_sort/__init__.py +0 -21
- supervisely/nn/tracker/bot_sort/fast_reid_interface.py +0 -152
- supervisely/nn/tracker/bot_sort/matching.py +0 -127
- supervisely/nn/tracker/bot_sort/sly_tracker.py +0 -401
- supervisely/nn/tracker/deep_sort/__init__.py +0 -6
- supervisely/nn/tracker/deep_sort/deep_sort/__init__.py +0 -1
- supervisely/nn/tracker/deep_sort/deep_sort/detection.py +0 -49
- supervisely/nn/tracker/deep_sort/deep_sort/iou_matching.py +0 -81
- supervisely/nn/tracker/deep_sort/deep_sort/linear_assignment.py +0 -202
- supervisely/nn/tracker/deep_sort/deep_sort/nn_matching.py +0 -176
- supervisely/nn/tracker/deep_sort/deep_sort/track.py +0 -166
- supervisely/nn/tracker/deep_sort/deep_sort/tracker.py +0 -145
- supervisely/nn/tracker/deep_sort/deep_sort.py +0 -301
- supervisely/nn/tracker/deep_sort/generate_clip_detections.py +0 -90
- supervisely/nn/tracker/deep_sort/preprocessing.py +0 -70
- supervisely/nn/tracker/deep_sort/sly_tracker.py +0 -273
- supervisely/nn/tracker/tracker.py +0 -285
- supervisely/nn/tracker/utils/kalman_filter.py +0 -492
- supervisely/nn/tracking/__init__.py +0 -1
- supervisely/nn/tracking/boxmot.py +0 -114
- supervisely/nn/tracking/tracking.py +0 -24
- /supervisely/nn/tracker/{utils → botsort/osnet_reid}/__init__.py +0 -0
- {supervisely-6.73.418.dist-info → supervisely-6.73.420.dist-info}/LICENSE +0 -0
- {supervisely-6.73.418.dist-info → supervisely-6.73.420.dist-info}/WHEEL +0 -0
- {supervisely-6.73.418.dist-info → supervisely-6.73.420.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.418.dist-info → supervisely-6.73.420.dist-info}/top_level.txt +0 -0
|
@@ -1,273 +0,0 @@
|
|
|
1
|
-
from typing import Dict, List, Tuple, Union
|
|
2
|
-
|
|
3
|
-
# pylint: disable=import-error
|
|
4
|
-
import clip
|
|
5
|
-
import numpy as np
|
|
6
|
-
|
|
7
|
-
from supervisely import Annotation, Label, VideoAnnotation
|
|
8
|
-
from supervisely.nn.tracker.deep_sort import generate_clip_detections as gdet
|
|
9
|
-
from supervisely.nn.tracker.deep_sort import preprocessing
|
|
10
|
-
from supervisely.nn.tracker.deep_sort.deep_sort import nn_matching
|
|
11
|
-
from supervisely.nn.tracker.deep_sort.deep_sort.detection import (
|
|
12
|
-
Detection as dsDetection,
|
|
13
|
-
)
|
|
14
|
-
from supervisely.nn.tracker.deep_sort.deep_sort.track import Track as dsTrack
|
|
15
|
-
from supervisely.nn.tracker.deep_sort.deep_sort.track import TrackState
|
|
16
|
-
from supervisely.nn.tracker.deep_sort.deep_sort.tracker import Tracker as dsTracker
|
|
17
|
-
from supervisely.nn.tracker.tracker import BaseDetection, BaseTrack, BaseTracker
|
|
18
|
-
from supervisely.sly_logger import logger
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class Detection(BaseDetection, dsDetection):
|
|
22
|
-
def __init__(self, sly_label: Label, tlwh, confidence, feature):
|
|
23
|
-
dsDetection.__init__(self, tlwh, confidence, feature)
|
|
24
|
-
self.sly_label = sly_label
|
|
25
|
-
|
|
26
|
-
def get_sly_label(self):
|
|
27
|
-
return self.sly_label
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
class Track(BaseTrack, dsTrack):
|
|
31
|
-
def __init__(
|
|
32
|
-
self,
|
|
33
|
-
mean,
|
|
34
|
-
covariance,
|
|
35
|
-
track_id,
|
|
36
|
-
n_init,
|
|
37
|
-
max_age,
|
|
38
|
-
detection: Detection = None,
|
|
39
|
-
):
|
|
40
|
-
dsTrack.__init__(self, mean, covariance, track_id, n_init, max_age, feature=None)
|
|
41
|
-
|
|
42
|
-
self.state = TrackState.Confirmed
|
|
43
|
-
self.features = []
|
|
44
|
-
self._sly_label = None
|
|
45
|
-
self.class_num = None
|
|
46
|
-
if detection is not None:
|
|
47
|
-
self.features.append(detection.feature)
|
|
48
|
-
self._sly_label = detection.get_sly_label()
|
|
49
|
-
self.class_num = self._sly_label.obj_class.name
|
|
50
|
-
|
|
51
|
-
def get_sly_label(self):
|
|
52
|
-
return self._sly_label
|
|
53
|
-
|
|
54
|
-
def clean_sly_label(self):
|
|
55
|
-
self._sly_label = None
|
|
56
|
-
|
|
57
|
-
def update(self, kf, detection: Detection):
|
|
58
|
-
dsTrack.update(self, kf, detection)
|
|
59
|
-
self._sly_label = detection.get_sly_label()
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
class _dsTracker(dsTracker):
|
|
63
|
-
"""Extend deep sort tracker to support Supervisely labels."""
|
|
64
|
-
|
|
65
|
-
def update(self, detections: List[Detection]):
|
|
66
|
-
"""Perform measurement update and track management.
|
|
67
|
-
|
|
68
|
-
Parameters
|
|
69
|
-
----------
|
|
70
|
-
detections : List[deep_sort.detection.Detection]
|
|
71
|
-
A list of detections at the current time step.
|
|
72
|
-
|
|
73
|
-
"""
|
|
74
|
-
# Clean up all previous labels
|
|
75
|
-
for track in self.tracks:
|
|
76
|
-
track: Track
|
|
77
|
-
track.clean_sly_label()
|
|
78
|
-
|
|
79
|
-
dsTracker.update(self, detections)
|
|
80
|
-
|
|
81
|
-
def _initiate_track(self, detection: Detection):
|
|
82
|
-
mean, covariance = self.kf.initiate(detection.to_xyah())
|
|
83
|
-
self.tracks.append(
|
|
84
|
-
Track(mean, covariance, self._next_id, self.n_init, self.max_age, detection)
|
|
85
|
-
)
|
|
86
|
-
self._next_id += 1
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
class DeepSortTracker(BaseTracker):
|
|
90
|
-
def __init__(self, settings: Dict = None):
|
|
91
|
-
if settings is None:
|
|
92
|
-
settings = {}
|
|
93
|
-
super().__init__(settings)
|
|
94
|
-
model_filename = "ViT-B/32" # initialize deep sort
|
|
95
|
-
logger.info("Loading CLIP...")
|
|
96
|
-
model, transform = clip.load(model_filename, device=self.device)
|
|
97
|
-
self.encoder = gdet.create_box_encoder(model, transform, batch_size=1, device=self.device)
|
|
98
|
-
metric = nn_matching.NearestNeighborDistanceMetric( # calculate cosine distance metric
|
|
99
|
-
"cosine", self.args.max_cosine_distance, self.args.nn_budget
|
|
100
|
-
)
|
|
101
|
-
self.tracker = _dsTracker(metric, n_init=1)
|
|
102
|
-
|
|
103
|
-
def default_settings(self):
|
|
104
|
-
"""To be overridden by subclasses."""
|
|
105
|
-
return {"nms_max_overlap": 1.0, "max_cosine_distance": 0.6, "nn_budget": None}
|
|
106
|
-
|
|
107
|
-
def track(
|
|
108
|
-
self,
|
|
109
|
-
source: Union[List[np.ndarray], List[str], str],
|
|
110
|
-
frame_to_annotation: Dict[int, Annotation],
|
|
111
|
-
frame_shape: Tuple[int, int],
|
|
112
|
-
pbar_cb=None,
|
|
113
|
-
) -> VideoAnnotation:
|
|
114
|
-
"""
|
|
115
|
-
Track objects in the video using DeepSort algorithm.
|
|
116
|
-
|
|
117
|
-
:param source: List of images, paths to images or path to the video file.
|
|
118
|
-
:type source: List[np.ndarray] | List[str] | str
|
|
119
|
-
:param frame_to_annotation: Dictionary with frame index as key and Annotation as value.
|
|
120
|
-
:type frame_to_annotation: Dict[int, Annotation]
|
|
121
|
-
:param frame_shape: Size of the frame (height, width).
|
|
122
|
-
:type frame_shape: Tuple[int, int]
|
|
123
|
-
:param pbar_cb: Callback to update progress bar.
|
|
124
|
-
:type pbar_cb: Callable, optional
|
|
125
|
-
|
|
126
|
-
:return: Video annotation with tracked objects.
|
|
127
|
-
:rtype: VideoAnnotation
|
|
128
|
-
|
|
129
|
-
:raises ValueError: If number of images and annotations are not the same.
|
|
130
|
-
|
|
131
|
-
:Usage example:
|
|
132
|
-
|
|
133
|
-
.. code-block:: python
|
|
134
|
-
|
|
135
|
-
import supervisely as sly
|
|
136
|
-
from supervisely.nn.tracker import DeepSortTracker
|
|
137
|
-
|
|
138
|
-
api = sly.Api()
|
|
139
|
-
|
|
140
|
-
project_id = 12345
|
|
141
|
-
video_id = 12345678
|
|
142
|
-
video_path = "video.mp4"
|
|
143
|
-
|
|
144
|
-
# Download video and get video info
|
|
145
|
-
video_info = api.video.get_info_by_id(video_id)
|
|
146
|
-
frame_shape = (video_info.frame_height, video_info.frame_width)
|
|
147
|
-
api.video.download_path(id=video_id, path=video_path)
|
|
148
|
-
|
|
149
|
-
# Run inference app to get detections
|
|
150
|
-
task_id = 12345 # detection app task id
|
|
151
|
-
session = sly.nn.inference.Session(api, task_id)
|
|
152
|
-
annotations = session.inference_video_id(video_id, 0, video_info.frames_count)
|
|
153
|
-
frame_to_annotation = {i: ann for i, ann in enumerate(annotations)}
|
|
154
|
-
|
|
155
|
-
# Run tracker
|
|
156
|
-
tracker = DeepSortTracker()
|
|
157
|
-
video_ann = tracker.track(video_path, frame_to_annotation, frame_shape)
|
|
158
|
-
|
|
159
|
-
# Upload result
|
|
160
|
-
model_meta = session.get_model_meta()
|
|
161
|
-
project_meta = sly.ProjectMeta.from_json(api.project.get_meta(project_id))
|
|
162
|
-
project_meta = project_meta.merge(model_meta)
|
|
163
|
-
api.project.update_meta(project_id, project_meta)
|
|
164
|
-
api.video.annotation.append(video_id, video_ann)
|
|
165
|
-
"""
|
|
166
|
-
if not isinstance(source, str):
|
|
167
|
-
if len(source) != len(frame_to_annotation):
|
|
168
|
-
raise ValueError("Number of images and annotations should be the same")
|
|
169
|
-
|
|
170
|
-
tracks_data = {}
|
|
171
|
-
logger.info("Starting deep_sort tracking with CLIP...")
|
|
172
|
-
|
|
173
|
-
for frame_index, img in enumerate(self.frames_generator(source)):
|
|
174
|
-
tracks_data = self.update(
|
|
175
|
-
img, frame_to_annotation[frame_index], frame_index, tracks_data
|
|
176
|
-
)
|
|
177
|
-
|
|
178
|
-
if pbar_cb is not None:
|
|
179
|
-
pbar_cb()
|
|
180
|
-
|
|
181
|
-
tracks_data = self.clear_empty_ids(tracker_annotations=tracks_data)
|
|
182
|
-
|
|
183
|
-
return self.get_annotation(
|
|
184
|
-
tracks_data=tracks_data,
|
|
185
|
-
frame_shape=frame_shape,
|
|
186
|
-
frames_count=len(frame_to_annotation),
|
|
187
|
-
)
|
|
188
|
-
|
|
189
|
-
def update(
|
|
190
|
-
self, img, annotation: Annotation, frame_index, tracks_data: Dict[int, List[Dict]] = None
|
|
191
|
-
):
|
|
192
|
-
import torch
|
|
193
|
-
|
|
194
|
-
detections = []
|
|
195
|
-
try:
|
|
196
|
-
pred, sly_labels = self.convert_annotation(annotation)
|
|
197
|
-
det = torch.tensor(pred)
|
|
198
|
-
|
|
199
|
-
# Process detections
|
|
200
|
-
bboxes = det[:, :4].clone().cpu()
|
|
201
|
-
# tlwh -> lthw
|
|
202
|
-
bboxes = [bbox[[1, 0, 3, 2]] for bbox in bboxes]
|
|
203
|
-
confs = det[:, 4]
|
|
204
|
-
|
|
205
|
-
# encode yolo detections and feed to tracker
|
|
206
|
-
features = self.encoder(img, bboxes)
|
|
207
|
-
detections = [
|
|
208
|
-
Detection(sly_label, bbox, conf, feature)
|
|
209
|
-
for bbox, conf, feature, sly_label in zip(bboxes, confs, features, sly_labels)
|
|
210
|
-
]
|
|
211
|
-
|
|
212
|
-
# run non-maxima supression
|
|
213
|
-
boxs = np.array([d.tlwh for d in detections])
|
|
214
|
-
scores = np.array([d.confidence for d in detections])
|
|
215
|
-
class_nums = np.array([d.sly_label.obj_class.name for d in detections])
|
|
216
|
-
indices_of_alive_labels = preprocessing.non_max_suppression(
|
|
217
|
-
boxs, class_nums, self.args.nms_max_overlap, scores
|
|
218
|
-
)
|
|
219
|
-
detections = [detections[i] for i in indices_of_alive_labels]
|
|
220
|
-
except Exception as ex:
|
|
221
|
-
import traceback
|
|
222
|
-
|
|
223
|
-
logger.info(f"frame {frame_index} skipped on tracking")
|
|
224
|
-
logger.debug(traceback.format_exc())
|
|
225
|
-
|
|
226
|
-
# Call the tracker
|
|
227
|
-
self.tracker.predict()
|
|
228
|
-
self.tracker.update(detections)
|
|
229
|
-
|
|
230
|
-
if tracks_data is None:
|
|
231
|
-
tracks_data = {}
|
|
232
|
-
self.update_track_data(
|
|
233
|
-
tracks_data=tracks_data,
|
|
234
|
-
tracks=[
|
|
235
|
-
track
|
|
236
|
-
for track in self.tracker.tracks
|
|
237
|
-
if track.is_confirmed() or track.time_since_update <= 1
|
|
238
|
-
],
|
|
239
|
-
frame_index=frame_index,
|
|
240
|
-
)
|
|
241
|
-
return tracks_data
|
|
242
|
-
|
|
243
|
-
def update_track_data(self, tracks_data: dict, tracks: List[BaseTrack], frame_index: int):
|
|
244
|
-
track_id_data = []
|
|
245
|
-
labels_data = []
|
|
246
|
-
|
|
247
|
-
for curr_track in tracks:
|
|
248
|
-
track_id = curr_track.track_id - 1 # track_id starts from 1
|
|
249
|
-
|
|
250
|
-
if curr_track.get_sly_label() is not None:
|
|
251
|
-
track_id_data.append(track_id)
|
|
252
|
-
labels_data.append(curr_track.get_sly_label())
|
|
253
|
-
|
|
254
|
-
tracks_data[frame_index] = {"ids": track_id_data, "labels": labels_data}
|
|
255
|
-
|
|
256
|
-
return tracks_data
|
|
257
|
-
|
|
258
|
-
def clear_empty_ids(self, tracker_annotations):
|
|
259
|
-
id_mappings = {}
|
|
260
|
-
last_ordinal_id = 0
|
|
261
|
-
|
|
262
|
-
for frame_index, data in tracker_annotations.items():
|
|
263
|
-
data_ids_temp = []
|
|
264
|
-
for current_id in data["ids"]:
|
|
265
|
-
new_id = id_mappings.get(current_id, -1)
|
|
266
|
-
if new_id == -1:
|
|
267
|
-
id_mappings[current_id] = last_ordinal_id
|
|
268
|
-
last_ordinal_id += 1
|
|
269
|
-
new_id = id_mappings.get(current_id, -1)
|
|
270
|
-
data_ids_temp.append(new_id)
|
|
271
|
-
data["ids"] = data_ids_temp
|
|
272
|
-
|
|
273
|
-
return tracker_annotations
|
|
@@ -1,285 +0,0 @@
|
|
|
1
|
-
import argparse
|
|
2
|
-
import os
|
|
3
|
-
from contextlib import contextmanager
|
|
4
|
-
from typing import Dict, List, Union
|
|
5
|
-
|
|
6
|
-
import cv2
|
|
7
|
-
import numpy as np
|
|
8
|
-
|
|
9
|
-
from supervisely import (
|
|
10
|
-
Annotation,
|
|
11
|
-
Frame,
|
|
12
|
-
FrameCollection,
|
|
13
|
-
Label,
|
|
14
|
-
Rectangle,
|
|
15
|
-
VideoAnnotation,
|
|
16
|
-
VideoFigure,
|
|
17
|
-
VideoObject,
|
|
18
|
-
VideoObjectCollection,
|
|
19
|
-
)
|
|
20
|
-
from supervisely.sly_logger import logger
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class BaseDetection:
|
|
24
|
-
"""
|
|
25
|
-
This class represents a bounding box detection in a single image.
|
|
26
|
-
|
|
27
|
-
Parameters
|
|
28
|
-
----------
|
|
29
|
-
tlwh : array_like
|
|
30
|
-
Bounding box in format `(x, y, w, h)`.
|
|
31
|
-
confidence : float
|
|
32
|
-
Detector confidence score.
|
|
33
|
-
feature : array_like | NoneType
|
|
34
|
-
A feature vector that describes the object contained in this image.
|
|
35
|
-
sly_label : Label | NoneType
|
|
36
|
-
A Supervisely Label object
|
|
37
|
-
|
|
38
|
-
Attributes
|
|
39
|
-
----------
|
|
40
|
-
tlwh : ndarray
|
|
41
|
-
Bounding box in format `(top left x, top left y, width, height)`.
|
|
42
|
-
confidence : ndarray
|
|
43
|
-
Detector confidence score.
|
|
44
|
-
feature : ndarray | NoneType
|
|
45
|
-
A feature vector that describes the object contained in this image.
|
|
46
|
-
sly_label : Label | NoneType
|
|
47
|
-
A Supervisely Label object
|
|
48
|
-
|
|
49
|
-
"""
|
|
50
|
-
|
|
51
|
-
def __init__(self, tlwh, confidence: float, feature=None, sly_label: Label = None):
|
|
52
|
-
self.tlwh = np.asarray(tlwh, dtype=float)
|
|
53
|
-
self.confidence = float(confidence)
|
|
54
|
-
self.feature = np.asarray(feature, dtype=np.float32)
|
|
55
|
-
self._sly_label = sly_label
|
|
56
|
-
|
|
57
|
-
def __iter__(self):
|
|
58
|
-
return iter([*self.tlwh, self.confidence, self.feature])
|
|
59
|
-
|
|
60
|
-
def tlbr(self):
|
|
61
|
-
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
|
|
62
|
-
`(top left, bottom right)`.
|
|
63
|
-
"""
|
|
64
|
-
ret = self.tlwh.copy()
|
|
65
|
-
ret[2:] += ret[:2]
|
|
66
|
-
return ret
|
|
67
|
-
|
|
68
|
-
def xyah(self):
|
|
69
|
-
"""Convert bounding box to format `(center x, center y, aspect ratio,
|
|
70
|
-
height)`, where the aspect ratio is `width / height`.
|
|
71
|
-
"""
|
|
72
|
-
ret = self.tlwh.copy()
|
|
73
|
-
ret[:2] += ret[2:] / 2
|
|
74
|
-
ret[2] /= ret[3]
|
|
75
|
-
return ret
|
|
76
|
-
|
|
77
|
-
@property
|
|
78
|
-
def sly_label(self):
|
|
79
|
-
return self._sly_label
|
|
80
|
-
|
|
81
|
-
@sly_label.setter
|
|
82
|
-
def sly_label(self, sly_label: Label):
|
|
83
|
-
self._sly_label = sly_label
|
|
84
|
-
|
|
85
|
-
def get_sly_label(self):
|
|
86
|
-
return self.sly_label
|
|
87
|
-
|
|
88
|
-
def set_sly_label(self, sly_label: Label):
|
|
89
|
-
self.sly_label = sly_label
|
|
90
|
-
|
|
91
|
-
def clean_sly_label(self):
|
|
92
|
-
self.sly_label = None
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
class BaseTrack:
|
|
96
|
-
def __init__(self, track_id, *args, **kwargs):
|
|
97
|
-
self.track_id = track_id
|
|
98
|
-
|
|
99
|
-
def get_sly_label(self):
|
|
100
|
-
raise NotImplementedError()
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
class BaseTracker:
|
|
104
|
-
def __init__(self, settings: Dict):
|
|
105
|
-
self.settings = settings
|
|
106
|
-
self.args = self.parse_settings(settings)
|
|
107
|
-
self.device = self.select_device(device=self.args.device)
|
|
108
|
-
|
|
109
|
-
def select_device(self, device="", batch_size=None):
|
|
110
|
-
import torch # pylint: disable=import-error
|
|
111
|
-
|
|
112
|
-
# device = 'cpu' or '0' or '0,1,2,3'
|
|
113
|
-
cpu_request = device.lower() == "cpu"
|
|
114
|
-
if device and not cpu_request: # if device requested other than 'cpu'
|
|
115
|
-
os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable
|
|
116
|
-
assert (
|
|
117
|
-
torch.cuda.is_available()
|
|
118
|
-
), f"CUDA unavailable, invalid device {device} requested" # check availablity
|
|
119
|
-
|
|
120
|
-
cuda = False if cpu_request else torch.cuda.is_available()
|
|
121
|
-
if cuda:
|
|
122
|
-
c = 1024**2 # bytes to MB
|
|
123
|
-
ng = torch.cuda.device_count()
|
|
124
|
-
if ng > 1 and batch_size: # check that batch_size is compatible with device_count
|
|
125
|
-
assert (
|
|
126
|
-
batch_size % ng == 0
|
|
127
|
-
), f"batch-size {batch_size} not multiple of GPU count {ng}"
|
|
128
|
-
x = [torch.cuda.get_device_properties(i) for i in range(ng)]
|
|
129
|
-
s = f"Using torch {torch.__version__} "
|
|
130
|
-
for i, d in enumerate((device or "0").split(",")):
|
|
131
|
-
if i == 1:
|
|
132
|
-
s = " " * len(s)
|
|
133
|
-
logger.info(f"{s}CUDA:{d} ({x[i].name}, {x[i].total_memory / c}MB)")
|
|
134
|
-
else:
|
|
135
|
-
logger.info(f"Using torch {torch.__version__} CPU")
|
|
136
|
-
|
|
137
|
-
logger.info("") # skip a line
|
|
138
|
-
return torch.device("cuda:0" if cuda else "cpu")
|
|
139
|
-
|
|
140
|
-
def parse_settings(self, settings: Dict) -> argparse.Namespace:
|
|
141
|
-
_settings = self.default_settings()
|
|
142
|
-
_settings.update(settings)
|
|
143
|
-
if "device" not in _settings:
|
|
144
|
-
_settings["device"] = ""
|
|
145
|
-
return argparse.Namespace(**_settings)
|
|
146
|
-
|
|
147
|
-
def default_settings(self):
|
|
148
|
-
"""To be overridden by subclasses."""
|
|
149
|
-
return {}
|
|
150
|
-
|
|
151
|
-
@contextmanager
|
|
152
|
-
def _video_frames_generator(self, video_path: str):
|
|
153
|
-
cap = cv2.VideoCapture(video_path)
|
|
154
|
-
try:
|
|
155
|
-
while cap.isOpened():
|
|
156
|
-
ret, frame = cap.read()
|
|
157
|
-
if not ret:
|
|
158
|
-
break
|
|
159
|
-
yield frame
|
|
160
|
-
finally:
|
|
161
|
-
cap.release()
|
|
162
|
-
|
|
163
|
-
def frames_generator(self, source: str):
|
|
164
|
-
if isinstance(source, str):
|
|
165
|
-
|
|
166
|
-
def _gen():
|
|
167
|
-
with self._video_frames_generator(source) as frames:
|
|
168
|
-
for frame in frames:
|
|
169
|
-
yield frame
|
|
170
|
-
|
|
171
|
-
return _gen()
|
|
172
|
-
elif isinstance(source, list) and isinstance(source[0], str):
|
|
173
|
-
return [cv2.imread(img) for img in source]
|
|
174
|
-
else:
|
|
175
|
-
return source
|
|
176
|
-
|
|
177
|
-
def track(
|
|
178
|
-
self,
|
|
179
|
-
source: Union[List[np.ndarray], List[str], str],
|
|
180
|
-
frame_to_annotation: Dict[int, Annotation],
|
|
181
|
-
pbar_cb=None,
|
|
182
|
-
):
|
|
183
|
-
"""To be overridden by subclasses."""
|
|
184
|
-
raise NotImplementedError()
|
|
185
|
-
|
|
186
|
-
def convert_annotation(self, annotation_for_frame: Annotation):
|
|
187
|
-
formatted_predictions = []
|
|
188
|
-
sly_labels = []
|
|
189
|
-
|
|
190
|
-
for label in annotation_for_frame.labels:
|
|
191
|
-
confidence = 1.0
|
|
192
|
-
if label.tags.get("confidence", None) is not None:
|
|
193
|
-
confidence = label.tags.get("confidence").value
|
|
194
|
-
elif label.tags.get("conf", None) is not None:
|
|
195
|
-
confidence = label.tags.get("conf").value
|
|
196
|
-
|
|
197
|
-
rectangle: Rectangle = label.geometry.to_bbox()
|
|
198
|
-
tlwh = [
|
|
199
|
-
rectangle.top,
|
|
200
|
-
rectangle.left,
|
|
201
|
-
rectangle.height,
|
|
202
|
-
rectangle.width,
|
|
203
|
-
confidence,
|
|
204
|
-
]
|
|
205
|
-
|
|
206
|
-
formatted_predictions.append(tlwh)
|
|
207
|
-
sly_labels.append(label)
|
|
208
|
-
|
|
209
|
-
return formatted_predictions, sly_labels
|
|
210
|
-
|
|
211
|
-
def update(
|
|
212
|
-
self, img, annotation: Annotation, frame_index, tracks_data: Dict[int, List[Dict]] = None
|
|
213
|
-
):
|
|
214
|
-
raise NotImplementedError()
|
|
215
|
-
|
|
216
|
-
def correct_figure(self, img_size, figure): # img_size — height, width tuple
|
|
217
|
-
# check figure is within image bounds
|
|
218
|
-
canvas_rect = Rectangle.from_size(img_size)
|
|
219
|
-
if canvas_rect.contains(figure.to_bbox()) is False:
|
|
220
|
-
# crop figure
|
|
221
|
-
figures_after_crop = figure.crop(canvas_rect)
|
|
222
|
-
if len(figures_after_crop) > 0:
|
|
223
|
-
return figures_after_crop[0]
|
|
224
|
-
else:
|
|
225
|
-
return None
|
|
226
|
-
else:
|
|
227
|
-
return figure
|
|
228
|
-
|
|
229
|
-
def update_track_data(self, tracks_data: dict, tracks: List[BaseTrack], frame_index: int):
|
|
230
|
-
track_id_data = []
|
|
231
|
-
labels_data = []
|
|
232
|
-
|
|
233
|
-
for curr_track in tracks:
|
|
234
|
-
track_id = curr_track.track_id
|
|
235
|
-
|
|
236
|
-
if curr_track.get_sly_label() is not None:
|
|
237
|
-
track_id_data.append(track_id)
|
|
238
|
-
labels_data.append(curr_track.get_sly_label())
|
|
239
|
-
|
|
240
|
-
tracks_data[frame_index] = {"ids": track_id_data, "labels": labels_data}
|
|
241
|
-
|
|
242
|
-
return tracks_data
|
|
243
|
-
|
|
244
|
-
def get_annotation(self, tracks_data: Dict, frame_shape, frames_count) -> VideoAnnotation:
|
|
245
|
-
# Create and count object classes for each track
|
|
246
|
-
object_classes = {} # object_class_name -> object_class
|
|
247
|
-
object_class_counter = {} # track_id -> object_class_name -> count
|
|
248
|
-
for frame_index, data in tracks_data.items():
|
|
249
|
-
for track_id, label in zip(data["ids"], data["labels"]):
|
|
250
|
-
label: Label
|
|
251
|
-
object_classes.setdefault(label.obj_class.name, label.obj_class)
|
|
252
|
-
object_class_counter.setdefault(track_id, {}).setdefault(label.obj_class.name, 0)
|
|
253
|
-
object_class_counter[track_id][label.obj_class.name] += 1
|
|
254
|
-
|
|
255
|
-
# Assign object classes to tracks
|
|
256
|
-
track_obj_classes = {} # track_id -> object_class
|
|
257
|
-
for track_id, counters in object_class_counter.items():
|
|
258
|
-
max_counter = -1
|
|
259
|
-
obj_class_name = None
|
|
260
|
-
for obj_class_name, count in counters.items():
|
|
261
|
-
if count > max_counter:
|
|
262
|
-
max_counter = count
|
|
263
|
-
obj_class_name = obj_class_name
|
|
264
|
-
track_obj_classes[track_id] = object_classes[obj_class_name]
|
|
265
|
-
|
|
266
|
-
# Create video objects, figures and frames
|
|
267
|
-
video_objects = {} # track_id -> VideoObject
|
|
268
|
-
frames = []
|
|
269
|
-
for frame_index, data in tracks_data.items():
|
|
270
|
-
frame_figures = []
|
|
271
|
-
for track_id, label in zip(data["ids"], data["labels"]):
|
|
272
|
-
label: Label
|
|
273
|
-
video_object = video_objects.setdefault(
|
|
274
|
-
track_id, VideoObject(track_obj_classes[track_id])
|
|
275
|
-
)
|
|
276
|
-
frame_figures.append(VideoFigure(video_object, label.geometry, frame_index))
|
|
277
|
-
frames.append(Frame(frame_index, frame_figures))
|
|
278
|
-
|
|
279
|
-
objects = list(video_objects.values())
|
|
280
|
-
return VideoAnnotation(
|
|
281
|
-
img_size=frame_shape,
|
|
282
|
-
frames_count=frames_count,
|
|
283
|
-
objects=VideoObjectCollection(objects),
|
|
284
|
-
frames=FrameCollection(frames),
|
|
285
|
-
)
|