supervisely 6.73.357__py3-none-any.whl → 6.73.358__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/_utils.py +12 -0
- supervisely/api/annotation_api.py +3 -0
- supervisely/api/api.py +2 -2
- supervisely/api/app_api.py +27 -2
- supervisely/api/entity_annotation/tag_api.py +0 -1
- supervisely/api/nn/__init__.py +0 -0
- supervisely/api/nn/deploy_api.py +821 -0
- supervisely/api/nn/neural_network_api.py +248 -0
- supervisely/api/task_api.py +26 -467
- supervisely/app/fastapi/subapp.py +1 -0
- supervisely/nn/__init__.py +2 -1
- supervisely/nn/artifacts/artifacts.py +5 -5
- supervisely/nn/benchmark/object_detection/metric_provider.py +3 -0
- supervisely/nn/experiments.py +28 -5
- supervisely/nn/inference/cache.py +178 -114
- supervisely/nn/inference/gui/gui.py +18 -35
- supervisely/nn/inference/gui/serving_gui.py +3 -1
- supervisely/nn/inference/inference.py +1421 -1265
- supervisely/nn/inference/inference_request.py +412 -0
- supervisely/nn/inference/object_detection_3d/object_detection_3d.py +31 -24
- supervisely/nn/inference/session.py +2 -2
- supervisely/nn/inference/tracking/base_tracking.py +45 -79
- supervisely/nn/inference/tracking/bbox_tracking.py +220 -155
- supervisely/nn/inference/tracking/mask_tracking.py +274 -250
- supervisely/nn/inference/tracking/tracker_interface.py +23 -0
- supervisely/nn/inference/uploader.py +164 -0
- supervisely/nn/model/__init__.py +0 -0
- supervisely/nn/model/model_api.py +259 -0
- supervisely/nn/model/prediction.py +311 -0
- supervisely/nn/model/prediction_session.py +632 -0
- supervisely/nn/tracking/__init__.py +1 -0
- supervisely/nn/tracking/boxmot.py +114 -0
- supervisely/nn/tracking/tracking.py +24 -0
- supervisely/nn/training/train_app.py +61 -19
- supervisely/nn/utils.py +43 -3
- supervisely/task/progress.py +12 -2
- supervisely/video/video.py +107 -1
- {supervisely-6.73.357.dist-info → supervisely-6.73.358.dist-info}/METADATA +2 -1
- {supervisely-6.73.357.dist-info → supervisely-6.73.358.dist-info}/RECORD +43 -32
- supervisely/api/neural_network_api.py +0 -202
- {supervisely-6.73.357.dist-info → supervisely-6.73.358.dist-info}/LICENSE +0 -0
- {supervisely-6.73.357.dist-info → supervisely-6.73.358.dist-info}/WHEEL +0 -0
- {supervisely-6.73.357.dist-info → supervisely-6.73.358.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.357.dist-info → supervisely-6.73.358.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,6 @@ import functools
|
|
|
2
2
|
import inspect
|
|
3
3
|
import json
|
|
4
4
|
import traceback
|
|
5
|
-
from threading import Lock
|
|
6
5
|
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
|
|
7
6
|
|
|
8
7
|
from fastapi import Form, Request, Response, UploadFile, status
|
|
@@ -14,7 +13,6 @@ from supervisely.api.module_api import ApiField
|
|
|
14
13
|
from supervisely.io import env
|
|
15
14
|
from supervisely.nn.inference.inference import (
|
|
16
15
|
Inference,
|
|
17
|
-
_convert_sly_progress_to_dict,
|
|
18
16
|
_get_log_extra_for_inference_request,
|
|
19
17
|
)
|
|
20
18
|
from supervisely.sly_logger import logger
|
|
@@ -97,10 +95,6 @@ class BaseTracking(Inference):
|
|
|
97
95
|
info["task type"] = "tracking"
|
|
98
96
|
return info
|
|
99
97
|
|
|
100
|
-
def _on_inference_start(self, inference_request_uuid: str):
|
|
101
|
-
super()._on_inference_start(inference_request_uuid)
|
|
102
|
-
self._inference_requests[inference_request_uuid]["lock"] = Lock()
|
|
103
|
-
|
|
104
98
|
@staticmethod
|
|
105
99
|
def _notify_error_default(
|
|
106
100
|
api: Api, track_id: str, exception: Exception, with_traceback: bool = False
|
|
@@ -131,23 +125,6 @@ class BaseTracking(Inference):
|
|
|
131
125
|
message=f"{error_name}: {message}",
|
|
132
126
|
)
|
|
133
127
|
|
|
134
|
-
def _handle_error_in_async(self, uuid):
|
|
135
|
-
def decorator(func):
|
|
136
|
-
@functools.wraps(func)
|
|
137
|
-
def wrapper(*args, **kwargs):
|
|
138
|
-
try:
|
|
139
|
-
return func(*args, **kwargs)
|
|
140
|
-
except Exception as e:
|
|
141
|
-
inf_request = self._inference_requests.get(uuid, None)
|
|
142
|
-
if inf_request is not None:
|
|
143
|
-
inf_request["exception"] = str(e)
|
|
144
|
-
logger.error(f"Error in {func.__name__} function: {e}", exc_info=True)
|
|
145
|
-
raise e
|
|
146
|
-
|
|
147
|
-
return wrapper
|
|
148
|
-
|
|
149
|
-
return decorator
|
|
150
|
-
|
|
151
128
|
@staticmethod
|
|
152
129
|
def send_error_data(api, context):
|
|
153
130
|
def decorator(func):
|
|
@@ -181,72 +158,58 @@ class BaseTracking(Inference):
|
|
|
181
158
|
|
|
182
159
|
return decorator
|
|
183
160
|
|
|
184
|
-
def schedule_task(self, func, *args, **kwargs):
|
|
185
|
-
inference_request_uuid = kwargs.get("inference_request_uuid", None)
|
|
186
|
-
if inference_request_uuid is None:
|
|
187
|
-
self._executor.submit(func, *args, **kwargs)
|
|
188
|
-
else:
|
|
189
|
-
self._on_inference_start(inference_request_uuid)
|
|
190
|
-
fn = self._handle_error_in_async(inference_request_uuid)(func)
|
|
191
|
-
future = self._executor.submit(
|
|
192
|
-
fn,
|
|
193
|
-
*args,
|
|
194
|
-
**kwargs,
|
|
195
|
-
)
|
|
196
|
-
end_callback = functools.partial(
|
|
197
|
-
self._on_inference_end, inference_request_uuid=inference_request_uuid
|
|
198
|
-
)
|
|
199
|
-
future.add_done_callback(end_callback)
|
|
200
|
-
logger.debug("Scheduled task.", extra={"inference_request_uuid": inference_request_uuid})
|
|
201
|
-
|
|
202
161
|
def _pop_tracking_results(self, inference_request_uuid: str, frame_range: Tuple = None):
|
|
203
|
-
inference_request = self.
|
|
162
|
+
inference_request = self.inference_requests_manager.get(inference_request_uuid)
|
|
204
163
|
logger.debug(
|
|
205
164
|
"Pop tracking results",
|
|
206
165
|
extra={
|
|
207
166
|
"inference_request_uuid": inference_request_uuid,
|
|
208
|
-
"pending_results_len":
|
|
167
|
+
"pending_results_len": inference_request.pending_num(),
|
|
209
168
|
"frame_range": frame_range,
|
|
210
169
|
},
|
|
211
170
|
)
|
|
212
|
-
with inference_request["lock"]:
|
|
213
|
-
inference_request_copy = inference_request.copy()
|
|
214
|
-
|
|
215
|
-
if frame_range is not None:
|
|
216
|
-
|
|
217
|
-
def _in_range(figure):
|
|
218
|
-
return (
|
|
219
|
-
figure.frame_index >= frame_range[0]
|
|
220
|
-
and figure.frame_index <= frame_range[1]
|
|
221
|
-
)
|
|
222
|
-
|
|
223
|
-
inference_request_copy["pending_results"] = list(
|
|
224
|
-
filter(_in_range, inference_request_copy["pending_results"])
|
|
225
|
-
)
|
|
226
|
-
inference_request["pending_results"] = list(
|
|
227
|
-
filter(lambda x: not _in_range(x), inference_request["pending_results"])
|
|
228
|
-
)
|
|
229
|
-
else:
|
|
230
|
-
inference_request["pending_results"] = []
|
|
231
|
-
inference_request_copy.pop("lock")
|
|
232
|
-
inference_request_copy["progress"] = _convert_sly_progress_to_dict(
|
|
233
|
-
inference_request_copy["progress"]
|
|
234
|
-
)
|
|
235
171
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
172
|
+
data = {}
|
|
173
|
+
if frame_range is not None:
|
|
174
|
+
|
|
175
|
+
def _in_range(figure):
|
|
176
|
+
return figure.frame_index >= frame_range[0] and figure.frame_index <= frame_range[1]
|
|
177
|
+
|
|
178
|
+
with inference_request._lock:
|
|
179
|
+
data["pending_results"] = [
|
|
180
|
+
x for x in inference_request._pending_results if _in_range(x)
|
|
181
|
+
]
|
|
182
|
+
inference_request._pending_results = [
|
|
183
|
+
x for x in inference_request._pending_results if not _in_range(x)
|
|
184
|
+
]
|
|
185
|
+
else:
|
|
186
|
+
data["pending_results"] = inference_request.pop_pending_results()
|
|
187
|
+
|
|
188
|
+
data = {
|
|
189
|
+
**inference_request.to_json(),
|
|
190
|
+
**_get_log_extra_for_inference_request(inference_request.uuid, inference_request),
|
|
191
|
+
"pending_results": data["pending_results"],
|
|
192
|
+
}
|
|
239
193
|
|
|
240
|
-
return
|
|
194
|
+
return data
|
|
241
195
|
|
|
242
196
|
def _clear_tracking_results(self, inference_request_uuid):
|
|
243
|
-
|
|
197
|
+
if inference_request_uuid is None:
|
|
198
|
+
raise ValueError("'inference_request_uuid' is required.")
|
|
199
|
+
self.inference_requests_manager.remove_after(inference_request_uuid, 60)
|
|
244
200
|
logger.debug("Removed an inference request:", extra={"uuid": inference_request_uuid})
|
|
201
|
+
return {"success": True}
|
|
245
202
|
|
|
246
203
|
def _stop_tracking(self, inference_request_uuid: str):
|
|
247
|
-
inference_request = self.
|
|
248
|
-
inference_request
|
|
249
|
-
logger.debug(
|
|
204
|
+
inference_request = self.inference_requests_manager.get(inference_request_uuid)
|
|
205
|
+
inference_request.stop()
|
|
206
|
+
logger.debug(
|
|
207
|
+
"Stopped tracking:",
|
|
208
|
+
extra={
|
|
209
|
+
"uuid": inference_request_uuid,
|
|
210
|
+
"inference_request_uuid": inference_request_uuid,
|
|
211
|
+
},
|
|
212
|
+
)
|
|
250
213
|
|
|
251
214
|
# Implement the following methods in the derived class
|
|
252
215
|
def track(self, api: Api, state: Dict, context: Dict):
|
|
@@ -274,14 +237,17 @@ class BaseTracking(Inference):
|
|
|
274
237
|
def pop_tracking_results(self, state: Dict, context: Dict):
|
|
275
238
|
validate_key(context, "inference_request_uuid", str)
|
|
276
239
|
inference_request_uuid = context["inference_request_uuid"]
|
|
240
|
+
|
|
241
|
+
inference_request = self.inference_requests_manager.get(inference_request_uuid)
|
|
242
|
+
log_extra = _get_log_extra_for_inference_request(inference_request.uuid, inference_request)
|
|
277
243
|
frame_range = find_value_by_keys(context, ["frameRange", "frame_range", "frames"])
|
|
278
244
|
tracking_results = self._pop_tracking_results(inference_request_uuid, frame_range)
|
|
279
|
-
log_extra = _get_log_extra_for_inference_request(inference_request_uuid, tracking_results)
|
|
280
245
|
logger.debug(f"Sending inference delta results with uuid:", extra=log_extra)
|
|
281
246
|
return tracking_results
|
|
282
247
|
|
|
283
248
|
def clear_tracking_results(self, state: Dict, context: Dict):
|
|
284
|
-
|
|
249
|
+
inference_request_uuid = context.get("inference_request_uuid", None)
|
|
250
|
+
self._clear_tracking_results(inference_request_uuid)
|
|
285
251
|
return {"message": "Inference results cleared.", "success": True}
|
|
286
252
|
|
|
287
253
|
def _register_endpoints(self):
|
|
@@ -290,7 +256,7 @@ class BaseTracking(Inference):
|
|
|
290
256
|
@server.post("/track")
|
|
291
257
|
@handle_validation
|
|
292
258
|
def track_handler(request: Request):
|
|
293
|
-
api = request
|
|
259
|
+
api = self.api_from_request(request)
|
|
294
260
|
state = request.state.state
|
|
295
261
|
context = request.state.context
|
|
296
262
|
logger.info("Received track request.", extra={"context": context, "state": state})
|
|
@@ -299,7 +265,7 @@ class BaseTracking(Inference):
|
|
|
299
265
|
@server.post("/track-api")
|
|
300
266
|
@handle_validation
|
|
301
267
|
async def track_api_handler(request: Request):
|
|
302
|
-
api = request
|
|
268
|
+
api = self.api_from_request(request)
|
|
303
269
|
state = request.state.state
|
|
304
270
|
context = request.state.context
|
|
305
271
|
logger.info("Received track-api request.", extra={"context": context, "state": state})
|
|
@@ -320,7 +286,7 @@ class BaseTracking(Inference):
|
|
|
320
286
|
@server.post("/track_async")
|
|
321
287
|
@handle_validation
|
|
322
288
|
def track_async_handler(request: Request):
|
|
323
|
-
api = request
|
|
289
|
+
api = self.api_from_request(request)
|
|
324
290
|
state = request.state.state
|
|
325
291
|
context = request.state.context
|
|
326
292
|
logger.info("Received track_async request.", extra={"context": context, "state": state})
|