supervisely 6.73.390__py3-none-any.whl → 6.73.392__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. supervisely/app/widgets/experiment_selector/experiment_selector.py +21 -3
  2. supervisely/app/widgets/experiment_selector/template.html +49 -70
  3. supervisely/app/widgets/report_thumbnail/report_thumbnail.py +19 -4
  4. supervisely/decorators/profile.py +20 -0
  5. supervisely/nn/benchmark/utils/detection/utlis.py +7 -0
  6. supervisely/nn/experiments.py +4 -0
  7. supervisely/nn/inference/gui/serving_gui_template.py +71 -11
  8. supervisely/nn/inference/inference.py +108 -6
  9. supervisely/nn/training/gui/classes_selector.py +246 -27
  10. supervisely/nn/training/gui/gui.py +318 -234
  11. supervisely/nn/training/gui/hyperparameters_selector.py +2 -2
  12. supervisely/nn/training/gui/model_selector.py +42 -1
  13. supervisely/nn/training/gui/tags_selector.py +1 -1
  14. supervisely/nn/training/gui/train_val_splits_selector.py +8 -7
  15. supervisely/nn/training/gui/training_artifacts.py +10 -1
  16. supervisely/nn/training/gui/training_process.py +17 -1
  17. supervisely/nn/training/train_app.py +227 -72
  18. supervisely/template/__init__.py +2 -0
  19. supervisely/template/base_generator.py +90 -0
  20. supervisely/template/experiment/__init__.py +0 -0
  21. supervisely/template/experiment/experiment.html.jinja +537 -0
  22. supervisely/template/experiment/experiment_generator.py +996 -0
  23. supervisely/template/experiment/header.html.jinja +154 -0
  24. supervisely/template/experiment/sidebar.html.jinja +240 -0
  25. supervisely/template/experiment/sly-style.css +397 -0
  26. supervisely/template/experiment/template.html.jinja +18 -0
  27. supervisely/template/extensions.py +172 -0
  28. supervisely/template/template_renderer.py +253 -0
  29. {supervisely-6.73.390.dist-info → supervisely-6.73.392.dist-info}/METADATA +3 -1
  30. {supervisely-6.73.390.dist-info → supervisely-6.73.392.dist-info}/RECORD +34 -23
  31. {supervisely-6.73.390.dist-info → supervisely-6.73.392.dist-info}/LICENSE +0 -0
  32. {supervisely-6.73.390.dist-info → supervisely-6.73.392.dist-info}/WHEEL +0 -0
  33. {supervisely-6.73.390.dist-info → supervisely-6.73.392.dist-info}/entry_points.txt +0 -0
  34. {supervisely-6.73.390.dist-info → supervisely-6.73.392.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,996 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+ from typing import Dict, Literal, Optional, Tuple
7
+
8
+ import supervisely.io.env as sly_env
9
+ import supervisely.io.fs as sly_fs
10
+ import supervisely.io.json as sly_json
11
+ from supervisely import logger
12
+ from supervisely.api.api import Api
13
+ from supervisely.api.file_api import FileInfo
14
+ from supervisely.nn.benchmark.object_detection.metric_provider import (
15
+ METRIC_NAMES as OBJECT_DETECTION_METRIC_NAMES,
16
+ )
17
+ from supervisely.nn.benchmark.semantic_segmentation.metric_provider import (
18
+ METRIC_NAMES as SEMANTIC_SEGMENTATION_METRIC_NAMES,
19
+ )
20
+ from supervisely.nn.inference import Inference
21
+ from supervisely.nn.task_type import TaskType
22
+ from supervisely.nn.utils import RuntimeType
23
+ from supervisely.project import ProjectMeta
24
+ from supervisely.template.base_generator import BaseGenerator
25
+
26
+ from supervisely.geometry.any_geometry import AnyGeometry
27
+ from supervisely.geometry.cuboid_3d import Cuboid3d
28
+ from supervisely.geometry.point_3d import Point3d
29
+ from supervisely.geometry.pointcloud import Pointcloud
30
+ from supervisely.geometry.multichannel_bitmap import MultichannelBitmap
31
+ from supervisely.geometry.cuboid import Cuboid
32
+ from supervisely.geometry.graph import GraphNodes
33
+ from supervisely.geometry.bitmap import Bitmap
34
+ from supervisely.geometry.point import Point
35
+ from supervisely.geometry.polyline import Polyline
36
+ from supervisely.geometry.polygon import Polygon
37
+ from supervisely.geometry.rectangle import Rectangle
38
+
39
+ from supervisely.imaging.color import rgb2hex
40
+
41
+
42
+ # @TODO: Partly supports unreleased apps
43
+ class ExperimentGenerator(BaseGenerator):
44
+
45
+ def __init__(
46
+ self,
47
+ api: Api,
48
+ experiment_info: dict,
49
+ hyperparameters: str,
50
+ model_meta: ProjectMeta,
51
+ serving_class: Optional[Inference] = None,
52
+ team_id: Optional[int] = None,
53
+ output_dir: str = "./experiment_report",
54
+ app_options: Optional[dict] = None,
55
+ ):
56
+ """Initialize experiment generator class.
57
+
58
+ :param api: Supervisely API instance
59
+ :type api: Api
60
+ :param experiment_info: Dictionary with experiment information
61
+ :type experiment_info: Dict[str, Any]
62
+ :param hyperparameters: Hyperparameters as YAML string or dictionary
63
+ :type hyperparameters: Optional[Union[str, Dict]]
64
+ :param model_meta: Model metadata as dictionary
65
+ :type model_meta: Optional[Union[str, Dict]]
66
+ :param serving_class: Serving class for model inference
67
+ :type serving_class: Optional[Inference]
68
+ """
69
+ super().__init__(api, output_dir=output_dir)
70
+ self.team_id = team_id or sly_env.team_id()
71
+ self.info = experiment_info
72
+ self.hyperparameters = hyperparameters
73
+ self.model_meta = model_meta
74
+ self.artifacts_dir = self.info["artifacts_dir"]
75
+ self.serving_class = serving_class
76
+ self.app_info = self._get_app_info()
77
+ self.app_options = app_options
78
+
79
+ def _report_url(self, server_address: str, template_id: int) -> str:
80
+ return f"{server_address}/nn/experiments/{template_id}"
81
+
82
+ def upload_to_artifacts(self):
83
+ remote_dir = os.path.join(self.info["artifacts_dir"], "visualization")
84
+ self.upload(remote_dir, team_id=self.team_id)
85
+
86
+ def get_report(self) -> FileInfo:
87
+ remote_report_path = os.path.join(
88
+ self.info["artifacts_dir"], "visualization", "template.vue"
89
+ )
90
+ experiment_report = self.api.file.get_info_by_path(
91
+ self.team_id, remote_report_path
92
+ )
93
+ if experiment_report is None:
94
+ raise ValueError("Generate and upload report first")
95
+ return experiment_report
96
+
97
+ def get_report_id(self) -> int:
98
+ return self.get_report().id
99
+
100
+ def get_report_link(self) -> str:
101
+ return self._report_url(self.api.server_address, self.get_report_id())
102
+
103
+ def state(self) -> dict:
104
+ return {}
105
+
106
+ def context(self) -> dict:
107
+ context = {
108
+ "env": self._get_env_context(),
109
+ "experiment": self._get_experiment_context(),
110
+ "resources": self._get_links_context(),
111
+ "code": self._get_code_context(),
112
+ "widgets": self._get_widgets_context(),
113
+ }
114
+ return context
115
+
116
+ # --------------------------- Context blocks helpers --------------------------- #
117
+ def _get_env_context(self):
118
+ return {"server_address": self.api.server_address}
119
+
120
+ def _get_apps_context(self):
121
+ train_app, serve_app = self._get_app_train_serve_app_info()
122
+ apply_images_app, apply_videos_app = self._get_app_apply_nn_app_info()
123
+ log_viewer_app = self._get_log_viewer_app_info()
124
+ return {
125
+ "train": train_app,
126
+ "serve": serve_app,
127
+ "log_viewer": log_viewer_app,
128
+ "apply_nn_to_images": apply_images_app,
129
+ "apply_nn_to_videos": apply_videos_app,
130
+ }
131
+
132
+ def _get_links_context(self):
133
+ return {"apps": self._get_apps_context()}
134
+
135
+ def _get_code_context(self):
136
+ docker_image = self._get_docker_image()
137
+ repo_info = self._get_repository_info()
138
+ pytorch_demo, onnx_demo, trt_demo = self._get_demo_scripts()
139
+
140
+ return {
141
+ "docker": {"image": docker_image},
142
+ "local_prediction": {
143
+ "repo": repo_info,
144
+ "serving_module": (
145
+ self.serving_class.__module__ if self.serving_class else None
146
+ ),
147
+ "serving_class": (
148
+ self.serving_class.__name__ if self.serving_class else None
149
+ ),
150
+ },
151
+ "demo": {
152
+ "pytorch": pytorch_demo,
153
+ "onnx": onnx_demo,
154
+ "tensorrt": trt_demo,
155
+ },
156
+ }
157
+
158
+ def _get_widgets_context(self):
159
+ checkpoints_table = self._generate_checkpoints_table()
160
+ metrics_table = self._generate_metrics_table(self.info["task_type"])
161
+ sample_gallery = self._get_sample_predictions_gallery()
162
+ classes_table = self._generate_classes_table()
163
+
164
+ return {
165
+ "tables": {
166
+ "checkpoints": checkpoints_table,
167
+ "metrics": metrics_table,
168
+ "classes": classes_table,
169
+ },
170
+ "sample_pred_gallery": sample_gallery,
171
+ }
172
+
173
+ # --------------------------------------------------------------------------- #
174
+
175
+ def _generate_metrics_table(self, task_type: str) -> str:
176
+ """Generate HTML table with evaluation metrics.
177
+
178
+ :returns: HTML string with metrics table
179
+ :rtype: str
180
+ """
181
+ metrics = self.info.get("evaluation_metrics", {})
182
+ if not metrics:
183
+ return None
184
+
185
+ html = ['<table class="table">']
186
+ html.append("<thead><tr><th>Metrics</th><th>Value</th></tr></thead>")
187
+ html.append("<tbody>")
188
+
189
+ if (
190
+ task_type == TaskType.OBJECT_DETECTION
191
+ or task_type == TaskType.INSTANCE_SEGMENTATION
192
+ ):
193
+ metric_names = OBJECT_DETECTION_METRIC_NAMES
194
+ elif task_type == TaskType.SEMANTIC_SEGMENTATION:
195
+ metric_names = SEMANTIC_SEGMENTATION_METRIC_NAMES
196
+ else:
197
+ raise NotImplementedError(f"Task type '{task_type}' is not supported")
198
+
199
+ for metric_name, metric_value in metrics.items():
200
+ formatted_metric_name = metric_names.get(metric_name)
201
+ if formatted_metric_name is None:
202
+ formatted_metric_name = metric_name.replace("_", " ")
203
+ formatted_metric_name = formatted_metric_name.replace("-", " ")
204
+ if isinstance(metric_value, float):
205
+ metric_value = f"{metric_value:.4f}"
206
+ html.append(f"<tr><td>{metric_name}</td><td>{metric_value}</td></tr>")
207
+
208
+ html.append("</tbody>")
209
+ html.append("</table>")
210
+ return "\n".join(html)
211
+
212
+ def _generate_classes_table(self) -> str:
213
+ """Generate HTML table with class names, shapes and colors.
214
+
215
+ :returns: HTML string with classes table
216
+ :rtype: str
217
+ """
218
+ type_to_icon = {
219
+ AnyGeometry: "zmdi zmdi-shape",
220
+ Rectangle: "zmdi zmdi-crop-din",
221
+ Polygon: "",
222
+ Bitmap: "zmdi zmdi-brush",
223
+ Polyline: "zmdi zmdi-gesture",
224
+ Point: "zmdi zmdi-dot-circle-alt",
225
+ Cuboid: "zmdi zmdi-ungroup", #
226
+ GraphNodes: "zmdi zmdi-grain",
227
+ Cuboid3d: "zmdi zmdi-codepen",
228
+ Pointcloud: "zmdi zmdi-cloud-outline",
229
+ MultichannelBitmap: "zmdi zmdi-layers",
230
+ Point3d: "zmdi zmdi-filter-center-focus",
231
+ }
232
+
233
+ if not hasattr(self.model_meta, "obj_classes"):
234
+ return None
235
+
236
+ if len(self.model_meta.obj_classes) == 0:
237
+ return None
238
+
239
+ html = ['<table class="table">']
240
+ html.append("<thead><tr><th>Class name</th><th>Shape</th></tr></thead>")
241
+ html.append("<tbody>")
242
+
243
+ for obj_class in self.model_meta.obj_classes:
244
+ class_name = obj_class.name
245
+ color_hex = rgb2hex(obj_class.color)
246
+ icon = type_to_icon.get(obj_class.geometry_type, "zmdi zmdi-shape")
247
+
248
+ class_cell = (
249
+ f"<i class='zmdi zmdi-circle' style='color: {color_hex}; margin-right: 5px;'></i>"
250
+ f"<span>{class_name}</span>"
251
+ )
252
+
253
+ if isinstance(icon, str) and icon.startswith("data:image"):
254
+ shape_cell = f"<img src='{icon}' style='height: 15px; margin-right: 2px;'/>"
255
+ else:
256
+ shape_cell = f"<i class='{icon}' style='margin-right: 5px;'></i>"
257
+
258
+ shape_name = obj_class.geometry_type.geometry_name()
259
+ shape_cell += f"<span>{shape_name}</span>"
260
+
261
+ html.append(f"<tr><td>{class_cell}</td><td>{shape_cell}</td></tr>")
262
+
263
+ html.append("</tbody>")
264
+ html.append("</table>")
265
+ return "\n".join(html)
266
+
267
+ def _generate_metrics_table_horizontal(self) -> str:
268
+ """Generate HTML table with evaluation metrics.
269
+
270
+ :returns: HTML string with metrics table
271
+ :rtype: str
272
+ """
273
+ metrics = self.info.get("evaluation_metrics", {})
274
+ if not metrics:
275
+ return None
276
+
277
+ html = ['<table class="table">']
278
+ # Generate header row with metric names
279
+ header_cells = []
280
+ for metric_name in metrics.keys():
281
+ metric_name = metric_name.replace("_", " ")
282
+ metric_name = metric_name.replace("-", " ")
283
+ header_cells.append(f"<th>{metric_name}</th>")
284
+ html.append(f"<thead><tr>{''.join(header_cells)}</tr></thead>")
285
+
286
+ # Generate value row
287
+ html.append("<tbody>")
288
+ value_cells = []
289
+ for metric_value in metrics.values():
290
+ if isinstance(metric_value, float):
291
+ metric_value = f"{metric_value:.4f}"
292
+ value_cells.append(f"<td>{metric_value}</td>")
293
+ html.append(f"<tr>{''.join(value_cells)}</tr>")
294
+ html.append("</tbody>")
295
+
296
+ html.append("</table>")
297
+ return "\n".join(html)
298
+
299
+ def _get_primary_metric(self) -> dict:
300
+ """Get primary metric from evaluation metrics.
301
+
302
+ :returns: Primary metric info
303
+ :rtype: dict
304
+ """
305
+ primary_metric = {
306
+ "name": None,
307
+ "value": None,
308
+ "rounded_value": None,
309
+ }
310
+
311
+ eval_metrics = self.info.get("evaluation_metrics", {})
312
+ primary_metric_name = self.info.get("primary_metric")
313
+ primary_metric_value = eval_metrics.get(primary_metric_name)
314
+
315
+ if primary_metric_name is None or primary_metric_value is None:
316
+ logger.debug(
317
+ f"Primary metric is not found in evaluation metrics: {eval_metrics}"
318
+ )
319
+ return primary_metric
320
+
321
+ primary_metric = {
322
+ "name": primary_metric_name,
323
+ "value": primary_metric_value,
324
+ "rounded_value": round(primary_metric_value, 3),
325
+ }
326
+ return primary_metric
327
+
328
+ def _get_display_metrics(self, task_type: str) -> list:
329
+ """Get first 5 metrics for display (excluding primary metric).
330
+
331
+ :param primary_metric: Primary metric info
332
+ :type primary_metric: dict
333
+ :returns: List of tuples (metric_name, metric_value) for display
334
+ :rtype: list
335
+ """
336
+ display_metrics = []
337
+ eval_metrics = self.info.get("evaluation_metrics", {})
338
+ if not eval_metrics:
339
+ return display_metrics
340
+
341
+ main_metrics = []
342
+ if (
343
+ task_type == TaskType.OBJECT_DETECTION
344
+ or task_type == TaskType.INSTANCE_SEGMENTATION
345
+ ):
346
+ main_metrics = ["mAP", "AP75", "AP50", "precision", "recall"]
347
+ elif task_type == TaskType.SEMANTIC_SEGMENTATION:
348
+ main_metrics = ["mIoU", "mPixel", "mPrecision", "mRecall", "mF1"]
349
+ else:
350
+ raise NotImplementedError(f"Task type '{task_type}' is not supported")
351
+
352
+ for metric_name in main_metrics:
353
+ if metric_name in eval_metrics:
354
+ metric_value = eval_metrics[metric_name]
355
+ value = round(metric_value, 3)
356
+ percent_value = round(metric_value * 100, 3)
357
+ display_metrics.append(
358
+ {
359
+ "name": metric_name,
360
+ "value": value,
361
+ "percent_value": percent_value,
362
+ }
363
+ )
364
+
365
+ return display_metrics
366
+
367
+ def _generate_checkpoints_table(self) -> str:
368
+ """Generate HTML table with checkpoint information.
369
+
370
+ :returns: HTML string with checkpoints table
371
+ :rtype: str
372
+ """
373
+ pytorch_checkpoints = self.info.get("checkpoints", None)
374
+ if pytorch_checkpoints is None:
375
+ raise ValueError("Checkpoints are not found in experiment info")
376
+
377
+ checkpoints = pytorch_checkpoints.copy()
378
+ export = self.info.get("export", {})
379
+ if export:
380
+ onnx_checkpoint = export.get(RuntimeType.ONNXRUNTIME)
381
+ trt_checkpoint = export.get(RuntimeType.TENSORRT)
382
+ if onnx_checkpoint is not None:
383
+ checkpoints.append(onnx_checkpoint)
384
+ if trt_checkpoint is not None:
385
+ checkpoints.append(trt_checkpoint)
386
+
387
+ checkpoint_paths = [
388
+ os.path.join(self.artifacts_dir, ckpt) for ckpt in checkpoints
389
+ ]
390
+ checkpoint_infos = [
391
+ self.api.file.get_info_by_path(self.team_id, path)
392
+ for path in checkpoint_paths
393
+ ]
394
+ checkpoint_sizes = [
395
+ f"{info.sizeb / 1024 / 1024:.2f} MB" for info in checkpoint_infos
396
+ ]
397
+ checkpoint_dl_links = [
398
+ f"<a href='{info.full_storage_url}' download='{sly_fs.get_file_name_with_ext(info.path)}'>Download</a>"
399
+ for info in checkpoint_infos
400
+ ]
401
+
402
+ html = ['<table class="table">']
403
+ html.append(
404
+ "<thead><tr><th>Checkpoint</th><th>Size</th><th> </th></tr></thead>"
405
+ )
406
+ html.append("<tbody>")
407
+ for checkpoint, size, dl_link in zip(
408
+ checkpoints, checkpoint_sizes, checkpoint_dl_links
409
+ ):
410
+ if isinstance(checkpoint, str):
411
+ html.append(
412
+ f"<tr><td>{os.path.basename(checkpoint)}</td><td>{size}</td><td>{dl_link}</td></tr>"
413
+ )
414
+ html.append("</tbody>")
415
+ html.append("</table>")
416
+ return "\n".join(html)
417
+
418
+ def _generate_hyperparameters_yaml(self) -> str:
419
+ """Return hyperparameters as YAML string.
420
+
421
+ :returns: YAML string with hyperparameters
422
+ :rtype: str
423
+ """
424
+ if self.hyperparameters is not None:
425
+ if not isinstance(self.hyperparameters, str):
426
+ raise ValueError("Hyperparameters must be a yaml string")
427
+ hyperparameters = self.hyperparameters.split("\n")
428
+ return hyperparameters
429
+ return None
430
+
431
+ def _get_log_viewer_app_info(self):
432
+ """Get log viewer app information.
433
+
434
+ :returns: Log viewer app info
435
+ :rtype: dict
436
+ """
437
+ slug = "supervisely-ecosystem/tensorboard-experiments-viewer"
438
+ module_id = self.api.app.get_ecosystem_module_id(slug)
439
+ return {"slug": slug, "module_id": module_id}
440
+
441
+ def _get_app_info(self):
442
+ """Get app information from task.
443
+
444
+ :returns: App info object or None if not found
445
+ :rtype: Optional[Any]
446
+ """
447
+ task_id = self.info.get("task_id", None)
448
+ if task_id is None or task_id == -1:
449
+ return None
450
+
451
+ task_info = self.api.task.get_info_by_id(task_id)
452
+ app_id = task_info["meta"]["app"]["id"]
453
+ return self.api.app.get_info_by_id(app_id)
454
+
455
+ def _get_training_session(self) -> dict:
456
+ """Get training session information.
457
+
458
+ :returns: Training session info
459
+ :rtype: dict
460
+ """
461
+ task_id = self.info["task_id"]
462
+ if task_id is None or task_id == -1:
463
+ training_session = {
464
+ "id": None,
465
+ "url": None,
466
+ }
467
+ return training_session
468
+
469
+ training_session = {
470
+ "id": task_id,
471
+ "url": f"{self.api.server_address}/apps/sessions/{task_id}",
472
+ }
473
+ return training_session
474
+
475
+ def _get_training_duration(self) -> str:
476
+ """Get training duration.
477
+
478
+ :returns: Training duration in format "{h}h {m}m" or "N/A"
479
+ :rtype: str
480
+ """
481
+ raw_duration = self.info.get("training_duration", "N/A")
482
+ if raw_duration in (None, "N/A"):
483
+ return "N/A"
484
+
485
+ try:
486
+ duration_seconds = float(raw_duration)
487
+ except (TypeError, ValueError):
488
+ return str(raw_duration)
489
+
490
+ hours = int(duration_seconds // 3600)
491
+ minutes = int((duration_seconds % 3600) // 60)
492
+ seconds = int(duration_seconds % 60)
493
+ return f"{hours}h {minutes}m {seconds}s"
494
+
495
+ def _get_date(self) -> str:
496
+ """Format experiment date.
497
+
498
+ :returns: Formatted date string
499
+ :rtype: str
500
+ """
501
+ date_str = self.info.get("datetime", "")
502
+ date = date_str
503
+ if date_str:
504
+ try:
505
+ dt = datetime.strptime(date_str, "%Y-%m-%d %H:%M:%S")
506
+ date = dt.strftime("%d %b %Y")
507
+ except ValueError:
508
+ pass
509
+ return date
510
+
511
+ def _get_best_checkpoint(self) -> dict:
512
+ """Get best checkpoint filename.
513
+
514
+ :returns: Best checkpoint info
515
+ :rtype: dict
516
+ """
517
+ best_checkpoint_path = os.path.join(
518
+ self.artifacts_dir, "checkpoints", self.info["best_checkpoint"]
519
+ )
520
+ best_checkpoint_info = self.api.file.get_info_by_path(
521
+ self.team_id, best_checkpoint_path
522
+ )
523
+ best_checkpoint = {
524
+ "name": self.info["best_checkpoint"],
525
+ "path": best_checkpoint_path,
526
+ "url": best_checkpoint_info.full_storage_url,
527
+ "size": f"{best_checkpoint_info.sizeb / 1024 / 1024:.1f} MB",
528
+ }
529
+ return best_checkpoint
530
+
531
+ def _get_optimized_checkpoints(self) -> Tuple[dict, dict]:
532
+ """Get optimized checkpoint filename (ONNX or TensorRT).
533
+
534
+ :returns: Checkpoint info or None if not available
535
+ :rtype: Optional[dict]
536
+ """
537
+ export = self.info.get("export", {})
538
+
539
+ onnx_checkpoint_data = {
540
+ "name": None,
541
+ "path": None,
542
+ "size": None,
543
+ "url": None,
544
+ "classes_url": None,
545
+ }
546
+ trt_checkpoint_data = {
547
+ "name": None,
548
+ "path": None,
549
+ "size": None,
550
+ "url": None,
551
+ "classes_url": None,
552
+ }
553
+
554
+ onnx_checkpoint = export.get(RuntimeType.ONNXRUNTIME)
555
+ if onnx_checkpoint is not None:
556
+ onnx_checkpoint_path = os.path.join(
557
+ self.artifacts_dir, export.get(RuntimeType.ONNXRUNTIME)
558
+ )
559
+ classes_file = self.api.file.get_info_by_path(
560
+ self.team_id,
561
+ os.path.join(os.path.dirname(onnx_checkpoint_path), "classes.json"),
562
+ )
563
+ onnx_file_info = self.api.file.get_info_by_path(
564
+ self.team_id, onnx_checkpoint_path
565
+ )
566
+ onnx_checkpoint_data = {
567
+ "name": os.path.basename(export.get(RuntimeType.ONNXRUNTIME)),
568
+ "path": onnx_checkpoint_path,
569
+ "size": f"{onnx_file_info.sizeb / 1024 / 1024:.1f} MB",
570
+ "url": onnx_file_info.full_storage_url,
571
+ "classes_url": classes_file.full_storage_url if classes_file else None,
572
+ }
573
+ trt_checkpoint = export.get(RuntimeType.TENSORRT)
574
+ if trt_checkpoint is not None:
575
+ trt_checkpoint_path = os.path.join(
576
+ self.artifacts_dir, export.get(RuntimeType.TENSORRT)
577
+ )
578
+ classes_file = self.api.file.get_info_by_path(
579
+ self.team_id,
580
+ os.path.join(os.path.dirname(trt_checkpoint_path), "classes.json"),
581
+ )
582
+ trt_file_info = self.api.file.get_info_by_path(
583
+ self.team_id, trt_checkpoint_path
584
+ )
585
+ trt_checkpoint_data = {
586
+ "name": os.path.basename(export.get(RuntimeType.TENSORRT)),
587
+ "path": trt_checkpoint_path,
588
+ "size": f"{trt_file_info.sizeb / 1024 / 1024:.1f} MB",
589
+ "url": trt_file_info.full_storage_url,
590
+ "classes_url": classes_file.full_storage_url if classes_file else None,
591
+ }
592
+ return onnx_checkpoint_data, trt_checkpoint_data
593
+
594
+ def _get_docker_image(self) -> str:
595
+ """Get Docker image for model.
596
+
597
+ :returns: Docker image name
598
+ :rtype: str
599
+ """
600
+ if self.app_info is None:
601
+ return None
602
+
603
+ docker_image = self.app_info.config["docker_image"]
604
+ if not docker_image:
605
+ raise ValueError("Docker image is not found in app config")
606
+ return docker_image
607
+
608
+ def _get_repository_info(self) -> Dict[str, str]:
609
+ """Get repository information.
610
+
611
+ :returns: Dictionary with repo URL and name
612
+ :rtype: Dict[str, str]
613
+ """
614
+ if self.app_info is None:
615
+ return {"url": None, "name": None}
616
+
617
+ framework_name = self.info["framework_name"]
618
+ if hasattr(self.app_info, "repo"):
619
+ repo_link = self.app_info.repo
620
+ repo_name = repo_link.split("/")[-1]
621
+ return {"url": repo_link, "name": repo_name}
622
+
623
+ # @TODO: for unreleased apps
624
+ repo_name = framework_name.replace(" ", "-")
625
+ repo_link = f"https://github.com/supervisely-ecosystem/{repo_name}"
626
+ return {"url": repo_link, "name": repo_name}
627
+
628
+ # @TODO: method not used (might be helpful for unreleased apps)
629
+ def _find_app_config(self):
630
+ """
631
+ Find app config.json in project structure.
632
+
633
+ :returns: Config dictionary or None if not found
634
+ :rtype: Optional[Dict[str, Any]]
635
+ """
636
+ try:
637
+ current_dir = Path(os.path.abspath(os.path.dirname(__file__)))
638
+ root_dir = current_dir
639
+
640
+ while root_dir.parent != root_dir:
641
+ config_path = (
642
+ root_dir / "supervisely_integration" / "train" / "config.json"
643
+ )
644
+ if config_path.exists():
645
+ break
646
+ root_dir = root_dir.parent
647
+
648
+ config_path = root_dir / "supervisely_integration" / "train" / "config.json"
649
+ if config_path.exists():
650
+ return sly_json.load_json_file(config_path)
651
+
652
+ except Exception as e:
653
+ logger.warning(f"Failed to load config.json: {e}")
654
+ return None
655
+
656
+ def _get_sample_predictions_gallery(self):
657
+ evaluation_report_id = self.info.get("evaluation_report_id")
658
+ if evaluation_report_id is None:
659
+ return None
660
+ benchmark_file_info = self.api.file.get_info_by_id(evaluation_report_id)
661
+ evaluation_report_path = os.path.dirname(benchmark_file_info.path)
662
+ if os.path.basename(evaluation_report_path) != "visualizations":
663
+ logger.debug(
664
+ f"Visualizations directory is not found in the report directory: '{evaluation_report_path}'"
665
+ )
666
+ return None
667
+
668
+ evaluation_report_data_path = os.path.join(evaluation_report_path, "data")
669
+
670
+ seek_file = "explore_predictions_gallery_widget"
671
+ remote_gallery_widget_json = None
672
+ for file in self.api.file.list(
673
+ self.team_id, evaluation_report_data_path, False, "fileinfo"
674
+ ):
675
+ if file.name.startswith(seek_file) and file.name.endswith(".json"):
676
+ remote_gallery_widget_json = file.path
677
+ break
678
+
679
+ if remote_gallery_widget_json is None:
680
+ logger.debug(
681
+ f"Gallery widget is not found in the report directory: '{evaluation_report_path}'"
682
+ )
683
+ return None
684
+
685
+ save_path = os.path.join(
686
+ self.output_dir, "data", "explore_predictions_gallery_widget_expmt.json"
687
+ )
688
+ self.api.file.download(self.team_id, remote_gallery_widget_json, save_path)
689
+
690
+ widget_html = """
691
+ <sly-iw-gallery ref="gallery_widget_expmt" iw-widget-id="gallery_widget_expmt" :options="{'isModalWindow': false}"
692
+ :actions="{
693
+ 'init': {
694
+ 'dataSource': '/data/explore_predictions_gallery_widget_expmt.json',
695
+ },
696
+
697
+ }" :command="command" :data="data">
698
+ </sly-iw-gallery>
699
+ """
700
+ return widget_html
701
+
702
+ def _get_demo_scripts(self):
703
+ """Get demo scripts.
704
+
705
+ :returns: Demo scripts
706
+ :rtype: Tuple[dict, dict, dict]
707
+ """
708
+
709
+ demo_pytorch_filename = "demo_pytorch.py"
710
+ demo_onnx_filename = "demo_onnx.py"
711
+ demo_tensorrt_filename = "demo_tensorrt.py"
712
+
713
+ pytorch_demo = {"path": None, "script": None}
714
+ onnx_demo = {"path": None, "script": None}
715
+ trt_demo = {"path": None, "script": None}
716
+
717
+ demo_path = self.app_options.get("demo", {}).get("path")
718
+ if demo_path is None:
719
+ logger.debug("Demo path is not found in app options")
720
+ return pytorch_demo, onnx_demo, trt_demo
721
+
722
+ local_demo_dir = os.path.join(os.getcwd(), demo_path)
723
+ if not sly_fs.dir_exists(local_demo_dir):
724
+ logger.debug(f"Demo directory '{local_demo_dir}' does not exist")
725
+ return pytorch_demo, onnx_demo, trt_demo
726
+
727
+ local_files = sly_fs.list_files(local_demo_dir)
728
+ for file in local_files:
729
+ if file.endswith(demo_pytorch_filename):
730
+ with open(file, "r", encoding="utf-8") as f:
731
+ script = f.read()
732
+ pytorch_demo = {"path": file, "script": script}
733
+ elif file.endswith(demo_onnx_filename):
734
+ with open(file, "r", encoding="utf-8") as f:
735
+ script = f.read()
736
+ onnx_demo = {"path": file, "script": script}
737
+ elif file.endswith(demo_tensorrt_filename):
738
+ with open(file, "r", encoding="utf-8") as f:
739
+ script = f.read()
740
+ trt_demo = {"path": file, "script": script}
741
+
742
+ return pytorch_demo, onnx_demo, trt_demo
743
+
744
+ def _get_app_train_serve_app_info(self):
745
+ """Get app slugs.
746
+
747
+ :returns: App slugs
748
+ :rtype: Tuple[str, str]
749
+ """
750
+
751
+ def find_app_by_framework(
752
+ api: Api, framework: str, action: Literal["train", "serve"]
753
+ ):
754
+ try:
755
+ modules = api.app.get_list_ecosystem_modules(
756
+ categories=[action, f"framework:{framework}"],
757
+ categories_operation="and",
758
+ )
759
+ if len(modules) == 0:
760
+ return None
761
+ return modules[0]
762
+ except Exception as e:
763
+ logger.warning(f"Failed to find {action} app by framework: {e}")
764
+ return None
765
+
766
+ train_app_info = find_app_by_framework(
767
+ self.api, self.info["framework_name"], "train"
768
+ )
769
+ serve_app_info = find_app_by_framework(
770
+ self.api, self.info["framework_name"], "serve"
771
+ )
772
+
773
+ if train_app_info is not None:
774
+ train_app_slug = train_app_info["slug"].replace(
775
+ "supervisely-ecosystem/", ""
776
+ )
777
+ train_app_id = train_app_info["id"]
778
+ else:
779
+ train_app_slug = None
780
+ train_app_id = None
781
+
782
+ if serve_app_info is not None:
783
+ serve_app_slug = serve_app_info["slug"].replace(
784
+ "supervisely-ecosystem/", ""
785
+ )
786
+ serve_app_id = serve_app_info["id"]
787
+ else:
788
+ serve_app_slug = None
789
+ serve_app_id = None
790
+
791
+ train_app = {
792
+ "slug": train_app_slug,
793
+ "module_id": train_app_id,
794
+ }
795
+ serve_app = {
796
+ "slug": serve_app_slug,
797
+ "module_id": serve_app_id,
798
+ }
799
+ return train_app, serve_app
800
+
801
+ def _get_agent_info(self) -> str:
802
+ task_id = self.info.get("task_id", None)
803
+
804
+ agent_info = {
805
+ "name": None,
806
+ "id": None,
807
+ "link": None,
808
+ }
809
+
810
+ if task_id is None or task_id == -1:
811
+ return agent_info
812
+
813
+ task_info = self.api.task.get_info_by_id(task_id)
814
+ if task_info is not None:
815
+ agent_info["name"] = task_info["agentName"]
816
+ agent_info["id"] = task_info["agentId"]
817
+ agent_info["link"] = (
818
+ f"{self.api.server_address}/nodes/{agent_info['id']}/info"
819
+ )
820
+ return agent_info
821
+
822
+ def _get_class_names(self, model_classes: list) -> dict:
823
+ """Get class names from model meta.
824
+
825
+ :returns: List of class names
826
+ :rtype: list
827
+ """
828
+
829
+ return {
830
+ "string": ", ".join(model_classes),
831
+ "short_string": (
832
+ ", ".join(model_classes[:5] + ["..."])
833
+ if len(model_classes) > 5
834
+ else ", ".join(model_classes)
835
+ ),
836
+ "list": model_classes,
837
+ "short_list": (
838
+ model_classes[:3] + ["..."] if len(model_classes) > 3 else model_classes
839
+ ),
840
+ }
841
+
842
+ def _get_app_apply_nn_app_info(self):
843
+ """Get apply NN app info.
844
+
845
+ :returns: Apply NN app info
846
+ :rtype: dict
847
+ """
848
+
849
+ apply_nn_images_slug = "nn-image-labeling/project-dataset"
850
+ apply_nn_images_module_id = self.api.app.get_ecosystem_module_id(
851
+ f"supervisely-ecosystem/{apply_nn_images_slug}"
852
+ )
853
+ apply_nn_videos_slug = "apply-nn-to-videos-project"
854
+ apply_nn_videos_module_id = self.api.app.get_ecosystem_module_id(
855
+ f"supervisely-ecosystem/{apply_nn_videos_slug}"
856
+ )
857
+
858
+ apply_nn_images_app = {
859
+ "slug": apply_nn_images_slug,
860
+ "module_id": apply_nn_images_module_id,
861
+ }
862
+ apply_nn_videos_app = {
863
+ "slug": apply_nn_videos_slug,
864
+ "module_id": apply_nn_videos_module_id,
865
+ }
866
+ return apply_nn_images_app, apply_nn_videos_app
867
+
868
+ def _get_project_context(self):
869
+ project_id = self.info["project_id"]
870
+ project_info = self.api.project.get_info_by_id(project_id)
871
+ project_type = project_info.type
872
+ project_url = f"{self.api.server_address}/projects/{project_id}/datasets"
873
+ project_train_size = self.info.get("train_size", "N/A")
874
+ project_val_size = self.info.get("val_size", "N/A")
875
+ model_classes = [cls.name for cls in self.model_meta.obj_classes]
876
+ class_names = self._get_class_names(model_classes)
877
+
878
+ project_context = {
879
+ "id": project_id,
880
+ "name": project_info.name if project_info else "Project was archived",
881
+ "url": project_url,
882
+ "type": project_type,
883
+ "splits": {
884
+ "train": project_train_size,
885
+ "val": project_val_size,
886
+ },
887
+ "classes": {
888
+ "count": len(model_classes),
889
+ "names": class_names,
890
+ },
891
+ }
892
+ return project_context
893
+
894
+ def _get_base_checkpoint_info(self):
895
+ base_checkpoint_name = self.info.get("base_checkpoint", "N/A")
896
+ base_checkpoint_link = self.info.get("base_checkpoint_link", None)
897
+ base_checkpoint_path = None
898
+ if base_checkpoint_link is not None:
899
+ if base_checkpoint_link.startswith("/experiments/"):
900
+ base_checkpoint_info = self.api.file.get_info_by_path(self.team_id, base_checkpoint_link)
901
+ base_checkpoint_name = base_checkpoint_info.name
902
+ base_checkpoint_link = base_checkpoint_info.full_storage_url
903
+ base_checkpoint_path = f"{self.api.server_address}/files/?path={base_checkpoint_info.path}"
904
+
905
+ base_checkpoint = {
906
+ "name": base_checkpoint_name,
907
+ "url": base_checkpoint_link,
908
+ "path": base_checkpoint_path,
909
+ }
910
+ return base_checkpoint
911
+
912
+ def _get_model_context(self):
913
+ """Return model description part of context."""
914
+ return {
915
+ "name": self.info["model_name"],
916
+ "framework": self.info["framework_name"],
917
+ "base_checkpoint": self._get_base_checkpoint_info(),
918
+ "task_type": self.info["task_type"],
919
+ }
920
+
921
+ def _get_training_context(self):
922
+ """Return training-related context (checkpoints, metrics, etc.)."""
923
+
924
+ device = self.info.get("device", "N/A")
925
+ training_session = self._get_training_session()
926
+ training_duration = self._get_training_duration()
927
+ hyperparameters = self._generate_hyperparameters_yaml()
928
+
929
+ best_checkpoint = self._get_best_checkpoint()
930
+ onnx_checkpoint, trt_checkpoint = self._get_optimized_checkpoints()
931
+
932
+ logs_path = self.info.get("logs", {}).get("link")
933
+ logs_url = (
934
+ f"{self.api.server_address}/files/?path={logs_path}" if logs_path else None
935
+ )
936
+
937
+ primary_metric = self._get_primary_metric()
938
+ display_metrics = self._get_display_metrics(self.info["task_type"])
939
+
940
+ return {
941
+ "device": device,
942
+ "session": training_session,
943
+ "duration": training_duration,
944
+ "hyperparameters": hyperparameters,
945
+ "checkpoints": {
946
+ "pytorch": best_checkpoint,
947
+ "onnx": onnx_checkpoint,
948
+ "tensorrt": trt_checkpoint,
949
+ },
950
+ "export": self.info.get("export"),
951
+ "logs": {"path": logs_path, "url": logs_url},
952
+ "evaluation": {
953
+ "id": self.info.get("evaluation_report_id"),
954
+ "url": self.info.get("evaluation_report_link"),
955
+ "primary_metric": primary_metric,
956
+ "display_metrics": display_metrics,
957
+ "metrics": self.info.get("evaluation_metrics"),
958
+ },
959
+ }
960
+
961
+ def _get_experiment_context(self):
962
+ task_id = self.info["task_id"]
963
+ exp_name = self.info["experiment_name"]
964
+ agent_info = self._get_agent_info()
965
+ date = self._get_date()
966
+ project_context = self._get_project_context()
967
+ model_context = self._get_model_context()
968
+ training_context = self._get_training_context()
969
+ artifacts_dir = self.info["artifacts_dir"].rstrip("/")
970
+ experiment_dir = os.path.basename(artifacts_dir)
971
+ checkpoints_dir = os.path.join(artifacts_dir, "checkpoints")
972
+
973
+ experiment_context = {
974
+ "task_id": task_id,
975
+ "name": exp_name,
976
+ "agent": agent_info,
977
+ "date": date,
978
+ "project": project_context,
979
+ "model": model_context,
980
+ "training": training_context,
981
+ "paths": {
982
+ "experiment_dir": {
983
+ "path": experiment_dir,
984
+ "url": f"{self.api.server_address}/files/?path={experiment_dir.rstrip('/') + '/'}",
985
+ },
986
+ "artifacts_dir": {
987
+ "path": artifacts_dir,
988
+ "url": f"{self.api.server_address}/files/?path={artifacts_dir.rstrip('/') + '/'}",
989
+ },
990
+ "checkpoints_dir": {
991
+ "path": checkpoints_dir,
992
+ "url": f"{self.api.server_address}/files/?path={checkpoints_dir.rstrip('/') + '/'}",
993
+ },
994
+ },
995
+ }
996
+ return experiment_context