wandb 0.15.9__py3-none-any.whl → 0.15.11__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (114) hide show
  1. wandb/__init__.py +5 -1
  2. wandb/apis/public.py +137 -17
  3. wandb/apis/reports/_panels.py +1 -1
  4. wandb/apis/reports/blocks.py +1 -0
  5. wandb/apis/reports/report.py +27 -5
  6. wandb/cli/cli.py +52 -41
  7. wandb/docker/__init__.py +17 -0
  8. wandb/docker/auth.py +1 -1
  9. wandb/env.py +24 -4
  10. wandb/filesync/step_checksum.py +3 -3
  11. wandb/integration/openai/openai.py +3 -0
  12. wandb/integration/ultralytics/__init__.py +9 -0
  13. wandb/integration/ultralytics/bbox_utils.py +196 -0
  14. wandb/integration/ultralytics/callback.py +458 -0
  15. wandb/integration/ultralytics/classification_utils.py +66 -0
  16. wandb/integration/ultralytics/mask_utils.py +141 -0
  17. wandb/integration/ultralytics/pose_utils.py +92 -0
  18. wandb/integration/xgboost/xgboost.py +3 -3
  19. wandb/integration/yolov8/__init__.py +0 -7
  20. wandb/integration/yolov8/yolov8.py +22 -3
  21. wandb/old/settings.py +7 -0
  22. wandb/plot/line_series.py +0 -1
  23. wandb/proto/v3/wandb_internal_pb2.py +353 -300
  24. wandb/proto/v3/wandb_server_pb2.py +37 -41
  25. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  26. wandb/proto/v3/wandb_telemetry_pb2.py +16 -16
  27. wandb/proto/v4/wandb_internal_pb2.py +272 -260
  28. wandb/proto/v4/wandb_server_pb2.py +37 -40
  29. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  30. wandb/proto/v4/wandb_telemetry_pb2.py +16 -16
  31. wandb/proto/wandb_internal_codegen.py +7 -31
  32. wandb/sdk/artifacts/artifact.py +321 -189
  33. wandb/sdk/artifacts/artifact_cache.py +14 -0
  34. wandb/sdk/artifacts/artifact_manifest.py +5 -4
  35. wandb/sdk/artifacts/artifact_manifest_entry.py +37 -9
  36. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -9
  37. wandb/sdk/artifacts/artifact_saver.py +13 -50
  38. wandb/sdk/artifacts/artifact_ttl.py +6 -0
  39. wandb/sdk/artifacts/artifacts_cache.py +119 -93
  40. wandb/sdk/artifacts/staging.py +25 -0
  41. wandb/sdk/artifacts/storage_handlers/s3_handler.py +12 -7
  42. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +2 -3
  43. wandb/sdk/artifacts/storage_policies/__init__.py +4 -0
  44. wandb/sdk/artifacts/storage_policies/register.py +1 -0
  45. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +4 -3
  46. wandb/sdk/artifacts/storage_policy.py +4 -2
  47. wandb/sdk/backend/backend.py +0 -16
  48. wandb/sdk/data_types/image.py +3 -1
  49. wandb/sdk/integration_utils/auto_logging.py +38 -13
  50. wandb/sdk/interface/interface.py +16 -135
  51. wandb/sdk/interface/interface_shared.py +9 -147
  52. wandb/sdk/interface/interface_sock.py +0 -26
  53. wandb/sdk/internal/file_pusher.py +20 -3
  54. wandb/sdk/internal/file_stream.py +3 -1
  55. wandb/sdk/internal/handler.py +53 -70
  56. wandb/sdk/internal/internal_api.py +220 -130
  57. wandb/sdk/internal/job_builder.py +41 -37
  58. wandb/sdk/internal/sender.py +7 -25
  59. wandb/sdk/internal/system/assets/disk.py +144 -11
  60. wandb/sdk/internal/system/system_info.py +6 -2
  61. wandb/sdk/launch/__init__.py +5 -0
  62. wandb/sdk/launch/{launch.py → _launch.py} +53 -54
  63. wandb/sdk/launch/{launch_add.py → _launch_add.py} +34 -31
  64. wandb/sdk/launch/_project_spec.py +13 -2
  65. wandb/sdk/launch/agent/agent.py +103 -59
  66. wandb/sdk/launch/agent/run_queue_item_file_saver.py +6 -4
  67. wandb/sdk/launch/builder/build.py +19 -1
  68. wandb/sdk/launch/builder/docker_builder.py +5 -1
  69. wandb/sdk/launch/builder/kaniko_builder.py +5 -1
  70. wandb/sdk/launch/create_job.py +20 -5
  71. wandb/sdk/launch/loader.py +14 -5
  72. wandb/sdk/launch/runner/abstract.py +0 -2
  73. wandb/sdk/launch/runner/kubernetes_monitor.py +329 -0
  74. wandb/sdk/launch/runner/kubernetes_runner.py +66 -209
  75. wandb/sdk/launch/runner/local_container.py +5 -2
  76. wandb/sdk/launch/runner/local_process.py +4 -1
  77. wandb/sdk/launch/sweeps/scheduler.py +43 -25
  78. wandb/sdk/launch/sweeps/utils.py +5 -3
  79. wandb/sdk/launch/utils.py +3 -1
  80. wandb/sdk/lib/_settings_toposort_generate.py +3 -9
  81. wandb/sdk/lib/_settings_toposort_generated.py +27 -3
  82. wandb/sdk/lib/_wburls_generated.py +1 -0
  83. wandb/sdk/lib/filenames.py +27 -6
  84. wandb/sdk/lib/filesystem.py +181 -7
  85. wandb/sdk/lib/fsm.py +5 -3
  86. wandb/sdk/lib/gql_request.py +3 -0
  87. wandb/sdk/lib/ipython.py +7 -0
  88. wandb/sdk/lib/wburls.py +1 -0
  89. wandb/sdk/service/port_file.py +2 -15
  90. wandb/sdk/service/server.py +7 -55
  91. wandb/sdk/service/service.py +56 -26
  92. wandb/sdk/service/service_base.py +1 -1
  93. wandb/sdk/service/streams.py +11 -5
  94. wandb/sdk/verify/verify.py +2 -2
  95. wandb/sdk/wandb_init.py +8 -2
  96. wandb/sdk/wandb_manager.py +4 -14
  97. wandb/sdk/wandb_run.py +143 -53
  98. wandb/sdk/wandb_settings.py +148 -35
  99. wandb/testing/relay.py +85 -38
  100. wandb/util.py +87 -4
  101. wandb/wandb_torch.py +24 -38
  102. {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/METADATA +48 -23
  103. {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/RECORD +107 -103
  104. {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/WHEEL +1 -1
  105. wandb/proto/v3/wandb_server_pb2_grpc.py +0 -1422
  106. wandb/proto/v4/wandb_server_pb2_grpc.py +0 -1422
  107. wandb/proto/wandb_server_pb2_grpc.py +0 -8
  108. wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +0 -61
  109. wandb/sdk/interface/interface_grpc.py +0 -460
  110. wandb/sdk/service/server_grpc.py +0 -444
  111. wandb/sdk/service/service_grpc.py +0 -73
  112. {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/LICENSE +0 -0
  113. {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/entry_points.txt +0 -0
  114. {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,458 @@
1
+ import copy
2
+ from datetime import datetime
3
+ from typing import Callable, Dict, Optional, Union
4
+
5
+ try:
6
+ import dill as pickle
7
+ except ImportError:
8
+ import pickle
9
+
10
+ import wandb
11
+ from wandb.sdk.lib import telemetry
12
+
13
+ try:
14
+ import torch
15
+ from tqdm.auto import tqdm
16
+ from ultralytics.models import YOLO
17
+ from ultralytics.models.yolo.classify import (
18
+ ClassificationPredictor,
19
+ ClassificationTrainer,
20
+ ClassificationValidator,
21
+ )
22
+ from ultralytics.models.yolo.detect import (
23
+ DetectionPredictor,
24
+ DetectionTrainer,
25
+ DetectionValidator,
26
+ )
27
+ from ultralytics.models.yolo.pose import PosePredictor, PoseTrainer, PoseValidator
28
+ from ultralytics.models.yolo.segment import (
29
+ SegmentationPredictor,
30
+ SegmentationTrainer,
31
+ SegmentationValidator,
32
+ )
33
+ from ultralytics.utils.torch_utils import de_parallel
34
+ from ultralytics.yolo.utils import RANK, __version__
35
+
36
+ from wandb.integration.ultralytics.bbox_utils import (
37
+ plot_predictions,
38
+ plot_validation_results,
39
+ )
40
+ from wandb.integration.ultralytics.classification_utils import (
41
+ plot_classification_predictions,
42
+ plot_classification_validation_results,
43
+ )
44
+ from wandb.integration.ultralytics.mask_utils import (
45
+ plot_mask_predictions,
46
+ plot_mask_validation_results,
47
+ )
48
+ from wandb.integration.ultralytics.pose_utils import (
49
+ plot_pose_predictions,
50
+ plot_pose_validation_results,
51
+ )
52
+ except ImportError as e:
53
+ wandb.error(e)
54
+
55
+
56
+ TRAINER_TYPE = Union[
57
+ ClassificationTrainer, DetectionTrainer, SegmentationTrainer, PoseTrainer
58
+ ]
59
+ VALIDATOR_TYPE = Union[
60
+ ClassificationValidator, DetectionValidator, SegmentationValidator, PoseValidator
61
+ ]
62
+ PREDICTOR_TYPE = Union[
63
+ ClassificationPredictor, DetectionPredictor, SegmentationPredictor, PosePredictor
64
+ ]
65
+
66
+
67
+ class WandBUltralyticsCallback:
68
+ """Stateful callback for logging to W&B.
69
+
70
+ In particular, it will log model checkpoints, predictions, and
71
+ ground-truth annotations with interactive overlays for bounding boxes
72
+ to Weights & Biases Tables during training, validation and prediction
73
+ for a `ultratytics` workflow.
74
+
75
+ **Usage:**
76
+
77
+ ```python
78
+ from ultralytics.yolo.engine.model import YOLO
79
+ from wandb.yolov8 import add_wandb_callback
80
+
81
+ # initialize YOLO model
82
+ model = YOLO("yolov8n.pt")
83
+
84
+ # add wandb callback
85
+ add_wandb_callback(model, max_validation_batches=2, enable_model_checkpointing=True)
86
+
87
+ # train
88
+ model.train(data="coco128.yaml", epochs=5, imgsz=640)
89
+
90
+ # validate
91
+ model.val()
92
+
93
+ # perform inference
94
+ model(["img1.jpeg", "img2.jpeg"])
95
+ ```
96
+
97
+ Args:
98
+ model: YOLO Model of type `:class:ultralytics.yolo.engine.model.YOLO`.
99
+ max_validation_batches: maximum number of validation batches to log to
100
+ a table per epoch.
101
+ enable_model_checkpointing: enable logging model checkpoints as
102
+ artifacts at the end of eveny epoch if set to `True`.
103
+ visualize_skeleton: visualize pose skeleton by drawing lines connecting
104
+ keypoints for human pose.
105
+ """
106
+
107
+ def __init__(
108
+ self,
109
+ model: YOLO,
110
+ max_validation_batches: int = 1,
111
+ enable_model_checkpointing: bool = False,
112
+ visualize_skeleton: bool = False,
113
+ ) -> None:
114
+ self.max_validation_batches = max_validation_batches
115
+ self.enable_model_checkpointing = enable_model_checkpointing
116
+ self.visualize_skeleton = visualize_skeleton
117
+ self.task = model.task
118
+ self.task_map = model.task_map
119
+ self.model_name = model.overrides["model"].split(".")[0]
120
+ self._make_tables()
121
+ self._make_predictor(model)
122
+ self.supported_tasks = ["detect", "segment", "pose", "classify"]
123
+
124
+ def _make_tables(self):
125
+ if self.task in ["detect", "segment"]:
126
+ validation_columns = [
127
+ "Data-Index",
128
+ "Batch-Index",
129
+ "Image",
130
+ "Mean-Confidence",
131
+ "Speed",
132
+ ]
133
+ train_columns = ["Epoch"] + validation_columns
134
+ self.train_validation_table = wandb.Table(
135
+ columns=["Model-Name"] + train_columns
136
+ )
137
+ self.validation_table = wandb.Table(
138
+ columns=["Model-Name"] + validation_columns
139
+ )
140
+ self.prediction_table = wandb.Table(
141
+ columns=[
142
+ "Model-Name",
143
+ "Image",
144
+ "Num-Objects",
145
+ "Mean-Confidence",
146
+ "Speed",
147
+ ]
148
+ )
149
+ elif self.task == "classify":
150
+ classification_columns = [
151
+ "Image",
152
+ "Predicted-Category",
153
+ "Prediction-Confidence",
154
+ "Top-5-Prediction-Categories",
155
+ "Top-5-Prediction-Confindence",
156
+ "Probabilities",
157
+ "Speed",
158
+ ]
159
+ validation_columns = ["Data-Index", "Batch-Index"] + classification_columns
160
+ validation_columns.insert(3, "Ground-Truth-Category")
161
+ self.train_validation_table = wandb.Table(
162
+ columns=["Model-Name", "Epoch"] + validation_columns
163
+ )
164
+ self.validation_table = wandb.Table(
165
+ columns=["Model-Name"] + validation_columns
166
+ )
167
+ self.prediction_table = wandb.Table(
168
+ columns=["Model-Name"] + classification_columns
169
+ )
170
+ elif self.task == "pose":
171
+ validation_columns = [
172
+ "Data-Index",
173
+ "Batch-Index",
174
+ "Image-Ground-Truth",
175
+ "Image-Prediction",
176
+ "Num-Instances",
177
+ "Mean-Confidence",
178
+ "Speed",
179
+ ]
180
+ train_columns = ["Epoch"] + validation_columns
181
+ self.train_validation_table = wandb.Table(
182
+ columns=["Model-Name"] + train_columns
183
+ )
184
+ self.validation_table = wandb.Table(
185
+ columns=["Model-Name"] + validation_columns
186
+ )
187
+ self.prediction_table = wandb.Table(
188
+ columns=[
189
+ "Model-Name",
190
+ "Image-Prediction",
191
+ "Num-Instances",
192
+ "Mean-Confidence",
193
+ "Speed",
194
+ ]
195
+ )
196
+
197
+ def _make_predictor(self, model: YOLO):
198
+ overrides = copy.deepcopy(model.overrides)
199
+ overrides["conf"] = 0.1
200
+ self.predictor = self.task_map[self.task]["predictor"](
201
+ overrides=overrides, _callbacks=None
202
+ )
203
+
204
+ def _save_model(self, trainer: TRAINER_TYPE):
205
+ model_checkpoint_artifact = wandb.Artifact(
206
+ f"run_{wandb.run.id}_model", "model", metadata=vars(trainer.args)
207
+ )
208
+ checkpoint_dict = {
209
+ "epoch": trainer.epoch,
210
+ "best_fitness": trainer.best_fitness,
211
+ "model": copy.deepcopy(de_parallel(self.model)).half(),
212
+ "ema": copy.deepcopy(trainer.ema.ema).half(),
213
+ "updates": trainer.ema.updates,
214
+ "optimizer": trainer.optimizer.state_dict(),
215
+ "train_args": vars(trainer.args),
216
+ "date": datetime.now().isoformat(),
217
+ "version": __version__,
218
+ }
219
+ checkpoint_path = trainer.wdir / f"epoch{trainer.epoch}.pt"
220
+ torch.save(checkpoint_dict, checkpoint_path, pickle_module=pickle)
221
+ model_checkpoint_artifact.add_file(checkpoint_path)
222
+ wandb.log_artifact(
223
+ model_checkpoint_artifact, aliases=[f"epoch_{trainer.epoch}"]
224
+ )
225
+
226
+ def on_train_start(self, trainer: TRAINER_TYPE):
227
+ with telemetry.context(run=wandb.run) as tel:
228
+ tel.feature.ultralytics_yolov8 = True
229
+ wandb.config.train = vars(trainer.args)
230
+
231
+ def on_fit_epoch_end(self, trainer: TRAINER_TYPE):
232
+ if self.task in self.supported_tasks:
233
+ validator = trainer.validator
234
+ dataloader = validator.dataloader
235
+ class_label_map = validator.names
236
+ with torch.no_grad():
237
+ self.device = next(trainer.model.parameters()).device
238
+ if isinstance(trainer.model, torch.nn.parallel.DistributedDataParallel):
239
+ model = trainer.model.module
240
+ else:
241
+ model = trainer.model
242
+ self.model = copy.deepcopy(model).eval().to(self.device)
243
+ self.predictor.setup_model(model=self.model, verbose=False)
244
+ if self.task == "pose":
245
+ self.train_validation_table = plot_pose_validation_results(
246
+ dataloader=dataloader,
247
+ class_label_map=class_label_map,
248
+ model_name=self.model_name,
249
+ predictor=self.predictor,
250
+ visualize_skeleton=self.visualize_skeleton,
251
+ table=self.train_validation_table,
252
+ max_validation_batches=self.max_validation_batches,
253
+ epoch=trainer.epoch,
254
+ )
255
+ elif self.task == "segment":
256
+ self.train_validation_table = plot_mask_validation_results(
257
+ dataloader=dataloader,
258
+ class_label_map=class_label_map,
259
+ model_name=self.model_name,
260
+ predictor=self.predictor,
261
+ table=self.train_validation_table,
262
+ max_validation_batches=self.max_validation_batches,
263
+ epoch=trainer.epoch,
264
+ )
265
+ elif self.task == "detect":
266
+ self.train_validation_table = plot_validation_results(
267
+ dataloader=dataloader,
268
+ class_label_map=class_label_map,
269
+ model_name=self.model_name,
270
+ predictor=self.predictor,
271
+ table=self.train_validation_table,
272
+ max_validation_batches=self.max_validation_batches,
273
+ epoch=trainer.epoch,
274
+ )
275
+ elif self.task == "classify":
276
+ self.train_validation_table = (
277
+ plot_classification_validation_results(
278
+ dataloader=dataloader,
279
+ model_name=self.model_name,
280
+ predictor=self.predictor,
281
+ table=self.train_validation_table,
282
+ max_validation_batches=self.max_validation_batches,
283
+ epoch=trainer.epoch,
284
+ )
285
+ )
286
+ if self.enable_model_checkpointing:
287
+ self._save_model(trainer)
288
+ self.model.to("cpu")
289
+ trainer.model.to(self.device)
290
+
291
+ def on_train_end(self, trainer: TRAINER_TYPE):
292
+ if self.task in self.supported_tasks:
293
+ wandb.log({"Train-Validation-Table": self.train_validation_table})
294
+
295
+ def on_val_end(self, trainer: VALIDATOR_TYPE):
296
+ if self.task in self.supported_tasks:
297
+ validator = trainer
298
+ dataloader = validator.dataloader
299
+ class_label_map = validator.names
300
+ with torch.no_grad():
301
+ self.model.to(self.device)
302
+ self.predictor.setup_model(model=self.model, verbose=False)
303
+ if self.task == "pose":
304
+ self.validation_table = plot_pose_validation_results(
305
+ dataloader=dataloader,
306
+ class_label_map=class_label_map,
307
+ model_name=self.model_name,
308
+ predictor=self.predictor,
309
+ visualize_skeleton=self.visualize_skeleton,
310
+ table=self.validation_table,
311
+ max_validation_batches=self.max_validation_batches,
312
+ )
313
+ elif self.task == "segment":
314
+ self.validation_table = plot_mask_validation_results(
315
+ dataloader=dataloader,
316
+ class_label_map=class_label_map,
317
+ model_name=self.model_name,
318
+ predictor=self.predictor,
319
+ table=self.validation_table,
320
+ max_validation_batches=self.max_validation_batches,
321
+ )
322
+ elif self.task == "detect":
323
+ self.validation_table = plot_validation_results(
324
+ dataloader=dataloader,
325
+ class_label_map=class_label_map,
326
+ model_name=self.model_name,
327
+ predictor=self.predictor,
328
+ table=self.validation_table,
329
+ max_validation_batches=self.max_validation_batches,
330
+ )
331
+ elif self.task == "classify":
332
+ self.validation_table = plot_classification_validation_results(
333
+ dataloader=dataloader,
334
+ model_name=self.model_name,
335
+ predictor=self.predictor,
336
+ table=self.validation_table,
337
+ max_validation_batches=self.max_validation_batches,
338
+ )
339
+ wandb.log({"Validation-Table": self.validation_table})
340
+
341
+ def on_predict_end(self, predictor: PREDICTOR_TYPE):
342
+ wandb.config.prediction_configs = vars(predictor.args)
343
+ if self.task in self.supported_tasks:
344
+ for result in tqdm(predictor.results):
345
+ if self.task == "pose":
346
+ self.prediction_table = plot_pose_predictions(
347
+ result,
348
+ self.model_name,
349
+ self.visualize_skeleton,
350
+ self.prediction_table,
351
+ )
352
+ elif self.task == "segment":
353
+ self.prediction_table = plot_mask_predictions(
354
+ result, self.model_name, self.prediction_table
355
+ )
356
+ elif self.task == "detect":
357
+ self.prediction_table = plot_predictions(
358
+ result, self.model_name, self.prediction_table
359
+ )
360
+ elif self.task == "classify":
361
+ self.prediction_table = plot_classification_predictions(
362
+ result, self.model_name, self.prediction_table
363
+ )
364
+
365
+ wandb.log({"Prediction-Table": self.prediction_table})
366
+
367
+ @property
368
+ def callbacks(self) -> Dict[str, Callable]:
369
+ """Property contains all the relevant callbacks to add to the YOLO model for the Weights & Biases logging."""
370
+ return {
371
+ "on_train_start": self.on_train_start,
372
+ "on_fit_epoch_end": self.on_fit_epoch_end,
373
+ "on_train_end": self.on_train_end,
374
+ "on_val_end": self.on_val_end,
375
+ "on_predict_end": self.on_predict_end,
376
+ }
377
+
378
+
379
+ def add_wandb_callback(
380
+ model: YOLO,
381
+ enable_model_checkpointing: bool = False,
382
+ enable_train_validation_logging: bool = True,
383
+ enable_validation_logging: bool = True,
384
+ enable_prediction_logging: bool = True,
385
+ max_validation_batches: Optional[int] = 1,
386
+ visualize_skeleton: Optional[bool] = True,
387
+ ):
388
+ """Function to add the `WandBUltralyticsCallback` callback to the `YOLO` model.
389
+
390
+ **Usage:**
391
+
392
+ ```python
393
+ from ultralytics.yolo.engine.model import YOLO
394
+ from wandb.yolov8 import add_wandb_callback
395
+
396
+ # initialize YOLO model
397
+ model = YOLO("yolov8n.pt")
398
+
399
+ # add wandb callback
400
+ add_wandb_callback(model, max_validation_batches=2, enable_model_checkpointing=True)
401
+
402
+ # train
403
+ model.train(data="coco128.yaml", epochs=5, imgsz=640)
404
+
405
+ # validate
406
+ model.val()
407
+
408
+ # perform inference
409
+ model(["img1.jpeg", "img2.jpeg"])
410
+ ```
411
+
412
+ Args:
413
+ model: YOLO Model of type `:class:ultralytics.yolo.engine.model.YOLO`.
414
+ enable_model_checkpointing: enable logging model checkpoints as
415
+ artifacts at the end of eveny epoch if set to `True`.
416
+ enable_train_validation_logging: enable logging the predictions and
417
+ ground-truths as interactive image overlays on the images from
418
+ the validation dataloader to a `wandb.Table` along with
419
+ mean-confidence of the predictions per-class at the end of each
420
+ training epoch.
421
+ enable_validation_logging: enable logging the predictions and
422
+ ground-truths as interactive image overlays on the images from the
423
+ validation dataloader to a `wandb.Table` along with
424
+ mean-confidence of the predictions per-class at the end of
425
+ validation.
426
+ enable_prediction_logging: enable logging the predictions and
427
+ ground-truths as interactive image overlays on the images from the
428
+ validation dataloader to a `wandb.Table` along with mean-confidence
429
+ of the predictions per-class at the end of each prediction.
430
+ max_validation_batches: maximum number of validation batches to log to
431
+ a table per epoch.
432
+ visualize_skeleton: visualize pose skeleton by drawing lines connecting
433
+ keypoints for human pose.
434
+ """
435
+ if RANK in [-1, 0]:
436
+ wandb_callback = WandBUltralyticsCallback(
437
+ copy.deepcopy(model),
438
+ max_validation_batches,
439
+ enable_model_checkpointing,
440
+ visualize_skeleton,
441
+ )
442
+ callbacks = wandb_callback.callbacks
443
+ if not enable_train_validation_logging:
444
+ _ = callbacks.pop("on_fit_epoch_end")
445
+ _ = callbacks.pop("on_train_end")
446
+ if not enable_validation_logging:
447
+ _ = callbacks.pop("on_val_end")
448
+ if not enable_prediction_logging:
449
+ _ = callbacks.pop("on_predict_end")
450
+ for event, callback_fn in callbacks.items():
451
+ model.add_callback(event, callback_fn)
452
+ else:
453
+ wandb.termerror(
454
+ "The RANK of the process to add the callbacks was neither 0 or "
455
+ "-1. No Weights & Biases callbacks were added to this instance "
456
+ "of the YOLO model."
457
+ )
458
+ return model
@@ -0,0 +1,66 @@
1
+ from typing import Any, Optional
2
+
3
+ import numpy as np
4
+ from ultralytics.engine.results import Results
5
+ from ultralytics.models.yolo.classify import ClassificationPredictor
6
+
7
+ import wandb
8
+
9
+
10
+ def plot_classification_predictions(
11
+ result: Results, model_name: str, table: Optional[wandb.Table] = None
12
+ ):
13
+ """Plot classification prediction results to a `wandb.Table` if the table is passed otherwise return the data."""
14
+ result = result.to("cpu")
15
+ probabilities = result.probs
16
+ probabilities_list = probabilities.data.numpy().tolist()
17
+ class_id_to_label = {int(k): str(v) for k, v in result.names.items()}
18
+ table_row = [
19
+ model_name,
20
+ wandb.Image(result.orig_img[:, :, ::-1]),
21
+ class_id_to_label[int(probabilities.top1)],
22
+ probabilities.top1conf,
23
+ [class_id_to_label[int(class_idx)] for class_idx in list(probabilities.top5)],
24
+ [probabilities_list[int(class_idx)] for class_idx in list(probabilities.top5)],
25
+ {
26
+ class_id_to_label[int(class_idx)]: probability
27
+ for class_idx, probability in enumerate(probabilities_list)
28
+ },
29
+ result.speed,
30
+ ]
31
+ if table is not None:
32
+ table.add_data(*table_row)
33
+ return table
34
+ return class_id_to_label, table_row
35
+
36
+
37
+ def plot_classification_validation_results(
38
+ dataloader: Any,
39
+ model_name: str,
40
+ predictor: ClassificationPredictor,
41
+ table: wandb.Table,
42
+ max_validation_batches: int,
43
+ epoch: Optional[int] = None,
44
+ ):
45
+ """Plot classification results to a `wandb.Table`."""
46
+ data_idx = 0
47
+ predictor.args.save = False
48
+ predictor.args.show = False
49
+ for batch_idx, batch in enumerate(dataloader):
50
+ image_batch = batch["img"].numpy()
51
+ ground_truth = batch["cls"].numpy().tolist()
52
+ for img_idx in range(image_batch.shape[0]):
53
+ image = np.transpose(image_batch[img_idx], (1, 2, 0))
54
+ prediction_result = predictor(image, show=False)[0]
55
+ class_id_to_label, table_row = plot_classification_predictions(
56
+ prediction_result, model_name
57
+ )
58
+ table_row = [data_idx, batch_idx] + table_row[1:]
59
+ table_row.insert(3, class_id_to_label[ground_truth[img_idx]])
60
+ table_row = [epoch] + table_row if epoch is not None else table_row
61
+ table_row = [model_name] + table_row
62
+ table.add_data(*table_row)
63
+ data_idx += 1
64
+ if batch_idx + 1 == max_validation_batches:
65
+ break
66
+ return table
@@ -0,0 +1,141 @@
1
+ from typing import Dict, Optional, Tuple
2
+
3
+ import numpy as np
4
+ from ultralytics.engine.results import Results
5
+ from ultralytics.models.yolo.segment import SegmentationPredictor
6
+ from ultralytics.utils.ops import scale_image
7
+
8
+ import wandb
9
+
10
+ from .bbox_utils import get_ground_truth_bbox_annotations, get_mean_confidence_map
11
+
12
+
13
+ def instance_mask_to_semantic_mask(instance_mask, class_indices):
14
+ height, width, num_instances = instance_mask.shape
15
+ semantic_mask = np.zeros((height, width), dtype=np.uint8)
16
+ for i in range(num_instances):
17
+ instance_map = instance_mask[:, :, i]
18
+ class_index = class_indices[i]
19
+ semantic_mask[instance_map == 1] = class_index
20
+ return semantic_mask
21
+
22
+
23
+ def get_boxes_and_masks(result: Results) -> Tuple[Dict, Dict, Dict]:
24
+ boxes = result.boxes.xywh.long().numpy()
25
+ classes = result.boxes.cls.long().numpy()
26
+ confidence = result.boxes.conf.numpy()
27
+ class_id_to_label = {int(k): str(v) for k, v in result.names.items()}
28
+ class_id_to_label.update({len(result.names.items()): "background"})
29
+ mean_confidence_map = get_mean_confidence_map(
30
+ classes, confidence, class_id_to_label
31
+ )
32
+ masks = None
33
+ if result.masks is not None:
34
+ scaled_instance_mask = scale_image(
35
+ np.transpose(result.masks.data.numpy(), (1, 2, 0)),
36
+ result.orig_img[:, :, ::-1].shape,
37
+ )
38
+ scaled_semantic_mask = instance_mask_to_semantic_mask(
39
+ scaled_instance_mask, classes.tolist()
40
+ )
41
+ scaled_semantic_mask[scaled_semantic_mask == 0] = len(result.names.items())
42
+ masks = {
43
+ "predictions": {
44
+ "mask_data": scaled_semantic_mask,
45
+ "class_labels": class_id_to_label,
46
+ }
47
+ }
48
+ box_data, total_confidence = [], 0.0
49
+ for idx in range(len(boxes)):
50
+ box_data.append(
51
+ {
52
+ "position": {
53
+ "middle": [int(boxes[idx][0]), int(boxes[idx][1])],
54
+ "width": int(boxes[idx][2]),
55
+ "height": int(boxes[idx][3]),
56
+ },
57
+ "domain": "pixel",
58
+ "class_id": int(classes[idx]),
59
+ "box_caption": class_id_to_label[int(classes[idx])],
60
+ "scores": {"confidence": float(confidence[idx])},
61
+ }
62
+ )
63
+ total_confidence += float(confidence[idx])
64
+
65
+ boxes = {
66
+ "predictions": {
67
+ "box_data": box_data,
68
+ "class_labels": class_id_to_label,
69
+ },
70
+ }
71
+ return boxes, masks, mean_confidence_map
72
+
73
+
74
+ def plot_mask_predictions(
75
+ result: Results, model_name: str, table: Optional[wandb.Table] = None
76
+ ) -> Tuple[wandb.Image, Dict, Dict, Dict]:
77
+ result = result.to("cpu")
78
+ boxes, masks, mean_confidence_map = get_boxes_and_masks(result)
79
+ image = wandb.Image(result.orig_img[:, :, ::-1], boxes=boxes, masks=masks)
80
+ if table is not None:
81
+ table.add_data(
82
+ model_name,
83
+ image,
84
+ len(boxes["predictions"]["box_data"]),
85
+ mean_confidence_map,
86
+ result.speed,
87
+ )
88
+ return table
89
+ return image, masks, boxes["predictions"], mean_confidence_map
90
+
91
+
92
+ def plot_mask_validation_results(
93
+ dataloader,
94
+ class_label_map,
95
+ model_name: str,
96
+ predictor: SegmentationPredictor,
97
+ table: wandb.Table,
98
+ max_validation_batches: int,
99
+ epoch: Optional[int] = None,
100
+ ):
101
+ data_idx = 0
102
+ for batch_idx, batch in enumerate(dataloader):
103
+ for img_idx, image_path in enumerate(batch["im_file"]):
104
+ prediction_result = predictor(image_path)[0]
105
+ (
106
+ _,
107
+ prediction_mask_data,
108
+ prediction_box_data,
109
+ mean_confidence_map,
110
+ ) = plot_mask_predictions(prediction_result, model_name)
111
+ try:
112
+ ground_truth_data = get_ground_truth_bbox_annotations(
113
+ img_idx, image_path, batch, class_label_map
114
+ )
115
+ wandb_image = wandb.Image(
116
+ image_path,
117
+ boxes={
118
+ "ground-truth": {
119
+ "box_data": ground_truth_data,
120
+ "class_labels": class_label_map,
121
+ },
122
+ "predictions": prediction_box_data,
123
+ },
124
+ masks=prediction_mask_data,
125
+ )
126
+ table_rows = [
127
+ data_idx,
128
+ batch_idx,
129
+ wandb_image,
130
+ mean_confidence_map,
131
+ prediction_result.speed,
132
+ ]
133
+ table_rows = [epoch] + table_rows if epoch is not None else table_rows
134
+ table_rows = [model_name] + table_rows
135
+ table.add_data(*table_rows)
136
+ data_idx += 1
137
+ except TypeError:
138
+ pass
139
+ if batch_idx + 1 == max_validation_batches:
140
+ break
141
+ return table