ultralytics 8.3.153__py3-none-any.whl → 8.3.155__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. tests/test_python.py +1 -0
  2. ultralytics/__init__.py +1 -1
  3. ultralytics/cfg/__init__.py +2 -0
  4. ultralytics/engine/predictor.py +1 -1
  5. ultralytics/engine/validator.py +0 -6
  6. ultralytics/models/fastsam/val.py +0 -2
  7. ultralytics/models/rtdetr/val.py +28 -16
  8. ultralytics/models/yolo/classify/val.py +26 -23
  9. ultralytics/models/yolo/detect/train.py +4 -7
  10. ultralytics/models/yolo/detect/val.py +88 -90
  11. ultralytics/models/yolo/obb/val.py +52 -44
  12. ultralytics/models/yolo/pose/train.py +1 -35
  13. ultralytics/models/yolo/pose/val.py +77 -176
  14. ultralytics/models/yolo/segment/train.py +1 -41
  15. ultralytics/models/yolo/segment/val.py +64 -176
  16. ultralytics/models/yolo/yoloe/val.py +2 -1
  17. ultralytics/nn/autobackend.py +2 -2
  18. ultralytics/nn/tasks.py +0 -1
  19. ultralytics/solutions/ai_gym.py +5 -5
  20. ultralytics/solutions/analytics.py +2 -2
  21. ultralytics/solutions/config.py +2 -2
  22. ultralytics/solutions/distance_calculation.py +1 -1
  23. ultralytics/solutions/heatmap.py +5 -3
  24. ultralytics/solutions/instance_segmentation.py +4 -2
  25. ultralytics/solutions/object_blurrer.py +4 -2
  26. ultralytics/solutions/object_counter.py +5 -5
  27. ultralytics/solutions/object_cropper.py +3 -2
  28. ultralytics/solutions/parking_management.py +9 -9
  29. ultralytics/solutions/queue_management.py +4 -2
  30. ultralytics/solutions/region_counter.py +13 -5
  31. ultralytics/solutions/security_alarm.py +6 -4
  32. ultralytics/solutions/similarity_search.py +6 -6
  33. ultralytics/solutions/solutions.py +9 -7
  34. ultralytics/solutions/speed_estimation.py +3 -2
  35. ultralytics/solutions/streamlit_inference.py +6 -6
  36. ultralytics/solutions/templates/similarity-search.html +31 -0
  37. ultralytics/solutions/trackzone.py +4 -2
  38. ultralytics/solutions/vision_eye.py +4 -2
  39. ultralytics/utils/callbacks/comet.py +1 -1
  40. ultralytics/utils/metrics.py +146 -317
  41. ultralytics/utils/ops.py +4 -4
  42. ultralytics/utils/plotting.py +31 -56
  43. {ultralytics-8.3.153.dist-info → ultralytics-8.3.155.dist-info}/METADATA +1 -1
  44. {ultralytics-8.3.153.dist-info → ultralytics-8.3.155.dist-info}/RECORD +48 -48
  45. {ultralytics-8.3.153.dist-info → ultralytics-8.3.155.dist-info}/WHEEL +0 -0
  46. {ultralytics-8.3.153.dist-info → ultralytics-8.3.155.dist-info}/entry_points.txt +0 -0
  47. {ultralytics-8.3.153.dist-info → ultralytics-8.3.155.dist-info}/licenses/LICENSE +0 -0
  48. {ultralytics-8.3.153.dist-info → ultralytics-8.3.155.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from pathlib import Path
4
- from typing import Any, Dict, List, Optional, Tuple
4
+ from typing import Any, Dict, Tuple
5
5
 
6
6
  import numpy as np
7
7
  import torch
@@ -9,8 +9,7 @@ import torch
9
9
  from ultralytics.models.yolo.detect import DetectionValidator
10
10
  from ultralytics.utils import LOGGER, ops
11
11
  from ultralytics.utils.checks import check_requirements
12
- from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, box_iou, kpt_iou
13
- from ultralytics.utils.plotting import output_to_target, plot_images
12
+ from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, kpt_iou
14
13
 
15
14
 
16
15
  class PoseValidator(DetectionValidator):
@@ -33,7 +32,6 @@ class PoseValidator(DetectionValidator):
33
32
  _prepare_batch: Prepare a batch for processing by converting keypoints to float and scaling to original
34
33
  dimensions.
35
34
  _prepare_pred: Prepare and scale keypoints in predictions for pose processing.
36
- update_metrics: Update metrics with new predictions and ground truth data.
37
35
  _process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between
38
36
  detections and ground truth.
39
37
  plot_val_samples: Plot and save validation set samples with ground truth bounding boxes and keypoints.
@@ -77,7 +75,7 @@ class PoseValidator(DetectionValidator):
77
75
  self.sigma = None
78
76
  self.kpt_shape = None
79
77
  self.args.task = "pose"
80
- self.metrics = PoseMetrics(save_dir=self.save_dir)
78
+ self.metrics = PoseMetrics()
81
79
  if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
82
80
  LOGGER.warning(
83
81
  "Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
@@ -118,7 +116,36 @@ class PoseValidator(DetectionValidator):
118
116
  is_pose = self.kpt_shape == [17, 3]
119
117
  nkpt = self.kpt_shape[0]
120
118
  self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
121
- self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
119
+
120
+ def postprocess(self, preds: torch.Tensor) -> Dict[str, torch.Tensor]:
121
+ """
122
+ Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
123
+
124
+ This method extends the parent class postprocessing by extracting keypoints from the 'extra'
125
+ field of predictions and reshaping them according to the keypoint shape configuration.
126
+ The keypoints are reshaped from a flattened format to the proper dimensional structure
127
+ (typically [N, 17, 3] for COCO pose format).
128
+
129
+ Args:
130
+ preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing
131
+ bounding boxes, confidence scores, class predictions, and keypoint data.
132
+
133
+ Returns:
134
+ (Dict[torch.Tensor]): Dict of processed prediction dictionaries, each containing:
135
+ - 'bboxes': Bounding box coordinates
136
+ - 'conf': Confidence scores
137
+ - 'cls': Class predictions
138
+ - 'keypoints': Reshaped keypoint coordinates with shape (-1, *self.kpt_shape)
139
+
140
+ Note:
141
+ If no keypoints are present in a prediction (empty keypoints), that prediction
142
+ is skipped and continues to the next one. The keypoints are extracted from the
143
+ 'extra' field which contains additional task-specific data beyond basic detection.
144
+ """
145
+ preds = super().postprocess(preds)
146
+ for pred in preds:
147
+ pred["keypoints"] = pred.pop("extra").reshape(-1, *self.kpt_shape) # remove extra if exists
148
+ return preds
122
149
 
123
150
  def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:
124
151
  """
@@ -142,10 +169,10 @@ class PoseValidator(DetectionValidator):
142
169
  kpts[..., 0] *= w
143
170
  kpts[..., 1] *= h
144
171
  kpts = ops.scale_coords(pbatch["imgsz"], kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
145
- pbatch["kpts"] = kpts
172
+ pbatch["keypoints"] = kpts
146
173
  return pbatch
147
174
 
148
- def _prepare_pred(self, pred: torch.Tensor, pbatch: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]:
175
+ def _prepare_pred(self, pred: Dict[str, Any], pbatch: Dict[str, Any]) -> Dict[str, Any]:
149
176
  """
150
177
  Prepare and scale keypoints in predictions for pose processing.
151
178
 
@@ -154,189 +181,59 @@ class PoseValidator(DetectionValidator):
154
181
  to match the original image dimensions.
155
182
 
156
183
  Args:
157
- pred (torch.Tensor): Raw prediction tensor from the model.
184
+ pred (Dict[str, torch.Tensor]): Post-processed predictions from the model.
158
185
  pbatch (Dict[str, Any]): Processed batch dictionary containing image information including:
159
186
  - imgsz: Image size used for inference
160
187
  - ori_shape: Original image shape
161
188
  - ratio_pad: Ratio and padding information for coordinate scaling
162
189
 
163
190
  Returns:
164
- predn (torch.Tensor): Processed prediction boxes scaled to original image dimensions.
165
- pred_kpts (torch.Tensor): Predicted keypoints scaled to original image dimensions.
191
+ (Dict[str, Any]): Processed prediction dictionary with keypoints scaled to original image dimensions.
166
192
  """
167
193
  predn = super()._prepare_pred(pred, pbatch)
168
- nk = pbatch["kpts"].shape[1]
169
- pred_kpts = predn[:, 6:].view(len(predn), nk, -1)
170
- ops.scale_coords(pbatch["imgsz"], pred_kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
171
- return predn, pred_kpts
172
-
173
- def update_metrics(self, preds: List[torch.Tensor], batch: Dict[str, Any]) -> None:
174
- """
175
- Update metrics with new predictions and ground truth data.
176
-
177
- This method processes each prediction, compares it with ground truth, and updates various statistics
178
- for performance evaluation.
194
+ predn["keypoints"] = ops.scale_coords(
195
+ pbatch["imgsz"], pred.get("keypoints").clone(), pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
196
+ )
197
+ return predn
179
198
 
180
- Args:
181
- preds (List[torch.Tensor]): List of prediction tensors from the model.
182
- batch (Dict[str, Any]): Batch data containing images and ground truth annotations.
183
- """
184
- for si, pred in enumerate(preds):
185
- self.seen += 1
186
- npr = len(pred)
187
- stat = dict(
188
- conf=torch.zeros(0, device=self.device),
189
- pred_cls=torch.zeros(0, device=self.device),
190
- tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
191
- tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
192
- )
193
- pbatch = self._prepare_batch(si, batch)
194
- cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
195
- nl = len(cls)
196
- stat["target_cls"] = cls
197
- stat["target_img"] = cls.unique()
198
- if npr == 0:
199
- if nl:
200
- for k in self.stats.keys():
201
- self.stats[k].append(stat[k])
202
- if self.args.plots:
203
- self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
204
- continue
205
-
206
- # Predictions
207
- if self.args.single_cls:
208
- pred[:, 5] = 0
209
- predn, pred_kpts = self._prepare_pred(pred, pbatch)
210
- stat["conf"] = predn[:, 4]
211
- stat["pred_cls"] = predn[:, 5]
212
-
213
- # Evaluate
214
- if nl:
215
- stat["tp"] = self._process_batch(predn, bbox, cls)
216
- stat["tp_p"] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch["kpts"])
217
- if self.args.plots:
218
- self.confusion_matrix.process_batch(predn, bbox, cls)
219
-
220
- for k in self.stats.keys():
221
- self.stats[k].append(stat[k])
222
-
223
- # Save
224
- if self.args.save_json:
225
- self.pred_to_json(predn, batch["im_file"][si])
226
- if self.args.save_txt:
227
- self.save_one_txt(
228
- predn,
229
- pred_kpts,
230
- self.args.save_conf,
231
- pbatch["ori_shape"],
232
- self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
233
- )
234
-
235
- def _process_batch(
236
- self,
237
- detections: torch.Tensor,
238
- gt_bboxes: torch.Tensor,
239
- gt_cls: torch.Tensor,
240
- pred_kpts: Optional[torch.Tensor] = None,
241
- gt_kpts: Optional[torch.Tensor] = None,
242
- ) -> torch.Tensor:
199
+ def _process_batch(self, preds: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, np.ndarray]:
243
200
  """
244
201
  Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.
245
202
 
246
203
  Args:
247
- detections (torch.Tensor): Tensor with shape (N, 6) representing detection boxes and scores, where each
248
- detection is of the format (x1, y1, x2, y2, conf, class).
249
- gt_bboxes (torch.Tensor): Tensor with shape (M, 4) representing ground truth bounding boxes, where each
250
- box is of the format (x1, y1, x2, y2).
251
- gt_cls (torch.Tensor): Tensor with shape (M,) representing ground truth class indices.
252
- pred_kpts (torch.Tensor, optional): Tensor with shape (N, 51) representing predicted keypoints, where
253
- 51 corresponds to 17 keypoints each having 3 values.
254
- gt_kpts (torch.Tensor, optional): Tensor with shape (N, 51) representing ground truth keypoints.
204
+ preds (Dict[str, torch.Tensor]): Dictionary containing prediction data with keys 'cls' for class predictions
205
+ and 'keypoints' for keypoint predictions.
206
+ batch (Dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels,
207
+ 'bboxes' for bounding boxes, and 'keypoints' for keypoint annotations.
255
208
 
256
209
  Returns:
257
- (torch.Tensor): A tensor with shape (N, 10) representing the correct prediction matrix for 10 IoU levels,
258
- where N is the number of detections.
210
+ (Dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose
211
+ true positives across 10 IoU levels.
259
212
 
260
213
  Notes:
261
214
  `0.53` scale factor used in area computation is referenced from
262
215
  https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
263
216
  """
264
- if pred_kpts is not None and gt_kpts is not None:
217
+ tp = super()._process_batch(preds, batch)
218
+ gt_cls = batch["cls"]
219
+ if len(gt_cls) == 0 or len(preds["cls"]) == 0:
220
+ tp_p = np.zeros((len(preds["cls"]), self.niou), dtype=bool)
221
+ else:
265
222
  # `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
266
- area = ops.xyxy2xywh(gt_bboxes)[:, 2:].prod(1) * 0.53
267
- iou = kpt_iou(gt_kpts, pred_kpts, sigma=self.sigma, area=area)
268
- else: # boxes
269
- iou = box_iou(gt_bboxes, detections[:, :4])
223
+ area = ops.xyxy2xywh(batch["bboxes"])[:, 2:].prod(1) * 0.53
224
+ iou = kpt_iou(batch["keypoints"], preds["keypoints"], sigma=self.sigma, area=area)
225
+ tp_p = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
226
+ tp.update({"tp_p": tp_p}) # update tp with kpts IoU
227
+ return tp
270
228
 
271
- return self.match_predictions(detections[:, 5], gt_cls, iou)
272
-
273
- def plot_val_samples(self, batch: Dict[str, Any], ni: int) -> None:
274
- """
275
- Plot and save validation set samples with ground truth bounding boxes and keypoints.
276
-
277
- Args:
278
- batch (Dict[str, Any]): Dictionary containing batch data with keys:
279
- - img (torch.Tensor): Batch of images
280
- - batch_idx (torch.Tensor): Batch indices for each image
281
- - cls (torch.Tensor): Class labels
282
- - bboxes (torch.Tensor): Bounding box coordinates
283
- - keypoints (torch.Tensor): Keypoint coordinates
284
- - im_file (list): List of image file paths
285
- ni (int): Batch index used for naming the output file
286
- """
287
- plot_images(
288
- batch["img"],
289
- batch["batch_idx"],
290
- batch["cls"].squeeze(-1),
291
- batch["bboxes"],
292
- kpts=batch["keypoints"],
293
- paths=batch["im_file"],
294
- fname=self.save_dir / f"val_batch{ni}_labels.jpg",
295
- names=self.names,
296
- on_plot=self.on_plot,
297
- )
298
-
299
- def plot_predictions(self, batch: Dict[str, Any], preds: List[torch.Tensor], ni: int) -> None:
300
- """
301
- Plot and save model predictions with bounding boxes and keypoints.
302
-
303
- Args:
304
- batch (Dict[str, Any]): Dictionary containing batch data including images, file paths, and other metadata.
305
- preds (List[torch.Tensor]): List of prediction tensors from the model, each containing bounding boxes,
306
- confidence scores, class predictions, and keypoints.
307
- ni (int): Batch index used for naming the output file.
308
-
309
- The function extracts keypoints from predictions, converts predictions to target format, and plots them
310
- on the input images. The resulting visualization is saved to the specified save directory.
311
- """
312
- pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
313
- plot_images(
314
- batch["img"],
315
- *output_to_target(preds, max_det=self.args.max_det),
316
- kpts=pred_kpts,
317
- paths=batch["im_file"],
318
- fname=self.save_dir / f"val_batch{ni}_pred.jpg",
319
- names=self.names,
320
- on_plot=self.on_plot,
321
- ) # pred
322
-
323
- def save_one_txt(
324
- self,
325
- predn: torch.Tensor,
326
- pred_kpts: torch.Tensor,
327
- save_conf: bool,
328
- shape: Tuple[int, int],
329
- file: Path,
330
- ) -> None:
229
+ def save_one_txt(self, predn: Dict[str, torch.Tensor], save_conf: bool, shape: Tuple[int, int], file: Path) -> None:
331
230
  """
332
231
  Save YOLO pose detections to a text file in normalized coordinates.
333
232
 
334
233
  Args:
335
- predn (torch.Tensor): Prediction boxes and scores with shape (N, 6) for (x1, y1, x2, y2, conf, cls).
336
- pred_kpts (torch.Tensor): Predicted keypoints with shape (N, K, D) where K is the number of keypoints
337
- and D is the dimension (typically 3 for x, y, visibility).
234
+ predn (Dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', 'cls' and 'keypoints.
338
235
  save_conf (bool): Whether to save confidence scores.
339
- shape (tuple): Original image shape (height, width).
236
+ shape (Tuple[int, int]): Shape of the original image (height, width).
340
237
  file (Path): Output file path to save detections.
341
238
 
342
239
  Notes:
@@ -349,11 +246,11 @@ class PoseValidator(DetectionValidator):
349
246
  np.zeros((shape[0], shape[1]), dtype=np.uint8),
350
247
  path=None,
351
248
  names=self.names,
352
- boxes=predn[:, :6],
353
- keypoints=pred_kpts,
249
+ boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
250
+ keypoints=predn["keypoints"],
354
251
  ).save_txt(file, save_conf=save_conf)
355
252
 
356
- def pred_to_json(self, predn: torch.Tensor, filename: str) -> None:
253
+ def pred_to_json(self, predn: Dict[str, torch.Tensor], filename: str) -> None:
357
254
  """
358
255
  Convert YOLO predictions to COCO JSON format.
359
256
 
@@ -361,10 +258,9 @@ class PoseValidator(DetectionValidator):
361
258
  to COCO format, and appends the results to the internal JSON dictionary (self.jdict).
362
259
 
363
260
  Args:
364
- predn (torch.Tensor): Prediction tensor containing bounding boxes, confidence scores, class IDs,
365
- and keypoints, with shape (N, 6+K) where N is the number of predictions and K is the flattened
366
- keypoints dimension.
367
- filename (str | Path): Path to the image file for which predictions are being processed.
261
+ predn (Dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls',
262
+ and 'keypoints' tensors.
263
+ filename (str): Path to the image file for which predictions are being processed.
368
264
 
369
265
  Notes:
370
266
  The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string),
@@ -373,16 +269,21 @@ class PoseValidator(DetectionValidator):
373
269
  """
374
270
  stem = Path(filename).stem
375
271
  image_id = int(stem) if stem.isnumeric() else stem
376
- box = ops.xyxy2xywh(predn[:, :4]) # xywh
272
+ box = ops.xyxy2xywh(predn["bboxes"]) # xywh
377
273
  box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
378
- for p, b in zip(predn.tolist(), box.tolist()):
274
+ for b, s, c, k in zip(
275
+ box.tolist(),
276
+ predn["conf"].tolist(),
277
+ predn["cls"].tolist(),
278
+ predn["keypoints"].flatten(1, 2).tolist(),
279
+ ):
379
280
  self.jdict.append(
380
281
  {
381
282
  "image_id": image_id,
382
- "category_id": self.class_map[int(p[5])],
283
+ "category_id": self.class_map[int(c)],
383
284
  "bbox": [round(x, 3) for x in b],
384
- "keypoints": p[6:],
385
- "score": round(p[4], 5),
285
+ "keypoints": k,
286
+ "score": round(s, 5),
386
287
  }
387
288
  )
388
289
 
@@ -7,7 +7,7 @@ from typing import Dict, Optional, Union
7
7
  from ultralytics.models import yolo
8
8
  from ultralytics.nn.tasks import SegmentationModel
9
9
  from ultralytics.utils import DEFAULT_CFG, RANK
10
- from ultralytics.utils.plotting import plot_images, plot_results
10
+ from ultralytics.utils.plotting import plot_results
11
11
 
12
12
 
13
13
  class SegmentationTrainer(yolo.detect.DetectionTrainer):
@@ -82,46 +82,6 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
82
82
  self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
83
83
  )
84
84
 
85
- def plot_training_samples(self, batch: Dict, ni: int):
86
- """
87
- Plot training sample images with labels, bounding boxes, and masks.
88
-
89
- This method creates a visualization of training batch images with their corresponding labels, bounding boxes,
90
- and segmentation masks, saving the result to a file for inspection and debugging.
91
-
92
- Args:
93
- batch (dict): Dictionary containing batch data with the following keys:
94
- 'img': Images tensor
95
- 'batch_idx': Batch indices for each box
96
- 'cls': Class labels tensor (squeezed to remove last dimension)
97
- 'bboxes': Bounding box coordinates tensor
98
- 'masks': Segmentation masks tensor
99
- 'im_file': List of image file paths
100
- ni (int): Current training iteration number, used for naming the output file.
101
-
102
- Examples:
103
- >>> trainer = SegmentationTrainer()
104
- >>> batch = {
105
- ... "img": torch.rand(16, 3, 640, 640),
106
- ... "batch_idx": torch.zeros(16),
107
- ... "cls": torch.randint(0, 80, (16, 1)),
108
- ... "bboxes": torch.rand(16, 4),
109
- ... "masks": torch.rand(16, 640, 640),
110
- ... "im_file": ["image1.jpg", "image2.jpg"],
111
- ... }
112
- >>> trainer.plot_training_samples(batch, ni=5)
113
- """
114
- plot_images(
115
- batch["img"],
116
- batch["batch_idx"],
117
- batch["cls"].squeeze(-1),
118
- batch["bboxes"],
119
- masks=batch["masks"],
120
- paths=batch["im_file"],
121
- fname=self.save_dir / f"train_batch{ni}.jpg",
122
- on_plot=self.on_plot,
123
- )
124
-
125
85
  def plot_metrics(self):
126
86
  """Plot training/validation metrics."""
127
87
  plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png