wandb 0.15.9__py3-none-any.whl → 0.15.11__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +5 -1
- wandb/apis/public.py +137 -17
- wandb/apis/reports/_panels.py +1 -1
- wandb/apis/reports/blocks.py +1 -0
- wandb/apis/reports/report.py +27 -5
- wandb/cli/cli.py +52 -41
- wandb/docker/__init__.py +17 -0
- wandb/docker/auth.py +1 -1
- wandb/env.py +24 -4
- wandb/filesync/step_checksum.py +3 -3
- wandb/integration/openai/openai.py +3 -0
- wandb/integration/ultralytics/__init__.py +9 -0
- wandb/integration/ultralytics/bbox_utils.py +196 -0
- wandb/integration/ultralytics/callback.py +458 -0
- wandb/integration/ultralytics/classification_utils.py +66 -0
- wandb/integration/ultralytics/mask_utils.py +141 -0
- wandb/integration/ultralytics/pose_utils.py +92 -0
- wandb/integration/xgboost/xgboost.py +3 -3
- wandb/integration/yolov8/__init__.py +0 -7
- wandb/integration/yolov8/yolov8.py +22 -3
- wandb/old/settings.py +7 -0
- wandb/plot/line_series.py +0 -1
- wandb/proto/v3/wandb_internal_pb2.py +353 -300
- wandb/proto/v3/wandb_server_pb2.py +37 -41
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +16 -16
- wandb/proto/v4/wandb_internal_pb2.py +272 -260
- wandb/proto/v4/wandb_server_pb2.py +37 -40
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +16 -16
- wandb/proto/wandb_internal_codegen.py +7 -31
- wandb/sdk/artifacts/artifact.py +321 -189
- wandb/sdk/artifacts/artifact_cache.py +14 -0
- wandb/sdk/artifacts/artifact_manifest.py +5 -4
- wandb/sdk/artifacts/artifact_manifest_entry.py +37 -9
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -9
- wandb/sdk/artifacts/artifact_saver.py +13 -50
- wandb/sdk/artifacts/artifact_ttl.py +6 -0
- wandb/sdk/artifacts/artifacts_cache.py +119 -93
- wandb/sdk/artifacts/staging.py +25 -0
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +12 -7
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +2 -3
- wandb/sdk/artifacts/storage_policies/__init__.py +4 -0
- wandb/sdk/artifacts/storage_policies/register.py +1 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +4 -3
- wandb/sdk/artifacts/storage_policy.py +4 -2
- wandb/sdk/backend/backend.py +0 -16
- wandb/sdk/data_types/image.py +3 -1
- wandb/sdk/integration_utils/auto_logging.py +38 -13
- wandb/sdk/interface/interface.py +16 -135
- wandb/sdk/interface/interface_shared.py +9 -147
- wandb/sdk/interface/interface_sock.py +0 -26
- wandb/sdk/internal/file_pusher.py +20 -3
- wandb/sdk/internal/file_stream.py +3 -1
- wandb/sdk/internal/handler.py +53 -70
- wandb/sdk/internal/internal_api.py +220 -130
- wandb/sdk/internal/job_builder.py +41 -37
- wandb/sdk/internal/sender.py +7 -25
- wandb/sdk/internal/system/assets/disk.py +144 -11
- wandb/sdk/internal/system/system_info.py +6 -2
- wandb/sdk/launch/__init__.py +5 -0
- wandb/sdk/launch/{launch.py → _launch.py} +53 -54
- wandb/sdk/launch/{launch_add.py → _launch_add.py} +34 -31
- wandb/sdk/launch/_project_spec.py +13 -2
- wandb/sdk/launch/agent/agent.py +103 -59
- wandb/sdk/launch/agent/run_queue_item_file_saver.py +6 -4
- wandb/sdk/launch/builder/build.py +19 -1
- wandb/sdk/launch/builder/docker_builder.py +5 -1
- wandb/sdk/launch/builder/kaniko_builder.py +5 -1
- wandb/sdk/launch/create_job.py +20 -5
- wandb/sdk/launch/loader.py +14 -5
- wandb/sdk/launch/runner/abstract.py +0 -2
- wandb/sdk/launch/runner/kubernetes_monitor.py +329 -0
- wandb/sdk/launch/runner/kubernetes_runner.py +66 -209
- wandb/sdk/launch/runner/local_container.py +5 -2
- wandb/sdk/launch/runner/local_process.py +4 -1
- wandb/sdk/launch/sweeps/scheduler.py +43 -25
- wandb/sdk/launch/sweeps/utils.py +5 -3
- wandb/sdk/launch/utils.py +3 -1
- wandb/sdk/lib/_settings_toposort_generate.py +3 -9
- wandb/sdk/lib/_settings_toposort_generated.py +27 -3
- wandb/sdk/lib/_wburls_generated.py +1 -0
- wandb/sdk/lib/filenames.py +27 -6
- wandb/sdk/lib/filesystem.py +181 -7
- wandb/sdk/lib/fsm.py +5 -3
- wandb/sdk/lib/gql_request.py +3 -0
- wandb/sdk/lib/ipython.py +7 -0
- wandb/sdk/lib/wburls.py +1 -0
- wandb/sdk/service/port_file.py +2 -15
- wandb/sdk/service/server.py +7 -55
- wandb/sdk/service/service.py +56 -26
- wandb/sdk/service/service_base.py +1 -1
- wandb/sdk/service/streams.py +11 -5
- wandb/sdk/verify/verify.py +2 -2
- wandb/sdk/wandb_init.py +8 -2
- wandb/sdk/wandb_manager.py +4 -14
- wandb/sdk/wandb_run.py +143 -53
- wandb/sdk/wandb_settings.py +148 -35
- wandb/testing/relay.py +85 -38
- wandb/util.py +87 -4
- wandb/wandb_torch.py +24 -38
- {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/METADATA +48 -23
- {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/RECORD +107 -103
- {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/WHEEL +1 -1
- wandb/proto/v3/wandb_server_pb2_grpc.py +0 -1422
- wandb/proto/v4/wandb_server_pb2_grpc.py +0 -1422
- wandb/proto/wandb_server_pb2_grpc.py +0 -8
- wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +0 -61
- wandb/sdk/interface/interface_grpc.py +0 -460
- wandb/sdk/service/server_grpc.py +0 -444
- wandb/sdk/service/service_grpc.py +0 -73
- {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/LICENSE +0 -0
- {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/entry_points.txt +0 -0
- {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,458 @@
|
|
1
|
+
import copy
|
2
|
+
from datetime import datetime
|
3
|
+
from typing import Callable, Dict, Optional, Union
|
4
|
+
|
5
|
+
try:
|
6
|
+
import dill as pickle
|
7
|
+
except ImportError:
|
8
|
+
import pickle
|
9
|
+
|
10
|
+
import wandb
|
11
|
+
from wandb.sdk.lib import telemetry
|
12
|
+
|
13
|
+
try:
|
14
|
+
import torch
|
15
|
+
from tqdm.auto import tqdm
|
16
|
+
from ultralytics.models import YOLO
|
17
|
+
from ultralytics.models.yolo.classify import (
|
18
|
+
ClassificationPredictor,
|
19
|
+
ClassificationTrainer,
|
20
|
+
ClassificationValidator,
|
21
|
+
)
|
22
|
+
from ultralytics.models.yolo.detect import (
|
23
|
+
DetectionPredictor,
|
24
|
+
DetectionTrainer,
|
25
|
+
DetectionValidator,
|
26
|
+
)
|
27
|
+
from ultralytics.models.yolo.pose import PosePredictor, PoseTrainer, PoseValidator
|
28
|
+
from ultralytics.models.yolo.segment import (
|
29
|
+
SegmentationPredictor,
|
30
|
+
SegmentationTrainer,
|
31
|
+
SegmentationValidator,
|
32
|
+
)
|
33
|
+
from ultralytics.utils.torch_utils import de_parallel
|
34
|
+
from ultralytics.yolo.utils import RANK, __version__
|
35
|
+
|
36
|
+
from wandb.integration.ultralytics.bbox_utils import (
|
37
|
+
plot_predictions,
|
38
|
+
plot_validation_results,
|
39
|
+
)
|
40
|
+
from wandb.integration.ultralytics.classification_utils import (
|
41
|
+
plot_classification_predictions,
|
42
|
+
plot_classification_validation_results,
|
43
|
+
)
|
44
|
+
from wandb.integration.ultralytics.mask_utils import (
|
45
|
+
plot_mask_predictions,
|
46
|
+
plot_mask_validation_results,
|
47
|
+
)
|
48
|
+
from wandb.integration.ultralytics.pose_utils import (
|
49
|
+
plot_pose_predictions,
|
50
|
+
plot_pose_validation_results,
|
51
|
+
)
|
52
|
+
except ImportError as e:
|
53
|
+
wandb.error(e)
|
54
|
+
|
55
|
+
|
56
|
+
TRAINER_TYPE = Union[
|
57
|
+
ClassificationTrainer, DetectionTrainer, SegmentationTrainer, PoseTrainer
|
58
|
+
]
|
59
|
+
VALIDATOR_TYPE = Union[
|
60
|
+
ClassificationValidator, DetectionValidator, SegmentationValidator, PoseValidator
|
61
|
+
]
|
62
|
+
PREDICTOR_TYPE = Union[
|
63
|
+
ClassificationPredictor, DetectionPredictor, SegmentationPredictor, PosePredictor
|
64
|
+
]
|
65
|
+
|
66
|
+
|
67
|
+
class WandBUltralyticsCallback:
|
68
|
+
"""Stateful callback for logging to W&B.
|
69
|
+
|
70
|
+
In particular, it will log model checkpoints, predictions, and
|
71
|
+
ground-truth annotations with interactive overlays for bounding boxes
|
72
|
+
to Weights & Biases Tables during training, validation and prediction
|
73
|
+
for a `ultratytics` workflow.
|
74
|
+
|
75
|
+
**Usage:**
|
76
|
+
|
77
|
+
```python
|
78
|
+
from ultralytics.yolo.engine.model import YOLO
|
79
|
+
from wandb.yolov8 import add_wandb_callback
|
80
|
+
|
81
|
+
# initialize YOLO model
|
82
|
+
model = YOLO("yolov8n.pt")
|
83
|
+
|
84
|
+
# add wandb callback
|
85
|
+
add_wandb_callback(model, max_validation_batches=2, enable_model_checkpointing=True)
|
86
|
+
|
87
|
+
# train
|
88
|
+
model.train(data="coco128.yaml", epochs=5, imgsz=640)
|
89
|
+
|
90
|
+
# validate
|
91
|
+
model.val()
|
92
|
+
|
93
|
+
# perform inference
|
94
|
+
model(["img1.jpeg", "img2.jpeg"])
|
95
|
+
```
|
96
|
+
|
97
|
+
Args:
|
98
|
+
model: YOLO Model of type `:class:ultralytics.yolo.engine.model.YOLO`.
|
99
|
+
max_validation_batches: maximum number of validation batches to log to
|
100
|
+
a table per epoch.
|
101
|
+
enable_model_checkpointing: enable logging model checkpoints as
|
102
|
+
artifacts at the end of eveny epoch if set to `True`.
|
103
|
+
visualize_skeleton: visualize pose skeleton by drawing lines connecting
|
104
|
+
keypoints for human pose.
|
105
|
+
"""
|
106
|
+
|
107
|
+
def __init__(
|
108
|
+
self,
|
109
|
+
model: YOLO,
|
110
|
+
max_validation_batches: int = 1,
|
111
|
+
enable_model_checkpointing: bool = False,
|
112
|
+
visualize_skeleton: bool = False,
|
113
|
+
) -> None:
|
114
|
+
self.max_validation_batches = max_validation_batches
|
115
|
+
self.enable_model_checkpointing = enable_model_checkpointing
|
116
|
+
self.visualize_skeleton = visualize_skeleton
|
117
|
+
self.task = model.task
|
118
|
+
self.task_map = model.task_map
|
119
|
+
self.model_name = model.overrides["model"].split(".")[0]
|
120
|
+
self._make_tables()
|
121
|
+
self._make_predictor(model)
|
122
|
+
self.supported_tasks = ["detect", "segment", "pose", "classify"]
|
123
|
+
|
124
|
+
def _make_tables(self):
|
125
|
+
if self.task in ["detect", "segment"]:
|
126
|
+
validation_columns = [
|
127
|
+
"Data-Index",
|
128
|
+
"Batch-Index",
|
129
|
+
"Image",
|
130
|
+
"Mean-Confidence",
|
131
|
+
"Speed",
|
132
|
+
]
|
133
|
+
train_columns = ["Epoch"] + validation_columns
|
134
|
+
self.train_validation_table = wandb.Table(
|
135
|
+
columns=["Model-Name"] + train_columns
|
136
|
+
)
|
137
|
+
self.validation_table = wandb.Table(
|
138
|
+
columns=["Model-Name"] + validation_columns
|
139
|
+
)
|
140
|
+
self.prediction_table = wandb.Table(
|
141
|
+
columns=[
|
142
|
+
"Model-Name",
|
143
|
+
"Image",
|
144
|
+
"Num-Objects",
|
145
|
+
"Mean-Confidence",
|
146
|
+
"Speed",
|
147
|
+
]
|
148
|
+
)
|
149
|
+
elif self.task == "classify":
|
150
|
+
classification_columns = [
|
151
|
+
"Image",
|
152
|
+
"Predicted-Category",
|
153
|
+
"Prediction-Confidence",
|
154
|
+
"Top-5-Prediction-Categories",
|
155
|
+
"Top-5-Prediction-Confindence",
|
156
|
+
"Probabilities",
|
157
|
+
"Speed",
|
158
|
+
]
|
159
|
+
validation_columns = ["Data-Index", "Batch-Index"] + classification_columns
|
160
|
+
validation_columns.insert(3, "Ground-Truth-Category")
|
161
|
+
self.train_validation_table = wandb.Table(
|
162
|
+
columns=["Model-Name", "Epoch"] + validation_columns
|
163
|
+
)
|
164
|
+
self.validation_table = wandb.Table(
|
165
|
+
columns=["Model-Name"] + validation_columns
|
166
|
+
)
|
167
|
+
self.prediction_table = wandb.Table(
|
168
|
+
columns=["Model-Name"] + classification_columns
|
169
|
+
)
|
170
|
+
elif self.task == "pose":
|
171
|
+
validation_columns = [
|
172
|
+
"Data-Index",
|
173
|
+
"Batch-Index",
|
174
|
+
"Image-Ground-Truth",
|
175
|
+
"Image-Prediction",
|
176
|
+
"Num-Instances",
|
177
|
+
"Mean-Confidence",
|
178
|
+
"Speed",
|
179
|
+
]
|
180
|
+
train_columns = ["Epoch"] + validation_columns
|
181
|
+
self.train_validation_table = wandb.Table(
|
182
|
+
columns=["Model-Name"] + train_columns
|
183
|
+
)
|
184
|
+
self.validation_table = wandb.Table(
|
185
|
+
columns=["Model-Name"] + validation_columns
|
186
|
+
)
|
187
|
+
self.prediction_table = wandb.Table(
|
188
|
+
columns=[
|
189
|
+
"Model-Name",
|
190
|
+
"Image-Prediction",
|
191
|
+
"Num-Instances",
|
192
|
+
"Mean-Confidence",
|
193
|
+
"Speed",
|
194
|
+
]
|
195
|
+
)
|
196
|
+
|
197
|
+
def _make_predictor(self, model: YOLO):
|
198
|
+
overrides = copy.deepcopy(model.overrides)
|
199
|
+
overrides["conf"] = 0.1
|
200
|
+
self.predictor = self.task_map[self.task]["predictor"](
|
201
|
+
overrides=overrides, _callbacks=None
|
202
|
+
)
|
203
|
+
|
204
|
+
def _save_model(self, trainer: TRAINER_TYPE):
|
205
|
+
model_checkpoint_artifact = wandb.Artifact(
|
206
|
+
f"run_{wandb.run.id}_model", "model", metadata=vars(trainer.args)
|
207
|
+
)
|
208
|
+
checkpoint_dict = {
|
209
|
+
"epoch": trainer.epoch,
|
210
|
+
"best_fitness": trainer.best_fitness,
|
211
|
+
"model": copy.deepcopy(de_parallel(self.model)).half(),
|
212
|
+
"ema": copy.deepcopy(trainer.ema.ema).half(),
|
213
|
+
"updates": trainer.ema.updates,
|
214
|
+
"optimizer": trainer.optimizer.state_dict(),
|
215
|
+
"train_args": vars(trainer.args),
|
216
|
+
"date": datetime.now().isoformat(),
|
217
|
+
"version": __version__,
|
218
|
+
}
|
219
|
+
checkpoint_path = trainer.wdir / f"epoch{trainer.epoch}.pt"
|
220
|
+
torch.save(checkpoint_dict, checkpoint_path, pickle_module=pickle)
|
221
|
+
model_checkpoint_artifact.add_file(checkpoint_path)
|
222
|
+
wandb.log_artifact(
|
223
|
+
model_checkpoint_artifact, aliases=[f"epoch_{trainer.epoch}"]
|
224
|
+
)
|
225
|
+
|
226
|
+
def on_train_start(self, trainer: TRAINER_TYPE):
|
227
|
+
with telemetry.context(run=wandb.run) as tel:
|
228
|
+
tel.feature.ultralytics_yolov8 = True
|
229
|
+
wandb.config.train = vars(trainer.args)
|
230
|
+
|
231
|
+
def on_fit_epoch_end(self, trainer: TRAINER_TYPE):
|
232
|
+
if self.task in self.supported_tasks:
|
233
|
+
validator = trainer.validator
|
234
|
+
dataloader = validator.dataloader
|
235
|
+
class_label_map = validator.names
|
236
|
+
with torch.no_grad():
|
237
|
+
self.device = next(trainer.model.parameters()).device
|
238
|
+
if isinstance(trainer.model, torch.nn.parallel.DistributedDataParallel):
|
239
|
+
model = trainer.model.module
|
240
|
+
else:
|
241
|
+
model = trainer.model
|
242
|
+
self.model = copy.deepcopy(model).eval().to(self.device)
|
243
|
+
self.predictor.setup_model(model=self.model, verbose=False)
|
244
|
+
if self.task == "pose":
|
245
|
+
self.train_validation_table = plot_pose_validation_results(
|
246
|
+
dataloader=dataloader,
|
247
|
+
class_label_map=class_label_map,
|
248
|
+
model_name=self.model_name,
|
249
|
+
predictor=self.predictor,
|
250
|
+
visualize_skeleton=self.visualize_skeleton,
|
251
|
+
table=self.train_validation_table,
|
252
|
+
max_validation_batches=self.max_validation_batches,
|
253
|
+
epoch=trainer.epoch,
|
254
|
+
)
|
255
|
+
elif self.task == "segment":
|
256
|
+
self.train_validation_table = plot_mask_validation_results(
|
257
|
+
dataloader=dataloader,
|
258
|
+
class_label_map=class_label_map,
|
259
|
+
model_name=self.model_name,
|
260
|
+
predictor=self.predictor,
|
261
|
+
table=self.train_validation_table,
|
262
|
+
max_validation_batches=self.max_validation_batches,
|
263
|
+
epoch=trainer.epoch,
|
264
|
+
)
|
265
|
+
elif self.task == "detect":
|
266
|
+
self.train_validation_table = plot_validation_results(
|
267
|
+
dataloader=dataloader,
|
268
|
+
class_label_map=class_label_map,
|
269
|
+
model_name=self.model_name,
|
270
|
+
predictor=self.predictor,
|
271
|
+
table=self.train_validation_table,
|
272
|
+
max_validation_batches=self.max_validation_batches,
|
273
|
+
epoch=trainer.epoch,
|
274
|
+
)
|
275
|
+
elif self.task == "classify":
|
276
|
+
self.train_validation_table = (
|
277
|
+
plot_classification_validation_results(
|
278
|
+
dataloader=dataloader,
|
279
|
+
model_name=self.model_name,
|
280
|
+
predictor=self.predictor,
|
281
|
+
table=self.train_validation_table,
|
282
|
+
max_validation_batches=self.max_validation_batches,
|
283
|
+
epoch=trainer.epoch,
|
284
|
+
)
|
285
|
+
)
|
286
|
+
if self.enable_model_checkpointing:
|
287
|
+
self._save_model(trainer)
|
288
|
+
self.model.to("cpu")
|
289
|
+
trainer.model.to(self.device)
|
290
|
+
|
291
|
+
def on_train_end(self, trainer: TRAINER_TYPE):
|
292
|
+
if self.task in self.supported_tasks:
|
293
|
+
wandb.log({"Train-Validation-Table": self.train_validation_table})
|
294
|
+
|
295
|
+
def on_val_end(self, trainer: VALIDATOR_TYPE):
|
296
|
+
if self.task in self.supported_tasks:
|
297
|
+
validator = trainer
|
298
|
+
dataloader = validator.dataloader
|
299
|
+
class_label_map = validator.names
|
300
|
+
with torch.no_grad():
|
301
|
+
self.model.to(self.device)
|
302
|
+
self.predictor.setup_model(model=self.model, verbose=False)
|
303
|
+
if self.task == "pose":
|
304
|
+
self.validation_table = plot_pose_validation_results(
|
305
|
+
dataloader=dataloader,
|
306
|
+
class_label_map=class_label_map,
|
307
|
+
model_name=self.model_name,
|
308
|
+
predictor=self.predictor,
|
309
|
+
visualize_skeleton=self.visualize_skeleton,
|
310
|
+
table=self.validation_table,
|
311
|
+
max_validation_batches=self.max_validation_batches,
|
312
|
+
)
|
313
|
+
elif self.task == "segment":
|
314
|
+
self.validation_table = plot_mask_validation_results(
|
315
|
+
dataloader=dataloader,
|
316
|
+
class_label_map=class_label_map,
|
317
|
+
model_name=self.model_name,
|
318
|
+
predictor=self.predictor,
|
319
|
+
table=self.validation_table,
|
320
|
+
max_validation_batches=self.max_validation_batches,
|
321
|
+
)
|
322
|
+
elif self.task == "detect":
|
323
|
+
self.validation_table = plot_validation_results(
|
324
|
+
dataloader=dataloader,
|
325
|
+
class_label_map=class_label_map,
|
326
|
+
model_name=self.model_name,
|
327
|
+
predictor=self.predictor,
|
328
|
+
table=self.validation_table,
|
329
|
+
max_validation_batches=self.max_validation_batches,
|
330
|
+
)
|
331
|
+
elif self.task == "classify":
|
332
|
+
self.validation_table = plot_classification_validation_results(
|
333
|
+
dataloader=dataloader,
|
334
|
+
model_name=self.model_name,
|
335
|
+
predictor=self.predictor,
|
336
|
+
table=self.validation_table,
|
337
|
+
max_validation_batches=self.max_validation_batches,
|
338
|
+
)
|
339
|
+
wandb.log({"Validation-Table": self.validation_table})
|
340
|
+
|
341
|
+
def on_predict_end(self, predictor: PREDICTOR_TYPE):
|
342
|
+
wandb.config.prediction_configs = vars(predictor.args)
|
343
|
+
if self.task in self.supported_tasks:
|
344
|
+
for result in tqdm(predictor.results):
|
345
|
+
if self.task == "pose":
|
346
|
+
self.prediction_table = plot_pose_predictions(
|
347
|
+
result,
|
348
|
+
self.model_name,
|
349
|
+
self.visualize_skeleton,
|
350
|
+
self.prediction_table,
|
351
|
+
)
|
352
|
+
elif self.task == "segment":
|
353
|
+
self.prediction_table = plot_mask_predictions(
|
354
|
+
result, self.model_name, self.prediction_table
|
355
|
+
)
|
356
|
+
elif self.task == "detect":
|
357
|
+
self.prediction_table = plot_predictions(
|
358
|
+
result, self.model_name, self.prediction_table
|
359
|
+
)
|
360
|
+
elif self.task == "classify":
|
361
|
+
self.prediction_table = plot_classification_predictions(
|
362
|
+
result, self.model_name, self.prediction_table
|
363
|
+
)
|
364
|
+
|
365
|
+
wandb.log({"Prediction-Table": self.prediction_table})
|
366
|
+
|
367
|
+
@property
|
368
|
+
def callbacks(self) -> Dict[str, Callable]:
|
369
|
+
"""Property contains all the relevant callbacks to add to the YOLO model for the Weights & Biases logging."""
|
370
|
+
return {
|
371
|
+
"on_train_start": self.on_train_start,
|
372
|
+
"on_fit_epoch_end": self.on_fit_epoch_end,
|
373
|
+
"on_train_end": self.on_train_end,
|
374
|
+
"on_val_end": self.on_val_end,
|
375
|
+
"on_predict_end": self.on_predict_end,
|
376
|
+
}
|
377
|
+
|
378
|
+
|
379
|
+
def add_wandb_callback(
|
380
|
+
model: YOLO,
|
381
|
+
enable_model_checkpointing: bool = False,
|
382
|
+
enable_train_validation_logging: bool = True,
|
383
|
+
enable_validation_logging: bool = True,
|
384
|
+
enable_prediction_logging: bool = True,
|
385
|
+
max_validation_batches: Optional[int] = 1,
|
386
|
+
visualize_skeleton: Optional[bool] = True,
|
387
|
+
):
|
388
|
+
"""Function to add the `WandBUltralyticsCallback` callback to the `YOLO` model.
|
389
|
+
|
390
|
+
**Usage:**
|
391
|
+
|
392
|
+
```python
|
393
|
+
from ultralytics.yolo.engine.model import YOLO
|
394
|
+
from wandb.yolov8 import add_wandb_callback
|
395
|
+
|
396
|
+
# initialize YOLO model
|
397
|
+
model = YOLO("yolov8n.pt")
|
398
|
+
|
399
|
+
# add wandb callback
|
400
|
+
add_wandb_callback(model, max_validation_batches=2, enable_model_checkpointing=True)
|
401
|
+
|
402
|
+
# train
|
403
|
+
model.train(data="coco128.yaml", epochs=5, imgsz=640)
|
404
|
+
|
405
|
+
# validate
|
406
|
+
model.val()
|
407
|
+
|
408
|
+
# perform inference
|
409
|
+
model(["img1.jpeg", "img2.jpeg"])
|
410
|
+
```
|
411
|
+
|
412
|
+
Args:
|
413
|
+
model: YOLO Model of type `:class:ultralytics.yolo.engine.model.YOLO`.
|
414
|
+
enable_model_checkpointing: enable logging model checkpoints as
|
415
|
+
artifacts at the end of eveny epoch if set to `True`.
|
416
|
+
enable_train_validation_logging: enable logging the predictions and
|
417
|
+
ground-truths as interactive image overlays on the images from
|
418
|
+
the validation dataloader to a `wandb.Table` along with
|
419
|
+
mean-confidence of the predictions per-class at the end of each
|
420
|
+
training epoch.
|
421
|
+
enable_validation_logging: enable logging the predictions and
|
422
|
+
ground-truths as interactive image overlays on the images from the
|
423
|
+
validation dataloader to a `wandb.Table` along with
|
424
|
+
mean-confidence of the predictions per-class at the end of
|
425
|
+
validation.
|
426
|
+
enable_prediction_logging: enable logging the predictions and
|
427
|
+
ground-truths as interactive image overlays on the images from the
|
428
|
+
validation dataloader to a `wandb.Table` along with mean-confidence
|
429
|
+
of the predictions per-class at the end of each prediction.
|
430
|
+
max_validation_batches: maximum number of validation batches to log to
|
431
|
+
a table per epoch.
|
432
|
+
visualize_skeleton: visualize pose skeleton by drawing lines connecting
|
433
|
+
keypoints for human pose.
|
434
|
+
"""
|
435
|
+
if RANK in [-1, 0]:
|
436
|
+
wandb_callback = WandBUltralyticsCallback(
|
437
|
+
copy.deepcopy(model),
|
438
|
+
max_validation_batches,
|
439
|
+
enable_model_checkpointing,
|
440
|
+
visualize_skeleton,
|
441
|
+
)
|
442
|
+
callbacks = wandb_callback.callbacks
|
443
|
+
if not enable_train_validation_logging:
|
444
|
+
_ = callbacks.pop("on_fit_epoch_end")
|
445
|
+
_ = callbacks.pop("on_train_end")
|
446
|
+
if not enable_validation_logging:
|
447
|
+
_ = callbacks.pop("on_val_end")
|
448
|
+
if not enable_prediction_logging:
|
449
|
+
_ = callbacks.pop("on_predict_end")
|
450
|
+
for event, callback_fn in callbacks.items():
|
451
|
+
model.add_callback(event, callback_fn)
|
452
|
+
else:
|
453
|
+
wandb.termerror(
|
454
|
+
"The RANK of the process to add the callbacks was neither 0 or "
|
455
|
+
"-1. No Weights & Biases callbacks were added to this instance "
|
456
|
+
"of the YOLO model."
|
457
|
+
)
|
458
|
+
return model
|
@@ -0,0 +1,66 @@
|
|
1
|
+
from typing import Any, Optional
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
from ultralytics.engine.results import Results
|
5
|
+
from ultralytics.models.yolo.classify import ClassificationPredictor
|
6
|
+
|
7
|
+
import wandb
|
8
|
+
|
9
|
+
|
10
|
+
def plot_classification_predictions(
|
11
|
+
result: Results, model_name: str, table: Optional[wandb.Table] = None
|
12
|
+
):
|
13
|
+
"""Plot classification prediction results to a `wandb.Table` if the table is passed otherwise return the data."""
|
14
|
+
result = result.to("cpu")
|
15
|
+
probabilities = result.probs
|
16
|
+
probabilities_list = probabilities.data.numpy().tolist()
|
17
|
+
class_id_to_label = {int(k): str(v) for k, v in result.names.items()}
|
18
|
+
table_row = [
|
19
|
+
model_name,
|
20
|
+
wandb.Image(result.orig_img[:, :, ::-1]),
|
21
|
+
class_id_to_label[int(probabilities.top1)],
|
22
|
+
probabilities.top1conf,
|
23
|
+
[class_id_to_label[int(class_idx)] for class_idx in list(probabilities.top5)],
|
24
|
+
[probabilities_list[int(class_idx)] for class_idx in list(probabilities.top5)],
|
25
|
+
{
|
26
|
+
class_id_to_label[int(class_idx)]: probability
|
27
|
+
for class_idx, probability in enumerate(probabilities_list)
|
28
|
+
},
|
29
|
+
result.speed,
|
30
|
+
]
|
31
|
+
if table is not None:
|
32
|
+
table.add_data(*table_row)
|
33
|
+
return table
|
34
|
+
return class_id_to_label, table_row
|
35
|
+
|
36
|
+
|
37
|
+
def plot_classification_validation_results(
|
38
|
+
dataloader: Any,
|
39
|
+
model_name: str,
|
40
|
+
predictor: ClassificationPredictor,
|
41
|
+
table: wandb.Table,
|
42
|
+
max_validation_batches: int,
|
43
|
+
epoch: Optional[int] = None,
|
44
|
+
):
|
45
|
+
"""Plot classification results to a `wandb.Table`."""
|
46
|
+
data_idx = 0
|
47
|
+
predictor.args.save = False
|
48
|
+
predictor.args.show = False
|
49
|
+
for batch_idx, batch in enumerate(dataloader):
|
50
|
+
image_batch = batch["img"].numpy()
|
51
|
+
ground_truth = batch["cls"].numpy().tolist()
|
52
|
+
for img_idx in range(image_batch.shape[0]):
|
53
|
+
image = np.transpose(image_batch[img_idx], (1, 2, 0))
|
54
|
+
prediction_result = predictor(image, show=False)[0]
|
55
|
+
class_id_to_label, table_row = plot_classification_predictions(
|
56
|
+
prediction_result, model_name
|
57
|
+
)
|
58
|
+
table_row = [data_idx, batch_idx] + table_row[1:]
|
59
|
+
table_row.insert(3, class_id_to_label[ground_truth[img_idx]])
|
60
|
+
table_row = [epoch] + table_row if epoch is not None else table_row
|
61
|
+
table_row = [model_name] + table_row
|
62
|
+
table.add_data(*table_row)
|
63
|
+
data_idx += 1
|
64
|
+
if batch_idx + 1 == max_validation_batches:
|
65
|
+
break
|
66
|
+
return table
|
@@ -0,0 +1,141 @@
|
|
1
|
+
from typing import Dict, Optional, Tuple
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
from ultralytics.engine.results import Results
|
5
|
+
from ultralytics.models.yolo.segment import SegmentationPredictor
|
6
|
+
from ultralytics.utils.ops import scale_image
|
7
|
+
|
8
|
+
import wandb
|
9
|
+
|
10
|
+
from .bbox_utils import get_ground_truth_bbox_annotations, get_mean_confidence_map
|
11
|
+
|
12
|
+
|
13
|
+
def instance_mask_to_semantic_mask(instance_mask, class_indices):
|
14
|
+
height, width, num_instances = instance_mask.shape
|
15
|
+
semantic_mask = np.zeros((height, width), dtype=np.uint8)
|
16
|
+
for i in range(num_instances):
|
17
|
+
instance_map = instance_mask[:, :, i]
|
18
|
+
class_index = class_indices[i]
|
19
|
+
semantic_mask[instance_map == 1] = class_index
|
20
|
+
return semantic_mask
|
21
|
+
|
22
|
+
|
23
|
+
def get_boxes_and_masks(result: Results) -> Tuple[Dict, Dict, Dict]:
|
24
|
+
boxes = result.boxes.xywh.long().numpy()
|
25
|
+
classes = result.boxes.cls.long().numpy()
|
26
|
+
confidence = result.boxes.conf.numpy()
|
27
|
+
class_id_to_label = {int(k): str(v) for k, v in result.names.items()}
|
28
|
+
class_id_to_label.update({len(result.names.items()): "background"})
|
29
|
+
mean_confidence_map = get_mean_confidence_map(
|
30
|
+
classes, confidence, class_id_to_label
|
31
|
+
)
|
32
|
+
masks = None
|
33
|
+
if result.masks is not None:
|
34
|
+
scaled_instance_mask = scale_image(
|
35
|
+
np.transpose(result.masks.data.numpy(), (1, 2, 0)),
|
36
|
+
result.orig_img[:, :, ::-1].shape,
|
37
|
+
)
|
38
|
+
scaled_semantic_mask = instance_mask_to_semantic_mask(
|
39
|
+
scaled_instance_mask, classes.tolist()
|
40
|
+
)
|
41
|
+
scaled_semantic_mask[scaled_semantic_mask == 0] = len(result.names.items())
|
42
|
+
masks = {
|
43
|
+
"predictions": {
|
44
|
+
"mask_data": scaled_semantic_mask,
|
45
|
+
"class_labels": class_id_to_label,
|
46
|
+
}
|
47
|
+
}
|
48
|
+
box_data, total_confidence = [], 0.0
|
49
|
+
for idx in range(len(boxes)):
|
50
|
+
box_data.append(
|
51
|
+
{
|
52
|
+
"position": {
|
53
|
+
"middle": [int(boxes[idx][0]), int(boxes[idx][1])],
|
54
|
+
"width": int(boxes[idx][2]),
|
55
|
+
"height": int(boxes[idx][3]),
|
56
|
+
},
|
57
|
+
"domain": "pixel",
|
58
|
+
"class_id": int(classes[idx]),
|
59
|
+
"box_caption": class_id_to_label[int(classes[idx])],
|
60
|
+
"scores": {"confidence": float(confidence[idx])},
|
61
|
+
}
|
62
|
+
)
|
63
|
+
total_confidence += float(confidence[idx])
|
64
|
+
|
65
|
+
boxes = {
|
66
|
+
"predictions": {
|
67
|
+
"box_data": box_data,
|
68
|
+
"class_labels": class_id_to_label,
|
69
|
+
},
|
70
|
+
}
|
71
|
+
return boxes, masks, mean_confidence_map
|
72
|
+
|
73
|
+
|
74
|
+
def plot_mask_predictions(
|
75
|
+
result: Results, model_name: str, table: Optional[wandb.Table] = None
|
76
|
+
) -> Tuple[wandb.Image, Dict, Dict, Dict]:
|
77
|
+
result = result.to("cpu")
|
78
|
+
boxes, masks, mean_confidence_map = get_boxes_and_masks(result)
|
79
|
+
image = wandb.Image(result.orig_img[:, :, ::-1], boxes=boxes, masks=masks)
|
80
|
+
if table is not None:
|
81
|
+
table.add_data(
|
82
|
+
model_name,
|
83
|
+
image,
|
84
|
+
len(boxes["predictions"]["box_data"]),
|
85
|
+
mean_confidence_map,
|
86
|
+
result.speed,
|
87
|
+
)
|
88
|
+
return table
|
89
|
+
return image, masks, boxes["predictions"], mean_confidence_map
|
90
|
+
|
91
|
+
|
92
|
+
def plot_mask_validation_results(
|
93
|
+
dataloader,
|
94
|
+
class_label_map,
|
95
|
+
model_name: str,
|
96
|
+
predictor: SegmentationPredictor,
|
97
|
+
table: wandb.Table,
|
98
|
+
max_validation_batches: int,
|
99
|
+
epoch: Optional[int] = None,
|
100
|
+
):
|
101
|
+
data_idx = 0
|
102
|
+
for batch_idx, batch in enumerate(dataloader):
|
103
|
+
for img_idx, image_path in enumerate(batch["im_file"]):
|
104
|
+
prediction_result = predictor(image_path)[0]
|
105
|
+
(
|
106
|
+
_,
|
107
|
+
prediction_mask_data,
|
108
|
+
prediction_box_data,
|
109
|
+
mean_confidence_map,
|
110
|
+
) = plot_mask_predictions(prediction_result, model_name)
|
111
|
+
try:
|
112
|
+
ground_truth_data = get_ground_truth_bbox_annotations(
|
113
|
+
img_idx, image_path, batch, class_label_map
|
114
|
+
)
|
115
|
+
wandb_image = wandb.Image(
|
116
|
+
image_path,
|
117
|
+
boxes={
|
118
|
+
"ground-truth": {
|
119
|
+
"box_data": ground_truth_data,
|
120
|
+
"class_labels": class_label_map,
|
121
|
+
},
|
122
|
+
"predictions": prediction_box_data,
|
123
|
+
},
|
124
|
+
masks=prediction_mask_data,
|
125
|
+
)
|
126
|
+
table_rows = [
|
127
|
+
data_idx,
|
128
|
+
batch_idx,
|
129
|
+
wandb_image,
|
130
|
+
mean_confidence_map,
|
131
|
+
prediction_result.speed,
|
132
|
+
]
|
133
|
+
table_rows = [epoch] + table_rows if epoch is not None else table_rows
|
134
|
+
table_rows = [model_name] + table_rows
|
135
|
+
table.add_data(*table_rows)
|
136
|
+
data_idx += 1
|
137
|
+
except TypeError:
|
138
|
+
pass
|
139
|
+
if batch_idx + 1 == max_validation_batches:
|
140
|
+
break
|
141
|
+
return table
|