ultralytics 8.3.218__py3-none-any.whl → 8.3.222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +5 -7
 - tests/conftest.py +3 -7
 - tests/test_cli.py +9 -2
 - tests/test_engine.py +1 -1
 - tests/test_exports.py +37 -9
 - tests/test_integrations.py +4 -4
 - tests/test_python.py +40 -47
 - tests/test_solutions.py +154 -145
 - ultralytics/__init__.py +1 -1
 - ultralytics/cfg/__init__.py +7 -5
 - ultralytics/cfg/default.yaml +1 -1
 - ultralytics/data/__init__.py +4 -4
 - ultralytics/data/augment.py +10 -10
 - ultralytics/data/base.py +2 -2
 - ultralytics/data/build.py +1 -1
 - ultralytics/data/converter.py +3 -3
 - ultralytics/data/dataset.py +3 -3
 - ultralytics/data/loaders.py +2 -2
 - ultralytics/data/utils.py +3 -3
 - ultralytics/engine/exporter.py +79 -29
 - ultralytics/engine/model.py +2 -2
 - ultralytics/engine/predictor.py +1 -0
 - ultralytics/engine/trainer.py +6 -4
 - ultralytics/engine/tuner.py +4 -4
 - ultralytics/hub/__init__.py +9 -7
 - ultralytics/hub/utils.py +2 -2
 - ultralytics/models/__init__.py +1 -1
 - ultralytics/models/fastsam/__init__.py +1 -1
 - ultralytics/models/fastsam/predict.py +10 -16
 - ultralytics/models/nas/__init__.py +1 -1
 - ultralytics/models/rtdetr/__init__.py +1 -1
 - ultralytics/models/sam/__init__.py +1 -1
 - ultralytics/models/sam/amg.py +2 -2
 - ultralytics/models/sam/modules/blocks.py +1 -1
 - ultralytics/models/sam/modules/transformer.py +1 -1
 - ultralytics/models/sam/predict.py +1 -1
 - ultralytics/models/yolo/__init__.py +1 -1
 - ultralytics/models/yolo/classify/train.py +2 -2
 - ultralytics/models/yolo/pose/__init__.py +1 -1
 - ultralytics/models/yolo/segment/val.py +1 -1
 - ultralytics/models/yolo/yoloe/__init__.py +7 -7
 - ultralytics/nn/__init__.py +7 -7
 - ultralytics/nn/autobackend.py +32 -5
 - ultralytics/nn/modules/__init__.py +60 -60
 - ultralytics/nn/modules/block.py +26 -26
 - ultralytics/nn/modules/conv.py +7 -7
 - ultralytics/nn/modules/head.py +1 -1
 - ultralytics/nn/modules/transformer.py +7 -7
 - ultralytics/nn/modules/utils.py +1 -1
 - ultralytics/nn/tasks.py +3 -3
 - ultralytics/solutions/__init__.py +12 -12
 - ultralytics/solutions/object_counter.py +3 -6
 - ultralytics/solutions/queue_management.py +1 -1
 - ultralytics/solutions/similarity_search.py +3 -3
 - ultralytics/trackers/__init__.py +1 -1
 - ultralytics/trackers/byte_tracker.py +2 -2
 - ultralytics/trackers/utils/matching.py +1 -1
 - ultralytics/utils/__init__.py +6 -6
 - ultralytics/utils/benchmarks.py +7 -5
 - ultralytics/utils/callbacks/comet.py +2 -2
 - ultralytics/utils/checks.py +2 -2
 - ultralytics/utils/downloads.py +2 -2
 - ultralytics/utils/export/__init__.py +1 -1
 - ultralytics/utils/export/imx.py +39 -28
 - ultralytics/utils/files.py +1 -1
 - ultralytics/utils/git.py +1 -1
 - ultralytics/utils/logger.py +1 -1
 - ultralytics/utils/metrics.py +15 -11
 - ultralytics/utils/ops.py +8 -8
 - ultralytics/utils/plotting.py +3 -2
 - ultralytics/utils/torch_utils.py +5 -4
 - ultralytics/utils/triton.py +2 -2
 - ultralytics/utils/tuner.py +4 -2
 - {ultralytics-8.3.218.dist-info → ultralytics-8.3.222.dist-info}/METADATA +1 -1
 - {ultralytics-8.3.218.dist-info → ultralytics-8.3.222.dist-info}/RECORD +79 -79
 - {ultralytics-8.3.218.dist-info → ultralytics-8.3.222.dist-info}/WHEEL +0 -0
 - {ultralytics-8.3.218.dist-info → ultralytics-8.3.222.dist-info}/entry_points.txt +0 -0
 - {ultralytics-8.3.218.dist-info → ultralytics-8.3.222.dist-info}/licenses/LICENSE +0 -0
 - {ultralytics-8.3.218.dist-info → ultralytics-8.3.222.dist-info}/top_level.txt +0 -0
 
    
        tests/__init__.py
    CHANGED
    
    | 
         @@ -1,25 +1,23 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
         
     | 
| 
       2 
2 
     | 
    
         | 
| 
       3 
3 
     | 
    
         
             
            from ultralytics.cfg import TASK2DATA, TASK2MODEL, TASKS
         
     | 
| 
       4 
     | 
    
         
            -
            from ultralytics.utils import ASSETS,  
     | 
| 
      
 4 
     | 
    
         
            +
            from ultralytics.utils import ASSETS, WEIGHTS_DIR, checks
         
     | 
| 
       5 
5 
     | 
    
         | 
| 
       6 
6 
     | 
    
         
             
            # Constants used in tests
         
     | 
| 
       7 
7 
     | 
    
         
             
            MODEL = WEIGHTS_DIR / "path with spaces" / "yolo11n.pt"  # test spaces in path
         
     | 
| 
       8 
8 
     | 
    
         
             
            CFG = "yolo11n.yaml"
         
     | 
| 
       9 
9 
     | 
    
         
             
            SOURCE = ASSETS / "bus.jpg"
         
     | 
| 
       10 
10 
     | 
    
         
             
            SOURCES_LIST = [ASSETS / "bus.jpg", ASSETS, ASSETS / "*", ASSETS / "**/*.jpg"]
         
     | 
| 
       11 
     | 
    
         
            -
            TMP = (ROOT / "../tests/tmp").resolve()  # temp directory for test files
         
     | 
| 
       12 
11 
     | 
    
         
             
            CUDA_IS_AVAILABLE = checks.cuda_is_available()
         
     | 
| 
       13 
12 
     | 
    
         
             
            CUDA_DEVICE_COUNT = checks.cuda_device_count()
         
     | 
| 
       14 
13 
     | 
    
         
             
            TASK_MODEL_DATA = [(task, WEIGHTS_DIR / TASK2MODEL[task], TASK2DATA[task]) for task in TASKS]
         
     | 
| 
       15 
     | 
    
         
            -
            MODELS = frozenset(list(TASK2MODEL.values())  
     | 
| 
      
 14 
     | 
    
         
            +
            MODELS = frozenset([*list(TASK2MODEL.values()), "yolo11n-grayscale.pt"])
         
     | 
| 
       16 
15 
     | 
    
         | 
| 
       17 
16 
     | 
    
         
             
            __all__ = (
         
     | 
| 
       18 
     | 
    
         
            -
                "MODEL",
         
     | 
| 
       19 
17 
     | 
    
         
             
                "CFG",
         
     | 
| 
      
 18 
     | 
    
         
            +
                "CUDA_DEVICE_COUNT",
         
     | 
| 
      
 19 
     | 
    
         
            +
                "CUDA_IS_AVAILABLE",
         
     | 
| 
      
 20 
     | 
    
         
            +
                "MODEL",
         
     | 
| 
       20 
21 
     | 
    
         
             
                "SOURCE",
         
     | 
| 
       21 
22 
     | 
    
         
             
                "SOURCES_LIST",
         
     | 
| 
       22 
     | 
    
         
            -
                "TMP",
         
     | 
| 
       23 
     | 
    
         
            -
                "CUDA_IS_AVAILABLE",
         
     | 
| 
       24 
     | 
    
         
            -
                "CUDA_DEVICE_COUNT",
         
     | 
| 
       25 
23 
     | 
    
         
             
            )
         
     | 
    
        tests/conftest.py
    CHANGED
    
    | 
         @@ -3,8 +3,6 @@ 
     | 
|
| 
       3 
3 
     | 
    
         
             
            import shutil
         
     | 
| 
       4 
4 
     | 
    
         
             
            from pathlib import Path
         
     | 
| 
       5 
5 
     | 
    
         | 
| 
       6 
     | 
    
         
            -
            from tests import TMP
         
     | 
| 
       7 
     | 
    
         
            -
             
     | 
| 
       8 
6 
     | 
    
         | 
| 
       9 
7 
     | 
    
         
             
            def pytest_addoption(parser):
         
     | 
| 
       10 
8 
     | 
    
         
             
                """Add custom command-line options to pytest."""
         
     | 
| 
         @@ -29,7 +27,7 @@ def pytest_sessionstart(session): 
     | 
|
| 
       29 
27 
     | 
    
         
             
                Initialize session configurations for pytest.
         
     | 
| 
       30 
28 
     | 
    
         | 
| 
       31 
29 
     | 
    
         
             
                This function is automatically called by pytest after the 'Session' object has been created but before performing
         
     | 
| 
       32 
     | 
    
         
            -
                test collection. It sets the initial seeds  
     | 
| 
      
 30 
     | 
    
         
            +
                test collection. It sets the initial seeds for the test session.
         
     | 
| 
       33 
31 
     | 
    
         | 
| 
       34 
32 
     | 
    
         
             
                Args:
         
     | 
| 
       35 
33 
     | 
    
         
             
                    session: The pytest session object.
         
     | 
| 
         @@ -37,8 +35,6 @@ def pytest_sessionstart(session): 
     | 
|
| 
       37 
35 
     | 
    
         
             
                from ultralytics.utils.torch_utils import init_seeds
         
     | 
| 
       38 
36 
     | 
    
         | 
| 
       39 
37 
     | 
    
         
             
                init_seeds()
         
     | 
| 
       40 
     | 
    
         
            -
                shutil.rmtree(TMP, ignore_errors=True)  # Delete any existing tests/tmp directory
         
     | 
| 
       41 
     | 
    
         
            -
                TMP.mkdir(parents=True, exist_ok=True)  # Create a new empty directory
         
     | 
| 
       42 
38 
     | 
    
         | 
| 
       43 
39 
     | 
    
         | 
| 
       44 
40 
     | 
    
         
             
            def pytest_terminal_summary(terminalreporter, exitstatus, config):
         
     | 
| 
         @@ -57,10 +53,10 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): 
     | 
|
| 
       57 
53 
     | 
    
         | 
| 
       58 
54 
     | 
    
         
             
                # Remove files
         
     | 
| 
       59 
55 
     | 
    
         
             
                models = [path for x in {"*.onnx", "*.torchscript"} for path in WEIGHTS_DIR.rglob(x)]
         
     | 
| 
       60 
     | 
    
         
            -
                for file in ["decelera_portrait_min.mov", "bus.jpg", "yolo11n.onnx", "yolo11n.torchscript" 
     | 
| 
      
 56 
     | 
    
         
            +
                for file in ["decelera_portrait_min.mov", "bus.jpg", "yolo11n.onnx", "yolo11n.torchscript", *models]:
         
     | 
| 
       61 
57 
     | 
    
         
             
                    Path(file).unlink(missing_ok=True)
         
     | 
| 
       62 
58 
     | 
    
         | 
| 
       63 
59 
     | 
    
         
             
                # Remove directories
         
     | 
| 
       64 
60 
     | 
    
         
             
                models = [path for x in {"*.mlpackage", "*_openvino_model"} for path in WEIGHTS_DIR.rglob(x)]
         
     | 
| 
       65 
     | 
    
         
            -
                for directory in [WEIGHTS_DIR / "path with spaces",  
     | 
| 
      
 61 
     | 
    
         
            +
                for directory in [WEIGHTS_DIR / "path with spaces", *models]:
         
     | 
| 
       66 
62 
     | 
    
         
             
                    shutil.rmtree(directory, ignore_errors=True)
         
     | 
    
        tests/test_cli.py
    CHANGED
    
    | 
         @@ -8,7 +8,7 @@ from PIL import Image 
     | 
|
| 
       8 
8 
     | 
    
         | 
| 
       9 
9 
     | 
    
         
             
            from tests import CUDA_DEVICE_COUNT, CUDA_IS_AVAILABLE, MODELS, TASK_MODEL_DATA
         
     | 
| 
       10 
10 
     | 
    
         
             
            from ultralytics.utils import ARM64, ASSETS, LINUX, WEIGHTS_DIR, checks
         
     | 
| 
       11 
     | 
    
         
            -
            from ultralytics.utils.torch_utils import TORCH_1_11
         
     | 
| 
      
 11 
     | 
    
         
            +
            from ultralytics.utils.torch_utils import TORCH_1_11, TORCH_2_9, WINDOWS
         
     | 
| 
       12 
12 
     | 
    
         | 
| 
       13 
13 
     | 
    
         | 
| 
       14 
14 
     | 
    
         
             
            def run(cmd: str) -> None:
         
     | 
| 
         @@ -82,7 +82,7 @@ def test_fastsam( 
     | 
|
| 
       82 
82 
     | 
    
         
             
                    everything_results = sam_model(s, device="cpu", retina_masks=True, imgsz=320, conf=0.4, iou=0.9)
         
     | 
| 
       83 
83 
     | 
    
         | 
| 
       84 
84 
     | 
    
         
             
                    # Remove small regions
         
     | 
| 
       85 
     | 
    
         
            -
                     
     | 
| 
      
 85 
     | 
    
         
            +
                    _new_masks, _ = Predictor.remove_small_regions(everything_results[0].masks.data, min_area=20)
         
     | 
| 
       86 
86 
     | 
    
         | 
| 
       87 
87 
     | 
    
         
             
                    # Run inference with bboxes and points and texts prompt at the same time
         
     | 
| 
       88 
88 
     | 
    
         
             
                    sam_model(source, bboxes=[439, 437, 524, 709], points=[[200, 200]], labels=[1], texts="a photo of a dog")
         
     | 
| 
         @@ -129,3 +129,10 @@ def test_train_gpu(task: str, model: str, data: str) -> None: 
     | 
|
| 
       129 
129 
     | 
    
         
             
            def test_solutions(solution: str) -> None:
         
     | 
| 
       130 
130 
     | 
    
         
             
                """Test yolo solutions command-line modes."""
         
     | 
| 
       131 
131 
     | 
    
         
             
                run(f"yolo solutions {solution} verbose=False")
         
     | 
| 
      
 132 
     | 
    
         
            +
             
     | 
| 
      
 133 
     | 
    
         
            +
             
     | 
| 
      
 134 
     | 
    
         
            +
            @pytest.mark.skipif(not checks.IS_PYTHON_MINIMUM_3_10 or not TORCH_2_9, reason="Requires Python>=3.10 and Torch>=2.9.0")
         
     | 
| 
      
 135 
     | 
    
         
            +
            @pytest.mark.skipif(WINDOWS, reason="Skipping test on Windows")
         
     | 
| 
      
 136 
     | 
    
         
            +
            def test_export_executorch() -> None:
         
     | 
| 
      
 137 
     | 
    
         
            +
                """Test exporting a YOLO model to ExecuTorch format via CLI."""
         
     | 
| 
      
 138 
     | 
    
         
            +
                run("yolo export model=yolo11n.pt format=executorch imgsz=32")
         
     | 
    
        tests/test_engine.py
    CHANGED
    
    | 
         @@ -13,7 +13,7 @@ from ultralytics.models.yolo import classify, detect, segment 
     | 
|
| 
       13 
13 
     | 
    
         
             
            from ultralytics.utils import ASSETS, DEFAULT_CFG, WEIGHTS_DIR
         
     | 
| 
       14 
14 
     | 
    
         | 
| 
       15 
15 
     | 
    
         | 
| 
       16 
     | 
    
         
            -
            def test_func(*args): 
     | 
| 
      
 16 
     | 
    
         
            +
            def test_func(*args):
         
     | 
| 
       17 
17 
     | 
    
         
             
                """Test function callback for evaluating YOLO model performance metrics."""
         
     | 
| 
       18 
18 
     | 
    
         
             
                print("callback test passed")
         
     | 
| 
       19 
19 
     | 
    
         | 
    
        tests/test_exports.py
    CHANGED
    
    | 
         @@ -12,15 +12,8 @@ import pytest 
     | 
|
| 
       12 
12 
     | 
    
         
             
            from tests import MODEL, SOURCE
         
     | 
| 
       13 
13 
     | 
    
         
             
            from ultralytics import YOLO
         
     | 
| 
       14 
14 
     | 
    
         
             
            from ultralytics.cfg import TASK2DATA, TASK2MODEL, TASKS
         
     | 
| 
       15 
     | 
    
         
            -
            from ultralytics.utils import  
     | 
| 
       16 
     | 
    
         
            -
             
     | 
| 
       17 
     | 
    
         
            -
                IS_RASPBERRYPI,
         
     | 
| 
       18 
     | 
    
         
            -
                LINUX,
         
     | 
| 
       19 
     | 
    
         
            -
                MACOS,
         
     | 
| 
       20 
     | 
    
         
            -
                WINDOWS,
         
     | 
| 
       21 
     | 
    
         
            -
                checks,
         
     | 
| 
       22 
     | 
    
         
            -
            )
         
     | 
| 
       23 
     | 
    
         
            -
            from ultralytics.utils.torch_utils import TORCH_1_11, TORCH_1_13, TORCH_2_1
         
     | 
| 
      
 15 
     | 
    
         
            +
            from ultralytics.utils import ARM64, IS_RASPBERRYPI, LINUX, MACOS, WINDOWS, checks
         
     | 
| 
      
 16 
     | 
    
         
            +
            from ultralytics.utils.torch_utils import TORCH_1_11, TORCH_1_13, TORCH_2_1, TORCH_2_9
         
     | 
| 
       24 
17 
     | 
    
         | 
| 
       25 
18 
     | 
    
         | 
| 
       26 
19 
     | 
    
         
             
            def test_export_torchscript():
         
     | 
| 
         @@ -262,3 +255,38 @@ def test_export_imx(): 
     | 
|
| 
       262 
255 
     | 
    
         
             
                model = YOLO("yolov8n.pt")
         
     | 
| 
       263 
256 
     | 
    
         
             
                file = model.export(format="imx", imgsz=32)
         
     | 
| 
       264 
257 
     | 
    
         
             
                YOLO(file)(SOURCE, imgsz=32)
         
     | 
| 
      
 258 
     | 
    
         
            +
             
     | 
| 
      
 259 
     | 
    
         
            +
             
     | 
| 
      
 260 
     | 
    
         
            +
            @pytest.mark.skipif(not checks.IS_PYTHON_MINIMUM_3_10 or not TORCH_2_9, reason="Requires Python>=3.10 and Torch>=2.9.0")
         
     | 
| 
      
 261 
     | 
    
         
            +
            @pytest.mark.skipif(WINDOWS, reason="Skipping test on Windows")
         
     | 
| 
      
 262 
     | 
    
         
            +
            def test_export_executorch():
         
     | 
| 
      
 263 
     | 
    
         
            +
                """Test YOLO model export to ExecuTorch format."""
         
     | 
| 
      
 264 
     | 
    
         
            +
                file = YOLO(MODEL).export(format="executorch", imgsz=32)
         
     | 
| 
      
 265 
     | 
    
         
            +
                assert Path(file).exists(), f"ExecuTorch export failed, directory not found: {file}"
         
     | 
| 
      
 266 
     | 
    
         
            +
                # Check that .pte file exists in the exported directory
         
     | 
| 
      
 267 
     | 
    
         
            +
                pte_file = Path(file) / Path(MODEL).with_suffix(".pte").name
         
     | 
| 
      
 268 
     | 
    
         
            +
                assert pte_file.exists(), f"ExecuTorch .pte file not found: {pte_file}"
         
     | 
| 
      
 269 
     | 
    
         
            +
                # Check that metadata.yaml exists
         
     | 
| 
      
 270 
     | 
    
         
            +
                metadata_file = Path(file) / "metadata.yaml"
         
     | 
| 
      
 271 
     | 
    
         
            +
                assert metadata_file.exists(), f"ExecuTorch metadata.yaml not found: {metadata_file}"
         
     | 
| 
      
 272 
     | 
    
         
            +
                # Note: Inference testing skipped as ExecuTorch requires special runtime setup
         
     | 
| 
      
 273 
     | 
    
         
            +
                shutil.rmtree(file, ignore_errors=True)  # cleanup
         
     | 
| 
      
 274 
     | 
    
         
            +
             
     | 
| 
      
 275 
     | 
    
         
            +
             
     | 
| 
      
 276 
     | 
    
         
            +
            @pytest.mark.slow
         
     | 
| 
      
 277 
     | 
    
         
            +
            @pytest.mark.skipif(not checks.IS_PYTHON_MINIMUM_3_10 or not TORCH_2_9, reason="Requires Python>=3.10 and Torch>=2.9.0")
         
     | 
| 
      
 278 
     | 
    
         
            +
            @pytest.mark.skipif(WINDOWS, reason="Skipping test on Windows")
         
     | 
| 
      
 279 
     | 
    
         
            +
            @pytest.mark.parametrize("task", TASKS)
         
     | 
| 
      
 280 
     | 
    
         
            +
            def test_export_executorch_matrix(task):
         
     | 
| 
      
 281 
     | 
    
         
            +
                """Test YOLO export to ExecuTorch format for various task types."""
         
     | 
| 
      
 282 
     | 
    
         
            +
                file = YOLO(TASK2MODEL[task]).export(format="executorch", imgsz=32)
         
     | 
| 
      
 283 
     | 
    
         
            +
                assert Path(file).exists(), f"ExecuTorch export failed for task '{task}', directory not found: {file}"
         
     | 
| 
      
 284 
     | 
    
         
            +
                # Check that .pte file exists in the exported directory
         
     | 
| 
      
 285 
     | 
    
         
            +
                model_name = Path(TASK2MODEL[task]).with_suffix(".pte").name
         
     | 
| 
      
 286 
     | 
    
         
            +
                pte_file = Path(file) / model_name
         
     | 
| 
      
 287 
     | 
    
         
            +
                assert pte_file.exists(), f"ExecuTorch .pte file not found for task '{task}': {pte_file}"
         
     | 
| 
      
 288 
     | 
    
         
            +
                # Check that metadata.yaml exists
         
     | 
| 
      
 289 
     | 
    
         
            +
                metadata_file = Path(file) / "metadata.yaml"
         
     | 
| 
      
 290 
     | 
    
         
            +
                assert metadata_file.exists(), f"ExecuTorch metadata.yaml not found for task '{task}': {metadata_file}"
         
     | 
| 
      
 291 
     | 
    
         
            +
                # Note: Inference testing skipped as ExecuTorch requires special runtime setup
         
     | 
| 
      
 292 
     | 
    
         
            +
                shutil.rmtree(file, ignore_errors=True)  # cleanup
         
     | 
    
        tests/test_integrations.py
    CHANGED
    
    | 
         @@ -8,7 +8,7 @@ from pathlib import Path 
     | 
|
| 
       8 
8 
     | 
    
         | 
| 
       9 
9 
     | 
    
         
             
            import pytest
         
     | 
| 
       10 
10 
     | 
    
         | 
| 
       11 
     | 
    
         
            -
            from tests import MODEL, SOURCE 
     | 
| 
      
 11 
     | 
    
         
            +
            from tests import MODEL, SOURCE
         
     | 
| 
       12 
12 
     | 
    
         
             
            from ultralytics import YOLO, download
         
     | 
| 
       13 
13 
     | 
    
         
             
            from ultralytics.utils import ASSETS_URL, DATASETS_DIR, SETTINGS
         
     | 
| 
       14 
14 
     | 
    
         
             
            from ultralytics.utils.checks import check_requirements
         
     | 
| 
         @@ -71,14 +71,14 @@ def test_mlflow_keep_run_active(): 
     | 
|
| 
       71 
71 
     | 
    
         | 
| 
       72 
72 
     | 
    
         | 
| 
       73 
73 
     | 
    
         
             
            @pytest.mark.skipif(not check_requirements("tritonclient", install=False), reason="tritonclient[all] not installed")
         
     | 
| 
       74 
     | 
    
         
            -
            def test_triton():
         
     | 
| 
      
 74 
     | 
    
         
            +
            def test_triton(tmp_path):
         
     | 
| 
       75 
75 
     | 
    
         
             
                """Test NVIDIA Triton Server functionalities with YOLO model."""
         
     | 
| 
       76 
76 
     | 
    
         
             
                check_requirements("tritonclient[all]")
         
     | 
| 
       77 
     | 
    
         
            -
                from tritonclient.http import InferenceServerClient 
     | 
| 
      
 77 
     | 
    
         
            +
                from tritonclient.http import InferenceServerClient
         
     | 
| 
       78 
78 
     | 
    
         | 
| 
       79 
79 
     | 
    
         
             
                # Create variables
         
     | 
| 
       80 
80 
     | 
    
         
             
                model_name = "yolo"
         
     | 
| 
       81 
     | 
    
         
            -
                triton_repo =  
     | 
| 
      
 81 
     | 
    
         
            +
                triton_repo = tmp_path / "triton_repo"  # Triton repo path
         
     | 
| 
       82 
82 
     | 
    
         
             
                triton_model = triton_repo / model_name  # Triton model path
         
     | 
| 
       83 
83 
     | 
    
         | 
| 
       84 
84 
     | 
    
         
             
                # Export model to ONNX
         
     | 
    
        tests/test_python.py
    CHANGED
    
    | 
         @@ -12,7 +12,7 @@ import pytest 
     | 
|
| 
       12 
12 
     | 
    
         
             
            import torch
         
     | 
| 
       13 
13 
     | 
    
         
             
            from PIL import Image
         
     | 
| 
       14 
14 
     | 
    
         | 
| 
       15 
     | 
    
         
            -
            from tests import CFG, MODEL, MODELS, SOURCE, SOURCES_LIST, TASK_MODEL_DATA 
     | 
| 
      
 15 
     | 
    
         
            +
            from tests import CFG, MODEL, MODELS, SOURCE, SOURCES_LIST, TASK_MODEL_DATA
         
     | 
| 
       16 
16 
     | 
    
         
             
            from ultralytics import RTDETR, YOLO
         
     | 
| 
       17 
17 
     | 
    
         
             
            from ultralytics.cfg import TASK2DATA, TASKS
         
     | 
| 
       18 
18 
     | 
    
         
             
            from ultralytics.data.build import load_inference_source
         
     | 
| 
         @@ -33,14 +33,11 @@ from ultralytics.utils import ( 
     | 
|
| 
       33 
33 
     | 
    
         
             
                WINDOWS,
         
     | 
| 
       34 
34 
     | 
    
         
             
                YAML,
         
     | 
| 
       35 
35 
     | 
    
         
             
                checks,
         
     | 
| 
       36 
     | 
    
         
            -
                is_dir_writeable,
         
     | 
| 
       37 
36 
     | 
    
         
             
                is_github_action_running,
         
     | 
| 
       38 
37 
     | 
    
         
             
            )
         
     | 
| 
       39 
38 
     | 
    
         
             
            from ultralytics.utils.downloads import download
         
     | 
| 
       40 
39 
     | 
    
         
             
            from ultralytics.utils.torch_utils import TORCH_1_11, TORCH_1_13
         
     | 
| 
       41 
40 
     | 
    
         | 
| 
       42 
     | 
    
         
            -
            IS_TMP_WRITEABLE = is_dir_writeable(TMP)  # WARNING: must be run once tests start as TMP does not exist on tests/init
         
     | 
| 
       43 
     | 
    
         
            -
             
     | 
| 
       44 
41 
     | 
    
         | 
| 
       45 
42 
     | 
    
         
             
            def test_model_forward():
         
     | 
| 
       46 
43 
     | 
    
         
             
                """Test the forward pass of the YOLO model."""
         
     | 
| 
         @@ -77,10 +74,9 @@ def test_model_profile(): 
     | 
|
| 
       77 
74 
     | 
    
         
             
                _ = model.predict(im, profile=True)
         
     | 
| 
       78 
75 
     | 
    
         | 
| 
       79 
76 
     | 
    
         | 
| 
       80 
     | 
    
         
            -
             
     | 
| 
       81 
     | 
    
         
            -
            def test_predict_txt():
         
     | 
| 
      
 77 
     | 
    
         
            +
            def test_predict_txt(tmp_path):
         
     | 
| 
       82 
78 
     | 
    
         
             
                """Test YOLO predictions with file, directory, and pattern sources listed in a text file."""
         
     | 
| 
       83 
     | 
    
         
            -
                file =  
     | 
| 
      
 79 
     | 
    
         
            +
                file = tmp_path / "sources_multi_row.txt"
         
     | 
| 
       84 
80 
     | 
    
         
             
                with open(file, "w") as f:
         
     | 
| 
       85 
81 
     | 
    
         
             
                    for src in SOURCES_LIST:
         
     | 
| 
       86 
82 
     | 
    
         
             
                        f.write(f"{src}\n")
         
     | 
| 
         @@ -89,10 +85,9 @@ def test_predict_txt(): 
     | 
|
| 
       89 
85 
     | 
    
         | 
| 
       90 
86 
     | 
    
         | 
| 
       91 
87 
     | 
    
         
             
            @pytest.mark.skipif(True, reason="disabled for testing")
         
     | 
| 
       92 
     | 
    
         
            -
             
     | 
| 
       93 
     | 
    
         
            -
            def test_predict_csv_multi_row():
         
     | 
| 
      
 88 
     | 
    
         
            +
            def test_predict_csv_multi_row(tmp_path):
         
     | 
| 
       94 
89 
     | 
    
         
             
                """Test YOLO predictions with sources listed in multiple rows of a CSV file."""
         
     | 
| 
       95 
     | 
    
         
            -
                file =  
     | 
| 
      
 90 
     | 
    
         
            +
                file = tmp_path / "sources_multi_row.csv"
         
     | 
| 
       96 
91 
     | 
    
         
             
                with open(file, "w", newline="") as f:
         
     | 
| 
       97 
92 
     | 
    
         
             
                    writer = csv.writer(f)
         
     | 
| 
       98 
93 
     | 
    
         
             
                    writer.writerow(["source"])
         
     | 
| 
         @@ -102,10 +97,9 @@ def test_predict_csv_multi_row(): 
     | 
|
| 
       102 
97 
     | 
    
         | 
| 
       103 
98 
     | 
    
         | 
| 
       104 
99 
     | 
    
         
             
            @pytest.mark.skipif(True, reason="disabled for testing")
         
     | 
| 
       105 
     | 
    
         
            -
             
     | 
| 
       106 
     | 
    
         
            -
            def test_predict_csv_single_row():
         
     | 
| 
      
 100 
     | 
    
         
            +
            def test_predict_csv_single_row(tmp_path):
         
     | 
| 
       107 
101 
     | 
    
         
             
                """Test YOLO predictions with sources listed in a single row of a CSV file."""
         
     | 
| 
       108 
     | 
    
         
            -
                file =  
     | 
| 
      
 102 
     | 
    
         
            +
                file = tmp_path / "sources_single_row.csv"
         
     | 
| 
       109 
103 
     | 
    
         
             
                with open(file, "w", newline="") as f:
         
     | 
| 
       110 
104 
     | 
    
         
             
                    writer = csv.writer(f)
         
     | 
| 
       111 
105 
     | 
    
         
             
                    writer.writerow(SOURCES_LIST)
         
     | 
| 
         @@ -142,25 +136,23 @@ def test_predict_visualize(model): 
     | 
|
| 
       142 
136 
     | 
    
         
             
                YOLO(WEIGHTS_DIR / model)(SOURCE, imgsz=32, visualize=True)
         
     | 
| 
       143 
137 
     | 
    
         | 
| 
       144 
138 
     | 
    
         | 
| 
       145 
     | 
    
         
            -
            def  
     | 
| 
       146 
     | 
    
         
            -
                """Test YOLO prediction on SOURCE converted to  
     | 
| 
      
 139 
     | 
    
         
            +
            def test_predict_gray_and_4ch(tmp_path):
         
     | 
| 
      
 140 
     | 
    
         
            +
                """Test YOLO prediction on SOURCE converted to grayscale and 4-channel images with various filenames."""
         
     | 
| 
       147 
141 
     | 
    
         
             
                im = Image.open(SOURCE)
         
     | 
| 
       148 
     | 
    
         
            -
                directory = TMP / "im4"
         
     | 
| 
       149 
     | 
    
         
            -
                directory.mkdir(parents=True, exist_ok=True)
         
     | 
| 
       150 
142 
     | 
    
         | 
| 
       151 
     | 
    
         
            -
                 
     | 
| 
       152 
     | 
    
         
            -
                source_rgba =  
     | 
| 
       153 
     | 
    
         
            -
                source_non_utf =  
     | 
| 
       154 
     | 
    
         
            -
                source_spaces =  
     | 
| 
      
 143 
     | 
    
         
            +
                source_grayscale = tmp_path / "grayscale.jpg"
         
     | 
| 
      
 144 
     | 
    
         
            +
                source_rgba = tmp_path / "4ch.png"
         
     | 
| 
      
 145 
     | 
    
         
            +
                source_non_utf = tmp_path / "non_UTF_测试文件_tést_image.jpg"
         
     | 
| 
      
 146 
     | 
    
         
            +
                source_spaces = tmp_path / "image with spaces.jpg"
         
     | 
| 
       155 
147 
     | 
    
         | 
| 
       156 
     | 
    
         
            -
                im.convert("L").save( 
     | 
| 
      
 148 
     | 
    
         
            +
                im.convert("L").save(source_grayscale)  # grayscale
         
     | 
| 
       157 
149 
     | 
    
         
             
                im.convert("RGBA").save(source_rgba)  # 4-ch PNG with alpha
         
     | 
| 
       158 
150 
     | 
    
         
             
                im.save(source_non_utf)  # non-UTF characters in filename
         
     | 
| 
       159 
151 
     | 
    
         
             
                im.save(source_spaces)  # spaces in filename
         
     | 
| 
       160 
152 
     | 
    
         | 
| 
       161 
153 
     | 
    
         
             
                # Inference
         
     | 
| 
       162 
154 
     | 
    
         
             
                model = YOLO(MODEL)
         
     | 
| 
       163 
     | 
    
         
            -
                for f in source_rgba,  
     | 
| 
      
 155 
     | 
    
         
            +
                for f in source_rgba, source_grayscale, source_non_utf, source_spaces:
         
     | 
| 
       164 
156 
     | 
    
         
             
                    for source in Image.open(f), cv2.imread(str(f)), f:
         
     | 
| 
       165 
157 
     | 
    
         
             
                        results = model(source, save=True, verbose=True, imgsz=32)
         
     | 
| 
       166 
158 
     | 
    
         
             
                        assert len(results) == 1  # verify that an image was run
         
     | 
| 
         @@ -181,9 +173,8 @@ def test_youtube(): 
     | 
|
| 
       181 
173 
     | 
    
         | 
| 
       182 
174 
     | 
    
         | 
| 
       183 
175 
     | 
    
         
             
            @pytest.mark.skipif(not ONLINE, reason="environment is offline")
         
     | 
| 
       184 
     | 
    
         
            -
            @pytest.mark.skipif(not IS_TMP_WRITEABLE, reason="directory is not writeable")
         
     | 
| 
       185 
176 
     | 
    
         
             
            @pytest.mark.parametrize("model", MODELS)
         
     | 
| 
       186 
     | 
    
         
            -
            def test_track_stream(model):
         
     | 
| 
      
 177 
     | 
    
         
            +
            def test_track_stream(model, tmp_path):
         
     | 
| 
       187 
178 
     | 
    
         
             
                """
         
     | 
| 
       188 
179 
     | 
    
         
             
                Test streaming tracking on a short 10 frame video using ByteTrack tracker and different GMC methods.
         
     | 
| 
       189 
180 
     | 
    
         | 
| 
         @@ -199,7 +190,7 @@ def test_track_stream(model): 
     | 
|
| 
       199 
190 
     | 
    
         
             
                # Test Global Motion Compensation (GMC) methods and ReID
         
     | 
| 
       200 
191 
     | 
    
         
             
                for gmc, reidm in zip(["orb", "sift", "ecc"], ["auto", "auto", "yolo11n-cls.pt"]):
         
     | 
| 
       201 
192 
     | 
    
         
             
                    default_args = YAML.load(ROOT / "cfg/trackers/botsort.yaml")
         
     | 
| 
       202 
     | 
    
         
            -
                    custom_yaml =  
     | 
| 
      
 193 
     | 
    
         
            +
                    custom_yaml = tmp_path / f"botsort-{gmc}.yaml"
         
     | 
| 
       203 
194 
     | 
    
         
             
                    YAML.save(custom_yaml, {**default_args, "gmc_method": gmc, "with_reid": True, "model": reidm})
         
     | 
| 
       204 
195 
     | 
    
         
             
                    model.track(video_url, imgsz=160, tracker=custom_yaml)
         
     | 
| 
       205 
196 
     | 
    
         | 
| 
         @@ -278,7 +269,7 @@ def test_predict_callback_and_setup(): 
     | 
|
| 
       278 
269 
     | 
    
         
             
                model.add_callback("on_predict_batch_end", on_predict_batch_end)
         
     | 
| 
       279 
270 
     | 
    
         | 
| 
       280 
271 
     | 
    
         
             
                dataset = load_inference_source(source=SOURCE)
         
     | 
| 
       281 
     | 
    
         
            -
                bs = dataset.bs  #  
     | 
| 
      
 272 
     | 
    
         
            +
                bs = dataset.bs  # access predictor properties
         
     | 
| 
       282 
273 
     | 
    
         
             
                results = model.predict(dataset, stream=True, imgsz=160)  # source already setup
         
     | 
| 
       283 
274 
     | 
    
         
             
                for r, im0, bs in results:
         
     | 
| 
       284 
275 
     | 
    
         
             
                    print("test_callback", im0.shape)
         
     | 
| 
         @@ -288,7 +279,7 @@ def test_predict_callback_and_setup(): 
     | 
|
| 
       288 
279 
     | 
    
         | 
| 
       289 
280 
     | 
    
         | 
| 
       290 
281 
     | 
    
         
             
            @pytest.mark.parametrize("model", MODELS)
         
     | 
| 
       291 
     | 
    
         
            -
            def test_results(model: str):
         
     | 
| 
      
 282 
     | 
    
         
            +
            def test_results(model: str, tmp_path):
         
     | 
| 
       292 
283 
     | 
    
         
             
                """Test YOLO model results processing and output in various formats."""
         
     | 
| 
       293 
284 
     | 
    
         
             
                im = f"{ASSETS_URL}/boats.jpg" if model == "yolo11n-obb.pt" else SOURCE
         
     | 
| 
       294 
285 
     | 
    
         
             
                results = YOLO(WEIGHTS_DIR / model)([im, im], imgsz=160)
         
     | 
| 
         @@ -297,12 +288,12 @@ def test_results(model: str): 
     | 
|
| 
       297 
288 
     | 
    
         
             
                    r = r.cpu().numpy()
         
     | 
| 
       298 
289 
     | 
    
         
             
                    print(r, len(r), r.path)  # print numpy attributes
         
     | 
| 
       299 
290 
     | 
    
         
             
                    r = r.to(device="cpu", dtype=torch.float32)
         
     | 
| 
       300 
     | 
    
         
            -
                    r.save_txt(txt_file= 
     | 
| 
       301 
     | 
    
         
            -
                    r.save_crop(save_dir= 
     | 
| 
      
 291 
     | 
    
         
            +
                    r.save_txt(txt_file=tmp_path / "runs/tests/label.txt", save_conf=True)
         
     | 
| 
      
 292 
     | 
    
         
            +
                    r.save_crop(save_dir=tmp_path / "runs/tests/crops/")
         
     | 
| 
       302 
293 
     | 
    
         
             
                    r.to_df(decimals=3)  # Align to_ methods: https://docs.ultralytics.com/modes/predict/#working-with-results
         
     | 
| 
       303 
294 
     | 
    
         
             
                    r.to_csv()
         
     | 
| 
       304 
295 
     | 
    
         
             
                    r.to_json(normalize=True)
         
     | 
| 
       305 
     | 
    
         
            -
                    r.plot(pil=True, save=True, filename= 
     | 
| 
      
 296 
     | 
    
         
            +
                    r.plot(pil=True, save=True, filename=tmp_path / "results_plot_save.jpg")
         
     | 
| 
       306 
297 
     | 
    
         
             
                    r.plot(conf=True, boxes=True)
         
     | 
| 
       307 
298 
     | 
    
         
             
                    print(r, len(r), r.path)  # print after methods
         
     | 
| 
       308 
299 
     | 
    
         | 
| 
         @@ -332,7 +323,7 @@ def test_labels_and_crops(): 
     | 
|
| 
       332 
323 
     | 
    
         | 
| 
       333 
324 
     | 
    
         | 
| 
       334 
325 
     | 
    
         
             
            @pytest.mark.skipif(not ONLINE, reason="environment is offline")
         
     | 
| 
       335 
     | 
    
         
            -
            def test_data_utils():
         
     | 
| 
      
 326 
     | 
    
         
            +
            def test_data_utils(tmp_path):
         
     | 
| 
       336 
327 
     | 
    
         
             
                """Test utility functions in ultralytics/data/utils.py, including dataset stats and auto-splitting."""
         
     | 
| 
       337 
328 
     | 
    
         
             
                from ultralytics.data.split import autosplit
         
     | 
| 
       338 
329 
     | 
    
         
             
                from ultralytics.data.utils import HUBDatasetStats
         
     | 
| 
         @@ -343,26 +334,28 @@ def test_data_utils(): 
     | 
|
| 
       343 
334 
     | 
    
         | 
| 
       344 
335 
     | 
    
         
             
                for task in TASKS:
         
     | 
| 
       345 
336 
     | 
    
         
             
                    file = Path(TASK2DATA[task]).with_suffix(".zip")  # i.e. coco8.zip
         
     | 
| 
       346 
     | 
    
         
            -
                    download(f"https://github.com/ultralytics/hub/raw/main/example_datasets/{file}", unzip=False, dir= 
     | 
| 
       347 
     | 
    
         
            -
                    stats = HUBDatasetStats( 
     | 
| 
      
 337 
     | 
    
         
            +
                    download(f"https://github.com/ultralytics/hub/raw/main/example_datasets/{file}", unzip=False, dir=tmp_path)
         
     | 
| 
      
 338 
     | 
    
         
            +
                    stats = HUBDatasetStats(tmp_path / file, task=task)
         
     | 
| 
       348 
339 
     | 
    
         
             
                    stats.get_json(save=True)
         
     | 
| 
       349 
340 
     | 
    
         
             
                    stats.process_images()
         
     | 
| 
       350 
341 
     | 
    
         | 
| 
       351 
     | 
    
         
            -
                autosplit( 
     | 
| 
       352 
     | 
    
         
            -
                zip_directory( 
     | 
| 
      
 342 
     | 
    
         
            +
                autosplit(tmp_path / "coco8")
         
     | 
| 
      
 343 
     | 
    
         
            +
                zip_directory(tmp_path / "coco8/images/val")  # zip
         
     | 
| 
       353 
344 
     | 
    
         | 
| 
       354 
345 
     | 
    
         | 
| 
       355 
346 
     | 
    
         
             
            @pytest.mark.skipif(not ONLINE, reason="environment is offline")
         
     | 
| 
       356 
     | 
    
         
            -
            def test_data_converter():
         
     | 
| 
      
 347 
     | 
    
         
            +
            def test_data_converter(tmp_path):
         
     | 
| 
       357 
348 
     | 
    
         
             
                """Test dataset conversion functions from COCO to YOLO format and class mappings."""
         
     | 
| 
       358 
349 
     | 
    
         
             
                from ultralytics.data.converter import coco80_to_coco91_class, convert_coco
         
     | 
| 
       359 
350 
     | 
    
         | 
| 
       360 
     | 
    
         
            -
                download(f"{ASSETS_URL}/instances_val2017.json", dir= 
     | 
| 
       361 
     | 
    
         
            -
                convert_coco( 
     | 
| 
      
 351 
     | 
    
         
            +
                download(f"{ASSETS_URL}/instances_val2017.json", dir=tmp_path)
         
     | 
| 
      
 352 
     | 
    
         
            +
                convert_coco(
         
     | 
| 
      
 353 
     | 
    
         
            +
                    labels_dir=tmp_path, save_dir=tmp_path / "yolo_labels", use_segments=True, use_keypoints=False, cls91to80=True
         
     | 
| 
      
 354 
     | 
    
         
            +
                )
         
     | 
| 
       362 
355 
     | 
    
         
             
                coco80_to_coco91_class()
         
     | 
| 
       363 
356 
     | 
    
         | 
| 
       364 
357 
     | 
    
         | 
| 
       365 
     | 
    
         
            -
            def test_data_annotator():
         
     | 
| 
      
 358 
     | 
    
         
            +
            def test_data_annotator(tmp_path):
         
     | 
| 
       366 
359 
     | 
    
         
             
                """Test automatic annotation of data using detection and segmentation models."""
         
     | 
| 
       367 
360 
     | 
    
         
             
                from ultralytics.data.annotator import auto_annotate
         
     | 
| 
       368 
361 
     | 
    
         | 
| 
         @@ -370,7 +363,7 @@ def test_data_annotator(): 
     | 
|
| 
       370 
363 
     | 
    
         
             
                    ASSETS,
         
     | 
| 
       371 
364 
     | 
    
         
             
                    det_model=WEIGHTS_DIR / "yolo11n.pt",
         
     | 
| 
       372 
365 
     | 
    
         
             
                    sam_model=WEIGHTS_DIR / "mobile_sam.pt",
         
     | 
| 
       373 
     | 
    
         
            -
                    output_dir= 
     | 
| 
      
 366 
     | 
    
         
            +
                    output_dir=tmp_path / "auto_annotate_labels",
         
     | 
| 
       374 
367 
     | 
    
         
             
                )
         
     | 
| 
       375 
368 
     | 
    
         | 
| 
       376 
369 
     | 
    
         | 
| 
         @@ -464,7 +457,7 @@ def test_utils_ops(): 
     | 
|
| 
       464 
457 
     | 
    
         
             
                torch.allclose(boxes, xyxyxyxy2xywhr(xywhr2xyxyxyxy(boxes)), rtol=1e-3)
         
     | 
| 
       465 
458 
     | 
    
         | 
| 
       466 
459 
     | 
    
         | 
| 
       467 
     | 
    
         
            -
            def test_utils_files():
         
     | 
| 
      
 460 
     | 
    
         
            +
            def test_utils_files(tmp_path):
         
     | 
| 
       468 
461 
     | 
    
         
             
                """Test file handling utilities including file age, date, and paths with spaces."""
         
     | 
| 
       469 
462 
     | 
    
         
             
                from ultralytics.utils.files import file_age, file_date, get_latest_run, spaces_in_path
         
     | 
| 
       470 
463 
     | 
    
         | 
| 
         @@ -472,14 +465,14 @@ def test_utils_files(): 
     | 
|
| 
       472 
465 
     | 
    
         
             
                file_date(SOURCE)
         
     | 
| 
       473 
466 
     | 
    
         
             
                get_latest_run(ROOT / "runs")
         
     | 
| 
       474 
467 
     | 
    
         | 
| 
       475 
     | 
    
         
            -
                path =  
     | 
| 
      
 468 
     | 
    
         
            +
                path = tmp_path / "path/with spaces"
         
     | 
| 
       476 
469 
     | 
    
         
             
                path.mkdir(parents=True, exist_ok=True)
         
     | 
| 
       477 
470 
     | 
    
         
             
                with spaces_in_path(path) as new_path:
         
     | 
| 
       478 
471 
     | 
    
         
             
                    print(new_path)
         
     | 
| 
       479 
472 
     | 
    
         | 
| 
       480 
473 
     | 
    
         | 
| 
       481 
474 
     | 
    
         
             
            @pytest.mark.slow
         
     | 
| 
       482 
     | 
    
         
            -
            def test_utils_patches_torch_save():
         
     | 
| 
      
 475 
     | 
    
         
            +
            def test_utils_patches_torch_save(tmp_path):
         
     | 
| 
       483 
476 
     | 
    
         
             
                """Test torch_save backoff when _torch_save raises RuntimeError."""
         
     | 
| 
       484 
477 
     | 
    
         
             
                from unittest.mock import MagicMock, patch
         
     | 
| 
       485 
478 
     | 
    
         | 
| 
         @@ -489,7 +482,7 @@ def test_utils_patches_torch_save(): 
     | 
|
| 
       489 
482 
     | 
    
         | 
| 
       490 
483 
     | 
    
         
             
                with patch("ultralytics.utils.patches._torch_save", new=mock):
         
     | 
| 
       491 
484 
     | 
    
         
             
                    with pytest.raises(RuntimeError):
         
     | 
| 
       492 
     | 
    
         
            -
                        torch_save(torch.zeros(1),  
     | 
| 
      
 485 
     | 
    
         
            +
                        torch_save(torch.zeros(1), tmp_path / "test.pt")
         
     | 
| 
       493 
486 
     | 
    
         | 
| 
       494 
487 
     | 
    
         
             
                assert mock.call_count == 4, "torch_save was not attempted the expected number of times"
         
     | 
| 
       495 
488 
     | 
    
         | 
| 
         @@ -722,11 +715,11 @@ def test_multichannel(): 
     | 
|
| 
       722 
715 
     | 
    
         | 
| 
       723 
716 
     | 
    
         | 
| 
       724 
717 
     | 
    
         
             
            @pytest.mark.parametrize("task,model,data", TASK_MODEL_DATA)
         
     | 
| 
       725 
     | 
    
         
            -
            def test_grayscale(task: str, model: str, data: str) -> None:
         
     | 
| 
      
 718 
     | 
    
         
            +
            def test_grayscale(task: str, model: str, data: str, tmp_path) -> None:
         
     | 
| 
       726 
719 
     | 
    
         
             
                """Test YOLO model grayscale training, validation, and prediction functionality."""
         
     | 
| 
       727 
720 
     | 
    
         
             
                if task == "classify":  # not support grayscale classification yet
         
     | 
| 
       728 
721 
     | 
    
         
             
                    return
         
     | 
| 
       729 
     | 
    
         
            -
                grayscale_data =  
     | 
| 
      
 722 
     | 
    
         
            +
                grayscale_data = tmp_path / f"{Path(data).stem}-grayscale.yaml"
         
     | 
| 
       730 
723 
     | 
    
         
             
                data = check_det_dataset(data)
         
     | 
| 
       731 
724 
     | 
    
         
             
                data["channels"] = 1  # add additional channels key for grayscale
         
     | 
| 
       732 
725 
     | 
    
         
             
                YAML.save(grayscale_data, data)
         
     |