ultralytics 8.0.237__py3-none-any.whl → 8.0.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (137) hide show
  1. ultralytics/__init__.py +2 -2
  2. ultralytics/cfg/__init__.py +241 -138
  3. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  4. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  5. ultralytics/cfg/datasets/dota8.yaml +34 -0
  6. ultralytics/data/__init__.py +9 -2
  7. ultralytics/data/annotator.py +4 -4
  8. ultralytics/data/augment.py +186 -169
  9. ultralytics/data/base.py +54 -48
  10. ultralytics/data/build.py +34 -23
  11. ultralytics/data/converter.py +242 -70
  12. ultralytics/data/dataset.py +117 -95
  13. ultralytics/data/explorer/__init__.py +5 -0
  14. ultralytics/data/explorer/explorer.py +170 -97
  15. ultralytics/data/explorer/gui/__init__.py +1 -0
  16. ultralytics/data/explorer/gui/dash.py +146 -76
  17. ultralytics/data/explorer/utils.py +87 -25
  18. ultralytics/data/loaders.py +75 -62
  19. ultralytics/data/split_dota.py +44 -36
  20. ultralytics/data/utils.py +160 -142
  21. ultralytics/engine/exporter.py +348 -292
  22. ultralytics/engine/model.py +102 -66
  23. ultralytics/engine/predictor.py +74 -55
  24. ultralytics/engine/results.py +63 -40
  25. ultralytics/engine/trainer.py +192 -144
  26. ultralytics/engine/tuner.py +66 -59
  27. ultralytics/engine/validator.py +31 -26
  28. ultralytics/hub/__init__.py +54 -31
  29. ultralytics/hub/auth.py +28 -25
  30. ultralytics/hub/session.py +282 -133
  31. ultralytics/hub/utils.py +64 -42
  32. ultralytics/models/__init__.py +1 -1
  33. ultralytics/models/fastsam/__init__.py +1 -1
  34. ultralytics/models/fastsam/model.py +6 -6
  35. ultralytics/models/fastsam/predict.py +3 -2
  36. ultralytics/models/fastsam/prompt.py +55 -48
  37. ultralytics/models/fastsam/val.py +1 -1
  38. ultralytics/models/nas/__init__.py +1 -1
  39. ultralytics/models/nas/model.py +9 -8
  40. ultralytics/models/nas/predict.py +8 -6
  41. ultralytics/models/nas/val.py +11 -9
  42. ultralytics/models/rtdetr/__init__.py +1 -1
  43. ultralytics/models/rtdetr/model.py +11 -9
  44. ultralytics/models/rtdetr/train.py +18 -16
  45. ultralytics/models/rtdetr/val.py +25 -19
  46. ultralytics/models/sam/__init__.py +1 -1
  47. ultralytics/models/sam/amg.py +13 -14
  48. ultralytics/models/sam/build.py +44 -42
  49. ultralytics/models/sam/model.py +6 -6
  50. ultralytics/models/sam/modules/decoders.py +6 -4
  51. ultralytics/models/sam/modules/encoders.py +37 -35
  52. ultralytics/models/sam/modules/sam.py +5 -4
  53. ultralytics/models/sam/modules/tiny_encoder.py +95 -73
  54. ultralytics/models/sam/modules/transformer.py +3 -2
  55. ultralytics/models/sam/predict.py +39 -27
  56. ultralytics/models/utils/loss.py +99 -95
  57. ultralytics/models/utils/ops.py +34 -31
  58. ultralytics/models/yolo/__init__.py +1 -1
  59. ultralytics/models/yolo/classify/__init__.py +1 -1
  60. ultralytics/models/yolo/classify/predict.py +8 -6
  61. ultralytics/models/yolo/classify/train.py +37 -31
  62. ultralytics/models/yolo/classify/val.py +26 -24
  63. ultralytics/models/yolo/detect/__init__.py +1 -1
  64. ultralytics/models/yolo/detect/predict.py +8 -6
  65. ultralytics/models/yolo/detect/train.py +47 -37
  66. ultralytics/models/yolo/detect/val.py +100 -82
  67. ultralytics/models/yolo/model.py +31 -25
  68. ultralytics/models/yolo/obb/__init__.py +1 -1
  69. ultralytics/models/yolo/obb/predict.py +13 -12
  70. ultralytics/models/yolo/obb/train.py +3 -3
  71. ultralytics/models/yolo/obb/val.py +80 -58
  72. ultralytics/models/yolo/pose/__init__.py +1 -1
  73. ultralytics/models/yolo/pose/predict.py +17 -12
  74. ultralytics/models/yolo/pose/train.py +28 -25
  75. ultralytics/models/yolo/pose/val.py +91 -64
  76. ultralytics/models/yolo/segment/__init__.py +1 -1
  77. ultralytics/models/yolo/segment/predict.py +10 -8
  78. ultralytics/models/yolo/segment/train.py +16 -15
  79. ultralytics/models/yolo/segment/val.py +90 -68
  80. ultralytics/nn/__init__.py +26 -6
  81. ultralytics/nn/autobackend.py +144 -112
  82. ultralytics/nn/modules/__init__.py +96 -13
  83. ultralytics/nn/modules/block.py +28 -7
  84. ultralytics/nn/modules/conv.py +41 -23
  85. ultralytics/nn/modules/head.py +67 -59
  86. ultralytics/nn/modules/transformer.py +49 -32
  87. ultralytics/nn/modules/utils.py +20 -15
  88. ultralytics/nn/tasks.py +215 -141
  89. ultralytics/solutions/ai_gym.py +59 -47
  90. ultralytics/solutions/distance_calculation.py +22 -15
  91. ultralytics/solutions/heatmap.py +76 -54
  92. ultralytics/solutions/object_counter.py +46 -39
  93. ultralytics/solutions/speed_estimation.py +13 -16
  94. ultralytics/trackers/__init__.py +1 -1
  95. ultralytics/trackers/basetrack.py +1 -0
  96. ultralytics/trackers/bot_sort.py +2 -1
  97. ultralytics/trackers/byte_tracker.py +10 -7
  98. ultralytics/trackers/track.py +7 -7
  99. ultralytics/trackers/utils/gmc.py +25 -25
  100. ultralytics/trackers/utils/kalman_filter.py +85 -42
  101. ultralytics/trackers/utils/matching.py +8 -7
  102. ultralytics/utils/__init__.py +173 -151
  103. ultralytics/utils/autobatch.py +10 -10
  104. ultralytics/utils/benchmarks.py +76 -86
  105. ultralytics/utils/callbacks/__init__.py +1 -1
  106. ultralytics/utils/callbacks/base.py +29 -29
  107. ultralytics/utils/callbacks/clearml.py +51 -43
  108. ultralytics/utils/callbacks/comet.py +81 -66
  109. ultralytics/utils/callbacks/dvc.py +33 -26
  110. ultralytics/utils/callbacks/hub.py +44 -26
  111. ultralytics/utils/callbacks/mlflow.py +31 -24
  112. ultralytics/utils/callbacks/neptune.py +35 -25
  113. ultralytics/utils/callbacks/raytune.py +9 -4
  114. ultralytics/utils/callbacks/tensorboard.py +16 -11
  115. ultralytics/utils/callbacks/wb.py +39 -33
  116. ultralytics/utils/checks.py +189 -141
  117. ultralytics/utils/dist.py +15 -12
  118. ultralytics/utils/downloads.py +112 -96
  119. ultralytics/utils/errors.py +1 -1
  120. ultralytics/utils/files.py +11 -11
  121. ultralytics/utils/instance.py +22 -22
  122. ultralytics/utils/loss.py +117 -67
  123. ultralytics/utils/metrics.py +224 -158
  124. ultralytics/utils/ops.py +39 -29
  125. ultralytics/utils/patches.py +3 -3
  126. ultralytics/utils/plotting.py +217 -120
  127. ultralytics/utils/tal.py +19 -13
  128. ultralytics/utils/torch_utils.py +138 -109
  129. ultralytics/utils/triton.py +12 -10
  130. ultralytics/utils/tuner.py +49 -47
  131. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/METADATA +5 -4
  132. ultralytics-8.0.239.dist-info/RECORD +188 -0
  133. ultralytics-8.0.237.dist-info/RECORD +0 -187
  134. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
  135. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
  136. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
  137. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
@@ -12,28 +12,34 @@ class YOLO(Model):
12
12
  def task_map(self):
13
13
  """Map head to model, trainer, validator, and predictor classes."""
14
14
  return {
15
- 'classify': {
16
- 'model': ClassificationModel,
17
- 'trainer': yolo.classify.ClassificationTrainer,
18
- 'validator': yolo.classify.ClassificationValidator,
19
- 'predictor': yolo.classify.ClassificationPredictor, },
20
- 'detect': {
21
- 'model': DetectionModel,
22
- 'trainer': yolo.detect.DetectionTrainer,
23
- 'validator': yolo.detect.DetectionValidator,
24
- 'predictor': yolo.detect.DetectionPredictor, },
25
- 'segment': {
26
- 'model': SegmentationModel,
27
- 'trainer': yolo.segment.SegmentationTrainer,
28
- 'validator': yolo.segment.SegmentationValidator,
29
- 'predictor': yolo.segment.SegmentationPredictor, },
30
- 'pose': {
31
- 'model': PoseModel,
32
- 'trainer': yolo.pose.PoseTrainer,
33
- 'validator': yolo.pose.PoseValidator,
34
- 'predictor': yolo.pose.PosePredictor, },
35
- 'obb': {
36
- 'model': OBBModel,
37
- 'trainer': yolo.obb.OBBTrainer,
38
- 'validator': yolo.obb.OBBValidator,
39
- 'predictor': yolo.obb.OBBPredictor, }, }
15
+ "classify": {
16
+ "model": ClassificationModel,
17
+ "trainer": yolo.classify.ClassificationTrainer,
18
+ "validator": yolo.classify.ClassificationValidator,
19
+ "predictor": yolo.classify.ClassificationPredictor,
20
+ },
21
+ "detect": {
22
+ "model": DetectionModel,
23
+ "trainer": yolo.detect.DetectionTrainer,
24
+ "validator": yolo.detect.DetectionValidator,
25
+ "predictor": yolo.detect.DetectionPredictor,
26
+ },
27
+ "segment": {
28
+ "model": SegmentationModel,
29
+ "trainer": yolo.segment.SegmentationTrainer,
30
+ "validator": yolo.segment.SegmentationValidator,
31
+ "predictor": yolo.segment.SegmentationPredictor,
32
+ },
33
+ "pose": {
34
+ "model": PoseModel,
35
+ "trainer": yolo.pose.PoseTrainer,
36
+ "validator": yolo.pose.PoseValidator,
37
+ "predictor": yolo.pose.PosePredictor,
38
+ },
39
+ "obb": {
40
+ "model": OBBModel,
41
+ "trainer": yolo.obb.OBBTrainer,
42
+ "validator": yolo.obb.OBBValidator,
43
+ "predictor": yolo.obb.OBBPredictor,
44
+ },
45
+ }
@@ -4,4 +4,4 @@ from .predict import OBBPredictor
4
4
  from .train import OBBTrainer
5
5
  from .val import OBBValidator
6
6
 
7
- __all__ = 'OBBPredictor', 'OBBTrainer', 'OBBValidator'
7
+ __all__ = "OBBPredictor", "OBBTrainer", "OBBValidator"
@@ -23,28 +23,29 @@ class OBBPredictor(DetectionPredictor):
23
23
  """
24
24
 
25
25
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
26
+ """Initializes OBBPredictor with optional model and data configuration overrides."""
26
27
  super().__init__(cfg, overrides, _callbacks)
27
- self.args.task = 'obb'
28
+ self.args.task = "obb"
28
29
 
29
30
  def postprocess(self, preds, img, orig_imgs):
30
31
  """Post-processes predictions and returns a list of Results objects."""
31
- preds = ops.non_max_suppression(preds,
32
- self.args.conf,
33
- self.args.iou,
34
- agnostic=self.args.agnostic_nms,
35
- max_det=self.args.max_det,
36
- nc=len(self.model.names),
37
- classes=self.args.classes,
38
- rotated=True)
32
+ preds = ops.non_max_suppression(
33
+ preds,
34
+ self.args.conf,
35
+ self.args.iou,
36
+ agnostic=self.args.agnostic_nms,
37
+ max_det=self.args.max_det,
38
+ nc=len(self.model.names),
39
+ classes=self.args.classes,
40
+ rotated=True,
41
+ )
39
42
 
40
43
  if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
41
44
  orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
42
45
 
43
46
  results = []
44
- for i, pred in enumerate(preds):
45
- orig_img = orig_imgs[i]
47
+ for i, (pred, orig_img, img_path) in enumerate(zip(preds, orig_imgs, self.batch[0])):
46
48
  pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape, xywh=True)
47
- img_path = self.batch[0][i]
48
49
  # xywh, r, conf, cls
49
50
  obb = torch.cat([pred[:, :4], pred[:, -1:], pred[:, 4:6]], dim=-1)
50
51
  results.append(Results(orig_img, path=img_path, names=self.model.names, obb=obb))
@@ -25,12 +25,12 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
25
25
  """Initialize a OBBTrainer object with given arguments."""
26
26
  if overrides is None:
27
27
  overrides = {}
28
- overrides['task'] = 'obb'
28
+ overrides["task"] = "obb"
29
29
  super().__init__(cfg, overrides, _callbacks)
30
30
 
31
31
  def get_model(self, cfg=None, weights=None, verbose=True):
32
32
  """Return OBBModel initialized with specified config and weights."""
33
- model = OBBModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
33
+ model = OBBModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
34
34
  if weights:
35
35
  model.load(weights)
36
36
 
@@ -38,5 +38,5 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
38
38
 
39
39
  def get_validator(self):
40
40
  """Return an instance of OBBValidator for validation of YOLO model."""
41
- self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
41
+ self.loss_names = "box_loss", "cls_loss", "dfl_loss"
42
42
  return yolo.obb.OBBValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
@@ -27,26 +27,28 @@ class OBBValidator(DetectionValidator):
27
27
  def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
28
28
  """Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics."""
29
29
  super().__init__(dataloader, save_dir, pbar, args, _callbacks)
30
- self.args.task = 'obb'
30
+ self.args.task = "obb"
31
31
  self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True, on_plot=self.on_plot)
32
32
 
33
33
  def init_metrics(self, model):
34
34
  """Initialize evaluation metrics for YOLO."""
35
35
  super().init_metrics(model)
36
- val = self.data.get(self.args.split, '') # validation path
37
- self.is_dota = isinstance(val, str) and 'DOTA' in val # is COCO
36
+ val = self.data.get(self.args.split, "") # validation path
37
+ self.is_dota = isinstance(val, str) and "DOTA" in val # is COCO
38
38
 
39
39
  def postprocess(self, preds):
40
40
  """Apply Non-maximum suppression to prediction outputs."""
41
- return ops.non_max_suppression(preds,
42
- self.args.conf,
43
- self.args.iou,
44
- labels=self.lb,
45
- nc=self.nc,
46
- multi_label=True,
47
- agnostic=self.args.single_cls,
48
- max_det=self.args.max_det,
49
- rotated=True)
41
+ return ops.non_max_suppression(
42
+ preds,
43
+ self.args.conf,
44
+ self.args.iou,
45
+ labels=self.lb,
46
+ nc=self.nc,
47
+ multi_label=True,
48
+ agnostic=self.args.single_cls,
49
+ max_det=self.args.max_det,
50
+ rotated=True,
51
+ )
50
52
 
51
53
  def _process_batch(self, detections, gt_bboxes, gt_cls):
52
54
  """
@@ -61,16 +63,17 @@ class OBBValidator(DetectionValidator):
61
63
  Returns:
62
64
  (torch.Tensor): Correct prediction matrix of shape [N, 10] for 10 IoU levels.
63
65
  """
64
- iou = batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -2:-1]], dim=-1))
66
+ iou = batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
65
67
  return self.match_predictions(detections[:, 5], gt_cls, iou)
66
68
 
67
69
  def _prepare_batch(self, si, batch):
68
- idx = batch['batch_idx'] == si
69
- cls = batch['cls'][idx].squeeze(-1)
70
- bbox = batch['bboxes'][idx]
71
- ori_shape = batch['ori_shape'][si]
72
- imgsz = batch['img'].shape[2:]
73
- ratio_pad = batch['ratio_pad'][si]
70
+ """Prepares and returns a batch for OBB validation."""
71
+ idx = batch["batch_idx"] == si
72
+ cls = batch["cls"][idx].squeeze(-1)
73
+ bbox = batch["bboxes"][idx]
74
+ ori_shape = batch["ori_shape"][si]
75
+ imgsz = batch["img"].shape[2:]
76
+ ratio_pad = batch["ratio_pad"][si]
74
77
  if len(cls):
75
78
  bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes
76
79
  ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels
@@ -78,19 +81,23 @@ class OBBValidator(DetectionValidator):
78
81
  return prepared_batch
79
82
 
80
83
  def _prepare_pred(self, pred, pbatch):
84
+ """Prepares and returns a batch for OBB validation with scaled and padded bounding boxes."""
81
85
  predn = pred.clone()
82
- ops.scale_boxes(pbatch['imgsz'], predn[:, :4], pbatch['ori_shape'], ratio_pad=pbatch['ratio_pad'],
83
- xywh=True) # native-space pred
86
+ ops.scale_boxes(
87
+ pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
88
+ ) # native-space pred
84
89
  return predn
85
90
 
86
91
  def plot_predictions(self, batch, preds, ni):
87
92
  """Plots predicted bounding boxes on input images and saves the result."""
88
- plot_images(batch['img'],
89
- *output_to_rotated_target(preds, max_det=self.args.max_det),
90
- paths=batch['im_file'],
91
- fname=self.save_dir / f'val_batch{ni}_pred.jpg',
92
- names=self.names,
93
- on_plot=self.on_plot) # pred
93
+ plot_images(
94
+ batch["img"],
95
+ *output_to_rotated_target(preds, max_det=self.args.max_det),
96
+ paths=batch["im_file"],
97
+ fname=self.save_dir / f"val_batch{ni}_pred.jpg",
98
+ names=self.names,
99
+ on_plot=self.on_plot,
100
+ ) # pred
94
101
 
95
102
  def pred_to_json(self, predn, filename):
96
103
  """Serialize YOLO predictions to COCO json format."""
@@ -99,12 +106,26 @@ class OBBValidator(DetectionValidator):
99
106
  rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
100
107
  poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
101
108
  for i, (r, b) in enumerate(zip(rbox.tolist(), poly.tolist())):
102
- self.jdict.append({
103
- 'image_id': image_id,
104
- 'category_id': self.class_map[int(predn[i, 5].item())],
105
- 'score': round(predn[i, 4].item(), 5),
106
- 'rbox': [round(x, 3) for x in r],
107
- 'poly': [round(x, 3) for x in b]})
109
+ self.jdict.append(
110
+ {
111
+ "image_id": image_id,
112
+ "category_id": self.class_map[int(predn[i, 5].item())],
113
+ "score": round(predn[i, 4].item(), 5),
114
+ "rbox": [round(x, 3) for x in r],
115
+ "poly": [round(x, 3) for x in b],
116
+ }
117
+ )
118
+
119
+ def save_one_txt(self, predn, save_conf, shape, file):
120
+ """Save YOLO detections to a txt file in normalized coordinates in a specific format."""
121
+ gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh
122
+ for *xyxy, conf, cls, angle in predn.tolist():
123
+ xywha = torch.tensor([*xyxy, angle]).view(1, 5)
124
+ xywha[:, :4] /= gn
125
+ xyxyxyxy = ops.xywhr2xyxyxyxy(xywha).view(-1).tolist() # normalized xywh
126
+ line = (cls, *xyxyxyxy, conf) if save_conf else (cls, *xyxyxyxy) # label format
127
+ with open(file, "a") as f:
128
+ f.write(("%g " * len(line)).rstrip() % line + "\n")
108
129
 
109
130
  def eval_json(self, stats):
110
131
  """Evaluates YOLO output in JSON format and returns performance statistics."""
@@ -112,42 +133,43 @@ class OBBValidator(DetectionValidator):
112
133
  import json
113
134
  import re
114
135
  from collections import defaultdict
115
- pred_json = self.save_dir / 'predictions.json' # predictions
116
- pred_txt = self.save_dir / 'predictions_txt' # predictions
136
+
137
+ pred_json = self.save_dir / "predictions.json" # predictions
138
+ pred_txt = self.save_dir / "predictions_txt" # predictions
117
139
  pred_txt.mkdir(parents=True, exist_ok=True)
118
140
  data = json.load(open(pred_json))
119
141
  # Save split results
120
- LOGGER.info(f'Saving predictions with DOTA format to {str(pred_txt)}...')
142
+ LOGGER.info(f"Saving predictions with DOTA format to {str(pred_txt)}...")
121
143
  for d in data:
122
- image_id = d['image_id']
123
- score = d['score']
124
- classname = self.names[d['category_id']].replace(' ', '-')
144
+ image_id = d["image_id"]
145
+ score = d["score"]
146
+ classname = self.names[d["category_id"]].replace(" ", "-")
125
147
 
126
- lines = '{} {} {} {} {} {} {} {} {} {}\n'.format(
148
+ lines = "{} {} {} {} {} {} {} {} {} {}\n".format(
127
149
  image_id,
128
150
  score,
129
- d['poly'][0],
130
- d['poly'][1],
131
- d['poly'][2],
132
- d['poly'][3],
133
- d['poly'][4],
134
- d['poly'][5],
135
- d['poly'][6],
136
- d['poly'][7],
151
+ d["poly"][0],
152
+ d["poly"][1],
153
+ d["poly"][2],
154
+ d["poly"][3],
155
+ d["poly"][4],
156
+ d["poly"][5],
157
+ d["poly"][6],
158
+ d["poly"][7],
137
159
  )
138
- with open(str(pred_txt / f'Task1_{classname}') + '.txt', 'a') as f:
160
+ with open(str(pred_txt / f"Task1_{classname}") + ".txt", "a") as f:
139
161
  f.writelines(lines)
140
162
  # Save merged results, this could result slightly lower map than using official merging script,
141
163
  # because of the probiou calculation.
142
- pred_merged_txt = self.save_dir / 'predictions_merged_txt' # predictions
164
+ pred_merged_txt = self.save_dir / "predictions_merged_txt" # predictions
143
165
  pred_merged_txt.mkdir(parents=True, exist_ok=True)
144
166
  merged_results = defaultdict(list)
145
- LOGGER.info(f'Saving merged predictions with DOTA format to {str(pred_merged_txt)}...')
167
+ LOGGER.info(f"Saving merged predictions with DOTA format to {str(pred_merged_txt)}...")
146
168
  for d in data:
147
- image_id = d['image_id'].split('__')[0]
148
- pattern = re.compile(r'\d+___\d+')
149
- x, y = (int(c) for c in re.findall(pattern, d['image_id'])[0].split('___'))
150
- bbox, score, cls = d['rbox'], d['score'], d['category_id']
169
+ image_id = d["image_id"].split("__")[0]
170
+ pattern = re.compile(r"\d+___\d+")
171
+ x, y = (int(c) for c in re.findall(pattern, d["image_id"])[0].split("___"))
172
+ bbox, score, cls = d["rbox"], d["score"], d["category_id"]
151
173
  bbox[0] += x
152
174
  bbox[1] += y
153
175
  bbox.extend([score, cls])
@@ -165,11 +187,11 @@ class OBBValidator(DetectionValidator):
165
187
 
166
188
  b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8)
167
189
  for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist():
168
- classname = self.names[int(x[-1])].replace(' ', '-')
190
+ classname = self.names[int(x[-1])].replace(" ", "-")
169
191
  poly = [round(i, 3) for i in x[:-2]]
170
192
  score = round(x[-2], 3)
171
193
 
172
- lines = '{} {} {} {} {} {} {} {} {} {}\n'.format(
194
+ lines = "{} {} {} {} {} {} {} {} {} {}\n".format(
173
195
  image_id,
174
196
  score,
175
197
  poly[0],
@@ -181,7 +203,7 @@ class OBBValidator(DetectionValidator):
181
203
  poly[6],
182
204
  poly[7],
183
205
  )
184
- with open(str(pred_merged_txt / f'Task1_{classname}') + '.txt', 'a') as f:
206
+ with open(str(pred_merged_txt / f"Task1_{classname}") + ".txt", "a") as f:
185
207
  f.writelines(lines)
186
208
 
187
209
  return stats
@@ -4,4 +4,4 @@ from .predict import PosePredictor
4
4
  from .train import PoseTrainer
5
5
  from .val import PoseValidator
6
6
 
7
- __all__ = 'PoseTrainer', 'PoseValidator', 'PosePredictor'
7
+ __all__ = "PoseTrainer", "PoseValidator", "PosePredictor"
@@ -23,20 +23,24 @@ class PosePredictor(DetectionPredictor):
23
23
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
24
24
  """Initializes PosePredictor, sets task to 'pose' and logs a warning for using 'mps' as device."""
25
25
  super().__init__(cfg, overrides, _callbacks)
26
- self.args.task = 'pose'
27
- if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
28
- LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
29
- 'See https://github.com/ultralytics/ultralytics/issues/4031.')
26
+ self.args.task = "pose"
27
+ if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
28
+ LOGGER.warning(
29
+ "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
30
+ "See https://github.com/ultralytics/ultralytics/issues/4031."
31
+ )
30
32
 
31
33
  def postprocess(self, preds, img, orig_imgs):
32
34
  """Return detection results for a given input image or list of images."""
33
- preds = ops.non_max_suppression(preds,
34
- self.args.conf,
35
- self.args.iou,
36
- agnostic=self.args.agnostic_nms,
37
- max_det=self.args.max_det,
38
- classes=self.args.classes,
39
- nc=len(self.model.names))
35
+ preds = ops.non_max_suppression(
36
+ preds,
37
+ self.args.conf,
38
+ self.args.iou,
39
+ agnostic=self.args.agnostic_nms,
40
+ max_det=self.args.max_det,
41
+ classes=self.args.classes,
42
+ nc=len(self.model.names),
43
+ )
40
44
 
41
45
  if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
42
46
  orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
@@ -49,5 +53,6 @@ class PosePredictor(DetectionPredictor):
49
53
  pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
50
54
  img_path = self.batch[0][i]
51
55
  results.append(
52
- Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts))
56
+ Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts)
57
+ )
53
58
  return results
@@ -26,16 +26,18 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
26
26
  """Initialize a PoseTrainer object with specified configurations and overrides."""
27
27
  if overrides is None:
28
28
  overrides = {}
29
- overrides['task'] = 'pose'
29
+ overrides["task"] = "pose"
30
30
  super().__init__(cfg, overrides, _callbacks)
31
31
 
32
- if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
33
- LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
34
- 'See https://github.com/ultralytics/ultralytics/issues/4031.')
32
+ if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
33
+ LOGGER.warning(
34
+ "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
35
+ "See https://github.com/ultralytics/ultralytics/issues/4031."
36
+ )
35
37
 
36
38
  def get_model(self, cfg=None, weights=None, verbose=True):
37
39
  """Get pose estimation model with specified configuration and weights."""
38
- model = PoseModel(cfg, ch=3, nc=self.data['nc'], data_kpt_shape=self.data['kpt_shape'], verbose=verbose)
40
+ model = PoseModel(cfg, ch=3, nc=self.data["nc"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose)
39
41
  if weights:
40
42
  model.load(weights)
41
43
 
@@ -44,32 +46,33 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
44
46
  def set_model_attributes(self):
45
47
  """Sets keypoints shape attribute of PoseModel."""
46
48
  super().set_model_attributes()
47
- self.model.kpt_shape = self.data['kpt_shape']
49
+ self.model.kpt_shape = self.data["kpt_shape"]
48
50
 
49
51
  def get_validator(self):
50
52
  """Returns an instance of the PoseValidator class for validation."""
51
- self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss'
52
- return yolo.pose.PoseValidator(self.test_loader,
53
- save_dir=self.save_dir,
54
- args=copy(self.args),
55
- _callbacks=self.callbacks)
53
+ self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
54
+ return yolo.pose.PoseValidator(
55
+ self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
56
+ )
56
57
 
57
58
  def plot_training_samples(self, batch, ni):
58
59
  """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
59
- images = batch['img']
60
- kpts = batch['keypoints']
61
- cls = batch['cls'].squeeze(-1)
62
- bboxes = batch['bboxes']
63
- paths = batch['im_file']
64
- batch_idx = batch['batch_idx']
65
- plot_images(images,
66
- batch_idx,
67
- cls,
68
- bboxes,
69
- kpts=kpts,
70
- paths=paths,
71
- fname=self.save_dir / f'train_batch{ni}.jpg',
72
- on_plot=self.on_plot)
60
+ images = batch["img"]
61
+ kpts = batch["keypoints"]
62
+ cls = batch["cls"].squeeze(-1)
63
+ bboxes = batch["bboxes"]
64
+ paths = batch["im_file"]
65
+ batch_idx = batch["batch_idx"]
66
+ plot_images(
67
+ images,
68
+ batch_idx,
69
+ cls,
70
+ bboxes,
71
+ kpts=kpts,
72
+ paths=paths,
73
+ fname=self.save_dir / f"train_batch{ni}.jpg",
74
+ on_plot=self.on_plot,
75
+ )
73
76
 
74
77
  def plot_metrics(self):
75
78
  """Plots training/val metrics."""