supervisely 6.73.357__py3-none-any.whl → 6.73.358__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/_utils.py +12 -0
- supervisely/api/annotation_api.py +3 -0
- supervisely/api/api.py +2 -2
- supervisely/api/app_api.py +27 -2
- supervisely/api/entity_annotation/tag_api.py +0 -1
- supervisely/api/nn/__init__.py +0 -0
- supervisely/api/nn/deploy_api.py +821 -0
- supervisely/api/nn/neural_network_api.py +248 -0
- supervisely/api/task_api.py +26 -467
- supervisely/app/fastapi/subapp.py +1 -0
- supervisely/nn/__init__.py +2 -1
- supervisely/nn/artifacts/artifacts.py +5 -5
- supervisely/nn/benchmark/object_detection/metric_provider.py +3 -0
- supervisely/nn/experiments.py +28 -5
- supervisely/nn/inference/cache.py +178 -114
- supervisely/nn/inference/gui/gui.py +18 -35
- supervisely/nn/inference/gui/serving_gui.py +3 -1
- supervisely/nn/inference/inference.py +1421 -1265
- supervisely/nn/inference/inference_request.py +412 -0
- supervisely/nn/inference/object_detection_3d/object_detection_3d.py +31 -24
- supervisely/nn/inference/session.py +2 -2
- supervisely/nn/inference/tracking/base_tracking.py +45 -79
- supervisely/nn/inference/tracking/bbox_tracking.py +220 -155
- supervisely/nn/inference/tracking/mask_tracking.py +274 -250
- supervisely/nn/inference/tracking/tracker_interface.py +23 -0
- supervisely/nn/inference/uploader.py +164 -0
- supervisely/nn/model/__init__.py +0 -0
- supervisely/nn/model/model_api.py +259 -0
- supervisely/nn/model/prediction.py +311 -0
- supervisely/nn/model/prediction_session.py +632 -0
- supervisely/nn/tracking/__init__.py +1 -0
- supervisely/nn/tracking/boxmot.py +114 -0
- supervisely/nn/tracking/tracking.py +24 -0
- supervisely/nn/training/train_app.py +61 -19
- supervisely/nn/utils.py +43 -3
- supervisely/task/progress.py +12 -2
- supervisely/video/video.py +107 -1
- {supervisely-6.73.357.dist-info → supervisely-6.73.358.dist-info}/METADATA +2 -1
- {supervisely-6.73.357.dist-info → supervisely-6.73.358.dist-info}/RECORD +43 -32
- supervisely/api/neural_network_api.py +0 -202
- {supervisely-6.73.357.dist-info → supervisely-6.73.358.dist-info}/LICENSE +0 -0
- {supervisely-6.73.357.dist-info → supervisely-6.73.358.dist-info}/WHEEL +0 -0
- {supervisely-6.73.357.dist-info → supervisely-6.73.358.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.357.dist-info → supervisely-6.73.358.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,412 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import threading
|
|
4
|
+
import time
|
|
5
|
+
import traceback
|
|
6
|
+
import uuid
|
|
7
|
+
from concurrent.futures import Future, ThreadPoolExecutor
|
|
8
|
+
from functools import partial, wraps
|
|
9
|
+
from typing import Any, Dict, List, Tuple, Union
|
|
10
|
+
|
|
11
|
+
from supervisely._utils import rand_str
|
|
12
|
+
from supervisely.nn.utils import get_gpu_usage, get_ram_usage
|
|
13
|
+
from supervisely.sly_logger import logger
|
|
14
|
+
from supervisely.task.progress import Progress
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def generate_uuid(self) -> str:
|
|
18
|
+
"""
|
|
19
|
+
Generates a unique UUID for the inference request.
|
|
20
|
+
"""
|
|
21
|
+
return uuid.uuid5(namespace=uuid.NAMESPACE_URL, name=f"{time.time()}-{rand_str(10)}").hex
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class InferenceRequest:
|
|
25
|
+
class Stage:
|
|
26
|
+
PREPARING = "Preparing model for inference..."
|
|
27
|
+
INFERENCE = "Running inference..."
|
|
28
|
+
FINISHED = "Finished"
|
|
29
|
+
CANCELLED = "Cancelled"
|
|
30
|
+
ERROR = "Error"
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
uuid_: str = None,
|
|
35
|
+
ttl: Union[int, None] = 60 * 60,
|
|
36
|
+
manager: InferenceRequestsManager = None,
|
|
37
|
+
):
|
|
38
|
+
if uuid_ is None:
|
|
39
|
+
uuid_ = uuid.uuid5(namespace=uuid.NAMESPACE_URL, name=f"{time.time()}").hex
|
|
40
|
+
self._uuid = uuid_
|
|
41
|
+
self._ttl = ttl
|
|
42
|
+
self.manager = manager
|
|
43
|
+
self.context = {}
|
|
44
|
+
self._lock = threading.Lock()
|
|
45
|
+
self._stage = InferenceRequest.Stage.PREPARING
|
|
46
|
+
self._pending_results = []
|
|
47
|
+
self._final_result = None
|
|
48
|
+
self._exception = None
|
|
49
|
+
self._stopped = threading.Event()
|
|
50
|
+
self.progress = Progress(
|
|
51
|
+
message=self._stage,
|
|
52
|
+
total_cnt=1,
|
|
53
|
+
need_info_log=True,
|
|
54
|
+
update_task_progress=False,
|
|
55
|
+
log_extra={"inference_request_uuid": self._uuid},
|
|
56
|
+
)
|
|
57
|
+
self._created_at = time.monotonic()
|
|
58
|
+
self._updated_at = self._created_at
|
|
59
|
+
self._finished = False
|
|
60
|
+
|
|
61
|
+
self.global_progress = None
|
|
62
|
+
self.global_progress_total = 1
|
|
63
|
+
self.global_progress_current = 0
|
|
64
|
+
|
|
65
|
+
def _updated(self):
|
|
66
|
+
self._updated_at = time.monotonic()
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def uuid(self):
|
|
70
|
+
return self._uuid
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def created_at(self):
|
|
74
|
+
return self._created_at
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def updated_at(self):
|
|
78
|
+
return self._updated_at
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def stage(self):
|
|
82
|
+
return self._stage
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def final_result(self):
|
|
86
|
+
with self._lock:
|
|
87
|
+
return self._final_result
|
|
88
|
+
|
|
89
|
+
@final_result.setter
|
|
90
|
+
def final_result(self, result: Any):
|
|
91
|
+
with self._lock:
|
|
92
|
+
self._final_result = result
|
|
93
|
+
self._updated()
|
|
94
|
+
|
|
95
|
+
def add_results(self, results: List[Dict]):
|
|
96
|
+
with self._lock:
|
|
97
|
+
self._pending_results.extend(results)
|
|
98
|
+
self._updated()
|
|
99
|
+
|
|
100
|
+
def pop_pending_results(self, n: int = None):
|
|
101
|
+
with self._lock:
|
|
102
|
+
if len(self._pending_results) == 0:
|
|
103
|
+
return []
|
|
104
|
+
if n is None:
|
|
105
|
+
n = len(self._pending_results)
|
|
106
|
+
if n > len(self._pending_results):
|
|
107
|
+
n = len(self._pending_results)
|
|
108
|
+
results = self._pending_results[:n]
|
|
109
|
+
self._pending_results = self._pending_results[n:]
|
|
110
|
+
self._updated()
|
|
111
|
+
return results
|
|
112
|
+
|
|
113
|
+
def pending_num(self):
|
|
114
|
+
return len(self._pending_results)
|
|
115
|
+
|
|
116
|
+
def set_stage(self, stage: str, current: int = None, total: int = None, is_size: bool = False):
|
|
117
|
+
with self._lock:
|
|
118
|
+
self._stage = stage
|
|
119
|
+
self.progress.message = self._stage
|
|
120
|
+
if current is not None:
|
|
121
|
+
self.progress.current = current
|
|
122
|
+
if total is not None:
|
|
123
|
+
logger.debug("setting total = %s", total)
|
|
124
|
+
self.progress.total = total
|
|
125
|
+
if is_size:
|
|
126
|
+
self.progress.is_size = True
|
|
127
|
+
self.progress._refresh_labels()
|
|
128
|
+
self.progress.report_progress()
|
|
129
|
+
if self._stage == InferenceRequest.Stage.INFERENCE:
|
|
130
|
+
self.global_progress_total = total
|
|
131
|
+
self.global_progress_current = current
|
|
132
|
+
self.manager.global_progress.inference_started(
|
|
133
|
+
current=self.global_progress_current,
|
|
134
|
+
total=self.global_progress_total,
|
|
135
|
+
)
|
|
136
|
+
self._updated()
|
|
137
|
+
|
|
138
|
+
def done(self, n=1):
|
|
139
|
+
with self._lock:
|
|
140
|
+
self.progress.iters_done_report(n)
|
|
141
|
+
if self._stage == InferenceRequest.Stage.INFERENCE:
|
|
142
|
+
self.global_progress_current += n
|
|
143
|
+
if self.manager is not None:
|
|
144
|
+
self.manager.done(n)
|
|
145
|
+
self._updated()
|
|
146
|
+
|
|
147
|
+
@property
|
|
148
|
+
def exception(self):
|
|
149
|
+
return self._exception
|
|
150
|
+
|
|
151
|
+
@exception.setter
|
|
152
|
+
def exception(self, exc: Exception):
|
|
153
|
+
self._exception = exc
|
|
154
|
+
self.set_stage(InferenceRequest.Stage.ERROR)
|
|
155
|
+
self._updated()
|
|
156
|
+
|
|
157
|
+
def is_inferring(self):
|
|
158
|
+
return self.stage == InferenceRequest.Stage.INFERENCE
|
|
159
|
+
|
|
160
|
+
def stop(self):
|
|
161
|
+
self._stopped.set()
|
|
162
|
+
self._updated()
|
|
163
|
+
|
|
164
|
+
def is_stopped(self):
|
|
165
|
+
return self._stopped.is_set()
|
|
166
|
+
|
|
167
|
+
def is_finished(self):
|
|
168
|
+
return self._finished
|
|
169
|
+
|
|
170
|
+
def is_expired(self):
|
|
171
|
+
if self._ttl is None:
|
|
172
|
+
return False
|
|
173
|
+
return time.monotonic() - self._updated_at > self._ttl
|
|
174
|
+
|
|
175
|
+
def progress_json(self):
|
|
176
|
+
return {
|
|
177
|
+
"message": self.progress.message,
|
|
178
|
+
"status": self.progress.message,
|
|
179
|
+
"current": self.progress.current,
|
|
180
|
+
"total": self.progress.total,
|
|
181
|
+
"is_size": self.progress.is_size,
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
def exception_json(self):
|
|
185
|
+
if self._exception is None:
|
|
186
|
+
return None
|
|
187
|
+
return {
|
|
188
|
+
"type": type(self._exception).__name__,
|
|
189
|
+
"message": str(self._exception),
|
|
190
|
+
"traceback": str(traceback.format_exc()),
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
def to_json(self):
|
|
194
|
+
return {
|
|
195
|
+
"uuid": self._uuid,
|
|
196
|
+
"stage": str(self._stage),
|
|
197
|
+
"progress": self.progress_json(),
|
|
198
|
+
"pending_results": self.pending_num(),
|
|
199
|
+
"final_result": self._final_result is not None,
|
|
200
|
+
"exception": self.exception_json(),
|
|
201
|
+
"is_inferring": self.is_inferring(),
|
|
202
|
+
"stopped": self.is_stopped(),
|
|
203
|
+
"finished": self._finished,
|
|
204
|
+
"created_at": self._created_at,
|
|
205
|
+
"updated_at": self._updated_at,
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
def on_inference_end(self):
|
|
209
|
+
if self._stage not in (
|
|
210
|
+
InferenceRequest.Stage.FINISHED,
|
|
211
|
+
InferenceRequest.Stage.CANCELLED,
|
|
212
|
+
InferenceRequest.Stage.ERROR,
|
|
213
|
+
):
|
|
214
|
+
if self.is_stopped():
|
|
215
|
+
self.set_stage(InferenceRequest.Stage.CANCELLED)
|
|
216
|
+
else:
|
|
217
|
+
self.set_stage(InferenceRequest.Stage.FINISHED)
|
|
218
|
+
self._finished = True
|
|
219
|
+
self._updated()
|
|
220
|
+
|
|
221
|
+
def get_usage(self):
|
|
222
|
+
ram_allocated, ram_total = get_ram_usage()
|
|
223
|
+
gpu_allocated, gpu_total = get_gpu_usage()
|
|
224
|
+
return {
|
|
225
|
+
"gpu_memory": {
|
|
226
|
+
"allocated": gpu_allocated,
|
|
227
|
+
"total": gpu_total,
|
|
228
|
+
},
|
|
229
|
+
"ram_memory": {
|
|
230
|
+
"allocated": ram_allocated,
|
|
231
|
+
"total": ram_total,
|
|
232
|
+
},
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
def status(self):
|
|
236
|
+
status_data = self.to_json()
|
|
237
|
+
for key in ["pending_results", "final_result", "created_at", "updated_at"]:
|
|
238
|
+
status_data.pop(key, None)
|
|
239
|
+
status_data.update(self.get_usage())
|
|
240
|
+
return status_data
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
class GlobalProgress:
|
|
244
|
+
def __init__(self):
|
|
245
|
+
self.progress = Progress(message="Ready", total_cnt=1)
|
|
246
|
+
self._lock = threading.Lock()
|
|
247
|
+
|
|
248
|
+
def set_message(self, message: str):
|
|
249
|
+
with self._lock:
|
|
250
|
+
if self.progress.message != message:
|
|
251
|
+
self.progress.message = message
|
|
252
|
+
self.progress.report_progress()
|
|
253
|
+
|
|
254
|
+
def set_ready(self):
|
|
255
|
+
with self._lock:
|
|
256
|
+
self.progress.message = "Ready"
|
|
257
|
+
self.progress.current = 0
|
|
258
|
+
self.progress.total = 1
|
|
259
|
+
self.progress.report_progress()
|
|
260
|
+
|
|
261
|
+
def done(self, n=1):
|
|
262
|
+
with self._lock:
|
|
263
|
+
self.progress.iters_done_report(n)
|
|
264
|
+
if self.progress.current >= self.progress.total:
|
|
265
|
+
self.set_ready()
|
|
266
|
+
|
|
267
|
+
def inference_started(self, current: int, total: int):
|
|
268
|
+
with self._lock:
|
|
269
|
+
if self.progress.message == "Ready":
|
|
270
|
+
self.progress.total = total
|
|
271
|
+
self.progress.current = current
|
|
272
|
+
else:
|
|
273
|
+
self.progress.total += total
|
|
274
|
+
self.progress.current += current
|
|
275
|
+
self.set_message("Inferring model...")
|
|
276
|
+
|
|
277
|
+
def inference_finished(self, current: int, total: int):
|
|
278
|
+
with self._lock:
|
|
279
|
+
if self.progress.message == "Ready":
|
|
280
|
+
return
|
|
281
|
+
self.progress.current = self.progress.current - current
|
|
282
|
+
self.progress.total = self.progress.total - total
|
|
283
|
+
if self.progress.current >= self.progress.total:
|
|
284
|
+
self.set_ready()
|
|
285
|
+
|
|
286
|
+
def to_json(self):
|
|
287
|
+
return {
|
|
288
|
+
"message": self.progress.message,
|
|
289
|
+
"current": self.progress.current,
|
|
290
|
+
"total": self.progress.total,
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class InferenceRequestsManager:
|
|
295
|
+
|
|
296
|
+
def __init__(self, executor: ThreadPoolExecutor = None):
|
|
297
|
+
if executor is None:
|
|
298
|
+
executor = ThreadPoolExecutor(max_workers=1)
|
|
299
|
+
self._executor = executor
|
|
300
|
+
self._inference_requests: Dict[str, InferenceRequest] = {}
|
|
301
|
+
self._lock = threading.Lock()
|
|
302
|
+
self._stop_event = threading.Event()
|
|
303
|
+
self._monitor_thread = threading.Thread(target=self.monitor, daemon=True)
|
|
304
|
+
self._monitor_thread.start()
|
|
305
|
+
self.global_progress = GlobalProgress()
|
|
306
|
+
|
|
307
|
+
def __del__(self):
|
|
308
|
+
try:
|
|
309
|
+
self._executor.shutdown(wait=False)
|
|
310
|
+
self._stop_event.set()
|
|
311
|
+
self._monitor_thread.join(timeout=5)
|
|
312
|
+
finally:
|
|
313
|
+
logger.debug("InferenceRequestsManager was deleted")
|
|
314
|
+
|
|
315
|
+
def add(self, inference_request: InferenceRequest):
|
|
316
|
+
with self._lock:
|
|
317
|
+
self._inference_requests[inference_request.uuid] = inference_request
|
|
318
|
+
|
|
319
|
+
def remove(self, inference_request_uuid: str):
|
|
320
|
+
with self._lock:
|
|
321
|
+
inference_request = self._inference_requests.get(inference_request_uuid)
|
|
322
|
+
if inference_request is not None:
|
|
323
|
+
inference_request.stop()
|
|
324
|
+
del self._inference_requests[inference_request_uuid]
|
|
325
|
+
|
|
326
|
+
def remove_after(self, inference_request_uuid, wait_time=0):
|
|
327
|
+
with self._lock:
|
|
328
|
+
inference_request = self._inference_requests.get(inference_request_uuid)
|
|
329
|
+
if inference_request is not None:
|
|
330
|
+
inference_request.stop()
|
|
331
|
+
inference_request._ttl = wait_time
|
|
332
|
+
inference_request._updated()
|
|
333
|
+
|
|
334
|
+
def get(self, inference_request_uuid: str):
|
|
335
|
+
if inference_request_uuid is None:
|
|
336
|
+
return None
|
|
337
|
+
try:
|
|
338
|
+
return self._inference_requests[inference_request_uuid]
|
|
339
|
+
except Exception as ex:
|
|
340
|
+
raise RuntimeError(
|
|
341
|
+
f"inference_request_uuid {inference_request_uuid} was given, "
|
|
342
|
+
f"but there is no such uuid in 'self._inference_requests' ({len(self._inference_requests)} items)"
|
|
343
|
+
) from ex
|
|
344
|
+
|
|
345
|
+
def create(self, inference_request_uuid: str = None) -> InferenceRequest:
|
|
346
|
+
inference_request = InferenceRequest(uuid_=inference_request_uuid, manager=self)
|
|
347
|
+
self.add(inference_request)
|
|
348
|
+
return inference_request
|
|
349
|
+
|
|
350
|
+
def monitor(self):
|
|
351
|
+
while self._stop_event.is_set() is False:
|
|
352
|
+
for inference_request_uuid in list(self._inference_requests.keys()):
|
|
353
|
+
inference_request = self._inference_requests.get(inference_request_uuid)
|
|
354
|
+
if inference_request is None:
|
|
355
|
+
continue
|
|
356
|
+
if inference_request.is_expired():
|
|
357
|
+
self.remove(inference_request_uuid)
|
|
358
|
+
logger.debug(f"Expired inference request {inference_request_uuid} was deleted")
|
|
359
|
+
time.sleep(30)
|
|
360
|
+
|
|
361
|
+
def done(self, n=1):
|
|
362
|
+
with self._lock:
|
|
363
|
+
self.global_progress.done(n)
|
|
364
|
+
|
|
365
|
+
def _on_inference_start(self, inference_request: InferenceRequest):
|
|
366
|
+
if inference_request.uuid not in self._inference_requests:
|
|
367
|
+
self.add(inference_request)
|
|
368
|
+
|
|
369
|
+
def _on_inference_end(self, future, inference_request_uuid: str):
|
|
370
|
+
logger.debug("callback: on_inference_end()")
|
|
371
|
+
inference_request = self._inference_requests.get(inference_request_uuid)
|
|
372
|
+
if inference_request is not None:
|
|
373
|
+
inference_request.on_inference_end()
|
|
374
|
+
|
|
375
|
+
self.global_progress.inference_finished(
|
|
376
|
+
current=inference_request.global_progress_current,
|
|
377
|
+
total=inference_request.global_progress_total,
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
def _handle_error_in_async(self, inference_request_uuid: str, func, args, kwargs):
|
|
381
|
+
try:
|
|
382
|
+
return func(*args, **kwargs)
|
|
383
|
+
except Exception as e:
|
|
384
|
+
inference_request = self._inference_requests.get(inference_request_uuid, None)
|
|
385
|
+
if inference_request is not None:
|
|
386
|
+
inference_request.exception = e
|
|
387
|
+
logger.error(f"Error in {func.__name__} function: {e}", exc_info=True)
|
|
388
|
+
|
|
389
|
+
def schedule_task(self, func, *args, **kwargs) -> Tuple[InferenceRequest, Future]:
|
|
390
|
+
inference_request = kwargs.get("inference_request", None)
|
|
391
|
+
if inference_request is None:
|
|
392
|
+
inference_request = self.create()
|
|
393
|
+
kwargs["inference_request"] = inference_request
|
|
394
|
+
self._on_inference_start(inference_request)
|
|
395
|
+
future = self._executor.submit(
|
|
396
|
+
self._handle_error_in_async,
|
|
397
|
+
inference_request.uuid,
|
|
398
|
+
func,
|
|
399
|
+
args,
|
|
400
|
+
kwargs,
|
|
401
|
+
)
|
|
402
|
+
end_callback = partial(
|
|
403
|
+
self._on_inference_end, inference_request_uuid=inference_request.uuid
|
|
404
|
+
)
|
|
405
|
+
future.add_done_callback(end_callback)
|
|
406
|
+
logger.debug("Scheduled task.", extra={"inference_request_uuid": inference_request.uuid})
|
|
407
|
+
return inference_request, future
|
|
408
|
+
|
|
409
|
+
def run(self, func, *args, **kwargs):
|
|
410
|
+
inference_request, future = self.schedule_task(func, *args, **kwargs)
|
|
411
|
+
future.result()
|
|
412
|
+
return inference_request.pop_pending_results()
|
|
@@ -1,19 +1,22 @@
|
|
|
1
|
-
from typing import Dict, List, Any
|
|
2
|
-
from supervisely.geometry.cuboid_3d import Cuboid3d
|
|
3
|
-
from supervisely.nn.prediction_dto import PredictionCuboid3d
|
|
4
|
-
from supervisely.annotation.label import Label
|
|
5
|
-
from supervisely.annotation.tag import Tag
|
|
6
|
-
from supervisely.nn.inference.inference import Inference
|
|
7
|
-
from fastapi import Response, Request, status
|
|
8
|
-
from supervisely.sly_logger import logger
|
|
9
1
|
import os
|
|
10
|
-
from
|
|
11
|
-
|
|
2
|
+
from typing import Any, Dict, List
|
|
3
|
+
|
|
4
|
+
from fastapi import Request, Response, status
|
|
5
|
+
|
|
6
|
+
from supervisely import Api, PointcloudAnnotation, PointcloudFigure, PointcloudObject
|
|
12
7
|
from supervisely._utils import rand_str
|
|
8
|
+
from supervisely.annotation.label import Label
|
|
9
|
+
from supervisely.annotation.tag import Tag
|
|
13
10
|
from supervisely.app.content import get_data_dir
|
|
11
|
+
from supervisely.geometry.cuboid_3d import Cuboid3d
|
|
12
|
+
from supervisely.io.fs import silent_remove
|
|
13
|
+
from supervisely.nn.inference.inference import Inference
|
|
14
|
+
from supervisely.nn.prediction_dto import PredictionCuboid3d
|
|
14
15
|
from supervisely.pointcloud_annotation.pointcloud_object_collection import (
|
|
15
16
|
PointcloudObjectCollection,
|
|
16
17
|
)
|
|
18
|
+
from supervisely.sly_logger import logger
|
|
19
|
+
|
|
17
20
|
|
|
18
21
|
class ObjectDetection3D(Inference):
|
|
19
22
|
def get_info(self) -> dict:
|
|
@@ -23,7 +26,7 @@ class ObjectDetection3D(Inference):
|
|
|
23
26
|
info["async_video_inference_support"] = False
|
|
24
27
|
info["tracking_on_videos_support"] = False
|
|
25
28
|
info["async_image_inference_support"] = False
|
|
26
|
-
|
|
29
|
+
|
|
27
30
|
# recommended parameters:
|
|
28
31
|
# info["model_name"] = ""
|
|
29
32
|
# info["checkpoint_name"] = ""
|
|
@@ -44,7 +47,7 @@ class ObjectDetection3D(Inference):
|
|
|
44
47
|
raise NotImplementedError(
|
|
45
48
|
"Have to be implemented in child class If sliding_window_mode is 'advanced'."
|
|
46
49
|
)
|
|
47
|
-
|
|
50
|
+
|
|
48
51
|
def _inference_pointcloud_id(self, api: Api, pointcloud_id: int, settings: Dict[str, Any]):
|
|
49
52
|
# 1. download pointcloud
|
|
50
53
|
pcd_path = os.path.join(get_data_dir(), rand_str(10) + ".pcd")
|
|
@@ -73,7 +76,9 @@ class ObjectDetection3D(Inference):
|
|
|
73
76
|
annotation = PointcloudAnnotation(objects, figures)
|
|
74
77
|
return annotation
|
|
75
78
|
|
|
76
|
-
def raw_results_from_prediction(
|
|
79
|
+
def raw_results_from_prediction(
|
|
80
|
+
self, prediction: List[PredictionCuboid3d]
|
|
81
|
+
) -> List[Dict[str, Any]]:
|
|
77
82
|
results = []
|
|
78
83
|
for pred in prediction:
|
|
79
84
|
detection_name = pred.class_name
|
|
@@ -82,14 +87,16 @@ class ObjectDetection3D(Inference):
|
|
|
82
87
|
rotation_z = pred.cuboid_3d.rotation.z
|
|
83
88
|
velocity = [0, 0] # Is not supported now
|
|
84
89
|
detection_score = pred.score
|
|
85
|
-
results.append(
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
90
|
+
results.append(
|
|
91
|
+
{
|
|
92
|
+
"detection_name": detection_name,
|
|
93
|
+
"translation": translation,
|
|
94
|
+
"size": size,
|
|
95
|
+
"rotation": rotation_z,
|
|
96
|
+
"velocity": velocity,
|
|
97
|
+
"detection_score": detection_score,
|
|
98
|
+
}
|
|
99
|
+
)
|
|
93
100
|
return results
|
|
94
101
|
|
|
95
102
|
def serve(self):
|
|
@@ -103,7 +110,7 @@ class ObjectDetection3D(Inference):
|
|
|
103
110
|
extra={**request.state.state, "api_token": "***"},
|
|
104
111
|
)
|
|
105
112
|
state = request.state.state
|
|
106
|
-
api
|
|
113
|
+
api: Api = request.state.api
|
|
107
114
|
settings = self._get_inference_settings(state)
|
|
108
115
|
prediction = self._inference_pointcloud_id(api, state["pointcloud_id"], settings)
|
|
109
116
|
annotation = self.annotation_from_prediction(prediction)
|
|
@@ -123,11 +130,11 @@ class ObjectDetection3D(Inference):
|
|
|
123
130
|
extra={**request.state.state, "api_token": "***"},
|
|
124
131
|
)
|
|
125
132
|
state = request.state.state
|
|
126
|
-
api
|
|
133
|
+
api: Api = request.state.api
|
|
127
134
|
settings = self._get_inference_settings(state)
|
|
128
135
|
annotations = []
|
|
129
136
|
for pcd_id in state["pointcloud_ids"]:
|
|
130
137
|
prediction = self._inference_pointcloud_id(api, pcd_id, settings)
|
|
131
138
|
annotation = self.annotation_from_prediction(prediction)
|
|
132
139
|
annotations.append(annotation.to_json())
|
|
133
|
-
return annotations
|
|
140
|
+
return annotations
|
|
@@ -459,9 +459,9 @@ class SessionJSON:
|
|
|
459
459
|
progress_widget = preparing_cb(
|
|
460
460
|
message="Downloading infos", total=resp["total"], unit="it"
|
|
461
461
|
)
|
|
462
|
-
|
|
463
462
|
while resp["status"] == "download_info":
|
|
464
463
|
current = resp["current"]
|
|
464
|
+
# pylint: disable=possibly-used-before-assignment
|
|
465
465
|
progress_widget.update(current - prev_current)
|
|
466
466
|
prev_current = current
|
|
467
467
|
resp = self._get_preparing_progress()
|
|
@@ -813,7 +813,7 @@ class Session(SessionJSON):
|
|
|
813
813
|
frames_direction: Literal["forward", "backward"] = None,
|
|
814
814
|
tracker: Literal["bot", "deepsort"] = None,
|
|
815
815
|
batch_size: int = None,
|
|
816
|
-
preparing_cb
|
|
816
|
+
preparing_cb=None,
|
|
817
817
|
) -> AsyncInferenceIterator:
|
|
818
818
|
frame_iterator = super().inference_video_id_async(
|
|
819
819
|
video_id,
|