ultralytics 8.0.237__py3-none-any.whl → 8.0.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (137) hide show
  1. ultralytics/__init__.py +2 -2
  2. ultralytics/cfg/__init__.py +241 -138
  3. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  4. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  5. ultralytics/cfg/datasets/dota8.yaml +34 -0
  6. ultralytics/data/__init__.py +9 -2
  7. ultralytics/data/annotator.py +4 -4
  8. ultralytics/data/augment.py +186 -169
  9. ultralytics/data/base.py +54 -48
  10. ultralytics/data/build.py +34 -23
  11. ultralytics/data/converter.py +242 -70
  12. ultralytics/data/dataset.py +117 -95
  13. ultralytics/data/explorer/__init__.py +5 -0
  14. ultralytics/data/explorer/explorer.py +170 -97
  15. ultralytics/data/explorer/gui/__init__.py +1 -0
  16. ultralytics/data/explorer/gui/dash.py +146 -76
  17. ultralytics/data/explorer/utils.py +87 -25
  18. ultralytics/data/loaders.py +75 -62
  19. ultralytics/data/split_dota.py +44 -36
  20. ultralytics/data/utils.py +160 -142
  21. ultralytics/engine/exporter.py +348 -292
  22. ultralytics/engine/model.py +102 -66
  23. ultralytics/engine/predictor.py +74 -55
  24. ultralytics/engine/results.py +63 -40
  25. ultralytics/engine/trainer.py +192 -144
  26. ultralytics/engine/tuner.py +66 -59
  27. ultralytics/engine/validator.py +31 -26
  28. ultralytics/hub/__init__.py +54 -31
  29. ultralytics/hub/auth.py +28 -25
  30. ultralytics/hub/session.py +282 -133
  31. ultralytics/hub/utils.py +64 -42
  32. ultralytics/models/__init__.py +1 -1
  33. ultralytics/models/fastsam/__init__.py +1 -1
  34. ultralytics/models/fastsam/model.py +6 -6
  35. ultralytics/models/fastsam/predict.py +3 -2
  36. ultralytics/models/fastsam/prompt.py +55 -48
  37. ultralytics/models/fastsam/val.py +1 -1
  38. ultralytics/models/nas/__init__.py +1 -1
  39. ultralytics/models/nas/model.py +9 -8
  40. ultralytics/models/nas/predict.py +8 -6
  41. ultralytics/models/nas/val.py +11 -9
  42. ultralytics/models/rtdetr/__init__.py +1 -1
  43. ultralytics/models/rtdetr/model.py +11 -9
  44. ultralytics/models/rtdetr/train.py +18 -16
  45. ultralytics/models/rtdetr/val.py +25 -19
  46. ultralytics/models/sam/__init__.py +1 -1
  47. ultralytics/models/sam/amg.py +13 -14
  48. ultralytics/models/sam/build.py +44 -42
  49. ultralytics/models/sam/model.py +6 -6
  50. ultralytics/models/sam/modules/decoders.py +6 -4
  51. ultralytics/models/sam/modules/encoders.py +37 -35
  52. ultralytics/models/sam/modules/sam.py +5 -4
  53. ultralytics/models/sam/modules/tiny_encoder.py +95 -73
  54. ultralytics/models/sam/modules/transformer.py +3 -2
  55. ultralytics/models/sam/predict.py +39 -27
  56. ultralytics/models/utils/loss.py +99 -95
  57. ultralytics/models/utils/ops.py +34 -31
  58. ultralytics/models/yolo/__init__.py +1 -1
  59. ultralytics/models/yolo/classify/__init__.py +1 -1
  60. ultralytics/models/yolo/classify/predict.py +8 -6
  61. ultralytics/models/yolo/classify/train.py +37 -31
  62. ultralytics/models/yolo/classify/val.py +26 -24
  63. ultralytics/models/yolo/detect/__init__.py +1 -1
  64. ultralytics/models/yolo/detect/predict.py +8 -6
  65. ultralytics/models/yolo/detect/train.py +47 -37
  66. ultralytics/models/yolo/detect/val.py +100 -82
  67. ultralytics/models/yolo/model.py +31 -25
  68. ultralytics/models/yolo/obb/__init__.py +1 -1
  69. ultralytics/models/yolo/obb/predict.py +13 -12
  70. ultralytics/models/yolo/obb/train.py +3 -3
  71. ultralytics/models/yolo/obb/val.py +80 -58
  72. ultralytics/models/yolo/pose/__init__.py +1 -1
  73. ultralytics/models/yolo/pose/predict.py +17 -12
  74. ultralytics/models/yolo/pose/train.py +28 -25
  75. ultralytics/models/yolo/pose/val.py +91 -64
  76. ultralytics/models/yolo/segment/__init__.py +1 -1
  77. ultralytics/models/yolo/segment/predict.py +10 -8
  78. ultralytics/models/yolo/segment/train.py +16 -15
  79. ultralytics/models/yolo/segment/val.py +90 -68
  80. ultralytics/nn/__init__.py +26 -6
  81. ultralytics/nn/autobackend.py +144 -112
  82. ultralytics/nn/modules/__init__.py +96 -13
  83. ultralytics/nn/modules/block.py +28 -7
  84. ultralytics/nn/modules/conv.py +41 -23
  85. ultralytics/nn/modules/head.py +67 -59
  86. ultralytics/nn/modules/transformer.py +49 -32
  87. ultralytics/nn/modules/utils.py +20 -15
  88. ultralytics/nn/tasks.py +215 -141
  89. ultralytics/solutions/ai_gym.py +59 -47
  90. ultralytics/solutions/distance_calculation.py +22 -15
  91. ultralytics/solutions/heatmap.py +76 -54
  92. ultralytics/solutions/object_counter.py +46 -39
  93. ultralytics/solutions/speed_estimation.py +13 -16
  94. ultralytics/trackers/__init__.py +1 -1
  95. ultralytics/trackers/basetrack.py +1 -0
  96. ultralytics/trackers/bot_sort.py +2 -1
  97. ultralytics/trackers/byte_tracker.py +10 -7
  98. ultralytics/trackers/track.py +7 -7
  99. ultralytics/trackers/utils/gmc.py +25 -25
  100. ultralytics/trackers/utils/kalman_filter.py +85 -42
  101. ultralytics/trackers/utils/matching.py +8 -7
  102. ultralytics/utils/__init__.py +173 -151
  103. ultralytics/utils/autobatch.py +10 -10
  104. ultralytics/utils/benchmarks.py +76 -86
  105. ultralytics/utils/callbacks/__init__.py +1 -1
  106. ultralytics/utils/callbacks/base.py +29 -29
  107. ultralytics/utils/callbacks/clearml.py +51 -43
  108. ultralytics/utils/callbacks/comet.py +81 -66
  109. ultralytics/utils/callbacks/dvc.py +33 -26
  110. ultralytics/utils/callbacks/hub.py +44 -26
  111. ultralytics/utils/callbacks/mlflow.py +31 -24
  112. ultralytics/utils/callbacks/neptune.py +35 -25
  113. ultralytics/utils/callbacks/raytune.py +9 -4
  114. ultralytics/utils/callbacks/tensorboard.py +16 -11
  115. ultralytics/utils/callbacks/wb.py +39 -33
  116. ultralytics/utils/checks.py +189 -141
  117. ultralytics/utils/dist.py +15 -12
  118. ultralytics/utils/downloads.py +112 -96
  119. ultralytics/utils/errors.py +1 -1
  120. ultralytics/utils/files.py +11 -11
  121. ultralytics/utils/instance.py +22 -22
  122. ultralytics/utils/loss.py +117 -67
  123. ultralytics/utils/metrics.py +224 -158
  124. ultralytics/utils/ops.py +39 -29
  125. ultralytics/utils/patches.py +3 -3
  126. ultralytics/utils/plotting.py +217 -120
  127. ultralytics/utils/tal.py +19 -13
  128. ultralytics/utils/torch_utils.py +138 -109
  129. ultralytics/utils/triton.py +12 -10
  130. ultralytics/utils/tuner.py +49 -47
  131. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/METADATA +5 -4
  132. ultralytics-8.0.239.dist-info/RECORD +188 -0
  133. ultralytics-8.0.237.dist-info/RECORD +0 -187
  134. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
  135. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
  136. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
  137. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
@@ -28,11 +28,11 @@ class Conv2d_BN(torch.nn.Sequential):
28
28
  drop path.
29
29
  """
30
30
  super().__init__()
31
- self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
31
+ self.add_module("c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
32
32
  bn = torch.nn.BatchNorm2d(b)
33
33
  torch.nn.init.constant_(bn.weight, bn_weight_init)
34
34
  torch.nn.init.constant_(bn.bias, 0)
35
- self.add_module('bn', bn)
35
+ self.add_module("bn", bn)
36
36
 
37
37
 
38
38
  class PatchEmbed(nn.Module):
@@ -146,11 +146,11 @@ class ConvLayer(nn.Module):
146
146
  input_resolution,
147
147
  depth,
148
148
  activation,
149
- drop_path=0.,
149
+ drop_path=0.0,
150
150
  downsample=None,
151
151
  use_checkpoint=False,
152
152
  out_dim=None,
153
- conv_expand_ratio=4.,
153
+ conv_expand_ratio=4.0,
154
154
  ):
155
155
  """
156
156
  Initializes the ConvLayer with the given dimensions and settings.
@@ -173,18 +173,25 @@ class ConvLayer(nn.Module):
173
173
  self.use_checkpoint = use_checkpoint
174
174
 
175
175
  # Build blocks
176
- self.blocks = nn.ModuleList([
177
- MBConv(
178
- dim,
179
- dim,
180
- conv_expand_ratio,
181
- activation,
182
- drop_path[i] if isinstance(drop_path, list) else drop_path,
183
- ) for i in range(depth)])
176
+ self.blocks = nn.ModuleList(
177
+ [
178
+ MBConv(
179
+ dim,
180
+ dim,
181
+ conv_expand_ratio,
182
+ activation,
183
+ drop_path[i] if isinstance(drop_path, list) else drop_path,
184
+ )
185
+ for i in range(depth)
186
+ ]
187
+ )
184
188
 
185
189
  # Patch merging layer
186
- self.downsample = None if downsample is None else downsample(
187
- input_resolution, dim=dim, out_dim=out_dim, activation=activation)
190
+ self.downsample = (
191
+ None
192
+ if downsample is None
193
+ else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
194
+ )
188
195
 
189
196
  def forward(self, x):
190
197
  """Processes the input through a series of convolutional layers and returns the activated output."""
@@ -200,7 +207,7 @@ class Mlp(nn.Module):
200
207
  This layer takes an input with in_features, applies layer normalization and two fully-connected layers.
201
208
  """
202
209
 
203
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
210
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
204
211
  """Initializes Attention module with the given parameters including dimension, key_dim, number of heads, etc."""
205
212
  super().__init__()
206
213
  out_features = out_features or in_features
@@ -232,12 +239,12 @@ class Attention(torch.nn.Module):
232
239
  """
233
240
 
234
241
  def __init__(
235
- self,
236
- dim,
237
- key_dim,
238
- num_heads=8,
239
- attn_ratio=4,
240
- resolution=(14, 14),
242
+ self,
243
+ dim,
244
+ key_dim,
245
+ num_heads=8,
246
+ attn_ratio=4,
247
+ resolution=(14, 14),
241
248
  ):
242
249
  """
243
250
  Initializes the Attention module.
@@ -256,7 +263,7 @@ class Attention(torch.nn.Module):
256
263
 
257
264
  assert isinstance(resolution, tuple) and len(resolution) == 2
258
265
  self.num_heads = num_heads
259
- self.scale = key_dim ** -0.5
266
+ self.scale = key_dim**-0.5
260
267
  self.key_dim = key_dim
261
268
  self.nh_kd = nh_kd = key_dim * num_heads
262
269
  self.d = int(attn_ratio * key_dim)
@@ -279,13 +286,13 @@ class Attention(torch.nn.Module):
279
286
  attention_offsets[offset] = len(attention_offsets)
280
287
  idxs.append(attention_offsets[offset])
281
288
  self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
282
- self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False)
289
+ self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(N, N), persistent=False)
283
290
 
284
291
  @torch.no_grad()
285
292
  def train(self, mode=True):
286
293
  """Sets the module in training mode and handles attribute 'ab' based on the mode."""
287
294
  super().train(mode)
288
- if mode and hasattr(self, 'ab'):
295
+ if mode and hasattr(self, "ab"):
289
296
  del self.ab
290
297
  else:
291
298
  self.ab = self.attention_biases[:, self.attention_bias_idxs]
@@ -306,8 +313,9 @@ class Attention(torch.nn.Module):
306
313
  v = v.permute(0, 2, 1, 3)
307
314
  self.ab = self.ab.to(self.attention_biases.device)
308
315
 
309
- attn = ((q @ k.transpose(-2, -1)) * self.scale +
310
- (self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab))
316
+ attn = (q @ k.transpose(-2, -1)) * self.scale + (
317
+ self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
318
+ )
311
319
  attn = attn.softmax(dim=-1)
312
320
  x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
313
321
  return self.proj(x)
@@ -322,9 +330,9 @@ class TinyViTBlock(nn.Module):
322
330
  input_resolution,
323
331
  num_heads,
324
332
  window_size=7,
325
- mlp_ratio=4.,
326
- drop=0.,
327
- drop_path=0.,
333
+ mlp_ratio=4.0,
334
+ drop=0.0,
335
+ drop_path=0.0,
328
336
  local_conv_size=3,
329
337
  activation=nn.GELU,
330
338
  ):
@@ -350,7 +358,7 @@ class TinyViTBlock(nn.Module):
350
358
  self.dim = dim
351
359
  self.input_resolution = input_resolution
352
360
  self.num_heads = num_heads
353
- assert window_size > 0, 'window_size must be greater than 0'
361
+ assert window_size > 0, "window_size must be greater than 0"
354
362
  self.window_size = window_size
355
363
  self.mlp_ratio = mlp_ratio
356
364
 
@@ -358,7 +366,7 @@ class TinyViTBlock(nn.Module):
358
366
  # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
359
367
  self.drop_path = nn.Identity()
360
368
 
361
- assert dim % num_heads == 0, 'dim must be divisible by num_heads'
369
+ assert dim % num_heads == 0, "dim must be divisible by num_heads"
362
370
  head_dim = dim // num_heads
363
371
 
364
372
  window_resolution = (window_size, window_size)
@@ -377,7 +385,7 @@ class TinyViTBlock(nn.Module):
377
385
  """
378
386
  H, W = self.input_resolution
379
387
  B, L, C = x.shape
380
- assert L == H * W, 'input feature has wrong size'
388
+ assert L == H * W, "input feature has wrong size"
381
389
  res_x = x
382
390
  if H == self.window_size and W == self.window_size:
383
391
  x = self.attn(x)
@@ -394,8 +402,11 @@ class TinyViTBlock(nn.Module):
394
402
  nH = pH // self.window_size
395
403
  nW = pW // self.window_size
396
404
  # Window partition
397
- x = x.view(B, nH, self.window_size, nW, self.window_size,
398
- C).transpose(2, 3).reshape(B * nH * nW, self.window_size * self.window_size, C)
405
+ x = (
406
+ x.view(B, nH, self.window_size, nW, self.window_size, C)
407
+ .transpose(2, 3)
408
+ .reshape(B * nH * nW, self.window_size * self.window_size, C)
409
+ )
399
410
  x = self.attn(x)
400
411
  # Window reverse
401
412
  x = x.view(B, nH, nW, self.window_size, self.window_size, C).transpose(2, 3).reshape(B, pH, pW, C)
@@ -417,8 +428,10 @@ class TinyViTBlock(nn.Module):
417
428
  """Returns a formatted string representing the TinyViTBlock's parameters: dimension, input resolution, number of
418
429
  attentions heads, window size, and MLP ratio.
419
430
  """
420
- return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \
421
- f'window_size={self.window_size}, mlp_ratio={self.mlp_ratio}'
431
+ return (
432
+ f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
433
+ f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
434
+ )
422
435
 
423
436
 
424
437
  class BasicLayer(nn.Module):
@@ -431,9 +444,9 @@ class BasicLayer(nn.Module):
431
444
  depth,
432
445
  num_heads,
433
446
  window_size,
434
- mlp_ratio=4.,
435
- drop=0.,
436
- drop_path=0.,
447
+ mlp_ratio=4.0,
448
+ drop=0.0,
449
+ drop_path=0.0,
437
450
  downsample=None,
438
451
  use_checkpoint=False,
439
452
  local_conv_size=3,
@@ -468,22 +481,29 @@ class BasicLayer(nn.Module):
468
481
  self.use_checkpoint = use_checkpoint
469
482
 
470
483
  # Build blocks
471
- self.blocks = nn.ModuleList([
472
- TinyViTBlock(
473
- dim=dim,
474
- input_resolution=input_resolution,
475
- num_heads=num_heads,
476
- window_size=window_size,
477
- mlp_ratio=mlp_ratio,
478
- drop=drop,
479
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
480
- local_conv_size=local_conv_size,
481
- activation=activation,
482
- ) for i in range(depth)])
484
+ self.blocks = nn.ModuleList(
485
+ [
486
+ TinyViTBlock(
487
+ dim=dim,
488
+ input_resolution=input_resolution,
489
+ num_heads=num_heads,
490
+ window_size=window_size,
491
+ mlp_ratio=mlp_ratio,
492
+ drop=drop,
493
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
494
+ local_conv_size=local_conv_size,
495
+ activation=activation,
496
+ )
497
+ for i in range(depth)
498
+ ]
499
+ )
483
500
 
484
501
  # Patch merging layer
485
- self.downsample = None if downsample is None else downsample(
486
- input_resolution, dim=dim, out_dim=out_dim, activation=activation)
502
+ self.downsample = (
503
+ None
504
+ if downsample is None
505
+ else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
506
+ )
487
507
 
488
508
  def forward(self, x):
489
509
  """Performs forward propagation on the input tensor and returns a normalized tensor."""
@@ -493,7 +513,7 @@ class BasicLayer(nn.Module):
493
513
 
494
514
  def extra_repr(self) -> str:
495
515
  """Returns a string representation of the extra_repr function with the layer's parameters."""
496
- return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
516
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
497
517
 
498
518
 
499
519
  class LayerNorm2d(nn.Module):
@@ -549,8 +569,8 @@ class TinyViT(nn.Module):
549
569
  depths=[2, 2, 6, 2],
550
570
  num_heads=[3, 6, 12, 24],
551
571
  window_sizes=[7, 7, 14, 7],
552
- mlp_ratio=4.,
553
- drop_rate=0.,
572
+ mlp_ratio=4.0,
573
+ drop_rate=0.0,
554
574
  drop_path_rate=0.1,
555
575
  use_checkpoint=False,
556
576
  mbconv_expand_ratio=4.0,
@@ -585,10 +605,9 @@ class TinyViT(nn.Module):
585
605
 
586
606
  activation = nn.GELU
587
607
 
588
- self.patch_embed = PatchEmbed(in_chans=in_chans,
589
- embed_dim=embed_dims[0],
590
- resolution=img_size,
591
- activation=activation)
608
+ self.patch_embed = PatchEmbed(
609
+ in_chans=in_chans, embed_dim=embed_dims[0], resolution=img_size, activation=activation
610
+ )
592
611
 
593
612
  patches_resolution = self.patch_embed.patches_resolution
594
613
  self.patches_resolution = patches_resolution
@@ -601,27 +620,30 @@ class TinyViT(nn.Module):
601
620
  for i_layer in range(self.num_layers):
602
621
  kwargs = dict(
603
622
  dim=embed_dims[i_layer],
604
- input_resolution=(patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
605
- patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer))),
623
+ input_resolution=(
624
+ patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
625
+ patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
626
+ ),
606
627
  # input_resolution=(patches_resolution[0] // (2 ** i_layer),
607
628
  # patches_resolution[1] // (2 ** i_layer)),
608
629
  depth=depths[i_layer],
609
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
630
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
610
631
  downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
611
632
  use_checkpoint=use_checkpoint,
612
- out_dim=embed_dims[min(i_layer + 1,
613
- len(embed_dims) - 1)],
633
+ out_dim=embed_dims[min(i_layer + 1, len(embed_dims) - 1)],
614
634
  activation=activation,
615
635
  )
616
636
  if i_layer == 0:
617
637
  layer = ConvLayer(conv_expand_ratio=mbconv_expand_ratio, **kwargs)
618
638
  else:
619
- layer = BasicLayer(num_heads=num_heads[i_layer],
620
- window_size=window_sizes[i_layer],
621
- mlp_ratio=self.mlp_ratio,
622
- drop=drop_rate,
623
- local_conv_size=local_conv_size,
624
- **kwargs)
639
+ layer = BasicLayer(
640
+ num_heads=num_heads[i_layer],
641
+ window_size=window_sizes[i_layer],
642
+ mlp_ratio=self.mlp_ratio,
643
+ drop=drop_rate,
644
+ local_conv_size=local_conv_size,
645
+ **kwargs,
646
+ )
625
647
  self.layers.append(layer)
626
648
 
627
649
  # Classifier head
@@ -680,7 +702,7 @@ class TinyViT(nn.Module):
680
702
  def _check_lr_scale(m):
681
703
  """Checks if the learning rate scale attribute is present in module's parameters."""
682
704
  for p in m.parameters():
683
- assert hasattr(p, 'lr_scale'), p.param_name
705
+ assert hasattr(p, "lr_scale"), p.param_name
684
706
 
685
707
  self.apply(_check_lr_scale)
686
708
 
@@ -698,7 +720,7 @@ class TinyViT(nn.Module):
698
720
  @torch.jit.ignore
699
721
  def no_weight_decay_keywords(self):
700
722
  """Returns a dictionary of parameter names where weight decay should not be applied."""
701
- return {'attention_biases'}
723
+ return {"attention_biases"}
702
724
 
703
725
  def forward_features(self, x):
704
726
  """Runs the input through the model layers and returns the transformed output."""
@@ -62,7 +62,8 @@ class TwoWayTransformer(nn.Module):
62
62
  activation=activation,
63
63
  attention_downsample_rate=attention_downsample_rate,
64
64
  skip_first_layer_pe=(i == 0),
65
- ))
65
+ )
66
+ )
66
67
 
67
68
  self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
68
69
  self.norm_final_attn = nn.LayerNorm(embedding_dim)
@@ -227,7 +228,7 @@ class Attention(nn.Module):
227
228
  self.embedding_dim = embedding_dim
228
229
  self.internal_dim = embedding_dim // downsample_rate
229
230
  self.num_heads = num_heads
230
- assert self.internal_dim % num_heads == 0, 'num_heads must divide embedding_dim.'
231
+ assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
231
232
 
232
233
  self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
233
234
  self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
@@ -19,8 +19,17 @@ from ultralytics.engine.results import Results
19
19
  from ultralytics.utils import DEFAULT_CFG, ops
20
20
  from ultralytics.utils.torch_utils import select_device
21
21
 
22
- from .amg import (batch_iterator, batched_mask_to_box, build_all_layer_point_grids, calculate_stability_score,
23
- generate_crop_boxes, is_box_near_crop_edge, remove_small_regions, uncrop_boxes_xyxy, uncrop_masks)
22
+ from .amg import (
23
+ batch_iterator,
24
+ batched_mask_to_box,
25
+ build_all_layer_point_grids,
26
+ calculate_stability_score,
27
+ generate_crop_boxes,
28
+ is_box_near_crop_edge,
29
+ remove_small_regions,
30
+ uncrop_boxes_xyxy,
31
+ uncrop_masks,
32
+ )
24
33
  from .build import build_sam
25
34
 
26
35
 
@@ -58,7 +67,7 @@ class Predictor(BasePredictor):
58
67
  """
59
68
  if overrides is None:
60
69
  overrides = {}
61
- overrides.update(dict(task='segment', mode='predict', imgsz=1024))
70
+ overrides.update(dict(task="segment", mode="predict", imgsz=1024))
62
71
  super().__init__(cfg, overrides, _callbacks)
63
72
  self.args.retina_masks = True
64
73
  self.im = None
@@ -107,7 +116,7 @@ class Predictor(BasePredictor):
107
116
  Returns:
108
117
  (List[np.ndarray]): List of transformed images.
109
118
  """
110
- assert len(im) == 1, 'SAM model does not currently support batched inference'
119
+ assert len(im) == 1, "SAM model does not currently support batched inference"
111
120
  letterbox = LetterBox(self.args.imgsz, auto=False, center=False)
112
121
  return [letterbox(image=x) for x in im]
113
122
 
@@ -132,9 +141,9 @@ class Predictor(BasePredictor):
132
141
  - np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
133
142
  """
134
143
  # Override prompts if any stored in self.prompts
135
- bboxes = self.prompts.pop('bboxes', bboxes)
136
- points = self.prompts.pop('points', points)
137
- masks = self.prompts.pop('masks', masks)
144
+ bboxes = self.prompts.pop("bboxes", bboxes)
145
+ points = self.prompts.pop("points", points)
146
+ masks = self.prompts.pop("masks", masks)
138
147
 
139
148
  if all(i is None for i in [bboxes, points, masks]):
140
149
  return self.generate(im, *args, **kwargs)
@@ -199,18 +208,20 @@ class Predictor(BasePredictor):
199
208
  # `d` could be 1 or 3 depends on `multimask_output`.
200
209
  return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
201
210
 
202
- def generate(self,
203
- im,
204
- crop_n_layers=0,
205
- crop_overlap_ratio=512 / 1500,
206
- crop_downscale_factor=1,
207
- point_grids=None,
208
- points_stride=32,
209
- points_batch_size=64,
210
- conf_thres=0.88,
211
- stability_score_thresh=0.95,
212
- stability_score_offset=0.95,
213
- crop_nms_thresh=0.7):
211
+ def generate(
212
+ self,
213
+ im,
214
+ crop_n_layers=0,
215
+ crop_overlap_ratio=512 / 1500,
216
+ crop_downscale_factor=1,
217
+ point_grids=None,
218
+ points_stride=32,
219
+ points_batch_size=64,
220
+ conf_thres=0.88,
221
+ stability_score_thresh=0.95,
222
+ stability_score_offset=0.95,
223
+ crop_nms_thresh=0.7,
224
+ ):
214
225
  """
215
226
  Perform image segmentation using the Segment Anything Model (SAM).
216
227
 
@@ -248,19 +259,20 @@ class Predictor(BasePredictor):
248
259
  area = torch.tensor(w * h, device=im.device)
249
260
  points_scale = np.array([[w, h]]) # w, h
250
261
  # Crop image and interpolate to input size
251
- crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode='bilinear', align_corners=False)
262
+ crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode="bilinear", align_corners=False)
252
263
  # (num_points, 2)
253
264
  points_for_image = point_grids[layer_idx] * points_scale
254
265
  crop_masks, crop_scores, crop_bboxes = [], [], []
255
- for (points, ) in batch_iterator(points_batch_size, points_for_image):
266
+ for (points,) in batch_iterator(points_batch_size, points_for_image):
256
267
  pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True)
257
268
  # Interpolate predicted masks to input size
258
- pred_mask = F.interpolate(pred_mask[None], (h, w), mode='bilinear', align_corners=False)[0]
269
+ pred_mask = F.interpolate(pred_mask[None], (h, w), mode="bilinear", align_corners=False)[0]
259
270
  idx = pred_score > conf_thres
260
271
  pred_mask, pred_score = pred_mask[idx], pred_score[idx]
261
272
 
262
- stability_score = calculate_stability_score(pred_mask, self.model.mask_threshold,
263
- stability_score_offset)
273
+ stability_score = calculate_stability_score(
274
+ pred_mask, self.model.mask_threshold, stability_score_offset
275
+ )
264
276
  idx = stability_score > stability_score_thresh
265
277
  pred_mask, pred_score = pred_mask[idx], pred_score[idx]
266
278
  # Bool type is much more memory-efficient.
@@ -404,7 +416,7 @@ class Predictor(BasePredictor):
404
416
  model = build_sam(self.args.model)
405
417
  self.setup_model(model)
406
418
  self.setup_source(image)
407
- assert len(self.dataset) == 1, '`set_image` only supports setting one image!'
419
+ assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
408
420
  for batch in self.dataset:
409
421
  im = self.preprocess(batch[1])
410
422
  self.features = self.model.image_encoder(im)
@@ -446,9 +458,9 @@ class Predictor(BasePredictor):
446
458
  scores = []
447
459
  for mask in masks:
448
460
  mask = mask.cpu().numpy().astype(np.uint8)
449
- mask, changed = remove_small_regions(mask, min_area, mode='holes')
461
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
450
462
  unchanged = not changed
451
- mask, changed = remove_small_regions(mask, min_area, mode='islands')
463
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
452
464
  unchanged = unchanged and not changed
453
465
 
454
466
  new_masks.append(torch.as_tensor(mask).unsqueeze(0))