dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -1,32 +1,36 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  import math
4
6
  import random
5
7
  from copy import copy
8
+ from typing import Any
6
9
 
7
10
  import numpy as np
11
+ import torch
8
12
  import torch.nn as nn
9
13
 
10
14
  from ultralytics.data import build_dataloader, build_yolo_dataset
11
15
  from ultralytics.engine.trainer import BaseTrainer
12
16
  from ultralytics.models import yolo
13
17
  from ultralytics.nn.tasks import DetectionModel
14
- from ultralytics.utils import LOGGER, RANK
15
- from ultralytics.utils.plotting import plot_images, plot_labels, plot_results
16
- from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first
18
+ from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
19
+ from ultralytics.utils.patches import override_configs
20
+ from ultralytics.utils.plotting import plot_images, plot_labels
21
+ from ultralytics.utils.torch_utils import torch_distributed_zero_first, unwrap_model
17
22
 
18
23
 
19
24
  class DetectionTrainer(BaseTrainer):
20
- """
21
- A class extending the BaseTrainer class for training based on a detection model.
25
+ """A class extending the BaseTrainer class for training based on a detection model.
22
26
 
23
- This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models
24
- for object detection.
27
+ This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models for
28
+ object detection including dataset building, data loading, preprocessing, and model configuration.
25
29
 
26
30
  Attributes:
27
31
  model (DetectionModel): The YOLO detection model being trained.
28
32
  data (dict): Dictionary containing dataset information including class names and number of classes.
29
- loss_names (Tuple[str]): Names of the loss components used in training (box_loss, cls_loss, dfl_loss).
33
+ loss_names (tuple): Names of the loss components used in training (box_loss, cls_loss, dfl_loss).
30
34
 
31
35
  Methods:
32
36
  build_dataset: Build YOLO dataset for training or validation.
@@ -38,7 +42,6 @@ class DetectionTrainer(BaseTrainer):
38
42
  label_loss_items: Return a loss dictionary with labeled training loss items.
39
43
  progress_string: Return a formatted string of training progress.
40
44
  plot_training_samples: Plot training samples with their annotations.
41
- plot_metrics: Plot metrics from a CSV file.
42
45
  plot_training_labels: Create a labeled training plot of the YOLO model.
43
46
  auto_batch: Calculate optimal batch size based on model memory requirements.
44
47
 
@@ -49,24 +52,32 @@ class DetectionTrainer(BaseTrainer):
49
52
  >>> trainer.train()
50
53
  """
51
54
 
52
- def build_dataset(self, img_path, mode="train", batch=None):
55
+ def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
56
+ """Initialize a DetectionTrainer object for training YOLO object detection model training.
57
+
58
+ Args:
59
+ cfg (dict, optional): Default configuration dictionary containing training parameters.
60
+ overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
61
+ _callbacks (list, optional): List of callback functions to be executed during training.
53
62
  """
54
- Build YOLO Dataset for training or validation.
63
+ super().__init__(cfg, overrides, _callbacks)
64
+
65
+ def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
66
+ """Build YOLO Dataset for training or validation.
55
67
 
56
68
  Args:
57
69
  img_path (str): Path to the folder containing images.
58
- mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
59
- batch (int, optional): Size of batches, this is for `rect`.
70
+ mode (str): 'train' mode or 'val' mode, users are able to customize different augmentations for each mode.
71
+ batch (int, optional): Size of batches, this is for 'rect' mode.
60
72
 
61
73
  Returns:
62
74
  (Dataset): YOLO dataset object configured for the specified mode.
63
75
  """
64
- gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
76
+ gs = max(int(unwrap_model(self.model).stride.max() if self.model else 0), 32)
65
77
  return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
66
78
 
67
- def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
68
- """
69
- Construct and return dataloader for the specified mode.
79
+ def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
80
+ """Construct and return dataloader for the specified mode.
70
81
 
71
82
  Args:
72
83
  dataset_path (str): Path to the dataset.
@@ -84,12 +95,17 @@ class DetectionTrainer(BaseTrainer):
84
95
  if getattr(dataset, "rect", False) and shuffle:
85
96
  LOGGER.warning("'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
86
97
  shuffle = False
87
- workers = self.args.workers if mode == "train" else self.args.workers * 2
88
- return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader
98
+ return build_dataloader(
99
+ dataset,
100
+ batch=batch_size,
101
+ workers=self.args.workers if mode == "train" else self.args.workers * 2,
102
+ shuffle=shuffle,
103
+ rank=rank,
104
+ drop_last=self.args.compile and mode == "train",
105
+ )
89
106
 
90
- def preprocess_batch(self, batch):
91
- """
92
- Preprocess a batch of images by scaling and converting to float.
107
+ def preprocess_batch(self, batch: dict) -> dict:
108
+ """Preprocess a batch of images by scaling and converting to float.
93
109
 
94
110
  Args:
95
111
  batch (dict): Dictionary containing batch data with 'img' tensor.
@@ -97,7 +113,10 @@ class DetectionTrainer(BaseTrainer):
97
113
  Returns:
98
114
  (dict): Preprocessed batch with normalized images.
99
115
  """
100
- batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
116
+ for k, v in batch.items():
117
+ if isinstance(v, torch.Tensor):
118
+ batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
119
+ batch["img"] = batch["img"].float() / 255
101
120
  if self.args.multi_scale:
102
121
  imgs = batch["img"]
103
122
  sz = (
@@ -125,9 +144,8 @@ class DetectionTrainer(BaseTrainer):
125
144
  self.model.args = self.args # attach hyperparameters to model
126
145
  # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
127
146
 
128
- def get_model(self, cfg=None, weights=None, verbose=True):
129
- """
130
- Return a YOLO detection model.
147
+ def get_model(self, cfg: str | None = None, weights: str | None = None, verbose: bool = True):
148
+ """Return a YOLO detection model.
131
149
 
132
150
  Args:
133
151
  cfg (str, optional): Path to model configuration file.
@@ -149,16 +167,15 @@ class DetectionTrainer(BaseTrainer):
149
167
  self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
150
168
  )
151
169
 
152
- def label_loss_items(self, loss_items=None, prefix="train"):
153
- """
154
- Return a loss dict with labeled training loss items tensor.
170
+ def label_loss_items(self, loss_items: list[float] | None = None, prefix: str = "train"):
171
+ """Return a loss dict with labeled training loss items tensor.
155
172
 
156
173
  Args:
157
- loss_items (List[float], optional): List of loss values.
174
+ loss_items (list[float], optional): List of loss values.
158
175
  prefix (str): Prefix for keys in the returned dictionary.
159
176
 
160
177
  Returns:
161
- (Dict | List): Dictionary of labeled loss items if loss_items is provided, otherwise list of keys.
178
+ (dict | list): Dictionary of labeled loss items if loss_items is provided, otherwise list of keys.
162
179
  """
163
180
  keys = [f"{prefix}/{x}" for x in self.loss_names]
164
181
  if loss_items is not None:
@@ -177,28 +194,20 @@ class DetectionTrainer(BaseTrainer):
177
194
  "Size",
178
195
  )
179
196
 
180
- def plot_training_samples(self, batch, ni):
181
- """
182
- Plot training samples with their annotations.
197
+ def plot_training_samples(self, batch: dict[str, Any], ni: int) -> None:
198
+ """Plot training samples with their annotations.
183
199
 
184
200
  Args:
185
- batch (dict): Dictionary containing batch data.
201
+ batch (dict[str, Any]): Dictionary containing batch data.
186
202
  ni (int): Number of iterations.
187
203
  """
188
204
  plot_images(
189
- images=batch["img"],
190
- batch_idx=batch["batch_idx"],
191
- cls=batch["cls"].squeeze(-1),
192
- bboxes=batch["bboxes"],
205
+ labels=batch,
193
206
  paths=batch["im_file"],
194
207
  fname=self.save_dir / f"train_batch{ni}.jpg",
195
208
  on_plot=self.on_plot,
196
209
  )
197
210
 
198
- def plot_metrics(self):
199
- """Plot metrics from a CSV file."""
200
- plot_results(file=self.csv, on_plot=self.on_plot) # save results.png
201
-
202
211
  def plot_training_labels(self):
203
212
  """Create a labeled training plot of the YOLO model."""
204
213
  boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
@@ -206,12 +215,13 @@ class DetectionTrainer(BaseTrainer):
206
215
  plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
207
216
 
208
217
  def auto_batch(self):
209
- """
210
- Get optimal batch size by calculating memory occupation of model.
218
+ """Get optimal batch size by calculating memory occupation of model.
211
219
 
212
220
  Returns:
213
221
  (int): Optimal batch size.
214
222
  """
215
- train_dataset = self.build_dataset(self.data["train"], mode="train", batch=16)
223
+ with override_configs(self.args, overrides={"cache": False}) as self.args:
224
+ train_dataset = self.build_dataset(self.data["train"], mode="train", batch=16)
216
225
  max_num_obj = max(len(label["cls"]) for label in train_dataset.labels) * 4 # 4 for mosaic augmentation
226
+ del train_dataset # free memory
217
227
  return super().auto_batch(max_num_obj)