supervisely 6.73.243__py3-none-any.whl → 6.73.245__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of supervisely might be problematic. Click here for more details.

Files changed (56) hide show
  1. supervisely/__init__.py +1 -1
  2. supervisely/_utils.py +18 -0
  3. supervisely/app/widgets/__init__.py +1 -0
  4. supervisely/app/widgets/card/card.py +3 -0
  5. supervisely/app/widgets/classes_table/classes_table.py +15 -1
  6. supervisely/app/widgets/custom_models_selector/custom_models_selector.py +25 -7
  7. supervisely/app/widgets/custom_models_selector/template.html +1 -1
  8. supervisely/app/widgets/experiment_selector/__init__.py +0 -0
  9. supervisely/app/widgets/experiment_selector/experiment_selector.py +500 -0
  10. supervisely/app/widgets/experiment_selector/style.css +27 -0
  11. supervisely/app/widgets/experiment_selector/template.html +82 -0
  12. supervisely/app/widgets/pretrained_models_selector/pretrained_models_selector.py +25 -3
  13. supervisely/app/widgets/random_splits_table/random_splits_table.py +41 -17
  14. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +12 -5
  15. supervisely/app/widgets/train_val_splits/train_val_splits.py +99 -10
  16. supervisely/app/widgets/tree_select/tree_select.py +2 -0
  17. supervisely/nn/__init__.py +3 -1
  18. supervisely/nn/artifacts/artifacts.py +10 -0
  19. supervisely/nn/artifacts/detectron2.py +2 -0
  20. supervisely/nn/artifacts/hrda.py +3 -0
  21. supervisely/nn/artifacts/mmclassification.py +2 -0
  22. supervisely/nn/artifacts/mmdetection.py +6 -3
  23. supervisely/nn/artifacts/mmsegmentation.py +2 -0
  24. supervisely/nn/artifacts/ritm.py +3 -1
  25. supervisely/nn/artifacts/rtdetr.py +2 -0
  26. supervisely/nn/artifacts/unet.py +2 -0
  27. supervisely/nn/artifacts/yolov5.py +3 -0
  28. supervisely/nn/artifacts/yolov8.py +7 -1
  29. supervisely/nn/experiments.py +113 -0
  30. supervisely/nn/inference/gui/__init__.py +3 -1
  31. supervisely/nn/inference/gui/gui.py +31 -232
  32. supervisely/nn/inference/gui/serving_gui.py +223 -0
  33. supervisely/nn/inference/gui/serving_gui_template.py +240 -0
  34. supervisely/nn/inference/inference.py +225 -24
  35. supervisely/nn/training/__init__.py +0 -0
  36. supervisely/nn/training/gui/__init__.py +1 -0
  37. supervisely/nn/training/gui/classes_selector.py +100 -0
  38. supervisely/nn/training/gui/gui.py +539 -0
  39. supervisely/nn/training/gui/hyperparameters_selector.py +117 -0
  40. supervisely/nn/training/gui/input_selector.py +70 -0
  41. supervisely/nn/training/gui/model_selector.py +95 -0
  42. supervisely/nn/training/gui/train_val_splits_selector.py +200 -0
  43. supervisely/nn/training/gui/training_logs.py +93 -0
  44. supervisely/nn/training/gui/training_process.py +114 -0
  45. supervisely/nn/training/gui/utils.py +128 -0
  46. supervisely/nn/training/loggers/__init__.py +0 -0
  47. supervisely/nn/training/loggers/base_train_logger.py +58 -0
  48. supervisely/nn/training/loggers/tensorboard_logger.py +46 -0
  49. supervisely/nn/training/train_app.py +2038 -0
  50. supervisely/nn/utils.py +5 -0
  51. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/METADATA +3 -1
  52. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/RECORD +56 -34
  53. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/LICENSE +0 -0
  54. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/WHEEL +0 -0
  55. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/entry_points.txt +0 -0
  56. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,27 @@
1
+ .custom-models-selector-table {
2
+ border-collapse: collapse;
3
+ }
4
+ .custom-models-selector-table tr:nth-child(2n) {
5
+ background-color: #f6f8fa;
6
+ }
7
+ .custom-models-selector-table td,
8
+ .custom-models-selector-table th {
9
+ border: 1px solid #dfe2e5;
10
+ padding: 6px 13px;
11
+ text-align: center;
12
+ line-height: 20px;
13
+ }
14
+
15
+ .custom-models-selector-table td {
16
+ text-align: left;
17
+ }
18
+
19
+ .custom-models-selector-table tr td:nth-child(4) {
20
+ text-align: center;
21
+ }
22
+
23
+ .el-radio-group.multi-line label.el-radio {
24
+ display: block;
25
+ margin-left: 0px;
26
+ margin-bottom: 5px;
27
+ }
@@ -0,0 +1,82 @@
1
+ <link rel="stylesheet" href="./sly/css/app/widgets/custom_models_selector/style.css"/>
2
+
3
+ <div
4
+ {% if widget._changes_handled == true %}
5
+ @change="post('/{{{widget.widget_id}}}/value_changed')"
6
+ {% endif %}
7
+ >
8
+
9
+ <div v-if="Object.keys(data.{{{widget.widget_id}}}.rowsHtml).length === 0"> You don't have any custom models</div>
10
+ <div v-else>
11
+
12
+ <div v-if="data.{{{widget.widget_id}}}.taskTypes.length > 1">
13
+ <sly-field
14
+ title="Task Type"
15
+ >
16
+ <el-radio-group
17
+ class="multi-line mt10"
18
+ :value="state.{{{widget.widget_id}}}.selectedTaskType"
19
+
20
+ {% if widget._task_type_changes_handled == true %}
21
+ @input="(evt) => {state.{{{widget.widget_id}}}.selectedTaskType = evt; state.{{{widget.widget_id}}}.selectedRow = 0; post('/{{{widget.widget_id}}}/task_type_changed')}"
22
+ {% else %}
23
+ @input="(evt) => {state.{{{widget.widget_id}}}.selectedTaskType = evt; state.{{{widget.widget_id}}}.selectedRow = 0;}"
24
+ {% endif %}
25
+ >
26
+
27
+ <el-radio
28
+ v-for="(item, idx) in {{{widget._task_types}}}"
29
+ :key="item"
30
+ :label="item"
31
+ >
32
+ {{ item }}
33
+ </el-radio>
34
+ </el-radio-group>
35
+ </sly-field>
36
+ </div>
37
+
38
+ <div>
39
+
40
+ <table class="custom-models-selector-table">
41
+ <thead>
42
+ <tr>
43
+ <th v-for="col in data.{{{widget.widget_id}}}.columns">
44
+ <div> {{col}} </div>
45
+ </th>
46
+ </tr>
47
+ </thead>
48
+ <tbody>
49
+ <tr v-for="row, ridx in data.{{{widget.widget_id}}}.rowsHtml[state.{{{widget.widget_id}}}.selectedTaskType]">
50
+ <td v-for="col, vidx in row">
51
+ <div v-if="vidx === 0" style="display: flex;">
52
+ <el-radio
53
+ style="display: flex;"
54
+ v-model="state.{{{widget.widget_id}}}.selectedRow"
55
+ :label="ridx"
56
+ >&#8205;</el-radio>
57
+
58
+ <sly-html-compiler :params="{ridx: ridx, vidx: vidx}" :template="col" :data="data" :state="state"></sly-html-compiler>
59
+
60
+ </div>
61
+
62
+ <div v-else>
63
+
64
+ <sly-html-compiler :params="{ridx: ridx, vidx: vidx}" :template="col" :data="data" :state="state">
65
+ </sly-html-compiler>
66
+
67
+ </div>
68
+
69
+ </td>
70
+ </tr>
71
+ </tbody>
72
+ </table>
73
+ </div>
74
+ <div class="mt10" v-if="{{{widget.show_custom_checkpoint_path}}}"
75
+ >
76
+ {{{widget.show_custom_checkpoint_path_checkbox}}}
77
+ <div class="mt10">
78
+ {{{widget.custom_tab_widgets}}}
79
+ </div>
80
+ </div>
81
+ </div>
82
+ </div>
@@ -136,6 +136,9 @@ class PretrainedModelsSelector(Widget):
136
136
 
137
137
  def get_selected_model_params(self, model_name_column: str = "Model") -> Union[Dict, None]:
138
138
  selected_model = self.get_selected_row()
139
+ if selected_model is None:
140
+ return {}
141
+
139
142
  model_name = selected_model.get(model_name_column)
140
143
  if model_name is None:
141
144
  raise ValueError(
@@ -172,24 +175,43 @@ class PretrainedModelsSelector(Widget):
172
175
  if widget_actual_state is not None and widget_actual_data is not None:
173
176
  return widget_actual_state["selectedRow"]
174
177
 
175
- def set_active_arch_type(self, arch_type: str):
178
+ def set_active_arch_type(self, arch_type: str) -> None:
176
179
  if arch_type not in self._arch_types:
177
180
  raise ValueError(f'Architecture type "{arch_type}" does not exist')
178
181
  StateJson()[self.widget_id]["selectedArchType"] = arch_type
179
182
  StateJson().send_changes()
180
183
 
181
- def set_active_task_type(self, task_type: str):
184
+ def set_active_task_type(self, task_type: str) -> None:
182
185
  if task_type not in self._task_types:
183
186
  raise ValueError(f'Task type "{task_type}" does not exist')
184
187
  StateJson()[self.widget_id]["selectedTaskType"] = task_type
185
188
  StateJson().send_changes()
186
189
 
187
- def set_active_row(self, row_index: int):
190
+ def set_active_row(self, row_index: int) -> None:
188
191
  if row_index < 0:
189
192
  raise ValueError(f'Row with index "{row_index}" does not exist')
190
193
  StateJson()[self.widget_id]["selectedRow"] = row_index
191
194
  StateJson().send_changes()
192
195
 
196
+ def set_by_model_name(self, model_name: str) -> None:
197
+ for task_type in self._table_data:
198
+ for arch_type in self._table_data[task_type]:
199
+ for idx, model in enumerate(self._table_data[task_type][arch_type]):
200
+ model_meta = model.get("meta", {})
201
+ if model_meta.get("model_name") == model_name:
202
+ self.set_active_task_type(task_type)
203
+ self.set_active_arch_type(arch_type)
204
+ self.set_active_row(idx)
205
+ return
206
+
207
+ def get_by_model_name(self, model_name: str) -> Union[Dict, None]:
208
+ for task_type in self._table_data:
209
+ for arch_type in self._table_data[task_type]:
210
+ for idx, model in enumerate(self._table_data[task_type][arch_type]):
211
+ model_meta = model.get("meta", {})
212
+ if model_meta.get("model_name") == model_name:
213
+ return model
214
+
193
215
  def _filter_and_sort_models(self, models: List[Dict], sort_models: bool = True) -> Dict:
194
216
  filtered_models = {}
195
217
 
@@ -1,15 +1,16 @@
1
1
  from typing import Dict, Optional
2
2
 
3
- from supervisely.app import StateJson
3
+ from supervisely.app import DataJson, StateJson
4
4
  from supervisely.app.widgets import Widget
5
5
 
6
+
6
7
  class RandomSplitsTable(Widget):
7
8
  def __init__(
8
- self,
9
- items_count: int,
9
+ self,
10
+ items_count: int,
10
11
  start_train_percent: Optional[int] = 80,
11
12
  disabled: Optional[bool] = False,
12
- widget_id: Optional[int] = None
13
+ widget_id: Optional[int] = None,
13
14
  ):
14
15
  self._disabled = disabled
15
16
  if 1 <= start_train_percent <= 99:
@@ -23,33 +24,56 @@ class RandomSplitsTable(Widget):
23
24
  ]
24
25
  self._items_count = items_count
25
26
  train_count = int(items_count / 100 * start_train_percent)
26
- self._count = {
27
- "total": items_count,
28
- "train": train_count,
29
- "val": items_count - train_count
30
- }
27
+ self._count = {"total": items_count, "train": train_count, "val": items_count - train_count}
31
28
 
32
29
  self._percent = {
33
30
  "total": 100,
34
31
  "train": start_train_percent,
35
- "val": 100 - start_train_percent
32
+ "val": 100 - start_train_percent,
36
33
  }
37
34
 
38
35
  super().__init__(widget_id=widget_id, file_path=__file__)
39
36
 
40
-
41
37
  def get_json_data(self):
42
38
  return {
43
39
  "table_data": self._table_data,
44
40
  "items_count": self._items_count,
45
- "disabled": self._disabled
41
+ "disabled": self._disabled,
46
42
  }
47
43
 
48
44
  def get_json_state(self):
49
- return {
50
- "count": self._count,
51
- "percent": self._percent
52
- }
45
+ return {"count": self._count, "percent": self._percent}
53
46
 
54
47
  def get_splits_counts(self) -> Dict[str, int]:
55
- return StateJson()[self.widget_id]["count"]
48
+ return StateJson()[self.widget_id]["count"]
49
+
50
+ def set_train_split_percent(self, percent: int):
51
+ if 1 <= percent <= 99:
52
+ self._percent["train"] = percent
53
+ self._percent["val"] = 100 - percent
54
+ self._count["train"] = int(self._items_count / 100 * percent)
55
+ self._count["val"] = self._items_count - self._count["train"]
56
+
57
+ StateJson()[self.widget_id]["count"] = self._count
58
+ StateJson()[self.widget_id]["percent"] = self._percent
59
+ StateJson().send_changes()
60
+ else:
61
+ raise ValueError("percent must be in range [1; 99].")
62
+
63
+ def get_train_split_percent(self) -> Dict[str, int]:
64
+ return StateJson()[self.widget_id]["percent"]["train"]
65
+
66
+ def set_val_split_percent(self, percent: int):
67
+ if 1 <= percent <= 99:
68
+ self._percent["val"] = percent
69
+ self._percent["train"] = 100 - percent
70
+ self._count["val"] = int(self._items_count / 100 * percent)
71
+ self._count["train"] = self._items_count - self._count["val"]
72
+ StateJson()[self.widget_id]["count"] = self._count
73
+ StateJson()[self.widget_id]["percent"] = self._percent
74
+ StateJson().send_changes()
75
+ else:
76
+ raise ValueError("percent must be in range [1; 99].")
77
+
78
+ def get_val_split_percent(self) -> Dict[str, int]:
79
+ return StateJson()[self.widget_id]["percent"]["val"]
@@ -119,11 +119,18 @@ class SelectDatasetTree(Widget):
119
119
  self._append_to_body = append_to_body
120
120
 
121
121
  # Extract values from Enum to match the .type property of the ProjectInfo object.
122
- self._project_types = (
123
- [project_type.value for project_type in allowed_project_types]
124
- if allowed_project_types is not None
125
- else None
126
- )
122
+
123
+ if allowed_project_types is not None:
124
+ if all(allowed_project_types) is isinstance(allowed_project_types, ProjectType):
125
+ self._project_types = (
126
+ [project_type.value for project_type in allowed_project_types]
127
+ if allowed_project_types is not None
128
+ else None
129
+ )
130
+ elif all(allowed_project_types) is isinstance(allowed_project_types, str):
131
+ self._project_types = allowed_project_types
132
+ else:
133
+ self._project_types = None
127
134
 
128
135
  # Widget components.
129
136
  self._select_team = None
@@ -10,7 +10,6 @@ from supervisely.app.widgets import (
10
10
  Field,
11
11
  NotificationBox,
12
12
  RadioTabs,
13
- SelectDataset,
14
13
  SelectString,
15
14
  SelectTagMeta,
16
15
  Widget,
@@ -18,11 +17,15 @@ from supervisely.app.widgets import (
18
17
  from supervisely.app.widgets.random_splits_table.random_splits_table import (
19
18
  RandomSplitsTable,
20
19
  )
20
+ from supervisely.app.widgets.select_dataset_tree.select_dataset_tree import (
21
+ SelectDatasetTree,
22
+ )
21
23
  from supervisely.io.fs import remove_dir
22
24
  from supervisely.project import get_project_class
23
25
  from supervisely.project.pointcloud_episode_project import PointcloudEpisodeProject
24
26
  from supervisely.project.pointcloud_project import PointcloudProject
25
27
  from supervisely.project.project import ItemInfo, Project
28
+ from supervisely.project.project_type import ProjectType
26
29
  from supervisely.project.video_project import VideoProject
27
30
  from supervisely.project.volume_project import VolumeProject
28
31
 
@@ -65,8 +68,8 @@ class TrainValSplits(Widget):
65
68
  self._train_tag_select: SelectTagMeta = None
66
69
  self._val_tag_select: SelectTagMeta = None
67
70
  self._untagged_select: SelectString = None
68
- self._train_ds_select: Union[SelectDataset, SelectString] = None
69
- self._val_ds_select: Union[SelectDataset, SelectString] = None
71
+ self._train_ds_select: Union[SelectDatasetTree, SelectString] = None
72
+ self._val_ds_select: Union[SelectDatasetTree, SelectString] = None
70
73
  self._split_methods = []
71
74
 
72
75
  contents = []
@@ -97,6 +100,7 @@ class TrainValSplits(Widget):
97
100
  super().__init__(widget_id=widget_id, file_path=__file__)
98
101
 
99
102
  def _get_random_content(self):
103
+ items_count = 0
100
104
  if self._project_id is not None:
101
105
  items_count = self._project_info.items_count
102
106
  elif self._project_fs is not None:
@@ -163,12 +167,37 @@ class TrainValSplits(Widget):
163
167
  box_type="info",
164
168
  )
165
169
  if self._project_id is not None:
166
- self._train_ds_select = SelectDataset(
167
- project_id=self._project_id, multiselect=True, compact=True, show_label=False
170
+ self._train_ds_select = SelectDatasetTree(
171
+ multiselect=True,
172
+ flat=True,
173
+ select_all_datasets=False,
174
+ allowed_project_types=[self._project_type],
175
+ always_open=False,
176
+ compact=True,
177
+ team_is_selectable=False,
178
+ workspace_is_selectable=False,
179
+ append_to_body=True,
168
180
  )
169
- self._val_ds_select = SelectDataset(
170
- project_id=self._project_id, multiselect=True, compact=True, show_label=False
181
+
182
+ self._val_ds_select = SelectDatasetTree(
183
+ multiselect=True,
184
+ flat=True,
185
+ select_all_datasets=False,
186
+ allowed_project_types=[self._project_type],
187
+ always_open=False,
188
+ compact=True,
189
+ team_is_selectable=False,
190
+ workspace_is_selectable=False,
191
+ append_to_body=True,
171
192
  )
193
+
194
+ # old implementation
195
+ # self._train_ds_select = SelectDataset(
196
+ # project_id=self._project_id, multiselect=True, compact=True, show_label=False
197
+ # )
198
+ # self._val_ds_select = SelectDataset(
199
+ # project_id=self._project_id, multiselect=True, compact=True, show_label=False
200
+ # )
172
201
  elif self._project_fs is not None:
173
202
  ds_names = [ds.name for ds in self._project_fs.datasets]
174
203
  self._train_ds_select = SelectString(ds_names, multiple=True)
@@ -196,6 +225,7 @@ class TrainValSplits(Widget):
196
225
  def get_splits(self) -> Tuple[List[ItemInfo], List[ItemInfo]]:
197
226
  split_method = self._content.get_active_tab()
198
227
  tmp_project_dir = None
228
+ train_set, val_set = [], []
199
229
  if self._project_fs is None:
200
230
  tmp_project_dir = os.path.join(get_data_dir(), rand_str(15))
201
231
  self._project_class.download(self._api, self._project_id, tmp_project_dir)
@@ -226,11 +256,11 @@ class TrainValSplits(Widget):
226
256
 
227
257
  elif split_method == "Based on datasets":
228
258
  if self._project_id is not None:
229
- self._train_ds_select: SelectDataset
230
- self._val_ds_select: SelectDataset
259
+ self._train_ds_select: SelectDatasetTree
260
+ self._val_ds_select: SelectDatasetTree
231
261
  train_ds_ids = self._train_ds_select.get_selected_ids()
232
262
  val_ds_ids = self._val_ds_select.get_selected_ids()
233
- ds_infos = self._api.dataset.get_list(self._project_id)
263
+ ds_infos = [dataset for _, dataset in self._api.dataset.tree(self._project_id)]
234
264
  train_ds_names, val_ds_names = [], []
235
265
  for ds_info in ds_infos:
236
266
  if ds_info.id in train_ds_ids:
@@ -251,6 +281,65 @@ class TrainValSplits(Widget):
251
281
  remove_dir(tmp_project_dir)
252
282
  return train_set, val_set
253
283
 
284
+ def set_split_method(self, split_method: Literal["random", "tags", "datasets"]):
285
+ if split_method == "random":
286
+ split_method = "Random"
287
+ elif split_method == "tags":
288
+ split_method = "Based on item tags"
289
+ elif split_method == "datasets":
290
+ split_method = "Based on datasets"
291
+ self._content.set_active_tab(split_method)
292
+ StateJson().send_changes()
293
+ DataJson().send_changes()
294
+
295
+ def get_split_method(self) -> str:
296
+ return self._content.get_active_tab()
297
+
298
+ def set_random_splits(
299
+ self, split: Literal["train", "training", "val", "validation"], percent: int
300
+ ):
301
+ self._content.set_active_tab("Random")
302
+ if split == "train" or split == "training":
303
+ self._random_splits_table.set_train_split_percent(percent)
304
+ elif split == "val" or split == "validation":
305
+ self._random_splits_table.set_val_split_percent(percent)
306
+ else:
307
+ raise ValueError("Split value must be 'train', 'training', 'val' or 'validation'")
308
+
309
+ def get_train_split_percent(self) -> List[int]:
310
+ return self._random_splits_table.get_train_split_percent()
311
+
312
+ def get_val_split_percent(self) -> List[int]:
313
+ return 100 - self._random_splits_table.get_train_split_percent()
314
+
315
+ def set_tags_splits(
316
+ self, train_tag: str, val_tag: str, untagged_action: Literal["train", "val", "ignore"]
317
+ ):
318
+ self._content.set_active_tab("Based on item tags")
319
+ self._train_tag_select.set_name(train_tag)
320
+ self._val_tag_select.set_name(val_tag)
321
+ self._untagged_select.set_value(untagged_action)
322
+
323
+ def get_train_tag(self) -> str:
324
+ return self._train_tag_select.get_selected_name()
325
+
326
+ def get_val_tag(self) -> str:
327
+ return self._val_tag_select.get_selected_name()
328
+
329
+ def set_datasets_splits(self, train_datasets: List[int], val_datasets: List[int]):
330
+ self._content.set_active_tab("Based on datasets")
331
+ self._train_ds_select.set_dataset_ids(train_datasets)
332
+ self._val_ds_select.set_dataset_ids(val_datasets)
333
+
334
+ def get_train_dataset_ids(self) -> List[int]:
335
+ return self._train_ds_select.get_selected_ids()
336
+
337
+ def get_val_dataset_ids(self) -> List[int]:
338
+ return self._val_ds_select.get_selected_ids()
339
+
340
+ def get_untagged_action(self) -> str:
341
+ return self._untagged_select.get_value()
342
+
254
343
  def disable(self):
255
344
  self._content.disable()
256
345
  self._random_splits_table.disable()
@@ -200,6 +200,8 @@ class TreeSelect(Widget):
200
200
  :rtype: Union[List[TreeSelect.Item], TreeSelect.Item]
201
201
  """
202
202
  res = StateJson()[self.widget_id]["value"]
203
+ if res is None:
204
+ return None
203
205
  if isinstance(res, list):
204
206
  return [TreeSelect.Item.from_json(item) for item in res]
205
207
  return TreeSelect.Item.from_json(res)
@@ -10,4 +10,6 @@ from supervisely.nn.prediction_dto import (
10
10
  PredictionMask,
11
11
  PredictionSegmentation,
12
12
  )
13
- from supervisely.nn.task_type import TaskType
13
+ from supervisely.nn.task_type import TaskType
14
+ from supervisely.nn.utils import ModelSource, RuntimeType
15
+ from supervisely.nn.experiments import ExperimentInfo, get_experiment_infos
@@ -59,6 +59,7 @@ class BaseTrainArtifacts:
59
59
  self._weights_ext: str = None
60
60
  self._config_file: str = None
61
61
  self._pattern: str = None
62
+ self._available_task_types: List[str] = []
62
63
 
63
64
  @property
64
65
  def team_id(self) -> int:
@@ -516,3 +517,12 @@ class BaseTrainArtifacts:
516
517
  train_infos = self.sort_train_infos(train_infos, sort)
517
518
  logger.debug(f"Listing time: '{format(end_time - start_time, '.6f')}' sec")
518
519
  return train_infos
520
+
521
+ def get_available_task_types(self) -> List[str]:
522
+ """
523
+ Get available task types.
524
+
525
+ :return: The list of available task types.
526
+ :rtype: List[str]
527
+ """
528
+ return self._available_task_types
@@ -1,5 +1,6 @@
1
1
  from os.path import join
2
2
  from re import compile as re_compile
3
+ from typing import List
3
4
 
4
5
  from supervisely.nn.artifacts.artifacts import BaseTrainArtifacts
5
6
 
@@ -16,6 +17,7 @@ class Detectron2(BaseTrainArtifacts):
16
17
  self._weights_ext = ".pth"
17
18
  self._config_file = "model_config.yaml"
18
19
  self._pattern = re_compile(r"^/detectron2/\d+_[^/]+/?$")
20
+ self._available_task_types: List[str] = ["instance segmentation"]
19
21
 
20
22
  def get_task_id(self, artifacts_folder: str) -> str:
21
23
  parts = artifacts_folder.split("/")
@@ -1,3 +1,5 @@
1
+ from typing import List
2
+
1
3
  from supervisely.nn.artifacts.artifacts import BaseTrainArtifacts
2
4
 
3
5
 
@@ -13,6 +15,7 @@ class HRDA(BaseTrainArtifacts):
13
15
  # self._task_type = "semantic segmentation"
14
16
  # self._weights_ext = ".pth"
15
17
  # self._config_file = "config.py"
18
+ # self._available_task_types: List[str] = ["semantic segmentation"]
16
19
 
17
20
  def get_task_id(self, artifacts_folder: str) -> str:
18
21
  raise NotImplementedError
@@ -1,5 +1,6 @@
1
1
  from os.path import join
2
2
  from re import compile as re_compile
3
+ from typing import List
3
4
 
4
5
  from supervisely.nn.artifacts.artifacts import BaseTrainArtifacts
5
6
 
@@ -14,6 +15,7 @@ class MMClassification(BaseTrainArtifacts):
14
15
  self._task_type = "classification"
15
16
  self._weights_ext = ".pth"
16
17
  self._pattern = re_compile(r"^/mmclassification/\d+_[^/]+/?$")
18
+ self._available_task_types: List[str] = ["classification"]
17
19
 
18
20
  def get_task_id(self, artifacts_folder: str) -> str:
19
21
  parts = artifacts_folder.split("/")
@@ -1,10 +1,11 @@
1
+ import random
2
+ import string
1
3
  from os.path import join
2
4
  from re import compile as re_compile
5
+ from typing import List
3
6
 
4
7
  from supervisely.io.fs import silent_remove
5
8
  from supervisely.nn.artifacts.artifacts import BaseTrainArtifacts
6
- import string
7
- import random
8
9
 
9
10
 
10
11
  class MMDetection(BaseTrainArtifacts):
@@ -19,6 +20,7 @@ class MMDetection(BaseTrainArtifacts):
19
20
  self._info_file = "info/ui_state.json"
20
21
  self._config_file = "config.py"
21
22
  self._pattern = re_compile(r"^/mmdetection/\d+_[^/]+/?$")
23
+ self._available_task_types: List[str] = ["object detection", "instance segmentation"]
22
24
 
23
25
  def get_task_id(self, artifacts_folder: str) -> str:
24
26
  parts = artifacts_folder.split("/")
@@ -39,7 +41,7 @@ class MMDetection(BaseTrainArtifacts):
39
41
  task_type = "undefined"
40
42
  for file_info in self._get_file_infos():
41
43
  if file_info.path == info_path:
42
- json_data = self._fetch_json_from_url(file_info.full_storage_url)
44
+ json_data = self._fetch_json_from_path(file_info.path)
43
45
  task_type = json_data.get("task", "undefined")
44
46
  break
45
47
  return task_type
@@ -62,6 +64,7 @@ class MMDetection3(BaseTrainArtifacts):
62
64
  self._weights_ext = ".pth"
63
65
  self._config_file = "config.py"
64
66
  self._pattern = re_compile(r"^/mmdetection-3/\d+_[^/]+/?$")
67
+ self._available_task_types: List[str] = ["object detection", "instance segmentation"]
65
68
 
66
69
  def get_task_id(self, artifacts_folder: str) -> str:
67
70
  parts = artifacts_folder.split("/")
@@ -1,5 +1,6 @@
1
1
  from os.path import join
2
2
  from re import compile as re_compile
3
+ from typing import List
3
4
 
4
5
  from supervisely.nn.artifacts.artifacts import BaseTrainArtifacts
5
6
 
@@ -15,6 +16,7 @@ class MMSegmentation(BaseTrainArtifacts):
15
16
  self._weights_ext = ".pth"
16
17
  self._config_file = "config.py"
17
18
  self._pattern = re_compile(r"^/mmsegmentation/\d+_[^/]+/?$")
19
+ self._available_task_types: List[str] = ["instance segmentation"]
18
20
 
19
21
  def get_task_id(self, artifacts_folder: str) -> str:
20
22
  return artifacts_folder.split("/")[2].split("_")[0]
@@ -1,5 +1,6 @@
1
1
  from os.path import join
2
2
  from re import compile as re_compile
3
+ from typing import List
3
4
 
4
5
  from supervisely.nn.artifacts.artifacts import BaseTrainArtifacts
5
6
 
@@ -15,6 +16,7 @@ class RITM(BaseTrainArtifacts):
15
16
  self._info_file = "info/ui_state.json"
16
17
  self._weights_ext = ".pth"
17
18
  self._pattern = re_compile(r"^/RITM_training/\d+_[^/]+/?$")
19
+ self._available_task_types: List[str] = ["interactive segmentation"]
18
20
 
19
21
  def get_task_id(self, artifacts_folder: str) -> str:
20
22
  parts = artifacts_folder.split("/")
@@ -35,7 +37,7 @@ class RITM(BaseTrainArtifacts):
35
37
  task_type = "undefined"
36
38
  for file_info in self._get_file_infos():
37
39
  if file_info.path == info_path:
38
- json_data = self._fetch_json_from_url(file_info.full_storage_url)
40
+ json_data = self._fetch_json_from_path(file_info.path)
39
41
  task_type = json_data.get("segmentationType", "undefined")
40
42
  if task_type is not None:
41
43
  task_type = task_type.lower()
@@ -1,5 +1,6 @@
1
1
  from os.path import join
2
2
  from re import compile as re_compile
3
+ from typing import List
3
4
 
4
5
  from supervisely.nn.artifacts.artifacts import BaseTrainArtifacts
5
6
 
@@ -15,6 +16,7 @@ class RTDETR(BaseTrainArtifacts):
15
16
  self._weights_ext = ".pth"
16
17
  self._config_file = "config.yml"
17
18
  self._pattern = re_compile(r"^/RT-DETR/[^/]+/\d+/?$")
19
+ self._available_task_types: List[str] = ["object detection"]
18
20
 
19
21
  def get_task_id(self, artifacts_folder: str) -> str:
20
22
  return artifacts_folder.split("/")[-1]
@@ -1,5 +1,6 @@
1
1
  from os.path import join
2
2
  from re import compile as re_compile
3
+ from typing import List
3
4
 
4
5
  from supervisely.nn.artifacts.artifacts import BaseTrainArtifacts
5
6
 
@@ -15,6 +16,7 @@ class UNet(BaseTrainArtifacts):
15
16
  self._weights_ext = ".pth"
16
17
  self._config_file = "train_args.json"
17
18
  self._pattern = re_compile(r"^/unet/\d+_[^/]+/?$")
19
+ self._available_task_types: List[str] = ["semantic segmentation"]
18
20
 
19
21
  def get_task_id(self, artifacts_folder: str) -> str:
20
22
  parts = artifacts_folder.split("/")
@@ -1,5 +1,6 @@
1
1
  from os.path import join
2
2
  from re import compile as re_compile
3
+ from typing import List
3
4
 
4
5
  from supervisely.nn.artifacts.artifacts import BaseTrainArtifacts
5
6
 
@@ -15,6 +16,7 @@ class YOLOv5(BaseTrainArtifacts):
15
16
  self._weights_ext = ".pt"
16
17
  self._config_file = None
17
18
  self._pattern = re_compile(r"^/yolov5_train/[^/]+/\d+/?$")
19
+ self._available_task_types: List[str] = ["object detection"]
18
20
 
19
21
  def get_task_id(self, artifacts_folder: str) -> str:
20
22
  return artifacts_folder.split("/")[-1]
@@ -43,3 +45,4 @@ class YOLOv5v2(YOLOv5):
43
45
  self._weights_ext = ".pt"
44
46
  self._config_file = None
45
47
  self._pattern = re_compile(r"^/yolov5_2.0_train/[^/]+/\d+/?$")
48
+ self._available_task_types: List[str] = ["object detection"]