supervisely 6.73.243__py3-none-any.whl → 6.73.244__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. supervisely/__init__.py +1 -1
  2. supervisely/_utils.py +18 -0
  3. supervisely/app/widgets/__init__.py +1 -0
  4. supervisely/app/widgets/card/card.py +3 -0
  5. supervisely/app/widgets/classes_table/classes_table.py +15 -1
  6. supervisely/app/widgets/custom_models_selector/custom_models_selector.py +25 -7
  7. supervisely/app/widgets/custom_models_selector/template.html +1 -1
  8. supervisely/app/widgets/experiment_selector/__init__.py +0 -0
  9. supervisely/app/widgets/experiment_selector/experiment_selector.py +500 -0
  10. supervisely/app/widgets/experiment_selector/style.css +27 -0
  11. supervisely/app/widgets/experiment_selector/template.html +82 -0
  12. supervisely/app/widgets/pretrained_models_selector/pretrained_models_selector.py +25 -3
  13. supervisely/app/widgets/random_splits_table/random_splits_table.py +41 -17
  14. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +12 -5
  15. supervisely/app/widgets/train_val_splits/train_val_splits.py +99 -10
  16. supervisely/app/widgets/tree_select/tree_select.py +2 -0
  17. supervisely/nn/__init__.py +3 -1
  18. supervisely/nn/artifacts/artifacts.py +10 -0
  19. supervisely/nn/artifacts/detectron2.py +2 -0
  20. supervisely/nn/artifacts/hrda.py +3 -0
  21. supervisely/nn/artifacts/mmclassification.py +2 -0
  22. supervisely/nn/artifacts/mmdetection.py +6 -3
  23. supervisely/nn/artifacts/mmsegmentation.py +2 -0
  24. supervisely/nn/artifacts/ritm.py +3 -1
  25. supervisely/nn/artifacts/rtdetr.py +2 -0
  26. supervisely/nn/artifacts/unet.py +2 -0
  27. supervisely/nn/artifacts/yolov5.py +3 -0
  28. supervisely/nn/artifacts/yolov8.py +7 -1
  29. supervisely/nn/experiments.py +113 -0
  30. supervisely/nn/inference/gui/__init__.py +3 -1
  31. supervisely/nn/inference/gui/gui.py +31 -232
  32. supervisely/nn/inference/gui/serving_gui.py +223 -0
  33. supervisely/nn/inference/gui/serving_gui_template.py +240 -0
  34. supervisely/nn/inference/inference.py +225 -24
  35. supervisely/nn/training/__init__.py +0 -0
  36. supervisely/nn/training/gui/__init__.py +1 -0
  37. supervisely/nn/training/gui/classes_selector.py +100 -0
  38. supervisely/nn/training/gui/gui.py +539 -0
  39. supervisely/nn/training/gui/hyperparameters_selector.py +117 -0
  40. supervisely/nn/training/gui/input_selector.py +70 -0
  41. supervisely/nn/training/gui/model_selector.py +95 -0
  42. supervisely/nn/training/gui/train_val_splits_selector.py +200 -0
  43. supervisely/nn/training/gui/training_logs.py +93 -0
  44. supervisely/nn/training/gui/training_process.py +114 -0
  45. supervisely/nn/training/gui/utils.py +128 -0
  46. supervisely/nn/training/loggers/__init__.py +0 -0
  47. supervisely/nn/training/loggers/base_train_logger.py +58 -0
  48. supervisely/nn/training/loggers/tensorboard_logger.py +46 -0
  49. supervisely/nn/training/train_app.py +2038 -0
  50. supervisely/nn/utils.py +5 -0
  51. {supervisely-6.73.243.dist-info → supervisely-6.73.244.dist-info}/METADATA +3 -1
  52. {supervisely-6.73.243.dist-info → supervisely-6.73.244.dist-info}/RECORD +56 -34
  53. {supervisely-6.73.243.dist-info → supervisely-6.73.244.dist-info}/LICENSE +0 -0
  54. {supervisely-6.73.243.dist-info → supervisely-6.73.244.dist-info}/WHEEL +0 -0
  55. {supervisely-6.73.243.dist-info → supervisely-6.73.244.dist-info}/entry_points.txt +0 -0
  56. {supervisely-6.73.243.dist-info → supervisely-6.73.244.dist-info}/top_level.txt +0 -0
supervisely/__init__.py CHANGED
@@ -309,4 +309,4 @@ except Exception as e:
309
309
  # If new changes in Supervisely Python SDK require upgrade of the Supervisely instance
310
310
  # set a new value for the environment variable MINIMUM_INSTANCE_VERSION_FOR_SDK, otherwise
311
311
  # users can face compatibility issues, if the instance version is lower than the SDK version.
312
- os.environ["MINIMUM_INSTANCE_VERSION_FOR_SDK"] = "6.12.5"
312
+ os.environ["MINIMUM_INSTANCE_VERSION_FOR_SDK"] = "6.12.12"
supervisely/_utils.py CHANGED
@@ -17,6 +17,7 @@ from tempfile import gettempdir
17
17
  from typing import Any, Dict, List, Literal, Optional, Tuple
18
18
 
19
19
  import numpy as np
20
+ import requests
20
21
  from requests.utils import DEFAULT_CA_BUNDLE_PATH
21
22
 
22
23
  from supervisely.io import env as sly_env
@@ -459,3 +460,20 @@ def get_or_create_event_loop() -> asyncio.AbstractEventLoop:
459
460
  loop = asyncio.new_event_loop()
460
461
  asyncio.set_event_loop(loop)
461
462
  return loop
463
+
464
+
465
+ def get_filename_from_headers(url):
466
+ try:
467
+ response = requests.head(url, allow_redirects=True)
468
+ if response.status_code >= 400 or "Content-Disposition" not in response.headers:
469
+ response = requests.get(url, stream=True)
470
+ content_disposition = response.headers.get("Content-Disposition")
471
+ if content_disposition:
472
+ filename = re.findall('filename="?([^"]+)"?', content_disposition)
473
+ if filename:
474
+ return filename[0]
475
+ filename = url.split("/")[-1] or "downloaded_file"
476
+ return filename
477
+ except Exception as e:
478
+ print(f"Error retrieving file name from headers: {e}")
479
+ return None
@@ -147,3 +147,4 @@ from supervisely.app.widgets.tree_select.tree_select import TreeSelect
147
147
  from supervisely.app.widgets.select_dataset_tree.select_dataset_tree import SelectDatasetTree
148
148
  from supervisely.app.widgets.grid_gallery_v2.grid_gallery_v2 import GridGalleryV2
149
149
  from supervisely.app.widgets.report_thumbnail.report_thumbnail import ReportThumbnail
150
+ from supervisely.app.widgets.experiment_selector.experiment_selector import ExperimentSelector
@@ -125,6 +125,9 @@ class Card(Widget):
125
125
  StateJson()[self.widget_id]["collapsed"] = self._collapsed
126
126
  StateJson().send_changes()
127
127
 
128
+ def is_collapsed(self) -> bool:
129
+ return StateJson()[self.widget_id]["collapsed"]
130
+
128
131
  def lock(self, message: Optional[str] = None) -> None:
129
132
  """Locks the card, changes the lock message if specified.
130
133
 
@@ -283,7 +283,9 @@ class ClassesTable(Widget):
283
283
  StateJson().send_changes()
284
284
  self.loading = False
285
285
 
286
- def read_project_from_id(self, project_id: int, dataset_ids: Optional[List[int]] = None) -> None:
286
+ def read_project_from_id(
287
+ self, project_id: int, dataset_ids: Optional[List[int]] = None
288
+ ) -> None:
287
289
  """Read remote project by id and update table data.
288
290
 
289
291
  :param project_id: Project id from which classes will be taken.
@@ -469,3 +471,15 @@ class ClassesTable(Widget):
469
471
  StateJson()[self.widget_id]["global_checkbox"] = self._global_checkbox
470
472
  StateJson()[self.widget_id]["checkboxes"] = self._checkboxes
471
473
  StateJson().send_changes()
474
+
475
+ def set_dataset_ids(self, dataset_ids: List[int]) -> None:
476
+ """Sets dataset ids to filter classes.
477
+
478
+ :param dataset_ids: List of dataset ids to filter classes.
479
+ :type dataset_ids: List[int]
480
+ """
481
+ selected_classes = self.get_selected_classes()
482
+ self._dataset_ids = dataset_ids
483
+ self._update_meta(self._project_meta)
484
+ self.update_data()
485
+ self.select_classes(selected_classes)
@@ -209,11 +209,15 @@ class CustomModelsSelector(Widget):
209
209
  for checkpoint_info in self._checkpoints:
210
210
  if isinstance(checkpoint_info, dict):
211
211
  checkpoint_selector_items.append(
212
- Select.Item(value=checkpoint_info["path"], label=checkpoint_info["name"])
212
+ Select.Item(
213
+ value=checkpoint_info["path"], label=checkpoint_info["name"]
214
+ )
213
215
  )
214
216
  elif isinstance(checkpoint_info, FileInfo):
215
217
  checkpoint_selector_items.append(
216
- Select.Item(value=checkpoint_info.path, label=checkpoint_info.name)
218
+ Select.Item(
219
+ value=checkpoint_info.path, label=checkpoint_info.name
220
+ )
217
221
  )
218
222
 
219
223
  checkpoint_selector = Select(items=checkpoint_selector_items)
@@ -278,7 +282,9 @@ class CustomModelsSelector(Widget):
278
282
  )
279
283
 
280
284
  file_api = FileApi(self._api)
281
- self._model_path_input = Input(placeholder="Path to model file in Team Files")
285
+ self._model_path_input = Input(
286
+ placeholder="Path to model file in Team Files"
287
+ )
282
288
 
283
289
  @self._model_path_input.value_changed
284
290
  def change_folder(value):
@@ -316,7 +322,9 @@ class CustomModelsSelector(Widget):
316
322
 
317
323
  self.custom_tab_widgets.hide()
318
324
 
319
- self.show_custom_checkpoint_path_checkbox = Checkbox("Use custom checkpoint", False)
325
+ self.show_custom_checkpoint_path_checkbox = Checkbox(
326
+ "Use custom checkpoint", False
327
+ )
320
328
 
321
329
  @self.show_custom_checkpoint_path_checkbox.value_changed
322
330
  def show_custom_checkpoint_path_checkbox_changed(is_checked):
@@ -391,7 +399,9 @@ class CustomModelsSelector(Widget):
391
399
  self.disable_table()
392
400
  super().disable()
393
401
 
394
- def _generate_table_rows(self, train_infos: List[TrainInfo]) -> Dict[str, List[ModelRow]]:
402
+ def _generate_table_rows(
403
+ self, train_infos: List[TrainInfo]
404
+ ) -> Dict[str, List[ModelRow]]:
395
405
  """Method to generate table rows from remote path to training app save directory"""
396
406
 
397
407
  def process_train_info(train_info):
@@ -438,7 +448,8 @@ class CustomModelsSelector(Widget):
438
448
  if "pose estimation" in task_types:
439
449
  sorted_tt.append("pose estimation")
440
450
  other_tasks = sorted(
441
- set(task_types) - set(["object detection", "instance segmentation", "pose estimation"])
451
+ set(task_types)
452
+ - set(["object detection", "instance segmentation", "pose estimation"])
442
453
  )
443
454
  sorted_tt.extend(other_tasks)
444
455
  return sorted_tt
@@ -484,11 +495,16 @@ class CustomModelsSelector(Widget):
484
495
  "checkpoint_url": checkpoint_url,
485
496
  }
486
497
 
498
+ # if model_name is not None:
499
+ # model_params["model_name"] = model_name
500
+
487
501
  if config_path is not None:
488
502
  model_params["config_url"] = config_path
489
503
 
490
504
  return model_params
491
505
 
506
+ # def get_selected_model_params_v2(self) -> Union[Dict, None]:
507
+
492
508
  def set_active_row(self, row_index: int) -> None:
493
509
  if row_index < 0 or row_index > len(self._rows) - 1:
494
510
  raise ValueError(f'Row with index "{row_index}" does not exist')
@@ -520,7 +536,9 @@ class CustomModelsSelector(Widget):
520
536
 
521
537
  def set_custom_checkpoint_task_type(self, task_type: str) -> None:
522
538
  if self.use_custom_checkpoint_path():
523
- available_task_types = self.custom_checkpoint_task_type_selector.get_labels()
539
+ available_task_types = (
540
+ self.custom_checkpoint_task_type_selector.get_labels()
541
+ )
524
542
  if task_type not in available_task_types:
525
543
  raise ValueError(f'"{task_type}" is not available task type')
526
544
  self.custom_checkpoint_task_type_selector.set_value(task_type)
@@ -6,7 +6,7 @@
6
6
  {% endif %}
7
7
  >
8
8
 
9
- <div v-if="data.{{{widget.widget_id}}}.rowsHtml.length === 0"> You don't have any custom models</div>
9
+ <div v-if="Object.keys(data.{{{widget.widget_id}}}.rowsHtml).length === 0"> You don't have any custom models</div>
10
10
  <div v-else>
11
11
 
12
12
  <div v-if="data.{{{widget.widget_id}}}.taskTypes.length > 1">
@@ -0,0 +1,500 @@
1
+ import os
2
+ from collections import defaultdict
3
+ from concurrent.futures import ThreadPoolExecutor, as_completed
4
+ from typing import Any, Callable, Dict, List, Union
5
+
6
+ from supervisely import env, logger
7
+ from supervisely._utils import abs_url, is_development
8
+ from supervisely.api.api import Api
9
+ from supervisely.api.project_api import ProjectInfo
10
+ from supervisely.app.content import DataJson, StateJson
11
+ from supervisely.app.widgets import (
12
+ Container,
13
+ Flexbox,
14
+ ProjectThumbnail,
15
+ Select,
16
+ Text,
17
+ Widget,
18
+ )
19
+ from supervisely.io.fs import get_file_name_with_ext
20
+ from supervisely.nn.experiments import ExperimentInfo
21
+
22
+ WEIGHTS_DIR = "weights"
23
+
24
+ COL_ID = "task id".upper()
25
+ COL_MODEL = "model".upper()
26
+ COL_PROJECT = "training data".upper()
27
+ COL_CHECKPOINTS = "checkpoints".upper()
28
+ COL_SESSION = "session".upper()
29
+ COL_BENCHMARK = "benchmark".upper()
30
+
31
+ columns = [COL_ID, COL_MODEL, COL_PROJECT, COL_CHECKPOINTS, COL_SESSION, COL_BENCHMARK]
32
+
33
+
34
+ class ExperimentSelector(Widget):
35
+ class Routes:
36
+ TASK_TYPE_CHANGED = "task_type_changed"
37
+ VALUE_CHANGED = "value_changed"
38
+
39
+ class ModelRow:
40
+ def __init__(
41
+ self,
42
+ api: Api,
43
+ team_id: int,
44
+ task_type: str,
45
+ experiment_info: ExperimentInfo,
46
+ ):
47
+ self._api = api
48
+ self._team_id = team_id
49
+ self._task_type = task_type
50
+ self._experiment_info = experiment_info
51
+
52
+ task_id = experiment_info.task_id
53
+ if task_id == "debug-session":
54
+ pass
55
+ elif type(task_id) is str:
56
+ if task_id.isdigit():
57
+ task_id = int(task_id)
58
+ else:
59
+ raise ValueError(f"Task id {task_id} is not a number")
60
+
61
+ # col 1 task
62
+ self._task_id = task_id
63
+ self._task_path = experiment_info.artifacts_dir
64
+ self._task_date = experiment_info.datetime
65
+ self._task_link = self._create_task_link()
66
+ self._config_path = experiment_info.model_files.get("config")
67
+ if self._config_path is not None:
68
+ self._config_path = os.path.join(experiment_info.artifacts_dir, self._config_path)
69
+
70
+ # col 2 model
71
+ self._model_name = experiment_info.model_name
72
+
73
+ # col 3 project
74
+ self._training_project_id = experiment_info.project_id
75
+ self._training_project_info = self._api.project.get_info_by_id(
76
+ self._training_project_id
77
+ )
78
+
79
+ # col 4 checkpoints
80
+ self._checkpoints = experiment_info.checkpoints
81
+
82
+ self._checkpoints_names = []
83
+ self._checkpoints_paths = []
84
+ for checkpoint_path in self._checkpoints:
85
+ self._checkpoints_names.append(get_file_name_with_ext(checkpoint_path))
86
+ self._checkpoints_paths.append(
87
+ os.path.join(experiment_info.artifacts_dir, checkpoint_path)
88
+ )
89
+
90
+ # col 5 session
91
+ self._session_link = self._generate_session_link()
92
+
93
+ # col 6 benchmark report
94
+ self._benchmark_report = None # experiment_infos.benchmark_report_path
95
+
96
+ # widgets
97
+ self._task_widget = self._create_task_widget()
98
+ self._model_wiget = self._create_model_widget()
99
+ self._training_project_widget = self._create_training_project_widget()
100
+ self._checkpoints_widget = self._create_checkpoints_widget()
101
+ self._session_widget = self._create_session_widget()
102
+ self._benchmark_widget = self._create_benchmark_widget()
103
+
104
+ @property
105
+ def task_id(self) -> int:
106
+ return self._task_id
107
+
108
+ @property
109
+ def task_date(self) -> str:
110
+ return self._task_date
111
+
112
+ @property
113
+ def task_link(self) -> str:
114
+ return self._task_link
115
+
116
+ @property
117
+ def task_type(self) -> str:
118
+ return self._task_type
119
+
120
+ @property
121
+ def training_project_info(self) -> ProjectInfo:
122
+ return self._training_project_info
123
+
124
+ @property
125
+ def checkpoints_names(self) -> List[str]:
126
+ return self._checkpoints_names
127
+
128
+ @property
129
+ def checkpoints_paths(self) -> List[str]:
130
+ return self._checkpoints_paths
131
+
132
+ @property
133
+ def checkpoints_selector(self) -> Select:
134
+ return self._checkpoints_widget
135
+
136
+ @property
137
+ def session_link(self) -> str:
138
+ return self._session_link
139
+
140
+ @property
141
+ def config_path(self) -> str:
142
+ return self._config_path
143
+
144
+ def get_selected_checkpoint_path(self) -> str:
145
+ return self._checkpoints_widget.get_value()
146
+
147
+ def get_selected_checkpoint_name(self) -> str:
148
+ return self._checkpoints_widget.get_label()
149
+
150
+ def set_selected_checkpoint_by_name(self, checkpoint_name: str):
151
+ for i, name in enumerate(self._checkpoints_names):
152
+ if name == checkpoint_name:
153
+ self._checkpoints_widget.set_value(self._checkpoints_paths[i])
154
+ return
155
+
156
+ def set_selected_checkpoint_by_path(self, checkpoint_path: str):
157
+ for i, path in enumerate(self._checkpoints_paths):
158
+ if path == checkpoint_path:
159
+ self._checkpoints_widget.set_value(path)
160
+ return
161
+
162
+ def to_html(self) -> List[str]:
163
+ return [
164
+ f"<div> {self._task_widget.to_html()} </div>",
165
+ f"<div> {self._model_wiget.to_html()} </div>",
166
+ f"<div> {self._training_project_widget.to_html()} </div>",
167
+ f"<div> {self._checkpoints_widget.to_html()} </div>",
168
+ f"<div> {self._session_widget.to_html()} </div>",
169
+ f"<div> {self._benchmark_widget.to_html()} </div>",
170
+ ]
171
+
172
+ def _create_task_link(self) -> str:
173
+ remote_path = os.path.join(self._task_path, "open_app.lnk")
174
+ task_file = self._api.file.get_info_by_path(self._team_id, remote_path)
175
+ if task_file is not None:
176
+ if is_development():
177
+ return abs_url(f"/files/{task_file.id}")
178
+ else:
179
+ return f"/files/{task_file.id}"
180
+ else:
181
+ return ""
182
+
183
+ def _generate_session_link(self) -> str:
184
+ if is_development():
185
+ session_link = abs_url(f"/apps/sessions/{self._task_id}")
186
+ else:
187
+ session_link = f"/apps/sessions/{self._task_id}"
188
+ return session_link
189
+
190
+ def _create_task_widget(self) -> Flexbox:
191
+ task_widget = Container(
192
+ [
193
+ Text(
194
+ f"<i class='zmdi zmdi-folder' style='color: #7f858e'></i> <a href='{self._task_link}'>{self._task_id}</a>",
195
+ "text",
196
+ ),
197
+ Text(
198
+ f"<span class='field-description text-muted' style='color: #7f858e'>{self._task_date}</span>",
199
+ "text",
200
+ font_size=13,
201
+ ),
202
+ ],
203
+ gap=0,
204
+ )
205
+ return task_widget
206
+
207
+ def _create_model_widget(self) -> Text:
208
+ if self._model_name is None:
209
+ self._model_name = "Unknown model"
210
+
211
+ model_widget = Text(
212
+ f"<span class='field-description text-muted' style='color: #7f858e'>{self._model_name}</span>",
213
+ "text",
214
+ font_size=13,
215
+ )
216
+ return model_widget
217
+
218
+ def _create_training_project_widget(self) -> Union[ProjectThumbnail, Text]:
219
+ if self.training_project_info is not None:
220
+ training_project_widget = ProjectThumbnail(
221
+ self._training_project_info, remove_margins=True
222
+ )
223
+ else:
224
+ training_project_widget = Text(
225
+ f"<span class='field-description text-muted' style='color: #7f858e'>Project was deleted</span>",
226
+ "text",
227
+ font_size=13,
228
+ )
229
+ return training_project_widget
230
+
231
+ def _create_checkpoints_widget(self) -> Select:
232
+ checkpoint_selector_items = []
233
+ for path, name in zip(self._checkpoints_paths, self._checkpoints_names):
234
+ checkpoint_selector_items.append(Select.Item(value=path, label=name))
235
+ checkpoint_selector = Select(items=checkpoint_selector_items)
236
+ return checkpoint_selector
237
+
238
+ def _create_session_widget(self) -> Text:
239
+ session_link_widget = Text(
240
+ f"<a href='{self._session_link}'>Preview</a> <i class='zmdi zmdi-open-in-new'></i>",
241
+ "text",
242
+ )
243
+ return session_link_widget
244
+
245
+ def _create_benchmark_widget(self) -> Text:
246
+ if self._benchmark_report is None:
247
+ self._benchmark_report = "No benchmark report available"
248
+ benchmark_widget = Text(
249
+ "<span class='field-description text-muted' style='color: #7f858e'>No benchmark report available</span>",
250
+ "text",
251
+ font_size=13,
252
+ )
253
+ else:
254
+ benchmark_widget = Text(
255
+ f"<a href='{self._benchmark_report}'>Benchmark Report</a> <i class='zmdi zmdi-chart'></i>",
256
+ "text",
257
+ )
258
+ return benchmark_widget
259
+
260
+ def __init__(
261
+ self,
262
+ team_id: int,
263
+ experiment_infos: List[ExperimentInfo] = [],
264
+ widget_id: str = None,
265
+ ):
266
+ self._api = Api.from_env()
267
+
268
+ self._team_id = team_id
269
+ self.__debug_row = None
270
+
271
+ with ThreadPoolExecutor() as executor:
272
+ future = executor.submit(self._generate_table_rows, experiment_infos)
273
+ table_rows = future.result()
274
+
275
+ self._columns = columns
276
+ self._rows = table_rows
277
+ # self._rows_html = #[row.to_html() for row in self._rows]
278
+
279
+ task_types = [task_type for task_type in table_rows]
280
+ self._rows_html = defaultdict(list)
281
+ for task_type in table_rows:
282
+ self._rows_html[task_type].extend(
283
+ [model_row.to_html() for model_row in table_rows[task_type]]
284
+ )
285
+
286
+ self._task_types = self._filter_task_types(task_types)
287
+ if len(self._task_types) == 0:
288
+ self.__default_selected_task_type = None
289
+ else:
290
+ self.__default_selected_task_type = self._task_types[0]
291
+
292
+ self._changes_handled = False
293
+ self._task_type_changes_handled = False
294
+ super().__init__(widget_id=widget_id, file_path=__file__)
295
+
296
+ @property
297
+ def columns(self) -> List[str]:
298
+ return self._columns
299
+
300
+ @property
301
+ def rows(self) -> Dict[str, List[ModelRow]]:
302
+ return self._rows
303
+
304
+ def get_json_data(self) -> Dict:
305
+ return {
306
+ "columns": self._columns,
307
+ "rowsHtml": self._rows_html,
308
+ "taskTypes": self._task_types,
309
+ }
310
+
311
+ def get_json_state(self) -> Dict:
312
+ return {
313
+ "selectedRow": 0,
314
+ "selectedTaskType": self.__default_selected_task_type,
315
+ }
316
+
317
+ def set_active_task_type(self, task_type: str):
318
+ if task_type not in self._task_types:
319
+ raise ValueError(f'Task Type "{task_type}" does not exist')
320
+ StateJson()[self.widget_id]["selectedTaskType"] = task_type
321
+ StateJson().send_changes()
322
+
323
+ def get_available_task_types(self) -> List[str]:
324
+ return self._task_types
325
+
326
+ def disable_table(self) -> None:
327
+ for task_type in self._rows:
328
+ for row in self._rows[task_type]:
329
+ row.checkpoints_selector.disable()
330
+ super().disable()
331
+
332
+ def enable_table(self) -> None:
333
+ for task_type in self._rows:
334
+ for row in self._rows[task_type]:
335
+ row.checkpoints_selector.enable()
336
+ super().enable()
337
+
338
+ def enable(self):
339
+ self.enable_table()
340
+ super().enable()
341
+
342
+ def disable(self) -> None:
343
+ self.disable_table()
344
+ super().disable()
345
+
346
+ def _generate_table_rows(
347
+ self, experiment_infos: List[ExperimentInfo]
348
+ ) -> Dict[str, List[ModelRow]]:
349
+ """Method to generate table rows from remote path to training app save directory"""
350
+
351
+ def process_experiment_info(experiment_info: ExperimentInfo):
352
+ try:
353
+ model_row = ExperimentSelector.ModelRow(
354
+ api=self._api,
355
+ team_id=self._team_id,
356
+ task_type=experiment_info.task_type,
357
+ experiment_info=experiment_info,
358
+ )
359
+ return experiment_info.task_type, model_row
360
+ except Exception as e:
361
+ logger.warn(f"Failed to process experiment info: {experiment_info}")
362
+ return None, None
363
+
364
+ table_rows = defaultdict(list)
365
+ with ThreadPoolExecutor() as executor:
366
+ futures = {
367
+ executor.submit(process_experiment_info, experiment_info): experiment_info
368
+ for experiment_info in experiment_infos
369
+ }
370
+
371
+ for future in as_completed(futures):
372
+ result = future.result()
373
+ if result:
374
+ task_type, model_row = result
375
+ if task_type is not None and model_row is not None:
376
+ if model_row.task_id == "debug-session":
377
+ self.__debug_row = (task_type, model_row)
378
+ continue
379
+ table_rows[task_type].append(model_row)
380
+ self._sort_table_rows(table_rows)
381
+ if self.__debug_row and is_development():
382
+ task_type, model_row = self.__debug_row
383
+ table_rows[task_type].insert(0, model_row)
384
+ return table_rows
385
+
386
+ def _sort_table_rows(self, table_rows: Dict[str, List[ModelRow]]) -> None:
387
+ for task_type in table_rows:
388
+ table_rows[task_type].sort(key=lambda row: int(row.task_id), reverse=True)
389
+
390
+ def _filter_task_types(self, task_types: List[str]):
391
+ sorted_tt = []
392
+ if "object detection" in task_types:
393
+ sorted_tt.append("object detection")
394
+ if "instance segmentation" in task_types:
395
+ sorted_tt.append("instance segmentation")
396
+ if "pose estimation" in task_types:
397
+ sorted_tt.append("pose estimation")
398
+ other_tasks = sorted(
399
+ set(task_types)
400
+ - set(
401
+ [
402
+ "object detection",
403
+ "instance segmentation",
404
+ "semantic segmentation",
405
+ "pose estimation",
406
+ ]
407
+ )
408
+ )
409
+ sorted_tt.extend(other_tasks)
410
+ return sorted_tt
411
+
412
+ def get_selected_row(self, state=StateJson()) -> Union[ModelRow, None]:
413
+ if len(self._rows) == 0:
414
+ return
415
+ widget_actual_state = state[self.widget_id]
416
+ widget_actual_data = DataJson()[self.widget_id]
417
+ task_type = widget_actual_state["selectedTaskType"]
418
+ if widget_actual_state is not None and widget_actual_data is not None:
419
+ selected_row_index = int(widget_actual_state["selectedRow"])
420
+ return self._rows[task_type][selected_row_index]
421
+
422
+ def get_selected_row_index(self, state=StateJson()) -> Union[int, None]:
423
+ widget_actual_state = state[self.widget_id]
424
+ widget_actual_data = DataJson()[self.widget_id]
425
+ if widget_actual_state is not None and widget_actual_data is not None:
426
+ return widget_actual_state["selectedRow"]
427
+
428
+ def get_selected_task_type(self) -> str:
429
+ return StateJson()[self.widget_id]["selectedTaskType"]
430
+
431
+ def get_selected_experiment_info(self) -> Dict[str, Any]:
432
+ if len(self._rows) == 0:
433
+ return
434
+ selected_row = self.get_selected_row()
435
+ selected_row_json = selected_row._experiment_info._asdict()
436
+ return selected_row_json
437
+
438
+ def get_selected_checkpoint_path(self) -> str:
439
+ if len(self._rows) == 0:
440
+ return
441
+ selected_row = self.get_selected_row()
442
+ return selected_row.get_selected_checkpoint_path()
443
+
444
+ def get_model_files(self) -> Dict[str, str]:
445
+ """
446
+ Returns a dictionary with full paths to model files in Supervisely Team Files.
447
+ """
448
+ experiment_info = self.get_selected_experiment_info()
449
+ artifacts_dir = experiment_info.get("artifacts_dir")
450
+ model_files = experiment_info.get("model_files", {})
451
+
452
+ full_model_files = {
453
+ name: os.path.join(artifacts_dir, file) for name, file in model_files.items()
454
+ }
455
+ full_model_files["checkpoint"] = self.get_selected_checkpoint_path()
456
+ return full_model_files
457
+
458
+ def set_active_row(self, row_index: int) -> None:
459
+ if row_index < 0 or row_index > len(self._rows) - 1:
460
+ raise ValueError(f'Row with index "{row_index}" does not exist')
461
+ StateJson()[self.widget_id]["selectedRow"] = row_index
462
+ StateJson().send_changes()
463
+
464
+ def set_by_task_id(self, task_id: int) -> None:
465
+ for task_type in self._rows:
466
+ for i, row in enumerate(self._rows[task_type]):
467
+ if row.task_id == task_id:
468
+ self.set_active_row(i)
469
+ return
470
+
471
+ def get_by_task_id(self, task_id: int) -> Union[ModelRow, None]:
472
+ for task_type in self._rows:
473
+ for row in self._rows[task_type]:
474
+ if row.task_id == task_id:
475
+ return row
476
+ return None
477
+
478
+ def task_type_changed(self, func: Callable):
479
+ route_path = self.get_route_path(ExperimentSelector.Routes.TASK_TYPE_CHANGED)
480
+ server = self._sly_app.get_server()
481
+ self._task_type_changes_handled = True
482
+
483
+ @server.post(route_path)
484
+ def _task_type_changed():
485
+ res = self.get_selected_task_type()
486
+ func(res)
487
+
488
+ return _task_type_changed
489
+
490
+ def value_changed(self, func: Callable):
491
+ route_path = self.get_route_path(ExperimentSelector.Routes.VALUE_CHANGED)
492
+ server = self._sly_app.get_server()
493
+ self._changes_handled = True
494
+
495
+ @server.post(route_path)
496
+ def _value_changed():
497
+ res = self.get_selected_row()
498
+ func(res)
499
+
500
+ return _value_changed