dgenerate-ultralytics-headless 8.3.222__py3-none-any.whl → 8.3.225__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/METADATA +2 -2
- dgenerate_ultralytics_headless-8.3.225.dist-info/RECORD +286 -0
- tests/conftest.py +5 -8
- tests/test_cli.py +1 -8
- tests/test_python.py +1 -2
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +34 -49
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +244 -323
- ultralytics/data/base.py +12 -22
- ultralytics/data/build.py +47 -40
- ultralytics/data/converter.py +32 -42
- ultralytics/data/dataset.py +43 -71
- ultralytics/data/loaders.py +22 -34
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +27 -36
- ultralytics/engine/exporter.py +49 -116
- ultralytics/engine/model.py +144 -180
- ultralytics/engine/predictor.py +18 -29
- ultralytics/engine/results.py +165 -231
- ultralytics/engine/trainer.py +11 -19
- ultralytics/engine/tuner.py +13 -23
- ultralytics/engine/validator.py +6 -10
- ultralytics/hub/__init__.py +7 -12
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +3 -6
- ultralytics/models/fastsam/model.py +6 -8
- ultralytics/models/fastsam/predict.py +5 -10
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +2 -4
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -18
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +13 -20
- ultralytics/models/sam/amg.py +12 -18
- ultralytics/models/sam/build.py +6 -9
- ultralytics/models/sam/model.py +16 -23
- ultralytics/models/sam/modules/blocks.py +62 -84
- ultralytics/models/sam/modules/decoders.py +17 -24
- ultralytics/models/sam/modules/encoders.py +40 -56
- ultralytics/models/sam/modules/memory_attention.py +10 -16
- ultralytics/models/sam/modules/sam.py +41 -47
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +17 -27
- ultralytics/models/sam/modules/utils.py +31 -42
- ultralytics/models/sam/predict.py +172 -209
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/classify/predict.py +8 -11
- ultralytics/models/yolo/classify/train.py +8 -16
- ultralytics/models/yolo/classify/val.py +13 -20
- ultralytics/models/yolo/detect/predict.py +4 -8
- ultralytics/models/yolo/detect/train.py +11 -20
- ultralytics/models/yolo/detect/val.py +38 -48
- ultralytics/models/yolo/model.py +35 -47
- ultralytics/models/yolo/obb/predict.py +5 -8
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +20 -28
- ultralytics/models/yolo/pose/predict.py +5 -8
- ultralytics/models/yolo/pose/train.py +4 -8
- ultralytics/models/yolo/pose/val.py +31 -39
- ultralytics/models/yolo/segment/predict.py +9 -14
- ultralytics/models/yolo/segment/train.py +3 -6
- ultralytics/models/yolo/segment/val.py +16 -26
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -16
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +30 -43
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/autobackend.py +10 -18
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +99 -185
- ultralytics/nn/modules/conv.py +45 -90
- ultralytics/nn/modules/head.py +44 -98
- ultralytics/nn/modules/transformer.py +44 -76
- ultralytics/nn/modules/utils.py +14 -19
- ultralytics/nn/tasks.py +86 -146
- ultralytics/nn/text_model.py +25 -40
- ultralytics/solutions/ai_gym.py +10 -16
- ultralytics/solutions/analytics.py +7 -10
- ultralytics/solutions/config.py +4 -5
- ultralytics/solutions/distance_calculation.py +9 -12
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +8 -12
- ultralytics/solutions/object_cropper.py +5 -8
- ultralytics/solutions/parking_management.py +12 -14
- ultralytics/solutions/queue_management.py +4 -6
- ultralytics/solutions/region_counter.py +7 -10
- ultralytics/solutions/security_alarm.py +14 -19
- ultralytics/solutions/similarity_search.py +7 -12
- ultralytics/solutions/solutions.py +31 -53
- ultralytics/solutions/speed_estimation.py +6 -9
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/basetrack.py +2 -4
- ultralytics/trackers/bot_sort.py +6 -11
- ultralytics/trackers/byte_tracker.py +10 -15
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +6 -12
- ultralytics/trackers/utils/kalman_filter.py +35 -43
- ultralytics/trackers/utils/matching.py +6 -10
- ultralytics/utils/__init__.py +61 -100
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +11 -13
- ultralytics/utils/benchmarks.py +25 -35
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +2 -4
- ultralytics/utils/callbacks/comet.py +30 -44
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +4 -6
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +4 -6
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +29 -56
- ultralytics/utils/cpu.py +1 -2
- ultralytics/utils/dist.py +8 -12
- ultralytics/utils/downloads.py +17 -27
- ultralytics/utils/errors.py +6 -8
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -239
- ultralytics/utils/export/engine.py +237 -0
- ultralytics/utils/export/imx.py +11 -17
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +10 -15
- ultralytics/utils/git.py +5 -7
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +11 -15
- ultralytics/utils/loss.py +8 -14
- ultralytics/utils/metrics.py +98 -138
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +47 -74
- ultralytics/utils/patches.py +11 -18
- ultralytics/utils/plotting.py +29 -42
- ultralytics/utils/tal.py +25 -39
- ultralytics/utils/torch_utils.py +45 -73
- ultralytics/utils/tqdm.py +6 -8
- ultralytics/utils/triton.py +9 -12
- ultralytics/utils/tuner.py +1 -2
- dgenerate_ultralytics_headless-8.3.222.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from ultralytics.nn.modules import Detect, Pose
|
|
11
|
+
from ultralytics.utils import LOGGER
|
|
12
|
+
from ultralytics.utils.downloads import attempt_download_asset
|
|
13
|
+
from ultralytics.utils.files import spaces_in_path
|
|
14
|
+
from ultralytics.utils.tal import make_anchors
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def tf_wrapper(model: torch.nn.Module) -> torch.nn.Module:
|
|
18
|
+
"""A wrapper to add TensorFlow compatible inference methods to Detect and Pose layers."""
|
|
19
|
+
for m in model.modules():
|
|
20
|
+
if not isinstance(m, Detect):
|
|
21
|
+
continue
|
|
22
|
+
import types
|
|
23
|
+
|
|
24
|
+
m._inference = types.MethodType(_tf_inference, m)
|
|
25
|
+
if type(m) is Pose:
|
|
26
|
+
m.kpts_decode = types.MethodType(tf_kpts_decode, m)
|
|
27
|
+
return model
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _tf_inference(self, x: list[torch.Tensor]) -> tuple[torch.Tensor]:
|
|
31
|
+
"""Decode boxes and cls scores for tf object detection."""
|
|
32
|
+
shape = x[0].shape # BCHW
|
|
33
|
+
x_cat = torch.cat([xi.view(x[0].shape[0], self.no, -1) for xi in x], 2)
|
|
34
|
+
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
|
35
|
+
if self.dynamic or self.shape != shape:
|
|
36
|
+
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
|
37
|
+
self.shape = shape
|
|
38
|
+
grid_h, grid_w = shape[2], shape[3]
|
|
39
|
+
grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
|
|
40
|
+
norm = self.strides / (self.stride[0] * grid_size)
|
|
41
|
+
dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
|
|
42
|
+
return torch.cat((dbox, cls.sigmoid()), 1)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def tf_kpts_decode(self, bs: int, kpts: torch.Tensor) -> torch.Tensor:
|
|
46
|
+
"""Decode keypoints for tf pose estimation."""
|
|
47
|
+
ndim = self.kpt_shape[1]
|
|
48
|
+
# required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
|
|
49
|
+
# Precompute normalization factor to increase numerical stability
|
|
50
|
+
y = kpts.view(bs, *self.kpt_shape, -1)
|
|
51
|
+
grid_h, grid_w = self.shape[2], self.shape[3]
|
|
52
|
+
grid_size = torch.tensor([grid_w, grid_h], device=y.device).reshape(1, 2, 1)
|
|
53
|
+
norm = self.strides / (self.stride[0] * grid_size)
|
|
54
|
+
a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * norm
|
|
55
|
+
if ndim == 3:
|
|
56
|
+
a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
|
|
57
|
+
return a.view(bs, self.nk, -1)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def onnx2saved_model(
|
|
61
|
+
onnx_file: str,
|
|
62
|
+
output_dir: Path,
|
|
63
|
+
int8: bool = False,
|
|
64
|
+
images: np.ndarray = None,
|
|
65
|
+
disable_group_convolution: bool = False,
|
|
66
|
+
prefix="",
|
|
67
|
+
):
|
|
68
|
+
"""Convert a ONNX model to TensorFlow SavedModel format via ONNX.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
onnx_file (str): ONNX file path.
|
|
72
|
+
output_dir (Path): Output directory path for the SavedModel.
|
|
73
|
+
int8 (bool, optional): Enable INT8 quantization. Defaults to False.
|
|
74
|
+
images (np.ndarray, optional): Calibration images for INT8 quantization in BHWC format.
|
|
75
|
+
disable_group_convolution (bool, optional): Disable group convolution optimization. Defaults to False.
|
|
76
|
+
prefix (str, optional): Logging prefix. Defaults to "".
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
(keras.Model): Converted Keras model.
|
|
80
|
+
|
|
81
|
+
Notes:
|
|
82
|
+
- Requires onnx2tf package. Downloads calibration data if INT8 quantization is enabled.
|
|
83
|
+
- Removes temporary files and renames quantized models after conversion.
|
|
84
|
+
"""
|
|
85
|
+
# Pre-download calibration file to fix https://github.com/PINTO0309/onnx2tf/issues/545
|
|
86
|
+
onnx2tf_file = Path("calibration_image_sample_data_20x128x128x3_float32.npy")
|
|
87
|
+
if not onnx2tf_file.exists():
|
|
88
|
+
attempt_download_asset(f"{onnx2tf_file}.zip", unzip=True, delete=True)
|
|
89
|
+
np_data = None
|
|
90
|
+
if int8:
|
|
91
|
+
tmp_file = output_dir / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file
|
|
92
|
+
if images is not None:
|
|
93
|
+
output_dir.mkdir()
|
|
94
|
+
np.save(str(tmp_file), images) # BHWC
|
|
95
|
+
np_data = [["images", tmp_file, [[[[0, 0, 0]]]], [[[[255, 255, 255]]]]]]
|
|
96
|
+
|
|
97
|
+
import onnx2tf # scoped for after ONNX export for reduced conflict during import
|
|
98
|
+
|
|
99
|
+
LOGGER.info(f"{prefix} starting TFLite export with onnx2tf {onnx2tf.__version__}...")
|
|
100
|
+
keras_model = onnx2tf.convert(
|
|
101
|
+
input_onnx_file_path=onnx_file,
|
|
102
|
+
output_folder_path=str(output_dir),
|
|
103
|
+
not_use_onnxsim=True,
|
|
104
|
+
verbosity="error", # note INT8-FP16 activation bug https://github.com/ultralytics/ultralytics/issues/15873
|
|
105
|
+
output_integer_quantized_tflite=int8,
|
|
106
|
+
custom_input_op_name_np_data_path=np_data,
|
|
107
|
+
enable_batchmatmul_unfold=True and not int8, # fix lower no. of detected objects on GPU delegate
|
|
108
|
+
output_signaturedefs=True, # fix error with Attention block group convolution
|
|
109
|
+
disable_group_convolution=disable_group_convolution, # fix error with group convolution
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Remove/rename TFLite models
|
|
113
|
+
if int8:
|
|
114
|
+
tmp_file.unlink(missing_ok=True)
|
|
115
|
+
for file in output_dir.rglob("*_dynamic_range_quant.tflite"):
|
|
116
|
+
file.rename(file.with_name(file.stem.replace("_dynamic_range_quant", "_int8") + file.suffix))
|
|
117
|
+
for file in output_dir.rglob("*_integer_quant_with_int16_act.tflite"):
|
|
118
|
+
file.unlink() # delete extra fp16 activation TFLite files
|
|
119
|
+
return keras_model
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def keras2pb(keras_model, file: Path, prefix=""):
|
|
123
|
+
"""Convert a Keras model to TensorFlow GraphDef (.pb) format.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
keras_model (keras.Model): Keras model to convert to frozen graph format.
|
|
127
|
+
file (Path): Output file path (suffix will be changed to .pb).
|
|
128
|
+
prefix (str, optional): Logging prefix. Defaults to "".
|
|
129
|
+
|
|
130
|
+
Notes:
|
|
131
|
+
Creates a frozen graph by converting variables to constants for inference optimization.
|
|
132
|
+
"""
|
|
133
|
+
import tensorflow as tf
|
|
134
|
+
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
|
|
135
|
+
|
|
136
|
+
LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
|
|
137
|
+
m = tf.function(lambda x: keras_model(x)) # full model
|
|
138
|
+
m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
|
|
139
|
+
frozen_func = convert_variables_to_constants_v2(m)
|
|
140
|
+
frozen_func.graph.as_graph_def()
|
|
141
|
+
tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(file.parent), name=file.name, as_text=False)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def tflite2edgetpu(tflite_file: str | Path, output_dir: str | Path, prefix: str = ""):
|
|
145
|
+
"""Convert a TensorFlow Lite model to Edge TPU format using the Edge TPU compiler.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
tflite_file (str | Path): Path to the input TensorFlow Lite (.tflite) model file.
|
|
149
|
+
output_dir (str | Path): Output directory path for the compiled Edge TPU model.
|
|
150
|
+
prefix (str, optional): Logging prefix. Defaults to "".
|
|
151
|
+
|
|
152
|
+
Notes:
|
|
153
|
+
Requires the Edge TPU compiler to be installed. The function compiles the TFLite model
|
|
154
|
+
for optimal performance on Google's Edge TPU hardware accelerator.
|
|
155
|
+
"""
|
|
156
|
+
import subprocess
|
|
157
|
+
|
|
158
|
+
cmd = (
|
|
159
|
+
"edgetpu_compiler "
|
|
160
|
+
f'--out_dir "{output_dir}" '
|
|
161
|
+
"--show_operations "
|
|
162
|
+
"--search_delegate "
|
|
163
|
+
"--delegate_search_step 30 "
|
|
164
|
+
"--timeout_sec 180 "
|
|
165
|
+
f'"{tflite_file}"'
|
|
166
|
+
)
|
|
167
|
+
LOGGER.info(f"{prefix} running '{cmd}'")
|
|
168
|
+
subprocess.run(cmd, shell=True)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def pb2tfjs(pb_file: str, output_dir: str, half: bool = False, int8: bool = False, prefix: str = ""):
|
|
172
|
+
"""Convert a TensorFlow GraphDef (.pb) model to TensorFlow.js format.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
pb_file (str): Path to the input TensorFlow GraphDef (.pb) model file.
|
|
176
|
+
output_dir (str): Output directory path for the converted TensorFlow.js model.
|
|
177
|
+
half (bool, optional): Enable FP16 quantization. Defaults to False.
|
|
178
|
+
int8 (bool, optional): Enable INT8 quantization. Defaults to False.
|
|
179
|
+
prefix (str, optional): Logging prefix. Defaults to "".
|
|
180
|
+
|
|
181
|
+
Notes:
|
|
182
|
+
Requires tensorflowjs package. Uses tensorflowjs_converter command-line tool for conversion.
|
|
183
|
+
Handles spaces in file paths and warns if output directory contains spaces.
|
|
184
|
+
"""
|
|
185
|
+
import subprocess
|
|
186
|
+
|
|
187
|
+
import tensorflow as tf
|
|
188
|
+
import tensorflowjs as tfjs
|
|
189
|
+
|
|
190
|
+
LOGGER.info(f"\n{prefix} starting export with tensorflowjs {tfjs.__version__}...")
|
|
191
|
+
|
|
192
|
+
gd = tf.Graph().as_graph_def() # TF GraphDef
|
|
193
|
+
with open(pb_file, "rb") as file:
|
|
194
|
+
gd.ParseFromString(file.read())
|
|
195
|
+
outputs = ",".join(gd_outputs(gd))
|
|
196
|
+
LOGGER.info(f"\n{prefix} output node names: {outputs}")
|
|
197
|
+
|
|
198
|
+
quantization = "--quantize_float16" if half else "--quantize_uint8" if int8 else ""
|
|
199
|
+
with spaces_in_path(pb_file) as fpb_, spaces_in_path(output_dir) as f_: # exporter can not handle spaces in path
|
|
200
|
+
cmd = (
|
|
201
|
+
"tensorflowjs_converter "
|
|
202
|
+
f'--input_format=tf_frozen_model {quantization} --output_node_names={outputs} "{fpb_}" "{f_}"'
|
|
203
|
+
)
|
|
204
|
+
LOGGER.info(f"{prefix} running '{cmd}'")
|
|
205
|
+
subprocess.run(cmd, shell=True)
|
|
206
|
+
|
|
207
|
+
if " " in output_dir:
|
|
208
|
+
LOGGER.warning(f"{prefix} your model may not work correctly with spaces in path '{output_dir}'.")
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def gd_outputs(gd):
|
|
212
|
+
"""Return TensorFlow GraphDef model output node names."""
|
|
213
|
+
name_list, input_list = [], []
|
|
214
|
+
for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
|
|
215
|
+
name_list.append(node.name)
|
|
216
|
+
input_list.extend(node.input)
|
|
217
|
+
return sorted(f"{x}:0" for x in list(set(name_list) - set(input_list)) if not x.startswith("NoOp"))
|
ultralytics/utils/files.py
CHANGED
|
@@ -13,11 +13,10 @@ from pathlib import Path
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class WorkingDirectory(contextlib.ContextDecorator):
|
|
16
|
-
"""
|
|
17
|
-
A context manager and decorator for temporarily changing the working directory.
|
|
16
|
+
"""A context manager and decorator for temporarily changing the working directory.
|
|
18
17
|
|
|
19
|
-
This class allows for the temporary change of the working directory using a context manager or decorator.
|
|
20
|
-
|
|
18
|
+
This class allows for the temporary change of the working directory using a context manager or decorator. It ensures
|
|
19
|
+
that the original working directory is restored after the context or decorated function completes.
|
|
21
20
|
|
|
22
21
|
Attributes:
|
|
23
22
|
dir (Path | str): The new directory to switch to.
|
|
@@ -56,8 +55,7 @@ class WorkingDirectory(contextlib.ContextDecorator):
|
|
|
56
55
|
|
|
57
56
|
@contextmanager
|
|
58
57
|
def spaces_in_path(path: str | Path):
|
|
59
|
-
"""
|
|
60
|
-
Context manager to handle paths with spaces in their names.
|
|
58
|
+
"""Context manager to handle paths with spaces in their names.
|
|
61
59
|
|
|
62
60
|
If a path contains spaces, it replaces them with underscores, copies the file/directory to the new path, executes
|
|
63
61
|
the context code block, then copies the file/directory back to its original location.
|
|
@@ -66,8 +64,7 @@ def spaces_in_path(path: str | Path):
|
|
|
66
64
|
path (str | Path): The original path that may contain spaces.
|
|
67
65
|
|
|
68
66
|
Yields:
|
|
69
|
-
(Path | str): Temporary path with spaces replaced by underscores
|
|
70
|
-
original path.
|
|
67
|
+
(Path | str): Temporary path with any spaces replaced by underscores.
|
|
71
68
|
|
|
72
69
|
Examples:
|
|
73
70
|
>>> with spaces_in_path('/path/with spaces') as new_path:
|
|
@@ -107,12 +104,11 @@ def spaces_in_path(path: str | Path):
|
|
|
107
104
|
|
|
108
105
|
|
|
109
106
|
def increment_path(path: str | Path, exist_ok: bool = False, sep: str = "", mkdir: bool = False) -> Path:
|
|
110
|
-
"""
|
|
111
|
-
Increment a file or directory path, i.e., runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
|
|
107
|
+
"""Increment a file or directory path, i.e., runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
|
|
112
108
|
|
|
113
|
-
If the path exists and `exist_ok` is not True, the path will be incremented by appending a number and `sep` to
|
|
114
|
-
|
|
115
|
-
|
|
109
|
+
If the path exists and `exist_ok` is not True, the path will be incremented by appending a number and `sep` to the
|
|
110
|
+
end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the number
|
|
111
|
+
will be appended directly to the end of the path.
|
|
116
112
|
|
|
117
113
|
Args:
|
|
118
114
|
path (str | Path): Path to increment.
|
|
@@ -185,8 +181,7 @@ def get_latest_run(search_dir: str = ".") -> str:
|
|
|
185
181
|
|
|
186
182
|
|
|
187
183
|
def update_models(model_names: tuple = ("yolo11n.pt",), source_dir: Path = Path("."), update_names: bool = False):
|
|
188
|
-
"""
|
|
189
|
-
Update and re-save specified YOLO models in an 'updated_models' subdirectory.
|
|
184
|
+
"""Update and re-save specified YOLO models in an 'updated_models' subdirectory.
|
|
190
185
|
|
|
191
186
|
Args:
|
|
192
187
|
model_names (tuple, optional): Model filenames to update.
|
ultralytics/utils/git.py
CHANGED
|
@@ -7,13 +7,12 @@ from pathlib import Path
|
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class GitRepo:
|
|
10
|
-
"""
|
|
11
|
-
Represent a local Git repository and expose branch, commit, and remote metadata.
|
|
10
|
+
"""Represent a local Git repository and expose branch, commit, and remote metadata.
|
|
12
11
|
|
|
13
12
|
This class discovers the repository root by searching for a .git entry from the given path upward, resolves the
|
|
14
|
-
actual .git directory (including worktrees), and reads Git metadata directly from on-disk files. It does not
|
|
15
|
-
|
|
16
|
-
|
|
13
|
+
actual .git directory (including worktrees), and reads Git metadata directly from on-disk files. It does not invoke
|
|
14
|
+
the git binary and therefore works in restricted environments. All metadata properties are resolved lazily and
|
|
15
|
+
cached; construct a new instance to refresh state.
|
|
17
16
|
|
|
18
17
|
Attributes:
|
|
19
18
|
root (Path | None): Repository root directory containing the .git entry; None if not in a repository.
|
|
@@ -39,8 +38,7 @@ class GitRepo:
|
|
|
39
38
|
"""
|
|
40
39
|
|
|
41
40
|
def __init__(self, path: Path = Path(__file__).resolve()):
|
|
42
|
-
"""
|
|
43
|
-
Initialize a Git repository context by discovering the repository root from a starting path.
|
|
41
|
+
"""Initialize a Git repository context by discovering the repository root from a starting path.
|
|
44
42
|
|
|
45
43
|
Args:
|
|
46
44
|
path (Path, optional): File or directory path used as the starting point to locate the repository root.
|
ultralytics/utils/instance.py
CHANGED
|
@@ -33,8 +33,7 @@ __all__ = ("Bboxes", "Instances") # tuple or list
|
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
class Bboxes:
|
|
36
|
-
"""
|
|
37
|
-
A class for handling bounding boxes in multiple formats.
|
|
36
|
+
"""A class for handling bounding boxes in multiple formats.
|
|
38
37
|
|
|
39
38
|
The class supports various bounding box formats like 'xyxy', 'xywh', and 'ltwh' and provides methods for format
|
|
40
39
|
conversion, scaling, and area calculation. Bounding box data should be provided as numpy arrays.
|
|
@@ -61,8 +60,7 @@ class Bboxes:
|
|
|
61
60
|
"""
|
|
62
61
|
|
|
63
62
|
def __init__(self, bboxes: np.ndarray, format: str = "xyxy") -> None:
|
|
64
|
-
"""
|
|
65
|
-
Initialize the Bboxes class with bounding box data in a specified format.
|
|
63
|
+
"""Initialize the Bboxes class with bounding box data in a specified format.
|
|
66
64
|
|
|
67
65
|
Args:
|
|
68
66
|
bboxes (np.ndarray): Array of bounding boxes with shape (N, 4) or (4,).
|
|
@@ -76,8 +74,7 @@ class Bboxes:
|
|
|
76
74
|
self.format = format
|
|
77
75
|
|
|
78
76
|
def convert(self, format: str) -> None:
|
|
79
|
-
"""
|
|
80
|
-
Convert bounding box format from one type to another.
|
|
77
|
+
"""Convert bounding box format from one type to another.
|
|
81
78
|
|
|
82
79
|
Args:
|
|
83
80
|
format (str): Target format for conversion, one of 'xyxy', 'xywh', or 'ltwh'.
|
|
@@ -103,12 +100,11 @@ class Bboxes:
|
|
|
103
100
|
)
|
|
104
101
|
|
|
105
102
|
def mul(self, scale: int | tuple | list) -> None:
|
|
106
|
-
"""
|
|
107
|
-
Multiply bounding box coordinates by scale factor(s).
|
|
103
|
+
"""Multiply bounding box coordinates by scale factor(s).
|
|
108
104
|
|
|
109
105
|
Args:
|
|
110
|
-
scale (int | tuple | list): Scale factor(s) for four coordinates. If int, the same scale is applied to
|
|
111
|
-
|
|
106
|
+
scale (int | tuple | list): Scale factor(s) for four coordinates. If int, the same scale is applied to all
|
|
107
|
+
coordinates.
|
|
112
108
|
"""
|
|
113
109
|
if isinstance(scale, Number):
|
|
114
110
|
scale = to_4tuple(scale)
|
|
@@ -120,12 +116,11 @@ class Bboxes:
|
|
|
120
116
|
self.bboxes[:, 3] *= scale[3]
|
|
121
117
|
|
|
122
118
|
def add(self, offset: int | tuple | list) -> None:
|
|
123
|
-
"""
|
|
124
|
-
Add offset to bounding box coordinates.
|
|
119
|
+
"""Add offset to bounding box coordinates.
|
|
125
120
|
|
|
126
121
|
Args:
|
|
127
|
-
offset (int | tuple | list): Offset(s) for four coordinates. If int, the same offset is applied to
|
|
128
|
-
|
|
122
|
+
offset (int | tuple | list): Offset(s) for four coordinates. If int, the same offset is applied to all
|
|
123
|
+
coordinates.
|
|
129
124
|
"""
|
|
130
125
|
if isinstance(offset, Number):
|
|
131
126
|
offset = to_4tuple(offset)
|
|
@@ -142,8 +137,7 @@ class Bboxes:
|
|
|
142
137
|
|
|
143
138
|
@classmethod
|
|
144
139
|
def concatenate(cls, boxes_list: list[Bboxes], axis: int = 0) -> Bboxes:
|
|
145
|
-
"""
|
|
146
|
-
Concatenate a list of Bboxes objects into a single Bboxes object.
|
|
140
|
+
"""Concatenate a list of Bboxes objects into a single Bboxes object.
|
|
147
141
|
|
|
148
142
|
Args:
|
|
149
143
|
boxes_list (list[Bboxes]): A list of Bboxes objects to concatenate.
|
|
@@ -165,8 +159,7 @@ class Bboxes:
|
|
|
165
159
|
return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))
|
|
166
160
|
|
|
167
161
|
def __getitem__(self, index: int | np.ndarray | slice) -> Bboxes:
|
|
168
|
-
"""
|
|
169
|
-
Retrieve a specific bounding box or a set of bounding boxes using indexing.
|
|
162
|
+
"""Retrieve a specific bounding box or a set of bounding boxes using indexing.
|
|
170
163
|
|
|
171
164
|
Args:
|
|
172
165
|
index (int | slice | np.ndarray): The index, slice, or boolean array to select the desired bounding boxes.
|
|
@@ -186,12 +179,11 @@ class Bboxes:
|
|
|
186
179
|
|
|
187
180
|
|
|
188
181
|
class Instances:
|
|
189
|
-
"""
|
|
190
|
-
Container for bounding boxes, segments, and keypoints of detected objects in an image.
|
|
182
|
+
"""Container for bounding boxes, segments, and keypoints of detected objects in an image.
|
|
191
183
|
|
|
192
|
-
This class provides a unified interface for handling different types of object annotations including bounding
|
|
193
|
-
|
|
194
|
-
|
|
184
|
+
This class provides a unified interface for handling different types of object annotations including bounding boxes,
|
|
185
|
+
segmentation masks, and keypoints. It supports various operations like scaling, normalization, clipping, and format
|
|
186
|
+
conversion.
|
|
195
187
|
|
|
196
188
|
Attributes:
|
|
197
189
|
_bboxes (Bboxes): Internal object for handling bounding box operations.
|
|
@@ -229,8 +221,7 @@ class Instances:
|
|
|
229
221
|
bbox_format: str = "xywh",
|
|
230
222
|
normalized: bool = True,
|
|
231
223
|
) -> None:
|
|
232
|
-
"""
|
|
233
|
-
Initialize the Instances object with bounding boxes, segments, and keypoints.
|
|
224
|
+
"""Initialize the Instances object with bounding boxes, segments, and keypoints.
|
|
234
225
|
|
|
235
226
|
Args:
|
|
236
227
|
bboxes (np.ndarray): Bounding boxes with shape (N, 4).
|
|
@@ -245,8 +236,7 @@ class Instances:
|
|
|
245
236
|
self.segments = segments
|
|
246
237
|
|
|
247
238
|
def convert_bbox(self, format: str) -> None:
|
|
248
|
-
"""
|
|
249
|
-
Convert bounding box format.
|
|
239
|
+
"""Convert bounding box format.
|
|
250
240
|
|
|
251
241
|
Args:
|
|
252
242
|
format (str): Target format for conversion, one of 'xyxy', 'xywh', or 'ltwh'.
|
|
@@ -259,8 +249,7 @@ class Instances:
|
|
|
259
249
|
return self._bboxes.areas()
|
|
260
250
|
|
|
261
251
|
def scale(self, scale_w: float, scale_h: float, bbox_only: bool = False):
|
|
262
|
-
"""
|
|
263
|
-
Scale coordinates by given factors.
|
|
252
|
+
"""Scale coordinates by given factors.
|
|
264
253
|
|
|
265
254
|
Args:
|
|
266
255
|
scale_w (float): Scale factor for width.
|
|
@@ -277,8 +266,7 @@ class Instances:
|
|
|
277
266
|
self.keypoints[..., 1] *= scale_h
|
|
278
267
|
|
|
279
268
|
def denormalize(self, w: int, h: int) -> None:
|
|
280
|
-
"""
|
|
281
|
-
Convert normalized coordinates to absolute coordinates.
|
|
269
|
+
"""Convert normalized coordinates to absolute coordinates.
|
|
282
270
|
|
|
283
271
|
Args:
|
|
284
272
|
w (int): Image width.
|
|
@@ -295,8 +283,7 @@ class Instances:
|
|
|
295
283
|
self.normalized = False
|
|
296
284
|
|
|
297
285
|
def normalize(self, w: int, h: int) -> None:
|
|
298
|
-
"""
|
|
299
|
-
Convert absolute coordinates to normalized coordinates.
|
|
286
|
+
"""Convert absolute coordinates to normalized coordinates.
|
|
300
287
|
|
|
301
288
|
Args:
|
|
302
289
|
w (int): Image width.
|
|
@@ -313,8 +300,7 @@ class Instances:
|
|
|
313
300
|
self.normalized = True
|
|
314
301
|
|
|
315
302
|
def add_padding(self, padw: int, padh: int) -> None:
|
|
316
|
-
"""
|
|
317
|
-
Add padding to coordinates.
|
|
303
|
+
"""Add padding to coordinates.
|
|
318
304
|
|
|
319
305
|
Args:
|
|
320
306
|
padw (int): Padding width.
|
|
@@ -329,8 +315,7 @@ class Instances:
|
|
|
329
315
|
self.keypoints[..., 1] += padh
|
|
330
316
|
|
|
331
317
|
def __getitem__(self, index: int | np.ndarray | slice) -> Instances:
|
|
332
|
-
"""
|
|
333
|
-
Retrieve a specific instance or a set of instances using indexing.
|
|
318
|
+
"""Retrieve a specific instance or a set of instances using indexing.
|
|
334
319
|
|
|
335
320
|
Args:
|
|
336
321
|
index (int | slice | np.ndarray): The index, slice, or boolean array to select the desired instances.
|
|
@@ -355,8 +340,7 @@ class Instances:
|
|
|
355
340
|
)
|
|
356
341
|
|
|
357
342
|
def flipud(self, h: int) -> None:
|
|
358
|
-
"""
|
|
359
|
-
Flip coordinates vertically.
|
|
343
|
+
"""Flip coordinates vertically.
|
|
360
344
|
|
|
361
345
|
Args:
|
|
362
346
|
h (int): Image height.
|
|
@@ -373,8 +357,7 @@ class Instances:
|
|
|
373
357
|
self.keypoints[..., 1] = h - self.keypoints[..., 1]
|
|
374
358
|
|
|
375
359
|
def fliplr(self, w: int) -> None:
|
|
376
|
-
"""
|
|
377
|
-
Flip coordinates horizontally.
|
|
360
|
+
"""Flip coordinates horizontally.
|
|
378
361
|
|
|
379
362
|
Args:
|
|
380
363
|
w (int): Image width.
|
|
@@ -391,8 +374,7 @@ class Instances:
|
|
|
391
374
|
self.keypoints[..., 0] = w - self.keypoints[..., 0]
|
|
392
375
|
|
|
393
376
|
def clip(self, w: int, h: int) -> None:
|
|
394
|
-
"""
|
|
395
|
-
Clip coordinates to stay within image boundaries.
|
|
377
|
+
"""Clip coordinates to stay within image boundaries.
|
|
396
378
|
|
|
397
379
|
Args:
|
|
398
380
|
w (int): Image width.
|
|
@@ -418,8 +400,7 @@ class Instances:
|
|
|
418
400
|
self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h)
|
|
419
401
|
|
|
420
402
|
def remove_zero_area_boxes(self) -> np.ndarray:
|
|
421
|
-
"""
|
|
422
|
-
Remove zero-area boxes, i.e. after clipping some boxes may have zero width or height.
|
|
403
|
+
"""Remove zero-area boxes, i.e. after clipping some boxes may have zero width or height.
|
|
423
404
|
|
|
424
405
|
Returns:
|
|
425
406
|
(np.ndarray): Boolean array indicating which boxes were kept.
|
|
@@ -434,8 +415,7 @@ class Instances:
|
|
|
434
415
|
return good
|
|
435
416
|
|
|
436
417
|
def update(self, bboxes: np.ndarray, segments: np.ndarray = None, keypoints: np.ndarray = None):
|
|
437
|
-
"""
|
|
438
|
-
Update instance variables.
|
|
418
|
+
"""Update instance variables.
|
|
439
419
|
|
|
440
420
|
Args:
|
|
441
421
|
bboxes (np.ndarray): New bounding boxes.
|
|
@@ -454,16 +434,15 @@ class Instances:
|
|
|
454
434
|
|
|
455
435
|
@classmethod
|
|
456
436
|
def concatenate(cls, instances_list: list[Instances], axis=0) -> Instances:
|
|
457
|
-
"""
|
|
458
|
-
Concatenate a list of Instances objects into a single Instances object.
|
|
437
|
+
"""Concatenate a list of Instances objects into a single Instances object.
|
|
459
438
|
|
|
460
439
|
Args:
|
|
461
440
|
instances_list (list[Instances]): A list of Instances objects to concatenate.
|
|
462
441
|
axis (int, optional): The axis along which the arrays will be concatenated.
|
|
463
442
|
|
|
464
443
|
Returns:
|
|
465
|
-
(Instances): A new Instances object containing the concatenated bounding boxes, segments, and keypoints
|
|
466
|
-
|
|
444
|
+
(Instances): A new Instances object containing the concatenated bounding boxes, segments, and keypoints if
|
|
445
|
+
present.
|
|
467
446
|
|
|
468
447
|
Notes:
|
|
469
448
|
The `Instances` objects in the list should have the same properties, such as the format of the bounding
|
ultralytics/utils/logger.py
CHANGED
|
@@ -19,11 +19,10 @@ if RANK in {-1, 0} and DEFAULT_LOG_PATH.exists():
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class ConsoleLogger:
|
|
22
|
-
"""
|
|
23
|
-
Console output capture with API/file streaming and deduplication.
|
|
22
|
+
"""Console output capture with API/file streaming and deduplication.
|
|
24
23
|
|
|
25
|
-
Captures stdout/stderr output and streams it to either an API endpoint or local file, with intelligent
|
|
26
|
-
|
|
24
|
+
Captures stdout/stderr output and streams it to either an API endpoint or local file, with intelligent deduplication
|
|
25
|
+
to reduce noise from repetitive console output.
|
|
27
26
|
|
|
28
27
|
Attributes:
|
|
29
28
|
destination (str | Path): Target destination for streaming (URL or Path object).
|
|
@@ -53,8 +52,7 @@ class ConsoleLogger:
|
|
|
53
52
|
"""
|
|
54
53
|
|
|
55
54
|
def __init__(self, destination):
|
|
56
|
-
"""
|
|
57
|
-
Initialize with API endpoint or local file path.
|
|
55
|
+
"""Initialize with API endpoint or local file path.
|
|
58
56
|
|
|
59
57
|
Args:
|
|
60
58
|
destination (str | Path): API endpoint URL (http/https) or local file path for streaming output.
|
|
@@ -227,11 +225,10 @@ class ConsoleLogger:
|
|
|
227
225
|
|
|
228
226
|
|
|
229
227
|
class SystemLogger:
|
|
230
|
-
"""
|
|
231
|
-
Log dynamic system metrics for training monitoring.
|
|
228
|
+
"""Log dynamic system metrics for training monitoring.
|
|
232
229
|
|
|
233
|
-
Captures real-time system metrics including CPU, RAM, disk I/O, network I/O, and NVIDIA GPU statistics for
|
|
234
|
-
|
|
230
|
+
Captures real-time system metrics including CPU, RAM, disk I/O, network I/O, and NVIDIA GPU statistics for training
|
|
231
|
+
performance monitoring and analysis.
|
|
235
232
|
|
|
236
233
|
Attributes:
|
|
237
234
|
pynvml: NVIDIA pynvml module instance if successfully imported, None otherwise.
|
|
@@ -277,11 +274,10 @@ class SystemLogger:
|
|
|
277
274
|
return False
|
|
278
275
|
|
|
279
276
|
def get_metrics(self):
|
|
280
|
-
"""
|
|
281
|
-
Get current system metrics.
|
|
277
|
+
"""Get current system metrics.
|
|
282
278
|
|
|
283
|
-
Collects comprehensive system metrics including CPU usage, RAM usage, disk I/O statistics,
|
|
284
|
-
|
|
279
|
+
Collects comprehensive system metrics including CPU usage, RAM usage, disk I/O statistics, network I/O
|
|
280
|
+
statistics, and GPU metrics (if available). Example output:
|
|
285
281
|
|
|
286
282
|
```python
|
|
287
283
|
metrics = {
|
|
@@ -312,7 +308,7 @@ class SystemLogger:
|
|
|
312
308
|
- power (int): GPU power consumption in watts
|
|
313
309
|
|
|
314
310
|
Returns:
|
|
315
|
-
metrics (dict): System metrics containing 'cpu', 'ram', 'disk', 'network', 'gpus' with
|
|
311
|
+
metrics (dict): System metrics containing 'cpu', 'ram', 'disk', 'network', 'gpus' with usage data.
|
|
316
312
|
"""
|
|
317
313
|
import psutil # scoped as slow import
|
|
318
314
|
|
ultralytics/utils/loss.py
CHANGED
|
@@ -18,8 +18,7 @@ from .tal import bbox2dist
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class VarifocalLoss(nn.Module):
|
|
21
|
-
"""
|
|
22
|
-
Varifocal loss by Zhang et al.
|
|
21
|
+
"""Varifocal loss by Zhang et al.
|
|
23
22
|
|
|
24
23
|
Implements the Varifocal Loss function for addressing class imbalance in object detection by focusing on
|
|
25
24
|
hard-to-classify examples and balancing positive/negative samples.
|
|
@@ -51,11 +50,10 @@ class VarifocalLoss(nn.Module):
|
|
|
51
50
|
|
|
52
51
|
|
|
53
52
|
class FocalLoss(nn.Module):
|
|
54
|
-
"""
|
|
55
|
-
Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).
|
|
53
|
+
"""Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).
|
|
56
54
|
|
|
57
|
-
Implements the Focal Loss function for addressing class imbalance by down-weighting easy examples and focusing
|
|
58
|
-
|
|
55
|
+
Implements the Focal Loss function for addressing class imbalance by down-weighting easy examples and focusing on
|
|
56
|
+
hard negatives during training.
|
|
59
57
|
|
|
60
58
|
Attributes:
|
|
61
59
|
gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
|
|
@@ -399,8 +397,7 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|
|
399
397
|
def single_mask_loss(
|
|
400
398
|
gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor
|
|
401
399
|
) -> torch.Tensor:
|
|
402
|
-
"""
|
|
403
|
-
Compute the instance segmentation loss for a single image.
|
|
400
|
+
"""Compute the instance segmentation loss for a single image.
|
|
404
401
|
|
|
405
402
|
Args:
|
|
406
403
|
gt_mask (torch.Tensor): Ground truth mask of shape (N, H, W), where N is the number of objects.
|
|
@@ -432,8 +429,7 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|
|
432
429
|
imgsz: torch.Tensor,
|
|
433
430
|
overlap: bool,
|
|
434
431
|
) -> torch.Tensor:
|
|
435
|
-
"""
|
|
436
|
-
Calculate the loss for instance segmentation.
|
|
432
|
+
"""Calculate the loss for instance segmentation.
|
|
437
433
|
|
|
438
434
|
Args:
|
|
439
435
|
fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive.
|
|
@@ -585,8 +581,7 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
585
581
|
target_bboxes: torch.Tensor,
|
|
586
582
|
pred_kpts: torch.Tensor,
|
|
587
583
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
588
|
-
"""
|
|
589
|
-
Calculate the keypoints loss for the model.
|
|
584
|
+
"""Calculate the keypoints loss for the model.
|
|
590
585
|
|
|
591
586
|
This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
|
|
592
587
|
based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
|
|
@@ -760,8 +755,7 @@ class v8OBBLoss(v8DetectionLoss):
|
|
|
760
755
|
def bbox_decode(
|
|
761
756
|
self, anchor_points: torch.Tensor, pred_dist: torch.Tensor, pred_angle: torch.Tensor
|
|
762
757
|
) -> torch.Tensor:
|
|
763
|
-
"""
|
|
764
|
-
Decode predicted object bounding box coordinates from anchor points and distribution.
|
|
758
|
+
"""Decode predicted object bounding box coordinates from anchor points and distribution.
|
|
765
759
|
|
|
766
760
|
Args:
|
|
767
761
|
anchor_points (torch.Tensor): Anchor points, (h*w, 2).
|