ultralytics 8.3.101__py3-none-any.whl → 8.3.103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_exports.py +14 -5
- tests/test_solutions.py +140 -76
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +1 -1
- ultralytics/engine/exporter.py +23 -8
- ultralytics/engine/tuner.py +8 -2
- ultralytics/hub/__init__.py +29 -2
- ultralytics/hub/google/__init__.py +18 -1
- ultralytics/models/fastsam/predict.py +12 -1
- ultralytics/models/nas/predict.py +21 -3
- ultralytics/models/rtdetr/val.py +26 -2
- ultralytics/models/sam/amg.py +22 -1
- ultralytics/models/sam/modules/encoders.py +85 -4
- ultralytics/models/sam/modules/memory_attention.py +61 -3
- ultralytics/models/sam/modules/utils.py +108 -5
- ultralytics/models/utils/loss.py +38 -2
- ultralytics/models/utils/ops.py +15 -1
- ultralytics/models/yolo/classify/predict.py +11 -1
- ultralytics/models/yolo/classify/train.py +17 -1
- ultralytics/models/yolo/classify/val.py +82 -6
- ultralytics/models/yolo/detect/predict.py +20 -1
- ultralytics/models/yolo/model.py +55 -4
- ultralytics/models/yolo/obb/predict.py +16 -1
- ultralytics/models/yolo/obb/train.py +35 -2
- ultralytics/models/yolo/obb/val.py +87 -6
- ultralytics/models/yolo/pose/predict.py +18 -1
- ultralytics/models/yolo/pose/train.py +48 -3
- ultralytics/models/yolo/pose/val.py +113 -8
- ultralytics/models/yolo/segment/predict.py +27 -2
- ultralytics/models/yolo/segment/train.py +61 -3
- ultralytics/models/yolo/segment/val.py +10 -1
- ultralytics/models/yolo/world/train_world.py +29 -1
- ultralytics/models/yolo/yoloe/train.py +47 -3
- ultralytics/nn/autobackend.py +9 -8
- ultralytics/nn/modules/activation.py +26 -3
- ultralytics/nn/modules/block.py +89 -0
- ultralytics/nn/modules/head.py +3 -92
- ultralytics/nn/modules/utils.py +70 -4
- ultralytics/nn/tasks.py +3 -0
- ultralytics/nn/text_model.py +93 -17
- ultralytics/solutions/instance_segmentation.py +15 -7
- ultralytics/solutions/solutions.py +2 -47
- ultralytics/utils/benchmarks.py +1 -1
- ultralytics/utils/callbacks/base.py +22 -5
- ultralytics/utils/callbacks/comet.py +93 -5
- ultralytics/utils/callbacks/dvc.py +64 -5
- ultralytics/utils/callbacks/neptune.py +25 -2
- ultralytics/utils/callbacks/tensorboard.py +30 -2
- ultralytics/utils/callbacks/wb.py +16 -1
- ultralytics/utils/dist.py +35 -2
- ultralytics/utils/errors.py +27 -6
- ultralytics/utils/metrics.py +1 -1
- ultralytics/utils/patches.py +33 -5
- ultralytics/utils/torch_utils.py +14 -6
- ultralytics/utils/triton.py +16 -3
- ultralytics/utils/tuner.py +17 -9
- {ultralytics-8.3.101.dist-info → ultralytics-8.3.103.dist-info}/METADATA +3 -4
- {ultralytics-8.3.101.dist-info → ultralytics-8.3.103.dist-info}/RECORD +62 -62
- {ultralytics-8.3.101.dist-info → ultralytics-8.3.103.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.101.dist-info → ultralytics-8.3.103.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.101.dist-info → ultralytics-8.3.103.dist-info}/licenses/LICENSE +0 -0
- {ultralytics-8.3.101.dist-info → ultralytics-8.3.103.dist-info}/top_level.txt +0 -0
@@ -24,14 +24,42 @@ except (ImportError, AssertionError, TypeError, AttributeError):
|
|
24
24
|
|
25
25
|
|
26
26
|
def _log_scalars(scalars: dict, step: int = 0) -> None:
|
27
|
-
"""
|
27
|
+
"""
|
28
|
+
Log scalar values to TensorBoard.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
scalars (dict): Dictionary of scalar values to log to TensorBoard. Keys are scalar names and values are the
|
32
|
+
corresponding scalar values.
|
33
|
+
step (int): Global step value to record with the scalar values. Used for x-axis in TensorBoard graphs.
|
34
|
+
|
35
|
+
Examples:
|
36
|
+
>>> # Log training metrics
|
37
|
+
>>> metrics = {"loss": 0.5, "accuracy": 0.95}
|
38
|
+
>>> _log_scalars(metrics, step=100)
|
39
|
+
"""
|
28
40
|
if WRITER:
|
29
41
|
for k, v in scalars.items():
|
30
42
|
WRITER.add_scalar(k, v, step)
|
31
43
|
|
32
44
|
|
33
45
|
def _log_tensorboard_graph(trainer) -> None:
|
34
|
-
"""
|
46
|
+
"""
|
47
|
+
Log model graph to TensorBoard.
|
48
|
+
|
49
|
+
This function attempts to visualize the model architecture in TensorBoard by tracing the model with a dummy input
|
50
|
+
tensor. It first tries a simple method suitable for YOLO models, and if that fails, falls back to a more complex
|
51
|
+
approach for models like RTDETR that may require special handling.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
trainer (BaseTrainer): The trainer object containing the model to visualize. Must have attributes:
|
55
|
+
- model: PyTorch model to visualize
|
56
|
+
- args: Configuration arguments with 'imgsz' attribute
|
57
|
+
|
58
|
+
Notes:
|
59
|
+
This function requires TensorBoard integration to be enabled and the global WRITER to be initialized.
|
60
|
+
It handles potential warnings from the PyTorch JIT tracer and attempts to gracefully handle different
|
61
|
+
model architectures.
|
62
|
+
"""
|
35
63
|
# Input image
|
36
64
|
imgsz = trainer.args.imgsz
|
37
65
|
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
|
@@ -99,7 +99,22 @@ def _plot_curve(
|
|
99
99
|
|
100
100
|
|
101
101
|
def _log_plots(plots, step):
|
102
|
-
"""
|
102
|
+
"""
|
103
|
+
Log plots to WandB at a specific step if they haven't been logged already.
|
104
|
+
|
105
|
+
This function checks each plot in the input dictionary against previously processed plots and logs
|
106
|
+
new or updated plots to WandB at the specified step.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
plots (dict): Dictionary of plots to log, where keys are plot names and values are dictionaries
|
110
|
+
containing plot metadata including timestamps.
|
111
|
+
step (int): The step/epoch at which to log the plots in the WandB run.
|
112
|
+
|
113
|
+
Notes:
|
114
|
+
- The function uses a shallow copy of the plots dictionary to prevent modification during iteration
|
115
|
+
- Plots are identified by their stem name (filename without extension)
|
116
|
+
- Each plot is logged as a WandB Image object
|
117
|
+
"""
|
103
118
|
for name, params in plots.copy().items(): # shallow copy to prevent plots dict changing during iteration
|
104
119
|
timestamp = params["timestamp"]
|
105
120
|
if _processed_plots.get(name) != timestamp:
|
ultralytics/utils/dist.py
CHANGED
@@ -26,7 +26,26 @@ def find_free_network_port() -> int:
|
|
26
26
|
|
27
27
|
|
28
28
|
def generate_ddp_file(trainer):
|
29
|
-
"""
|
29
|
+
"""
|
30
|
+
Generate a DDP (Distributed Data Parallel) file for multi-GPU training.
|
31
|
+
|
32
|
+
This function creates a temporary Python file that enables distributed training across multiple GPUs.
|
33
|
+
The file contains the necessary configuration to initialize the trainer in a distributed environment.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
trainer (object): The trainer object containing training configuration and arguments.
|
37
|
+
Must have args attribute and be a class instance.
|
38
|
+
|
39
|
+
Returns:
|
40
|
+
(str): Path to the generated temporary DDP file.
|
41
|
+
|
42
|
+
Notes:
|
43
|
+
The generated file is saved in the USER_CONFIG_DIR/DDP directory and includes:
|
44
|
+
- Trainer class import
|
45
|
+
- Configuration overrides from the trainer arguments
|
46
|
+
- Model path configuration
|
47
|
+
- Training initialization code
|
48
|
+
"""
|
30
49
|
module, name = f"{trainer.__class__.__module__}.{trainer.__class__.__name__}".rsplit(".", 1)
|
31
50
|
|
32
51
|
content = f"""
|
@@ -80,6 +99,20 @@ def generate_ddp_command(world_size, trainer):
|
|
80
99
|
|
81
100
|
|
82
101
|
def ddp_cleanup(trainer, file):
|
83
|
-
"""
|
102
|
+
"""
|
103
|
+
Delete temporary file if created during distributed data parallel (DDP) training.
|
104
|
+
|
105
|
+
This function checks if the provided file contains the trainer's ID in its name, indicating it was created
|
106
|
+
as a temporary file for DDP training, and deletes it if so.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
trainer (object): The trainer object used for distributed training.
|
110
|
+
file (str): Path to the file that might need to be deleted.
|
111
|
+
|
112
|
+
Examples:
|
113
|
+
>>> trainer = YOLOTrainer()
|
114
|
+
>>> file = "/tmp/ddp_temp_123456789.py"
|
115
|
+
>>> ddp_cleanup(trainer, file)
|
116
|
+
"""
|
84
117
|
if f"{id(trainer)}.py" in file: # if temp_file suffix in file
|
85
118
|
os.remove(file)
|
ultralytics/utils/errors.py
CHANGED
@@ -5,18 +5,39 @@ from ultralytics.utils import emojis
|
|
5
5
|
|
6
6
|
class HUBModelError(Exception):
|
7
7
|
"""
|
8
|
-
|
8
|
+
Exception raised when a model cannot be found or retrieved from Ultralytics HUB.
|
9
9
|
|
10
|
-
This exception is
|
11
|
-
The message is
|
10
|
+
This custom exception is used specifically for handling errors related to model fetching in Ultralytics YOLO.
|
11
|
+
The error message is processed to include emojis for better user experience.
|
12
12
|
|
13
13
|
Attributes:
|
14
14
|
message (str): The error message displayed when the exception is raised.
|
15
15
|
|
16
|
-
|
17
|
-
|
16
|
+
Methods:
|
17
|
+
__init__: Initialize the HUBModelError with a custom message.
|
18
|
+
|
19
|
+
Examples:
|
20
|
+
>>> try:
|
21
|
+
>>> # Code that might fail to find a model
|
22
|
+
>>> raise HUBModelError("Custom model not found message")
|
23
|
+
>>> except HUBModelError as e:
|
24
|
+
>>> print(e) # Displays the emoji-enhanced error message
|
18
25
|
"""
|
19
26
|
|
20
27
|
def __init__(self, message="Model not found. Please check model URL and try again."):
|
21
|
-
"""
|
28
|
+
"""
|
29
|
+
Initialize a HUBModelError exception.
|
30
|
+
|
31
|
+
This exception is raised when a requested model is not found or cannot be retrieved from Ultralytics HUB.
|
32
|
+
The message is processed to include emojis for better user experience.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
message (str, optional): The error message to display when the exception is raised.
|
36
|
+
|
37
|
+
Examples:
|
38
|
+
>>> try:
|
39
|
+
... raise HUBModelError("Custom model error message")
|
40
|
+
... except HUBModelError as e:
|
41
|
+
... print(e)
|
42
|
+
"""
|
22
43
|
super().__init__(emojis(message))
|
ultralytics/utils/metrics.py
CHANGED
@@ -523,7 +523,7 @@ def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names={}, xlabel="Confi
|
|
523
523
|
else:
|
524
524
|
ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric)
|
525
525
|
|
526
|
-
y = smooth(py.mean(0), 0.
|
526
|
+
y = smooth(py.mean(0), 0.1)
|
527
527
|
ax.plot(px, y, linewidth=3, color="blue", label=f"all classes {y.max():.2f} at {px[y.argmax()]:.3f}")
|
528
528
|
ax.set_xlabel(xlabel)
|
529
529
|
ax.set_ylabel(ylabel)
|
ultralytics/utils/patches.py
CHANGED
@@ -18,10 +18,14 @@ def imread(filename: str, flags: int = cv2.IMREAD_COLOR):
|
|
18
18
|
|
19
19
|
Args:
|
20
20
|
filename (str): Path to the file to read.
|
21
|
-
flags (int
|
21
|
+
flags (int): Flag that can take values of cv2.IMREAD_*. Controls how the image is read.
|
22
22
|
|
23
23
|
Returns:
|
24
24
|
(np.ndarray): The read image.
|
25
|
+
|
26
|
+
Examples:
|
27
|
+
>>> img = imread("path/to/image.jpg")
|
28
|
+
>>> img = imread("path/to/image.jpg", cv2.IMREAD_GRAYSCALE)
|
25
29
|
"""
|
26
30
|
return cv2.imdecode(np.fromfile(filename, np.uint8), flags)
|
27
31
|
|
@@ -36,7 +40,14 @@ def imwrite(filename: str, img: np.ndarray, params=None):
|
|
36
40
|
params (List[int], optional): Additional parameters for image encoding.
|
37
41
|
|
38
42
|
Returns:
|
39
|
-
(bool): True if the file was written, False otherwise.
|
43
|
+
(bool): True if the file was written successfully, False otherwise.
|
44
|
+
|
45
|
+
Examples:
|
46
|
+
>>> import numpy as np
|
47
|
+
>>> img = np.zeros((100, 100, 3), dtype=np.uint8) # Create a black image
|
48
|
+
>>> success = imwrite("output.jpg", img) # Write image to file
|
49
|
+
>>> print(success)
|
50
|
+
True
|
40
51
|
"""
|
41
52
|
try:
|
42
53
|
cv2.imencode(Path(filename).suffix, img, params)[1].tofile(filename)
|
@@ -49,9 +60,19 @@ def imshow(winname: str, mat: np.ndarray):
|
|
49
60
|
"""
|
50
61
|
Display an image in the specified window.
|
51
62
|
|
63
|
+
This function is a wrapper around OpenCV's imshow function that displays an image in a named window. It is
|
64
|
+
particularly useful for visualizing images during development and debugging.
|
65
|
+
|
52
66
|
Args:
|
53
|
-
winname (str): Name of the window.
|
54
|
-
|
67
|
+
winname (str): Name of the window where the image will be displayed. If a window with this name already
|
68
|
+
exists, the image will be displayed in that window.
|
69
|
+
mat (np.ndarray): Image to be shown. Should be a valid numpy array representing an image.
|
70
|
+
|
71
|
+
Examples:
|
72
|
+
>>> import numpy as np
|
73
|
+
>>> img = np.zeros((300, 300, 3), dtype=np.uint8) # Create a black image
|
74
|
+
>>> img[:100, :100] = [255, 0, 0] # Add a blue square
|
75
|
+
>>> imshow("Example Window", img) # Display the image
|
55
76
|
"""
|
56
77
|
_imshow(winname.encode("unicode_escape").decode(), mat)
|
57
78
|
|
@@ -74,7 +95,7 @@ def torch_load(*args, **kwargs):
|
|
74
95
|
Returns:
|
75
96
|
(Any): The loaded PyTorch object.
|
76
97
|
|
77
|
-
|
98
|
+
Notes:
|
78
99
|
For PyTorch versions 2.0 and above, this function automatically sets 'weights_only=False'
|
79
100
|
if the argument is not provided, to avoid deprecation warnings.
|
80
101
|
"""
|
@@ -96,6 +117,13 @@ def torch_save(*args, **kwargs):
|
|
96
117
|
Args:
|
97
118
|
*args (Any): Positional arguments to pass to torch.save.
|
98
119
|
**kwargs (Any): Keyword arguments to pass to torch.save.
|
120
|
+
|
121
|
+
Returns:
|
122
|
+
(Any): Result of torch.save operation if successful, None otherwise.
|
123
|
+
|
124
|
+
Examples:
|
125
|
+
>>> model = torch.nn.Linear(10, 1)
|
126
|
+
>>> torch_save(model.state_dict(), "model.pt")
|
99
127
|
"""
|
100
128
|
for i in range(4): # 3 retries
|
101
129
|
try:
|
ultralytics/utils/torch_utils.py
CHANGED
@@ -386,14 +386,18 @@ def model_info_for_loggers(trainer):
|
|
386
386
|
|
387
387
|
def get_flops(model, imgsz=640):
|
388
388
|
"""
|
389
|
-
|
389
|
+
Calculate FLOPs (floating point operations) for a model in billions.
|
390
|
+
|
391
|
+
Attempts two calculation methods: first with a stride-based tensor for efficiency,
|
392
|
+
then falls back to full image size if needed (e.g., for RTDETR models). Returns 0.0
|
393
|
+
if thop library is unavailable or calculation fails.
|
390
394
|
|
391
395
|
Args:
|
392
396
|
model (nn.Module): The model to calculate FLOPs for.
|
393
397
|
imgsz (int | List[int], optional): Input image size. Defaults to 640.
|
394
398
|
|
395
399
|
Returns:
|
396
|
-
(float): The model
|
400
|
+
(float): The model FLOPs in billions.
|
397
401
|
"""
|
398
402
|
if not thop:
|
399
403
|
return 0.0 # if not installed return 0.0 GFLOPs
|
@@ -404,13 +408,13 @@ def get_flops(model, imgsz=640):
|
|
404
408
|
if not isinstance(imgsz, list):
|
405
409
|
imgsz = [imgsz, imgsz] # expand if int/float
|
406
410
|
try:
|
407
|
-
# Use stride
|
411
|
+
# Method 1: Use stride-based input tensor
|
408
412
|
stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32 # max stride
|
409
413
|
im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
410
414
|
flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # stride GFLOPs
|
411
415
|
return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs
|
412
416
|
except Exception:
|
413
|
-
# Use actual image size
|
417
|
+
# Method 2: Use actual image size (required for RTDETR models)
|
414
418
|
im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format
|
415
419
|
return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # imgsz GFLOPs
|
416
420
|
except Exception:
|
@@ -611,10 +615,10 @@ def unset_deterministic():
|
|
611
615
|
|
612
616
|
class ModelEMA:
|
613
617
|
"""
|
614
|
-
Updated Exponential Moving Average (EMA)
|
618
|
+
Updated Exponential Moving Average (EMA) implementation.
|
615
619
|
|
616
620
|
Keeps a moving average of everything in the model state_dict (parameters and buffers).
|
617
|
-
For EMA details see
|
621
|
+
For EMA details see References.
|
618
622
|
|
619
623
|
To disable EMA set the `enabled` attribute to `False`.
|
620
624
|
|
@@ -623,6 +627,10 @@ class ModelEMA:
|
|
623
627
|
updates (int): Number of EMA updates.
|
624
628
|
decay (function): Decay function that determines the EMA weight.
|
625
629
|
enabled (bool): Whether EMA is enabled.
|
630
|
+
|
631
|
+
References:
|
632
|
+
- https://github.com/rwightman/pytorch-image-models
|
633
|
+
- https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
626
634
|
"""
|
627
635
|
|
628
636
|
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
ultralytics/utils/triton.py
CHANGED
@@ -25,6 +25,9 @@ class TritonRemoteModel:
|
|
25
25
|
output_names (List[str]): The names of the model outputs.
|
26
26
|
metadata: The metadata associated with the model.
|
27
27
|
|
28
|
+
Methods:
|
29
|
+
__call__: Call the model with the given inputs and return the outputs.
|
30
|
+
|
28
31
|
Examples:
|
29
32
|
Initialize a Triton client with HTTP
|
30
33
|
>>> model = TritonRemoteModel(url="localhost:8000", endpoint="yolov8", scheme="http")
|
@@ -34,7 +37,7 @@ class TritonRemoteModel:
|
|
34
37
|
|
35
38
|
def __init__(self, url: str, endpoint: str = "", scheme: str = ""):
|
36
39
|
"""
|
37
|
-
Initialize the TritonRemoteModel.
|
40
|
+
Initialize the TritonRemoteModel for interacting with a remote Triton Inference Server.
|
38
41
|
|
39
42
|
Arguments may be provided individually or parsed from a collective 'url' argument of the form
|
40
43
|
<scheme>://<netloc>/<endpoint>/<task_name>
|
@@ -43,6 +46,10 @@ class TritonRemoteModel:
|
|
43
46
|
url (str): The URL of the Triton server.
|
44
47
|
endpoint (str): The name of the model on the Triton server.
|
45
48
|
scheme (str): The communication scheme ('http' or 'grpc').
|
49
|
+
|
50
|
+
Examples:
|
51
|
+
>>> model = TritonRemoteModel(url="localhost:8000", endpoint="yolov8", scheme="http")
|
52
|
+
>>> model = TritonRemoteModel(url="http://localhost:8000/yolov8")
|
46
53
|
"""
|
47
54
|
if not endpoint and not scheme: # Parse all args from URL string
|
48
55
|
splits = urlsplit(url)
|
@@ -83,10 +90,16 @@ class TritonRemoteModel:
|
|
83
90
|
Call the model with the given inputs.
|
84
91
|
|
85
92
|
Args:
|
86
|
-
*inputs (np.ndarray): Input data to the model.
|
93
|
+
*inputs (np.ndarray): Input data to the model. Each array should match the expected shape and type
|
94
|
+
for the corresponding model input.
|
87
95
|
|
88
96
|
Returns:
|
89
|
-
(List[np.ndarray]): Model outputs with the same dtype as the input.
|
97
|
+
(List[np.ndarray]): Model outputs with the same dtype as the input. Each element in the list
|
98
|
+
corresponds to one of the model's output tensors.
|
99
|
+
|
100
|
+
Examples:
|
101
|
+
>>> model = TritonRemoteModel(url="localhost:8000", endpoint="yolov8", scheme="http")
|
102
|
+
>>> outputs = model(np.random.rand(1, 3, 640, 640).astype(np.float32))
|
90
103
|
"""
|
91
104
|
infer_inputs = []
|
92
105
|
input_format = inputs[0].dtype
|
ultralytics/utils/tuner.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_cfg, get_save_dir
|
4
|
-
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks
|
4
|
+
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks, colorstr
|
5
5
|
|
6
6
|
|
7
7
|
def run_ray_tune(
|
@@ -95,7 +95,7 @@ def run_ray_tune(
|
|
95
95
|
return results.results_dict
|
96
96
|
|
97
97
|
# Get search space
|
98
|
-
if not space:
|
98
|
+
if not space and not train_args.get("resume"):
|
99
99
|
space = default_space
|
100
100
|
LOGGER.warning("WARNING ⚠️ search space not provided, using default search space.")
|
101
101
|
|
@@ -123,15 +123,23 @@ def run_ray_tune(
|
|
123
123
|
|
124
124
|
# Create the Ray Tune hyperparameter search tuner
|
125
125
|
tune_dir = get_save_dir(
|
126
|
-
get_cfg(
|
126
|
+
get_cfg(
|
127
|
+
DEFAULT_CFG,
|
128
|
+
{**train_args, **{"exist_ok": train_args.pop("resume", False)}}, # resume w/ same tune_dir
|
129
|
+
),
|
130
|
+
name=train_args.pop("name", "tune"), # runs/{task}/{tune_dir}
|
127
131
|
).resolve() # must be absolute dir
|
128
132
|
tune_dir.mkdir(parents=True, exist_ok=True)
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
133
|
+
if tune.Tuner.can_restore(tune_dir):
|
134
|
+
LOGGER.info(f"{colorstr('Tuner: ')} Resuming tuning run {tune_dir}...")
|
135
|
+
tuner = tune.Tuner.restore(str(tune_dir), trainable=trainable_with_resources, resume_errored=True)
|
136
|
+
else:
|
137
|
+
tuner = tune.Tuner(
|
138
|
+
trainable_with_resources,
|
139
|
+
param_space=space,
|
140
|
+
tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples),
|
141
|
+
run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir.parent, name=tune_dir.name),
|
142
|
+
)
|
135
143
|
|
136
144
|
# Run the hyperparameter search
|
137
145
|
tuner.fit()
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ultralytics
|
3
|
-
Version: 8.3.
|
3
|
+
Version: 8.3.103
|
4
4
|
Summary: Ultralytics YOLO 🚀 for SOTA object detection, multi-object tracking, instance segmentation, pose estimation and image classification.
|
5
5
|
Author-email: Glenn Jocher <glenn.jocher@ultralytics.com>, Jing Qiu <jing.qiu@ultralytics.com>
|
6
6
|
Maintainer-email: Ultralytics <hello@ultralytics.com>
|
@@ -56,7 +56,6 @@ Requires-Dist: coverage[toml]; extra == "dev"
|
|
56
56
|
Requires-Dist: mkdocs>=1.6.0; extra == "dev"
|
57
57
|
Requires-Dist: mkdocs-material>=9.5.9; extra == "dev"
|
58
58
|
Requires-Dist: mkdocstrings[python]; extra == "dev"
|
59
|
-
Requires-Dist: mkdocs-redirects; extra == "dev"
|
60
59
|
Requires-Dist: mkdocs-ultralytics-plugin>=0.1.17; extra == "dev"
|
61
60
|
Requires-Dist: mkdocs-macros-plugin>=1.0.5; extra == "dev"
|
62
61
|
Provides-Extra: export
|
@@ -71,8 +70,8 @@ Requires-Dist: keras; extra == "export"
|
|
71
70
|
Requires-Dist: flatbuffers<100,>=23.5.26; platform_machine == "aarch64" and extra == "export"
|
72
71
|
Requires-Dist: h5py!=3.11.0; platform_machine == "aarch64" and extra == "export"
|
73
72
|
Provides-Extra: solutions
|
74
|
-
Requires-Dist: shapely
|
75
|
-
Requires-Dist: streamlit; extra == "solutions"
|
73
|
+
Requires-Dist: shapely<2.1.0,>=2.0.0; extra == "solutions"
|
74
|
+
Requires-Dist: streamlit<1.44.0,>=1.29.0; extra == "solutions"
|
76
75
|
Provides-Extra: logging
|
77
76
|
Requires-Dist: comet; extra == "logging"
|
78
77
|
Requires-Dist: tensorboard>=2.13.0; extra == "logging"
|