supervisely 6.73.242__py3-none-any.whl → 6.73.244__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/__init__.py +1 -1
- supervisely/_utils.py +18 -0
- supervisely/app/widgets/__init__.py +1 -0
- supervisely/app/widgets/card/card.py +3 -0
- supervisely/app/widgets/classes_table/classes_table.py +15 -1
- supervisely/app/widgets/custom_models_selector/custom_models_selector.py +25 -7
- supervisely/app/widgets/custom_models_selector/template.html +1 -1
- supervisely/app/widgets/experiment_selector/__init__.py +0 -0
- supervisely/app/widgets/experiment_selector/experiment_selector.py +500 -0
- supervisely/app/widgets/experiment_selector/style.css +27 -0
- supervisely/app/widgets/experiment_selector/template.html +82 -0
- supervisely/app/widgets/pretrained_models_selector/pretrained_models_selector.py +25 -3
- supervisely/app/widgets/random_splits_table/random_splits_table.py +41 -17
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +12 -5
- supervisely/app/widgets/train_val_splits/train_val_splits.py +99 -10
- supervisely/app/widgets/tree_select/tree_select.py +2 -0
- supervisely/nn/__init__.py +3 -1
- supervisely/nn/artifacts/artifacts.py +10 -0
- supervisely/nn/artifacts/detectron2.py +2 -0
- supervisely/nn/artifacts/hrda.py +3 -0
- supervisely/nn/artifacts/mmclassification.py +2 -0
- supervisely/nn/artifacts/mmdetection.py +6 -3
- supervisely/nn/artifacts/mmsegmentation.py +2 -0
- supervisely/nn/artifacts/ritm.py +3 -1
- supervisely/nn/artifacts/rtdetr.py +2 -0
- supervisely/nn/artifacts/unet.py +2 -0
- supervisely/nn/artifacts/yolov5.py +3 -0
- supervisely/nn/artifacts/yolov8.py +7 -1
- supervisely/nn/experiments.py +113 -0
- supervisely/nn/inference/gui/__init__.py +3 -1
- supervisely/nn/inference/gui/gui.py +31 -232
- supervisely/nn/inference/gui/serving_gui.py +223 -0
- supervisely/nn/inference/gui/serving_gui_template.py +240 -0
- supervisely/nn/inference/inference.py +225 -24
- supervisely/nn/training/__init__.py +0 -0
- supervisely/nn/training/gui/__init__.py +1 -0
- supervisely/nn/training/gui/classes_selector.py +100 -0
- supervisely/nn/training/gui/gui.py +539 -0
- supervisely/nn/training/gui/hyperparameters_selector.py +117 -0
- supervisely/nn/training/gui/input_selector.py +70 -0
- supervisely/nn/training/gui/model_selector.py +95 -0
- supervisely/nn/training/gui/train_val_splits_selector.py +200 -0
- supervisely/nn/training/gui/training_logs.py +93 -0
- supervisely/nn/training/gui/training_process.py +114 -0
- supervisely/nn/training/gui/utils.py +128 -0
- supervisely/nn/training/loggers/__init__.py +0 -0
- supervisely/nn/training/loggers/base_train_logger.py +58 -0
- supervisely/nn/training/loggers/tensorboard_logger.py +46 -0
- supervisely/nn/training/train_app.py +2038 -0
- supervisely/nn/utils.py +5 -0
- supervisely/project/project.py +1 -1
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/METADATA +3 -1
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/RECORD +57 -35
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/LICENSE +0 -0
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/WHEEL +0 -0
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
.custom-models-selector-table {
|
|
2
|
+
border-collapse: collapse;
|
|
3
|
+
}
|
|
4
|
+
.custom-models-selector-table tr:nth-child(2n) {
|
|
5
|
+
background-color: #f6f8fa;
|
|
6
|
+
}
|
|
7
|
+
.custom-models-selector-table td,
|
|
8
|
+
.custom-models-selector-table th {
|
|
9
|
+
border: 1px solid #dfe2e5;
|
|
10
|
+
padding: 6px 13px;
|
|
11
|
+
text-align: center;
|
|
12
|
+
line-height: 20px;
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
.custom-models-selector-table td {
|
|
16
|
+
text-align: left;
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
.custom-models-selector-table tr td:nth-child(4) {
|
|
20
|
+
text-align: center;
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
.el-radio-group.multi-line label.el-radio {
|
|
24
|
+
display: block;
|
|
25
|
+
margin-left: 0px;
|
|
26
|
+
margin-bottom: 5px;
|
|
27
|
+
}
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
<link rel="stylesheet" href="./sly/css/app/widgets/custom_models_selector/style.css"/>
|
|
2
|
+
|
|
3
|
+
<div
|
|
4
|
+
{% if widget._changes_handled == true %}
|
|
5
|
+
@change="post('/{{{widget.widget_id}}}/value_changed')"
|
|
6
|
+
{% endif %}
|
|
7
|
+
>
|
|
8
|
+
|
|
9
|
+
<div v-if="Object.keys(data.{{{widget.widget_id}}}.rowsHtml).length === 0"> You don't have any custom models</div>
|
|
10
|
+
<div v-else>
|
|
11
|
+
|
|
12
|
+
<div v-if="data.{{{widget.widget_id}}}.taskTypes.length > 1">
|
|
13
|
+
<sly-field
|
|
14
|
+
title="Task Type"
|
|
15
|
+
>
|
|
16
|
+
<el-radio-group
|
|
17
|
+
class="multi-line mt10"
|
|
18
|
+
:value="state.{{{widget.widget_id}}}.selectedTaskType"
|
|
19
|
+
|
|
20
|
+
{% if widget._task_type_changes_handled == true %}
|
|
21
|
+
@input="(evt) => {state.{{{widget.widget_id}}}.selectedTaskType = evt; state.{{{widget.widget_id}}}.selectedRow = 0; post('/{{{widget.widget_id}}}/task_type_changed')}"
|
|
22
|
+
{% else %}
|
|
23
|
+
@input="(evt) => {state.{{{widget.widget_id}}}.selectedTaskType = evt; state.{{{widget.widget_id}}}.selectedRow = 0;}"
|
|
24
|
+
{% endif %}
|
|
25
|
+
>
|
|
26
|
+
|
|
27
|
+
<el-radio
|
|
28
|
+
v-for="(item, idx) in {{{widget._task_types}}}"
|
|
29
|
+
:key="item"
|
|
30
|
+
:label="item"
|
|
31
|
+
>
|
|
32
|
+
{{ item }}
|
|
33
|
+
</el-radio>
|
|
34
|
+
</el-radio-group>
|
|
35
|
+
</sly-field>
|
|
36
|
+
</div>
|
|
37
|
+
|
|
38
|
+
<div>
|
|
39
|
+
|
|
40
|
+
<table class="custom-models-selector-table">
|
|
41
|
+
<thead>
|
|
42
|
+
<tr>
|
|
43
|
+
<th v-for="col in data.{{{widget.widget_id}}}.columns">
|
|
44
|
+
<div> {{col}} </div>
|
|
45
|
+
</th>
|
|
46
|
+
</tr>
|
|
47
|
+
</thead>
|
|
48
|
+
<tbody>
|
|
49
|
+
<tr v-for="row, ridx in data.{{{widget.widget_id}}}.rowsHtml[state.{{{widget.widget_id}}}.selectedTaskType]">
|
|
50
|
+
<td v-for="col, vidx in row">
|
|
51
|
+
<div v-if="vidx === 0" style="display: flex;">
|
|
52
|
+
<el-radio
|
|
53
|
+
style="display: flex;"
|
|
54
|
+
v-model="state.{{{widget.widget_id}}}.selectedRow"
|
|
55
|
+
:label="ridx"
|
|
56
|
+
>‍</el-radio>
|
|
57
|
+
|
|
58
|
+
<sly-html-compiler :params="{ridx: ridx, vidx: vidx}" :template="col" :data="data" :state="state"></sly-html-compiler>
|
|
59
|
+
|
|
60
|
+
</div>
|
|
61
|
+
|
|
62
|
+
<div v-else>
|
|
63
|
+
|
|
64
|
+
<sly-html-compiler :params="{ridx: ridx, vidx: vidx}" :template="col" :data="data" :state="state">
|
|
65
|
+
</sly-html-compiler>
|
|
66
|
+
|
|
67
|
+
</div>
|
|
68
|
+
|
|
69
|
+
</td>
|
|
70
|
+
</tr>
|
|
71
|
+
</tbody>
|
|
72
|
+
</table>
|
|
73
|
+
</div>
|
|
74
|
+
<div class="mt10" v-if="{{{widget.show_custom_checkpoint_path}}}"
|
|
75
|
+
>
|
|
76
|
+
{{{widget.show_custom_checkpoint_path_checkbox}}}
|
|
77
|
+
<div class="mt10">
|
|
78
|
+
{{{widget.custom_tab_widgets}}}
|
|
79
|
+
</div>
|
|
80
|
+
</div>
|
|
81
|
+
</div>
|
|
82
|
+
</div>
|
|
@@ -136,6 +136,9 @@ class PretrainedModelsSelector(Widget):
|
|
|
136
136
|
|
|
137
137
|
def get_selected_model_params(self, model_name_column: str = "Model") -> Union[Dict, None]:
|
|
138
138
|
selected_model = self.get_selected_row()
|
|
139
|
+
if selected_model is None:
|
|
140
|
+
return {}
|
|
141
|
+
|
|
139
142
|
model_name = selected_model.get(model_name_column)
|
|
140
143
|
if model_name is None:
|
|
141
144
|
raise ValueError(
|
|
@@ -172,24 +175,43 @@ class PretrainedModelsSelector(Widget):
|
|
|
172
175
|
if widget_actual_state is not None and widget_actual_data is not None:
|
|
173
176
|
return widget_actual_state["selectedRow"]
|
|
174
177
|
|
|
175
|
-
def set_active_arch_type(self, arch_type: str):
|
|
178
|
+
def set_active_arch_type(self, arch_type: str) -> None:
|
|
176
179
|
if arch_type not in self._arch_types:
|
|
177
180
|
raise ValueError(f'Architecture type "{arch_type}" does not exist')
|
|
178
181
|
StateJson()[self.widget_id]["selectedArchType"] = arch_type
|
|
179
182
|
StateJson().send_changes()
|
|
180
183
|
|
|
181
|
-
def set_active_task_type(self, task_type: str):
|
|
184
|
+
def set_active_task_type(self, task_type: str) -> None:
|
|
182
185
|
if task_type not in self._task_types:
|
|
183
186
|
raise ValueError(f'Task type "{task_type}" does not exist')
|
|
184
187
|
StateJson()[self.widget_id]["selectedTaskType"] = task_type
|
|
185
188
|
StateJson().send_changes()
|
|
186
189
|
|
|
187
|
-
def set_active_row(self, row_index: int):
|
|
190
|
+
def set_active_row(self, row_index: int) -> None:
|
|
188
191
|
if row_index < 0:
|
|
189
192
|
raise ValueError(f'Row with index "{row_index}" does not exist')
|
|
190
193
|
StateJson()[self.widget_id]["selectedRow"] = row_index
|
|
191
194
|
StateJson().send_changes()
|
|
192
195
|
|
|
196
|
+
def set_by_model_name(self, model_name: str) -> None:
|
|
197
|
+
for task_type in self._table_data:
|
|
198
|
+
for arch_type in self._table_data[task_type]:
|
|
199
|
+
for idx, model in enumerate(self._table_data[task_type][arch_type]):
|
|
200
|
+
model_meta = model.get("meta", {})
|
|
201
|
+
if model_meta.get("model_name") == model_name:
|
|
202
|
+
self.set_active_task_type(task_type)
|
|
203
|
+
self.set_active_arch_type(arch_type)
|
|
204
|
+
self.set_active_row(idx)
|
|
205
|
+
return
|
|
206
|
+
|
|
207
|
+
def get_by_model_name(self, model_name: str) -> Union[Dict, None]:
|
|
208
|
+
for task_type in self._table_data:
|
|
209
|
+
for arch_type in self._table_data[task_type]:
|
|
210
|
+
for idx, model in enumerate(self._table_data[task_type][arch_type]):
|
|
211
|
+
model_meta = model.get("meta", {})
|
|
212
|
+
if model_meta.get("model_name") == model_name:
|
|
213
|
+
return model
|
|
214
|
+
|
|
193
215
|
def _filter_and_sort_models(self, models: List[Dict], sort_models: bool = True) -> Dict:
|
|
194
216
|
filtered_models = {}
|
|
195
217
|
|
|
@@ -1,15 +1,16 @@
|
|
|
1
1
|
from typing import Dict, Optional
|
|
2
2
|
|
|
3
|
-
from supervisely.app import StateJson
|
|
3
|
+
from supervisely.app import DataJson, StateJson
|
|
4
4
|
from supervisely.app.widgets import Widget
|
|
5
5
|
|
|
6
|
+
|
|
6
7
|
class RandomSplitsTable(Widget):
|
|
7
8
|
def __init__(
|
|
8
|
-
self,
|
|
9
|
-
items_count: int,
|
|
9
|
+
self,
|
|
10
|
+
items_count: int,
|
|
10
11
|
start_train_percent: Optional[int] = 80,
|
|
11
12
|
disabled: Optional[bool] = False,
|
|
12
|
-
widget_id: Optional[int] = None
|
|
13
|
+
widget_id: Optional[int] = None,
|
|
13
14
|
):
|
|
14
15
|
self._disabled = disabled
|
|
15
16
|
if 1 <= start_train_percent <= 99:
|
|
@@ -23,33 +24,56 @@ class RandomSplitsTable(Widget):
|
|
|
23
24
|
]
|
|
24
25
|
self._items_count = items_count
|
|
25
26
|
train_count = int(items_count / 100 * start_train_percent)
|
|
26
|
-
self._count = {
|
|
27
|
-
"total": items_count,
|
|
28
|
-
"train": train_count,
|
|
29
|
-
"val": items_count - train_count
|
|
30
|
-
}
|
|
27
|
+
self._count = {"total": items_count, "train": train_count, "val": items_count - train_count}
|
|
31
28
|
|
|
32
29
|
self._percent = {
|
|
33
30
|
"total": 100,
|
|
34
31
|
"train": start_train_percent,
|
|
35
|
-
"val": 100 - start_train_percent
|
|
32
|
+
"val": 100 - start_train_percent,
|
|
36
33
|
}
|
|
37
34
|
|
|
38
35
|
super().__init__(widget_id=widget_id, file_path=__file__)
|
|
39
36
|
|
|
40
|
-
|
|
41
37
|
def get_json_data(self):
|
|
42
38
|
return {
|
|
43
39
|
"table_data": self._table_data,
|
|
44
40
|
"items_count": self._items_count,
|
|
45
|
-
"disabled": self._disabled
|
|
41
|
+
"disabled": self._disabled,
|
|
46
42
|
}
|
|
47
43
|
|
|
48
44
|
def get_json_state(self):
|
|
49
|
-
return {
|
|
50
|
-
"count": self._count,
|
|
51
|
-
"percent": self._percent
|
|
52
|
-
}
|
|
45
|
+
return {"count": self._count, "percent": self._percent}
|
|
53
46
|
|
|
54
47
|
def get_splits_counts(self) -> Dict[str, int]:
|
|
55
|
-
return StateJson()[self.widget_id]["count"]
|
|
48
|
+
return StateJson()[self.widget_id]["count"]
|
|
49
|
+
|
|
50
|
+
def set_train_split_percent(self, percent: int):
|
|
51
|
+
if 1 <= percent <= 99:
|
|
52
|
+
self._percent["train"] = percent
|
|
53
|
+
self._percent["val"] = 100 - percent
|
|
54
|
+
self._count["train"] = int(self._items_count / 100 * percent)
|
|
55
|
+
self._count["val"] = self._items_count - self._count["train"]
|
|
56
|
+
|
|
57
|
+
StateJson()[self.widget_id]["count"] = self._count
|
|
58
|
+
StateJson()[self.widget_id]["percent"] = self._percent
|
|
59
|
+
StateJson().send_changes()
|
|
60
|
+
else:
|
|
61
|
+
raise ValueError("percent must be in range [1; 99].")
|
|
62
|
+
|
|
63
|
+
def get_train_split_percent(self) -> Dict[str, int]:
|
|
64
|
+
return StateJson()[self.widget_id]["percent"]["train"]
|
|
65
|
+
|
|
66
|
+
def set_val_split_percent(self, percent: int):
|
|
67
|
+
if 1 <= percent <= 99:
|
|
68
|
+
self._percent["val"] = percent
|
|
69
|
+
self._percent["train"] = 100 - percent
|
|
70
|
+
self._count["val"] = int(self._items_count / 100 * percent)
|
|
71
|
+
self._count["train"] = self._items_count - self._count["val"]
|
|
72
|
+
StateJson()[self.widget_id]["count"] = self._count
|
|
73
|
+
StateJson()[self.widget_id]["percent"] = self._percent
|
|
74
|
+
StateJson().send_changes()
|
|
75
|
+
else:
|
|
76
|
+
raise ValueError("percent must be in range [1; 99].")
|
|
77
|
+
|
|
78
|
+
def get_val_split_percent(self) -> Dict[str, int]:
|
|
79
|
+
return StateJson()[self.widget_id]["percent"]["val"]
|
|
@@ -119,11 +119,18 @@ class SelectDatasetTree(Widget):
|
|
|
119
119
|
self._append_to_body = append_to_body
|
|
120
120
|
|
|
121
121
|
# Extract values from Enum to match the .type property of the ProjectInfo object.
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
if allowed_project_types is
|
|
125
|
-
|
|
126
|
-
|
|
122
|
+
|
|
123
|
+
if allowed_project_types is not None:
|
|
124
|
+
if all(allowed_project_types) is isinstance(allowed_project_types, ProjectType):
|
|
125
|
+
self._project_types = (
|
|
126
|
+
[project_type.value for project_type in allowed_project_types]
|
|
127
|
+
if allowed_project_types is not None
|
|
128
|
+
else None
|
|
129
|
+
)
|
|
130
|
+
elif all(allowed_project_types) is isinstance(allowed_project_types, str):
|
|
131
|
+
self._project_types = allowed_project_types
|
|
132
|
+
else:
|
|
133
|
+
self._project_types = None
|
|
127
134
|
|
|
128
135
|
# Widget components.
|
|
129
136
|
self._select_team = None
|
|
@@ -10,7 +10,6 @@ from supervisely.app.widgets import (
|
|
|
10
10
|
Field,
|
|
11
11
|
NotificationBox,
|
|
12
12
|
RadioTabs,
|
|
13
|
-
SelectDataset,
|
|
14
13
|
SelectString,
|
|
15
14
|
SelectTagMeta,
|
|
16
15
|
Widget,
|
|
@@ -18,11 +17,15 @@ from supervisely.app.widgets import (
|
|
|
18
17
|
from supervisely.app.widgets.random_splits_table.random_splits_table import (
|
|
19
18
|
RandomSplitsTable,
|
|
20
19
|
)
|
|
20
|
+
from supervisely.app.widgets.select_dataset_tree.select_dataset_tree import (
|
|
21
|
+
SelectDatasetTree,
|
|
22
|
+
)
|
|
21
23
|
from supervisely.io.fs import remove_dir
|
|
22
24
|
from supervisely.project import get_project_class
|
|
23
25
|
from supervisely.project.pointcloud_episode_project import PointcloudEpisodeProject
|
|
24
26
|
from supervisely.project.pointcloud_project import PointcloudProject
|
|
25
27
|
from supervisely.project.project import ItemInfo, Project
|
|
28
|
+
from supervisely.project.project_type import ProjectType
|
|
26
29
|
from supervisely.project.video_project import VideoProject
|
|
27
30
|
from supervisely.project.volume_project import VolumeProject
|
|
28
31
|
|
|
@@ -65,8 +68,8 @@ class TrainValSplits(Widget):
|
|
|
65
68
|
self._train_tag_select: SelectTagMeta = None
|
|
66
69
|
self._val_tag_select: SelectTagMeta = None
|
|
67
70
|
self._untagged_select: SelectString = None
|
|
68
|
-
self._train_ds_select: Union[
|
|
69
|
-
self._val_ds_select: Union[
|
|
71
|
+
self._train_ds_select: Union[SelectDatasetTree, SelectString] = None
|
|
72
|
+
self._val_ds_select: Union[SelectDatasetTree, SelectString] = None
|
|
70
73
|
self._split_methods = []
|
|
71
74
|
|
|
72
75
|
contents = []
|
|
@@ -97,6 +100,7 @@ class TrainValSplits(Widget):
|
|
|
97
100
|
super().__init__(widget_id=widget_id, file_path=__file__)
|
|
98
101
|
|
|
99
102
|
def _get_random_content(self):
|
|
103
|
+
items_count = 0
|
|
100
104
|
if self._project_id is not None:
|
|
101
105
|
items_count = self._project_info.items_count
|
|
102
106
|
elif self._project_fs is not None:
|
|
@@ -163,12 +167,37 @@ class TrainValSplits(Widget):
|
|
|
163
167
|
box_type="info",
|
|
164
168
|
)
|
|
165
169
|
if self._project_id is not None:
|
|
166
|
-
self._train_ds_select =
|
|
167
|
-
|
|
170
|
+
self._train_ds_select = SelectDatasetTree(
|
|
171
|
+
multiselect=True,
|
|
172
|
+
flat=True,
|
|
173
|
+
select_all_datasets=False,
|
|
174
|
+
allowed_project_types=[self._project_type],
|
|
175
|
+
always_open=False,
|
|
176
|
+
compact=True,
|
|
177
|
+
team_is_selectable=False,
|
|
178
|
+
workspace_is_selectable=False,
|
|
179
|
+
append_to_body=True,
|
|
168
180
|
)
|
|
169
|
-
|
|
170
|
-
|
|
181
|
+
|
|
182
|
+
self._val_ds_select = SelectDatasetTree(
|
|
183
|
+
multiselect=True,
|
|
184
|
+
flat=True,
|
|
185
|
+
select_all_datasets=False,
|
|
186
|
+
allowed_project_types=[self._project_type],
|
|
187
|
+
always_open=False,
|
|
188
|
+
compact=True,
|
|
189
|
+
team_is_selectable=False,
|
|
190
|
+
workspace_is_selectable=False,
|
|
191
|
+
append_to_body=True,
|
|
171
192
|
)
|
|
193
|
+
|
|
194
|
+
# old implementation
|
|
195
|
+
# self._train_ds_select = SelectDataset(
|
|
196
|
+
# project_id=self._project_id, multiselect=True, compact=True, show_label=False
|
|
197
|
+
# )
|
|
198
|
+
# self._val_ds_select = SelectDataset(
|
|
199
|
+
# project_id=self._project_id, multiselect=True, compact=True, show_label=False
|
|
200
|
+
# )
|
|
172
201
|
elif self._project_fs is not None:
|
|
173
202
|
ds_names = [ds.name for ds in self._project_fs.datasets]
|
|
174
203
|
self._train_ds_select = SelectString(ds_names, multiple=True)
|
|
@@ -196,6 +225,7 @@ class TrainValSplits(Widget):
|
|
|
196
225
|
def get_splits(self) -> Tuple[List[ItemInfo], List[ItemInfo]]:
|
|
197
226
|
split_method = self._content.get_active_tab()
|
|
198
227
|
tmp_project_dir = None
|
|
228
|
+
train_set, val_set = [], []
|
|
199
229
|
if self._project_fs is None:
|
|
200
230
|
tmp_project_dir = os.path.join(get_data_dir(), rand_str(15))
|
|
201
231
|
self._project_class.download(self._api, self._project_id, tmp_project_dir)
|
|
@@ -226,11 +256,11 @@ class TrainValSplits(Widget):
|
|
|
226
256
|
|
|
227
257
|
elif split_method == "Based on datasets":
|
|
228
258
|
if self._project_id is not None:
|
|
229
|
-
self._train_ds_select:
|
|
230
|
-
self._val_ds_select:
|
|
259
|
+
self._train_ds_select: SelectDatasetTree
|
|
260
|
+
self._val_ds_select: SelectDatasetTree
|
|
231
261
|
train_ds_ids = self._train_ds_select.get_selected_ids()
|
|
232
262
|
val_ds_ids = self._val_ds_select.get_selected_ids()
|
|
233
|
-
ds_infos = self._api.dataset.
|
|
263
|
+
ds_infos = [dataset for _, dataset in self._api.dataset.tree(self._project_id)]
|
|
234
264
|
train_ds_names, val_ds_names = [], []
|
|
235
265
|
for ds_info in ds_infos:
|
|
236
266
|
if ds_info.id in train_ds_ids:
|
|
@@ -251,6 +281,65 @@ class TrainValSplits(Widget):
|
|
|
251
281
|
remove_dir(tmp_project_dir)
|
|
252
282
|
return train_set, val_set
|
|
253
283
|
|
|
284
|
+
def set_split_method(self, split_method: Literal["random", "tags", "datasets"]):
|
|
285
|
+
if split_method == "random":
|
|
286
|
+
split_method = "Random"
|
|
287
|
+
elif split_method == "tags":
|
|
288
|
+
split_method = "Based on item tags"
|
|
289
|
+
elif split_method == "datasets":
|
|
290
|
+
split_method = "Based on datasets"
|
|
291
|
+
self._content.set_active_tab(split_method)
|
|
292
|
+
StateJson().send_changes()
|
|
293
|
+
DataJson().send_changes()
|
|
294
|
+
|
|
295
|
+
def get_split_method(self) -> str:
|
|
296
|
+
return self._content.get_active_tab()
|
|
297
|
+
|
|
298
|
+
def set_random_splits(
|
|
299
|
+
self, split: Literal["train", "training", "val", "validation"], percent: int
|
|
300
|
+
):
|
|
301
|
+
self._content.set_active_tab("Random")
|
|
302
|
+
if split == "train" or split == "training":
|
|
303
|
+
self._random_splits_table.set_train_split_percent(percent)
|
|
304
|
+
elif split == "val" or split == "validation":
|
|
305
|
+
self._random_splits_table.set_val_split_percent(percent)
|
|
306
|
+
else:
|
|
307
|
+
raise ValueError("Split value must be 'train', 'training', 'val' or 'validation'")
|
|
308
|
+
|
|
309
|
+
def get_train_split_percent(self) -> List[int]:
|
|
310
|
+
return self._random_splits_table.get_train_split_percent()
|
|
311
|
+
|
|
312
|
+
def get_val_split_percent(self) -> List[int]:
|
|
313
|
+
return 100 - self._random_splits_table.get_train_split_percent()
|
|
314
|
+
|
|
315
|
+
def set_tags_splits(
|
|
316
|
+
self, train_tag: str, val_tag: str, untagged_action: Literal["train", "val", "ignore"]
|
|
317
|
+
):
|
|
318
|
+
self._content.set_active_tab("Based on item tags")
|
|
319
|
+
self._train_tag_select.set_name(train_tag)
|
|
320
|
+
self._val_tag_select.set_name(val_tag)
|
|
321
|
+
self._untagged_select.set_value(untagged_action)
|
|
322
|
+
|
|
323
|
+
def get_train_tag(self) -> str:
|
|
324
|
+
return self._train_tag_select.get_selected_name()
|
|
325
|
+
|
|
326
|
+
def get_val_tag(self) -> str:
|
|
327
|
+
return self._val_tag_select.get_selected_name()
|
|
328
|
+
|
|
329
|
+
def set_datasets_splits(self, train_datasets: List[int], val_datasets: List[int]):
|
|
330
|
+
self._content.set_active_tab("Based on datasets")
|
|
331
|
+
self._train_ds_select.set_dataset_ids(train_datasets)
|
|
332
|
+
self._val_ds_select.set_dataset_ids(val_datasets)
|
|
333
|
+
|
|
334
|
+
def get_train_dataset_ids(self) -> List[int]:
|
|
335
|
+
return self._train_ds_select.get_selected_ids()
|
|
336
|
+
|
|
337
|
+
def get_val_dataset_ids(self) -> List[int]:
|
|
338
|
+
return self._val_ds_select.get_selected_ids()
|
|
339
|
+
|
|
340
|
+
def get_untagged_action(self) -> str:
|
|
341
|
+
return self._untagged_select.get_value()
|
|
342
|
+
|
|
254
343
|
def disable(self):
|
|
255
344
|
self._content.disable()
|
|
256
345
|
self._random_splits_table.disable()
|
|
@@ -200,6 +200,8 @@ class TreeSelect(Widget):
|
|
|
200
200
|
:rtype: Union[List[TreeSelect.Item], TreeSelect.Item]
|
|
201
201
|
"""
|
|
202
202
|
res = StateJson()[self.widget_id]["value"]
|
|
203
|
+
if res is None:
|
|
204
|
+
return None
|
|
203
205
|
if isinstance(res, list):
|
|
204
206
|
return [TreeSelect.Item.from_json(item) for item in res]
|
|
205
207
|
return TreeSelect.Item.from_json(res)
|
supervisely/nn/__init__.py
CHANGED
|
@@ -10,4 +10,6 @@ from supervisely.nn.prediction_dto import (
|
|
|
10
10
|
PredictionMask,
|
|
11
11
|
PredictionSegmentation,
|
|
12
12
|
)
|
|
13
|
-
from supervisely.nn.task_type import TaskType
|
|
13
|
+
from supervisely.nn.task_type import TaskType
|
|
14
|
+
from supervisely.nn.utils import ModelSource, RuntimeType
|
|
15
|
+
from supervisely.nn.experiments import ExperimentInfo, get_experiment_infos
|
|
@@ -59,6 +59,7 @@ class BaseTrainArtifacts:
|
|
|
59
59
|
self._weights_ext: str = None
|
|
60
60
|
self._config_file: str = None
|
|
61
61
|
self._pattern: str = None
|
|
62
|
+
self._available_task_types: List[str] = []
|
|
62
63
|
|
|
63
64
|
@property
|
|
64
65
|
def team_id(self) -> int:
|
|
@@ -516,3 +517,12 @@ class BaseTrainArtifacts:
|
|
|
516
517
|
train_infos = self.sort_train_infos(train_infos, sort)
|
|
517
518
|
logger.debug(f"Listing time: '{format(end_time - start_time, '.6f')}' sec")
|
|
518
519
|
return train_infos
|
|
520
|
+
|
|
521
|
+
def get_available_task_types(self) -> List[str]:
|
|
522
|
+
"""
|
|
523
|
+
Get available task types.
|
|
524
|
+
|
|
525
|
+
:return: The list of available task types.
|
|
526
|
+
:rtype: List[str]
|
|
527
|
+
"""
|
|
528
|
+
return self._available_task_types
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from os.path import join
|
|
2
2
|
from re import compile as re_compile
|
|
3
|
+
from typing import List
|
|
3
4
|
|
|
4
5
|
from supervisely.nn.artifacts.artifacts import BaseTrainArtifacts
|
|
5
6
|
|
|
@@ -16,6 +17,7 @@ class Detectron2(BaseTrainArtifacts):
|
|
|
16
17
|
self._weights_ext = ".pth"
|
|
17
18
|
self._config_file = "model_config.yaml"
|
|
18
19
|
self._pattern = re_compile(r"^/detectron2/\d+_[^/]+/?$")
|
|
20
|
+
self._available_task_types: List[str] = ["instance segmentation"]
|
|
19
21
|
|
|
20
22
|
def get_task_id(self, artifacts_folder: str) -> str:
|
|
21
23
|
parts = artifacts_folder.split("/")
|
supervisely/nn/artifacts/hrda.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
1
3
|
from supervisely.nn.artifacts.artifacts import BaseTrainArtifacts
|
|
2
4
|
|
|
3
5
|
|
|
@@ -13,6 +15,7 @@ class HRDA(BaseTrainArtifacts):
|
|
|
13
15
|
# self._task_type = "semantic segmentation"
|
|
14
16
|
# self._weights_ext = ".pth"
|
|
15
17
|
# self._config_file = "config.py"
|
|
18
|
+
# self._available_task_types: List[str] = ["semantic segmentation"]
|
|
16
19
|
|
|
17
20
|
def get_task_id(self, artifacts_folder: str) -> str:
|
|
18
21
|
raise NotImplementedError
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from os.path import join
|
|
2
2
|
from re import compile as re_compile
|
|
3
|
+
from typing import List
|
|
3
4
|
|
|
4
5
|
from supervisely.nn.artifacts.artifacts import BaseTrainArtifacts
|
|
5
6
|
|
|
@@ -14,6 +15,7 @@ class MMClassification(BaseTrainArtifacts):
|
|
|
14
15
|
self._task_type = "classification"
|
|
15
16
|
self._weights_ext = ".pth"
|
|
16
17
|
self._pattern = re_compile(r"^/mmclassification/\d+_[^/]+/?$")
|
|
18
|
+
self._available_task_types: List[str] = ["classification"]
|
|
17
19
|
|
|
18
20
|
def get_task_id(self, artifacts_folder: str) -> str:
|
|
19
21
|
parts = artifacts_folder.split("/")
|
|
@@ -1,10 +1,11 @@
|
|
|
1
|
+
import random
|
|
2
|
+
import string
|
|
1
3
|
from os.path import join
|
|
2
4
|
from re import compile as re_compile
|
|
5
|
+
from typing import List
|
|
3
6
|
|
|
4
7
|
from supervisely.io.fs import silent_remove
|
|
5
8
|
from supervisely.nn.artifacts.artifacts import BaseTrainArtifacts
|
|
6
|
-
import string
|
|
7
|
-
import random
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
class MMDetection(BaseTrainArtifacts):
|
|
@@ -19,6 +20,7 @@ class MMDetection(BaseTrainArtifacts):
|
|
|
19
20
|
self._info_file = "info/ui_state.json"
|
|
20
21
|
self._config_file = "config.py"
|
|
21
22
|
self._pattern = re_compile(r"^/mmdetection/\d+_[^/]+/?$")
|
|
23
|
+
self._available_task_types: List[str] = ["object detection", "instance segmentation"]
|
|
22
24
|
|
|
23
25
|
def get_task_id(self, artifacts_folder: str) -> str:
|
|
24
26
|
parts = artifacts_folder.split("/")
|
|
@@ -39,7 +41,7 @@ class MMDetection(BaseTrainArtifacts):
|
|
|
39
41
|
task_type = "undefined"
|
|
40
42
|
for file_info in self._get_file_infos():
|
|
41
43
|
if file_info.path == info_path:
|
|
42
|
-
json_data = self.
|
|
44
|
+
json_data = self._fetch_json_from_path(file_info.path)
|
|
43
45
|
task_type = json_data.get("task", "undefined")
|
|
44
46
|
break
|
|
45
47
|
return task_type
|
|
@@ -62,6 +64,7 @@ class MMDetection3(BaseTrainArtifacts):
|
|
|
62
64
|
self._weights_ext = ".pth"
|
|
63
65
|
self._config_file = "config.py"
|
|
64
66
|
self._pattern = re_compile(r"^/mmdetection-3/\d+_[^/]+/?$")
|
|
67
|
+
self._available_task_types: List[str] = ["object detection", "instance segmentation"]
|
|
65
68
|
|
|
66
69
|
def get_task_id(self, artifacts_folder: str) -> str:
|
|
67
70
|
parts = artifacts_folder.split("/")
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from os.path import join
|
|
2
2
|
from re import compile as re_compile
|
|
3
|
+
from typing import List
|
|
3
4
|
|
|
4
5
|
from supervisely.nn.artifacts.artifacts import BaseTrainArtifacts
|
|
5
6
|
|
|
@@ -15,6 +16,7 @@ class MMSegmentation(BaseTrainArtifacts):
|
|
|
15
16
|
self._weights_ext = ".pth"
|
|
16
17
|
self._config_file = "config.py"
|
|
17
18
|
self._pattern = re_compile(r"^/mmsegmentation/\d+_[^/]+/?$")
|
|
19
|
+
self._available_task_types: List[str] = ["instance segmentation"]
|
|
18
20
|
|
|
19
21
|
def get_task_id(self, artifacts_folder: str) -> str:
|
|
20
22
|
return artifacts_folder.split("/")[2].split("_")[0]
|
supervisely/nn/artifacts/ritm.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from os.path import join
|
|
2
2
|
from re import compile as re_compile
|
|
3
|
+
from typing import List
|
|
3
4
|
|
|
4
5
|
from supervisely.nn.artifacts.artifacts import BaseTrainArtifacts
|
|
5
6
|
|
|
@@ -15,6 +16,7 @@ class RITM(BaseTrainArtifacts):
|
|
|
15
16
|
self._info_file = "info/ui_state.json"
|
|
16
17
|
self._weights_ext = ".pth"
|
|
17
18
|
self._pattern = re_compile(r"^/RITM_training/\d+_[^/]+/?$")
|
|
19
|
+
self._available_task_types: List[str] = ["interactive segmentation"]
|
|
18
20
|
|
|
19
21
|
def get_task_id(self, artifacts_folder: str) -> str:
|
|
20
22
|
parts = artifacts_folder.split("/")
|
|
@@ -35,7 +37,7 @@ class RITM(BaseTrainArtifacts):
|
|
|
35
37
|
task_type = "undefined"
|
|
36
38
|
for file_info in self._get_file_infos():
|
|
37
39
|
if file_info.path == info_path:
|
|
38
|
-
json_data = self.
|
|
40
|
+
json_data = self._fetch_json_from_path(file_info.path)
|
|
39
41
|
task_type = json_data.get("segmentationType", "undefined")
|
|
40
42
|
if task_type is not None:
|
|
41
43
|
task_type = task_type.lower()
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from os.path import join
|
|
2
2
|
from re import compile as re_compile
|
|
3
|
+
from typing import List
|
|
3
4
|
|
|
4
5
|
from supervisely.nn.artifacts.artifacts import BaseTrainArtifacts
|
|
5
6
|
|
|
@@ -15,6 +16,7 @@ class RTDETR(BaseTrainArtifacts):
|
|
|
15
16
|
self._weights_ext = ".pth"
|
|
16
17
|
self._config_file = "config.yml"
|
|
17
18
|
self._pattern = re_compile(r"^/RT-DETR/[^/]+/\d+/?$")
|
|
19
|
+
self._available_task_types: List[str] = ["object detection"]
|
|
18
20
|
|
|
19
21
|
def get_task_id(self, artifacts_folder: str) -> str:
|
|
20
22
|
return artifacts_folder.split("/")[-1]
|
supervisely/nn/artifacts/unet.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from os.path import join
|
|
2
2
|
from re import compile as re_compile
|
|
3
|
+
from typing import List
|
|
3
4
|
|
|
4
5
|
from supervisely.nn.artifacts.artifacts import BaseTrainArtifacts
|
|
5
6
|
|
|
@@ -15,6 +16,7 @@ class UNet(BaseTrainArtifacts):
|
|
|
15
16
|
self._weights_ext = ".pth"
|
|
16
17
|
self._config_file = "train_args.json"
|
|
17
18
|
self._pattern = re_compile(r"^/unet/\d+_[^/]+/?$")
|
|
19
|
+
self._available_task_types: List[str] = ["semantic segmentation"]
|
|
18
20
|
|
|
19
21
|
def get_task_id(self, artifacts_folder: str) -> str:
|
|
20
22
|
parts = artifacts_folder.split("/")
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from os.path import join
|
|
2
2
|
from re import compile as re_compile
|
|
3
|
+
from typing import List
|
|
3
4
|
|
|
4
5
|
from supervisely.nn.artifacts.artifacts import BaseTrainArtifacts
|
|
5
6
|
|
|
@@ -15,6 +16,7 @@ class YOLOv5(BaseTrainArtifacts):
|
|
|
15
16
|
self._weights_ext = ".pt"
|
|
16
17
|
self._config_file = None
|
|
17
18
|
self._pattern = re_compile(r"^/yolov5_train/[^/]+/\d+/?$")
|
|
19
|
+
self._available_task_types: List[str] = ["object detection"]
|
|
18
20
|
|
|
19
21
|
def get_task_id(self, artifacts_folder: str) -> str:
|
|
20
22
|
return artifacts_folder.split("/")[-1]
|
|
@@ -43,3 +45,4 @@ class YOLOv5v2(YOLOv5):
|
|
|
43
45
|
self._weights_ext = ".pt"
|
|
44
46
|
self._config_file = None
|
|
45
47
|
self._pattern = re_compile(r"^/yolov5_2.0_train/[^/]+/\d+/?$")
|
|
48
|
+
self._available_task_types: List[str] = ["object detection"]
|