supervisely 6.73.284__py3-none-any.whl → 6.73.285__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of supervisely might be problematic. Click here for more details.

@@ -1,67 +1,298 @@
1
- import functools
2
- import json
3
1
  import time
2
+ import uuid
4
3
  from pathlib import Path
5
4
  from queue import Queue
6
5
  from threading import Event, Thread
7
- from typing import Any, Dict, List, Optional, Union
6
+ from typing import Any, BinaryIO, Dict, List, Union
8
7
 
9
8
  import numpy as np
10
- from fastapi import BackgroundTasks, Form, Request, UploadFile
9
+ from pydantic import ValidationError
11
10
 
12
- import supervisely as sly
13
11
  import supervisely.nn.inference.tracking.functional as F
12
+ from supervisely.annotation.annotation import Annotation
14
13
  from supervisely.annotation.label import Geometry, Label
15
- from supervisely.nn.inference import Inference
16
- from supervisely.nn.inference.tracking.tracker_interface import TrackerInterface
14
+ from supervisely.annotation.obj_class import ObjClass
15
+ from supervisely.api.api import Api
16
+ from supervisely.api.module_api import ApiField
17
+ from supervisely.api.video.video_figure_api import FigureInfo
18
+ from supervisely.geometry.graph import GraphNodes
19
+ from supervisely.geometry.helpers import deserialize_geometry
20
+ from supervisely.geometry.point import Point
21
+ from supervisely.geometry.polygon import Polygon
22
+ from supervisely.geometry.polyline import Polyline
23
+ from supervisely.geometry.rectangle import Rectangle
24
+ from supervisely.imaging import image as sly_image
25
+ from supervisely.nn.inference.tracking.base_tracking import BaseTracking
26
+ from supervisely.nn.inference.tracking.tracker_interface import (
27
+ TrackerInterface,
28
+ TrackerInterfaceV2,
29
+ )
17
30
  from supervisely.nn.prediction_dto import Prediction, PredictionPoint
31
+ from supervisely.sly_logger import logger
32
+ from supervisely.task.progress import Progress
18
33
 
19
34
 
20
- class PointTracking(Inference):
21
- def __init__(
22
- self,
23
- model_dir: Optional[str] = None,
24
- custom_inference_settings: Optional[Union[Dict[str, Any], str]] = None,
25
- ):
26
- Inference.__init__(
27
- self,
28
- model_dir,
29
- custom_inference_settings,
30
- sliding_window_mode=None,
31
- use_gui=False,
35
+ class PointTracking(BaseTracking):
36
+ def _deserialize_geometry(self, data: dict):
37
+ geometry_type_str = data["type"]
38
+ geometry_json = data["data"]
39
+ return deserialize_geometry(geometry_type_str, geometry_json)
40
+
41
+ def _inference(self, frames: List[np.ndarray], geometries: List[Geometry], settings: dict):
42
+ updated_settings = {
43
+ **self.custom_inference_settings_dict,
44
+ **settings,
45
+ }
46
+ results = [[] for _ in range(len(frames) - 1)]
47
+ for geometry in geometries:
48
+ if isinstance(geometry, Point):
49
+ predictions = self._predict_point_geometries(geometry, frames, updated_settings)
50
+ elif isinstance(geometry, Polygon):
51
+ if len(geometry.interior) > 0:
52
+ raise ValueError("Can't track polygons with interior.")
53
+ predictions = self._predict_polygon_geometries(
54
+ geometry,
55
+ frames,
56
+ updated_settings,
57
+ )
58
+ elif isinstance(geometry, GraphNodes):
59
+ predictions = self._predict_graph_geometries(
60
+ geometry,
61
+ frames,
62
+ updated_settings,
63
+ )
64
+ elif isinstance(geometry, Polyline):
65
+ predictions = self._predict_polyline_geometries(
66
+ geometry,
67
+ frames,
68
+ updated_settings,
69
+ )
70
+ else:
71
+ raise TypeError(f"Tracking does not work with {geometry.geometry_name()}.")
72
+
73
+ for i, prediction in enumerate(predictions):
74
+ results[i].append({"type": geometry.geometry_name(), "data": prediction.to_json()})
75
+
76
+ return results
77
+
78
+ def _track(self, api: Api, context: dict):
79
+ if self.custom_inference_settings_dict.get("load_all_frames"):
80
+ load_all_frames = True
81
+ else:
82
+ load_all_frames = False
83
+ video_interface = TrackerInterface(
84
+ context=context,
85
+ api=api,
86
+ load_all_frames=load_all_frames,
87
+ frame_loader=self.cache.download_frame,
88
+ frames_loader=self.cache.download_frames,
32
89
  )
33
90
 
91
+ range_of_frames = [
92
+ video_interface.frames_indexes[0],
93
+ video_interface.frames_indexes[-1],
94
+ ]
95
+
96
+ if self.cache.is_persistent:
97
+ # if cache is persistent, run cache task for whole video
98
+ self.cache.run_cache_task_manually(
99
+ api,
100
+ None,
101
+ video_id=video_interface.video_id,
102
+ )
103
+ else:
104
+ # if cache is not persistent, run cache task for range of frames
105
+ self.cache.run_cache_task_manually(
106
+ api,
107
+ [range_of_frames],
108
+ video_id=video_interface.video_id,
109
+ )
110
+
111
+ api.logger.info("Start tracking.")
112
+
113
+ def _upload_loop(q: Queue, stop_event: Event, video_interface: TrackerInterface):
114
+ try:
115
+ while True:
116
+ items = []
117
+ while not q.empty():
118
+ items.append(q.get_nowait())
119
+ if len(items) > 0:
120
+ video_interface.add_object_geometries_on_frames(*list(zip(*items)))
121
+ continue
122
+ if stop_event.is_set():
123
+ video_interface._notify(True, task="stop tracking")
124
+ return
125
+ time.sleep(1)
126
+ except Exception as e:
127
+ api.logger.error("Error in upload loop: %s", str(e), exc_info=True)
128
+ video_interface._notify(True, task="stop tracking")
129
+ video_interface.global_stop_indicatior = True
130
+ raise
131
+
132
+ upload_queue = Queue()
133
+ stop_upload_event = Event()
134
+ Thread(
135
+ target=_upload_loop,
136
+ args=[upload_queue, stop_upload_event, video_interface],
137
+ daemon=True,
138
+ ).start()
34
139
  try:
35
- self.load_on_device(model_dir, "cuda")
36
- except RuntimeError:
37
- self.load_on_device(model_dir, "cpu")
38
- sly.logger.warn("Failed to load model on CUDA device.")
39
-
40
- sly.logger.debug(
41
- "Smart cache params",
42
- extra={
43
- "ttl": sly.env.smart_cache_ttl(),
44
- "maxsize": sly.env.smart_cache_size(),
45
- "path": sly.env.smart_cache_container_dir(),
46
- },
47
- )
140
+ for _ in video_interface.frames_loader_generator():
141
+ for (fig_id, geom), obj_id in zip(
142
+ video_interface.geometries.items(),
143
+ video_interface.object_ids,
144
+ ):
145
+ if isinstance(geom, Point):
146
+ geometries = self._predict_point_geometries(
147
+ geom,
148
+ video_interface.frames_with_notification,
149
+ )
150
+ elif isinstance(geom, Polygon):
151
+ if len(geom.interior) > 0:
152
+ stop_upload_event.set()
153
+ raise ValueError("Can't track polygons with interior.")
154
+ geometries = self._predict_polygon_geometries(
155
+ geom,
156
+ video_interface.frames_with_notification,
157
+ )
158
+ elif isinstance(geom, GraphNodes):
159
+ geometries = self._predict_graph_geometries(
160
+ geom,
161
+ video_interface.frames_with_notification,
162
+ )
163
+ elif isinstance(geom, Polyline):
164
+ geometries = self._predict_polyline_geometries(
165
+ geom,
166
+ video_interface.frames_with_notification,
167
+ )
168
+ else:
169
+ raise TypeError(f"Tracking does not work with {geom.geometry_name()}.")
170
+
171
+ for frame_idx, geometry in zip(
172
+ video_interface._cur_frames_indexes[1:], geometries
173
+ ):
174
+ upload_queue.put(
175
+ (
176
+ geometry,
177
+ obj_id,
178
+ frame_idx,
179
+ )
180
+ )
181
+ api.logger.info(f"Object #{obj_id} tracked.")
48
182
 
49
- def _deserialize_geometry(self, data: dict):
50
- geometry_type_str = data["type"]
51
- geometry_json = data["data"]
52
- return sly.deserialize_geometry(geometry_type_str, geometry_json)
53
-
54
- def get_info(self):
55
- info = super().get_info()
56
- info["task type"] = "tracking"
57
- # recommended parameters:
58
- # info["model_name"] = ""
59
- # info["checkpoint_name"] = ""
60
- # info["pretrained_on_dataset"] = ""
61
- # info["device"] = ""
62
- return info
63
-
64
- def track_api(self, api: sly.Api, context: dict):
183
+ if video_interface.global_stop_indicatior:
184
+ stop_upload_event.set()
185
+ return
186
+ except Exception:
187
+ stop_upload_event.set()
188
+ raise
189
+ stop_upload_event.set()
190
+
191
+ def _track_async(self, api: Api, context: dict, inference_request_uuid: str):
192
+ inference_request = self._inference_requests[inference_request_uuid]
193
+ tracker_interface = TrackerInterfaceV2(api, context, self.cache)
194
+ progress: Progress = inference_request["progress"]
195
+ frames_count = tracker_interface.frames_count
196
+ figures = tracker_interface.figures
197
+ progress_total = frames_count * len(figures)
198
+ progress.total = progress_total
199
+
200
+ def _upload_f(items: List[FigureInfo]):
201
+ with inference_request["lock"]:
202
+ inference_request["pending_results"].extend(items)
203
+
204
+ def _notify_f(items: List[FigureInfo]):
205
+ items_by_object_id: Dict[int, List[FigureInfo]] = {}
206
+ for item in items:
207
+ items_by_object_id.setdefault(item.object_id, []).append(item)
208
+
209
+ for object_id, object_items in items_by_object_id.items():
210
+ frame_range = [
211
+ min(item.frame_index for item in object_items),
212
+ max(item.frame_index for item in object_items),
213
+ ]
214
+ progress.iters_done_report(len(object_items))
215
+ tracker_interface.notify_progress(progress.current, progress.total, frame_range)
216
+
217
+ frame_index = tracker_interface.frame_index
218
+ direction_n = tracker_interface.direction_n
219
+ api.logger.info("Start tracking.")
220
+ try:
221
+ with tracker_interface(_upload_f, _notify_f):
222
+ frames = tracker_interface.load_all_frames()
223
+ frames = [frame.image for frame in frames]
224
+ for figure in figures:
225
+ figure = api.video.figure._convert_json_info(figure)
226
+ api.logger.info("geometry:", extra={"figure": figure._asdict()})
227
+ sly_geometry: Rectangle = deserialize_geometry(
228
+ figure.geometry_type, figure.geometry
229
+ )
230
+ api.logger.info("geometry:", extra={"geometry": type(sly_geometry)})
231
+ if isinstance(sly_geometry, Point):
232
+ geometries = self._predict_point_geometries(
233
+ sly_geometry,
234
+ frames,
235
+ )
236
+ elif isinstance(sly_geometry, Polygon):
237
+ if len(sly_geometry.interior) > 0:
238
+ raise ValueError("Can't track polygons with interior.")
239
+ geometries = self._predict_polygon_geometries(
240
+ sly_geometry,
241
+ frames,
242
+ )
243
+ elif isinstance(sly_geometry, GraphNodes):
244
+ geometries = self._predict_graph_geometries(
245
+ sly_geometry,
246
+ frames,
247
+ )
248
+ elif isinstance(sly_geometry, Polyline):
249
+ geometries = self._predict_polyline_geometries(
250
+ sly_geometry,
251
+ frames,
252
+ )
253
+ else:
254
+ raise TypeError(
255
+ f"Tracking does not work with {sly_geometry.geometry_name()}."
256
+ )
257
+
258
+ for i, geometry in enumerate(geometries, 1):
259
+ figure_id = uuid.uuid5(
260
+ namespace=uuid.NAMESPACE_URL, name=f"{time.time()}"
261
+ ).hex
262
+ result_figure = api.video.figure._convert_json_info(
263
+ {
264
+ ApiField.ID: figure_id,
265
+ ApiField.OBJECT_ID: figure.object_id,
266
+ "meta": {"frame": frame_index + i * direction_n},
267
+ ApiField.GEOMETRY_TYPE: geometry.geometry_name(),
268
+ ApiField.GEOMETRY: geometry.to_json(),
269
+ ApiField.TRACK_ID: tracker_interface.track_id,
270
+ }
271
+ )
272
+ tracker_interface.add_prediction(result_figure)
273
+ api.logger.info(f"Figure #{figure.id} tracked.")
274
+
275
+ if inference_request["cancel_inference"]:
276
+ return
277
+ if tracker_interface.is_stopped():
278
+ reason = tracker_interface.stop_reason()
279
+ if isinstance(reason, Exception):
280
+ raise reason
281
+ return
282
+ except Exception as e:
283
+ progress.message = "Error occured during tracking"
284
+ raise
285
+ else:
286
+ progress.message = "Ready"
287
+ finally:
288
+ progress.set(current=0, total=1, report=True)
289
+
290
+ def track(self, api: Api, state: Dict, context: Dict):
291
+ fn = self.send_error_data(api, context)(self._track)
292
+ self.schedule_task(fn, api, context)
293
+ return {"message": "Track task started."}
294
+
295
+ def track_api(self, api: Api, state: Dict, context: Dict):
65
296
  # unused fields:
66
297
  context["trackId"] = "auto"
67
298
  context["objectIds"] = []
@@ -110,24 +341,24 @@ class PointTracking(Inference):
110
341
  for _ in video_interface.frames_loader_generator():
111
342
  for input_geom in input_geometries:
112
343
  geom = self._deserialize_geometry(input_geom)
113
- if isinstance(geom, sly.Point):
344
+ if isinstance(geom, Point):
114
345
  geometries = self._predict_point_geometries(
115
346
  geom,
116
347
  video_interface.frames,
117
348
  )
118
- elif isinstance(geom, sly.Polygon):
349
+ elif isinstance(geom, Polygon):
119
350
  if len(geom.interior) > 0:
120
351
  raise ValueError("Can't track polygons with interior.")
121
352
  geometries = self._predict_polygon_geometries(
122
353
  geom,
123
354
  video_interface.frames,
124
355
  )
125
- elif isinstance(geom, sly.GraphNodes):
356
+ elif isinstance(geom, GraphNodes):
126
357
  geometries = self._predict_graph_geometries(
127
358
  geom,
128
359
  video_interface.frames,
129
360
  )
130
- elif isinstance(geom, sly.Polyline):
361
+ elif isinstance(geom, Polyline):
131
362
  geometries = self._predict_polyline_geometries(
132
363
  geom,
133
364
  video_interface.frames,
@@ -145,237 +376,43 @@ class PointTracking(Inference):
145
376
  predictions = list(map(list, zip(*predictions)))
146
377
  return predictions
147
378
 
148
- def _inference(self, frames: List[np.ndarray], geometries: List[Geometry], settings: dict):
149
- updated_settings = {
150
- **self.custom_inference_settings_dict,
151
- **settings,
152
- }
153
- results = [[] for _ in range(len(frames) - 1)]
154
- for geometry in geometries:
155
- if isinstance(geometry, sly.Point):
156
- predictions = self._predict_point_geometries(geometry, frames, updated_settings)
157
- elif isinstance(geometry, sly.Polygon):
158
- if len(geometry.interior) > 0:
159
- raise ValueError("Can't track polygons with interior.")
160
- predictions = self._predict_polygon_geometries(
161
- geometry,
162
- frames,
163
- updated_settings,
164
- )
165
- elif isinstance(geometry, sly.GraphNodes):
166
- predictions = self._predict_graph_geometries(
167
- geometry,
168
- frames,
169
- updated_settings,
170
- )
171
- elif isinstance(geometry, sly.Polyline):
172
- predictions = self._predict_polyline_geometries(
173
- geometry,
174
- frames,
175
- updated_settings,
176
- )
177
- else:
178
- raise TypeError(f"Tracking does not work with {geometry.geometry_name()}.")
179
-
180
- for i, prediction in enumerate(predictions):
181
- results[i].append({"type": geometry.geometry_name(), "data": prediction.to_json()})
182
-
183
- return results
184
-
185
- def track_api_cached(self, request: Request, context: dict):
186
- sly.logger.info(f"Start tracking with settings: {context}.")
187
- video_id = context["video_id"]
188
- frame_indexes = list(
189
- range(context["frame_index"], context["frame_index"] + context["frames"] + 1)
190
- )
191
- geometries = map(self._deserialize_geometry, context["input_geometries"])
192
- frames = self.cache.get_frames_from_cache(video_id, frame_indexes)
193
- return self._inference(frames, geometries, context)
194
-
195
- def _track_api_files(
196
- self, request: Request, files: List[UploadFile], settings: str = Form("{}")
379
+ def track_api_files(
380
+ self,
381
+ files: List[BinaryIO],
382
+ settings: Dict,
197
383
  ):
198
- state = json.loads(settings)
199
- sly.logger.info(f"Start tracking with settings: {state}.")
200
- video_id = state["video_id"]
384
+ logger.info(f"Start tracking with settings:", extra={"settings": settings})
201
385
  frame_indexes = list(
202
- range(state["frame_index"], state["frame_index"] + state["frames"] + 1)
386
+ range(settings["frame_index"], settings["frame_index"] + settings["frames"] + 1)
203
387
  )
204
- geometries = map(self._deserialize_geometry, state["input_geometries"])
388
+ geometries = map(self._deserialize_geometry, settings["input_geometries"])
205
389
  frames = []
206
390
  for file, frame_idx in zip(files, frame_indexes):
207
- img_bytes = file.file.read()
208
- frame = sly.image.read_bytes(img_bytes)
391
+ img_bytes = file.read()
392
+ frame = sly_image.read_bytes(img_bytes)
209
393
  frames.append(frame)
210
- sly.logger.info("Start tracking.")
211
- return self._inference(frames, geometries, state)
212
-
213
- def serve(self):
214
- super().serve()
215
- server = self._app.get_server()
216
- self.cache.add_cache_endpoint(server)
217
- self.cache.add_cache_files_endpoint(server)
218
-
219
- @server.post("/track")
220
- def start_track(request: Request, task: BackgroundTasks):
221
- task.add_task(track, request)
222
- return {"message": "Track task started."}
223
-
224
- @server.post("/track-api")
225
- def track_api(request: Request):
226
- return self.track_api(request.state.api, request.state.context)
227
-
228
- @server.post("/track-api-files")
229
- def track_api_frames_files(
230
- request: Request,
231
- files: List[UploadFile],
232
- settings: str = Form("{}"),
233
- ):
234
- return self._track_api_files(request, files, settings)
235
-
236
- def send_error_data(func):
237
- @functools.wraps(func)
238
- def wrapper(*args, **kwargs):
239
- value = None
240
- try:
241
- value = func(*args, **kwargs)
242
- except Exception as exc:
243
- request: Request = args[0]
244
- context = request.state.context
245
- api: sly.Api = request.state.api
246
- track_id = context["trackId"]
247
- api.logger.error(f"An error occured: {repr(exc)}")
248
-
249
- api.post(
250
- "videos.notify-annotation-tool",
251
- data={
252
- "type": "videos:tracking-error",
253
- "data": {
254
- "trackId": track_id,
255
- "error": {"message": repr(exc)},
256
- },
257
- },
258
- )
259
- return value
260
-
261
- return wrapper
262
-
263
- @send_error_data
264
- def track(request: Request):
265
- context = request.state.context
266
- api: sly.Api = request.state.api
267
-
268
- if self.custom_inference_settings_dict.get("load_all_frames"):
269
- load_all_frames = True
270
- else:
271
- load_all_frames = False
272
- video_interface = TrackerInterface(
273
- context=context,
274
- api=api,
275
- load_all_frames=load_all_frames,
276
- frame_loader=self.cache.download_frame,
277
- frames_loader=self.cache.download_frames,
394
+ logger.info("Start tracking.")
395
+ return self._inference(frames, geometries, settings)
396
+
397
+ def track_async(self, api: Api, state: Dict, context: Dict):
398
+ batch_size = context.get("batch_size", self.get_batch_size())
399
+ if self.max_batch_size is not None and batch_size > self.max_batch_size:
400
+ raise ValidationError(
401
+ f"Batch size should be less than or equal to {self.max_batch_size} for this model."
278
402
  )
279
403
 
280
- range_of_frames = [
281
- video_interface.frames_indexes[0],
282
- video_interface.frames_indexes[-1],
283
- ]
284
-
285
- if self.cache.is_persistent:
286
- # if cache is persistent, run cache task for whole video
287
- self.cache.run_cache_task_manually(
288
- api,
289
- None,
290
- video_id=video_interface.video_id,
291
- )
292
- else:
293
- # if cache is not persistent, run cache task for range of frames
294
- self.cache.run_cache_task_manually(
295
- api,
296
- [range_of_frames],
297
- video_id=video_interface.video_id,
298
- )
299
-
300
- api.logger.info("Start tracking.")
301
-
302
- def _upload_loop(q: Queue, stop_event: Event, video_interface: TrackerInterface):
303
- try:
304
- while True:
305
- items = []
306
- while not q.empty():
307
- items.append(q.get_nowait())
308
- if len(items) > 0:
309
- video_interface.add_object_geometries_on_frames(*list(zip(*items)))
310
- continue
311
- if stop_event.is_set():
312
- video_interface._notify(True, task="stop tracking")
313
- return
314
- time.sleep(1)
315
- except Exception as e:
316
- api.logger.error("Error in upload loop: %s", str(e), exc_info=True)
317
- video_interface._notify(True, task="stop tracking")
318
- video_interface.global_stop_indicatior = True
319
- raise
320
-
321
- upload_queue = Queue()
322
- stop_upload_event = Event()
323
- Thread(
324
- target=_upload_loop,
325
- args=[upload_queue, stop_upload_event, video_interface],
326
- daemon=True,
327
- ).start()
328
- try:
329
- for _ in video_interface.frames_loader_generator():
330
- for (fig_id, geom), obj_id in zip(
331
- video_interface.geometries.items(),
332
- video_interface.object_ids,
333
- ):
334
- if isinstance(geom, sly.Point):
335
- geometries = self._predict_point_geometries(
336
- geom,
337
- video_interface.frames_with_notification,
338
- )
339
- elif isinstance(geom, sly.Polygon):
340
- if len(geom.interior) > 0:
341
- stop_upload_event.set()
342
- raise ValueError("Can't track polygons with interior.")
343
- geometries = self._predict_polygon_geometries(
344
- geom,
345
- video_interface.frames_with_notification,
346
- )
347
- elif isinstance(geom, sly.GraphNodes):
348
- geometries = self._predict_graph_geometries(
349
- geom,
350
- video_interface.frames_with_notification,
351
- )
352
- elif isinstance(geom, sly.Polyline):
353
- geometries = self._predict_polyline_geometries(
354
- geom,
355
- video_interface.frames_with_notification,
356
- )
357
- else:
358
- raise TypeError(f"Tracking does not work with {geom.geometry_name()}.")
359
-
360
- for frame_idx, geometry in zip(
361
- video_interface._cur_frames_indexes[1:], geometries
362
- ):
363
- upload_queue.put(
364
- (
365
- geometry,
366
- obj_id,
367
- frame_idx,
368
- )
369
- )
370
- api.logger.info(f"Object #{obj_id} tracked.")
404
+ inference_request_uuid = uuid.uuid5(namespace=uuid.NAMESPACE_URL, name=f"{time.time()}").hex
405
+ fn = self.send_error_data(api, context)(self._track_async)
406
+ self.schedule_task(fn, api, context, inference_request_uuid=inference_request_uuid)
371
407
 
372
- if video_interface.global_stop_indicatior:
373
- stop_upload_event.set()
374
- return
375
- except Exception:
376
- stop_upload_event.set()
377
- raise
378
- stop_upload_event.set()
408
+ logger.debug(
409
+ "Inference has scheduled from 'track_async' endpoint",
410
+ extra={"inference_request_uuid": inference_request_uuid},
411
+ )
412
+ return {
413
+ "message": "Inference has started.",
414
+ "inference_request_uuid": inference_request_uuid,
415
+ }
379
416
 
380
417
  def predict(
381
418
  self,
@@ -438,19 +475,19 @@ class PointTracking(Inference):
438
475
  fill_rectangles=False,
439
476
  )
440
477
 
441
- def _create_label(self, dto: PredictionPoint) -> sly.Point:
442
- geometry = sly.Point(row=dto.row, col=dto.col)
443
- return Label(geometry, sly.ObjClass("", sly.Point))
478
+ def _create_label(self, dto: PredictionPoint) -> Point:
479
+ geometry = Point(row=dto.row, col=dto.col)
480
+ return Label(geometry, ObjClass("", Point))
444
481
 
445
482
  def _get_obj_class_shape(self):
446
- return sly.Point
483
+ return Point
447
484
 
448
485
  def _predict_point_geometries(
449
486
  self,
450
- geom: sly.Point,
487
+ geom: Point,
451
488
  frames: List[np.ndarray],
452
489
  settings: Dict[str, Any] = None,
453
- ) -> List[sly.Point]:
490
+ ) -> List[Point]:
454
491
  if settings is None:
455
492
  settings = self.custom_inference_settings_dict
456
493
  pp_geom = PredictionPoint("point", col=geom.col, row=geom.row)
@@ -472,10 +509,10 @@ class PointTracking(Inference):
472
509
 
473
510
  def _predict_polygon_geometries(
474
511
  self,
475
- geom: sly.Polygon,
512
+ geom: Polygon,
476
513
  frames: List[np.ndarray],
477
514
  settings: Dict[str, Any] = None,
478
- ) -> List[sly.Polygon]:
515
+ ) -> List[Polygon]:
479
516
  if settings is None:
480
517
  settings = self.custom_inference_settings_dict
481
518
  polygon_points = F.numpy_to_dto_point(geom.exterior_np, "polygon")
@@ -501,10 +538,10 @@ class PointTracking(Inference):
501
538
 
502
539
  def _predict_graph_geometries(
503
540
  self,
504
- geom: sly.GraphNodes,
541
+ geom: GraphNodes,
505
542
  frames: List[np.ndarray],
506
543
  settings: Dict[str, Any] = None,
507
- ) -> List[sly.GraphNodes]:
544
+ ) -> List[GraphNodes]:
508
545
  if settings is None:
509
546
  settings = self.custom_inference_settings_dict
510
547
  points, pids = F.graph_to_dto_points(geom)
@@ -530,10 +567,10 @@ class PointTracking(Inference):
530
567
 
531
568
  def _predict_polyline_geometries(
532
569
  self,
533
- geom: sly.Polyline,
570
+ geom: Polyline,
534
571
  frames: List[np.ndarray],
535
572
  settings: Dict[str, Any] = None,
536
- ) -> List[sly.Polyline]:
573
+ ) -> List[Polyline]:
537
574
  if settings is None:
538
575
  settings = self.custom_inference_settings_dict
539
576
  polyline_points = F.numpy_to_dto_point(geom.exterior_np, "polyline")
@@ -558,7 +595,7 @@ class PointTracking(Inference):
558
595
 
559
596
  def _predictions_to_annotation(
560
597
  self, image: np.ndarray, predictions: List[Prediction], classes_whitelist: List[str] = None
561
- ) -> sly.Annotation:
598
+ ) -> Annotation:
562
599
  labels = []
563
600
  for prediction in predictions:
564
601
  if (
@@ -576,6 +613,6 @@ class PointTracking(Inference):
576
613
  labels.append(label)
577
614
 
578
615
  # create annotation with correct image resolution
579
- ann = sly.Annotation(img_size=image.shape[:2])
616
+ ann = Annotation(img_size=image.shape[:2])
580
617
  ann = ann.add_labels(labels)
581
618
  return ann