dgenerate-ultralytics-headless 8.3.195__py3-none-any.whl → 8.3.196__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {dgenerate_ultralytics_headless-8.3.195.dist-info → dgenerate_ultralytics_headless-8.3.196.dist-info}/METADATA +1 -1
  2. {dgenerate_ultralytics_headless-8.3.195.dist-info → dgenerate_ultralytics_headless-8.3.196.dist-info}/RECORD +35 -35
  3. ultralytics/__init__.py +1 -1
  4. ultralytics/cfg/__init__.py +1 -0
  5. ultralytics/cfg/default.yaml +1 -0
  6. ultralytics/data/augment.py +1 -1
  7. ultralytics/data/build.py +5 -1
  8. ultralytics/engine/exporter.py +19 -31
  9. ultralytics/engine/predictor.py +3 -1
  10. ultralytics/engine/trainer.py +15 -4
  11. ultralytics/engine/validator.py +6 -2
  12. ultralytics/models/yolo/classify/train.py +1 -11
  13. ultralytics/models/yolo/detect/train.py +32 -6
  14. ultralytics/models/yolo/detect/val.py +6 -5
  15. ultralytics/models/yolo/obb/train.py +0 -9
  16. ultralytics/models/yolo/pose/train.py +1 -9
  17. ultralytics/models/yolo/pose/val.py +1 -1
  18. ultralytics/models/yolo/segment/train.py +1 -9
  19. ultralytics/models/yolo/segment/val.py +1 -1
  20. ultralytics/models/yolo/world/train.py +4 -4
  21. ultralytics/models/yolo/world/train_world.py +2 -2
  22. ultralytics/models/yolo/yoloe/train.py +3 -12
  23. ultralytics/models/yolo/yoloe/val.py +0 -7
  24. ultralytics/nn/modules/head.py +2 -1
  25. ultralytics/nn/tasks.py +4 -2
  26. ultralytics/utils/__init__.py +30 -19
  27. ultralytics/utils/callbacks/tensorboard.py +2 -2
  28. ultralytics/utils/checks.py +2 -0
  29. ultralytics/utils/loss.py +14 -8
  30. ultralytics/utils/plotting.py +1 -0
  31. ultralytics/utils/torch_utils.py +111 -9
  32. {dgenerate_ultralytics_headless-8.3.195.dist-info → dgenerate_ultralytics_headless-8.3.196.dist-info}/WHEEL +0 -0
  33. {dgenerate_ultralytics_headless-8.3.195.dist-info → dgenerate_ultralytics_headless-8.3.196.dist-info}/entry_points.txt +0 -0
  34. {dgenerate_ultralytics_headless-8.3.195.dist-info → dgenerate_ultralytics_headless-8.3.196.dist-info}/licenses/LICENSE +0 -0
  35. {dgenerate_ultralytics_headless-8.3.195.dist-info → dgenerate_ultralytics_headless-8.3.196.dist-info}/top_level.txt +0 -0
@@ -37,21 +37,12 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
37
37
  """
38
38
  Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
39
39
 
40
- This trainer extends the DetectionTrainer class to specialize in training models that detect oriented
41
- bounding boxes. It automatically sets the task to 'obb' in the configuration.
42
-
43
40
  Args:
44
41
  cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and
45
42
  model configuration.
46
43
  overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here
47
44
  will take precedence over those in cfg.
48
45
  _callbacks (list[Any], optional): List of callback functions to be invoked during training.
49
-
50
- Examples:
51
- >>> from ultralytics.models.yolo.obb import OBBTrainer
52
- >>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml", epochs=3)
53
- >>> trainer = OBBTrainer(overrides=args)
54
- >>> trainer.train()
55
46
  """
56
47
  if overrides is None:
57
48
  overrides = {}
@@ -44,9 +44,6 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
44
44
  """
45
45
  Initialize a PoseTrainer object for training YOLO pose estimation models.
46
46
 
47
- This initializes a trainer specialized for pose estimation tasks, setting the task to 'pose' and
48
- handling specific configurations needed for keypoint detection models.
49
-
50
47
  Args:
51
48
  cfg (dict, optional): Default configuration dictionary containing training parameters.
52
49
  overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
@@ -55,17 +52,12 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
55
52
  Notes:
56
53
  This trainer will automatically set the task to 'pose' regardless of what is provided in overrides.
57
54
  A warning is issued when using Apple MPS device due to known bugs with pose models.
58
-
59
- Examples:
60
- >>> from ultralytics.models.yolo.pose import PoseTrainer
61
- >>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3)
62
- >>> trainer = PoseTrainer(overrides=args)
63
- >>> trainer.train()
64
55
  """
65
56
  if overrides is None:
66
57
  overrides = {}
67
58
  overrides["task"] = "pose"
68
59
  super().__init__(cfg, overrides, _callbacks)
60
+ self.dynamic_tensors = ["batch_idx", "cls", "bboxes", "keypoints"]
69
61
 
70
62
  if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
71
63
  LOGGER.warning(
@@ -86,7 +86,7 @@ class PoseValidator(DetectionValidator):
86
86
  def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
87
87
  """Preprocess batch by converting keypoints data to float and moving it to the device."""
88
88
  batch = super().preprocess(batch)
89
- batch["keypoints"] = batch["keypoints"].to(self.device, non_blocking=True).float()
89
+ batch["keypoints"] = batch["keypoints"].float()
90
90
  return batch
91
91
 
92
92
  def get_desc(self) -> str:
@@ -32,24 +32,16 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
32
32
  """
33
33
  Initialize a SegmentationTrainer object.
34
34
 
35
- This initializes a trainer for segmentation tasks, extending the detection trainer with segmentation-specific
36
- functionality. It sets the task to 'segment' and prepares the trainer for training segmentation models.
37
-
38
35
  Args:
39
36
  cfg (dict): Configuration dictionary with default training settings.
40
37
  overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
41
38
  _callbacks (list, optional): List of callback functions to be executed during training.
42
-
43
- Examples:
44
- >>> from ultralytics.models.yolo.segment import SegmentationTrainer
45
- >>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml", epochs=3)
46
- >>> trainer = SegmentationTrainer(overrides=args)
47
- >>> trainer.train()
48
39
  """
49
40
  if overrides is None:
50
41
  overrides = {}
51
42
  overrides["task"] = "segment"
52
43
  super().__init__(cfg, overrides, _callbacks)
44
+ self.dynamic_tensors = ["batch_idx", "cls", "bboxes", "masks"]
53
45
 
54
46
  def get_model(self, cfg: dict | str | None = None, weights: str | Path | None = None, verbose: bool = True):
55
47
  """
@@ -63,7 +63,7 @@ class SegmentationValidator(DetectionValidator):
63
63
  (dict[str, Any]): Preprocessed batch.
64
64
  """
65
65
  batch = super().preprocess(batch)
66
- batch["masks"] = batch["masks"].to(self.device, non_blocking=True).float()
66
+ batch["masks"] = batch["masks"].float()
67
67
  return batch
68
68
 
69
69
  def init_metrics(self, model: torch.nn.Module) -> None:
@@ -12,7 +12,7 @@ from ultralytics.data import build_yolo_dataset
12
12
  from ultralytics.models.yolo.detect import DetectionTrainer
13
13
  from ultralytics.nn.tasks import WorldModel
14
14
  from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
15
- from ultralytics.utils.torch_utils import de_parallel
15
+ from ultralytics.utils.torch_utils import unwrap_model
16
16
 
17
17
 
18
18
  def on_pretrain_routine_end(trainer) -> None:
@@ -20,7 +20,7 @@ def on_pretrain_routine_end(trainer) -> None:
20
20
  if RANK in {-1, 0}:
21
21
  # Set class names for evaluation
22
22
  names = [name.split("/", 1)[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
23
- de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)
23
+ unwrap_model(trainer.ema.ema).set_classes(names, cache_clip_model=False)
24
24
 
25
25
 
26
26
  class WorldTrainer(DetectionTrainer):
@@ -105,7 +105,7 @@ class WorldTrainer(DetectionTrainer):
105
105
  Returns:
106
106
  (Any): YOLO dataset configured for training or validation.
107
107
  """
108
- gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
108
+ gs = max(int(unwrap_model(self.model).stride.max() if self.model else 0), 32)
109
109
  dataset = build_yolo_dataset(
110
110
  self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
111
111
  )
@@ -160,7 +160,7 @@ class WorldTrainer(DetectionTrainer):
160
160
  return txt_map
161
161
  LOGGER.info(f"Caching text embeddings to '{cache_path}'")
162
162
  assert self.model is not None
163
- txt_feats = de_parallel(self.model).get_text_pe(texts, batch, cache_clip_model=False)
163
+ txt_feats = unwrap_model(self.model).get_text_pe(texts, batch, cache_clip_model=False)
164
164
  txt_map = dict(zip(texts, txt_feats.squeeze(0)))
165
165
  torch.save(txt_map, cache_path)
166
166
  return txt_map
@@ -6,7 +6,7 @@ from ultralytics.data import YOLOConcatDataset, build_grounding, build_yolo_data
6
6
  from ultralytics.data.utils import check_det_dataset
7
7
  from ultralytics.models.yolo.world import WorldTrainer
8
8
  from ultralytics.utils import DATASETS_DIR, DEFAULT_CFG, LOGGER
9
- from ultralytics.utils.torch_utils import de_parallel
9
+ from ultralytics.utils.torch_utils import unwrap_model
10
10
 
11
11
 
12
12
  class WorldTrainerFromScratch(WorldTrainer):
@@ -101,7 +101,7 @@ class WorldTrainerFromScratch(WorldTrainer):
101
101
  Returns:
102
102
  (YOLOConcatDataset | Dataset): The constructed dataset for training or validation.
103
103
  """
104
- gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
104
+ gs = max(int(unwrap_model(self.model).stride.max() if self.model else 0), 32)
105
105
  if mode != "train":
106
106
  return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=False, stride=gs)
107
107
  datasets = [
@@ -13,7 +13,7 @@ from ultralytics.data.augment import LoadVisualPrompt
13
13
  from ultralytics.models.yolo.detect import DetectionTrainer, DetectionValidator
14
14
  from ultralytics.nn.tasks import YOLOEModel
15
15
  from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
16
- from ultralytics.utils.torch_utils import de_parallel
16
+ from ultralytics.utils.torch_utils import unwrap_model
17
17
 
18
18
  from ..world.train_world import WorldTrainerFromScratch
19
19
  from .val import YOLOEDetectValidator
@@ -39,9 +39,6 @@ class YOLOETrainer(DetectionTrainer):
39
39
  """
40
40
  Initialize the YOLOE Trainer with specified configurations.
41
41
 
42
- This method sets up the YOLOE trainer with the provided configuration and overrides, initializing
43
- the training environment, model, and callbacks for YOLOE object detection training.
44
-
45
42
  Args:
46
43
  cfg (dict): Configuration dictionary with default training settings from DEFAULT_CFG.
47
44
  overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
@@ -102,7 +99,7 @@ class YOLOETrainer(DetectionTrainer):
102
99
  Returns:
103
100
  (Dataset): YOLO dataset configured for training or validation.
104
101
  """
105
- gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
102
+ gs = max(int(unwrap_model(self.model).stride.max() if self.model else 0), 32)
106
103
  return build_yolo_dataset(
107
104
  self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
108
105
  )
@@ -223,7 +220,7 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
223
220
  return txt_map
224
221
  LOGGER.info(f"Caching text embeddings to '{cache_path}'")
225
222
  assert self.model is not None
226
- txt_feats = de_parallel(self.model).get_text_pe(texts, batch, without_reprta=True, cache_clip_model=False)
223
+ txt_feats = unwrap_model(self.model).get_text_pe(texts, batch, without_reprta=True, cache_clip_model=False)
227
224
  txt_map = dict(zip(texts, txt_feats.squeeze(0)))
228
225
  torch.save(txt_map, cache_path)
229
226
  return txt_map
@@ -313,9 +310,3 @@ class YOLOEVPTrainer(YOLOETrainerFromScratch):
313
310
  d.transforms.append(LoadVisualPrompt())
314
311
  else:
315
312
  self.train_loader.dataset.transforms.append(LoadVisualPrompt())
316
-
317
- def preprocess_batch(self, batch):
318
- """Preprocess a batch of images for YOLOE training, moving visual prompts to the appropriate device."""
319
- batch = super().preprocess_batch(batch)
320
- batch["visuals"] = batch["visuals"].to(self.device, non_blocking=True)
321
- return batch
@@ -98,13 +98,6 @@ class YOLOEDetectValidator(DetectionValidator):
98
98
  visual_pe[cls_visual_num == 0] = 0
99
99
  return visual_pe.unsqueeze(0)
100
100
 
101
- def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
102
- """Preprocess batch data, ensuring visuals are on the same device as images."""
103
- batch = super().preprocess(batch)
104
- if "visuals" in batch:
105
- batch["visuals"] = batch["visuals"].to(batch["img"].device, non_blocking=True)
106
- return batch
107
-
108
101
  def get_vpe_dataloader(self, data: dict[str, Any]) -> torch.utils.data.DataLoader:
109
102
  """
110
103
  Create a dataloader for LVIS training visual prompt samples.
@@ -13,7 +13,7 @@ from torch.nn.init import constant_, xavier_uniform_
13
13
 
14
14
  from ultralytics.utils import NOT_MACOS14
15
15
  from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors
16
- from ultralytics.utils.torch_utils import fuse_conv_and_bn, smart_inference_mode
16
+ from ultralytics.utils.torch_utils import disable_dynamo, fuse_conv_and_bn, smart_inference_mode
17
17
 
18
18
  from .block import DFL, SAVPE, BNContrastiveHead, ContrastiveHead, Proto, Residual, SwiGLUFFN
19
19
  from .conv import Conv, DWConv
@@ -149,6 +149,7 @@ class Detect(nn.Module):
149
149
  y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc)
150
150
  return y if self.export else (y, {"one2many": x, "one2one": one2one})
151
151
 
152
+ @disable_dynamo
152
153
  def _inference(self, x: list[torch.Tensor]) -> torch.Tensor:
153
154
  """
154
155
  Decode predicted bounding boxes and class probabilities based on multiple-level feature maps.
ultralytics/nn/tasks.py CHANGED
@@ -334,7 +334,8 @@ class BaseModel(torch.nn.Module):
334
334
  if getattr(self, "criterion", None) is None:
335
335
  self.criterion = self.init_criterion()
336
336
 
337
- preds = self.forward(batch["img"]) if preds is None else preds
337
+ if preds is None:
338
+ preds = self.forward(batch["img"])
338
339
  return self.criterion(preds, batch)
339
340
 
340
341
  def init_criterion(self):
@@ -775,7 +776,8 @@ class RTDETRDetectionModel(DetectionModel):
775
776
  "gt_groups": gt_groups,
776
777
  }
777
778
 
778
- preds = self.predict(img, batch=targets) if preds is None else preds
779
+ if preds is None:
780
+ preds = self.predict(img, batch=targets)
779
781
  dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1]
780
782
  if dn_meta is None:
781
783
  dn_bboxes, dn_scores = None, None
@@ -857,7 +857,7 @@ def get_ubuntu_version():
857
857
 
858
858
  def get_user_config_dir(sub_dir="Ultralytics"):
859
859
  """
860
- Return the appropriate config directory based on the environment operating system.
860
+ Return a writable config dir, preferring YOLO_CONFIG_DIR and being OS-aware.
861
861
 
862
862
  Args:
863
863
  sub_dir (str): The name of the subdirectory to create.
@@ -865,27 +865,38 @@ def get_user_config_dir(sub_dir="Ultralytics"):
865
865
  Returns:
866
866
  (Path): The path to the user config directory.
867
867
  """
868
- if WINDOWS:
869
- path = Path.home() / "AppData" / "Roaming" / sub_dir
870
- elif MACOS: # macOS
871
- path = Path.home() / "Library" / "Application Support" / sub_dir
868
+ if env_dir := os.getenv("YOLO_CONFIG_DIR"):
869
+ p = Path(env_dir).expanduser() / sub_dir
872
870
  elif LINUX:
873
- path = Path.home() / ".config" / sub_dir
871
+ p = Path(os.getenv("XDG_CONFIG_HOME", Path.home() / ".config")) / sub_dir
872
+ elif WINDOWS:
873
+ p = Path.home() / "AppData" / "Roaming" / sub_dir
874
+ elif MACOS:
875
+ p = Path.home() / "Library" / "Application Support" / sub_dir
874
876
  else:
875
877
  raise ValueError(f"Unsupported operating system: {platform.system()}")
876
878
 
877
- # GCP and AWS lambda fix, only /tmp is writeable
878
- if not is_dir_writeable(path.parent):
879
- LOGGER.warning(
880
- f"user config directory '{path}' is not writeable, defaulting to '/tmp' or CWD. "
881
- "Alternatively you can define a YOLO_CONFIG_DIR environment variable for this path."
882
- )
883
- path = Path("/tmp") / sub_dir if is_dir_writeable("/tmp") else Path().cwd() / sub_dir
884
-
885
- # Create the subdirectory if it does not exist
886
- path.mkdir(parents=True, exist_ok=True)
879
+ if p.exists(): # already created trust it
880
+ return p
881
+ if is_dir_writeable(p.parent): # create if possible
882
+ p.mkdir(parents=True, exist_ok=True)
883
+ return p
884
+
885
+ # Fallbacks for Docker, GCP/AWS functions where only /tmp is writeable
886
+ for alt in [Path("/tmp") / sub_dir, Path.cwd() / sub_dir]:
887
+ if alt.exists():
888
+ return alt
889
+ if is_dir_writeable(alt.parent):
890
+ alt.mkdir(parents=True, exist_ok=True)
891
+ LOGGER.warning(
892
+ f"user config directory '{p}' is not writeable, using '{alt}'. Set YOLO_CONFIG_DIR to override."
893
+ )
894
+ return alt
887
895
 
888
- return path
896
+ # Last fallback → CWD
897
+ p = Path.cwd() / sub_dir
898
+ p.mkdir(parents=True, exist_ok=True)
899
+ return p
889
900
 
890
901
 
891
902
  # Define constants (required below)
@@ -899,7 +910,7 @@ IS_JUPYTER = is_jupyter()
899
910
  IS_PIP_PACKAGE = is_pip_package()
900
911
  IS_RASPBERRYPI = is_raspberrypi()
901
912
  GIT = GitRepo()
902
- USER_CONFIG_DIR = Path(os.getenv("YOLO_CONFIG_DIR") or get_user_config_dir()) # Ultralytics settings dir
913
+ USER_CONFIG_DIR = get_user_config_dir() # Ultralytics settings dir
903
914
  SETTINGS_FILE = USER_CONFIG_DIR / "settings.json"
904
915
 
905
916
 
@@ -1383,7 +1394,7 @@ class SettingsManager(JSONDict):
1383
1394
 
1384
1395
  def deprecation_warn(arg, new_arg=None):
1385
1396
  """Issue a deprecation warning when a deprecated argument is used, suggesting an updated argument."""
1386
- msg = f"'{arg}' is deprecated and will be removed in in the future."
1397
+ msg = f"'{arg}' is deprecated and will be removed in the future."
1387
1398
  if new_arg is not None:
1388
1399
  msg += f" Use '{new_arg}' instead."
1389
1400
  LOGGER.warning(msg)
@@ -70,14 +70,14 @@ def _log_tensorboard_graph(trainer) -> None:
70
70
  # Try simple method first (YOLO)
71
71
  try:
72
72
  trainer.model.eval() # place in .eval() mode to avoid BatchNorm statistics changes
73
- WRITER.add_graph(torch.jit.trace(torch_utils.de_parallel(trainer.model), im, strict=False), [])
73
+ WRITER.add_graph(torch.jit.trace(torch_utils.unwrap_model(trainer.model), im, strict=False), [])
74
74
  LOGGER.info(f"{PREFIX}model graph visualization added ✅")
75
75
  return
76
76
 
77
77
  except Exception:
78
78
  # Fallback to TorchScript export steps (RTDETR)
79
79
  try:
80
- model = deepcopy(torch_utils.de_parallel(trainer.model))
80
+ model = deepcopy(torch_utils.unwrap_model(trainer.model))
81
81
  model.eval()
82
82
  model = model.fuse(verbose=False)
83
83
  for m in model.modules():
@@ -452,6 +452,8 @@ def check_torchvision():
452
452
  to the compatibility table based on: https://github.com/pytorch/vision#installation.
453
453
  """
454
454
  compatibility_table = {
455
+ "2.9": ["0.24"],
456
+ "2.8": ["0.23"],
455
457
  "2.7": ["0.22"],
456
458
  "2.6": ["0.21"],
457
459
  "2.5": ["0.20"],
ultralytics/utils/loss.py CHANGED
@@ -11,7 +11,7 @@ import torch.nn.functional as F
11
11
  from ultralytics.utils.metrics import OKS_SIGMA
12
12
  from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
13
13
  from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
14
- from ultralytics.utils.torch_utils import autocast
14
+ from ultralytics.utils.torch_utils import autocast, disable_dynamo
15
15
 
16
16
  from .metrics import bbox_iou, probiou
17
17
  from .tal import bbox2dist
@@ -215,6 +215,7 @@ class v8DetectionLoss:
215
215
  self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)
216
216
  self.bbox_loss = BboxLoss(m.reg_max).to(device)
217
217
  self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
218
+ disable_dynamo(self.__class__) # exclude from compile
218
219
 
219
220
  def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:
220
221
  """Preprocess targets by converting to tensor format and scaling coordinates."""
@@ -260,7 +261,7 @@ class v8DetectionLoss:
260
261
 
261
262
  # Targets
262
263
  targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
263
- targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
264
+ targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
264
265
  gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
265
266
  mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
266
267
 
@@ -287,9 +288,14 @@ class v8DetectionLoss:
287
288
 
288
289
  # Bbox loss
289
290
  if fg_mask.sum():
290
- target_bboxes /= stride_tensor
291
291
  loss[0], loss[2] = self.bbox_loss(
292
- pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
292
+ pred_distri,
293
+ pred_bboxes,
294
+ anchor_points,
295
+ target_bboxes / stride_tensor,
296
+ target_scores,
297
+ target_scores_sum,
298
+ fg_mask,
293
299
  )
294
300
 
295
301
  loss[0] *= self.hyp.box # box gain
@@ -329,7 +335,7 @@ class v8SegmentationLoss(v8DetectionLoss):
329
335
  try:
330
336
  batch_idx = batch["batch_idx"].view(-1, 1)
331
337
  targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
332
- targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
338
+ targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
333
339
  gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
334
340
  mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
335
341
  except RuntimeError as e:
@@ -388,7 +394,7 @@ class v8SegmentationLoss(v8DetectionLoss):
388
394
  loss[2] *= self.hyp.cls # cls gain
389
395
  loss[3] *= self.hyp.dfl # dfl gain
390
396
 
391
- return loss * batch_size, loss.detach() # loss(box, cls, dfl)
397
+ return loss * batch_size, loss.detach() # loss(box, seg, cls, dfl)
392
398
 
393
399
  @staticmethod
394
400
  def single_mask_loss(
@@ -516,7 +522,7 @@ class v8PoseLoss(v8DetectionLoss):
516
522
  batch_size = pred_scores.shape[0]
517
523
  batch_idx = batch["batch_idx"].view(-1, 1)
518
524
  targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
519
- targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
525
+ targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
520
526
  gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
521
527
  mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
522
528
 
@@ -704,7 +710,7 @@ class v8OBBLoss(v8DetectionLoss):
704
710
  targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
705
711
  rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item()
706
712
  targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
707
- targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
713
+ targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
708
714
  gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
709
715
  mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
710
716
  except RuntimeError as e:
@@ -1004,6 +1004,7 @@ def plot_tune_results(csv_file: str = "tune_results.csv"):
1004
1004
  _save_one_file(csv_file.with_name("tune_fitness.png"))
1005
1005
 
1006
1006
 
1007
+ @plt_settings()
1007
1008
  def feature_visualization(x, module_type: str, stage: int, n: int = 32, save_dir: Path = Path("runs/detect/exp")):
1008
1009
  """
1009
1010
  Visualize feature maps of a given model module during inference.
@@ -429,7 +429,7 @@ def get_flops(model, imgsz=640):
429
429
  return 0.0 # if not installed return 0.0 GFLOPs
430
430
 
431
431
  try:
432
- model = de_parallel(model)
432
+ model = unwrap_model(model)
433
433
  p = next(model.parameters())
434
434
  if not isinstance(imgsz, list):
435
435
  imgsz = [imgsz, imgsz] # expand if int/float
@@ -460,7 +460,7 @@ def get_flops_with_torch_profiler(model, imgsz=640):
460
460
  """
461
461
  if not TORCH_2_0: # torch profiler implemented in torch>=2.0
462
462
  return 0.0
463
- model = de_parallel(model)
463
+ model = unwrap_model(model)
464
464
  p = next(model.parameters())
465
465
  if not isinstance(imgsz, list):
466
466
  imgsz = [imgsz, imgsz] # expand if int/float
@@ -577,17 +577,24 @@ def is_parallel(model):
577
577
  return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))
578
578
 
579
579
 
580
- def de_parallel(model):
580
+ def unwrap_model(m: nn.Module) -> nn.Module:
581
581
  """
582
- De-parallelize a model: return single-GPU model if model is of type DP or DDP.
582
+ Unwrap compiled and parallel models to get the base model.
583
583
 
584
584
  Args:
585
- model (nn.Module): Model to de-parallelize.
585
+ m (nn.Module): A model that may be wrapped by torch.compile (._orig_mod) or parallel wrappers such as
586
+ DataParallel/DistributedDataParallel (.module).
586
587
 
587
588
  Returns:
588
- (nn.Module): De-parallelized model.
589
+ m (nn.Module): The unwrapped base model without compile or parallel wrappers.
589
590
  """
590
- return model.module if is_parallel(model) else model
591
+ while True:
592
+ if hasattr(m, "_orig_mod") and isinstance(m._orig_mod, nn.Module):
593
+ m = m._orig_mod
594
+ elif hasattr(m, "module") and isinstance(m.module, nn.Module):
595
+ m = m.module
596
+ else:
597
+ return m
591
598
 
592
599
 
593
600
  def one_cycle(y1=0.0, y2=1.0, steps=100):
@@ -669,7 +676,7 @@ class ModelEMA:
669
676
  tau (int, optional): EMA decay time constant.
670
677
  updates (int, optional): Initial number of updates.
671
678
  """
672
- self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
679
+ self.ema = deepcopy(unwrap_model(model)).eval() # FP32 EMA
673
680
  self.updates = updates # number of EMA updates
674
681
  self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
675
682
  for p in self.ema.parameters():
@@ -687,7 +694,7 @@ class ModelEMA:
687
694
  self.updates += 1
688
695
  d = self.decay(self.updates)
689
696
 
690
- msd = de_parallel(model).state_dict() # model state_dict
697
+ msd = unwrap_model(model).state_dict() # model state_dict
691
698
  for k, v in self.ema.state_dict().items():
692
699
  if v.dtype.is_floating_point: # true for FP16 and FP32
693
700
  v *= d
@@ -997,3 +1004,98 @@ class FXModel(nn.Module):
997
1004
  x = m(x) # run
998
1005
  y.append(x) # save output
999
1006
  return x
1007
+
1008
+
1009
+ def disable_dynamo(func: Any) -> Any:
1010
+ """
1011
+ Disable torch.compile/dynamo for a callable when available.
1012
+
1013
+ Args:
1014
+ func (Any): Callable object to wrap. Could be a function, method, or class.
1015
+
1016
+ Returns:
1017
+ func (Any): Same callable, wrapped by torch._dynamo.disable when available, otherwise unchanged.
1018
+
1019
+ Examples:
1020
+ >>> @disable_dynamo
1021
+ ... def fn(x):
1022
+ ... return x + 1
1023
+ >>> # Works even if torch._dynamo is not available
1024
+ >>> _ = fn(1)
1025
+ """
1026
+ if hasattr(torch, "_dynamo"):
1027
+ return torch._dynamo.disable(func)
1028
+ return func
1029
+
1030
+
1031
+ def attempt_compile(
1032
+ model: torch.nn.Module,
1033
+ device: torch.device,
1034
+ imgsz: int = 640,
1035
+ use_autocast: bool = False,
1036
+ warmup: bool = False,
1037
+ prefix: str = colorstr("compile:"),
1038
+ ) -> torch.nn.Module:
1039
+ """
1040
+ Compile a model with torch.compile and optionally warm up the graph to reduce first-iteration latency.
1041
+
1042
+ This utility attempts to compile the provided model using the inductor backend with dynamic shapes enabled and an
1043
+ autotuning mode. If compilation is unavailable or fails, the original model is returned unchanged. An optional
1044
+ warmup performs a single forward pass on a dummy input to prime the compiled graph and measure compile/warmup time.
1045
+
1046
+ Args:
1047
+ model (torch.nn.Module): Model to compile.
1048
+ device (torch.device): Inference device used for warmup and autocast decisions.
1049
+ imgsz (int, optional): Square input size to create a dummy tensor with shape (1, 3, imgsz, imgsz) for warmup.
1050
+ use_autocast (bool, optional): Whether to run warmup under autocast on CUDA or MPS devices.
1051
+ warmup (bool, optional): Whether to execute a single dummy forward pass to warm up the compiled model.
1052
+ prefix (str, optional): Message prefix for logger output.
1053
+
1054
+ Returns:
1055
+ model (torch.nn.Module): Compiled model if compilation succeeds, otherwise the original unmodified model.
1056
+
1057
+ Notes:
1058
+ - If the current PyTorch build does not provide torch.compile, the function returns the input model immediately.
1059
+ - Warmup runs under torch.inference_mode and may use torch.autocast for CUDA/MPS to align compute precision.
1060
+ - CUDA devices are synchronized after warmup to account for asynchronous kernel execution.
1061
+
1062
+ Examples:
1063
+ >>> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1064
+ >>> # Try to compile and warm up a model with a 640x640 input
1065
+ >>> model = attempt_compile(model, device=device, imgsz=640, use_autocast=True, warmup=True)
1066
+ """
1067
+ if not hasattr(torch, "compile"):
1068
+ return model
1069
+
1070
+ LOGGER.info(f"{prefix} starting torch.compile...")
1071
+ t0 = time.perf_counter()
1072
+ try:
1073
+ model = torch.compile(model, mode="max-autotune", backend="inductor")
1074
+ except Exception as e:
1075
+ LOGGER.warning(f"{prefix} torch.compile failed, continuing uncompiled: {e}")
1076
+ return model
1077
+ t_compile = time.perf_counter() - t0
1078
+
1079
+ t_warm = 0.0
1080
+ if warmup:
1081
+ # Use a single dummy tensor to build the graph shape state and reduce first-iteration latency
1082
+ dummy = torch.zeros(1, 3, imgsz, imgsz, device=device)
1083
+ if use_autocast and device.type == "cuda":
1084
+ dummy = dummy.half()
1085
+ t1 = time.perf_counter()
1086
+ with torch.inference_mode():
1087
+ if use_autocast and device.type in {"cuda", "mps"}:
1088
+ with torch.autocast(device.type):
1089
+ _ = model(dummy)
1090
+ else:
1091
+ _ = model(dummy)
1092
+ if device.type == "cuda":
1093
+ torch.cuda.synchronize(device)
1094
+ t_warm = time.perf_counter() - t1
1095
+
1096
+ total = t_compile + t_warm
1097
+ if warmup:
1098
+ LOGGER.info(f"{prefix} complete in {total:.1f}s (compile {t_compile:.1f}s + warmup {t_warm:.1f}s)")
1099
+ else:
1100
+ LOGGER.info(f"{prefix} compile complete in {t_compile:.1f}s (no warmup)")
1101
+ return model