ultralytics 8.0.238__py3-none-any.whl → 8.0.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (134) hide show
  1. ultralytics/__init__.py +2 -2
  2. ultralytics/cfg/__init__.py +241 -138
  3. ultralytics/data/__init__.py +9 -2
  4. ultralytics/data/annotator.py +4 -4
  5. ultralytics/data/augment.py +186 -169
  6. ultralytics/data/base.py +54 -48
  7. ultralytics/data/build.py +34 -23
  8. ultralytics/data/converter.py +242 -70
  9. ultralytics/data/dataset.py +117 -95
  10. ultralytics/data/explorer/__init__.py +3 -1
  11. ultralytics/data/explorer/explorer.py +120 -100
  12. ultralytics/data/explorer/gui/__init__.py +1 -0
  13. ultralytics/data/explorer/gui/dash.py +123 -89
  14. ultralytics/data/explorer/utils.py +37 -39
  15. ultralytics/data/loaders.py +75 -62
  16. ultralytics/data/split_dota.py +44 -36
  17. ultralytics/data/utils.py +160 -142
  18. ultralytics/engine/exporter.py +348 -292
  19. ultralytics/engine/model.py +102 -66
  20. ultralytics/engine/predictor.py +74 -55
  21. ultralytics/engine/results.py +61 -41
  22. ultralytics/engine/trainer.py +192 -144
  23. ultralytics/engine/tuner.py +66 -59
  24. ultralytics/engine/validator.py +31 -26
  25. ultralytics/hub/__init__.py +54 -31
  26. ultralytics/hub/auth.py +28 -25
  27. ultralytics/hub/session.py +282 -133
  28. ultralytics/hub/utils.py +64 -42
  29. ultralytics/models/__init__.py +1 -1
  30. ultralytics/models/fastsam/__init__.py +1 -1
  31. ultralytics/models/fastsam/model.py +6 -6
  32. ultralytics/models/fastsam/predict.py +3 -2
  33. ultralytics/models/fastsam/prompt.py +55 -48
  34. ultralytics/models/fastsam/val.py +1 -1
  35. ultralytics/models/nas/__init__.py +1 -1
  36. ultralytics/models/nas/model.py +9 -8
  37. ultralytics/models/nas/predict.py +8 -6
  38. ultralytics/models/nas/val.py +11 -9
  39. ultralytics/models/rtdetr/__init__.py +1 -1
  40. ultralytics/models/rtdetr/model.py +11 -9
  41. ultralytics/models/rtdetr/train.py +18 -16
  42. ultralytics/models/rtdetr/val.py +25 -19
  43. ultralytics/models/sam/__init__.py +1 -1
  44. ultralytics/models/sam/amg.py +13 -14
  45. ultralytics/models/sam/build.py +44 -42
  46. ultralytics/models/sam/model.py +6 -6
  47. ultralytics/models/sam/modules/decoders.py +6 -4
  48. ultralytics/models/sam/modules/encoders.py +37 -35
  49. ultralytics/models/sam/modules/sam.py +5 -4
  50. ultralytics/models/sam/modules/tiny_encoder.py +95 -73
  51. ultralytics/models/sam/modules/transformer.py +3 -2
  52. ultralytics/models/sam/predict.py +39 -27
  53. ultralytics/models/utils/loss.py +99 -95
  54. ultralytics/models/utils/ops.py +34 -31
  55. ultralytics/models/yolo/__init__.py +1 -1
  56. ultralytics/models/yolo/classify/__init__.py +1 -1
  57. ultralytics/models/yolo/classify/predict.py +8 -6
  58. ultralytics/models/yolo/classify/train.py +37 -31
  59. ultralytics/models/yolo/classify/val.py +26 -24
  60. ultralytics/models/yolo/detect/__init__.py +1 -1
  61. ultralytics/models/yolo/detect/predict.py +8 -6
  62. ultralytics/models/yolo/detect/train.py +47 -37
  63. ultralytics/models/yolo/detect/val.py +100 -82
  64. ultralytics/models/yolo/model.py +31 -25
  65. ultralytics/models/yolo/obb/__init__.py +1 -1
  66. ultralytics/models/yolo/obb/predict.py +13 -11
  67. ultralytics/models/yolo/obb/train.py +3 -3
  68. ultralytics/models/yolo/obb/val.py +70 -59
  69. ultralytics/models/yolo/pose/__init__.py +1 -1
  70. ultralytics/models/yolo/pose/predict.py +17 -12
  71. ultralytics/models/yolo/pose/train.py +28 -25
  72. ultralytics/models/yolo/pose/val.py +91 -64
  73. ultralytics/models/yolo/segment/__init__.py +1 -1
  74. ultralytics/models/yolo/segment/predict.py +10 -8
  75. ultralytics/models/yolo/segment/train.py +16 -15
  76. ultralytics/models/yolo/segment/val.py +90 -68
  77. ultralytics/nn/__init__.py +26 -6
  78. ultralytics/nn/autobackend.py +144 -112
  79. ultralytics/nn/modules/__init__.py +96 -13
  80. ultralytics/nn/modules/block.py +28 -7
  81. ultralytics/nn/modules/conv.py +41 -23
  82. ultralytics/nn/modules/head.py +60 -52
  83. ultralytics/nn/modules/transformer.py +49 -32
  84. ultralytics/nn/modules/utils.py +20 -15
  85. ultralytics/nn/tasks.py +215 -141
  86. ultralytics/solutions/ai_gym.py +59 -47
  87. ultralytics/solutions/distance_calculation.py +17 -14
  88. ultralytics/solutions/heatmap.py +57 -55
  89. ultralytics/solutions/object_counter.py +46 -39
  90. ultralytics/solutions/speed_estimation.py +13 -16
  91. ultralytics/trackers/__init__.py +1 -1
  92. ultralytics/trackers/basetrack.py +1 -0
  93. ultralytics/trackers/bot_sort.py +2 -1
  94. ultralytics/trackers/byte_tracker.py +10 -7
  95. ultralytics/trackers/track.py +7 -7
  96. ultralytics/trackers/utils/gmc.py +25 -25
  97. ultralytics/trackers/utils/kalman_filter.py +85 -42
  98. ultralytics/trackers/utils/matching.py +8 -7
  99. ultralytics/utils/__init__.py +173 -152
  100. ultralytics/utils/autobatch.py +10 -10
  101. ultralytics/utils/benchmarks.py +76 -86
  102. ultralytics/utils/callbacks/__init__.py +1 -1
  103. ultralytics/utils/callbacks/base.py +29 -29
  104. ultralytics/utils/callbacks/clearml.py +51 -43
  105. ultralytics/utils/callbacks/comet.py +81 -66
  106. ultralytics/utils/callbacks/dvc.py +33 -26
  107. ultralytics/utils/callbacks/hub.py +44 -26
  108. ultralytics/utils/callbacks/mlflow.py +31 -24
  109. ultralytics/utils/callbacks/neptune.py +35 -25
  110. ultralytics/utils/callbacks/raytune.py +9 -4
  111. ultralytics/utils/callbacks/tensorboard.py +16 -11
  112. ultralytics/utils/callbacks/wb.py +39 -33
  113. ultralytics/utils/checks.py +189 -141
  114. ultralytics/utils/dist.py +15 -12
  115. ultralytics/utils/downloads.py +112 -96
  116. ultralytics/utils/errors.py +1 -1
  117. ultralytics/utils/files.py +11 -11
  118. ultralytics/utils/instance.py +22 -22
  119. ultralytics/utils/loss.py +117 -67
  120. ultralytics/utils/metrics.py +224 -158
  121. ultralytics/utils/ops.py +38 -28
  122. ultralytics/utils/patches.py +3 -3
  123. ultralytics/utils/plotting.py +217 -120
  124. ultralytics/utils/tal.py +19 -13
  125. ultralytics/utils/torch_utils.py +138 -109
  126. ultralytics/utils/triton.py +12 -10
  127. ultralytics/utils/tuner.py +49 -47
  128. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/METADATA +2 -1
  129. ultralytics-8.0.239.dist-info/RECORD +188 -0
  130. ultralytics-8.0.238.dist-info/RECORD +0 -188
  131. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
  132. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
  133. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
  134. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
@@ -8,10 +8,9 @@ import numpy as np
8
8
  import torch
9
9
 
10
10
 
11
- def is_box_near_crop_edge(boxes: torch.Tensor,
12
- crop_box: List[int],
13
- orig_box: List[int],
14
- atol: float = 20.0) -> torch.Tensor:
11
+ def is_box_near_crop_edge(
12
+ boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
13
+ ) -> torch.Tensor:
15
14
  """Return a boolean tensor indicating if boxes are near the crop edge."""
16
15
  crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
17
16
  orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
@@ -24,10 +23,10 @@ def is_box_near_crop_edge(boxes: torch.Tensor,
24
23
 
25
24
  def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
26
25
  """Yield batches of data from the input arguments."""
27
- assert args and all(len(a) == len(args[0]) for a in args), 'Batched iteration must have same-size inputs.'
26
+ assert args and all(len(a) == len(args[0]) for a in args), "Batched iteration must have same-size inputs."
28
27
  n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
29
28
  for b in range(n_batches):
30
- yield [arg[b * batch_size:(b + 1) * batch_size] for arg in args]
29
+ yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
31
30
 
32
31
 
33
32
  def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
@@ -39,9 +38,8 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh
39
38
  """
40
39
  # One mask is always contained inside the other.
41
40
  # Save memory by preventing unnecessary cast to torch.int64
42
- intersections = ((masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1,
43
- dtype=torch.int32))
44
- unions = ((masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32))
41
+ intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
42
+ unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
45
43
  return intersections / unions
46
44
 
47
45
 
@@ -56,11 +54,12 @@ def build_point_grid(n_per_side: int) -> np.ndarray:
56
54
 
57
55
  def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
58
56
  """Generate point grids for all crop layers."""
59
- return [build_point_grid(int(n_per_side / (scale_per_layer ** i))) for i in range(n_layers + 1)]
57
+ return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]
60
58
 
61
59
 
62
- def generate_crop_boxes(im_size: Tuple[int, ...], n_layers: int,
63
- overlap_ratio: float) -> Tuple[List[List[int]], List[int]]:
60
+ def generate_crop_boxes(
61
+ im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
62
+ ) -> Tuple[List[List[int]], List[int]]:
64
63
  """
65
64
  Generates a list of crop boxes of different sizes.
66
65
 
@@ -132,8 +131,8 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
132
131
  """Remove small disconnected regions or holes in a mask, returning the mask and a modification indicator."""
133
132
  import cv2 # type: ignore
134
133
 
135
- assert mode in {'holes', 'islands'}
136
- correct_holes = mode == 'holes'
134
+ assert mode in {"holes", "islands"}
135
+ correct_holes = mode == "holes"
137
136
  working_mask = (correct_holes ^ mask).astype(np.uint8)
138
137
  n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
139
138
  sizes = stats[:, -1][1:] # Row 0 is background label
@@ -64,46 +64,47 @@ def build_mobile_sam(checkpoint=None):
64
64
  )
65
65
 
66
66
 
67
- def _build_sam(encoder_embed_dim,
68
- encoder_depth,
69
- encoder_num_heads,
70
- encoder_global_attn_indexes,
71
- checkpoint=None,
72
- mobile_sam=False):
67
+ def _build_sam(
68
+ encoder_embed_dim, encoder_depth, encoder_num_heads, encoder_global_attn_indexes, checkpoint=None, mobile_sam=False
69
+ ):
73
70
  """Builds the selected SAM model architecture."""
74
71
  prompt_embed_dim = 256
75
72
  image_size = 1024
76
73
  vit_patch_size = 16
77
74
  image_embedding_size = image_size // vit_patch_size
78
- image_encoder = (TinyViT(
79
- img_size=1024,
80
- in_chans=3,
81
- num_classes=1000,
82
- embed_dims=encoder_embed_dim,
83
- depths=encoder_depth,
84
- num_heads=encoder_num_heads,
85
- window_sizes=[7, 7, 14, 7],
86
- mlp_ratio=4.0,
87
- drop_rate=0.0,
88
- drop_path_rate=0.0,
89
- use_checkpoint=False,
90
- mbconv_expand_ratio=4.0,
91
- local_conv_size=3,
92
- layer_lr_decay=0.8,
93
- ) if mobile_sam else ImageEncoderViT(
94
- depth=encoder_depth,
95
- embed_dim=encoder_embed_dim,
96
- img_size=image_size,
97
- mlp_ratio=4,
98
- norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
99
- num_heads=encoder_num_heads,
100
- patch_size=vit_patch_size,
101
- qkv_bias=True,
102
- use_rel_pos=True,
103
- global_attn_indexes=encoder_global_attn_indexes,
104
- window_size=14,
105
- out_chans=prompt_embed_dim,
106
- ))
75
+ image_encoder = (
76
+ TinyViT(
77
+ img_size=1024,
78
+ in_chans=3,
79
+ num_classes=1000,
80
+ embed_dims=encoder_embed_dim,
81
+ depths=encoder_depth,
82
+ num_heads=encoder_num_heads,
83
+ window_sizes=[7, 7, 14, 7],
84
+ mlp_ratio=4.0,
85
+ drop_rate=0.0,
86
+ drop_path_rate=0.0,
87
+ use_checkpoint=False,
88
+ mbconv_expand_ratio=4.0,
89
+ local_conv_size=3,
90
+ layer_lr_decay=0.8,
91
+ )
92
+ if mobile_sam
93
+ else ImageEncoderViT(
94
+ depth=encoder_depth,
95
+ embed_dim=encoder_embed_dim,
96
+ img_size=image_size,
97
+ mlp_ratio=4,
98
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
99
+ num_heads=encoder_num_heads,
100
+ patch_size=vit_patch_size,
101
+ qkv_bias=True,
102
+ use_rel_pos=True,
103
+ global_attn_indexes=encoder_global_attn_indexes,
104
+ window_size=14,
105
+ out_chans=prompt_embed_dim,
106
+ )
107
+ )
107
108
  sam = Sam(
108
109
  image_encoder=image_encoder,
109
110
  prompt_encoder=PromptEncoder(
@@ -129,7 +130,7 @@ def _build_sam(encoder_embed_dim,
129
130
  )
130
131
  if checkpoint is not None:
131
132
  checkpoint = attempt_download_asset(checkpoint)
132
- with open(checkpoint, 'rb') as f:
133
+ with open(checkpoint, "rb") as f:
133
134
  state_dict = torch.load(f)
134
135
  sam.load_state_dict(state_dict)
135
136
  sam.eval()
@@ -139,13 +140,14 @@ def _build_sam(encoder_embed_dim,
139
140
 
140
141
 
141
142
  sam_model_map = {
142
- 'sam_h.pt': build_sam_vit_h,
143
- 'sam_l.pt': build_sam_vit_l,
144
- 'sam_b.pt': build_sam_vit_b,
145
- 'mobile_sam.pt': build_mobile_sam, }
143
+ "sam_h.pt": build_sam_vit_h,
144
+ "sam_l.pt": build_sam_vit_l,
145
+ "sam_b.pt": build_sam_vit_b,
146
+ "mobile_sam.pt": build_mobile_sam,
147
+ }
146
148
 
147
149
 
148
- def build_sam(ckpt='sam_b.pt'):
150
+ def build_sam(ckpt="sam_b.pt"):
149
151
  """Build a SAM model specified by ckpt."""
150
152
  model_builder = None
151
153
  ckpt = str(ckpt) # to allow Path ckpt types
@@ -154,6 +156,6 @@ def build_sam(ckpt='sam_b.pt'):
154
156
  model_builder = sam_model_map.get(k)
155
157
 
156
158
  if not model_builder:
157
- raise FileNotFoundError(f'{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}')
159
+ raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}")
158
160
 
159
161
  return model_builder(ckpt)
@@ -32,7 +32,7 @@ class SAM(Model):
32
32
  dataset.
33
33
  """
34
34
 
35
- def __init__(self, model='sam_b.pt') -> None:
35
+ def __init__(self, model="sam_b.pt") -> None:
36
36
  """
37
37
  Initializes the SAM model with a pre-trained model file.
38
38
 
@@ -42,9 +42,9 @@ class SAM(Model):
42
42
  Raises:
43
43
  NotImplementedError: If the model file extension is not .pt or .pth.
44
44
  """
45
- if model and Path(model).suffix not in ('.pt', '.pth'):
46
- raise NotImplementedError('SAM prediction requires pre-trained *.pt or *.pth model.')
47
- super().__init__(model=model, task='segment')
45
+ if model and Path(model).suffix not in (".pt", ".pth"):
46
+ raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
47
+ super().__init__(model=model, task="segment")
48
48
 
49
49
  def _load(self, weights: str, task=None):
50
50
  """
@@ -70,7 +70,7 @@ class SAM(Model):
70
70
  Returns:
71
71
  (list): The model predictions.
72
72
  """
73
- overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024)
73
+ overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024)
74
74
  kwargs.update(overrides)
75
75
  prompts = dict(bboxes=bboxes, points=points, labels=labels)
76
76
  return super().predict(source, stream, prompts=prompts, **kwargs)
@@ -112,4 +112,4 @@ class SAM(Model):
112
112
  Returns:
113
113
  (dict): A dictionary mapping the 'segment' task to its corresponding 'Predictor'.
114
114
  """
115
- return {'segment': {'predictor': Predictor}}
115
+ return {"segment": {"predictor": Predictor}}
@@ -64,8 +64,9 @@ class MaskDecoder(nn.Module):
64
64
  nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
65
65
  activation(),
66
66
  )
67
- self.output_hypernetworks_mlps = nn.ModuleList([
68
- MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)])
67
+ self.output_hypernetworks_mlps = nn.ModuleList(
68
+ [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]
69
+ )
69
70
 
70
71
  self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth)
71
72
 
@@ -132,13 +133,14 @@ class MaskDecoder(nn.Module):
132
133
  # Run the transformer
133
134
  hs, src = self.transformer(src, pos_src, tokens)
134
135
  iou_token_out = hs[:, 0, :]
135
- mask_tokens_out = hs[:, 1:(1 + self.num_mask_tokens), :]
136
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
136
137
 
137
138
  # Upscale mask embeddings and predict masks using the mask tokens
138
139
  src = src.transpose(1, 2).view(b, c, h, w)
139
140
  upscaled_embedding = self.output_upscaling(src)
140
141
  hyper_in_list: List[torch.Tensor] = [
141
- self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)]
142
+ self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)
143
+ ]
142
144
  hyper_in = torch.stack(hyper_in_list, dim=1)
143
145
  b, c, h, w = upscaled_embedding.shape
144
146
  masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
@@ -28,23 +28,23 @@ class ImageEncoderViT(nn.Module):
28
28
  """
29
29
 
30
30
  def __init__(
31
- self,
32
- img_size: int = 1024,
33
- patch_size: int = 16,
34
- in_chans: int = 3,
35
- embed_dim: int = 768,
36
- depth: int = 12,
37
- num_heads: int = 12,
38
- mlp_ratio: float = 4.0,
39
- out_chans: int = 256,
40
- qkv_bias: bool = True,
41
- norm_layer: Type[nn.Module] = nn.LayerNorm,
42
- act_layer: Type[nn.Module] = nn.GELU,
43
- use_abs_pos: bool = True,
44
- use_rel_pos: bool = False,
45
- rel_pos_zero_init: bool = True,
46
- window_size: int = 0,
47
- global_attn_indexes: Tuple[int, ...] = (),
31
+ self,
32
+ img_size: int = 1024,
33
+ patch_size: int = 16,
34
+ in_chans: int = 3,
35
+ embed_dim: int = 768,
36
+ depth: int = 12,
37
+ num_heads: int = 12,
38
+ mlp_ratio: float = 4.0,
39
+ out_chans: int = 256,
40
+ qkv_bias: bool = True,
41
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
42
+ act_layer: Type[nn.Module] = nn.GELU,
43
+ use_abs_pos: bool = True,
44
+ use_rel_pos: bool = False,
45
+ rel_pos_zero_init: bool = True,
46
+ window_size: int = 0,
47
+ global_attn_indexes: Tuple[int, ...] = (),
48
48
  ) -> None:
49
49
  """
50
50
  Args:
@@ -283,9 +283,9 @@ class PromptEncoder(nn.Module):
283
283
  if masks is not None:
284
284
  dense_embeddings = self._embed_masks(masks)
285
285
  else:
286
- dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1,
287
- 1).expand(bs, -1, self.image_embedding_size[0],
288
- self.image_embedding_size[1])
286
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
287
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
288
+ )
289
289
 
290
290
  return sparse_embeddings, dense_embeddings
291
291
 
@@ -298,7 +298,7 @@ class PositionEmbeddingRandom(nn.Module):
298
298
  super().__init__()
299
299
  if scale is None or scale <= 0.0:
300
300
  scale = 1.0
301
- self.register_buffer('positional_encoding_gaussian_matrix', scale * torch.randn((2, num_pos_feats)))
301
+ self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats)))
302
302
 
303
303
  # Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation'
304
304
  torch.use_deterministic_algorithms(False)
@@ -425,14 +425,14 @@ class Attention(nn.Module):
425
425
  super().__init__()
426
426
  self.num_heads = num_heads
427
427
  head_dim = dim // num_heads
428
- self.scale = head_dim ** -0.5
428
+ self.scale = head_dim**-0.5
429
429
 
430
430
  self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
431
431
  self.proj = nn.Linear(dim, dim)
432
432
 
433
433
  self.use_rel_pos = use_rel_pos
434
434
  if self.use_rel_pos:
435
- assert (input_size is not None), 'Input size must be provided if using relative positional encoding.'
435
+ assert input_size is not None, "Input size must be provided if using relative positional encoding."
436
436
  # Initialize relative positional embeddings
437
437
  self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
438
438
  self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
@@ -479,8 +479,9 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T
479
479
  return windows, (Hp, Wp)
480
480
 
481
481
 
482
- def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int],
483
- hw: Tuple[int, int]) -> torch.Tensor:
482
+ def window_unpartition(
483
+ windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
484
+ ) -> torch.Tensor:
484
485
  """
485
486
  Window unpartition into original sequences and removing padding.
486
487
 
@@ -523,7 +524,7 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor
523
524
  rel_pos_resized = F.interpolate(
524
525
  rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
525
526
  size=max_rel_dist,
526
- mode='linear',
527
+ mode="linear",
527
528
  )
528
529
  rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
529
530
  else:
@@ -567,11 +568,12 @@ def add_decomposed_rel_pos(
567
568
 
568
569
  B, _, dim = q.shape
569
570
  r_q = q.reshape(B, q_h, q_w, dim)
570
- rel_h = torch.einsum('bhwc,hkc->bhwk', r_q, Rh)
571
- rel_w = torch.einsum('bhwc,wkc->bhwk', r_q, Rw)
571
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
572
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
572
573
 
573
574
  attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
574
- B, q_h * q_w, k_h * k_w)
575
+ B, q_h * q_w, k_h * k_w
576
+ )
575
577
 
576
578
  return attn
577
579
 
@@ -580,12 +582,12 @@ class PatchEmbed(nn.Module):
580
582
  """Image to Patch Embedding."""
581
583
 
582
584
  def __init__(
583
- self,
584
- kernel_size: Tuple[int, int] = (16, 16),
585
- stride: Tuple[int, int] = (16, 16),
586
- padding: Tuple[int, int] = (0, 0),
587
- in_chans: int = 3,
588
- embed_dim: int = 768,
585
+ self,
586
+ kernel_size: Tuple[int, int] = (16, 16),
587
+ stride: Tuple[int, int] = (16, 16),
588
+ padding: Tuple[int, int] = (0, 0),
589
+ in_chans: int = 3,
590
+ embed_dim: int = 768,
589
591
  ) -> None:
590
592
  """
591
593
  Initialize PatchEmbed module.
@@ -30,8 +30,9 @@ class Sam(nn.Module):
30
30
  pixel_mean (List[float]): Mean pixel values for image normalization.
31
31
  pixel_std (List[float]): Standard deviation values for image normalization.
32
32
  """
33
+
33
34
  mask_threshold: float = 0.0
34
- image_format: str = 'RGB'
35
+ image_format: str = "RGB"
35
36
 
36
37
  def __init__(
37
38
  self,
@@ -39,7 +40,7 @@ class Sam(nn.Module):
39
40
  prompt_encoder: PromptEncoder,
40
41
  mask_decoder: MaskDecoder,
41
42
  pixel_mean: List[float] = (123.675, 116.28, 103.53),
42
- pixel_std: List[float] = (58.395, 57.12, 57.375)
43
+ pixel_std: List[float] = (58.395, 57.12, 57.375),
43
44
  ) -> None:
44
45
  """
45
46
  Initialize the Sam class to predict object masks from an image and input prompts.
@@ -60,5 +61,5 @@ class Sam(nn.Module):
60
61
  self.image_encoder = image_encoder
61
62
  self.prompt_encoder = prompt_encoder
62
63
  self.mask_decoder = mask_decoder
63
- self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(-1, 1, 1), False)
64
- self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(-1, 1, 1), False)
64
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
65
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)