ultralytics 8.0.238__py3-none-any.whl → 8.0.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (134) hide show
  1. ultralytics/__init__.py +2 -2
  2. ultralytics/cfg/__init__.py +241 -138
  3. ultralytics/data/__init__.py +9 -2
  4. ultralytics/data/annotator.py +4 -4
  5. ultralytics/data/augment.py +186 -169
  6. ultralytics/data/base.py +54 -48
  7. ultralytics/data/build.py +34 -23
  8. ultralytics/data/converter.py +242 -70
  9. ultralytics/data/dataset.py +117 -95
  10. ultralytics/data/explorer/__init__.py +3 -1
  11. ultralytics/data/explorer/explorer.py +120 -100
  12. ultralytics/data/explorer/gui/__init__.py +1 -0
  13. ultralytics/data/explorer/gui/dash.py +123 -89
  14. ultralytics/data/explorer/utils.py +37 -39
  15. ultralytics/data/loaders.py +75 -62
  16. ultralytics/data/split_dota.py +44 -36
  17. ultralytics/data/utils.py +160 -142
  18. ultralytics/engine/exporter.py +348 -292
  19. ultralytics/engine/model.py +102 -66
  20. ultralytics/engine/predictor.py +74 -55
  21. ultralytics/engine/results.py +61 -41
  22. ultralytics/engine/trainer.py +192 -144
  23. ultralytics/engine/tuner.py +66 -59
  24. ultralytics/engine/validator.py +31 -26
  25. ultralytics/hub/__init__.py +54 -31
  26. ultralytics/hub/auth.py +28 -25
  27. ultralytics/hub/session.py +282 -133
  28. ultralytics/hub/utils.py +64 -42
  29. ultralytics/models/__init__.py +1 -1
  30. ultralytics/models/fastsam/__init__.py +1 -1
  31. ultralytics/models/fastsam/model.py +6 -6
  32. ultralytics/models/fastsam/predict.py +3 -2
  33. ultralytics/models/fastsam/prompt.py +55 -48
  34. ultralytics/models/fastsam/val.py +1 -1
  35. ultralytics/models/nas/__init__.py +1 -1
  36. ultralytics/models/nas/model.py +9 -8
  37. ultralytics/models/nas/predict.py +8 -6
  38. ultralytics/models/nas/val.py +11 -9
  39. ultralytics/models/rtdetr/__init__.py +1 -1
  40. ultralytics/models/rtdetr/model.py +11 -9
  41. ultralytics/models/rtdetr/train.py +18 -16
  42. ultralytics/models/rtdetr/val.py +25 -19
  43. ultralytics/models/sam/__init__.py +1 -1
  44. ultralytics/models/sam/amg.py +13 -14
  45. ultralytics/models/sam/build.py +44 -42
  46. ultralytics/models/sam/model.py +6 -6
  47. ultralytics/models/sam/modules/decoders.py +6 -4
  48. ultralytics/models/sam/modules/encoders.py +37 -35
  49. ultralytics/models/sam/modules/sam.py +5 -4
  50. ultralytics/models/sam/modules/tiny_encoder.py +95 -73
  51. ultralytics/models/sam/modules/transformer.py +3 -2
  52. ultralytics/models/sam/predict.py +39 -27
  53. ultralytics/models/utils/loss.py +99 -95
  54. ultralytics/models/utils/ops.py +34 -31
  55. ultralytics/models/yolo/__init__.py +1 -1
  56. ultralytics/models/yolo/classify/__init__.py +1 -1
  57. ultralytics/models/yolo/classify/predict.py +8 -6
  58. ultralytics/models/yolo/classify/train.py +37 -31
  59. ultralytics/models/yolo/classify/val.py +26 -24
  60. ultralytics/models/yolo/detect/__init__.py +1 -1
  61. ultralytics/models/yolo/detect/predict.py +8 -6
  62. ultralytics/models/yolo/detect/train.py +47 -37
  63. ultralytics/models/yolo/detect/val.py +100 -82
  64. ultralytics/models/yolo/model.py +31 -25
  65. ultralytics/models/yolo/obb/__init__.py +1 -1
  66. ultralytics/models/yolo/obb/predict.py +13 -11
  67. ultralytics/models/yolo/obb/train.py +3 -3
  68. ultralytics/models/yolo/obb/val.py +70 -59
  69. ultralytics/models/yolo/pose/__init__.py +1 -1
  70. ultralytics/models/yolo/pose/predict.py +17 -12
  71. ultralytics/models/yolo/pose/train.py +28 -25
  72. ultralytics/models/yolo/pose/val.py +91 -64
  73. ultralytics/models/yolo/segment/__init__.py +1 -1
  74. ultralytics/models/yolo/segment/predict.py +10 -8
  75. ultralytics/models/yolo/segment/train.py +16 -15
  76. ultralytics/models/yolo/segment/val.py +90 -68
  77. ultralytics/nn/__init__.py +26 -6
  78. ultralytics/nn/autobackend.py +144 -112
  79. ultralytics/nn/modules/__init__.py +96 -13
  80. ultralytics/nn/modules/block.py +28 -7
  81. ultralytics/nn/modules/conv.py +41 -23
  82. ultralytics/nn/modules/head.py +60 -52
  83. ultralytics/nn/modules/transformer.py +49 -32
  84. ultralytics/nn/modules/utils.py +20 -15
  85. ultralytics/nn/tasks.py +215 -141
  86. ultralytics/solutions/ai_gym.py +59 -47
  87. ultralytics/solutions/distance_calculation.py +17 -14
  88. ultralytics/solutions/heatmap.py +57 -55
  89. ultralytics/solutions/object_counter.py +46 -39
  90. ultralytics/solutions/speed_estimation.py +13 -16
  91. ultralytics/trackers/__init__.py +1 -1
  92. ultralytics/trackers/basetrack.py +1 -0
  93. ultralytics/trackers/bot_sort.py +2 -1
  94. ultralytics/trackers/byte_tracker.py +10 -7
  95. ultralytics/trackers/track.py +7 -7
  96. ultralytics/trackers/utils/gmc.py +25 -25
  97. ultralytics/trackers/utils/kalman_filter.py +85 -42
  98. ultralytics/trackers/utils/matching.py +8 -7
  99. ultralytics/utils/__init__.py +173 -152
  100. ultralytics/utils/autobatch.py +10 -10
  101. ultralytics/utils/benchmarks.py +76 -86
  102. ultralytics/utils/callbacks/__init__.py +1 -1
  103. ultralytics/utils/callbacks/base.py +29 -29
  104. ultralytics/utils/callbacks/clearml.py +51 -43
  105. ultralytics/utils/callbacks/comet.py +81 -66
  106. ultralytics/utils/callbacks/dvc.py +33 -26
  107. ultralytics/utils/callbacks/hub.py +44 -26
  108. ultralytics/utils/callbacks/mlflow.py +31 -24
  109. ultralytics/utils/callbacks/neptune.py +35 -25
  110. ultralytics/utils/callbacks/raytune.py +9 -4
  111. ultralytics/utils/callbacks/tensorboard.py +16 -11
  112. ultralytics/utils/callbacks/wb.py +39 -33
  113. ultralytics/utils/checks.py +189 -141
  114. ultralytics/utils/dist.py +15 -12
  115. ultralytics/utils/downloads.py +112 -96
  116. ultralytics/utils/errors.py +1 -1
  117. ultralytics/utils/files.py +11 -11
  118. ultralytics/utils/instance.py +22 -22
  119. ultralytics/utils/loss.py +117 -67
  120. ultralytics/utils/metrics.py +224 -158
  121. ultralytics/utils/ops.py +38 -28
  122. ultralytics/utils/patches.py +3 -3
  123. ultralytics/utils/plotting.py +217 -120
  124. ultralytics/utils/tal.py +19 -13
  125. ultralytics/utils/torch_utils.py +138 -109
  126. ultralytics/utils/triton.py +12 -10
  127. ultralytics/utils/tuner.py +49 -47
  128. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/METADATA +2 -1
  129. ultralytics-8.0.239.dist-info/RECORD +188 -0
  130. ultralytics-8.0.238.dist-info/RECORD +0 -188
  131. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
  132. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
  133. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
  134. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ from .base import BaseDataset
18
18
  from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label
19
19
 
20
20
  # Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
21
- DATASET_CACHE_VERSION = '1.0.3'
21
+ DATASET_CACHE_VERSION = "1.0.3"
22
22
 
23
23
 
24
24
  class YOLODataset(BaseDataset):
@@ -33,16 +33,16 @@ class YOLODataset(BaseDataset):
33
33
  (torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
34
34
  """
35
35
 
36
- def __init__(self, *args, data=None, task='detect', **kwargs):
36
+ def __init__(self, *args, data=None, task="detect", **kwargs):
37
37
  """Initializes the YOLODataset with optional configurations for segments and keypoints."""
38
- self.use_segments = task == 'segment'
39
- self.use_keypoints = task == 'pose'
40
- self.use_obb = task == 'obb'
38
+ self.use_segments = task == "segment"
39
+ self.use_keypoints = task == "pose"
40
+ self.use_obb = task == "obb"
41
41
  self.data = data
42
- assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.'
42
+ assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
43
43
  super().__init__(*args, **kwargs)
44
44
 
45
- def cache_labels(self, path=Path('./labels.cache')):
45
+ def cache_labels(self, path=Path("./labels.cache")):
46
46
  """
47
47
  Cache dataset labels, check images and read shapes.
48
48
 
@@ -51,19 +51,29 @@ class YOLODataset(BaseDataset):
51
51
  Returns:
52
52
  (dict): labels.
53
53
  """
54
- x = {'labels': []}
54
+ x = {"labels": []}
55
55
  nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
56
- desc = f'{self.prefix}Scanning {path.parent / path.stem}...'
56
+ desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
57
57
  total = len(self.im_files)
58
- nkpt, ndim = self.data.get('kpt_shape', (0, 0))
58
+ nkpt, ndim = self.data.get("kpt_shape", (0, 0))
59
59
  if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)):
60
- raise ValueError("'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
61
- "keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'")
60
+ raise ValueError(
61
+ "'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
62
+ "keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'"
63
+ )
62
64
  with ThreadPool(NUM_THREADS) as pool:
63
- results = pool.imap(func=verify_image_label,
64
- iterable=zip(self.im_files, self.label_files, repeat(self.prefix),
65
- repeat(self.use_keypoints), repeat(len(self.data['names'])), repeat(nkpt),
66
- repeat(ndim)))
65
+ results = pool.imap(
66
+ func=verify_image_label,
67
+ iterable=zip(
68
+ self.im_files,
69
+ self.label_files,
70
+ repeat(self.prefix),
71
+ repeat(self.use_keypoints),
72
+ repeat(len(self.data["names"])),
73
+ repeat(nkpt),
74
+ repeat(ndim),
75
+ ),
76
+ )
67
77
  pbar = TQDM(results, desc=desc, total=total)
68
78
  for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
69
79
  nm += nm_f
@@ -71,7 +81,7 @@ class YOLODataset(BaseDataset):
71
81
  ne += ne_f
72
82
  nc += nc_f
73
83
  if im_file:
74
- x['labels'].append(
84
+ x["labels"].append(
75
85
  dict(
76
86
  im_file=im_file,
77
87
  shape=shape,
@@ -80,60 +90,63 @@ class YOLODataset(BaseDataset):
80
90
  segments=segments,
81
91
  keypoints=keypoint,
82
92
  normalized=True,
83
- bbox_format='xywh'))
93
+ bbox_format="xywh",
94
+ )
95
+ )
84
96
  if msg:
85
97
  msgs.append(msg)
86
- pbar.desc = f'{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt'
98
+ pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
87
99
  pbar.close()
88
100
 
89
101
  if msgs:
90
- LOGGER.info('\n'.join(msgs))
102
+ LOGGER.info("\n".join(msgs))
91
103
  if nf == 0:
92
- LOGGER.warning(f'{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}')
93
- x['hash'] = get_hash(self.label_files + self.im_files)
94
- x['results'] = nf, nm, ne, nc, len(self.im_files)
95
- x['msgs'] = msgs # warnings
104
+ LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}")
105
+ x["hash"] = get_hash(self.label_files + self.im_files)
106
+ x["results"] = nf, nm, ne, nc, len(self.im_files)
107
+ x["msgs"] = msgs # warnings
96
108
  save_dataset_cache_file(self.prefix, path, x)
97
109
  return x
98
110
 
99
111
  def get_labels(self):
100
112
  """Returns dictionary of labels for YOLO training."""
101
113
  self.label_files = img2label_paths(self.im_files)
102
- cache_path = Path(self.label_files[0]).parent.with_suffix('.cache')
114
+ cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
103
115
  try:
104
116
  cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
105
- assert cache['version'] == DATASET_CACHE_VERSION # matches current version
106
- assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash
117
+ assert cache["version"] == DATASET_CACHE_VERSION # matches current version
118
+ assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
107
119
  except (FileNotFoundError, AssertionError, AttributeError):
108
120
  cache, exists = self.cache_labels(cache_path), False # run cache ops
109
121
 
110
122
  # Display cache
111
- nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
123
+ nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
112
124
  if exists and LOCAL_RANK in (-1, 0):
113
- d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt'
125
+ d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
114
126
  TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
115
- if cache['msgs']:
116
- LOGGER.info('\n'.join(cache['msgs'])) # display warnings
127
+ if cache["msgs"]:
128
+ LOGGER.info("\n".join(cache["msgs"])) # display warnings
117
129
 
118
130
  # Read cache
119
- [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
120
- labels = cache['labels']
131
+ [cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
132
+ labels = cache["labels"]
121
133
  if not labels:
122
- LOGGER.warning(f'WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}')
123
- self.im_files = [lb['im_file'] for lb in labels] # update im_files
134
+ LOGGER.warning(f"WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}")
135
+ self.im_files = [lb["im_file"] for lb in labels] # update im_files
124
136
 
125
137
  # Check if the dataset is all boxes or all segments
126
- lengths = ((len(lb['cls']), len(lb['bboxes']), len(lb['segments'])) for lb in labels)
138
+ lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels)
127
139
  len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
128
140
  if len_segments and len_boxes != len_segments:
129
141
  LOGGER.warning(
130
- f'WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, '
131
- f'len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. '
132
- 'To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.')
142
+ f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, "
143
+ f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
144
+ "To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset."
145
+ )
133
146
  for lb in labels:
134
- lb['segments'] = []
147
+ lb["segments"] = []
135
148
  if len_cls == 0:
136
- LOGGER.warning(f'WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}')
149
+ LOGGER.warning(f"WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}")
137
150
  return labels
138
151
 
139
152
  def build_transforms(self, hyp=None):
@@ -145,14 +158,17 @@ class YOLODataset(BaseDataset):
145
158
  else:
146
159
  transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
147
160
  transforms.append(
148
- Format(bbox_format='xywh',
149
- normalize=True,
150
- return_mask=self.use_segments,
151
- return_keypoint=self.use_keypoints,
152
- return_obb=self.use_obb,
153
- batch_idx=True,
154
- mask_ratio=hyp.mask_ratio,
155
- mask_overlap=hyp.overlap_mask))
161
+ Format(
162
+ bbox_format="xywh",
163
+ normalize=True,
164
+ return_mask=self.use_segments,
165
+ return_keypoint=self.use_keypoints,
166
+ return_obb=self.use_obb,
167
+ batch_idx=True,
168
+ mask_ratio=hyp.mask_ratio,
169
+ mask_overlap=hyp.overlap_mask,
170
+ )
171
+ )
156
172
  return transforms
157
173
 
158
174
  def close_mosaic(self, hyp):
@@ -166,11 +182,11 @@ class YOLODataset(BaseDataset):
166
182
  """Custom your label format here."""
167
183
  # NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label
168
184
  # We can make it also support classification and semantic segmentation by add or remove some dict keys there.
169
- bboxes = label.pop('bboxes')
170
- segments = label.pop('segments', [])
171
- keypoints = label.pop('keypoints', None)
172
- bbox_format = label.pop('bbox_format')
173
- normalized = label.pop('normalized')
185
+ bboxes = label.pop("bboxes")
186
+ segments = label.pop("segments", [])
187
+ keypoints = label.pop("keypoints", None)
188
+ bbox_format = label.pop("bbox_format")
189
+ normalized = label.pop("normalized")
174
190
 
175
191
  # NOTE: do NOT resample oriented boxes
176
192
  segment_resamples = 100 if self.use_obb else 1000
@@ -180,7 +196,7 @@ class YOLODataset(BaseDataset):
180
196
  segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0)
181
197
  else:
182
198
  segments = np.zeros((0, segment_resamples, 2), dtype=np.float32)
183
- label['instances'] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
199
+ label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
184
200
  return label
185
201
 
186
202
  @staticmethod
@@ -191,15 +207,15 @@ class YOLODataset(BaseDataset):
191
207
  values = list(zip(*[list(b.values()) for b in batch]))
192
208
  for i, k in enumerate(keys):
193
209
  value = values[i]
194
- if k == 'img':
210
+ if k == "img":
195
211
  value = torch.stack(value, 0)
196
- if k in ['masks', 'keypoints', 'bboxes', 'cls', 'segments', 'obb']:
212
+ if k in ["masks", "keypoints", "bboxes", "cls", "segments", "obb"]:
197
213
  value = torch.cat(value, 0)
198
214
  new_batch[k] = value
199
- new_batch['batch_idx'] = list(new_batch['batch_idx'])
200
- for i in range(len(new_batch['batch_idx'])):
201
- new_batch['batch_idx'][i] += i # add target image index for build_targets()
202
- new_batch['batch_idx'] = torch.cat(new_batch['batch_idx'], 0)
215
+ new_batch["batch_idx"] = list(new_batch["batch_idx"])
216
+ for i in range(len(new_batch["batch_idx"])):
217
+ new_batch["batch_idx"][i] += i # add target image index for build_targets()
218
+ new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
203
219
  return new_batch
204
220
 
205
221
 
@@ -219,7 +235,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
219
235
  album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True.
220
236
  """
221
237
 
222
- def __init__(self, root, args, augment=False, cache=False, prefix=''):
238
+ def __init__(self, root, args, augment=False, cache=False, prefix=""):
223
239
  """
224
240
  Initialize YOLO object with root, image size, augmentations, and cache settings.
225
241
 
@@ -231,23 +247,28 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
231
247
  """
232
248
  super().__init__(root=root)
233
249
  if augment and args.fraction < 1.0: # reduce training fraction
234
- self.samples = self.samples[:round(len(self.samples) * args.fraction)]
235
- self.prefix = colorstr(f'{prefix}: ') if prefix else ''
236
- self.cache_ram = cache is True or cache == 'ram'
237
- self.cache_disk = cache == 'disk'
250
+ self.samples = self.samples[: round(len(self.samples) * args.fraction)]
251
+ self.prefix = colorstr(f"{prefix}: ") if prefix else ""
252
+ self.cache_ram = cache is True or cache == "ram"
253
+ self.cache_disk = cache == "disk"
238
254
  self.samples = self.verify_images() # filter out bad images
239
- self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
255
+ self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
240
256
  scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
241
- self.torch_transforms = classify_augmentations(size=args.imgsz,
242
- scale=scale,
243
- hflip=args.fliplr,
244
- vflip=args.flipud,
245
- erasing=args.erasing,
246
- auto_augment=args.auto_augment,
247
- hsv_h=args.hsv_h,
248
- hsv_s=args.hsv_s,
249
- hsv_v=args.hsv_v) if augment else classify_transforms(
250
- size=args.imgsz, crop_fraction=args.crop_fraction)
257
+ self.torch_transforms = (
258
+ classify_augmentations(
259
+ size=args.imgsz,
260
+ scale=scale,
261
+ hflip=args.fliplr,
262
+ vflip=args.flipud,
263
+ erasing=args.erasing,
264
+ auto_augment=args.auto_augment,
265
+ hsv_h=args.hsv_h,
266
+ hsv_s=args.hsv_s,
267
+ hsv_v=args.hsv_v,
268
+ )
269
+ if augment
270
+ else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction)
271
+ )
251
272
 
252
273
  def __getitem__(self, i):
253
274
  """Returns subset of data and targets corresponding to given indices."""
@@ -263,7 +284,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
263
284
  # Convert NumPy array to PIL image
264
285
  im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
265
286
  sample = self.torch_transforms(im)
266
- return {'img': sample, 'cls': j}
287
+ return {"img": sample, "cls": j}
267
288
 
268
289
  def __len__(self) -> int:
269
290
  """Return the total number of samples in the dataset."""
@@ -271,19 +292,19 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
271
292
 
272
293
  def verify_images(self):
273
294
  """Verify all images in dataset."""
274
- desc = f'{self.prefix}Scanning {self.root}...'
275
- path = Path(self.root).with_suffix('.cache') # *.cache file path
295
+ desc = f"{self.prefix}Scanning {self.root}..."
296
+ path = Path(self.root).with_suffix(".cache") # *.cache file path
276
297
 
277
298
  with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
278
299
  cache = load_dataset_cache_file(path) # attempt to load a *.cache file
279
- assert cache['version'] == DATASET_CACHE_VERSION # matches current version
280
- assert cache['hash'] == get_hash([x[0] for x in self.samples]) # identical hash
281
- nf, nc, n, samples = cache.pop('results') # found, missing, empty, corrupt, total
300
+ assert cache["version"] == DATASET_CACHE_VERSION # matches current version
301
+ assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
302
+ nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
282
303
  if LOCAL_RANK in (-1, 0):
283
- d = f'{desc} {nf} images, {nc} corrupt'
304
+ d = f"{desc} {nf} images, {nc} corrupt"
284
305
  TQDM(None, desc=d, total=n, initial=n)
285
- if cache['msgs']:
286
- LOGGER.info('\n'.join(cache['msgs'])) # display warnings
306
+ if cache["msgs"]:
307
+ LOGGER.info("\n".join(cache["msgs"])) # display warnings
287
308
  return samples
288
309
 
289
310
  # Run scan if *.cache retrieval failed
@@ -298,13 +319,13 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
298
319
  msgs.append(msg)
299
320
  nf += nf_f
300
321
  nc += nc_f
301
- pbar.desc = f'{desc} {nf} images, {nc} corrupt'
322
+ pbar.desc = f"{desc} {nf} images, {nc} corrupt"
302
323
  pbar.close()
303
324
  if msgs:
304
- LOGGER.info('\n'.join(msgs))
305
- x['hash'] = get_hash([x[0] for x in self.samples])
306
- x['results'] = nf, nc, len(samples), samples
307
- x['msgs'] = msgs # warnings
325
+ LOGGER.info("\n".join(msgs))
326
+ x["hash"] = get_hash([x[0] for x in self.samples])
327
+ x["results"] = nf, nc, len(samples), samples
328
+ x["msgs"] = msgs # warnings
308
329
  save_dataset_cache_file(self.prefix, path, x)
309
330
  return samples
310
331
 
@@ -312,6 +333,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
312
333
  def load_dataset_cache_file(path):
313
334
  """Load an Ultralytics *.cache dictionary from path."""
314
335
  import gc
336
+
315
337
  gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
316
338
  cache = np.load(str(path), allow_pickle=True).item() # load dict
317
339
  gc.enable()
@@ -320,15 +342,15 @@ def load_dataset_cache_file(path):
320
342
 
321
343
  def save_dataset_cache_file(prefix, path, x):
322
344
  """Save an Ultralytics dataset *.cache dictionary x to path."""
323
- x['version'] = DATASET_CACHE_VERSION # add cache version
345
+ x["version"] = DATASET_CACHE_VERSION # add cache version
324
346
  if is_dir_writeable(path.parent):
325
347
  if path.exists():
326
348
  path.unlink() # remove *.cache file if exists
327
349
  np.save(str(path), x) # save cache for next time
328
- path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
329
- LOGGER.info(f'{prefix}New cache created: {path}')
350
+ path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
351
+ LOGGER.info(f"{prefix}New cache created: {path}")
330
352
  else:
331
- LOGGER.warning(f'{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.')
353
+ LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")
332
354
 
333
355
 
334
356
  # TODO: support semantic segmentation
@@ -1,3 +1,5 @@
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
1
3
  from .utils import plot_query_result
2
4
 
3
- __all__ = ['plot_query_result']
5
+ __all__ = ["plot_query_result"]