ultralytics 8.3.89__py3-none-any.whl → 8.3.91__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/conftest.py +2 -2
- tests/test_cli.py +13 -11
- tests/test_cuda.py +10 -1
- tests/test_exports.py +2 -2
- tests/test_integrations.py +1 -5
- tests/test_python.py +16 -16
- tests/test_solutions.py +9 -9
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +3 -1
- ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
- ultralytics/cfg/models/11/yolo11.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8.yaml +5 -5
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
- ultralytics/data/annotator.py +9 -14
- ultralytics/data/base.py +118 -30
- ultralytics/data/build.py +63 -24
- ultralytics/data/converter.py +5 -5
- ultralytics/data/dataset.py +207 -53
- ultralytics/data/loaders.py +1 -0
- ultralytics/data/split_dota.py +39 -12
- ultralytics/data/utils.py +15 -19
- ultralytics/engine/exporter.py +24 -23
- ultralytics/engine/model.py +67 -88
- ultralytics/engine/predictor.py +106 -21
- ultralytics/engine/trainer.py +32 -23
- ultralytics/engine/tuner.py +21 -18
- ultralytics/engine/validator.py +75 -41
- ultralytics/hub/__init__.py +12 -13
- ultralytics/hub/auth.py +9 -12
- ultralytics/hub/session.py +76 -21
- ultralytics/hub/utils.py +19 -17
- ultralytics/models/fastsam/model.py +20 -11
- ultralytics/models/fastsam/predict.py +36 -16
- ultralytics/models/fastsam/utils.py +5 -5
- ultralytics/models/fastsam/val.py +6 -6
- ultralytics/models/nas/model.py +22 -11
- ultralytics/models/nas/predict.py +9 -4
- ultralytics/models/nas/val.py +5 -5
- ultralytics/models/rtdetr/model.py +20 -11
- ultralytics/models/rtdetr/predict.py +18 -15
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +42 -6
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +50 -4
- ultralytics/models/sam/model.py +8 -14
- ultralytics/models/sam/modules/decoders.py +18 -21
- ultralytics/models/sam/modules/encoders.py +25 -46
- ultralytics/models/sam/modules/memory_attention.py +19 -15
- ultralytics/models/sam/modules/sam.py +18 -25
- ultralytics/models/sam/modules/tiny_encoder.py +19 -29
- ultralytics/models/sam/modules/transformer.py +35 -57
- ultralytics/models/sam/modules/utils.py +15 -15
- ultralytics/models/sam/predict.py +0 -3
- ultralytics/models/utils/loss.py +87 -36
- ultralytics/models/utils/ops.py +26 -31
- ultralytics/models/yolo/classify/predict.py +24 -3
- ultralytics/models/yolo/classify/train.py +77 -10
- ultralytics/models/yolo/classify/val.py +40 -15
- ultralytics/models/yolo/detect/predict.py +23 -10
- ultralytics/models/yolo/detect/train.py +85 -15
- ultralytics/models/yolo/detect/val.py +145 -21
- ultralytics/models/yolo/model.py +1 -2
- ultralytics/models/yolo/obb/predict.py +12 -4
- ultralytics/models/yolo/obb/train.py +7 -0
- ultralytics/models/yolo/obb/val.py +25 -7
- ultralytics/models/yolo/pose/predict.py +22 -6
- ultralytics/models/yolo/pose/train.py +17 -1
- ultralytics/models/yolo/pose/val.py +46 -21
- ultralytics/models/yolo/segment/predict.py +22 -8
- ultralytics/models/yolo/segment/train.py +6 -0
- ultralytics/models/yolo/segment/val.py +100 -14
- ultralytics/models/yolo/world/train.py +38 -8
- ultralytics/models/yolo/world/train_world.py +39 -10
- ultralytics/nn/autobackend.py +28 -14
- ultralytics/nn/modules/__init__.py +3 -0
- ultralytics/nn/modules/activation.py +12 -3
- ultralytics/nn/modules/block.py +587 -84
- ultralytics/nn/modules/conv.py +418 -54
- ultralytics/nn/modules/head.py +3 -4
- ultralytics/nn/modules/transformer.py +320 -34
- ultralytics/nn/modules/utils.py +17 -3
- ultralytics/nn/tasks.py +221 -69
- ultralytics/solutions/ai_gym.py +2 -2
- ultralytics/solutions/analytics.py +4 -4
- ultralytics/solutions/heatmap.py +4 -4
- ultralytics/solutions/instance_segmentation.py +10 -4
- ultralytics/solutions/object_blurrer.py +2 -2
- ultralytics/solutions/object_counter.py +2 -2
- ultralytics/solutions/object_cropper.py +2 -2
- ultralytics/solutions/parking_management.py +9 -9
- ultralytics/solutions/queue_management.py +1 -1
- ultralytics/solutions/region_counter.py +2 -2
- ultralytics/solutions/security_alarm.py +7 -7
- ultralytics/solutions/solutions.py +7 -4
- ultralytics/solutions/speed_estimation.py +2 -2
- ultralytics/solutions/streamlit_inference.py +6 -6
- ultralytics/solutions/trackzone.py +9 -2
- ultralytics/solutions/vision_eye.py +4 -4
- ultralytics/trackers/basetrack.py +1 -1
- ultralytics/trackers/bot_sort.py +23 -22
- ultralytics/trackers/byte_tracker.py +4 -4
- ultralytics/trackers/track.py +2 -1
- ultralytics/trackers/utils/gmc.py +26 -27
- ultralytics/trackers/utils/kalman_filter.py +31 -29
- ultralytics/trackers/utils/matching.py +7 -7
- ultralytics/utils/__init__.py +32 -27
- ultralytics/utils/autobatch.py +5 -5
- ultralytics/utils/benchmarks.py +111 -18
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +11 -11
- ultralytics/utils/callbacks/comet.py +42 -24
- ultralytics/utils/callbacks/dvc.py +11 -10
- ultralytics/utils/callbacks/hub.py +8 -8
- ultralytics/utils/callbacks/mlflow.py +1 -1
- ultralytics/utils/callbacks/neptune.py +12 -10
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +6 -6
- ultralytics/utils/callbacks/wb.py +16 -16
- ultralytics/utils/checks.py +116 -35
- ultralytics/utils/dist.py +15 -2
- ultralytics/utils/downloads.py +13 -9
- ultralytics/utils/files.py +12 -13
- ultralytics/utils/instance.py +112 -45
- ultralytics/utils/loss.py +28 -33
- ultralytics/utils/metrics.py +246 -181
- ultralytics/utils/ops.py +61 -53
- ultralytics/utils/patches.py +8 -6
- ultralytics/utils/plotting.py +65 -45
- ultralytics/utils/tal.py +88 -57
- ultralytics/utils/torch_utils.py +181 -33
- ultralytics/utils/triton.py +13 -3
- ultralytics/utils/tuner.py +8 -16
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/METADATA +1 -1
- ultralytics-8.3.91.dist-info/RECORD +250 -0
- ultralytics-8.3.89.dist-info/RECORD +0 -250
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/LICENSE +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/top_level.txt +0 -0
ultralytics/utils/tal.py
CHANGED
@@ -21,6 +21,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
21
21
|
Attributes:
|
22
22
|
topk (int): The number of top candidates to consider.
|
23
23
|
num_classes (int): The number of object classes.
|
24
|
+
bg_idx (int): Background class index.
|
24
25
|
alpha (float): The alpha parameter for the classification component of the task-aligned metric.
|
25
26
|
beta (float): The beta parameter for the localization component of the task-aligned metric.
|
26
27
|
eps (float): A small value to prevent division by zero.
|
@@ -39,23 +40,25 @@ class TaskAlignedAssigner(nn.Module):
|
|
39
40
|
@torch.no_grad()
|
40
41
|
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
41
42
|
"""
|
42
|
-
Compute the task-aligned assignment.
|
43
|
-
https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py.
|
43
|
+
Compute the task-aligned assignment.
|
44
44
|
|
45
45
|
Args:
|
46
|
-
pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
|
47
|
-
pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
|
48
|
-
anc_points (Tensor): shape(num_total_anchors, 2)
|
49
|
-
gt_labels (Tensor): shape(bs, n_max_boxes, 1)
|
50
|
-
gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
|
51
|
-
mask_gt (Tensor): shape(bs, n_max_boxes, 1)
|
46
|
+
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
47
|
+
pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
|
48
|
+
anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
|
49
|
+
gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
|
50
|
+
gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
|
51
|
+
mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
|
52
52
|
|
53
53
|
Returns:
|
54
|
-
target_labels (Tensor): shape(bs, num_total_anchors)
|
55
|
-
target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
|
56
|
-
target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
|
57
|
-
fg_mask (Tensor): shape(bs, num_total_anchors)
|
58
|
-
target_gt_idx (Tensor): shape(bs, num_total_anchors)
|
54
|
+
target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors).
|
55
|
+
target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4).
|
56
|
+
target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes).
|
57
|
+
fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors).
|
58
|
+
target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors).
|
59
|
+
|
60
|
+
References:
|
61
|
+
https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
|
59
62
|
"""
|
60
63
|
self.bs = pd_scores.shape[0]
|
61
64
|
self.n_max_boxes = gt_bboxes.shape[1]
|
@@ -81,23 +84,22 @@ class TaskAlignedAssigner(nn.Module):
|
|
81
84
|
|
82
85
|
def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
83
86
|
"""
|
84
|
-
Compute the task-aligned assignment.
|
85
|
-
https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py.
|
87
|
+
Compute the task-aligned assignment.
|
86
88
|
|
87
89
|
Args:
|
88
|
-
pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
|
89
|
-
pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
|
90
|
-
anc_points (Tensor): shape(num_total_anchors, 2)
|
91
|
-
gt_labels (Tensor): shape(bs, n_max_boxes, 1)
|
92
|
-
gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
|
93
|
-
mask_gt (Tensor): shape(bs, n_max_boxes, 1)
|
90
|
+
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
91
|
+
pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
|
92
|
+
anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
|
93
|
+
gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
|
94
|
+
gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
|
95
|
+
mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
|
94
96
|
|
95
97
|
Returns:
|
96
|
-
target_labels (Tensor): shape(bs, num_total_anchors)
|
97
|
-
target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
|
98
|
-
target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
|
99
|
-
fg_mask (Tensor): shape(bs, num_total_anchors)
|
100
|
-
target_gt_idx (Tensor): shape(bs, num_total_anchors)
|
98
|
+
target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors).
|
99
|
+
target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4).
|
100
|
+
target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes).
|
101
|
+
fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors).
|
102
|
+
target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors).
|
101
103
|
"""
|
102
104
|
mask_pos, align_metric, overlaps = self.get_pos_mask(
|
103
105
|
pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
|
@@ -118,7 +120,22 @@ class TaskAlignedAssigner(nn.Module):
|
|
118
120
|
return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
|
119
121
|
|
120
122
|
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
|
121
|
-
"""
|
123
|
+
"""
|
124
|
+
Get positive mask for each ground truth box.
|
125
|
+
|
126
|
+
Args:
|
127
|
+
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
128
|
+
pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
|
129
|
+
gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
|
130
|
+
gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
|
131
|
+
anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
|
132
|
+
mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
|
133
|
+
|
134
|
+
Returns:
|
135
|
+
mask_pos (torch.Tensor): Positive mask with shape (bs, max_num_obj, h*w).
|
136
|
+
align_metric (torch.Tensor): Alignment metric with shape (bs, max_num_obj, h*w).
|
137
|
+
overlaps (torch.Tensor): Overlaps between predicted and ground truth boxes with shape (bs, max_num_obj, h*w).
|
138
|
+
"""
|
122
139
|
mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
|
123
140
|
# Get anchor_align metric, (b, max_num_obj, h*w)
|
124
141
|
align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
|
@@ -130,7 +147,20 @@ class TaskAlignedAssigner(nn.Module):
|
|
130
147
|
return mask_pos, align_metric, overlaps
|
131
148
|
|
132
149
|
def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
|
133
|
-
"""
|
150
|
+
"""
|
151
|
+
Compute alignment metric given predicted and ground truth bounding boxes.
|
152
|
+
|
153
|
+
Args:
|
154
|
+
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
155
|
+
pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
|
156
|
+
gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
|
157
|
+
gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
|
158
|
+
mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, h*w).
|
159
|
+
|
160
|
+
Returns:
|
161
|
+
align_metric (torch.Tensor): Alignment metric combining classification and localization.
|
162
|
+
overlaps (torch.Tensor): IoU overlaps between predicted and ground truth boxes.
|
163
|
+
"""
|
134
164
|
na = pd_bboxes.shape[-2]
|
135
165
|
mask_gt = mask_gt.bool() # b, max_num_obj, h*w
|
136
166
|
overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
|
@@ -151,7 +181,16 @@ class TaskAlignedAssigner(nn.Module):
|
|
151
181
|
return align_metric, overlaps
|
152
182
|
|
153
183
|
def iou_calculation(self, gt_bboxes, pd_bboxes):
|
154
|
-
"""
|
184
|
+
"""
|
185
|
+
Calculate IoU for horizontal bounding boxes.
|
186
|
+
|
187
|
+
Args:
|
188
|
+
gt_bboxes (torch.Tensor): Ground truth boxes.
|
189
|
+
pd_bboxes (torch.Tensor): Predicted boxes.
|
190
|
+
|
191
|
+
Returns:
|
192
|
+
(torch.Tensor): IoU values between each pair of boxes.
|
193
|
+
"""
|
155
194
|
return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
|
156
195
|
|
157
196
|
def select_topk_candidates(self, metrics, largest=True, topk_mask=None):
|
@@ -159,16 +198,16 @@ class TaskAlignedAssigner(nn.Module):
|
|
159
198
|
Select the top-k candidates based on the given metrics.
|
160
199
|
|
161
200
|
Args:
|
162
|
-
metrics (Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,
|
201
|
+
metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,
|
163
202
|
max_num_obj is the maximum number of objects, and h*w represents the
|
164
203
|
total number of anchor points.
|
165
204
|
largest (bool): If True, select the largest values; otherwise, select the smallest values.
|
166
|
-
topk_mask (Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), where
|
205
|
+
topk_mask (torch.Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), where
|
167
206
|
topk is the number of top candidates to consider. If not provided,
|
168
207
|
the top-k values are automatically computed based on the given metrics.
|
169
208
|
|
170
209
|
Returns:
|
171
|
-
(Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
|
210
|
+
(torch.Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
|
172
211
|
"""
|
173
212
|
# (b, max_num_obj, topk)
|
174
213
|
topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
|
@@ -183,7 +222,6 @@ class TaskAlignedAssigner(nn.Module):
|
|
183
222
|
for k in range(self.topk):
|
184
223
|
# Expand topk_idxs for each value of k and add 1 at the specified positions
|
185
224
|
count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)
|
186
|
-
# count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device))
|
187
225
|
# Filter invalid bboxes
|
188
226
|
count_tensor.masked_fill_(count_tensor > 1, 0)
|
189
227
|
|
@@ -194,24 +232,21 @@ class TaskAlignedAssigner(nn.Module):
|
|
194
232
|
Compute target labels, target bounding boxes, and target scores for the positive anchor points.
|
195
233
|
|
196
234
|
Args:
|
197
|
-
gt_labels (Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
|
235
|
+
gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
|
198
236
|
batch size and max_num_obj is the maximum number of objects.
|
199
|
-
gt_bboxes (Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
|
200
|
-
target_gt_idx (Tensor): Indices of the assigned ground truth objects for positive
|
237
|
+
gt_bboxes (torch.Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
|
238
|
+
target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive
|
201
239
|
anchor points, with shape (b, h*w), where h*w is the total
|
202
240
|
number of anchor points.
|
203
|
-
fg_mask (Tensor): A boolean tensor of shape (b, h*w) indicating the positive
|
241
|
+
fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive
|
204
242
|
(foreground) anchor points.
|
205
243
|
|
206
244
|
Returns:
|
207
|
-
(
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
- target_scores (Tensor): Shape (b, h*w, num_classes), containing the target scores
|
213
|
-
for positive anchor points, where num_classes is the number
|
214
|
-
of object classes.
|
245
|
+
target_labels (torch.Tensor): Shape (b, h*w), containing the target labels for positive anchor points.
|
246
|
+
target_bboxes (torch.Tensor): Shape (b, h*w, 4), containing the target bounding boxes for positive
|
247
|
+
anchor points.
|
248
|
+
target_scores (torch.Tensor): Shape (b, h*w, num_classes), containing the target scores for positive
|
249
|
+
anchor points.
|
215
250
|
"""
|
216
251
|
# Assigned target labels, (b, 1)
|
217
252
|
batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
|
@@ -258,7 +293,6 @@ class TaskAlignedAssigner(nn.Module):
|
|
258
293
|
bs, n_boxes, _ = gt_bboxes.shape
|
259
294
|
lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom
|
260
295
|
bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
|
261
|
-
# return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype)
|
262
296
|
return bbox_deltas.amin(3).gt_(eps)
|
263
297
|
|
264
298
|
@staticmethod
|
@@ -275,9 +309,6 @@ class TaskAlignedAssigner(nn.Module):
|
|
275
309
|
target_gt_idx (torch.Tensor): Indices of assigned ground truths, shape (b, h*w).
|
276
310
|
fg_mask (torch.Tensor): Foreground mask, shape (b, h*w).
|
277
311
|
mask_pos (torch.Tensor): Updated positive mask, shape (b, n_max_boxes, h*w).
|
278
|
-
|
279
|
-
Note:
|
280
|
-
b: batch size, h: height, w: width.
|
281
312
|
"""
|
282
313
|
# Convert (b, n_max_boxes, h*w) -> (b, h*w)
|
283
314
|
fg_mask = mask_pos.sum(-2)
|
@@ -299,7 +330,7 @@ class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
|
|
299
330
|
"""Assigns ground-truth objects to rotated bounding boxes using a task-aligned metric."""
|
300
331
|
|
301
332
|
def iou_calculation(self, gt_bboxes, pd_bboxes):
|
302
|
-
"""IoU
|
333
|
+
"""Calculate IoU for rotated bounding boxes."""
|
303
334
|
return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)
|
304
335
|
|
305
336
|
@staticmethod
|
@@ -308,11 +339,11 @@ class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
|
|
308
339
|
Select the positive anchor center in gt for rotated bounding boxes.
|
309
340
|
|
310
341
|
Args:
|
311
|
-
xy_centers (Tensor): shape(h*w, 2)
|
312
|
-
gt_bboxes (Tensor): shape(b, n_boxes, 5)
|
342
|
+
xy_centers (torch.Tensor): Anchor center coordinates with shape (h*w, 2).
|
343
|
+
gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (b, n_boxes, 5).
|
313
344
|
|
314
345
|
Returns:
|
315
|
-
(Tensor): shape(b, n_boxes, h*w)
|
346
|
+
(torch.Tensor): Boolean mask of positive anchors with shape (b, n_boxes, h*w).
|
316
347
|
"""
|
317
348
|
# (b, n_boxes, 5) --> (b, n_boxes, 4, 2)
|
318
349
|
corners = xywhr2xyxyxyxy(gt_bboxes)
|
@@ -368,13 +399,13 @@ def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
|
|
368
399
|
Decode predicted rotated bounding box coordinates from anchor points and distribution.
|
369
400
|
|
370
401
|
Args:
|
371
|
-
pred_dist (torch.Tensor): Predicted rotated distance
|
372
|
-
pred_angle (torch.Tensor): Predicted angle
|
373
|
-
anchor_points (torch.Tensor): Anchor points
|
402
|
+
pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).
|
403
|
+
pred_angle (torch.Tensor): Predicted angle with shape (bs, h*w, 1).
|
404
|
+
anchor_points (torch.Tensor): Anchor points with shape (h*w, 2).
|
374
405
|
dim (int, optional): Dimension along which to split. Defaults to -1.
|
375
406
|
|
376
407
|
Returns:
|
377
|
-
(torch.Tensor): Predicted rotated bounding boxes
|
408
|
+
(torch.Tensor): Predicted rotated bounding boxes with shape (bs, h*w, 4).
|
378
409
|
"""
|
379
410
|
lt, rb = pred_dist.split(2, dim=dim)
|
380
411
|
cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)
|
ultralytics/utils/torch_utils.py
CHANGED
@@ -90,12 +90,12 @@ def autocast(enabled: bool, device: str = "cuda"):
|
|
90
90
|
Returns:
|
91
91
|
(torch.amp.autocast): The appropriate autocast context manager.
|
92
92
|
|
93
|
-
|
93
|
+
Notes:
|
94
94
|
- For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`.
|
95
95
|
- For older versions, it uses `torch.cuda.autocast`.
|
96
96
|
|
97
97
|
Examples:
|
98
|
-
>>> with autocast(
|
98
|
+
>>> with autocast(enabled=True):
|
99
99
|
... # Your mixed precision operations here
|
100
100
|
... pass
|
101
101
|
"""
|
@@ -130,7 +130,7 @@ def get_gpu_info(index):
|
|
130
130
|
|
131
131
|
def select_device(device="", batch=0, newline=False, verbose=True):
|
132
132
|
"""
|
133
|
-
|
133
|
+
Select the appropriate PyTorch device based on the provided arguments.
|
134
134
|
|
135
135
|
The function takes a string specifying the device or a torch.device object and returns a torch.device object
|
136
136
|
representing the selected device. The function also validates the number of available devices and raises an
|
@@ -299,7 +299,18 @@ def fuse_deconv_and_bn(deconv, bn):
|
|
299
299
|
|
300
300
|
|
301
301
|
def model_info(model, detailed=False, verbose=True, imgsz=640):
|
302
|
-
"""
|
302
|
+
"""
|
303
|
+
Print and return detailed model information layer by layer.
|
304
|
+
|
305
|
+
Args:
|
306
|
+
model (nn.Module): Model to analyze.
|
307
|
+
detailed (bool, optional): Whether to print detailed layer information. Defaults to False.
|
308
|
+
verbose (bool, optional): Whether to print model information. Defaults to True.
|
309
|
+
imgsz (int | List, optional): Input image size. Defaults to 640.
|
310
|
+
|
311
|
+
Returns:
|
312
|
+
(Tuple[int, int, int, float]): Number of layers, parameters, gradients, and GFLOPs.
|
313
|
+
"""
|
303
314
|
if not verbose:
|
304
315
|
return
|
305
316
|
n_p = get_num_params(model) # number of parameters
|
@@ -343,6 +354,12 @@ def model_info_for_loggers(trainer):
|
|
343
354
|
"""
|
344
355
|
Return model info dict with useful model information.
|
345
356
|
|
357
|
+
Args:
|
358
|
+
trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing model and validation data.
|
359
|
+
|
360
|
+
Returns:
|
361
|
+
(dict): Dictionary containing model parameters, GFLOPs, and inference speeds.
|
362
|
+
|
346
363
|
Examples:
|
347
364
|
YOLOv8n info for loggers
|
348
365
|
>>> results = {
|
@@ -368,7 +385,16 @@ def model_info_for_loggers(trainer):
|
|
368
385
|
|
369
386
|
|
370
387
|
def get_flops(model, imgsz=640):
|
371
|
-
"""
|
388
|
+
"""
|
389
|
+
Return a YOLO model's FLOPs.
|
390
|
+
|
391
|
+
Args:
|
392
|
+
model (nn.Module): The model to calculate FLOPs for.
|
393
|
+
imgsz (int | List[int], optional): Input image size. Defaults to 640.
|
394
|
+
|
395
|
+
Returns:
|
396
|
+
(float): The model's FLOPs in billions.
|
397
|
+
"""
|
372
398
|
if not thop:
|
373
399
|
return 0.0 # if not installed return 0.0 GFLOPs
|
374
400
|
|
@@ -392,7 +418,16 @@ def get_flops(model, imgsz=640):
|
|
392
418
|
|
393
419
|
|
394
420
|
def get_flops_with_torch_profiler(model, imgsz=640):
|
395
|
-
"""
|
421
|
+
"""
|
422
|
+
Compute model FLOPs using torch profiler (alternative to thop package, but 2-10x slower).
|
423
|
+
|
424
|
+
Args:
|
425
|
+
model (nn.Module): The model to calculate FLOPs for.
|
426
|
+
imgsz (int | List[int], optional): Input image size. Defaults to 640.
|
427
|
+
|
428
|
+
Returns:
|
429
|
+
(float): The model's FLOPs in billions.
|
430
|
+
"""
|
396
431
|
if not TORCH_2_0: # torch profiler implemented in torch>=2.0
|
397
432
|
return 0.0
|
398
433
|
model = de_parallel(model)
|
@@ -430,7 +465,18 @@ def initialize_weights(model):
|
|
430
465
|
|
431
466
|
|
432
467
|
def scale_img(img, ratio=1.0, same_shape=False, gs=32):
|
433
|
-
"""
|
468
|
+
"""
|
469
|
+
Scales and pads an image tensor, optionally maintaining aspect ratio and padding to gs multiple.
|
470
|
+
|
471
|
+
Args:
|
472
|
+
img (torch.Tensor): Input image tensor.
|
473
|
+
ratio (float, optional): Scaling ratio. Defaults to 1.0.
|
474
|
+
same_shape (bool, optional): Whether to maintain the same shape. Defaults to False.
|
475
|
+
gs (int, optional): Grid size for padding. Defaults to 32.
|
476
|
+
|
477
|
+
Returns:
|
478
|
+
(torch.Tensor): Scaled and padded image tensor.
|
479
|
+
"""
|
434
480
|
if ratio == 1.0:
|
435
481
|
return img
|
436
482
|
h, w = img.shape[2:]
|
@@ -442,7 +488,15 @@ def scale_img(img, ratio=1.0, same_shape=False, gs=32):
|
|
442
488
|
|
443
489
|
|
444
490
|
def copy_attr(a, b, include=(), exclude=()):
|
445
|
-
"""
|
491
|
+
"""
|
492
|
+
Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes.
|
493
|
+
|
494
|
+
Args:
|
495
|
+
a (object): Destination object to copy attributes to.
|
496
|
+
b (object): Source object to copy attributes from.
|
497
|
+
include (tuple, optional): Attributes to include. If empty, all attributes are included. Defaults to ().
|
498
|
+
exclude (tuple, optional): Attributes to exclude. Defaults to ().
|
499
|
+
"""
|
446
500
|
for k, v in b.__dict__.items():
|
447
501
|
if (len(include) and k not in include) or k.startswith("_") or k in exclude:
|
448
502
|
continue
|
@@ -451,7 +505,12 @@ def copy_attr(a, b, include=(), exclude=()):
|
|
451
505
|
|
452
506
|
|
453
507
|
def get_latest_opset():
|
454
|
-
"""
|
508
|
+
"""
|
509
|
+
Return the second-most recent ONNX opset version supported by this version of PyTorch, adjusted for maturity.
|
510
|
+
|
511
|
+
Returns:
|
512
|
+
(int): The ONNX opset version.
|
513
|
+
"""
|
455
514
|
if TORCH_1_13:
|
456
515
|
# If the PyTorch>=1.13, dynamically compute the latest opset minus one using 'symbolic_opset'
|
457
516
|
return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1
|
@@ -461,27 +520,69 @@ def get_latest_opset():
|
|
461
520
|
|
462
521
|
|
463
522
|
def intersect_dicts(da, db, exclude=()):
|
464
|
-
"""
|
523
|
+
"""
|
524
|
+
Returns a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values.
|
525
|
+
|
526
|
+
Args:
|
527
|
+
da (dict): First dictionary.
|
528
|
+
db (dict): Second dictionary.
|
529
|
+
exclude (tuple, optional): Keys to exclude. Defaults to ().
|
530
|
+
|
531
|
+
Returns:
|
532
|
+
(dict): Dictionary of intersecting keys with matching shapes.
|
533
|
+
"""
|
465
534
|
return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
|
466
535
|
|
467
536
|
|
468
537
|
def is_parallel(model):
|
469
|
-
"""
|
538
|
+
"""
|
539
|
+
Returns True if model is of type DP or DDP.
|
540
|
+
|
541
|
+
Args:
|
542
|
+
model (nn.Module): Model to check.
|
543
|
+
|
544
|
+
Returns:
|
545
|
+
(bool): True if model is DataParallel or DistributedDataParallel.
|
546
|
+
"""
|
470
547
|
return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))
|
471
548
|
|
472
549
|
|
473
550
|
def de_parallel(model):
|
474
|
-
"""
|
551
|
+
"""
|
552
|
+
De-parallelize a model: returns single-GPU model if model is of type DP or DDP.
|
553
|
+
|
554
|
+
Args:
|
555
|
+
model (nn.Module): Model to de-parallelize.
|
556
|
+
|
557
|
+
Returns:
|
558
|
+
(nn.Module): De-parallelized model.
|
559
|
+
"""
|
475
560
|
return model.module if is_parallel(model) else model
|
476
561
|
|
477
562
|
|
478
563
|
def one_cycle(y1=0.0, y2=1.0, steps=100):
|
479
|
-
"""
|
564
|
+
"""
|
565
|
+
Returns a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf.
|
566
|
+
|
567
|
+
Args:
|
568
|
+
y1 (float, optional): Initial value. Defaults to 0.0.
|
569
|
+
y2 (float, optional): Final value. Defaults to 1.0.
|
570
|
+
steps (int, optional): Number of steps. Defaults to 100.
|
571
|
+
|
572
|
+
Returns:
|
573
|
+
(function): Lambda function for computing the sinusoidal ramp.
|
574
|
+
"""
|
480
575
|
return lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1) + y1
|
481
576
|
|
482
577
|
|
483
578
|
def init_seeds(seed=0, deterministic=False):
|
484
|
-
"""
|
579
|
+
"""
|
580
|
+
Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html.
|
581
|
+
|
582
|
+
Args:
|
583
|
+
seed (int, optional): Random seed. Defaults to 0.
|
584
|
+
deterministic (bool, optional): Whether to set deterministic algorithms. Defaults to False.
|
585
|
+
"""
|
485
586
|
random.seed(seed)
|
486
587
|
np.random.seed(seed)
|
487
588
|
torch.manual_seed(seed)
|
@@ -510,16 +611,30 @@ def unset_deterministic():
|
|
510
611
|
|
511
612
|
class ModelEMA:
|
512
613
|
"""
|
513
|
-
Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models.
|
514
|
-
average of everything in the model state_dict (parameters and buffers).
|
614
|
+
Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models.
|
515
615
|
|
616
|
+
Keeps a moving average of everything in the model state_dict (parameters and buffers).
|
516
617
|
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
517
618
|
|
518
619
|
To disable EMA set the `enabled` attribute to `False`.
|
620
|
+
|
621
|
+
Attributes:
|
622
|
+
ema (nn.Module): Copy of the model in evaluation mode.
|
623
|
+
updates (int): Number of EMA updates.
|
624
|
+
decay (function): Decay function that determines the EMA weight.
|
625
|
+
enabled (bool): Whether EMA is enabled.
|
519
626
|
"""
|
520
627
|
|
521
628
|
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
522
|
-
"""
|
629
|
+
"""
|
630
|
+
Initialize EMA for 'model' with given arguments.
|
631
|
+
|
632
|
+
Args:
|
633
|
+
model (nn.Module): Model to create EMA for.
|
634
|
+
decay (float, optional): Maximum EMA decay rate. Defaults to 0.9999.
|
635
|
+
tau (int, optional): EMA decay time constant. Defaults to 2000.
|
636
|
+
updates (int, optional): Initial number of updates. Defaults to 0.
|
637
|
+
"""
|
523
638
|
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
524
639
|
self.updates = updates # number of EMA updates
|
525
640
|
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
@@ -528,7 +643,12 @@ class ModelEMA:
|
|
528
643
|
self.enabled = True
|
529
644
|
|
530
645
|
def update(self, model):
|
531
|
-
"""
|
646
|
+
"""
|
647
|
+
Update EMA parameters.
|
648
|
+
|
649
|
+
Args:
|
650
|
+
model (nn.Module): Model to update EMA from.
|
651
|
+
"""
|
532
652
|
if self.enabled:
|
533
653
|
self.updates += 1
|
534
654
|
d = self.decay(self.updates)
|
@@ -541,7 +661,14 @@ class ModelEMA:
|
|
541
661
|
# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}'
|
542
662
|
|
543
663
|
def update_attr(self, model, include=(), exclude=("process_group", "reducer")):
|
544
|
-
"""
|
664
|
+
"""
|
665
|
+
Updates attributes and saves stripped model with optimizer removed.
|
666
|
+
|
667
|
+
Args:
|
668
|
+
model (nn.Module): Model to update attributes from.
|
669
|
+
include (tuple, optional): Attributes to include. Defaults to ().
|
670
|
+
exclude (tuple, optional): Attributes to exclude. Defaults to ("process_group", "reducer").
|
671
|
+
"""
|
545
672
|
if self.enabled:
|
546
673
|
copy_attr(self.ema, model, include, exclude)
|
547
674
|
|
@@ -551,9 +678,9 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict
|
|
551
678
|
Strip optimizer from 'f' to finalize training, optionally save as 's'.
|
552
679
|
|
553
680
|
Args:
|
554
|
-
f (str):
|
555
|
-
s (str):
|
556
|
-
updates (dict):
|
681
|
+
f (str | Path): File path to model to strip the optimizer from. Defaults to 'best.pt'.
|
682
|
+
s (str, optional): File path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
|
683
|
+
updates (dict, optional): A dictionary of updates to overlay onto the checkpoint before saving.
|
557
684
|
|
558
685
|
Returns:
|
559
686
|
(dict): The combined checkpoint dictionary.
|
@@ -563,9 +690,6 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict
|
|
563
690
|
>>> from ultralytics.utils.torch_utils import strip_optimizer
|
564
691
|
>>> for f in Path("path/to/model/checkpoints").rglob("*.pt"):
|
565
692
|
>>> strip_optimizer(f)
|
566
|
-
|
567
|
-
Note:
|
568
|
-
Use `ultralytics.nn.torch_safe_load` for missing modules with `x = torch_safe_load(f)[0]`
|
569
693
|
"""
|
570
694
|
try:
|
571
695
|
x = torch.load(f, map_location=torch.device("cpu"))
|
@@ -613,7 +737,11 @@ def convert_optimizer_state_dict_to_fp16(state_dict):
|
|
613
737
|
"""
|
614
738
|
Converts the state_dict of a given optimizer to FP16, focusing on the 'state' key for tensor conversions.
|
615
739
|
|
616
|
-
|
740
|
+
Args:
|
741
|
+
state_dict (dict): Optimizer state dictionary.
|
742
|
+
|
743
|
+
Returns:
|
744
|
+
(dict): Converted optimizer state dictionary with FP16 tensors.
|
617
745
|
"""
|
618
746
|
for state in state_dict["state"].values():
|
619
747
|
for k, v in state.items():
|
@@ -653,6 +781,16 @@ def profile(input, ops, n=10, device=None, max_num_obj=0):
|
|
653
781
|
"""
|
654
782
|
Ultralytics speed, memory and FLOPs profiler.
|
655
783
|
|
784
|
+
Args:
|
785
|
+
input (torch.Tensor | List[torch.Tensor]): Input tensor(s) to profile.
|
786
|
+
ops (nn.Module | List[nn.Module]): Model or list of operations to profile.
|
787
|
+
n (int, optional): Number of iterations to average. Defaults to 10.
|
788
|
+
device (str | torch.device, optional): Device to profile on. Defaults to None.
|
789
|
+
max_num_obj (int, optional): Maximum number of objects for simulation. Defaults to 0.
|
790
|
+
|
791
|
+
Returns:
|
792
|
+
(List): Profile results for each operation.
|
793
|
+
|
656
794
|
Examples:
|
657
795
|
>>> from ultralytics.utils.torch_utils import profile
|
658
796
|
>>> input = torch.randn(16, 3, 640, 640)
|
@@ -721,7 +859,15 @@ def profile(input, ops, n=10, device=None, max_num_obj=0):
|
|
721
859
|
|
722
860
|
|
723
861
|
class EarlyStopping:
|
724
|
-
"""
|
862
|
+
"""
|
863
|
+
Early stopping class that stops training when a specified number of epochs have passed without improvement.
|
864
|
+
|
865
|
+
Attributes:
|
866
|
+
best_fitness (float): Best fitness value observed.
|
867
|
+
best_epoch (int): Epoch where best fitness was observed.
|
868
|
+
patience (int): Number of epochs to wait after fitness stops improving before stopping.
|
869
|
+
possible_stop (bool): Flag indicating if stopping may occur next epoch.
|
870
|
+
"""
|
725
871
|
|
726
872
|
def __init__(self, patience=50):
|
727
873
|
"""
|
@@ -770,11 +916,12 @@ class FXModel(nn.Module):
|
|
770
916
|
"""
|
771
917
|
A custom model class for torch.fx compatibility.
|
772
918
|
|
773
|
-
This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph
|
774
|
-
It copies attributes from an existing model and explicitly sets the model attribute to ensure proper
|
919
|
+
This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph
|
920
|
+
manipulation. It copies attributes from an existing model and explicitly sets the model attribute to ensure proper
|
921
|
+
copying.
|
775
922
|
|
776
|
-
|
777
|
-
model (
|
923
|
+
Attributes:
|
924
|
+
model (nn.Module): The original model's layers.
|
778
925
|
"""
|
779
926
|
|
780
927
|
def __init__(self, model):
|
@@ -782,7 +929,7 @@ class FXModel(nn.Module):
|
|
782
929
|
Initialize the FXModel.
|
783
930
|
|
784
931
|
Args:
|
785
|
-
model (
|
932
|
+
model (nn.Module): The original model to wrap for torch.fx compatibility.
|
786
933
|
"""
|
787
934
|
super().__init__()
|
788
935
|
copy_attr(self, model)
|
@@ -793,7 +940,8 @@ class FXModel(nn.Module):
|
|
793
940
|
"""
|
794
941
|
Forward pass through the model.
|
795
942
|
|
796
|
-
This method performs the forward pass through the model, handling the dependencies between layers and saving
|
943
|
+
This method performs the forward pass through the model, handling the dependencies between layers and saving
|
944
|
+
intermediate outputs.
|
797
945
|
|
798
946
|
Args:
|
799
947
|
x (torch.Tensor): The input tensor to the model.
|