supervisely 6.73.417__py3-none-any.whl → 6.73.419__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. supervisely/api/entity_annotation/figure_api.py +89 -45
  2. supervisely/nn/inference/inference.py +61 -45
  3. supervisely/nn/inference/instance_segmentation/instance_segmentation.py +1 -0
  4. supervisely/nn/inference/object_detection/object_detection.py +1 -0
  5. supervisely/nn/inference/session.py +4 -4
  6. supervisely/nn/model/model_api.py +31 -20
  7. supervisely/nn/model/prediction.py +11 -0
  8. supervisely/nn/model/prediction_session.py +33 -6
  9. supervisely/nn/tracker/__init__.py +1 -2
  10. supervisely/nn/tracker/base_tracker.py +44 -0
  11. supervisely/nn/tracker/botsort/__init__.py +1 -0
  12. supervisely/nn/tracker/botsort/botsort_config.yaml +31 -0
  13. supervisely/nn/tracker/botsort/osnet_reid/osnet.py +566 -0
  14. supervisely/nn/tracker/botsort/osnet_reid/osnet_reid_interface.py +88 -0
  15. supervisely/nn/tracker/botsort/tracker/__init__.py +0 -0
  16. supervisely/nn/tracker/{bot_sort → botsort/tracker}/basetrack.py +1 -2
  17. supervisely/nn/tracker/{utils → botsort/tracker}/gmc.py +51 -59
  18. supervisely/nn/tracker/{deep_sort/deep_sort → botsort/tracker}/kalman_filter.py +71 -33
  19. supervisely/nn/tracker/botsort/tracker/matching.py +202 -0
  20. supervisely/nn/tracker/{bot_sort/bot_sort.py → botsort/tracker/mc_bot_sort.py} +68 -81
  21. supervisely/nn/tracker/botsort_tracker.py +259 -0
  22. supervisely/project/project.py +1 -1
  23. {supervisely-6.73.417.dist-info → supervisely-6.73.419.dist-info}/METADATA +5 -3
  24. {supervisely-6.73.417.dist-info → supervisely-6.73.419.dist-info}/RECORD +29 -42
  25. supervisely/nn/tracker/bot_sort/__init__.py +0 -21
  26. supervisely/nn/tracker/bot_sort/fast_reid_interface.py +0 -152
  27. supervisely/nn/tracker/bot_sort/matching.py +0 -127
  28. supervisely/nn/tracker/bot_sort/sly_tracker.py +0 -401
  29. supervisely/nn/tracker/deep_sort/__init__.py +0 -6
  30. supervisely/nn/tracker/deep_sort/deep_sort/__init__.py +0 -1
  31. supervisely/nn/tracker/deep_sort/deep_sort/detection.py +0 -49
  32. supervisely/nn/tracker/deep_sort/deep_sort/iou_matching.py +0 -81
  33. supervisely/nn/tracker/deep_sort/deep_sort/linear_assignment.py +0 -202
  34. supervisely/nn/tracker/deep_sort/deep_sort/nn_matching.py +0 -176
  35. supervisely/nn/tracker/deep_sort/deep_sort/track.py +0 -166
  36. supervisely/nn/tracker/deep_sort/deep_sort/tracker.py +0 -145
  37. supervisely/nn/tracker/deep_sort/deep_sort.py +0 -301
  38. supervisely/nn/tracker/deep_sort/generate_clip_detections.py +0 -90
  39. supervisely/nn/tracker/deep_sort/preprocessing.py +0 -70
  40. supervisely/nn/tracker/deep_sort/sly_tracker.py +0 -273
  41. supervisely/nn/tracker/tracker.py +0 -285
  42. supervisely/nn/tracker/utils/kalman_filter.py +0 -492
  43. supervisely/nn/tracking/__init__.py +0 -1
  44. supervisely/nn/tracking/boxmot.py +0 -114
  45. supervisely/nn/tracking/tracking.py +0 -24
  46. /supervisely/nn/tracker/{utils → botsort/osnet_reid}/__init__.py +0 -0
  47. {supervisely-6.73.417.dist-info → supervisely-6.73.419.dist-info}/LICENSE +0 -0
  48. {supervisely-6.73.417.dist-info → supervisely-6.73.419.dist-info}/WHEEL +0 -0
  49. {supervisely-6.73.417.dist-info → supervisely-6.73.419.dist-info}/entry_points.txt +0 -0
  50. {supervisely-6.73.417.dist-info → supervisely-6.73.419.dist-info}/top_level.txt +0 -0
@@ -1,273 +0,0 @@
1
- from typing import Dict, List, Tuple, Union
2
-
3
- # pylint: disable=import-error
4
- import clip
5
- import numpy as np
6
-
7
- from supervisely import Annotation, Label, VideoAnnotation
8
- from supervisely.nn.tracker.deep_sort import generate_clip_detections as gdet
9
- from supervisely.nn.tracker.deep_sort import preprocessing
10
- from supervisely.nn.tracker.deep_sort.deep_sort import nn_matching
11
- from supervisely.nn.tracker.deep_sort.deep_sort.detection import (
12
- Detection as dsDetection,
13
- )
14
- from supervisely.nn.tracker.deep_sort.deep_sort.track import Track as dsTrack
15
- from supervisely.nn.tracker.deep_sort.deep_sort.track import TrackState
16
- from supervisely.nn.tracker.deep_sort.deep_sort.tracker import Tracker as dsTracker
17
- from supervisely.nn.tracker.tracker import BaseDetection, BaseTrack, BaseTracker
18
- from supervisely.sly_logger import logger
19
-
20
-
21
- class Detection(BaseDetection, dsDetection):
22
- def __init__(self, sly_label: Label, tlwh, confidence, feature):
23
- dsDetection.__init__(self, tlwh, confidence, feature)
24
- self.sly_label = sly_label
25
-
26
- def get_sly_label(self):
27
- return self.sly_label
28
-
29
-
30
- class Track(BaseTrack, dsTrack):
31
- def __init__(
32
- self,
33
- mean,
34
- covariance,
35
- track_id,
36
- n_init,
37
- max_age,
38
- detection: Detection = None,
39
- ):
40
- dsTrack.__init__(self, mean, covariance, track_id, n_init, max_age, feature=None)
41
-
42
- self.state = TrackState.Confirmed
43
- self.features = []
44
- self._sly_label = None
45
- self.class_num = None
46
- if detection is not None:
47
- self.features.append(detection.feature)
48
- self._sly_label = detection.get_sly_label()
49
- self.class_num = self._sly_label.obj_class.name
50
-
51
- def get_sly_label(self):
52
- return self._sly_label
53
-
54
- def clean_sly_label(self):
55
- self._sly_label = None
56
-
57
- def update(self, kf, detection: Detection):
58
- dsTrack.update(self, kf, detection)
59
- self._sly_label = detection.get_sly_label()
60
-
61
-
62
- class _dsTracker(dsTracker):
63
- """Extend deep sort tracker to support Supervisely labels."""
64
-
65
- def update(self, detections: List[Detection]):
66
- """Perform measurement update and track management.
67
-
68
- Parameters
69
- ----------
70
- detections : List[deep_sort.detection.Detection]
71
- A list of detections at the current time step.
72
-
73
- """
74
- # Clean up all previous labels
75
- for track in self.tracks:
76
- track: Track
77
- track.clean_sly_label()
78
-
79
- dsTracker.update(self, detections)
80
-
81
- def _initiate_track(self, detection: Detection):
82
- mean, covariance = self.kf.initiate(detection.to_xyah())
83
- self.tracks.append(
84
- Track(mean, covariance, self._next_id, self.n_init, self.max_age, detection)
85
- )
86
- self._next_id += 1
87
-
88
-
89
- class DeepSortTracker(BaseTracker):
90
- def __init__(self, settings: Dict = None):
91
- if settings is None:
92
- settings = {}
93
- super().__init__(settings)
94
- model_filename = "ViT-B/32" # initialize deep sort
95
- logger.info("Loading CLIP...")
96
- model, transform = clip.load(model_filename, device=self.device)
97
- self.encoder = gdet.create_box_encoder(model, transform, batch_size=1, device=self.device)
98
- metric = nn_matching.NearestNeighborDistanceMetric( # calculate cosine distance metric
99
- "cosine", self.args.max_cosine_distance, self.args.nn_budget
100
- )
101
- self.tracker = _dsTracker(metric, n_init=1)
102
-
103
- def default_settings(self):
104
- """To be overridden by subclasses."""
105
- return {"nms_max_overlap": 1.0, "max_cosine_distance": 0.6, "nn_budget": None}
106
-
107
- def track(
108
- self,
109
- source: Union[List[np.ndarray], List[str], str],
110
- frame_to_annotation: Dict[int, Annotation],
111
- frame_shape: Tuple[int, int],
112
- pbar_cb=None,
113
- ) -> VideoAnnotation:
114
- """
115
- Track objects in the video using DeepSort algorithm.
116
-
117
- :param source: List of images, paths to images or path to the video file.
118
- :type source: List[np.ndarray] | List[str] | str
119
- :param frame_to_annotation: Dictionary with frame index as key and Annotation as value.
120
- :type frame_to_annotation: Dict[int, Annotation]
121
- :param frame_shape: Size of the frame (height, width).
122
- :type frame_shape: Tuple[int, int]
123
- :param pbar_cb: Callback to update progress bar.
124
- :type pbar_cb: Callable, optional
125
-
126
- :return: Video annotation with tracked objects.
127
- :rtype: VideoAnnotation
128
-
129
- :raises ValueError: If number of images and annotations are not the same.
130
-
131
- :Usage example:
132
-
133
- .. code-block:: python
134
-
135
- import supervisely as sly
136
- from supervisely.nn.tracker import DeepSortTracker
137
-
138
- api = sly.Api()
139
-
140
- project_id = 12345
141
- video_id = 12345678
142
- video_path = "video.mp4"
143
-
144
- # Download video and get video info
145
- video_info = api.video.get_info_by_id(video_id)
146
- frame_shape = (video_info.frame_height, video_info.frame_width)
147
- api.video.download_path(id=video_id, path=video_path)
148
-
149
- # Run inference app to get detections
150
- task_id = 12345 # detection app task id
151
- session = sly.nn.inference.Session(api, task_id)
152
- annotations = session.inference_video_id(video_id, 0, video_info.frames_count)
153
- frame_to_annotation = {i: ann for i, ann in enumerate(annotations)}
154
-
155
- # Run tracker
156
- tracker = DeepSortTracker()
157
- video_ann = tracker.track(video_path, frame_to_annotation, frame_shape)
158
-
159
- # Upload result
160
- model_meta = session.get_model_meta()
161
- project_meta = sly.ProjectMeta.from_json(api.project.get_meta(project_id))
162
- project_meta = project_meta.merge(model_meta)
163
- api.project.update_meta(project_id, project_meta)
164
- api.video.annotation.append(video_id, video_ann)
165
- """
166
- if not isinstance(source, str):
167
- if len(source) != len(frame_to_annotation):
168
- raise ValueError("Number of images and annotations should be the same")
169
-
170
- tracks_data = {}
171
- logger.info("Starting deep_sort tracking with CLIP...")
172
-
173
- for frame_index, img in enumerate(self.frames_generator(source)):
174
- tracks_data = self.update(
175
- img, frame_to_annotation[frame_index], frame_index, tracks_data
176
- )
177
-
178
- if pbar_cb is not None:
179
- pbar_cb()
180
-
181
- tracks_data = self.clear_empty_ids(tracker_annotations=tracks_data)
182
-
183
- return self.get_annotation(
184
- tracks_data=tracks_data,
185
- frame_shape=frame_shape,
186
- frames_count=len(frame_to_annotation),
187
- )
188
-
189
- def update(
190
- self, img, annotation: Annotation, frame_index, tracks_data: Dict[int, List[Dict]] = None
191
- ):
192
- import torch
193
-
194
- detections = []
195
- try:
196
- pred, sly_labels = self.convert_annotation(annotation)
197
- det = torch.tensor(pred)
198
-
199
- # Process detections
200
- bboxes = det[:, :4].clone().cpu()
201
- # tlwh -> lthw
202
- bboxes = [bbox[[1, 0, 3, 2]] for bbox in bboxes]
203
- confs = det[:, 4]
204
-
205
- # encode yolo detections and feed to tracker
206
- features = self.encoder(img, bboxes)
207
- detections = [
208
- Detection(sly_label, bbox, conf, feature)
209
- for bbox, conf, feature, sly_label in zip(bboxes, confs, features, sly_labels)
210
- ]
211
-
212
- # run non-maxima supression
213
- boxs = np.array([d.tlwh for d in detections])
214
- scores = np.array([d.confidence for d in detections])
215
- class_nums = np.array([d.sly_label.obj_class.name for d in detections])
216
- indices_of_alive_labels = preprocessing.non_max_suppression(
217
- boxs, class_nums, self.args.nms_max_overlap, scores
218
- )
219
- detections = [detections[i] for i in indices_of_alive_labels]
220
- except Exception as ex:
221
- import traceback
222
-
223
- logger.info(f"frame {frame_index} skipped on tracking")
224
- logger.debug(traceback.format_exc())
225
-
226
- # Call the tracker
227
- self.tracker.predict()
228
- self.tracker.update(detections)
229
-
230
- if tracks_data is None:
231
- tracks_data = {}
232
- self.update_track_data(
233
- tracks_data=tracks_data,
234
- tracks=[
235
- track
236
- for track in self.tracker.tracks
237
- if track.is_confirmed() or track.time_since_update <= 1
238
- ],
239
- frame_index=frame_index,
240
- )
241
- return tracks_data
242
-
243
- def update_track_data(self, tracks_data: dict, tracks: List[BaseTrack], frame_index: int):
244
- track_id_data = []
245
- labels_data = []
246
-
247
- for curr_track in tracks:
248
- track_id = curr_track.track_id - 1 # track_id starts from 1
249
-
250
- if curr_track.get_sly_label() is not None:
251
- track_id_data.append(track_id)
252
- labels_data.append(curr_track.get_sly_label())
253
-
254
- tracks_data[frame_index] = {"ids": track_id_data, "labels": labels_data}
255
-
256
- return tracks_data
257
-
258
- def clear_empty_ids(self, tracker_annotations):
259
- id_mappings = {}
260
- last_ordinal_id = 0
261
-
262
- for frame_index, data in tracker_annotations.items():
263
- data_ids_temp = []
264
- for current_id in data["ids"]:
265
- new_id = id_mappings.get(current_id, -1)
266
- if new_id == -1:
267
- id_mappings[current_id] = last_ordinal_id
268
- last_ordinal_id += 1
269
- new_id = id_mappings.get(current_id, -1)
270
- data_ids_temp.append(new_id)
271
- data["ids"] = data_ids_temp
272
-
273
- return tracker_annotations
@@ -1,285 +0,0 @@
1
- import argparse
2
- import os
3
- from contextlib import contextmanager
4
- from typing import Dict, List, Union
5
-
6
- import cv2
7
- import numpy as np
8
-
9
- from supervisely import (
10
- Annotation,
11
- Frame,
12
- FrameCollection,
13
- Label,
14
- Rectangle,
15
- VideoAnnotation,
16
- VideoFigure,
17
- VideoObject,
18
- VideoObjectCollection,
19
- )
20
- from supervisely.sly_logger import logger
21
-
22
-
23
- class BaseDetection:
24
- """
25
- This class represents a bounding box detection in a single image.
26
-
27
- Parameters
28
- ----------
29
- tlwh : array_like
30
- Bounding box in format `(x, y, w, h)`.
31
- confidence : float
32
- Detector confidence score.
33
- feature : array_like | NoneType
34
- A feature vector that describes the object contained in this image.
35
- sly_label : Label | NoneType
36
- A Supervisely Label object
37
-
38
- Attributes
39
- ----------
40
- tlwh : ndarray
41
- Bounding box in format `(top left x, top left y, width, height)`.
42
- confidence : ndarray
43
- Detector confidence score.
44
- feature : ndarray | NoneType
45
- A feature vector that describes the object contained in this image.
46
- sly_label : Label | NoneType
47
- A Supervisely Label object
48
-
49
- """
50
-
51
- def __init__(self, tlwh, confidence: float, feature=None, sly_label: Label = None):
52
- self.tlwh = np.asarray(tlwh, dtype=float)
53
- self.confidence = float(confidence)
54
- self.feature = np.asarray(feature, dtype=np.float32)
55
- self._sly_label = sly_label
56
-
57
- def __iter__(self):
58
- return iter([*self.tlwh, self.confidence, self.feature])
59
-
60
- def tlbr(self):
61
- """Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
62
- `(top left, bottom right)`.
63
- """
64
- ret = self.tlwh.copy()
65
- ret[2:] += ret[:2]
66
- return ret
67
-
68
- def xyah(self):
69
- """Convert bounding box to format `(center x, center y, aspect ratio,
70
- height)`, where the aspect ratio is `width / height`.
71
- """
72
- ret = self.tlwh.copy()
73
- ret[:2] += ret[2:] / 2
74
- ret[2] /= ret[3]
75
- return ret
76
-
77
- @property
78
- def sly_label(self):
79
- return self._sly_label
80
-
81
- @sly_label.setter
82
- def sly_label(self, sly_label: Label):
83
- self._sly_label = sly_label
84
-
85
- def get_sly_label(self):
86
- return self.sly_label
87
-
88
- def set_sly_label(self, sly_label: Label):
89
- self.sly_label = sly_label
90
-
91
- def clean_sly_label(self):
92
- self.sly_label = None
93
-
94
-
95
- class BaseTrack:
96
- def __init__(self, track_id, *args, **kwargs):
97
- self.track_id = track_id
98
-
99
- def get_sly_label(self):
100
- raise NotImplementedError()
101
-
102
-
103
- class BaseTracker:
104
- def __init__(self, settings: Dict):
105
- self.settings = settings
106
- self.args = self.parse_settings(settings)
107
- self.device = self.select_device(device=self.args.device)
108
-
109
- def select_device(self, device="", batch_size=None):
110
- import torch # pylint: disable=import-error
111
-
112
- # device = 'cpu' or '0' or '0,1,2,3'
113
- cpu_request = device.lower() == "cpu"
114
- if device and not cpu_request: # if device requested other than 'cpu'
115
- os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable
116
- assert (
117
- torch.cuda.is_available()
118
- ), f"CUDA unavailable, invalid device {device} requested" # check availablity
119
-
120
- cuda = False if cpu_request else torch.cuda.is_available()
121
- if cuda:
122
- c = 1024**2 # bytes to MB
123
- ng = torch.cuda.device_count()
124
- if ng > 1 and batch_size: # check that batch_size is compatible with device_count
125
- assert (
126
- batch_size % ng == 0
127
- ), f"batch-size {batch_size} not multiple of GPU count {ng}"
128
- x = [torch.cuda.get_device_properties(i) for i in range(ng)]
129
- s = f"Using torch {torch.__version__} "
130
- for i, d in enumerate((device or "0").split(",")):
131
- if i == 1:
132
- s = " " * len(s)
133
- logger.info(f"{s}CUDA:{d} ({x[i].name}, {x[i].total_memory / c}MB)")
134
- else:
135
- logger.info(f"Using torch {torch.__version__} CPU")
136
-
137
- logger.info("") # skip a line
138
- return torch.device("cuda:0" if cuda else "cpu")
139
-
140
- def parse_settings(self, settings: Dict) -> argparse.Namespace:
141
- _settings = self.default_settings()
142
- _settings.update(settings)
143
- if "device" not in _settings:
144
- _settings["device"] = ""
145
- return argparse.Namespace(**_settings)
146
-
147
- def default_settings(self):
148
- """To be overridden by subclasses."""
149
- return {}
150
-
151
- @contextmanager
152
- def _video_frames_generator(self, video_path: str):
153
- cap = cv2.VideoCapture(video_path)
154
- try:
155
- while cap.isOpened():
156
- ret, frame = cap.read()
157
- if not ret:
158
- break
159
- yield frame
160
- finally:
161
- cap.release()
162
-
163
- def frames_generator(self, source: str):
164
- if isinstance(source, str):
165
-
166
- def _gen():
167
- with self._video_frames_generator(source) as frames:
168
- for frame in frames:
169
- yield frame
170
-
171
- return _gen()
172
- elif isinstance(source, list) and isinstance(source[0], str):
173
- return [cv2.imread(img) for img in source]
174
- else:
175
- return source
176
-
177
- def track(
178
- self,
179
- source: Union[List[np.ndarray], List[str], str],
180
- frame_to_annotation: Dict[int, Annotation],
181
- pbar_cb=None,
182
- ):
183
- """To be overridden by subclasses."""
184
- raise NotImplementedError()
185
-
186
- def convert_annotation(self, annotation_for_frame: Annotation):
187
- formatted_predictions = []
188
- sly_labels = []
189
-
190
- for label in annotation_for_frame.labels:
191
- confidence = 1.0
192
- if label.tags.get("confidence", None) is not None:
193
- confidence = label.tags.get("confidence").value
194
- elif label.tags.get("conf", None) is not None:
195
- confidence = label.tags.get("conf").value
196
-
197
- rectangle: Rectangle = label.geometry.to_bbox()
198
- tlwh = [
199
- rectangle.top,
200
- rectangle.left,
201
- rectangle.height,
202
- rectangle.width,
203
- confidence,
204
- ]
205
-
206
- formatted_predictions.append(tlwh)
207
- sly_labels.append(label)
208
-
209
- return formatted_predictions, sly_labels
210
-
211
- def update(
212
- self, img, annotation: Annotation, frame_index, tracks_data: Dict[int, List[Dict]] = None
213
- ):
214
- raise NotImplementedError()
215
-
216
- def correct_figure(self, img_size, figure): # img_size — height, width tuple
217
- # check figure is within image bounds
218
- canvas_rect = Rectangle.from_size(img_size)
219
- if canvas_rect.contains(figure.to_bbox()) is False:
220
- # crop figure
221
- figures_after_crop = figure.crop(canvas_rect)
222
- if len(figures_after_crop) > 0:
223
- return figures_after_crop[0]
224
- else:
225
- return None
226
- else:
227
- return figure
228
-
229
- def update_track_data(self, tracks_data: dict, tracks: List[BaseTrack], frame_index: int):
230
- track_id_data = []
231
- labels_data = []
232
-
233
- for curr_track in tracks:
234
- track_id = curr_track.track_id
235
-
236
- if curr_track.get_sly_label() is not None:
237
- track_id_data.append(track_id)
238
- labels_data.append(curr_track.get_sly_label())
239
-
240
- tracks_data[frame_index] = {"ids": track_id_data, "labels": labels_data}
241
-
242
- return tracks_data
243
-
244
- def get_annotation(self, tracks_data: Dict, frame_shape, frames_count) -> VideoAnnotation:
245
- # Create and count object classes for each track
246
- object_classes = {} # object_class_name -> object_class
247
- object_class_counter = {} # track_id -> object_class_name -> count
248
- for frame_index, data in tracks_data.items():
249
- for track_id, label in zip(data["ids"], data["labels"]):
250
- label: Label
251
- object_classes.setdefault(label.obj_class.name, label.obj_class)
252
- object_class_counter.setdefault(track_id, {}).setdefault(label.obj_class.name, 0)
253
- object_class_counter[track_id][label.obj_class.name] += 1
254
-
255
- # Assign object classes to tracks
256
- track_obj_classes = {} # track_id -> object_class
257
- for track_id, counters in object_class_counter.items():
258
- max_counter = -1
259
- obj_class_name = None
260
- for obj_class_name, count in counters.items():
261
- if count > max_counter:
262
- max_counter = count
263
- obj_class_name = obj_class_name
264
- track_obj_classes[track_id] = object_classes[obj_class_name]
265
-
266
- # Create video objects, figures and frames
267
- video_objects = {} # track_id -> VideoObject
268
- frames = []
269
- for frame_index, data in tracks_data.items():
270
- frame_figures = []
271
- for track_id, label in zip(data["ids"], data["labels"]):
272
- label: Label
273
- video_object = video_objects.setdefault(
274
- track_id, VideoObject(track_obj_classes[track_id])
275
- )
276
- frame_figures.append(VideoFigure(video_object, label.geometry, frame_index))
277
- frames.append(Frame(frame_index, frame_figures))
278
-
279
- objects = list(video_objects.values())
280
- return VideoAnnotation(
281
- img_size=frame_shape,
282
- frames_count=frames_count,
283
- objects=VideoObjectCollection(objects),
284
- frames=FrameCollection(frames),
285
- )