ultralytics 8.0.238__py3-none-any.whl → 8.0.239__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +2 -2
- ultralytics/cfg/__init__.py +241 -138
- ultralytics/data/__init__.py +9 -2
- ultralytics/data/annotator.py +4 -4
- ultralytics/data/augment.py +186 -169
- ultralytics/data/base.py +54 -48
- ultralytics/data/build.py +34 -23
- ultralytics/data/converter.py +242 -70
- ultralytics/data/dataset.py +117 -95
- ultralytics/data/explorer/__init__.py +3 -1
- ultralytics/data/explorer/explorer.py +120 -100
- ultralytics/data/explorer/gui/__init__.py +1 -0
- ultralytics/data/explorer/gui/dash.py +123 -89
- ultralytics/data/explorer/utils.py +37 -39
- ultralytics/data/loaders.py +75 -62
- ultralytics/data/split_dota.py +44 -36
- ultralytics/data/utils.py +160 -142
- ultralytics/engine/exporter.py +348 -292
- ultralytics/engine/model.py +102 -66
- ultralytics/engine/predictor.py +74 -55
- ultralytics/engine/results.py +61 -41
- ultralytics/engine/trainer.py +192 -144
- ultralytics/engine/tuner.py +66 -59
- ultralytics/engine/validator.py +31 -26
- ultralytics/hub/__init__.py +54 -31
- ultralytics/hub/auth.py +28 -25
- ultralytics/hub/session.py +282 -133
- ultralytics/hub/utils.py +64 -42
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +6 -6
- ultralytics/models/fastsam/predict.py +3 -2
- ultralytics/models/fastsam/prompt.py +55 -48
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +9 -8
- ultralytics/models/nas/predict.py +8 -6
- ultralytics/models/nas/val.py +11 -9
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +11 -9
- ultralytics/models/rtdetr/train.py +18 -16
- ultralytics/models/rtdetr/val.py +25 -19
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +13 -14
- ultralytics/models/sam/build.py +44 -42
- ultralytics/models/sam/model.py +6 -6
- ultralytics/models/sam/modules/decoders.py +6 -4
- ultralytics/models/sam/modules/encoders.py +37 -35
- ultralytics/models/sam/modules/sam.py +5 -4
- ultralytics/models/sam/modules/tiny_encoder.py +95 -73
- ultralytics/models/sam/modules/transformer.py +3 -2
- ultralytics/models/sam/predict.py +39 -27
- ultralytics/models/utils/loss.py +99 -95
- ultralytics/models/utils/ops.py +34 -31
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +8 -6
- ultralytics/models/yolo/classify/train.py +37 -31
- ultralytics/models/yolo/classify/val.py +26 -24
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +8 -6
- ultralytics/models/yolo/detect/train.py +47 -37
- ultralytics/models/yolo/detect/val.py +100 -82
- ultralytics/models/yolo/model.py +31 -25
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +13 -11
- ultralytics/models/yolo/obb/train.py +3 -3
- ultralytics/models/yolo/obb/val.py +70 -59
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +17 -12
- ultralytics/models/yolo/pose/train.py +28 -25
- ultralytics/models/yolo/pose/val.py +91 -64
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +10 -8
- ultralytics/models/yolo/segment/train.py +16 -15
- ultralytics/models/yolo/segment/val.py +90 -68
- ultralytics/nn/__init__.py +26 -6
- ultralytics/nn/autobackend.py +144 -112
- ultralytics/nn/modules/__init__.py +96 -13
- ultralytics/nn/modules/block.py +28 -7
- ultralytics/nn/modules/conv.py +41 -23
- ultralytics/nn/modules/head.py +60 -52
- ultralytics/nn/modules/transformer.py +49 -32
- ultralytics/nn/modules/utils.py +20 -15
- ultralytics/nn/tasks.py +215 -141
- ultralytics/solutions/ai_gym.py +59 -47
- ultralytics/solutions/distance_calculation.py +17 -14
- ultralytics/solutions/heatmap.py +57 -55
- ultralytics/solutions/object_counter.py +46 -39
- ultralytics/solutions/speed_estimation.py +13 -16
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +1 -0
- ultralytics/trackers/bot_sort.py +2 -1
- ultralytics/trackers/byte_tracker.py +10 -7
- ultralytics/trackers/track.py +7 -7
- ultralytics/trackers/utils/gmc.py +25 -25
- ultralytics/trackers/utils/kalman_filter.py +85 -42
- ultralytics/trackers/utils/matching.py +8 -7
- ultralytics/utils/__init__.py +173 -152
- ultralytics/utils/autobatch.py +10 -10
- ultralytics/utils/benchmarks.py +76 -86
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +29 -29
- ultralytics/utils/callbacks/clearml.py +51 -43
- ultralytics/utils/callbacks/comet.py +81 -66
- ultralytics/utils/callbacks/dvc.py +33 -26
- ultralytics/utils/callbacks/hub.py +44 -26
- ultralytics/utils/callbacks/mlflow.py +31 -24
- ultralytics/utils/callbacks/neptune.py +35 -25
- ultralytics/utils/callbacks/raytune.py +9 -4
- ultralytics/utils/callbacks/tensorboard.py +16 -11
- ultralytics/utils/callbacks/wb.py +39 -33
- ultralytics/utils/checks.py +189 -141
- ultralytics/utils/dist.py +15 -12
- ultralytics/utils/downloads.py +112 -96
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +11 -11
- ultralytics/utils/instance.py +22 -22
- ultralytics/utils/loss.py +117 -67
- ultralytics/utils/metrics.py +224 -158
- ultralytics/utils/ops.py +38 -28
- ultralytics/utils/patches.py +3 -3
- ultralytics/utils/plotting.py +217 -120
- ultralytics/utils/tal.py +19 -13
- ultralytics/utils/torch_utils.py +138 -109
- ultralytics/utils/triton.py +12 -10
- ultralytics/utils/tuner.py +49 -47
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/METADATA +2 -1
- ultralytics-8.0.239.dist-info/RECORD +188 -0
- ultralytics-8.0.238.dist-info/RECORD +0 -188
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
|
@@ -28,11 +28,11 @@ class Conv2d_BN(torch.nn.Sequential):
|
|
|
28
28
|
drop path.
|
|
29
29
|
"""
|
|
30
30
|
super().__init__()
|
|
31
|
-
self.add_module(
|
|
31
|
+
self.add_module("c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
|
|
32
32
|
bn = torch.nn.BatchNorm2d(b)
|
|
33
33
|
torch.nn.init.constant_(bn.weight, bn_weight_init)
|
|
34
34
|
torch.nn.init.constant_(bn.bias, 0)
|
|
35
|
-
self.add_module(
|
|
35
|
+
self.add_module("bn", bn)
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
class PatchEmbed(nn.Module):
|
|
@@ -146,11 +146,11 @@ class ConvLayer(nn.Module):
|
|
|
146
146
|
input_resolution,
|
|
147
147
|
depth,
|
|
148
148
|
activation,
|
|
149
|
-
drop_path=0
|
|
149
|
+
drop_path=0.0,
|
|
150
150
|
downsample=None,
|
|
151
151
|
use_checkpoint=False,
|
|
152
152
|
out_dim=None,
|
|
153
|
-
conv_expand_ratio=4
|
|
153
|
+
conv_expand_ratio=4.0,
|
|
154
154
|
):
|
|
155
155
|
"""
|
|
156
156
|
Initializes the ConvLayer with the given dimensions and settings.
|
|
@@ -173,18 +173,25 @@ class ConvLayer(nn.Module):
|
|
|
173
173
|
self.use_checkpoint = use_checkpoint
|
|
174
174
|
|
|
175
175
|
# Build blocks
|
|
176
|
-
self.blocks = nn.ModuleList(
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
176
|
+
self.blocks = nn.ModuleList(
|
|
177
|
+
[
|
|
178
|
+
MBConv(
|
|
179
|
+
dim,
|
|
180
|
+
dim,
|
|
181
|
+
conv_expand_ratio,
|
|
182
|
+
activation,
|
|
183
|
+
drop_path[i] if isinstance(drop_path, list) else drop_path,
|
|
184
|
+
)
|
|
185
|
+
for i in range(depth)
|
|
186
|
+
]
|
|
187
|
+
)
|
|
184
188
|
|
|
185
189
|
# Patch merging layer
|
|
186
|
-
self.downsample =
|
|
187
|
-
|
|
190
|
+
self.downsample = (
|
|
191
|
+
None
|
|
192
|
+
if downsample is None
|
|
193
|
+
else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
|
194
|
+
)
|
|
188
195
|
|
|
189
196
|
def forward(self, x):
|
|
190
197
|
"""Processes the input through a series of convolutional layers and returns the activated output."""
|
|
@@ -200,7 +207,7 @@ class Mlp(nn.Module):
|
|
|
200
207
|
This layer takes an input with in_features, applies layer normalization and two fully-connected layers.
|
|
201
208
|
"""
|
|
202
209
|
|
|
203
|
-
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
|
210
|
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
|
|
204
211
|
"""Initializes Attention module with the given parameters including dimension, key_dim, number of heads, etc."""
|
|
205
212
|
super().__init__()
|
|
206
213
|
out_features = out_features or in_features
|
|
@@ -232,12 +239,12 @@ class Attention(torch.nn.Module):
|
|
|
232
239
|
"""
|
|
233
240
|
|
|
234
241
|
def __init__(
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
242
|
+
self,
|
|
243
|
+
dim,
|
|
244
|
+
key_dim,
|
|
245
|
+
num_heads=8,
|
|
246
|
+
attn_ratio=4,
|
|
247
|
+
resolution=(14, 14),
|
|
241
248
|
):
|
|
242
249
|
"""
|
|
243
250
|
Initializes the Attention module.
|
|
@@ -256,7 +263,7 @@ class Attention(torch.nn.Module):
|
|
|
256
263
|
|
|
257
264
|
assert isinstance(resolution, tuple) and len(resolution) == 2
|
|
258
265
|
self.num_heads = num_heads
|
|
259
|
-
self.scale = key_dim
|
|
266
|
+
self.scale = key_dim**-0.5
|
|
260
267
|
self.key_dim = key_dim
|
|
261
268
|
self.nh_kd = nh_kd = key_dim * num_heads
|
|
262
269
|
self.d = int(attn_ratio * key_dim)
|
|
@@ -279,13 +286,13 @@ class Attention(torch.nn.Module):
|
|
|
279
286
|
attention_offsets[offset] = len(attention_offsets)
|
|
280
287
|
idxs.append(attention_offsets[offset])
|
|
281
288
|
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
|
|
282
|
-
self.register_buffer(
|
|
289
|
+
self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(N, N), persistent=False)
|
|
283
290
|
|
|
284
291
|
@torch.no_grad()
|
|
285
292
|
def train(self, mode=True):
|
|
286
293
|
"""Sets the module in training mode and handles attribute 'ab' based on the mode."""
|
|
287
294
|
super().train(mode)
|
|
288
|
-
if mode and hasattr(self,
|
|
295
|
+
if mode and hasattr(self, "ab"):
|
|
289
296
|
del self.ab
|
|
290
297
|
else:
|
|
291
298
|
self.ab = self.attention_biases[:, self.attention_bias_idxs]
|
|
@@ -306,8 +313,9 @@ class Attention(torch.nn.Module):
|
|
|
306
313
|
v = v.permute(0, 2, 1, 3)
|
|
307
314
|
self.ab = self.ab.to(self.attention_biases.device)
|
|
308
315
|
|
|
309
|
-
attn = (
|
|
310
|
-
|
|
316
|
+
attn = (q @ k.transpose(-2, -1)) * self.scale + (
|
|
317
|
+
self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
|
|
318
|
+
)
|
|
311
319
|
attn = attn.softmax(dim=-1)
|
|
312
320
|
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
|
|
313
321
|
return self.proj(x)
|
|
@@ -322,9 +330,9 @@ class TinyViTBlock(nn.Module):
|
|
|
322
330
|
input_resolution,
|
|
323
331
|
num_heads,
|
|
324
332
|
window_size=7,
|
|
325
|
-
mlp_ratio=4
|
|
326
|
-
drop=0
|
|
327
|
-
drop_path=0
|
|
333
|
+
mlp_ratio=4.0,
|
|
334
|
+
drop=0.0,
|
|
335
|
+
drop_path=0.0,
|
|
328
336
|
local_conv_size=3,
|
|
329
337
|
activation=nn.GELU,
|
|
330
338
|
):
|
|
@@ -350,7 +358,7 @@ class TinyViTBlock(nn.Module):
|
|
|
350
358
|
self.dim = dim
|
|
351
359
|
self.input_resolution = input_resolution
|
|
352
360
|
self.num_heads = num_heads
|
|
353
|
-
assert window_size > 0,
|
|
361
|
+
assert window_size > 0, "window_size must be greater than 0"
|
|
354
362
|
self.window_size = window_size
|
|
355
363
|
self.mlp_ratio = mlp_ratio
|
|
356
364
|
|
|
@@ -358,7 +366,7 @@ class TinyViTBlock(nn.Module):
|
|
|
358
366
|
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
359
367
|
self.drop_path = nn.Identity()
|
|
360
368
|
|
|
361
|
-
assert dim % num_heads == 0,
|
|
369
|
+
assert dim % num_heads == 0, "dim must be divisible by num_heads"
|
|
362
370
|
head_dim = dim // num_heads
|
|
363
371
|
|
|
364
372
|
window_resolution = (window_size, window_size)
|
|
@@ -377,7 +385,7 @@ class TinyViTBlock(nn.Module):
|
|
|
377
385
|
"""
|
|
378
386
|
H, W = self.input_resolution
|
|
379
387
|
B, L, C = x.shape
|
|
380
|
-
assert L == H * W,
|
|
388
|
+
assert L == H * W, "input feature has wrong size"
|
|
381
389
|
res_x = x
|
|
382
390
|
if H == self.window_size and W == self.window_size:
|
|
383
391
|
x = self.attn(x)
|
|
@@ -394,8 +402,11 @@ class TinyViTBlock(nn.Module):
|
|
|
394
402
|
nH = pH // self.window_size
|
|
395
403
|
nW = pW // self.window_size
|
|
396
404
|
# Window partition
|
|
397
|
-
x =
|
|
398
|
-
|
|
405
|
+
x = (
|
|
406
|
+
x.view(B, nH, self.window_size, nW, self.window_size, C)
|
|
407
|
+
.transpose(2, 3)
|
|
408
|
+
.reshape(B * nH * nW, self.window_size * self.window_size, C)
|
|
409
|
+
)
|
|
399
410
|
x = self.attn(x)
|
|
400
411
|
# Window reverse
|
|
401
412
|
x = x.view(B, nH, nW, self.window_size, self.window_size, C).transpose(2, 3).reshape(B, pH, pW, C)
|
|
@@ -417,8 +428,10 @@ class TinyViTBlock(nn.Module):
|
|
|
417
428
|
"""Returns a formatted string representing the TinyViTBlock's parameters: dimension, input resolution, number of
|
|
418
429
|
attentions heads, window size, and MLP ratio.
|
|
419
430
|
"""
|
|
420
|
-
return
|
|
421
|
-
|
|
431
|
+
return (
|
|
432
|
+
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
|
|
433
|
+
f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
|
|
434
|
+
)
|
|
422
435
|
|
|
423
436
|
|
|
424
437
|
class BasicLayer(nn.Module):
|
|
@@ -431,9 +444,9 @@ class BasicLayer(nn.Module):
|
|
|
431
444
|
depth,
|
|
432
445
|
num_heads,
|
|
433
446
|
window_size,
|
|
434
|
-
mlp_ratio=4
|
|
435
|
-
drop=0
|
|
436
|
-
drop_path=0
|
|
447
|
+
mlp_ratio=4.0,
|
|
448
|
+
drop=0.0,
|
|
449
|
+
drop_path=0.0,
|
|
437
450
|
downsample=None,
|
|
438
451
|
use_checkpoint=False,
|
|
439
452
|
local_conv_size=3,
|
|
@@ -468,22 +481,29 @@ class BasicLayer(nn.Module):
|
|
|
468
481
|
self.use_checkpoint = use_checkpoint
|
|
469
482
|
|
|
470
483
|
# Build blocks
|
|
471
|
-
self.blocks = nn.ModuleList(
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
484
|
+
self.blocks = nn.ModuleList(
|
|
485
|
+
[
|
|
486
|
+
TinyViTBlock(
|
|
487
|
+
dim=dim,
|
|
488
|
+
input_resolution=input_resolution,
|
|
489
|
+
num_heads=num_heads,
|
|
490
|
+
window_size=window_size,
|
|
491
|
+
mlp_ratio=mlp_ratio,
|
|
492
|
+
drop=drop,
|
|
493
|
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
|
494
|
+
local_conv_size=local_conv_size,
|
|
495
|
+
activation=activation,
|
|
496
|
+
)
|
|
497
|
+
for i in range(depth)
|
|
498
|
+
]
|
|
499
|
+
)
|
|
483
500
|
|
|
484
501
|
# Patch merging layer
|
|
485
|
-
self.downsample =
|
|
486
|
-
|
|
502
|
+
self.downsample = (
|
|
503
|
+
None
|
|
504
|
+
if downsample is None
|
|
505
|
+
else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
|
506
|
+
)
|
|
487
507
|
|
|
488
508
|
def forward(self, x):
|
|
489
509
|
"""Performs forward propagation on the input tensor and returns a normalized tensor."""
|
|
@@ -493,7 +513,7 @@ class BasicLayer(nn.Module):
|
|
|
493
513
|
|
|
494
514
|
def extra_repr(self) -> str:
|
|
495
515
|
"""Returns a string representation of the extra_repr function with the layer's parameters."""
|
|
496
|
-
return f
|
|
516
|
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
|
497
517
|
|
|
498
518
|
|
|
499
519
|
class LayerNorm2d(nn.Module):
|
|
@@ -549,8 +569,8 @@ class TinyViT(nn.Module):
|
|
|
549
569
|
depths=[2, 2, 6, 2],
|
|
550
570
|
num_heads=[3, 6, 12, 24],
|
|
551
571
|
window_sizes=[7, 7, 14, 7],
|
|
552
|
-
mlp_ratio=4
|
|
553
|
-
drop_rate=0
|
|
572
|
+
mlp_ratio=4.0,
|
|
573
|
+
drop_rate=0.0,
|
|
554
574
|
drop_path_rate=0.1,
|
|
555
575
|
use_checkpoint=False,
|
|
556
576
|
mbconv_expand_ratio=4.0,
|
|
@@ -585,10 +605,9 @@ class TinyViT(nn.Module):
|
|
|
585
605
|
|
|
586
606
|
activation = nn.GELU
|
|
587
607
|
|
|
588
|
-
self.patch_embed = PatchEmbed(
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
activation=activation)
|
|
608
|
+
self.patch_embed = PatchEmbed(
|
|
609
|
+
in_chans=in_chans, embed_dim=embed_dims[0], resolution=img_size, activation=activation
|
|
610
|
+
)
|
|
592
611
|
|
|
593
612
|
patches_resolution = self.patch_embed.patches_resolution
|
|
594
613
|
self.patches_resolution = patches_resolution
|
|
@@ -601,27 +620,30 @@ class TinyViT(nn.Module):
|
|
|
601
620
|
for i_layer in range(self.num_layers):
|
|
602
621
|
kwargs = dict(
|
|
603
622
|
dim=embed_dims[i_layer],
|
|
604
|
-
input_resolution=(
|
|
605
|
-
|
|
623
|
+
input_resolution=(
|
|
624
|
+
patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
|
|
625
|
+
patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
|
|
626
|
+
),
|
|
606
627
|
# input_resolution=(patches_resolution[0] // (2 ** i_layer),
|
|
607
628
|
# patches_resolution[1] // (2 ** i_layer)),
|
|
608
629
|
depth=depths[i_layer],
|
|
609
|
-
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
|
630
|
+
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
|
|
610
631
|
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
|
611
632
|
use_checkpoint=use_checkpoint,
|
|
612
|
-
out_dim=embed_dims[min(i_layer + 1,
|
|
613
|
-
len(embed_dims) - 1)],
|
|
633
|
+
out_dim=embed_dims[min(i_layer + 1, len(embed_dims) - 1)],
|
|
614
634
|
activation=activation,
|
|
615
635
|
)
|
|
616
636
|
if i_layer == 0:
|
|
617
637
|
layer = ConvLayer(conv_expand_ratio=mbconv_expand_ratio, **kwargs)
|
|
618
638
|
else:
|
|
619
|
-
layer = BasicLayer(
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
639
|
+
layer = BasicLayer(
|
|
640
|
+
num_heads=num_heads[i_layer],
|
|
641
|
+
window_size=window_sizes[i_layer],
|
|
642
|
+
mlp_ratio=self.mlp_ratio,
|
|
643
|
+
drop=drop_rate,
|
|
644
|
+
local_conv_size=local_conv_size,
|
|
645
|
+
**kwargs,
|
|
646
|
+
)
|
|
625
647
|
self.layers.append(layer)
|
|
626
648
|
|
|
627
649
|
# Classifier head
|
|
@@ -680,7 +702,7 @@ class TinyViT(nn.Module):
|
|
|
680
702
|
def _check_lr_scale(m):
|
|
681
703
|
"""Checks if the learning rate scale attribute is present in module's parameters."""
|
|
682
704
|
for p in m.parameters():
|
|
683
|
-
assert hasattr(p,
|
|
705
|
+
assert hasattr(p, "lr_scale"), p.param_name
|
|
684
706
|
|
|
685
707
|
self.apply(_check_lr_scale)
|
|
686
708
|
|
|
@@ -698,7 +720,7 @@ class TinyViT(nn.Module):
|
|
|
698
720
|
@torch.jit.ignore
|
|
699
721
|
def no_weight_decay_keywords(self):
|
|
700
722
|
"""Returns a dictionary of parameter names where weight decay should not be applied."""
|
|
701
|
-
return {
|
|
723
|
+
return {"attention_biases"}
|
|
702
724
|
|
|
703
725
|
def forward_features(self, x):
|
|
704
726
|
"""Runs the input through the model layers and returns the transformed output."""
|
|
@@ -62,7 +62,8 @@ class TwoWayTransformer(nn.Module):
|
|
|
62
62
|
activation=activation,
|
|
63
63
|
attention_downsample_rate=attention_downsample_rate,
|
|
64
64
|
skip_first_layer_pe=(i == 0),
|
|
65
|
-
)
|
|
65
|
+
)
|
|
66
|
+
)
|
|
66
67
|
|
|
67
68
|
self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
|
|
68
69
|
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
|
@@ -227,7 +228,7 @@ class Attention(nn.Module):
|
|
|
227
228
|
self.embedding_dim = embedding_dim
|
|
228
229
|
self.internal_dim = embedding_dim // downsample_rate
|
|
229
230
|
self.num_heads = num_heads
|
|
230
|
-
assert self.internal_dim % num_heads == 0,
|
|
231
|
+
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
|
|
231
232
|
|
|
232
233
|
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
|
233
234
|
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
|
|
@@ -19,8 +19,17 @@ from ultralytics.engine.results import Results
|
|
|
19
19
|
from ultralytics.utils import DEFAULT_CFG, ops
|
|
20
20
|
from ultralytics.utils.torch_utils import select_device
|
|
21
21
|
|
|
22
|
-
from .amg import (
|
|
23
|
-
|
|
22
|
+
from .amg import (
|
|
23
|
+
batch_iterator,
|
|
24
|
+
batched_mask_to_box,
|
|
25
|
+
build_all_layer_point_grids,
|
|
26
|
+
calculate_stability_score,
|
|
27
|
+
generate_crop_boxes,
|
|
28
|
+
is_box_near_crop_edge,
|
|
29
|
+
remove_small_regions,
|
|
30
|
+
uncrop_boxes_xyxy,
|
|
31
|
+
uncrop_masks,
|
|
32
|
+
)
|
|
24
33
|
from .build import build_sam
|
|
25
34
|
|
|
26
35
|
|
|
@@ -58,7 +67,7 @@ class Predictor(BasePredictor):
|
|
|
58
67
|
"""
|
|
59
68
|
if overrides is None:
|
|
60
69
|
overrides = {}
|
|
61
|
-
overrides.update(dict(task=
|
|
70
|
+
overrides.update(dict(task="segment", mode="predict", imgsz=1024))
|
|
62
71
|
super().__init__(cfg, overrides, _callbacks)
|
|
63
72
|
self.args.retina_masks = True
|
|
64
73
|
self.im = None
|
|
@@ -107,7 +116,7 @@ class Predictor(BasePredictor):
|
|
|
107
116
|
Returns:
|
|
108
117
|
(List[np.ndarray]): List of transformed images.
|
|
109
118
|
"""
|
|
110
|
-
assert len(im) == 1,
|
|
119
|
+
assert len(im) == 1, "SAM model does not currently support batched inference"
|
|
111
120
|
letterbox = LetterBox(self.args.imgsz, auto=False, center=False)
|
|
112
121
|
return [letterbox(image=x) for x in im]
|
|
113
122
|
|
|
@@ -132,9 +141,9 @@ class Predictor(BasePredictor):
|
|
|
132
141
|
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
|
|
133
142
|
"""
|
|
134
143
|
# Override prompts if any stored in self.prompts
|
|
135
|
-
bboxes = self.prompts.pop(
|
|
136
|
-
points = self.prompts.pop(
|
|
137
|
-
masks = self.prompts.pop(
|
|
144
|
+
bboxes = self.prompts.pop("bboxes", bboxes)
|
|
145
|
+
points = self.prompts.pop("points", points)
|
|
146
|
+
masks = self.prompts.pop("masks", masks)
|
|
138
147
|
|
|
139
148
|
if all(i is None for i in [bboxes, points, masks]):
|
|
140
149
|
return self.generate(im, *args, **kwargs)
|
|
@@ -199,18 +208,20 @@ class Predictor(BasePredictor):
|
|
|
199
208
|
# `d` could be 1 or 3 depends on `multimask_output`.
|
|
200
209
|
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
|
|
201
210
|
|
|
202
|
-
def generate(
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
211
|
+
def generate(
|
|
212
|
+
self,
|
|
213
|
+
im,
|
|
214
|
+
crop_n_layers=0,
|
|
215
|
+
crop_overlap_ratio=512 / 1500,
|
|
216
|
+
crop_downscale_factor=1,
|
|
217
|
+
point_grids=None,
|
|
218
|
+
points_stride=32,
|
|
219
|
+
points_batch_size=64,
|
|
220
|
+
conf_thres=0.88,
|
|
221
|
+
stability_score_thresh=0.95,
|
|
222
|
+
stability_score_offset=0.95,
|
|
223
|
+
crop_nms_thresh=0.7,
|
|
224
|
+
):
|
|
214
225
|
"""
|
|
215
226
|
Perform image segmentation using the Segment Anything Model (SAM).
|
|
216
227
|
|
|
@@ -248,19 +259,20 @@ class Predictor(BasePredictor):
|
|
|
248
259
|
area = torch.tensor(w * h, device=im.device)
|
|
249
260
|
points_scale = np.array([[w, h]]) # w, h
|
|
250
261
|
# Crop image and interpolate to input size
|
|
251
|
-
crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode=
|
|
262
|
+
crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode="bilinear", align_corners=False)
|
|
252
263
|
# (num_points, 2)
|
|
253
264
|
points_for_image = point_grids[layer_idx] * points_scale
|
|
254
265
|
crop_masks, crop_scores, crop_bboxes = [], [], []
|
|
255
|
-
for (points,
|
|
266
|
+
for (points,) in batch_iterator(points_batch_size, points_for_image):
|
|
256
267
|
pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True)
|
|
257
268
|
# Interpolate predicted masks to input size
|
|
258
|
-
pred_mask = F.interpolate(pred_mask[None], (h, w), mode=
|
|
269
|
+
pred_mask = F.interpolate(pred_mask[None], (h, w), mode="bilinear", align_corners=False)[0]
|
|
259
270
|
idx = pred_score > conf_thres
|
|
260
271
|
pred_mask, pred_score = pred_mask[idx], pred_score[idx]
|
|
261
272
|
|
|
262
|
-
stability_score = calculate_stability_score(
|
|
263
|
-
|
|
273
|
+
stability_score = calculate_stability_score(
|
|
274
|
+
pred_mask, self.model.mask_threshold, stability_score_offset
|
|
275
|
+
)
|
|
264
276
|
idx = stability_score > stability_score_thresh
|
|
265
277
|
pred_mask, pred_score = pred_mask[idx], pred_score[idx]
|
|
266
278
|
# Bool type is much more memory-efficient.
|
|
@@ -404,7 +416,7 @@ class Predictor(BasePredictor):
|
|
|
404
416
|
model = build_sam(self.args.model)
|
|
405
417
|
self.setup_model(model)
|
|
406
418
|
self.setup_source(image)
|
|
407
|
-
assert len(self.dataset) == 1,
|
|
419
|
+
assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
|
|
408
420
|
for batch in self.dataset:
|
|
409
421
|
im = self.preprocess(batch[1])
|
|
410
422
|
self.features = self.model.image_encoder(im)
|
|
@@ -446,9 +458,9 @@ class Predictor(BasePredictor):
|
|
|
446
458
|
scores = []
|
|
447
459
|
for mask in masks:
|
|
448
460
|
mask = mask.cpu().numpy().astype(np.uint8)
|
|
449
|
-
mask, changed = remove_small_regions(mask, min_area, mode=
|
|
461
|
+
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
|
450
462
|
unchanged = not changed
|
|
451
|
-
mask, changed = remove_small_regions(mask, min_area, mode=
|
|
463
|
+
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
|
452
464
|
unchanged = unchanged and not changed
|
|
453
465
|
|
|
454
466
|
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|