ultralytics 8.0.238__py3-none-any.whl → 8.0.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (134) hide show
  1. ultralytics/__init__.py +2 -2
  2. ultralytics/cfg/__init__.py +241 -138
  3. ultralytics/data/__init__.py +9 -2
  4. ultralytics/data/annotator.py +4 -4
  5. ultralytics/data/augment.py +186 -169
  6. ultralytics/data/base.py +54 -48
  7. ultralytics/data/build.py +34 -23
  8. ultralytics/data/converter.py +242 -70
  9. ultralytics/data/dataset.py +117 -95
  10. ultralytics/data/explorer/__init__.py +3 -1
  11. ultralytics/data/explorer/explorer.py +120 -100
  12. ultralytics/data/explorer/gui/__init__.py +1 -0
  13. ultralytics/data/explorer/gui/dash.py +123 -89
  14. ultralytics/data/explorer/utils.py +37 -39
  15. ultralytics/data/loaders.py +75 -62
  16. ultralytics/data/split_dota.py +44 -36
  17. ultralytics/data/utils.py +160 -142
  18. ultralytics/engine/exporter.py +348 -292
  19. ultralytics/engine/model.py +102 -66
  20. ultralytics/engine/predictor.py +74 -55
  21. ultralytics/engine/results.py +61 -41
  22. ultralytics/engine/trainer.py +192 -144
  23. ultralytics/engine/tuner.py +66 -59
  24. ultralytics/engine/validator.py +31 -26
  25. ultralytics/hub/__init__.py +54 -31
  26. ultralytics/hub/auth.py +28 -25
  27. ultralytics/hub/session.py +282 -133
  28. ultralytics/hub/utils.py +64 -42
  29. ultralytics/models/__init__.py +1 -1
  30. ultralytics/models/fastsam/__init__.py +1 -1
  31. ultralytics/models/fastsam/model.py +6 -6
  32. ultralytics/models/fastsam/predict.py +3 -2
  33. ultralytics/models/fastsam/prompt.py +55 -48
  34. ultralytics/models/fastsam/val.py +1 -1
  35. ultralytics/models/nas/__init__.py +1 -1
  36. ultralytics/models/nas/model.py +9 -8
  37. ultralytics/models/nas/predict.py +8 -6
  38. ultralytics/models/nas/val.py +11 -9
  39. ultralytics/models/rtdetr/__init__.py +1 -1
  40. ultralytics/models/rtdetr/model.py +11 -9
  41. ultralytics/models/rtdetr/train.py +18 -16
  42. ultralytics/models/rtdetr/val.py +25 -19
  43. ultralytics/models/sam/__init__.py +1 -1
  44. ultralytics/models/sam/amg.py +13 -14
  45. ultralytics/models/sam/build.py +44 -42
  46. ultralytics/models/sam/model.py +6 -6
  47. ultralytics/models/sam/modules/decoders.py +6 -4
  48. ultralytics/models/sam/modules/encoders.py +37 -35
  49. ultralytics/models/sam/modules/sam.py +5 -4
  50. ultralytics/models/sam/modules/tiny_encoder.py +95 -73
  51. ultralytics/models/sam/modules/transformer.py +3 -2
  52. ultralytics/models/sam/predict.py +39 -27
  53. ultralytics/models/utils/loss.py +99 -95
  54. ultralytics/models/utils/ops.py +34 -31
  55. ultralytics/models/yolo/__init__.py +1 -1
  56. ultralytics/models/yolo/classify/__init__.py +1 -1
  57. ultralytics/models/yolo/classify/predict.py +8 -6
  58. ultralytics/models/yolo/classify/train.py +37 -31
  59. ultralytics/models/yolo/classify/val.py +26 -24
  60. ultralytics/models/yolo/detect/__init__.py +1 -1
  61. ultralytics/models/yolo/detect/predict.py +8 -6
  62. ultralytics/models/yolo/detect/train.py +47 -37
  63. ultralytics/models/yolo/detect/val.py +100 -82
  64. ultralytics/models/yolo/model.py +31 -25
  65. ultralytics/models/yolo/obb/__init__.py +1 -1
  66. ultralytics/models/yolo/obb/predict.py +13 -11
  67. ultralytics/models/yolo/obb/train.py +3 -3
  68. ultralytics/models/yolo/obb/val.py +70 -59
  69. ultralytics/models/yolo/pose/__init__.py +1 -1
  70. ultralytics/models/yolo/pose/predict.py +17 -12
  71. ultralytics/models/yolo/pose/train.py +28 -25
  72. ultralytics/models/yolo/pose/val.py +91 -64
  73. ultralytics/models/yolo/segment/__init__.py +1 -1
  74. ultralytics/models/yolo/segment/predict.py +10 -8
  75. ultralytics/models/yolo/segment/train.py +16 -15
  76. ultralytics/models/yolo/segment/val.py +90 -68
  77. ultralytics/nn/__init__.py +26 -6
  78. ultralytics/nn/autobackend.py +144 -112
  79. ultralytics/nn/modules/__init__.py +96 -13
  80. ultralytics/nn/modules/block.py +28 -7
  81. ultralytics/nn/modules/conv.py +41 -23
  82. ultralytics/nn/modules/head.py +60 -52
  83. ultralytics/nn/modules/transformer.py +49 -32
  84. ultralytics/nn/modules/utils.py +20 -15
  85. ultralytics/nn/tasks.py +215 -141
  86. ultralytics/solutions/ai_gym.py +59 -47
  87. ultralytics/solutions/distance_calculation.py +17 -14
  88. ultralytics/solutions/heatmap.py +57 -55
  89. ultralytics/solutions/object_counter.py +46 -39
  90. ultralytics/solutions/speed_estimation.py +13 -16
  91. ultralytics/trackers/__init__.py +1 -1
  92. ultralytics/trackers/basetrack.py +1 -0
  93. ultralytics/trackers/bot_sort.py +2 -1
  94. ultralytics/trackers/byte_tracker.py +10 -7
  95. ultralytics/trackers/track.py +7 -7
  96. ultralytics/trackers/utils/gmc.py +25 -25
  97. ultralytics/trackers/utils/kalman_filter.py +85 -42
  98. ultralytics/trackers/utils/matching.py +8 -7
  99. ultralytics/utils/__init__.py +173 -152
  100. ultralytics/utils/autobatch.py +10 -10
  101. ultralytics/utils/benchmarks.py +76 -86
  102. ultralytics/utils/callbacks/__init__.py +1 -1
  103. ultralytics/utils/callbacks/base.py +29 -29
  104. ultralytics/utils/callbacks/clearml.py +51 -43
  105. ultralytics/utils/callbacks/comet.py +81 -66
  106. ultralytics/utils/callbacks/dvc.py +33 -26
  107. ultralytics/utils/callbacks/hub.py +44 -26
  108. ultralytics/utils/callbacks/mlflow.py +31 -24
  109. ultralytics/utils/callbacks/neptune.py +35 -25
  110. ultralytics/utils/callbacks/raytune.py +9 -4
  111. ultralytics/utils/callbacks/tensorboard.py +16 -11
  112. ultralytics/utils/callbacks/wb.py +39 -33
  113. ultralytics/utils/checks.py +189 -141
  114. ultralytics/utils/dist.py +15 -12
  115. ultralytics/utils/downloads.py +112 -96
  116. ultralytics/utils/errors.py +1 -1
  117. ultralytics/utils/files.py +11 -11
  118. ultralytics/utils/instance.py +22 -22
  119. ultralytics/utils/loss.py +117 -67
  120. ultralytics/utils/metrics.py +224 -158
  121. ultralytics/utils/ops.py +38 -28
  122. ultralytics/utils/patches.py +3 -3
  123. ultralytics/utils/plotting.py +217 -120
  124. ultralytics/utils/tal.py +19 -13
  125. ultralytics/utils/torch_utils.py +138 -109
  126. ultralytics/utils/triton.py +12 -10
  127. ultralytics/utils/tuner.py +49 -47
  128. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/METADATA +2 -1
  129. ultralytics-8.0.239.dist-info/RECORD +188 -0
  130. ultralytics-8.0.238.dist-info/RECORD +0 -188
  131. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
  132. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
  133. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
  134. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
@@ -12,28 +12,34 @@ class YOLO(Model):
12
12
  def task_map(self):
13
13
  """Map head to model, trainer, validator, and predictor classes."""
14
14
  return {
15
- 'classify': {
16
- 'model': ClassificationModel,
17
- 'trainer': yolo.classify.ClassificationTrainer,
18
- 'validator': yolo.classify.ClassificationValidator,
19
- 'predictor': yolo.classify.ClassificationPredictor, },
20
- 'detect': {
21
- 'model': DetectionModel,
22
- 'trainer': yolo.detect.DetectionTrainer,
23
- 'validator': yolo.detect.DetectionValidator,
24
- 'predictor': yolo.detect.DetectionPredictor, },
25
- 'segment': {
26
- 'model': SegmentationModel,
27
- 'trainer': yolo.segment.SegmentationTrainer,
28
- 'validator': yolo.segment.SegmentationValidator,
29
- 'predictor': yolo.segment.SegmentationPredictor, },
30
- 'pose': {
31
- 'model': PoseModel,
32
- 'trainer': yolo.pose.PoseTrainer,
33
- 'validator': yolo.pose.PoseValidator,
34
- 'predictor': yolo.pose.PosePredictor, },
35
- 'obb': {
36
- 'model': OBBModel,
37
- 'trainer': yolo.obb.OBBTrainer,
38
- 'validator': yolo.obb.OBBValidator,
39
- 'predictor': yolo.obb.OBBPredictor, }, }
15
+ "classify": {
16
+ "model": ClassificationModel,
17
+ "trainer": yolo.classify.ClassificationTrainer,
18
+ "validator": yolo.classify.ClassificationValidator,
19
+ "predictor": yolo.classify.ClassificationPredictor,
20
+ },
21
+ "detect": {
22
+ "model": DetectionModel,
23
+ "trainer": yolo.detect.DetectionTrainer,
24
+ "validator": yolo.detect.DetectionValidator,
25
+ "predictor": yolo.detect.DetectionPredictor,
26
+ },
27
+ "segment": {
28
+ "model": SegmentationModel,
29
+ "trainer": yolo.segment.SegmentationTrainer,
30
+ "validator": yolo.segment.SegmentationValidator,
31
+ "predictor": yolo.segment.SegmentationPredictor,
32
+ },
33
+ "pose": {
34
+ "model": PoseModel,
35
+ "trainer": yolo.pose.PoseTrainer,
36
+ "validator": yolo.pose.PoseValidator,
37
+ "predictor": yolo.pose.PosePredictor,
38
+ },
39
+ "obb": {
40
+ "model": OBBModel,
41
+ "trainer": yolo.obb.OBBTrainer,
42
+ "validator": yolo.obb.OBBValidator,
43
+ "predictor": yolo.obb.OBBPredictor,
44
+ },
45
+ }
@@ -4,4 +4,4 @@ from .predict import OBBPredictor
4
4
  from .train import OBBTrainer
5
5
  from .val import OBBValidator
6
6
 
7
- __all__ = 'OBBPredictor', 'OBBTrainer', 'OBBValidator'
7
+ __all__ = "OBBPredictor", "OBBTrainer", "OBBValidator"
@@ -23,27 +23,29 @@ class OBBPredictor(DetectionPredictor):
23
23
  """
24
24
 
25
25
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
26
+ """Initializes OBBPredictor with optional model and data configuration overrides."""
26
27
  super().__init__(cfg, overrides, _callbacks)
27
- self.args.task = 'obb'
28
+ self.args.task = "obb"
28
29
 
29
30
  def postprocess(self, preds, img, orig_imgs):
30
31
  """Post-processes predictions and returns a list of Results objects."""
31
- preds = ops.non_max_suppression(preds,
32
- self.args.conf,
33
- self.args.iou,
34
- agnostic=self.args.agnostic_nms,
35
- max_det=self.args.max_det,
36
- nc=len(self.model.names),
37
- classes=self.args.classes,
38
- rotated=True)
32
+ preds = ops.non_max_suppression(
33
+ preds,
34
+ self.args.conf,
35
+ self.args.iou,
36
+ agnostic=self.args.agnostic_nms,
37
+ max_det=self.args.max_det,
38
+ nc=len(self.model.names),
39
+ classes=self.args.classes,
40
+ rotated=True,
41
+ )
39
42
 
40
43
  if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
41
44
  orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
42
45
 
43
46
  results = []
44
- for i, (pred, orig_img) in enumerate(zip(preds, orig_imgs)):
47
+ for i, (pred, orig_img, img_path) in enumerate(zip(preds, orig_imgs, self.batch[0])):
45
48
  pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape, xywh=True)
46
- img_path = self.batch[0][i]
47
49
  # xywh, r, conf, cls
48
50
  obb = torch.cat([pred[:, :4], pred[:, -1:], pred[:, 4:6]], dim=-1)
49
51
  results.append(Results(orig_img, path=img_path, names=self.model.names, obb=obb))
@@ -25,12 +25,12 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
25
25
  """Initialize a OBBTrainer object with given arguments."""
26
26
  if overrides is None:
27
27
  overrides = {}
28
- overrides['task'] = 'obb'
28
+ overrides["task"] = "obb"
29
29
  super().__init__(cfg, overrides, _callbacks)
30
30
 
31
31
  def get_model(self, cfg=None, weights=None, verbose=True):
32
32
  """Return OBBModel initialized with specified config and weights."""
33
- model = OBBModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
33
+ model = OBBModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
34
34
  if weights:
35
35
  model.load(weights)
36
36
 
@@ -38,5 +38,5 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
38
38
 
39
39
  def get_validator(self):
40
40
  """Return an instance of OBBValidator for validation of YOLO model."""
41
- self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
41
+ self.loss_names = "box_loss", "cls_loss", "dfl_loss"
42
42
  return yolo.obb.OBBValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
@@ -27,26 +27,28 @@ class OBBValidator(DetectionValidator):
27
27
  def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
28
28
  """Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics."""
29
29
  super().__init__(dataloader, save_dir, pbar, args, _callbacks)
30
- self.args.task = 'obb'
30
+ self.args.task = "obb"
31
31
  self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True, on_plot=self.on_plot)
32
32
 
33
33
  def init_metrics(self, model):
34
34
  """Initialize evaluation metrics for YOLO."""
35
35
  super().init_metrics(model)
36
- val = self.data.get(self.args.split, '') # validation path
37
- self.is_dota = isinstance(val, str) and 'DOTA' in val # is COCO
36
+ val = self.data.get(self.args.split, "") # validation path
37
+ self.is_dota = isinstance(val, str) and "DOTA" in val # is COCO
38
38
 
39
39
  def postprocess(self, preds):
40
40
  """Apply Non-maximum suppression to prediction outputs."""
41
- return ops.non_max_suppression(preds,
42
- self.args.conf,
43
- self.args.iou,
44
- labels=self.lb,
45
- nc=self.nc,
46
- multi_label=True,
47
- agnostic=self.args.single_cls,
48
- max_det=self.args.max_det,
49
- rotated=True)
41
+ return ops.non_max_suppression(
42
+ preds,
43
+ self.args.conf,
44
+ self.args.iou,
45
+ labels=self.lb,
46
+ nc=self.nc,
47
+ multi_label=True,
48
+ agnostic=self.args.single_cls,
49
+ max_det=self.args.max_det,
50
+ rotated=True,
51
+ )
50
52
 
51
53
  def _process_batch(self, detections, gt_bboxes, gt_cls):
52
54
  """
@@ -65,12 +67,13 @@ class OBBValidator(DetectionValidator):
65
67
  return self.match_predictions(detections[:, 5], gt_cls, iou)
66
68
 
67
69
  def _prepare_batch(self, si, batch):
68
- idx = batch['batch_idx'] == si
69
- cls = batch['cls'][idx].squeeze(-1)
70
- bbox = batch['bboxes'][idx]
71
- ori_shape = batch['ori_shape'][si]
72
- imgsz = batch['img'].shape[2:]
73
- ratio_pad = batch['ratio_pad'][si]
70
+ """Prepares and returns a batch for OBB validation."""
71
+ idx = batch["batch_idx"] == si
72
+ cls = batch["cls"][idx].squeeze(-1)
73
+ bbox = batch["bboxes"][idx]
74
+ ori_shape = batch["ori_shape"][si]
75
+ imgsz = batch["img"].shape[2:]
76
+ ratio_pad = batch["ratio_pad"][si]
74
77
  if len(cls):
75
78
  bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes
76
79
  ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels
@@ -78,19 +81,23 @@ class OBBValidator(DetectionValidator):
78
81
  return prepared_batch
79
82
 
80
83
  def _prepare_pred(self, pred, pbatch):
84
+ """Prepares and returns a batch for OBB validation with scaled and padded bounding boxes."""
81
85
  predn = pred.clone()
82
- ops.scale_boxes(pbatch['imgsz'], predn[:, :4], pbatch['ori_shape'], ratio_pad=pbatch['ratio_pad'],
83
- xywh=True) # native-space pred
86
+ ops.scale_boxes(
87
+ pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
88
+ ) # native-space pred
84
89
  return predn
85
90
 
86
91
  def plot_predictions(self, batch, preds, ni):
87
92
  """Plots predicted bounding boxes on input images and saves the result."""
88
- plot_images(batch['img'],
89
- *output_to_rotated_target(preds, max_det=self.args.max_det),
90
- paths=batch['im_file'],
91
- fname=self.save_dir / f'val_batch{ni}_pred.jpg',
92
- names=self.names,
93
- on_plot=self.on_plot) # pred
93
+ plot_images(
94
+ batch["img"],
95
+ *output_to_rotated_target(preds, max_det=self.args.max_det),
96
+ paths=batch["im_file"],
97
+ fname=self.save_dir / f"val_batch{ni}_pred.jpg",
98
+ names=self.names,
99
+ on_plot=self.on_plot,
100
+ ) # pred
94
101
 
95
102
  def pred_to_json(self, predn, filename):
96
103
  """Serialize YOLO predictions to COCO json format."""
@@ -99,12 +106,15 @@ class OBBValidator(DetectionValidator):
99
106
  rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
100
107
  poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
101
108
  for i, (r, b) in enumerate(zip(rbox.tolist(), poly.tolist())):
102
- self.jdict.append({
103
- 'image_id': image_id,
104
- 'category_id': self.class_map[int(predn[i, 5].item())],
105
- 'score': round(predn[i, 4].item(), 5),
106
- 'rbox': [round(x, 3) for x in r],
107
- 'poly': [round(x, 3) for x in b]})
109
+ self.jdict.append(
110
+ {
111
+ "image_id": image_id,
112
+ "category_id": self.class_map[int(predn[i, 5].item())],
113
+ "score": round(predn[i, 4].item(), 5),
114
+ "rbox": [round(x, 3) for x in r],
115
+ "poly": [round(x, 3) for x in b],
116
+ }
117
+ )
108
118
 
109
119
  def save_one_txt(self, predn, save_conf, shape, file):
110
120
  """Save YOLO detections to a txt file in normalized coordinates in a specific format."""
@@ -114,8 +124,8 @@ class OBBValidator(DetectionValidator):
114
124
  xywha[:, :4] /= gn
115
125
  xyxyxyxy = ops.xywhr2xyxyxyxy(xywha).view(-1).tolist() # normalized xywh
116
126
  line = (cls, *xyxyxyxy, conf) if save_conf else (cls, *xyxyxyxy) # label format
117
- with open(file, 'a') as f:
118
- f.write(('%g ' * len(line)).rstrip() % line + '\n')
127
+ with open(file, "a") as f:
128
+ f.write(("%g " * len(line)).rstrip() % line + "\n")
119
129
 
120
130
  def eval_json(self, stats):
121
131
  """Evaluates YOLO output in JSON format and returns performance statistics."""
@@ -123,42 +133,43 @@ class OBBValidator(DetectionValidator):
123
133
  import json
124
134
  import re
125
135
  from collections import defaultdict
126
- pred_json = self.save_dir / 'predictions.json' # predictions
127
- pred_txt = self.save_dir / 'predictions_txt' # predictions
136
+
137
+ pred_json = self.save_dir / "predictions.json" # predictions
138
+ pred_txt = self.save_dir / "predictions_txt" # predictions
128
139
  pred_txt.mkdir(parents=True, exist_ok=True)
129
140
  data = json.load(open(pred_json))
130
141
  # Save split results
131
- LOGGER.info(f'Saving predictions with DOTA format to {str(pred_txt)}...')
142
+ LOGGER.info(f"Saving predictions with DOTA format to {str(pred_txt)}...")
132
143
  for d in data:
133
- image_id = d['image_id']
134
- score = d['score']
135
- classname = self.names[d['category_id']].replace(' ', '-')
144
+ image_id = d["image_id"]
145
+ score = d["score"]
146
+ classname = self.names[d["category_id"]].replace(" ", "-")
136
147
 
137
- lines = '{} {} {} {} {} {} {} {} {} {}\n'.format(
148
+ lines = "{} {} {} {} {} {} {} {} {} {}\n".format(
138
149
  image_id,
139
150
  score,
140
- d['poly'][0],
141
- d['poly'][1],
142
- d['poly'][2],
143
- d['poly'][3],
144
- d['poly'][4],
145
- d['poly'][5],
146
- d['poly'][6],
147
- d['poly'][7],
151
+ d["poly"][0],
152
+ d["poly"][1],
153
+ d["poly"][2],
154
+ d["poly"][3],
155
+ d["poly"][4],
156
+ d["poly"][5],
157
+ d["poly"][6],
158
+ d["poly"][7],
148
159
  )
149
- with open(str(pred_txt / f'Task1_{classname}') + '.txt', 'a') as f:
160
+ with open(str(pred_txt / f"Task1_{classname}") + ".txt", "a") as f:
150
161
  f.writelines(lines)
151
162
  # Save merged results, this could result slightly lower map than using official merging script,
152
163
  # because of the probiou calculation.
153
- pred_merged_txt = self.save_dir / 'predictions_merged_txt' # predictions
164
+ pred_merged_txt = self.save_dir / "predictions_merged_txt" # predictions
154
165
  pred_merged_txt.mkdir(parents=True, exist_ok=True)
155
166
  merged_results = defaultdict(list)
156
- LOGGER.info(f'Saving merged predictions with DOTA format to {str(pred_merged_txt)}...')
167
+ LOGGER.info(f"Saving merged predictions with DOTA format to {str(pred_merged_txt)}...")
157
168
  for d in data:
158
- image_id = d['image_id'].split('__')[0]
159
- pattern = re.compile(r'\d+___\d+')
160
- x, y = (int(c) for c in re.findall(pattern, d['image_id'])[0].split('___'))
161
- bbox, score, cls = d['rbox'], d['score'], d['category_id']
169
+ image_id = d["image_id"].split("__")[0]
170
+ pattern = re.compile(r"\d+___\d+")
171
+ x, y = (int(c) for c in re.findall(pattern, d["image_id"])[0].split("___"))
172
+ bbox, score, cls = d["rbox"], d["score"], d["category_id"]
162
173
  bbox[0] += x
163
174
  bbox[1] += y
164
175
  bbox.extend([score, cls])
@@ -176,11 +187,11 @@ class OBBValidator(DetectionValidator):
176
187
 
177
188
  b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8)
178
189
  for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist():
179
- classname = self.names[int(x[-1])].replace(' ', '-')
190
+ classname = self.names[int(x[-1])].replace(" ", "-")
180
191
  poly = [round(i, 3) for i in x[:-2]]
181
192
  score = round(x[-2], 3)
182
193
 
183
- lines = '{} {} {} {} {} {} {} {} {} {}\n'.format(
194
+ lines = "{} {} {} {} {} {} {} {} {} {}\n".format(
184
195
  image_id,
185
196
  score,
186
197
  poly[0],
@@ -192,7 +203,7 @@ class OBBValidator(DetectionValidator):
192
203
  poly[6],
193
204
  poly[7],
194
205
  )
195
- with open(str(pred_merged_txt / f'Task1_{classname}') + '.txt', 'a') as f:
206
+ with open(str(pred_merged_txt / f"Task1_{classname}") + ".txt", "a") as f:
196
207
  f.writelines(lines)
197
208
 
198
209
  return stats
@@ -4,4 +4,4 @@ from .predict import PosePredictor
4
4
  from .train import PoseTrainer
5
5
  from .val import PoseValidator
6
6
 
7
- __all__ = 'PoseTrainer', 'PoseValidator', 'PosePredictor'
7
+ __all__ = "PoseTrainer", "PoseValidator", "PosePredictor"
@@ -23,20 +23,24 @@ class PosePredictor(DetectionPredictor):
23
23
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
24
24
  """Initializes PosePredictor, sets task to 'pose' and logs a warning for using 'mps' as device."""
25
25
  super().__init__(cfg, overrides, _callbacks)
26
- self.args.task = 'pose'
27
- if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
28
- LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
29
- 'See https://github.com/ultralytics/ultralytics/issues/4031.')
26
+ self.args.task = "pose"
27
+ if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
28
+ LOGGER.warning(
29
+ "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
30
+ "See https://github.com/ultralytics/ultralytics/issues/4031."
31
+ )
30
32
 
31
33
  def postprocess(self, preds, img, orig_imgs):
32
34
  """Return detection results for a given input image or list of images."""
33
- preds = ops.non_max_suppression(preds,
34
- self.args.conf,
35
- self.args.iou,
36
- agnostic=self.args.agnostic_nms,
37
- max_det=self.args.max_det,
38
- classes=self.args.classes,
39
- nc=len(self.model.names))
35
+ preds = ops.non_max_suppression(
36
+ preds,
37
+ self.args.conf,
38
+ self.args.iou,
39
+ agnostic=self.args.agnostic_nms,
40
+ max_det=self.args.max_det,
41
+ classes=self.args.classes,
42
+ nc=len(self.model.names),
43
+ )
40
44
 
41
45
  if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
42
46
  orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
@@ -49,5 +53,6 @@ class PosePredictor(DetectionPredictor):
49
53
  pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
50
54
  img_path = self.batch[0][i]
51
55
  results.append(
52
- Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts))
56
+ Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts)
57
+ )
53
58
  return results
@@ -26,16 +26,18 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
26
26
  """Initialize a PoseTrainer object with specified configurations and overrides."""
27
27
  if overrides is None:
28
28
  overrides = {}
29
- overrides['task'] = 'pose'
29
+ overrides["task"] = "pose"
30
30
  super().__init__(cfg, overrides, _callbacks)
31
31
 
32
- if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
33
- LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
34
- 'See https://github.com/ultralytics/ultralytics/issues/4031.')
32
+ if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
33
+ LOGGER.warning(
34
+ "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
35
+ "See https://github.com/ultralytics/ultralytics/issues/4031."
36
+ )
35
37
 
36
38
  def get_model(self, cfg=None, weights=None, verbose=True):
37
39
  """Get pose estimation model with specified configuration and weights."""
38
- model = PoseModel(cfg, ch=3, nc=self.data['nc'], data_kpt_shape=self.data['kpt_shape'], verbose=verbose)
40
+ model = PoseModel(cfg, ch=3, nc=self.data["nc"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose)
39
41
  if weights:
40
42
  model.load(weights)
41
43
 
@@ -44,32 +46,33 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
44
46
  def set_model_attributes(self):
45
47
  """Sets keypoints shape attribute of PoseModel."""
46
48
  super().set_model_attributes()
47
- self.model.kpt_shape = self.data['kpt_shape']
49
+ self.model.kpt_shape = self.data["kpt_shape"]
48
50
 
49
51
  def get_validator(self):
50
52
  """Returns an instance of the PoseValidator class for validation."""
51
- self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss'
52
- return yolo.pose.PoseValidator(self.test_loader,
53
- save_dir=self.save_dir,
54
- args=copy(self.args),
55
- _callbacks=self.callbacks)
53
+ self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
54
+ return yolo.pose.PoseValidator(
55
+ self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
56
+ )
56
57
 
57
58
  def plot_training_samples(self, batch, ni):
58
59
  """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
59
- images = batch['img']
60
- kpts = batch['keypoints']
61
- cls = batch['cls'].squeeze(-1)
62
- bboxes = batch['bboxes']
63
- paths = batch['im_file']
64
- batch_idx = batch['batch_idx']
65
- plot_images(images,
66
- batch_idx,
67
- cls,
68
- bboxes,
69
- kpts=kpts,
70
- paths=paths,
71
- fname=self.save_dir / f'train_batch{ni}.jpg',
72
- on_plot=self.on_plot)
60
+ images = batch["img"]
61
+ kpts = batch["keypoints"]
62
+ cls = batch["cls"].squeeze(-1)
63
+ bboxes = batch["bboxes"]
64
+ paths = batch["im_file"]
65
+ batch_idx = batch["batch_idx"]
66
+ plot_images(
67
+ images,
68
+ batch_idx,
69
+ cls,
70
+ bboxes,
71
+ kpts=kpts,
72
+ paths=paths,
73
+ fname=self.save_dir / f"train_batch{ni}.jpg",
74
+ on_plot=self.on_plot,
75
+ )
73
76
 
74
77
  def plot_metrics(self):
75
78
  """Plots training/val metrics."""