supervisely 6.73.420__py3-none-any.whl → 6.73.422__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/api/api.py +10 -5
- supervisely/api/app_api.py +71 -4
- supervisely/api/module_api.py +4 -0
- supervisely/api/nn/deploy_api.py +15 -9
- supervisely/api/nn/ecosystem_models_api.py +201 -0
- supervisely/api/nn/neural_network_api.py +12 -3
- supervisely/api/project_api.py +35 -6
- supervisely/api/task_api.py +5 -1
- supervisely/app/widgets/__init__.py +8 -1
- supervisely/app/widgets/agent_selector/template.html +1 -0
- supervisely/app/widgets/deploy_model/__init__.py +0 -0
- supervisely/app/widgets/deploy_model/deploy_model.py +729 -0
- supervisely/app/widgets/dropdown_checkbox_selector/__init__.py +0 -0
- supervisely/app/widgets/dropdown_checkbox_selector/dropdown_checkbox_selector.py +87 -0
- supervisely/app/widgets/dropdown_checkbox_selector/template.html +12 -0
- supervisely/app/widgets/ecosystem_model_selector/__init__.py +0 -0
- supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +190 -0
- supervisely/app/widgets/experiment_selector/experiment_selector.py +447 -264
- supervisely/app/widgets/fast_table/fast_table.py +402 -74
- supervisely/app/widgets/fast_table/script.js +364 -96
- supervisely/app/widgets/fast_table/style.css +24 -0
- supervisely/app/widgets/fast_table/template.html +43 -3
- supervisely/app/widgets/radio_table/radio_table.py +10 -2
- supervisely/app/widgets/select/select.py +6 -4
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +18 -0
- supervisely/app/widgets/tabs/tabs.py +22 -6
- supervisely/app/widgets/tabs/template.html +5 -1
- supervisely/nn/artifacts/__init__.py +1 -1
- supervisely/nn/artifacts/artifacts.py +10 -2
- supervisely/nn/artifacts/detectron2.py +1 -0
- supervisely/nn/artifacts/hrda.py +1 -0
- supervisely/nn/artifacts/mmclassification.py +20 -0
- supervisely/nn/artifacts/mmdetection.py +5 -3
- supervisely/nn/artifacts/mmsegmentation.py +1 -0
- supervisely/nn/artifacts/ritm.py +1 -0
- supervisely/nn/artifacts/rtdetr.py +1 -0
- supervisely/nn/artifacts/unet.py +1 -0
- supervisely/nn/artifacts/utils.py +3 -0
- supervisely/nn/artifacts/yolov5.py +2 -0
- supervisely/nn/artifacts/yolov8.py +1 -0
- supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +18 -18
- supervisely/nn/experiments.py +9 -0
- supervisely/nn/inference/gui/serving_gui_template.py +39 -13
- supervisely/nn/inference/inference.py +160 -94
- supervisely/nn/inference/predict_app/__init__.py +0 -0
- supervisely/nn/inference/predict_app/gui/__init__.py +0 -0
- supervisely/nn/inference/predict_app/gui/classes_selector.py +91 -0
- supervisely/nn/inference/predict_app/gui/gui.py +710 -0
- supervisely/nn/inference/predict_app/gui/input_selector.py +165 -0
- supervisely/nn/inference/predict_app/gui/model_selector.py +79 -0
- supervisely/nn/inference/predict_app/gui/output_selector.py +139 -0
- supervisely/nn/inference/predict_app/gui/preview.py +93 -0
- supervisely/nn/inference/predict_app/gui/settings_selector.py +184 -0
- supervisely/nn/inference/predict_app/gui/tags_selector.py +110 -0
- supervisely/nn/inference/predict_app/gui/utils.py +282 -0
- supervisely/nn/inference/predict_app/predict_app.py +184 -0
- supervisely/nn/inference/uploader.py +9 -5
- supervisely/nn/model/prediction.py +2 -0
- supervisely/nn/model/prediction_session.py +20 -3
- supervisely/nn/training/gui/gui.py +131 -44
- supervisely/nn/training/gui/model_selector.py +8 -6
- supervisely/nn/training/gui/train_val_splits_selector.py +122 -70
- supervisely/nn/training/gui/training_artifacts.py +0 -5
- supervisely/nn/training/train_app.py +161 -44
- supervisely/template/experiment/experiment.html.jinja +74 -17
- supervisely/template/experiment/experiment_generator.py +258 -112
- supervisely/template/experiment/header.html.jinja +31 -13
- supervisely/template/experiment/sly-style.css +7 -2
- {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/METADATA +3 -1
- {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/RECORD +74 -56
- supervisely/app/widgets/experiment_selector/style.css +0 -27
- supervisely/app/widgets/experiment_selector/template.html +0 -61
- {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/LICENSE +0 -0
- {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/WHEEL +0 -0
- {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/top_level.txt +0 -0
|
@@ -1,42 +1,52 @@
|
|
|
1
|
+
import json
|
|
1
2
|
import os
|
|
2
|
-
from
|
|
3
|
-
from
|
|
4
|
-
from
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
from supervisely
|
|
9
|
-
from supervisely.
|
|
3
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
4
|
+
from functools import partial
|
|
5
|
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
|
6
|
+
|
|
7
|
+
import pandas as pd
|
|
8
|
+
|
|
9
|
+
from supervisely import batched
|
|
10
|
+
from supervisely._utils import abs_url, is_development, logger
|
|
11
|
+
from supervisely.api.api import Api, ApiField
|
|
10
12
|
from supervisely.api.project_api import ProjectInfo
|
|
11
|
-
from supervisely.app.
|
|
12
|
-
from supervisely.app.widgets import
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
ProjectThumbnail,
|
|
16
|
-
Select,
|
|
17
|
-
Text,
|
|
18
|
-
Widget,
|
|
13
|
+
from supervisely.app.exceptions import show_dialog
|
|
14
|
+
from supervisely.app.widgets.container.container import Container
|
|
15
|
+
from supervisely.app.widgets.dropdown_checkbox_selector.dropdown_checkbox_selector import (
|
|
16
|
+
DropdownCheckboxSelector,
|
|
19
17
|
)
|
|
18
|
+
from supervisely.app.widgets.fast_table.fast_table import FastTable
|
|
19
|
+
from supervisely.app.widgets.flexbox.flexbox import Flexbox
|
|
20
|
+
from supervisely.app.widgets.project_thumbnail.project_thumbnail import ProjectThumbnail
|
|
21
|
+
from supervisely.app.widgets.select.select import Select
|
|
22
|
+
from supervisely.app.widgets.text.text import Text
|
|
23
|
+
from supervisely.app.widgets.widget import Widget
|
|
24
|
+
from supervisely.io import env
|
|
20
25
|
from supervisely.io.fs import get_file_name_with_ext
|
|
21
26
|
from supervisely.nn.experiments import ExperimentInfo
|
|
22
|
-
from supervisely.nn.utils import ModelSource
|
|
23
|
-
|
|
24
|
-
WEIGHTS_DIR = "weights"
|
|
25
|
-
|
|
26
|
-
COL_ID = "task id".upper()
|
|
27
|
-
COL_MODEL = "model".upper()
|
|
28
|
-
COL_PROJECT = "training data".upper()
|
|
29
|
-
COL_CHECKPOINTS = "checkpoints".upper()
|
|
30
|
-
COL_SESSION = "session".upper()
|
|
31
|
-
COL_BENCHMARK = "benchmark".upper()
|
|
32
|
-
|
|
33
|
-
columns = [COL_ID, COL_MODEL, COL_PROJECT, COL_CHECKPOINTS, COL_SESSION, COL_BENCHMARK]
|
|
34
27
|
|
|
35
28
|
|
|
36
29
|
class ExperimentSelector(Widget):
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
30
|
+
"""
|
|
31
|
+
Widget for selecting experiments from a team.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
class COLUMN:
|
|
35
|
+
NAME = "TASK ID"
|
|
36
|
+
MODEL = "MODEL"
|
|
37
|
+
TRAINING_DATA = "TRAINING DATA"
|
|
38
|
+
CHECKPOINTS = "CHECKPOINTS"
|
|
39
|
+
SESSION = "SESSION"
|
|
40
|
+
BENCHMARK = "BENCHMARK"
|
|
41
|
+
|
|
42
|
+
COLUMNS = [
|
|
43
|
+
COLUMN.NAME,
|
|
44
|
+
COLUMN.MODEL,
|
|
45
|
+
COLUMN.TRAINING_DATA,
|
|
46
|
+
COLUMN.CHECKPOINTS,
|
|
47
|
+
COLUMN.SESSION,
|
|
48
|
+
COLUMN.BENCHMARK,
|
|
49
|
+
]
|
|
40
50
|
|
|
41
51
|
class ModelRow:
|
|
42
52
|
def __init__(
|
|
@@ -45,15 +55,19 @@ class ExperimentSelector(Widget):
|
|
|
45
55
|
team_id: int,
|
|
46
56
|
task_type: str,
|
|
47
57
|
experiment_info: ExperimentInfo,
|
|
58
|
+
project_info: Optional[ProjectInfo] = None,
|
|
48
59
|
):
|
|
49
60
|
self._api = api
|
|
50
61
|
self._team_id = team_id
|
|
51
62
|
self._task_type = task_type
|
|
52
63
|
self._experiment_info = experiment_info
|
|
64
|
+
self._project_info = project_info
|
|
53
65
|
|
|
54
66
|
task_id = experiment_info.task_id
|
|
55
|
-
if task_id ==
|
|
67
|
+
if task_id == -1:
|
|
56
68
|
pass
|
|
69
|
+
elif task_id == "debug-session":
|
|
70
|
+
task_id = -1
|
|
57
71
|
elif type(task_id) is str:
|
|
58
72
|
if task_id.isdigit():
|
|
59
73
|
task_id = int(task_id)
|
|
@@ -77,9 +91,7 @@ class ExperimentSelector(Widget):
|
|
|
77
91
|
if self._training_project_id is None:
|
|
78
92
|
self._training_project_info = None
|
|
79
93
|
else:
|
|
80
|
-
self._training_project_info = self.
|
|
81
|
-
self._training_project_id
|
|
82
|
-
)
|
|
94
|
+
self._training_project_info = self._project_info
|
|
83
95
|
|
|
84
96
|
# col 4 checkpoints
|
|
85
97
|
self._checkpoints = experiment_info.checkpoints
|
|
@@ -143,14 +155,6 @@ class ExperimentSelector(Widget):
|
|
|
143
155
|
def checkpoints_selector(self) -> Select:
|
|
144
156
|
return self._checkpoints_widget
|
|
145
157
|
|
|
146
|
-
@property
|
|
147
|
-
def experiment_info(self) -> ExperimentInfo:
|
|
148
|
-
return self._experiment_info
|
|
149
|
-
|
|
150
|
-
@property
|
|
151
|
-
def best_checkpoint(self) -> str:
|
|
152
|
-
return self.experiment_info.best_checkpoint
|
|
153
|
-
|
|
154
158
|
@property
|
|
155
159
|
def session_link(self) -> str:
|
|
156
160
|
return self._session_link
|
|
@@ -234,17 +238,24 @@ class ExperimentSelector(Widget):
|
|
|
234
238
|
return model_widget
|
|
235
239
|
|
|
236
240
|
def _create_training_project_widget(self) -> Union[ProjectThumbnail, Text]:
|
|
241
|
+
training_project_thumbnail = ProjectThumbnail(
|
|
242
|
+
self._training_project_info, remove_margins=True
|
|
243
|
+
)
|
|
244
|
+
training_project_text = Text(
|
|
245
|
+
f"<span class='field-description text-muted' style='color: #7f858e'>Project was deleted</span>",
|
|
246
|
+
"text",
|
|
247
|
+
font_size=13,
|
|
248
|
+
)
|
|
237
249
|
if self.training_project_info is not None:
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
)
|
|
250
|
+
training_project_thumbnail.show()
|
|
251
|
+
training_project_text.hide()
|
|
241
252
|
else:
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
return
|
|
253
|
+
training_project_thumbnail.hide()
|
|
254
|
+
training_project_text.show()
|
|
255
|
+
return Container(widgets=[training_project_thumbnail, training_project_text], gap=0)
|
|
256
|
+
|
|
257
|
+
def checkpoint_changed(self, checkpoint_value: str):
|
|
258
|
+
return
|
|
248
259
|
|
|
249
260
|
def _create_checkpoints_widget(self) -> Select:
|
|
250
261
|
checkpoint_selector_items = []
|
|
@@ -253,6 +264,11 @@ class ExperimentSelector(Widget):
|
|
|
253
264
|
checkpoint_selector = Select(items=checkpoint_selector_items)
|
|
254
265
|
if self._best_checkpoint_value is not None:
|
|
255
266
|
checkpoint_selector.set_value(self._best_checkpoint)
|
|
267
|
+
|
|
268
|
+
@checkpoint_selector.value_changed
|
|
269
|
+
def on_checkpoint_changed(checkpoint_value: str):
|
|
270
|
+
self.checkpoint_changed(checkpoint_value)
|
|
271
|
+
|
|
256
272
|
return checkpoint_selector
|
|
257
273
|
|
|
258
274
|
def _create_session_widget(self) -> Text:
|
|
@@ -284,265 +300,432 @@ class ExperimentSelector(Widget):
|
|
|
284
300
|
)
|
|
285
301
|
return benchmark_widget
|
|
286
302
|
|
|
303
|
+
def _widget_to_cell_value(self, widget: Widget) -> str:
|
|
304
|
+
if isinstance(widget, Container):
|
|
305
|
+
return json.dumps(
|
|
306
|
+
{
|
|
307
|
+
"widget_id": widget.widget_id,
|
|
308
|
+
"widgets": [w.widget_id for w in widget._widgets],
|
|
309
|
+
}
|
|
310
|
+
)
|
|
311
|
+
else:
|
|
312
|
+
return json.dumps({"widget_id": widget.widget_id, "widgets": []})
|
|
313
|
+
|
|
314
|
+
def to_table_row(self):
|
|
315
|
+
return [
|
|
316
|
+
self._widget_to_cell_value(w)
|
|
317
|
+
for w in [
|
|
318
|
+
self._task_widget,
|
|
319
|
+
self._model_wiget,
|
|
320
|
+
self._training_project_widget,
|
|
321
|
+
self._checkpoints_widget,
|
|
322
|
+
self._session_widget,
|
|
323
|
+
self._benchmark_widget,
|
|
324
|
+
]
|
|
325
|
+
]
|
|
326
|
+
|
|
327
|
+
@classmethod
|
|
328
|
+
def widgets_templates(cls):
|
|
329
|
+
checkpoints_template_widget = Select(items=[])
|
|
330
|
+
checkpoints_template_widget.value_changed(lambda _: None)
|
|
331
|
+
|
|
332
|
+
return [
|
|
333
|
+
# _task_widget
|
|
334
|
+
Container(widgets=[Text(""), Text("")], gap=0),
|
|
335
|
+
# _model_wiget
|
|
336
|
+
Text(""),
|
|
337
|
+
# _training_project_widget
|
|
338
|
+
Container(widgets=[ProjectThumbnail(remove_margins=True), Text("")], gap=0),
|
|
339
|
+
# _checkpoints_widget
|
|
340
|
+
checkpoints_template_widget,
|
|
341
|
+
# _session_widget
|
|
342
|
+
Text(""),
|
|
343
|
+
# _benchmark_widget
|
|
344
|
+
Text(""),
|
|
345
|
+
]
|
|
346
|
+
|
|
347
|
+
def search_text(self) -> str:
|
|
348
|
+
text = ""
|
|
349
|
+
text += str(self._task_id)
|
|
350
|
+
text += str(self._task_date)
|
|
351
|
+
text += str(self._model_name)
|
|
352
|
+
if self._training_project_info is not None:
|
|
353
|
+
text += str(self._training_project_info.name)
|
|
354
|
+
else:
|
|
355
|
+
text += "Project was deleted"
|
|
356
|
+
return text
|
|
357
|
+
|
|
358
|
+
def sort_values(self) -> List[int]:
|
|
359
|
+
# Sort by training project name: real names first (A->Z), deleted projects last
|
|
360
|
+
if self._training_project_info is not None:
|
|
361
|
+
training_project_name = (0, self._training_project_info.name.lower())
|
|
362
|
+
else:
|
|
363
|
+
training_project_name = (1, "")
|
|
364
|
+
|
|
365
|
+
if self._benchmark_report_id == "No evaluation report available":
|
|
366
|
+
benchmark_report_id = 0
|
|
367
|
+
else:
|
|
368
|
+
benchmark_report_id = 1
|
|
369
|
+
|
|
370
|
+
return [
|
|
371
|
+
self._task_id,
|
|
372
|
+
self._model_name.capitalize(),
|
|
373
|
+
training_project_name,
|
|
374
|
+
0,
|
|
375
|
+
0,
|
|
376
|
+
benchmark_report_id,
|
|
377
|
+
]
|
|
378
|
+
|
|
287
379
|
def __init__(
|
|
288
380
|
self,
|
|
289
|
-
|
|
381
|
+
api: Api = None,
|
|
382
|
+
team_id: int = None,
|
|
290
383
|
experiment_infos: List[ExperimentInfo] = [],
|
|
291
384
|
widget_id: str = None,
|
|
292
385
|
):
|
|
293
|
-
|
|
386
|
+
if team_id is None:
|
|
387
|
+
team_id = env.team_id()
|
|
388
|
+
self.team_id = team_id
|
|
389
|
+
if api is None:
|
|
390
|
+
api = Api()
|
|
391
|
+
self.api = api
|
|
392
|
+
self._experiment_infos = experiment_infos
|
|
393
|
+
self._checkpoint_changed_func = None
|
|
394
|
+
|
|
395
|
+
self._rows = []
|
|
396
|
+
self.table = self._create_table()
|
|
397
|
+
self._rows_search_texts = []
|
|
398
|
+
self._rows_sort_values = []
|
|
399
|
+
|
|
400
|
+
self._project_infos_map = self._get_project_infos_map(experiment_infos)
|
|
401
|
+
self.set_experiment_infos(experiment_infos)
|
|
402
|
+
super().__init__(widget_id=widget_id)
|
|
403
|
+
|
|
404
|
+
def _search_function(self, data: pd.DataFrame, search_value: str) -> List[ModelRow]:
|
|
405
|
+
search_texts = []
|
|
406
|
+
for idx in data.index:
|
|
407
|
+
first_col_value = data.loc[idx, self.COLUMNS[0]]
|
|
408
|
+
if isinstance(first_col_value, pd.Series):
|
|
409
|
+
first_col_value = first_col_value.iloc[0]
|
|
410
|
+
original_idx = self._first_column_value_to_index[first_col_value]
|
|
411
|
+
search_texts.append(self._rows_search_texts[original_idx])
|
|
412
|
+
|
|
413
|
+
search_series = pd.Series(search_texts, index=data.index)
|
|
414
|
+
mask = search_series.str.contains(search_value, case=False, na=False)
|
|
415
|
+
return data[mask]
|
|
416
|
+
|
|
417
|
+
def _sort_function(
|
|
418
|
+
self, data: pd.DataFrame, column_idx: int, order: str = "asc"
|
|
419
|
+
) -> List[ModelRow]:
|
|
420
|
+
data = data.copy()
|
|
421
|
+
if column_idx >= len(self._rows_sort_values[0]) if self._rows_sort_values else True:
|
|
422
|
+
raise IndexError(
|
|
423
|
+
f"Sorting by column idx = {column_idx} is not possible, your sort values have only {len(self._rows_sort_values[0]) if self._rows_sort_values else 0} columns with idx from 0 to {len(self._rows_sort_values[0]) - 1 if self._rows_sort_values else -1}"
|
|
424
|
+
)
|
|
294
425
|
|
|
295
|
-
|
|
296
|
-
|
|
426
|
+
if order == "asc":
|
|
427
|
+
ascending = True
|
|
428
|
+
else:
|
|
429
|
+
ascending = False
|
|
430
|
+
|
|
431
|
+
try:
|
|
432
|
+
sort_values = []
|
|
433
|
+
for idx in data.index:
|
|
434
|
+
first_col_value = data.loc[idx, self.COLUMNS[0]]
|
|
435
|
+
if isinstance(first_col_value, pd.Series):
|
|
436
|
+
first_col_value = first_col_value.iloc[0]
|
|
437
|
+
original_idx = self._first_column_value_to_index[first_col_value]
|
|
438
|
+
sort_values.append(self._rows_sort_values[original_idx][column_idx])
|
|
439
|
+
|
|
440
|
+
sort_series = pd.Series(sort_values, index=data.index)
|
|
441
|
+
sorted_indices = sort_series.sort_values(ascending=ascending).index
|
|
442
|
+
data = data.loc[sorted_indices]
|
|
443
|
+
data.reset_index(inplace=True, drop=True)
|
|
444
|
+
|
|
445
|
+
except IndexError as e:
|
|
446
|
+
e.args = (
|
|
447
|
+
f"Sorting by column idx = {column_idx} is not possible, your sort values have only {len(self._rows_sort_values[0]) if self._rows_sort_values else 0} columns with idx from 0 to {len(self._rows_sort_values[0]) - 1 if self._rows_sort_values else -1}",
|
|
448
|
+
)
|
|
449
|
+
raise e
|
|
450
|
+
|
|
451
|
+
return data
|
|
452
|
+
|
|
453
|
+
def _filter_function(
|
|
454
|
+
self, data: pd.DataFrame, filter_value: Tuple[List[str], List[str]]
|
|
455
|
+
) -> pd.DataFrame:
|
|
456
|
+
try:
|
|
457
|
+
frameworks, task_types = filter_value
|
|
458
|
+
|
|
459
|
+
filtered_experiments_idxs = set()
|
|
460
|
+
if not frameworks and not task_types:
|
|
461
|
+
return data
|
|
462
|
+
|
|
463
|
+
for idx, experiment_info in enumerate(self._experiment_infos):
|
|
464
|
+
should_add = True
|
|
465
|
+
if frameworks and experiment_info.framework_name not in frameworks:
|
|
466
|
+
should_add = False
|
|
467
|
+
if task_types and experiment_info.task_type not in task_types:
|
|
468
|
+
should_add = False
|
|
469
|
+
if should_add:
|
|
470
|
+
filtered_experiments_idxs.add(idx)
|
|
471
|
+
|
|
472
|
+
filtered_data = data.iloc[sorted(filtered_experiments_idxs)]
|
|
473
|
+
filtered_data.reset_index(inplace=True, drop=True)
|
|
474
|
+
return filtered_data
|
|
475
|
+
except Exception as e:
|
|
476
|
+
logger.error(f"Error during filtering: {e}", exc_info=True)
|
|
477
|
+
show_dialog(title="Filtering Error", description=str(e), status="error")
|
|
478
|
+
return data
|
|
479
|
+
|
|
480
|
+
def _get_frameworks(self):
|
|
481
|
+
frameworks = set()
|
|
482
|
+
for experiment_info in self._experiment_infos:
|
|
483
|
+
frameworks.add(experiment_info.framework_name)
|
|
484
|
+
return sorted(frameworks)
|
|
485
|
+
|
|
486
|
+
def _get_task_types(self):
|
|
487
|
+
task_types = set()
|
|
488
|
+
for experiment_info in self._experiment_infos:
|
|
489
|
+
task_types.add(experiment_info.task_type)
|
|
490
|
+
return sorted(task_types)
|
|
491
|
+
|
|
492
|
+
def _create_table(self) -> FastTable:
|
|
493
|
+
widgets = self.ModelRow.widgets_templates()
|
|
494
|
+
columns = []
|
|
495
|
+
columns_options = []
|
|
496
|
+
for column_name, widget in zip(self.COLUMNS, widgets):
|
|
497
|
+
columns.append((column_name, widget))
|
|
498
|
+
columns_options.append({"customCell": True})
|
|
499
|
+
columns_options[3].update({"classes": "border border-gray-200 px-2"})
|
|
500
|
+
columns_options[3].update({"disableSort": True})
|
|
501
|
+
columns_options[4].update({"disableSort": True})
|
|
502
|
+
self.framework_filter = DropdownCheckboxSelector(
|
|
503
|
+
label="Framework",
|
|
504
|
+
items=[
|
|
505
|
+
DropdownCheckboxSelector.Item(framework) for framework in self._get_frameworks()
|
|
506
|
+
],
|
|
507
|
+
)
|
|
508
|
+
self.task_type_filter = DropdownCheckboxSelector(
|
|
509
|
+
label="Task Type",
|
|
510
|
+
items=[
|
|
511
|
+
DropdownCheckboxSelector.Item(task_type) for task_type in self._get_task_types()
|
|
512
|
+
],
|
|
513
|
+
)
|
|
514
|
+
table = FastTable(
|
|
515
|
+
columns=columns,
|
|
516
|
+
columns_options=columns_options,
|
|
517
|
+
is_radio=True,
|
|
518
|
+
page_size=10,
|
|
519
|
+
header_right_content=Container(
|
|
520
|
+
widgets=[self.framework_filter, self.task_type_filter],
|
|
521
|
+
gap=10,
|
|
522
|
+
direction="horizontal",
|
|
523
|
+
),
|
|
524
|
+
)
|
|
525
|
+
table.set_search(self._search_function)
|
|
526
|
+
table.set_sort(self._sort_function)
|
|
527
|
+
table.set_filter(self._filter_function)
|
|
297
528
|
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
529
|
+
@self.framework_filter.value_changed
|
|
530
|
+
def on_framework_filter_change(
|
|
531
|
+
selected_frameworks: List[DropdownCheckboxSelector.Item],
|
|
532
|
+
):
|
|
533
|
+
selected_frameworks = [item.id for item in selected_frameworks]
|
|
534
|
+
selected_task_types = self.task_type_filter.get_selected()
|
|
535
|
+
self.table.filter((selected_frameworks, selected_task_types))
|
|
301
536
|
|
|
302
|
-
self.
|
|
303
|
-
|
|
304
|
-
|
|
537
|
+
@self.task_type_filter.value_changed
|
|
538
|
+
def on_task_type_filter_change(
|
|
539
|
+
selected_task_types: List[DropdownCheckboxSelector.Item],
|
|
540
|
+
):
|
|
541
|
+
selected_task_types = [item.id for item in selected_task_types]
|
|
542
|
+
selected_frameworks = self.framework_filter.get_selected()
|
|
543
|
+
self.table.filter((selected_frameworks, selected_task_types))
|
|
305
544
|
|
|
306
|
-
|
|
307
|
-
self._rows_html = defaultdict(list)
|
|
308
|
-
for task_type in table_rows:
|
|
309
|
-
self._rows_html[task_type].extend(
|
|
310
|
-
[model_row.to_html() for model_row in table_rows[task_type]]
|
|
311
|
-
)
|
|
545
|
+
return table
|
|
312
546
|
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
"taskTypes": self._task_types,
|
|
336
|
-
}
|
|
337
|
-
|
|
338
|
-
def get_json_state(self) -> Dict:
|
|
339
|
-
return {
|
|
340
|
-
"selectedRow": 0,
|
|
341
|
-
"selectedTaskType": self.__default_selected_task_type,
|
|
342
|
-
}
|
|
343
|
-
|
|
344
|
-
def set_active_task_type(self, task_type: str):
|
|
345
|
-
if task_type not in self._task_types:
|
|
346
|
-
raise ValueError(f'Task Type "{task_type}" does not exist')
|
|
347
|
-
StateJson()[self.widget_id]["selectedTaskType"] = task_type
|
|
348
|
-
StateJson().send_changes()
|
|
349
|
-
|
|
350
|
-
def get_available_task_types(self) -> List[str]:
|
|
351
|
-
return self._task_types
|
|
352
|
-
|
|
353
|
-
def disable_table(self) -> None:
|
|
354
|
-
for task_type in self._rows:
|
|
355
|
-
for row in self._rows[task_type]:
|
|
356
|
-
row.checkpoints_selector.disable()
|
|
357
|
-
super().disable()
|
|
358
|
-
|
|
359
|
-
def enable_table(self) -> None:
|
|
360
|
-
for task_type in self._rows:
|
|
361
|
-
for row in self._rows[task_type]:
|
|
362
|
-
row.checkpoints_selector.enable()
|
|
363
|
-
super().enable()
|
|
547
|
+
def _get_project_infos_map(
|
|
548
|
+
self, experiment_infos: List[ExperimentInfo]
|
|
549
|
+
) -> Dict[int, ProjectInfo]:
|
|
550
|
+
"""
|
|
551
|
+
Returns a map of project IDs to project infos used in the experiment infos.
|
|
552
|
+
"""
|
|
553
|
+
project_ids = set()
|
|
554
|
+
for experiment_info in experiment_infos:
|
|
555
|
+
if experiment_info.project_id is not None:
|
|
556
|
+
project_ids.add(experiment_info.project_id)
|
|
557
|
+
project_ids = list(project_ids)
|
|
558
|
+
|
|
559
|
+
project_infos_map = {}
|
|
560
|
+
if project_ids is not None:
|
|
561
|
+
for batch in batched(project_ids):
|
|
562
|
+
filters = [
|
|
563
|
+
{
|
|
564
|
+
ApiField.FIELD: ApiField.ID,
|
|
565
|
+
ApiField.OPERATOR: "in",
|
|
566
|
+
ApiField.VALUE: batch,
|
|
567
|
+
},
|
|
568
|
+
]
|
|
364
569
|
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
570
|
+
fields = [ApiField.IMAGES_COUNT, ApiField.REFERENCE_IMAGE_URL]
|
|
571
|
+
batch_infos = self.api.project.get_list(
|
|
572
|
+
team_id=self.team_id,
|
|
573
|
+
filters=filters,
|
|
574
|
+
fields=fields,
|
|
575
|
+
)
|
|
576
|
+
for info in batch_infos:
|
|
577
|
+
project_infos_map[info.id] = info
|
|
368
578
|
|
|
369
|
-
|
|
370
|
-
self.disable_table()
|
|
371
|
-
super().disable()
|
|
579
|
+
return project_infos_map
|
|
372
580
|
|
|
373
|
-
def _generate_table_rows(
|
|
374
|
-
self, experiment_infos: List[ExperimentInfo]
|
|
375
|
-
) -> Dict[str, List[ModelRow]]:
|
|
581
|
+
def _generate_table_rows(self, experiment_infos: List[ExperimentInfo]) -> List[ModelRow]:
|
|
376
582
|
"""Method to generate table rows from remote path to training app save directory"""
|
|
377
583
|
|
|
378
584
|
def process_experiment_info(experiment_info: ExperimentInfo):
|
|
379
585
|
try:
|
|
586
|
+
logger.debug(f"Processing experiment info: {experiment_info.task_id}")
|
|
587
|
+
project_info = self._project_infos_map.get(experiment_info.project_id)
|
|
380
588
|
model_row = ExperimentSelector.ModelRow(
|
|
381
|
-
api=self.
|
|
382
|
-
team_id=self.
|
|
589
|
+
api=self.api,
|
|
590
|
+
team_id=self.team_id,
|
|
383
591
|
task_type=experiment_info.task_type,
|
|
384
592
|
experiment_info=experiment_info,
|
|
593
|
+
project_info=project_info,
|
|
385
594
|
)
|
|
595
|
+
|
|
596
|
+
def this_row_checkpoint_changed(checkpoint_value: str):
|
|
597
|
+
self._checkpoint_changed(model_row, checkpoint_value)
|
|
598
|
+
|
|
599
|
+
model_row.checkpoint_changed = this_row_checkpoint_changed
|
|
386
600
|
return experiment_info.task_type, model_row
|
|
387
601
|
except Exception as e:
|
|
388
602
|
logger.debug(f"Failed to process experiment info: {experiment_info}")
|
|
389
603
|
return None, None
|
|
390
604
|
|
|
391
|
-
table_rows =
|
|
392
|
-
with ThreadPoolExecutor() as executor:
|
|
393
|
-
futures =
|
|
394
|
-
executor.submit(process_experiment_info, experiment_info)
|
|
605
|
+
table_rows = []
|
|
606
|
+
with ThreadPoolExecutor(max_workers=10) as executor:
|
|
607
|
+
futures = [
|
|
608
|
+
executor.submit(process_experiment_info, experiment_info)
|
|
395
609
|
for experiment_info in experiment_infos
|
|
396
|
-
|
|
610
|
+
]
|
|
397
611
|
|
|
398
|
-
for future in
|
|
612
|
+
for future in futures:
|
|
399
613
|
result = future.result()
|
|
400
614
|
if result:
|
|
401
615
|
task_type, model_row = result
|
|
402
616
|
if task_type is not None and model_row is not None:
|
|
403
|
-
|
|
404
|
-
self.__debug_row = (task_type, model_row)
|
|
405
|
-
continue
|
|
406
|
-
table_rows[task_type].append(model_row)
|
|
407
|
-
self._sort_table_rows(table_rows)
|
|
408
|
-
if self.__debug_row and is_development():
|
|
409
|
-
task_type, model_row = self.__debug_row
|
|
410
|
-
table_rows[task_type].insert(0, model_row)
|
|
411
|
-
return table_rows
|
|
617
|
+
table_rows.append(model_row)
|
|
412
618
|
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
table_rows[task_type].sort(key=lambda row: int(row.task_id), reverse=True)
|
|
416
|
-
|
|
417
|
-
def _filter_task_types(self, task_types: List[str]):
|
|
418
|
-
sorted_tt = []
|
|
419
|
-
if "object detection" in task_types:
|
|
420
|
-
sorted_tt.append("object detection")
|
|
421
|
-
if "instance segmentation" in task_types:
|
|
422
|
-
sorted_tt.append("instance segmentation")
|
|
423
|
-
if "pose estimation" in task_types:
|
|
424
|
-
sorted_tt.append("pose estimation")
|
|
425
|
-
other_tasks = sorted(
|
|
426
|
-
set(task_types)
|
|
427
|
-
- set(
|
|
428
|
-
[
|
|
429
|
-
"object detection",
|
|
430
|
-
"instance segmentation",
|
|
431
|
-
"semantic segmentation",
|
|
432
|
-
"pose estimation",
|
|
433
|
-
]
|
|
434
|
-
)
|
|
435
|
-
)
|
|
436
|
-
sorted_tt.extend(other_tasks)
|
|
437
|
-
return sorted_tt
|
|
619
|
+
table_rows.sort(key=lambda x: x.task_id, reverse=True)
|
|
620
|
+
return table_rows
|
|
438
621
|
|
|
439
|
-
def
|
|
440
|
-
|
|
441
|
-
return
|
|
442
|
-
widget_actual_state = state[self.widget_id]
|
|
443
|
-
widget_actual_data = DataJson()[self.widget_id]
|
|
444
|
-
task_type = widget_actual_state["selectedTaskType"]
|
|
445
|
-
if widget_actual_state is not None and widget_actual_data is not None:
|
|
446
|
-
selected_row_index = int(widget_actual_state["selectedRow"])
|
|
447
|
-
return self._rows[task_type][selected_row_index]
|
|
448
|
-
|
|
449
|
-
def get_selected_row_index(self, state=StateJson()) -> Union[int, None]:
|
|
450
|
-
widget_actual_state = state[self.widget_id]
|
|
451
|
-
widget_actual_data = DataJson()[self.widget_id]
|
|
452
|
-
if widget_actual_state is not None and widget_actual_data is not None:
|
|
453
|
-
return widget_actual_state["selectedRow"]
|
|
454
|
-
|
|
455
|
-
def get_selected_task_type(self) -> str:
|
|
456
|
-
return StateJson()[self.widget_id]["selectedTaskType"]
|
|
457
|
-
|
|
458
|
-
def get_selected_experiment_info(self) -> Dict[str, Any]:
|
|
459
|
-
if len(self._rows) == 0:
|
|
460
|
-
return
|
|
461
|
-
selected_row = self.get_selected_row()
|
|
462
|
-
selected_row_json = asdict(selected_row._experiment_info)
|
|
463
|
-
return selected_row_json
|
|
622
|
+
def _update_search_text(self):
|
|
623
|
+
self._rows_search_texts = [row.search_text() for row in self._rows]
|
|
464
624
|
|
|
465
|
-
def
|
|
466
|
-
|
|
467
|
-
return
|
|
468
|
-
selected_row = self.get_selected_row()
|
|
469
|
-
return selected_row.get_selected_checkpoint_path()
|
|
625
|
+
def _update_sort_values(self):
|
|
626
|
+
self._rows_sort_values = [row.sort_values() for row in self._rows]
|
|
470
627
|
|
|
471
|
-
def
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
628
|
+
def _update_value_index_map(self):
|
|
629
|
+
self._first_column_value_to_index = {}
|
|
630
|
+
for i, row in self.table._source_data.iterrows():
|
|
631
|
+
value = row.iloc[0]
|
|
632
|
+
self._first_column_value_to_index[value] = i
|
|
476
633
|
|
|
477
|
-
def
|
|
634
|
+
def set_experiment_infos(self, experiment_infos: List[ExperimentInfo]) -> None:
|
|
478
635
|
"""
|
|
479
|
-
|
|
636
|
+
Updates the experiment infos and regenerates the table rows.
|
|
480
637
|
"""
|
|
638
|
+
table_rows = self._generate_table_rows(experiment_infos)
|
|
639
|
+
self._rows = table_rows
|
|
640
|
+
for row in table_rows:
|
|
641
|
+
self.table.insert_row(row.to_table_row())
|
|
642
|
+
self._update_value_index_map()
|
|
643
|
+
self._update_search_text()
|
|
644
|
+
self._update_sort_values()
|
|
645
|
+
|
|
646
|
+
def get_selected_experiment_info(self) -> Union[ExperimentInfo, None]:
|
|
647
|
+
selected_row = self.table.get_selected_row()
|
|
648
|
+
if selected_row is None:
|
|
649
|
+
return None
|
|
650
|
+
return self._rows[selected_row.row_index]._experiment_info
|
|
651
|
+
|
|
652
|
+
def get_selected_experiment_info_json(self) -> Union[dict, None]:
|
|
481
653
|
experiment_info = self.get_selected_experiment_info()
|
|
482
|
-
|
|
483
|
-
|
|
654
|
+
if experiment_info is None:
|
|
655
|
+
return None
|
|
656
|
+
return experiment_info.to_json()
|
|
657
|
+
|
|
658
|
+
def get_selected_checkpoint_name(self) -> Union[str, None]:
|
|
659
|
+
selected_row = self.table.get_selected_row()
|
|
660
|
+
if selected_row is None:
|
|
661
|
+
return None
|
|
662
|
+
return self._rows[selected_row.row_index].get_selected_checkpoint_name()
|
|
663
|
+
|
|
664
|
+
def get_selected_checkpoint_path(self) -> Union[str, None]:
|
|
665
|
+
selected_row = self.table.get_selected_row()
|
|
666
|
+
if selected_row is None:
|
|
667
|
+
return None
|
|
668
|
+
return self._rows[selected_row.row_index].get_selected_checkpoint_path()
|
|
669
|
+
|
|
670
|
+
def set_selected_row_by_experiment_info(self, experiment_info: ExperimentInfo) -> None:
|
|
671
|
+
for idx, row in enumerate(self._rows):
|
|
672
|
+
if row._experiment_info.task_id == experiment_info.task_id:
|
|
673
|
+
self.table.select_row(idx)
|
|
674
|
+
return
|
|
675
|
+
raise ValueError(f"Experiment info {experiment_info} not found in the table rows.")
|
|
676
|
+
|
|
677
|
+
def _checkpoint_changed(self, row: ModelRow, checkpoint_value: str):
|
|
678
|
+
if self._checkpoint_changed_func is None:
|
|
679
|
+
return
|
|
680
|
+
return self._checkpoint_changed_func(row, checkpoint_value)
|
|
484
681
|
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
full_model_files["checkpoint"] = self.get_selected_checkpoint_path()
|
|
489
|
-
return full_model_files
|
|
682
|
+
def checkpoint_changed(self, func: Callable[[ModelRow, str], None]):
|
|
683
|
+
self._checkpoint_changed_func = func
|
|
684
|
+
return self._checkpoint_changed_func
|
|
490
685
|
|
|
491
|
-
def
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
return deploy_params
|
|
501
|
-
|
|
502
|
-
def set_active_row(self, row_index: int, task_type: str = None) -> None:
|
|
503
|
-
if task_type is None:
|
|
504
|
-
task_type = self.get_selected_task_type()
|
|
505
|
-
self.set_active_task_type(task_type)
|
|
506
|
-
if row_index < 0 or row_index > len(self._rows[task_type]) - 1:
|
|
507
|
-
raise ValueError(f'Row with index "{row_index}" does not exist')
|
|
508
|
-
StateJson()[self.widget_id]["selectedRow"] = row_index
|
|
509
|
-
StateJson().send_changes()
|
|
510
|
-
|
|
511
|
-
def set_by_task_id(self, task_id: int) -> None:
|
|
512
|
-
for task_type in self._rows:
|
|
513
|
-
for i, row in enumerate(self._rows[task_type]):
|
|
514
|
-
if row.task_id == task_id:
|
|
515
|
-
self.set_active_task_type(task_type)
|
|
516
|
-
self.set_active_row(i, task_type)
|
|
517
|
-
return
|
|
686
|
+
def selection_changed(self, func):
|
|
687
|
+
def f(selected_row: FastTable.ClickedRow):
|
|
688
|
+
if selected_row is None:
|
|
689
|
+
return
|
|
690
|
+
idx = selected_row.row_index
|
|
691
|
+
experiment_info = self._rows[idx]._experiment_info
|
|
692
|
+
func(experiment_info)
|
|
693
|
+
|
|
694
|
+
return self.table.selection_changed(f)
|
|
518
695
|
|
|
519
|
-
def
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
696
|
+
def set_selected_checkpoint_by_name(self, checkpoint_name: str):
|
|
697
|
+
selected_row = self.table.get_selected_row()
|
|
698
|
+
if selected_row is None:
|
|
699
|
+
return
|
|
700
|
+
self._rows[selected_row.row_index].set_selected_checkpoint_by_name(checkpoint_name)
|
|
701
|
+
|
|
702
|
+
def set_selected_row_by_task_id(self, task_id: int):
|
|
703
|
+
for idx, row in enumerate(self._rows):
|
|
704
|
+
if row._experiment_info.task_id == task_id:
|
|
705
|
+
self.table.select_row(idx)
|
|
706
|
+
return
|
|
707
|
+
raise ValueError(f"Experiment info with task id {task_id} not found in the table rows.")
|
|
708
|
+
|
|
709
|
+
def get_selected_row_by_task_id(self, task_id: int):
|
|
710
|
+
for idx, row in enumerate(self._rows):
|
|
711
|
+
if row._experiment_info.task_id == task_id:
|
|
712
|
+
return row
|
|
524
713
|
return None
|
|
525
714
|
|
|
526
|
-
def
|
|
527
|
-
|
|
528
|
-
server = self._sly_app.get_server()
|
|
529
|
-
self._task_type_changes_handled = True
|
|
715
|
+
def search(self, search_value: str):
|
|
716
|
+
self.table.search(search_value)
|
|
530
717
|
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
res = self.get_selected_task_type()
|
|
534
|
-
func(res)
|
|
718
|
+
def disable(self):
|
|
719
|
+
return self.table.disable()
|
|
535
720
|
|
|
536
|
-
|
|
721
|
+
def enable(self):
|
|
722
|
+
return self.table.enable()
|
|
537
723
|
|
|
538
|
-
def
|
|
539
|
-
|
|
540
|
-
server = self._sly_app.get_server()
|
|
541
|
-
self._changes_handled = True
|
|
724
|
+
def get_json_data(self):
|
|
725
|
+
return {}
|
|
542
726
|
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
res = self.get_selected_row()
|
|
546
|
-
func(res)
|
|
727
|
+
def get_json_state(self):
|
|
728
|
+
return {}
|
|
547
729
|
|
|
548
|
-
|
|
730
|
+
def to_html(self):
|
|
731
|
+
return self.table.to_html()
|