supervisely 6.73.225__py3-none-any.whl → 6.73.227__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/__init__.py +1 -1
- supervisely/_utils.py +23 -0
- supervisely/api/annotation_api.py +184 -14
- supervisely/api/api.py +2 -2
- supervisely/api/app_api.py +14 -8
- supervisely/api/entity_annotation/figure_api.py +11 -2
- supervisely/api/file_api.py +144 -8
- supervisely/api/image_api.py +5 -10
- supervisely/api/pointcloud/pointcloud_api.py +4 -8
- supervisely/api/video/video_annotation_api.py +45 -0
- supervisely/api/video/video_api.py +2 -4
- supervisely/api/volume/volume_api.py +2 -4
- supervisely/convert/base_converter.py +14 -10
- supervisely/io/fs.py +55 -8
- supervisely/io/json.py +32 -0
- supervisely/nn/inference/inference.py +45 -4
- supervisely/nn/inference/semantic_segmentation/semantic_segmentation.py +38 -8
- supervisely/project/download.py +176 -64
- supervisely/project/project.py +676 -35
- supervisely/project/video_project.py +293 -3
- {supervisely-6.73.225.dist-info → supervisely-6.73.227.dist-info}/METADATA +1 -1
- {supervisely-6.73.225.dist-info → supervisely-6.73.227.dist-info}/RECORD +26 -26
- {supervisely-6.73.225.dist-info → supervisely-6.73.227.dist-info}/LICENSE +0 -0
- {supervisely-6.73.225.dist-info → supervisely-6.73.227.dist-info}/WHEEL +0 -0
- {supervisely-6.73.225.dist-info → supervisely-6.73.227.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.225.dist-info → supervisely-6.73.227.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
# coding: utf-8
|
|
2
2
|
from __future__ import annotations
|
|
3
3
|
|
|
4
|
+
import asyncio
|
|
4
5
|
import json
|
|
5
6
|
from typing import Callable, Dict, List, Optional, Union
|
|
6
7
|
|
|
@@ -247,3 +248,47 @@ class VideoAnnotationAPI(EntityAnnotationAPI):
|
|
|
247
248
|
self.append(dst_id, ann)
|
|
248
249
|
if progress_cb is not None:
|
|
249
250
|
progress_cb(1)
|
|
251
|
+
|
|
252
|
+
async def download_async(
|
|
253
|
+
self,
|
|
254
|
+
video_id: int,
|
|
255
|
+
video_info=None,
|
|
256
|
+
semaphore: Optional[asyncio.Semaphore] = None,
|
|
257
|
+
) -> Dict:
|
|
258
|
+
"""
|
|
259
|
+
Download information about VideoAnnotation by video ID from API asynchronously.
|
|
260
|
+
|
|
261
|
+
:param video_id: Video ID in Supervisely.
|
|
262
|
+
:type video_id: int
|
|
263
|
+
:param video_info: VideoInfo object. Use it to avoid additional request to the server.
|
|
264
|
+
:type video_info: VideoInfo, optional
|
|
265
|
+
:param semaphore: Semaphore to limit the number of parallel downloads.
|
|
266
|
+
:type semaphore: asyncio.Semaphore, optional
|
|
267
|
+
:return: Information about VideoAnnotation in json format
|
|
268
|
+
:rtype: :class:`dict`
|
|
269
|
+
:Usage example:
|
|
270
|
+
|
|
271
|
+
.. code-block:: python
|
|
272
|
+
|
|
273
|
+
import supervisely as sly
|
|
274
|
+
|
|
275
|
+
os.environ['SERVER_ADDRESS'] = 'https://app.supervisely.com'
|
|
276
|
+
os.environ['API_TOKEN'] = 'Your Supervisely API Token'
|
|
277
|
+
api = sly.Api.from_env()
|
|
278
|
+
|
|
279
|
+
video_id = 198702499
|
|
280
|
+
loop = sly.utils.get_or_create_event_loop()
|
|
281
|
+
ann_info = loop.run_until_complete(api.video.annotation.download_async(video_id))
|
|
282
|
+
"""
|
|
283
|
+
if video_info is None:
|
|
284
|
+
video_info = self._api.video.get_info_by_id(video_id)
|
|
285
|
+
|
|
286
|
+
if semaphore is None:
|
|
287
|
+
semaphore = self._api._get_default_semaphore()
|
|
288
|
+
|
|
289
|
+
async with semaphore:
|
|
290
|
+
response = await self._api.post_async(
|
|
291
|
+
self._method_download_bulk,
|
|
292
|
+
{ApiField.DATASET_ID: video_info.dataset_id, self._entity_ids_str: [video_info.id]},
|
|
293
|
+
)
|
|
294
|
+
return response.json()
|
|
@@ -2466,8 +2466,7 @@ class VideoApi(RemoveableBulkModuleApi):
|
|
|
2466
2466
|
save_path = os.path.join("/path/to/save/", video_info.name)
|
|
2467
2467
|
|
|
2468
2468
|
semaphore = asyncio.Semaphore(100)
|
|
2469
|
-
loop =
|
|
2470
|
-
asyncio.set_event_loop(loop)
|
|
2469
|
+
loop = sly.utils.get_or_create_event_loop()
|
|
2471
2470
|
loop.run_until_complete(
|
|
2472
2471
|
api.video.download_path_async(video_info.id, save_path, semaphore)
|
|
2473
2472
|
)
|
|
@@ -2555,8 +2554,7 @@ class VideoApi(RemoveableBulkModuleApi):
|
|
|
2555
2554
|
|
|
2556
2555
|
ids = [770914, 770915]
|
|
2557
2556
|
paths = ["/path/to/save/video1.mp4", "/path/to/save/video2.mp4"]
|
|
2558
|
-
loop =
|
|
2559
|
-
asyncio.set_event_loop(loop)
|
|
2557
|
+
loop = sly.utils.get_or_create_event_loop()
|
|
2560
2558
|
loop.run_until_complete(api.video.download_paths_async(ids, paths))
|
|
2561
2559
|
"""
|
|
2562
2560
|
if len(ids) == 0:
|
|
@@ -1343,8 +1343,7 @@ class VolumeApi(RemoveableBulkModuleApi):
|
|
|
1343
1343
|
save_path = os.path.join("/path/to/save/", volume_info.name)
|
|
1344
1344
|
|
|
1345
1345
|
semaphore = asyncio.Semaphore(100)
|
|
1346
|
-
loop =
|
|
1347
|
-
asyncio.set_event_loop(loop)
|
|
1346
|
+
loop = sly.utils.get_or_create_event_loop()
|
|
1348
1347
|
loop.run_until_complete(
|
|
1349
1348
|
api.volume.download_path_async(volume_info.id, save_path, semaphore)
|
|
1350
1349
|
)
|
|
@@ -1433,8 +1432,7 @@ class VolumeApi(RemoveableBulkModuleApi):
|
|
|
1433
1432
|
|
|
1434
1433
|
ids = [770914, 770915]
|
|
1435
1434
|
paths = ["/path/to/save/volume1.nrrd", "/path/to/save/volume2.nrrd"]
|
|
1436
|
-
loop =
|
|
1437
|
-
asyncio.set_event_loop(loop)
|
|
1435
|
+
loop = sly.utils.get_or_create_event_loop()
|
|
1438
1436
|
loop.run_until_complete(api.volume.download_paths_async(ids, paths))
|
|
1439
1437
|
"""
|
|
1440
1438
|
if len(ids) == 0:
|
|
@@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Tuple, Union
|
|
|
6
6
|
|
|
7
7
|
from tqdm import tqdm
|
|
8
8
|
|
|
9
|
-
from supervisely._utils import batched, is_production
|
|
9
|
+
from supervisely._utils import batched, get_or_create_event_loop, is_production
|
|
10
10
|
from supervisely.annotation.annotation import Annotation
|
|
11
11
|
from supervisely.annotation.tag_meta import TagValueType
|
|
12
12
|
from supervisely.api.api import Api
|
|
@@ -468,7 +468,7 @@ class BaseConverter:
|
|
|
468
468
|
for remote_path in files.values()
|
|
469
469
|
)
|
|
470
470
|
|
|
471
|
-
loop =
|
|
471
|
+
loop = get_or_create_event_loop()
|
|
472
472
|
_, progress_cb = self.get_progress(
|
|
473
473
|
len(files) if not is_archive_type else file_size,
|
|
474
474
|
f"Downloading {files_type} from remote storage",
|
|
@@ -479,15 +479,19 @@ class BaseConverter:
|
|
|
479
479
|
silent_remove(local_path)
|
|
480
480
|
|
|
481
481
|
logger.info(f"Downloading {files_type} from remote storage...")
|
|
482
|
-
|
|
483
|
-
self.
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
progress_cb_type="number" if not is_archive_type else "size",
|
|
489
|
-
)
|
|
482
|
+
download_coro = self._api.storage.download_bulk_async(
|
|
483
|
+
team_id=self._team_id,
|
|
484
|
+
remote_paths=list(files.values()),
|
|
485
|
+
local_save_paths=list(files.keys()),
|
|
486
|
+
progress_cb=progress_cb,
|
|
487
|
+
progress_cb_type="number" if not is_archive_type else "size",
|
|
490
488
|
)
|
|
489
|
+
|
|
490
|
+
if loop.is_running():
|
|
491
|
+
future = asyncio.run_coroutine_threadsafe(download_coro, loop=loop)
|
|
492
|
+
future.result()
|
|
493
|
+
else:
|
|
494
|
+
loop.run_until_complete(download_coro)
|
|
491
495
|
logger.info("Possible annotations downloaded successfully.")
|
|
492
496
|
|
|
493
497
|
if is_archive_type:
|
supervisely/io/fs.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
# coding: utf-8
|
|
2
2
|
|
|
3
3
|
# docs
|
|
4
|
-
import asyncio
|
|
5
4
|
import errno
|
|
6
5
|
import mimetypes
|
|
7
6
|
import os
|
|
@@ -16,7 +15,7 @@ import requests
|
|
|
16
15
|
from requests.structures import CaseInsensitiveDict
|
|
17
16
|
from tqdm import tqdm
|
|
18
17
|
|
|
19
|
-
from supervisely._utils import get_bytes_hash, get_string_hash
|
|
18
|
+
from supervisely._utils import get_bytes_hash, get_or_create_event_loop, get_string_hash
|
|
20
19
|
from supervisely.io.fs_cache import FileCache
|
|
21
20
|
from supervisely.sly_logger import logger
|
|
22
21
|
from supervisely.task.progress import Progress
|
|
@@ -1375,8 +1374,15 @@ async def copy_file_async(
|
|
|
1375
1374
|
|
|
1376
1375
|
.. code-block:: python
|
|
1377
1376
|
|
|
1378
|
-
|
|
1379
|
-
|
|
1377
|
+
import supervisely as sly
|
|
1378
|
+
|
|
1379
|
+
loop = sly.utils.get_or_create_event_loop()
|
|
1380
|
+
coro = sly.fs.copy_file_async('/home/admin/work/projects/example/1.png', '/home/admin/work/tests/2.png')
|
|
1381
|
+
if loop.is_running():
|
|
1382
|
+
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
|
1383
|
+
future.result()
|
|
1384
|
+
else:
|
|
1385
|
+
loop.run_until_complete(coro)
|
|
1380
1386
|
"""
|
|
1381
1387
|
ensure_base_path(dst)
|
|
1382
1388
|
async with aiofiles.open(dst, "wb") as out_f:
|
|
@@ -1404,8 +1410,15 @@ async def get_file_hash_async(path: str) -> str:
|
|
|
1404
1410
|
|
|
1405
1411
|
.. code-block:: python
|
|
1406
1412
|
|
|
1407
|
-
|
|
1408
|
-
|
|
1413
|
+
import supervisely as sly
|
|
1414
|
+
|
|
1415
|
+
loop = sly.utils.get_or_create_event_loop()
|
|
1416
|
+
coro = sly.fs.get_file_hash_async('/home/admin/work/projects/examples/1.jpeg')
|
|
1417
|
+
if loop.is_running():
|
|
1418
|
+
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
|
1419
|
+
hash = future.result()
|
|
1420
|
+
else:
|
|
1421
|
+
hash = loop.run_until_complete(coro)
|
|
1409
1422
|
"""
|
|
1410
1423
|
async with aiofiles.open(path, "rb") as file:
|
|
1411
1424
|
file_bytes = await file.read()
|
|
@@ -1442,7 +1455,13 @@ async def unpack_archive_async(
|
|
|
1442
1455
|
archive_path = '/home/admin/work/examples.tar'
|
|
1443
1456
|
target_dir = '/home/admin/work/projects'
|
|
1444
1457
|
|
|
1445
|
-
|
|
1458
|
+
loop = sly.utils.get_or_create_event_loop()
|
|
1459
|
+
coro = sly.fs.unpack_archive_async(archive_path, target_dir)
|
|
1460
|
+
if loop.is_running():
|
|
1461
|
+
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
|
1462
|
+
future.result()
|
|
1463
|
+
else:
|
|
1464
|
+
loop.run_until_complete(coro)
|
|
1446
1465
|
"""
|
|
1447
1466
|
if is_split:
|
|
1448
1467
|
chunk = chunk_size_mb * 1024 * 1024
|
|
@@ -1467,9 +1486,37 @@ async def unpack_archive_async(
|
|
|
1467
1486
|
await output_file.write(data)
|
|
1468
1487
|
archive_path = combined
|
|
1469
1488
|
|
|
1470
|
-
loop =
|
|
1489
|
+
loop = get_or_create_event_loop()
|
|
1471
1490
|
await loop.run_in_executor(None, shutil.unpack_archive, archive_path, target_dir)
|
|
1472
1491
|
if is_split:
|
|
1473
1492
|
silent_remove(archive_path)
|
|
1474
1493
|
if remove_junk:
|
|
1475
1494
|
remove_junk_from_dir(target_dir)
|
|
1495
|
+
|
|
1496
|
+
|
|
1497
|
+
async def touch_async(path: str) -> None:
|
|
1498
|
+
"""
|
|
1499
|
+
Sets access and modification times for a file asynchronously.
|
|
1500
|
+
|
|
1501
|
+
:param path: Target file path.
|
|
1502
|
+
:type path: str
|
|
1503
|
+
:returns: None
|
|
1504
|
+
:rtype: :class:`NoneType`
|
|
1505
|
+
:Usage example:
|
|
1506
|
+
|
|
1507
|
+
.. code-block:: python
|
|
1508
|
+
|
|
1509
|
+
import supervisely as sly
|
|
1510
|
+
|
|
1511
|
+
loop = sly.utils.get_or_create_event_loop()
|
|
1512
|
+
coro = sly.fs.touch_async('/home/admin/work/projects/examples/1.jpeg')
|
|
1513
|
+
if loop.is_running():
|
|
1514
|
+
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
|
1515
|
+
future.result()
|
|
1516
|
+
else:
|
|
1517
|
+
loop.run_until_complete(coro)
|
|
1518
|
+
"""
|
|
1519
|
+
ensure_base_path(path)
|
|
1520
|
+
async with aiofiles.open(path, "a"):
|
|
1521
|
+
loop = get_or_create_event_loop()
|
|
1522
|
+
await loop.run_in_executor(None, os.utime, path, None)
|
supervisely/io/json.py
CHANGED
|
@@ -3,6 +3,7 @@ import json
|
|
|
3
3
|
import os
|
|
4
4
|
from typing import Dict, Optional
|
|
5
5
|
|
|
6
|
+
import aiofiles
|
|
6
7
|
import jsonschema
|
|
7
8
|
|
|
8
9
|
|
|
@@ -230,3 +231,34 @@ def validate_json(data: Dict, schema: Dict, raise_error: bool = False) -> bool:
|
|
|
230
231
|
if raise_error:
|
|
231
232
|
raise ValueError("JSON data is invalid. See error message for more details.") from err
|
|
232
233
|
return False
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
async def dump_json_file_async(data: Dict, filename: str, indent: Optional[int] = 4) -> None:
|
|
237
|
+
"""
|
|
238
|
+
Write given data in json format in file with given name asynchronously.
|
|
239
|
+
|
|
240
|
+
:param data: Data in json format as a dict.
|
|
241
|
+
:type data: dict
|
|
242
|
+
:param filename: Target file path to write data.
|
|
243
|
+
:type filename: str
|
|
244
|
+
:param indent: Json array elements and object members will be pretty-printed with that indent level.
|
|
245
|
+
:type indent: int, optional
|
|
246
|
+
:returns: None
|
|
247
|
+
:rtype: :class:`NoneType`
|
|
248
|
+
:Usage example:
|
|
249
|
+
|
|
250
|
+
.. code-block:: python
|
|
251
|
+
|
|
252
|
+
import supervisely as sly
|
|
253
|
+
|
|
254
|
+
data = {1: 'example'}
|
|
255
|
+
loop = sly.utils.get_or_create_event_loop()
|
|
256
|
+
coro = sly.json.dump_json_file_async(data, '/home/admin/work/projects/examples/1.json')
|
|
257
|
+
if loop.is_running():
|
|
258
|
+
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
|
259
|
+
future.result()
|
|
260
|
+
else:
|
|
261
|
+
loop.run_until_complete(coro)
|
|
262
|
+
"""
|
|
263
|
+
async with aiofiles.open(filename, "w") as fout:
|
|
264
|
+
await fout.write(json.dumps(data, indent=indent))
|
|
@@ -574,10 +574,15 @@ class Inference:
|
|
|
574
574
|
for prediction in predictions:
|
|
575
575
|
if (
|
|
576
576
|
not classes_whitelist in (None, "all")
|
|
577
|
+
and hasattr(prediction, "class_name")
|
|
577
578
|
and prediction.class_name not in classes_whitelist
|
|
578
579
|
):
|
|
579
580
|
continue
|
|
580
|
-
|
|
581
|
+
if "classes_whitelist" in inspect.signature(self._create_label).parameters:
|
|
582
|
+
# pylint: disable=unexpected-keyword-arg
|
|
583
|
+
label = self._create_label(prediction, classes_whitelist) # pylint: disable=too-many-function-args
|
|
584
|
+
else:
|
|
585
|
+
label = self._create_label(prediction)
|
|
581
586
|
if label is None:
|
|
582
587
|
# for example empty mask
|
|
583
588
|
continue
|
|
@@ -2543,17 +2548,52 @@ def clean_up_cuda():
|
|
|
2543
2548
|
logger.debug("Error in clean_up_cuda.", exc_info=True)
|
|
2544
2549
|
|
|
2545
2550
|
|
|
2551
|
+
def _fix_classes_names(meta: ProjectMeta, ann: Annotation):
|
|
2552
|
+
def _replace_strip(s, chars: str, replacement: str = "_") -> str:
|
|
2553
|
+
replace_pattern = f"^[{re.escape(chars)}]+|[{re.escape(chars)}]+$"
|
|
2554
|
+
return re.sub(replace_pattern, replacement, s)
|
|
2555
|
+
|
|
2556
|
+
replaced_classes_in_meta = []
|
|
2557
|
+
for obj_class in meta.obj_classes:
|
|
2558
|
+
obj_class_name = _replace_strip(obj_class.name, " ", "")
|
|
2559
|
+
if obj_class_name != obj_class.name:
|
|
2560
|
+
new_obj_class = obj_class.clone(name=obj_class_name)
|
|
2561
|
+
meta = meta.delete_obj_class(obj_class.name)
|
|
2562
|
+
meta = meta.add_obj_class(new_obj_class)
|
|
2563
|
+
replaced_classes_in_meta.append((obj_class.name, obj_class_name))
|
|
2564
|
+
replaced_classes_in_ann = set()
|
|
2565
|
+
new_labels = []
|
|
2566
|
+
for label in ann.labels:
|
|
2567
|
+
obj_class = label.obj_class
|
|
2568
|
+
obj_class_name = _replace_strip(obj_class.name, " ", "")
|
|
2569
|
+
if obj_class_name != obj_class.name:
|
|
2570
|
+
new_obj_class = obj_class.clone(name=obj_class_name)
|
|
2571
|
+
label = label.clone(obj_class=new_obj_class)
|
|
2572
|
+
replaced_classes_in_ann.add((obj_class.name, obj_class_name))
|
|
2573
|
+
new_labels.append(label)
|
|
2574
|
+
ann = ann.clone(labels=new_labels)
|
|
2575
|
+
return meta, ann, replaced_classes_in_meta, list(replaced_classes_in_ann)
|
|
2576
|
+
|
|
2577
|
+
|
|
2546
2578
|
def update_meta_and_ann(meta: ProjectMeta, ann: Annotation):
|
|
2547
2579
|
"""Update project meta and annotation to match each other
|
|
2548
2580
|
If obj class or tag meta from annotation conflicts with project meta
|
|
2549
2581
|
add suffix to obj class or tag meta.
|
|
2550
2582
|
Return tuple of updated project meta, annotation and boolean flag if meta was changed."""
|
|
2551
|
-
obj_classes_suffixes =
|
|
2552
|
-
tag_meta_suffixes =
|
|
2583
|
+
obj_classes_suffixes = ["_nn"]
|
|
2584
|
+
tag_meta_suffixes = ["_nn"]
|
|
2553
2585
|
ann_obj_classes = {}
|
|
2554
2586
|
ann_tag_metas = {}
|
|
2555
2587
|
meta_changed = False
|
|
2556
2588
|
|
|
2589
|
+
meta, ann, replaced_classes_in_meta, replaced_classes_in_ann = _fix_classes_names(meta, ann)
|
|
2590
|
+
if replaced_classes_in_meta:
|
|
2591
|
+
meta_changed = True
|
|
2592
|
+
logger.warning(
|
|
2593
|
+
"Some classes names were fixed in project meta",
|
|
2594
|
+
extra={"replaced_classes": {old: new for old, new in replaced_classes_in_meta}},
|
|
2595
|
+
)
|
|
2596
|
+
|
|
2557
2597
|
# get all obj classes and tag metas from annotation
|
|
2558
2598
|
for label in ann.labels:
|
|
2559
2599
|
ann_obj_classes[label.obj_class.name] = label.obj_class
|
|
@@ -2563,7 +2603,8 @@ def update_meta_and_ann(meta: ProjectMeta, ann: Annotation):
|
|
|
2563
2603
|
ann_tag_metas[tag.meta.name] = tag.meta
|
|
2564
2604
|
|
|
2565
2605
|
# check if obj classes are in project meta
|
|
2566
|
-
# if not, add them
|
|
2606
|
+
# if not, add them.
|
|
2607
|
+
# if shape is different, add them with suffix
|
|
2567
2608
|
changed_obj_classes = {}
|
|
2568
2609
|
for ann_obj_class in ann_obj_classes.values():
|
|
2569
2610
|
if meta.get_obj_class(ann_obj_class.name) is None:
|
|
@@ -3,9 +3,11 @@ from typing import Any, Dict, List
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
|
|
5
5
|
from supervisely.annotation.label import Label
|
|
6
|
+
from supervisely.annotation.obj_class import ObjClass
|
|
6
7
|
from supervisely.geometry.bitmap import Bitmap
|
|
7
8
|
from supervisely.nn.inference.inference import Inference
|
|
8
9
|
from supervisely.nn.prediction_dto import PredictionSegmentation
|
|
10
|
+
from supervisely.project.project_meta import ProjectMeta
|
|
9
11
|
from supervisely.sly_logger import logger
|
|
10
12
|
|
|
11
13
|
|
|
@@ -24,20 +26,48 @@ class SemanticSegmentation(Inference):
|
|
|
24
26
|
def _get_obj_class_shape(self):
|
|
25
27
|
return Bitmap
|
|
26
28
|
|
|
27
|
-
def
|
|
28
|
-
|
|
29
|
+
def _find_bg_class_index(self, class_names: List[str]):
|
|
30
|
+
possible_bg_names = ["background", "bg", "unlabeled", "neutral", "__bg__"]
|
|
31
|
+
bg_class_index = None
|
|
32
|
+
for i, name in enumerate(class_names):
|
|
33
|
+
if name in possible_bg_names:
|
|
34
|
+
bg_class_index = i
|
|
35
|
+
break
|
|
36
|
+
return bg_class_index
|
|
37
|
+
|
|
38
|
+
def _add_default_bg_class(self, meta: ProjectMeta):
|
|
39
|
+
default_bg_class_name = "__bg__"
|
|
40
|
+
obj_class = meta.get_obj_class(default_bg_class_name)
|
|
41
|
+
if obj_class is None:
|
|
42
|
+
obj_class = ObjClass(default_bg_class_name, self._get_obj_class_shape())
|
|
43
|
+
meta = meta.add_obj_class(obj_class)
|
|
44
|
+
return meta, obj_class
|
|
45
|
+
|
|
46
|
+
def _get_or_create_bg_obj_class(self, classes):
|
|
47
|
+
bg_class_index = self._find_bg_class_index(classes)
|
|
48
|
+
if bg_class_index is None:
|
|
49
|
+
self._model_meta, bg_obj_class = self._add_default_bg_class(self.model_meta)
|
|
50
|
+
else:
|
|
51
|
+
bg_class_name = classes[bg_class_index]
|
|
52
|
+
bg_obj_class = self.model_meta.get_obj_class(bg_class_name)
|
|
53
|
+
return bg_obj_class
|
|
54
|
+
|
|
55
|
+
def _create_label(self, dto: PredictionSegmentation, classes_whitelist: List[str] = None):
|
|
56
|
+
classes = self.get_classes()
|
|
57
|
+
|
|
58
|
+
image_classes_indexes = np.unique(dto.mask)
|
|
29
59
|
labels = []
|
|
30
|
-
for class_idx in
|
|
60
|
+
for class_idx in image_classes_indexes:
|
|
31
61
|
class_mask = dto.mask == class_idx
|
|
32
|
-
class_name =
|
|
33
|
-
|
|
62
|
+
class_name = classes[class_idx]
|
|
63
|
+
if classes_whitelist not in (None, "all") and class_name not in classes_whitelist:
|
|
64
|
+
obj_class = self._get_or_create_bg_obj_class(classes)
|
|
65
|
+
else:
|
|
66
|
+
obj_class = self.model_meta.get_obj_class(class_name)
|
|
34
67
|
if obj_class is None:
|
|
35
68
|
raise KeyError(
|
|
36
69
|
f"Class {class_name} not found in model classes {self.get_classes()}"
|
|
37
70
|
)
|
|
38
|
-
if not class_mask.any(): # skip empty masks
|
|
39
|
-
logger.debug(f"Mask of class {class_name} is empty and will be sklipped")
|
|
40
|
-
return None
|
|
41
71
|
geometry = Bitmap(class_mask, extra_validation=False)
|
|
42
72
|
label = Label(geometry, obj_class)
|
|
43
73
|
labels.append(label)
|