ultralytics 8.3.100__py3-none-any.whl → 8.3.102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_solutions.py +140 -76
- ultralytics/__init__.py +1 -1
- ultralytics/engine/exporter.py +20 -5
- ultralytics/engine/model.py +1 -1
- ultralytics/engine/predictor.py +3 -1
- ultralytics/hub/__init__.py +29 -2
- ultralytics/hub/google/__init__.py +18 -1
- ultralytics/models/fastsam/predict.py +12 -1
- ultralytics/models/nas/predict.py +21 -3
- ultralytics/models/rtdetr/val.py +26 -2
- ultralytics/models/sam/amg.py +22 -1
- ultralytics/models/sam/modules/encoders.py +85 -4
- ultralytics/models/sam/modules/memory_attention.py +61 -3
- ultralytics/models/sam/modules/utils.py +108 -5
- ultralytics/models/utils/loss.py +38 -2
- ultralytics/models/utils/ops.py +15 -1
- ultralytics/models/yolo/classify/predict.py +11 -1
- ultralytics/models/yolo/classify/train.py +17 -1
- ultralytics/models/yolo/classify/val.py +82 -6
- ultralytics/models/yolo/detect/predict.py +20 -1
- ultralytics/models/yolo/model.py +69 -4
- ultralytics/models/yolo/obb/predict.py +16 -1
- ultralytics/models/yolo/obb/train.py +35 -2
- ultralytics/models/yolo/obb/val.py +87 -6
- ultralytics/models/yolo/pose/predict.py +18 -1
- ultralytics/models/yolo/pose/train.py +48 -3
- ultralytics/models/yolo/pose/val.py +113 -8
- ultralytics/models/yolo/segment/predict.py +27 -2
- ultralytics/models/yolo/segment/train.py +61 -3
- ultralytics/models/yolo/segment/val.py +10 -1
- ultralytics/models/yolo/world/train_world.py +29 -1
- ultralytics/models/yolo/yoloe/train.py +47 -3
- ultralytics/nn/modules/activation.py +26 -3
- ultralytics/nn/modules/block.py +89 -0
- ultralytics/nn/modules/head.py +3 -92
- ultralytics/nn/modules/utils.py +70 -4
- ultralytics/nn/tasks.py +2 -2
- ultralytics/nn/text_model.py +93 -17
- ultralytics/utils/benchmarks.py +1 -1
- ultralytics/utils/callbacks/base.py +22 -5
- ultralytics/utils/callbacks/comet.py +93 -5
- ultralytics/utils/callbacks/dvc.py +64 -5
- ultralytics/utils/callbacks/neptune.py +25 -2
- ultralytics/utils/callbacks/tensorboard.py +30 -2
- ultralytics/utils/callbacks/wb.py +16 -1
- ultralytics/utils/dist.py +35 -2
- ultralytics/utils/errors.py +27 -6
- ultralytics/utils/patches.py +33 -5
- ultralytics/utils/triton.py +16 -3
- {ultralytics-8.3.100.dist-info → ultralytics-8.3.102.dist-info}/METADATA +1 -2
- {ultralytics-8.3.100.dist-info → ultralytics-8.3.102.dist-info}/RECORD +55 -55
- {ultralytics-8.3.100.dist-info → ultralytics-8.3.102.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.100.dist-info → ultralytics-8.3.102.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.100.dist-info → ultralytics-8.3.102.dist-info}/licenses/LICENSE +0 -0
- {ultralytics-8.3.100.dist-info → ultralytics-8.3.102.dist-info}/top_level.txt +0 -0
@@ -60,7 +60,17 @@ class MemoryAttentionLayer(nn.Module):
|
|
60
60
|
pos_enc_at_cross_attn_keys: bool = True,
|
61
61
|
pos_enc_at_cross_attn_queries: bool = False,
|
62
62
|
):
|
63
|
-
"""
|
63
|
+
"""
|
64
|
+
Initialize a memory attention layer with self-attention, cross-attention, and feedforward components.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
d_model (int): Dimensionality of the model.
|
68
|
+
dim_feedforward (int): Dimensionality of the feedforward network.
|
69
|
+
dropout (float): Dropout rate for regularization.
|
70
|
+
pos_enc_at_attn (bool): Whether to add positional encoding at attention.
|
71
|
+
pos_enc_at_cross_attn_keys (bool): Whether to add positional encoding to cross-attention keys.
|
72
|
+
pos_enc_at_cross_attn_queries (bool): Whether to add positional encoding to cross-attention queries.
|
73
|
+
"""
|
64
74
|
super().__init__()
|
65
75
|
self.d_model = d_model
|
66
76
|
self.dim_feedforward = dim_feedforward
|
@@ -183,7 +193,31 @@ class MemoryAttention(nn.Module):
|
|
183
193
|
num_layers: int,
|
184
194
|
batch_first: bool = True, # Do layers expect batch first input?
|
185
195
|
):
|
186
|
-
"""
|
196
|
+
"""
|
197
|
+
Initialize MemoryAttention with specified layers and normalization for sequential data processing.
|
198
|
+
|
199
|
+
This class implements a multi-layer attention mechanism that combines self-attention and cross-attention
|
200
|
+
for processing sequential data, particularly useful in transformer-like architectures.
|
201
|
+
|
202
|
+
Args:
|
203
|
+
d_model (int): The dimension of the model's hidden state.
|
204
|
+
pos_enc_at_input (bool): Whether to apply positional encoding at the input.
|
205
|
+
layer (nn.Module): The attention layer to be used in the module.
|
206
|
+
num_layers (int): The number of attention layers.
|
207
|
+
batch_first (bool): Whether the input tensors are in batch-first format.
|
208
|
+
|
209
|
+
Examples:
|
210
|
+
>>> d_model = 256
|
211
|
+
>>> layer = MemoryAttentionLayer(d_model)
|
212
|
+
>>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)
|
213
|
+
>>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)
|
214
|
+
>>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)
|
215
|
+
>>> curr_pos = torch.randn(10, 32, d_model)
|
216
|
+
>>> memory_pos = torch.randn(20, 32, d_model)
|
217
|
+
>>> output = attention(curr, memory, curr_pos, memory_pos)
|
218
|
+
>>> print(output.shape)
|
219
|
+
torch.Size([10, 32, 256])
|
220
|
+
"""
|
187
221
|
super().__init__()
|
188
222
|
self.d_model = d_model
|
189
223
|
self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])
|
@@ -200,7 +234,31 @@ class MemoryAttention(nn.Module):
|
|
200
234
|
memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
|
201
235
|
num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
|
202
236
|
) -> torch.Tensor:
|
203
|
-
"""
|
237
|
+
"""
|
238
|
+
Process inputs through attention layers, applying self and cross-attention with positional encoding.
|
239
|
+
|
240
|
+
Args:
|
241
|
+
curr (torch.Tensor): Self-attention input tensor, representing the current state.
|
242
|
+
memory (torch.Tensor): Cross-attention input tensor, representing memory information.
|
243
|
+
curr_pos (Optional[Tensor]): Positional encoding for self-attention inputs.
|
244
|
+
memory_pos (Optional[Tensor]): Positional encoding for cross-attention inputs.
|
245
|
+
num_obj_ptr_tokens (int): Number of object pointer tokens to exclude from rotary position embedding.
|
246
|
+
|
247
|
+
Returns:
|
248
|
+
(torch.Tensor): Processed output tensor after applying attention layers and normalization.
|
249
|
+
|
250
|
+
Examples:
|
251
|
+
>>> d_model = 256
|
252
|
+
>>> layer = MemoryAttentionLayer(d_model)
|
253
|
+
>>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)
|
254
|
+
>>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)
|
255
|
+
>>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)
|
256
|
+
>>> curr_pos = torch.randn(10, 32, d_model)
|
257
|
+
>>> memory_pos = torch.randn(20, 32, d_model)
|
258
|
+
>>> output = attention(curr, memory, curr_pos, memory_pos)
|
259
|
+
>>> print(output.shape)
|
260
|
+
torch.Size([10, 32, 256])
|
261
|
+
"""
|
204
262
|
if isinstance(curr, list):
|
205
263
|
assert isinstance(curr_pos, list)
|
206
264
|
assert len(curr) == len(curr_pos) == 1
|
@@ -61,7 +61,23 @@ def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num
|
|
61
61
|
|
62
62
|
|
63
63
|
def get_1d_sine_pe(pos_inds, dim, temperature=10000):
|
64
|
-
"""
|
64
|
+
"""
|
65
|
+
Generate 1D sinusoidal positional embeddings for given positions and dimensions.
|
66
|
+
|
67
|
+
Args:
|
68
|
+
pos_inds (torch.Tensor): Position indices for which to generate embeddings.
|
69
|
+
dim (int): Dimension of the positional embeddings. Should be an even number.
|
70
|
+
temperature (float): Scaling factor for the frequency of the sinusoidal functions.
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
(torch.Tensor): Sinusoidal positional embeddings with shape (pos_inds.shape, dim).
|
74
|
+
|
75
|
+
Examples:
|
76
|
+
>>> pos = torch.tensor([0, 1, 2, 3])
|
77
|
+
>>> embeddings = get_1d_sine_pe(pos, 128)
|
78
|
+
>>> embeddings.shape
|
79
|
+
torch.Size([4, 128])
|
80
|
+
"""
|
65
81
|
pe_dim = dim // 2
|
66
82
|
dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
|
67
83
|
dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
|
@@ -72,7 +88,30 @@ def get_1d_sine_pe(pos_inds, dim, temperature=10000):
|
|
72
88
|
|
73
89
|
|
74
90
|
def init_t_xy(end_x: int, end_y: int):
|
75
|
-
"""
|
91
|
+
"""
|
92
|
+
Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.
|
93
|
+
|
94
|
+
This function creates coordinate tensors for a grid with dimensions end_x × end_y. It generates a linear index tensor
|
95
|
+
and corresponding x and y coordinate tensors.
|
96
|
+
|
97
|
+
Args:
|
98
|
+
end_x (int): Width of the grid (number of columns).
|
99
|
+
end_y (int): Height of the grid (number of rows).
|
100
|
+
|
101
|
+
Returns:
|
102
|
+
t (torch.Tensor): Linear indices for each position in the grid, with shape (end_x * end_y).
|
103
|
+
t_x (torch.Tensor): X-coordinates for each position, with shape (end_x * end_y).
|
104
|
+
t_y (torch.Tensor): Y-coordinates for each position, with shape (end_x * end_y).
|
105
|
+
|
106
|
+
Examples:
|
107
|
+
>>> t, t_x, t_y = init_t_xy(3, 2)
|
108
|
+
>>> print(t)
|
109
|
+
tensor([0., 1., 2., 3., 4., 5.])
|
110
|
+
>>> print(t_x)
|
111
|
+
tensor([0., 1., 2., 0., 1., 2.])
|
112
|
+
>>> print(t_y)
|
113
|
+
tensor([0., 0., 0., 1., 1., 1.])
|
114
|
+
"""
|
76
115
|
t = torch.arange(end_x * end_y, dtype=torch.float32)
|
77
116
|
t_x = (t % end_x).float()
|
78
117
|
t_y = torch.div(t, end_x, rounding_mode="floor").float()
|
@@ -80,7 +119,32 @@ def init_t_xy(end_x: int, end_y: int):
|
|
80
119
|
|
81
120
|
|
82
121
|
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
83
|
-
"""
|
122
|
+
"""
|
123
|
+
Compute axial complex exponential positional encodings for 2D spatial positions in a grid.
|
124
|
+
|
125
|
+
This function generates complex exponential positional encodings for a 2D grid of spatial positions,
|
126
|
+
using separate frequency components for the x and y dimensions.
|
127
|
+
|
128
|
+
Args:
|
129
|
+
dim (int): Dimension of the positional encoding.
|
130
|
+
end_x (int): Width of the 2D grid.
|
131
|
+
end_y (int): Height of the 2D grid.
|
132
|
+
theta (float, optional): Scaling factor for frequency computation.
|
133
|
+
|
134
|
+
Returns:
|
135
|
+
freqs_cis_x (torch.Tensor): Complex exponential positional encodings for x-dimension with shape
|
136
|
+
(end_x*end_y, dim//4).
|
137
|
+
freqs_cis_y (torch.Tensor): Complex exponential positional encodings for y-dimension with shape
|
138
|
+
(end_x*end_y, dim//4).
|
139
|
+
|
140
|
+
Examples:
|
141
|
+
>>> dim, end_x, end_y = 128, 8, 8
|
142
|
+
>>> freqs_cis_x, freqs_cis_y = compute_axial_cis(dim, end_x, end_y)
|
143
|
+
>>> freqs_cis_x.shape
|
144
|
+
torch.Size([64, 32])
|
145
|
+
>>> freqs_cis_y.shape
|
146
|
+
torch.Size([64, 32])
|
147
|
+
"""
|
84
148
|
freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
85
149
|
freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
86
150
|
|
@@ -93,7 +157,22 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
|
93
157
|
|
94
158
|
|
95
159
|
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
96
|
-
"""
|
160
|
+
"""
|
161
|
+
Reshape frequency tensor for broadcasting with input tensor.
|
162
|
+
|
163
|
+
Reshapes a frequency tensor to ensure dimensional compatibility for broadcasting with an input tensor.
|
164
|
+
This function is typically used in positional encoding operations.
|
165
|
+
|
166
|
+
Args:
|
167
|
+
freqs_cis (torch.Tensor): Frequency tensor with shape matching the last two dimensions of x.
|
168
|
+
x (torch.Tensor): Input tensor to broadcast with.
|
169
|
+
|
170
|
+
Returns:
|
171
|
+
(torch.Tensor): Reshaped frequency tensor ready for broadcasting with the input tensor.
|
172
|
+
|
173
|
+
Raises:
|
174
|
+
AssertionError: If the shape of freqs_cis doesn't match the last two dimensions of x.
|
175
|
+
"""
|
97
176
|
ndim = x.ndim
|
98
177
|
assert 0 <= 1 < ndim
|
99
178
|
assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
|
@@ -107,7 +186,31 @@ def apply_rotary_enc(
|
|
107
186
|
freqs_cis: torch.Tensor,
|
108
187
|
repeat_freqs_k: bool = False,
|
109
188
|
):
|
110
|
-
"""
|
189
|
+
"""
|
190
|
+
Apply rotary positional encoding to query and key tensors.
|
191
|
+
|
192
|
+
This function applies rotary positional encoding (RoPE) to query and key tensors using complex-valued frequency
|
193
|
+
components. RoPE is a technique that injects relative position information into self-attention mechanisms.
|
194
|
+
|
195
|
+
Args:
|
196
|
+
xq (torch.Tensor): Query tensor to encode with positional information.
|
197
|
+
xk (torch.Tensor): Key tensor to encode with positional information.
|
198
|
+
freqs_cis (torch.Tensor): Complex-valued frequency components for rotary encoding with shape matching the
|
199
|
+
last two dimensions of xq.
|
200
|
+
repeat_freqs_k (bool, optional): Whether to repeat frequency components along sequence length dimension
|
201
|
+
to match key sequence length.
|
202
|
+
|
203
|
+
Returns:
|
204
|
+
xq_out (torch.Tensor): Query tensor with rotary positional encoding applied.
|
205
|
+
xk_out (torch.Tensor): Key tensor with rotary positional encoding applied, or original xk if xk is empty.
|
206
|
+
|
207
|
+
Examples:
|
208
|
+
>>> import torch
|
209
|
+
>>> xq = torch.randn(2, 8, 16, 64) # [batch, heads, seq_len, dim]
|
210
|
+
>>> xk = torch.randn(2, 8, 16, 64)
|
211
|
+
>>> freqs_cis = compute_axial_cis(64, 4, 4) # For a 4x4 spatial grid with dim=64
|
212
|
+
>>> q_encoded, k_encoded = apply_rotary_enc(xq, xk, freqs_cis)
|
213
|
+
"""
|
111
214
|
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
112
215
|
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) if xk.shape[-2] != 0 else None
|
113
216
|
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
ultralytics/models/utils/loss.py
CHANGED
@@ -65,7 +65,25 @@ class DETRLoss(nn.Module):
|
|
65
65
|
self.device = None
|
66
66
|
|
67
67
|
def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=""):
|
68
|
-
"""
|
68
|
+
"""
|
69
|
+
Compute classification loss based on predictions, target values, and ground truth scores.
|
70
|
+
|
71
|
+
Args:
|
72
|
+
pred_scores (torch.Tensor): Predicted class scores with shape (batch_size, num_queries, num_classes).
|
73
|
+
targets (torch.Tensor): Target class indices with shape (batch_size, num_queries).
|
74
|
+
gt_scores (torch.Tensor): Ground truth confidence scores with shape (batch_size, num_queries).
|
75
|
+
num_gts (int): Number of ground truth objects.
|
76
|
+
postfix (str, optional): String to append to the loss name for identification in multi-loss scenarios.
|
77
|
+
|
78
|
+
Returns:
|
79
|
+
loss_cls (torch.Tensor): Classification loss value.
|
80
|
+
|
81
|
+
Notes:
|
82
|
+
The function supports different classification loss types:
|
83
|
+
- Varifocal Loss (if self.vfl is True and num_gts > 0)
|
84
|
+
- Focal Loss (if self.fl is True)
|
85
|
+
- BCE Loss (default fallback)
|
86
|
+
"""
|
69
87
|
# Logits: [b, query, num_classes], gt_class: list[[n, 1]]
|
70
88
|
name_class = f"loss_class{postfix}"
|
71
89
|
bs, nq = pred_scores.shape[:2]
|
@@ -87,7 +105,25 @@ class DETRLoss(nn.Module):
|
|
87
105
|
return {name_class: loss_cls.squeeze() * self.loss_gain["class"]}
|
88
106
|
|
89
107
|
def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=""):
|
90
|
-
"""
|
108
|
+
"""
|
109
|
+
Compute bounding box and GIoU losses for predicted and ground truth bounding boxes.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (batch_size, num_queries, 4).
|
113
|
+
gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (N, 4), where N is the total
|
114
|
+
number of ground truth boxes.
|
115
|
+
postfix (str): String to append to the loss names for identification in multi-loss scenarios.
|
116
|
+
|
117
|
+
Returns:
|
118
|
+
loss (dict): Dictionary containing:
|
119
|
+
- loss_bbox{postfix} (torch.Tensor): L1 loss between predicted and ground truth boxes,
|
120
|
+
scaled by the bbox loss gain.
|
121
|
+
- loss_giou{postfix} (torch.Tensor): GIoU loss between predicted and ground truth boxes,
|
122
|
+
scaled by the giou loss gain.
|
123
|
+
|
124
|
+
Notes:
|
125
|
+
If no ground truth boxes are provided (empty list), zero-valued tensors are returned for both losses.
|
126
|
+
"""
|
91
127
|
# Boxes: [b, query, 4], gt_bbox: list[[n, 4]]
|
92
128
|
name_bbox = f"loss_bbox{postfix}"
|
93
129
|
name_giou = f"loss_giou{postfix}"
|
ultralytics/models/utils/ops.py
CHANGED
@@ -31,7 +31,21 @@ class HungarianMatcher(nn.Module):
|
|
31
31
|
"""
|
32
32
|
|
33
33
|
def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
|
34
|
-
"""
|
34
|
+
"""
|
35
|
+
Initialize a HungarianMatcher module for optimal assignment of predicted and ground truth bounding boxes.
|
36
|
+
|
37
|
+
The HungarianMatcher uses a cost function that considers classification scores, bounding box coordinates,
|
38
|
+
and optionally mask predictions to perform optimal bipartite matching between predictions and ground truths.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
cost_gain (dict, optional): Dictionary of cost coefficients for different components of the matching cost.
|
42
|
+
Should contain keys 'class', 'bbox', 'giou', 'mask', and 'dice'.
|
43
|
+
use_fl (bool, optional): Whether to use Focal Loss for the classification cost calculation.
|
44
|
+
with_mask (bool, optional): Whether the model makes mask predictions.
|
45
|
+
num_sample_points (int, optional): Number of sample points used in mask cost calculation.
|
46
|
+
alpha (float, optional): Alpha factor in Focal Loss calculation.
|
47
|
+
gamma (float, optional): Gamma factor in Focal Loss calculation.
|
48
|
+
"""
|
35
49
|
super().__init__()
|
36
50
|
if cost_gain is None:
|
37
51
|
cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1}
|
@@ -36,7 +36,17 @@ class ClassificationPredictor(BasePredictor):
|
|
36
36
|
"""
|
37
37
|
|
38
38
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
39
|
-
"""
|
39
|
+
"""
|
40
|
+
Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.
|
41
|
+
|
42
|
+
This constructor initializes a ClassificationPredictor instance, which extends BasePredictor for classification
|
43
|
+
tasks. It ensures the task is set to 'classify' regardless of input configuration.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
cfg (dict): Default configuration dictionary containing prediction settings. Defaults to DEFAULT_CFG.
|
47
|
+
overrides (dict, optional): Configuration overrides that take precedence over cfg.
|
48
|
+
_callbacks (list, optional): List of callback functions to be executed during prediction.
|
49
|
+
"""
|
40
50
|
super().__init__(cfg, overrides, _callbacks)
|
41
51
|
self.args.task = "classify"
|
42
52
|
self._legacy_transform_name = "ultralytics.yolo.data.augment.ToTensor"
|
@@ -48,7 +48,23 @@ class ClassificationTrainer(BaseTrainer):
|
|
48
48
|
"""
|
49
49
|
|
50
50
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
51
|
-
"""
|
51
|
+
"""
|
52
|
+
Initialize a ClassificationTrainer object.
|
53
|
+
|
54
|
+
This constructor sets up a trainer for image classification tasks, configuring the task type and default
|
55
|
+
image size if not specified.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
cfg (dict, optional): Default configuration dictionary containing training parameters.
|
59
|
+
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
|
60
|
+
_callbacks (list, optional): List of callback functions to be executed during training.
|
61
|
+
|
62
|
+
Examples:
|
63
|
+
>>> from ultralytics.models.yolo.classify import ClassificationTrainer
|
64
|
+
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3)
|
65
|
+
>>> trainer = ClassificationTrainer(overrides=args)
|
66
|
+
>>> trainer.train()
|
67
|
+
"""
|
52
68
|
if overrides is None:
|
53
69
|
overrides = {}
|
54
70
|
overrides["task"] = "classify"
|
@@ -49,7 +49,25 @@ class ClassificationValidator(BaseValidator):
|
|
49
49
|
"""
|
50
50
|
|
51
51
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
52
|
-
"""
|
52
|
+
"""
|
53
|
+
Initialize ClassificationValidator with dataloader, save directory, and other parameters.
|
54
|
+
|
55
|
+
This validator handles the validation process for classification models, including metrics calculation,
|
56
|
+
confusion matrix generation, and visualization of results.
|
57
|
+
|
58
|
+
Args:
|
59
|
+
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
|
60
|
+
save_dir (str | Path, optional): Directory to save results.
|
61
|
+
pbar (bool, optional): Display a progress bar.
|
62
|
+
args (dict, optional): Arguments containing model and validation configuration.
|
63
|
+
_callbacks (list, optional): List of callback functions to be called during validation.
|
64
|
+
|
65
|
+
Examples:
|
66
|
+
>>> from ultralytics.models.yolo.classify import ClassificationValidator
|
67
|
+
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10")
|
68
|
+
>>> validator = ClassificationValidator(args=args)
|
69
|
+
>>> validator()
|
70
|
+
"""
|
53
71
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
54
72
|
self.targets = None
|
55
73
|
self.pred = None
|
@@ -76,13 +94,38 @@ class ClassificationValidator(BaseValidator):
|
|
76
94
|
return batch
|
77
95
|
|
78
96
|
def update_metrics(self, preds, batch):
|
79
|
-
"""
|
97
|
+
"""
|
98
|
+
Update running metrics with model predictions and batch targets.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
preds (torch.Tensor): Model predictions, typically logits or probabilities for each class.
|
102
|
+
batch (dict): Batch data containing images and class labels.
|
103
|
+
|
104
|
+
This method appends the top-N predictions (sorted by confidence in descending order) to the
|
105
|
+
prediction list for later evaluation. N is limited to the minimum of 5 and the number of classes.
|
106
|
+
"""
|
80
107
|
n5 = min(len(self.names), 5)
|
81
108
|
self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
|
82
109
|
self.targets.append(batch["cls"].type(torch.int32).cpu())
|
83
110
|
|
84
111
|
def finalize_metrics(self, *args, **kwargs):
|
85
|
-
"""
|
112
|
+
"""
|
113
|
+
Finalize metrics including confusion matrix and processing speed.
|
114
|
+
|
115
|
+
This method processes the accumulated predictions and targets to generate the confusion matrix,
|
116
|
+
optionally plots it, and updates the metrics object with speed information.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
*args (Any): Variable length argument list.
|
120
|
+
**kwargs (Any): Arbitrary keyword arguments.
|
121
|
+
|
122
|
+
Examples:
|
123
|
+
>>> validator = ClassificationValidator()
|
124
|
+
>>> validator.pred = [torch.tensor([[0, 1, 2]])] # Top-3 predictions for one sample
|
125
|
+
>>> validator.targets = [torch.tensor([0])] # Ground truth class
|
126
|
+
>>> validator.finalize_metrics()
|
127
|
+
>>> print(validator.metrics.confusion_matrix) # Access the confusion matrix
|
128
|
+
"""
|
86
129
|
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
|
87
130
|
if self.args.plots:
|
88
131
|
for normalize in True, False:
|
@@ -107,7 +150,16 @@ class ClassificationValidator(BaseValidator):
|
|
107
150
|
return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
|
108
151
|
|
109
152
|
def get_dataloader(self, dataset_path, batch_size):
|
110
|
-
"""
|
153
|
+
"""
|
154
|
+
Build and return a data loader for classification validation.
|
155
|
+
|
156
|
+
Args:
|
157
|
+
dataset_path (str | Path): Path to the dataset directory.
|
158
|
+
batch_size (int): Number of samples per batch.
|
159
|
+
|
160
|
+
Returns:
|
161
|
+
(torch.utils.data.DataLoader): DataLoader object for the classification validation dataset.
|
162
|
+
"""
|
111
163
|
dataset = self.build_dataset(dataset_path)
|
112
164
|
return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)
|
113
165
|
|
@@ -117,7 +169,18 @@ class ClassificationValidator(BaseValidator):
|
|
117
169
|
LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
|
118
170
|
|
119
171
|
def plot_val_samples(self, batch, ni):
|
120
|
-
"""
|
172
|
+
"""
|
173
|
+
Plot validation image samples with their ground truth labels.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
batch (dict): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).
|
177
|
+
ni (int): Batch index used for naming the output file.
|
178
|
+
|
179
|
+
Examples:
|
180
|
+
>>> validator = ClassificationValidator()
|
181
|
+
>>> batch = {"img": torch.rand(16, 3, 224, 224), "cls": torch.randint(0, 10, (16,))}
|
182
|
+
>>> validator.plot_val_samples(batch, 0)
|
183
|
+
"""
|
121
184
|
plot_images(
|
122
185
|
images=batch["img"],
|
123
186
|
batch_idx=torch.arange(len(batch["img"])),
|
@@ -128,7 +191,20 @@ class ClassificationValidator(BaseValidator):
|
|
128
191
|
)
|
129
192
|
|
130
193
|
def plot_predictions(self, batch, preds, ni):
|
131
|
-
"""
|
194
|
+
"""
|
195
|
+
Plot images with their predicted class labels and save the visualization.
|
196
|
+
|
197
|
+
Args:
|
198
|
+
batch (dict): Batch data containing images and other information.
|
199
|
+
preds (torch.Tensor): Model predictions with shape (batch_size, num_classes).
|
200
|
+
ni (int): Batch index used for naming the output file.
|
201
|
+
|
202
|
+
Examples:
|
203
|
+
>>> validator = ClassificationValidator()
|
204
|
+
>>> batch = {"img": torch.rand(16, 3, 224, 224)}
|
205
|
+
>>> preds = torch.rand(16, 10) # 16 images, 10 classes
|
206
|
+
>>> validator.plot_predictions(batch, preds, 0)
|
207
|
+
"""
|
132
208
|
plot_images(
|
133
209
|
batch["img"],
|
134
210
|
batch_idx=torch.arange(len(batch["img"])),
|
@@ -31,7 +31,26 @@ class DetectionPredictor(BasePredictor):
|
|
31
31
|
"""
|
32
32
|
|
33
33
|
def postprocess(self, preds, img, orig_imgs, **kwargs):
|
34
|
-
"""
|
34
|
+
"""
|
35
|
+
Post-process predictions and return a list of Results objects.
|
36
|
+
|
37
|
+
This method applies non-maximum suppression to raw model predictions and prepares them for visualization and
|
38
|
+
further analysis.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
preds (torch.Tensor): Raw predictions from the model.
|
42
|
+
img (torch.Tensor): Processed input image tensor in model input format.
|
43
|
+
orig_imgs (torch.Tensor | list): Original input images before preprocessing.
|
44
|
+
**kwargs (Any): Additional keyword arguments.
|
45
|
+
|
46
|
+
Returns:
|
47
|
+
(list): List of Results objects containing the post-processed predictions.
|
48
|
+
|
49
|
+
Examples:
|
50
|
+
>>> predictor = DetectionPredictor(overrides=dict(model="yolov8n.pt"))
|
51
|
+
>>> results = predictor.predict("path/to/image.jpg")
|
52
|
+
>>> processed_results = predictor.postprocess(preds, img, orig_imgs)
|
53
|
+
"""
|
35
54
|
preds = ops.non_max_suppression(
|
36
55
|
preds,
|
37
56
|
self.args.conf,
|
ultralytics/models/yolo/model.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
from pathlib import Path
|
4
4
|
|
5
|
+
from ultralytics.data.build import load_inference_source
|
5
6
|
from ultralytics.engine.model import Model
|
6
7
|
from ultralytics.models import yolo
|
7
8
|
from ultralytics.nn.tasks import (
|
@@ -21,7 +22,24 @@ class YOLO(Model):
|
|
21
22
|
"""YOLO (You Only Look Once) object detection model."""
|
22
23
|
|
23
24
|
def __init__(self, model="yolo11n.pt", task=None, verbose=False):
|
24
|
-
"""
|
25
|
+
"""
|
26
|
+
Initialize a YOLO model.
|
27
|
+
|
28
|
+
This constructor initializes a YOLO model, automatically switching to specialized model types
|
29
|
+
(YOLOWorld or YOLOE) based on the model filename.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
model (str | Path): Model name or path to model file, i.e. 'yolo11n.pt', 'yolov8n.yaml'.
|
33
|
+
task (str | None): YOLO task specification, i.e. 'detect', 'segment', 'classify', 'pose', 'obb'.
|
34
|
+
Defaults to auto-detection based on model.
|
35
|
+
verbose (bool): Display model info on load.
|
36
|
+
|
37
|
+
Examples:
|
38
|
+
>>> from ultralytics import YOLO
|
39
|
+
>>> model = YOLO("yolov8n.pt") # load a pretrained YOLOv8n detection model
|
40
|
+
>>> model = YOLO("yolov8n-seg.pt") # load a pretrained YOLOv8n segmentation model
|
41
|
+
>>> model = YOLO("yolo11n.pt") # load a pretrained YOLOv11n detection model
|
42
|
+
"""
|
25
43
|
path = Path(model)
|
26
44
|
if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model
|
27
45
|
new_instance = YOLOWorld(path, verbose=verbose)
|
@@ -165,12 +183,46 @@ class YOLOE(Model):
|
|
165
183
|
return self.model.get_text_pe(texts)
|
166
184
|
|
167
185
|
def get_visual_pe(self, img, visual):
|
168
|
-
"""
|
186
|
+
"""
|
187
|
+
Get visual positional embeddings for the given image and visual features.
|
188
|
+
|
189
|
+
This method extracts positional embeddings from visual features based on the input image. It requires
|
190
|
+
that the model is an instance of YOLOEModel.
|
191
|
+
|
192
|
+
Args:
|
193
|
+
img (torch.Tensor): Input image tensor.
|
194
|
+
visual (torch.Tensor): Visual features extracted from the image.
|
195
|
+
|
196
|
+
Returns:
|
197
|
+
(torch.Tensor): Visual positional embeddings.
|
198
|
+
|
199
|
+
Examples:
|
200
|
+
>>> model = YOLOE("yoloe-v8s.pt")
|
201
|
+
>>> img = torch.rand(1, 3, 640, 640)
|
202
|
+
>>> visual_features = model.model.backbone(img)
|
203
|
+
>>> pe = model.get_visual_pe(img, visual_features)
|
204
|
+
"""
|
169
205
|
assert isinstance(self.model, YOLOEModel)
|
170
206
|
return self.model.get_visual_pe(img, visual)
|
171
207
|
|
172
208
|
def set_vocab(self, vocab, names):
|
173
|
-
"""
|
209
|
+
"""
|
210
|
+
Set vocabulary and class names for the YOLOE model.
|
211
|
+
|
212
|
+
This method configures the vocabulary and class names used by the model for text processing and
|
213
|
+
classification tasks. The model must be an instance of YOLOEModel.
|
214
|
+
|
215
|
+
Args:
|
216
|
+
vocab (list): Vocabulary list containing tokens or words used by the model for text processing.
|
217
|
+
names (list): List of class names that the model can detect or classify.
|
218
|
+
|
219
|
+
Raises:
|
220
|
+
AssertionError: If the model is not an instance of YOLOEModel.
|
221
|
+
|
222
|
+
Examples:
|
223
|
+
>>> model = YOLOE("yoloe-v8s.pt")
|
224
|
+
>>> model.set_vocab(["person", "car", "dog"], ["person", "car", "dog"])
|
225
|
+
"""
|
174
226
|
assert isinstance(self.model, YOLOEModel)
|
175
227
|
self.model.set_vocab(vocab, names=names)
|
176
228
|
|
@@ -267,7 +319,14 @@ class YOLOE(Model):
|
|
267
319
|
f"{len(visual_prompts['cls'])} respectively"
|
268
320
|
)
|
269
321
|
self.predictor = (predictor or self._smart_load("predictor"))(
|
270
|
-
overrides={
|
322
|
+
overrides={
|
323
|
+
"task": self.model.task,
|
324
|
+
"mode": "predict",
|
325
|
+
"save": False,
|
326
|
+
"verbose": refer_image is None,
|
327
|
+
"batch": 1,
|
328
|
+
},
|
329
|
+
_callbacks=self.callbacks,
|
271
330
|
)
|
272
331
|
|
273
332
|
if len(visual_prompts):
|
@@ -281,6 +340,12 @@ class YOLOE(Model):
|
|
281
340
|
self.predictor.set_prompts(visual_prompts.copy())
|
282
341
|
|
283
342
|
self.predictor.setup_model(model=self.model)
|
343
|
+
|
344
|
+
if refer_image is None and source:
|
345
|
+
dataset = load_inference_source(source)
|
346
|
+
if dataset.mode in {"video", "stream"}:
|
347
|
+
# NOTE: set the first frame as refer image for videos/streams inference
|
348
|
+
refer_image = next(iter(dataset))[1][0]
|
284
349
|
if refer_image is not None and len(visual_prompts):
|
285
350
|
vpe = self.predictor.get_vpe(refer_image)
|
286
351
|
self.model.set_classes(self.model.names, vpe)
|
@@ -27,7 +27,22 @@ class OBBPredictor(DetectionPredictor):
|
|
27
27
|
"""
|
28
28
|
|
29
29
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
30
|
-
"""
|
30
|
+
"""
|
31
|
+
Initialize OBBPredictor with optional model and data configuration overrides.
|
32
|
+
|
33
|
+
This constructor sets up an OBBPredictor instance for oriented bounding box detection tasks.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
cfg (dict, optional): Default configuration for the predictor.
|
37
|
+
overrides (dict, optional): Configuration overrides that take precedence over the default config.
|
38
|
+
_callbacks (list, optional): List of callback functions to be invoked during prediction.
|
39
|
+
|
40
|
+
Examples:
|
41
|
+
>>> from ultralytics.utils import ASSETS
|
42
|
+
>>> from ultralytics.models.yolo.obb import OBBPredictor
|
43
|
+
>>> args = dict(model="yolo11n-obb.pt", source=ASSETS)
|
44
|
+
>>> predictor = OBBPredictor(overrides=args)
|
45
|
+
"""
|
31
46
|
super().__init__(cfg, overrides, _callbacks)
|
32
47
|
self.args.task = "obb"
|
33
48
|
|