supervisely 6.73.243__py3-none-any.whl → 6.73.244__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/__init__.py +1 -1
- supervisely/_utils.py +18 -0
- supervisely/app/widgets/__init__.py +1 -0
- supervisely/app/widgets/card/card.py +3 -0
- supervisely/app/widgets/classes_table/classes_table.py +15 -1
- supervisely/app/widgets/custom_models_selector/custom_models_selector.py +25 -7
- supervisely/app/widgets/custom_models_selector/template.html +1 -1
- supervisely/app/widgets/experiment_selector/__init__.py +0 -0
- supervisely/app/widgets/experiment_selector/experiment_selector.py +500 -0
- supervisely/app/widgets/experiment_selector/style.css +27 -0
- supervisely/app/widgets/experiment_selector/template.html +82 -0
- supervisely/app/widgets/pretrained_models_selector/pretrained_models_selector.py +25 -3
- supervisely/app/widgets/random_splits_table/random_splits_table.py +41 -17
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +12 -5
- supervisely/app/widgets/train_val_splits/train_val_splits.py +99 -10
- supervisely/app/widgets/tree_select/tree_select.py +2 -0
- supervisely/nn/__init__.py +3 -1
- supervisely/nn/artifacts/artifacts.py +10 -0
- supervisely/nn/artifacts/detectron2.py +2 -0
- supervisely/nn/artifacts/hrda.py +3 -0
- supervisely/nn/artifacts/mmclassification.py +2 -0
- supervisely/nn/artifacts/mmdetection.py +6 -3
- supervisely/nn/artifacts/mmsegmentation.py +2 -0
- supervisely/nn/artifacts/ritm.py +3 -1
- supervisely/nn/artifacts/rtdetr.py +2 -0
- supervisely/nn/artifacts/unet.py +2 -0
- supervisely/nn/artifacts/yolov5.py +3 -0
- supervisely/nn/artifacts/yolov8.py +7 -1
- supervisely/nn/experiments.py +113 -0
- supervisely/nn/inference/gui/__init__.py +3 -1
- supervisely/nn/inference/gui/gui.py +31 -232
- supervisely/nn/inference/gui/serving_gui.py +223 -0
- supervisely/nn/inference/gui/serving_gui_template.py +240 -0
- supervisely/nn/inference/inference.py +225 -24
- supervisely/nn/training/__init__.py +0 -0
- supervisely/nn/training/gui/__init__.py +1 -0
- supervisely/nn/training/gui/classes_selector.py +100 -0
- supervisely/nn/training/gui/gui.py +539 -0
- supervisely/nn/training/gui/hyperparameters_selector.py +117 -0
- supervisely/nn/training/gui/input_selector.py +70 -0
- supervisely/nn/training/gui/model_selector.py +95 -0
- supervisely/nn/training/gui/train_val_splits_selector.py +200 -0
- supervisely/nn/training/gui/training_logs.py +93 -0
- supervisely/nn/training/gui/training_process.py +114 -0
- supervisely/nn/training/gui/utils.py +128 -0
- supervisely/nn/training/loggers/__init__.py +0 -0
- supervisely/nn/training/loggers/base_train_logger.py +58 -0
- supervisely/nn/training/loggers/tensorboard_logger.py +46 -0
- supervisely/nn/training/train_app.py +2038 -0
- supervisely/nn/utils.py +5 -0
- {supervisely-6.73.243.dist-info → supervisely-6.73.244.dist-info}/METADATA +3 -1
- {supervisely-6.73.243.dist-info → supervisely-6.73.244.dist-info}/RECORD +56 -34
- {supervisely-6.73.243.dist-info → supervisely-6.73.244.dist-info}/LICENSE +0 -0
- {supervisely-6.73.243.dist-info → supervisely-6.73.244.dist-info}/WHEEL +0 -0
- {supervisely-6.73.243.dist-info → supervisely-6.73.244.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.243.dist-info → supervisely-6.73.244.dist-info}/top_level.txt +0 -0
supervisely/__init__.py
CHANGED
|
@@ -309,4 +309,4 @@ except Exception as e:
|
|
|
309
309
|
# If new changes in Supervisely Python SDK require upgrade of the Supervisely instance
|
|
310
310
|
# set a new value for the environment variable MINIMUM_INSTANCE_VERSION_FOR_SDK, otherwise
|
|
311
311
|
# users can face compatibility issues, if the instance version is lower than the SDK version.
|
|
312
|
-
os.environ["MINIMUM_INSTANCE_VERSION_FOR_SDK"] = "6.12.
|
|
312
|
+
os.environ["MINIMUM_INSTANCE_VERSION_FOR_SDK"] = "6.12.12"
|
supervisely/_utils.py
CHANGED
|
@@ -17,6 +17,7 @@ from tempfile import gettempdir
|
|
|
17
17
|
from typing import Any, Dict, List, Literal, Optional, Tuple
|
|
18
18
|
|
|
19
19
|
import numpy as np
|
|
20
|
+
import requests
|
|
20
21
|
from requests.utils import DEFAULT_CA_BUNDLE_PATH
|
|
21
22
|
|
|
22
23
|
from supervisely.io import env as sly_env
|
|
@@ -459,3 +460,20 @@ def get_or_create_event_loop() -> asyncio.AbstractEventLoop:
|
|
|
459
460
|
loop = asyncio.new_event_loop()
|
|
460
461
|
asyncio.set_event_loop(loop)
|
|
461
462
|
return loop
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
def get_filename_from_headers(url):
|
|
466
|
+
try:
|
|
467
|
+
response = requests.head(url, allow_redirects=True)
|
|
468
|
+
if response.status_code >= 400 or "Content-Disposition" not in response.headers:
|
|
469
|
+
response = requests.get(url, stream=True)
|
|
470
|
+
content_disposition = response.headers.get("Content-Disposition")
|
|
471
|
+
if content_disposition:
|
|
472
|
+
filename = re.findall('filename="?([^"]+)"?', content_disposition)
|
|
473
|
+
if filename:
|
|
474
|
+
return filename[0]
|
|
475
|
+
filename = url.split("/")[-1] or "downloaded_file"
|
|
476
|
+
return filename
|
|
477
|
+
except Exception as e:
|
|
478
|
+
print(f"Error retrieving file name from headers: {e}")
|
|
479
|
+
return None
|
|
@@ -147,3 +147,4 @@ from supervisely.app.widgets.tree_select.tree_select import TreeSelect
|
|
|
147
147
|
from supervisely.app.widgets.select_dataset_tree.select_dataset_tree import SelectDatasetTree
|
|
148
148
|
from supervisely.app.widgets.grid_gallery_v2.grid_gallery_v2 import GridGalleryV2
|
|
149
149
|
from supervisely.app.widgets.report_thumbnail.report_thumbnail import ReportThumbnail
|
|
150
|
+
from supervisely.app.widgets.experiment_selector.experiment_selector import ExperimentSelector
|
|
@@ -125,6 +125,9 @@ class Card(Widget):
|
|
|
125
125
|
StateJson()[self.widget_id]["collapsed"] = self._collapsed
|
|
126
126
|
StateJson().send_changes()
|
|
127
127
|
|
|
128
|
+
def is_collapsed(self) -> bool:
|
|
129
|
+
return StateJson()[self.widget_id]["collapsed"]
|
|
130
|
+
|
|
128
131
|
def lock(self, message: Optional[str] = None) -> None:
|
|
129
132
|
"""Locks the card, changes the lock message if specified.
|
|
130
133
|
|
|
@@ -283,7 +283,9 @@ class ClassesTable(Widget):
|
|
|
283
283
|
StateJson().send_changes()
|
|
284
284
|
self.loading = False
|
|
285
285
|
|
|
286
|
-
def read_project_from_id(
|
|
286
|
+
def read_project_from_id(
|
|
287
|
+
self, project_id: int, dataset_ids: Optional[List[int]] = None
|
|
288
|
+
) -> None:
|
|
287
289
|
"""Read remote project by id and update table data.
|
|
288
290
|
|
|
289
291
|
:param project_id: Project id from which classes will be taken.
|
|
@@ -469,3 +471,15 @@ class ClassesTable(Widget):
|
|
|
469
471
|
StateJson()[self.widget_id]["global_checkbox"] = self._global_checkbox
|
|
470
472
|
StateJson()[self.widget_id]["checkboxes"] = self._checkboxes
|
|
471
473
|
StateJson().send_changes()
|
|
474
|
+
|
|
475
|
+
def set_dataset_ids(self, dataset_ids: List[int]) -> None:
|
|
476
|
+
"""Sets dataset ids to filter classes.
|
|
477
|
+
|
|
478
|
+
:param dataset_ids: List of dataset ids to filter classes.
|
|
479
|
+
:type dataset_ids: List[int]
|
|
480
|
+
"""
|
|
481
|
+
selected_classes = self.get_selected_classes()
|
|
482
|
+
self._dataset_ids = dataset_ids
|
|
483
|
+
self._update_meta(self._project_meta)
|
|
484
|
+
self.update_data()
|
|
485
|
+
self.select_classes(selected_classes)
|
|
@@ -209,11 +209,15 @@ class CustomModelsSelector(Widget):
|
|
|
209
209
|
for checkpoint_info in self._checkpoints:
|
|
210
210
|
if isinstance(checkpoint_info, dict):
|
|
211
211
|
checkpoint_selector_items.append(
|
|
212
|
-
Select.Item(
|
|
212
|
+
Select.Item(
|
|
213
|
+
value=checkpoint_info["path"], label=checkpoint_info["name"]
|
|
214
|
+
)
|
|
213
215
|
)
|
|
214
216
|
elif isinstance(checkpoint_info, FileInfo):
|
|
215
217
|
checkpoint_selector_items.append(
|
|
216
|
-
Select.Item(
|
|
218
|
+
Select.Item(
|
|
219
|
+
value=checkpoint_info.path, label=checkpoint_info.name
|
|
220
|
+
)
|
|
217
221
|
)
|
|
218
222
|
|
|
219
223
|
checkpoint_selector = Select(items=checkpoint_selector_items)
|
|
@@ -278,7 +282,9 @@ class CustomModelsSelector(Widget):
|
|
|
278
282
|
)
|
|
279
283
|
|
|
280
284
|
file_api = FileApi(self._api)
|
|
281
|
-
self._model_path_input = Input(
|
|
285
|
+
self._model_path_input = Input(
|
|
286
|
+
placeholder="Path to model file in Team Files"
|
|
287
|
+
)
|
|
282
288
|
|
|
283
289
|
@self._model_path_input.value_changed
|
|
284
290
|
def change_folder(value):
|
|
@@ -316,7 +322,9 @@ class CustomModelsSelector(Widget):
|
|
|
316
322
|
|
|
317
323
|
self.custom_tab_widgets.hide()
|
|
318
324
|
|
|
319
|
-
self.show_custom_checkpoint_path_checkbox = Checkbox(
|
|
325
|
+
self.show_custom_checkpoint_path_checkbox = Checkbox(
|
|
326
|
+
"Use custom checkpoint", False
|
|
327
|
+
)
|
|
320
328
|
|
|
321
329
|
@self.show_custom_checkpoint_path_checkbox.value_changed
|
|
322
330
|
def show_custom_checkpoint_path_checkbox_changed(is_checked):
|
|
@@ -391,7 +399,9 @@ class CustomModelsSelector(Widget):
|
|
|
391
399
|
self.disable_table()
|
|
392
400
|
super().disable()
|
|
393
401
|
|
|
394
|
-
def _generate_table_rows(
|
|
402
|
+
def _generate_table_rows(
|
|
403
|
+
self, train_infos: List[TrainInfo]
|
|
404
|
+
) -> Dict[str, List[ModelRow]]:
|
|
395
405
|
"""Method to generate table rows from remote path to training app save directory"""
|
|
396
406
|
|
|
397
407
|
def process_train_info(train_info):
|
|
@@ -438,7 +448,8 @@ class CustomModelsSelector(Widget):
|
|
|
438
448
|
if "pose estimation" in task_types:
|
|
439
449
|
sorted_tt.append("pose estimation")
|
|
440
450
|
other_tasks = sorted(
|
|
441
|
-
set(task_types)
|
|
451
|
+
set(task_types)
|
|
452
|
+
- set(["object detection", "instance segmentation", "pose estimation"])
|
|
442
453
|
)
|
|
443
454
|
sorted_tt.extend(other_tasks)
|
|
444
455
|
return sorted_tt
|
|
@@ -484,11 +495,16 @@ class CustomModelsSelector(Widget):
|
|
|
484
495
|
"checkpoint_url": checkpoint_url,
|
|
485
496
|
}
|
|
486
497
|
|
|
498
|
+
# if model_name is not None:
|
|
499
|
+
# model_params["model_name"] = model_name
|
|
500
|
+
|
|
487
501
|
if config_path is not None:
|
|
488
502
|
model_params["config_url"] = config_path
|
|
489
503
|
|
|
490
504
|
return model_params
|
|
491
505
|
|
|
506
|
+
# def get_selected_model_params_v2(self) -> Union[Dict, None]:
|
|
507
|
+
|
|
492
508
|
def set_active_row(self, row_index: int) -> None:
|
|
493
509
|
if row_index < 0 or row_index > len(self._rows) - 1:
|
|
494
510
|
raise ValueError(f'Row with index "{row_index}" does not exist')
|
|
@@ -520,7 +536,9 @@ class CustomModelsSelector(Widget):
|
|
|
520
536
|
|
|
521
537
|
def set_custom_checkpoint_task_type(self, task_type: str) -> None:
|
|
522
538
|
if self.use_custom_checkpoint_path():
|
|
523
|
-
available_task_types =
|
|
539
|
+
available_task_types = (
|
|
540
|
+
self.custom_checkpoint_task_type_selector.get_labels()
|
|
541
|
+
)
|
|
524
542
|
if task_type not in available_task_types:
|
|
525
543
|
raise ValueError(f'"{task_type}" is not available task type')
|
|
526
544
|
self.custom_checkpoint_task_type_selector.set_value(task_type)
|
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
{% endif %}
|
|
7
7
|
>
|
|
8
8
|
|
|
9
|
-
<div v-if="data.{{{widget.widget_id}}}.rowsHtml.length === 0"> You don't have any custom models</div>
|
|
9
|
+
<div v-if="Object.keys(data.{{{widget.widget_id}}}.rowsHtml).length === 0"> You don't have any custom models</div>
|
|
10
10
|
<div v-else>
|
|
11
11
|
|
|
12
12
|
<div v-if="data.{{{widget.widget_id}}}.taskTypes.length > 1">
|
|
File without changes
|
|
@@ -0,0 +1,500 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
4
|
+
from typing import Any, Callable, Dict, List, Union
|
|
5
|
+
|
|
6
|
+
from supervisely import env, logger
|
|
7
|
+
from supervisely._utils import abs_url, is_development
|
|
8
|
+
from supervisely.api.api import Api
|
|
9
|
+
from supervisely.api.project_api import ProjectInfo
|
|
10
|
+
from supervisely.app.content import DataJson, StateJson
|
|
11
|
+
from supervisely.app.widgets import (
|
|
12
|
+
Container,
|
|
13
|
+
Flexbox,
|
|
14
|
+
ProjectThumbnail,
|
|
15
|
+
Select,
|
|
16
|
+
Text,
|
|
17
|
+
Widget,
|
|
18
|
+
)
|
|
19
|
+
from supervisely.io.fs import get_file_name_with_ext
|
|
20
|
+
from supervisely.nn.experiments import ExperimentInfo
|
|
21
|
+
|
|
22
|
+
WEIGHTS_DIR = "weights"
|
|
23
|
+
|
|
24
|
+
COL_ID = "task id".upper()
|
|
25
|
+
COL_MODEL = "model".upper()
|
|
26
|
+
COL_PROJECT = "training data".upper()
|
|
27
|
+
COL_CHECKPOINTS = "checkpoints".upper()
|
|
28
|
+
COL_SESSION = "session".upper()
|
|
29
|
+
COL_BENCHMARK = "benchmark".upper()
|
|
30
|
+
|
|
31
|
+
columns = [COL_ID, COL_MODEL, COL_PROJECT, COL_CHECKPOINTS, COL_SESSION, COL_BENCHMARK]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ExperimentSelector(Widget):
|
|
35
|
+
class Routes:
|
|
36
|
+
TASK_TYPE_CHANGED = "task_type_changed"
|
|
37
|
+
VALUE_CHANGED = "value_changed"
|
|
38
|
+
|
|
39
|
+
class ModelRow:
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
api: Api,
|
|
43
|
+
team_id: int,
|
|
44
|
+
task_type: str,
|
|
45
|
+
experiment_info: ExperimentInfo,
|
|
46
|
+
):
|
|
47
|
+
self._api = api
|
|
48
|
+
self._team_id = team_id
|
|
49
|
+
self._task_type = task_type
|
|
50
|
+
self._experiment_info = experiment_info
|
|
51
|
+
|
|
52
|
+
task_id = experiment_info.task_id
|
|
53
|
+
if task_id == "debug-session":
|
|
54
|
+
pass
|
|
55
|
+
elif type(task_id) is str:
|
|
56
|
+
if task_id.isdigit():
|
|
57
|
+
task_id = int(task_id)
|
|
58
|
+
else:
|
|
59
|
+
raise ValueError(f"Task id {task_id} is not a number")
|
|
60
|
+
|
|
61
|
+
# col 1 task
|
|
62
|
+
self._task_id = task_id
|
|
63
|
+
self._task_path = experiment_info.artifacts_dir
|
|
64
|
+
self._task_date = experiment_info.datetime
|
|
65
|
+
self._task_link = self._create_task_link()
|
|
66
|
+
self._config_path = experiment_info.model_files.get("config")
|
|
67
|
+
if self._config_path is not None:
|
|
68
|
+
self._config_path = os.path.join(experiment_info.artifacts_dir, self._config_path)
|
|
69
|
+
|
|
70
|
+
# col 2 model
|
|
71
|
+
self._model_name = experiment_info.model_name
|
|
72
|
+
|
|
73
|
+
# col 3 project
|
|
74
|
+
self._training_project_id = experiment_info.project_id
|
|
75
|
+
self._training_project_info = self._api.project.get_info_by_id(
|
|
76
|
+
self._training_project_id
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
# col 4 checkpoints
|
|
80
|
+
self._checkpoints = experiment_info.checkpoints
|
|
81
|
+
|
|
82
|
+
self._checkpoints_names = []
|
|
83
|
+
self._checkpoints_paths = []
|
|
84
|
+
for checkpoint_path in self._checkpoints:
|
|
85
|
+
self._checkpoints_names.append(get_file_name_with_ext(checkpoint_path))
|
|
86
|
+
self._checkpoints_paths.append(
|
|
87
|
+
os.path.join(experiment_info.artifacts_dir, checkpoint_path)
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# col 5 session
|
|
91
|
+
self._session_link = self._generate_session_link()
|
|
92
|
+
|
|
93
|
+
# col 6 benchmark report
|
|
94
|
+
self._benchmark_report = None # experiment_infos.benchmark_report_path
|
|
95
|
+
|
|
96
|
+
# widgets
|
|
97
|
+
self._task_widget = self._create_task_widget()
|
|
98
|
+
self._model_wiget = self._create_model_widget()
|
|
99
|
+
self._training_project_widget = self._create_training_project_widget()
|
|
100
|
+
self._checkpoints_widget = self._create_checkpoints_widget()
|
|
101
|
+
self._session_widget = self._create_session_widget()
|
|
102
|
+
self._benchmark_widget = self._create_benchmark_widget()
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def task_id(self) -> int:
|
|
106
|
+
return self._task_id
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def task_date(self) -> str:
|
|
110
|
+
return self._task_date
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def task_link(self) -> str:
|
|
114
|
+
return self._task_link
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def task_type(self) -> str:
|
|
118
|
+
return self._task_type
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def training_project_info(self) -> ProjectInfo:
|
|
122
|
+
return self._training_project_info
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def checkpoints_names(self) -> List[str]:
|
|
126
|
+
return self._checkpoints_names
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def checkpoints_paths(self) -> List[str]:
|
|
130
|
+
return self._checkpoints_paths
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def checkpoints_selector(self) -> Select:
|
|
134
|
+
return self._checkpoints_widget
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
def session_link(self) -> str:
|
|
138
|
+
return self._session_link
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def config_path(self) -> str:
|
|
142
|
+
return self._config_path
|
|
143
|
+
|
|
144
|
+
def get_selected_checkpoint_path(self) -> str:
|
|
145
|
+
return self._checkpoints_widget.get_value()
|
|
146
|
+
|
|
147
|
+
def get_selected_checkpoint_name(self) -> str:
|
|
148
|
+
return self._checkpoints_widget.get_label()
|
|
149
|
+
|
|
150
|
+
def set_selected_checkpoint_by_name(self, checkpoint_name: str):
|
|
151
|
+
for i, name in enumerate(self._checkpoints_names):
|
|
152
|
+
if name == checkpoint_name:
|
|
153
|
+
self._checkpoints_widget.set_value(self._checkpoints_paths[i])
|
|
154
|
+
return
|
|
155
|
+
|
|
156
|
+
def set_selected_checkpoint_by_path(self, checkpoint_path: str):
|
|
157
|
+
for i, path in enumerate(self._checkpoints_paths):
|
|
158
|
+
if path == checkpoint_path:
|
|
159
|
+
self._checkpoints_widget.set_value(path)
|
|
160
|
+
return
|
|
161
|
+
|
|
162
|
+
def to_html(self) -> List[str]:
|
|
163
|
+
return [
|
|
164
|
+
f"<div> {self._task_widget.to_html()} </div>",
|
|
165
|
+
f"<div> {self._model_wiget.to_html()} </div>",
|
|
166
|
+
f"<div> {self._training_project_widget.to_html()} </div>",
|
|
167
|
+
f"<div> {self._checkpoints_widget.to_html()} </div>",
|
|
168
|
+
f"<div> {self._session_widget.to_html()} </div>",
|
|
169
|
+
f"<div> {self._benchmark_widget.to_html()} </div>",
|
|
170
|
+
]
|
|
171
|
+
|
|
172
|
+
def _create_task_link(self) -> str:
|
|
173
|
+
remote_path = os.path.join(self._task_path, "open_app.lnk")
|
|
174
|
+
task_file = self._api.file.get_info_by_path(self._team_id, remote_path)
|
|
175
|
+
if task_file is not None:
|
|
176
|
+
if is_development():
|
|
177
|
+
return abs_url(f"/files/{task_file.id}")
|
|
178
|
+
else:
|
|
179
|
+
return f"/files/{task_file.id}"
|
|
180
|
+
else:
|
|
181
|
+
return ""
|
|
182
|
+
|
|
183
|
+
def _generate_session_link(self) -> str:
|
|
184
|
+
if is_development():
|
|
185
|
+
session_link = abs_url(f"/apps/sessions/{self._task_id}")
|
|
186
|
+
else:
|
|
187
|
+
session_link = f"/apps/sessions/{self._task_id}"
|
|
188
|
+
return session_link
|
|
189
|
+
|
|
190
|
+
def _create_task_widget(self) -> Flexbox:
|
|
191
|
+
task_widget = Container(
|
|
192
|
+
[
|
|
193
|
+
Text(
|
|
194
|
+
f"<i class='zmdi zmdi-folder' style='color: #7f858e'></i> <a href='{self._task_link}'>{self._task_id}</a>",
|
|
195
|
+
"text",
|
|
196
|
+
),
|
|
197
|
+
Text(
|
|
198
|
+
f"<span class='field-description text-muted' style='color: #7f858e'>{self._task_date}</span>",
|
|
199
|
+
"text",
|
|
200
|
+
font_size=13,
|
|
201
|
+
),
|
|
202
|
+
],
|
|
203
|
+
gap=0,
|
|
204
|
+
)
|
|
205
|
+
return task_widget
|
|
206
|
+
|
|
207
|
+
def _create_model_widget(self) -> Text:
|
|
208
|
+
if self._model_name is None:
|
|
209
|
+
self._model_name = "Unknown model"
|
|
210
|
+
|
|
211
|
+
model_widget = Text(
|
|
212
|
+
f"<span class='field-description text-muted' style='color: #7f858e'>{self._model_name}</span>",
|
|
213
|
+
"text",
|
|
214
|
+
font_size=13,
|
|
215
|
+
)
|
|
216
|
+
return model_widget
|
|
217
|
+
|
|
218
|
+
def _create_training_project_widget(self) -> Union[ProjectThumbnail, Text]:
|
|
219
|
+
if self.training_project_info is not None:
|
|
220
|
+
training_project_widget = ProjectThumbnail(
|
|
221
|
+
self._training_project_info, remove_margins=True
|
|
222
|
+
)
|
|
223
|
+
else:
|
|
224
|
+
training_project_widget = Text(
|
|
225
|
+
f"<span class='field-description text-muted' style='color: #7f858e'>Project was deleted</span>",
|
|
226
|
+
"text",
|
|
227
|
+
font_size=13,
|
|
228
|
+
)
|
|
229
|
+
return training_project_widget
|
|
230
|
+
|
|
231
|
+
def _create_checkpoints_widget(self) -> Select:
|
|
232
|
+
checkpoint_selector_items = []
|
|
233
|
+
for path, name in zip(self._checkpoints_paths, self._checkpoints_names):
|
|
234
|
+
checkpoint_selector_items.append(Select.Item(value=path, label=name))
|
|
235
|
+
checkpoint_selector = Select(items=checkpoint_selector_items)
|
|
236
|
+
return checkpoint_selector
|
|
237
|
+
|
|
238
|
+
def _create_session_widget(self) -> Text:
|
|
239
|
+
session_link_widget = Text(
|
|
240
|
+
f"<a href='{self._session_link}'>Preview</a> <i class='zmdi zmdi-open-in-new'></i>",
|
|
241
|
+
"text",
|
|
242
|
+
)
|
|
243
|
+
return session_link_widget
|
|
244
|
+
|
|
245
|
+
def _create_benchmark_widget(self) -> Text:
|
|
246
|
+
if self._benchmark_report is None:
|
|
247
|
+
self._benchmark_report = "No benchmark report available"
|
|
248
|
+
benchmark_widget = Text(
|
|
249
|
+
"<span class='field-description text-muted' style='color: #7f858e'>No benchmark report available</span>",
|
|
250
|
+
"text",
|
|
251
|
+
font_size=13,
|
|
252
|
+
)
|
|
253
|
+
else:
|
|
254
|
+
benchmark_widget = Text(
|
|
255
|
+
f"<a href='{self._benchmark_report}'>Benchmark Report</a> <i class='zmdi zmdi-chart'></i>",
|
|
256
|
+
"text",
|
|
257
|
+
)
|
|
258
|
+
return benchmark_widget
|
|
259
|
+
|
|
260
|
+
def __init__(
|
|
261
|
+
self,
|
|
262
|
+
team_id: int,
|
|
263
|
+
experiment_infos: List[ExperimentInfo] = [],
|
|
264
|
+
widget_id: str = None,
|
|
265
|
+
):
|
|
266
|
+
self._api = Api.from_env()
|
|
267
|
+
|
|
268
|
+
self._team_id = team_id
|
|
269
|
+
self.__debug_row = None
|
|
270
|
+
|
|
271
|
+
with ThreadPoolExecutor() as executor:
|
|
272
|
+
future = executor.submit(self._generate_table_rows, experiment_infos)
|
|
273
|
+
table_rows = future.result()
|
|
274
|
+
|
|
275
|
+
self._columns = columns
|
|
276
|
+
self._rows = table_rows
|
|
277
|
+
# self._rows_html = #[row.to_html() for row in self._rows]
|
|
278
|
+
|
|
279
|
+
task_types = [task_type for task_type in table_rows]
|
|
280
|
+
self._rows_html = defaultdict(list)
|
|
281
|
+
for task_type in table_rows:
|
|
282
|
+
self._rows_html[task_type].extend(
|
|
283
|
+
[model_row.to_html() for model_row in table_rows[task_type]]
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
self._task_types = self._filter_task_types(task_types)
|
|
287
|
+
if len(self._task_types) == 0:
|
|
288
|
+
self.__default_selected_task_type = None
|
|
289
|
+
else:
|
|
290
|
+
self.__default_selected_task_type = self._task_types[0]
|
|
291
|
+
|
|
292
|
+
self._changes_handled = False
|
|
293
|
+
self._task_type_changes_handled = False
|
|
294
|
+
super().__init__(widget_id=widget_id, file_path=__file__)
|
|
295
|
+
|
|
296
|
+
@property
|
|
297
|
+
def columns(self) -> List[str]:
|
|
298
|
+
return self._columns
|
|
299
|
+
|
|
300
|
+
@property
|
|
301
|
+
def rows(self) -> Dict[str, List[ModelRow]]:
|
|
302
|
+
return self._rows
|
|
303
|
+
|
|
304
|
+
def get_json_data(self) -> Dict:
|
|
305
|
+
return {
|
|
306
|
+
"columns": self._columns,
|
|
307
|
+
"rowsHtml": self._rows_html,
|
|
308
|
+
"taskTypes": self._task_types,
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
def get_json_state(self) -> Dict:
|
|
312
|
+
return {
|
|
313
|
+
"selectedRow": 0,
|
|
314
|
+
"selectedTaskType": self.__default_selected_task_type,
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
def set_active_task_type(self, task_type: str):
|
|
318
|
+
if task_type not in self._task_types:
|
|
319
|
+
raise ValueError(f'Task Type "{task_type}" does not exist')
|
|
320
|
+
StateJson()[self.widget_id]["selectedTaskType"] = task_type
|
|
321
|
+
StateJson().send_changes()
|
|
322
|
+
|
|
323
|
+
def get_available_task_types(self) -> List[str]:
|
|
324
|
+
return self._task_types
|
|
325
|
+
|
|
326
|
+
def disable_table(self) -> None:
|
|
327
|
+
for task_type in self._rows:
|
|
328
|
+
for row in self._rows[task_type]:
|
|
329
|
+
row.checkpoints_selector.disable()
|
|
330
|
+
super().disable()
|
|
331
|
+
|
|
332
|
+
def enable_table(self) -> None:
|
|
333
|
+
for task_type in self._rows:
|
|
334
|
+
for row in self._rows[task_type]:
|
|
335
|
+
row.checkpoints_selector.enable()
|
|
336
|
+
super().enable()
|
|
337
|
+
|
|
338
|
+
def enable(self):
|
|
339
|
+
self.enable_table()
|
|
340
|
+
super().enable()
|
|
341
|
+
|
|
342
|
+
def disable(self) -> None:
|
|
343
|
+
self.disable_table()
|
|
344
|
+
super().disable()
|
|
345
|
+
|
|
346
|
+
def _generate_table_rows(
|
|
347
|
+
self, experiment_infos: List[ExperimentInfo]
|
|
348
|
+
) -> Dict[str, List[ModelRow]]:
|
|
349
|
+
"""Method to generate table rows from remote path to training app save directory"""
|
|
350
|
+
|
|
351
|
+
def process_experiment_info(experiment_info: ExperimentInfo):
|
|
352
|
+
try:
|
|
353
|
+
model_row = ExperimentSelector.ModelRow(
|
|
354
|
+
api=self._api,
|
|
355
|
+
team_id=self._team_id,
|
|
356
|
+
task_type=experiment_info.task_type,
|
|
357
|
+
experiment_info=experiment_info,
|
|
358
|
+
)
|
|
359
|
+
return experiment_info.task_type, model_row
|
|
360
|
+
except Exception as e:
|
|
361
|
+
logger.warn(f"Failed to process experiment info: {experiment_info}")
|
|
362
|
+
return None, None
|
|
363
|
+
|
|
364
|
+
table_rows = defaultdict(list)
|
|
365
|
+
with ThreadPoolExecutor() as executor:
|
|
366
|
+
futures = {
|
|
367
|
+
executor.submit(process_experiment_info, experiment_info): experiment_info
|
|
368
|
+
for experiment_info in experiment_infos
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
for future in as_completed(futures):
|
|
372
|
+
result = future.result()
|
|
373
|
+
if result:
|
|
374
|
+
task_type, model_row = result
|
|
375
|
+
if task_type is not None and model_row is not None:
|
|
376
|
+
if model_row.task_id == "debug-session":
|
|
377
|
+
self.__debug_row = (task_type, model_row)
|
|
378
|
+
continue
|
|
379
|
+
table_rows[task_type].append(model_row)
|
|
380
|
+
self._sort_table_rows(table_rows)
|
|
381
|
+
if self.__debug_row and is_development():
|
|
382
|
+
task_type, model_row = self.__debug_row
|
|
383
|
+
table_rows[task_type].insert(0, model_row)
|
|
384
|
+
return table_rows
|
|
385
|
+
|
|
386
|
+
def _sort_table_rows(self, table_rows: Dict[str, List[ModelRow]]) -> None:
|
|
387
|
+
for task_type in table_rows:
|
|
388
|
+
table_rows[task_type].sort(key=lambda row: int(row.task_id), reverse=True)
|
|
389
|
+
|
|
390
|
+
def _filter_task_types(self, task_types: List[str]):
|
|
391
|
+
sorted_tt = []
|
|
392
|
+
if "object detection" in task_types:
|
|
393
|
+
sorted_tt.append("object detection")
|
|
394
|
+
if "instance segmentation" in task_types:
|
|
395
|
+
sorted_tt.append("instance segmentation")
|
|
396
|
+
if "pose estimation" in task_types:
|
|
397
|
+
sorted_tt.append("pose estimation")
|
|
398
|
+
other_tasks = sorted(
|
|
399
|
+
set(task_types)
|
|
400
|
+
- set(
|
|
401
|
+
[
|
|
402
|
+
"object detection",
|
|
403
|
+
"instance segmentation",
|
|
404
|
+
"semantic segmentation",
|
|
405
|
+
"pose estimation",
|
|
406
|
+
]
|
|
407
|
+
)
|
|
408
|
+
)
|
|
409
|
+
sorted_tt.extend(other_tasks)
|
|
410
|
+
return sorted_tt
|
|
411
|
+
|
|
412
|
+
def get_selected_row(self, state=StateJson()) -> Union[ModelRow, None]:
|
|
413
|
+
if len(self._rows) == 0:
|
|
414
|
+
return
|
|
415
|
+
widget_actual_state = state[self.widget_id]
|
|
416
|
+
widget_actual_data = DataJson()[self.widget_id]
|
|
417
|
+
task_type = widget_actual_state["selectedTaskType"]
|
|
418
|
+
if widget_actual_state is not None and widget_actual_data is not None:
|
|
419
|
+
selected_row_index = int(widget_actual_state["selectedRow"])
|
|
420
|
+
return self._rows[task_type][selected_row_index]
|
|
421
|
+
|
|
422
|
+
def get_selected_row_index(self, state=StateJson()) -> Union[int, None]:
|
|
423
|
+
widget_actual_state = state[self.widget_id]
|
|
424
|
+
widget_actual_data = DataJson()[self.widget_id]
|
|
425
|
+
if widget_actual_state is not None and widget_actual_data is not None:
|
|
426
|
+
return widget_actual_state["selectedRow"]
|
|
427
|
+
|
|
428
|
+
def get_selected_task_type(self) -> str:
|
|
429
|
+
return StateJson()[self.widget_id]["selectedTaskType"]
|
|
430
|
+
|
|
431
|
+
def get_selected_experiment_info(self) -> Dict[str, Any]:
|
|
432
|
+
if len(self._rows) == 0:
|
|
433
|
+
return
|
|
434
|
+
selected_row = self.get_selected_row()
|
|
435
|
+
selected_row_json = selected_row._experiment_info._asdict()
|
|
436
|
+
return selected_row_json
|
|
437
|
+
|
|
438
|
+
def get_selected_checkpoint_path(self) -> str:
|
|
439
|
+
if len(self._rows) == 0:
|
|
440
|
+
return
|
|
441
|
+
selected_row = self.get_selected_row()
|
|
442
|
+
return selected_row.get_selected_checkpoint_path()
|
|
443
|
+
|
|
444
|
+
def get_model_files(self) -> Dict[str, str]:
|
|
445
|
+
"""
|
|
446
|
+
Returns a dictionary with full paths to model files in Supervisely Team Files.
|
|
447
|
+
"""
|
|
448
|
+
experiment_info = self.get_selected_experiment_info()
|
|
449
|
+
artifacts_dir = experiment_info.get("artifacts_dir")
|
|
450
|
+
model_files = experiment_info.get("model_files", {})
|
|
451
|
+
|
|
452
|
+
full_model_files = {
|
|
453
|
+
name: os.path.join(artifacts_dir, file) for name, file in model_files.items()
|
|
454
|
+
}
|
|
455
|
+
full_model_files["checkpoint"] = self.get_selected_checkpoint_path()
|
|
456
|
+
return full_model_files
|
|
457
|
+
|
|
458
|
+
def set_active_row(self, row_index: int) -> None:
|
|
459
|
+
if row_index < 0 or row_index > len(self._rows) - 1:
|
|
460
|
+
raise ValueError(f'Row with index "{row_index}" does not exist')
|
|
461
|
+
StateJson()[self.widget_id]["selectedRow"] = row_index
|
|
462
|
+
StateJson().send_changes()
|
|
463
|
+
|
|
464
|
+
def set_by_task_id(self, task_id: int) -> None:
|
|
465
|
+
for task_type in self._rows:
|
|
466
|
+
for i, row in enumerate(self._rows[task_type]):
|
|
467
|
+
if row.task_id == task_id:
|
|
468
|
+
self.set_active_row(i)
|
|
469
|
+
return
|
|
470
|
+
|
|
471
|
+
def get_by_task_id(self, task_id: int) -> Union[ModelRow, None]:
|
|
472
|
+
for task_type in self._rows:
|
|
473
|
+
for row in self._rows[task_type]:
|
|
474
|
+
if row.task_id == task_id:
|
|
475
|
+
return row
|
|
476
|
+
return None
|
|
477
|
+
|
|
478
|
+
def task_type_changed(self, func: Callable):
|
|
479
|
+
route_path = self.get_route_path(ExperimentSelector.Routes.TASK_TYPE_CHANGED)
|
|
480
|
+
server = self._sly_app.get_server()
|
|
481
|
+
self._task_type_changes_handled = True
|
|
482
|
+
|
|
483
|
+
@server.post(route_path)
|
|
484
|
+
def _task_type_changed():
|
|
485
|
+
res = self.get_selected_task_type()
|
|
486
|
+
func(res)
|
|
487
|
+
|
|
488
|
+
return _task_type_changed
|
|
489
|
+
|
|
490
|
+
def value_changed(self, func: Callable):
|
|
491
|
+
route_path = self.get_route_path(ExperimentSelector.Routes.VALUE_CHANGED)
|
|
492
|
+
server = self._sly_app.get_server()
|
|
493
|
+
self._changes_handled = True
|
|
494
|
+
|
|
495
|
+
@server.post(route_path)
|
|
496
|
+
def _value_changed():
|
|
497
|
+
res = self.get_selected_row()
|
|
498
|
+
func(res)
|
|
499
|
+
|
|
500
|
+
return _value_changed
|