supervisely 6.73.389__py3-none-any.whl → 6.73.391__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/app/widgets/experiment_selector/experiment_selector.py +20 -3
- supervisely/app/widgets/experiment_selector/template.html +49 -70
- supervisely/app/widgets/report_thumbnail/report_thumbnail.py +19 -4
- supervisely/decorators/profile.py +20 -0
- supervisely/nn/benchmark/utils/detection/utlis.py +7 -0
- supervisely/nn/experiments.py +4 -0
- supervisely/nn/inference/gui/serving_gui_template.py +71 -11
- supervisely/nn/inference/inference.py +108 -6
- supervisely/nn/training/gui/classes_selector.py +246 -27
- supervisely/nn/training/gui/gui.py +318 -234
- supervisely/nn/training/gui/hyperparameters_selector.py +2 -2
- supervisely/nn/training/gui/model_selector.py +42 -1
- supervisely/nn/training/gui/tags_selector.py +1 -1
- supervisely/nn/training/gui/train_val_splits_selector.py +8 -7
- supervisely/nn/training/gui/training_artifacts.py +10 -1
- supervisely/nn/training/gui/training_process.py +17 -1
- supervisely/nn/training/train_app.py +227 -72
- supervisely/template/__init__.py +2 -0
- supervisely/template/base_generator.py +90 -0
- supervisely/template/experiment/__init__.py +0 -0
- supervisely/template/experiment/experiment.html.jinja +537 -0
- supervisely/template/experiment/experiment_generator.py +996 -0
- supervisely/template/experiment/header.html.jinja +154 -0
- supervisely/template/experiment/sidebar.html.jinja +240 -0
- supervisely/template/experiment/sly-style.css +397 -0
- supervisely/template/experiment/template.html.jinja +18 -0
- supervisely/template/extensions.py +172 -0
- supervisely/template/template_renderer.py +253 -0
- {supervisely-6.73.389.dist-info → supervisely-6.73.391.dist-info}/METADATA +3 -1
- {supervisely-6.73.389.dist-info → supervisely-6.73.391.dist-info}/RECORD +34 -23
- {supervisely-6.73.389.dist-info → supervisely-6.73.391.dist-info}/LICENSE +0 -0
- {supervisely-6.73.389.dist-info → supervisely-6.73.391.dist-info}/WHEEL +0 -0
- {supervisely-6.73.389.dist-info → supervisely-6.73.391.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.389.dist-info → supervisely-6.73.391.dist-info}/top_level.txt +0 -0
|
@@ -142,6 +142,14 @@ class ExperimentSelector(Widget):
|
|
|
142
142
|
@property
|
|
143
143
|
def checkpoints_selector(self) -> Select:
|
|
144
144
|
return self._checkpoints_widget
|
|
145
|
+
|
|
146
|
+
@property
|
|
147
|
+
def experiment_info(self) -> ExperimentInfo:
|
|
148
|
+
return self._experiment_info
|
|
149
|
+
|
|
150
|
+
@property
|
|
151
|
+
def best_checkpoint(self) -> str:
|
|
152
|
+
return self.experiment_info.best_checkpoint
|
|
145
153
|
|
|
146
154
|
@property
|
|
147
155
|
def session_link(self) -> str:
|
|
@@ -459,6 +467,12 @@ class ExperimentSelector(Widget):
|
|
|
459
467
|
return
|
|
460
468
|
selected_row = self.get_selected_row()
|
|
461
469
|
return selected_row.get_selected_checkpoint_path()
|
|
470
|
+
|
|
471
|
+
def get_selected_checkpoint_name(self) -> str:
|
|
472
|
+
if len(self._rows) == 0:
|
|
473
|
+
return
|
|
474
|
+
selected_row = self.get_selected_row()
|
|
475
|
+
return selected_row.get_selected_checkpoint_name()
|
|
462
476
|
|
|
463
477
|
def get_model_files(self) -> Dict[str, str]:
|
|
464
478
|
"""
|
|
@@ -485,9 +499,12 @@ class ExperimentSelector(Widget):
|
|
|
485
499
|
}
|
|
486
500
|
return deploy_params
|
|
487
501
|
|
|
488
|
-
def set_active_row(self, row_index: int) -> None:
|
|
489
|
-
if
|
|
502
|
+
def set_active_row(self, row_index: int, task_type: str = None) -> None:
|
|
503
|
+
if task_type is None:
|
|
504
|
+
task_type = self.get_selected_task_type()
|
|
505
|
+
if row_index < 0 or row_index > len(self._rows[task_type]) - 1:
|
|
490
506
|
raise ValueError(f'Row with index "{row_index}" does not exist')
|
|
507
|
+
self.set_active_task_type(task_type)
|
|
491
508
|
StateJson()[self.widget_id]["selectedRow"] = row_index
|
|
492
509
|
StateJson().send_changes()
|
|
493
510
|
|
|
@@ -495,7 +512,7 @@ class ExperimentSelector(Widget):
|
|
|
495
512
|
for task_type in self._rows:
|
|
496
513
|
for i, row in enumerate(self._rows[task_type]):
|
|
497
514
|
if row.task_id == task_id:
|
|
498
|
-
self.set_active_row(i)
|
|
515
|
+
self.set_active_row(i, task_type)
|
|
499
516
|
return
|
|
500
517
|
|
|
501
518
|
def get_by_task_id(self, task_id: int) -> Union[ModelRow, None]:
|
|
@@ -1,82 +1,61 @@
|
|
|
1
|
-
<link rel="stylesheet" href="./sly/css/app/widgets/custom_models_selector/style.css"/>
|
|
1
|
+
<link rel="stylesheet" href="./sly/css/app/widgets/custom_models_selector/style.css" />
|
|
2
2
|
|
|
3
|
-
<div
|
|
4
|
-
{% if widget._changes_handled == true %}
|
|
5
|
-
@change="post('/{{{widget.widget_id}}}/value_changed')"
|
|
6
|
-
{% endif %}
|
|
7
|
-
>
|
|
3
|
+
<div {% if widget._changes_handled==true %} @change="post('/{{{widget.widget_id}}}/value_changed')" {% endif %}>
|
|
8
4
|
|
|
9
5
|
<div v-if="Object.keys(data.{{{widget.widget_id}}}.rowsHtml).length === 0"> You don't have any custom models</div>
|
|
10
6
|
<div v-else>
|
|
11
7
|
|
|
12
8
|
<div v-if="data.{{{widget.widget_id}}}.taskTypes.length > 1">
|
|
13
|
-
<sly-field
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class="multi-line mt10"
|
|
18
|
-
:value="state.{{{widget.widget_id}}}.selectedTaskType"
|
|
19
|
-
|
|
20
|
-
{% if widget._task_type_changes_handled == true %}
|
|
21
|
-
@input="(evt) => {state.{{{widget.widget_id}}}.selectedTaskType = evt; state.{{{widget.widget_id}}}.selectedRow = 0; post('/{{{widget.widget_id}}}/task_type_changed')}"
|
|
9
|
+
<sly-field title="Task Type">
|
|
10
|
+
<el-radio-group class="multi-line mt10" :value="state.{{{widget.widget_id}}}.selectedTaskType" {% if
|
|
11
|
+
widget._task_type_changes_handled==true %}
|
|
12
|
+
@input="(evt) => {state.{{{widget.widget_id}}}.selectedTaskType = evt; state.{{{widget.widget_id}}}.selectedRow = 0; post('/{{{widget.widget_id}}}/task_type_changed')}"
|
|
22
13
|
{% else %}
|
|
23
|
-
|
|
24
|
-
{% endif %}
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
v-for="(item, idx) in {{{widget._task_types}}}"
|
|
29
|
-
:key="item"
|
|
30
|
-
:label="item"
|
|
31
|
-
>
|
|
32
|
-
{{ item }}
|
|
14
|
+
@input="(evt) => {state.{{{widget.widget_id}}}.selectedTaskType = evt; state.{{{widget.widget_id}}}.selectedRow = 0;}"
|
|
15
|
+
{% endif %}>
|
|
16
|
+
|
|
17
|
+
<el-radio v-for="(item, idx) in {{{widget._task_types}}}" :key="item" :label="item">
|
|
18
|
+
{{ item }}
|
|
33
19
|
</el-radio>
|
|
34
20
|
</el-radio-group>
|
|
35
|
-
|
|
36
|
-
|
|
21
|
+
</sly-field>
|
|
22
|
+
</div>
|
|
37
23
|
|
|
38
24
|
<div>
|
|
39
|
-
|
|
40
|
-
<table class="custom-models-selector-table">
|
|
41
|
-
<thead>
|
|
42
|
-
<tr>
|
|
43
|
-
<th v-for="col in data.{{{widget.widget_id}}}.columns">
|
|
44
|
-
<div> {{col}} </div>
|
|
45
|
-
</th>
|
|
46
|
-
</tr>
|
|
47
|
-
</thead>
|
|
48
|
-
<tbody>
|
|
49
|
-
<tr v-for="row, ridx in data.{{{widget.widget_id}}}.rowsHtml[state.{{{widget.widget_id}}}.selectedTaskType]">
|
|
50
|
-
<td v-for="col, vidx in row">
|
|
51
|
-
<div v-if="vidx === 0" style="display: flex;">
|
|
52
|
-
<el-radio
|
|
53
|
-
style="display: flex;"
|
|
54
|
-
v-model="state.{{{widget.widget_id}}}.selectedRow"
|
|
55
|
-
:label="ridx"
|
|
56
|
-
>‍</el-radio>
|
|
57
|
-
|
|
58
|
-
<sly-html-compiler :params="{ridx: ridx, vidx: vidx}" :template="col" :data="data" :state="state"></sly-html-compiler>
|
|
59
|
-
|
|
60
|
-
</div>
|
|
61
|
-
|
|
62
|
-
<div v-else>
|
|
63
25
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
26
|
+
<table class="custom-models-selector-table">
|
|
27
|
+
<thead>
|
|
28
|
+
<tr>
|
|
29
|
+
<th v-for="col in data.{{{widget.widget_id}}}.columns">
|
|
30
|
+
<div> {{col}} </div>
|
|
31
|
+
</th>
|
|
32
|
+
</tr>
|
|
33
|
+
</thead>
|
|
34
|
+
<tbody>
|
|
35
|
+
<tr
|
|
36
|
+
v-for="row, ridx in data.{{{widget.widget_id}}}.rowsHtml[state.{{{widget.widget_id}}}.selectedTaskType]">
|
|
37
|
+
<td v-for="col, vidx in row">
|
|
38
|
+
<div v-if="vidx === 0" style="display: flex;">
|
|
39
|
+
<el-radio style="display: flex;" v-model="state.{{{widget.widget_id}}}.selectedRow"
|
|
40
|
+
:label="ridx">‍</el-radio>
|
|
41
|
+
|
|
42
|
+
<sly-html-compiler :params="{ridx: ridx, vidx: vidx}" :template="col" :data="data"
|
|
43
|
+
:state="state"></sly-html-compiler>
|
|
44
|
+
|
|
45
|
+
</div>
|
|
46
|
+
|
|
47
|
+
<div v-else>
|
|
48
|
+
|
|
49
|
+
<sly-html-compiler :params="{ridx: ridx, vidx: vidx}" :template="col" :data="data"
|
|
50
|
+
:state="state">
|
|
51
|
+
</sly-html-compiler>
|
|
52
|
+
|
|
53
|
+
</div>
|
|
54
|
+
|
|
55
|
+
</td>
|
|
56
|
+
</tr>
|
|
57
|
+
</tbody>
|
|
58
|
+
</table>
|
|
59
|
+
</div>
|
|
60
|
+
</div>
|
|
61
|
+
</div>
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Optional
|
|
1
|
+
from typing import Literal, Optional
|
|
2
2
|
|
|
3
3
|
from supervisely._utils import abs_url, is_debug_with_sly_net, is_development
|
|
4
4
|
from supervisely.api.file_api import FileInfo
|
|
@@ -16,11 +16,13 @@ class ReportThumbnail(Widget):
|
|
|
16
16
|
title: Optional[str] = None,
|
|
17
17
|
color: Optional[str] = "#dcb0ff",
|
|
18
18
|
bg_color: Optional[str] = "#faebff",
|
|
19
|
+
report_type: Literal["model_benchmark", "experiment"] = "model_benchmark",
|
|
19
20
|
):
|
|
20
21
|
self._id: int = None
|
|
21
22
|
self._info: FileInfo = None
|
|
22
23
|
self._description: str = None
|
|
23
24
|
self._url: str = None
|
|
25
|
+
self._report_type = report_type
|
|
24
26
|
self._set_info(info)
|
|
25
27
|
self._title = title
|
|
26
28
|
self._color = color if _validate_hex_color(color) else "#dcb0ff"
|
|
@@ -54,13 +56,26 @@ class ReportThumbnail(Widget):
|
|
|
54
56
|
return
|
|
55
57
|
self._id = info.id
|
|
56
58
|
self._info = info
|
|
57
|
-
self.
|
|
58
|
-
|
|
59
|
+
if self._report_type == "model_benchmark":
|
|
60
|
+
self._description = "Open the Model Benchmark evaluation report."
|
|
61
|
+
lnk = f"/model-benchmark?id={info.id}"
|
|
62
|
+
elif self._report_type == "experiment":
|
|
63
|
+
self._description = "Open the Experiment report."
|
|
64
|
+
lnk = f"/nn/experiments/{info.id}"
|
|
65
|
+
else:
|
|
66
|
+
raise ValueError(f"Invalid report type: {self._report_type}")
|
|
67
|
+
|
|
59
68
|
lnk = abs_url(lnk) if is_development() or is_debug_with_sly_net() else lnk
|
|
60
69
|
# self._description = info.path
|
|
61
70
|
self._url = lnk
|
|
62
71
|
|
|
63
|
-
def set(
|
|
72
|
+
def set(
|
|
73
|
+
self,
|
|
74
|
+
info: FileInfo = None,
|
|
75
|
+
report_type: Optional[Literal["model_benchmark", "experiment"]] = None,
|
|
76
|
+
):
|
|
77
|
+
if report_type is not None:
|
|
78
|
+
self._report_type = report_type
|
|
64
79
|
self._set_info(info)
|
|
65
80
|
self.update_data()
|
|
66
81
|
DataJson().send_changes()
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import time
|
|
3
|
+
|
|
3
4
|
from supervisely.sly_logger import logger
|
|
4
5
|
|
|
5
6
|
|
|
@@ -21,6 +22,25 @@ def timeit(func):
|
|
|
21
22
|
return wrapper_timer
|
|
22
23
|
|
|
23
24
|
|
|
25
|
+
def timeit_with_result(func):
|
|
26
|
+
"""Measures execution time and stores it in function's 'elapsed' attribute."""
|
|
27
|
+
|
|
28
|
+
@functools.wraps(func)
|
|
29
|
+
def wrapper(*args, **kwargs):
|
|
30
|
+
start_time = time.perf_counter()
|
|
31
|
+
try:
|
|
32
|
+
result = func(*args, **kwargs)
|
|
33
|
+
return result
|
|
34
|
+
finally:
|
|
35
|
+
wrapper.elapsed = time.perf_counter() - start_time
|
|
36
|
+
logger.debug(
|
|
37
|
+
f"Function '{func.__name__}' finished in {wrapper.elapsed:.2f} seconds (≈ {wrapper.elapsed/60:.2f} minutes)"
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
wrapper.elapsed = None # Initial value before first call
|
|
41
|
+
return wrapper
|
|
42
|
+
|
|
43
|
+
|
|
24
44
|
def update_fields(func):
|
|
25
45
|
"""Update state field after executing function"""
|
|
26
46
|
|
|
@@ -24,5 +24,12 @@ def read_coco_datasets(cocoGt_json, cocoDt_json):
|
|
|
24
24
|
cocoGt = COCO()
|
|
25
25
|
cocoGt.dataset = cocoGt_json
|
|
26
26
|
cocoGt.createIndex()
|
|
27
|
+
|
|
28
|
+
# Fix key error in pycocotools
|
|
29
|
+
info = cocoGt.dataset.get("info", None)
|
|
30
|
+
if info is None:
|
|
31
|
+
cocoGt.dataset["info"] = {}
|
|
32
|
+
# ------------------------------ #
|
|
33
|
+
|
|
27
34
|
cocoDt = cocoGt.loadRes(cocoDt_json["annotations"])
|
|
28
35
|
return cocoGt, cocoDt
|
supervisely/nn/experiments.py
CHANGED
|
@@ -20,6 +20,10 @@ class ExperimentInfo:
|
|
|
20
20
|
"""Name of the framework used in the experiment"""
|
|
21
21
|
model_name: str
|
|
22
22
|
"""Name of the model used in the experiment. Defined by the user in the training app"""
|
|
23
|
+
base_checkpoint: str
|
|
24
|
+
"""Name of the base checkpoint used for training"""
|
|
25
|
+
base_checkpoint_link: str
|
|
26
|
+
"""Link to the base checkpoint used for training. URL in case of pretrained model, or Team Files path in case of custom model."""
|
|
23
27
|
task_type: str
|
|
24
28
|
"""Task type of the experiment"""
|
|
25
29
|
project_id: int
|
|
@@ -7,17 +7,8 @@ import supervisely.io.env as sly_env
|
|
|
7
7
|
import supervisely.io.fs as sly_fs
|
|
8
8
|
import supervisely.io.json as sly_json
|
|
9
9
|
from supervisely import Api
|
|
10
|
-
from supervisely.app.widgets import
|
|
11
|
-
|
|
12
|
-
Container,
|
|
13
|
-
Field,
|
|
14
|
-
RadioTabs,
|
|
15
|
-
SelectString,
|
|
16
|
-
Widget,
|
|
17
|
-
)
|
|
18
|
-
from supervisely.app.widgets.experiment_selector.experiment_selector import (
|
|
19
|
-
ExperimentSelector,
|
|
20
|
-
)
|
|
10
|
+
from supervisely.app.widgets import Card, Container, Field, RadioTabs, SelectString, Text, Widget
|
|
11
|
+
from supervisely.app.widgets.experiment_selector.experiment_selector import ExperimentSelector
|
|
21
12
|
from supervisely.app.widgets.pretrained_models_selector.pretrained_models_selector import (
|
|
22
13
|
PretrainedModelsSelector,
|
|
23
14
|
)
|
|
@@ -133,6 +124,27 @@ class ServingGUITemplate(ServingGUI):
|
|
|
133
124
|
card_widgets = [self.model_source_tabs]
|
|
134
125
|
if runtime_field is not None:
|
|
135
126
|
card_widgets.append(runtime_field)
|
|
127
|
+
|
|
128
|
+
# Runtime exported checkpoint message
|
|
129
|
+
self._export_msg = Text("")
|
|
130
|
+
self._export_msg.hide()
|
|
131
|
+
card_widgets.append(self._export_msg)
|
|
132
|
+
|
|
133
|
+
if self.runtime_select is not None:
|
|
134
|
+
self.runtime_select.value_changed(lambda _: self._update_export_message())
|
|
135
|
+
if self.experiment_selector is not None:
|
|
136
|
+
self.experiment_selector.value_changed(lambda _: self._update_export_message())
|
|
137
|
+
for task_type in self.experiment_selector.rows:
|
|
138
|
+
for row in self.experiment_selector.rows[task_type]:
|
|
139
|
+
row.checkpoints_selector.value_changed(lambda _: self._update_export_message())
|
|
140
|
+
if self.pretrained_models_table is not None:
|
|
141
|
+
self.pretrained_models_table.model_changed(lambda _: self._update_export_message())
|
|
142
|
+
|
|
143
|
+
if self.model_source_tabs is not None:
|
|
144
|
+
self.model_source_tabs.value_changed(lambda _: self._update_export_message())
|
|
145
|
+
|
|
146
|
+
self._update_export_message()
|
|
147
|
+
|
|
136
148
|
return card_widgets
|
|
137
149
|
|
|
138
150
|
def _initialize_extra_widgets(self) -> List[Widget]:
|
|
@@ -204,3 +216,51 @@ class ServingGUITemplate(ServingGUI):
|
|
|
204
216
|
elif self.model_source == ModelSource.CUSTOM and self.experiment_selector:
|
|
205
217
|
return self.experiment_selector.get_selected_experiment_info()
|
|
206
218
|
return {}
|
|
219
|
+
|
|
220
|
+
def _update_export_message(self):
|
|
221
|
+
self._export_msg.hide()
|
|
222
|
+
|
|
223
|
+
runtime = self.runtime
|
|
224
|
+
non_conversion_runtimes = [RuntimeType.ONNXRUNTIME, RuntimeType.TENSORRT]
|
|
225
|
+
if runtime not in non_conversion_runtimes:
|
|
226
|
+
return
|
|
227
|
+
|
|
228
|
+
if self.model_source == ModelSource.PRETRAINED:
|
|
229
|
+
self._export_msg.set(
|
|
230
|
+
"Checkpoint will be converted before deployment.",
|
|
231
|
+
"info",
|
|
232
|
+
)
|
|
233
|
+
self._export_msg.show()
|
|
234
|
+
return
|
|
235
|
+
|
|
236
|
+
checkpoint_name = None
|
|
237
|
+
if self.model_source == ModelSource.CUSTOM and self.experiment_selector is not None:
|
|
238
|
+
selected_row = self.experiment_selector.get_selected_row()
|
|
239
|
+
if selected_row is None:
|
|
240
|
+
return
|
|
241
|
+
checkpoint_name = selected_row.get_selected_checkpoint_name()
|
|
242
|
+
if checkpoint_name is None:
|
|
243
|
+
return
|
|
244
|
+
|
|
245
|
+
model_info = self.model_info or {}
|
|
246
|
+
export_info = model_info.get("export", {})
|
|
247
|
+
available = False
|
|
248
|
+
if isinstance(export_info, dict):
|
|
249
|
+
for key in export_info.keys():
|
|
250
|
+
if key.lower().startswith(runtime.lower()):
|
|
251
|
+
available = True
|
|
252
|
+
break
|
|
253
|
+
if checkpoint_name != selected_row.best_checkpoint:
|
|
254
|
+
available = False
|
|
255
|
+
|
|
256
|
+
if available:
|
|
257
|
+
self._export_msg.set(
|
|
258
|
+
"Runtime checkpoint exists – no conversion needed.",
|
|
259
|
+
"info",
|
|
260
|
+
)
|
|
261
|
+
else:
|
|
262
|
+
self._export_msg.set(
|
|
263
|
+
"Checkpoint will be converted before deployment.",
|
|
264
|
+
"info",
|
|
265
|
+
)
|
|
266
|
+
self._export_msg.show()
|
|
@@ -34,6 +34,7 @@ import supervisely.io.env as sly_env
|
|
|
34
34
|
import supervisely.io.fs as sly_fs
|
|
35
35
|
import supervisely.io.json as sly_json
|
|
36
36
|
import supervisely.nn.inference.gui as GUI
|
|
37
|
+
from supervisely.nn.experiments import ExperimentInfo
|
|
37
38
|
from supervisely import DatasetInfo, batched
|
|
38
39
|
from supervisely._utils import (
|
|
39
40
|
add_callback,
|
|
@@ -277,6 +278,19 @@ class Inference:
|
|
|
277
278
|
self.gui.on_serve_callbacks.append(on_serve_callback)
|
|
278
279
|
self._initialize_app_layout()
|
|
279
280
|
|
|
281
|
+
train_task_id = os.getenv("modal.state.trainTaskId")
|
|
282
|
+
if train_task_id:
|
|
283
|
+
try:
|
|
284
|
+
train_task_id = int(train_task_id)
|
|
285
|
+
logger.info(f"Setting best checkpoint from training task id: {train_task_id}")
|
|
286
|
+
experiment_info = self.api.nn.get_experiment_info(train_task_id)
|
|
287
|
+
self._set_checkpoint_from_experiment_info(experiment_info)
|
|
288
|
+
self.gui.deploy_with_current_params()
|
|
289
|
+
except Exception as e:
|
|
290
|
+
logger.warning(f"Failed to set checkpoint from training task id: {repr(e)}")
|
|
291
|
+
logger.info("Resetting UI to default state")
|
|
292
|
+
self._reset_gui_state()
|
|
293
|
+
|
|
280
294
|
self._inference_requests = {}
|
|
281
295
|
max_workers = 1 if not multithread_inference else None
|
|
282
296
|
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
|
@@ -748,7 +762,10 @@ class Inference:
|
|
|
748
762
|
"""
|
|
749
763
|
team_id = sly_env.team_id()
|
|
750
764
|
local_model_files = {}
|
|
751
|
-
|
|
765
|
+
|
|
766
|
+
# Sort files to download 'checkpoint' first
|
|
767
|
+
files_order = sorted(model_files.keys(), key=lambda x: (0 if x == "checkpoint" else 1, x))
|
|
768
|
+
for file in files_order:
|
|
752
769
|
file_url = model_files[file]
|
|
753
770
|
file_info = self.api.file.get_info_by_path(team_id, file_url)
|
|
754
771
|
if file_info is None:
|
|
@@ -773,11 +790,66 @@ class Inference:
|
|
|
773
790
|
)
|
|
774
791
|
else:
|
|
775
792
|
self.api.file.download(team_id, file_url, file_path)
|
|
793
|
+
|
|
794
|
+
if file == "checkpoint":
|
|
795
|
+
try:
|
|
796
|
+
extracted_files = self._extract_model_files_from_checkpoint(file_path)
|
|
797
|
+
local_model_files.update(extracted_files)
|
|
798
|
+
if extracted_files:
|
|
799
|
+
local_model_files[file] = file_path
|
|
800
|
+
return local_model_files
|
|
801
|
+
except Exception as e:
|
|
802
|
+
logger.debug(f"Failed to process checkpoint '{file_name}' to extract auxiliary files: {repr(e)}")
|
|
803
|
+
logger.debug("Model files will be downloaded from Team Files")
|
|
804
|
+
local_model_files[file] = file_path
|
|
805
|
+
continue
|
|
806
|
+
|
|
776
807
|
local_model_files[file] = file_path
|
|
777
808
|
if log_progress:
|
|
778
809
|
self.gui.download_progress.hide()
|
|
779
810
|
return local_model_files
|
|
780
811
|
|
|
812
|
+
def _extract_model_files_from_checkpoint(self, checkpoint_path: str) -> dict:
|
|
813
|
+
extracted_files: dict = {}
|
|
814
|
+
file_ext = sly_fs.get_file_ext(checkpoint_path)
|
|
815
|
+
if file_ext not in (".pth", ".pt"):
|
|
816
|
+
return extracted_files
|
|
817
|
+
|
|
818
|
+
import torch # pylint: disable=import-error
|
|
819
|
+
|
|
820
|
+
logger.debug(f"Reading checkpoint: {checkpoint_path}")
|
|
821
|
+
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
|
822
|
+
|
|
823
|
+
# 1. Extract additional model files embedded into checkpoint (if any)
|
|
824
|
+
ckpt_files = checkpoint.get("model_files", None)
|
|
825
|
+
if ckpt_files and isinstance(ckpt_files, dict):
|
|
826
|
+
for file_key, file_info in ckpt_files.items():
|
|
827
|
+
try:
|
|
828
|
+
content = file_info["content"]
|
|
829
|
+
fname = file_info.get("name", f"{file_key}.txt")
|
|
830
|
+
dst_path = os.path.join(self.model_dir, fname)
|
|
831
|
+
# Overwrite if exists
|
|
832
|
+
if os.path.exists(dst_path):
|
|
833
|
+
sly_fs.silent_remove(dst_path)
|
|
834
|
+
with open(dst_path, "w") as f:
|
|
835
|
+
f.write(content)
|
|
836
|
+
extracted_files[file_key] = dst_path
|
|
837
|
+
except Exception as e:
|
|
838
|
+
logger.debug(f"Failed to write '{file_key}' from checkpoint to disk: {repr(e)}")
|
|
839
|
+
|
|
840
|
+
# 2. Extract project meta (if present)
|
|
841
|
+
model_meta = checkpoint.get("model_meta", None)
|
|
842
|
+
if model_meta is not None:
|
|
843
|
+
try:
|
|
844
|
+
meta_path = os.path.join(self.model_dir, "model_meta.json")
|
|
845
|
+
if os.path.exists(meta_path):
|
|
846
|
+
sly_fs.silent_remove(meta_path)
|
|
847
|
+
sly_json.dump_json_file(model_meta, meta_path)
|
|
848
|
+
except Exception as e:
|
|
849
|
+
logger.debug(f"Failed to dump model_meta from checkpoint: {repr(e)}")
|
|
850
|
+
|
|
851
|
+
return extracted_files
|
|
852
|
+
|
|
781
853
|
def _load_model(self, deploy_params: dict):
|
|
782
854
|
self.model_source = deploy_params.get("model_source")
|
|
783
855
|
self.device = deploy_params.get("device")
|
|
@@ -910,11 +982,20 @@ class Inference:
|
|
|
910
982
|
if isinstance(model_meta, dict):
|
|
911
983
|
self._model_meta = ProjectMeta.from_json(model_meta)
|
|
912
984
|
elif isinstance(model_meta, str):
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
985
|
+
model_meta_path = os.path.join(self.model_dir, "model_meta.json")
|
|
986
|
+
if os.path.exists(model_meta_path):
|
|
987
|
+
logger.debug("Reading model meta from checkpoint")
|
|
988
|
+
model_meta = sly_json.load_json_file(model_meta_path)
|
|
989
|
+
self._model_meta = ProjectMeta.from_json(model_meta)
|
|
990
|
+
sly_fs.silent_remove(model_meta_path)
|
|
991
|
+
else:
|
|
992
|
+
remote_artifacts_dir = model_info["artifacts_dir"]
|
|
993
|
+
model_meta_url = os.path.join(remote_artifacts_dir, model_meta)
|
|
994
|
+
model_meta_path = self.download(model_meta_url)
|
|
995
|
+
model_meta = sly_json.load_json_file(model_meta_path)
|
|
996
|
+
self._model_meta = ProjectMeta.from_json(model_meta)
|
|
997
|
+
sly_fs.silent_remove(model_meta_path)
|
|
998
|
+
|
|
918
999
|
else:
|
|
919
1000
|
raise ValueError(
|
|
920
1001
|
"model_meta should be a dict or a name of '.json' file in experiment artifacts folder in Team Files"
|
|
@@ -3741,6 +3822,27 @@ class Inference:
|
|
|
3741
3822
|
f"Checkpoint {checkpoint_url} not found in Team Files. Cannot set workflow input"
|
|
3742
3823
|
)
|
|
3743
3824
|
|
|
3825
|
+
def _set_checkpoint_from_experiment_info(self, experiment_info: ExperimentInfo):
|
|
3826
|
+
try:
|
|
3827
|
+
if not isinstance(self.gui, GUI.ServingGUITemplate) or self.gui.experiment_selector is None:
|
|
3828
|
+
return
|
|
3829
|
+
|
|
3830
|
+
task_id = experiment_info.task_id
|
|
3831
|
+
self.gui.model_source_tabs.set_active_tab(ModelSource.CUSTOM)
|
|
3832
|
+
self.gui.experiment_selector.set_by_task_id(task_id)
|
|
3833
|
+
|
|
3834
|
+
best_ckpt = experiment_info.best_checkpoint
|
|
3835
|
+
if best_ckpt:
|
|
3836
|
+
row = self.gui.experiment_selector.get_by_task_id(task_id)
|
|
3837
|
+
if row is not None:
|
|
3838
|
+
row.set_selected_checkpoint_by_name(best_ckpt)
|
|
3839
|
+
except Exception as e:
|
|
3840
|
+
logger.warning(f"Failed to set checkpoint from experiment info: {repr(e)}")
|
|
3841
|
+
|
|
3842
|
+
def _reset_gui_state(self):
|
|
3843
|
+
if not isinstance(self.gui, GUI.ServingGUITemplate) or self.gui.experiment_selector is None:
|
|
3844
|
+
return
|
|
3845
|
+
self.gui.model_source_tabs.set_active_tab(ModelSource.PRETRAINED)
|
|
3744
3846
|
|
|
3745
3847
|
def _exclude_duplicated_predictions(
|
|
3746
3848
|
api: Api,
|