dgenerate-ultralytics-headless 8.3.197__py3-none-any.whl → 8.3.198__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {dgenerate_ultralytics_headless-8.3.197.dist-info → dgenerate_ultralytics_headless-8.3.198.dist-info}/METADATA +1 -1
  2. {dgenerate_ultralytics_headless-8.3.197.dist-info → dgenerate_ultralytics_headless-8.3.198.dist-info}/RECORD +42 -42
  3. tests/test_engine.py +9 -1
  4. ultralytics/__init__.py +1 -1
  5. ultralytics/cfg/__init__.py +0 -1
  6. ultralytics/cfg/default.yaml +96 -94
  7. ultralytics/cfg/trackers/botsort.yaml +16 -17
  8. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  9. ultralytics/data/augment.py +1 -1
  10. ultralytics/data/dataset.py +1 -1
  11. ultralytics/engine/exporter.py +35 -35
  12. ultralytics/engine/predictor.py +1 -2
  13. ultralytics/engine/results.py +1 -1
  14. ultralytics/engine/trainer.py +5 -5
  15. ultralytics/engine/tuner.py +54 -32
  16. ultralytics/models/sam/modules/decoders.py +3 -3
  17. ultralytics/models/sam/modules/sam.py +5 -5
  18. ultralytics/models/sam/predict.py +11 -11
  19. ultralytics/models/yolo/classify/train.py +2 -7
  20. ultralytics/models/yolo/classify/val.py +2 -2
  21. ultralytics/models/yolo/detect/predict.py +1 -1
  22. ultralytics/models/yolo/detect/train.py +1 -6
  23. ultralytics/models/yolo/detect/val.py +4 -4
  24. ultralytics/models/yolo/obb/val.py +3 -3
  25. ultralytics/models/yolo/pose/predict.py +1 -1
  26. ultralytics/models/yolo/pose/train.py +0 -6
  27. ultralytics/models/yolo/pose/val.py +2 -2
  28. ultralytics/models/yolo/segment/predict.py +2 -2
  29. ultralytics/models/yolo/segment/train.py +0 -5
  30. ultralytics/models/yolo/segment/val.py +9 -7
  31. ultralytics/models/yolo/yoloe/val.py +1 -1
  32. ultralytics/nn/modules/block.py +1 -1
  33. ultralytics/nn/tasks.py +2 -2
  34. ultralytics/utils/checks.py +1 -1
  35. ultralytics/utils/metrics.py +6 -6
  36. ultralytics/utils/nms.py +5 -13
  37. ultralytics/utils/plotting.py +22 -36
  38. ultralytics/utils/torch_utils.py +9 -5
  39. {dgenerate_ultralytics_headless-8.3.197.dist-info → dgenerate_ultralytics_headless-8.3.198.dist-info}/WHEEL +0 -0
  40. {dgenerate_ultralytics_headless-8.3.197.dist-info → dgenerate_ultralytics_headless-8.3.198.dist-info}/entry_points.txt +0 -0
  41. {dgenerate_ultralytics_headless-8.3.197.dist-info → dgenerate_ultralytics_headless-8.3.198.dist-info}/licenses/LICENSE +0 -0
  42. {dgenerate_ultralytics_headless-8.3.197.dist-info → dgenerate_ultralytics_headless-8.3.198.dist-info}/top_level.txt +0 -0
@@ -178,7 +178,7 @@ class ClassificationValidator(BaseValidator):
178
178
  >>> batch = {"img": torch.rand(16, 3, 224, 224), "cls": torch.randint(0, 10, (16,))}
179
179
  >>> validator.plot_val_samples(batch, 0)
180
180
  """
181
- batch["batch_idx"] = torch.arange(len(batch["img"])) # add batch index for plotting
181
+ batch["batch_idx"] = torch.arange(batch["img"].shape[0]) # add batch index for plotting
182
182
  plot_images(
183
183
  labels=batch,
184
184
  fname=self.save_dir / f"val_batch{ni}_labels.jpg",
@@ -203,7 +203,7 @@ class ClassificationValidator(BaseValidator):
203
203
  """
204
204
  batched_preds = dict(
205
205
  img=batch["img"],
206
- batch_idx=torch.arange(len(batch["img"])),
206
+ batch_idx=torch.arange(batch["img"].shape[0]),
207
207
  cls=torch.argmax(preds, dim=1),
208
208
  )
209
209
  plot_images(
@@ -89,7 +89,7 @@ class DetectionPredictor(BasePredictor):
89
89
  obj_feats = torch.cat(
90
90
  [x.permute(0, 2, 3, 1).reshape(x.shape[0], -1, s, x.shape[1] // s).mean(dim=-1) for x in feat_maps], dim=1
91
91
  ) # mean reduce all vectors to same length
92
- return [feats[idx] if len(idx) else [] for feats, idx in zip(obj_feats, idxs)] # for each img in batch
92
+ return [feats[idx] if idx.shape[0] else [] for feats, idx in zip(obj_feats, idxs)] # for each img in batch
93
93
 
94
94
  def construct_results(self, preds, img, orig_imgs):
95
95
  """
@@ -17,7 +17,7 @@ from ultralytics.models import yolo
17
17
  from ultralytics.nn.tasks import DetectionModel
18
18
  from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
19
19
  from ultralytics.utils.patches import override_configs
20
- from ultralytics.utils.plotting import plot_images, plot_labels, plot_results
20
+ from ultralytics.utils.plotting import plot_images, plot_labels
21
21
  from ultralytics.utils.torch_utils import torch_distributed_zero_first, unwrap_model
22
22
 
23
23
 
@@ -43,7 +43,6 @@ class DetectionTrainer(BaseTrainer):
43
43
  label_loss_items: Return a loss dictionary with labeled training loss items.
44
44
  progress_string: Return a formatted string of training progress.
45
45
  plot_training_samples: Plot training samples with their annotations.
46
- plot_metrics: Plot metrics from a CSV file.
47
46
  plot_training_labels: Create a labeled training plot of the YOLO model.
48
47
  auto_batch: Calculate optimal batch size based on model memory requirements.
49
48
 
@@ -217,10 +216,6 @@ class DetectionTrainer(BaseTrainer):
217
216
  on_plot=self.on_plot,
218
217
  )
219
218
 
220
- def plot_metrics(self):
221
- """Plot metrics from a CSV file."""
222
- plot_results(file=self.csv, on_plot=self.on_plot) # save results.png
223
-
224
219
  def plot_training_labels(self):
225
220
  """Create a labeled training plot of the YOLO model."""
226
221
  boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
@@ -146,7 +146,7 @@ class DetectionValidator(BaseValidator):
146
146
  ori_shape = batch["ori_shape"][si]
147
147
  imgsz = batch["img"].shape[2:]
148
148
  ratio_pad = batch["ratio_pad"][si]
149
- if len(cls):
149
+ if cls.shape[0]:
150
150
  bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes
151
151
  return {
152
152
  "cls": cls,
@@ -185,7 +185,7 @@ class DetectionValidator(BaseValidator):
185
185
  predn = self._prepare_pred(pred)
186
186
 
187
187
  cls = pbatch["cls"].cpu().numpy()
188
- no_pred = len(predn["cls"]) == 0
188
+ no_pred = predn["cls"].shape[0] == 0
189
189
  self.metrics.update_stats(
190
190
  {
191
191
  **self._process_batch(predn, pbatch),
@@ -268,8 +268,8 @@ class DetectionValidator(BaseValidator):
268
268
  Returns:
269
269
  (dict[str, np.ndarray]): Dictionary containing 'tp' key with correct prediction matrix of shape (N, 10) for 10 IoU levels.
270
270
  """
271
- if len(batch["cls"]) == 0 or len(preds["cls"]) == 0:
272
- return {"tp": np.zeros((len(preds["cls"]), self.niou), dtype=bool)}
271
+ if batch["cls"].shape[0] == 0 or preds["cls"].shape[0] == 0:
272
+ return {"tp": np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)}
273
273
  iou = box_iou(batch["bboxes"], preds["bboxes"])
274
274
  return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
275
275
 
@@ -93,8 +93,8 @@ class OBBValidator(DetectionValidator):
93
93
  >>> gt_cls = torch.randint(0, 5, (50,)) # 50 ground truth class labels
94
94
  >>> correct_matrix = validator._process_batch(detections, gt_bboxes, gt_cls)
95
95
  """
96
- if len(batch["cls"]) == 0 or len(preds["cls"]) == 0:
97
- return {"tp": np.zeros((len(preds["cls"]), self.niou), dtype=bool)}
96
+ if batch["cls"].shape[0] == 0 or preds["cls"].shape[0] == 0:
97
+ return {"tp": np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)}
98
98
  iou = batch_probiou(batch["bboxes"], preds["bboxes"])
99
99
  return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
100
100
 
@@ -134,7 +134,7 @@ class OBBValidator(DetectionValidator):
134
134
  ori_shape = batch["ori_shape"][si]
135
135
  imgsz = batch["img"].shape[2:]
136
136
  ratio_pad = batch["ratio_pad"][si]
137
- if len(cls):
137
+ if cls.shape[0]:
138
138
  bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes
139
139
  return {
140
140
  "cls": cls,
@@ -73,7 +73,7 @@ class PosePredictor(DetectionPredictor):
73
73
  """
74
74
  result = super().construct_result(pred, img, orig_img, img_path)
75
75
  # Extract keypoints from prediction and reshape according to model's keypoint shape
76
- pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape)
76
+ pred_kpts = pred[:, 6:].view(pred.shape[0], *self.model.kpt_shape)
77
77
  # Scale keypoints coordinates to match the original image dimensions
78
78
  pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
79
79
  result.update(keypoints=pred_kpts)
@@ -9,7 +9,6 @@ from typing import Any
9
9
  from ultralytics.models import yolo
10
10
  from ultralytics.nn.tasks import PoseModel
11
11
  from ultralytics.utils import DEFAULT_CFG, LOGGER
12
- from ultralytics.utils.plotting import plot_results
13
12
 
14
13
 
15
14
  class PoseTrainer(yolo.detect.DetectionTrainer):
@@ -30,7 +29,6 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
30
29
  set_model_attributes: Set keypoints shape attribute on the model.
31
30
  get_validator: Create a validator instance for model evaluation.
32
31
  plot_training_samples: Visualize training samples with keypoints.
33
- plot_metrics: Generate and save training/validation metric plots.
34
32
  get_dataset: Retrieve the dataset and ensure it contains required kpt_shape key.
35
33
 
36
34
  Examples:
@@ -101,10 +99,6 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
101
99
  self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
102
100
  )
103
101
 
104
- def plot_metrics(self):
105
- """Plot training/validation metrics."""
106
- plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png
107
-
108
102
  def get_dataset(self) -> dict[str, Any]:
109
103
  """
110
104
  Retrieve the dataset and ensure it contains the required `kpt_shape` key.
@@ -192,8 +192,8 @@ class PoseValidator(DetectionValidator):
192
192
  """
193
193
  tp = super()._process_batch(preds, batch)
194
194
  gt_cls = batch["cls"]
195
- if len(gt_cls) == 0 or len(preds["cls"]) == 0:
196
- tp_p = np.zeros((len(preds["cls"]), self.niou), dtype=bool)
195
+ if gt_cls.shape[0] == 0 or preds["cls"].shape[0] == 0:
196
+ tp_p = np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)
197
197
  else:
198
198
  # `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
199
199
  area = ops.xyxy2xywh(batch["bboxes"])[:, 2:].prod(1) * 0.53
@@ -90,7 +90,7 @@ class SegmentationPredictor(DetectionPredictor):
90
90
  Construct a single result object from the prediction.
91
91
 
92
92
  Args:
93
- pred (np.ndarray): The predicted bounding boxes, scores, and masks.
93
+ pred (torch.Tensor): The predicted bounding boxes, scores, and masks.
94
94
  img (torch.Tensor): The image after preprocessing.
95
95
  orig_img (np.ndarray): The original image before preprocessing.
96
96
  img_path (str): The path to the original image.
@@ -99,7 +99,7 @@ class SegmentationPredictor(DetectionPredictor):
99
99
  Returns:
100
100
  (Results): Result object containing the original image, image path, class names, bounding boxes, and masks.
101
101
  """
102
- if not len(pred): # save empty boxes
102
+ if pred.shape[0] == 0: # save empty boxes
103
103
  masks = None
104
104
  elif self.args.retina_masks:
105
105
  pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
@@ -8,7 +8,6 @@ from pathlib import Path
8
8
  from ultralytics.models import yolo
9
9
  from ultralytics.nn.tasks import SegmentationModel
10
10
  from ultralytics.utils import DEFAULT_CFG, RANK
11
- from ultralytics.utils.plotting import plot_results
12
11
 
13
12
 
14
13
  class SegmentationTrainer(yolo.detect.DetectionTrainer):
@@ -71,7 +70,3 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
71
70
  return yolo.segment.SegmentationValidator(
72
71
  self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
73
72
  )
74
-
75
- def plot_metrics(self):
76
- """Plot training/validation metrics."""
77
- plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png
@@ -112,7 +112,7 @@ class SegmentationValidator(DetectionValidator):
112
112
  coefficient = pred.pop("extra")
113
113
  pred["masks"] = (
114
114
  self.process(proto[i], coefficient, pred["bboxes"], shape=imgsz)
115
- if len(coefficient)
115
+ if coefficient.shape[0]
116
116
  else torch.zeros(
117
117
  (0, *(imgsz if self.process is ops.process_mask_native else proto.shape[2:])),
118
118
  dtype=torch.uint8,
@@ -133,16 +133,18 @@ class SegmentationValidator(DetectionValidator):
133
133
  (dict[str, Any]): Prepared batch with processed annotations.
134
134
  """
135
135
  prepared_batch = super()._prepare_batch(si, batch)
136
- nl = len(prepared_batch["cls"])
136
+ nl = prepared_batch["cls"].shape[0]
137
137
  if self.args.overlap_mask:
138
138
  masks = batch["masks"][si]
139
139
  index = torch.arange(1, nl + 1, device=masks.device).view(nl, 1, 1)
140
140
  masks = (masks == index).float()
141
141
  else:
142
142
  masks = batch["masks"][batch["batch_idx"] == si]
143
- if nl and self.process is ops.process_mask_native:
144
- masks = F.interpolate(masks[None], prepared_batch["imgsz"], mode="bilinear", align_corners=False)[0]
145
- masks = masks.gt_(0.5)
143
+ if nl:
144
+ mask_size = [s if self.process is ops.process_mask_native else s // 4 for s in prepared_batch["imgsz"]]
145
+ if masks.shape[1:] != mask_size:
146
+ masks = F.interpolate(masks[None], mask_size, mode="bilinear", align_corners=False)[0]
147
+ masks = masks.gt_(0.5)
146
148
  prepared_batch["masks"] = masks
147
149
  return prepared_batch
148
150
 
@@ -168,8 +170,8 @@ class SegmentationValidator(DetectionValidator):
168
170
  """
169
171
  tp = super()._process_batch(preds, batch)
170
172
  gt_cls = batch["cls"]
171
- if len(gt_cls) == 0 or len(preds["cls"]) == 0:
172
- tp_m = np.zeros((len(preds["cls"]), self.niou), dtype=bool)
173
+ if gt_cls.shape[0] == 0 or preds["cls"].shape[0] == 0:
174
+ tp_m = np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)
173
175
  else:
174
176
  iou = mask_iou(batch["masks"].flatten(1), preds["masks"].flatten(1))
175
177
  tp_m = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
@@ -89,7 +89,7 @@ class YOLOEDetectValidator(DetectionValidator):
89
89
  for i in range(preds.shape[0]):
90
90
  cls = batch["cls"][batch_idx == i].squeeze(-1).to(torch.int).unique(sorted=True)
91
91
  pad_cls = torch.ones(preds.shape[1], device=self.device) * -1
92
- pad_cls[: len(cls)] = cls
92
+ pad_cls[: cls.shape[0]] = cls
93
93
  for c in cls:
94
94
  visual_pe[c] += preds[i][pad_cls == c].sum(0) / cls_visual_num[c]
95
95
 
@@ -1921,7 +1921,7 @@ class A2C2f(nn.Module):
1921
1921
  y.extend(m(y[-1]) for m in self.m)
1922
1922
  y = self.cv2(torch.cat(y, 1))
1923
1923
  if self.gamma is not None:
1924
- return x + self.gamma.view(-1, len(self.gamma), 1, 1) * y
1924
+ return x + self.gamma.view(-1, self.gamma.shape[0], 1, 1) * y
1925
1925
  return y
1926
1926
 
1927
1927
 
ultralytics/nn/tasks.py CHANGED
@@ -766,7 +766,7 @@ class RTDETRDetectionModel(DetectionModel):
766
766
 
767
767
  img = batch["img"]
768
768
  # NOTE: preprocess gt_bbox and gt_labels to list.
769
- bs = len(img)
769
+ bs = img.shape[0]
770
770
  batch_idx = batch["batch_idx"]
771
771
  gt_groups = [(batch_idx == i).sum().item() for i in range(bs)]
772
772
  targets = {
@@ -923,7 +923,7 @@ class WorldModel(DetectionModel):
923
923
  (torch.Tensor): Model's output tensor.
924
924
  """
925
925
  txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype)
926
- if len(txt_feats) != len(x) or self.model[-1].export:
926
+ if txt_feats.shape[0] != x.shape[0] or self.model[-1].export:
927
927
  txt_feats = txt_feats.expand(x.shape[0], -1, -1)
928
928
  ori_txt_feats = txt_feats.clone()
929
929
  y, dt, embeddings = [], [], [] # outputs
@@ -907,7 +907,7 @@ def is_intel():
907
907
  try:
908
908
  result = subprocess.run(["xpu-smi", "discovery"], capture_output=True, text=True, timeout=5)
909
909
  return "intel" in result.stdout.lower()
910
- except (subprocess.TimeoutExpired, FileNotFoundError, subprocess.SubprocessError):
910
+ except Exception: # broad clause to capture all Intel GPU exception types
911
911
  return False
912
912
 
913
913
 
@@ -397,11 +397,11 @@ class ConfusionMatrix(DataExportMixin):
397
397
  gt_cls, gt_bboxes = batch["cls"], batch["bboxes"]
398
398
  if self.matches is not None: # only if visualization is enabled
399
399
  self.matches = {k: defaultdict(list) for k in {"TP", "FP", "FN", "GT"}}
400
- for i in range(len(gt_cls)):
400
+ for i in range(gt_cls.shape[0]):
401
401
  self._append_matches("GT", batch, i) # store GT
402
402
  is_obb = gt_bboxes.shape[1] == 5 # check if boxes contains angle for OBB
403
403
  conf = 0.25 if conf in {None, 0.01 if is_obb else 0.001} else conf # apply 0.25 if default val conf is passed
404
- no_pred = len(detections["cls"]) == 0
404
+ no_pred = detections["cls"].shape[0] == 0
405
405
  if gt_cls.shape[0] == 0: # Check if labels is empty
406
406
  if not no_pred:
407
407
  detections = {k: detections[k][detections["conf"] > conf] for k in detections}
@@ -491,13 +491,13 @@ class ConfusionMatrix(DataExportMixin):
491
491
  for i, mtype in enumerate(["GT", "FP", "TP", "FN"]):
492
492
  mbatch = self.matches[mtype]
493
493
  if "conf" not in mbatch:
494
- mbatch["conf"] = torch.tensor([1.0] * len(mbatch["bboxes"]), device=img.device)
495
- mbatch["batch_idx"] = torch.ones(len(mbatch["bboxes"]), device=img.device) * i
494
+ mbatch["conf"] = torch.tensor([1.0] * mbatch["bboxes"].shape[0], device=img.device)
495
+ mbatch["batch_idx"] = torch.ones(mbatch["bboxes"].shape[0], device=img.device) * i
496
496
  for k in mbatch.keys():
497
497
  labels[k] += mbatch[k]
498
498
 
499
499
  labels = {k: torch.stack(v, 0) if len(v) else v for k, v in labels.items()}
500
- if self.task != "obb" and len(labels["bboxes"]):
500
+ if self.task != "obb" and labels["bboxes"].shape[0]:
501
501
  labels["bboxes"] = xyxy2xywh(labels["bboxes"])
502
502
  (save_dir / "visualizations").mkdir(parents=True, exist_ok=True)
503
503
  plot_images(
@@ -980,7 +980,7 @@ class Metric(SimpleClass):
980
980
 
981
981
  def fitness(self) -> float:
982
982
  """Return model fitness as a weighted combination of metrics."""
983
- w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
983
+ w = [0.0, 0.0, 0.0, 1.0] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
984
984
  return (np.nan_to_num(np.array(self.mean_results())) * w).sum()
985
985
 
986
986
  def update(self, results: tuple):
ultralytics/utils/nms.py CHANGED
@@ -263,12 +263,11 @@ class TorchNMS:
263
263
  areas = (x2 - x1) * (y2 - y1)
264
264
 
265
265
  # Sort by scores descending
266
- _, order = scores.sort(0, descending=True)
266
+ order = scores.argsort(0, descending=True)
267
267
 
268
268
  # Pre-allocate keep list with maximum possible size
269
269
  keep = torch.zeros(order.numel(), dtype=torch.int64, device=boxes.device)
270
270
  keep_idx = 0
271
-
272
271
  while order.numel() > 0:
273
272
  i = order[0]
274
273
  keep[keep_idx] = i
@@ -276,7 +275,6 @@ class TorchNMS:
276
275
 
277
276
  if order.numel() == 1:
278
277
  break
279
-
280
278
  # Vectorized IoU calculation for remaining boxes
281
279
  rest = order[1:]
282
280
  xx1 = torch.maximum(x1[i], x1[rest])
@@ -288,20 +286,14 @@ class TorchNMS:
288
286
  w = (xx2 - xx1).clamp_(min=0)
289
287
  h = (yy2 - yy1).clamp_(min=0)
290
288
  inter = w * h
291
-
292
- # Early termination: skip IoU calculation if no intersection
289
+ # Early exit: skip IoU calculation if no intersection
293
290
  if inter.sum() == 0:
294
291
  # No overlaps with current box, keep all remaining boxes
295
- remaining_count = rest.numel()
296
- keep[keep_idx : keep_idx + remaining_count] = rest
297
- keep_idx += remaining_count
298
- break
299
-
292
+ order = rest
293
+ continue
300
294
  iou = inter / (areas[i] + areas[rest] - inter)
301
-
302
295
  # Keep boxes with IoU <= threshold
303
- mask = iou <= iou_threshold
304
- order = rest[mask]
296
+ order = rest[iou <= iou_threshold]
305
297
 
306
298
  return keep[:keep_idx]
307
299
 
@@ -812,14 +812,13 @@ def plot_images(
812
812
 
813
813
  # Plot masks
814
814
  if len(masks):
815
- if idx.shape[0] == masks.shape[0]: # overlap_mask=False
815
+ if idx.shape[0] == masks.shape[0] and masks.max() <= 1: # overlap_mask=False
816
816
  image_masks = masks[idx]
817
817
  else: # overlap_mask=True
818
818
  image_masks = masks[[i]] # (1, 640, 640)
819
819
  nl = idx.sum()
820
- index = np.arange(nl).reshape((nl, 1, 1)) + 1
821
- image_masks = np.repeat(image_masks, nl, axis=0)
822
- image_masks = np.where(image_masks == index, 1.0, 0.0)
820
+ index = np.arange(1, nl + 1).reshape((nl, 1, 1))
821
+ image_masks = (image_masks == index).astype(np.float32)
823
822
 
824
823
  im = np.asarray(annotator.im).copy()
825
824
  for j in range(len(image_masks)):
@@ -847,14 +846,7 @@ def plot_images(
847
846
 
848
847
 
849
848
  @plt_settings()
850
- def plot_results(
851
- file: str = "path/to/results.csv",
852
- dir: str = "",
853
- segment: bool = False,
854
- pose: bool = False,
855
- classify: bool = False,
856
- on_plot: Callable | None = None,
857
- ):
849
+ def plot_results(file: str = "path/to/results.csv", dir: str = "", on_plot: Callable | None = None):
858
850
  """
859
851
  Plot training results from a results CSV file. The function supports various types of data including segmentation,
860
852
  pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.
@@ -862,9 +854,6 @@ def plot_results(
862
854
  Args:
863
855
  file (str, optional): Path to the CSV file containing the training results.
864
856
  dir (str, optional): Directory where the CSV file is located if 'file' is not provided.
865
- segment (bool, optional): Flag to indicate if the data is for segmentation.
866
- pose (bool, optional): Flag to indicate if the data is for pose estimation.
867
- classify (bool, optional): Flag to indicate if the data is for classification.
868
857
  on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument.
869
858
 
870
859
  Examples:
@@ -876,34 +865,31 @@ def plot_results(
876
865
  from scipy.ndimage import gaussian_filter1d
877
866
 
878
867
  save_dir = Path(file).parent if file else Path(dir)
879
- if classify:
880
- fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True)
881
- index = [2, 5, 3, 4]
882
- elif segment:
883
- fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)
884
- index = [2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16, 17, 8, 9, 12, 13]
885
- elif pose:
886
- fig, ax = plt.subplots(2, 9, figsize=(21, 6), tight_layout=True)
887
- index = [2, 3, 4, 5, 6, 7, 8, 11, 12, 15, 16, 17, 18, 19, 9, 10, 13, 14]
888
- else:
889
- fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
890
- index = [2, 3, 4, 5, 6, 9, 10, 11, 7, 8]
891
- ax = ax.ravel()
892
868
  files = list(save_dir.glob("results*.csv"))
893
869
  assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."
894
- for f in files:
870
+
871
+ loss_keys, metric_keys = [], []
872
+ for i, f in enumerate(files):
895
873
  try:
896
874
  data = pl.read_csv(f, infer_schema_length=None)
897
- s = [x.strip() for x in data.columns]
875
+ if i == 0:
876
+ for c in data.columns:
877
+ if "loss" in c:
878
+ loss_keys.append(c)
879
+ elif "metric" in c:
880
+ metric_keys.append(c)
881
+ loss_mid, metric_mid = len(loss_keys) // 2, len(metric_keys) // 2
882
+ columns = (
883
+ loss_keys[:loss_mid] + metric_keys[:metric_mid] + loss_keys[loss_mid:] + metric_keys[metric_mid:]
884
+ )
885
+ fig, ax = plt.subplots(2, len(columns) // 2, figsize=(len(columns) + 2, 6), tight_layout=True)
886
+ ax = ax.ravel()
898
887
  x = data.select(data.columns[0]).to_numpy().flatten()
899
- for i, j in enumerate(index):
900
- y = data.select(data.columns[j]).to_numpy().flatten().astype("float")
901
- # y[y == 0] = np.nan # don't show zero values
888
+ for i, j in enumerate(columns):
889
+ y = data.select(j).to_numpy().flatten().astype("float")
902
890
  ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results
903
891
  ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line
904
- ax[i].set_title(s[j], fontsize=12)
905
- # if j in {8, 9, 10}: # share train and val loss y axes
906
- # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
892
+ ax[i].set_title(j, fontsize=12)
907
893
  except Exception as e:
908
894
  LOGGER.error(f"Plotting error for {f}: {e}")
909
895
  ax[1].legend()
@@ -1012,7 +1012,7 @@ def attempt_compile(
1012
1012
  imgsz: int = 640,
1013
1013
  use_autocast: bool = False,
1014
1014
  warmup: bool = False,
1015
- prefix: str = colorstr("compile:"),
1015
+ mode: bool | str = "default",
1016
1016
  ) -> torch.nn.Module:
1017
1017
  """
1018
1018
  Compile a model with torch.compile and optionally warm up the graph to reduce first-iteration latency.
@@ -1027,7 +1027,8 @@ def attempt_compile(
1027
1027
  imgsz (int, optional): Square input size to create a dummy tensor with shape (1, 3, imgsz, imgsz) for warmup.
1028
1028
  use_autocast (bool, optional): Whether to run warmup under autocast on CUDA or MPS devices.
1029
1029
  warmup (bool, optional): Whether to execute a single dummy forward pass to warm up the compiled model.
1030
- prefix (str, optional): Message prefix for logger output.
1030
+ mode (bool | str, optional): torch.compile mode. True "default", False → no compile, or a string like
1031
+ "default", "reduce-overhead", "max-autotune".
1031
1032
 
1032
1033
  Returns:
1033
1034
  model (torch.nn.Module): Compiled model if compilation succeeds, otherwise the original unmodified model.
@@ -1042,13 +1043,16 @@ def attempt_compile(
1042
1043
  >>> # Try to compile and warm up a model with a 640x640 input
1043
1044
  >>> model = attempt_compile(model, device=device, imgsz=640, use_autocast=True, warmup=True)
1044
1045
  """
1045
- if not hasattr(torch, "compile"):
1046
+ if not hasattr(torch, "compile") or not mode:
1046
1047
  return model
1047
1048
 
1048
- LOGGER.info(f"{prefix} starting torch.compile...")
1049
+ if mode is True:
1050
+ mode = "default"
1051
+ prefix = colorstr("compile:")
1052
+ LOGGER.info(f"{prefix} starting torch.compile with '{mode}' mode...")
1049
1053
  t0 = time.perf_counter()
1050
1054
  try:
1051
- model = torch.compile(model, mode="max-autotune", backend="inductor")
1055
+ model = torch.compile(model, mode=mode, backend="inductor")
1052
1056
  except Exception as e:
1053
1057
  LOGGER.warning(f"{prefix} torch.compile failed, continuing uncompiled: {e}")
1054
1058
  return model