supervisely 6.73.390__py3-none-any.whl → 6.73.392__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. supervisely/app/widgets/experiment_selector/experiment_selector.py +21 -3
  2. supervisely/app/widgets/experiment_selector/template.html +49 -70
  3. supervisely/app/widgets/report_thumbnail/report_thumbnail.py +19 -4
  4. supervisely/decorators/profile.py +20 -0
  5. supervisely/nn/benchmark/utils/detection/utlis.py +7 -0
  6. supervisely/nn/experiments.py +4 -0
  7. supervisely/nn/inference/gui/serving_gui_template.py +71 -11
  8. supervisely/nn/inference/inference.py +108 -6
  9. supervisely/nn/training/gui/classes_selector.py +246 -27
  10. supervisely/nn/training/gui/gui.py +318 -234
  11. supervisely/nn/training/gui/hyperparameters_selector.py +2 -2
  12. supervisely/nn/training/gui/model_selector.py +42 -1
  13. supervisely/nn/training/gui/tags_selector.py +1 -1
  14. supervisely/nn/training/gui/train_val_splits_selector.py +8 -7
  15. supervisely/nn/training/gui/training_artifacts.py +10 -1
  16. supervisely/nn/training/gui/training_process.py +17 -1
  17. supervisely/nn/training/train_app.py +227 -72
  18. supervisely/template/__init__.py +2 -0
  19. supervisely/template/base_generator.py +90 -0
  20. supervisely/template/experiment/__init__.py +0 -0
  21. supervisely/template/experiment/experiment.html.jinja +537 -0
  22. supervisely/template/experiment/experiment_generator.py +996 -0
  23. supervisely/template/experiment/header.html.jinja +154 -0
  24. supervisely/template/experiment/sidebar.html.jinja +240 -0
  25. supervisely/template/experiment/sly-style.css +397 -0
  26. supervisely/template/experiment/template.html.jinja +18 -0
  27. supervisely/template/extensions.py +172 -0
  28. supervisely/template/template_renderer.py +253 -0
  29. {supervisely-6.73.390.dist-info → supervisely-6.73.392.dist-info}/METADATA +3 -1
  30. {supervisely-6.73.390.dist-info → supervisely-6.73.392.dist-info}/RECORD +34 -23
  31. {supervisely-6.73.390.dist-info → supervisely-6.73.392.dist-info}/LICENSE +0 -0
  32. {supervisely-6.73.390.dist-info → supervisely-6.73.392.dist-info}/WHEEL +0 -0
  33. {supervisely-6.73.390.dist-info → supervisely-6.73.392.dist-info}/entry_points.txt +0 -0
  34. {supervisely-6.73.390.dist-info → supervisely-6.73.392.dist-info}/top_level.txt +0 -0
@@ -143,6 +143,14 @@ class ExperimentSelector(Widget):
143
143
  def checkpoints_selector(self) -> Select:
144
144
  return self._checkpoints_widget
145
145
 
146
+ @property
147
+ def experiment_info(self) -> ExperimentInfo:
148
+ return self._experiment_info
149
+
150
+ @property
151
+ def best_checkpoint(self) -> str:
152
+ return self.experiment_info.best_checkpoint
153
+
146
154
  @property
147
155
  def session_link(self) -> str:
148
156
  return self._session_link
@@ -460,6 +468,12 @@ class ExperimentSelector(Widget):
460
468
  selected_row = self.get_selected_row()
461
469
  return selected_row.get_selected_checkpoint_path()
462
470
 
471
+ def get_selected_checkpoint_name(self) -> str:
472
+ if len(self._rows) == 0:
473
+ return
474
+ selected_row = self.get_selected_row()
475
+ return selected_row.get_selected_checkpoint_name()
476
+
463
477
  def get_model_files(self) -> Dict[str, str]:
464
478
  """
465
479
  Returns a dictionary with full paths to model files in Supervisely Team Files.
@@ -485,8 +499,11 @@ class ExperimentSelector(Widget):
485
499
  }
486
500
  return deploy_params
487
501
 
488
- def set_active_row(self, row_index: int) -> None:
489
- if row_index < 0 or row_index > len(self._rows) - 1:
502
+ def set_active_row(self, row_index: int, task_type: str = None) -> None:
503
+ if task_type is None:
504
+ task_type = self.get_selected_task_type()
505
+ self.set_active_task_type(task_type)
506
+ if row_index < 0 or row_index > len(self._rows[task_type]) - 1:
490
507
  raise ValueError(f'Row with index "{row_index}" does not exist')
491
508
  StateJson()[self.widget_id]["selectedRow"] = row_index
492
509
  StateJson().send_changes()
@@ -495,7 +512,8 @@ class ExperimentSelector(Widget):
495
512
  for task_type in self._rows:
496
513
  for i, row in enumerate(self._rows[task_type]):
497
514
  if row.task_id == task_id:
498
- self.set_active_row(i)
515
+ self.set_active_task_type(task_type)
516
+ self.set_active_row(i, task_type)
499
517
  return
500
518
 
501
519
  def get_by_task_id(self, task_id: int) -> Union[ModelRow, None]:
@@ -1,82 +1,61 @@
1
- <link rel="stylesheet" href="./sly/css/app/widgets/custom_models_selector/style.css"/>
1
+ <link rel="stylesheet" href="./sly/css/app/widgets/custom_models_selector/style.css" />
2
2
 
3
- <div
4
- {% if widget._changes_handled == true %}
5
- @change="post('/{{{widget.widget_id}}}/value_changed')"
6
- {% endif %}
7
- >
3
+ <div {% if widget._changes_handled==true %} @change="post('/{{{widget.widget_id}}}/value_changed')" {% endif %}>
8
4
 
9
5
  <div v-if="Object.keys(data.{{{widget.widget_id}}}.rowsHtml).length === 0"> You don't have any custom models</div>
10
6
  <div v-else>
11
7
 
12
8
  <div v-if="data.{{{widget.widget_id}}}.taskTypes.length > 1">
13
- <sly-field
14
- title="Task Type"
15
- >
16
- <el-radio-group
17
- class="multi-line mt10"
18
- :value="state.{{{widget.widget_id}}}.selectedTaskType"
19
-
20
- {% if widget._task_type_changes_handled == true %}
21
- @input="(evt) => {state.{{{widget.widget_id}}}.selectedTaskType = evt; state.{{{widget.widget_id}}}.selectedRow = 0; post('/{{{widget.widget_id}}}/task_type_changed')}"
9
+ <sly-field title="Task Type">
10
+ <el-radio-group class="multi-line mt10" :value="state.{{{widget.widget_id}}}.selectedTaskType" {% if
11
+ widget._task_type_changes_handled==true %}
12
+ @input="(evt) => {state.{{{widget.widget_id}}}.selectedTaskType = evt; state.{{{widget.widget_id}}}.selectedRow = 0; post('/{{{widget.widget_id}}}/task_type_changed')}"
22
13
  {% else %}
23
- @input="(evt) => {state.{{{widget.widget_id}}}.selectedTaskType = evt; state.{{{widget.widget_id}}}.selectedRow = 0;}"
24
- {% endif %}
25
- >
26
-
27
- <el-radio
28
- v-for="(item, idx) in {{{widget._task_types}}}"
29
- :key="item"
30
- :label="item"
31
- >
32
- {{ item }}
14
+ @input="(evt) => {state.{{{widget.widget_id}}}.selectedTaskType = evt; state.{{{widget.widget_id}}}.selectedRow = 0;}"
15
+ {% endif %}>
16
+
17
+ <el-radio v-for="(item, idx) in {{{widget._task_types}}}" :key="item" :label="item">
18
+ {{ item }}
33
19
  </el-radio>
34
20
  </el-radio-group>
35
- </sly-field>
36
- </div>
21
+ </sly-field>
22
+ </div>
37
23
 
38
24
  <div>
39
-
40
- <table class="custom-models-selector-table">
41
- <thead>
42
- <tr>
43
- <th v-for="col in data.{{{widget.widget_id}}}.columns">
44
- <div> {{col}} </div>
45
- </th>
46
- </tr>
47
- </thead>
48
- <tbody>
49
- <tr v-for="row, ridx in data.{{{widget.widget_id}}}.rowsHtml[state.{{{widget.widget_id}}}.selectedTaskType]">
50
- <td v-for="col, vidx in row">
51
- <div v-if="vidx === 0" style="display: flex;">
52
- <el-radio
53
- style="display: flex;"
54
- v-model="state.{{{widget.widget_id}}}.selectedRow"
55
- :label="ridx"
56
- >&#8205;</el-radio>
57
-
58
- <sly-html-compiler :params="{ridx: ridx, vidx: vidx}" :template="col" :data="data" :state="state"></sly-html-compiler>
59
-
60
- </div>
61
-
62
- <div v-else>
63
25
 
64
- <sly-html-compiler :params="{ridx: ridx, vidx: vidx}" :template="col" :data="data" :state="state">
65
- </sly-html-compiler>
66
-
67
- </div>
68
-
69
- </td>
70
- </tr>
71
- </tbody>
72
- </table>
73
- </div>
74
- <div class="mt10" v-if="{{{widget.show_custom_checkpoint_path}}}"
75
- >
76
- {{{widget.show_custom_checkpoint_path_checkbox}}}
77
- <div class="mt10">
78
- {{{widget.custom_tab_widgets}}}
79
- </div>
80
- </div>
81
- </div>
82
- </div>
26
+ <table class="custom-models-selector-table">
27
+ <thead>
28
+ <tr>
29
+ <th v-for="col in data.{{{widget.widget_id}}}.columns">
30
+ <div> {{col}} </div>
31
+ </th>
32
+ </tr>
33
+ </thead>
34
+ <tbody>
35
+ <tr
36
+ v-for="row, ridx in data.{{{widget.widget_id}}}.rowsHtml[state.{{{widget.widget_id}}}.selectedTaskType]">
37
+ <td v-for="col, vidx in row">
38
+ <div v-if="vidx === 0" style="display: flex;">
39
+ <el-radio style="display: flex;" v-model="state.{{{widget.widget_id}}}.selectedRow"
40
+ :label="ridx">&#8205;</el-radio>
41
+
42
+ <sly-html-compiler :params="{ridx: ridx, vidx: vidx}" :template="col" :data="data"
43
+ :state="state"></sly-html-compiler>
44
+
45
+ </div>
46
+
47
+ <div v-else>
48
+
49
+ <sly-html-compiler :params="{ridx: ridx, vidx: vidx}" :template="col" :data="data"
50
+ :state="state">
51
+ </sly-html-compiler>
52
+
53
+ </div>
54
+
55
+ </td>
56
+ </tr>
57
+ </tbody>
58
+ </table>
59
+ </div>
60
+ </div>
61
+ </div>
@@ -1,4 +1,4 @@
1
- from typing import Optional
1
+ from typing import Literal, Optional
2
2
 
3
3
  from supervisely._utils import abs_url, is_debug_with_sly_net, is_development
4
4
  from supervisely.api.file_api import FileInfo
@@ -16,11 +16,13 @@ class ReportThumbnail(Widget):
16
16
  title: Optional[str] = None,
17
17
  color: Optional[str] = "#dcb0ff",
18
18
  bg_color: Optional[str] = "#faebff",
19
+ report_type: Literal["model_benchmark", "experiment"] = "model_benchmark",
19
20
  ):
20
21
  self._id: int = None
21
22
  self._info: FileInfo = None
22
23
  self._description: str = None
23
24
  self._url: str = None
25
+ self._report_type = report_type
24
26
  self._set_info(info)
25
27
  self._title = title
26
28
  self._color = color if _validate_hex_color(color) else "#dcb0ff"
@@ -54,13 +56,26 @@ class ReportThumbnail(Widget):
54
56
  return
55
57
  self._id = info.id
56
58
  self._info = info
57
- self._description = "Open the Model Benchmark evaluation report."
58
- lnk = f"/model-benchmark?id={info.id}"
59
+ if self._report_type == "model_benchmark":
60
+ self._description = "Open the Model Benchmark evaluation report."
61
+ lnk = f"/model-benchmark?id={info.id}"
62
+ elif self._report_type == "experiment":
63
+ self._description = "Open the Experiment report."
64
+ lnk = f"/nn/experiments/{info.id}"
65
+ else:
66
+ raise ValueError(f"Invalid report type: {self._report_type}")
67
+
59
68
  lnk = abs_url(lnk) if is_development() or is_debug_with_sly_net() else lnk
60
69
  # self._description = info.path
61
70
  self._url = lnk
62
71
 
63
- def set(self, info: FileInfo = None):
72
+ def set(
73
+ self,
74
+ info: FileInfo = None,
75
+ report_type: Optional[Literal["model_benchmark", "experiment"]] = None,
76
+ ):
77
+ if report_type is not None:
78
+ self._report_type = report_type
64
79
  self._set_info(info)
65
80
  self.update_data()
66
81
  DataJson().send_changes()
@@ -1,5 +1,6 @@
1
1
  import functools
2
2
  import time
3
+
3
4
  from supervisely.sly_logger import logger
4
5
 
5
6
 
@@ -21,6 +22,25 @@ def timeit(func):
21
22
  return wrapper_timer
22
23
 
23
24
 
25
+ def timeit_with_result(func):
26
+ """Measures execution time and stores it in function's 'elapsed' attribute."""
27
+
28
+ @functools.wraps(func)
29
+ def wrapper(*args, **kwargs):
30
+ start_time = time.perf_counter()
31
+ try:
32
+ result = func(*args, **kwargs)
33
+ return result
34
+ finally:
35
+ wrapper.elapsed = time.perf_counter() - start_time
36
+ logger.debug(
37
+ f"Function '{func.__name__}' finished in {wrapper.elapsed:.2f} seconds (≈ {wrapper.elapsed/60:.2f} minutes)"
38
+ )
39
+
40
+ wrapper.elapsed = None # Initial value before first call
41
+ return wrapper
42
+
43
+
24
44
  def update_fields(func):
25
45
  """Update state field after executing function"""
26
46
 
@@ -24,5 +24,12 @@ def read_coco_datasets(cocoGt_json, cocoDt_json):
24
24
  cocoGt = COCO()
25
25
  cocoGt.dataset = cocoGt_json
26
26
  cocoGt.createIndex()
27
+
28
+ # Fix key error in pycocotools
29
+ info = cocoGt.dataset.get("info", None)
30
+ if info is None:
31
+ cocoGt.dataset["info"] = {}
32
+ # ------------------------------ #
33
+
27
34
  cocoDt = cocoGt.loadRes(cocoDt_json["annotations"])
28
35
  return cocoGt, cocoDt
@@ -38,6 +38,10 @@ class ExperimentInfo:
38
38
  """Path to .yaml file with hyperparameters used in the experiment"""
39
39
  artifacts_dir: str
40
40
  """Path to the directory with artifacts"""
41
+ base_checkpoint: Optional[str] = None
42
+ """Name of the base checkpoint used for training"""
43
+ base_checkpoint_link: Optional[str] = None
44
+ """Link to the base checkpoint used for training. URL in case of pretrained model, or Team Files path in case of custom model."""
41
45
  export: Optional[dict] = None
42
46
  """Dictionary with exported weights in different formats"""
43
47
  app_state: Optional[str] = None
@@ -7,17 +7,8 @@ import supervisely.io.env as sly_env
7
7
  import supervisely.io.fs as sly_fs
8
8
  import supervisely.io.json as sly_json
9
9
  from supervisely import Api
10
- from supervisely.app.widgets import (
11
- Card,
12
- Container,
13
- Field,
14
- RadioTabs,
15
- SelectString,
16
- Widget,
17
- )
18
- from supervisely.app.widgets.experiment_selector.experiment_selector import (
19
- ExperimentSelector,
20
- )
10
+ from supervisely.app.widgets import Card, Container, Field, RadioTabs, SelectString, Text, Widget
11
+ from supervisely.app.widgets.experiment_selector.experiment_selector import ExperimentSelector
21
12
  from supervisely.app.widgets.pretrained_models_selector.pretrained_models_selector import (
22
13
  PretrainedModelsSelector,
23
14
  )
@@ -133,6 +124,27 @@ class ServingGUITemplate(ServingGUI):
133
124
  card_widgets = [self.model_source_tabs]
134
125
  if runtime_field is not None:
135
126
  card_widgets.append(runtime_field)
127
+
128
+ # Runtime exported checkpoint message
129
+ self._export_msg = Text("")
130
+ self._export_msg.hide()
131
+ card_widgets.append(self._export_msg)
132
+
133
+ if self.runtime_select is not None:
134
+ self.runtime_select.value_changed(lambda _: self._update_export_message())
135
+ if self.experiment_selector is not None:
136
+ self.experiment_selector.value_changed(lambda _: self._update_export_message())
137
+ for task_type in self.experiment_selector.rows:
138
+ for row in self.experiment_selector.rows[task_type]:
139
+ row.checkpoints_selector.value_changed(lambda _: self._update_export_message())
140
+ if self.pretrained_models_table is not None:
141
+ self.pretrained_models_table.model_changed(lambda _: self._update_export_message())
142
+
143
+ if self.model_source_tabs is not None:
144
+ self.model_source_tabs.value_changed(lambda _: self._update_export_message())
145
+
146
+ self._update_export_message()
147
+
136
148
  return card_widgets
137
149
 
138
150
  def _initialize_extra_widgets(self) -> List[Widget]:
@@ -204,3 +216,51 @@ class ServingGUITemplate(ServingGUI):
204
216
  elif self.model_source == ModelSource.CUSTOM and self.experiment_selector:
205
217
  return self.experiment_selector.get_selected_experiment_info()
206
218
  return {}
219
+
220
+ def _update_export_message(self):
221
+ self._export_msg.hide()
222
+
223
+ runtime = self.runtime
224
+ non_conversion_runtimes = [RuntimeType.ONNXRUNTIME, RuntimeType.TENSORRT]
225
+ if runtime not in non_conversion_runtimes:
226
+ return
227
+
228
+ if self.model_source == ModelSource.PRETRAINED:
229
+ self._export_msg.set(
230
+ "Checkpoint will be converted before deployment.",
231
+ "info",
232
+ )
233
+ self._export_msg.show()
234
+ return
235
+
236
+ checkpoint_name = None
237
+ if self.model_source == ModelSource.CUSTOM and self.experiment_selector is not None:
238
+ selected_row = self.experiment_selector.get_selected_row()
239
+ if selected_row is None:
240
+ return
241
+ checkpoint_name = selected_row.get_selected_checkpoint_name()
242
+ if checkpoint_name is None:
243
+ return
244
+
245
+ model_info = self.model_info or {}
246
+ export_info = model_info.get("export", {})
247
+ available = False
248
+ if isinstance(export_info, dict):
249
+ for key in export_info.keys():
250
+ if key.lower().startswith(runtime.lower()):
251
+ available = True
252
+ break
253
+ if checkpoint_name != selected_row.best_checkpoint:
254
+ available = False
255
+
256
+ if available:
257
+ self._export_msg.set(
258
+ "Runtime checkpoint exists – no conversion needed.",
259
+ "info",
260
+ )
261
+ else:
262
+ self._export_msg.set(
263
+ "Checkpoint will be converted before deployment.",
264
+ "info",
265
+ )
266
+ self._export_msg.show()
@@ -34,6 +34,7 @@ import supervisely.io.env as sly_env
34
34
  import supervisely.io.fs as sly_fs
35
35
  import supervisely.io.json as sly_json
36
36
  import supervisely.nn.inference.gui as GUI
37
+ from supervisely.nn.experiments import ExperimentInfo
37
38
  from supervisely import DatasetInfo, batched
38
39
  from supervisely._utils import (
39
40
  add_callback,
@@ -277,6 +278,19 @@ class Inference:
277
278
  self.gui.on_serve_callbacks.append(on_serve_callback)
278
279
  self._initialize_app_layout()
279
280
 
281
+ train_task_id = os.getenv("modal.state.trainTaskId")
282
+ if train_task_id:
283
+ try:
284
+ train_task_id = int(train_task_id)
285
+ logger.info(f"Setting best checkpoint from training task id: {train_task_id}")
286
+ experiment_info = self.api.nn.get_experiment_info(train_task_id)
287
+ self._set_checkpoint_from_experiment_info(experiment_info)
288
+ self.gui.deploy_with_current_params()
289
+ except Exception as e:
290
+ logger.warning(f"Failed to set checkpoint from training task id: {repr(e)}")
291
+ logger.info("Resetting UI to default state")
292
+ self._reset_gui_state()
293
+
280
294
  self._inference_requests = {}
281
295
  max_workers = 1 if not multithread_inference else None
282
296
  self._executor = ThreadPoolExecutor(max_workers=max_workers)
@@ -748,7 +762,10 @@ class Inference:
748
762
  """
749
763
  team_id = sly_env.team_id()
750
764
  local_model_files = {}
751
- for file in model_files:
765
+
766
+ # Sort files to download 'checkpoint' first
767
+ files_order = sorted(model_files.keys(), key=lambda x: (0 if x == "checkpoint" else 1, x))
768
+ for file in files_order:
752
769
  file_url = model_files[file]
753
770
  file_info = self.api.file.get_info_by_path(team_id, file_url)
754
771
  if file_info is None:
@@ -773,11 +790,66 @@ class Inference:
773
790
  )
774
791
  else:
775
792
  self.api.file.download(team_id, file_url, file_path)
793
+
794
+ if file == "checkpoint":
795
+ try:
796
+ extracted_files = self._extract_model_files_from_checkpoint(file_path)
797
+ local_model_files.update(extracted_files)
798
+ if extracted_files:
799
+ local_model_files[file] = file_path
800
+ return local_model_files
801
+ except Exception as e:
802
+ logger.debug(f"Failed to process checkpoint '{file_name}' to extract auxiliary files: {repr(e)}")
803
+ logger.debug("Model files will be downloaded from Team Files")
804
+ local_model_files[file] = file_path
805
+ continue
806
+
776
807
  local_model_files[file] = file_path
777
808
  if log_progress:
778
809
  self.gui.download_progress.hide()
779
810
  return local_model_files
780
811
 
812
+ def _extract_model_files_from_checkpoint(self, checkpoint_path: str) -> dict:
813
+ extracted_files: dict = {}
814
+ file_ext = sly_fs.get_file_ext(checkpoint_path)
815
+ if file_ext not in (".pth", ".pt"):
816
+ return extracted_files
817
+
818
+ import torch # pylint: disable=import-error
819
+
820
+ logger.debug(f"Reading checkpoint: {checkpoint_path}")
821
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
822
+
823
+ # 1. Extract additional model files embedded into checkpoint (if any)
824
+ ckpt_files = checkpoint.get("model_files", None)
825
+ if ckpt_files and isinstance(ckpt_files, dict):
826
+ for file_key, file_info in ckpt_files.items():
827
+ try:
828
+ content = file_info["content"]
829
+ fname = file_info.get("name", f"{file_key}.txt")
830
+ dst_path = os.path.join(self.model_dir, fname)
831
+ # Overwrite if exists
832
+ if os.path.exists(dst_path):
833
+ sly_fs.silent_remove(dst_path)
834
+ with open(dst_path, "w") as f:
835
+ f.write(content)
836
+ extracted_files[file_key] = dst_path
837
+ except Exception as e:
838
+ logger.debug(f"Failed to write '{file_key}' from checkpoint to disk: {repr(e)}")
839
+
840
+ # 2. Extract project meta (if present)
841
+ model_meta = checkpoint.get("model_meta", None)
842
+ if model_meta is not None:
843
+ try:
844
+ meta_path = os.path.join(self.model_dir, "model_meta.json")
845
+ if os.path.exists(meta_path):
846
+ sly_fs.silent_remove(meta_path)
847
+ sly_json.dump_json_file(model_meta, meta_path)
848
+ except Exception as e:
849
+ logger.debug(f"Failed to dump model_meta from checkpoint: {repr(e)}")
850
+
851
+ return extracted_files
852
+
781
853
  def _load_model(self, deploy_params: dict):
782
854
  self.model_source = deploy_params.get("model_source")
783
855
  self.device = deploy_params.get("device")
@@ -910,11 +982,20 @@ class Inference:
910
982
  if isinstance(model_meta, dict):
911
983
  self._model_meta = ProjectMeta.from_json(model_meta)
912
984
  elif isinstance(model_meta, str):
913
- remote_artifacts_dir = model_info["artifacts_dir"]
914
- model_meta_url = os.path.join(remote_artifacts_dir, model_meta)
915
- model_meta_path = self.download(model_meta_url)
916
- model_meta = sly_json.load_json_file(model_meta_path)
917
- self._model_meta = ProjectMeta.from_json(model_meta)
985
+ model_meta_path = os.path.join(self.model_dir, "model_meta.json")
986
+ if os.path.exists(model_meta_path):
987
+ logger.debug("Reading model meta from checkpoint")
988
+ model_meta = sly_json.load_json_file(model_meta_path)
989
+ self._model_meta = ProjectMeta.from_json(model_meta)
990
+ sly_fs.silent_remove(model_meta_path)
991
+ else:
992
+ remote_artifacts_dir = model_info["artifacts_dir"]
993
+ model_meta_url = os.path.join(remote_artifacts_dir, model_meta)
994
+ model_meta_path = self.download(model_meta_url)
995
+ model_meta = sly_json.load_json_file(model_meta_path)
996
+ self._model_meta = ProjectMeta.from_json(model_meta)
997
+ sly_fs.silent_remove(model_meta_path)
998
+
918
999
  else:
919
1000
  raise ValueError(
920
1001
  "model_meta should be a dict or a name of '.json' file in experiment artifacts folder in Team Files"
@@ -3741,6 +3822,27 @@ class Inference:
3741
3822
  f"Checkpoint {checkpoint_url} not found in Team Files. Cannot set workflow input"
3742
3823
  )
3743
3824
 
3825
+ def _set_checkpoint_from_experiment_info(self, experiment_info: ExperimentInfo):
3826
+ try:
3827
+ if not isinstance(self.gui, GUI.ServingGUITemplate) or self.gui.experiment_selector is None:
3828
+ return
3829
+
3830
+ task_id = experiment_info.task_id
3831
+ self.gui.model_source_tabs.set_active_tab(ModelSource.CUSTOM)
3832
+ self.gui.experiment_selector.set_by_task_id(task_id)
3833
+
3834
+ best_ckpt = experiment_info.best_checkpoint
3835
+ if best_ckpt:
3836
+ row = self.gui.experiment_selector.get_by_task_id(task_id)
3837
+ if row is not None:
3838
+ row.set_selected_checkpoint_by_name(best_ckpt)
3839
+ except Exception as e:
3840
+ logger.warning(f"Failed to set checkpoint from experiment info: {repr(e)}")
3841
+
3842
+ def _reset_gui_state(self):
3843
+ if not isinstance(self.gui, GUI.ServingGUITemplate) or self.gui.experiment_selector is None:
3844
+ return
3845
+ self.gui.model_source_tabs.set_active_tab(ModelSource.PRETRAINED)
3744
3846
 
3745
3847
  def _exclude_duplicated_predictions(
3746
3848
  api: Api,