ultralytics 8.3.153__py3-none-any.whl → 8.3.155__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_python.py +1 -0
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +2 -0
- ultralytics/engine/predictor.py +1 -1
- ultralytics/engine/validator.py +0 -6
- ultralytics/models/fastsam/val.py +0 -2
- ultralytics/models/rtdetr/val.py +28 -16
- ultralytics/models/yolo/classify/val.py +26 -23
- ultralytics/models/yolo/detect/train.py +4 -7
- ultralytics/models/yolo/detect/val.py +88 -90
- ultralytics/models/yolo/obb/val.py +52 -44
- ultralytics/models/yolo/pose/train.py +1 -35
- ultralytics/models/yolo/pose/val.py +77 -176
- ultralytics/models/yolo/segment/train.py +1 -41
- ultralytics/models/yolo/segment/val.py +64 -176
- ultralytics/models/yolo/yoloe/val.py +2 -1
- ultralytics/nn/autobackend.py +2 -2
- ultralytics/nn/tasks.py +0 -1
- ultralytics/solutions/ai_gym.py +5 -5
- ultralytics/solutions/analytics.py +2 -2
- ultralytics/solutions/config.py +2 -2
- ultralytics/solutions/distance_calculation.py +1 -1
- ultralytics/solutions/heatmap.py +5 -3
- ultralytics/solutions/instance_segmentation.py +4 -2
- ultralytics/solutions/object_blurrer.py +4 -2
- ultralytics/solutions/object_counter.py +5 -5
- ultralytics/solutions/object_cropper.py +3 -2
- ultralytics/solutions/parking_management.py +9 -9
- ultralytics/solutions/queue_management.py +4 -2
- ultralytics/solutions/region_counter.py +13 -5
- ultralytics/solutions/security_alarm.py +6 -4
- ultralytics/solutions/similarity_search.py +6 -6
- ultralytics/solutions/solutions.py +9 -7
- ultralytics/solutions/speed_estimation.py +3 -2
- ultralytics/solutions/streamlit_inference.py +6 -6
- ultralytics/solutions/templates/similarity-search.html +31 -0
- ultralytics/solutions/trackzone.py +4 -2
- ultralytics/solutions/vision_eye.py +4 -2
- ultralytics/utils/callbacks/comet.py +1 -1
- ultralytics/utils/metrics.py +146 -317
- ultralytics/utils/ops.py +4 -4
- ultralytics/utils/plotting.py +31 -56
- {ultralytics-8.3.153.dist-info → ultralytics-8.3.155.dist-info}/METADATA +1 -1
- {ultralytics-8.3.153.dist-info → ultralytics-8.3.155.dist-info}/RECORD +48 -48
- {ultralytics-8.3.153.dist-info → ultralytics-8.3.155.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.153.dist-info → ultralytics-8.3.155.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.153.dist-info → ultralytics-8.3.155.dist-info}/licenses/LICENSE +0 -0
- {ultralytics-8.3.153.dist-info → ultralytics-8.3.155.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
from pathlib import Path
|
4
|
-
from typing import Any, Dict,
|
4
|
+
from typing import Any, Dict, Tuple
|
5
5
|
|
6
6
|
import numpy as np
|
7
7
|
import torch
|
@@ -9,8 +9,7 @@ import torch
|
|
9
9
|
from ultralytics.models.yolo.detect import DetectionValidator
|
10
10
|
from ultralytics.utils import LOGGER, ops
|
11
11
|
from ultralytics.utils.checks import check_requirements
|
12
|
-
from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics,
|
13
|
-
from ultralytics.utils.plotting import output_to_target, plot_images
|
12
|
+
from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, kpt_iou
|
14
13
|
|
15
14
|
|
16
15
|
class PoseValidator(DetectionValidator):
|
@@ -33,7 +32,6 @@ class PoseValidator(DetectionValidator):
|
|
33
32
|
_prepare_batch: Prepare a batch for processing by converting keypoints to float and scaling to original
|
34
33
|
dimensions.
|
35
34
|
_prepare_pred: Prepare and scale keypoints in predictions for pose processing.
|
36
|
-
update_metrics: Update metrics with new predictions and ground truth data.
|
37
35
|
_process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between
|
38
36
|
detections and ground truth.
|
39
37
|
plot_val_samples: Plot and save validation set samples with ground truth bounding boxes and keypoints.
|
@@ -77,7 +75,7 @@ class PoseValidator(DetectionValidator):
|
|
77
75
|
self.sigma = None
|
78
76
|
self.kpt_shape = None
|
79
77
|
self.args.task = "pose"
|
80
|
-
self.metrics = PoseMetrics(
|
78
|
+
self.metrics = PoseMetrics()
|
81
79
|
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
82
80
|
LOGGER.warning(
|
83
81
|
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
@@ -118,7 +116,36 @@ class PoseValidator(DetectionValidator):
|
|
118
116
|
is_pose = self.kpt_shape == [17, 3]
|
119
117
|
nkpt = self.kpt_shape[0]
|
120
118
|
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
|
121
|
-
|
119
|
+
|
120
|
+
def postprocess(self, preds: torch.Tensor) -> Dict[str, torch.Tensor]:
|
121
|
+
"""
|
122
|
+
Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
|
123
|
+
|
124
|
+
This method extends the parent class postprocessing by extracting keypoints from the 'extra'
|
125
|
+
field of predictions and reshaping them according to the keypoint shape configuration.
|
126
|
+
The keypoints are reshaped from a flattened format to the proper dimensional structure
|
127
|
+
(typically [N, 17, 3] for COCO pose format).
|
128
|
+
|
129
|
+
Args:
|
130
|
+
preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing
|
131
|
+
bounding boxes, confidence scores, class predictions, and keypoint data.
|
132
|
+
|
133
|
+
Returns:
|
134
|
+
(Dict[torch.Tensor]): Dict of processed prediction dictionaries, each containing:
|
135
|
+
- 'bboxes': Bounding box coordinates
|
136
|
+
- 'conf': Confidence scores
|
137
|
+
- 'cls': Class predictions
|
138
|
+
- 'keypoints': Reshaped keypoint coordinates with shape (-1, *self.kpt_shape)
|
139
|
+
|
140
|
+
Note:
|
141
|
+
If no keypoints are present in a prediction (empty keypoints), that prediction
|
142
|
+
is skipped and continues to the next one. The keypoints are extracted from the
|
143
|
+
'extra' field which contains additional task-specific data beyond basic detection.
|
144
|
+
"""
|
145
|
+
preds = super().postprocess(preds)
|
146
|
+
for pred in preds:
|
147
|
+
pred["keypoints"] = pred.pop("extra").reshape(-1, *self.kpt_shape) # remove extra if exists
|
148
|
+
return preds
|
122
149
|
|
123
150
|
def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:
|
124
151
|
"""
|
@@ -142,10 +169,10 @@ class PoseValidator(DetectionValidator):
|
|
142
169
|
kpts[..., 0] *= w
|
143
170
|
kpts[..., 1] *= h
|
144
171
|
kpts = ops.scale_coords(pbatch["imgsz"], kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
|
145
|
-
pbatch["
|
172
|
+
pbatch["keypoints"] = kpts
|
146
173
|
return pbatch
|
147
174
|
|
148
|
-
def _prepare_pred(self, pred:
|
175
|
+
def _prepare_pred(self, pred: Dict[str, Any], pbatch: Dict[str, Any]) -> Dict[str, Any]:
|
149
176
|
"""
|
150
177
|
Prepare and scale keypoints in predictions for pose processing.
|
151
178
|
|
@@ -154,189 +181,59 @@ class PoseValidator(DetectionValidator):
|
|
154
181
|
to match the original image dimensions.
|
155
182
|
|
156
183
|
Args:
|
157
|
-
pred (torch.Tensor):
|
184
|
+
pred (Dict[str, torch.Tensor]): Post-processed predictions from the model.
|
158
185
|
pbatch (Dict[str, Any]): Processed batch dictionary containing image information including:
|
159
186
|
- imgsz: Image size used for inference
|
160
187
|
- ori_shape: Original image shape
|
161
188
|
- ratio_pad: Ratio and padding information for coordinate scaling
|
162
189
|
|
163
190
|
Returns:
|
164
|
-
|
165
|
-
pred_kpts (torch.Tensor): Predicted keypoints scaled to original image dimensions.
|
191
|
+
(Dict[str, Any]): Processed prediction dictionary with keypoints scaled to original image dimensions.
|
166
192
|
"""
|
167
193
|
predn = super()._prepare_pred(pred, pbatch)
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
return predn
|
172
|
-
|
173
|
-
def update_metrics(self, preds: List[torch.Tensor], batch: Dict[str, Any]) -> None:
|
174
|
-
"""
|
175
|
-
Update metrics with new predictions and ground truth data.
|
176
|
-
|
177
|
-
This method processes each prediction, compares it with ground truth, and updates various statistics
|
178
|
-
for performance evaluation.
|
194
|
+
predn["keypoints"] = ops.scale_coords(
|
195
|
+
pbatch["imgsz"], pred.get("keypoints").clone(), pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
|
196
|
+
)
|
197
|
+
return predn
|
179
198
|
|
180
|
-
|
181
|
-
preds (List[torch.Tensor]): List of prediction tensors from the model.
|
182
|
-
batch (Dict[str, Any]): Batch data containing images and ground truth annotations.
|
183
|
-
"""
|
184
|
-
for si, pred in enumerate(preds):
|
185
|
-
self.seen += 1
|
186
|
-
npr = len(pred)
|
187
|
-
stat = dict(
|
188
|
-
conf=torch.zeros(0, device=self.device),
|
189
|
-
pred_cls=torch.zeros(0, device=self.device),
|
190
|
-
tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
|
191
|
-
tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
|
192
|
-
)
|
193
|
-
pbatch = self._prepare_batch(si, batch)
|
194
|
-
cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
|
195
|
-
nl = len(cls)
|
196
|
-
stat["target_cls"] = cls
|
197
|
-
stat["target_img"] = cls.unique()
|
198
|
-
if npr == 0:
|
199
|
-
if nl:
|
200
|
-
for k in self.stats.keys():
|
201
|
-
self.stats[k].append(stat[k])
|
202
|
-
if self.args.plots:
|
203
|
-
self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
|
204
|
-
continue
|
205
|
-
|
206
|
-
# Predictions
|
207
|
-
if self.args.single_cls:
|
208
|
-
pred[:, 5] = 0
|
209
|
-
predn, pred_kpts = self._prepare_pred(pred, pbatch)
|
210
|
-
stat["conf"] = predn[:, 4]
|
211
|
-
stat["pred_cls"] = predn[:, 5]
|
212
|
-
|
213
|
-
# Evaluate
|
214
|
-
if nl:
|
215
|
-
stat["tp"] = self._process_batch(predn, bbox, cls)
|
216
|
-
stat["tp_p"] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch["kpts"])
|
217
|
-
if self.args.plots:
|
218
|
-
self.confusion_matrix.process_batch(predn, bbox, cls)
|
219
|
-
|
220
|
-
for k in self.stats.keys():
|
221
|
-
self.stats[k].append(stat[k])
|
222
|
-
|
223
|
-
# Save
|
224
|
-
if self.args.save_json:
|
225
|
-
self.pred_to_json(predn, batch["im_file"][si])
|
226
|
-
if self.args.save_txt:
|
227
|
-
self.save_one_txt(
|
228
|
-
predn,
|
229
|
-
pred_kpts,
|
230
|
-
self.args.save_conf,
|
231
|
-
pbatch["ori_shape"],
|
232
|
-
self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
|
233
|
-
)
|
234
|
-
|
235
|
-
def _process_batch(
|
236
|
-
self,
|
237
|
-
detections: torch.Tensor,
|
238
|
-
gt_bboxes: torch.Tensor,
|
239
|
-
gt_cls: torch.Tensor,
|
240
|
-
pred_kpts: Optional[torch.Tensor] = None,
|
241
|
-
gt_kpts: Optional[torch.Tensor] = None,
|
242
|
-
) -> torch.Tensor:
|
199
|
+
def _process_batch(self, preds: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, np.ndarray]:
|
243
200
|
"""
|
244
201
|
Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.
|
245
202
|
|
246
203
|
Args:
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
gt_cls (torch.Tensor): Tensor with shape (M,) representing ground truth class indices.
|
252
|
-
pred_kpts (torch.Tensor, optional): Tensor with shape (N, 51) representing predicted keypoints, where
|
253
|
-
51 corresponds to 17 keypoints each having 3 values.
|
254
|
-
gt_kpts (torch.Tensor, optional): Tensor with shape (N, 51) representing ground truth keypoints.
|
204
|
+
preds (Dict[str, torch.Tensor]): Dictionary containing prediction data with keys 'cls' for class predictions
|
205
|
+
and 'keypoints' for keypoint predictions.
|
206
|
+
batch (Dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels,
|
207
|
+
'bboxes' for bounding boxes, and 'keypoints' for keypoint annotations.
|
255
208
|
|
256
209
|
Returns:
|
257
|
-
(
|
258
|
-
|
210
|
+
(Dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose
|
211
|
+
true positives across 10 IoU levels.
|
259
212
|
|
260
213
|
Notes:
|
261
214
|
`0.53` scale factor used in area computation is referenced from
|
262
215
|
https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
|
263
216
|
"""
|
264
|
-
|
217
|
+
tp = super()._process_batch(preds, batch)
|
218
|
+
gt_cls = batch["cls"]
|
219
|
+
if len(gt_cls) == 0 or len(preds["cls"]) == 0:
|
220
|
+
tp_p = np.zeros((len(preds["cls"]), self.niou), dtype=bool)
|
221
|
+
else:
|
265
222
|
# `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
|
266
|
-
area = ops.xyxy2xywh(
|
267
|
-
iou = kpt_iou(
|
268
|
-
|
269
|
-
|
223
|
+
area = ops.xyxy2xywh(batch["bboxes"])[:, 2:].prod(1) * 0.53
|
224
|
+
iou = kpt_iou(batch["keypoints"], preds["keypoints"], sigma=self.sigma, area=area)
|
225
|
+
tp_p = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
|
226
|
+
tp.update({"tp_p": tp_p}) # update tp with kpts IoU
|
227
|
+
return tp
|
270
228
|
|
271
|
-
|
272
|
-
|
273
|
-
def plot_val_samples(self, batch: Dict[str, Any], ni: int) -> None:
|
274
|
-
"""
|
275
|
-
Plot and save validation set samples with ground truth bounding boxes and keypoints.
|
276
|
-
|
277
|
-
Args:
|
278
|
-
batch (Dict[str, Any]): Dictionary containing batch data with keys:
|
279
|
-
- img (torch.Tensor): Batch of images
|
280
|
-
- batch_idx (torch.Tensor): Batch indices for each image
|
281
|
-
- cls (torch.Tensor): Class labels
|
282
|
-
- bboxes (torch.Tensor): Bounding box coordinates
|
283
|
-
- keypoints (torch.Tensor): Keypoint coordinates
|
284
|
-
- im_file (list): List of image file paths
|
285
|
-
ni (int): Batch index used for naming the output file
|
286
|
-
"""
|
287
|
-
plot_images(
|
288
|
-
batch["img"],
|
289
|
-
batch["batch_idx"],
|
290
|
-
batch["cls"].squeeze(-1),
|
291
|
-
batch["bboxes"],
|
292
|
-
kpts=batch["keypoints"],
|
293
|
-
paths=batch["im_file"],
|
294
|
-
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
295
|
-
names=self.names,
|
296
|
-
on_plot=self.on_plot,
|
297
|
-
)
|
298
|
-
|
299
|
-
def plot_predictions(self, batch: Dict[str, Any], preds: List[torch.Tensor], ni: int) -> None:
|
300
|
-
"""
|
301
|
-
Plot and save model predictions with bounding boxes and keypoints.
|
302
|
-
|
303
|
-
Args:
|
304
|
-
batch (Dict[str, Any]): Dictionary containing batch data including images, file paths, and other metadata.
|
305
|
-
preds (List[torch.Tensor]): List of prediction tensors from the model, each containing bounding boxes,
|
306
|
-
confidence scores, class predictions, and keypoints.
|
307
|
-
ni (int): Batch index used for naming the output file.
|
308
|
-
|
309
|
-
The function extracts keypoints from predictions, converts predictions to target format, and plots them
|
310
|
-
on the input images. The resulting visualization is saved to the specified save directory.
|
311
|
-
"""
|
312
|
-
pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
|
313
|
-
plot_images(
|
314
|
-
batch["img"],
|
315
|
-
*output_to_target(preds, max_det=self.args.max_det),
|
316
|
-
kpts=pred_kpts,
|
317
|
-
paths=batch["im_file"],
|
318
|
-
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
|
319
|
-
names=self.names,
|
320
|
-
on_plot=self.on_plot,
|
321
|
-
) # pred
|
322
|
-
|
323
|
-
def save_one_txt(
|
324
|
-
self,
|
325
|
-
predn: torch.Tensor,
|
326
|
-
pred_kpts: torch.Tensor,
|
327
|
-
save_conf: bool,
|
328
|
-
shape: Tuple[int, int],
|
329
|
-
file: Path,
|
330
|
-
) -> None:
|
229
|
+
def save_one_txt(self, predn: Dict[str, torch.Tensor], save_conf: bool, shape: Tuple[int, int], file: Path) -> None:
|
331
230
|
"""
|
332
231
|
Save YOLO pose detections to a text file in normalized coordinates.
|
333
232
|
|
334
233
|
Args:
|
335
|
-
predn (torch.Tensor):
|
336
|
-
pred_kpts (torch.Tensor): Predicted keypoints with shape (N, K, D) where K is the number of keypoints
|
337
|
-
and D is the dimension (typically 3 for x, y, visibility).
|
234
|
+
predn (Dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', 'cls' and 'keypoints.
|
338
235
|
save_conf (bool): Whether to save confidence scores.
|
339
|
-
shape (
|
236
|
+
shape (Tuple[int, int]): Shape of the original image (height, width).
|
340
237
|
file (Path): Output file path to save detections.
|
341
238
|
|
342
239
|
Notes:
|
@@ -349,11 +246,11 @@ class PoseValidator(DetectionValidator):
|
|
349
246
|
np.zeros((shape[0], shape[1]), dtype=np.uint8),
|
350
247
|
path=None,
|
351
248
|
names=self.names,
|
352
|
-
boxes=predn[
|
353
|
-
keypoints=
|
249
|
+
boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
|
250
|
+
keypoints=predn["keypoints"],
|
354
251
|
).save_txt(file, save_conf=save_conf)
|
355
252
|
|
356
|
-
def pred_to_json(self, predn: torch.Tensor, filename: str) -> None:
|
253
|
+
def pred_to_json(self, predn: Dict[str, torch.Tensor], filename: str) -> None:
|
357
254
|
"""
|
358
255
|
Convert YOLO predictions to COCO JSON format.
|
359
256
|
|
@@ -361,10 +258,9 @@ class PoseValidator(DetectionValidator):
|
|
361
258
|
to COCO format, and appends the results to the internal JSON dictionary (self.jdict).
|
362
259
|
|
363
260
|
Args:
|
364
|
-
predn (torch.Tensor): Prediction
|
365
|
-
and keypoints
|
366
|
-
|
367
|
-
filename (str | Path): Path to the image file for which predictions are being processed.
|
261
|
+
predn (Dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls',
|
262
|
+
and 'keypoints' tensors.
|
263
|
+
filename (str): Path to the image file for which predictions are being processed.
|
368
264
|
|
369
265
|
Notes:
|
370
266
|
The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string),
|
@@ -373,16 +269,21 @@ class PoseValidator(DetectionValidator):
|
|
373
269
|
"""
|
374
270
|
stem = Path(filename).stem
|
375
271
|
image_id = int(stem) if stem.isnumeric() else stem
|
376
|
-
box = ops.xyxy2xywh(predn[
|
272
|
+
box = ops.xyxy2xywh(predn["bboxes"]) # xywh
|
377
273
|
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
|
378
|
-
for
|
274
|
+
for b, s, c, k in zip(
|
275
|
+
box.tolist(),
|
276
|
+
predn["conf"].tolist(),
|
277
|
+
predn["cls"].tolist(),
|
278
|
+
predn["keypoints"].flatten(1, 2).tolist(),
|
279
|
+
):
|
379
280
|
self.jdict.append(
|
380
281
|
{
|
381
282
|
"image_id": image_id,
|
382
|
-
"category_id": self.class_map[int(
|
283
|
+
"category_id": self.class_map[int(c)],
|
383
284
|
"bbox": [round(x, 3) for x in b],
|
384
|
-
"keypoints":
|
385
|
-
"score": round(
|
285
|
+
"keypoints": k,
|
286
|
+
"score": round(s, 5),
|
386
287
|
}
|
387
288
|
)
|
388
289
|
|
@@ -7,7 +7,7 @@ from typing import Dict, Optional, Union
|
|
7
7
|
from ultralytics.models import yolo
|
8
8
|
from ultralytics.nn.tasks import SegmentationModel
|
9
9
|
from ultralytics.utils import DEFAULT_CFG, RANK
|
10
|
-
from ultralytics.utils.plotting import
|
10
|
+
from ultralytics.utils.plotting import plot_results
|
11
11
|
|
12
12
|
|
13
13
|
class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
@@ -82,46 +82,6 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|
82
82
|
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
83
83
|
)
|
84
84
|
|
85
|
-
def plot_training_samples(self, batch: Dict, ni: int):
|
86
|
-
"""
|
87
|
-
Plot training sample images with labels, bounding boxes, and masks.
|
88
|
-
|
89
|
-
This method creates a visualization of training batch images with their corresponding labels, bounding boxes,
|
90
|
-
and segmentation masks, saving the result to a file for inspection and debugging.
|
91
|
-
|
92
|
-
Args:
|
93
|
-
batch (dict): Dictionary containing batch data with the following keys:
|
94
|
-
'img': Images tensor
|
95
|
-
'batch_idx': Batch indices for each box
|
96
|
-
'cls': Class labels tensor (squeezed to remove last dimension)
|
97
|
-
'bboxes': Bounding box coordinates tensor
|
98
|
-
'masks': Segmentation masks tensor
|
99
|
-
'im_file': List of image file paths
|
100
|
-
ni (int): Current training iteration number, used for naming the output file.
|
101
|
-
|
102
|
-
Examples:
|
103
|
-
>>> trainer = SegmentationTrainer()
|
104
|
-
>>> batch = {
|
105
|
-
... "img": torch.rand(16, 3, 640, 640),
|
106
|
-
... "batch_idx": torch.zeros(16),
|
107
|
-
... "cls": torch.randint(0, 80, (16, 1)),
|
108
|
-
... "bboxes": torch.rand(16, 4),
|
109
|
-
... "masks": torch.rand(16, 640, 640),
|
110
|
-
... "im_file": ["image1.jpg", "image2.jpg"],
|
111
|
-
... }
|
112
|
-
>>> trainer.plot_training_samples(batch, ni=5)
|
113
|
-
"""
|
114
|
-
plot_images(
|
115
|
-
batch["img"],
|
116
|
-
batch["batch_idx"],
|
117
|
-
batch["cls"].squeeze(-1),
|
118
|
-
batch["bboxes"],
|
119
|
-
masks=batch["masks"],
|
120
|
-
paths=batch["im_file"],
|
121
|
-
fname=self.save_dir / f"train_batch{ni}.jpg",
|
122
|
-
on_plot=self.on_plot,
|
123
|
-
)
|
124
|
-
|
125
85
|
def plot_metrics(self):
|
126
86
|
"""Plot training/validation metrics."""
|
127
87
|
plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png
|