supervisely 6.73.356__py3-none-any.whl → 6.73.358__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/_utils.py +12 -0
- supervisely/api/annotation_api.py +3 -0
- supervisely/api/api.py +2 -2
- supervisely/api/app_api.py +27 -2
- supervisely/api/entity_annotation/tag_api.py +0 -1
- supervisely/api/labeling_job_api.py +4 -1
- supervisely/api/nn/__init__.py +0 -0
- supervisely/api/nn/deploy_api.py +821 -0
- supervisely/api/nn/neural_network_api.py +248 -0
- supervisely/api/task_api.py +26 -467
- supervisely/app/fastapi/subapp.py +1 -0
- supervisely/nn/__init__.py +2 -1
- supervisely/nn/artifacts/artifacts.py +5 -5
- supervisely/nn/benchmark/object_detection/metric_provider.py +3 -0
- supervisely/nn/experiments.py +28 -5
- supervisely/nn/inference/cache.py +178 -114
- supervisely/nn/inference/gui/gui.py +18 -35
- supervisely/nn/inference/gui/serving_gui.py +3 -1
- supervisely/nn/inference/inference.py +1421 -1265
- supervisely/nn/inference/inference_request.py +412 -0
- supervisely/nn/inference/object_detection_3d/object_detection_3d.py +31 -24
- supervisely/nn/inference/session.py +2 -2
- supervisely/nn/inference/tracking/base_tracking.py +45 -79
- supervisely/nn/inference/tracking/bbox_tracking.py +220 -155
- supervisely/nn/inference/tracking/mask_tracking.py +274 -250
- supervisely/nn/inference/tracking/tracker_interface.py +23 -0
- supervisely/nn/inference/uploader.py +164 -0
- supervisely/nn/model/__init__.py +0 -0
- supervisely/nn/model/model_api.py +259 -0
- supervisely/nn/model/prediction.py +311 -0
- supervisely/nn/model/prediction_session.py +632 -0
- supervisely/nn/tracking/__init__.py +1 -0
- supervisely/nn/tracking/boxmot.py +114 -0
- supervisely/nn/tracking/tracking.py +24 -0
- supervisely/nn/training/train_app.py +61 -19
- supervisely/nn/utils.py +43 -3
- supervisely/task/progress.py +12 -2
- supervisely/video/video.py +107 -1
- supervisely/volume_annotation/volume_figure.py +8 -2
- {supervisely-6.73.356.dist-info → supervisely-6.73.358.dist-info}/METADATA +2 -1
- {supervisely-6.73.356.dist-info → supervisely-6.73.358.dist-info}/RECORD +45 -34
- supervisely/api/neural_network_api.py +0 -202
- {supervisely-6.73.356.dist-info → supervisely-6.73.358.dist-info}/LICENSE +0 -0
- {supervisely-6.73.356.dist-info → supervisely-6.73.358.dist-info}/WHEEL +0 -0
- {supervisely-6.73.356.dist-info → supervisely-6.73.358.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.356.dist-info → supervisely-6.73.358.dist-info}/top_level.txt +0 -0
|
@@ -1,13 +1,12 @@
|
|
|
1
|
+
import inspect
|
|
1
2
|
import time
|
|
2
3
|
import uuid
|
|
3
|
-
from queue import Queue
|
|
4
|
-
from threading import Event, Thread
|
|
5
4
|
from typing import BinaryIO, Dict, List, Tuple
|
|
6
5
|
|
|
7
6
|
import numpy as np
|
|
8
7
|
from pydantic import ValidationError
|
|
9
8
|
|
|
10
|
-
from supervisely._utils import find_value_by_keys
|
|
9
|
+
from supervisely._utils import find_value_by_keys, get_valid_kwargs
|
|
11
10
|
from supervisely.annotation.label import Geometry, Label
|
|
12
11
|
from supervisely.annotation.obj_class import ObjClass
|
|
13
12
|
from supervisely.api.api import Api
|
|
@@ -17,13 +16,14 @@ from supervisely.geometry.bitmap import Bitmap
|
|
|
17
16
|
from supervisely.geometry.helpers import deserialize_geometry
|
|
18
17
|
from supervisely.geometry.polygon import Polygon
|
|
19
18
|
from supervisely.imaging import image as sly_image
|
|
19
|
+
from supervisely.nn.inference.inference_request import InferenceRequest
|
|
20
20
|
from supervisely.nn.inference.tracking.base_tracking import BaseTracking
|
|
21
21
|
from supervisely.nn.inference.tracking.tracker_interface import (
|
|
22
22
|
TrackerInterface,
|
|
23
23
|
TrackerInterfaceV2,
|
|
24
24
|
)
|
|
25
|
+
from supervisely.nn.inference.uploader import Uploader
|
|
25
26
|
from supervisely.sly_logger import logger
|
|
26
|
-
from supervisely.task.progress import Progress
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class MaskTracking(BaseTracking):
|
|
@@ -92,19 +92,20 @@ class MaskTracking(BaseTracking):
|
|
|
92
92
|
)
|
|
93
93
|
return results
|
|
94
94
|
|
|
95
|
-
def _track(self, api: Api, context: Dict):
|
|
96
|
-
|
|
95
|
+
def _track(self, api: Api, context: Dict, inference_request: InferenceRequest):
|
|
96
|
+
video_interface = TrackerInterface(
|
|
97
97
|
context=context,
|
|
98
98
|
api=api,
|
|
99
|
-
load_all_frames=
|
|
99
|
+
load_all_frames=False,
|
|
100
100
|
notify_in_predict=True,
|
|
101
101
|
per_point_polygon_tracking=False,
|
|
102
102
|
frame_loader=self.cache.download_frame,
|
|
103
103
|
frames_loader=self.cache.download_frames,
|
|
104
104
|
)
|
|
105
|
+
video_interface.stop += video_interface.frames_count + 1
|
|
105
106
|
range_of_frames = [
|
|
106
|
-
|
|
107
|
-
|
|
107
|
+
video_interface.frames_indexes[0],
|
|
108
|
+
video_interface.frames_indexes[-1],
|
|
108
109
|
]
|
|
109
110
|
|
|
110
111
|
if self.cache.is_persistent:
|
|
@@ -112,58 +113,201 @@ class MaskTracking(BaseTracking):
|
|
|
112
113
|
self.cache.run_cache_task_manually(
|
|
113
114
|
api,
|
|
114
115
|
None,
|
|
115
|
-
video_id=
|
|
116
|
+
video_id=video_interface.video_id,
|
|
116
117
|
)
|
|
117
118
|
else:
|
|
118
119
|
# if cache is not persistent, run cache task for range of frames
|
|
119
120
|
self.cache.run_cache_task_manually(
|
|
120
121
|
api,
|
|
121
122
|
[range_of_frames],
|
|
122
|
-
video_id=
|
|
123
|
+
video_id=video_interface.video_id,
|
|
123
124
|
)
|
|
124
125
|
|
|
125
|
-
api.logger.
|
|
126
|
+
api.logger.debug("frames_count = %s", video_interface.frames_count)
|
|
127
|
+
inference_request.set_stage("Downloading frames", 0, video_interface.frames_count)
|
|
126
128
|
# load frames
|
|
127
|
-
|
|
129
|
+
|
|
130
|
+
def _load_frames_cb(n: int = 1):
|
|
131
|
+
inference_request.done(n)
|
|
132
|
+
video_interface._notify(pos_increment=n, task="Downloading frames")
|
|
133
|
+
|
|
134
|
+
frames = self.cache.download_frames(
|
|
135
|
+
api,
|
|
136
|
+
video_interface.video_id,
|
|
137
|
+
video_interface.frames_indexes,
|
|
138
|
+
progress_cb=_load_frames_cb,
|
|
139
|
+
)
|
|
140
|
+
|
|
128
141
|
# combine several binary masks into one multilabel mask
|
|
129
|
-
i =
|
|
142
|
+
i = 1
|
|
130
143
|
label2id = {}
|
|
144
|
+
multilabel_mask = np.zeros(frames[0].shape, dtype=np.uint8)
|
|
145
|
+
for (fig_id, geometry), obj_id in zip(
|
|
146
|
+
video_interface.geometries.items(),
|
|
147
|
+
video_interface.object_ids,
|
|
148
|
+
):
|
|
149
|
+
original_geometry = geometry.clone()
|
|
150
|
+
if not isinstance(geometry, Bitmap) and not isinstance(geometry, Polygon):
|
|
151
|
+
raise TypeError(f"This app does not support {geometry.geometry_name()} tracking")
|
|
152
|
+
# convert polygon to bitmap
|
|
153
|
+
if isinstance(geometry, Polygon):
|
|
154
|
+
polygon_obj_class = ObjClass("polygon", Polygon)
|
|
155
|
+
polygon_label = Label(geometry, polygon_obj_class)
|
|
156
|
+
bitmap_obj_class = ObjClass("bitmap", Bitmap)
|
|
157
|
+
bitmap_label = polygon_label.convert(bitmap_obj_class)[0]
|
|
158
|
+
geometry = bitmap_label.geometry
|
|
159
|
+
geometry.draw(bitmap=multilabel_mask, color=i)
|
|
160
|
+
label2id[i] = {
|
|
161
|
+
"fig_id": fig_id,
|
|
162
|
+
"obj_id": obj_id,
|
|
163
|
+
"original_geometry": original_geometry.geometry_name(),
|
|
164
|
+
}
|
|
165
|
+
i += 1
|
|
131
166
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
daemon=True,
|
|
157
|
-
).start()
|
|
158
|
-
|
|
159
|
-
try:
|
|
160
|
-
for (fig_id, geometry), obj_id in zip(
|
|
161
|
-
self.video_interface.geometries.items(),
|
|
162
|
-
self.video_interface.object_ids,
|
|
167
|
+
unique_labels = np.unique(multilabel_mask)
|
|
168
|
+
if 0 in unique_labels:
|
|
169
|
+
unique_labels = unique_labels[1:]
|
|
170
|
+
api.logger.debug("unique_labels = %s", unique_labels)
|
|
171
|
+
total_progress = len(unique_labels) * video_interface.frames_count
|
|
172
|
+
api.logger.info("Starting tracking process")
|
|
173
|
+
api.logger.debug("total_progress = %s", total_progress)
|
|
174
|
+
inference_request.set_stage(
|
|
175
|
+
InferenceRequest.Stage.INFERENCE,
|
|
176
|
+
0,
|
|
177
|
+
total_progress,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
def _upload_f(items: List):
|
|
181
|
+
video_interface.add_object_geometries_on_frames(*list(zip(*items)))
|
|
182
|
+
inference_request.done(sum(item[-1] for item in items))
|
|
183
|
+
|
|
184
|
+
with Uploader(upload_f=_upload_f, logger=api.logger) as uploader:
|
|
185
|
+
# run tracker
|
|
186
|
+
tracked_multilabel_masks = self.predict(
|
|
187
|
+
frames=frames, input_mask=multilabel_mask[:, :, 0]
|
|
188
|
+
)
|
|
189
|
+
for curframe_i, mask in enumerate(
|
|
190
|
+
tracked_multilabel_masks, video_interface.frame_index
|
|
163
191
|
):
|
|
192
|
+
if curframe_i == video_interface.frame_index:
|
|
193
|
+
continue
|
|
194
|
+
for i in unique_labels:
|
|
195
|
+
binary_mask = mask == i
|
|
196
|
+
fig_id = label2id[i]["fig_id"]
|
|
197
|
+
obj_id = label2id[i]["obj_id"]
|
|
198
|
+
geometry_type = label2id[i]["original_geometry"]
|
|
199
|
+
if not np.any(binary_mask):
|
|
200
|
+
api.logger.info(f"Skipping empty mask on frame {curframe_i}")
|
|
201
|
+
inference_request.done()
|
|
202
|
+
else:
|
|
203
|
+
if geometry_type == "polygon":
|
|
204
|
+
bitmap_geometry = Bitmap(binary_mask)
|
|
205
|
+
bitmap_obj_class = ObjClass("bitmap", Bitmap)
|
|
206
|
+
bitmap_label = Label(bitmap_geometry, bitmap_obj_class)
|
|
207
|
+
polygon_obj_class = ObjClass("polygon", Polygon)
|
|
208
|
+
polygon_labels = bitmap_label.convert(polygon_obj_class)
|
|
209
|
+
geometries = [label.geometry for label in polygon_labels]
|
|
210
|
+
else:
|
|
211
|
+
geometries = [Bitmap(binary_mask)]
|
|
212
|
+
uploader.put(
|
|
213
|
+
[
|
|
214
|
+
(
|
|
215
|
+
geometry,
|
|
216
|
+
obj_id,
|
|
217
|
+
curframe_i,
|
|
218
|
+
True if g_idx == len(geometries) - 1 else False,
|
|
219
|
+
)
|
|
220
|
+
for g_idx, geometry in enumerate(geometries)
|
|
221
|
+
]
|
|
222
|
+
)
|
|
223
|
+
if inference_request.is_stopped() or video_interface.global_stop_indicatior:
|
|
224
|
+
api.logger.info(
|
|
225
|
+
"Tracking stopped by user",
|
|
226
|
+
extra={"inference_request_uuid": inference_request.uuid},
|
|
227
|
+
)
|
|
228
|
+
video_interface._notify(True, task="Stop tracking")
|
|
229
|
+
return
|
|
230
|
+
if uploader.has_exception():
|
|
231
|
+
raise uploader.exception
|
|
232
|
+
|
|
233
|
+
api.logger.info(f"Frame {curframe_i} was successfully tracked")
|
|
234
|
+
|
|
235
|
+
def _track_async(self, api: Api, context: dict, inference_request: InferenceRequest):
|
|
236
|
+
tracker_interface = TrackerInterfaceV2(api, context, self.cache)
|
|
237
|
+
frames_count = tracker_interface.frames_count
|
|
238
|
+
figures = tracker_interface.figures
|
|
239
|
+
progress_total = frames_count * len(figures)
|
|
240
|
+
frame_range = [
|
|
241
|
+
tracker_interface.frame_indexes[0],
|
|
242
|
+
tracker_interface.frame_indexes[-1],
|
|
243
|
+
]
|
|
244
|
+
frame_range_asc = [min(frame_range), max(frame_range)]
|
|
245
|
+
|
|
246
|
+
def _upload_f(items: List[Tuple[FigureInfo, bool]]):
|
|
247
|
+
inference_request.add_results([item[0] for item in items])
|
|
248
|
+
inference_request.done(sum(item[1] for item in items))
|
|
249
|
+
|
|
250
|
+
def _notify_f(items: List[Tuple[FigureInfo, bool]]):
|
|
251
|
+
frame_range = [
|
|
252
|
+
min(item[0].frame_index for item in items),
|
|
253
|
+
max(item[0].frame_index for item in items),
|
|
254
|
+
]
|
|
255
|
+
tracker_interface.notify_progress(
|
|
256
|
+
inference_request.progress.current, inference_request.progress.total, frame_range
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
def _exception_handler(exception: Exception):
|
|
260
|
+
api.logger.error(f"Error saving predictions: {str(exception)}", exc_info=True)
|
|
261
|
+
tracker_interface.notify_progress(
|
|
262
|
+
inference_request.progress.current,
|
|
263
|
+
inference_request.progress.current,
|
|
264
|
+
frame_range_asc,
|
|
265
|
+
)
|
|
266
|
+
tracker_interface.notify_error(exception)
|
|
267
|
+
raise Exception
|
|
268
|
+
|
|
269
|
+
def _maybe_stop():
|
|
270
|
+
if inference_request.is_stopped() or tracker_interface.is_stopped():
|
|
271
|
+
if isinstance(tracker_interface.stop_reason(), Exception):
|
|
272
|
+
raise tracker_interface.stop_reason()
|
|
273
|
+
api.logger.info(
|
|
274
|
+
"Inference request stopped.",
|
|
275
|
+
extra={"inference_request_uuid": inference_request.uuid},
|
|
276
|
+
)
|
|
277
|
+
tracker_interface.notify_progress(
|
|
278
|
+
inference_request.progress.current,
|
|
279
|
+
inference_request.progress.current,
|
|
280
|
+
frame_range_asc,
|
|
281
|
+
)
|
|
282
|
+
return True
|
|
283
|
+
if uploader.has_exception():
|
|
284
|
+
raise uploader.exception
|
|
285
|
+
return False
|
|
286
|
+
|
|
287
|
+
# run tracker
|
|
288
|
+
frame_index = tracker_interface.frame_index
|
|
289
|
+
direction_n = tracker_interface.direction_n
|
|
290
|
+
api.logger.info("Start tracking.")
|
|
291
|
+
inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
|
|
292
|
+
with Uploader(
|
|
293
|
+
upload_f=_upload_f,
|
|
294
|
+
notify_f=_notify_f,
|
|
295
|
+
exception_handler=_exception_handler,
|
|
296
|
+
logger=api.logger,
|
|
297
|
+
) as uploader:
|
|
298
|
+
# combine several binary masks into one multilabel mask
|
|
299
|
+
i = 0
|
|
300
|
+
label2id = {}
|
|
301
|
+
# load frames
|
|
302
|
+
frames = tracker_interface.load_all_frames()
|
|
303
|
+
frames = [frame.image for frame in frames]
|
|
304
|
+
for figure in figures:
|
|
305
|
+
figure = api.video.figure._convert_json_info(figure)
|
|
306
|
+
fig_id = figure.id
|
|
307
|
+
obj_id = figure.object_id
|
|
308
|
+
geometry = deserialize_geometry(figure.geometry_type, figure.geometry)
|
|
164
309
|
original_geometry = geometry.clone()
|
|
165
|
-
if not isinstance(geometry, Bitmap
|
|
166
|
-
stop_upload_event.set()
|
|
310
|
+
if not isinstance(geometry, (Bitmap, Polygon)):
|
|
167
311
|
raise TypeError(
|
|
168
312
|
f"This app does not support {geometry.geometry_name()} tracking"
|
|
169
313
|
)
|
|
@@ -187,26 +331,34 @@ class MaskTracking(BaseTracking):
|
|
|
187
331
|
"obj_id": obj_id,
|
|
188
332
|
"original_geometry": original_geometry.geometry_name(),
|
|
189
333
|
}
|
|
190
|
-
|
|
334
|
+
|
|
335
|
+
if _maybe_stop():
|
|
336
|
+
return
|
|
337
|
+
|
|
338
|
+
# predict
|
|
191
339
|
tracked_multilabel_masks = self.predict(
|
|
192
340
|
frames=frames, input_mask=multilabel_mask[:, :, 0]
|
|
193
341
|
)
|
|
194
342
|
tracked_multilabel_masks = np.array(tracked_multilabel_masks)
|
|
343
|
+
|
|
195
344
|
# decompose multilabel masks into binary masks
|
|
196
345
|
for i in np.unique(tracked_multilabel_masks):
|
|
346
|
+
if _maybe_stop():
|
|
347
|
+
return
|
|
197
348
|
if i != 0:
|
|
198
349
|
binary_masks = tracked_multilabel_masks == i
|
|
199
350
|
fig_id = label2id[i]["fig_id"]
|
|
200
351
|
obj_id = label2id[i]["obj_id"]
|
|
201
352
|
geometry_type = label2id[i]["original_geometry"]
|
|
202
|
-
for j, mask in enumerate(binary_masks[1:]):
|
|
353
|
+
for j, mask in enumerate(binary_masks[1:], 1):
|
|
354
|
+
if _maybe_stop():
|
|
355
|
+
return
|
|
356
|
+
this_figure_index = frame_index + j * direction_n
|
|
203
357
|
# check if mask is not empty
|
|
204
358
|
if not np.any(mask):
|
|
205
|
-
api.logger.info(
|
|
206
|
-
f"Skipping empty mask on frame {self.video_interface.frame_index + j + 1}"
|
|
207
|
-
)
|
|
359
|
+
api.logger.info(f"Skipping empty mask on frame {this_figure_index}")
|
|
208
360
|
# update progress bar anyway (otherwise it will not be finished)
|
|
209
|
-
|
|
361
|
+
inference_request.done()
|
|
210
362
|
else:
|
|
211
363
|
if geometry_type == "polygon":
|
|
212
364
|
bitmap_geometry = Bitmap(mask)
|
|
@@ -218,184 +370,29 @@ class MaskTracking(BaseTracking):
|
|
|
218
370
|
else:
|
|
219
371
|
geometries = [Bitmap(mask)]
|
|
220
372
|
for l, geometry in enumerate(geometries):
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
373
|
+
figure_id = uuid.uuid5(
|
|
374
|
+
namespace=uuid.NAMESPACE_URL, name=f"{time.time()}"
|
|
375
|
+
).hex
|
|
376
|
+
result_figure = api.video.figure._convert_json_info(
|
|
377
|
+
{
|
|
378
|
+
ApiField.ID: figure_id,
|
|
379
|
+
ApiField.OBJECT_ID: obj_id,
|
|
380
|
+
"meta": {"frame": this_figure_index},
|
|
381
|
+
ApiField.GEOMETRY_TYPE: geometry.geometry_name(),
|
|
382
|
+
ApiField.GEOMETRY: geometry.to_json(),
|
|
383
|
+
ApiField.TRACK_ID: tracker_interface.track_id,
|
|
384
|
+
}
|
|
232
385
|
)
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
def _track_async(self, api: Api, context: dict, inference_request_uuid: str = None):
|
|
243
|
-
inference_request = self._inference_requests[inference_request_uuid]
|
|
244
|
-
tracker_interface = TrackerInterfaceV2(api, context, self.cache)
|
|
245
|
-
progress: Progress = inference_request["progress"]
|
|
246
|
-
frames_count = tracker_interface.frames_count
|
|
247
|
-
figures = tracker_interface.figures
|
|
248
|
-
progress_total = frames_count * len(figures)
|
|
249
|
-
progress.total = progress_total
|
|
250
|
-
|
|
251
|
-
def _upload_f(items: List[Tuple[FigureInfo, bool]]):
|
|
252
|
-
with inference_request["lock"]:
|
|
253
|
-
inference_request["pending_results"].extend([item[0] for item in items])
|
|
254
|
-
|
|
255
|
-
def _notify_f(items: List[Tuple[FigureInfo, bool]]):
|
|
256
|
-
items_by_object_id: Dict[int, List[Tuple[FigureInfo, bool]]] = {}
|
|
257
|
-
for item in items:
|
|
258
|
-
items_by_object_id.setdefault(item[0].object_id, []).append(item)
|
|
259
|
-
|
|
260
|
-
for object_id, object_items in items_by_object_id.items():
|
|
261
|
-
frame_range = [
|
|
262
|
-
min(item[0].frame_index for item in object_items),
|
|
263
|
-
max(item[0].frame_index for item in object_items),
|
|
264
|
-
]
|
|
265
|
-
progress.iters_done_report(sum(1 for item in object_items if item[1]))
|
|
266
|
-
tracker_interface.notify_progress(progress.current, progress.total, frame_range)
|
|
267
|
-
|
|
268
|
-
# run tracker
|
|
269
|
-
frame_index = tracker_interface.frame_index
|
|
270
|
-
direction_n = tracker_interface.direction_n
|
|
271
|
-
api.logger.info("Start tracking.")
|
|
272
|
-
try:
|
|
273
|
-
with tracker_interface(_upload_f, _notify_f):
|
|
274
|
-
# combine several binary masks into one multilabel mask
|
|
275
|
-
i = 0
|
|
276
|
-
label2id = {}
|
|
277
|
-
# load frames
|
|
278
|
-
frames = tracker_interface.load_all_frames()
|
|
279
|
-
frames = [frame.image for frame in frames]
|
|
280
|
-
for figure in figures:
|
|
281
|
-
figure = api.video.figure._convert_json_info(figure)
|
|
282
|
-
fig_id = figure.id
|
|
283
|
-
obj_id = figure.object_id
|
|
284
|
-
geometry = deserialize_geometry(figure.geometry_type, figure.geometry)
|
|
285
|
-
original_geometry = geometry.clone()
|
|
286
|
-
if not isinstance(geometry, (Bitmap, Polygon)):
|
|
287
|
-
raise TypeError(
|
|
288
|
-
f"This app does not support {geometry.geometry_name()} tracking"
|
|
289
|
-
)
|
|
290
|
-
# convert polygon to bitmap
|
|
291
|
-
if isinstance(geometry, Polygon):
|
|
292
|
-
polygon_obj_class = ObjClass("polygon", Polygon)
|
|
293
|
-
polygon_label = Label(geometry, polygon_obj_class)
|
|
294
|
-
bitmap_obj_class = ObjClass("bitmap", Bitmap)
|
|
295
|
-
bitmap_label = polygon_label.convert(bitmap_obj_class)[0]
|
|
296
|
-
geometry = bitmap_label.geometry
|
|
297
|
-
if i == 0:
|
|
298
|
-
multilabel_mask = geometry.data.astype(int)
|
|
299
|
-
multilabel_mask = np.zeros(frames[0].shape, dtype=np.uint8)
|
|
300
|
-
geometry.draw(bitmap=multilabel_mask, color=[1, 1, 1])
|
|
301
|
-
i += 1
|
|
302
|
-
else:
|
|
303
|
-
i += 1
|
|
304
|
-
geometry.draw(bitmap=multilabel_mask, color=[i, i, i])
|
|
305
|
-
label2id[i] = {
|
|
306
|
-
"fig_id": fig_id,
|
|
307
|
-
"obj_id": obj_id,
|
|
308
|
-
"original_geometry": original_geometry.geometry_name(),
|
|
309
|
-
}
|
|
310
|
-
if inference_request["cancel_inference"]:
|
|
311
|
-
return
|
|
312
|
-
if tracker_interface.is_stopped():
|
|
313
|
-
reason = tracker_interface.stop_reason()
|
|
314
|
-
if isinstance(reason, Exception):
|
|
315
|
-
raise reason
|
|
316
|
-
return
|
|
317
|
-
|
|
318
|
-
# predict
|
|
319
|
-
tracked_multilabel_masks = self.predict(
|
|
320
|
-
frames=frames, input_mask=multilabel_mask[:, :, 0]
|
|
321
|
-
)
|
|
322
|
-
tracked_multilabel_masks = np.array(tracked_multilabel_masks)
|
|
323
|
-
|
|
324
|
-
# decompose multilabel masks into binary masks
|
|
325
|
-
for i in np.unique(tracked_multilabel_masks):
|
|
326
|
-
if inference_request["cancel_inference"]:
|
|
327
|
-
return
|
|
328
|
-
if tracker_interface.is_stopped():
|
|
329
|
-
reason = tracker_interface.stop_reason()
|
|
330
|
-
if isinstance(reason, Exception):
|
|
331
|
-
raise reason
|
|
332
|
-
return
|
|
333
|
-
if i != 0:
|
|
334
|
-
binary_masks = tracked_multilabel_masks == i
|
|
335
|
-
fig_id = label2id[i]["fig_id"]
|
|
336
|
-
obj_id = label2id[i]["obj_id"]
|
|
337
|
-
geometry_type = label2id[i]["original_geometry"]
|
|
338
|
-
for j, mask in enumerate(binary_masks[1:], 1):
|
|
339
|
-
if inference_request["cancel_inference"]:
|
|
340
|
-
return
|
|
341
|
-
if tracker_interface.is_stopped():
|
|
342
|
-
reason = tracker_interface.stop_reason()
|
|
343
|
-
if isinstance(reason, Exception):
|
|
344
|
-
raise reason
|
|
345
|
-
return
|
|
346
|
-
this_figure_index = frame_index + j * direction_n
|
|
347
|
-
# check if mask is not empty
|
|
348
|
-
if not np.any(mask):
|
|
349
|
-
api.logger.info(f"Skipping empty mask on frame {this_figure_index}")
|
|
350
|
-
# update progress bar anyway (otherwise it will not be finished)
|
|
351
|
-
progress.iter_done_report()
|
|
352
|
-
else:
|
|
353
|
-
if geometry_type == "polygon":
|
|
354
|
-
bitmap_geometry = Bitmap(mask)
|
|
355
|
-
bitmap_obj_class = ObjClass("bitmap", Bitmap)
|
|
356
|
-
bitmap_label = Label(bitmap_geometry, bitmap_obj_class)
|
|
357
|
-
polygon_obj_class = ObjClass("polygon", Polygon)
|
|
358
|
-
polygon_labels = bitmap_label.convert(polygon_obj_class)
|
|
359
|
-
geometries = [label.geometry for label in polygon_labels]
|
|
360
|
-
else:
|
|
361
|
-
geometries = [Bitmap(mask)]
|
|
362
|
-
for l, geometry in enumerate(geometries):
|
|
363
|
-
figure_id = uuid.uuid5(
|
|
364
|
-
namespace=uuid.NAMESPACE_URL, name=f"{time.time()}"
|
|
365
|
-
).hex
|
|
366
|
-
result_figure = api.video.figure._convert_json_info(
|
|
367
|
-
{
|
|
368
|
-
ApiField.ID: figure_id,
|
|
369
|
-
ApiField.OBJECT_ID: obj_id,
|
|
370
|
-
"meta": {"frame": this_figure_index},
|
|
371
|
-
ApiField.GEOMETRY_TYPE: geometry.geometry_name(),
|
|
372
|
-
ApiField.GEOMETRY: geometry.to_json(),
|
|
373
|
-
ApiField.TRACK_ID: tracker_interface.track_id,
|
|
374
|
-
}
|
|
375
|
-
)
|
|
376
|
-
should_notify = l == len(geometries) - 1
|
|
377
|
-
tracker_interface.add_prediction((result_figure, should_notify))
|
|
378
|
-
api.logger.info(
|
|
379
|
-
"Figure [%d, %d] tracked.",
|
|
380
|
-
i,
|
|
381
|
-
len(figures),
|
|
382
|
-
extra={"figure_id": figure.id},
|
|
383
|
-
)
|
|
384
|
-
except Exception:
|
|
385
|
-
progress.message = "Error occured during tracking"
|
|
386
|
-
raise
|
|
387
|
-
else:
|
|
388
|
-
progress.message = "Ready"
|
|
389
|
-
finally:
|
|
390
|
-
progress.set(current=0, total=1, report=True)
|
|
391
|
-
|
|
392
|
-
# Implement the following methods in the derived class
|
|
393
|
-
def track(self, api: Api, state: Dict, context: Dict):
|
|
394
|
-
fn = self.send_error_data(api, context)(self._track)
|
|
395
|
-
self.schedule_task(fn, api, context)
|
|
396
|
-
return {"message": "Tracking has started."}
|
|
386
|
+
should_notify = l == len(geometries) - 1
|
|
387
|
+
tracker_interface.add_prediction((result_figure, should_notify))
|
|
388
|
+
api.logger.info(
|
|
389
|
+
"Figure [%d, %d] tracked.",
|
|
390
|
+
i,
|
|
391
|
+
len(figures),
|
|
392
|
+
extra={"figure_id": figure.id},
|
|
393
|
+
)
|
|
397
394
|
|
|
398
|
-
def
|
|
395
|
+
def _track_api(self, api: Api, context: Dict, inference_request: InferenceRequest):
|
|
399
396
|
# unused fields:
|
|
400
397
|
context["trackId"] = "auto"
|
|
401
398
|
context["objectIds"] = []
|
|
@@ -405,7 +402,7 @@ class MaskTracking(BaseTracking):
|
|
|
405
402
|
|
|
406
403
|
input_geometries: list = context["input_geometries"]
|
|
407
404
|
|
|
408
|
-
|
|
405
|
+
video_interface = TrackerInterface(
|
|
409
406
|
context=context,
|
|
410
407
|
api=api,
|
|
411
408
|
load_all_frames=True,
|
|
@@ -417,8 +414,8 @@ class MaskTracking(BaseTracking):
|
|
|
417
414
|
)
|
|
418
415
|
|
|
419
416
|
range_of_frames = [
|
|
420
|
-
|
|
421
|
-
|
|
417
|
+
video_interface.frames_indexes[0],
|
|
418
|
+
video_interface.frames_indexes[-1],
|
|
422
419
|
]
|
|
423
420
|
|
|
424
421
|
if self.cache.is_persistent:
|
|
@@ -426,24 +423,29 @@ class MaskTracking(BaseTracking):
|
|
|
426
423
|
self.cache.run_cache_task_manually(
|
|
427
424
|
api,
|
|
428
425
|
None,
|
|
429
|
-
video_id=
|
|
426
|
+
video_id=video_interface.video_id,
|
|
430
427
|
)
|
|
431
428
|
else:
|
|
432
429
|
# if cache is not persistent, run cache task for range of frames
|
|
433
430
|
self.cache.run_cache_task_manually(
|
|
434
431
|
api,
|
|
435
432
|
[range_of_frames],
|
|
436
|
-
video_id=
|
|
433
|
+
video_id=video_interface.video_id,
|
|
437
434
|
)
|
|
438
435
|
|
|
439
|
-
|
|
436
|
+
inference_request.set_stage("Downloading frames", 0, video_interface.frames_count)
|
|
440
437
|
# load frames
|
|
441
|
-
frames = self.
|
|
438
|
+
frames = self.cache.download_frames(
|
|
439
|
+
api,
|
|
440
|
+
video_interface.video_id,
|
|
441
|
+
video_interface.frames_indexes,
|
|
442
|
+
progress_cb=inference_request.done,
|
|
443
|
+
)
|
|
442
444
|
# combine several binary masks into one multilabel mask
|
|
443
|
-
i = 0
|
|
444
445
|
label2id = {}
|
|
445
446
|
|
|
446
|
-
|
|
447
|
+
multilabel_mask = np.zeros(frames[0].shape, dtype=np.uint8)
|
|
448
|
+
for i, input_geom in enumerate(input_geometries, 1):
|
|
447
449
|
geometry = self._deserialize_geometry(input_geom)
|
|
448
450
|
if not isinstance(geometry, Bitmap) and not isinstance(geometry, Polygon):
|
|
449
451
|
raise TypeError(f"This app does not support {geometry.geometry_name()} tracking")
|
|
@@ -454,18 +456,24 @@ class MaskTracking(BaseTracking):
|
|
|
454
456
|
bitmap_obj_class = ObjClass("bitmap", Bitmap)
|
|
455
457
|
bitmap_label = polygon_label.convert(bitmap_obj_class)[0]
|
|
456
458
|
geometry = bitmap_label.geometry
|
|
457
|
-
|
|
458
|
-
multilabel_mask = geometry.data.astype(int)
|
|
459
|
-
multilabel_mask = np.zeros(frames[0].shape, dtype=np.uint8)
|
|
460
|
-
geometry.draw(bitmap=multilabel_mask, color=[1, 1, 1])
|
|
461
|
-
i += 1
|
|
462
|
-
else:
|
|
463
|
-
i += 1
|
|
464
|
-
geometry.draw(bitmap=multilabel_mask, color=[i, i, i])
|
|
459
|
+
geometry.draw(bitmap=multilabel_mask, color=i)
|
|
465
460
|
label2id[i] = {
|
|
466
461
|
"original_geometry": geometry.geometry_name(),
|
|
467
462
|
}
|
|
468
463
|
|
|
464
|
+
result_indexes = np.unique(multilabel_mask)
|
|
465
|
+
progress_total = len(result_indexes)
|
|
466
|
+
if 0 in result_indexes:
|
|
467
|
+
progress_total -= 1
|
|
468
|
+
progress_total = progress_total * video_interface.frames_count
|
|
469
|
+
|
|
470
|
+
api.logger.info("Starting tracking process")
|
|
471
|
+
inference_request.set_stage(
|
|
472
|
+
InferenceRequest.Stage.INFERENCE,
|
|
473
|
+
0,
|
|
474
|
+
progress_total,
|
|
475
|
+
)
|
|
476
|
+
|
|
469
477
|
# run tracker
|
|
470
478
|
tracked_multilabel_masks = self.predict(frames=frames, input_mask=multilabel_mask[:, :, 0])
|
|
471
479
|
tracked_multilabel_masks = np.array(tracked_multilabel_masks)
|
|
@@ -492,8 +500,25 @@ class MaskTracking(BaseTracking):
|
|
|
492
500
|
|
|
493
501
|
# predictions must be NxK masks: N=number of frames, K=number of objects
|
|
494
502
|
predictions = list(map(list, zip(*predictions)))
|
|
503
|
+
inference_request.final_result = predictions
|
|
495
504
|
return predictions
|
|
496
505
|
|
|
506
|
+
# Implement the following methods in the derived class
|
|
507
|
+
def track(self, api: Api, state: Dict, context: Dict):
|
|
508
|
+
fn = self.send_error_data(api, context)(self._track)
|
|
509
|
+
self.inference_requests_manager.schedule_task(fn, api, context)
|
|
510
|
+
return {"message": "Tracking has started."}
|
|
511
|
+
|
|
512
|
+
def track_api(self, api: Api, state: Dict, context: Dict):
|
|
513
|
+
inference_request, future = self.inference_requests_manager.schedule_task(
|
|
514
|
+
self._track_api, api, context
|
|
515
|
+
)
|
|
516
|
+
future.result()
|
|
517
|
+
logger.info(
|
|
518
|
+
"Track-api request processed.", extra={"inference_request_uuid": inference_request.uuid}
|
|
519
|
+
)
|
|
520
|
+
return inference_request.final_result
|
|
521
|
+
|
|
497
522
|
def track_api_files(
|
|
498
523
|
self,
|
|
499
524
|
files: List[BinaryIO],
|
|
@@ -524,15 +549,14 @@ class MaskTracking(BaseTracking):
|
|
|
524
549
|
f"Batch size should be less than or equal to {self.max_batch_size} for this model."
|
|
525
550
|
)
|
|
526
551
|
|
|
527
|
-
inference_request_uuid = uuid.uuid5(namespace=uuid.NAMESPACE_URL, name=f"{time.time()}").hex
|
|
528
552
|
fn = self.send_error_data(api, context)(self._track_async)
|
|
529
|
-
self.schedule_task(fn, api, context
|
|
553
|
+
inference_request, _ = self.inference_requests_manager.schedule_task(fn, api, context)
|
|
530
554
|
|
|
531
555
|
logger.debug(
|
|
532
556
|
"Inference has scheduled from 'track_async' endpoint",
|
|
533
|
-
extra={"inference_request_uuid":
|
|
557
|
+
extra={"inference_request_uuid": inference_request.uuid},
|
|
534
558
|
)
|
|
535
559
|
return {
|
|
536
560
|
"message": "Inference has started.",
|
|
537
|
-
"inference_request_uuid":
|
|
561
|
+
"inference_request_uuid": inference_request.uuid,
|
|
538
562
|
}
|