supervisely 6.73.438__py3-none-any.whl → 6.73.513__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/__init__.py +137 -1
- supervisely/_utils.py +81 -0
- supervisely/annotation/annotation.py +8 -2
- supervisely/annotation/json_geometries_map.py +14 -11
- supervisely/annotation/label.py +80 -3
- supervisely/api/annotation_api.py +14 -11
- supervisely/api/api.py +59 -38
- supervisely/api/app_api.py +11 -2
- supervisely/api/dataset_api.py +74 -12
- supervisely/api/entities_collection_api.py +10 -0
- supervisely/api/entity_annotation/figure_api.py +52 -4
- supervisely/api/entity_annotation/object_api.py +3 -3
- supervisely/api/entity_annotation/tag_api.py +63 -12
- supervisely/api/guides_api.py +210 -0
- supervisely/api/image_api.py +72 -1
- supervisely/api/labeling_job_api.py +83 -1
- supervisely/api/labeling_queue_api.py +33 -7
- supervisely/api/module_api.py +9 -0
- supervisely/api/project_api.py +71 -26
- supervisely/api/storage_api.py +3 -1
- supervisely/api/task_api.py +13 -2
- supervisely/api/team_api.py +4 -3
- supervisely/api/video/video_annotation_api.py +119 -3
- supervisely/api/video/video_api.py +65 -14
- supervisely/api/video/video_figure_api.py +24 -11
- supervisely/app/__init__.py +1 -1
- supervisely/app/content.py +23 -7
- supervisely/app/development/development.py +18 -2
- supervisely/app/fastapi/__init__.py +1 -0
- supervisely/app/fastapi/custom_static_files.py +1 -1
- supervisely/app/fastapi/multi_user.py +105 -0
- supervisely/app/fastapi/subapp.py +88 -42
- supervisely/app/fastapi/websocket.py +77 -9
- supervisely/app/singleton.py +21 -0
- supervisely/app/v1/app_service.py +18 -2
- supervisely/app/v1/constants.py +7 -1
- supervisely/app/widgets/__init__.py +6 -0
- supervisely/app/widgets/activity_feed/__init__.py +0 -0
- supervisely/app/widgets/activity_feed/activity_feed.py +239 -0
- supervisely/app/widgets/activity_feed/style.css +78 -0
- supervisely/app/widgets/activity_feed/template.html +22 -0
- supervisely/app/widgets/card/card.py +20 -0
- supervisely/app/widgets/classes_list_selector/classes_list_selector.py +121 -9
- supervisely/app/widgets/classes_list_selector/template.html +60 -93
- supervisely/app/widgets/classes_mapping/classes_mapping.py +13 -12
- supervisely/app/widgets/classes_table/classes_table.py +1 -0
- supervisely/app/widgets/deploy_model/deploy_model.py +56 -35
- supervisely/app/widgets/dialog/dialog.py +12 -0
- supervisely/app/widgets/dialog/template.html +2 -1
- supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +1 -1
- supervisely/app/widgets/experiment_selector/experiment_selector.py +8 -0
- supervisely/app/widgets/fast_table/fast_table.py +184 -60
- supervisely/app/widgets/fast_table/template.html +1 -1
- supervisely/app/widgets/heatmap/__init__.py +0 -0
- supervisely/app/widgets/heatmap/heatmap.py +564 -0
- supervisely/app/widgets/heatmap/script.js +533 -0
- supervisely/app/widgets/heatmap/style.css +233 -0
- supervisely/app/widgets/heatmap/template.html +21 -0
- supervisely/app/widgets/modal/__init__.py +0 -0
- supervisely/app/widgets/modal/modal.py +198 -0
- supervisely/app/widgets/modal/template.html +10 -0
- supervisely/app/widgets/object_class_view/object_class_view.py +3 -0
- supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
- supervisely/app/widgets/radio_tabs/template.html +1 -0
- supervisely/app/widgets/select/select.py +6 -3
- supervisely/app/widgets/select_class/__init__.py +0 -0
- supervisely/app/widgets/select_class/select_class.py +363 -0
- supervisely/app/widgets/select_class/template.html +50 -0
- supervisely/app/widgets/select_cuda/select_cuda.py +22 -0
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +65 -7
- supervisely/app/widgets/select_tag/__init__.py +0 -0
- supervisely/app/widgets/select_tag/select_tag.py +352 -0
- supervisely/app/widgets/select_tag/template.html +64 -0
- supervisely/app/widgets/select_team/select_team.py +37 -4
- supervisely/app/widgets/select_team/template.html +4 -5
- supervisely/app/widgets/select_user/__init__.py +0 -0
- supervisely/app/widgets/select_user/select_user.py +270 -0
- supervisely/app/widgets/select_user/template.html +13 -0
- supervisely/app/widgets/select_workspace/select_workspace.py +59 -10
- supervisely/app/widgets/select_workspace/template.html +9 -12
- supervisely/app/widgets/table/table.py +68 -13
- supervisely/app/widgets/tree_select/tree_select.py +2 -0
- supervisely/aug/aug.py +6 -2
- supervisely/convert/base_converter.py +1 -0
- supervisely/convert/converter.py +2 -2
- supervisely/convert/image/csv/csv_converter.py +24 -15
- supervisely/convert/image/image_converter.py +3 -1
- supervisely/convert/image/image_helper.py +48 -4
- supervisely/convert/image/label_studio/label_studio_converter.py +2 -0
- supervisely/convert/image/medical2d/medical2d_helper.py +2 -24
- supervisely/convert/image/multispectral/multispectral_converter.py +6 -0
- supervisely/convert/image/pascal_voc/pascal_voc_converter.py +8 -5
- supervisely/convert/image/pascal_voc/pascal_voc_helper.py +7 -0
- supervisely/convert/pointcloud/kitti_3d/kitti_3d_converter.py +33 -3
- supervisely/convert/pointcloud/kitti_3d/kitti_3d_helper.py +12 -5
- supervisely/convert/pointcloud/las/las_converter.py +13 -1
- supervisely/convert/pointcloud/las/las_helper.py +110 -11
- supervisely/convert/pointcloud/nuscenes_conv/nuscenes_converter.py +27 -16
- supervisely/convert/pointcloud/pointcloud_converter.py +91 -3
- supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_converter.py +58 -22
- supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_helper.py +21 -47
- supervisely/convert/video/__init__.py +1 -0
- supervisely/convert/video/multi_view/__init__.py +0 -0
- supervisely/convert/video/multi_view/multi_view.py +543 -0
- supervisely/convert/video/sly/sly_video_converter.py +359 -3
- supervisely/convert/video/video_converter.py +24 -4
- supervisely/convert/volume/dicom/dicom_converter.py +13 -5
- supervisely/convert/volume/dicom/dicom_helper.py +30 -18
- supervisely/geometry/constants.py +1 -0
- supervisely/geometry/geometry.py +4 -0
- supervisely/geometry/helpers.py +5 -1
- supervisely/geometry/oriented_bbox.py +676 -0
- supervisely/geometry/polyline_3d.py +110 -0
- supervisely/geometry/rectangle.py +2 -1
- supervisely/io/env.py +76 -1
- supervisely/io/fs.py +21 -0
- supervisely/nn/benchmark/base_evaluator.py +104 -11
- supervisely/nn/benchmark/instance_segmentation/evaluator.py +1 -8
- supervisely/nn/benchmark/object_detection/evaluator.py +20 -4
- supervisely/nn/benchmark/object_detection/vis_metrics/pr_curve.py +10 -5
- supervisely/nn/benchmark/semantic_segmentation/evaluator.py +34 -16
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/confusion_matrix.py +1 -1
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/frequently_confused.py +1 -1
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/overview.py +1 -1
- supervisely/nn/benchmark/visualization/evaluation_result.py +66 -4
- supervisely/nn/inference/cache.py +43 -18
- supervisely/nn/inference/gui/serving_gui_template.py +5 -2
- supervisely/nn/inference/inference.py +916 -222
- supervisely/nn/inference/inference_request.py +55 -10
- supervisely/nn/inference/predict_app/gui/classes_selector.py +83 -12
- supervisely/nn/inference/predict_app/gui/gui.py +676 -488
- supervisely/nn/inference/predict_app/gui/input_selector.py +205 -26
- supervisely/nn/inference/predict_app/gui/model_selector.py +2 -4
- supervisely/nn/inference/predict_app/gui/output_selector.py +46 -6
- supervisely/nn/inference/predict_app/gui/settings_selector.py +756 -59
- supervisely/nn/inference/predict_app/gui/tags_selector.py +1 -1
- supervisely/nn/inference/predict_app/gui/utils.py +236 -119
- supervisely/nn/inference/predict_app/predict_app.py +2 -2
- supervisely/nn/inference/session.py +43 -35
- supervisely/nn/inference/tracking/bbox_tracking.py +118 -35
- supervisely/nn/inference/tracking/point_tracking.py +5 -1
- supervisely/nn/inference/tracking/tracker_interface.py +10 -1
- supervisely/nn/inference/uploader.py +139 -12
- supervisely/nn/live_training/__init__.py +7 -0
- supervisely/nn/live_training/api_server.py +111 -0
- supervisely/nn/live_training/artifacts_utils.py +243 -0
- supervisely/nn/live_training/checkpoint_utils.py +229 -0
- supervisely/nn/live_training/dynamic_sampler.py +44 -0
- supervisely/nn/live_training/helpers.py +14 -0
- supervisely/nn/live_training/incremental_dataset.py +146 -0
- supervisely/nn/live_training/live_training.py +497 -0
- supervisely/nn/live_training/loss_plateau_detector.py +111 -0
- supervisely/nn/live_training/request_queue.py +52 -0
- supervisely/nn/model/model_api.py +9 -0
- supervisely/nn/model/prediction.py +2 -1
- supervisely/nn/model/prediction_session.py +26 -14
- supervisely/nn/prediction_dto.py +19 -1
- supervisely/nn/tracker/base_tracker.py +11 -1
- supervisely/nn/tracker/botsort/botsort_config.yaml +0 -1
- supervisely/nn/tracker/botsort/tracker/mc_bot_sort.py +7 -4
- supervisely/nn/tracker/botsort_tracker.py +94 -65
- supervisely/nn/tracker/utils.py +4 -5
- supervisely/nn/tracker/visualize.py +93 -93
- supervisely/nn/training/gui/classes_selector.py +16 -1
- supervisely/nn/training/gui/train_val_splits_selector.py +52 -31
- supervisely/nn/training/train_app.py +46 -31
- supervisely/project/data_version.py +115 -51
- supervisely/project/download.py +1 -1
- supervisely/project/pointcloud_episode_project.py +37 -8
- supervisely/project/pointcloud_project.py +30 -2
- supervisely/project/project.py +14 -2
- supervisely/project/project_meta.py +27 -1
- supervisely/project/project_settings.py +32 -18
- supervisely/project/versioning/__init__.py +1 -0
- supervisely/project/versioning/common.py +20 -0
- supervisely/project/versioning/schema_fields.py +35 -0
- supervisely/project/versioning/video_schema.py +221 -0
- supervisely/project/versioning/volume_schema.py +87 -0
- supervisely/project/video_project.py +717 -15
- supervisely/project/volume_project.py +623 -5
- supervisely/template/experiment/experiment.html.jinja +4 -4
- supervisely/template/experiment/experiment_generator.py +14 -21
- supervisely/template/live_training/__init__.py +0 -0
- supervisely/template/live_training/header.html.jinja +96 -0
- supervisely/template/live_training/live_training.html.jinja +51 -0
- supervisely/template/live_training/live_training_generator.py +464 -0
- supervisely/template/live_training/sly-style.css +402 -0
- supervisely/template/live_training/template.html.jinja +18 -0
- supervisely/versions.json +28 -26
- supervisely/video/sampling.py +39 -20
- supervisely/video/video.py +41 -12
- supervisely/video_annotation/video_figure.py +38 -4
- supervisely/video_annotation/video_object.py +29 -4
- supervisely/volume/stl_converter.py +2 -0
- supervisely/worker_api/agent_rpc.py +24 -1
- supervisely/worker_api/rpc_servicer.py +31 -7
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/METADATA +58 -40
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/RECORD +203 -155
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/WHEEL +1 -1
- supervisely_lib/__init__.py +6 -1
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info/licenses}/LICENSE +0 -0
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,464 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import re
|
|
5
|
+
import math
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Dict, Optional
|
|
9
|
+
import plotly.graph_objects as go # pylint: disable=import-error
|
|
10
|
+
from plotly.subplots import make_subplots # pylint: disable=import-error
|
|
11
|
+
|
|
12
|
+
import supervisely as sly
|
|
13
|
+
from supervisely import Api, ProjectMeta, logger
|
|
14
|
+
from supervisely.template.base_generator import BaseGenerator
|
|
15
|
+
from supervisely.imaging.color import rgb2hex
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LiveTrainingGenerator(BaseGenerator):
|
|
19
|
+
"""
|
|
20
|
+
Generator for Live training session reports.
|
|
21
|
+
|
|
22
|
+
Logs:
|
|
23
|
+
- Model hyperparameters
|
|
24
|
+
- Training loss graphs
|
|
25
|
+
- Checkpoints
|
|
26
|
+
- Dataset size over time
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
api: Api,
|
|
32
|
+
session_info: dict,
|
|
33
|
+
model_config: dict,
|
|
34
|
+
model_meta: ProjectMeta,
|
|
35
|
+
task_type: str,
|
|
36
|
+
output_dir: str = "./live_training_report",
|
|
37
|
+
team_id: Optional[int] = None,
|
|
38
|
+
|
|
39
|
+
):
|
|
40
|
+
"""
|
|
41
|
+
Initialize Live training generator.
|
|
42
|
+
|
|
43
|
+
:param api: Supervisely API instance
|
|
44
|
+
:param session_info: Session metadata (session_id, start_time, project_id, etc.)
|
|
45
|
+
:param model_config: Model configuration (hyperparameters, backbone, etc.)
|
|
46
|
+
:param model_meta: Model metadata with classes
|
|
47
|
+
:param output_dir: Local output directory
|
|
48
|
+
:param team_id: Team ID
|
|
49
|
+
"""
|
|
50
|
+
super().__init__(api, output_dir)
|
|
51
|
+
self.team_id = team_id or sly.env.team_id()
|
|
52
|
+
self.session_info = session_info
|
|
53
|
+
self.model_config = model_config
|
|
54
|
+
self.model_meta = model_meta
|
|
55
|
+
self.task_type = task_type
|
|
56
|
+
self._slug_map = {
|
|
57
|
+
"semantic segmentation": "supervisely-ecosystem/live-training---semantic-segmentation",
|
|
58
|
+
"object detection": "supervisely-ecosystem/live-training---object-detection",
|
|
59
|
+
}
|
|
60
|
+
self.slug = self._slug_map[task_type]
|
|
61
|
+
|
|
62
|
+
# Validate required fields
|
|
63
|
+
self._validate_session_info()
|
|
64
|
+
|
|
65
|
+
def _validate_session_info(self):
|
|
66
|
+
"""Validate that session_info contains required fields"""
|
|
67
|
+
required = ["session_id", "project_id", "start_time"]
|
|
68
|
+
missing = [k for k in required if k not in self.session_info]
|
|
69
|
+
if missing:
|
|
70
|
+
raise ValueError(f"Missing required fields in session_info: {missing}")
|
|
71
|
+
|
|
72
|
+
def _report_url(self, server_address: str, template_id: int) -> str:
|
|
73
|
+
"""Generate URL to open the Live training report"""
|
|
74
|
+
return f"{server_address}/nn/experiments/{template_id}"
|
|
75
|
+
|
|
76
|
+
def context(self) -> dict:
|
|
77
|
+
return {
|
|
78
|
+
"env": self._get_env_context(),
|
|
79
|
+
"session": self._get_session_context(),
|
|
80
|
+
"model": self._get_model_context(),
|
|
81
|
+
"training": self._get_training_context(),
|
|
82
|
+
"dataset": self._get_dataset_context(),
|
|
83
|
+
"widgets": self._get_widgets_context(),
|
|
84
|
+
"resources": self._get_resources_context(),
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
def _get_env_context(self) -> dict:
|
|
88
|
+
"""Environment info"""
|
|
89
|
+
return {
|
|
90
|
+
"server_address": self.api.server_address,
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
def _get_session_context(self) -> dict:
|
|
94
|
+
session_id = self.session_info["session_id"]
|
|
95
|
+
project_id = self.session_info["project_id"]
|
|
96
|
+
artifacts_dir = self.session_info.get("artifacts_dir", "")
|
|
97
|
+
task_id = self.session_info.get("task_id", session_id)
|
|
98
|
+
|
|
99
|
+
project_info = self.api.project.get_info_by_id(project_id)
|
|
100
|
+
project_url = f"{self.api.server_address}/projects/{project_id}/datasets"
|
|
101
|
+
artifacts_url = f"{self.api.server_address}/files/?path={artifacts_dir}" if artifacts_dir else None
|
|
102
|
+
|
|
103
|
+
return {
|
|
104
|
+
"id": session_id,
|
|
105
|
+
"task_id": task_id,
|
|
106
|
+
"name": self.session_info.get("session_name", f"Session {session_id}"),
|
|
107
|
+
"start_time": self.session_info["start_time"],
|
|
108
|
+
"duration": self.session_info.get("duration"),
|
|
109
|
+
"current_iteration": self.session_info.get("current_iteration", 0),
|
|
110
|
+
"artifacts_url": artifacts_url,
|
|
111
|
+
"artifacts_dir": artifacts_dir,
|
|
112
|
+
"project": {
|
|
113
|
+
"id": project_id,
|
|
114
|
+
"name": project_info.name if project_info else "Unknown",
|
|
115
|
+
"url": project_url if project_info else None,
|
|
116
|
+
},
|
|
117
|
+
"status": self.session_info.get("status", "running"),
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
@staticmethod
|
|
121
|
+
def parse_hyperparameters(config_path: str) -> dict:
|
|
122
|
+
"""
|
|
123
|
+
Parse hyperparameters from MMEngine config file.
|
|
124
|
+
|
|
125
|
+
:param config_path: Path to config.py
|
|
126
|
+
:return: Dict with extracted hyperparameters
|
|
127
|
+
"""
|
|
128
|
+
# TODO: only basic parsing for segmentation
|
|
129
|
+
hyperparams = {}
|
|
130
|
+
|
|
131
|
+
if not os.path.exists(config_path):
|
|
132
|
+
return hyperparams
|
|
133
|
+
|
|
134
|
+
with open(config_path, 'r') as f:
|
|
135
|
+
content = f.read()
|
|
136
|
+
|
|
137
|
+
# Extract crop_size
|
|
138
|
+
match = re.search(r'crop_size\s*=\s*\((\d+),\s*(\d+)\)', content)
|
|
139
|
+
if match:
|
|
140
|
+
hyperparams['crop_size'] = f"({match.group(1)}, {match.group(2)})"
|
|
141
|
+
|
|
142
|
+
# Extract learning rate
|
|
143
|
+
match = re.search(r'lr=([0-9.e-]+)', content)
|
|
144
|
+
if match:
|
|
145
|
+
hyperparams['learning_rate'] = float(match.group(1))
|
|
146
|
+
|
|
147
|
+
# Extract batch_size
|
|
148
|
+
match = re.search(r'batch_size=(\d+)', content)
|
|
149
|
+
if match:
|
|
150
|
+
hyperparams['batch_size'] = int(match.group(1))
|
|
151
|
+
|
|
152
|
+
# Extract max_epochs
|
|
153
|
+
match = re.search(r'max_epochs\s*=\s*(\d+)', content)
|
|
154
|
+
if match:
|
|
155
|
+
hyperparams['max_epochs'] = int(match.group(1))
|
|
156
|
+
|
|
157
|
+
# Extract weight_decay
|
|
158
|
+
match = re.search(r'weight_decay=([0-9.e-]+)', content)
|
|
159
|
+
if match:
|
|
160
|
+
hyperparams['weight_decay'] = float(match.group(1))
|
|
161
|
+
|
|
162
|
+
# Extract optimizer
|
|
163
|
+
match = re.search(r"optimizer=dict\(type='(\w+)'", content)
|
|
164
|
+
if match:
|
|
165
|
+
hyperparams['optimizer'] = match.group(1)
|
|
166
|
+
|
|
167
|
+
return hyperparams
|
|
168
|
+
|
|
169
|
+
def _get_model_context(self) -> dict:
|
|
170
|
+
"""Model configuration info"""
|
|
171
|
+
classes = [cls.name for cls in self.model_meta.obj_classes if cls.name != "_background_"]
|
|
172
|
+
display_name = self.model_config.get("display_name", self.model_config.get("model_name", "Unknown"))
|
|
173
|
+
|
|
174
|
+
return {
|
|
175
|
+
"name": display_name,
|
|
176
|
+
"backbone": self.model_config.get("backbone", "N/A"),
|
|
177
|
+
"num_classes": len(classes),
|
|
178
|
+
"classes": classes,
|
|
179
|
+
"classes_short": classes[:3] + (["..."] if len(classes) > 3 else []),
|
|
180
|
+
"config_file": self.model_config.get("config_file", "N/A"),
|
|
181
|
+
"task_type": self.model_config.get("task_type", "Live Training"),
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
def _get_training_context(self) -> dict:
|
|
185
|
+
"""Training logs and checkpoints"""
|
|
186
|
+
logs_path = self.session_info.get("logs_dir")
|
|
187
|
+
logs_url = None
|
|
188
|
+
if logs_path:
|
|
189
|
+
logs_url = f"{self.api.server_address}/files/?path={logs_path}"
|
|
190
|
+
|
|
191
|
+
checkpoints = []
|
|
192
|
+
artifacts_dir = self.session_info.get("artifacts_dir", "")
|
|
193
|
+
for ckpt in self.session_info.get("checkpoints", []):
|
|
194
|
+
checkpoint = {
|
|
195
|
+
"name": ckpt["name"],
|
|
196
|
+
"iteration": ckpt["iteration"],
|
|
197
|
+
"loss": ckpt.get("loss"),
|
|
198
|
+
"url": f"{self.api.server_address}/files/?path={artifacts_dir}/checkpoints/{ckpt['name']}",
|
|
199
|
+
}
|
|
200
|
+
checkpoints.append(checkpoint)
|
|
201
|
+
|
|
202
|
+
# Get total iterations from loss_history or checkpoints
|
|
203
|
+
loss_history = self.session_info.get("loss_history", [])
|
|
204
|
+
# Handle both old (list) and new (dict) formats
|
|
205
|
+
if isinstance(loss_history, list) and loss_history:
|
|
206
|
+
total_iterations = loss_history[-1]["iteration"]
|
|
207
|
+
elif isinstance(loss_history, dict):
|
|
208
|
+
# Get max step from any metric
|
|
209
|
+
total_iterations = max(
|
|
210
|
+
(item["step"] for metric_data in loss_history.values() for item in metric_data),
|
|
211
|
+
default=0
|
|
212
|
+
) if loss_history else 0
|
|
213
|
+
else:
|
|
214
|
+
total_iterations = max([c["iteration"] for c in self.session_info.get("checkpoints", [])]) if self.session_info.get("checkpoints") else 0
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
return {
|
|
218
|
+
"total_iterations": total_iterations,
|
|
219
|
+
"device": self.session_info.get("device", "N/A"),
|
|
220
|
+
"session_url": self.session_info.get("session_url"),
|
|
221
|
+
"checkpoints": checkpoints,
|
|
222
|
+
"logs": {
|
|
223
|
+
"path": logs_path,
|
|
224
|
+
"url": logs_url,
|
|
225
|
+
},
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
def _get_dataset_context(self) -> dict:
|
|
229
|
+
"""Dataset info"""
|
|
230
|
+
return {
|
|
231
|
+
"current_size": self.session_info.get("dataset_size", 0),
|
|
232
|
+
"initial_samples": self.session_info.get("initial_samples", 0),
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
def _get_training_plots_html(self) -> Optional[str]:
|
|
236
|
+
"""
|
|
237
|
+
Generate HTML for training loss plot.
|
|
238
|
+
Currently returns None - to be implemented later with actual loss data.
|
|
239
|
+
"""
|
|
240
|
+
# TODO: Generate plot from loss history
|
|
241
|
+
# For now return placeholder
|
|
242
|
+
return None
|
|
243
|
+
|
|
244
|
+
def _generate_classes_table(self) -> str:
|
|
245
|
+
"""Generate HTML table with class names, shapes and colors.
|
|
246
|
+
|
|
247
|
+
:returns: HTML string with classes table
|
|
248
|
+
:rtype: str
|
|
249
|
+
"""
|
|
250
|
+
type_to_icon = {
|
|
251
|
+
sly.AnyGeometry: "zmdi zmdi-shape",
|
|
252
|
+
sly.Rectangle: "zmdi zmdi-crop-din",
|
|
253
|
+
sly.Polygon: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAYAAABzenr0AAAABmJLR0QA/wD/AP+gvaeTAAAB6klEQVRYhe2Wuy8EURTGf+u5VESNXq2yhYZCoeBv8RcI1i6NVUpsoVCKkHjUGlFTiYb1mFmh2MiKjVXMudmb3cPOzB0VXzKZm5k53/nmvO6Ff4RHD5AD7gFP1l3Kd11AHvCBEpAVW2esAvWmK6t8l1O+W0lCQEnIJoAZxUnzNQNkZF36jrQjgoA+uaciCgc9VaExBOyh/6WWAi1VhbjOJ4FbIXkBtgkK0BNHnYqNKUIPeBPbKyDdzpld5T6wD9SE4AwYjfEDaXFeFzE/doUWuhqwiFsOCwqv2hV2lU/L+sHBscGTxdvSFVoXpAjCZdauMHVic6ndl6U1VBsJCFhTeNUU9IiIEo3qvQYGHAV0AyfC5wNLhKipXuBCjA5wT8WxcM1FMRoBymK44CjAE57hqIazwCfwQdARcXa3UXHuRXVucIjb7jYvNkdxBZg0TBFid7PQTRAtX2xOiXkuMAMqYwkIE848rZFbjyNAmw9bIeweaZ2A5TgC7PnwKkTPtN+cTOrsyN3FEWAjRTAX6sA5ek77gSL6+WHZVQDAIHAjhJtN78aAS3lXAXYIivBOnCdyOAUYB6o0xqsvziry7FLE/Cp20cNcJEjDr8MUmVOVRzkVN+Nd7vZGVXXgiwxtPiRS5WFhz4fEq/zv4AvToMn7vCn3eAAAAABJRU5ErkJggg==",
|
|
254
|
+
sly.Bitmap: "zmdi zmdi-brush",
|
|
255
|
+
sly.Polyline: "zmdi zmdi-gesture",
|
|
256
|
+
sly.Point: "zmdi zmdi-dot-circle-alt",
|
|
257
|
+
sly.Cuboid: "zmdi zmdi-ungroup", #
|
|
258
|
+
sly.GraphNodes: "zmdi zmdi-grain",
|
|
259
|
+
sly.MultichannelBitmap: "zmdi zmdi-layers",
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
if not hasattr(self.model_meta, "obj_classes"):
|
|
263
|
+
return None
|
|
264
|
+
|
|
265
|
+
if len(self.model_meta.obj_classes) == 0:
|
|
266
|
+
return None
|
|
267
|
+
|
|
268
|
+
html = ['<table class="table">']
|
|
269
|
+
html.append("<thead><tr><th>Class name</th><th>Shape</th></tr></thead>")
|
|
270
|
+
html.append("<tbody>")
|
|
271
|
+
|
|
272
|
+
for obj_class in self.model_meta.obj_classes:
|
|
273
|
+
class_name = obj_class.name
|
|
274
|
+
color_hex = rgb2hex(obj_class.color)
|
|
275
|
+
icon = type_to_icon.get(obj_class.geometry_type, "zmdi zmdi-shape")
|
|
276
|
+
|
|
277
|
+
class_cell = (
|
|
278
|
+
f"<i class='zmdi zmdi-circle' style='color: {color_hex}; margin-right: 5px;'></i>"
|
|
279
|
+
f"<span>{class_name}</span>"
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
if isinstance(icon, str) and icon.startswith("data:image"):
|
|
283
|
+
shape_cell = f"<img src='{icon}' style='height: 15px; margin-right: 2px;'/>"
|
|
284
|
+
else:
|
|
285
|
+
shape_cell = f"<i class='{icon}' style='margin-right: 5px;'></i>"
|
|
286
|
+
|
|
287
|
+
shape_name = obj_class.geometry_type.geometry_name()
|
|
288
|
+
shape_cell += f"<span>{shape_name}</span>"
|
|
289
|
+
|
|
290
|
+
html.append(f"<tr><td>{class_cell}</td><td>{shape_cell}</td></tr>")
|
|
291
|
+
|
|
292
|
+
html.append("</tbody>")
|
|
293
|
+
html.append("</table>")
|
|
294
|
+
return "\n".join(html)
|
|
295
|
+
|
|
296
|
+
def upload_to_artifacts(self, remote_dir: str):
|
|
297
|
+
"""
|
|
298
|
+
Upload report to team files.
|
|
299
|
+
|
|
300
|
+
Default path: /live-training/{project_id}_{project_name}/{session_id}/
|
|
301
|
+
"""
|
|
302
|
+
# Normalize path - remove trailing slash
|
|
303
|
+
remote_dir = remote_dir.rstrip("/")
|
|
304
|
+
file_info = self.upload(remote_dir, team_id=self.team_id)
|
|
305
|
+
self._report_file_info = file_info
|
|
306
|
+
return file_info
|
|
307
|
+
|
|
308
|
+
def _get_widgets_context(self) -> dict:
|
|
309
|
+
"""Generate widgets (tables, plots) for the report"""
|
|
310
|
+
checkpoints_table = self._generate_checkpoints_table()
|
|
311
|
+
training_plot = self._generate_training_plot()
|
|
312
|
+
classes = self._generate_classes_table()
|
|
313
|
+
|
|
314
|
+
return {
|
|
315
|
+
"tables": {
|
|
316
|
+
"checkpoints": checkpoints_table,
|
|
317
|
+
"classes": classes,
|
|
318
|
+
},
|
|
319
|
+
"training_plot": training_plot,
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
def _generate_checkpoints_table(self) -> Optional[str]:
|
|
323
|
+
"""Generate HTML table with checkpoints"""
|
|
324
|
+
# Get training context to access checkpoints with URLs
|
|
325
|
+
training_ctx = self._get_training_context()
|
|
326
|
+
checkpoints = training_ctx.get("checkpoints", [])
|
|
327
|
+
|
|
328
|
+
if not checkpoints:
|
|
329
|
+
return None
|
|
330
|
+
|
|
331
|
+
html = ['<table class="table">']
|
|
332
|
+
html.append("<thead><tr><th>Checkpoint Name</th><th>Iteration</th><th>Loss</th><th>Actions</th></tr></thead>")
|
|
333
|
+
html.append("<tbody>")
|
|
334
|
+
|
|
335
|
+
for checkpoint in checkpoints:
|
|
336
|
+
name = checkpoint.get("name", "N/A")
|
|
337
|
+
iteration = checkpoint.get("iteration", "N/A")
|
|
338
|
+
loss = checkpoint.get("loss")
|
|
339
|
+
url = checkpoint.get("url", "")
|
|
340
|
+
loss_str = f"{loss:.6f}" if loss is not None else "N/A"
|
|
341
|
+
|
|
342
|
+
download_link = f'<a href="{url}" target="_blank" class="download-link">Download</a>' if url else ""
|
|
343
|
+
|
|
344
|
+
html.append(f"<tr><td>{name}</td><td>{iteration}</td><td>{loss_str}</td><td>{download_link}</td></tr>")
|
|
345
|
+
|
|
346
|
+
html.append("</tbody>")
|
|
347
|
+
html.append("</table>")
|
|
348
|
+
return "\n".join(html)
|
|
349
|
+
|
|
350
|
+
def _generate_training_plot(self) -> str:
|
|
351
|
+
"""Generate training plots grid (like Experiments)"""
|
|
352
|
+
loss_history = self.session_info.get("loss_history", {})
|
|
353
|
+
|
|
354
|
+
if not loss_history or not isinstance(loss_history, dict):
|
|
355
|
+
return "<p>No training data available yet.</p>"
|
|
356
|
+
|
|
357
|
+
# Get all metrics
|
|
358
|
+
metrics = list(loss_history.keys())
|
|
359
|
+
n_metrics = len(metrics)
|
|
360
|
+
|
|
361
|
+
if n_metrics == 0:
|
|
362
|
+
return "<p>No training data available yet.</p>"
|
|
363
|
+
|
|
364
|
+
# Calculate grid size (like in Experiments)
|
|
365
|
+
side = min(4, max(2, math.ceil(math.sqrt(n_metrics))))
|
|
366
|
+
cols = side
|
|
367
|
+
rows = math.ceil(n_metrics / cols)
|
|
368
|
+
|
|
369
|
+
# Create subplots
|
|
370
|
+
fig = make_subplots(rows=rows, cols=cols, subplot_titles=metrics)
|
|
371
|
+
|
|
372
|
+
for idx, metric in enumerate(metrics, start=1):
|
|
373
|
+
data = loss_history[metric]
|
|
374
|
+
if not data:
|
|
375
|
+
continue
|
|
376
|
+
|
|
377
|
+
steps = [item["step"] for item in data]
|
|
378
|
+
values = [item["value"] for item in data]
|
|
379
|
+
|
|
380
|
+
row = (idx - 1) // cols + 1
|
|
381
|
+
col = (idx - 1) % cols + 1
|
|
382
|
+
|
|
383
|
+
fig.add_trace(
|
|
384
|
+
go.Scatter(
|
|
385
|
+
x=steps,
|
|
386
|
+
y=values,
|
|
387
|
+
mode="lines",
|
|
388
|
+
name=metric,
|
|
389
|
+
showlegend=False,
|
|
390
|
+
),
|
|
391
|
+
row=row,
|
|
392
|
+
col=col,
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
# Special formatting for training rate
|
|
396
|
+
if metric.startswith("lr"):
|
|
397
|
+
fig.update_yaxes(tickformat=".0e", row=row, col=col)
|
|
398
|
+
|
|
399
|
+
fig.update_layout(
|
|
400
|
+
height=300 * rows,
|
|
401
|
+
width=400 * cols,
|
|
402
|
+
showlegend=False,
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
# Save as PNG
|
|
406
|
+
data_dir = os.path.join(self.output_dir, "data")
|
|
407
|
+
os.makedirs(data_dir, exist_ok=True)
|
|
408
|
+
img_path = os.path.join(data_dir, "training_plots_grid.png")
|
|
409
|
+
|
|
410
|
+
try:
|
|
411
|
+
fig.write_image(img_path, engine="kaleido")
|
|
412
|
+
except Exception as e:
|
|
413
|
+
logger.warning(f"Failed to save training plot: {e}")
|
|
414
|
+
return "<p>Failed to generate training plot</p>"
|
|
415
|
+
|
|
416
|
+
# Return Vue image component
|
|
417
|
+
return f'<sly-iw-image src="/data/training_plots_grid.png" :template-base-path="templateBasePath" :options="{{ style: {{ width: \'70%\', height: \'auto\' }} }}" />'
|
|
418
|
+
|
|
419
|
+
def _get_online_training_app_info(self):
|
|
420
|
+
"""Get online training app info from ecosystem"""
|
|
421
|
+
try:
|
|
422
|
+
# TODO: only works for public apps.
|
|
423
|
+
# Exception handles only private apps on dev server. Need implement for private apps on any server.
|
|
424
|
+
module_id = self.api.app.get_ecosystem_module_id(self.slug)
|
|
425
|
+
except Exception as e:
|
|
426
|
+
logger.warning(f"Failed to get module ID for slug {self.slug}: {e}.")
|
|
427
|
+
if self.api.server_address.endswith("dev.internal.supervisely.com"):
|
|
428
|
+
logger.warning("Using hardcoded module ID for dev server")
|
|
429
|
+
task2module_map = {
|
|
430
|
+
"object detection": 620,
|
|
431
|
+
"semantic segmentation": 621,
|
|
432
|
+
}
|
|
433
|
+
module_id = task2module_map.get(self.task_type)
|
|
434
|
+
else:
|
|
435
|
+
raise e
|
|
436
|
+
return {
|
|
437
|
+
"slug": self.slug,
|
|
438
|
+
"module_id": module_id,
|
|
439
|
+
}
|
|
440
|
+
|
|
441
|
+
def _get_resources_context(self):
|
|
442
|
+
"""Return apps module IDs for buttons"""
|
|
443
|
+
online_training_app = self._get_online_training_app_info()
|
|
444
|
+
|
|
445
|
+
return {
|
|
446
|
+
"apps": {
|
|
447
|
+
"online_training": online_training_app,
|
|
448
|
+
}
|
|
449
|
+
}
|
|
450
|
+
|
|
451
|
+
def get_report(self) -> str:
|
|
452
|
+
"""Get report URL after upload"""
|
|
453
|
+
if self._report_file_info is None:
|
|
454
|
+
raise RuntimeError("Report not uploaded yet. Call upload_to_artifacts() first.")
|
|
455
|
+
|
|
456
|
+
# self._report_file_info is file_id (int), not FileInfo object
|
|
457
|
+
file_id = self._report_file_info if isinstance(self._report_file_info, int) else self._report_file_info.id
|
|
458
|
+
return self._report_url(self.api.server_address, file_id)
|
|
459
|
+
|
|
460
|
+
def get_report_id(self) -> int:
|
|
461
|
+
"""Get report file ID"""
|
|
462
|
+
if self._report_file_info is None:
|
|
463
|
+
raise RuntimeError("Report not uploaded yet. Call upload_to_artifacts() first.")
|
|
464
|
+
return self._report_file_info.id
|