ultralytics-opencv-headless 8.3.253__py3-none-any.whl → 8.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. tests/__init__.py +2 -2
  2. tests/conftest.py +1 -1
  3. tests/test_cuda.py +8 -2
  4. tests/test_engine.py +6 -6
  5. tests/test_exports.py +10 -3
  6. tests/test_integrations.py +9 -9
  7. tests/test_python.py +14 -14
  8. tests/test_solutions.py +3 -3
  9. ultralytics/__init__.py +1 -1
  10. ultralytics/cfg/__init__.py +6 -6
  11. ultralytics/cfg/default.yaml +3 -1
  12. ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
  13. ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
  14. ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
  15. ultralytics/cfg/models/26/yolo26-p6.yaml +60 -0
  16. ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
  17. ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
  18. ultralytics/cfg/models/26/yolo26.yaml +52 -0
  19. ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
  20. ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
  21. ultralytics/data/augment.py +7 -0
  22. ultralytics/data/dataset.py +1 -1
  23. ultralytics/engine/exporter.py +10 -3
  24. ultralytics/engine/model.py +1 -1
  25. ultralytics/engine/trainer.py +40 -15
  26. ultralytics/engine/tuner.py +15 -7
  27. ultralytics/models/fastsam/predict.py +1 -1
  28. ultralytics/models/yolo/detect/train.py +3 -2
  29. ultralytics/models/yolo/detect/val.py +6 -0
  30. ultralytics/models/yolo/model.py +1 -1
  31. ultralytics/models/yolo/obb/predict.py +1 -1
  32. ultralytics/models/yolo/obb/train.py +1 -1
  33. ultralytics/models/yolo/pose/train.py +1 -1
  34. ultralytics/models/yolo/segment/predict.py +1 -1
  35. ultralytics/models/yolo/segment/train.py +1 -1
  36. ultralytics/models/yolo/segment/val.py +3 -1
  37. ultralytics/models/yolo/yoloe/train.py +6 -1
  38. ultralytics/models/yolo/yoloe/train_seg.py +6 -1
  39. ultralytics/nn/autobackend.py +7 -3
  40. ultralytics/nn/modules/__init__.py +8 -0
  41. ultralytics/nn/modules/block.py +127 -8
  42. ultralytics/nn/modules/head.py +818 -205
  43. ultralytics/nn/tasks.py +74 -29
  44. ultralytics/nn/text_model.py +5 -2
  45. ultralytics/optim/__init__.py +5 -0
  46. ultralytics/optim/muon.py +338 -0
  47. ultralytics/utils/benchmarks.py +1 -0
  48. ultralytics/utils/callbacks/platform.py +9 -7
  49. ultralytics/utils/downloads.py +3 -1
  50. ultralytics/utils/export/engine.py +19 -10
  51. ultralytics/utils/export/imx.py +22 -11
  52. ultralytics/utils/export/tensorflow.py +1 -41
  53. ultralytics/utils/loss.py +584 -203
  54. ultralytics/utils/metrics.py +1 -0
  55. ultralytics/utils/ops.py +11 -2
  56. ultralytics/utils/tal.py +98 -19
  57. {ultralytics_opencv_headless-8.3.253.dist-info → ultralytics_opencv_headless-8.4.0.dist-info}/METADATA +31 -39
  58. {ultralytics_opencv_headless-8.3.253.dist-info → ultralytics_opencv_headless-8.4.0.dist-info}/RECORD +62 -51
  59. {ultralytics_opencv_headless-8.3.253.dist-info → ultralytics_opencv_headless-8.4.0.dist-info}/WHEEL +0 -0
  60. {ultralytics_opencv_headless-8.3.253.dist-info → ultralytics_opencv_headless-8.4.0.dist-info}/entry_points.txt +0 -0
  61. {ultralytics_opencv_headless-8.3.253.dist-info → ultralytics_opencv_headless-8.4.0.dist-info}/licenses/LICENSE +0 -0
  62. {ultralytics_opencv_headless-8.3.253.dist-info → ultralytics_opencv_headless-8.4.0.dist-info}/top_level.txt +0 -0
ultralytics/nn/tasks.py CHANGED
@@ -20,6 +20,7 @@ from ultralytics.nn.modules import (
20
20
  C3TR,
21
21
  ELAN1,
22
22
  OBB,
23
+ OBB26,
23
24
  PSA,
24
25
  SPP,
25
26
  SPPELAN,
@@ -55,6 +56,7 @@ from ultralytics.nn.modules import (
55
56
  Index,
56
57
  LRPCHead,
57
58
  Pose,
59
+ Pose26,
58
60
  RepC3,
59
61
  RepConv,
60
62
  RepNCSPELAN4,
@@ -63,16 +65,19 @@ from ultralytics.nn.modules import (
63
65
  RTDETRDecoder,
64
66
  SCDown,
65
67
  Segment,
68
+ Segment26,
66
69
  TorchVision,
67
70
  WorldDetect,
68
71
  YOLOEDetect,
69
72
  YOLOESegment,
73
+ YOLOESegment26,
70
74
  v10Detect,
71
75
  )
72
76
  from ultralytics.utils import DEFAULT_CFG_DICT, LOGGER, YAML, colorstr, emojis
73
77
  from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
74
78
  from ultralytics.utils.loss import (
75
- E2EDetectLoss,
79
+ E2ELoss,
80
+ PoseLoss26,
76
81
  v8ClassificationLoss,
77
82
  v8DetectionLoss,
78
83
  v8OBBLoss,
@@ -223,7 +228,7 @@ class BaseModel(torch.nn.Module):
223
228
  Returns:
224
229
  (torch.nn.Module): The fused model is returned.
225
230
  """
226
- if not self.is_fused():
231
+ if True:
227
232
  for m in self.model.modules():
228
233
  if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, "bn"):
229
234
  if isinstance(m, Conv2):
@@ -241,7 +246,7 @@ class BaseModel(torch.nn.Module):
241
246
  if isinstance(m, RepVGGDW):
242
247
  m.fuse()
243
248
  m.forward = m.forward_fuse
244
- if isinstance(m, v10Detect):
249
+ if isinstance(m, Detect) and getattr(m, "end2end", False):
245
250
  m.fuse() # remove one2many head
246
251
  self.info(verbose=verbose)
247
252
 
@@ -386,7 +391,6 @@ class DetectionModel(BaseModel):
386
391
  self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
387
392
  self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
388
393
  self.inplace = self.yaml.get("inplace", True)
389
- self.end2end = getattr(self.model[-1], "end2end", False)
390
394
 
391
395
  # Build strides
392
396
  m = self.model[-1] # Detect()
@@ -396,9 +400,10 @@ class DetectionModel(BaseModel):
396
400
 
397
401
  def _forward(x):
398
402
  """Perform a forward pass through the model, handling different Detect subclass types accordingly."""
403
+ output = self.forward(x)
399
404
  if self.end2end:
400
- return self.forward(x)["one2many"]
401
- return self.forward(x)[0] if isinstance(m, (Segment, YOLOESegment, Pose, OBB)) else self.forward(x)
405
+ output = output["one2many"]
406
+ return output["feats"]
402
407
 
403
408
  self.model.eval() # Avoid changing batch statistics until training begins
404
409
  m.training = True # Setting it to True to properly return strides
@@ -415,6 +420,11 @@ class DetectionModel(BaseModel):
415
420
  self.info()
416
421
  LOGGER.info("")
417
422
 
423
+ @property
424
+ def end2end(self):
425
+ """Return whether the model uses end-to-end NMS-free detection."""
426
+ return getattr(self.model[-1], "end2end", False)
427
+
418
428
  def _predict_augment(self, x):
419
429
  """Perform augmentations on input image x and return augmented inference and train outputs.
420
430
 
@@ -481,7 +491,7 @@ class DetectionModel(BaseModel):
481
491
 
482
492
  def init_criterion(self):
483
493
  """Initialize the loss criterion for the DetectionModel."""
484
- return E2EDetectLoss(self) if getattr(self, "end2end", False) else v8DetectionLoss(self)
494
+ return E2ELoss(self) if getattr(self, "end2end", False) else v8DetectionLoss(self)
485
495
 
486
496
 
487
497
  class OBBModel(DetectionModel):
@@ -513,7 +523,7 @@ class OBBModel(DetectionModel):
513
523
 
514
524
  def init_criterion(self):
515
525
  """Initialize the loss criterion for the model."""
516
- return v8OBBLoss(self)
526
+ return E2ELoss(self, v8OBBLoss) if getattr(self, "end2end", False) else v8OBBLoss(self)
517
527
 
518
528
 
519
529
  class SegmentationModel(DetectionModel):
@@ -545,7 +555,7 @@ class SegmentationModel(DetectionModel):
545
555
 
546
556
  def init_criterion(self):
547
557
  """Initialize the loss criterion for the SegmentationModel."""
548
- return v8SegmentationLoss(self)
558
+ return E2ELoss(self, v8SegmentationLoss) if getattr(self, "end2end", False) else v8SegmentationLoss(self)
549
559
 
550
560
 
551
561
  class PoseModel(DetectionModel):
@@ -586,7 +596,7 @@ class PoseModel(DetectionModel):
586
596
 
587
597
  def init_criterion(self):
588
598
  """Initialize the loss criterion for the PoseModel."""
589
- return v8PoseLoss(self)
599
+ return E2ELoss(self, PoseLoss26) if getattr(self, "end2end", False) else v8PoseLoss(self)
590
600
 
591
601
 
592
602
  class ClassificationModel(BaseModel):
@@ -984,6 +994,7 @@ class YOLOEModel(DetectionModel):
984
994
  verbose (bool): Whether to display model information.
985
995
  """
986
996
  super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
997
+ self.text_model = self.yaml.get("text_model", "mobileclip:blt")
987
998
 
988
999
  @smart_inference_mode()
989
1000
  def get_text_pe(self, text, batch=80, cache_clip_model=False, without_reprta=False):
@@ -1003,9 +1014,13 @@ class YOLOEModel(DetectionModel):
1003
1014
  device = next(self.model.parameters()).device
1004
1015
  if not getattr(self, "clip_model", None) and cache_clip_model:
1005
1016
  # For backwards compatibility of models lacking clip_model attribute
1006
- self.clip_model = build_text_model("mobileclip:blt", device=device)
1017
+ self.clip_model = build_text_model(getattr(self, "text_model", "mobileclip:blt"), device=device)
1007
1018
 
1008
- model = self.clip_model if cache_clip_model else build_text_model("mobileclip:blt", device=device)
1019
+ model = (
1020
+ self.clip_model
1021
+ if cache_clip_model
1022
+ else build_text_model(getattr(self, "text_model", "mobileclip:blt"), device=device)
1023
+ )
1009
1024
  text_token = model.tokenize(text)
1010
1025
  txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]
1011
1026
  txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)
@@ -1045,10 +1060,12 @@ class YOLOEModel(DetectionModel):
1045
1060
  device = next(self.parameters()).device
1046
1061
  self(torch.empty(1, 3, self.args["imgsz"], self.args["imgsz"]).to(device)) # warmup
1047
1062
 
1063
+ cv3 = getattr(head, "one2one_cv3", head.cv3)
1064
+ cv2 = getattr(head, "one2one_cv2", head.cv2)
1065
+
1048
1066
  # re-parameterization for prompt-free model
1049
1067
  self.model[-1].lrpc = nn.ModuleList(
1050
- LRPCHead(cls, pf[-1], loc[-1], enabled=i != 2)
1051
- for i, (cls, pf, loc) in enumerate(zip(vocab, head.cv3, head.cv2))
1068
+ LRPCHead(cls, pf[-1], loc[-1], enabled=i != 2) for i, (cls, pf, loc) in enumerate(zip(vocab, cv3, cv2))
1052
1069
  )
1053
1070
  for loc_head, cls_head in zip(head.cv2, head.cv3):
1054
1071
  assert isinstance(loc_head, nn.Sequential)
@@ -1077,8 +1094,9 @@ class YOLOEModel(DetectionModel):
1077
1094
  device = next(self.model.parameters()).device
1078
1095
  head.fuse(self.pe.to(device)) # fuse prompt embeddings to classify head
1079
1096
 
1097
+ cv3 = getattr(head, "one2one_cv3", head.cv3)
1080
1098
  vocab = nn.ModuleList()
1081
- for cls_head in head.cv3:
1099
+ for cls_head in cv3:
1082
1100
  assert isinstance(cls_head, nn.Sequential)
1083
1101
  vocab.append(cls_head[-1])
1084
1102
  return vocab
@@ -1155,9 +1173,8 @@ class YOLOEModel(DetectionModel):
1155
1173
  cls_pe = self.get_cls_pe(m.get_tpe(tpe), vpe).to(device=x[0].device, dtype=x[0].dtype)
1156
1174
  if cls_pe.shape[0] != b or m.export:
1157
1175
  cls_pe = cls_pe.expand(b, -1, -1)
1158
- x = m(x, cls_pe)
1159
- else:
1160
- x = m(x) # run
1176
+ x.append(cls_pe) # adding cls embedding
1177
+ x = m(x) # run
1161
1178
 
1162
1179
  y.append(x if m.i in self.save else None) # save output
1163
1180
  if visualize:
@@ -1179,10 +1196,17 @@ class YOLOEModel(DetectionModel):
1179
1196
  from ultralytics.utils.loss import TVPDetectLoss
1180
1197
 
1181
1198
  visual_prompt = batch.get("visuals", None) is not None # TODO
1182
- self.criterion = TVPDetectLoss(self) if visual_prompt else self.init_criterion()
1183
-
1199
+ self.criterion = (
1200
+ (E2ELoss(self, TVPDetectLoss) if getattr(self, "end2end", False) else TVPDetectLoss(self))
1201
+ if visual_prompt
1202
+ else self.init_criterion()
1203
+ )
1184
1204
  if preds is None:
1185
- preds = self.forward(batch["img"], tpe=batch.get("txt_feats", None), vpe=batch.get("visuals", None))
1205
+ preds = self.forward(
1206
+ batch["img"],
1207
+ tpe=None if "visuals" in batch else batch.get("txt_feats", None),
1208
+ vpe=batch.get("visuals", None),
1209
+ )
1186
1210
  return self.criterion(preds, batch)
1187
1211
 
1188
1212
 
@@ -1224,7 +1248,11 @@ class YOLOESegModel(YOLOEModel, SegmentationModel):
1224
1248
  from ultralytics.utils.loss import TVPSegmentLoss
1225
1249
 
1226
1250
  visual_prompt = batch.get("visuals", None) is not None # TODO
1227
- self.criterion = TVPSegmentLoss(self) if visual_prompt else self.init_criterion()
1251
+ self.criterion = (
1252
+ (E2ELoss(self, TVPSegmentLoss) if getattr(self, "end2end", False) else TVPSegmentLoss(self))
1253
+ if visual_prompt
1254
+ else self.init_criterion()
1255
+ )
1228
1256
 
1229
1257
  if preds is None:
1230
1258
  preds = self.forward(batch["img"], tpe=batch.get("txt_feats", None), vpe=batch.get("visuals", None))
@@ -1499,7 +1527,8 @@ def parse_model(d, ch, verbose=True):
1499
1527
  # Args
1500
1528
  legacy = True # backward compatibility for v3/v5/v8/v9 models
1501
1529
  max_channels = float("inf")
1502
- nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
1530
+ nc, act, scales, end2end = (d.get(x) for x in ("nc", "activation", "scales", "end2end"))
1531
+ reg_max = d.get("reg_max", 16)
1503
1532
  depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
1504
1533
  scale = d.get("scale")
1505
1534
  if scales:
@@ -1624,13 +1653,29 @@ def parse_model(d, ch, verbose=True):
1624
1653
  elif m is Concat:
1625
1654
  c2 = sum(ch[x] for x in f)
1626
1655
  elif m in frozenset(
1627
- {Detect, WorldDetect, YOLOEDetect, Segment, YOLOESegment, Pose, OBB, ImagePoolingAttn, v10Detect}
1656
+ {
1657
+ Detect,
1658
+ WorldDetect,
1659
+ YOLOEDetect,
1660
+ Segment,
1661
+ Segment26,
1662
+ YOLOESegment,
1663
+ YOLOESegment26,
1664
+ Pose,
1665
+ Pose26,
1666
+ OBB,
1667
+ OBB26,
1668
+ }
1628
1669
  ):
1629
- args.append([ch[x] for x in f])
1630
- if m is Segment or m is YOLOESegment:
1670
+ args.extend([reg_max, end2end, [ch[x] for x in f]])
1671
+ if m is Segment or m is YOLOESegment or m is Segment26 or m is YOLOESegment26:
1631
1672
  args[2] = make_divisible(min(args[2], max_channels) * width, 8)
1632
- if m in {Detect, YOLOEDetect, Segment, YOLOESegment, Pose, OBB}:
1673
+ if m in {Detect, YOLOEDetect, Segment, Segment26, YOLOESegment, YOLOESegment26, Pose, Pose26, OBB, OBB26}:
1633
1674
  m.legacy = legacy
1675
+ elif m is v10Detect:
1676
+ args.append([ch[x] for x in f])
1677
+ elif m is ImagePoolingAttn:
1678
+ args.insert(1, [ch[x] for x in f]) # channels as second arg
1634
1679
  elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
1635
1680
  args.insert(1, [ch[x] for x in f])
1636
1681
  elif m is CBLinear:
@@ -1717,9 +1762,9 @@ def guess_model_task(model):
1717
1762
  return "detect"
1718
1763
  if "segment" in m:
1719
1764
  return "segment"
1720
- if m == "pose":
1765
+ if "pose" in m:
1721
1766
  return "pose"
1722
- if m == "obb":
1767
+ if "obb" in m:
1723
1768
  return "obb"
1724
1769
 
1725
1770
  # Guess from model cfg
@@ -275,7 +275,7 @@ class MobileCLIPTS(TextModel):
275
275
  >>> features = text_encoder.encode_text(tokens)
276
276
  """
277
277
 
278
- def __init__(self, device: torch.device):
278
+ def __init__(self, device: torch.device, weight: str = "mobileclip_blt.ts"):
279
279
  """Initialize the MobileCLIP TorchScript text encoder.
280
280
 
281
281
  This class implements the TextModel interface using Apple's MobileCLIP model in TorchScript format for efficient
@@ -283,11 +283,12 @@ class MobileCLIPTS(TextModel):
283
283
 
284
284
  Args:
285
285
  device (torch.device): Device to load the model on.
286
+ weight (str): Path to the TorchScript model weights.
286
287
  """
287
288
  super().__init__()
288
289
  from ultralytics.utils.downloads import attempt_download_asset
289
290
 
290
- self.encoder = torch.jit.load(attempt_download_asset("mobileclip_blt.ts"), map_location=device)
291
+ self.encoder = torch.jit.load(attempt_download_asset(weight), map_location=device)
291
292
  self.tokenizer = clip.clip.tokenize
292
293
  self.device = device
293
294
 
@@ -352,5 +353,7 @@ def build_text_model(variant: str, device: torch.device = None) -> TextModel:
352
353
  return CLIP(size, device)
353
354
  elif base == "mobileclip":
354
355
  return MobileCLIPTS(device)
356
+ elif base == "mobileclip2":
357
+ return MobileCLIPTS(device, weight="mobileclip2_b.ts")
355
358
  else:
356
359
  raise ValueError(f"Unrecognized base model: '{base}'. Supported base models: 'clip', 'mobileclip'.")
@@ -0,0 +1,5 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from .muon import Muon, MuSGD
4
+
5
+ __all__ = ["MuSGD", "Muon"]
@@ -0,0 +1,338 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ from torch import optim
7
+
8
+
9
+ def zeropower_via_newtonschulz5(G: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
10
+ """Compute the zeroth power / orthogonalization of matrix G using Newton-Schulz iteration.
11
+
12
+ This function implements a quintic Newton-Schulz iteration to compute an approximate orthogonalization of the input
13
+ matrix G. The iteration coefficients are optimized to maximize convergence slope at zero, producing a result similar
14
+ to UV^T from SVD, where USV^T = G, but with relaxed convergence guarantees that empirically work well for
15
+ optimization purposes.
16
+
17
+ Args:
18
+ G (torch.Tensor): Input 2D tensor/matrix to orthogonalize.
19
+ eps (float, optional): Small epsilon value added to norm for numerical stability. Default: 1e-7.
20
+
21
+ Returns:
22
+ (torch.Tensor): Orthogonalized matrix with same shape as input G.
23
+
24
+ Examples:
25
+ >>> G = torch.randn(128, 64)
26
+ >>> G_ortho = zeropower_via_newtonschulz5(G)
27
+ >>> print(G_ortho.shape)
28
+ torch.Size([128, 64])
29
+
30
+ Notes:
31
+ - Uses bfloat16 precision for computation.
32
+ - Performs exactly 5 Newton-Schulz iteration steps with fixed coefficients.
33
+ - Automatically transposes for efficiency when rows > columns.
34
+ - Output approximates US'V^T where S' has diagonal entries ~ Uniform(0.5, 1.5).
35
+ - Does not produce exact UV^T but works well empirically for neural network optimization.
36
+ """
37
+ assert len(G.shape) == 2
38
+ X = G.bfloat16()
39
+ X /= X.norm() + eps # ensure top singular value <= 1
40
+ if G.size(0) > G.size(1):
41
+ X = X.T
42
+ for a, b, c in [ # num_steps fixed at 5
43
+ # original params
44
+ (3.4445, -4.7750, 2.0315),
45
+ (3.4445, -4.7750, 2.0315),
46
+ (3.4445, -4.7750, 2.0315),
47
+ (3.4445, -4.7750, 2.0315),
48
+ (3.4445, -4.7750, 2.0315),
49
+ ]:
50
+ # for _ in range(steps):
51
+ A = X @ X.T
52
+ B = b * A + c * A @ A
53
+ X = a * X + B @ X
54
+ if G.size(0) > G.size(1):
55
+ X = X.T
56
+ return X
57
+
58
+
59
+ def muon_update(grad: torch.Tensor, momentum: torch.Tensor, beta: float = 0.95, nesterov: bool = True) -> torch.Tensor:
60
+ """Compute Muon optimizer update with momentum and orthogonalization.
61
+
62
+ This function applies momentum to the gradient, optionally uses Nesterov acceleration, and then orthogonalizes the
63
+ update using Newton-Schulz iterations. For convolutional filters (4D tensors), it reshapes before orthogonalization
64
+ and scales the final update based on parameter dimensions.
65
+
66
+ Args:
67
+ grad (torch.Tensor): Gradient tensor to update. Can be 2D or 4D (for conv filters).
68
+ momentum (torch.Tensor): Momentum buffer tensor, modified in-place via lerp.
69
+ beta (float, optional): Momentum coefficient for exponential moving average. Default: 0.95.
70
+ nesterov (bool, optional): Whether to use Nesterov momentum acceleration. Default: True.
71
+
72
+ Returns:
73
+ (torch.Tensor): Orthogonalized update tensor with same shape as input grad. For 4D inputs, returns reshaped
74
+ result matching original dimensions.
75
+
76
+ Examples:
77
+ >>> grad = torch.randn(64, 128)
78
+ >>> momentum = torch.zeros_like(grad)
79
+ >>> update = muon_update(grad, momentum, beta=0.95, nesterov=True)
80
+ >>> print(update.shape)
81
+ torch.Size([64, 128])
82
+
83
+ Notes:
84
+ - Momentum buffer is updated in-place: momentum = beta * momentum + (1-beta) * grad.
85
+ - With Nesterov: update = beta * momentum + (1-beta) * grad.
86
+ - Without Nesterov: update = momentum.
87
+ - 4D tensors (conv filters) are reshaped to 2D as (channels, height*width*depth) for orthogonalization.
88
+ - Final update is scaled by sqrt(max(dim[-2], dim[-1])) to account for parameter dimensions.
89
+ """
90
+ momentum.lerp_(grad, 1 - beta)
91
+ update = grad.lerp(momentum, beta) if nesterov else momentum
92
+ if update.ndim == 4: # for the case of conv filters
93
+ update = update.view(len(update), -1)
94
+ update = zeropower_via_newtonschulz5(update)
95
+ update *= max(1, grad.size(-2) / grad.size(-1)) ** 0.5
96
+ return update
97
+
98
+
99
+ class MuSGD(optim.Optimizer):
100
+ """Hybrid optimizer combining Muon and SGD updates for neural network training.
101
+
102
+ This optimizer implements a combination of Muon (a momentum-based optimizer with orthogonalization via Newton-Schulz
103
+ iterations) and standard SGD with momentum. It allows different parameter groups to use either the hybrid Muon+SGD
104
+ approach or pure SGD.
105
+
106
+ Args:
107
+ param_groups (list): List of parameter groups with their optimization settings.
108
+ muon (float, optional): Weight factor for Muon updates in hybrid mode. Default: 0.5.
109
+ sgd (float, optional): Weight factor for SGD updates in hybrid mode. Default: 0.5.
110
+
111
+ Attributes:
112
+ muon (float): Scaling factor applied to Muon learning rate.
113
+ sgd (float): Scaling factor applied to SGD learning rate in hybrid mode.
114
+
115
+ Examples:
116
+ >>> param_groups = [
117
+ ... {
118
+ ... "params": model.conv_params,
119
+ ... "lr": 0.02,
120
+ ... "use_muon": True,
121
+ ... "momentum": 0.95,
122
+ ... "nesterov": True,
123
+ ... "weight_decay": 0.01,
124
+ ... },
125
+ ... {
126
+ ... "params": model.other_params,
127
+ ... "lr": 0.01,
128
+ ... "use_muon": False,
129
+ ... "momentum": 0.9,
130
+ ... "nesterov": False,
131
+ ... "weight_decay": 0,
132
+ ... },
133
+ ... ]
134
+ >>> optimizer = MuSGD(param_groups, muon=0.5, sgd=0.5)
135
+ >>> loss = model(data)
136
+ >>> loss.backward()
137
+ >>> optimizer.step()
138
+
139
+ Notes:
140
+ - Parameter groups with 'use_muon': True will receive both Muon and SGD updates.
141
+ - Parameter groups with 'use_muon': False will receive only SGD updates.
142
+ - The Muon update uses orthogonalization which works best for 2D+ parameter tensors.
143
+ """
144
+
145
+ def __init__(
146
+ self,
147
+ params,
148
+ lr: float = 1e-3,
149
+ momentum: float = 0.0,
150
+ weight_decay: float = 0.0,
151
+ nesterov: bool = False,
152
+ use_muon: bool = False,
153
+ muon: float = 0.5,
154
+ sgd: float = 0.5,
155
+ ):
156
+ """Initialize MuSGD optimizer with hybrid Muon and SGD capabilities.
157
+
158
+ Args:
159
+ params: Iterable of parameters to optimize or dicts defining parameter groups.
160
+ lr (float): Learning rate.
161
+ momentum (float): Momentum factor for SGD.
162
+ weight_decay (float): Weight decay (L2 penalty).
163
+ nesterov (bool): Whether to use Nesterov momentum.
164
+ use_muon (bool): Whether to enable Muon updates.
165
+ muon (float): Scaling factor for Muon component.
166
+ sgd (float): Scaling factor for SGD component.
167
+ """
168
+ defaults = dict(
169
+ lr=lr,
170
+ momentum=momentum,
171
+ weight_decay=weight_decay,
172
+ nesterov=nesterov,
173
+ use_muon=use_muon,
174
+ )
175
+ super().__init__(params, defaults)
176
+ self.muon = muon
177
+ self.sgd = sgd
178
+
179
+ @torch.no_grad()
180
+ def step(self, closure=None):
181
+ """Perform a single optimization step.
182
+
183
+ Applies either hybrid Muon+SGD updates or pure SGD updates depending on the
184
+ 'use_muon' flag in each parameter group. For Muon-enabled groups, parameters
185
+ receive both an orthogonalized Muon update and a standard SGD momentum update.
186
+
187
+ Args:
188
+ closure (Callable, optional): A closure that reevaluates the model
189
+ and returns the loss. Default: None.
190
+
191
+ Returns:
192
+ (torch.Tensor | None): The loss value if closure is provided, otherwise None.
193
+
194
+ Notes:
195
+ - Parameters with None gradients are assigned zero gradients for synchronization.
196
+ - Muon updates use Newton-Schulz orthogonalization and work best on 2D+ tensors.
197
+ - Weight decay is applied only to the SGD component in hybrid mode.
198
+ """
199
+ loss = None
200
+ if closure is not None:
201
+ with torch.enable_grad():
202
+ loss = closure()
203
+
204
+ for group in self.param_groups:
205
+ # Muon
206
+ if group["use_muon"]:
207
+ # generate weight updates in distributed fashion
208
+ for p in group["params"]:
209
+ lr = group["lr"]
210
+ if p.grad is None:
211
+ continue
212
+ grad = p.grad
213
+ state = self.state[p]
214
+ if len(state) == 0:
215
+ state["momentum_buffer"] = torch.zeros_like(p)
216
+ state["momentum_buffer_SGD"] = torch.zeros_like(p)
217
+
218
+ update = muon_update(
219
+ grad, state["momentum_buffer"], beta=group["momentum"], nesterov=group["nesterov"]
220
+ )
221
+ p.add_(update.reshape(p.shape), alpha=-(lr * self.muon))
222
+
223
+ # SGD update
224
+ if group["weight_decay"] != 0:
225
+ grad = grad.add(p, alpha=group["weight_decay"])
226
+ state["momentum_buffer_SGD"].mul_(group["momentum"]).add_(grad)
227
+ sgd_update = (
228
+ grad.add(state["momentum_buffer_SGD"], alpha=group["momentum"])
229
+ if group["nesterov"]
230
+ else state["momentum_buffer_SGD"]
231
+ )
232
+ p.add_(sgd_update, alpha=-(lr * self.sgd))
233
+ else: # SGD
234
+ for p in group["params"]:
235
+ lr = group["lr"]
236
+ if p.grad is None:
237
+ continue
238
+ grad = p.grad
239
+ if group["weight_decay"] != 0:
240
+ grad = grad.add(p, alpha=group["weight_decay"])
241
+ state = self.state[p]
242
+ if len(state) == 0:
243
+ state["momentum_buffer"] = torch.zeros_like(p)
244
+ state["momentum_buffer"].mul_(group["momentum"]).add_(grad)
245
+ update = (
246
+ grad.add(state["momentum_buffer"], alpha=group["momentum"])
247
+ if group["nesterov"]
248
+ else state["momentum_buffer"]
249
+ )
250
+ p.add_(update, alpha=-lr)
251
+ return loss
252
+
253
+
254
+ class Muon(optim.Optimizer):
255
+ """Muon optimizer for usage in non-distributed settings.
256
+
257
+ This optimizer implements the Muon algorithm, which combines momentum-based updates with orthogonalization via
258
+ Newton-Schulz iterations. It applies weight decay and learning rate scaling to parameter updates.
259
+
260
+ Args:
261
+ params (iterable): Iterable of parameters to optimize or dicts defining parameter groups.
262
+ lr (float, optional): Learning rate. Default: 0.02.
263
+ weight_decay (float, optional): Weight decay (L2 penalty) coefficient. Default: 0.
264
+ momentum (float, optional): Momentum coefficient for exponential moving average. Default: 0.95.
265
+
266
+ Attributes:
267
+ param_groups (list): List of parameter groups with their optimization settings.
268
+ state (dict): Dictionary containing optimizer state for each parameter.
269
+
270
+ Examples:
271
+ >>> model = YourModel()
272
+ >>> optimizer = Muon(model.parameters(), lr=0.02, weight_decay=0.01, momentum=0.95)
273
+ >>> loss = model(data)
274
+ >>> loss.backward()
275
+ >>> optimizer.step()
276
+
277
+ Notes:
278
+ - Designed for non-distributed training environments.
279
+ - Uses Muon updates with orthogonalization for all parameters.
280
+ - Weight decay is applied multiplicatively before parameter update.
281
+ - Parameters with None gradients are assigned zero gradients for synchronization.
282
+ """
283
+
284
+ def __init__(self, params, lr: float = 0.02, weight_decay: float = 0, momentum: float = 0.95):
285
+ """Initialize Muon optimizer with orthogonalization-based updates.
286
+
287
+ Args:
288
+ params: Iterable of parameters to optimize or dicts defining parameter groups.
289
+ lr (float): Learning rate.
290
+ weight_decay (float): Weight decay factor applied multiplicatively.
291
+ momentum (float): Momentum factor for gradient accumulation.
292
+ """
293
+ defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
294
+ super().__init__(params, defaults)
295
+
296
+ @torch.no_grad()
297
+ def step(self, closure=None):
298
+ """Perform a single optimization step.
299
+
300
+ Applies Muon updates to all parameters, incorporating momentum and orthogonalization.
301
+ Weight decay is applied multiplicatively before the parameter update.
302
+
303
+ Args:
304
+ closure (Callable[[], torch.Tensor] | None, optional): A closure that reevaluates the model
305
+ and returns the loss. Default: None.
306
+
307
+ Returns:
308
+ (torch.Tensor | None): The loss value if closure is provided, otherwise None.
309
+
310
+ Examples:
311
+ >>> optimizer = Muon(model.parameters())
312
+ >>> loss = model(inputs)
313
+ >>> loss.backward()
314
+ >>> optimizer.step()
315
+
316
+ Notes:
317
+ - Parameters with None gradients are assigned zero gradients for synchronization.
318
+ - Weight decay is applied as: p *= (1 - lr * weight_decay).
319
+ - Muon update uses Newton-Schulz orthogonalization and works best on 2D+ tensors.
320
+ """
321
+ loss = None
322
+ if closure is not None:
323
+ with torch.enable_grad():
324
+ loss = closure()
325
+
326
+ for group in self.param_groups:
327
+ for p in group["params"]:
328
+ if p.grad is None:
329
+ # continue
330
+ p.grad = torch.zeros_like(p) # Force synchronization
331
+ state = self.state[p]
332
+ if len(state) == 0:
333
+ state["momentum_buffer"] = torch.zeros_like(p)
334
+ update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
335
+ p.mul_(1 - group["lr"] * group["weight_decay"])
336
+ p.add_(update.reshape(p.shape), alpha=-group["lr"])
337
+
338
+ return loss
@@ -141,6 +141,7 @@ def benchmark(
141
141
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 MNN exports not supported yet"
142
142
  if format == "ncnn":
143
143
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet"
144
+ assert not ARM64, "NCNN not supported on ARM64" # https://github.com/Tencent/ncnn/issues/6509
144
145
  if format == "imx":
145
146
  assert not is_end2end
146
147
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 IMX exports not supported"