supervisely 6.73.357__py3-none-any.whl → 6.73.359__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. supervisely/_utils.py +12 -0
  2. supervisely/api/annotation_api.py +3 -0
  3. supervisely/api/api.py +2 -2
  4. supervisely/api/app_api.py +27 -2
  5. supervisely/api/entity_annotation/tag_api.py +0 -1
  6. supervisely/api/nn/__init__.py +0 -0
  7. supervisely/api/nn/deploy_api.py +821 -0
  8. supervisely/api/nn/neural_network_api.py +248 -0
  9. supervisely/api/task_api.py +26 -467
  10. supervisely/app/fastapi/subapp.py +1 -0
  11. supervisely/nn/__init__.py +2 -1
  12. supervisely/nn/artifacts/artifacts.py +5 -5
  13. supervisely/nn/benchmark/object_detection/metric_provider.py +3 -0
  14. supervisely/nn/experiments.py +28 -5
  15. supervisely/nn/inference/cache.py +178 -114
  16. supervisely/nn/inference/gui/gui.py +18 -35
  17. supervisely/nn/inference/gui/serving_gui.py +3 -1
  18. supervisely/nn/inference/inference.py +1421 -1265
  19. supervisely/nn/inference/inference_request.py +412 -0
  20. supervisely/nn/inference/object_detection_3d/object_detection_3d.py +31 -24
  21. supervisely/nn/inference/session.py +2 -2
  22. supervisely/nn/inference/tracking/base_tracking.py +45 -79
  23. supervisely/nn/inference/tracking/bbox_tracking.py +220 -155
  24. supervisely/nn/inference/tracking/mask_tracking.py +274 -250
  25. supervisely/nn/inference/tracking/tracker_interface.py +23 -0
  26. supervisely/nn/inference/uploader.py +164 -0
  27. supervisely/nn/model/__init__.py +0 -0
  28. supervisely/nn/model/model_api.py +259 -0
  29. supervisely/nn/model/prediction.py +311 -0
  30. supervisely/nn/model/prediction_session.py +632 -0
  31. supervisely/nn/tracking/__init__.py +1 -0
  32. supervisely/nn/tracking/boxmot.py +114 -0
  33. supervisely/nn/tracking/tracking.py +24 -0
  34. supervisely/nn/training/train_app.py +61 -19
  35. supervisely/nn/utils.py +43 -3
  36. supervisely/task/progress.py +12 -2
  37. supervisely/video/video.py +107 -1
  38. {supervisely-6.73.357.dist-info → supervisely-6.73.359.dist-info}/METADATA +2 -1
  39. {supervisely-6.73.357.dist-info → supervisely-6.73.359.dist-info}/RECORD +43 -32
  40. supervisely/api/neural_network_api.py +0 -202
  41. {supervisely-6.73.357.dist-info → supervisely-6.73.359.dist-info}/LICENSE +0 -0
  42. {supervisely-6.73.357.dist-info → supervisely-6.73.359.dist-info}/WHEEL +0 -0
  43. {supervisely-6.73.357.dist-info → supervisely-6.73.359.dist-info}/entry_points.txt +0 -0
  44. {supervisely-6.73.357.dist-info → supervisely-6.73.359.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,12 @@
1
+ import inspect
1
2
  import time
2
3
  import uuid
3
- from queue import Queue
4
- from threading import Event, Thread
5
4
  from typing import BinaryIO, Dict, List, Tuple
6
5
 
7
6
  import numpy as np
8
7
  from pydantic import ValidationError
9
8
 
10
- from supervisely._utils import find_value_by_keys
9
+ from supervisely._utils import find_value_by_keys, get_valid_kwargs
11
10
  from supervisely.annotation.label import Geometry, Label
12
11
  from supervisely.annotation.obj_class import ObjClass
13
12
  from supervisely.api.api import Api
@@ -17,13 +16,14 @@ from supervisely.geometry.bitmap import Bitmap
17
16
  from supervisely.geometry.helpers import deserialize_geometry
18
17
  from supervisely.geometry.polygon import Polygon
19
18
  from supervisely.imaging import image as sly_image
19
+ from supervisely.nn.inference.inference_request import InferenceRequest
20
20
  from supervisely.nn.inference.tracking.base_tracking import BaseTracking
21
21
  from supervisely.nn.inference.tracking.tracker_interface import (
22
22
  TrackerInterface,
23
23
  TrackerInterfaceV2,
24
24
  )
25
+ from supervisely.nn.inference.uploader import Uploader
25
26
  from supervisely.sly_logger import logger
26
- from supervisely.task.progress import Progress
27
27
 
28
28
 
29
29
  class MaskTracking(BaseTracking):
@@ -92,19 +92,20 @@ class MaskTracking(BaseTracking):
92
92
  )
93
93
  return results
94
94
 
95
- def _track(self, api: Api, context: Dict):
96
- self.video_interface = TrackerInterface(
95
+ def _track(self, api: Api, context: Dict, inference_request: InferenceRequest):
96
+ video_interface = TrackerInterface(
97
97
  context=context,
98
98
  api=api,
99
- load_all_frames=True,
99
+ load_all_frames=False,
100
100
  notify_in_predict=True,
101
101
  per_point_polygon_tracking=False,
102
102
  frame_loader=self.cache.download_frame,
103
103
  frames_loader=self.cache.download_frames,
104
104
  )
105
+ video_interface.stop += video_interface.frames_count + 1
105
106
  range_of_frames = [
106
- self.video_interface.frames_indexes[0],
107
- self.video_interface.frames_indexes[-1],
107
+ video_interface.frames_indexes[0],
108
+ video_interface.frames_indexes[-1],
108
109
  ]
109
110
 
110
111
  if self.cache.is_persistent:
@@ -112,58 +113,201 @@ class MaskTracking(BaseTracking):
112
113
  self.cache.run_cache_task_manually(
113
114
  api,
114
115
  None,
115
- video_id=self.video_interface.video_id,
116
+ video_id=video_interface.video_id,
116
117
  )
117
118
  else:
118
119
  # if cache is not persistent, run cache task for range of frames
119
120
  self.cache.run_cache_task_manually(
120
121
  api,
121
122
  [range_of_frames],
122
- video_id=self.video_interface.video_id,
123
+ video_id=video_interface.video_id,
123
124
  )
124
125
 
125
- api.logger.info("Starting tracking process")
126
+ api.logger.debug("frames_count = %s", video_interface.frames_count)
127
+ inference_request.set_stage("Downloading frames", 0, video_interface.frames_count)
126
128
  # load frames
127
- frames = self.video_interface.frames
129
+
130
+ def _load_frames_cb(n: int = 1):
131
+ inference_request.done(n)
132
+ video_interface._notify(pos_increment=n, task="Downloading frames")
133
+
134
+ frames = self.cache.download_frames(
135
+ api,
136
+ video_interface.video_id,
137
+ video_interface.frames_indexes,
138
+ progress_cb=_load_frames_cb,
139
+ )
140
+
128
141
  # combine several binary masks into one multilabel mask
129
- i = 0
142
+ i = 1
130
143
  label2id = {}
144
+ multilabel_mask = np.zeros(frames[0].shape, dtype=np.uint8)
145
+ for (fig_id, geometry), obj_id in zip(
146
+ video_interface.geometries.items(),
147
+ video_interface.object_ids,
148
+ ):
149
+ original_geometry = geometry.clone()
150
+ if not isinstance(geometry, Bitmap) and not isinstance(geometry, Polygon):
151
+ raise TypeError(f"This app does not support {geometry.geometry_name()} tracking")
152
+ # convert polygon to bitmap
153
+ if isinstance(geometry, Polygon):
154
+ polygon_obj_class = ObjClass("polygon", Polygon)
155
+ polygon_label = Label(geometry, polygon_obj_class)
156
+ bitmap_obj_class = ObjClass("bitmap", Bitmap)
157
+ bitmap_label = polygon_label.convert(bitmap_obj_class)[0]
158
+ geometry = bitmap_label.geometry
159
+ geometry.draw(bitmap=multilabel_mask, color=i)
160
+ label2id[i] = {
161
+ "fig_id": fig_id,
162
+ "obj_id": obj_id,
163
+ "original_geometry": original_geometry.geometry_name(),
164
+ }
165
+ i += 1
131
166
 
132
- def _upload_loop(q: Queue, stop_event: Event, video_interface: TrackerInterface):
133
- try:
134
- while True:
135
- items = []
136
- while not q.empty():
137
- items.append(q.get_nowait())
138
- if len(items) > 0:
139
- video_interface.add_object_geometries_on_frames(*list(zip(*items)))
140
- continue
141
- if stop_event.is_set():
142
- video_interface._notify(True, task="stop tracking")
143
- return
144
- time.sleep(1)
145
- except Exception as e:
146
- api.logger.error("Error in upload loop: %s", str(e), exc_info=True)
147
- video_interface._notify(True, task="stop tracking")
148
- video_interface.global_stop_indicatior = True
149
- raise
150
-
151
- upload_queue = Queue()
152
- stop_upload_event = Event()
153
- Thread(
154
- target=_upload_loop,
155
- args=[upload_queue, stop_upload_event, self.video_interface],
156
- daemon=True,
157
- ).start()
158
-
159
- try:
160
- for (fig_id, geometry), obj_id in zip(
161
- self.video_interface.geometries.items(),
162
- self.video_interface.object_ids,
167
+ unique_labels = np.unique(multilabel_mask)
168
+ if 0 in unique_labels:
169
+ unique_labels = unique_labels[1:]
170
+ api.logger.debug("unique_labels = %s", unique_labels)
171
+ total_progress = len(unique_labels) * video_interface.frames_count
172
+ api.logger.info("Starting tracking process")
173
+ api.logger.debug("total_progress = %s", total_progress)
174
+ inference_request.set_stage(
175
+ InferenceRequest.Stage.INFERENCE,
176
+ 0,
177
+ total_progress,
178
+ )
179
+
180
+ def _upload_f(items: List):
181
+ video_interface.add_object_geometries_on_frames(*list(zip(*items)))
182
+ inference_request.done(sum(item[-1] for item in items))
183
+
184
+ with Uploader(upload_f=_upload_f, logger=api.logger) as uploader:
185
+ # run tracker
186
+ tracked_multilabel_masks = self.predict(
187
+ frames=frames, input_mask=multilabel_mask[:, :, 0]
188
+ )
189
+ for curframe_i, mask in enumerate(
190
+ tracked_multilabel_masks, video_interface.frame_index
163
191
  ):
192
+ if curframe_i == video_interface.frame_index:
193
+ continue
194
+ for i in unique_labels:
195
+ binary_mask = mask == i
196
+ fig_id = label2id[i]["fig_id"]
197
+ obj_id = label2id[i]["obj_id"]
198
+ geometry_type = label2id[i]["original_geometry"]
199
+ if not np.any(binary_mask):
200
+ api.logger.info(f"Skipping empty mask on frame {curframe_i}")
201
+ inference_request.done()
202
+ else:
203
+ if geometry_type == "polygon":
204
+ bitmap_geometry = Bitmap(binary_mask)
205
+ bitmap_obj_class = ObjClass("bitmap", Bitmap)
206
+ bitmap_label = Label(bitmap_geometry, bitmap_obj_class)
207
+ polygon_obj_class = ObjClass("polygon", Polygon)
208
+ polygon_labels = bitmap_label.convert(polygon_obj_class)
209
+ geometries = [label.geometry for label in polygon_labels]
210
+ else:
211
+ geometries = [Bitmap(binary_mask)]
212
+ uploader.put(
213
+ [
214
+ (
215
+ geometry,
216
+ obj_id,
217
+ curframe_i,
218
+ True if g_idx == len(geometries) - 1 else False,
219
+ )
220
+ for g_idx, geometry in enumerate(geometries)
221
+ ]
222
+ )
223
+ if inference_request.is_stopped() or video_interface.global_stop_indicatior:
224
+ api.logger.info(
225
+ "Tracking stopped by user",
226
+ extra={"inference_request_uuid": inference_request.uuid},
227
+ )
228
+ video_interface._notify(True, task="Stop tracking")
229
+ return
230
+ if uploader.has_exception():
231
+ raise uploader.exception
232
+
233
+ api.logger.info(f"Frame {curframe_i} was successfully tracked")
234
+
235
+ def _track_async(self, api: Api, context: dict, inference_request: InferenceRequest):
236
+ tracker_interface = TrackerInterfaceV2(api, context, self.cache)
237
+ frames_count = tracker_interface.frames_count
238
+ figures = tracker_interface.figures
239
+ progress_total = frames_count * len(figures)
240
+ frame_range = [
241
+ tracker_interface.frame_indexes[0],
242
+ tracker_interface.frame_indexes[-1],
243
+ ]
244
+ frame_range_asc = [min(frame_range), max(frame_range)]
245
+
246
+ def _upload_f(items: List[Tuple[FigureInfo, bool]]):
247
+ inference_request.add_results([item[0] for item in items])
248
+ inference_request.done(sum(item[1] for item in items))
249
+
250
+ def _notify_f(items: List[Tuple[FigureInfo, bool]]):
251
+ frame_range = [
252
+ min(item[0].frame_index for item in items),
253
+ max(item[0].frame_index for item in items),
254
+ ]
255
+ tracker_interface.notify_progress(
256
+ inference_request.progress.current, inference_request.progress.total, frame_range
257
+ )
258
+
259
+ def _exception_handler(exception: Exception):
260
+ api.logger.error(f"Error saving predictions: {str(exception)}", exc_info=True)
261
+ tracker_interface.notify_progress(
262
+ inference_request.progress.current,
263
+ inference_request.progress.current,
264
+ frame_range_asc,
265
+ )
266
+ tracker_interface.notify_error(exception)
267
+ raise Exception
268
+
269
+ def _maybe_stop():
270
+ if inference_request.is_stopped() or tracker_interface.is_stopped():
271
+ if isinstance(tracker_interface.stop_reason(), Exception):
272
+ raise tracker_interface.stop_reason()
273
+ api.logger.info(
274
+ "Inference request stopped.",
275
+ extra={"inference_request_uuid": inference_request.uuid},
276
+ )
277
+ tracker_interface.notify_progress(
278
+ inference_request.progress.current,
279
+ inference_request.progress.current,
280
+ frame_range_asc,
281
+ )
282
+ return True
283
+ if uploader.has_exception():
284
+ raise uploader.exception
285
+ return False
286
+
287
+ # run tracker
288
+ frame_index = tracker_interface.frame_index
289
+ direction_n = tracker_interface.direction_n
290
+ api.logger.info("Start tracking.")
291
+ inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
292
+ with Uploader(
293
+ upload_f=_upload_f,
294
+ notify_f=_notify_f,
295
+ exception_handler=_exception_handler,
296
+ logger=api.logger,
297
+ ) as uploader:
298
+ # combine several binary masks into one multilabel mask
299
+ i = 0
300
+ label2id = {}
301
+ # load frames
302
+ frames = tracker_interface.load_all_frames()
303
+ frames = [frame.image for frame in frames]
304
+ for figure in figures:
305
+ figure = api.video.figure._convert_json_info(figure)
306
+ fig_id = figure.id
307
+ obj_id = figure.object_id
308
+ geometry = deserialize_geometry(figure.geometry_type, figure.geometry)
164
309
  original_geometry = geometry.clone()
165
- if not isinstance(geometry, Bitmap) and not isinstance(geometry, Polygon):
166
- stop_upload_event.set()
310
+ if not isinstance(geometry, (Bitmap, Polygon)):
167
311
  raise TypeError(
168
312
  f"This app does not support {geometry.geometry_name()} tracking"
169
313
  )
@@ -187,26 +331,34 @@ class MaskTracking(BaseTracking):
187
331
  "obj_id": obj_id,
188
332
  "original_geometry": original_geometry.geometry_name(),
189
333
  }
190
- # run tracker
334
+
335
+ if _maybe_stop():
336
+ return
337
+
338
+ # predict
191
339
  tracked_multilabel_masks = self.predict(
192
340
  frames=frames, input_mask=multilabel_mask[:, :, 0]
193
341
  )
194
342
  tracked_multilabel_masks = np.array(tracked_multilabel_masks)
343
+
195
344
  # decompose multilabel masks into binary masks
196
345
  for i in np.unique(tracked_multilabel_masks):
346
+ if _maybe_stop():
347
+ return
197
348
  if i != 0:
198
349
  binary_masks = tracked_multilabel_masks == i
199
350
  fig_id = label2id[i]["fig_id"]
200
351
  obj_id = label2id[i]["obj_id"]
201
352
  geometry_type = label2id[i]["original_geometry"]
202
- for j, mask in enumerate(binary_masks[1:]):
353
+ for j, mask in enumerate(binary_masks[1:], 1):
354
+ if _maybe_stop():
355
+ return
356
+ this_figure_index = frame_index + j * direction_n
203
357
  # check if mask is not empty
204
358
  if not np.any(mask):
205
- api.logger.info(
206
- f"Skipping empty mask on frame {self.video_interface.frame_index + j + 1}"
207
- )
359
+ api.logger.info(f"Skipping empty mask on frame {this_figure_index}")
208
360
  # update progress bar anyway (otherwise it will not be finished)
209
- self.video_interface._notify(task="add geometry on frame")
361
+ inference_request.done()
210
362
  else:
211
363
  if geometry_type == "polygon":
212
364
  bitmap_geometry = Bitmap(mask)
@@ -218,184 +370,29 @@ class MaskTracking(BaseTracking):
218
370
  else:
219
371
  geometries = [Bitmap(mask)]
220
372
  for l, geometry in enumerate(geometries):
221
- if l == len(geometries) - 1:
222
- notify = True
223
- else:
224
- notify = False
225
- upload_queue.put(
226
- (
227
- geometry,
228
- obj_id,
229
- self.video_interface.frames_indexes[j + 1],
230
- notify,
231
- )
373
+ figure_id = uuid.uuid5(
374
+ namespace=uuid.NAMESPACE_URL, name=f"{time.time()}"
375
+ ).hex
376
+ result_figure = api.video.figure._convert_json_info(
377
+ {
378
+ ApiField.ID: figure_id,
379
+ ApiField.OBJECT_ID: obj_id,
380
+ "meta": {"frame": this_figure_index},
381
+ ApiField.GEOMETRY_TYPE: geometry.geometry_name(),
382
+ ApiField.GEOMETRY: geometry.to_json(),
383
+ ApiField.TRACK_ID: tracker_interface.track_id,
384
+ }
232
385
  )
233
- if self.video_interface.global_stop_indicatior:
234
- stop_upload_event.set()
235
- return
236
- api.logger.info(f"Figure with id {fig_id} was successfully tracked")
237
- except Exception:
238
- stop_upload_event.set()
239
- raise
240
- stop_upload_event.set()
241
-
242
- def _track_async(self, api: Api, context: dict, inference_request_uuid: str = None):
243
- inference_request = self._inference_requests[inference_request_uuid]
244
- tracker_interface = TrackerInterfaceV2(api, context, self.cache)
245
- progress: Progress = inference_request["progress"]
246
- frames_count = tracker_interface.frames_count
247
- figures = tracker_interface.figures
248
- progress_total = frames_count * len(figures)
249
- progress.total = progress_total
250
-
251
- def _upload_f(items: List[Tuple[FigureInfo, bool]]):
252
- with inference_request["lock"]:
253
- inference_request["pending_results"].extend([item[0] for item in items])
254
-
255
- def _notify_f(items: List[Tuple[FigureInfo, bool]]):
256
- items_by_object_id: Dict[int, List[Tuple[FigureInfo, bool]]] = {}
257
- for item in items:
258
- items_by_object_id.setdefault(item[0].object_id, []).append(item)
259
-
260
- for object_id, object_items in items_by_object_id.items():
261
- frame_range = [
262
- min(item[0].frame_index for item in object_items),
263
- max(item[0].frame_index for item in object_items),
264
- ]
265
- progress.iters_done_report(sum(1 for item in object_items if item[1]))
266
- tracker_interface.notify_progress(progress.current, progress.total, frame_range)
267
-
268
- # run tracker
269
- frame_index = tracker_interface.frame_index
270
- direction_n = tracker_interface.direction_n
271
- api.logger.info("Start tracking.")
272
- try:
273
- with tracker_interface(_upload_f, _notify_f):
274
- # combine several binary masks into one multilabel mask
275
- i = 0
276
- label2id = {}
277
- # load frames
278
- frames = tracker_interface.load_all_frames()
279
- frames = [frame.image for frame in frames]
280
- for figure in figures:
281
- figure = api.video.figure._convert_json_info(figure)
282
- fig_id = figure.id
283
- obj_id = figure.object_id
284
- geometry = deserialize_geometry(figure.geometry_type, figure.geometry)
285
- original_geometry = geometry.clone()
286
- if not isinstance(geometry, (Bitmap, Polygon)):
287
- raise TypeError(
288
- f"This app does not support {geometry.geometry_name()} tracking"
289
- )
290
- # convert polygon to bitmap
291
- if isinstance(geometry, Polygon):
292
- polygon_obj_class = ObjClass("polygon", Polygon)
293
- polygon_label = Label(geometry, polygon_obj_class)
294
- bitmap_obj_class = ObjClass("bitmap", Bitmap)
295
- bitmap_label = polygon_label.convert(bitmap_obj_class)[0]
296
- geometry = bitmap_label.geometry
297
- if i == 0:
298
- multilabel_mask = geometry.data.astype(int)
299
- multilabel_mask = np.zeros(frames[0].shape, dtype=np.uint8)
300
- geometry.draw(bitmap=multilabel_mask, color=[1, 1, 1])
301
- i += 1
302
- else:
303
- i += 1
304
- geometry.draw(bitmap=multilabel_mask, color=[i, i, i])
305
- label2id[i] = {
306
- "fig_id": fig_id,
307
- "obj_id": obj_id,
308
- "original_geometry": original_geometry.geometry_name(),
309
- }
310
- if inference_request["cancel_inference"]:
311
- return
312
- if tracker_interface.is_stopped():
313
- reason = tracker_interface.stop_reason()
314
- if isinstance(reason, Exception):
315
- raise reason
316
- return
317
-
318
- # predict
319
- tracked_multilabel_masks = self.predict(
320
- frames=frames, input_mask=multilabel_mask[:, :, 0]
321
- )
322
- tracked_multilabel_masks = np.array(tracked_multilabel_masks)
323
-
324
- # decompose multilabel masks into binary masks
325
- for i in np.unique(tracked_multilabel_masks):
326
- if inference_request["cancel_inference"]:
327
- return
328
- if tracker_interface.is_stopped():
329
- reason = tracker_interface.stop_reason()
330
- if isinstance(reason, Exception):
331
- raise reason
332
- return
333
- if i != 0:
334
- binary_masks = tracked_multilabel_masks == i
335
- fig_id = label2id[i]["fig_id"]
336
- obj_id = label2id[i]["obj_id"]
337
- geometry_type = label2id[i]["original_geometry"]
338
- for j, mask in enumerate(binary_masks[1:], 1):
339
- if inference_request["cancel_inference"]:
340
- return
341
- if tracker_interface.is_stopped():
342
- reason = tracker_interface.stop_reason()
343
- if isinstance(reason, Exception):
344
- raise reason
345
- return
346
- this_figure_index = frame_index + j * direction_n
347
- # check if mask is not empty
348
- if not np.any(mask):
349
- api.logger.info(f"Skipping empty mask on frame {this_figure_index}")
350
- # update progress bar anyway (otherwise it will not be finished)
351
- progress.iter_done_report()
352
- else:
353
- if geometry_type == "polygon":
354
- bitmap_geometry = Bitmap(mask)
355
- bitmap_obj_class = ObjClass("bitmap", Bitmap)
356
- bitmap_label = Label(bitmap_geometry, bitmap_obj_class)
357
- polygon_obj_class = ObjClass("polygon", Polygon)
358
- polygon_labels = bitmap_label.convert(polygon_obj_class)
359
- geometries = [label.geometry for label in polygon_labels]
360
- else:
361
- geometries = [Bitmap(mask)]
362
- for l, geometry in enumerate(geometries):
363
- figure_id = uuid.uuid5(
364
- namespace=uuid.NAMESPACE_URL, name=f"{time.time()}"
365
- ).hex
366
- result_figure = api.video.figure._convert_json_info(
367
- {
368
- ApiField.ID: figure_id,
369
- ApiField.OBJECT_ID: obj_id,
370
- "meta": {"frame": this_figure_index},
371
- ApiField.GEOMETRY_TYPE: geometry.geometry_name(),
372
- ApiField.GEOMETRY: geometry.to_json(),
373
- ApiField.TRACK_ID: tracker_interface.track_id,
374
- }
375
- )
376
- should_notify = l == len(geometries) - 1
377
- tracker_interface.add_prediction((result_figure, should_notify))
378
- api.logger.info(
379
- "Figure [%d, %d] tracked.",
380
- i,
381
- len(figures),
382
- extra={"figure_id": figure.id},
383
- )
384
- except Exception:
385
- progress.message = "Error occured during tracking"
386
- raise
387
- else:
388
- progress.message = "Ready"
389
- finally:
390
- progress.set(current=0, total=1, report=True)
391
-
392
- # Implement the following methods in the derived class
393
- def track(self, api: Api, state: Dict, context: Dict):
394
- fn = self.send_error_data(api, context)(self._track)
395
- self.schedule_task(fn, api, context)
396
- return {"message": "Tracking has started."}
386
+ should_notify = l == len(geometries) - 1
387
+ tracker_interface.add_prediction((result_figure, should_notify))
388
+ api.logger.info(
389
+ "Figure [%d, %d] tracked.",
390
+ i,
391
+ len(figures),
392
+ extra={"figure_id": figure.id},
393
+ )
397
394
 
398
- def track_api(self, api: Api, state: Dict, context: Dict):
395
+ def _track_api(self, api: Api, context: Dict, inference_request: InferenceRequest):
399
396
  # unused fields:
400
397
  context["trackId"] = "auto"
401
398
  context["objectIds"] = []
@@ -405,7 +402,7 @@ class MaskTracking(BaseTracking):
405
402
 
406
403
  input_geometries: list = context["input_geometries"]
407
404
 
408
- self.video_interface = TrackerInterface(
405
+ video_interface = TrackerInterface(
409
406
  context=context,
410
407
  api=api,
411
408
  load_all_frames=True,
@@ -417,8 +414,8 @@ class MaskTracking(BaseTracking):
417
414
  )
418
415
 
419
416
  range_of_frames = [
420
- self.video_interface.frames_indexes[0],
421
- self.video_interface.frames_indexes[-1],
417
+ video_interface.frames_indexes[0],
418
+ video_interface.frames_indexes[-1],
422
419
  ]
423
420
 
424
421
  if self.cache.is_persistent:
@@ -426,24 +423,29 @@ class MaskTracking(BaseTracking):
426
423
  self.cache.run_cache_task_manually(
427
424
  api,
428
425
  None,
429
- video_id=self.video_interface.video_id,
426
+ video_id=video_interface.video_id,
430
427
  )
431
428
  else:
432
429
  # if cache is not persistent, run cache task for range of frames
433
430
  self.cache.run_cache_task_manually(
434
431
  api,
435
432
  [range_of_frames],
436
- video_id=self.video_interface.video_id,
433
+ video_id=video_interface.video_id,
437
434
  )
438
435
 
439
- api.logger.info("Starting tracking process")
436
+ inference_request.set_stage("Downloading frames", 0, video_interface.frames_count)
440
437
  # load frames
441
- frames = self.video_interface.frames
438
+ frames = self.cache.download_frames(
439
+ api,
440
+ video_interface.video_id,
441
+ video_interface.frames_indexes,
442
+ progress_cb=inference_request.done,
443
+ )
442
444
  # combine several binary masks into one multilabel mask
443
- i = 0
444
445
  label2id = {}
445
446
 
446
- for input_geom in input_geometries:
447
+ multilabel_mask = np.zeros(frames[0].shape, dtype=np.uint8)
448
+ for i, input_geom in enumerate(input_geometries, 1):
447
449
  geometry = self._deserialize_geometry(input_geom)
448
450
  if not isinstance(geometry, Bitmap) and not isinstance(geometry, Polygon):
449
451
  raise TypeError(f"This app does not support {geometry.geometry_name()} tracking")
@@ -454,18 +456,24 @@ class MaskTracking(BaseTracking):
454
456
  bitmap_obj_class = ObjClass("bitmap", Bitmap)
455
457
  bitmap_label = polygon_label.convert(bitmap_obj_class)[0]
456
458
  geometry = bitmap_label.geometry
457
- if i == 0:
458
- multilabel_mask = geometry.data.astype(int)
459
- multilabel_mask = np.zeros(frames[0].shape, dtype=np.uint8)
460
- geometry.draw(bitmap=multilabel_mask, color=[1, 1, 1])
461
- i += 1
462
- else:
463
- i += 1
464
- geometry.draw(bitmap=multilabel_mask, color=[i, i, i])
459
+ geometry.draw(bitmap=multilabel_mask, color=i)
465
460
  label2id[i] = {
466
461
  "original_geometry": geometry.geometry_name(),
467
462
  }
468
463
 
464
+ result_indexes = np.unique(multilabel_mask)
465
+ progress_total = len(result_indexes)
466
+ if 0 in result_indexes:
467
+ progress_total -= 1
468
+ progress_total = progress_total * video_interface.frames_count
469
+
470
+ api.logger.info("Starting tracking process")
471
+ inference_request.set_stage(
472
+ InferenceRequest.Stage.INFERENCE,
473
+ 0,
474
+ progress_total,
475
+ )
476
+
469
477
  # run tracker
470
478
  tracked_multilabel_masks = self.predict(frames=frames, input_mask=multilabel_mask[:, :, 0])
471
479
  tracked_multilabel_masks = np.array(tracked_multilabel_masks)
@@ -492,8 +500,25 @@ class MaskTracking(BaseTracking):
492
500
 
493
501
  # predictions must be NxK masks: N=number of frames, K=number of objects
494
502
  predictions = list(map(list, zip(*predictions)))
503
+ inference_request.final_result = predictions
495
504
  return predictions
496
505
 
506
+ # Implement the following methods in the derived class
507
+ def track(self, api: Api, state: Dict, context: Dict):
508
+ fn = self.send_error_data(api, context)(self._track)
509
+ self.inference_requests_manager.schedule_task(fn, api, context)
510
+ return {"message": "Tracking has started."}
511
+
512
+ def track_api(self, api: Api, state: Dict, context: Dict):
513
+ inference_request, future = self.inference_requests_manager.schedule_task(
514
+ self._track_api, api, context
515
+ )
516
+ future.result()
517
+ logger.info(
518
+ "Track-api request processed.", extra={"inference_request_uuid": inference_request.uuid}
519
+ )
520
+ return inference_request.final_result
521
+
497
522
  def track_api_files(
498
523
  self,
499
524
  files: List[BinaryIO],
@@ -524,15 +549,14 @@ class MaskTracking(BaseTracking):
524
549
  f"Batch size should be less than or equal to {self.max_batch_size} for this model."
525
550
  )
526
551
 
527
- inference_request_uuid = uuid.uuid5(namespace=uuid.NAMESPACE_URL, name=f"{time.time()}").hex
528
552
  fn = self.send_error_data(api, context)(self._track_async)
529
- self.schedule_task(fn, api, context, inference_request_uuid=inference_request_uuid)
553
+ inference_request, _ = self.inference_requests_manager.schedule_task(fn, api, context)
530
554
 
531
555
  logger.debug(
532
556
  "Inference has scheduled from 'track_async' endpoint",
533
- extra={"inference_request_uuid": inference_request_uuid},
557
+ extra={"inference_request_uuid": inference_request.uuid},
534
558
  )
535
559
  return {
536
560
  "message": "Inference has started.",
537
- "inference_request_uuid": inference_request_uuid,
561
+ "inference_request_uuid": inference_request.uuid,
538
562
  }