supervisely 6.73.410__py3-none-any.whl → 6.73.470__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of supervisely might be problematic. Click here for more details.

Files changed (190) hide show
  1. supervisely/__init__.py +136 -1
  2. supervisely/_utils.py +81 -0
  3. supervisely/annotation/json_geometries_map.py +2 -0
  4. supervisely/annotation/label.py +80 -3
  5. supervisely/api/annotation_api.py +9 -9
  6. supervisely/api/api.py +67 -43
  7. supervisely/api/app_api.py +72 -5
  8. supervisely/api/dataset_api.py +108 -33
  9. supervisely/api/entity_annotation/figure_api.py +113 -49
  10. supervisely/api/image_api.py +82 -0
  11. supervisely/api/module_api.py +10 -0
  12. supervisely/api/nn/deploy_api.py +15 -9
  13. supervisely/api/nn/ecosystem_models_api.py +201 -0
  14. supervisely/api/nn/neural_network_api.py +12 -3
  15. supervisely/api/pointcloud/pointcloud_api.py +38 -0
  16. supervisely/api/pointcloud/pointcloud_episode_annotation_api.py +3 -0
  17. supervisely/api/project_api.py +213 -6
  18. supervisely/api/task_api.py +11 -1
  19. supervisely/api/video/video_annotation_api.py +4 -2
  20. supervisely/api/video/video_api.py +79 -1
  21. supervisely/api/video/video_figure_api.py +24 -11
  22. supervisely/api/volume/volume_api.py +38 -0
  23. supervisely/app/__init__.py +1 -1
  24. supervisely/app/content.py +14 -6
  25. supervisely/app/fastapi/__init__.py +1 -0
  26. supervisely/app/fastapi/custom_static_files.py +1 -1
  27. supervisely/app/fastapi/multi_user.py +88 -0
  28. supervisely/app/fastapi/subapp.py +175 -42
  29. supervisely/app/fastapi/templating.py +1 -1
  30. supervisely/app/fastapi/websocket.py +77 -9
  31. supervisely/app/singleton.py +21 -0
  32. supervisely/app/v1/app_service.py +18 -2
  33. supervisely/app/v1/constants.py +7 -1
  34. supervisely/app/widgets/__init__.py +11 -1
  35. supervisely/app/widgets/agent_selector/template.html +1 -0
  36. supervisely/app/widgets/card/card.py +20 -0
  37. supervisely/app/widgets/dataset_thumbnail/dataset_thumbnail.py +11 -2
  38. supervisely/app/widgets/dataset_thumbnail/template.html +3 -1
  39. supervisely/app/widgets/deploy_model/deploy_model.py +750 -0
  40. supervisely/app/widgets/dialog/dialog.py +12 -0
  41. supervisely/app/widgets/dialog/template.html +2 -1
  42. supervisely/app/widgets/dropdown_checkbox_selector/__init__.py +0 -0
  43. supervisely/app/widgets/dropdown_checkbox_selector/dropdown_checkbox_selector.py +87 -0
  44. supervisely/app/widgets/dropdown_checkbox_selector/template.html +12 -0
  45. supervisely/app/widgets/ecosystem_model_selector/__init__.py +0 -0
  46. supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +195 -0
  47. supervisely/app/widgets/experiment_selector/experiment_selector.py +454 -263
  48. supervisely/app/widgets/fast_table/fast_table.py +713 -126
  49. supervisely/app/widgets/fast_table/script.js +492 -95
  50. supervisely/app/widgets/fast_table/style.css +54 -0
  51. supervisely/app/widgets/fast_table/template.html +45 -5
  52. supervisely/app/widgets/heatmap/__init__.py +0 -0
  53. supervisely/app/widgets/heatmap/heatmap.py +523 -0
  54. supervisely/app/widgets/heatmap/script.js +378 -0
  55. supervisely/app/widgets/heatmap/style.css +227 -0
  56. supervisely/app/widgets/heatmap/template.html +21 -0
  57. supervisely/app/widgets/input_tag/input_tag.py +102 -15
  58. supervisely/app/widgets/input_tag_list/__init__.py +0 -0
  59. supervisely/app/widgets/input_tag_list/input_tag_list.py +274 -0
  60. supervisely/app/widgets/input_tag_list/template.html +70 -0
  61. supervisely/app/widgets/radio_table/radio_table.py +10 -2
  62. supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
  63. supervisely/app/widgets/radio_tabs/template.html +1 -0
  64. supervisely/app/widgets/select/select.py +6 -4
  65. supervisely/app/widgets/select_dataset/select_dataset.py +6 -0
  66. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +83 -7
  67. supervisely/app/widgets/table/table.py +68 -13
  68. supervisely/app/widgets/tabs/tabs.py +22 -6
  69. supervisely/app/widgets/tabs/template.html +5 -1
  70. supervisely/app/widgets/transfer/style.css +3 -0
  71. supervisely/app/widgets/transfer/template.html +3 -1
  72. supervisely/app/widgets/transfer/transfer.py +48 -45
  73. supervisely/app/widgets/tree_select/tree_select.py +2 -0
  74. supervisely/convert/image/csv/csv_converter.py +24 -15
  75. supervisely/convert/pointcloud/nuscenes_conv/nuscenes_converter.py +43 -41
  76. supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_converter.py +75 -51
  77. supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_helper.py +137 -124
  78. supervisely/convert/video/video_converter.py +2 -2
  79. supervisely/geometry/polyline_3d.py +110 -0
  80. supervisely/io/env.py +161 -1
  81. supervisely/nn/artifacts/__init__.py +1 -1
  82. supervisely/nn/artifacts/artifacts.py +10 -2
  83. supervisely/nn/artifacts/detectron2.py +1 -0
  84. supervisely/nn/artifacts/hrda.py +1 -0
  85. supervisely/nn/artifacts/mmclassification.py +20 -0
  86. supervisely/nn/artifacts/mmdetection.py +5 -3
  87. supervisely/nn/artifacts/mmsegmentation.py +1 -0
  88. supervisely/nn/artifacts/ritm.py +1 -0
  89. supervisely/nn/artifacts/rtdetr.py +1 -0
  90. supervisely/nn/artifacts/unet.py +1 -0
  91. supervisely/nn/artifacts/utils.py +3 -0
  92. supervisely/nn/artifacts/yolov5.py +2 -0
  93. supervisely/nn/artifacts/yolov8.py +1 -0
  94. supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +18 -18
  95. supervisely/nn/experiments.py +9 -0
  96. supervisely/nn/inference/cache.py +37 -17
  97. supervisely/nn/inference/gui/serving_gui_template.py +39 -13
  98. supervisely/nn/inference/inference.py +953 -211
  99. supervisely/nn/inference/inference_request.py +15 -8
  100. supervisely/nn/inference/instance_segmentation/instance_segmentation.py +1 -0
  101. supervisely/nn/inference/object_detection/object_detection.py +1 -0
  102. supervisely/nn/inference/predict_app/__init__.py +0 -0
  103. supervisely/nn/inference/predict_app/gui/__init__.py +0 -0
  104. supervisely/nn/inference/predict_app/gui/classes_selector.py +160 -0
  105. supervisely/nn/inference/predict_app/gui/gui.py +915 -0
  106. supervisely/nn/inference/predict_app/gui/input_selector.py +344 -0
  107. supervisely/nn/inference/predict_app/gui/model_selector.py +77 -0
  108. supervisely/nn/inference/predict_app/gui/output_selector.py +179 -0
  109. supervisely/nn/inference/predict_app/gui/preview.py +93 -0
  110. supervisely/nn/inference/predict_app/gui/settings_selector.py +881 -0
  111. supervisely/nn/inference/predict_app/gui/tags_selector.py +110 -0
  112. supervisely/nn/inference/predict_app/gui/utils.py +399 -0
  113. supervisely/nn/inference/predict_app/predict_app.py +176 -0
  114. supervisely/nn/inference/session.py +47 -39
  115. supervisely/nn/inference/tracking/bbox_tracking.py +5 -1
  116. supervisely/nn/inference/tracking/point_tracking.py +5 -1
  117. supervisely/nn/inference/tracking/tracker_interface.py +4 -0
  118. supervisely/nn/inference/uploader.py +9 -5
  119. supervisely/nn/model/model_api.py +44 -22
  120. supervisely/nn/model/prediction.py +15 -1
  121. supervisely/nn/model/prediction_session.py +70 -14
  122. supervisely/nn/prediction_dto.py +7 -0
  123. supervisely/nn/tracker/__init__.py +6 -8
  124. supervisely/nn/tracker/base_tracker.py +54 -0
  125. supervisely/nn/tracker/botsort/__init__.py +1 -0
  126. supervisely/nn/tracker/botsort/botsort_config.yaml +30 -0
  127. supervisely/nn/tracker/botsort/osnet_reid/__init__.py +0 -0
  128. supervisely/nn/tracker/botsort/osnet_reid/osnet.py +566 -0
  129. supervisely/nn/tracker/botsort/osnet_reid/osnet_reid_interface.py +88 -0
  130. supervisely/nn/tracker/botsort/tracker/__init__.py +0 -0
  131. supervisely/nn/tracker/{bot_sort → botsort/tracker}/basetrack.py +1 -2
  132. supervisely/nn/tracker/{utils → botsort/tracker}/gmc.py +51 -59
  133. supervisely/nn/tracker/{deep_sort/deep_sort → botsort/tracker}/kalman_filter.py +71 -33
  134. supervisely/nn/tracker/botsort/tracker/matching.py +202 -0
  135. supervisely/nn/tracker/{bot_sort/bot_sort.py → botsort/tracker/mc_bot_sort.py} +68 -81
  136. supervisely/nn/tracker/botsort_tracker.py +273 -0
  137. supervisely/nn/tracker/calculate_metrics.py +264 -0
  138. supervisely/nn/tracker/utils.py +273 -0
  139. supervisely/nn/tracker/visualize.py +520 -0
  140. supervisely/nn/training/gui/gui.py +152 -49
  141. supervisely/nn/training/gui/hyperparameters_selector.py +1 -1
  142. supervisely/nn/training/gui/model_selector.py +8 -6
  143. supervisely/nn/training/gui/train_val_splits_selector.py +144 -71
  144. supervisely/nn/training/gui/training_artifacts.py +3 -1
  145. supervisely/nn/training/train_app.py +225 -46
  146. supervisely/project/pointcloud_episode_project.py +12 -8
  147. supervisely/project/pointcloud_project.py +12 -8
  148. supervisely/project/project.py +221 -75
  149. supervisely/template/experiment/experiment.html.jinja +105 -55
  150. supervisely/template/experiment/experiment_generator.py +258 -112
  151. supervisely/template/experiment/header.html.jinja +31 -13
  152. supervisely/template/experiment/sly-style.css +7 -2
  153. supervisely/versions.json +3 -1
  154. supervisely/video/sampling.py +42 -20
  155. supervisely/video/video.py +41 -12
  156. supervisely/video_annotation/video_figure.py +38 -4
  157. supervisely/volume/stl_converter.py +2 -0
  158. supervisely/worker_api/agent_rpc.py +24 -1
  159. supervisely/worker_api/rpc_servicer.py +31 -7
  160. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/METADATA +22 -14
  161. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/RECORD +167 -148
  162. supervisely_lib/__init__.py +6 -1
  163. supervisely/app/widgets/experiment_selector/style.css +0 -27
  164. supervisely/app/widgets/experiment_selector/template.html +0 -61
  165. supervisely/nn/tracker/bot_sort/__init__.py +0 -21
  166. supervisely/nn/tracker/bot_sort/fast_reid_interface.py +0 -152
  167. supervisely/nn/tracker/bot_sort/matching.py +0 -127
  168. supervisely/nn/tracker/bot_sort/sly_tracker.py +0 -401
  169. supervisely/nn/tracker/deep_sort/__init__.py +0 -6
  170. supervisely/nn/tracker/deep_sort/deep_sort/__init__.py +0 -1
  171. supervisely/nn/tracker/deep_sort/deep_sort/detection.py +0 -49
  172. supervisely/nn/tracker/deep_sort/deep_sort/iou_matching.py +0 -81
  173. supervisely/nn/tracker/deep_sort/deep_sort/linear_assignment.py +0 -202
  174. supervisely/nn/tracker/deep_sort/deep_sort/nn_matching.py +0 -176
  175. supervisely/nn/tracker/deep_sort/deep_sort/track.py +0 -166
  176. supervisely/nn/tracker/deep_sort/deep_sort/tracker.py +0 -145
  177. supervisely/nn/tracker/deep_sort/deep_sort.py +0 -301
  178. supervisely/nn/tracker/deep_sort/generate_clip_detections.py +0 -90
  179. supervisely/nn/tracker/deep_sort/preprocessing.py +0 -70
  180. supervisely/nn/tracker/deep_sort/sly_tracker.py +0 -273
  181. supervisely/nn/tracker/tracker.py +0 -285
  182. supervisely/nn/tracker/utils/kalman_filter.py +0 -492
  183. supervisely/nn/tracking/__init__.py +0 -1
  184. supervisely/nn/tracking/boxmot.py +0 -114
  185. supervisely/nn/tracking/tracking.py +0 -24
  186. /supervisely/{nn/tracker/utils → app/widgets/deploy_model}/__init__.py +0 -0
  187. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/LICENSE +0 -0
  188. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/WHEEL +0 -0
  189. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/entry_points.txt +0 -0
  190. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/top_level.txt +0 -0
@@ -6,20 +6,21 @@ training workflows in Supervisely.
6
6
  """
7
7
 
8
8
  import os
9
+ import json
9
10
  from os import environ, getenv
10
11
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
11
12
 
12
- from supervisely import logger
13
- import supervisely.io.fs as sly_fs
14
13
  import supervisely.io.env as sly_env
14
+ import supervisely.io.fs as sly_fs
15
15
  import supervisely.io.json as sly_json
16
- from supervisely import Api, ProjectMeta
16
+ from supervisely import Api, ProjectMeta, logger
17
17
  from supervisely._utils import is_production
18
18
  from supervisely.app.widgets import Button, Card, Stepper, Widget
19
19
  from supervisely.geometry.bitmap import Bitmap
20
20
  from supervisely.geometry.graph import GraphNodes
21
21
  from supervisely.geometry.polygon import Polygon
22
22
  from supervisely.geometry.rectangle import Rectangle
23
+ from supervisely.nn.experiments import ExperimentInfo
23
24
  from supervisely.nn.task_type import TaskType
24
25
  from supervisely.nn.training.gui.classes_selector import ClassesSelector
25
26
  from supervisely.nn.training.gui.hyperparameters_selector import HyperparametersSelector
@@ -32,7 +33,6 @@ from supervisely.nn.training.gui.training_logs import TrainingLogs
32
33
  from supervisely.nn.training.gui.training_process import TrainingProcess
33
34
  from supervisely.nn.training.gui.utils import set_stepper_step, wrap_button_click
34
35
  from supervisely.nn.utils import ModelSource, RuntimeType
35
- from supervisely.nn.experiments import ExperimentInfo
36
36
 
37
37
 
38
38
  class StepFlow:
@@ -254,6 +254,7 @@ class TrainGUI:
254
254
  self.app_options = app_options
255
255
  self.collapsable = self.app_options.get("collapsable", False)
256
256
  self.need_convert_shapes = False
257
+ self._start_training = False
257
258
 
258
259
  self.team_id = sly_env.team_id(raise_not_found=False)
259
260
  self.workspace_id = sly_env.workspace_id(raise_not_found=False)
@@ -303,7 +304,9 @@ class TrainGUI:
303
304
  # 3. Classes selector
304
305
  self.classes_selector = None
305
306
  if self.show_classes_selector:
306
- self.classes_selector = ClassesSelector(self.project_id, [], self.model_selector, self.app_options)
307
+ self.classes_selector = ClassesSelector(
308
+ self.project_id, [], self.model_selector, self.app_options
309
+ )
307
310
  self.steps.append(self.classes_selector.card)
308
311
 
309
312
  # 4. Tags selector
@@ -355,16 +358,19 @@ class TrainGUI:
355
358
  experiment_name = "Enter experiment name"
356
359
  else:
357
360
  if self.task_id == -1:
358
- experiment_name = f"debug_{self.project_info.name}_{model_name}"
361
+ experiment_name = f"debug {self.project_info.name} {model_name}"
359
362
  else:
360
- experiment_name = f"{self.task_id}_{self.project_info.name}_{model_name}"
363
+ experiment_name = f"{self.task_id} {self.project_info.name} {model_name}"
361
364
 
362
365
  if experiment_name == self.training_process.get_experiment_name():
363
366
  return
364
367
  self.training_process.set_experiment_name(experiment_name)
365
368
 
366
369
  def need_convert_class_shapes() -> bool:
367
- if self.hyperparameters_selector.run_model_benchmark_checkbox is None or not self.hyperparameters_selector.run_model_benchmark_checkbox.is_checked():
370
+ if (
371
+ self.hyperparameters_selector.run_model_benchmark_checkbox is None
372
+ or not self.hyperparameters_selector.run_model_benchmark_checkbox.is_checked()
373
+ ):
368
374
  self.hyperparameters_selector.model_benchmark_auto_convert_warning.hide()
369
375
  self.need_convert_shapes = False
370
376
  return False
@@ -376,14 +382,22 @@ class TrainGUI:
376
382
 
377
383
  # Exclude classes with no annotations to avoid unnecessary conversion
378
384
  data = self.classes_selector.classes_table._table_data
379
- empty_classes = {r[0]["data"] for r in data if r[2]["data"] == 0 and r[3]["data"] == 0}
385
+ empty_classes = {
386
+ r[0]["data"] for r in data if r[2]["data"] == 0 and r[3]["data"] == 0
387
+ }
380
388
  need_conversion = bool(wrong_shapes - empty_classes)
381
389
  else:
382
390
  # Classes selector disabled – check entire project meta
383
391
  if task_type == TaskType.OBJECT_DETECTION:
384
- need_conversion = any(obj_cls.geometry_type != Rectangle for obj_cls in self.project_meta.obj_classes)
392
+ need_conversion = any(
393
+ obj_cls.geometry_type != Rectangle
394
+ for obj_cls in self.project_meta.obj_classes
395
+ )
385
396
  elif task_type in [TaskType.INSTANCE_SEGMENTATION, TaskType.SEMANTIC_SEGMENTATION]:
386
- need_conversion = any(obj_cls.geometry_type == Polygon for obj_cls in self.project_meta.obj_classes)
397
+ need_conversion = any(
398
+ obj_cls.geometry_type == Polygon
399
+ for obj_cls in self.project_meta.obj_classes
400
+ )
387
401
  else:
388
402
  need_conversion = False
389
403
 
@@ -394,6 +408,7 @@ class TrainGUI:
394
408
 
395
409
  self.need_convert_shapes = need_conversion
396
410
  return need_conversion
411
+
397
412
  # ------------------------------------------------- #
398
413
 
399
414
  self.step_flow = StepFlow(self.stepper, self.app_options)
@@ -420,7 +435,7 @@ class TrainGUI:
420
435
  self.model_selector.widgets_to_disable,
421
436
  self.model_selector.validator_text,
422
437
  self.model_selector.validate_step,
423
- position=position
438
+ position=position,
424
439
  ).add_on_select_actions("model_selector", [set_experiment_name])
425
440
  position += 1
426
441
 
@@ -517,7 +532,9 @@ class TrainGUI:
517
532
  has_model_selector = self.show_model_selector and self.model_selector is not None
518
533
  has_classes_selector = self.show_classes_selector and self.classes_selector is not None
519
534
  has_tags_selector = self.show_tags_selector and self.tags_selector is not None
520
- has_train_val_splits = self.show_train_val_splits_selector and self.train_val_splits_selector is not None
535
+ has_train_val_splits = (
536
+ self.show_train_val_splits_selector and self.train_val_splits_selector is not None
537
+ )
521
538
 
522
539
  # Set step dependency chain
523
540
  prev_step = "input_selector"
@@ -571,11 +588,13 @@ class TrainGUI:
571
588
  @self.hyperparameters_selector.run_model_benchmark_checkbox.value_changed
572
589
  def show_mb_speedtest(is_checked: bool):
573
590
  self.hyperparameters_selector.toggle_mb_speedtest(is_checked)
591
+
574
592
  # ------------------------------------------------- #
575
593
 
576
594
  self.layout: Widget = self.stepper
577
595
 
578
596
  # Run from experiment page
597
+
579
598
  train_task_id = getenv("modal.state.trainTaskId", None)
580
599
  if train_task_id is not None:
581
600
  train_task_id = int(train_task_id)
@@ -584,7 +603,6 @@ class TrainGUI:
584
603
  self._run_from_experiment(train_task_id, train_mode)
585
604
  # ----------------------------------------- #
586
605
 
587
-
588
606
  def set_next_step(self):
589
607
  current_step = self.stepper.get_active_step()
590
608
  self.stepper.set_active_step(current_step + 1)
@@ -605,6 +623,8 @@ class TrainGUI:
605
623
  """
606
624
  if self.input_selector is not None:
607
625
  self.input_selector.button.enable()
626
+ if self.model_selector is not None:
627
+ self.model_selector.button.enable()
608
628
  if self.train_val_splits_selector is not None:
609
629
  self.train_val_splits_selector.button.enable()
610
630
  if self.classes_selector is not None:
@@ -622,6 +642,8 @@ class TrainGUI:
622
642
  """
623
643
  if self.input_selector is not None:
624
644
  self.input_selector.button.disable()
645
+ if self.model_selector is not None:
646
+ self.model_selector.button.disable()
625
647
  if self.train_val_splits_selector is not None:
626
648
  self.train_val_splits_selector.button.disable()
627
649
  if self.classes_selector is not None:
@@ -775,7 +797,9 @@ class TrainGUI:
775
797
  )
776
798
  return app_state
777
799
 
778
- def load_from_app_state(self, app_state: Union[str, dict], click_cb: bool = True, validate_steps: bool = True) -> None:
800
+ def load_from_app_state(
801
+ self, app_state: Union[str, dict], click_cb: bool = True, validate_steps: bool = True
802
+ ) -> None:
779
803
  """
780
804
  Load the GUI state from app state dictionary or path to the state file.
781
805
 
@@ -820,26 +844,25 @@ class TrainGUI:
820
844
  "TensorRT": True
821
845
  },
822
846
  },
823
- "experiment_name": "my_experiment",
847
+ "experiment_name": "My Experiment",
848
+ "start_training": False,
824
849
  }
825
850
  """
826
851
  if isinstance(app_state, str):
827
- app_state = sly_json.load_json_file(app_state)
828
-
852
+ if os.path.isfile(app_state):
853
+ app_state = sly_json.load_json_file(app_state)
854
+ else:
855
+ app_state = json.loads(app_state)
856
+
829
857
  app_state = self.validate_app_state(app_state)
830
858
  options = app_state.get("options", {})
831
-
832
- # Set experiment name
833
- experiment_name = app_state.get("experiment_name")
834
- if experiment_name is not None:
835
- self.training_process.set_experiment_name(experiment_name)
836
859
 
837
860
  # Run init-steps and stop on validation failure
838
861
  def _run_step(init_fn, settings) -> bool:
839
862
  if not init_fn(settings, options, click_cb, validate_steps):
840
863
  return False
841
864
  return True
842
-
865
+
843
866
  # GUI init steps
844
867
  _steps = [
845
868
  (self._init_input, app_state.get("input"), "Input project"),
@@ -856,12 +879,28 @@ class TrainGUI:
856
879
  logger.warning(f"Step '{step_name}' {idx}/{len(_steps)} failed to validate")
857
880
  return
858
881
  if validate_steps:
859
- logger.info(f"Step '{step_name}' {idx}/{len(_steps)} has been validated successfully")
882
+ logger.info(
883
+ f"Step '{step_name}' {idx}/{len(_steps)} has been validated successfully"
884
+ )
885
+
886
+ # Set experiment name
887
+ experiment_name = app_state.get("experiment_name")
888
+ if experiment_name is not None and experiment_name != "":
889
+ self.training_process.set_experiment_name(experiment_name)
890
+
860
891
  if validate_steps:
861
892
  logger.info(f"All steps have been validated successfully")
893
+
894
+ self._start_training = app_state.get("start_training", False)
862
895
  # ------------------------------------------------------------------ #
863
896
 
864
- def _init_input(self, input_settings: Union[dict, None], options: dict, click_cb: bool = True, validate: bool = True) -> bool:
897
+ def _init_input(
898
+ self,
899
+ input_settings: Union[dict, None],
900
+ options: dict,
901
+ click_cb: bool = True,
902
+ validate: bool = True,
903
+ ) -> bool:
865
904
  """
866
905
  Initialize the input selector with the given settings.
867
906
 
@@ -885,7 +924,13 @@ class TrainGUI:
885
924
  return is_valid
886
925
  # ----------------------------------------- #
887
926
 
888
- def _init_model(self, model_settings: dict, options: dict = None, click_cb: bool = True, validate: bool = True) -> bool:
927
+ def _init_model(
928
+ self,
929
+ model_settings: dict,
930
+ options: dict = None,
931
+ click_cb: bool = True,
932
+ validate: bool = True,
933
+ ) -> bool:
889
934
  """
890
935
  Initialize the model selector with the given settings.
891
936
 
@@ -909,14 +954,18 @@ class TrainGUI:
909
954
  # Custom
910
955
  elif model_settings["source"] == ModelSource.CUSTOM:
911
956
  self.model_selector.model_source_tabs.set_active_tab(ModelSource.CUSTOM)
912
- self.model_selector.experiment_selector.set_by_task_id(model_settings["task_id"])
913
- active_row = self.model_selector.experiment_selector.get_selected_row()
914
- if model_settings["checkpoint"] not in active_row.checkpoints_names:
915
- raise ValueError(
916
- f"Checkpoint '{model_settings['checkpoint']}' not found in selected task"
917
- )
918
-
919
- active_row.set_selected_checkpoint_by_name(model_settings["checkpoint"])
957
+ self.model_selector.experiment_selector.set_selected_row_by_task_id(
958
+ model_settings["task_id"]
959
+ )
960
+ experiment_info = self.model_selector.experiment_selector.get_selected_experiment_info()
961
+ if model_settings["checkpoint"] not in experiment_info.checkpoints:
962
+ if f"checkpoints/{model_settings['checkpoint']}" not in experiment_info.checkpoints:
963
+ raise ValueError(
964
+ f"Checkpoint '{model_settings['checkpoint']}' not found in selected task"
965
+ )
966
+ self.model_selector.experiment_selector.set_selected_checkpoint_by_name(
967
+ model_settings["checkpoint"]
968
+ )
920
969
 
921
970
  is_valid = True
922
971
  if validate:
@@ -926,8 +975,10 @@ class TrainGUI:
926
975
  self.set_next_step()
927
976
  return is_valid
928
977
  # ----------------------------------------- #
929
-
930
- def _init_classes(self, classes_settings: list, options: dict, click_cb: bool = True, validate: bool = True) -> bool:
978
+
979
+ def _init_classes(
980
+ self, classes_settings: list, options: dict, click_cb: bool = True, validate: bool = True
981
+ ) -> bool:
931
982
  """
932
983
  Initialize the classes selector with the given settings.
933
984
 
@@ -941,13 +992,20 @@ class TrainGUI:
941
992
  :type validate: bool
942
993
  """
943
994
  if self.classes_selector is None:
944
- return True # Selector disabled by app options
995
+ return True # Selector disabled by app options
945
996
 
946
997
  convert_class_shapes = options.get("convert_class_shapes", True)
947
998
  if convert_class_shapes:
948
999
  self.classes_selector.convert_class_shapes_checkbox.check()
949
1000
 
950
1001
  # Set Classes
1002
+ if all(isinstance(c, int) for c in classes_settings):
1003
+ project_classes = []
1004
+ for obj_class in self.project_meta.obj_classes:
1005
+ if obj_class.sly_id in classes_settings:
1006
+ project_classes.append(obj_class.name)
1007
+ classes_settings = project_classes
1008
+
951
1009
  self.classes_selector.set_classes(classes_settings)
952
1010
  is_valid = True
953
1011
  if validate:
@@ -958,7 +1016,9 @@ class TrainGUI:
958
1016
  return is_valid
959
1017
  # ----------------------------------------- #
960
1018
 
961
- def _init_tags(self, tags_settings: list, options: dict, click_cb: bool = True, validate: bool = True) -> bool:
1019
+ def _init_tags(
1020
+ self, tags_settings: list, options: dict, click_cb: bool = True, validate: bool = True
1021
+ ) -> bool:
962
1022
  """
963
1023
  Initialize the tags selector with the given settings.
964
1024
 
@@ -972,7 +1032,7 @@ class TrainGUI:
972
1032
  :type validate: bool
973
1033
  """
974
1034
  if self.tags_selector is None:
975
- return True # Selector disabled by app options
1035
+ return True # Selector disabled by app options
976
1036
 
977
1037
  # Set Tags
978
1038
  self.tags_selector.set_tags(tags_settings)
@@ -985,7 +1045,13 @@ class TrainGUI:
985
1045
  return is_valid
986
1046
  # ----------------------------------------- #
987
1047
 
988
- def _init_train_val_splits(self, train_val_splits_settings: dict, options: dict, click_cb: bool = True, validate: bool = True) -> bool:
1048
+ def _init_train_val_splits(
1049
+ self,
1050
+ train_val_splits_settings: dict,
1051
+ options: dict,
1052
+ click_cb: bool = True,
1053
+ validate: bool = True,
1054
+ ) -> bool:
989
1055
  """
990
1056
  Initialize the train/val splits selector with the given settings.
991
1057
 
@@ -999,7 +1065,7 @@ class TrainGUI:
999
1065
  :type validate: bool
1000
1066
  """
1001
1067
  if self.train_val_splits_selector is None:
1002
- return True # Selector disabled by app options
1068
+ return True # Selector disabled by app options
1003
1069
 
1004
1070
  if train_val_splits_settings == {}:
1005
1071
  available_methods = self.app_options.get("train_val_splits_methods", [])
@@ -1059,8 +1125,14 @@ class TrainGUI:
1059
1125
  self.train_val_splits_selector_cb()
1060
1126
  self.set_next_step()
1061
1127
  return is_valid
1062
-
1063
- def _init_hyperparameters(self, hyperparameters_settings: dict, options: dict, click_cb: bool = True, validate: bool = True) -> bool:
1128
+
1129
+ def _init_hyperparameters(
1130
+ self,
1131
+ hyperparameters_settings: dict,
1132
+ options: dict,
1133
+ click_cb: bool = True,
1134
+ validate: bool = True,
1135
+ ) -> bool:
1064
1136
  """
1065
1137
  Initialize the hyperparameters selector with the given settings.
1066
1138
 
@@ -1101,6 +1173,7 @@ class TrainGUI:
1101
1173
  self.hyperparameters_selector_cb()
1102
1174
  self.set_next_step()
1103
1175
  return is_valid
1176
+
1104
1177
  # ----------------------------------------- #
1105
1178
 
1106
1179
  # Run from experiment page
@@ -1111,10 +1184,12 @@ class TrainGUI:
1111
1184
  app_state = sly_json.load_json_file(local_app_state_path)
1112
1185
  sly_fs.silent_remove(local_app_state_path)
1113
1186
  return app_state
1114
-
1187
+
1115
1188
  def _download_experiment_hparams(self, experiment_info: ExperimentInfo) -> dict:
1116
1189
  local_hparams_path = f"./{experiment_info.hyperparameters}"
1117
- remote_hparams_path = os.path.join(experiment_info.artifacts_dir, experiment_info.hyperparameters)
1190
+ remote_hparams_path = os.path.join(
1191
+ experiment_info.artifacts_dir, experiment_info.hyperparameters
1192
+ )
1118
1193
  self._api.file.download(self.team_id, remote_hparams_path, local_hparams_path)
1119
1194
  with open(local_hparams_path, "r") as f:
1120
1195
  hparams = f.read()
@@ -1129,11 +1204,14 @@ class TrainGUI:
1129
1204
  model_settings = {
1130
1205
  "source": ModelSource.CUSTOM,
1131
1206
  "task_id": train_task_id,
1132
- "checkpoint": experiment_info.best_checkpoint
1207
+ "checkpoint": experiment_info.best_checkpoint,
1133
1208
  }
1134
1209
 
1135
1210
  if experiment_state is not None:
1136
- self.input_selector.validator_text.set(f"Training configuration is loaded from the experiment: {experiment_info.experiment_name}.", "success")
1211
+ self.input_selector.validator_text.set(
1212
+ f"Training configuration is loaded from the experiment: {experiment_info.experiment_name}.",
1213
+ "success",
1214
+ )
1137
1215
  self.input_selector.validator_text.show()
1138
1216
  experiment_state = self._download_experiment_state(experiment_info)
1139
1217
  if train_mode == "continue":
@@ -1142,7 +1220,7 @@ class TrainGUI:
1142
1220
  else:
1143
1221
  self.input_selector.validator_text.set(
1144
1222
  f"Couldn't load full training configuration from the experiment: {experiment_info.experiment_name}. Only model and hyperparameters are loaded.",
1145
- "warning"
1223
+ "warning",
1146
1224
  )
1147
1225
  self.input_selector.validator_text.show()
1148
1226
  hparams = self._download_experiment_hparams(experiment_info)
@@ -1150,3 +1228,28 @@ class TrainGUI:
1150
1228
  if train_mode == "continue":
1151
1229
  self._init_model(model_settings, {}, click_cb=False, validate=False)
1152
1230
  # ----------------------------------------- #
1231
+
1232
+ def _extract_state_from_env(self):
1233
+ import ast
1234
+ import os
1235
+
1236
+ base = "modal.state"
1237
+ state = {}
1238
+ for key, value in os.environ.items():
1239
+ state_part = state
1240
+ if key.startswith(base):
1241
+ key = key.replace(base + ".", "")
1242
+ parts = key.split(".")
1243
+ while len(parts) > 1:
1244
+ part = parts.pop(0)
1245
+ state_part.setdefault(part, {})
1246
+ state_part = state_part[part]
1247
+ part = parts.pop(0)
1248
+ if value and (value[0] == "[" or value.isdigit()):
1249
+ state_part[part] = ast.literal_eval(value)
1250
+ elif value in ["True", "true", "False", "false"]:
1251
+ state_part[part] = value in ["True", "true"]
1252
+ else:
1253
+ state_part[part] = value
1254
+ return state
1255
+ # ----------------------------------------- #
@@ -48,7 +48,7 @@ class HyperparametersSelector:
48
48
  self.run_model_benchmark_checkbox = Checkbox(
49
49
  content="Run Model Benchmark evaluation", checked=True
50
50
  )
51
- self.run_speedtest_checkbox = Checkbox(content="Run speed test", checked=True)
51
+ self.run_speedtest_checkbox = Checkbox(content="Run speed test", checked=False)
52
52
 
53
53
  self.model_benchmark_field = Field(
54
54
  title="Model Evaluation Benchmark",
@@ -26,6 +26,7 @@ class ModelSelector:
26
26
 
27
27
  def __init__(self, api: Api, framework: str, models: list, app_options: dict = {}):
28
28
  # Init widgets
29
+ self.api = api
29
30
  self.pretrained_models_table = None
30
31
  self.experiment_selector = None
31
32
  self.model_source_tabs = None
@@ -50,7 +51,7 @@ class ModelSelector:
50
51
 
51
52
  # GUI Components
52
53
  self.pretrained_models_table = PretrainedModelsSelector(self.models)
53
- experiment_infos = get_experiment_infos(api, self.team_id, framework)
54
+ experiment_infos = get_experiment_infos(self.api, self.team_id, framework)
54
55
  if self.app_options.get("legacy_checkpoints", False):
55
56
  try:
56
57
  framework_cls = FrameworkMapper.get_framework_cls(framework, self.team_id)
@@ -59,7 +60,7 @@ class ModelSelector:
59
60
  except:
60
61
  logger.warning(f"Legacy checkpoints are not available for '{framework}'")
61
62
 
62
- self.experiment_selector = ExperimentSelector(self.team_id, experiment_infos)
63
+ self.experiment_selector = ExperimentSelector(self.api, self.team_id, experiment_infos)
63
64
 
64
65
  tab_titles = []
65
66
  tab_descriptions = []
@@ -85,6 +86,7 @@ class ModelSelector:
85
86
  self.validator_text = Text("")
86
87
  self.validator_text.hide()
87
88
  self.button = Button("Select")
89
+
88
90
  self.display_widgets.extend([self.model_source_tabs, self.validator_text, self.button])
89
91
  # -------------------------------- #
90
92
 
@@ -118,14 +120,14 @@ class ModelSelector:
118
120
  model_name = _get_model_name(selected_row)
119
121
  else:
120
122
  selected_row = self.experiment_selector.get_selected_experiment_info()
121
- model_name = selected_row.get("model_name", None)
123
+ model_name = selected_row.model_name
122
124
  return model_name
123
125
 
124
126
  def get_model_info(self) -> dict:
125
127
  if self.get_model_source() == ModelSource.PRETRAINED:
126
128
  return self.pretrained_models_table.get_selected_row()
127
129
  else:
128
- return self.experiment_selector.get_selected_experiment_info()
130
+ return self.experiment_selector.get_selected_experiment_info().to_json()
129
131
 
130
132
  def get_checkpoint_name(self) -> str:
131
133
  if self.get_model_source() == ModelSource.PRETRAINED:
@@ -146,7 +148,7 @@ class ModelSelector:
146
148
  else:
147
149
  checkpoint_name = self.experiment_selector.get_selected_checkpoint_name()
148
150
  return checkpoint_name
149
-
151
+
150
152
  def get_checkpoint_link(self) -> str:
151
153
  if self.get_model_source() == ModelSource.PRETRAINED:
152
154
  selected_row = self.pretrained_models_table.get_selected_row()
@@ -182,4 +184,4 @@ class ModelSelector:
182
184
  if self.get_model_source() == ModelSource.PRETRAINED:
183
185
  return self.pretrained_models_table.get_selected_task_type()
184
186
  else:
185
- return self.experiment_selector.get_selected_task_type()
187
+ return self.experiment_selector.get_selected_experiment_info().task_type