dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
|
@@ -1,32 +1,36 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
import math
|
|
4
6
|
import random
|
|
5
7
|
from copy import copy
|
|
8
|
+
from typing import Any
|
|
6
9
|
|
|
7
10
|
import numpy as np
|
|
11
|
+
import torch
|
|
8
12
|
import torch.nn as nn
|
|
9
13
|
|
|
10
14
|
from ultralytics.data import build_dataloader, build_yolo_dataset
|
|
11
15
|
from ultralytics.engine.trainer import BaseTrainer
|
|
12
16
|
from ultralytics.models import yolo
|
|
13
17
|
from ultralytics.nn.tasks import DetectionModel
|
|
14
|
-
from ultralytics.utils import LOGGER, RANK
|
|
15
|
-
from ultralytics.utils.
|
|
16
|
-
from ultralytics.utils.
|
|
18
|
+
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
|
|
19
|
+
from ultralytics.utils.patches import override_configs
|
|
20
|
+
from ultralytics.utils.plotting import plot_images, plot_labels
|
|
21
|
+
from ultralytics.utils.torch_utils import torch_distributed_zero_first, unwrap_model
|
|
17
22
|
|
|
18
23
|
|
|
19
24
|
class DetectionTrainer(BaseTrainer):
|
|
20
|
-
"""
|
|
21
|
-
A class extending the BaseTrainer class for training based on a detection model.
|
|
25
|
+
"""A class extending the BaseTrainer class for training based on a detection model.
|
|
22
26
|
|
|
23
|
-
This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models
|
|
24
|
-
|
|
27
|
+
This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models for
|
|
28
|
+
object detection including dataset building, data loading, preprocessing, and model configuration.
|
|
25
29
|
|
|
26
30
|
Attributes:
|
|
27
31
|
model (DetectionModel): The YOLO detection model being trained.
|
|
28
32
|
data (dict): Dictionary containing dataset information including class names and number of classes.
|
|
29
|
-
loss_names (
|
|
33
|
+
loss_names (tuple): Names of the loss components used in training (box_loss, cls_loss, dfl_loss).
|
|
30
34
|
|
|
31
35
|
Methods:
|
|
32
36
|
build_dataset: Build YOLO dataset for training or validation.
|
|
@@ -38,7 +42,6 @@ class DetectionTrainer(BaseTrainer):
|
|
|
38
42
|
label_loss_items: Return a loss dictionary with labeled training loss items.
|
|
39
43
|
progress_string: Return a formatted string of training progress.
|
|
40
44
|
plot_training_samples: Plot training samples with their annotations.
|
|
41
|
-
plot_metrics: Plot metrics from a CSV file.
|
|
42
45
|
plot_training_labels: Create a labeled training plot of the YOLO model.
|
|
43
46
|
auto_batch: Calculate optimal batch size based on model memory requirements.
|
|
44
47
|
|
|
@@ -49,24 +52,32 @@ class DetectionTrainer(BaseTrainer):
|
|
|
49
52
|
>>> trainer.train()
|
|
50
53
|
"""
|
|
51
54
|
|
|
52
|
-
def
|
|
55
|
+
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
|
56
|
+
"""Initialize a DetectionTrainer object for training YOLO object detection model training.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
cfg (dict, optional): Default configuration dictionary containing training parameters.
|
|
60
|
+
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
|
|
61
|
+
_callbacks (list, optional): List of callback functions to be executed during training.
|
|
53
62
|
"""
|
|
54
|
-
|
|
63
|
+
super().__init__(cfg, overrides, _callbacks)
|
|
64
|
+
|
|
65
|
+
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
|
|
66
|
+
"""Build YOLO Dataset for training or validation.
|
|
55
67
|
|
|
56
68
|
Args:
|
|
57
69
|
img_path (str): Path to the folder containing images.
|
|
58
|
-
mode (str):
|
|
59
|
-
batch (int, optional): Size of batches, this is for
|
|
70
|
+
mode (str): 'train' mode or 'val' mode, users are able to customize different augmentations for each mode.
|
|
71
|
+
batch (int, optional): Size of batches, this is for 'rect' mode.
|
|
60
72
|
|
|
61
73
|
Returns:
|
|
62
74
|
(Dataset): YOLO dataset object configured for the specified mode.
|
|
63
75
|
"""
|
|
64
|
-
gs = max(int(
|
|
76
|
+
gs = max(int(unwrap_model(self.model).stride.max() if self.model else 0), 32)
|
|
65
77
|
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
|
|
66
78
|
|
|
67
|
-
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
|
68
|
-
"""
|
|
69
|
-
Construct and return dataloader for the specified mode.
|
|
79
|
+
def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
|
|
80
|
+
"""Construct and return dataloader for the specified mode.
|
|
70
81
|
|
|
71
82
|
Args:
|
|
72
83
|
dataset_path (str): Path to the dataset.
|
|
@@ -84,12 +95,17 @@ class DetectionTrainer(BaseTrainer):
|
|
|
84
95
|
if getattr(dataset, "rect", False) and shuffle:
|
|
85
96
|
LOGGER.warning("'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
|
|
86
97
|
shuffle = False
|
|
87
|
-
|
|
88
|
-
|
|
98
|
+
return build_dataloader(
|
|
99
|
+
dataset,
|
|
100
|
+
batch=batch_size,
|
|
101
|
+
workers=self.args.workers if mode == "train" else self.args.workers * 2,
|
|
102
|
+
shuffle=shuffle,
|
|
103
|
+
rank=rank,
|
|
104
|
+
drop_last=self.args.compile and mode == "train",
|
|
105
|
+
)
|
|
89
106
|
|
|
90
|
-
def preprocess_batch(self, batch):
|
|
91
|
-
"""
|
|
92
|
-
Preprocess a batch of images by scaling and converting to float.
|
|
107
|
+
def preprocess_batch(self, batch: dict) -> dict:
|
|
108
|
+
"""Preprocess a batch of images by scaling and converting to float.
|
|
93
109
|
|
|
94
110
|
Args:
|
|
95
111
|
batch (dict): Dictionary containing batch data with 'img' tensor.
|
|
@@ -97,7 +113,10 @@ class DetectionTrainer(BaseTrainer):
|
|
|
97
113
|
Returns:
|
|
98
114
|
(dict): Preprocessed batch with normalized images.
|
|
99
115
|
"""
|
|
100
|
-
|
|
116
|
+
for k, v in batch.items():
|
|
117
|
+
if isinstance(v, torch.Tensor):
|
|
118
|
+
batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
|
|
119
|
+
batch["img"] = batch["img"].float() / 255
|
|
101
120
|
if self.args.multi_scale:
|
|
102
121
|
imgs = batch["img"]
|
|
103
122
|
sz = (
|
|
@@ -125,9 +144,8 @@ class DetectionTrainer(BaseTrainer):
|
|
|
125
144
|
self.model.args = self.args # attach hyperparameters to model
|
|
126
145
|
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
|
|
127
146
|
|
|
128
|
-
def get_model(self, cfg=None, weights=None, verbose=True):
|
|
129
|
-
"""
|
|
130
|
-
Return a YOLO detection model.
|
|
147
|
+
def get_model(self, cfg: str | None = None, weights: str | None = None, verbose: bool = True):
|
|
148
|
+
"""Return a YOLO detection model.
|
|
131
149
|
|
|
132
150
|
Args:
|
|
133
151
|
cfg (str, optional): Path to model configuration file.
|
|
@@ -149,16 +167,15 @@ class DetectionTrainer(BaseTrainer):
|
|
|
149
167
|
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
|
150
168
|
)
|
|
151
169
|
|
|
152
|
-
def label_loss_items(self, loss_items=None, prefix="train"):
|
|
153
|
-
"""
|
|
154
|
-
Return a loss dict with labeled training loss items tensor.
|
|
170
|
+
def label_loss_items(self, loss_items: list[float] | None = None, prefix: str = "train"):
|
|
171
|
+
"""Return a loss dict with labeled training loss items tensor.
|
|
155
172
|
|
|
156
173
|
Args:
|
|
157
|
-
loss_items (
|
|
174
|
+
loss_items (list[float], optional): List of loss values.
|
|
158
175
|
prefix (str): Prefix for keys in the returned dictionary.
|
|
159
176
|
|
|
160
177
|
Returns:
|
|
161
|
-
(
|
|
178
|
+
(dict | list): Dictionary of labeled loss items if loss_items is provided, otherwise list of keys.
|
|
162
179
|
"""
|
|
163
180
|
keys = [f"{prefix}/{x}" for x in self.loss_names]
|
|
164
181
|
if loss_items is not None:
|
|
@@ -177,28 +194,20 @@ class DetectionTrainer(BaseTrainer):
|
|
|
177
194
|
"Size",
|
|
178
195
|
)
|
|
179
196
|
|
|
180
|
-
def plot_training_samples(self, batch, ni):
|
|
181
|
-
"""
|
|
182
|
-
Plot training samples with their annotations.
|
|
197
|
+
def plot_training_samples(self, batch: dict[str, Any], ni: int) -> None:
|
|
198
|
+
"""Plot training samples with their annotations.
|
|
183
199
|
|
|
184
200
|
Args:
|
|
185
|
-
batch (dict): Dictionary containing batch data.
|
|
201
|
+
batch (dict[str, Any]): Dictionary containing batch data.
|
|
186
202
|
ni (int): Number of iterations.
|
|
187
203
|
"""
|
|
188
204
|
plot_images(
|
|
189
|
-
|
|
190
|
-
batch_idx=batch["batch_idx"],
|
|
191
|
-
cls=batch["cls"].squeeze(-1),
|
|
192
|
-
bboxes=batch["bboxes"],
|
|
205
|
+
labels=batch,
|
|
193
206
|
paths=batch["im_file"],
|
|
194
207
|
fname=self.save_dir / f"train_batch{ni}.jpg",
|
|
195
208
|
on_plot=self.on_plot,
|
|
196
209
|
)
|
|
197
210
|
|
|
198
|
-
def plot_metrics(self):
|
|
199
|
-
"""Plot metrics from a CSV file."""
|
|
200
|
-
plot_results(file=self.csv, on_plot=self.on_plot) # save results.png
|
|
201
|
-
|
|
202
211
|
def plot_training_labels(self):
|
|
203
212
|
"""Create a labeled training plot of the YOLO model."""
|
|
204
213
|
boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
|
|
@@ -206,12 +215,13 @@ class DetectionTrainer(BaseTrainer):
|
|
|
206
215
|
plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
|
|
207
216
|
|
|
208
217
|
def auto_batch(self):
|
|
209
|
-
"""
|
|
210
|
-
Get optimal batch size by calculating memory occupation of model.
|
|
218
|
+
"""Get optimal batch size by calculating memory occupation of model.
|
|
211
219
|
|
|
212
220
|
Returns:
|
|
213
221
|
(int): Optimal batch size.
|
|
214
222
|
"""
|
|
215
|
-
|
|
223
|
+
with override_configs(self.args, overrides={"cache": False}) as self.args:
|
|
224
|
+
train_dataset = self.build_dataset(self.data["train"], mode="train", batch=16)
|
|
216
225
|
max_num_obj = max(len(label["cls"]) for label in train_dataset.labels) * 4 # 4 for mosaic augmentation
|
|
226
|
+
del train_dataset # free memory
|
|
217
227
|
return super().auto_batch(max_num_obj)
|