supervisely 6.73.420__py3-none-any.whl → 6.73.422__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. supervisely/api/api.py +10 -5
  2. supervisely/api/app_api.py +71 -4
  3. supervisely/api/module_api.py +4 -0
  4. supervisely/api/nn/deploy_api.py +15 -9
  5. supervisely/api/nn/ecosystem_models_api.py +201 -0
  6. supervisely/api/nn/neural_network_api.py +12 -3
  7. supervisely/api/project_api.py +35 -6
  8. supervisely/api/task_api.py +5 -1
  9. supervisely/app/widgets/__init__.py +8 -1
  10. supervisely/app/widgets/agent_selector/template.html +1 -0
  11. supervisely/app/widgets/deploy_model/__init__.py +0 -0
  12. supervisely/app/widgets/deploy_model/deploy_model.py +729 -0
  13. supervisely/app/widgets/dropdown_checkbox_selector/__init__.py +0 -0
  14. supervisely/app/widgets/dropdown_checkbox_selector/dropdown_checkbox_selector.py +87 -0
  15. supervisely/app/widgets/dropdown_checkbox_selector/template.html +12 -0
  16. supervisely/app/widgets/ecosystem_model_selector/__init__.py +0 -0
  17. supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +190 -0
  18. supervisely/app/widgets/experiment_selector/experiment_selector.py +447 -264
  19. supervisely/app/widgets/fast_table/fast_table.py +402 -74
  20. supervisely/app/widgets/fast_table/script.js +364 -96
  21. supervisely/app/widgets/fast_table/style.css +24 -0
  22. supervisely/app/widgets/fast_table/template.html +43 -3
  23. supervisely/app/widgets/radio_table/radio_table.py +10 -2
  24. supervisely/app/widgets/select/select.py +6 -4
  25. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +18 -0
  26. supervisely/app/widgets/tabs/tabs.py +22 -6
  27. supervisely/app/widgets/tabs/template.html +5 -1
  28. supervisely/nn/artifacts/__init__.py +1 -1
  29. supervisely/nn/artifacts/artifacts.py +10 -2
  30. supervisely/nn/artifacts/detectron2.py +1 -0
  31. supervisely/nn/artifacts/hrda.py +1 -0
  32. supervisely/nn/artifacts/mmclassification.py +20 -0
  33. supervisely/nn/artifacts/mmdetection.py +5 -3
  34. supervisely/nn/artifacts/mmsegmentation.py +1 -0
  35. supervisely/nn/artifacts/ritm.py +1 -0
  36. supervisely/nn/artifacts/rtdetr.py +1 -0
  37. supervisely/nn/artifacts/unet.py +1 -0
  38. supervisely/nn/artifacts/utils.py +3 -0
  39. supervisely/nn/artifacts/yolov5.py +2 -0
  40. supervisely/nn/artifacts/yolov8.py +1 -0
  41. supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +18 -18
  42. supervisely/nn/experiments.py +9 -0
  43. supervisely/nn/inference/gui/serving_gui_template.py +39 -13
  44. supervisely/nn/inference/inference.py +160 -94
  45. supervisely/nn/inference/predict_app/__init__.py +0 -0
  46. supervisely/nn/inference/predict_app/gui/__init__.py +0 -0
  47. supervisely/nn/inference/predict_app/gui/classes_selector.py +91 -0
  48. supervisely/nn/inference/predict_app/gui/gui.py +710 -0
  49. supervisely/nn/inference/predict_app/gui/input_selector.py +165 -0
  50. supervisely/nn/inference/predict_app/gui/model_selector.py +79 -0
  51. supervisely/nn/inference/predict_app/gui/output_selector.py +139 -0
  52. supervisely/nn/inference/predict_app/gui/preview.py +93 -0
  53. supervisely/nn/inference/predict_app/gui/settings_selector.py +184 -0
  54. supervisely/nn/inference/predict_app/gui/tags_selector.py +110 -0
  55. supervisely/nn/inference/predict_app/gui/utils.py +282 -0
  56. supervisely/nn/inference/predict_app/predict_app.py +184 -0
  57. supervisely/nn/inference/uploader.py +9 -5
  58. supervisely/nn/model/prediction.py +2 -0
  59. supervisely/nn/model/prediction_session.py +20 -3
  60. supervisely/nn/training/gui/gui.py +131 -44
  61. supervisely/nn/training/gui/model_selector.py +8 -6
  62. supervisely/nn/training/gui/train_val_splits_selector.py +122 -70
  63. supervisely/nn/training/gui/training_artifacts.py +0 -5
  64. supervisely/nn/training/train_app.py +161 -44
  65. supervisely/template/experiment/experiment.html.jinja +74 -17
  66. supervisely/template/experiment/experiment_generator.py +258 -112
  67. supervisely/template/experiment/header.html.jinja +31 -13
  68. supervisely/template/experiment/sly-style.css +7 -2
  69. {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/METADATA +3 -1
  70. {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/RECORD +74 -56
  71. supervisely/app/widgets/experiment_selector/style.css +0 -27
  72. supervisely/app/widgets/experiment_selector/template.html +0 -61
  73. {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/LICENSE +0 -0
  74. {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/WHEEL +0 -0
  75. {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/entry_points.txt +0 -0
  76. {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/top_level.txt +0 -0
@@ -1,42 +1,52 @@
1
+ import json
1
2
  import os
2
- from collections import defaultdict
3
- from concurrent.futures import ThreadPoolExecutor, as_completed
4
- from dataclasses import asdict
5
- from typing import Any, Callable, Dict, List, Union
6
-
7
- from supervisely import env, logger
8
- from supervisely._utils import abs_url, is_development
9
- from supervisely.api.api import Api
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ from functools import partial
5
+ from typing import Callable, Dict, List, Optional, Tuple, Union
6
+
7
+ import pandas as pd
8
+
9
+ from supervisely import batched
10
+ from supervisely._utils import abs_url, is_development, logger
11
+ from supervisely.api.api import Api, ApiField
10
12
  from supervisely.api.project_api import ProjectInfo
11
- from supervisely.app.content import DataJson, StateJson
12
- from supervisely.app.widgets import (
13
- Container,
14
- Flexbox,
15
- ProjectThumbnail,
16
- Select,
17
- Text,
18
- Widget,
13
+ from supervisely.app.exceptions import show_dialog
14
+ from supervisely.app.widgets.container.container import Container
15
+ from supervisely.app.widgets.dropdown_checkbox_selector.dropdown_checkbox_selector import (
16
+ DropdownCheckboxSelector,
19
17
  )
18
+ from supervisely.app.widgets.fast_table.fast_table import FastTable
19
+ from supervisely.app.widgets.flexbox.flexbox import Flexbox
20
+ from supervisely.app.widgets.project_thumbnail.project_thumbnail import ProjectThumbnail
21
+ from supervisely.app.widgets.select.select import Select
22
+ from supervisely.app.widgets.text.text import Text
23
+ from supervisely.app.widgets.widget import Widget
24
+ from supervisely.io import env
20
25
  from supervisely.io.fs import get_file_name_with_ext
21
26
  from supervisely.nn.experiments import ExperimentInfo
22
- from supervisely.nn.utils import ModelSource
23
-
24
- WEIGHTS_DIR = "weights"
25
-
26
- COL_ID = "task id".upper()
27
- COL_MODEL = "model".upper()
28
- COL_PROJECT = "training data".upper()
29
- COL_CHECKPOINTS = "checkpoints".upper()
30
- COL_SESSION = "session".upper()
31
- COL_BENCHMARK = "benchmark".upper()
32
-
33
- columns = [COL_ID, COL_MODEL, COL_PROJECT, COL_CHECKPOINTS, COL_SESSION, COL_BENCHMARK]
34
27
 
35
28
 
36
29
  class ExperimentSelector(Widget):
37
- class Routes:
38
- TASK_TYPE_CHANGED = "task_type_changed"
39
- VALUE_CHANGED = "value_changed"
30
+ """
31
+ Widget for selecting experiments from a team.
32
+ """
33
+
34
+ class COLUMN:
35
+ NAME = "TASK ID"
36
+ MODEL = "MODEL"
37
+ TRAINING_DATA = "TRAINING DATA"
38
+ CHECKPOINTS = "CHECKPOINTS"
39
+ SESSION = "SESSION"
40
+ BENCHMARK = "BENCHMARK"
41
+
42
+ COLUMNS = [
43
+ COLUMN.NAME,
44
+ COLUMN.MODEL,
45
+ COLUMN.TRAINING_DATA,
46
+ COLUMN.CHECKPOINTS,
47
+ COLUMN.SESSION,
48
+ COLUMN.BENCHMARK,
49
+ ]
40
50
 
41
51
  class ModelRow:
42
52
  def __init__(
@@ -45,15 +55,19 @@ class ExperimentSelector(Widget):
45
55
  team_id: int,
46
56
  task_type: str,
47
57
  experiment_info: ExperimentInfo,
58
+ project_info: Optional[ProjectInfo] = None,
48
59
  ):
49
60
  self._api = api
50
61
  self._team_id = team_id
51
62
  self._task_type = task_type
52
63
  self._experiment_info = experiment_info
64
+ self._project_info = project_info
53
65
 
54
66
  task_id = experiment_info.task_id
55
- if task_id == "debug-session" or task_id == -1:
67
+ if task_id == -1:
56
68
  pass
69
+ elif task_id == "debug-session":
70
+ task_id = -1
57
71
  elif type(task_id) is str:
58
72
  if task_id.isdigit():
59
73
  task_id = int(task_id)
@@ -77,9 +91,7 @@ class ExperimentSelector(Widget):
77
91
  if self._training_project_id is None:
78
92
  self._training_project_info = None
79
93
  else:
80
- self._training_project_info = self._api.project.get_info_by_id(
81
- self._training_project_id
82
- )
94
+ self._training_project_info = self._project_info
83
95
 
84
96
  # col 4 checkpoints
85
97
  self._checkpoints = experiment_info.checkpoints
@@ -143,14 +155,6 @@ class ExperimentSelector(Widget):
143
155
  def checkpoints_selector(self) -> Select:
144
156
  return self._checkpoints_widget
145
157
 
146
- @property
147
- def experiment_info(self) -> ExperimentInfo:
148
- return self._experiment_info
149
-
150
- @property
151
- def best_checkpoint(self) -> str:
152
- return self.experiment_info.best_checkpoint
153
-
154
158
  @property
155
159
  def session_link(self) -> str:
156
160
  return self._session_link
@@ -234,17 +238,24 @@ class ExperimentSelector(Widget):
234
238
  return model_widget
235
239
 
236
240
  def _create_training_project_widget(self) -> Union[ProjectThumbnail, Text]:
241
+ training_project_thumbnail = ProjectThumbnail(
242
+ self._training_project_info, remove_margins=True
243
+ )
244
+ training_project_text = Text(
245
+ f"<span class='field-description text-muted' style='color: #7f858e'>Project was deleted</span>",
246
+ "text",
247
+ font_size=13,
248
+ )
237
249
  if self.training_project_info is not None:
238
- training_project_widget = ProjectThumbnail(
239
- self._training_project_info, remove_margins=True
240
- )
250
+ training_project_thumbnail.show()
251
+ training_project_text.hide()
241
252
  else:
242
- training_project_widget = Text(
243
- f"<span class='field-description text-muted' style='color: #7f858e'>Project was deleted</span>",
244
- "text",
245
- font_size=13,
246
- )
247
- return training_project_widget
253
+ training_project_thumbnail.hide()
254
+ training_project_text.show()
255
+ return Container(widgets=[training_project_thumbnail, training_project_text], gap=0)
256
+
257
+ def checkpoint_changed(self, checkpoint_value: str):
258
+ return
248
259
 
249
260
  def _create_checkpoints_widget(self) -> Select:
250
261
  checkpoint_selector_items = []
@@ -253,6 +264,11 @@ class ExperimentSelector(Widget):
253
264
  checkpoint_selector = Select(items=checkpoint_selector_items)
254
265
  if self._best_checkpoint_value is not None:
255
266
  checkpoint_selector.set_value(self._best_checkpoint)
267
+
268
+ @checkpoint_selector.value_changed
269
+ def on_checkpoint_changed(checkpoint_value: str):
270
+ self.checkpoint_changed(checkpoint_value)
271
+
256
272
  return checkpoint_selector
257
273
 
258
274
  def _create_session_widget(self) -> Text:
@@ -284,265 +300,432 @@ class ExperimentSelector(Widget):
284
300
  )
285
301
  return benchmark_widget
286
302
 
303
+ def _widget_to_cell_value(self, widget: Widget) -> str:
304
+ if isinstance(widget, Container):
305
+ return json.dumps(
306
+ {
307
+ "widget_id": widget.widget_id,
308
+ "widgets": [w.widget_id for w in widget._widgets],
309
+ }
310
+ )
311
+ else:
312
+ return json.dumps({"widget_id": widget.widget_id, "widgets": []})
313
+
314
+ def to_table_row(self):
315
+ return [
316
+ self._widget_to_cell_value(w)
317
+ for w in [
318
+ self._task_widget,
319
+ self._model_wiget,
320
+ self._training_project_widget,
321
+ self._checkpoints_widget,
322
+ self._session_widget,
323
+ self._benchmark_widget,
324
+ ]
325
+ ]
326
+
327
+ @classmethod
328
+ def widgets_templates(cls):
329
+ checkpoints_template_widget = Select(items=[])
330
+ checkpoints_template_widget.value_changed(lambda _: None)
331
+
332
+ return [
333
+ # _task_widget
334
+ Container(widgets=[Text(""), Text("")], gap=0),
335
+ # _model_wiget
336
+ Text(""),
337
+ # _training_project_widget
338
+ Container(widgets=[ProjectThumbnail(remove_margins=True), Text("")], gap=0),
339
+ # _checkpoints_widget
340
+ checkpoints_template_widget,
341
+ # _session_widget
342
+ Text(""),
343
+ # _benchmark_widget
344
+ Text(""),
345
+ ]
346
+
347
+ def search_text(self) -> str:
348
+ text = ""
349
+ text += str(self._task_id)
350
+ text += str(self._task_date)
351
+ text += str(self._model_name)
352
+ if self._training_project_info is not None:
353
+ text += str(self._training_project_info.name)
354
+ else:
355
+ text += "Project was deleted"
356
+ return text
357
+
358
+ def sort_values(self) -> List[int]:
359
+ # Sort by training project name: real names first (A->Z), deleted projects last
360
+ if self._training_project_info is not None:
361
+ training_project_name = (0, self._training_project_info.name.lower())
362
+ else:
363
+ training_project_name = (1, "")
364
+
365
+ if self._benchmark_report_id == "No evaluation report available":
366
+ benchmark_report_id = 0
367
+ else:
368
+ benchmark_report_id = 1
369
+
370
+ return [
371
+ self._task_id,
372
+ self._model_name.capitalize(),
373
+ training_project_name,
374
+ 0,
375
+ 0,
376
+ benchmark_report_id,
377
+ ]
378
+
287
379
  def __init__(
288
380
  self,
289
- team_id: int,
381
+ api: Api = None,
382
+ team_id: int = None,
290
383
  experiment_infos: List[ExperimentInfo] = [],
291
384
  widget_id: str = None,
292
385
  ):
293
- self._api = Api.from_env()
386
+ if team_id is None:
387
+ team_id = env.team_id()
388
+ self.team_id = team_id
389
+ if api is None:
390
+ api = Api()
391
+ self.api = api
392
+ self._experiment_infos = experiment_infos
393
+ self._checkpoint_changed_func = None
394
+
395
+ self._rows = []
396
+ self.table = self._create_table()
397
+ self._rows_search_texts = []
398
+ self._rows_sort_values = []
399
+
400
+ self._project_infos_map = self._get_project_infos_map(experiment_infos)
401
+ self.set_experiment_infos(experiment_infos)
402
+ super().__init__(widget_id=widget_id)
403
+
404
+ def _search_function(self, data: pd.DataFrame, search_value: str) -> List[ModelRow]:
405
+ search_texts = []
406
+ for idx in data.index:
407
+ first_col_value = data.loc[idx, self.COLUMNS[0]]
408
+ if isinstance(first_col_value, pd.Series):
409
+ first_col_value = first_col_value.iloc[0]
410
+ original_idx = self._first_column_value_to_index[first_col_value]
411
+ search_texts.append(self._rows_search_texts[original_idx])
412
+
413
+ search_series = pd.Series(search_texts, index=data.index)
414
+ mask = search_series.str.contains(search_value, case=False, na=False)
415
+ return data[mask]
416
+
417
+ def _sort_function(
418
+ self, data: pd.DataFrame, column_idx: int, order: str = "asc"
419
+ ) -> List[ModelRow]:
420
+ data = data.copy()
421
+ if column_idx >= len(self._rows_sort_values[0]) if self._rows_sort_values else True:
422
+ raise IndexError(
423
+ f"Sorting by column idx = {column_idx} is not possible, your sort values have only {len(self._rows_sort_values[0]) if self._rows_sort_values else 0} columns with idx from 0 to {len(self._rows_sort_values[0]) - 1 if self._rows_sort_values else -1}"
424
+ )
294
425
 
295
- self._team_id = team_id
296
- self.__debug_row = None
426
+ if order == "asc":
427
+ ascending = True
428
+ else:
429
+ ascending = False
430
+
431
+ try:
432
+ sort_values = []
433
+ for idx in data.index:
434
+ first_col_value = data.loc[idx, self.COLUMNS[0]]
435
+ if isinstance(first_col_value, pd.Series):
436
+ first_col_value = first_col_value.iloc[0]
437
+ original_idx = self._first_column_value_to_index[first_col_value]
438
+ sort_values.append(self._rows_sort_values[original_idx][column_idx])
439
+
440
+ sort_series = pd.Series(sort_values, index=data.index)
441
+ sorted_indices = sort_series.sort_values(ascending=ascending).index
442
+ data = data.loc[sorted_indices]
443
+ data.reset_index(inplace=True, drop=True)
444
+
445
+ except IndexError as e:
446
+ e.args = (
447
+ f"Sorting by column idx = {column_idx} is not possible, your sort values have only {len(self._rows_sort_values[0]) if self._rows_sort_values else 0} columns with idx from 0 to {len(self._rows_sort_values[0]) - 1 if self._rows_sort_values else -1}",
448
+ )
449
+ raise e
450
+
451
+ return data
452
+
453
+ def _filter_function(
454
+ self, data: pd.DataFrame, filter_value: Tuple[List[str], List[str]]
455
+ ) -> pd.DataFrame:
456
+ try:
457
+ frameworks, task_types = filter_value
458
+
459
+ filtered_experiments_idxs = set()
460
+ if not frameworks and not task_types:
461
+ return data
462
+
463
+ for idx, experiment_info in enumerate(self._experiment_infos):
464
+ should_add = True
465
+ if frameworks and experiment_info.framework_name not in frameworks:
466
+ should_add = False
467
+ if task_types and experiment_info.task_type not in task_types:
468
+ should_add = False
469
+ if should_add:
470
+ filtered_experiments_idxs.add(idx)
471
+
472
+ filtered_data = data.iloc[sorted(filtered_experiments_idxs)]
473
+ filtered_data.reset_index(inplace=True, drop=True)
474
+ return filtered_data
475
+ except Exception as e:
476
+ logger.error(f"Error during filtering: {e}", exc_info=True)
477
+ show_dialog(title="Filtering Error", description=str(e), status="error")
478
+ return data
479
+
480
+ def _get_frameworks(self):
481
+ frameworks = set()
482
+ for experiment_info in self._experiment_infos:
483
+ frameworks.add(experiment_info.framework_name)
484
+ return sorted(frameworks)
485
+
486
+ def _get_task_types(self):
487
+ task_types = set()
488
+ for experiment_info in self._experiment_infos:
489
+ task_types.add(experiment_info.task_type)
490
+ return sorted(task_types)
491
+
492
+ def _create_table(self) -> FastTable:
493
+ widgets = self.ModelRow.widgets_templates()
494
+ columns = []
495
+ columns_options = []
496
+ for column_name, widget in zip(self.COLUMNS, widgets):
497
+ columns.append((column_name, widget))
498
+ columns_options.append({"customCell": True})
499
+ columns_options[3].update({"classes": "border border-gray-200 px-2"})
500
+ columns_options[3].update({"disableSort": True})
501
+ columns_options[4].update({"disableSort": True})
502
+ self.framework_filter = DropdownCheckboxSelector(
503
+ label="Framework",
504
+ items=[
505
+ DropdownCheckboxSelector.Item(framework) for framework in self._get_frameworks()
506
+ ],
507
+ )
508
+ self.task_type_filter = DropdownCheckboxSelector(
509
+ label="Task Type",
510
+ items=[
511
+ DropdownCheckboxSelector.Item(task_type) for task_type in self._get_task_types()
512
+ ],
513
+ )
514
+ table = FastTable(
515
+ columns=columns,
516
+ columns_options=columns_options,
517
+ is_radio=True,
518
+ page_size=10,
519
+ header_right_content=Container(
520
+ widgets=[self.framework_filter, self.task_type_filter],
521
+ gap=10,
522
+ direction="horizontal",
523
+ ),
524
+ )
525
+ table.set_search(self._search_function)
526
+ table.set_sort(self._sort_function)
527
+ table.set_filter(self._filter_function)
297
528
 
298
- with ThreadPoolExecutor() as executor:
299
- future = executor.submit(self._generate_table_rows, experiment_infos)
300
- table_rows = future.result()
529
+ @self.framework_filter.value_changed
530
+ def on_framework_filter_change(
531
+ selected_frameworks: List[DropdownCheckboxSelector.Item],
532
+ ):
533
+ selected_frameworks = [item.id for item in selected_frameworks]
534
+ selected_task_types = self.task_type_filter.get_selected()
535
+ self.table.filter((selected_frameworks, selected_task_types))
301
536
 
302
- self._columns = columns
303
- self._rows = table_rows
304
- # self._rows_html = #[row.to_html() for row in self._rows]
537
+ @self.task_type_filter.value_changed
538
+ def on_task_type_filter_change(
539
+ selected_task_types: List[DropdownCheckboxSelector.Item],
540
+ ):
541
+ selected_task_types = [item.id for item in selected_task_types]
542
+ selected_frameworks = self.framework_filter.get_selected()
543
+ self.table.filter((selected_frameworks, selected_task_types))
305
544
 
306
- task_types = [task_type for task_type in table_rows]
307
- self._rows_html = defaultdict(list)
308
- for task_type in table_rows:
309
- self._rows_html[task_type].extend(
310
- [model_row.to_html() for model_row in table_rows[task_type]]
311
- )
545
+ return table
312
546
 
313
- self._task_types = self._filter_task_types(task_types)
314
- if len(self._task_types) == 0:
315
- self.__default_selected_task_type = None
316
- else:
317
- self.__default_selected_task_type = self._task_types[0]
318
-
319
- self._changes_handled = False
320
- self._task_type_changes_handled = False
321
- super().__init__(widget_id=widget_id, file_path=__file__)
322
-
323
- @property
324
- def columns(self) -> List[str]:
325
- return self._columns
326
-
327
- @property
328
- def rows(self) -> Dict[str, List[ModelRow]]:
329
- return self._rows
330
-
331
- def get_json_data(self) -> Dict:
332
- return {
333
- "columns": self._columns,
334
- "rowsHtml": self._rows_html,
335
- "taskTypes": self._task_types,
336
- }
337
-
338
- def get_json_state(self) -> Dict:
339
- return {
340
- "selectedRow": 0,
341
- "selectedTaskType": self.__default_selected_task_type,
342
- }
343
-
344
- def set_active_task_type(self, task_type: str):
345
- if task_type not in self._task_types:
346
- raise ValueError(f'Task Type "{task_type}" does not exist')
347
- StateJson()[self.widget_id]["selectedTaskType"] = task_type
348
- StateJson().send_changes()
349
-
350
- def get_available_task_types(self) -> List[str]:
351
- return self._task_types
352
-
353
- def disable_table(self) -> None:
354
- for task_type in self._rows:
355
- for row in self._rows[task_type]:
356
- row.checkpoints_selector.disable()
357
- super().disable()
358
-
359
- def enable_table(self) -> None:
360
- for task_type in self._rows:
361
- for row in self._rows[task_type]:
362
- row.checkpoints_selector.enable()
363
- super().enable()
547
+ def _get_project_infos_map(
548
+ self, experiment_infos: List[ExperimentInfo]
549
+ ) -> Dict[int, ProjectInfo]:
550
+ """
551
+ Returns a map of project IDs to project infos used in the experiment infos.
552
+ """
553
+ project_ids = set()
554
+ for experiment_info in experiment_infos:
555
+ if experiment_info.project_id is not None:
556
+ project_ids.add(experiment_info.project_id)
557
+ project_ids = list(project_ids)
558
+
559
+ project_infos_map = {}
560
+ if project_ids is not None:
561
+ for batch in batched(project_ids):
562
+ filters = [
563
+ {
564
+ ApiField.FIELD: ApiField.ID,
565
+ ApiField.OPERATOR: "in",
566
+ ApiField.VALUE: batch,
567
+ },
568
+ ]
364
569
 
365
- def enable(self):
366
- self.enable_table()
367
- super().enable()
570
+ fields = [ApiField.IMAGES_COUNT, ApiField.REFERENCE_IMAGE_URL]
571
+ batch_infos = self.api.project.get_list(
572
+ team_id=self.team_id,
573
+ filters=filters,
574
+ fields=fields,
575
+ )
576
+ for info in batch_infos:
577
+ project_infos_map[info.id] = info
368
578
 
369
- def disable(self) -> None:
370
- self.disable_table()
371
- super().disable()
579
+ return project_infos_map
372
580
 
373
- def _generate_table_rows(
374
- self, experiment_infos: List[ExperimentInfo]
375
- ) -> Dict[str, List[ModelRow]]:
581
+ def _generate_table_rows(self, experiment_infos: List[ExperimentInfo]) -> List[ModelRow]:
376
582
  """Method to generate table rows from remote path to training app save directory"""
377
583
 
378
584
  def process_experiment_info(experiment_info: ExperimentInfo):
379
585
  try:
586
+ logger.debug(f"Processing experiment info: {experiment_info.task_id}")
587
+ project_info = self._project_infos_map.get(experiment_info.project_id)
380
588
  model_row = ExperimentSelector.ModelRow(
381
- api=self._api,
382
- team_id=self._team_id,
589
+ api=self.api,
590
+ team_id=self.team_id,
383
591
  task_type=experiment_info.task_type,
384
592
  experiment_info=experiment_info,
593
+ project_info=project_info,
385
594
  )
595
+
596
+ def this_row_checkpoint_changed(checkpoint_value: str):
597
+ self._checkpoint_changed(model_row, checkpoint_value)
598
+
599
+ model_row.checkpoint_changed = this_row_checkpoint_changed
386
600
  return experiment_info.task_type, model_row
387
601
  except Exception as e:
388
602
  logger.debug(f"Failed to process experiment info: {experiment_info}")
389
603
  return None, None
390
604
 
391
- table_rows = defaultdict(list)
392
- with ThreadPoolExecutor() as executor:
393
- futures = {
394
- executor.submit(process_experiment_info, experiment_info): experiment_info
605
+ table_rows = []
606
+ with ThreadPoolExecutor(max_workers=10) as executor:
607
+ futures = [
608
+ executor.submit(process_experiment_info, experiment_info)
395
609
  for experiment_info in experiment_infos
396
- }
610
+ ]
397
611
 
398
- for future in as_completed(futures):
612
+ for future in futures:
399
613
  result = future.result()
400
614
  if result:
401
615
  task_type, model_row = result
402
616
  if task_type is not None and model_row is not None:
403
- if model_row.task_id == "debug-session" or model_row.task_id == -1:
404
- self.__debug_row = (task_type, model_row)
405
- continue
406
- table_rows[task_type].append(model_row)
407
- self._sort_table_rows(table_rows)
408
- if self.__debug_row and is_development():
409
- task_type, model_row = self.__debug_row
410
- table_rows[task_type].insert(0, model_row)
411
- return table_rows
617
+ table_rows.append(model_row)
412
618
 
413
- def _sort_table_rows(self, table_rows: Dict[str, List[ModelRow]]) -> None:
414
- for task_type in table_rows:
415
- table_rows[task_type].sort(key=lambda row: int(row.task_id), reverse=True)
416
-
417
- def _filter_task_types(self, task_types: List[str]):
418
- sorted_tt = []
419
- if "object detection" in task_types:
420
- sorted_tt.append("object detection")
421
- if "instance segmentation" in task_types:
422
- sorted_tt.append("instance segmentation")
423
- if "pose estimation" in task_types:
424
- sorted_tt.append("pose estimation")
425
- other_tasks = sorted(
426
- set(task_types)
427
- - set(
428
- [
429
- "object detection",
430
- "instance segmentation",
431
- "semantic segmentation",
432
- "pose estimation",
433
- ]
434
- )
435
- )
436
- sorted_tt.extend(other_tasks)
437
- return sorted_tt
619
+ table_rows.sort(key=lambda x: x.task_id, reverse=True)
620
+ return table_rows
438
621
 
439
- def get_selected_row(self, state=StateJson()) -> Union[ModelRow, None]:
440
- if len(self._rows) == 0:
441
- return
442
- widget_actual_state = state[self.widget_id]
443
- widget_actual_data = DataJson()[self.widget_id]
444
- task_type = widget_actual_state["selectedTaskType"]
445
- if widget_actual_state is not None and widget_actual_data is not None:
446
- selected_row_index = int(widget_actual_state["selectedRow"])
447
- return self._rows[task_type][selected_row_index]
448
-
449
- def get_selected_row_index(self, state=StateJson()) -> Union[int, None]:
450
- widget_actual_state = state[self.widget_id]
451
- widget_actual_data = DataJson()[self.widget_id]
452
- if widget_actual_state is not None and widget_actual_data is not None:
453
- return widget_actual_state["selectedRow"]
454
-
455
- def get_selected_task_type(self) -> str:
456
- return StateJson()[self.widget_id]["selectedTaskType"]
457
-
458
- def get_selected_experiment_info(self) -> Dict[str, Any]:
459
- if len(self._rows) == 0:
460
- return
461
- selected_row = self.get_selected_row()
462
- selected_row_json = asdict(selected_row._experiment_info)
463
- return selected_row_json
622
+ def _update_search_text(self):
623
+ self._rows_search_texts = [row.search_text() for row in self._rows]
464
624
 
465
- def get_selected_checkpoint_path(self) -> str:
466
- if len(self._rows) == 0:
467
- return
468
- selected_row = self.get_selected_row()
469
- return selected_row.get_selected_checkpoint_path()
625
+ def _update_sort_values(self):
626
+ self._rows_sort_values = [row.sort_values() for row in self._rows]
470
627
 
471
- def get_selected_checkpoint_name(self) -> str:
472
- if len(self._rows) == 0:
473
- return
474
- selected_row = self.get_selected_row()
475
- return selected_row.get_selected_checkpoint_name()
628
+ def _update_value_index_map(self):
629
+ self._first_column_value_to_index = {}
630
+ for i, row in self.table._source_data.iterrows():
631
+ value = row.iloc[0]
632
+ self._first_column_value_to_index[value] = i
476
633
 
477
- def get_model_files(self) -> Dict[str, str]:
634
+ def set_experiment_infos(self, experiment_infos: List[ExperimentInfo]) -> None:
478
635
  """
479
- Returns a dictionary with full paths to model files in Supervisely Team Files.
636
+ Updates the experiment infos and regenerates the table rows.
480
637
  """
638
+ table_rows = self._generate_table_rows(experiment_infos)
639
+ self._rows = table_rows
640
+ for row in table_rows:
641
+ self.table.insert_row(row.to_table_row())
642
+ self._update_value_index_map()
643
+ self._update_search_text()
644
+ self._update_sort_values()
645
+
646
+ def get_selected_experiment_info(self) -> Union[ExperimentInfo, None]:
647
+ selected_row = self.table.get_selected_row()
648
+ if selected_row is None:
649
+ return None
650
+ return self._rows[selected_row.row_index]._experiment_info
651
+
652
+ def get_selected_experiment_info_json(self) -> Union[dict, None]:
481
653
  experiment_info = self.get_selected_experiment_info()
482
- artifacts_dir = experiment_info.get("artifacts_dir")
483
- model_files = experiment_info.get("model_files", {})
654
+ if experiment_info is None:
655
+ return None
656
+ return experiment_info.to_json()
657
+
658
+ def get_selected_checkpoint_name(self) -> Union[str, None]:
659
+ selected_row = self.table.get_selected_row()
660
+ if selected_row is None:
661
+ return None
662
+ return self._rows[selected_row.row_index].get_selected_checkpoint_name()
663
+
664
+ def get_selected_checkpoint_path(self) -> Union[str, None]:
665
+ selected_row = self.table.get_selected_row()
666
+ if selected_row is None:
667
+ return None
668
+ return self._rows[selected_row.row_index].get_selected_checkpoint_path()
669
+
670
+ def set_selected_row_by_experiment_info(self, experiment_info: ExperimentInfo) -> None:
671
+ for idx, row in enumerate(self._rows):
672
+ if row._experiment_info.task_id == experiment_info.task_id:
673
+ self.table.select_row(idx)
674
+ return
675
+ raise ValueError(f"Experiment info {experiment_info} not found in the table rows.")
676
+
677
+ def _checkpoint_changed(self, row: ModelRow, checkpoint_value: str):
678
+ if self._checkpoint_changed_func is None:
679
+ return
680
+ return self._checkpoint_changed_func(row, checkpoint_value)
484
681
 
485
- full_model_files = {
486
- name: os.path.join(artifacts_dir, file) for name, file in model_files.items()
487
- }
488
- full_model_files["checkpoint"] = self.get_selected_checkpoint_path()
489
- return full_model_files
682
+ def checkpoint_changed(self, func: Callable[[ModelRow, str], None]):
683
+ self._checkpoint_changed_func = func
684
+ return self._checkpoint_changed_func
490
685
 
491
- def get_deploy_params(self) -> Dict[str, Any]:
492
- """
493
- Returns a dictionary with deploy parameters except runtime and device keys.
494
- """
495
- deploy_params = {
496
- "model_source": ModelSource.CUSTOM,
497
- "model_files": self.get_model_files(),
498
- "model_info": self.get_selected_experiment_info(),
499
- }
500
- return deploy_params
501
-
502
- def set_active_row(self, row_index: int, task_type: str = None) -> None:
503
- if task_type is None:
504
- task_type = self.get_selected_task_type()
505
- self.set_active_task_type(task_type)
506
- if row_index < 0 or row_index > len(self._rows[task_type]) - 1:
507
- raise ValueError(f'Row with index "{row_index}" does not exist')
508
- StateJson()[self.widget_id]["selectedRow"] = row_index
509
- StateJson().send_changes()
510
-
511
- def set_by_task_id(self, task_id: int) -> None:
512
- for task_type in self._rows:
513
- for i, row in enumerate(self._rows[task_type]):
514
- if row.task_id == task_id:
515
- self.set_active_task_type(task_type)
516
- self.set_active_row(i, task_type)
517
- return
686
+ def selection_changed(self, func):
687
+ def f(selected_row: FastTable.ClickedRow):
688
+ if selected_row is None:
689
+ return
690
+ idx = selected_row.row_index
691
+ experiment_info = self._rows[idx]._experiment_info
692
+ func(experiment_info)
693
+
694
+ return self.table.selection_changed(f)
518
695
 
519
- def get_by_task_id(self, task_id: int) -> Union[ModelRow, None]:
520
- for task_type in self._rows:
521
- for row in self._rows[task_type]:
522
- if row.task_id == task_id:
523
- return row
696
+ def set_selected_checkpoint_by_name(self, checkpoint_name: str):
697
+ selected_row = self.table.get_selected_row()
698
+ if selected_row is None:
699
+ return
700
+ self._rows[selected_row.row_index].set_selected_checkpoint_by_name(checkpoint_name)
701
+
702
+ def set_selected_row_by_task_id(self, task_id: int):
703
+ for idx, row in enumerate(self._rows):
704
+ if row._experiment_info.task_id == task_id:
705
+ self.table.select_row(idx)
706
+ return
707
+ raise ValueError(f"Experiment info with task id {task_id} not found in the table rows.")
708
+
709
+ def get_selected_row_by_task_id(self, task_id: int):
710
+ for idx, row in enumerate(self._rows):
711
+ if row._experiment_info.task_id == task_id:
712
+ return row
524
713
  return None
525
714
 
526
- def task_type_changed(self, func: Callable):
527
- route_path = self.get_route_path(ExperimentSelector.Routes.TASK_TYPE_CHANGED)
528
- server = self._sly_app.get_server()
529
- self._task_type_changes_handled = True
715
+ def search(self, search_value: str):
716
+ self.table.search(search_value)
530
717
 
531
- @server.post(route_path)
532
- def _task_type_changed():
533
- res = self.get_selected_task_type()
534
- func(res)
718
+ def disable(self):
719
+ return self.table.disable()
535
720
 
536
- return _task_type_changed
721
+ def enable(self):
722
+ return self.table.enable()
537
723
 
538
- def value_changed(self, func: Callable):
539
- route_path = self.get_route_path(ExperimentSelector.Routes.VALUE_CHANGED)
540
- server = self._sly_app.get_server()
541
- self._changes_handled = True
724
+ def get_json_data(self):
725
+ return {}
542
726
 
543
- @server.post(route_path)
544
- def _value_changed():
545
- res = self.get_selected_row()
546
- func(res)
727
+ def get_json_state(self):
728
+ return {}
547
729
 
548
- return _value_changed
730
+ def to_html(self):
731
+ return self.table.to_html()