dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.3.248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +13 -14
- dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
- tests/__init__.py +5 -7
- tests/conftest.py +8 -15
- tests/test_cli.py +1 -1
- tests/test_cuda.py +5 -8
- tests/test_engine.py +1 -1
- tests/test_exports.py +57 -12
- tests/test_integrations.py +4 -4
- tests/test_python.py +84 -53
- tests/test_solutions.py +160 -151
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +56 -62
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/VOC.yaml +15 -16
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +21 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +1 -1
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +285 -475
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +147 -25
- ultralytics/data/converter.py +36 -46
- ultralytics/data/dataset.py +46 -74
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +34 -43
- ultralytics/engine/exporter.py +319 -237
- ultralytics/engine/model.py +148 -188
- ultralytics/engine/predictor.py +29 -38
- ultralytics/engine/results.py +177 -311
- ultralytics/engine/trainer.py +83 -59
- ultralytics/engine/tuner.py +23 -34
- ultralytics/engine/validator.py +39 -22
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +17 -29
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +19 -37
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +200 -73
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2248 -350
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +9 -12
- ultralytics/models/yolo/classify/train.py +11 -32
- ultralytics/models/yolo/classify/val.py +29 -28
- ultralytics/models/yolo/detect/predict.py +7 -10
- ultralytics/models/yolo/detect/train.py +11 -20
- ultralytics/models/yolo/detect/val.py +70 -58
- ultralytics/models/yolo/model.py +36 -53
- ultralytics/models/yolo/obb/predict.py +5 -14
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +39 -36
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +6 -21
- ultralytics/models/yolo/pose/train.py +10 -15
- ultralytics/models/yolo/pose/val.py +38 -57
- ultralytics/models/yolo/segment/predict.py +14 -18
- ultralytics/models/yolo/segment/train.py +3 -6
- ultralytics/models/yolo/segment/val.py +93 -45
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +30 -43
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +145 -77
- ultralytics/nn/modules/__init__.py +60 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +132 -216
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +50 -103
- ultralytics/nn/modules/transformer.py +76 -88
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +94 -154
- ultralytics/nn/text_model.py +40 -67
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +11 -17
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +5 -6
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +33 -31
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +10 -15
- ultralytics/solutions/solutions.py +75 -74
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +14 -30
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +116 -116
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +32 -46
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +314 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +23 -31
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +99 -76
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +8 -12
- ultralytics/utils/downloads.py +20 -30
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -236
- ultralytics/utils/export/engine.py +237 -0
- ultralytics/utils/export/imx.py +91 -55
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +24 -28
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +14 -22
- ultralytics/utils/metrics.py +126 -155
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +107 -165
- ultralytics/utils/patches.py +33 -21
- ultralytics/utils/plotting.py +72 -80
- ultralytics/utils/tal.py +25 -39
- ultralytics/utils/torch_utils.py +52 -78
- ultralytics/utils/tqdm.py +20 -20
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +17 -5
- dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
ultralytics/utils/logger.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
-
import queue
|
|
5
4
|
import shutil
|
|
6
5
|
import sys
|
|
7
6
|
import threading
|
|
@@ -9,77 +8,84 @@ import time
|
|
|
9
8
|
from datetime import datetime
|
|
10
9
|
from pathlib import Path
|
|
11
10
|
|
|
12
|
-
from ultralytics.utils import MACOS, RANK
|
|
11
|
+
from ultralytics.utils import LOGGER, MACOS, RANK
|
|
13
12
|
from ultralytics.utils.checks import check_requirements
|
|
14
13
|
|
|
15
|
-
# Initialize default log file
|
|
16
|
-
DEFAULT_LOG_PATH = Path("train.log")
|
|
17
|
-
if RANK in {-1, 0} and DEFAULT_LOG_PATH.exists():
|
|
18
|
-
DEFAULT_LOG_PATH.unlink(missing_ok=True)
|
|
19
|
-
|
|
20
14
|
|
|
21
15
|
class ConsoleLogger:
|
|
22
|
-
"""
|
|
23
|
-
Console output capture with API/file streaming and deduplication.
|
|
16
|
+
"""Console output capture with batched streaming to file, API, or custom callback.
|
|
24
17
|
|
|
25
|
-
Captures stdout/stderr output and streams it
|
|
26
|
-
deduplication to reduce noise from repetitive console output.
|
|
18
|
+
Captures stdout/stderr output and streams it with intelligent deduplication and configurable batching.
|
|
27
19
|
|
|
28
20
|
Attributes:
|
|
29
|
-
destination (str | Path): Target destination for streaming (URL or
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
log_queue (queue.Queue): Thread-safe queue for buffering log messages.
|
|
21
|
+
destination (str | Path | None): Target destination for streaming (URL, Path, or None for callback-only).
|
|
22
|
+
batch_size (int): Number of lines to batch before flushing (default: 1 for immediate).
|
|
23
|
+
flush_interval (float): Seconds between automatic flushes (default: 5.0).
|
|
24
|
+
on_flush (callable | None): Optional callback function called with batched content on flush.
|
|
34
25
|
active (bool): Whether console capture is currently active.
|
|
35
|
-
worker_thread (threading.Thread): Background thread for processing log queue.
|
|
36
|
-
last_line (str): Last processed line for deduplication.
|
|
37
|
-
last_time (float): Timestamp of last processed line.
|
|
38
|
-
last_progress_line (str): Last progress bar line for progress deduplication.
|
|
39
|
-
last_was_progress (bool): Whether the last line was a progress bar.
|
|
40
26
|
|
|
41
27
|
Examples:
|
|
42
|
-
|
|
28
|
+
File logging (immediate):
|
|
43
29
|
>>> logger = ConsoleLogger("training.log")
|
|
44
30
|
>>> logger.start_capture()
|
|
45
31
|
>>> print("This will be logged")
|
|
46
32
|
>>> logger.stop_capture()
|
|
47
33
|
|
|
48
|
-
API streaming:
|
|
49
|
-
>>> logger = ConsoleLogger("https://api.example.com/logs")
|
|
34
|
+
API streaming with batching:
|
|
35
|
+
>>> logger = ConsoleLogger("https://api.example.com/logs", batch_size=10)
|
|
36
|
+
>>> logger.start_capture()
|
|
37
|
+
|
|
38
|
+
Custom callback with batching:
|
|
39
|
+
>>> def my_handler(content, line_count, chunk_id):
|
|
40
|
+
... print(f"Received {line_count} lines")
|
|
41
|
+
>>> logger = ConsoleLogger(on_flush=my_handler, batch_size=5)
|
|
50
42
|
>>> logger.start_capture()
|
|
51
|
-
>>> # All output streams to API
|
|
52
|
-
>>> logger.stop_capture()
|
|
53
43
|
"""
|
|
54
44
|
|
|
55
|
-
def __init__(self, destination):
|
|
56
|
-
"""
|
|
57
|
-
Initialize with API endpoint or local file path.
|
|
45
|
+
def __init__(self, destination=None, batch_size=1, flush_interval=5.0, on_flush=None):
|
|
46
|
+
"""Initialize console logger with optional batching.
|
|
58
47
|
|
|
59
48
|
Args:
|
|
60
|
-
destination (str | Path): API endpoint URL (http/https)
|
|
49
|
+
destination (str | Path | None): API endpoint URL (http/https), local file path, or None.
|
|
50
|
+
batch_size (int): Lines to accumulate before flush (1 = immediate, higher = batched).
|
|
51
|
+
flush_interval (float): Max seconds between flushes when batching.
|
|
52
|
+
on_flush (callable | None): Callback(content: str, line_count: int, chunk_id: int) for custom handling.
|
|
61
53
|
"""
|
|
62
54
|
self.destination = destination
|
|
63
55
|
self.is_api = isinstance(destination, str) and destination.startswith(("http://", "https://"))
|
|
64
|
-
if not self.is_api:
|
|
56
|
+
if destination is not None and not self.is_api:
|
|
65
57
|
self.destination = Path(destination)
|
|
66
58
|
|
|
67
|
-
#
|
|
59
|
+
# Batching configuration
|
|
60
|
+
self.batch_size = max(1, batch_size)
|
|
61
|
+
self.flush_interval = flush_interval
|
|
62
|
+
self.on_flush = on_flush
|
|
63
|
+
|
|
64
|
+
# Console capture state
|
|
68
65
|
self.original_stdout = sys.stdout
|
|
69
66
|
self.original_stderr = sys.stderr
|
|
70
|
-
self.log_queue = queue.Queue(maxsize=1000)
|
|
71
67
|
self.active = False
|
|
72
|
-
self.
|
|
68
|
+
self._log_handler = None # Track handler for cleanup
|
|
73
69
|
|
|
74
|
-
#
|
|
70
|
+
# Buffer for batching
|
|
71
|
+
self.buffer = []
|
|
72
|
+
self.buffer_lock = threading.Lock()
|
|
73
|
+
self.flush_thread = None
|
|
74
|
+
self.chunk_id = 0
|
|
75
|
+
|
|
76
|
+
# Deduplication state
|
|
75
77
|
self.last_line = ""
|
|
76
78
|
self.last_time = 0.0
|
|
77
|
-
self.last_progress_line = "" # Track
|
|
79
|
+
self.last_progress_line = "" # Track progress sequence key for deduplication
|
|
78
80
|
self.last_was_progress = False # Track if last line was a progress bar
|
|
79
81
|
|
|
80
82
|
def start_capture(self):
|
|
81
|
-
"""Start capturing console output and redirect stdout/stderr
|
|
82
|
-
|
|
83
|
+
"""Start capturing console output and redirect stdout/stderr.
|
|
84
|
+
|
|
85
|
+
Notes:
|
|
86
|
+
In DDP training, only activates on rank 0/-1 to prevent duplicate logging.
|
|
87
|
+
"""
|
|
88
|
+
if self.active or RANK not in {-1, 0}:
|
|
83
89
|
return
|
|
84
90
|
|
|
85
91
|
self.active = True
|
|
@@ -88,23 +94,35 @@ class ConsoleLogger:
|
|
|
88
94
|
|
|
89
95
|
# Hook Ultralytics logger
|
|
90
96
|
try:
|
|
91
|
-
|
|
92
|
-
logging.getLogger("ultralytics").addHandler(
|
|
97
|
+
self._log_handler = self._LogHandler(self._queue_log)
|
|
98
|
+
logging.getLogger("ultralytics").addHandler(self._log_handler)
|
|
93
99
|
except Exception:
|
|
94
100
|
pass
|
|
95
101
|
|
|
96
|
-
|
|
97
|
-
self.
|
|
102
|
+
# Start background flush thread for batched mode
|
|
103
|
+
if self.batch_size > 1:
|
|
104
|
+
self.flush_thread = threading.Thread(target=self._flush_worker, daemon=True)
|
|
105
|
+
self.flush_thread.start()
|
|
98
106
|
|
|
99
107
|
def stop_capture(self):
|
|
100
|
-
"""Stop capturing console output and
|
|
108
|
+
"""Stop capturing console output and flush remaining buffer."""
|
|
101
109
|
if not self.active:
|
|
102
110
|
return
|
|
103
111
|
|
|
104
112
|
self.active = False
|
|
105
113
|
sys.stdout = self.original_stdout
|
|
106
114
|
sys.stderr = self.original_stderr
|
|
107
|
-
|
|
115
|
+
|
|
116
|
+
# Remove logging handler to prevent memory leak
|
|
117
|
+
if self._log_handler:
|
|
118
|
+
try:
|
|
119
|
+
logging.getLogger("ultralytics").removeHandler(self._log_handler)
|
|
120
|
+
except Exception:
|
|
121
|
+
pass
|
|
122
|
+
self._log_handler = None
|
|
123
|
+
|
|
124
|
+
# Final flush
|
|
125
|
+
self._flush_buffer()
|
|
108
126
|
|
|
109
127
|
def _queue_log(self, text):
|
|
110
128
|
"""Queue console text with deduplication and timestamp processing."""
|
|
@@ -128,12 +146,34 @@ class ConsoleLogger:
|
|
|
128
146
|
if "─" in line: # Has thin lines but no thick lines
|
|
129
147
|
continue
|
|
130
148
|
|
|
131
|
-
#
|
|
149
|
+
# Only show 100% completion lines for progress bars
|
|
132
150
|
if " ━━" in line:
|
|
133
|
-
|
|
134
|
-
|
|
151
|
+
is_complete = "100%" in line
|
|
152
|
+
|
|
153
|
+
# Skip ALL non-complete progress lines
|
|
154
|
+
if not is_complete:
|
|
135
155
|
continue
|
|
136
|
-
|
|
156
|
+
|
|
157
|
+
# Extract sequence key to deduplicate multiple 100% lines for same sequence
|
|
158
|
+
parts = line.split()
|
|
159
|
+
seq_key = ""
|
|
160
|
+
if parts:
|
|
161
|
+
# Check for epoch pattern (X/Y at start)
|
|
162
|
+
if "/" in parts[0] and parts[0].replace("/", "").isdigit():
|
|
163
|
+
seq_key = parts[0] # e.g., "1/3"
|
|
164
|
+
elif parts[0] == "Class" and len(parts) > 1:
|
|
165
|
+
seq_key = f"{parts[0]}_{parts[1]}" # e.g., "Class_train:" or "Class_val:"
|
|
166
|
+
elif parts[0] in ("train:", "val:"):
|
|
167
|
+
seq_key = parts[0] # Phase identifier
|
|
168
|
+
|
|
169
|
+
# Skip if we already showed 100% for this sequence
|
|
170
|
+
if seq_key and self.last_progress_line == f"{seq_key}:done":
|
|
171
|
+
continue
|
|
172
|
+
|
|
173
|
+
# Mark this sequence as done
|
|
174
|
+
if seq_key:
|
|
175
|
+
self.last_progress_line = f"{seq_key}:done"
|
|
176
|
+
|
|
137
177
|
self.last_was_progress = True
|
|
138
178
|
else:
|
|
139
179
|
# Skip empty line after progress bar
|
|
@@ -154,63 +194,80 @@ class ConsoleLogger:
|
|
|
154
194
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
155
195
|
line = f"[{timestamp}] {line}"
|
|
156
196
|
|
|
157
|
-
#
|
|
158
|
-
|
|
159
|
-
|
|
197
|
+
# Add to buffer and check if flush needed
|
|
198
|
+
should_flush = False
|
|
199
|
+
with self.buffer_lock:
|
|
200
|
+
self.buffer.append(line)
|
|
201
|
+
if len(self.buffer) >= self.batch_size:
|
|
202
|
+
should_flush = True
|
|
160
203
|
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
try:
|
|
168
|
-
self.log_queue.get_nowait() # Drop oldest
|
|
169
|
-
self.log_queue.put_nowait(item)
|
|
170
|
-
return True
|
|
171
|
-
except queue.Empty:
|
|
172
|
-
return False
|
|
173
|
-
|
|
174
|
-
def _stream_worker(self):
|
|
175
|
-
"""Background worker for streaming logs to destination."""
|
|
204
|
+
# Flush outside lock to avoid deadlock
|
|
205
|
+
if should_flush:
|
|
206
|
+
self._flush_buffer()
|
|
207
|
+
|
|
208
|
+
def _flush_worker(self):
|
|
209
|
+
"""Background worker that flushes buffer periodically."""
|
|
176
210
|
while self.active:
|
|
211
|
+
time.sleep(self.flush_interval)
|
|
212
|
+
if self.active:
|
|
213
|
+
self._flush_buffer()
|
|
214
|
+
|
|
215
|
+
def _flush_buffer(self):
|
|
216
|
+
"""Flush buffered lines to destination and/or callback."""
|
|
217
|
+
with self.buffer_lock:
|
|
218
|
+
if not self.buffer:
|
|
219
|
+
return
|
|
220
|
+
lines = self.buffer.copy()
|
|
221
|
+
self.buffer.clear()
|
|
222
|
+
self.chunk_id += 1
|
|
223
|
+
chunk_id = self.chunk_id # Capture under lock to avoid race
|
|
224
|
+
|
|
225
|
+
content = "\n".join(lines)
|
|
226
|
+
line_count = len(lines)
|
|
227
|
+
|
|
228
|
+
# Call custom callback if provided
|
|
229
|
+
if self.on_flush:
|
|
177
230
|
try:
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
231
|
+
self.on_flush(content, line_count, chunk_id)
|
|
232
|
+
except Exception:
|
|
233
|
+
pass # Silently ignore callback errors to avoid flooding stderr
|
|
234
|
+
|
|
235
|
+
# Write to destination (file or API)
|
|
236
|
+
if self.destination is not None:
|
|
237
|
+
self._write_destination(content)
|
|
184
238
|
|
|
185
|
-
def
|
|
186
|
-
"""Write
|
|
239
|
+
def _write_destination(self, content):
|
|
240
|
+
"""Write content to file or API destination."""
|
|
187
241
|
try:
|
|
188
242
|
if self.is_api:
|
|
189
|
-
import requests
|
|
243
|
+
import requests
|
|
190
244
|
|
|
191
|
-
payload = {"timestamp": datetime.now().isoformat(), "message":
|
|
245
|
+
payload = {"timestamp": datetime.now().isoformat(), "message": content}
|
|
192
246
|
requests.post(str(self.destination), json=payload, timeout=5)
|
|
193
247
|
else:
|
|
194
248
|
self.destination.parent.mkdir(parents=True, exist_ok=True)
|
|
195
249
|
with self.destination.open("a", encoding="utf-8") as f:
|
|
196
|
-
f.write(
|
|
250
|
+
f.write(content + "\n")
|
|
197
251
|
except Exception as e:
|
|
198
|
-
print(f"
|
|
252
|
+
print(f"Console logger write error: {e}", file=self.original_stderr)
|
|
199
253
|
|
|
200
254
|
class _ConsoleCapture:
|
|
201
255
|
"""Lightweight stdout/stderr capture."""
|
|
202
256
|
|
|
203
|
-
__slots__ = ("
|
|
257
|
+
__slots__ = ("callback", "original")
|
|
204
258
|
|
|
205
259
|
def __init__(self, original, callback):
|
|
260
|
+
"""Initialize a stream wrapper that redirects writes to a callback while preserving the original."""
|
|
206
261
|
self.original = original
|
|
207
262
|
self.callback = callback
|
|
208
263
|
|
|
209
264
|
def write(self, text):
|
|
265
|
+
"""Forward text to the wrapped original stream, preserving default stdout/stderr semantics."""
|
|
210
266
|
self.original.write(text)
|
|
211
267
|
self.callback(text)
|
|
212
268
|
|
|
213
269
|
def flush(self):
|
|
270
|
+
"""Flush the wrapped stream to propagate buffered output promptly during console capture."""
|
|
214
271
|
self.original.flush()
|
|
215
272
|
|
|
216
273
|
class _LogHandler(logging.Handler):
|
|
@@ -219,19 +276,20 @@ class ConsoleLogger:
|
|
|
219
276
|
__slots__ = ("callback",)
|
|
220
277
|
|
|
221
278
|
def __init__(self, callback):
|
|
279
|
+
"""Initialize a lightweight logging.Handler that forwards log records to the provided callback."""
|
|
222
280
|
super().__init__()
|
|
223
281
|
self.callback = callback
|
|
224
282
|
|
|
225
283
|
def emit(self, record):
|
|
284
|
+
"""Format and forward LogRecord messages to the capture callback for unified log streaming."""
|
|
226
285
|
self.callback(self.format(record) + "\n")
|
|
227
286
|
|
|
228
287
|
|
|
229
288
|
class SystemLogger:
|
|
230
|
-
"""
|
|
231
|
-
Log dynamic system metrics for training monitoring.
|
|
289
|
+
"""Log dynamic system metrics for training monitoring.
|
|
232
290
|
|
|
233
|
-
Captures real-time system metrics including CPU, RAM, disk I/O, network I/O, and NVIDIA GPU statistics for
|
|
234
|
-
|
|
291
|
+
Captures real-time system metrics including CPU, RAM, disk I/O, network I/O, and NVIDIA GPU statistics for training
|
|
292
|
+
performance monitoring and analysis.
|
|
235
293
|
|
|
236
294
|
Attributes:
|
|
237
295
|
pynvml: NVIDIA pynvml module instance if successfully imported, None otherwise.
|
|
@@ -265,54 +323,71 @@ class SystemLogger:
|
|
|
265
323
|
self.net_start = psutil.net_io_counters()
|
|
266
324
|
self.disk_start = psutil.disk_io_counters()
|
|
267
325
|
|
|
326
|
+
# For rate calculation
|
|
327
|
+
self._prev_net = self.net_start
|
|
328
|
+
self._prev_disk = self.disk_start
|
|
329
|
+
self._prev_time = time.time()
|
|
330
|
+
|
|
268
331
|
def _init_nvidia(self):
|
|
269
332
|
"""Initialize NVIDIA GPU monitoring with pynvml."""
|
|
333
|
+
if MACOS:
|
|
334
|
+
return False
|
|
335
|
+
|
|
270
336
|
try:
|
|
271
|
-
assert not MACOS
|
|
272
337
|
check_requirements("nvidia-ml-py>=12.0.0")
|
|
273
338
|
self.pynvml = __import__("pynvml")
|
|
274
339
|
self.pynvml.nvmlInit()
|
|
275
340
|
return True
|
|
276
|
-
except Exception:
|
|
341
|
+
except Exception as e:
|
|
342
|
+
import torch
|
|
343
|
+
|
|
344
|
+
if torch.cuda.is_available():
|
|
345
|
+
LOGGER.warning(f"SystemLogger NVML init failed: {e}")
|
|
277
346
|
return False
|
|
278
347
|
|
|
279
|
-
def get_metrics(self):
|
|
280
|
-
"""
|
|
281
|
-
Get current system metrics.
|
|
348
|
+
def get_metrics(self, rates=False):
|
|
349
|
+
"""Get current system metrics including CPU, RAM, disk, network, and GPU usage.
|
|
282
350
|
|
|
283
|
-
Collects comprehensive system metrics including CPU usage, RAM usage, disk I/O statistics,
|
|
284
|
-
|
|
351
|
+
Collects comprehensive system metrics including CPU usage, RAM usage, disk I/O statistics, network I/O
|
|
352
|
+
statistics, and GPU metrics (if available).
|
|
285
353
|
|
|
354
|
+
Example output (rates=False, default):
|
|
286
355
|
```python
|
|
287
|
-
|
|
356
|
+
{
|
|
288
357
|
"cpu": 45.2,
|
|
289
358
|
"ram": 78.9,
|
|
290
359
|
"disk": {"read_mb": 156.7, "write_mb": 89.3, "used_gb": 256.8},
|
|
291
360
|
"network": {"recv_mb": 157.2, "sent_mb": 89.1},
|
|
292
361
|
"gpus": {
|
|
293
|
-
0: {"usage": 95.6, "memory": 85.4, "temp": 72, "power": 285},
|
|
294
|
-
1: {"usage": 94.1, "memory": 82.7, "temp": 70, "power": 278},
|
|
362
|
+
"0": {"usage": 95.6, "memory": 85.4, "temp": 72, "power": 285},
|
|
363
|
+
"1": {"usage": 94.1, "memory": 82.7, "temp": 70, "power": 278},
|
|
364
|
+
},
|
|
365
|
+
}
|
|
366
|
+
```
|
|
367
|
+
|
|
368
|
+
Example output (rates=True):
|
|
369
|
+
```python
|
|
370
|
+
{
|
|
371
|
+
"cpu": 45.2,
|
|
372
|
+
"ram": 78.9,
|
|
373
|
+
"disk": {"read_mbs": 12.5, "write_mbs": 8.3, "used_gb": 256.8},
|
|
374
|
+
"network": {"recv_mbs": 5.2, "sent_mbs": 1.1},
|
|
375
|
+
"gpus": {
|
|
376
|
+
"0": {"usage": 95.6, "memory": 85.4, "temp": 72, "power": 285},
|
|
295
377
|
},
|
|
296
378
|
}
|
|
297
379
|
```
|
|
298
380
|
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
- disk (dict):
|
|
302
|
-
- read_mb (float): Cumulative disk read in MB since initialization
|
|
303
|
-
- write_mb (float): Cumulative disk write in MB since initialization
|
|
304
|
-
- used_gb (float): Total disk space used in GB
|
|
305
|
-
- network (dict):
|
|
306
|
-
- recv_mb (float): Cumulative network received in MB since initialization
|
|
307
|
-
- sent_mb (float): Cumulative network sent in MB since initialization
|
|
308
|
-
- gpus (dict): GPU metrics by device index (e.g., 0, 1) containing:
|
|
309
|
-
- usage (int): GPU utilization percentage (0-100%)
|
|
310
|
-
- memory (float): CUDA memory usage percentage (0-100%)
|
|
311
|
-
- temp (int): GPU temperature in degrees Celsius
|
|
312
|
-
- power (int): GPU power consumption in watts
|
|
381
|
+
Args:
|
|
382
|
+
rates (bool): If True, return disk/network as MB/s rates instead of cumulative MB.
|
|
313
383
|
|
|
314
384
|
Returns:
|
|
315
|
-
|
|
385
|
+
(dict): Metrics dictionary with cpu, ram, disk, network, and gpus keys.
|
|
386
|
+
|
|
387
|
+
Examples:
|
|
388
|
+
>>> logger = SystemLogger()
|
|
389
|
+
>>> logger.get_metrics()["cpu"] # CPU percentage
|
|
390
|
+
>>> logger.get_metrics(rates=True)["network"]["recv_mbs"] # MB/s download rate
|
|
316
391
|
"""
|
|
317
392
|
import psutil # scoped as slow import
|
|
318
393
|
|
|
@@ -320,21 +395,44 @@ class SystemLogger:
|
|
|
320
395
|
disk = psutil.disk_io_counters()
|
|
321
396
|
memory = psutil.virtual_memory()
|
|
322
397
|
disk_usage = shutil.disk_usage("/")
|
|
398
|
+
now = time.time()
|
|
323
399
|
|
|
324
400
|
metrics = {
|
|
325
401
|
"cpu": round(psutil.cpu_percent(), 3),
|
|
326
402
|
"ram": round(memory.percent, 3),
|
|
327
|
-
"
|
|
403
|
+
"gpus": {},
|
|
404
|
+
}
|
|
405
|
+
|
|
406
|
+
# Calculate elapsed time since last call
|
|
407
|
+
elapsed = max(0.1, now - self._prev_time) # Avoid division by zero
|
|
408
|
+
|
|
409
|
+
if rates:
|
|
410
|
+
# Calculate MB/s rates from delta since last call
|
|
411
|
+
metrics["disk"] = {
|
|
412
|
+
"read_mbs": round(max(0, (disk.read_bytes - self._prev_disk.read_bytes) / (1 << 20) / elapsed), 3),
|
|
413
|
+
"write_mbs": round(max(0, (disk.write_bytes - self._prev_disk.write_bytes) / (1 << 20) / elapsed), 3),
|
|
414
|
+
"used_gb": round(disk_usage.used / (1 << 30), 3),
|
|
415
|
+
}
|
|
416
|
+
metrics["network"] = {
|
|
417
|
+
"recv_mbs": round(max(0, (net.bytes_recv - self._prev_net.bytes_recv) / (1 << 20) / elapsed), 3),
|
|
418
|
+
"sent_mbs": round(max(0, (net.bytes_sent - self._prev_net.bytes_sent) / (1 << 20) / elapsed), 3),
|
|
419
|
+
}
|
|
420
|
+
else:
|
|
421
|
+
# Cumulative MB since initialization (original behavior)
|
|
422
|
+
metrics["disk"] = {
|
|
328
423
|
"read_mb": round((disk.read_bytes - self.disk_start.read_bytes) / (1 << 20), 3),
|
|
329
424
|
"write_mb": round((disk.write_bytes - self.disk_start.write_bytes) / (1 << 20), 3),
|
|
330
425
|
"used_gb": round(disk_usage.used / (1 << 30), 3),
|
|
331
|
-
}
|
|
332
|
-
"network"
|
|
426
|
+
}
|
|
427
|
+
metrics["network"] = {
|
|
333
428
|
"recv_mb": round((net.bytes_recv - self.net_start.bytes_recv) / (1 << 20), 3),
|
|
334
429
|
"sent_mb": round((net.bytes_sent - self.net_start.bytes_sent) / (1 << 20), 3),
|
|
335
|
-
}
|
|
336
|
-
|
|
337
|
-
|
|
430
|
+
}
|
|
431
|
+
|
|
432
|
+
# Always update previous values for accurate rate calculation on next call
|
|
433
|
+
self._prev_net = net
|
|
434
|
+
self._prev_disk = disk
|
|
435
|
+
self._prev_time = now
|
|
338
436
|
|
|
339
437
|
# Add GPU metrics (NVIDIA only)
|
|
340
438
|
if self.nvidia_initialized:
|
ultralytics/utils/loss.py
CHANGED
|
@@ -18,8 +18,7 @@ from .tal import bbox2dist
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class VarifocalLoss(nn.Module):
|
|
21
|
-
"""
|
|
22
|
-
Varifocal loss by Zhang et al.
|
|
21
|
+
"""Varifocal loss by Zhang et al.
|
|
23
22
|
|
|
24
23
|
Implements the Varifocal Loss function for addressing class imbalance in object detection by focusing on
|
|
25
24
|
hard-to-classify examples and balancing positive/negative samples.
|
|
@@ -51,11 +50,10 @@ class VarifocalLoss(nn.Module):
|
|
|
51
50
|
|
|
52
51
|
|
|
53
52
|
class FocalLoss(nn.Module):
|
|
54
|
-
"""
|
|
55
|
-
Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).
|
|
53
|
+
"""Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).
|
|
56
54
|
|
|
57
|
-
Implements the Focal Loss function for addressing class imbalance by down-weighting easy examples and focusing
|
|
58
|
-
|
|
55
|
+
Implements the Focal Loss function for addressing class imbalance by down-weighting easy examples and focusing on
|
|
56
|
+
hard negatives during training.
|
|
59
57
|
|
|
60
58
|
Attributes:
|
|
61
59
|
gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
|
|
@@ -399,8 +397,7 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|
|
399
397
|
def single_mask_loss(
|
|
400
398
|
gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor
|
|
401
399
|
) -> torch.Tensor:
|
|
402
|
-
"""
|
|
403
|
-
Compute the instance segmentation loss for a single image.
|
|
400
|
+
"""Compute the instance segmentation loss for a single image.
|
|
404
401
|
|
|
405
402
|
Args:
|
|
406
403
|
gt_mask (torch.Tensor): Ground truth mask of shape (N, H, W), where N is the number of objects.
|
|
@@ -432,8 +429,7 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|
|
432
429
|
imgsz: torch.Tensor,
|
|
433
430
|
overlap: bool,
|
|
434
431
|
) -> torch.Tensor:
|
|
435
|
-
"""
|
|
436
|
-
Calculate the loss for instance segmentation.
|
|
432
|
+
"""Calculate the loss for instance segmentation.
|
|
437
433
|
|
|
438
434
|
Args:
|
|
439
435
|
fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive.
|
|
@@ -502,7 +498,7 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
502
498
|
|
|
503
499
|
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
504
500
|
"""Calculate the total loss and detach it for pose estimation."""
|
|
505
|
-
loss = torch.zeros(5, device=self.device) # box,
|
|
501
|
+
loss = torch.zeros(5, device=self.device) # box, pose, kobj, cls, dfl
|
|
506
502
|
feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
|
|
507
503
|
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
|
508
504
|
(self.reg_max * 4, self.nc), 1
|
|
@@ -564,7 +560,7 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
564
560
|
loss[3] *= self.hyp.cls # cls gain
|
|
565
561
|
loss[4] *= self.hyp.dfl # dfl gain
|
|
566
562
|
|
|
567
|
-
return loss * batch_size, loss.detach() # loss(box, cls, dfl)
|
|
563
|
+
return loss * batch_size, loss.detach() # loss(box, pose, kobj, cls, dfl)
|
|
568
564
|
|
|
569
565
|
@staticmethod
|
|
570
566
|
def kpts_decode(anchor_points: torch.Tensor, pred_kpts: torch.Tensor) -> torch.Tensor:
|
|
@@ -585,8 +581,7 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
585
581
|
target_bboxes: torch.Tensor,
|
|
586
582
|
pred_kpts: torch.Tensor,
|
|
587
583
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
588
|
-
"""
|
|
589
|
-
Calculate the keypoints loss for the model.
|
|
584
|
+
"""Calculate the keypoints loss for the model.
|
|
590
585
|
|
|
591
586
|
This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
|
|
592
587
|
based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
|
|
@@ -689,7 +684,7 @@ class v8OBBLoss(v8DetectionLoss):
|
|
|
689
684
|
"""Calculate and return the loss for oriented bounding box detection."""
|
|
690
685
|
loss = torch.zeros(3, device=self.device) # box, cls, dfl
|
|
691
686
|
feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
|
|
692
|
-
batch_size = pred_angle.shape[0] # batch size
|
|
687
|
+
batch_size = pred_angle.shape[0] # batch size
|
|
693
688
|
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
|
694
689
|
(self.reg_max * 4, self.nc), 1
|
|
695
690
|
)
|
|
@@ -707,7 +702,7 @@ class v8OBBLoss(v8DetectionLoss):
|
|
|
707
702
|
try:
|
|
708
703
|
batch_idx = batch["batch_idx"].view(-1, 1)
|
|
709
704
|
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
|
|
710
|
-
rw, rh = targets[:, 4] * imgsz[
|
|
705
|
+
rw, rh = targets[:, 4] * float(imgsz[1]), targets[:, 5] * float(imgsz[0])
|
|
711
706
|
targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
|
|
712
707
|
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
713
708
|
gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
|
|
@@ -760,8 +755,7 @@ class v8OBBLoss(v8DetectionLoss):
|
|
|
760
755
|
def bbox_decode(
|
|
761
756
|
self, anchor_points: torch.Tensor, pred_dist: torch.Tensor, pred_angle: torch.Tensor
|
|
762
757
|
) -> torch.Tensor:
|
|
763
|
-
"""
|
|
764
|
-
Decode predicted object bounding box coordinates from anchor points and distribution.
|
|
758
|
+
"""Decode predicted object bounding box coordinates from anchor points and distribution.
|
|
765
759
|
|
|
766
760
|
Args:
|
|
767
761
|
anchor_points (torch.Tensor): Anchor points, (h*w, 2).
|
|
@@ -809,7 +803,6 @@ class TVPDetectLoss:
|
|
|
809
803
|
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
810
804
|
"""Calculate the loss for text-visual prompt detection."""
|
|
811
805
|
feats = preds[1] if isinstance(preds, tuple) else preds
|
|
812
|
-
assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
|
|
813
806
|
|
|
814
807
|
if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
|
|
815
808
|
loss = torch.zeros(3, device=self.vp_criterion.device, requires_grad=True)
|
|
@@ -817,8 +810,8 @@ class TVPDetectLoss:
|
|
|
817
810
|
|
|
818
811
|
vp_feats = self._get_vp_features(feats)
|
|
819
812
|
vp_loss = self.vp_criterion(vp_feats, batch)
|
|
820
|
-
|
|
821
|
-
return
|
|
813
|
+
cls_loss = vp_loss[0][1]
|
|
814
|
+
return cls_loss, vp_loss[1]
|
|
822
815
|
|
|
823
816
|
def _get_vp_features(self, feats: list[torch.Tensor]) -> list[torch.Tensor]:
|
|
824
817
|
"""Extract visual-prompt features from the model output."""
|
|
@@ -845,7 +838,6 @@ class TVPSegmentLoss(TVPDetectLoss):
|
|
|
845
838
|
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
846
839
|
"""Calculate the loss for text-visual prompt segmentation."""
|
|
847
840
|
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
|
|
848
|
-
assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
|
|
849
841
|
|
|
850
842
|
if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
|
|
851
843
|
loss = torch.zeros(4, device=self.vp_criterion.device, requires_grad=True)
|