supervisely 6.73.389__py3-none-any.whl → 6.73.391__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/app/widgets/experiment_selector/experiment_selector.py +20 -3
- supervisely/app/widgets/experiment_selector/template.html +49 -70
- supervisely/app/widgets/report_thumbnail/report_thumbnail.py +19 -4
- supervisely/decorators/profile.py +20 -0
- supervisely/nn/benchmark/utils/detection/utlis.py +7 -0
- supervisely/nn/experiments.py +4 -0
- supervisely/nn/inference/gui/serving_gui_template.py +71 -11
- supervisely/nn/inference/inference.py +108 -6
- supervisely/nn/training/gui/classes_selector.py +246 -27
- supervisely/nn/training/gui/gui.py +318 -234
- supervisely/nn/training/gui/hyperparameters_selector.py +2 -2
- supervisely/nn/training/gui/model_selector.py +42 -1
- supervisely/nn/training/gui/tags_selector.py +1 -1
- supervisely/nn/training/gui/train_val_splits_selector.py +8 -7
- supervisely/nn/training/gui/training_artifacts.py +10 -1
- supervisely/nn/training/gui/training_process.py +17 -1
- supervisely/nn/training/train_app.py +227 -72
- supervisely/template/__init__.py +2 -0
- supervisely/template/base_generator.py +90 -0
- supervisely/template/experiment/__init__.py +0 -0
- supervisely/template/experiment/experiment.html.jinja +537 -0
- supervisely/template/experiment/experiment_generator.py +996 -0
- supervisely/template/experiment/header.html.jinja +154 -0
- supervisely/template/experiment/sidebar.html.jinja +240 -0
- supervisely/template/experiment/sly-style.css +397 -0
- supervisely/template/experiment/template.html.jinja +18 -0
- supervisely/template/extensions.py +172 -0
- supervisely/template/template_renderer.py +253 -0
- {supervisely-6.73.389.dist-info → supervisely-6.73.391.dist-info}/METADATA +3 -1
- {supervisely-6.73.389.dist-info → supervisely-6.73.391.dist-info}/RECORD +34 -23
- {supervisely-6.73.389.dist-info → supervisely-6.73.391.dist-info}/LICENSE +0 -0
- {supervisely-6.73.389.dist-info → supervisely-6.73.391.dist-info}/WHEEL +0 -0
- {supervisely-6.73.389.dist-info → supervisely-6.73.391.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.389.dist-info → supervisely-6.73.391.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,996 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Dict, Literal, Optional, Tuple
|
|
7
|
+
|
|
8
|
+
import supervisely.io.env as sly_env
|
|
9
|
+
import supervisely.io.fs as sly_fs
|
|
10
|
+
import supervisely.io.json as sly_json
|
|
11
|
+
from supervisely import logger
|
|
12
|
+
from supervisely.api.api import Api
|
|
13
|
+
from supervisely.api.file_api import FileInfo
|
|
14
|
+
from supervisely.nn.benchmark.object_detection.metric_provider import (
|
|
15
|
+
METRIC_NAMES as OBJECT_DETECTION_METRIC_NAMES,
|
|
16
|
+
)
|
|
17
|
+
from supervisely.nn.benchmark.semantic_segmentation.metric_provider import (
|
|
18
|
+
METRIC_NAMES as SEMANTIC_SEGMENTATION_METRIC_NAMES,
|
|
19
|
+
)
|
|
20
|
+
from supervisely.nn.inference import Inference
|
|
21
|
+
from supervisely.nn.task_type import TaskType
|
|
22
|
+
from supervisely.nn.utils import RuntimeType
|
|
23
|
+
from supervisely.project import ProjectMeta
|
|
24
|
+
from supervisely.template.base_generator import BaseGenerator
|
|
25
|
+
|
|
26
|
+
from supervisely.geometry.any_geometry import AnyGeometry
|
|
27
|
+
from supervisely.geometry.cuboid_3d import Cuboid3d
|
|
28
|
+
from supervisely.geometry.point_3d import Point3d
|
|
29
|
+
from supervisely.geometry.pointcloud import Pointcloud
|
|
30
|
+
from supervisely.geometry.multichannel_bitmap import MultichannelBitmap
|
|
31
|
+
from supervisely.geometry.cuboid import Cuboid
|
|
32
|
+
from supervisely.geometry.graph import GraphNodes
|
|
33
|
+
from supervisely.geometry.bitmap import Bitmap
|
|
34
|
+
from supervisely.geometry.point import Point
|
|
35
|
+
from supervisely.geometry.polyline import Polyline
|
|
36
|
+
from supervisely.geometry.polygon import Polygon
|
|
37
|
+
from supervisely.geometry.rectangle import Rectangle
|
|
38
|
+
|
|
39
|
+
from supervisely.imaging.color import rgb2hex
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# @TODO: Partly supports unreleased apps
|
|
43
|
+
class ExperimentGenerator(BaseGenerator):
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
api: Api,
|
|
48
|
+
experiment_info: dict,
|
|
49
|
+
hyperparameters: str,
|
|
50
|
+
model_meta: ProjectMeta,
|
|
51
|
+
serving_class: Optional[Inference] = None,
|
|
52
|
+
team_id: Optional[int] = None,
|
|
53
|
+
output_dir: str = "./experiment_report",
|
|
54
|
+
app_options: Optional[dict] = None,
|
|
55
|
+
):
|
|
56
|
+
"""Initialize experiment generator class.
|
|
57
|
+
|
|
58
|
+
:param api: Supervisely API instance
|
|
59
|
+
:type api: Api
|
|
60
|
+
:param experiment_info: Dictionary with experiment information
|
|
61
|
+
:type experiment_info: Dict[str, Any]
|
|
62
|
+
:param hyperparameters: Hyperparameters as YAML string or dictionary
|
|
63
|
+
:type hyperparameters: Optional[Union[str, Dict]]
|
|
64
|
+
:param model_meta: Model metadata as dictionary
|
|
65
|
+
:type model_meta: Optional[Union[str, Dict]]
|
|
66
|
+
:param serving_class: Serving class for model inference
|
|
67
|
+
:type serving_class: Optional[Inference]
|
|
68
|
+
"""
|
|
69
|
+
super().__init__(api, output_dir=output_dir)
|
|
70
|
+
self.team_id = team_id or sly_env.team_id()
|
|
71
|
+
self.info = experiment_info
|
|
72
|
+
self.hyperparameters = hyperparameters
|
|
73
|
+
self.model_meta = model_meta
|
|
74
|
+
self.artifacts_dir = self.info["artifacts_dir"]
|
|
75
|
+
self.serving_class = serving_class
|
|
76
|
+
self.app_info = self._get_app_info()
|
|
77
|
+
self.app_options = app_options
|
|
78
|
+
|
|
79
|
+
def _report_url(self, server_address: str, template_id: int) -> str:
|
|
80
|
+
return f"{server_address}/nn/experiments/{template_id}"
|
|
81
|
+
|
|
82
|
+
def upload_to_artifacts(self):
|
|
83
|
+
remote_dir = os.path.join(self.info["artifacts_dir"], "visualization")
|
|
84
|
+
self.upload(remote_dir, team_id=self.team_id)
|
|
85
|
+
|
|
86
|
+
def get_report(self) -> FileInfo:
|
|
87
|
+
remote_report_path = os.path.join(
|
|
88
|
+
self.info["artifacts_dir"], "visualization", "template.vue"
|
|
89
|
+
)
|
|
90
|
+
experiment_report = self.api.file.get_info_by_path(
|
|
91
|
+
self.team_id, remote_report_path
|
|
92
|
+
)
|
|
93
|
+
if experiment_report is None:
|
|
94
|
+
raise ValueError("Generate and upload report first")
|
|
95
|
+
return experiment_report
|
|
96
|
+
|
|
97
|
+
def get_report_id(self) -> int:
|
|
98
|
+
return self.get_report().id
|
|
99
|
+
|
|
100
|
+
def get_report_link(self) -> str:
|
|
101
|
+
return self._report_url(self.api.server_address, self.get_report_id())
|
|
102
|
+
|
|
103
|
+
def state(self) -> dict:
|
|
104
|
+
return {}
|
|
105
|
+
|
|
106
|
+
def context(self) -> dict:
|
|
107
|
+
context = {
|
|
108
|
+
"env": self._get_env_context(),
|
|
109
|
+
"experiment": self._get_experiment_context(),
|
|
110
|
+
"resources": self._get_links_context(),
|
|
111
|
+
"code": self._get_code_context(),
|
|
112
|
+
"widgets": self._get_widgets_context(),
|
|
113
|
+
}
|
|
114
|
+
return context
|
|
115
|
+
|
|
116
|
+
# --------------------------- Context blocks helpers --------------------------- #
|
|
117
|
+
def _get_env_context(self):
|
|
118
|
+
return {"server_address": self.api.server_address}
|
|
119
|
+
|
|
120
|
+
def _get_apps_context(self):
|
|
121
|
+
train_app, serve_app = self._get_app_train_serve_app_info()
|
|
122
|
+
apply_images_app, apply_videos_app = self._get_app_apply_nn_app_info()
|
|
123
|
+
log_viewer_app = self._get_log_viewer_app_info()
|
|
124
|
+
return {
|
|
125
|
+
"train": train_app,
|
|
126
|
+
"serve": serve_app,
|
|
127
|
+
"log_viewer": log_viewer_app,
|
|
128
|
+
"apply_nn_to_images": apply_images_app,
|
|
129
|
+
"apply_nn_to_videos": apply_videos_app,
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
def _get_links_context(self):
|
|
133
|
+
return {"apps": self._get_apps_context()}
|
|
134
|
+
|
|
135
|
+
def _get_code_context(self):
|
|
136
|
+
docker_image = self._get_docker_image()
|
|
137
|
+
repo_info = self._get_repository_info()
|
|
138
|
+
pytorch_demo, onnx_demo, trt_demo = self._get_demo_scripts()
|
|
139
|
+
|
|
140
|
+
return {
|
|
141
|
+
"docker": {"image": docker_image},
|
|
142
|
+
"local_prediction": {
|
|
143
|
+
"repo": repo_info,
|
|
144
|
+
"serving_module": (
|
|
145
|
+
self.serving_class.__module__ if self.serving_class else None
|
|
146
|
+
),
|
|
147
|
+
"serving_class": (
|
|
148
|
+
self.serving_class.__name__ if self.serving_class else None
|
|
149
|
+
),
|
|
150
|
+
},
|
|
151
|
+
"demo": {
|
|
152
|
+
"pytorch": pytorch_demo,
|
|
153
|
+
"onnx": onnx_demo,
|
|
154
|
+
"tensorrt": trt_demo,
|
|
155
|
+
},
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
def _get_widgets_context(self):
|
|
159
|
+
checkpoints_table = self._generate_checkpoints_table()
|
|
160
|
+
metrics_table = self._generate_metrics_table(self.info["task_type"])
|
|
161
|
+
sample_gallery = self._get_sample_predictions_gallery()
|
|
162
|
+
classes_table = self._generate_classes_table()
|
|
163
|
+
|
|
164
|
+
return {
|
|
165
|
+
"tables": {
|
|
166
|
+
"checkpoints": checkpoints_table,
|
|
167
|
+
"metrics": metrics_table,
|
|
168
|
+
"classes": classes_table,
|
|
169
|
+
},
|
|
170
|
+
"sample_pred_gallery": sample_gallery,
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
# --------------------------------------------------------------------------- #
|
|
174
|
+
|
|
175
|
+
def _generate_metrics_table(self, task_type: str) -> str:
|
|
176
|
+
"""Generate HTML table with evaluation metrics.
|
|
177
|
+
|
|
178
|
+
:returns: HTML string with metrics table
|
|
179
|
+
:rtype: str
|
|
180
|
+
"""
|
|
181
|
+
metrics = self.info.get("evaluation_metrics", {})
|
|
182
|
+
if not metrics:
|
|
183
|
+
return None
|
|
184
|
+
|
|
185
|
+
html = ['<table class="table">']
|
|
186
|
+
html.append("<thead><tr><th>Metrics</th><th>Value</th></tr></thead>")
|
|
187
|
+
html.append("<tbody>")
|
|
188
|
+
|
|
189
|
+
if (
|
|
190
|
+
task_type == TaskType.OBJECT_DETECTION
|
|
191
|
+
or task_type == TaskType.INSTANCE_SEGMENTATION
|
|
192
|
+
):
|
|
193
|
+
metric_names = OBJECT_DETECTION_METRIC_NAMES
|
|
194
|
+
elif task_type == TaskType.SEMANTIC_SEGMENTATION:
|
|
195
|
+
metric_names = SEMANTIC_SEGMENTATION_METRIC_NAMES
|
|
196
|
+
else:
|
|
197
|
+
raise NotImplementedError(f"Task type '{task_type}' is not supported")
|
|
198
|
+
|
|
199
|
+
for metric_name, metric_value in metrics.items():
|
|
200
|
+
formatted_metric_name = metric_names.get(metric_name)
|
|
201
|
+
if formatted_metric_name is None:
|
|
202
|
+
formatted_metric_name = metric_name.replace("_", " ")
|
|
203
|
+
formatted_metric_name = formatted_metric_name.replace("-", " ")
|
|
204
|
+
if isinstance(metric_value, float):
|
|
205
|
+
metric_value = f"{metric_value:.4f}"
|
|
206
|
+
html.append(f"<tr><td>{metric_name}</td><td>{metric_value}</td></tr>")
|
|
207
|
+
|
|
208
|
+
html.append("</tbody>")
|
|
209
|
+
html.append("</table>")
|
|
210
|
+
return "\n".join(html)
|
|
211
|
+
|
|
212
|
+
def _generate_classes_table(self) -> str:
|
|
213
|
+
"""Generate HTML table with class names, shapes and colors.
|
|
214
|
+
|
|
215
|
+
:returns: HTML string with classes table
|
|
216
|
+
:rtype: str
|
|
217
|
+
"""
|
|
218
|
+
type_to_icon = {
|
|
219
|
+
AnyGeometry: "zmdi zmdi-shape",
|
|
220
|
+
Rectangle: "zmdi zmdi-crop-din",
|
|
221
|
+
Polygon: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAYAAABzenr0AAAABmJLR0QA/wD/AP+gvaeTAAAB6klEQVRYhe2Wuy8EURTGf+u5VESNXq2yhYZCoeBv8RcI1i6NVUpsoVCKkHjUGlFTiYb1mFmh2MiKjVXMudmb3cPOzB0VXzKZm5k53/nmvO6Ff4RHD5AD7gFP1l3Kd11AHvCBEpAVW2esAvWmK6t8l1O+W0lCQEnIJoAZxUnzNQNkZF36jrQjgoA+uaciCgc9VaExBOyh/6WWAi1VhbjOJ4FbIXkBtgkK0BNHnYqNKUIPeBPbKyDdzpld5T6wD9SE4AwYjfEDaXFeFzE/doUWuhqwiFsOCwqv2hV2lU/L+sHBscGTxdvSFVoXpAjCZdauMHVic6ndl6U1VBsJCFhTeNUU9IiIEo3qvQYGHAV0AyfC5wNLhKipXuBCjA5wT8WxcM1FMRoBymK44CjAE57hqIazwCfwQdARcXa3UXHuRXVucIjb7jYvNkdxBZg0TBFid7PQTRAtX2xOiXkuMAMqYwkIE848rZFbjyNAmw9bIeweaZ2A5TgC7PnwKkTPtN+cTOrsyN3FEWAjRTAX6sA5ek77gSL6+WHZVQDAIHAjhJtN78aAS3lXAXYIivBOnCdyOAUYB6o0xqsvziry7FLE/Cp20cNcJEjDr8MUmVOVRzkVN+Nd7vZGVXXgiwxtPiRS5WFhz4fEq/zv4AvToMn7vCn3eAAAAABJRU5ErkJggg==",
|
|
222
|
+
Bitmap: "zmdi zmdi-brush",
|
|
223
|
+
Polyline: "zmdi zmdi-gesture",
|
|
224
|
+
Point: "zmdi zmdi-dot-circle-alt",
|
|
225
|
+
Cuboid: "zmdi zmdi-ungroup", #
|
|
226
|
+
GraphNodes: "zmdi zmdi-grain",
|
|
227
|
+
Cuboid3d: "zmdi zmdi-codepen",
|
|
228
|
+
Pointcloud: "zmdi zmdi-cloud-outline",
|
|
229
|
+
MultichannelBitmap: "zmdi zmdi-layers",
|
|
230
|
+
Point3d: "zmdi zmdi-filter-center-focus",
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
if not hasattr(self.model_meta, "obj_classes"):
|
|
234
|
+
return None
|
|
235
|
+
|
|
236
|
+
if len(self.model_meta.obj_classes) == 0:
|
|
237
|
+
return None
|
|
238
|
+
|
|
239
|
+
html = ['<table class="table">']
|
|
240
|
+
html.append("<thead><tr><th>Class name</th><th>Shape</th></tr></thead>")
|
|
241
|
+
html.append("<tbody>")
|
|
242
|
+
|
|
243
|
+
for obj_class in self.model_meta.obj_classes:
|
|
244
|
+
class_name = obj_class.name
|
|
245
|
+
color_hex = rgb2hex(obj_class.color)
|
|
246
|
+
icon = type_to_icon.get(obj_class.geometry_type, "zmdi zmdi-shape")
|
|
247
|
+
|
|
248
|
+
class_cell = (
|
|
249
|
+
f"<i class='zmdi zmdi-circle' style='color: {color_hex}; margin-right: 5px;'></i>"
|
|
250
|
+
f"<span>{class_name}</span>"
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
if isinstance(icon, str) and icon.startswith("data:image"):
|
|
254
|
+
shape_cell = f"<img src='{icon}' style='height: 15px; margin-right: 2px;'/>"
|
|
255
|
+
else:
|
|
256
|
+
shape_cell = f"<i class='{icon}' style='margin-right: 5px;'></i>"
|
|
257
|
+
|
|
258
|
+
shape_name = obj_class.geometry_type.geometry_name()
|
|
259
|
+
shape_cell += f"<span>{shape_name}</span>"
|
|
260
|
+
|
|
261
|
+
html.append(f"<tr><td>{class_cell}</td><td>{shape_cell}</td></tr>")
|
|
262
|
+
|
|
263
|
+
html.append("</tbody>")
|
|
264
|
+
html.append("</table>")
|
|
265
|
+
return "\n".join(html)
|
|
266
|
+
|
|
267
|
+
def _generate_metrics_table_horizontal(self) -> str:
|
|
268
|
+
"""Generate HTML table with evaluation metrics.
|
|
269
|
+
|
|
270
|
+
:returns: HTML string with metrics table
|
|
271
|
+
:rtype: str
|
|
272
|
+
"""
|
|
273
|
+
metrics = self.info.get("evaluation_metrics", {})
|
|
274
|
+
if not metrics:
|
|
275
|
+
return None
|
|
276
|
+
|
|
277
|
+
html = ['<table class="table">']
|
|
278
|
+
# Generate header row with metric names
|
|
279
|
+
header_cells = []
|
|
280
|
+
for metric_name in metrics.keys():
|
|
281
|
+
metric_name = metric_name.replace("_", " ")
|
|
282
|
+
metric_name = metric_name.replace("-", " ")
|
|
283
|
+
header_cells.append(f"<th>{metric_name}</th>")
|
|
284
|
+
html.append(f"<thead><tr>{''.join(header_cells)}</tr></thead>")
|
|
285
|
+
|
|
286
|
+
# Generate value row
|
|
287
|
+
html.append("<tbody>")
|
|
288
|
+
value_cells = []
|
|
289
|
+
for metric_value in metrics.values():
|
|
290
|
+
if isinstance(metric_value, float):
|
|
291
|
+
metric_value = f"{metric_value:.4f}"
|
|
292
|
+
value_cells.append(f"<td>{metric_value}</td>")
|
|
293
|
+
html.append(f"<tr>{''.join(value_cells)}</tr>")
|
|
294
|
+
html.append("</tbody>")
|
|
295
|
+
|
|
296
|
+
html.append("</table>")
|
|
297
|
+
return "\n".join(html)
|
|
298
|
+
|
|
299
|
+
def _get_primary_metric(self) -> dict:
|
|
300
|
+
"""Get primary metric from evaluation metrics.
|
|
301
|
+
|
|
302
|
+
:returns: Primary metric info
|
|
303
|
+
:rtype: dict
|
|
304
|
+
"""
|
|
305
|
+
primary_metric = {
|
|
306
|
+
"name": None,
|
|
307
|
+
"value": None,
|
|
308
|
+
"rounded_value": None,
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
eval_metrics = self.info.get("evaluation_metrics", {})
|
|
312
|
+
primary_metric_name = self.info.get("primary_metric")
|
|
313
|
+
primary_metric_value = eval_metrics.get(primary_metric_name)
|
|
314
|
+
|
|
315
|
+
if primary_metric_name is None or primary_metric_value is None:
|
|
316
|
+
logger.debug(
|
|
317
|
+
f"Primary metric is not found in evaluation metrics: {eval_metrics}"
|
|
318
|
+
)
|
|
319
|
+
return primary_metric
|
|
320
|
+
|
|
321
|
+
primary_metric = {
|
|
322
|
+
"name": primary_metric_name,
|
|
323
|
+
"value": primary_metric_value,
|
|
324
|
+
"rounded_value": round(primary_metric_value, 3),
|
|
325
|
+
}
|
|
326
|
+
return primary_metric
|
|
327
|
+
|
|
328
|
+
def _get_display_metrics(self, task_type: str) -> list:
|
|
329
|
+
"""Get first 5 metrics for display (excluding primary metric).
|
|
330
|
+
|
|
331
|
+
:param primary_metric: Primary metric info
|
|
332
|
+
:type primary_metric: dict
|
|
333
|
+
:returns: List of tuples (metric_name, metric_value) for display
|
|
334
|
+
:rtype: list
|
|
335
|
+
"""
|
|
336
|
+
display_metrics = []
|
|
337
|
+
eval_metrics = self.info.get("evaluation_metrics", {})
|
|
338
|
+
if not eval_metrics:
|
|
339
|
+
return display_metrics
|
|
340
|
+
|
|
341
|
+
main_metrics = []
|
|
342
|
+
if (
|
|
343
|
+
task_type == TaskType.OBJECT_DETECTION
|
|
344
|
+
or task_type == TaskType.INSTANCE_SEGMENTATION
|
|
345
|
+
):
|
|
346
|
+
main_metrics = ["mAP", "AP75", "AP50", "precision", "recall"]
|
|
347
|
+
elif task_type == TaskType.SEMANTIC_SEGMENTATION:
|
|
348
|
+
main_metrics = ["mIoU", "mPixel", "mPrecision", "mRecall", "mF1"]
|
|
349
|
+
else:
|
|
350
|
+
raise NotImplementedError(f"Task type '{task_type}' is not supported")
|
|
351
|
+
|
|
352
|
+
for metric_name in main_metrics:
|
|
353
|
+
if metric_name in eval_metrics:
|
|
354
|
+
metric_value = eval_metrics[metric_name]
|
|
355
|
+
value = round(metric_value, 3)
|
|
356
|
+
percent_value = round(metric_value * 100, 3)
|
|
357
|
+
display_metrics.append(
|
|
358
|
+
{
|
|
359
|
+
"name": metric_name,
|
|
360
|
+
"value": value,
|
|
361
|
+
"percent_value": percent_value,
|
|
362
|
+
}
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
return display_metrics
|
|
366
|
+
|
|
367
|
+
def _generate_checkpoints_table(self) -> str:
|
|
368
|
+
"""Generate HTML table with checkpoint information.
|
|
369
|
+
|
|
370
|
+
:returns: HTML string with checkpoints table
|
|
371
|
+
:rtype: str
|
|
372
|
+
"""
|
|
373
|
+
pytorch_checkpoints = self.info.get("checkpoints", None)
|
|
374
|
+
if pytorch_checkpoints is None:
|
|
375
|
+
raise ValueError("Checkpoints are not found in experiment info")
|
|
376
|
+
|
|
377
|
+
checkpoints = pytorch_checkpoints.copy()
|
|
378
|
+
export = self.info.get("export", {})
|
|
379
|
+
if export:
|
|
380
|
+
onnx_checkpoint = export.get(RuntimeType.ONNXRUNTIME)
|
|
381
|
+
trt_checkpoint = export.get(RuntimeType.TENSORRT)
|
|
382
|
+
if onnx_checkpoint is not None:
|
|
383
|
+
checkpoints.append(onnx_checkpoint)
|
|
384
|
+
if trt_checkpoint is not None:
|
|
385
|
+
checkpoints.append(trt_checkpoint)
|
|
386
|
+
|
|
387
|
+
checkpoint_paths = [
|
|
388
|
+
os.path.join(self.artifacts_dir, ckpt) for ckpt in checkpoints
|
|
389
|
+
]
|
|
390
|
+
checkpoint_infos = [
|
|
391
|
+
self.api.file.get_info_by_path(self.team_id, path)
|
|
392
|
+
for path in checkpoint_paths
|
|
393
|
+
]
|
|
394
|
+
checkpoint_sizes = [
|
|
395
|
+
f"{info.sizeb / 1024 / 1024:.2f} MB" for info in checkpoint_infos
|
|
396
|
+
]
|
|
397
|
+
checkpoint_dl_links = [
|
|
398
|
+
f"<a href='{info.full_storage_url}' download='{sly_fs.get_file_name_with_ext(info.path)}'>Download</a>"
|
|
399
|
+
for info in checkpoint_infos
|
|
400
|
+
]
|
|
401
|
+
|
|
402
|
+
html = ['<table class="table">']
|
|
403
|
+
html.append(
|
|
404
|
+
"<thead><tr><th>Checkpoint</th><th>Size</th><th> </th></tr></thead>"
|
|
405
|
+
)
|
|
406
|
+
html.append("<tbody>")
|
|
407
|
+
for checkpoint, size, dl_link in zip(
|
|
408
|
+
checkpoints, checkpoint_sizes, checkpoint_dl_links
|
|
409
|
+
):
|
|
410
|
+
if isinstance(checkpoint, str):
|
|
411
|
+
html.append(
|
|
412
|
+
f"<tr><td>{os.path.basename(checkpoint)}</td><td>{size}</td><td>{dl_link}</td></tr>"
|
|
413
|
+
)
|
|
414
|
+
html.append("</tbody>")
|
|
415
|
+
html.append("</table>")
|
|
416
|
+
return "\n".join(html)
|
|
417
|
+
|
|
418
|
+
def _generate_hyperparameters_yaml(self) -> str:
|
|
419
|
+
"""Return hyperparameters as YAML string.
|
|
420
|
+
|
|
421
|
+
:returns: YAML string with hyperparameters
|
|
422
|
+
:rtype: str
|
|
423
|
+
"""
|
|
424
|
+
if self.hyperparameters is not None:
|
|
425
|
+
if not isinstance(self.hyperparameters, str):
|
|
426
|
+
raise ValueError("Hyperparameters must be a yaml string")
|
|
427
|
+
hyperparameters = self.hyperparameters.split("\n")
|
|
428
|
+
return hyperparameters
|
|
429
|
+
return None
|
|
430
|
+
|
|
431
|
+
def _get_log_viewer_app_info(self):
|
|
432
|
+
"""Get log viewer app information.
|
|
433
|
+
|
|
434
|
+
:returns: Log viewer app info
|
|
435
|
+
:rtype: dict
|
|
436
|
+
"""
|
|
437
|
+
slug = "supervisely-ecosystem/tensorboard-experiments-viewer"
|
|
438
|
+
module_id = self.api.app.get_ecosystem_module_id(slug)
|
|
439
|
+
return {"slug": slug, "module_id": module_id}
|
|
440
|
+
|
|
441
|
+
def _get_app_info(self):
|
|
442
|
+
"""Get app information from task.
|
|
443
|
+
|
|
444
|
+
:returns: App info object or None if not found
|
|
445
|
+
:rtype: Optional[Any]
|
|
446
|
+
"""
|
|
447
|
+
task_id = self.info.get("task_id", None)
|
|
448
|
+
if task_id is None or task_id == -1:
|
|
449
|
+
return None
|
|
450
|
+
|
|
451
|
+
task_info = self.api.task.get_info_by_id(task_id)
|
|
452
|
+
app_id = task_info["meta"]["app"]["id"]
|
|
453
|
+
return self.api.app.get_info_by_id(app_id)
|
|
454
|
+
|
|
455
|
+
def _get_training_session(self) -> dict:
|
|
456
|
+
"""Get training session information.
|
|
457
|
+
|
|
458
|
+
:returns: Training session info
|
|
459
|
+
:rtype: dict
|
|
460
|
+
"""
|
|
461
|
+
task_id = self.info["task_id"]
|
|
462
|
+
if task_id is None or task_id == -1:
|
|
463
|
+
training_session = {
|
|
464
|
+
"id": None,
|
|
465
|
+
"url": None,
|
|
466
|
+
}
|
|
467
|
+
return training_session
|
|
468
|
+
|
|
469
|
+
training_session = {
|
|
470
|
+
"id": task_id,
|
|
471
|
+
"url": f"{self.api.server_address}/apps/sessions/{task_id}",
|
|
472
|
+
}
|
|
473
|
+
return training_session
|
|
474
|
+
|
|
475
|
+
def _get_training_duration(self) -> str:
|
|
476
|
+
"""Get training duration.
|
|
477
|
+
|
|
478
|
+
:returns: Training duration in format "{h}h {m}m" or "N/A"
|
|
479
|
+
:rtype: str
|
|
480
|
+
"""
|
|
481
|
+
raw_duration = self.info.get("training_duration", "N/A")
|
|
482
|
+
if raw_duration in (None, "N/A"):
|
|
483
|
+
return "N/A"
|
|
484
|
+
|
|
485
|
+
try:
|
|
486
|
+
duration_seconds = float(raw_duration)
|
|
487
|
+
except (TypeError, ValueError):
|
|
488
|
+
return str(raw_duration)
|
|
489
|
+
|
|
490
|
+
hours = int(duration_seconds // 3600)
|
|
491
|
+
minutes = int((duration_seconds % 3600) // 60)
|
|
492
|
+
seconds = int(duration_seconds % 60)
|
|
493
|
+
return f"{hours}h {minutes}m {seconds}s"
|
|
494
|
+
|
|
495
|
+
def _get_date(self) -> str:
|
|
496
|
+
"""Format experiment date.
|
|
497
|
+
|
|
498
|
+
:returns: Formatted date string
|
|
499
|
+
:rtype: str
|
|
500
|
+
"""
|
|
501
|
+
date_str = self.info.get("datetime", "")
|
|
502
|
+
date = date_str
|
|
503
|
+
if date_str:
|
|
504
|
+
try:
|
|
505
|
+
dt = datetime.strptime(date_str, "%Y-%m-%d %H:%M:%S")
|
|
506
|
+
date = dt.strftime("%d %b %Y")
|
|
507
|
+
except ValueError:
|
|
508
|
+
pass
|
|
509
|
+
return date
|
|
510
|
+
|
|
511
|
+
def _get_best_checkpoint(self) -> dict:
|
|
512
|
+
"""Get best checkpoint filename.
|
|
513
|
+
|
|
514
|
+
:returns: Best checkpoint info
|
|
515
|
+
:rtype: dict
|
|
516
|
+
"""
|
|
517
|
+
best_checkpoint_path = os.path.join(
|
|
518
|
+
self.artifacts_dir, "checkpoints", self.info["best_checkpoint"]
|
|
519
|
+
)
|
|
520
|
+
best_checkpoint_info = self.api.file.get_info_by_path(
|
|
521
|
+
self.team_id, best_checkpoint_path
|
|
522
|
+
)
|
|
523
|
+
best_checkpoint = {
|
|
524
|
+
"name": self.info["best_checkpoint"],
|
|
525
|
+
"path": best_checkpoint_path,
|
|
526
|
+
"url": best_checkpoint_info.full_storage_url,
|
|
527
|
+
"size": f"{best_checkpoint_info.sizeb / 1024 / 1024:.1f} MB",
|
|
528
|
+
}
|
|
529
|
+
return best_checkpoint
|
|
530
|
+
|
|
531
|
+
def _get_optimized_checkpoints(self) -> Tuple[dict, dict]:
|
|
532
|
+
"""Get optimized checkpoint filename (ONNX or TensorRT).
|
|
533
|
+
|
|
534
|
+
:returns: Checkpoint info or None if not available
|
|
535
|
+
:rtype: Optional[dict]
|
|
536
|
+
"""
|
|
537
|
+
export = self.info.get("export", {})
|
|
538
|
+
|
|
539
|
+
onnx_checkpoint_data = {
|
|
540
|
+
"name": None,
|
|
541
|
+
"path": None,
|
|
542
|
+
"size": None,
|
|
543
|
+
"url": None,
|
|
544
|
+
"classes_url": None,
|
|
545
|
+
}
|
|
546
|
+
trt_checkpoint_data = {
|
|
547
|
+
"name": None,
|
|
548
|
+
"path": None,
|
|
549
|
+
"size": None,
|
|
550
|
+
"url": None,
|
|
551
|
+
"classes_url": None,
|
|
552
|
+
}
|
|
553
|
+
|
|
554
|
+
onnx_checkpoint = export.get(RuntimeType.ONNXRUNTIME)
|
|
555
|
+
if onnx_checkpoint is not None:
|
|
556
|
+
onnx_checkpoint_path = os.path.join(
|
|
557
|
+
self.artifacts_dir, export.get(RuntimeType.ONNXRUNTIME)
|
|
558
|
+
)
|
|
559
|
+
classes_file = self.api.file.get_info_by_path(
|
|
560
|
+
self.team_id,
|
|
561
|
+
os.path.join(os.path.dirname(onnx_checkpoint_path), "classes.json"),
|
|
562
|
+
)
|
|
563
|
+
onnx_file_info = self.api.file.get_info_by_path(
|
|
564
|
+
self.team_id, onnx_checkpoint_path
|
|
565
|
+
)
|
|
566
|
+
onnx_checkpoint_data = {
|
|
567
|
+
"name": os.path.basename(export.get(RuntimeType.ONNXRUNTIME)),
|
|
568
|
+
"path": onnx_checkpoint_path,
|
|
569
|
+
"size": f"{onnx_file_info.sizeb / 1024 / 1024:.1f} MB",
|
|
570
|
+
"url": onnx_file_info.full_storage_url,
|
|
571
|
+
"classes_url": classes_file.full_storage_url if classes_file else None,
|
|
572
|
+
}
|
|
573
|
+
trt_checkpoint = export.get(RuntimeType.TENSORRT)
|
|
574
|
+
if trt_checkpoint is not None:
|
|
575
|
+
trt_checkpoint_path = os.path.join(
|
|
576
|
+
self.artifacts_dir, export.get(RuntimeType.TENSORRT)
|
|
577
|
+
)
|
|
578
|
+
classes_file = self.api.file.get_info_by_path(
|
|
579
|
+
self.team_id,
|
|
580
|
+
os.path.join(os.path.dirname(trt_checkpoint_path), "classes.json"),
|
|
581
|
+
)
|
|
582
|
+
trt_file_info = self.api.file.get_info_by_path(
|
|
583
|
+
self.team_id, trt_checkpoint_path
|
|
584
|
+
)
|
|
585
|
+
trt_checkpoint_data = {
|
|
586
|
+
"name": os.path.basename(export.get(RuntimeType.TENSORRT)),
|
|
587
|
+
"path": trt_checkpoint_path,
|
|
588
|
+
"size": f"{trt_file_info.sizeb / 1024 / 1024:.1f} MB",
|
|
589
|
+
"url": trt_file_info.full_storage_url,
|
|
590
|
+
"classes_url": classes_file.full_storage_url if classes_file else None,
|
|
591
|
+
}
|
|
592
|
+
return onnx_checkpoint_data, trt_checkpoint_data
|
|
593
|
+
|
|
594
|
+
def _get_docker_image(self) -> str:
|
|
595
|
+
"""Get Docker image for model.
|
|
596
|
+
|
|
597
|
+
:returns: Docker image name
|
|
598
|
+
:rtype: str
|
|
599
|
+
"""
|
|
600
|
+
if self.app_info is None:
|
|
601
|
+
return None
|
|
602
|
+
|
|
603
|
+
docker_image = self.app_info.config["docker_image"]
|
|
604
|
+
if not docker_image:
|
|
605
|
+
raise ValueError("Docker image is not found in app config")
|
|
606
|
+
return docker_image
|
|
607
|
+
|
|
608
|
+
def _get_repository_info(self) -> Dict[str, str]:
|
|
609
|
+
"""Get repository information.
|
|
610
|
+
|
|
611
|
+
:returns: Dictionary with repo URL and name
|
|
612
|
+
:rtype: Dict[str, str]
|
|
613
|
+
"""
|
|
614
|
+
if self.app_info is None:
|
|
615
|
+
return {"url": None, "name": None}
|
|
616
|
+
|
|
617
|
+
framework_name = self.info["framework_name"]
|
|
618
|
+
if hasattr(self.app_info, "repo"):
|
|
619
|
+
repo_link = self.app_info.repo
|
|
620
|
+
repo_name = repo_link.split("/")[-1]
|
|
621
|
+
return {"url": repo_link, "name": repo_name}
|
|
622
|
+
|
|
623
|
+
# @TODO: for unreleased apps
|
|
624
|
+
repo_name = framework_name.replace(" ", "-")
|
|
625
|
+
repo_link = f"https://github.com/supervisely-ecosystem/{repo_name}"
|
|
626
|
+
return {"url": repo_link, "name": repo_name}
|
|
627
|
+
|
|
628
|
+
# @TODO: method not used (might be helpful for unreleased apps)
|
|
629
|
+
def _find_app_config(self):
|
|
630
|
+
"""
|
|
631
|
+
Find app config.json in project structure.
|
|
632
|
+
|
|
633
|
+
:returns: Config dictionary or None if not found
|
|
634
|
+
:rtype: Optional[Dict[str, Any]]
|
|
635
|
+
"""
|
|
636
|
+
try:
|
|
637
|
+
current_dir = Path(os.path.abspath(os.path.dirname(__file__)))
|
|
638
|
+
root_dir = current_dir
|
|
639
|
+
|
|
640
|
+
while root_dir.parent != root_dir:
|
|
641
|
+
config_path = (
|
|
642
|
+
root_dir / "supervisely_integration" / "train" / "config.json"
|
|
643
|
+
)
|
|
644
|
+
if config_path.exists():
|
|
645
|
+
break
|
|
646
|
+
root_dir = root_dir.parent
|
|
647
|
+
|
|
648
|
+
config_path = root_dir / "supervisely_integration" / "train" / "config.json"
|
|
649
|
+
if config_path.exists():
|
|
650
|
+
return sly_json.load_json_file(config_path)
|
|
651
|
+
|
|
652
|
+
except Exception as e:
|
|
653
|
+
logger.warning(f"Failed to load config.json: {e}")
|
|
654
|
+
return None
|
|
655
|
+
|
|
656
|
+
def _get_sample_predictions_gallery(self):
|
|
657
|
+
evaluation_report_id = self.info.get("evaluation_report_id")
|
|
658
|
+
if evaluation_report_id is None:
|
|
659
|
+
return None
|
|
660
|
+
benchmark_file_info = self.api.file.get_info_by_id(evaluation_report_id)
|
|
661
|
+
evaluation_report_path = os.path.dirname(benchmark_file_info.path)
|
|
662
|
+
if os.path.basename(evaluation_report_path) != "visualizations":
|
|
663
|
+
logger.debug(
|
|
664
|
+
f"Visualizations directory is not found in the report directory: '{evaluation_report_path}'"
|
|
665
|
+
)
|
|
666
|
+
return None
|
|
667
|
+
|
|
668
|
+
evaluation_report_data_path = os.path.join(evaluation_report_path, "data")
|
|
669
|
+
|
|
670
|
+
seek_file = "explore_predictions_gallery_widget"
|
|
671
|
+
remote_gallery_widget_json = None
|
|
672
|
+
for file in self.api.file.list(
|
|
673
|
+
self.team_id, evaluation_report_data_path, False, "fileinfo"
|
|
674
|
+
):
|
|
675
|
+
if file.name.startswith(seek_file) and file.name.endswith(".json"):
|
|
676
|
+
remote_gallery_widget_json = file.path
|
|
677
|
+
break
|
|
678
|
+
|
|
679
|
+
if remote_gallery_widget_json is None:
|
|
680
|
+
logger.debug(
|
|
681
|
+
f"Gallery widget is not found in the report directory: '{evaluation_report_path}'"
|
|
682
|
+
)
|
|
683
|
+
return None
|
|
684
|
+
|
|
685
|
+
save_path = os.path.join(
|
|
686
|
+
self.output_dir, "data", "explore_predictions_gallery_widget_expmt.json"
|
|
687
|
+
)
|
|
688
|
+
self.api.file.download(self.team_id, remote_gallery_widget_json, save_path)
|
|
689
|
+
|
|
690
|
+
widget_html = """
|
|
691
|
+
<sly-iw-gallery ref="gallery_widget_expmt" iw-widget-id="gallery_widget_expmt" :options="{'isModalWindow': false}"
|
|
692
|
+
:actions="{
|
|
693
|
+
'init': {
|
|
694
|
+
'dataSource': '/data/explore_predictions_gallery_widget_expmt.json',
|
|
695
|
+
},
|
|
696
|
+
|
|
697
|
+
}" :command="command" :data="data">
|
|
698
|
+
</sly-iw-gallery>
|
|
699
|
+
"""
|
|
700
|
+
return widget_html
|
|
701
|
+
|
|
702
|
+
def _get_demo_scripts(self):
|
|
703
|
+
"""Get demo scripts.
|
|
704
|
+
|
|
705
|
+
:returns: Demo scripts
|
|
706
|
+
:rtype: Tuple[dict, dict, dict]
|
|
707
|
+
"""
|
|
708
|
+
|
|
709
|
+
demo_pytorch_filename = "demo_pytorch.py"
|
|
710
|
+
demo_onnx_filename = "demo_onnx.py"
|
|
711
|
+
demo_tensorrt_filename = "demo_tensorrt.py"
|
|
712
|
+
|
|
713
|
+
pytorch_demo = {"path": None, "script": None}
|
|
714
|
+
onnx_demo = {"path": None, "script": None}
|
|
715
|
+
trt_demo = {"path": None, "script": None}
|
|
716
|
+
|
|
717
|
+
demo_path = self.app_options.get("demo", {}).get("path")
|
|
718
|
+
if demo_path is None:
|
|
719
|
+
logger.debug("Demo path is not found in app options")
|
|
720
|
+
return pytorch_demo, onnx_demo, trt_demo
|
|
721
|
+
|
|
722
|
+
local_demo_dir = os.path.join(os.getcwd(), demo_path)
|
|
723
|
+
if not sly_fs.dir_exists(local_demo_dir):
|
|
724
|
+
logger.debug(f"Demo directory '{local_demo_dir}' does not exist")
|
|
725
|
+
return pytorch_demo, onnx_demo, trt_demo
|
|
726
|
+
|
|
727
|
+
local_files = sly_fs.list_files(local_demo_dir)
|
|
728
|
+
for file in local_files:
|
|
729
|
+
if file.endswith(demo_pytorch_filename):
|
|
730
|
+
with open(file, "r", encoding="utf-8") as f:
|
|
731
|
+
script = f.read()
|
|
732
|
+
pytorch_demo = {"path": file, "script": script}
|
|
733
|
+
elif file.endswith(demo_onnx_filename):
|
|
734
|
+
with open(file, "r", encoding="utf-8") as f:
|
|
735
|
+
script = f.read()
|
|
736
|
+
onnx_demo = {"path": file, "script": script}
|
|
737
|
+
elif file.endswith(demo_tensorrt_filename):
|
|
738
|
+
with open(file, "r", encoding="utf-8") as f:
|
|
739
|
+
script = f.read()
|
|
740
|
+
trt_demo = {"path": file, "script": script}
|
|
741
|
+
|
|
742
|
+
return pytorch_demo, onnx_demo, trt_demo
|
|
743
|
+
|
|
744
|
+
def _get_app_train_serve_app_info(self):
|
|
745
|
+
"""Get app slugs.
|
|
746
|
+
|
|
747
|
+
:returns: App slugs
|
|
748
|
+
:rtype: Tuple[str, str]
|
|
749
|
+
"""
|
|
750
|
+
|
|
751
|
+
def find_app_by_framework(
|
|
752
|
+
api: Api, framework: str, action: Literal["train", "serve"]
|
|
753
|
+
):
|
|
754
|
+
try:
|
|
755
|
+
modules = api.app.get_list_ecosystem_modules(
|
|
756
|
+
categories=[action, f"framework:{framework}"],
|
|
757
|
+
categories_operation="and",
|
|
758
|
+
)
|
|
759
|
+
if len(modules) == 0:
|
|
760
|
+
return None
|
|
761
|
+
return modules[0]
|
|
762
|
+
except Exception as e:
|
|
763
|
+
logger.warning(f"Failed to find {action} app by framework: {e}")
|
|
764
|
+
return None
|
|
765
|
+
|
|
766
|
+
train_app_info = find_app_by_framework(
|
|
767
|
+
self.api, self.info["framework_name"], "train"
|
|
768
|
+
)
|
|
769
|
+
serve_app_info = find_app_by_framework(
|
|
770
|
+
self.api, self.info["framework_name"], "serve"
|
|
771
|
+
)
|
|
772
|
+
|
|
773
|
+
if train_app_info is not None:
|
|
774
|
+
train_app_slug = train_app_info["slug"].replace(
|
|
775
|
+
"supervisely-ecosystem/", ""
|
|
776
|
+
)
|
|
777
|
+
train_app_id = train_app_info["id"]
|
|
778
|
+
else:
|
|
779
|
+
train_app_slug = None
|
|
780
|
+
train_app_id = None
|
|
781
|
+
|
|
782
|
+
if serve_app_info is not None:
|
|
783
|
+
serve_app_slug = serve_app_info["slug"].replace(
|
|
784
|
+
"supervisely-ecosystem/", ""
|
|
785
|
+
)
|
|
786
|
+
serve_app_id = serve_app_info["id"]
|
|
787
|
+
else:
|
|
788
|
+
serve_app_slug = None
|
|
789
|
+
serve_app_id = None
|
|
790
|
+
|
|
791
|
+
train_app = {
|
|
792
|
+
"slug": train_app_slug,
|
|
793
|
+
"module_id": train_app_id,
|
|
794
|
+
}
|
|
795
|
+
serve_app = {
|
|
796
|
+
"slug": serve_app_slug,
|
|
797
|
+
"module_id": serve_app_id,
|
|
798
|
+
}
|
|
799
|
+
return train_app, serve_app
|
|
800
|
+
|
|
801
|
+
def _get_agent_info(self) -> str:
|
|
802
|
+
task_id = self.info.get("task_id", None)
|
|
803
|
+
|
|
804
|
+
agent_info = {
|
|
805
|
+
"name": None,
|
|
806
|
+
"id": None,
|
|
807
|
+
"link": None,
|
|
808
|
+
}
|
|
809
|
+
|
|
810
|
+
if task_id is None or task_id == -1:
|
|
811
|
+
return agent_info
|
|
812
|
+
|
|
813
|
+
task_info = self.api.task.get_info_by_id(task_id)
|
|
814
|
+
if task_info is not None:
|
|
815
|
+
agent_info["name"] = task_info["agentName"]
|
|
816
|
+
agent_info["id"] = task_info["agentId"]
|
|
817
|
+
agent_info["link"] = (
|
|
818
|
+
f"{self.api.server_address}/nodes/{agent_info['id']}/info"
|
|
819
|
+
)
|
|
820
|
+
return agent_info
|
|
821
|
+
|
|
822
|
+
def _get_class_names(self, model_classes: list) -> dict:
|
|
823
|
+
"""Get class names from model meta.
|
|
824
|
+
|
|
825
|
+
:returns: List of class names
|
|
826
|
+
:rtype: list
|
|
827
|
+
"""
|
|
828
|
+
|
|
829
|
+
return {
|
|
830
|
+
"string": ", ".join(model_classes),
|
|
831
|
+
"short_string": (
|
|
832
|
+
", ".join(model_classes[:5] + ["..."])
|
|
833
|
+
if len(model_classes) > 5
|
|
834
|
+
else ", ".join(model_classes)
|
|
835
|
+
),
|
|
836
|
+
"list": model_classes,
|
|
837
|
+
"short_list": (
|
|
838
|
+
model_classes[:3] + ["..."] if len(model_classes) > 3 else model_classes
|
|
839
|
+
),
|
|
840
|
+
}
|
|
841
|
+
|
|
842
|
+
def _get_app_apply_nn_app_info(self):
|
|
843
|
+
"""Get apply NN app info.
|
|
844
|
+
|
|
845
|
+
:returns: Apply NN app info
|
|
846
|
+
:rtype: dict
|
|
847
|
+
"""
|
|
848
|
+
|
|
849
|
+
apply_nn_images_slug = "nn-image-labeling/project-dataset"
|
|
850
|
+
apply_nn_images_module_id = self.api.app.get_ecosystem_module_id(
|
|
851
|
+
f"supervisely-ecosystem/{apply_nn_images_slug}"
|
|
852
|
+
)
|
|
853
|
+
apply_nn_videos_slug = "apply-nn-to-videos-project"
|
|
854
|
+
apply_nn_videos_module_id = self.api.app.get_ecosystem_module_id(
|
|
855
|
+
f"supervisely-ecosystem/{apply_nn_videos_slug}"
|
|
856
|
+
)
|
|
857
|
+
|
|
858
|
+
apply_nn_images_app = {
|
|
859
|
+
"slug": apply_nn_images_slug,
|
|
860
|
+
"module_id": apply_nn_images_module_id,
|
|
861
|
+
}
|
|
862
|
+
apply_nn_videos_app = {
|
|
863
|
+
"slug": apply_nn_videos_slug,
|
|
864
|
+
"module_id": apply_nn_videos_module_id,
|
|
865
|
+
}
|
|
866
|
+
return apply_nn_images_app, apply_nn_videos_app
|
|
867
|
+
|
|
868
|
+
def _get_project_context(self):
|
|
869
|
+
project_id = self.info["project_id"]
|
|
870
|
+
project_info = self.api.project.get_info_by_id(project_id)
|
|
871
|
+
project_type = project_info.type
|
|
872
|
+
project_url = f"{self.api.server_address}/projects/{project_id}/datasets"
|
|
873
|
+
project_train_size = self.info.get("train_size", "N/A")
|
|
874
|
+
project_val_size = self.info.get("val_size", "N/A")
|
|
875
|
+
model_classes = [cls.name for cls in self.model_meta.obj_classes]
|
|
876
|
+
class_names = self._get_class_names(model_classes)
|
|
877
|
+
|
|
878
|
+
project_context = {
|
|
879
|
+
"id": project_id,
|
|
880
|
+
"name": project_info.name if project_info else "Project was archived",
|
|
881
|
+
"url": project_url,
|
|
882
|
+
"type": project_type,
|
|
883
|
+
"splits": {
|
|
884
|
+
"train": project_train_size,
|
|
885
|
+
"val": project_val_size,
|
|
886
|
+
},
|
|
887
|
+
"classes": {
|
|
888
|
+
"count": len(model_classes),
|
|
889
|
+
"names": class_names,
|
|
890
|
+
},
|
|
891
|
+
}
|
|
892
|
+
return project_context
|
|
893
|
+
|
|
894
|
+
def _get_base_checkpoint_info(self):
|
|
895
|
+
base_checkpoint_name = self.info.get("base_checkpoint", "N/A")
|
|
896
|
+
base_checkpoint_link = self.info.get("base_checkpoint_link", None)
|
|
897
|
+
base_checkpoint_path = None
|
|
898
|
+
if base_checkpoint_link is not None:
|
|
899
|
+
if base_checkpoint_link.startswith("/experiments/"):
|
|
900
|
+
base_checkpoint_info = self.api.file.get_info_by_path(self.team_id, base_checkpoint_link)
|
|
901
|
+
base_checkpoint_name = base_checkpoint_info.name
|
|
902
|
+
base_checkpoint_link = base_checkpoint_info.full_storage_url
|
|
903
|
+
base_checkpoint_path = f"{self.api.server_address}/files/?path={base_checkpoint_info.path}"
|
|
904
|
+
|
|
905
|
+
base_checkpoint = {
|
|
906
|
+
"name": base_checkpoint_name,
|
|
907
|
+
"url": base_checkpoint_link,
|
|
908
|
+
"path": base_checkpoint_path,
|
|
909
|
+
}
|
|
910
|
+
return base_checkpoint
|
|
911
|
+
|
|
912
|
+
def _get_model_context(self):
|
|
913
|
+
"""Return model description part of context."""
|
|
914
|
+
return {
|
|
915
|
+
"name": self.info["model_name"],
|
|
916
|
+
"framework": self.info["framework_name"],
|
|
917
|
+
"base_checkpoint": self._get_base_checkpoint_info(),
|
|
918
|
+
"task_type": self.info["task_type"],
|
|
919
|
+
}
|
|
920
|
+
|
|
921
|
+
def _get_training_context(self):
|
|
922
|
+
"""Return training-related context (checkpoints, metrics, etc.)."""
|
|
923
|
+
|
|
924
|
+
device = self.info.get("device", "N/A")
|
|
925
|
+
training_session = self._get_training_session()
|
|
926
|
+
training_duration = self._get_training_duration()
|
|
927
|
+
hyperparameters = self._generate_hyperparameters_yaml()
|
|
928
|
+
|
|
929
|
+
best_checkpoint = self._get_best_checkpoint()
|
|
930
|
+
onnx_checkpoint, trt_checkpoint = self._get_optimized_checkpoints()
|
|
931
|
+
|
|
932
|
+
logs_path = self.info.get("logs", {}).get("link")
|
|
933
|
+
logs_url = (
|
|
934
|
+
f"{self.api.server_address}/files/?path={logs_path}" if logs_path else None
|
|
935
|
+
)
|
|
936
|
+
|
|
937
|
+
primary_metric = self._get_primary_metric()
|
|
938
|
+
display_metrics = self._get_display_metrics(self.info["task_type"])
|
|
939
|
+
|
|
940
|
+
return {
|
|
941
|
+
"device": device,
|
|
942
|
+
"session": training_session,
|
|
943
|
+
"duration": training_duration,
|
|
944
|
+
"hyperparameters": hyperparameters,
|
|
945
|
+
"checkpoints": {
|
|
946
|
+
"pytorch": best_checkpoint,
|
|
947
|
+
"onnx": onnx_checkpoint,
|
|
948
|
+
"tensorrt": trt_checkpoint,
|
|
949
|
+
},
|
|
950
|
+
"export": self.info.get("export"),
|
|
951
|
+
"logs": {"path": logs_path, "url": logs_url},
|
|
952
|
+
"evaluation": {
|
|
953
|
+
"id": self.info.get("evaluation_report_id"),
|
|
954
|
+
"url": self.info.get("evaluation_report_link"),
|
|
955
|
+
"primary_metric": primary_metric,
|
|
956
|
+
"display_metrics": display_metrics,
|
|
957
|
+
"metrics": self.info.get("evaluation_metrics"),
|
|
958
|
+
},
|
|
959
|
+
}
|
|
960
|
+
|
|
961
|
+
def _get_experiment_context(self):
|
|
962
|
+
task_id = self.info["task_id"]
|
|
963
|
+
exp_name = self.info["experiment_name"]
|
|
964
|
+
agent_info = self._get_agent_info()
|
|
965
|
+
date = self._get_date()
|
|
966
|
+
project_context = self._get_project_context()
|
|
967
|
+
model_context = self._get_model_context()
|
|
968
|
+
training_context = self._get_training_context()
|
|
969
|
+
artifacts_dir = self.info["artifacts_dir"].rstrip("/")
|
|
970
|
+
experiment_dir = os.path.basename(artifacts_dir)
|
|
971
|
+
checkpoints_dir = os.path.join(artifacts_dir, "checkpoints")
|
|
972
|
+
|
|
973
|
+
experiment_context = {
|
|
974
|
+
"task_id": task_id,
|
|
975
|
+
"name": exp_name,
|
|
976
|
+
"agent": agent_info,
|
|
977
|
+
"date": date,
|
|
978
|
+
"project": project_context,
|
|
979
|
+
"model": model_context,
|
|
980
|
+
"training": training_context,
|
|
981
|
+
"paths": {
|
|
982
|
+
"experiment_dir": {
|
|
983
|
+
"path": experiment_dir,
|
|
984
|
+
"url": f"{self.api.server_address}/files/?path={experiment_dir.rstrip('/') + '/'}",
|
|
985
|
+
},
|
|
986
|
+
"artifacts_dir": {
|
|
987
|
+
"path": artifacts_dir,
|
|
988
|
+
"url": f"{self.api.server_address}/files/?path={artifacts_dir.rstrip('/') + '/'}",
|
|
989
|
+
},
|
|
990
|
+
"checkpoints_dir": {
|
|
991
|
+
"path": checkpoints_dir,
|
|
992
|
+
"url": f"{self.api.server_address}/files/?path={checkpoints_dir.rstrip('/') + '/'}",
|
|
993
|
+
},
|
|
994
|
+
},
|
|
995
|
+
}
|
|
996
|
+
return experiment_context
|