ultralytics 8.0.238__py3-none-any.whl → 8.0.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (134) hide show
  1. ultralytics/__init__.py +2 -2
  2. ultralytics/cfg/__init__.py +241 -138
  3. ultralytics/data/__init__.py +9 -2
  4. ultralytics/data/annotator.py +4 -4
  5. ultralytics/data/augment.py +186 -169
  6. ultralytics/data/base.py +54 -48
  7. ultralytics/data/build.py +34 -23
  8. ultralytics/data/converter.py +242 -70
  9. ultralytics/data/dataset.py +117 -95
  10. ultralytics/data/explorer/__init__.py +3 -1
  11. ultralytics/data/explorer/explorer.py +120 -100
  12. ultralytics/data/explorer/gui/__init__.py +1 -0
  13. ultralytics/data/explorer/gui/dash.py +123 -89
  14. ultralytics/data/explorer/utils.py +37 -39
  15. ultralytics/data/loaders.py +75 -62
  16. ultralytics/data/split_dota.py +44 -36
  17. ultralytics/data/utils.py +160 -142
  18. ultralytics/engine/exporter.py +348 -292
  19. ultralytics/engine/model.py +102 -66
  20. ultralytics/engine/predictor.py +74 -55
  21. ultralytics/engine/results.py +61 -41
  22. ultralytics/engine/trainer.py +192 -144
  23. ultralytics/engine/tuner.py +66 -59
  24. ultralytics/engine/validator.py +31 -26
  25. ultralytics/hub/__init__.py +54 -31
  26. ultralytics/hub/auth.py +28 -25
  27. ultralytics/hub/session.py +282 -133
  28. ultralytics/hub/utils.py +64 -42
  29. ultralytics/models/__init__.py +1 -1
  30. ultralytics/models/fastsam/__init__.py +1 -1
  31. ultralytics/models/fastsam/model.py +6 -6
  32. ultralytics/models/fastsam/predict.py +3 -2
  33. ultralytics/models/fastsam/prompt.py +55 -48
  34. ultralytics/models/fastsam/val.py +1 -1
  35. ultralytics/models/nas/__init__.py +1 -1
  36. ultralytics/models/nas/model.py +9 -8
  37. ultralytics/models/nas/predict.py +8 -6
  38. ultralytics/models/nas/val.py +11 -9
  39. ultralytics/models/rtdetr/__init__.py +1 -1
  40. ultralytics/models/rtdetr/model.py +11 -9
  41. ultralytics/models/rtdetr/train.py +18 -16
  42. ultralytics/models/rtdetr/val.py +25 -19
  43. ultralytics/models/sam/__init__.py +1 -1
  44. ultralytics/models/sam/amg.py +13 -14
  45. ultralytics/models/sam/build.py +44 -42
  46. ultralytics/models/sam/model.py +6 -6
  47. ultralytics/models/sam/modules/decoders.py +6 -4
  48. ultralytics/models/sam/modules/encoders.py +37 -35
  49. ultralytics/models/sam/modules/sam.py +5 -4
  50. ultralytics/models/sam/modules/tiny_encoder.py +95 -73
  51. ultralytics/models/sam/modules/transformer.py +3 -2
  52. ultralytics/models/sam/predict.py +39 -27
  53. ultralytics/models/utils/loss.py +99 -95
  54. ultralytics/models/utils/ops.py +34 -31
  55. ultralytics/models/yolo/__init__.py +1 -1
  56. ultralytics/models/yolo/classify/__init__.py +1 -1
  57. ultralytics/models/yolo/classify/predict.py +8 -6
  58. ultralytics/models/yolo/classify/train.py +37 -31
  59. ultralytics/models/yolo/classify/val.py +26 -24
  60. ultralytics/models/yolo/detect/__init__.py +1 -1
  61. ultralytics/models/yolo/detect/predict.py +8 -6
  62. ultralytics/models/yolo/detect/train.py +47 -37
  63. ultralytics/models/yolo/detect/val.py +100 -82
  64. ultralytics/models/yolo/model.py +31 -25
  65. ultralytics/models/yolo/obb/__init__.py +1 -1
  66. ultralytics/models/yolo/obb/predict.py +13 -11
  67. ultralytics/models/yolo/obb/train.py +3 -3
  68. ultralytics/models/yolo/obb/val.py +70 -59
  69. ultralytics/models/yolo/pose/__init__.py +1 -1
  70. ultralytics/models/yolo/pose/predict.py +17 -12
  71. ultralytics/models/yolo/pose/train.py +28 -25
  72. ultralytics/models/yolo/pose/val.py +91 -64
  73. ultralytics/models/yolo/segment/__init__.py +1 -1
  74. ultralytics/models/yolo/segment/predict.py +10 -8
  75. ultralytics/models/yolo/segment/train.py +16 -15
  76. ultralytics/models/yolo/segment/val.py +90 -68
  77. ultralytics/nn/__init__.py +26 -6
  78. ultralytics/nn/autobackend.py +144 -112
  79. ultralytics/nn/modules/__init__.py +96 -13
  80. ultralytics/nn/modules/block.py +28 -7
  81. ultralytics/nn/modules/conv.py +41 -23
  82. ultralytics/nn/modules/head.py +60 -52
  83. ultralytics/nn/modules/transformer.py +49 -32
  84. ultralytics/nn/modules/utils.py +20 -15
  85. ultralytics/nn/tasks.py +215 -141
  86. ultralytics/solutions/ai_gym.py +59 -47
  87. ultralytics/solutions/distance_calculation.py +17 -14
  88. ultralytics/solutions/heatmap.py +57 -55
  89. ultralytics/solutions/object_counter.py +46 -39
  90. ultralytics/solutions/speed_estimation.py +13 -16
  91. ultralytics/trackers/__init__.py +1 -1
  92. ultralytics/trackers/basetrack.py +1 -0
  93. ultralytics/trackers/bot_sort.py +2 -1
  94. ultralytics/trackers/byte_tracker.py +10 -7
  95. ultralytics/trackers/track.py +7 -7
  96. ultralytics/trackers/utils/gmc.py +25 -25
  97. ultralytics/trackers/utils/kalman_filter.py +85 -42
  98. ultralytics/trackers/utils/matching.py +8 -7
  99. ultralytics/utils/__init__.py +173 -152
  100. ultralytics/utils/autobatch.py +10 -10
  101. ultralytics/utils/benchmarks.py +76 -86
  102. ultralytics/utils/callbacks/__init__.py +1 -1
  103. ultralytics/utils/callbacks/base.py +29 -29
  104. ultralytics/utils/callbacks/clearml.py +51 -43
  105. ultralytics/utils/callbacks/comet.py +81 -66
  106. ultralytics/utils/callbacks/dvc.py +33 -26
  107. ultralytics/utils/callbacks/hub.py +44 -26
  108. ultralytics/utils/callbacks/mlflow.py +31 -24
  109. ultralytics/utils/callbacks/neptune.py +35 -25
  110. ultralytics/utils/callbacks/raytune.py +9 -4
  111. ultralytics/utils/callbacks/tensorboard.py +16 -11
  112. ultralytics/utils/callbacks/wb.py +39 -33
  113. ultralytics/utils/checks.py +189 -141
  114. ultralytics/utils/dist.py +15 -12
  115. ultralytics/utils/downloads.py +112 -96
  116. ultralytics/utils/errors.py +1 -1
  117. ultralytics/utils/files.py +11 -11
  118. ultralytics/utils/instance.py +22 -22
  119. ultralytics/utils/loss.py +117 -67
  120. ultralytics/utils/metrics.py +224 -158
  121. ultralytics/utils/ops.py +38 -28
  122. ultralytics/utils/patches.py +3 -3
  123. ultralytics/utils/plotting.py +217 -120
  124. ultralytics/utils/tal.py +19 -13
  125. ultralytics/utils/torch_utils.py +138 -109
  126. ultralytics/utils/triton.py +12 -10
  127. ultralytics/utils/tuner.py +49 -47
  128. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/METADATA +2 -1
  129. ultralytics-8.0.239.dist-info/RECORD +188 -0
  130. ultralytics-8.0.238.dist-info/RECORD +0 -188
  131. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
  132. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
  133. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
  134. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
@@ -31,59 +31,76 @@ class PoseValidator(DetectionValidator):
31
31
  super().__init__(dataloader, save_dir, pbar, args, _callbacks)
32
32
  self.sigma = None
33
33
  self.kpt_shape = None
34
- self.args.task = 'pose'
34
+ self.args.task = "pose"
35
35
  self.metrics = PoseMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
36
- if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
37
- LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
38
- 'See https://github.com/ultralytics/ultralytics/issues/4031.')
36
+ if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
37
+ LOGGER.warning(
38
+ "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
39
+ "See https://github.com/ultralytics/ultralytics/issues/4031."
40
+ )
39
41
 
40
42
  def preprocess(self, batch):
41
43
  """Preprocesses the batch by converting the 'keypoints' data into a float and moving it to the device."""
42
44
  batch = super().preprocess(batch)
43
- batch['keypoints'] = batch['keypoints'].to(self.device).float()
45
+ batch["keypoints"] = batch["keypoints"].to(self.device).float()
44
46
  return batch
45
47
 
46
48
  def get_desc(self):
47
49
  """Returns description of evaluation metrics in string format."""
48
- return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Pose(P',
49
- 'R', 'mAP50', 'mAP50-95)')
50
+ return ("%22s" + "%11s" * 10) % (
51
+ "Class",
52
+ "Images",
53
+ "Instances",
54
+ "Box(P",
55
+ "R",
56
+ "mAP50",
57
+ "mAP50-95)",
58
+ "Pose(P",
59
+ "R",
60
+ "mAP50",
61
+ "mAP50-95)",
62
+ )
50
63
 
51
64
  def postprocess(self, preds):
52
65
  """Apply non-maximum suppression and return detections with high confidence scores."""
53
- return ops.non_max_suppression(preds,
54
- self.args.conf,
55
- self.args.iou,
56
- labels=self.lb,
57
- multi_label=True,
58
- agnostic=self.args.single_cls,
59
- max_det=self.args.max_det,
60
- nc=self.nc)
66
+ return ops.non_max_suppression(
67
+ preds,
68
+ self.args.conf,
69
+ self.args.iou,
70
+ labels=self.lb,
71
+ multi_label=True,
72
+ agnostic=self.args.single_cls,
73
+ max_det=self.args.max_det,
74
+ nc=self.nc,
75
+ )
61
76
 
62
77
  def init_metrics(self, model):
63
78
  """Initiate pose estimation metrics for YOLO model."""
64
79
  super().init_metrics(model)
65
- self.kpt_shape = self.data['kpt_shape']
80
+ self.kpt_shape = self.data["kpt_shape"]
66
81
  is_pose = self.kpt_shape == [17, 3]
67
82
  nkpt = self.kpt_shape[0]
68
83
  self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
69
84
  self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[])
70
85
 
71
86
  def _prepare_batch(self, si, batch):
87
+ """Prepares a batch for processing by converting keypoints to float and moving to device."""
72
88
  pbatch = super()._prepare_batch(si, batch)
73
- kpts = batch['keypoints'][batch['batch_idx'] == si]
74
- h, w = pbatch['imgsz']
89
+ kpts = batch["keypoints"][batch["batch_idx"] == si]
90
+ h, w = pbatch["imgsz"]
75
91
  kpts = kpts.clone()
76
92
  kpts[..., 0] *= w
77
93
  kpts[..., 1] *= h
78
- kpts = ops.scale_coords(pbatch['imgsz'], kpts, pbatch['ori_shape'], ratio_pad=pbatch['ratio_pad'])
79
- pbatch['kpts'] = kpts
94
+ kpts = ops.scale_coords(pbatch["imgsz"], kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
95
+ pbatch["kpts"] = kpts
80
96
  return pbatch
81
97
 
82
98
  def _prepare_pred(self, pred, pbatch):
99
+ """Prepares and scales keypoints in a batch for pose processing."""
83
100
  predn = super()._prepare_pred(pred, pbatch)
84
- nk = pbatch['kpts'].shape[1]
101
+ nk = pbatch["kpts"].shape[1]
85
102
  pred_kpts = predn[:, 6:].view(len(predn), nk, -1)
86
- ops.scale_coords(pbatch['imgsz'], pred_kpts, pbatch['ori_shape'], ratio_pad=pbatch['ratio_pad'])
103
+ ops.scale_coords(pbatch["imgsz"], pred_kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
87
104
  return predn, pred_kpts
88
105
 
89
106
  def update_metrics(self, preds, batch):
@@ -91,14 +108,16 @@ class PoseValidator(DetectionValidator):
91
108
  for si, pred in enumerate(preds):
92
109
  self.seen += 1
93
110
  npr = len(pred)
94
- stat = dict(conf=torch.zeros(0, device=self.device),
95
- pred_cls=torch.zeros(0, device=self.device),
96
- tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
97
- tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device))
111
+ stat = dict(
112
+ conf=torch.zeros(0, device=self.device),
113
+ pred_cls=torch.zeros(0, device=self.device),
114
+ tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
115
+ tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
116
+ )
98
117
  pbatch = self._prepare_batch(si, batch)
99
- cls, bbox = pbatch.pop('cls'), pbatch.pop('bbox')
118
+ cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
100
119
  nl = len(cls)
101
- stat['target_cls'] = cls
120
+ stat["target_cls"] = cls
102
121
  if npr == 0:
103
122
  if nl:
104
123
  for k in self.stats.keys():
@@ -111,13 +130,13 @@ class PoseValidator(DetectionValidator):
111
130
  if self.args.single_cls:
112
131
  pred[:, 5] = 0
113
132
  predn, pred_kpts = self._prepare_pred(pred, pbatch)
114
- stat['conf'] = predn[:, 4]
115
- stat['pred_cls'] = predn[:, 5]
133
+ stat["conf"] = predn[:, 4]
134
+ stat["pred_cls"] = predn[:, 5]
116
135
 
117
136
  # Evaluate
118
137
  if nl:
119
- stat['tp'] = self._process_batch(predn, bbox, cls)
120
- stat['tp_p'] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch['kpts'])
138
+ stat["tp"] = self._process_batch(predn, bbox, cls)
139
+ stat["tp_p"] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch["kpts"])
121
140
  if self.args.plots:
122
141
  self.confusion_matrix.process_batch(predn, bbox, cls)
123
142
 
@@ -126,7 +145,7 @@ class PoseValidator(DetectionValidator):
126
145
 
127
146
  # Save
128
147
  if self.args.save_json:
129
- self.pred_to_json(predn, batch['im_file'][si])
148
+ self.pred_to_json(predn, batch["im_file"][si])
130
149
  # if self.args.save_txt:
131
150
  # save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
132
151
 
@@ -157,26 +176,30 @@ class PoseValidator(DetectionValidator):
157
176
 
158
177
  def plot_val_samples(self, batch, ni):
159
178
  """Plots and saves validation set samples with predicted bounding boxes and keypoints."""
160
- plot_images(batch['img'],
161
- batch['batch_idx'],
162
- batch['cls'].squeeze(-1),
163
- batch['bboxes'],
164
- kpts=batch['keypoints'],
165
- paths=batch['im_file'],
166
- fname=self.save_dir / f'val_batch{ni}_labels.jpg',
167
- names=self.names,
168
- on_plot=self.on_plot)
179
+ plot_images(
180
+ batch["img"],
181
+ batch["batch_idx"],
182
+ batch["cls"].squeeze(-1),
183
+ batch["bboxes"],
184
+ kpts=batch["keypoints"],
185
+ paths=batch["im_file"],
186
+ fname=self.save_dir / f"val_batch{ni}_labels.jpg",
187
+ names=self.names,
188
+ on_plot=self.on_plot,
189
+ )
169
190
 
170
191
  def plot_predictions(self, batch, preds, ni):
171
192
  """Plots predictions for YOLO model."""
172
193
  pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
173
- plot_images(batch['img'],
174
- *output_to_target(preds, max_det=self.args.max_det),
175
- kpts=pred_kpts,
176
- paths=batch['im_file'],
177
- fname=self.save_dir / f'val_batch{ni}_pred.jpg',
178
- names=self.names,
179
- on_plot=self.on_plot) # pred
194
+ plot_images(
195
+ batch["img"],
196
+ *output_to_target(preds, max_det=self.args.max_det),
197
+ kpts=pred_kpts,
198
+ paths=batch["im_file"],
199
+ fname=self.save_dir / f"val_batch{ni}_pred.jpg",
200
+ names=self.names,
201
+ on_plot=self.on_plot,
202
+ ) # pred
180
203
 
181
204
  def pred_to_json(self, predn, filename):
182
205
  """Converts YOLO predictions to COCO JSON format."""
@@ -185,37 +208,41 @@ class PoseValidator(DetectionValidator):
185
208
  box = ops.xyxy2xywh(predn[:, :4]) # xywh
186
209
  box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
187
210
  for p, b in zip(predn.tolist(), box.tolist()):
188
- self.jdict.append({
189
- 'image_id': image_id,
190
- 'category_id': self.class_map[int(p[5])],
191
- 'bbox': [round(x, 3) for x in b],
192
- 'keypoints': p[6:],
193
- 'score': round(p[4], 5)})
211
+ self.jdict.append(
212
+ {
213
+ "image_id": image_id,
214
+ "category_id": self.class_map[int(p[5])],
215
+ "bbox": [round(x, 3) for x in b],
216
+ "keypoints": p[6:],
217
+ "score": round(p[4], 5),
218
+ }
219
+ )
194
220
 
195
221
  def eval_json(self, stats):
196
222
  """Evaluates object detection model using COCO JSON format."""
197
223
  if self.args.save_json and self.is_coco and len(self.jdict):
198
- anno_json = self.data['path'] / 'annotations/person_keypoints_val2017.json' # annotations
199
- pred_json = self.save_dir / 'predictions.json' # predictions
200
- LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
224
+ anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations
225
+ pred_json = self.save_dir / "predictions.json" # predictions
226
+ LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
201
227
  try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
202
- check_requirements('pycocotools>=2.0.6')
228
+ check_requirements("pycocotools>=2.0.6")
203
229
  from pycocotools.coco import COCO # noqa
204
230
  from pycocotools.cocoeval import COCOeval # noqa
205
231
 
206
232
  for x in anno_json, pred_json:
207
- assert x.is_file(), f'{x} file not found'
233
+ assert x.is_file(), f"{x} file not found"
208
234
  anno = COCO(str(anno_json)) # init annotations api
209
235
  pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
210
- for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'keypoints')]):
236
+ for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "keypoints")]):
211
237
  if self.is_coco:
212
238
  eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
213
239
  eval.evaluate()
214
240
  eval.accumulate()
215
241
  eval.summarize()
216
242
  idx = i * 4 + 2
217
- stats[self.metrics.keys[idx + 1]], stats[
218
- self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50
243
+ stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[
244
+ :2
245
+ ] # update mAP50-95 and mAP50
219
246
  except Exception as e:
220
- LOGGER.warning(f'pycocotools unable to run: {e}')
247
+ LOGGER.warning(f"pycocotools unable to run: {e}")
221
248
  return stats
@@ -4,4 +4,4 @@ from .predict import SegmentationPredictor
4
4
  from .train import SegmentationTrainer
5
5
  from .val import SegmentationValidator
6
6
 
7
- __all__ = 'SegmentationPredictor', 'SegmentationTrainer', 'SegmentationValidator'
7
+ __all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator"
@@ -23,17 +23,19 @@ class SegmentationPredictor(DetectionPredictor):
23
23
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
24
24
  """Initializes the SegmentationPredictor with the provided configuration, overrides, and callbacks."""
25
25
  super().__init__(cfg, overrides, _callbacks)
26
- self.args.task = 'segment'
26
+ self.args.task = "segment"
27
27
 
28
28
  def postprocess(self, preds, img, orig_imgs):
29
29
  """Applies non-max suppression and processes detections for each image in an input batch."""
30
- p = ops.non_max_suppression(preds[0],
31
- self.args.conf,
32
- self.args.iou,
33
- agnostic=self.args.agnostic_nms,
34
- max_det=self.args.max_det,
35
- nc=len(self.model.names),
36
- classes=self.args.classes)
30
+ p = ops.non_max_suppression(
31
+ preds[0],
32
+ self.args.conf,
33
+ self.args.iou,
34
+ agnostic=self.args.agnostic_nms,
35
+ max_det=self.args.max_det,
36
+ nc=len(self.model.names),
37
+ classes=self.args.classes,
38
+ )
37
39
 
38
40
  if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
39
41
  orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
@@ -26,12 +26,12 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
26
26
  """Initialize a SegmentationTrainer object with given arguments."""
27
27
  if overrides is None:
28
28
  overrides = {}
29
- overrides['task'] = 'segment'
29
+ overrides["task"] = "segment"
30
30
  super().__init__(cfg, overrides, _callbacks)
31
31
 
32
32
  def get_model(self, cfg=None, weights=None, verbose=True):
33
33
  """Return SegmentationModel initialized with specified config and weights."""
34
- model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
34
+ model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
35
35
  if weights:
36
36
  model.load(weights)
37
37
 
@@ -39,22 +39,23 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
39
39
 
40
40
  def get_validator(self):
41
41
  """Return an instance of SegmentationValidator for validation of YOLO model."""
42
- self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss'
43
- return yolo.segment.SegmentationValidator(self.test_loader,
44
- save_dir=self.save_dir,
45
- args=copy(self.args),
46
- _callbacks=self.callbacks)
42
+ self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss"
43
+ return yolo.segment.SegmentationValidator(
44
+ self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
45
+ )
47
46
 
48
47
  def plot_training_samples(self, batch, ni):
49
48
  """Creates a plot of training sample images with labels and box coordinates."""
50
- plot_images(batch['img'],
51
- batch['batch_idx'],
52
- batch['cls'].squeeze(-1),
53
- batch['bboxes'],
54
- masks=batch['masks'],
55
- paths=batch['im_file'],
56
- fname=self.save_dir / f'train_batch{ni}.jpg',
57
- on_plot=self.on_plot)
49
+ plot_images(
50
+ batch["img"],
51
+ batch["batch_idx"],
52
+ batch["cls"].squeeze(-1),
53
+ batch["bboxes"],
54
+ masks=batch["masks"],
55
+ paths=batch["im_file"],
56
+ fname=self.save_dir / f"train_batch{ni}.jpg",
57
+ on_plot=self.on_plot,
58
+ )
58
59
 
59
60
  def plot_metrics(self):
60
61
  """Plots training/val metrics."""
@@ -33,13 +33,13 @@ class SegmentationValidator(DetectionValidator):
33
33
  super().__init__(dataloader, save_dir, pbar, args, _callbacks)
34
34
  self.plot_masks = None
35
35
  self.process = None
36
- self.args.task = 'segment'
36
+ self.args.task = "segment"
37
37
  self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
38
38
 
39
39
  def preprocess(self, batch):
40
40
  """Preprocesses batch by converting masks to float and sending to device."""
41
41
  batch = super().preprocess(batch)
42
- batch['masks'] = batch['masks'].to(self.device).float()
42
+ batch["masks"] = batch["masks"].to(self.device).float()
43
43
  return batch
44
44
 
45
45
  def init_metrics(self, model):
@@ -47,7 +47,7 @@ class SegmentationValidator(DetectionValidator):
47
47
  super().init_metrics(model)
48
48
  self.plot_masks = []
49
49
  if self.args.save_json:
50
- check_requirements('pycocotools>=2.0.6')
50
+ check_requirements("pycocotools>=2.0.6")
51
51
  self.process = ops.process_mask_upsample # more accurate
52
52
  else:
53
53
  self.process = ops.process_mask # faster
@@ -55,31 +55,46 @@ class SegmentationValidator(DetectionValidator):
55
55
 
56
56
  def get_desc(self):
57
57
  """Return a formatted description of evaluation metrics."""
58
- return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Mask(P',
59
- 'R', 'mAP50', 'mAP50-95)')
58
+ return ("%22s" + "%11s" * 10) % (
59
+ "Class",
60
+ "Images",
61
+ "Instances",
62
+ "Box(P",
63
+ "R",
64
+ "mAP50",
65
+ "mAP50-95)",
66
+ "Mask(P",
67
+ "R",
68
+ "mAP50",
69
+ "mAP50-95)",
70
+ )
60
71
 
61
72
  def postprocess(self, preds):
62
73
  """Post-processes YOLO predictions and returns output detections with proto."""
63
- p = ops.non_max_suppression(preds[0],
64
- self.args.conf,
65
- self.args.iou,
66
- labels=self.lb,
67
- multi_label=True,
68
- agnostic=self.args.single_cls,
69
- max_det=self.args.max_det,
70
- nc=self.nc)
74
+ p = ops.non_max_suppression(
75
+ preds[0],
76
+ self.args.conf,
77
+ self.args.iou,
78
+ labels=self.lb,
79
+ multi_label=True,
80
+ agnostic=self.args.single_cls,
81
+ max_det=self.args.max_det,
82
+ nc=self.nc,
83
+ )
71
84
  proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
72
85
  return p, proto
73
86
 
74
87
  def _prepare_batch(self, si, batch):
88
+ """Prepares a batch for training or inference by processing images and targets."""
75
89
  prepared_batch = super()._prepare_batch(si, batch)
76
- midx = [si] if self.args.overlap_mask else batch['batch_idx'] == si
77
- prepared_batch['masks'] = batch['masks'][midx]
90
+ midx = [si] if self.args.overlap_mask else batch["batch_idx"] == si
91
+ prepared_batch["masks"] = batch["masks"][midx]
78
92
  return prepared_batch
79
93
 
80
94
  def _prepare_pred(self, pred, pbatch, proto):
95
+ """Prepares a batch for training or inference by processing images and targets."""
81
96
  predn = super()._prepare_pred(pred, pbatch)
82
- pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch['imgsz'])
97
+ pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch["imgsz"])
83
98
  return predn, pred_masks
84
99
 
85
100
  def update_metrics(self, preds, batch):
@@ -87,14 +102,16 @@ class SegmentationValidator(DetectionValidator):
87
102
  for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
88
103
  self.seen += 1
89
104
  npr = len(pred)
90
- stat = dict(conf=torch.zeros(0, device=self.device),
91
- pred_cls=torch.zeros(0, device=self.device),
92
- tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
93
- tp_m=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device))
105
+ stat = dict(
106
+ conf=torch.zeros(0, device=self.device),
107
+ pred_cls=torch.zeros(0, device=self.device),
108
+ tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
109
+ tp_m=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
110
+ )
94
111
  pbatch = self._prepare_batch(si, batch)
95
- cls, bbox = pbatch.pop('cls'), pbatch.pop('bbox')
112
+ cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
96
113
  nl = len(cls)
97
- stat['target_cls'] = cls
114
+ stat["target_cls"] = cls
98
115
  if npr == 0:
99
116
  if nl:
100
117
  for k in self.stats.keys():
@@ -104,24 +121,20 @@ class SegmentationValidator(DetectionValidator):
104
121
  continue
105
122
 
106
123
  # Masks
107
- gt_masks = pbatch.pop('masks')
124
+ gt_masks = pbatch.pop("masks")
108
125
  # Predictions
109
126
  if self.args.single_cls:
110
127
  pred[:, 5] = 0
111
128
  predn, pred_masks = self._prepare_pred(pred, pbatch, proto)
112
- stat['conf'] = predn[:, 4]
113
- stat['pred_cls'] = predn[:, 5]
129
+ stat["conf"] = predn[:, 4]
130
+ stat["pred_cls"] = predn[:, 5]
114
131
 
115
132
  # Evaluate
116
133
  if nl:
117
- stat['tp'] = self._process_batch(predn, bbox, cls)
118
- stat['tp_m'] = self._process_batch(predn,
119
- bbox,
120
- cls,
121
- pred_masks,
122
- gt_masks,
123
- self.args.overlap_mask,
124
- masks=True)
134
+ stat["tp"] = self._process_batch(predn, bbox, cls)
135
+ stat["tp_m"] = self._process_batch(
136
+ predn, bbox, cls, pred_masks, gt_masks, self.args.overlap_mask, masks=True
137
+ )
125
138
  if self.args.plots:
126
139
  self.confusion_matrix.process_batch(predn, bbox, cls)
127
140
 
@@ -134,10 +147,12 @@ class SegmentationValidator(DetectionValidator):
134
147
 
135
148
  # Save
136
149
  if self.args.save_json:
137
- pred_masks = ops.scale_image(pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
138
- pbatch['ori_shape'],
139
- ratio_pad=batch['ratio_pad'][si])
140
- self.pred_to_json(predn, batch['im_file'][si], pred_masks)
150
+ pred_masks = ops.scale_image(
151
+ pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
152
+ pbatch["ori_shape"],
153
+ ratio_pad=batch["ratio_pad"][si],
154
+ )
155
+ self.pred_to_json(predn, batch["im_file"][si], pred_masks)
141
156
  # if self.args.save_txt:
142
157
  # save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
143
158
 
@@ -164,7 +179,7 @@ class SegmentationValidator(DetectionValidator):
164
179
  gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
165
180
  gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
166
181
  if gt_masks.shape[1:] != pred_masks.shape[1:]:
167
- gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode='bilinear', align_corners=False)[0]
182
+ gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0]
168
183
  gt_masks = gt_masks.gt_(0.5)
169
184
  iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
170
185
  else: # boxes
@@ -174,26 +189,29 @@ class SegmentationValidator(DetectionValidator):
174
189
 
175
190
  def plot_val_samples(self, batch, ni):
176
191
  """Plots validation samples with bounding box labels."""
177
- plot_images(batch['img'],
178
- batch['batch_idx'],
179
- batch['cls'].squeeze(-1),
180
- batch['bboxes'],
181
- masks=batch['masks'],
182
- paths=batch['im_file'],
183
- fname=self.save_dir / f'val_batch{ni}_labels.jpg',
184
- names=self.names,
185
- on_plot=self.on_plot)
192
+ plot_images(
193
+ batch["img"],
194
+ batch["batch_idx"],
195
+ batch["cls"].squeeze(-1),
196
+ batch["bboxes"],
197
+ masks=batch["masks"],
198
+ paths=batch["im_file"],
199
+ fname=self.save_dir / f"val_batch{ni}_labels.jpg",
200
+ names=self.names,
201
+ on_plot=self.on_plot,
202
+ )
186
203
 
187
204
  def plot_predictions(self, batch, preds, ni):
188
205
  """Plots batch predictions with masks and bounding boxes."""
189
206
  plot_images(
190
- batch['img'],
207
+ batch["img"],
191
208
  *output_to_target(preds[0], max_det=15), # not set to self.args.max_det due to slow plotting speed
192
209
  torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
193
- paths=batch['im_file'],
194
- fname=self.save_dir / f'val_batch{ni}_pred.jpg',
210
+ paths=batch["im_file"],
211
+ fname=self.save_dir / f"val_batch{ni}_pred.jpg",
195
212
  names=self.names,
196
- on_plot=self.on_plot) # pred
213
+ on_plot=self.on_plot,
214
+ ) # pred
197
215
  self.plot_masks.clear()
198
216
 
199
217
  def pred_to_json(self, predn, filename, pred_masks):
@@ -203,8 +221,8 @@ class SegmentationValidator(DetectionValidator):
203
221
 
204
222
  def single_encode(x):
205
223
  """Encode predicted masks as RLE and append results to jdict."""
206
- rle = encode(np.asarray(x[:, :, None], order='F', dtype='uint8'))[0]
207
- rle['counts'] = rle['counts'].decode('utf-8')
224
+ rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0]
225
+ rle["counts"] = rle["counts"].decode("utf-8")
208
226
  return rle
209
227
 
210
228
  stem = Path(filename).stem
@@ -215,37 +233,41 @@ class SegmentationValidator(DetectionValidator):
215
233
  with ThreadPool(NUM_THREADS) as pool:
216
234
  rles = pool.map(single_encode, pred_masks)
217
235
  for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())):
218
- self.jdict.append({
219
- 'image_id': image_id,
220
- 'category_id': self.class_map[int(p[5])],
221
- 'bbox': [round(x, 3) for x in b],
222
- 'score': round(p[4], 5),
223
- 'segmentation': rles[i]})
236
+ self.jdict.append(
237
+ {
238
+ "image_id": image_id,
239
+ "category_id": self.class_map[int(p[5])],
240
+ "bbox": [round(x, 3) for x in b],
241
+ "score": round(p[4], 5),
242
+ "segmentation": rles[i],
243
+ }
244
+ )
224
245
 
225
246
  def eval_json(self, stats):
226
247
  """Return COCO-style object detection evaluation metrics."""
227
248
  if self.args.save_json and self.is_coco and len(self.jdict):
228
- anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
229
- pred_json = self.save_dir / 'predictions.json' # predictions
230
- LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
249
+ anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations
250
+ pred_json = self.save_dir / "predictions.json" # predictions
251
+ LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
231
252
  try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
232
- check_requirements('pycocotools>=2.0.6')
253
+ check_requirements("pycocotools>=2.0.6")
233
254
  from pycocotools.coco import COCO # noqa
234
255
  from pycocotools.cocoeval import COCOeval # noqa
235
256
 
236
257
  for x in anno_json, pred_json:
237
- assert x.is_file(), f'{x} file not found'
258
+ assert x.is_file(), f"{x} file not found"
238
259
  anno = COCO(str(anno_json)) # init annotations api
239
260
  pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
240
- for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm')]):
261
+ for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "segm")]):
241
262
  if self.is_coco:
242
263
  eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
243
264
  eval.evaluate()
244
265
  eval.accumulate()
245
266
  eval.summarize()
246
267
  idx = i * 4 + 2
247
- stats[self.metrics.keys[idx + 1]], stats[
248
- self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50
268
+ stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[
269
+ :2
270
+ ] # update mAP50-95 and mAP50
249
271
  except Exception as e:
250
- LOGGER.warning(f'pycocotools unable to run: {e}')
272
+ LOGGER.warning(f"pycocotools unable to run: {e}")
251
273
  return stats