ultralytics 8.0.195__py3-none-any.whl → 8.0.196__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (84) hide show
  1. ultralytics/__init__.py +1 -1
  2. ultralytics/cfg/__init__.py +5 -6
  3. ultralytics/data/augment.py +234 -29
  4. ultralytics/data/base.py +2 -1
  5. ultralytics/data/build.py +9 -3
  6. ultralytics/data/converter.py +5 -2
  7. ultralytics/data/dataset.py +16 -2
  8. ultralytics/data/loaders.py +111 -7
  9. ultralytics/data/utils.py +3 -3
  10. ultralytics/engine/exporter.py +1 -3
  11. ultralytics/engine/model.py +3 -9
  12. ultralytics/engine/predictor.py +10 -6
  13. ultralytics/engine/results.py +18 -8
  14. ultralytics/engine/trainer.py +19 -31
  15. ultralytics/engine/tuner.py +20 -20
  16. ultralytics/engine/validator.py +3 -4
  17. ultralytics/hub/__init__.py +2 -2
  18. ultralytics/hub/auth.py +18 -3
  19. ultralytics/hub/session.py +1 -0
  20. ultralytics/hub/utils.py +1 -3
  21. ultralytics/models/fastsam/model.py +2 -1
  22. ultralytics/models/fastsam/predict.py +2 -0
  23. ultralytics/models/fastsam/prompt.py +15 -1
  24. ultralytics/models/nas/model.py +3 -1
  25. ultralytics/models/rtdetr/model.py +4 -6
  26. ultralytics/models/rtdetr/predict.py +2 -1
  27. ultralytics/models/rtdetr/train.py +2 -1
  28. ultralytics/models/rtdetr/val.py +1 -0
  29. ultralytics/models/sam/amg.py +12 -6
  30. ultralytics/models/sam/model.py +5 -6
  31. ultralytics/models/sam/modules/decoders.py +5 -1
  32. ultralytics/models/sam/modules/encoders.py +15 -12
  33. ultralytics/models/sam/modules/tiny_encoder.py +38 -2
  34. ultralytics/models/sam/modules/transformer.py +2 -4
  35. ultralytics/models/sam/predict.py +8 -4
  36. ultralytics/models/utils/loss.py +35 -8
  37. ultralytics/models/utils/ops.py +14 -18
  38. ultralytics/models/yolo/classify/predict.py +1 -0
  39. ultralytics/models/yolo/classify/train.py +4 -2
  40. ultralytics/models/yolo/classify/val.py +1 -0
  41. ultralytics/models/yolo/detect/train.py +4 -3
  42. ultralytics/models/yolo/model.py +2 -4
  43. ultralytics/models/yolo/pose/predict.py +1 -0
  44. ultralytics/models/yolo/segment/predict.py +2 -0
  45. ultralytics/models/yolo/segment/val.py +1 -1
  46. ultralytics/nn/autobackend.py +45 -32
  47. ultralytics/nn/modules/__init__.py +13 -9
  48. ultralytics/nn/modules/block.py +11 -5
  49. ultralytics/nn/modules/conv.py +16 -7
  50. ultralytics/nn/modules/head.py +6 -3
  51. ultralytics/nn/modules/transformer.py +47 -15
  52. ultralytics/nn/modules/utils.py +6 -4
  53. ultralytics/nn/tasks.py +61 -21
  54. ultralytics/trackers/bot_sort.py +53 -6
  55. ultralytics/trackers/byte_tracker.py +71 -15
  56. ultralytics/trackers/track.py +0 -1
  57. ultralytics/trackers/utils/gmc.py +23 -0
  58. ultralytics/trackers/utils/kalman_filter.py +6 -6
  59. ultralytics/utils/__init__.py +31 -18
  60. ultralytics/utils/autobatch.py +1 -3
  61. ultralytics/utils/benchmarks.py +14 -1
  62. ultralytics/utils/callbacks/base.py +1 -3
  63. ultralytics/utils/callbacks/comet.py +11 -3
  64. ultralytics/utils/callbacks/dvc.py +9 -0
  65. ultralytics/utils/callbacks/neptune.py +5 -6
  66. ultralytics/utils/callbacks/wb.py +1 -0
  67. ultralytics/utils/checks.py +13 -9
  68. ultralytics/utils/dist.py +2 -1
  69. ultralytics/utils/downloads.py +7 -3
  70. ultralytics/utils/files.py +3 -3
  71. ultralytics/utils/instance.py +12 -3
  72. ultralytics/utils/loss.py +97 -22
  73. ultralytics/utils/metrics.py +34 -34
  74. ultralytics/utils/ops.py +10 -9
  75. ultralytics/utils/patches.py +9 -7
  76. ultralytics/utils/plotting.py +4 -3
  77. ultralytics/utils/torch_utils.py +8 -6
  78. ultralytics/utils/triton.py +2 -1
  79. {ultralytics-8.0.195.dist-info → ultralytics-8.0.196.dist-info}/METADATA +1 -1
  80. {ultralytics-8.0.195.dist-info → ultralytics-8.0.196.dist-info}/RECORD +84 -84
  81. {ultralytics-8.0.195.dist-info → ultralytics-8.0.196.dist-info}/LICENSE +0 -0
  82. {ultralytics-8.0.195.dist-info → ultralytics-8.0.196.dist-info}/WHEEL +0 -0
  83. {ultralytics-8.0.195.dist-info → ultralytics-8.0.196.dist-info}/entry_points.txt +0 -0
  84. {ultralytics-8.0.195.dist-info → ultralytics-8.0.196.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@ from ultralytics.utils import TQDM
15
15
  class FastSAMPrompt:
16
16
 
17
17
  def __init__(self, source, results, device='cuda') -> None:
18
+ """Initializes FastSAMPrompt with given source, results and device, and assigns clip for linear assignment."""
18
19
  self.device = device
19
20
  self.results = results
20
21
  self.source = source
@@ -30,6 +31,7 @@ class FastSAMPrompt:
30
31
 
31
32
  @staticmethod
32
33
  def _segment_image(image, bbox):
34
+ """Segments the given image according to the provided bounding box coordinates."""
33
35
  image_array = np.array(image)
34
36
  segmented_image_array = np.zeros_like(image_array)
35
37
  x1, y1, x2, y2 = bbox
@@ -45,6 +47,9 @@ class FastSAMPrompt:
45
47
 
46
48
  @staticmethod
47
49
  def _format_results(result, filter=0):
50
+ """Formats detection results into list of annotations each containing ID, segmentation, bounding box, score and
51
+ area.
52
+ """
48
53
  annotations = []
49
54
  n = len(result.masks.data) if result.masks is not None else 0
50
55
  for i in range(n):
@@ -61,6 +66,9 @@ class FastSAMPrompt:
61
66
 
62
67
  @staticmethod
63
68
  def _get_bbox_from_mask(mask):
69
+ """Applies morphological transformations to the mask, displays it, and if with_contours is True, draws
70
+ contours.
71
+ """
64
72
  mask = mask.astype(np.uint8)
65
73
  contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
66
74
  x1, y1, w, h = cv2.boundingRect(contours[0])
@@ -195,6 +203,7 @@ class FastSAMPrompt:
195
203
 
196
204
  @torch.no_grad()
197
205
  def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
206
+ """Processes images and text with a model, calculates similarity, and returns softmax score."""
198
207
  preprocessed_images = [preprocess(image).to(device) for image in elements]
199
208
  tokenized_text = self.clip.tokenize([search_text]).to(device)
200
209
  stacked_images = torch.stack(preprocessed_images)
@@ -206,6 +215,7 @@ class FastSAMPrompt:
206
215
  return probs[:, 0].softmax(dim=0)
207
216
 
208
217
  def _crop_image(self, format_results):
218
+ """Crops an image based on provided annotation format and returns cropped images and related data."""
209
219
  if os.path.isdir(self.source):
210
220
  raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
211
221
  image = Image.fromarray(cv2.cvtColor(self.results[0].orig_img, cv2.COLOR_BGR2RGB))
@@ -229,6 +239,7 @@ class FastSAMPrompt:
229
239
  return cropped_boxes, cropped_images, not_crop, filter_id, annotations
230
240
 
231
241
  def box_prompt(self, bbox):
242
+ """Modifies the bounding box properties and calculates IoU between masks and bounding box."""
232
243
  if self.results[0].masks is not None:
233
244
  assert (bbox[2] != 0 and bbox[3] != 0)
234
245
  if os.path.isdir(self.source):
@@ -261,7 +272,8 @@ class FastSAMPrompt:
261
272
  self.results[0].masks.data = torch.tensor(np.array([masks[max_iou_index].cpu().numpy()]))
262
273
  return self.results
263
274
 
264
- def point_prompt(self, points, pointlabel): # numpy 处理
275
+ def point_prompt(self, points, pointlabel): # numpy
276
+ """Adjusts points on detected masks based on user input and returns the modified results."""
265
277
  if self.results[0].masks is not None:
266
278
  if os.path.isdir(self.source):
267
279
  raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
@@ -284,6 +296,7 @@ class FastSAMPrompt:
284
296
  return self.results
285
297
 
286
298
  def text_prompt(self, text):
299
+ """Processes a text prompt, applies it to existing results and returns the updated results."""
287
300
  if self.results[0].masks is not None:
288
301
  format_results = self._format_results(self.results[0], 0)
289
302
  cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
@@ -296,4 +309,5 @@ class FastSAMPrompt:
296
309
  return self.results
297
310
 
298
311
  def everything_prompt(self):
312
+ """Returns the processed results from the previous methods in the class."""
299
313
  return self.results
@@ -25,12 +25,13 @@ from .val import NASValidator
25
25
  class NAS(Model):
26
26
 
27
27
  def __init__(self, model='yolo_nas_s.pt') -> None:
28
+ """Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model."""
28
29
  assert Path(model).suffix not in ('.yaml', '.yml'), 'YOLO-NAS models only support pre-trained models.'
29
30
  super().__init__(model, task='detect')
30
31
 
31
32
  @smart_inference_mode()
32
33
  def _load(self, weights: str, task: str):
33
- # Load or create new NAS model
34
+ """Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided."""
34
35
  import super_gradients
35
36
  suffix = Path(weights).suffix
36
37
  if suffix == '.pt':
@@ -58,4 +59,5 @@ class NAS(Model):
58
59
 
59
60
  @property
60
61
  def task_map(self):
62
+ """Returns a dictionary mapping tasks to respective predictor and validator classes."""
61
63
  return {'detect': {'predictor': NASPredictor, 'validator': NASValidator}}
@@ -1,7 +1,5 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
- """
3
- RT-DETR model interface
4
- """
2
+ """RT-DETR model interface."""
5
3
  from ultralytics.engine.model import Model
6
4
  from ultralytics.nn.tasks import RTDETRDetectionModel
7
5
 
@@ -11,17 +9,17 @@ from .val import RTDETRValidator
11
9
 
12
10
 
13
11
  class RTDETR(Model):
14
- """
15
- RTDETR model interface.
16
- """
12
+ """RTDETR model interface."""
17
13
 
18
14
  def __init__(self, model='rtdetr-l.pt') -> None:
15
+ """Initializes the RTDETR model with the given model file, defaulting to 'rtdetr-l.pt'."""
19
16
  if model and model.split('.')[-1] not in ('pt', 'yaml', 'yml'):
20
17
  raise NotImplementedError('RT-DETR only supports creating from *.pt file or *.yaml file.')
21
18
  super().__init__(model=model, task='detect')
22
19
 
23
20
  @property
24
21
  def task_map(self):
22
+ """Returns a dictionary mapping task names to corresponding Ultralytics task classes for RTDETR model."""
25
23
  return {
26
24
  'detect': {
27
25
  'predictor': RTDETRPredictor,
@@ -48,7 +48,8 @@ class RTDETRPredictor(BasePredictor):
48
48
  return results
49
49
 
50
50
  def pre_transform(self, im):
51
- """Pre-transform input image before inference.
51
+ """
52
+ Pre-transform input image before inference.
52
53
 
53
54
  Args:
54
55
  im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
@@ -37,7 +37,8 @@ class RTDETRTrainer(DetectionTrainer):
37
37
  return model
38
38
 
39
39
  def build_dataset(self, img_path, mode='val', batch=None):
40
- """Build RTDETR Dataset
40
+ """
41
+ Build RTDETR Dataset.
41
42
 
42
43
  Args:
43
44
  img_path (str): Path to the folder containing images.
@@ -16,6 +16,7 @@ __all__ = 'RTDETRValidator', # tuple or list
16
16
  class RTDETRDataset(YOLODataset):
17
17
 
18
18
  def __init__(self, *args, data=None, **kwargs):
19
+ """Initialize the RTDETRDataset class by inheriting from the YOLODataset class."""
19
20
  super().__init__(*args, data=data, use_segments=False, use_keypoints=False, **kwargs)
20
21
 
21
22
  # NOTE: add stretch version load_image for rtdetr mosaic
@@ -32,9 +32,10 @@ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
32
32
 
33
33
  def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
34
34
  """
35
- Computes the stability score for a batch of masks. The stability
36
- score is the IoU between the binary masks obtained by thresholding
37
- the predicted mask logits at high and low values.
35
+ Computes the stability score for a batch of masks.
36
+
37
+ The stability score is the IoU between the binary masks obtained by thresholding the predicted mask logits at high
38
+ and low values.
38
39
  """
39
40
  # One mask is always contained inside the other.
40
41
  # Save memory by preventing unnecessary cast to torch.int64
@@ -60,7 +61,11 @@ def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer:
60
61
 
61
62
  def generate_crop_boxes(im_size: Tuple[int, ...], n_layers: int,
62
63
  overlap_ratio: float) -> Tuple[List[List[int]], List[int]]:
63
- """Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer."""
64
+ """
65
+ Generates a list of crop boxes of different sizes.
66
+
67
+ Each layer has (2**i)**2 boxes for the ith layer.
68
+ """
64
69
  crop_boxes, layer_idxs = [], []
65
70
  im_h, im_w = im_size
66
71
  short_side = min(im_h, im_w)
@@ -145,8 +150,9 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
145
150
 
146
151
  def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
147
152
  """
148
- Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
149
- an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
153
+ Calculates boxes in XYXY format around masks.
154
+
155
+ Return [0,0,0,0] for an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
150
156
  """
151
157
  # torch.max below raises an error on empty inputs, just skip in this case
152
158
  if torch.numel(masks) == 0:
@@ -1,7 +1,5 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
- """
3
- SAM model interface
4
- """
2
+ """SAM model interface."""
5
3
 
6
4
  from pathlib import Path
7
5
 
@@ -13,16 +11,16 @@ from .predict import Predictor
13
11
 
14
12
 
15
13
  class SAM(Model):
16
- """
17
- SAM model interface.
18
- """
14
+ """SAM model interface."""
19
15
 
20
16
  def __init__(self, model='sam_b.pt') -> None:
17
+ """Initializes the SAM model instance with the specified pre-trained model file."""
21
18
  if model and Path(model).suffix not in ('.pt', '.pth'):
22
19
  raise NotImplementedError('SAM prediction requires pre-trained *.pt or *.pth model.')
23
20
  super().__init__(model=model, task='segment')
24
21
 
25
22
  def _load(self, weights: str, task=None):
23
+ """Loads the provided weights into the SAM model."""
26
24
  self.model = build_sam(weights)
27
25
 
28
26
  def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
@@ -48,4 +46,5 @@ class SAM(Model):
48
46
 
49
47
  @property
50
48
  def task_map(self):
49
+ """Returns a dictionary mapping the 'segment' task to its corresponding 'Predictor'."""
51
50
  return {'segment': {'predictor': Predictor}}
@@ -98,7 +98,11 @@ class MaskDecoder(nn.Module):
98
98
  sparse_prompt_embeddings: torch.Tensor,
99
99
  dense_prompt_embeddings: torch.Tensor,
100
100
  ) -> Tuple[torch.Tensor, torch.Tensor]:
101
- """Predicts masks. See 'forward' for more details."""
101
+ """
102
+ Predicts masks.
103
+
104
+ See 'forward' for more details.
105
+ """
102
106
  # Concatenate output tokens
103
107
  output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
104
108
  output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
@@ -100,6 +100,9 @@ class ImageEncoderViT(nn.Module):
100
100
  )
101
101
 
102
102
  def forward(self, x: torch.Tensor) -> torch.Tensor:
103
+ """Processes input through patch embedding, applies positional embedding if present, and passes through blocks
104
+ and neck.
105
+ """
103
106
  x = self.patch_embed(x)
104
107
  if self.pos_embed is not None:
105
108
  x = x + self.pos_embed
@@ -157,8 +160,8 @@ class PromptEncoder(nn.Module):
157
160
 
158
161
  def get_dense_pe(self) -> torch.Tensor:
159
162
  """
160
- Returns the positional encoding used to encode point prompts,
161
- applied to a dense set of points the shape of the image encoding.
163
+ Returns the positional encoding used to encode point prompts, applied to a dense set of points the shape of the
164
+ image encoding.
162
165
 
163
166
  Returns:
164
167
  torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w)
@@ -204,9 +207,7 @@ class PromptEncoder(nn.Module):
204
207
  boxes: Optional[torch.Tensor],
205
208
  masks: Optional[torch.Tensor],
206
209
  ) -> int:
207
- """
208
- Gets the batch size of the output given the batch size of the input prompts.
209
- """
210
+ """Gets the batch size of the output given the batch size of the input prompts."""
210
211
  if points is not None:
211
212
  return points[0].shape[0]
212
213
  elif boxes is not None:
@@ -217,6 +218,7 @@ class PromptEncoder(nn.Module):
217
218
  return 1
218
219
 
219
220
  def _get_device(self) -> torch.device:
221
+ """Returns the device of the first point embedding's weight tensor."""
220
222
  return self.point_embeddings[0].weight.device
221
223
 
222
224
  def forward(
@@ -259,11 +261,10 @@ class PromptEncoder(nn.Module):
259
261
 
260
262
 
261
263
  class PositionEmbeddingRandom(nn.Module):
262
- """
263
- Positional encoding using random spatial frequencies.
264
- """
264
+ """Positional encoding using random spatial frequencies."""
265
265
 
266
266
  def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
267
+ """Initializes a position embedding using random spatial frequencies."""
267
268
  super().__init__()
268
269
  if scale is None or scale <= 0.0:
269
270
  scale = 1.0
@@ -304,7 +305,7 @@ class PositionEmbeddingRandom(nn.Module):
304
305
 
305
306
 
306
307
  class Block(nn.Module):
307
- """Transformer blocks with support of window attention and residual propagation blocks"""
308
+ """Transformer blocks with support of window attention and residual propagation blocks."""
308
309
 
309
310
  def __init__(
310
311
  self,
@@ -351,6 +352,7 @@ class Block(nn.Module):
351
352
  self.window_size = window_size
352
353
 
353
354
  def forward(self, x: torch.Tensor) -> torch.Tensor:
355
+ """Executes a forward pass through the transformer block with window attention and non-overlapping windows."""
354
356
  shortcut = x
355
357
  x = self.norm1(x)
356
358
  # Window partition
@@ -404,6 +406,7 @@ class Attention(nn.Module):
404
406
  self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
405
407
 
406
408
  def forward(self, x: torch.Tensor) -> torch.Tensor:
409
+ """Applies the forward operation including attention, normalization, MLP, and indexing within window limits."""
407
410
  B, H, W, _ = x.shape
408
411
  # qkv with shape (3, B, nHead, H * W, C)
409
412
  qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
@@ -448,6 +451,7 @@ def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[in
448
451
  hw: Tuple[int, int]) -> torch.Tensor:
449
452
  """
450
453
  Window unpartition into original sequences and removing padding.
454
+
451
455
  Args:
452
456
  windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
453
457
  window_size (int): window size.
@@ -540,9 +544,7 @@ def add_decomposed_rel_pos(
540
544
 
541
545
 
542
546
  class PatchEmbed(nn.Module):
543
- """
544
- Image to Patch Embedding.
545
- """
547
+ """Image to Patch Embedding."""
546
548
 
547
549
  def __init__(
548
550
  self,
@@ -565,4 +567,5 @@ class PatchEmbed(nn.Module):
565
567
  self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
566
568
 
567
569
  def forward(self, x: torch.Tensor) -> torch.Tensor:
570
+ """Computes patch embedding by applying convolution and transposing resulting tensor."""
568
571
  return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C
@@ -23,6 +23,9 @@ from ultralytics.utils.instance import to_2tuple
23
23
  class Conv2d_BN(torch.nn.Sequential):
24
24
 
25
25
  def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
26
+ """Initializes the MBConv model with given input channels, output channels, expansion ratio, activation, and
27
+ drop path.
28
+ """
26
29
  super().__init__()
27
30
  self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
28
31
  bn = torch.nn.BatchNorm2d(b)
@@ -34,6 +37,9 @@ class Conv2d_BN(torch.nn.Sequential):
34
37
  class PatchEmbed(nn.Module):
35
38
 
36
39
  def __init__(self, in_chans, embed_dim, resolution, activation):
40
+ """Initialize the PatchMerging class with specified input, output dimensions, resolution and activation
41
+ function.
42
+ """
37
43
  super().__init__()
38
44
  img_size: Tuple[int, int] = to_2tuple(resolution)
39
45
  self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
@@ -48,12 +54,16 @@ class PatchEmbed(nn.Module):
48
54
  )
49
55
 
50
56
  def forward(self, x):
57
+ """Runs input tensor 'x' through the PatchMerging model's sequence of operations."""
51
58
  return self.seq(x)
52
59
 
53
60
 
54
61
  class MBConv(nn.Module):
55
62
 
56
63
  def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
64
+ """Initializes a convolutional layer with specified dimensions, input resolution, depth, and activation
65
+ function.
66
+ """
57
67
  super().__init__()
58
68
  self.in_chans = in_chans
59
69
  self.hidden_chans = int(in_chans * expand_ratio)
@@ -73,6 +83,7 @@ class MBConv(nn.Module):
73
83
  self.drop_path = nn.Identity()
74
84
 
75
85
  def forward(self, x):
86
+ """Implements the forward pass for the model architecture."""
76
87
  shortcut = x
77
88
  x = self.conv1(x)
78
89
  x = self.act1(x)
@@ -87,6 +98,9 @@ class MBConv(nn.Module):
87
98
  class PatchMerging(nn.Module):
88
99
 
89
100
  def __init__(self, input_resolution, dim, out_dim, activation):
101
+ """Initializes the ConvLayer with specific dimension, input resolution, depth, activation, drop path, and other
102
+ optional parameters.
103
+ """
90
104
  super().__init__()
91
105
 
92
106
  self.input_resolution = input_resolution
@@ -99,6 +113,7 @@ class PatchMerging(nn.Module):
99
113
  self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
100
114
 
101
115
  def forward(self, x):
116
+ """Applies forward pass on the input utilizing convolution and activation layers, and returns the result."""
102
117
  if x.ndim == 3:
103
118
  H, W = self.input_resolution
104
119
  B = len(x)
@@ -149,6 +164,7 @@ class ConvLayer(nn.Module):
149
164
  input_resolution, dim=dim, out_dim=out_dim, activation=activation)
150
165
 
151
166
  def forward(self, x):
167
+ """Processes the input through a series of convolutional layers and returns the activated output."""
152
168
  for blk in self.blocks:
153
169
  x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
154
170
  return x if self.downsample is None else self.downsample(x)
@@ -157,6 +173,7 @@ class ConvLayer(nn.Module):
157
173
  class Mlp(nn.Module):
158
174
 
159
175
  def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
176
+ """Initializes Attention module with the given parameters including dimension, key_dim, number of heads, etc."""
160
177
  super().__init__()
161
178
  out_features = out_features or in_features
162
179
  hidden_features = hidden_features or in_features
@@ -167,6 +184,7 @@ class Mlp(nn.Module):
167
184
  self.drop = nn.Dropout(drop)
168
185
 
169
186
  def forward(self, x):
187
+ """Applies operations on input x and returns modified x, runs downsample if not None."""
170
188
  x = self.norm(x)
171
189
  x = self.fc1(x)
172
190
  x = self.act(x)
@@ -216,6 +234,7 @@ class Attention(torch.nn.Module):
216
234
 
217
235
  @torch.no_grad()
218
236
  def train(self, mode=True):
237
+ """Sets the module in training mode and handles attribute 'ab' based on the mode."""
219
238
  super().train(mode)
220
239
  if mode and hasattr(self, 'ab'):
221
240
  del self.ab
@@ -298,6 +317,9 @@ class TinyViTBlock(nn.Module):
298
317
  self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
299
318
 
300
319
  def forward(self, x):
320
+ """Applies attention-based transformation or padding to input 'x' before passing it through a local
321
+ convolution.
322
+ """
301
323
  H, W = self.input_resolution
302
324
  B, L, C = x.shape
303
325
  assert L == H * W, 'input feature has wrong size'
@@ -337,6 +359,9 @@ class TinyViTBlock(nn.Module):
337
359
  return x + self.drop_path(self.mlp(x))
338
360
 
339
361
  def extra_repr(self) -> str:
362
+ """Returns a formatted string representing the TinyViTBlock's parameters: dimension, input resolution, number of
363
+ attentions heads, window size, and MLP ratio.
364
+ """
340
365
  return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \
341
366
  f'window_size={self.window_size}, mlp_ratio={self.mlp_ratio}'
342
367
 
@@ -402,23 +427,28 @@ class BasicLayer(nn.Module):
402
427
  input_resolution, dim=dim, out_dim=out_dim, activation=activation)
403
428
 
404
429
  def forward(self, x):
430
+ """Performs forward propagation on the input tensor and returns a normalized tensor."""
405
431
  for blk in self.blocks:
406
432
  x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
407
433
  return x if self.downsample is None else self.downsample(x)
408
434
 
409
435
  def extra_repr(self) -> str:
436
+ """Returns a string representation of the extra_repr function with the layer's parameters."""
410
437
  return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
411
438
 
412
439
 
413
440
  class LayerNorm2d(nn.Module):
441
+ """A PyTorch implementation of Layer Normalization in 2D."""
414
442
 
415
443
  def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
444
+ """Initialize LayerNorm2d with the number of channels and an optional epsilon."""
416
445
  super().__init__()
417
446
  self.weight = nn.Parameter(torch.ones(num_channels))
418
447
  self.bias = nn.Parameter(torch.zeros(num_channels))
419
448
  self.eps = eps
420
449
 
421
450
  def forward(self, x: torch.Tensor) -> torch.Tensor:
451
+ """Perform a forward pass, normalizing the input tensor."""
422
452
  u = x.mean(1, keepdim=True)
423
453
  s = (x - u).pow(2).mean(1, keepdim=True)
424
454
  x = (x - u) / torch.sqrt(s + self.eps)
@@ -518,6 +548,7 @@ class TinyViT(nn.Module):
518
548
  )
519
549
 
520
550
  def set_layer_lr_decay(self, layer_lr_decay):
551
+ """Sets the learning rate decay for each layer in the TinyViT model."""
521
552
  decay_rate = layer_lr_decay
522
553
 
523
554
  # layers -> blocks (depth)
@@ -525,6 +556,7 @@ class TinyViT(nn.Module):
525
556
  lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
526
557
 
527
558
  def _set_lr_scale(m, scale):
559
+ """Sets the learning rate scale for each layer in the model based on the layer's depth."""
528
560
  for p in m.parameters():
529
561
  p.lr_scale = scale
530
562
 
@@ -544,12 +576,14 @@ class TinyViT(nn.Module):
544
576
  p.param_name = k
545
577
 
546
578
  def _check_lr_scale(m):
579
+ """Checks if the learning rate scale attribute is present in module's parameters."""
547
580
  for p in m.parameters():
548
581
  assert hasattr(p, 'lr_scale'), p.param_name
549
582
 
550
583
  self.apply(_check_lr_scale)
551
584
 
552
585
  def _init_weights(self, m):
586
+ """Initializes weights for linear layers and layer normalization in the given module."""
553
587
  if isinstance(m, nn.Linear):
554
588
  # NOTE: This initialization is needed only for training.
555
589
  # trunc_normal_(m.weight, std=.02)
@@ -561,11 +595,12 @@ class TinyViT(nn.Module):
561
595
 
562
596
  @torch.jit.ignore
563
597
  def no_weight_decay_keywords(self):
598
+ """Returns a dictionary of parameter names where weight decay should not be applied."""
564
599
  return {'attention_biases'}
565
600
 
566
601
  def forward_features(self, x):
567
- # x: (N, C, H, W)
568
- x = self.patch_embed(x)
602
+ """Runs the input through the model layers and returns the transformed output."""
603
+ x = self.patch_embed(x) # x input is (N, C, H, W)
569
604
 
570
605
  x = self.layers[0](x)
571
606
  start_i = 1
@@ -579,4 +614,5 @@ class TinyViT(nn.Module):
579
614
  return self.neck(x)
580
615
 
581
616
  def forward(self, x):
617
+ """Executes a forward pass on the input tensor through the constructed model layers."""
582
618
  return self.forward_features(x)
@@ -21,8 +21,7 @@ class TwoWayTransformer(nn.Module):
21
21
  attention_downsample_rate: int = 2,
22
22
  ) -> None:
23
23
  """
24
- A transformer decoder that attends to an input image using
25
- queries whose positional embedding is supplied.
24
+ A transformer decoder that attends to an input image using queries whose positional embedding is supplied.
26
25
 
27
26
  Args:
28
27
  depth (int): number of layers in the transformer
@@ -171,8 +170,7 @@ class TwoWayAttentionBlock(nn.Module):
171
170
 
172
171
 
173
172
  class Attention(nn.Module):
174
- """
175
- An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
173
+ """An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
176
174
  values.
177
175
  """
178
176
 
@@ -19,6 +19,7 @@ from .build import build_sam
19
19
  class Predictor(BasePredictor):
20
20
 
21
21
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
22
+ """Initializes the Predictor class with default or provided configuration, overrides, and callbacks."""
22
23
  if overrides is None:
23
24
  overrides = {}
24
25
  overrides.update(dict(task='segment', mode='predict', imgsz=1024))
@@ -34,7 +35,8 @@ class Predictor(BasePredictor):
34
35
  self.segment_all = False
35
36
 
36
37
  def preprocess(self, im):
37
- """Prepares input image before inference.
38
+ """
39
+ Prepares input image before inference.
38
40
 
39
41
  Args:
40
42
  im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
@@ -189,7 +191,8 @@ class Predictor(BasePredictor):
189
191
  stability_score_thresh=0.95,
190
192
  stability_score_offset=0.95,
191
193
  crop_nms_thresh=0.7):
192
- """Segment the whole image.
194
+ """
195
+ Segment the whole image.
193
196
 
194
197
  Args:
195
198
  im (torch.Tensor): The preprocessed image, (N, C, H, W).
@@ -360,14 +363,15 @@ class Predictor(BasePredictor):
360
363
  self.prompts = prompts
361
364
 
362
365
  def reset_image(self):
366
+ """Resets the image and its features to None."""
363
367
  self.im = None
364
368
  self.features = None
365
369
 
366
370
  @staticmethod
367
371
  def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
368
372
  """
369
- Removes small disconnected regions and holes in masks, then reruns
370
- box NMS to remove any new duplicates. Requires open-cv as a dependency.
373
+ Removes small disconnected regions and holes in masks, then reruns box NMS to remove any new duplicates.
374
+ Requires open-cv as a dependency.
371
375
 
372
376
  Args:
373
377
  masks (torch.Tensor): Masks, (N, H, W).