supervisely 6.73.438__py3-none-any.whl → 6.73.513__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/__init__.py +137 -1
- supervisely/_utils.py +81 -0
- supervisely/annotation/annotation.py +8 -2
- supervisely/annotation/json_geometries_map.py +14 -11
- supervisely/annotation/label.py +80 -3
- supervisely/api/annotation_api.py +14 -11
- supervisely/api/api.py +59 -38
- supervisely/api/app_api.py +11 -2
- supervisely/api/dataset_api.py +74 -12
- supervisely/api/entities_collection_api.py +10 -0
- supervisely/api/entity_annotation/figure_api.py +52 -4
- supervisely/api/entity_annotation/object_api.py +3 -3
- supervisely/api/entity_annotation/tag_api.py +63 -12
- supervisely/api/guides_api.py +210 -0
- supervisely/api/image_api.py +72 -1
- supervisely/api/labeling_job_api.py +83 -1
- supervisely/api/labeling_queue_api.py +33 -7
- supervisely/api/module_api.py +9 -0
- supervisely/api/project_api.py +71 -26
- supervisely/api/storage_api.py +3 -1
- supervisely/api/task_api.py +13 -2
- supervisely/api/team_api.py +4 -3
- supervisely/api/video/video_annotation_api.py +119 -3
- supervisely/api/video/video_api.py +65 -14
- supervisely/api/video/video_figure_api.py +24 -11
- supervisely/app/__init__.py +1 -1
- supervisely/app/content.py +23 -7
- supervisely/app/development/development.py +18 -2
- supervisely/app/fastapi/__init__.py +1 -0
- supervisely/app/fastapi/custom_static_files.py +1 -1
- supervisely/app/fastapi/multi_user.py +105 -0
- supervisely/app/fastapi/subapp.py +88 -42
- supervisely/app/fastapi/websocket.py +77 -9
- supervisely/app/singleton.py +21 -0
- supervisely/app/v1/app_service.py +18 -2
- supervisely/app/v1/constants.py +7 -1
- supervisely/app/widgets/__init__.py +6 -0
- supervisely/app/widgets/activity_feed/__init__.py +0 -0
- supervisely/app/widgets/activity_feed/activity_feed.py +239 -0
- supervisely/app/widgets/activity_feed/style.css +78 -0
- supervisely/app/widgets/activity_feed/template.html +22 -0
- supervisely/app/widgets/card/card.py +20 -0
- supervisely/app/widgets/classes_list_selector/classes_list_selector.py +121 -9
- supervisely/app/widgets/classes_list_selector/template.html +60 -93
- supervisely/app/widgets/classes_mapping/classes_mapping.py +13 -12
- supervisely/app/widgets/classes_table/classes_table.py +1 -0
- supervisely/app/widgets/deploy_model/deploy_model.py +56 -35
- supervisely/app/widgets/dialog/dialog.py +12 -0
- supervisely/app/widgets/dialog/template.html +2 -1
- supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +1 -1
- supervisely/app/widgets/experiment_selector/experiment_selector.py +8 -0
- supervisely/app/widgets/fast_table/fast_table.py +184 -60
- supervisely/app/widgets/fast_table/template.html +1 -1
- supervisely/app/widgets/heatmap/__init__.py +0 -0
- supervisely/app/widgets/heatmap/heatmap.py +564 -0
- supervisely/app/widgets/heatmap/script.js +533 -0
- supervisely/app/widgets/heatmap/style.css +233 -0
- supervisely/app/widgets/heatmap/template.html +21 -0
- supervisely/app/widgets/modal/__init__.py +0 -0
- supervisely/app/widgets/modal/modal.py +198 -0
- supervisely/app/widgets/modal/template.html +10 -0
- supervisely/app/widgets/object_class_view/object_class_view.py +3 -0
- supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
- supervisely/app/widgets/radio_tabs/template.html +1 -0
- supervisely/app/widgets/select/select.py +6 -3
- supervisely/app/widgets/select_class/__init__.py +0 -0
- supervisely/app/widgets/select_class/select_class.py +363 -0
- supervisely/app/widgets/select_class/template.html +50 -0
- supervisely/app/widgets/select_cuda/select_cuda.py +22 -0
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +65 -7
- supervisely/app/widgets/select_tag/__init__.py +0 -0
- supervisely/app/widgets/select_tag/select_tag.py +352 -0
- supervisely/app/widgets/select_tag/template.html +64 -0
- supervisely/app/widgets/select_team/select_team.py +37 -4
- supervisely/app/widgets/select_team/template.html +4 -5
- supervisely/app/widgets/select_user/__init__.py +0 -0
- supervisely/app/widgets/select_user/select_user.py +270 -0
- supervisely/app/widgets/select_user/template.html +13 -0
- supervisely/app/widgets/select_workspace/select_workspace.py +59 -10
- supervisely/app/widgets/select_workspace/template.html +9 -12
- supervisely/app/widgets/table/table.py +68 -13
- supervisely/app/widgets/tree_select/tree_select.py +2 -0
- supervisely/aug/aug.py +6 -2
- supervisely/convert/base_converter.py +1 -0
- supervisely/convert/converter.py +2 -2
- supervisely/convert/image/csv/csv_converter.py +24 -15
- supervisely/convert/image/image_converter.py +3 -1
- supervisely/convert/image/image_helper.py +48 -4
- supervisely/convert/image/label_studio/label_studio_converter.py +2 -0
- supervisely/convert/image/medical2d/medical2d_helper.py +2 -24
- supervisely/convert/image/multispectral/multispectral_converter.py +6 -0
- supervisely/convert/image/pascal_voc/pascal_voc_converter.py +8 -5
- supervisely/convert/image/pascal_voc/pascal_voc_helper.py +7 -0
- supervisely/convert/pointcloud/kitti_3d/kitti_3d_converter.py +33 -3
- supervisely/convert/pointcloud/kitti_3d/kitti_3d_helper.py +12 -5
- supervisely/convert/pointcloud/las/las_converter.py +13 -1
- supervisely/convert/pointcloud/las/las_helper.py +110 -11
- supervisely/convert/pointcloud/nuscenes_conv/nuscenes_converter.py +27 -16
- supervisely/convert/pointcloud/pointcloud_converter.py +91 -3
- supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_converter.py +58 -22
- supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_helper.py +21 -47
- supervisely/convert/video/__init__.py +1 -0
- supervisely/convert/video/multi_view/__init__.py +0 -0
- supervisely/convert/video/multi_view/multi_view.py +543 -0
- supervisely/convert/video/sly/sly_video_converter.py +359 -3
- supervisely/convert/video/video_converter.py +24 -4
- supervisely/convert/volume/dicom/dicom_converter.py +13 -5
- supervisely/convert/volume/dicom/dicom_helper.py +30 -18
- supervisely/geometry/constants.py +1 -0
- supervisely/geometry/geometry.py +4 -0
- supervisely/geometry/helpers.py +5 -1
- supervisely/geometry/oriented_bbox.py +676 -0
- supervisely/geometry/polyline_3d.py +110 -0
- supervisely/geometry/rectangle.py +2 -1
- supervisely/io/env.py +76 -1
- supervisely/io/fs.py +21 -0
- supervisely/nn/benchmark/base_evaluator.py +104 -11
- supervisely/nn/benchmark/instance_segmentation/evaluator.py +1 -8
- supervisely/nn/benchmark/object_detection/evaluator.py +20 -4
- supervisely/nn/benchmark/object_detection/vis_metrics/pr_curve.py +10 -5
- supervisely/nn/benchmark/semantic_segmentation/evaluator.py +34 -16
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/confusion_matrix.py +1 -1
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/frequently_confused.py +1 -1
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/overview.py +1 -1
- supervisely/nn/benchmark/visualization/evaluation_result.py +66 -4
- supervisely/nn/inference/cache.py +43 -18
- supervisely/nn/inference/gui/serving_gui_template.py +5 -2
- supervisely/nn/inference/inference.py +916 -222
- supervisely/nn/inference/inference_request.py +55 -10
- supervisely/nn/inference/predict_app/gui/classes_selector.py +83 -12
- supervisely/nn/inference/predict_app/gui/gui.py +676 -488
- supervisely/nn/inference/predict_app/gui/input_selector.py +205 -26
- supervisely/nn/inference/predict_app/gui/model_selector.py +2 -4
- supervisely/nn/inference/predict_app/gui/output_selector.py +46 -6
- supervisely/nn/inference/predict_app/gui/settings_selector.py +756 -59
- supervisely/nn/inference/predict_app/gui/tags_selector.py +1 -1
- supervisely/nn/inference/predict_app/gui/utils.py +236 -119
- supervisely/nn/inference/predict_app/predict_app.py +2 -2
- supervisely/nn/inference/session.py +43 -35
- supervisely/nn/inference/tracking/bbox_tracking.py +118 -35
- supervisely/nn/inference/tracking/point_tracking.py +5 -1
- supervisely/nn/inference/tracking/tracker_interface.py +10 -1
- supervisely/nn/inference/uploader.py +139 -12
- supervisely/nn/live_training/__init__.py +7 -0
- supervisely/nn/live_training/api_server.py +111 -0
- supervisely/nn/live_training/artifacts_utils.py +243 -0
- supervisely/nn/live_training/checkpoint_utils.py +229 -0
- supervisely/nn/live_training/dynamic_sampler.py +44 -0
- supervisely/nn/live_training/helpers.py +14 -0
- supervisely/nn/live_training/incremental_dataset.py +146 -0
- supervisely/nn/live_training/live_training.py +497 -0
- supervisely/nn/live_training/loss_plateau_detector.py +111 -0
- supervisely/nn/live_training/request_queue.py +52 -0
- supervisely/nn/model/model_api.py +9 -0
- supervisely/nn/model/prediction.py +2 -1
- supervisely/nn/model/prediction_session.py +26 -14
- supervisely/nn/prediction_dto.py +19 -1
- supervisely/nn/tracker/base_tracker.py +11 -1
- supervisely/nn/tracker/botsort/botsort_config.yaml +0 -1
- supervisely/nn/tracker/botsort/tracker/mc_bot_sort.py +7 -4
- supervisely/nn/tracker/botsort_tracker.py +94 -65
- supervisely/nn/tracker/utils.py +4 -5
- supervisely/nn/tracker/visualize.py +93 -93
- supervisely/nn/training/gui/classes_selector.py +16 -1
- supervisely/nn/training/gui/train_val_splits_selector.py +52 -31
- supervisely/nn/training/train_app.py +46 -31
- supervisely/project/data_version.py +115 -51
- supervisely/project/download.py +1 -1
- supervisely/project/pointcloud_episode_project.py +37 -8
- supervisely/project/pointcloud_project.py +30 -2
- supervisely/project/project.py +14 -2
- supervisely/project/project_meta.py +27 -1
- supervisely/project/project_settings.py +32 -18
- supervisely/project/versioning/__init__.py +1 -0
- supervisely/project/versioning/common.py +20 -0
- supervisely/project/versioning/schema_fields.py +35 -0
- supervisely/project/versioning/video_schema.py +221 -0
- supervisely/project/versioning/volume_schema.py +87 -0
- supervisely/project/video_project.py +717 -15
- supervisely/project/volume_project.py +623 -5
- supervisely/template/experiment/experiment.html.jinja +4 -4
- supervisely/template/experiment/experiment_generator.py +14 -21
- supervisely/template/live_training/__init__.py +0 -0
- supervisely/template/live_training/header.html.jinja +96 -0
- supervisely/template/live_training/live_training.html.jinja +51 -0
- supervisely/template/live_training/live_training_generator.py +464 -0
- supervisely/template/live_training/sly-style.css +402 -0
- supervisely/template/live_training/template.html.jinja +18 -0
- supervisely/versions.json +28 -26
- supervisely/video/sampling.py +39 -20
- supervisely/video/video.py +41 -12
- supervisely/video_annotation/video_figure.py +38 -4
- supervisely/video_annotation/video_object.py +29 -4
- supervisely/volume/stl_converter.py +2 -0
- supervisely/worker_api/agent_rpc.py +24 -1
- supervisely/worker_api/rpc_servicer.py +31 -7
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/METADATA +58 -40
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/RECORD +203 -155
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/WHEEL +1 -1
- supervisely_lib/__init__.py +6 -1
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info/licenses}/LICENSE +0 -0
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import shutil
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Optional, Dict
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
import supervisely as sly
|
|
7
|
+
from supervisely import logger
|
|
8
|
+
from supervisely.template.live_training.live_training_generator import LiveTrainingGenerator
|
|
9
|
+
import supervisely.io.json as sly_json
|
|
10
|
+
from supervisely.nn.live_training.helpers import ClassMap
|
|
11
|
+
import yaml
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def upload_artifacts(
|
|
15
|
+
api: sly.Api,
|
|
16
|
+
session_info: dict,
|
|
17
|
+
artifacts: dict,
|
|
18
|
+
) -> str:
|
|
19
|
+
"""
|
|
20
|
+
Upload artifacts to Team Files and generate experiment report.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
session_info: Training session context
|
|
24
|
+
- team_id: Team ID
|
|
25
|
+
- task_id: Task ID
|
|
26
|
+
- project_id: Project ID
|
|
27
|
+
- framework_name: Framework name
|
|
28
|
+
- task_type: Task type
|
|
29
|
+
- class_map: Model class map
|
|
30
|
+
- start_time: Training start time string
|
|
31
|
+
- train_size: Final dataset size
|
|
32
|
+
- initial_samples: Number of initial samples
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
artifacts: Framework-specific artifacts
|
|
36
|
+
- checkpoint_path: Path to checkpoint file
|
|
37
|
+
- checkpoint_info: Dict with {name, iteration, loss}
|
|
38
|
+
- config_path: Path to config file
|
|
39
|
+
- logs_dir: Path to TensorBoard logs or None
|
|
40
|
+
- model_config: Model configuration dict
|
|
41
|
+
- loss_history: Dict with loss history
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
report_url: URL to experiment report
|
|
45
|
+
"""
|
|
46
|
+
logger.info("Starting artifacts upload")
|
|
47
|
+
|
|
48
|
+
# Unpack session_info
|
|
49
|
+
team_id = session_info['team_id']
|
|
50
|
+
task_id = session_info['task_id']
|
|
51
|
+
project_id = session_info['project_id']
|
|
52
|
+
framework_name = session_info['framework_name']
|
|
53
|
+
task_type = session_info['task_type']
|
|
54
|
+
class_map: ClassMap = session_info['class_map']
|
|
55
|
+
model_meta = sly.ProjectMeta(obj_classes=class_map.obj_classes)
|
|
56
|
+
start_time = session_info['start_time']
|
|
57
|
+
train_size = session_info['train_size']
|
|
58
|
+
initial_samples = session_info.get('initial_samples', 0)
|
|
59
|
+
|
|
60
|
+
# Unpack artifacts
|
|
61
|
+
checkpoint_path = artifacts['checkpoint_path']
|
|
62
|
+
checkpoint_info = artifacts['checkpoint_info']
|
|
63
|
+
config_path = artifacts['config_path']
|
|
64
|
+
logs_dir = artifacts.get('logs_dir')
|
|
65
|
+
model_config = artifacts['model_config']
|
|
66
|
+
|
|
67
|
+
work_dir = Path(os.path.dirname(checkpoint_path)).parent
|
|
68
|
+
output_dir = work_dir / "upload_artifacts"
|
|
69
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
70
|
+
|
|
71
|
+
project_info = api.project.get_info_by_id(project_id)
|
|
72
|
+
project_name = project_info.name if project_info else "unknown"
|
|
73
|
+
model_name = f"Live training - {project_name}"
|
|
74
|
+
model_config['model_name'] = model_name
|
|
75
|
+
remote_dir = f"/experiments/live_training/{project_id}_{project_name}/{task_id}_{framework_name}/"
|
|
76
|
+
logger.info(f"Remote directory: {remote_dir}")
|
|
77
|
+
|
|
78
|
+
checkpoints_dir = output_dir / "checkpoints"
|
|
79
|
+
checkpoints_dir.mkdir(exist_ok=True)
|
|
80
|
+
checkpoint_dest = checkpoints_dir / checkpoint_info["name"]
|
|
81
|
+
shutil.copy2(checkpoint_path, checkpoint_dest)
|
|
82
|
+
|
|
83
|
+
state_json_src = checkpoint_path.replace('.pth', '_state.json')
|
|
84
|
+
if os.path.exists(state_json_src):
|
|
85
|
+
state_json_dest = str(checkpoint_dest).replace('.pth', '_state.json')
|
|
86
|
+
shutil.copy2(state_json_src, state_json_dest)
|
|
87
|
+
|
|
88
|
+
if config_path and os.path.exists(config_path):
|
|
89
|
+
config_dest = output_dir / os.path.basename(config_path)
|
|
90
|
+
shutil.copy2(config_path, config_dest)
|
|
91
|
+
model_files = {"config": config_dest.name}
|
|
92
|
+
else:
|
|
93
|
+
model_files = {}
|
|
94
|
+
|
|
95
|
+
sly_json.dump_json_file(
|
|
96
|
+
model_meta.to_json(),
|
|
97
|
+
str(output_dir / "model_meta.json")
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
hyperparams = {}
|
|
101
|
+
if config_path and os.path.exists(config_path):
|
|
102
|
+
try:
|
|
103
|
+
hyperparams = LiveTrainingGenerator.parse_hyperparameters(config_path)
|
|
104
|
+
except Exception as e:
|
|
105
|
+
logger.warning(f"Failed to parse hyperparameters: {e}")
|
|
106
|
+
|
|
107
|
+
with open(output_dir / "hyperparameters.yaml", 'w') as f:
|
|
108
|
+
yaml.dump(hyperparams, f, default_flow_style=False)
|
|
109
|
+
|
|
110
|
+
with open(output_dir / "open_app.lnk", 'w') as f:
|
|
111
|
+
f.write(f"/apps/sessions/{task_id}")
|
|
112
|
+
|
|
113
|
+
if logs_dir and os.path.exists(logs_dir):
|
|
114
|
+
logs_dest = output_dir / "logs"
|
|
115
|
+
if logs_dest.exists():
|
|
116
|
+
shutil.rmtree(logs_dest)
|
|
117
|
+
shutil.copytree(logs_dir, logs_dest)
|
|
118
|
+
logger.info(f"Logs copied from {logs_dir}")
|
|
119
|
+
has_logs = True
|
|
120
|
+
else:
|
|
121
|
+
logger.warning("No logs provided")
|
|
122
|
+
has_logs = False
|
|
123
|
+
|
|
124
|
+
logger.info("Uploading to Team Files")
|
|
125
|
+
api.file.upload_directory_fast(
|
|
126
|
+
team_id=team_id,
|
|
127
|
+
local_dir=str(output_dir),
|
|
128
|
+
remote_dir=remote_dir
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
experiment_info = {
|
|
132
|
+
"experiment_name": f"Live Training {task_type.capitalize()} - Task {task_id}",
|
|
133
|
+
"framework_name": framework_name,
|
|
134
|
+
"model_name": model_name,
|
|
135
|
+
"base_checkpoint": None,
|
|
136
|
+
"base_checkpoint_link": None,
|
|
137
|
+
"task_type": task_type,
|
|
138
|
+
"project_id": project_id,
|
|
139
|
+
"project_version": project_info.version if project_info else None,
|
|
140
|
+
"task_id": task_id,
|
|
141
|
+
"model_files": model_files,
|
|
142
|
+
"checkpoints": [f"checkpoints/{checkpoint_info['name']}"],
|
|
143
|
+
"best_checkpoint": checkpoint_info['name'],
|
|
144
|
+
"export": {},
|
|
145
|
+
"model_meta": "model_meta.json",
|
|
146
|
+
"hyperparameters": "hyperparameters.yaml",
|
|
147
|
+
"hyperparameters_id": None,
|
|
148
|
+
"artifacts_dir": remote_dir,
|
|
149
|
+
"datetime": start_time,
|
|
150
|
+
"experiment_report_id": None,
|
|
151
|
+
"evaluation_report_link": None,
|
|
152
|
+
"evaluation_metrics": {},
|
|
153
|
+
"primary_metric": None,
|
|
154
|
+
"logs": {"type": "tensorboard", "link": f"{remote_dir}logs/"} if has_logs else None,
|
|
155
|
+
"device": get_device_name(),
|
|
156
|
+
"training_duration": calculate_duration(start_time),
|
|
157
|
+
"train_collection_id": None,
|
|
158
|
+
"val_collection_id": None,
|
|
159
|
+
"project_preview": project_info.image_preview_url if project_info else None,
|
|
160
|
+
"train_size": train_size,
|
|
161
|
+
"initial_samples": initial_samples,
|
|
162
|
+
"val_size": 0,
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
checkpoints_info = [checkpoint_info]
|
|
166
|
+
loss_history = artifacts.get('loss_history', {})
|
|
167
|
+
|
|
168
|
+
session_info = {
|
|
169
|
+
"session_id": task_id,
|
|
170
|
+
"session_name": experiment_info["experiment_name"],
|
|
171
|
+
"project_id": project_id,
|
|
172
|
+
"start_time": experiment_info["datetime"],
|
|
173
|
+
"duration": experiment_info["training_duration"],
|
|
174
|
+
"artifacts_dir": remote_dir,
|
|
175
|
+
"logs_dir": f"{remote_dir}logs/" if has_logs else None,
|
|
176
|
+
"checkpoints": checkpoints_info,
|
|
177
|
+
"loss_history": loss_history,
|
|
178
|
+
"hyperparameters": hyperparams,
|
|
179
|
+
"status": "completed",
|
|
180
|
+
"device": experiment_info["device"],
|
|
181
|
+
"dataset_size": train_size,
|
|
182
|
+
"initial_samples": 0,
|
|
183
|
+
"samples_added": 0,
|
|
184
|
+
"final_size": train_size,
|
|
185
|
+
"train_size": train_size,
|
|
186
|
+
"initial_samples": initial_samples,
|
|
187
|
+
"val_size": 0,
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
report_dir = work_dir / "live_training_report"
|
|
191
|
+
report_dir.mkdir(exist_ok=True)
|
|
192
|
+
|
|
193
|
+
generator = LiveTrainingGenerator(
|
|
194
|
+
api=api,
|
|
195
|
+
session_info=session_info,
|
|
196
|
+
model_config=model_config,
|
|
197
|
+
model_meta=model_meta,
|
|
198
|
+
output_dir=str(report_dir),
|
|
199
|
+
team_id=team_id,
|
|
200
|
+
task_type=task_type,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
generator.generate()
|
|
204
|
+
file_info = generator.upload_to_artifacts(os.path.join(remote_dir, "visualization"))
|
|
205
|
+
|
|
206
|
+
report_id = file_info if isinstance(file_info, int) else getattr(file_info, 'id', file_info)
|
|
207
|
+
report_url = f"{api.server_address}/nn/experiments/{report_id}"
|
|
208
|
+
|
|
209
|
+
logger.info(f"Report URL: {report_url}")
|
|
210
|
+
|
|
211
|
+
experiment_info["has_report"] = True
|
|
212
|
+
experiment_info["experiment_report_id"] = int(report_url.split('/')[-1])
|
|
213
|
+
response = api.task.set_output_experiment(task_id, experiment_info)
|
|
214
|
+
|
|
215
|
+
return report_url
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def calculate_duration(start_time: str) -> str:
|
|
219
|
+
"""Calculate training duration in 'Xh Ym' format."""
|
|
220
|
+
try:
|
|
221
|
+
start_dt = datetime.strptime(start_time, "%Y-%m-%d %H:%M:%S")
|
|
222
|
+
duration_sec = (datetime.now() - start_dt).total_seconds()
|
|
223
|
+
hours = int(duration_sec // 3600)
|
|
224
|
+
minutes = int((duration_sec % 3600) // 60)
|
|
225
|
+
return f"{hours}h {minutes}m"
|
|
226
|
+
except:
|
|
227
|
+
return "N/A"
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def get_device_name() -> str:
|
|
231
|
+
"""Get GPU device name or 'cpu'."""
|
|
232
|
+
import torch # pylint: disable=import-error
|
|
233
|
+
if not os.path.exists("/dev/nvidia0"):
|
|
234
|
+
return "cpu"
|
|
235
|
+
|
|
236
|
+
try:
|
|
237
|
+
if torch.cuda.is_available():
|
|
238
|
+
device_id = int(os.getenv("CUDA_VISIBLE_DEVICES", "0").split(",")[0])
|
|
239
|
+
return torch.cuda.get_device_name(device_id)
|
|
240
|
+
except Exception as e:
|
|
241
|
+
logger.warning(f"Failed to get GPU name: {e}")
|
|
242
|
+
|
|
243
|
+
return "cuda"
|
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
from typing import Tuple, List, Optional
|
|
4
|
+
import json
|
|
5
|
+
from supervisely import logger
|
|
6
|
+
from supervisely.nn.live_training.helpers import ClassMap
|
|
7
|
+
from supervisely.io.json import load_json_file
|
|
8
|
+
import supervisely as sly
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def validate_classes_exact_match(saved_classes: List[str], current_classes: List[str]):
|
|
12
|
+
"""
|
|
13
|
+
Validate that saved and current classes match exactly.
|
|
14
|
+
Raises ValueError if they don't match.
|
|
15
|
+
"""
|
|
16
|
+
if set(saved_classes) != set(current_classes):
|
|
17
|
+
raise ValueError(
|
|
18
|
+
f"Class names in checkpoint do not match current class names. "
|
|
19
|
+
f"Saved: {saved_classes}, Current: {current_classes}"
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def reorder_class_map(saved_classes: List[str], project_meta) -> ClassMap:
|
|
24
|
+
"""
|
|
25
|
+
Create ClassMap with reordered classes matching checkpoint order.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
class_map: New ClassMap with correct order
|
|
29
|
+
"""
|
|
30
|
+
class_mapping = {cls: idx for idx, cls in enumerate(saved_classes)}
|
|
31
|
+
logger.info(f"Class mapping: {class_mapping}")
|
|
32
|
+
|
|
33
|
+
# Create ClassMap from class names in checkpoint order
|
|
34
|
+
obj_classes = []
|
|
35
|
+
for name in saved_classes:
|
|
36
|
+
obj_class = project_meta.get_obj_class(name)
|
|
37
|
+
if obj_class is None:
|
|
38
|
+
raise ValueError(f"Class '{name}' not found in project metadata")
|
|
39
|
+
obj_classes.append(obj_class)
|
|
40
|
+
|
|
41
|
+
return ClassMap(obj_classes)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def remove_classification_head(checkpoint_path: str) -> str:
|
|
45
|
+
"""
|
|
46
|
+
Remove classification head weights from checkpoint and save modified version.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
modified_path: Path to checkpoint without classification head
|
|
50
|
+
"""
|
|
51
|
+
import torch # pylint: disable=import-error
|
|
52
|
+
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
|
|
53
|
+
state_dict = checkpoint.get('state_dict', {})
|
|
54
|
+
|
|
55
|
+
keys_to_remove = []
|
|
56
|
+
for key in state_dict.keys():
|
|
57
|
+
if 'decode_head' in key or 'auxiliary_head' in key:
|
|
58
|
+
keys_to_remove.append(key)
|
|
59
|
+
|
|
60
|
+
for key in keys_to_remove:
|
|
61
|
+
del state_dict[key]
|
|
62
|
+
|
|
63
|
+
logger.info(f"Removed {len(keys_to_remove)} classification head parameters")
|
|
64
|
+
|
|
65
|
+
modified_path = checkpoint_path.replace('.pth', '_headless.pth')
|
|
66
|
+
torch.save(checkpoint, modified_path)
|
|
67
|
+
|
|
68
|
+
return modified_path
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def resolve_checkpoint(
|
|
72
|
+
checkpoint_mode: str,
|
|
73
|
+
selected_experiment_task_id: Optional[int],
|
|
74
|
+
class_map: ClassMap,
|
|
75
|
+
project_meta,
|
|
76
|
+
api,
|
|
77
|
+
team_id: int,
|
|
78
|
+
work_dir: str
|
|
79
|
+
) -> Tuple[Optional[str], ClassMap, Optional[dict]]:
|
|
80
|
+
"""
|
|
81
|
+
Main orchestrator function to resolve checkpoint loading based on mode.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
checkpoint_mode: One of 'scratch', 'finetune', 'continue'
|
|
85
|
+
selected_experiment_task_id: Task ID to load checkpoint from (required for finetune/continue)
|
|
86
|
+
class_map: Current ClassMap
|
|
87
|
+
project_meta: Project metadata
|
|
88
|
+
api: Supervisely API instance
|
|
89
|
+
team_id: Team ID
|
|
90
|
+
framework_name: Framework name (unused, kept for compatibility)
|
|
91
|
+
work_dir: Working directory for downloaded files
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
checkpoint_path: Path to checkpoint file (None for scratch mode)
|
|
95
|
+
class_map: Updated ClassMap (may be reordered for finetune/continue)
|
|
96
|
+
state: Training state dict (only for continue mode)
|
|
97
|
+
"""
|
|
98
|
+
checkpoint_name = "latest.pth"
|
|
99
|
+
current_classes = [cls.name for cls in class_map.obj_classes]
|
|
100
|
+
logger.info(f"Checkpoint mode: {checkpoint_mode}")
|
|
101
|
+
|
|
102
|
+
if checkpoint_mode == "scratch":
|
|
103
|
+
logger.info("Starting from pretrained weights (scratch mode)")
|
|
104
|
+
return None, class_map, None
|
|
105
|
+
|
|
106
|
+
if selected_experiment_task_id is None:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
f"selected_experiment_task_id must be provided when checkpoint_mode='{checkpoint_mode}'"
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# Get experiment info
|
|
112
|
+
task_info = api.task.get_info_by_id(selected_experiment_task_id)
|
|
113
|
+
experiment_info = task_info["meta"]["output"]["experiment"]["data"]
|
|
114
|
+
|
|
115
|
+
artifacts_dir = experiment_info["artifacts_dir"]
|
|
116
|
+
model_meta_filename = experiment_info.get("model_meta", "model_meta.json")
|
|
117
|
+
|
|
118
|
+
# Setup local paths
|
|
119
|
+
local_dir = os.path.join(work_dir, 'downloaded_checkpoints')
|
|
120
|
+
os.makedirs(local_dir, exist_ok=True)
|
|
121
|
+
|
|
122
|
+
# Download checkpoint
|
|
123
|
+
remote_checkpoint = f"{artifacts_dir}checkpoints/{checkpoint_name}"
|
|
124
|
+
local_checkpoint = os.path.join(local_dir, checkpoint_name)
|
|
125
|
+
|
|
126
|
+
logger.info(f"Downloading checkpoint from {remote_checkpoint}")
|
|
127
|
+
api.file.download(team_id, remote_checkpoint, local_checkpoint)
|
|
128
|
+
logger.info(f"Checkpoint downloaded to {local_checkpoint}")
|
|
129
|
+
|
|
130
|
+
# Download model_meta.json
|
|
131
|
+
remote_model_meta = f"{artifacts_dir}{model_meta_filename}"
|
|
132
|
+
local_model_meta = os.path.join(local_dir, 'model_meta.json')
|
|
133
|
+
|
|
134
|
+
logger.info(f"Downloading model_meta from {remote_model_meta}")
|
|
135
|
+
api.file.download(team_id, remote_model_meta, local_model_meta)
|
|
136
|
+
|
|
137
|
+
# Load saved classes
|
|
138
|
+
model_meta_json = load_json_file(local_model_meta)
|
|
139
|
+
saved_project_meta = sly.ProjectMeta.from_json(model_meta_json)
|
|
140
|
+
saved_classes = [cls.name for cls in saved_project_meta.obj_classes]
|
|
141
|
+
|
|
142
|
+
logger.info(f"Saved classes: {saved_classes}")
|
|
143
|
+
logger.info(f"Current classes: {current_classes}")
|
|
144
|
+
|
|
145
|
+
# Finetune mode - flexible class handling
|
|
146
|
+
if checkpoint_mode == "finetune":
|
|
147
|
+
saved_set = set(saved_classes)
|
|
148
|
+
current_set = set(current_classes)
|
|
149
|
+
|
|
150
|
+
if saved_set == current_set:
|
|
151
|
+
if saved_classes != current_classes:
|
|
152
|
+
logger.info("Class order differs. Reordering classes")
|
|
153
|
+
class_map = reorder_class_map(saved_classes, project_meta)
|
|
154
|
+
else:
|
|
155
|
+
logger.info("Class names match exactly")
|
|
156
|
+
return local_checkpoint, class_map, None
|
|
157
|
+
|
|
158
|
+
elif len(saved_classes) == len(current_classes):
|
|
159
|
+
# logger.info("Class names differ but count matches. Removing classification head")
|
|
160
|
+
# modified_checkpoint = remove_classification_head(local_checkpoint)
|
|
161
|
+
logger.warning("Class names differ but count matches. Classification head will be kept as is")
|
|
162
|
+
return local_checkpoint, class_map, None
|
|
163
|
+
|
|
164
|
+
else:
|
|
165
|
+
logger.info("Classes differ completely. Starting from checkpoint")
|
|
166
|
+
return local_checkpoint, class_map, None
|
|
167
|
+
|
|
168
|
+
# Continue mode - strict matching required
|
|
169
|
+
elif checkpoint_mode == "continue":
|
|
170
|
+
logger.info(f"Continue mode: loading from task_id={selected_experiment_task_id}")
|
|
171
|
+
validate_classes_exact_match(saved_classes, current_classes)
|
|
172
|
+
|
|
173
|
+
# Download state JSON
|
|
174
|
+
state_filename = checkpoint_name.replace('.pth', '_state.json')
|
|
175
|
+
remote_state = f"{artifacts_dir}checkpoints/{state_filename}"
|
|
176
|
+
local_state = local_checkpoint.replace('.pth', '_state.json')
|
|
177
|
+
|
|
178
|
+
logger.info(f"Downloading state from {remote_state}")
|
|
179
|
+
api.file.download(team_id, remote_state, local_state)
|
|
180
|
+
|
|
181
|
+
# Use existing utility function
|
|
182
|
+
state = load_state_json(local_checkpoint)
|
|
183
|
+
|
|
184
|
+
logger.info(f"State loaded with {state.get('dataset_size', 0)} samples at iter {state.get('iter', 0)}")
|
|
185
|
+
return local_checkpoint, class_map, state
|
|
186
|
+
|
|
187
|
+
else:
|
|
188
|
+
raise ValueError(
|
|
189
|
+
f"Invalid checkpoint_mode='{checkpoint_mode}'. Valid: 'scratch', 'finetune', 'continue'"
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
def save_state_json(state: dict, checkpoint_path: str):
|
|
193
|
+
"""
|
|
194
|
+
Save training state as JSON file next to checkpoint.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
state: State dict from LiveTraining.state()
|
|
198
|
+
checkpoint_path: Path to .pth checkpoint file
|
|
199
|
+
"""
|
|
200
|
+
state_path = checkpoint_path.replace('.pth', '_state.json')
|
|
201
|
+
|
|
202
|
+
with open(state_path, 'w') as f:
|
|
203
|
+
json.dump(state, f, indent=2)
|
|
204
|
+
|
|
205
|
+
logger.info(f"State saved to {state_path}")
|
|
206
|
+
|
|
207
|
+
def load_state_json(checkpoint_path: str) -> dict:
|
|
208
|
+
"""
|
|
209
|
+
Load training state from JSON file next to checkpoint.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
checkpoint_path: Path to .pth checkpoint file
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
state: State dict for LiveTraining.load_state()
|
|
216
|
+
"""
|
|
217
|
+
state_path = checkpoint_path.replace('.pth', '_state.json')
|
|
218
|
+
|
|
219
|
+
if not os.path.exists(state_path):
|
|
220
|
+
raise ValueError(
|
|
221
|
+
f"State file not found: {state_path}. "
|
|
222
|
+
f"This checkpoint may not support 'continue' mode."
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
with open(state_path, 'r') as f:
|
|
226
|
+
state = json.load(f)
|
|
227
|
+
|
|
228
|
+
logger.info(f"State loaded from {state_path}")
|
|
229
|
+
return state
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
import random
|
|
2
|
+
import time
|
|
3
|
+
from collections.abc import Sized
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DynamicSampler:
|
|
7
|
+
"""
|
|
8
|
+
A sampler that dynamically adjusts to the size of a dataset that grows over time.
|
|
9
|
+
Implements torch.utils.data.Sampler interface.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def __init__(self, dataset: Sized, shuffle: bool = True, seed: int = 0):
|
|
13
|
+
self.dataset = dataset
|
|
14
|
+
self.shuffle = shuffle
|
|
15
|
+
self.seed = seed
|
|
16
|
+
|
|
17
|
+
def __iter__(self):
|
|
18
|
+
remaining_indices = []
|
|
19
|
+
last_known_len = 0
|
|
20
|
+
|
|
21
|
+
# Wait for first samples
|
|
22
|
+
while len(self.dataset) == 0:
|
|
23
|
+
time.sleep(0.1)
|
|
24
|
+
|
|
25
|
+
while True:
|
|
26
|
+
current_len = len(self.dataset)
|
|
27
|
+
|
|
28
|
+
if current_len > last_known_len:
|
|
29
|
+
new_indices = list(range(last_known_len, current_len))
|
|
30
|
+
if self.shuffle:
|
|
31
|
+
random.shuffle(new_indices)
|
|
32
|
+
remaining_indices.extend(new_indices)
|
|
33
|
+
last_known_len = current_len
|
|
34
|
+
|
|
35
|
+
if not remaining_indices:
|
|
36
|
+
# Reshuffle existing data
|
|
37
|
+
remaining_indices = list(range(current_len))
|
|
38
|
+
if self.shuffle:
|
|
39
|
+
random.shuffle(remaining_indices)
|
|
40
|
+
|
|
41
|
+
yield remaining_indices.pop()
|
|
42
|
+
|
|
43
|
+
def __len__(self):
|
|
44
|
+
return len(self.dataset)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from typing import List, Union
|
|
2
|
+
import supervisely as sly
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class ClassMap:
|
|
6
|
+
def __init__(self, obj_classes: Union[sly.ObjClassCollection, List[sly.ObjClass]]):
|
|
7
|
+
self.obj_classes = obj_classes
|
|
8
|
+
self.class2idx = {obj_class.name: idx for idx, obj_class in enumerate(self.obj_classes)}
|
|
9
|
+
self.idx2class = {idx: obj_class.name for idx, obj_class in enumerate(self.obj_classes)}
|
|
10
|
+
self.classes = [obj_class.name for obj_class in self.obj_classes]
|
|
11
|
+
self.sly_ids = [obj_class.sly_id for obj_class in self.obj_classes]
|
|
12
|
+
|
|
13
|
+
def __len__(self):
|
|
14
|
+
return len(self.obj_classes)
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
from typing import Dict, Any
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from PIL import Image
|
|
4
|
+
import numpy as np
|
|
5
|
+
import supervisely as sly
|
|
6
|
+
import cv2
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class IncrementalDataset:
|
|
10
|
+
"""
|
|
11
|
+
1. Save images on disk
|
|
12
|
+
2. Store annotations in SLY/COCO format. Handle case for Segmentation task
|
|
13
|
+
3. Implement indexing, adding, and updating samples
|
|
14
|
+
4. __getitem__ to retrieve samples by index
|
|
15
|
+
"""
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
class2idx: dict,
|
|
19
|
+
data_dir: str,
|
|
20
|
+
save_masks_as_images: bool = False,
|
|
21
|
+
):
|
|
22
|
+
self.class2idx = class2idx
|
|
23
|
+
self.data_dir = Path(data_dir)
|
|
24
|
+
self.save_masks_as_images = save_masks_as_images
|
|
25
|
+
self.images_dir = self.data_dir / "images"
|
|
26
|
+
self.images_dir.mkdir(parents=True, exist_ok=True)
|
|
27
|
+
if self.save_masks_as_images:
|
|
28
|
+
self.masks_dir = self.data_dir / "masks"
|
|
29
|
+
self.masks_dir.mkdir(parents=True, exist_ok=True)
|
|
30
|
+
self.samples: Dict[int, dict] = {}
|
|
31
|
+
self.samples_list = []
|
|
32
|
+
|
|
33
|
+
def add(
|
|
34
|
+
self,
|
|
35
|
+
image_id: int,
|
|
36
|
+
image_np: np.ndarray,
|
|
37
|
+
annotation: sly.Annotation,
|
|
38
|
+
image_name: str
|
|
39
|
+
) -> dict:
|
|
40
|
+
if image_id in self.samples:
|
|
41
|
+
raise ValueError(f"Cannot add sample: Image ID {image_id} already exists in the dataset.")
|
|
42
|
+
image_name = f"{image_id} {image_name}"
|
|
43
|
+
w, h = image_np.shape[1], image_np.shape[0]
|
|
44
|
+
img_size = (w, h)
|
|
45
|
+
image_path = self._save_img(image_np, image_name)
|
|
46
|
+
mask_path = None
|
|
47
|
+
if self.save_masks_as_images:
|
|
48
|
+
mask_path = self._save_mask(annotation, image_name)
|
|
49
|
+
sample = self._format_sample(
|
|
50
|
+
image_id,
|
|
51
|
+
annotation,
|
|
52
|
+
img_size,
|
|
53
|
+
image_path,
|
|
54
|
+
mask_path
|
|
55
|
+
)
|
|
56
|
+
assert isinstance(sample, dict), "Sample must be a dict."
|
|
57
|
+
# add extra fields for internal use
|
|
58
|
+
sample['image_path'] = image_path
|
|
59
|
+
sample['size'] = img_size
|
|
60
|
+
if mask_path is not None:
|
|
61
|
+
sample['mask_path'] = mask_path
|
|
62
|
+
# add to dataset
|
|
63
|
+
self.samples[image_id] = sample
|
|
64
|
+
self.samples_list.append(sample)
|
|
65
|
+
return sample
|
|
66
|
+
|
|
67
|
+
def update(
|
|
68
|
+
self,
|
|
69
|
+
image_id: int,
|
|
70
|
+
annotation: sly.Annotation,
|
|
71
|
+
) -> dict:
|
|
72
|
+
if image_id not in self.samples:
|
|
73
|
+
raise ValueError(f"Cannot update sample: Image ID {image_id} does not exist in the dataset.")
|
|
74
|
+
sample = self.samples[image_id]
|
|
75
|
+
new_sample = self._format_sample(
|
|
76
|
+
image_id,
|
|
77
|
+
annotation,
|
|
78
|
+
sample['size'],
|
|
79
|
+
sample['image_path'],
|
|
80
|
+
sample.get('mask_path')
|
|
81
|
+
)
|
|
82
|
+
sample.update(new_sample)
|
|
83
|
+
return sample
|
|
84
|
+
|
|
85
|
+
def add_or_update(
|
|
86
|
+
self,
|
|
87
|
+
image_id: int,
|
|
88
|
+
image_np: np.ndarray,
|
|
89
|
+
annotation: sly.Annotation,
|
|
90
|
+
image_name: str
|
|
91
|
+
) -> dict:
|
|
92
|
+
if image_id not in self.samples:
|
|
93
|
+
return self.add(image_id, image_np, annotation, image_name)
|
|
94
|
+
else:
|
|
95
|
+
return self.update(image_id, annotation)
|
|
96
|
+
|
|
97
|
+
def _format_sample(
|
|
98
|
+
self,
|
|
99
|
+
image_id: int,
|
|
100
|
+
annotation: sly.Annotation,
|
|
101
|
+
image_size: tuple,
|
|
102
|
+
image_path: str,
|
|
103
|
+
mask_path: str = None
|
|
104
|
+
) -> dict:
|
|
105
|
+
sample = {
|
|
106
|
+
'image_id': image_id,
|
|
107
|
+
'width': image_size[0],
|
|
108
|
+
'height': image_size[1],
|
|
109
|
+
'annotations': annotation.to_coco(annotation.image_id, self.class2idx)[0],
|
|
110
|
+
'image_path': image_path,
|
|
111
|
+
'mask_path': mask_path
|
|
112
|
+
}
|
|
113
|
+
return sample
|
|
114
|
+
|
|
115
|
+
def _save_img(self, image_np: np.ndarray, image_name: str) -> str:
|
|
116
|
+
image = Image.fromarray(image_np).convert('RGB')
|
|
117
|
+
image_path = str(self.images_dir / image_name)
|
|
118
|
+
image.save(image_path)
|
|
119
|
+
return image_path
|
|
120
|
+
|
|
121
|
+
def _save_mask(self, annotation: sly.Annotation, image_name: str) -> str:
|
|
122
|
+
|
|
123
|
+
mapping = {label.obj_class: label.obj_class for label in annotation.labels}
|
|
124
|
+
ann_nonoverlap = annotation.to_nonoverlapping_masks(mapping)
|
|
125
|
+
h, w = annotation.img_size
|
|
126
|
+
mask = np.zeros((h, w), dtype=np.uint8)
|
|
127
|
+
|
|
128
|
+
for label in ann_nonoverlap.labels:
|
|
129
|
+
class_name = label.obj_class.name
|
|
130
|
+
class_id = self.class2idx.get(class_name)
|
|
131
|
+
if class_id is not None:
|
|
132
|
+
label.geometry.draw(mask, color=class_id)
|
|
133
|
+
|
|
134
|
+
mask_name = Path(image_name).stem + '.png'
|
|
135
|
+
mask_path = str(self.masks_dir / mask_name)
|
|
136
|
+
cv2.imwrite(mask_path, mask)
|
|
137
|
+
|
|
138
|
+
return mask_path
|
|
139
|
+
|
|
140
|
+
def __len__(self) -> int:
|
|
141
|
+
return len(self.samples)
|
|
142
|
+
|
|
143
|
+
def get_image_ids(self) -> list:
|
|
144
|
+
"""Get list of image IDs in dataset"""
|
|
145
|
+
return list(self.samples.keys())
|
|
146
|
+
|