dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
- dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -9
- tests/conftest.py +8 -15
- tests/test_cli.py +1 -1
- tests/test_cuda.py +13 -10
- tests/test_engine.py +9 -9
- tests/test_exports.py +65 -13
- tests/test_integrations.py +13 -13
- tests/test_python.py +125 -69
- tests/test_solutions.py +161 -152
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +86 -92
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/TT100K.yaml +346 -0
- ultralytics/cfg/datasets/VOC.yaml +15 -16
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +21 -0
- ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +4 -2
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
- ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
- ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
- ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
- ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
- ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
- ultralytics/cfg/models/26/yolo26.yaml +52 -0
- ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
- ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +5 -6
- ultralytics/data/augment.py +300 -475
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +147 -25
- ultralytics/data/converter.py +108 -87
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +36 -45
- ultralytics/engine/exporter.py +351 -263
- ultralytics/engine/model.py +186 -225
- ultralytics/engine/predictor.py +45 -54
- ultralytics/engine/results.py +198 -325
- ultralytics/engine/trainer.py +165 -106
- ultralytics/engine/tuner.py +41 -43
- ultralytics/engine/validator.py +55 -38
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +18 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +19 -37
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +200 -73
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2248 -350
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +10 -13
- ultralytics/models/yolo/classify/train.py +12 -33
- ultralytics/models/yolo/classify/val.py +30 -29
- ultralytics/models/yolo/detect/predict.py +9 -12
- ultralytics/models/yolo/detect/train.py +17 -23
- ultralytics/models/yolo/detect/val.py +77 -59
- ultralytics/models/yolo/model.py +43 -60
- ultralytics/models/yolo/obb/predict.py +7 -16
- ultralytics/models/yolo/obb/train.py +14 -17
- ultralytics/models/yolo/obb/val.py +40 -37
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +13 -16
- ultralytics/models/yolo/pose/val.py +39 -58
- ultralytics/models/yolo/segment/predict.py +17 -21
- ultralytics/models/yolo/segment/train.py +7 -10
- ultralytics/models/yolo/segment/val.py +95 -47
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +36 -44
- ultralytics/models/yolo/yoloe/train_seg.py +11 -11
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +159 -85
- ultralytics/nn/modules/__init__.py +68 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +260 -224
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +831 -299
- ultralytics/nn/modules/transformer.py +76 -88
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +180 -195
- ultralytics/nn/text_model.py +45 -69
- ultralytics/optim/__init__.py +5 -0
- ultralytics/optim/muon.py +338 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +13 -19
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +6 -7
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +8 -14
- ultralytics/solutions/instance_segmentation.py +6 -9
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +34 -32
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +10 -15
- ultralytics/solutions/solutions.py +77 -76
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +21 -37
- ultralytics/trackers/track.py +4 -7
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +124 -124
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +57 -71
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +423 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +25 -31
- ultralytics/utils/callbacks/wb.py +16 -14
- ultralytics/utils/checks.py +127 -85
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +9 -12
- ultralytics/utils/downloads.py +25 -33
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -236
- ultralytics/utils/export/engine.py +246 -0
- ultralytics/utils/export/imx.py +117 -63
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +26 -30
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +601 -215
- ultralytics/utils/metrics.py +128 -156
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +117 -166
- ultralytics/utils/patches.py +75 -21
- ultralytics/utils/plotting.py +75 -80
- ultralytics/utils/tal.py +125 -59
- ultralytics/utils/torch_utils.py +53 -79
- ultralytics/utils/tqdm.py +24 -21
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +19 -10
- dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
ultralytics/nn/text_model.py
CHANGED
|
@@ -20,8 +20,7 @@ except ImportError:
|
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class TextModel(nn.Module):
|
|
23
|
-
"""
|
|
24
|
-
Abstract base class for text encoding models.
|
|
23
|
+
"""Abstract base class for text encoding models.
|
|
25
24
|
|
|
26
25
|
This class defines the interface for text encoding models used in vision-language tasks. Subclasses must implement
|
|
27
26
|
the tokenize and encode_text methods to provide text tokenization and encoding functionality.
|
|
@@ -47,11 +46,10 @@ class TextModel(nn.Module):
|
|
|
47
46
|
|
|
48
47
|
|
|
49
48
|
class CLIP(TextModel):
|
|
50
|
-
"""
|
|
51
|
-
Implements OpenAI's CLIP (Contrastive Language-Image Pre-training) text encoder.
|
|
49
|
+
"""Implements OpenAI's CLIP (Contrastive Language-Image Pre-training) text encoder.
|
|
52
50
|
|
|
53
|
-
This class provides a text encoder based on OpenAI's CLIP model, which can convert text into feature vectors
|
|
54
|
-
|
|
51
|
+
This class provides a text encoder based on OpenAI's CLIP model, which can convert text into feature vectors that
|
|
52
|
+
are aligned with corresponding image features in a shared embedding space.
|
|
55
53
|
|
|
56
54
|
Attributes:
|
|
57
55
|
model (clip.model.CLIP): The loaded CLIP model.
|
|
@@ -71,20 +69,14 @@ class CLIP(TextModel):
|
|
|
71
69
|
"""
|
|
72
70
|
|
|
73
71
|
def __init__(self, size: str, device: torch.device) -> None:
|
|
74
|
-
"""
|
|
75
|
-
Initialize the CLIP text encoder.
|
|
72
|
+
"""Initialize the CLIP text encoder.
|
|
76
73
|
|
|
77
|
-
This class implements the TextModel interface using OpenAI's CLIP model for text encoding. It loads
|
|
78
|
-
|
|
74
|
+
This class implements the TextModel interface using OpenAI's CLIP model for text encoding. It loads a
|
|
75
|
+
pre-trained CLIP model of the specified size and prepares it for text encoding tasks.
|
|
79
76
|
|
|
80
77
|
Args:
|
|
81
78
|
size (str): Model size identifier (e.g., 'ViT-B/32').
|
|
82
79
|
device (torch.device): Device to load the model on.
|
|
83
|
-
|
|
84
|
-
Examples:
|
|
85
|
-
>>> import torch
|
|
86
|
-
>>> clip_model = CLIP("ViT-B/32", device=torch.device("cuda:0"))
|
|
87
|
-
>>> text_features = clip_model.encode_text(["a photo of a cat", "a photo of a dog"])
|
|
88
80
|
"""
|
|
89
81
|
super().__init__()
|
|
90
82
|
self.model, self.image_preprocess = clip.load(size, device=device)
|
|
@@ -92,12 +84,13 @@ class CLIP(TextModel):
|
|
|
92
84
|
self.device = device
|
|
93
85
|
self.eval()
|
|
94
86
|
|
|
95
|
-
def tokenize(self, texts: str | list[str]) -> torch.Tensor:
|
|
96
|
-
"""
|
|
97
|
-
Convert input texts to CLIP tokens.
|
|
87
|
+
def tokenize(self, texts: str | list[str], truncate: bool = True) -> torch.Tensor:
|
|
88
|
+
"""Convert input texts to CLIP tokens.
|
|
98
89
|
|
|
99
90
|
Args:
|
|
100
91
|
texts (str | list[str]): Input text or list of texts to tokenize.
|
|
92
|
+
truncate (bool, optional): Whether to trim texts that exceed CLIP's context length. Defaults to True to
|
|
93
|
+
avoid RuntimeError from overly long inputs while still allowing explicit opt-out.
|
|
101
94
|
|
|
102
95
|
Returns:
|
|
103
96
|
(torch.Tensor): Tokenized text tensor with shape (batch_size, context_length) ready for model processing.
|
|
@@ -106,13 +99,14 @@ class CLIP(TextModel):
|
|
|
106
99
|
>>> model = CLIP("ViT-B/32", device="cpu")
|
|
107
100
|
>>> tokens = model.tokenize("a photo of a cat")
|
|
108
101
|
>>> print(tokens.shape) # torch.Size([1, 77])
|
|
102
|
+
>>> strict_tokens = model.tokenize("a photo of a cat", truncate=False) # Enforce strict length checks
|
|
103
|
+
>>> print(strict_tokens.shape) # Same shape/content as tokens since prompt less than 77 tokens
|
|
109
104
|
"""
|
|
110
|
-
return clip.tokenize(texts).to(self.device)
|
|
105
|
+
return clip.tokenize(texts, truncate=truncate).to(self.device)
|
|
111
106
|
|
|
112
107
|
@smart_inference_mode()
|
|
113
108
|
def encode_text(self, texts: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
|
114
|
-
"""
|
|
115
|
-
Encode tokenized texts into normalized feature vectors.
|
|
109
|
+
"""Encode tokenized texts into normalized feature vectors.
|
|
116
110
|
|
|
117
111
|
This method processes tokenized text inputs through the CLIP model to generate feature vectors, which are then
|
|
118
112
|
normalized to unit length. These normalized vectors can be used for text-image similarity comparisons.
|
|
@@ -137,15 +131,14 @@ class CLIP(TextModel):
|
|
|
137
131
|
|
|
138
132
|
@smart_inference_mode()
|
|
139
133
|
def encode_image(self, image: Image.Image | torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
|
140
|
-
"""
|
|
141
|
-
Encode preprocessed images into normalized feature vectors.
|
|
134
|
+
"""Encode preprocessed images into normalized feature vectors.
|
|
142
135
|
|
|
143
|
-
This method processes preprocessed image inputs through the CLIP model to generate feature vectors, which are
|
|
144
|
-
normalized to unit length. These normalized vectors can be used for text-image similarity comparisons.
|
|
136
|
+
This method processes preprocessed image inputs through the CLIP model to generate feature vectors, which are
|
|
137
|
+
then normalized to unit length. These normalized vectors can be used for text-image similarity comparisons.
|
|
145
138
|
|
|
146
139
|
Args:
|
|
147
|
-
image (PIL.Image | torch.Tensor): Preprocessed image input. If a PIL Image is provided, it will be
|
|
148
|
-
|
|
140
|
+
image (PIL.Image | torch.Tensor): Preprocessed image input. If a PIL Image is provided, it will be converted
|
|
141
|
+
to a tensor using the model's image preprocessing function.
|
|
149
142
|
dtype (torch.dtype, optional): Data type for output features.
|
|
150
143
|
|
|
151
144
|
Returns:
|
|
@@ -169,8 +162,7 @@ class CLIP(TextModel):
|
|
|
169
162
|
|
|
170
163
|
|
|
171
164
|
class MobileCLIP(TextModel):
|
|
172
|
-
"""
|
|
173
|
-
Implement Apple's MobileCLIP text encoder for efficient text encoding.
|
|
165
|
+
"""Implement Apple's MobileCLIP text encoder for efficient text encoding.
|
|
174
166
|
|
|
175
167
|
This class implements the TextModel interface using Apple's MobileCLIP model, providing efficient text encoding
|
|
176
168
|
capabilities for vision-language tasks with reduced computational requirements compared to standard CLIP models.
|
|
@@ -195,28 +187,16 @@ class MobileCLIP(TextModel):
|
|
|
195
187
|
config_size_map = {"s0": "s0", "s1": "s1", "s2": "s2", "b": "b", "blt": "b"}
|
|
196
188
|
|
|
197
189
|
def __init__(self, size: str, device: torch.device) -> None:
|
|
198
|
-
"""
|
|
199
|
-
Initialize the MobileCLIP text encoder.
|
|
190
|
+
"""Initialize the MobileCLIP text encoder.
|
|
200
191
|
|
|
201
192
|
This class implements the TextModel interface using Apple's MobileCLIP model for efficient text encoding.
|
|
202
193
|
|
|
203
194
|
Args:
|
|
204
195
|
size (str): Model size identifier (e.g., 's0', 's1', 's2', 'b', 'blt').
|
|
205
196
|
device (torch.device): Device to load the model on.
|
|
206
|
-
|
|
207
|
-
Examples:
|
|
208
|
-
>>> import torch
|
|
209
|
-
>>> model = MobileCLIP("s0", device=torch.device("cpu"))
|
|
210
|
-
>>> tokens = model.tokenize(["a photo of a cat", "a photo of a dog"])
|
|
211
|
-
>>> features = model.encode_text(tokens)
|
|
212
197
|
"""
|
|
213
198
|
try:
|
|
214
|
-
import
|
|
215
|
-
|
|
216
|
-
# Suppress 'timm.models.layers is deprecated, please import via timm.layers' warning from mobileclip usage
|
|
217
|
-
with warnings.catch_warnings():
|
|
218
|
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
219
|
-
import mobileclip
|
|
199
|
+
import mobileclip
|
|
220
200
|
except ImportError:
|
|
221
201
|
# Ultralytics fork preferred since Apple MobileCLIP repo has incorrect version of torchvision
|
|
222
202
|
checks.check_requirements("git+https://github.com/ultralytics/mobileclip.git")
|
|
@@ -236,8 +216,7 @@ class MobileCLIP(TextModel):
|
|
|
236
216
|
self.eval()
|
|
237
217
|
|
|
238
218
|
def tokenize(self, texts: list[str]) -> torch.Tensor:
|
|
239
|
-
"""
|
|
240
|
-
Convert input texts to MobileCLIP tokens.
|
|
219
|
+
"""Convert input texts to MobileCLIP tokens.
|
|
241
220
|
|
|
242
221
|
Args:
|
|
243
222
|
texts (list[str]): List of text strings to tokenize.
|
|
@@ -253,8 +232,7 @@ class MobileCLIP(TextModel):
|
|
|
253
232
|
|
|
254
233
|
@smart_inference_mode()
|
|
255
234
|
def encode_text(self, texts: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
|
256
|
-
"""
|
|
257
|
-
Encode tokenized texts into normalized feature vectors.
|
|
235
|
+
"""Encode tokenized texts into normalized feature vectors.
|
|
258
236
|
|
|
259
237
|
Args:
|
|
260
238
|
texts (torch.Tensor): Tokenized text inputs.
|
|
@@ -276,8 +254,7 @@ class MobileCLIP(TextModel):
|
|
|
276
254
|
|
|
277
255
|
|
|
278
256
|
class MobileCLIPTS(TextModel):
|
|
279
|
-
"""
|
|
280
|
-
Load a TorchScript traced version of MobileCLIP.
|
|
257
|
+
"""Load a TorchScript traced version of MobileCLIP.
|
|
281
258
|
|
|
282
259
|
This class implements the TextModel interface using Apple's MobileCLIP model in TorchScript format, providing
|
|
283
260
|
efficient text encoding capabilities for vision-language tasks with optimized inference performance.
|
|
@@ -298,48 +275,46 @@ class MobileCLIPTS(TextModel):
|
|
|
298
275
|
>>> features = text_encoder.encode_text(tokens)
|
|
299
276
|
"""
|
|
300
277
|
|
|
301
|
-
def __init__(self, device: torch.device):
|
|
302
|
-
"""
|
|
303
|
-
Initialize the MobileCLIP TorchScript text encoder.
|
|
278
|
+
def __init__(self, device: torch.device, weight: str = "mobileclip_blt.ts"):
|
|
279
|
+
"""Initialize the MobileCLIP TorchScript text encoder.
|
|
304
280
|
|
|
305
|
-
This class implements the TextModel interface using Apple's MobileCLIP model in TorchScript format for
|
|
306
|
-
|
|
281
|
+
This class implements the TextModel interface using Apple's MobileCLIP model in TorchScript format for efficient
|
|
282
|
+
text encoding with optimized inference performance.
|
|
307
283
|
|
|
308
284
|
Args:
|
|
309
285
|
device (torch.device): Device to load the model on.
|
|
310
|
-
|
|
311
|
-
Examples:
|
|
312
|
-
>>> model = MobileCLIPTS(device=torch.device("cpu"))
|
|
313
|
-
>>> tokens = model.tokenize(["a photo of a cat", "a photo of a dog"])
|
|
314
|
-
>>> features = model.encode_text(tokens)
|
|
286
|
+
weight (str): Path to the TorchScript model weights.
|
|
315
287
|
"""
|
|
316
288
|
super().__init__()
|
|
317
289
|
from ultralytics.utils.downloads import attempt_download_asset
|
|
318
290
|
|
|
319
|
-
self.encoder = torch.jit.load(attempt_download_asset(
|
|
291
|
+
self.encoder = torch.jit.load(attempt_download_asset(weight), map_location=device)
|
|
320
292
|
self.tokenizer = clip.clip.tokenize
|
|
321
293
|
self.device = device
|
|
322
294
|
|
|
323
|
-
def tokenize(self, texts: list[str]) -> torch.Tensor:
|
|
324
|
-
"""
|
|
325
|
-
Convert input texts to MobileCLIP tokens.
|
|
295
|
+
def tokenize(self, texts: list[str], truncate: bool = True) -> torch.Tensor:
|
|
296
|
+
"""Convert input texts to MobileCLIP tokens.
|
|
326
297
|
|
|
327
298
|
Args:
|
|
328
299
|
texts (list[str]): List of text strings to tokenize.
|
|
300
|
+
truncate (bool, optional): Whether to trim texts that exceed the tokenizer context length. Defaults to True,
|
|
301
|
+
matching CLIP's behavior to prevent runtime failures on long captions.
|
|
329
302
|
|
|
330
303
|
Returns:
|
|
331
304
|
(torch.Tensor): Tokenized text inputs with shape (batch_size, sequence_length).
|
|
332
305
|
|
|
333
306
|
Examples:
|
|
334
|
-
>>> model = MobileCLIPTS("cpu")
|
|
307
|
+
>>> model = MobileCLIPTS(device=torch.device("cpu"))
|
|
335
308
|
>>> tokens = model.tokenize(["a photo of a cat", "a photo of a dog"])
|
|
309
|
+
>>> strict_tokens = model.tokenize(
|
|
310
|
+
... ["a very long caption"], truncate=False
|
|
311
|
+
... ) # RuntimeError if exceeds 77-token
|
|
336
312
|
"""
|
|
337
|
-
return self.tokenizer(texts).to(self.device)
|
|
313
|
+
return self.tokenizer(texts, truncate=truncate).to(self.device)
|
|
338
314
|
|
|
339
315
|
@smart_inference_mode()
|
|
340
316
|
def encode_text(self, texts: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
|
341
|
-
"""
|
|
342
|
-
Encode tokenized texts into normalized feature vectors.
|
|
317
|
+
"""Encode tokenized texts into normalized feature vectors.
|
|
343
318
|
|
|
344
319
|
Args:
|
|
345
320
|
texts (torch.Tensor): Tokenized text inputs.
|
|
@@ -360,8 +335,7 @@ class MobileCLIPTS(TextModel):
|
|
|
360
335
|
|
|
361
336
|
|
|
362
337
|
def build_text_model(variant: str, device: torch.device = None) -> TextModel:
|
|
363
|
-
"""
|
|
364
|
-
Build a text encoding model based on the specified variant.
|
|
338
|
+
"""Build a text encoding model based on the specified variant.
|
|
365
339
|
|
|
366
340
|
Args:
|
|
367
341
|
variant (str): Model variant in format "base:size" (e.g., "clip:ViT-B/32" or "mobileclip:s0").
|
|
@@ -379,5 +353,7 @@ def build_text_model(variant: str, device: torch.device = None) -> TextModel:
|
|
|
379
353
|
return CLIP(size, device)
|
|
380
354
|
elif base == "mobileclip":
|
|
381
355
|
return MobileCLIPTS(device)
|
|
356
|
+
elif base == "mobileclip2":
|
|
357
|
+
return MobileCLIPTS(device, weight="mobileclip2_b.ts")
|
|
382
358
|
else:
|
|
383
359
|
raise ValueError(f"Unrecognized base model: '{base}'. Supported base models: 'clip', 'mobileclip'.")
|
|
@@ -0,0 +1,338 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import optim
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def zeropower_via_newtonschulz5(G: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
|
|
10
|
+
"""Compute the zeroth power / orthogonalization of matrix G using Newton-Schulz iteration.
|
|
11
|
+
|
|
12
|
+
This function implements a quintic Newton-Schulz iteration to compute an approximate orthogonalization of the input
|
|
13
|
+
matrix G. The iteration coefficients are optimized to maximize convergence slope at zero, producing a result similar
|
|
14
|
+
to UV^T from SVD, where USV^T = G, but with relaxed convergence guarantees that empirically work well for
|
|
15
|
+
optimization purposes.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
G (torch.Tensor): Input 2D tensor/matrix to orthogonalize.
|
|
19
|
+
eps (float, optional): Small epsilon value added to norm for numerical stability. Default: 1e-7.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
(torch.Tensor): Orthogonalized matrix with same shape as input G.
|
|
23
|
+
|
|
24
|
+
Examples:
|
|
25
|
+
>>> G = torch.randn(128, 64)
|
|
26
|
+
>>> G_ortho = zeropower_via_newtonschulz5(G)
|
|
27
|
+
>>> print(G_ortho.shape)
|
|
28
|
+
torch.Size([128, 64])
|
|
29
|
+
|
|
30
|
+
Notes:
|
|
31
|
+
- Uses bfloat16 precision for computation.
|
|
32
|
+
- Performs exactly 5 Newton-Schulz iteration steps with fixed coefficients.
|
|
33
|
+
- Automatically transposes for efficiency when rows > columns.
|
|
34
|
+
- Output approximates US'V^T where S' has diagonal entries ~ Uniform(0.5, 1.5).
|
|
35
|
+
- Does not produce exact UV^T but works well empirically for neural network optimization.
|
|
36
|
+
"""
|
|
37
|
+
assert len(G.shape) == 2
|
|
38
|
+
X = G.bfloat16()
|
|
39
|
+
X /= X.norm() + eps # ensure top singular value <= 1
|
|
40
|
+
if G.size(0) > G.size(1):
|
|
41
|
+
X = X.T
|
|
42
|
+
for a, b, c in [ # num_steps fixed at 5
|
|
43
|
+
# original params
|
|
44
|
+
(3.4445, -4.7750, 2.0315),
|
|
45
|
+
(3.4445, -4.7750, 2.0315),
|
|
46
|
+
(3.4445, -4.7750, 2.0315),
|
|
47
|
+
(3.4445, -4.7750, 2.0315),
|
|
48
|
+
(3.4445, -4.7750, 2.0315),
|
|
49
|
+
]:
|
|
50
|
+
# for _ in range(steps):
|
|
51
|
+
A = X @ X.T
|
|
52
|
+
B = b * A + c * A @ A
|
|
53
|
+
X = a * X + B @ X
|
|
54
|
+
if G.size(0) > G.size(1):
|
|
55
|
+
X = X.T
|
|
56
|
+
return X
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def muon_update(grad: torch.Tensor, momentum: torch.Tensor, beta: float = 0.95, nesterov: bool = True) -> torch.Tensor:
|
|
60
|
+
"""Compute Muon optimizer update with momentum and orthogonalization.
|
|
61
|
+
|
|
62
|
+
This function applies momentum to the gradient, optionally uses Nesterov acceleration, and then orthogonalizes the
|
|
63
|
+
update using Newton-Schulz iterations. For convolutional filters (4D tensors), it reshapes before orthogonalization
|
|
64
|
+
and scales the final update based on parameter dimensions.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
grad (torch.Tensor): Gradient tensor to update. Can be 2D or 4D (for conv filters).
|
|
68
|
+
momentum (torch.Tensor): Momentum buffer tensor, modified in-place via lerp.
|
|
69
|
+
beta (float, optional): Momentum coefficient for exponential moving average. Default: 0.95.
|
|
70
|
+
nesterov (bool, optional): Whether to use Nesterov momentum acceleration. Default: True.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
(torch.Tensor): Orthogonalized update tensor with same shape as input grad. For 4D inputs, returns reshaped
|
|
74
|
+
result matching original dimensions.
|
|
75
|
+
|
|
76
|
+
Examples:
|
|
77
|
+
>>> grad = torch.randn(64, 128)
|
|
78
|
+
>>> momentum = torch.zeros_like(grad)
|
|
79
|
+
>>> update = muon_update(grad, momentum, beta=0.95, nesterov=True)
|
|
80
|
+
>>> print(update.shape)
|
|
81
|
+
torch.Size([64, 128])
|
|
82
|
+
|
|
83
|
+
Notes:
|
|
84
|
+
- Momentum buffer is updated in-place: momentum = beta * momentum + (1-beta) * grad.
|
|
85
|
+
- With Nesterov: update = beta * momentum + (1-beta) * grad.
|
|
86
|
+
- Without Nesterov: update = momentum.
|
|
87
|
+
- 4D tensors (conv filters) are reshaped to 2D as (channels, height*width*depth) for orthogonalization.
|
|
88
|
+
- Final update is scaled by sqrt(max(dim[-2], dim[-1])) to account for parameter dimensions.
|
|
89
|
+
"""
|
|
90
|
+
momentum.lerp_(grad, 1 - beta)
|
|
91
|
+
update = grad.lerp(momentum, beta) if nesterov else momentum
|
|
92
|
+
if update.ndim == 4: # for the case of conv filters
|
|
93
|
+
update = update.view(len(update), -1)
|
|
94
|
+
update = zeropower_via_newtonschulz5(update)
|
|
95
|
+
update *= max(1, grad.size(-2) / grad.size(-1)) ** 0.5
|
|
96
|
+
return update
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class MuSGD(optim.Optimizer):
|
|
100
|
+
"""Hybrid optimizer combining Muon and SGD updates for neural network training.
|
|
101
|
+
|
|
102
|
+
This optimizer implements a combination of Muon (a momentum-based optimizer with orthogonalization via Newton-Schulz
|
|
103
|
+
iterations) and standard SGD with momentum. It allows different parameter groups to use either the hybrid Muon+SGD
|
|
104
|
+
approach or pure SGD.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
param_groups (list): List of parameter groups with their optimization settings.
|
|
108
|
+
muon (float, optional): Weight factor for Muon updates in hybrid mode. Default: 0.5.
|
|
109
|
+
sgd (float, optional): Weight factor for SGD updates in hybrid mode. Default: 0.5.
|
|
110
|
+
|
|
111
|
+
Attributes:
|
|
112
|
+
muon (float): Scaling factor applied to Muon learning rate.
|
|
113
|
+
sgd (float): Scaling factor applied to SGD learning rate in hybrid mode.
|
|
114
|
+
|
|
115
|
+
Examples:
|
|
116
|
+
>>> param_groups = [
|
|
117
|
+
... {
|
|
118
|
+
... "params": model.conv_params,
|
|
119
|
+
... "lr": 0.02,
|
|
120
|
+
... "use_muon": True,
|
|
121
|
+
... "momentum": 0.95,
|
|
122
|
+
... "nesterov": True,
|
|
123
|
+
... "weight_decay": 0.01,
|
|
124
|
+
... },
|
|
125
|
+
... {
|
|
126
|
+
... "params": model.other_params,
|
|
127
|
+
... "lr": 0.01,
|
|
128
|
+
... "use_muon": False,
|
|
129
|
+
... "momentum": 0.9,
|
|
130
|
+
... "nesterov": False,
|
|
131
|
+
... "weight_decay": 0,
|
|
132
|
+
... },
|
|
133
|
+
... ]
|
|
134
|
+
>>> optimizer = MuSGD(param_groups, muon=0.5, sgd=0.5)
|
|
135
|
+
>>> loss = model(data)
|
|
136
|
+
>>> loss.backward()
|
|
137
|
+
>>> optimizer.step()
|
|
138
|
+
|
|
139
|
+
Notes:
|
|
140
|
+
- Parameter groups with 'use_muon': True will receive both Muon and SGD updates.
|
|
141
|
+
- Parameter groups with 'use_muon': False will receive only SGD updates.
|
|
142
|
+
- The Muon update uses orthogonalization which works best for 2D+ parameter tensors.
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
def __init__(
|
|
146
|
+
self,
|
|
147
|
+
params,
|
|
148
|
+
lr: float = 1e-3,
|
|
149
|
+
momentum: float = 0.0,
|
|
150
|
+
weight_decay: float = 0.0,
|
|
151
|
+
nesterov: bool = False,
|
|
152
|
+
use_muon: bool = False,
|
|
153
|
+
muon: float = 0.5,
|
|
154
|
+
sgd: float = 0.5,
|
|
155
|
+
):
|
|
156
|
+
"""Initialize MuSGD optimizer with hybrid Muon and SGD capabilities.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
params: Iterable of parameters to optimize or dicts defining parameter groups.
|
|
160
|
+
lr (float): Learning rate.
|
|
161
|
+
momentum (float): Momentum factor for SGD.
|
|
162
|
+
weight_decay (float): Weight decay (L2 penalty).
|
|
163
|
+
nesterov (bool): Whether to use Nesterov momentum.
|
|
164
|
+
use_muon (bool): Whether to enable Muon updates.
|
|
165
|
+
muon (float): Scaling factor for Muon component.
|
|
166
|
+
sgd (float): Scaling factor for SGD component.
|
|
167
|
+
"""
|
|
168
|
+
defaults = dict(
|
|
169
|
+
lr=lr,
|
|
170
|
+
momentum=momentum,
|
|
171
|
+
weight_decay=weight_decay,
|
|
172
|
+
nesterov=nesterov,
|
|
173
|
+
use_muon=use_muon,
|
|
174
|
+
)
|
|
175
|
+
super().__init__(params, defaults)
|
|
176
|
+
self.muon = muon
|
|
177
|
+
self.sgd = sgd
|
|
178
|
+
|
|
179
|
+
@torch.no_grad()
|
|
180
|
+
def step(self, closure=None):
|
|
181
|
+
"""Perform a single optimization step.
|
|
182
|
+
|
|
183
|
+
Applies either hybrid Muon+SGD updates or pure SGD updates depending on the
|
|
184
|
+
'use_muon' flag in each parameter group. For Muon-enabled groups, parameters
|
|
185
|
+
receive both an orthogonalized Muon update and a standard SGD momentum update.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
closure (Callable, optional): A closure that reevaluates the model
|
|
189
|
+
and returns the loss. Default: None.
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
(torch.Tensor | None): The loss value if closure is provided, otherwise None.
|
|
193
|
+
|
|
194
|
+
Notes:
|
|
195
|
+
- Parameters with None gradients are assigned zero gradients for synchronization.
|
|
196
|
+
- Muon updates use Newton-Schulz orthogonalization and work best on 2D+ tensors.
|
|
197
|
+
- Weight decay is applied only to the SGD component in hybrid mode.
|
|
198
|
+
"""
|
|
199
|
+
loss = None
|
|
200
|
+
if closure is not None:
|
|
201
|
+
with torch.enable_grad():
|
|
202
|
+
loss = closure()
|
|
203
|
+
|
|
204
|
+
for group in self.param_groups:
|
|
205
|
+
# Muon
|
|
206
|
+
if group["use_muon"]:
|
|
207
|
+
# generate weight updates in distributed fashion
|
|
208
|
+
for p in group["params"]:
|
|
209
|
+
lr = group["lr"]
|
|
210
|
+
if p.grad is None:
|
|
211
|
+
continue
|
|
212
|
+
grad = p.grad
|
|
213
|
+
state = self.state[p]
|
|
214
|
+
if len(state) == 0:
|
|
215
|
+
state["momentum_buffer"] = torch.zeros_like(p)
|
|
216
|
+
state["momentum_buffer_SGD"] = torch.zeros_like(p)
|
|
217
|
+
|
|
218
|
+
update = muon_update(
|
|
219
|
+
grad, state["momentum_buffer"], beta=group["momentum"], nesterov=group["nesterov"]
|
|
220
|
+
)
|
|
221
|
+
p.add_(update.reshape(p.shape), alpha=-(lr * self.muon))
|
|
222
|
+
|
|
223
|
+
# SGD update
|
|
224
|
+
if group["weight_decay"] != 0:
|
|
225
|
+
grad = grad.add(p, alpha=group["weight_decay"])
|
|
226
|
+
state["momentum_buffer_SGD"].mul_(group["momentum"]).add_(grad)
|
|
227
|
+
sgd_update = (
|
|
228
|
+
grad.add(state["momentum_buffer_SGD"], alpha=group["momentum"])
|
|
229
|
+
if group["nesterov"]
|
|
230
|
+
else state["momentum_buffer_SGD"]
|
|
231
|
+
)
|
|
232
|
+
p.add_(sgd_update, alpha=-(lr * self.sgd))
|
|
233
|
+
else: # SGD
|
|
234
|
+
for p in group["params"]:
|
|
235
|
+
lr = group["lr"]
|
|
236
|
+
if p.grad is None:
|
|
237
|
+
continue
|
|
238
|
+
grad = p.grad
|
|
239
|
+
if group["weight_decay"] != 0:
|
|
240
|
+
grad = grad.add(p, alpha=group["weight_decay"])
|
|
241
|
+
state = self.state[p]
|
|
242
|
+
if len(state) == 0:
|
|
243
|
+
state["momentum_buffer"] = torch.zeros_like(p)
|
|
244
|
+
state["momentum_buffer"].mul_(group["momentum"]).add_(grad)
|
|
245
|
+
update = (
|
|
246
|
+
grad.add(state["momentum_buffer"], alpha=group["momentum"])
|
|
247
|
+
if group["nesterov"]
|
|
248
|
+
else state["momentum_buffer"]
|
|
249
|
+
)
|
|
250
|
+
p.add_(update, alpha=-lr)
|
|
251
|
+
return loss
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class Muon(optim.Optimizer):
|
|
255
|
+
"""Muon optimizer for usage in non-distributed settings.
|
|
256
|
+
|
|
257
|
+
This optimizer implements the Muon algorithm, which combines momentum-based updates with orthogonalization via
|
|
258
|
+
Newton-Schulz iterations. It applies weight decay and learning rate scaling to parameter updates.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
params (iterable): Iterable of parameters to optimize or dicts defining parameter groups.
|
|
262
|
+
lr (float, optional): Learning rate. Default: 0.02.
|
|
263
|
+
weight_decay (float, optional): Weight decay (L2 penalty) coefficient. Default: 0.
|
|
264
|
+
momentum (float, optional): Momentum coefficient for exponential moving average. Default: 0.95.
|
|
265
|
+
|
|
266
|
+
Attributes:
|
|
267
|
+
param_groups (list): List of parameter groups with their optimization settings.
|
|
268
|
+
state (dict): Dictionary containing optimizer state for each parameter.
|
|
269
|
+
|
|
270
|
+
Examples:
|
|
271
|
+
>>> model = YourModel()
|
|
272
|
+
>>> optimizer = Muon(model.parameters(), lr=0.02, weight_decay=0.01, momentum=0.95)
|
|
273
|
+
>>> loss = model(data)
|
|
274
|
+
>>> loss.backward()
|
|
275
|
+
>>> optimizer.step()
|
|
276
|
+
|
|
277
|
+
Notes:
|
|
278
|
+
- Designed for non-distributed training environments.
|
|
279
|
+
- Uses Muon updates with orthogonalization for all parameters.
|
|
280
|
+
- Weight decay is applied multiplicatively before parameter update.
|
|
281
|
+
- Parameters with None gradients are assigned zero gradients for synchronization.
|
|
282
|
+
"""
|
|
283
|
+
|
|
284
|
+
def __init__(self, params, lr: float = 0.02, weight_decay: float = 0, momentum: float = 0.95):
|
|
285
|
+
"""Initialize Muon optimizer with orthogonalization-based updates.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
params: Iterable of parameters to optimize or dicts defining parameter groups.
|
|
289
|
+
lr (float): Learning rate.
|
|
290
|
+
weight_decay (float): Weight decay factor applied multiplicatively.
|
|
291
|
+
momentum (float): Momentum factor for gradient accumulation.
|
|
292
|
+
"""
|
|
293
|
+
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
|
|
294
|
+
super().__init__(params, defaults)
|
|
295
|
+
|
|
296
|
+
@torch.no_grad()
|
|
297
|
+
def step(self, closure=None):
|
|
298
|
+
"""Perform a single optimization step.
|
|
299
|
+
|
|
300
|
+
Applies Muon updates to all parameters, incorporating momentum and orthogonalization.
|
|
301
|
+
Weight decay is applied multiplicatively before the parameter update.
|
|
302
|
+
|
|
303
|
+
Args:
|
|
304
|
+
closure (Callable[[], torch.Tensor] | None, optional): A closure that reevaluates the model
|
|
305
|
+
and returns the loss. Default: None.
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
(torch.Tensor | None): The loss value if closure is provided, otherwise None.
|
|
309
|
+
|
|
310
|
+
Examples:
|
|
311
|
+
>>> optimizer = Muon(model.parameters())
|
|
312
|
+
>>> loss = model(inputs)
|
|
313
|
+
>>> loss.backward()
|
|
314
|
+
>>> optimizer.step()
|
|
315
|
+
|
|
316
|
+
Notes:
|
|
317
|
+
- Parameters with None gradients are assigned zero gradients for synchronization.
|
|
318
|
+
- Weight decay is applied as: p *= (1 - lr * weight_decay).
|
|
319
|
+
- Muon update uses Newton-Schulz orthogonalization and works best on 2D+ tensors.
|
|
320
|
+
"""
|
|
321
|
+
loss = None
|
|
322
|
+
if closure is not None:
|
|
323
|
+
with torch.enable_grad():
|
|
324
|
+
loss = closure()
|
|
325
|
+
|
|
326
|
+
for group in self.param_groups:
|
|
327
|
+
for p in group["params"]:
|
|
328
|
+
if p.grad is None:
|
|
329
|
+
# continue
|
|
330
|
+
p.grad = torch.zeros_like(p) # Force synchronization
|
|
331
|
+
state = self.state[p]
|
|
332
|
+
if len(state) == 0:
|
|
333
|
+
state["momentum_buffer"] = torch.zeros_like(p)
|
|
334
|
+
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
|
|
335
|
+
p.mul_(1 - group["lr"] * group["weight_decay"])
|
|
336
|
+
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
|
337
|
+
|
|
338
|
+
return loss
|
|
@@ -19,23 +19,23 @@ from .trackzone import TrackZone
|
|
|
19
19
|
from .vision_eye import VisionEye
|
|
20
20
|
|
|
21
21
|
__all__ = (
|
|
22
|
-
"ObjectCounter",
|
|
23
|
-
"ObjectCropper",
|
|
24
|
-
"ObjectBlurrer",
|
|
25
22
|
"AIGym",
|
|
26
|
-
"
|
|
27
|
-
"
|
|
23
|
+
"Analytics",
|
|
24
|
+
"DistanceCalculation",
|
|
28
25
|
"Heatmap",
|
|
26
|
+
"Inference",
|
|
29
27
|
"InstanceSegmentation",
|
|
30
|
-
"
|
|
31
|
-
"
|
|
32
|
-
"
|
|
33
|
-
"QueueManager",
|
|
28
|
+
"ObjectBlurrer",
|
|
29
|
+
"ObjectCounter",
|
|
30
|
+
"ObjectCropper",
|
|
34
31
|
"ParkingManagement",
|
|
35
32
|
"ParkingPtsSelection",
|
|
36
|
-
"
|
|
37
|
-
"
|
|
38
|
-
"TrackZone",
|
|
33
|
+
"QueueManager",
|
|
34
|
+
"RegionCounter",
|
|
39
35
|
"SearchApp",
|
|
36
|
+
"SecurityAlarm",
|
|
37
|
+
"SpeedEstimator",
|
|
38
|
+
"TrackZone",
|
|
39
|
+
"VisionEye",
|
|
40
40
|
"VisualAISearch",
|
|
41
41
|
)
|