ultralytics 8.2.81__py3-none-any.whl → 8.2.82__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- tests/test_solutions.py +0 -4
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +14 -16
- ultralytics/data/annotator.py +1 -1
- ultralytics/data/augment.py +58 -58
- ultralytics/data/base.py +3 -3
- ultralytics/data/converter.py +7 -8
- ultralytics/data/explorer/explorer.py +7 -23
- ultralytics/data/loaders.py +1 -1
- ultralytics/data/split_dota.py +11 -3
- ultralytics/data/utils.py +6 -10
- ultralytics/engine/exporter.py +2 -4
- ultralytics/engine/model.py +47 -47
- ultralytics/engine/predictor.py +1 -1
- ultralytics/engine/results.py +28 -28
- ultralytics/engine/trainer.py +11 -8
- ultralytics/engine/tuner.py +7 -8
- ultralytics/engine/validator.py +3 -5
- ultralytics/hub/__init__.py +5 -5
- ultralytics/hub/auth.py +6 -2
- ultralytics/hub/session.py +3 -5
- ultralytics/models/fastsam/model.py +13 -10
- ultralytics/models/fastsam/predict.py +2 -2
- ultralytics/models/fastsam/utils.py +0 -1
- ultralytics/models/nas/model.py +4 -4
- ultralytics/models/nas/predict.py +1 -2
- ultralytics/models/nas/val.py +1 -1
- ultralytics/models/rtdetr/predict.py +1 -1
- ultralytics/models/rtdetr/train.py +1 -1
- ultralytics/models/rtdetr/val.py +1 -1
- ultralytics/models/sam/model.py +11 -11
- ultralytics/models/sam/modules/decoders.py +7 -4
- ultralytics/models/sam/modules/sam.py +9 -1
- ultralytics/models/sam/modules/tiny_encoder.py +1 -1
- ultralytics/models/sam/modules/transformer.py +0 -2
- ultralytics/models/sam/modules/utils.py +1 -1
- ultralytics/models/sam/predict.py +10 -10
- ultralytics/models/utils/loss.py +29 -17
- ultralytics/models/utils/ops.py +1 -5
- ultralytics/models/yolo/classify/predict.py +1 -1
- ultralytics/models/yolo/classify/train.py +1 -1
- ultralytics/models/yolo/classify/val.py +1 -1
- ultralytics/models/yolo/detect/predict.py +1 -1
- ultralytics/models/yolo/detect/train.py +1 -1
- ultralytics/models/yolo/detect/val.py +1 -1
- ultralytics/models/yolo/model.py +6 -2
- ultralytics/models/yolo/obb/predict.py +1 -1
- ultralytics/models/yolo/obb/train.py +1 -1
- ultralytics/models/yolo/obb/val.py +2 -2
- ultralytics/models/yolo/pose/predict.py +1 -1
- ultralytics/models/yolo/pose/train.py +1 -1
- ultralytics/models/yolo/pose/val.py +1 -1
- ultralytics/models/yolo/segment/predict.py +1 -1
- ultralytics/models/yolo/segment/train.py +1 -1
- ultralytics/models/yolo/segment/val.py +1 -1
- ultralytics/models/yolo/world/train.py +1 -1
- ultralytics/nn/autobackend.py +2 -2
- ultralytics/nn/modules/__init__.py +2 -2
- ultralytics/nn/modules/block.py +8 -20
- ultralytics/nn/modules/conv.py +1 -3
- ultralytics/nn/modules/head.py +16 -31
- ultralytics/nn/modules/transformer.py +0 -1
- ultralytics/nn/modules/utils.py +0 -1
- ultralytics/nn/tasks.py +11 -9
- ultralytics/solutions/__init__.py +1 -0
- ultralytics/solutions/ai_gym.py +0 -2
- ultralytics/solutions/analytics.py +1 -6
- ultralytics/solutions/heatmap.py +0 -1
- ultralytics/solutions/object_counter.py +0 -2
- ultralytics/solutions/queue_management.py +0 -2
- ultralytics/trackers/basetrack.py +1 -1
- ultralytics/trackers/byte_tracker.py +2 -2
- ultralytics/trackers/utils/gmc.py +5 -5
- ultralytics/trackers/utils/kalman_filter.py +1 -1
- ultralytics/trackers/utils/matching.py +1 -5
- ultralytics/utils/__init__.py +122 -23
- ultralytics/utils/autobatch.py +7 -4
- ultralytics/utils/benchmarks.py +6 -14
- ultralytics/utils/callbacks/base.py +0 -1
- ultralytics/utils/callbacks/comet.py +0 -1
- ultralytics/utils/callbacks/tensorboard.py +0 -1
- ultralytics/utils/checks.py +15 -18
- ultralytics/utils/downloads.py +6 -7
- ultralytics/utils/files.py +3 -4
- ultralytics/utils/instance.py +17 -7
- ultralytics/utils/metrics.py +15 -15
- ultralytics/utils/ops.py +8 -8
- ultralytics/utils/plotting.py +25 -35
- ultralytics/utils/tal.py +27 -18
- ultralytics/utils/torch_utils.py +12 -13
- ultralytics/utils/tuner.py +2 -3
- {ultralytics-8.2.81.dist-info → ultralytics-8.2.82.dist-info}/METADATA +1 -1
- {ultralytics-8.2.81.dist-info → ultralytics-8.2.82.dist-info}/RECORD +97 -97
- {ultralytics-8.2.81.dist-info → ultralytics-8.2.82.dist-info}/LICENSE +0 -0
- {ultralytics-8.2.81.dist-info → ultralytics-8.2.82.dist-info}/WHEEL +0 -0
- {ultralytics-8.2.81.dist-info → ultralytics-8.2.82.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.2.81.dist-info → ultralytics-8.2.82.dist-info}/top_level.txt +0 -0
ultralytics/engine/trainer.py
CHANGED
|
@@ -56,8 +56,6 @@ from ultralytics.utils.torch_utils import (
|
|
|
56
56
|
|
|
57
57
|
class BaseTrainer:
|
|
58
58
|
"""
|
|
59
|
-
BaseTrainer.
|
|
60
|
-
|
|
61
59
|
A base class for creating trainers.
|
|
62
60
|
|
|
63
61
|
Attributes:
|
|
@@ -230,7 +228,6 @@ class BaseTrainer:
|
|
|
230
228
|
|
|
231
229
|
def _setup_train(self, world_size):
|
|
232
230
|
"""Builds dataloaders and optimizer on correct rank process."""
|
|
233
|
-
|
|
234
231
|
# Model
|
|
235
232
|
self.run_callbacks("on_pretrain_routine_start")
|
|
236
233
|
ckpt = self.setup_model()
|
|
@@ -478,12 +475,16 @@ class BaseTrainer:
|
|
|
478
475
|
torch.cuda.empty_cache()
|
|
479
476
|
self.run_callbacks("teardown")
|
|
480
477
|
|
|
478
|
+
def read_results_csv(self):
|
|
479
|
+
"""Read results.csv into a dict using pandas."""
|
|
480
|
+
import pandas as pd # scope for faster 'import ultralytics'
|
|
481
|
+
|
|
482
|
+
return {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}
|
|
483
|
+
|
|
481
484
|
def save_model(self):
|
|
482
485
|
"""Save model training checkpoints with additional metadata."""
|
|
483
486
|
import io
|
|
484
487
|
|
|
485
|
-
import pandas as pd # scope for faster 'import ultralytics'
|
|
486
|
-
|
|
487
488
|
# Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
|
|
488
489
|
buffer = io.BytesIO()
|
|
489
490
|
torch.save(
|
|
@@ -496,7 +497,7 @@ class BaseTrainer:
|
|
|
496
497
|
"optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),
|
|
497
498
|
"train_args": vars(self.args), # save as dict
|
|
498
499
|
"train_metrics": {**self.metrics, **{"fitness": self.fitness}},
|
|
499
|
-
"train_results":
|
|
500
|
+
"train_results": self.read_results_csv(),
|
|
500
501
|
"date": datetime.now().isoformat(),
|
|
501
502
|
"version": __version__,
|
|
502
503
|
"license": "AGPL-3.0 (https://ultralytics.com/license)",
|
|
@@ -636,7 +637,7 @@ class BaseTrainer:
|
|
|
636
637
|
pass
|
|
637
638
|
|
|
638
639
|
def on_plot(self, name, data=None):
|
|
639
|
-
"""Registers plots (e.g. to be consumed in callbacks)"""
|
|
640
|
+
"""Registers plots (e.g. to be consumed in callbacks)."""
|
|
640
641
|
path = Path(name)
|
|
641
642
|
self.plots[path] = {"data": data, "timestamp": time.time()}
|
|
642
643
|
|
|
@@ -646,6 +647,9 @@ class BaseTrainer:
|
|
|
646
647
|
if f.exists():
|
|
647
648
|
strip_optimizer(f) # strip optimizers
|
|
648
649
|
if f is self.best:
|
|
650
|
+
if self.last.is_file(): # update best.pt train_metrics from last.pt
|
|
651
|
+
k = "train_results"
|
|
652
|
+
torch.save({**torch.load(self.best), **{k: torch.load(self.last)[k]}}, self.best)
|
|
649
653
|
LOGGER.info(f"\nValidating {f}...")
|
|
650
654
|
self.validator.args.plots = self.args.plots
|
|
651
655
|
self.metrics = self.validator(model=f)
|
|
@@ -732,7 +736,6 @@ class BaseTrainer:
|
|
|
732
736
|
Returns:
|
|
733
737
|
(torch.optim.Optimizer): The constructed optimizer.
|
|
734
738
|
"""
|
|
735
|
-
|
|
736
739
|
g = [], [], [] # optimizer parameter groups
|
|
737
740
|
bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
|
|
738
741
|
if name == "auto":
|
ultralytics/engine/tuner.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
2
|
"""
|
|
3
|
-
|
|
4
|
-
|
|
3
|
+
Module provides functionalities for hyperparameter tuning of the Ultralytics YOLO models for object detection, instance
|
|
4
|
+
segmentation, image classification, pose estimation, and multi-object tracking.
|
|
5
5
|
|
|
6
6
|
Hyperparameter tuning is the process of systematically searching for the optimal set of hyperparameters
|
|
7
7
|
that yield the best model performance. This is particularly crucial in deep learning models like YOLO,
|
|
@@ -12,8 +12,8 @@ Example:
|
|
|
12
12
|
```python
|
|
13
13
|
from ultralytics import YOLO
|
|
14
14
|
|
|
15
|
-
model = YOLO(
|
|
16
|
-
model.tune(data=
|
|
15
|
+
model = YOLO("yolov8n.pt")
|
|
16
|
+
model.tune(data="coco8.yaml", epochs=10, iterations=300, optimizer="AdamW", plots=False, save=False, val=False)
|
|
17
17
|
```
|
|
18
18
|
"""
|
|
19
19
|
|
|
@@ -54,15 +54,15 @@ class Tuner:
|
|
|
54
54
|
```python
|
|
55
55
|
from ultralytics import YOLO
|
|
56
56
|
|
|
57
|
-
model = YOLO(
|
|
58
|
-
model.tune(data=
|
|
57
|
+
model = YOLO("yolov8n.pt")
|
|
58
|
+
model.tune(data="coco8.yaml", epochs=10, iterations=300, optimizer="AdamW", plots=False, save=False, val=False)
|
|
59
59
|
```
|
|
60
60
|
|
|
61
61
|
Tune with custom search space.
|
|
62
62
|
```python
|
|
63
63
|
from ultralytics import YOLO
|
|
64
64
|
|
|
65
|
-
model = YOLO(
|
|
65
|
+
model = YOLO("yolov8n.pt")
|
|
66
66
|
model.tune(space={key1: val1, key2: val2}) # custom search space dictionary
|
|
67
67
|
```
|
|
68
68
|
"""
|
|
@@ -176,7 +176,6 @@ class Tuner:
|
|
|
176
176
|
The method utilizes the `self.tune_csv` Path object to read and log hyperparameters and fitness scores.
|
|
177
177
|
Ensure this path is set correctly in the Tuner instance.
|
|
178
178
|
"""
|
|
179
|
-
|
|
180
179
|
t0 = time.time()
|
|
181
180
|
best_save_dir, best_metrics = None, None
|
|
182
181
|
(self.tune_dir / "weights").mkdir(parents=True, exist_ok=True)
|
ultralytics/engine/validator.py
CHANGED
|
@@ -104,9 +104,7 @@ class BaseValidator:
|
|
|
104
104
|
|
|
105
105
|
@smart_inference_mode()
|
|
106
106
|
def __call__(self, trainer=None, model=None):
|
|
107
|
-
"""
|
|
108
|
-
gets priority).
|
|
109
|
-
"""
|
|
107
|
+
"""Executes validation process, running inference on dataloader and computing performance metrics."""
|
|
110
108
|
self.training = trainer is not None
|
|
111
109
|
augment = self.args.augment and (not self.training)
|
|
112
110
|
if self.training:
|
|
@@ -280,7 +278,7 @@ class BaseValidator:
|
|
|
280
278
|
return batch
|
|
281
279
|
|
|
282
280
|
def postprocess(self, preds):
|
|
283
|
-
"""
|
|
281
|
+
"""Preprocesses the predictions."""
|
|
284
282
|
return preds
|
|
285
283
|
|
|
286
284
|
def init_metrics(self, model):
|
|
@@ -317,7 +315,7 @@ class BaseValidator:
|
|
|
317
315
|
return []
|
|
318
316
|
|
|
319
317
|
def on_plot(self, name, data=None):
|
|
320
|
-
"""Registers plots (e.g. to be consumed in callbacks)"""
|
|
318
|
+
"""Registers plots (e.g. to be consumed in callbacks)."""
|
|
321
319
|
self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
|
|
322
320
|
|
|
323
321
|
# TODO: may need to put these following functions into callback
|
ultralytics/hub/__init__.py
CHANGED
|
@@ -136,11 +136,11 @@ def check_dataset(path: str, task: str) -> None:
|
|
|
136
136
|
```python
|
|
137
137
|
from ultralytics.hub import check_dataset
|
|
138
138
|
|
|
139
|
-
check_dataset(
|
|
140
|
-
check_dataset(
|
|
141
|
-
check_dataset(
|
|
142
|
-
check_dataset(
|
|
143
|
-
check_dataset(
|
|
139
|
+
check_dataset("path/to/coco8.zip", task="detect") # detect dataset
|
|
140
|
+
check_dataset("path/to/coco8-seg.zip", task="segment") # segment dataset
|
|
141
|
+
check_dataset("path/to/coco8-pose.zip", task="pose") # pose dataset
|
|
142
|
+
check_dataset("path/to/dota8.zip", task="obb") # OBB dataset
|
|
143
|
+
check_dataset("path/to/imagenet10.zip", task="classify") # classification dataset
|
|
144
144
|
```
|
|
145
145
|
"""
|
|
146
146
|
HUBDatasetStats(path=path, task=task).get_json()
|
ultralytics/hub/auth.py
CHANGED
|
@@ -27,10 +27,14 @@ class Auth:
|
|
|
27
27
|
|
|
28
28
|
def __init__(self, api_key="", verbose=False):
|
|
29
29
|
"""
|
|
30
|
-
Initialize
|
|
30
|
+
Initialize Auth class and authenticate user.
|
|
31
|
+
|
|
32
|
+
Handles API key validation, Google Colab authentication, and new key requests. Updates SETTINGS upon successful
|
|
33
|
+
authentication.
|
|
31
34
|
|
|
32
35
|
Args:
|
|
33
|
-
api_key (str
|
|
36
|
+
api_key (str): API key or combined key_id format.
|
|
37
|
+
verbose (bool): Enable verbose logging.
|
|
34
38
|
"""
|
|
35
39
|
# Split the input API key in case it contains a combined key_model and keep only the API key part
|
|
36
40
|
api_key = api_key.split("_")[0]
|
ultralytics/hub/session.py
CHANGED
|
@@ -159,7 +159,6 @@ class HUBTrainingSession:
|
|
|
159
159
|
Raises:
|
|
160
160
|
HUBModelError: If the identifier format is not recognized.
|
|
161
161
|
"""
|
|
162
|
-
|
|
163
162
|
# Initialize variables
|
|
164
163
|
api_key, model_id, filename = None, None, None
|
|
165
164
|
|
|
@@ -200,7 +199,6 @@ class HUBTrainingSession:
|
|
|
200
199
|
ValueError: If the model is already trained, if required dataset information is missing, or if there are
|
|
201
200
|
issues with the provided training arguments.
|
|
202
201
|
"""
|
|
203
|
-
|
|
204
202
|
if self.model.is_resumable():
|
|
205
203
|
# Model has saved weights
|
|
206
204
|
self.train_args = {"data": self.model.get_dataset_url(), "resume": True}
|
|
@@ -276,7 +274,7 @@ class HUBTrainingSession:
|
|
|
276
274
|
|
|
277
275
|
# if request related to metrics upload and exceed retries
|
|
278
276
|
if response is None and kwargs.get("metrics"):
|
|
279
|
-
self.metrics_upload_failed_queue.update(kwargs.get("metrics"
|
|
277
|
+
self.metrics_upload_failed_queue.update(kwargs.get("metrics"))
|
|
280
278
|
|
|
281
279
|
return response
|
|
282
280
|
|
|
@@ -350,10 +348,10 @@ class HUBTrainingSession:
|
|
|
350
348
|
last = weights.with_name("last" + weights.suffix)
|
|
351
349
|
if final and last.is_file():
|
|
352
350
|
LOGGER.warning(
|
|
353
|
-
f"{PREFIX}
|
|
351
|
+
f"{PREFIX} WARNING ⚠️ Model 'best.pt' not found, copying 'last.pt' to 'best.pt' and uploading. "
|
|
354
352
|
"This often happens when resuming training in transient environments like Google Colab. "
|
|
355
353
|
"For more reliable training, consider using Ultralytics HUB Cloud. "
|
|
356
|
-
"Learn more at https://docs.ultralytics.com/hub/cloud-training
|
|
354
|
+
"Learn more at https://docs.ultralytics.com/hub/cloud-training."
|
|
357
355
|
)
|
|
358
356
|
shutil.copy(last, weights) # copy last.pt to best.pt
|
|
359
357
|
else:
|
|
@@ -16,8 +16,8 @@ class FastSAM(Model):
|
|
|
16
16
|
```python
|
|
17
17
|
from ultralytics import FastSAM
|
|
18
18
|
|
|
19
|
-
model = FastSAM(
|
|
20
|
-
results = model.predict(
|
|
19
|
+
model = FastSAM("last.pt")
|
|
20
|
+
results = model.predict("ultralytics/assets/bus.jpg")
|
|
21
21
|
```
|
|
22
22
|
"""
|
|
23
23
|
|
|
@@ -30,18 +30,21 @@ class FastSAM(Model):
|
|
|
30
30
|
|
|
31
31
|
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, texts=None, **kwargs):
|
|
32
32
|
"""
|
|
33
|
-
|
|
33
|
+
Perform segmentation prediction on image or video source.
|
|
34
|
+
|
|
35
|
+
Supports prompted segmentation with bounding boxes, points, labels, and texts.
|
|
34
36
|
|
|
35
37
|
Args:
|
|
36
|
-
source (str
|
|
37
|
-
stream (bool
|
|
38
|
-
bboxes (list
|
|
39
|
-
points (list
|
|
40
|
-
labels (list
|
|
41
|
-
texts (list
|
|
38
|
+
source (str | PIL.Image | numpy.ndarray): Input source.
|
|
39
|
+
stream (bool): Enable real-time streaming.
|
|
40
|
+
bboxes (list): Bounding box coordinates for prompted segmentation.
|
|
41
|
+
points (list): Points for prompted segmentation.
|
|
42
|
+
labels (list): Labels for prompted segmentation.
|
|
43
|
+
texts (list): Texts for prompted segmentation.
|
|
44
|
+
**kwargs (Any): Additional keyword arguments.
|
|
42
45
|
|
|
43
46
|
Returns:
|
|
44
|
-
(list):
|
|
47
|
+
(list): Model predictions.
|
|
45
48
|
"""
|
|
46
49
|
prompts = dict(bboxes=bboxes, points=points, labels=labels, texts=texts)
|
|
47
50
|
return super().predict(source, stream, prompts=prompts, **kwargs)
|
|
@@ -92,8 +92,8 @@ class FastSAMPredictor(SegmentationPredictor):
|
|
|
92
92
|
if labels.sum() == 0 # all negative points
|
|
93
93
|
else torch.zeros(len(result), dtype=torch.bool, device=self.device)
|
|
94
94
|
)
|
|
95
|
-
for
|
|
96
|
-
point_idx[torch.nonzero(masks[:,
|
|
95
|
+
for point, label in zip(points, labels):
|
|
96
|
+
point_idx[torch.nonzero(masks[:, point[1], point[0]], as_tuple=True)[0]] = True if label else False
|
|
97
97
|
idx |= point_idx
|
|
98
98
|
if texts is not None:
|
|
99
99
|
if isinstance(texts, str):
|
ultralytics/models/nas/model.py
CHANGED
|
@@ -6,8 +6,8 @@ Example:
|
|
|
6
6
|
```python
|
|
7
7
|
from ultralytics import NAS
|
|
8
8
|
|
|
9
|
-
model = NAS(
|
|
10
|
-
results = model.predict(
|
|
9
|
+
model = NAS("yolo_nas_s")
|
|
10
|
+
results = model.predict("ultralytics/assets/bus.jpg")
|
|
11
11
|
```
|
|
12
12
|
"""
|
|
13
13
|
|
|
@@ -34,8 +34,8 @@ class NAS(Model):
|
|
|
34
34
|
```python
|
|
35
35
|
from ultralytics import NAS
|
|
36
36
|
|
|
37
|
-
model = NAS(
|
|
38
|
-
results = model.predict(
|
|
37
|
+
model = NAS("yolo_nas_s")
|
|
38
|
+
results = model.predict("ultralytics/assets/bus.jpg")
|
|
39
39
|
```
|
|
40
40
|
|
|
41
41
|
Attributes:
|
|
@@ -22,7 +22,7 @@ class NASPredictor(BasePredictor):
|
|
|
22
22
|
```python
|
|
23
23
|
from ultralytics import NAS
|
|
24
24
|
|
|
25
|
-
model = NAS(
|
|
25
|
+
model = NAS("yolo_nas_s")
|
|
26
26
|
predictor = model.predictor
|
|
27
27
|
# Assumes that raw_preds, img, orig_imgs are available
|
|
28
28
|
results = predictor.postprocess(raw_preds, img, orig_imgs)
|
|
@@ -34,7 +34,6 @@ class NASPredictor(BasePredictor):
|
|
|
34
34
|
|
|
35
35
|
def postprocess(self, preds_in, img, orig_imgs):
|
|
36
36
|
"""Postprocess predictions and returns a list of Results objects."""
|
|
37
|
-
|
|
38
37
|
# Cat boxes and class scores
|
|
39
38
|
boxes = ops.xyxy2xywh(preds_in[0][0])
|
|
40
39
|
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
|
ultralytics/models/nas/val.py
CHANGED
|
@@ -24,7 +24,7 @@ class NASValidator(DetectionValidator):
|
|
|
24
24
|
```python
|
|
25
25
|
from ultralytics import NAS
|
|
26
26
|
|
|
27
|
-
model = NAS(
|
|
27
|
+
model = NAS("yolo_nas_s")
|
|
28
28
|
validator = model.validator
|
|
29
29
|
# Assumes that raw_preds are available
|
|
30
30
|
final_preds = validator.postprocess(raw_preds)
|
|
@@ -21,7 +21,7 @@ class RTDETRPredictor(BasePredictor):
|
|
|
21
21
|
from ultralytics.utils import ASSETS
|
|
22
22
|
from ultralytics.models.rtdetr import RTDETRPredictor
|
|
23
23
|
|
|
24
|
-
args = dict(model=
|
|
24
|
+
args = dict(model="rtdetr-l.pt", source=ASSETS)
|
|
25
25
|
predictor = RTDETRPredictor(overrides=args)
|
|
26
26
|
predictor.predict_cli()
|
|
27
27
|
```
|
|
@@ -25,7 +25,7 @@ class RTDETRTrainer(DetectionTrainer):
|
|
|
25
25
|
```python
|
|
26
26
|
from ultralytics.models.rtdetr.train import RTDETRTrainer
|
|
27
27
|
|
|
28
|
-
args = dict(model=
|
|
28
|
+
args = dict(model="rtdetr-l.yaml", data="coco8.yaml", imgsz=640, epochs=3)
|
|
29
29
|
trainer = RTDETRTrainer(overrides=args)
|
|
30
30
|
trainer.train()
|
|
31
31
|
```
|
ultralytics/models/rtdetr/val.py
CHANGED
|
@@ -62,7 +62,7 @@ class RTDETRValidator(DetectionValidator):
|
|
|
62
62
|
```python
|
|
63
63
|
from ultralytics.models.rtdetr import RTDETRValidator
|
|
64
64
|
|
|
65
|
-
args = dict(model=
|
|
65
|
+
args = dict(model="rtdetr-l.pt", data="coco8.yaml")
|
|
66
66
|
validator = RTDETRValidator(args=args)
|
|
67
67
|
validator()
|
|
68
68
|
```
|
ultralytics/models/sam/model.py
CHANGED
|
@@ -41,8 +41,8 @@ class SAM(Model):
|
|
|
41
41
|
info: Logs information about the SAM model.
|
|
42
42
|
|
|
43
43
|
Examples:
|
|
44
|
-
>>> sam = SAM(
|
|
45
|
-
>>> results = sam.predict(
|
|
44
|
+
>>> sam = SAM("sam_b.pt")
|
|
45
|
+
>>> results = sam.predict("image.jpg", points=[[500, 375]])
|
|
46
46
|
>>> for r in results:
|
|
47
47
|
>>> print(f"Detected {len(r.masks)} masks")
|
|
48
48
|
"""
|
|
@@ -58,7 +58,7 @@ class SAM(Model):
|
|
|
58
58
|
NotImplementedError: If the model file extension is not .pt or .pth.
|
|
59
59
|
|
|
60
60
|
Examples:
|
|
61
|
-
>>> sam = SAM(
|
|
61
|
+
>>> sam = SAM("sam_b.pt")
|
|
62
62
|
>>> print(sam.is_sam2)
|
|
63
63
|
"""
|
|
64
64
|
if model and Path(model).suffix not in {".pt", ".pth"}:
|
|
@@ -78,8 +78,8 @@ class SAM(Model):
|
|
|
78
78
|
task (str | None): Task name. If provided, it specifies the particular task the model is being loaded for.
|
|
79
79
|
|
|
80
80
|
Examples:
|
|
81
|
-
>>> sam = SAM(
|
|
82
|
-
>>> sam._load(
|
|
81
|
+
>>> sam = SAM("sam_b.pt")
|
|
82
|
+
>>> sam._load("path/to/custom_weights.pt")
|
|
83
83
|
"""
|
|
84
84
|
self.model = build_sam(weights)
|
|
85
85
|
|
|
@@ -100,8 +100,8 @@ class SAM(Model):
|
|
|
100
100
|
(List): The model predictions.
|
|
101
101
|
|
|
102
102
|
Examples:
|
|
103
|
-
>>> sam = SAM(
|
|
104
|
-
>>> results = sam.predict(
|
|
103
|
+
>>> sam = SAM("sam_b.pt")
|
|
104
|
+
>>> results = sam.predict("image.jpg", points=[[500, 375]])
|
|
105
105
|
>>> for r in results:
|
|
106
106
|
... print(f"Detected {len(r.masks)} masks")
|
|
107
107
|
"""
|
|
@@ -130,8 +130,8 @@ class SAM(Model):
|
|
|
130
130
|
(List): The model predictions, typically containing segmentation masks and other relevant information.
|
|
131
131
|
|
|
132
132
|
Examples:
|
|
133
|
-
>>> sam = SAM(
|
|
134
|
-
>>> results = sam(
|
|
133
|
+
>>> sam = SAM("sam_b.pt")
|
|
134
|
+
>>> results = sam("image.jpg", points=[[500, 375]])
|
|
135
135
|
>>> print(f"Detected {len(results[0].masks)} masks")
|
|
136
136
|
"""
|
|
137
137
|
return self.predict(source, stream, bboxes, points, labels, **kwargs)
|
|
@@ -151,7 +151,7 @@ class SAM(Model):
|
|
|
151
151
|
(Tuple): A tuple containing the model's information (string representations of the model).
|
|
152
152
|
|
|
153
153
|
Examples:
|
|
154
|
-
>>> sam = SAM(
|
|
154
|
+
>>> sam = SAM("sam_b.pt")
|
|
155
155
|
>>> info = sam.info()
|
|
156
156
|
>>> print(info[0]) # Print summary information
|
|
157
157
|
"""
|
|
@@ -167,7 +167,7 @@ class SAM(Model):
|
|
|
167
167
|
class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor.
|
|
168
168
|
|
|
169
169
|
Examples:
|
|
170
|
-
>>> sam = SAM(
|
|
170
|
+
>>> sam = SAM("sam_b.pt")
|
|
171
171
|
>>> task_map = sam.task_map
|
|
172
172
|
>>> print(task_map)
|
|
173
173
|
{'segment': <class 'ultralytics.models.sam.predict.Predictor'>}
|
|
@@ -32,8 +32,9 @@ class MaskDecoder(nn.Module):
|
|
|
32
32
|
|
|
33
33
|
Examples:
|
|
34
34
|
>>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)
|
|
35
|
-
>>> masks, iou_pred = decoder(
|
|
36
|
-
...
|
|
35
|
+
>>> masks, iou_pred = decoder(
|
|
36
|
+
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, multimask_output=True
|
|
37
|
+
... )
|
|
37
38
|
>>> print(f"Predicted masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}")
|
|
38
39
|
"""
|
|
39
40
|
|
|
@@ -213,7 +214,8 @@ class SAM2MaskDecoder(nn.Module):
|
|
|
213
214
|
>>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
|
|
214
215
|
>>> decoder = SAM2MaskDecoder(256, transformer)
|
|
215
216
|
>>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(
|
|
216
|
-
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False
|
|
217
|
+
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False
|
|
218
|
+
... )
|
|
217
219
|
"""
|
|
218
220
|
|
|
219
221
|
def __init__(
|
|
@@ -345,7 +347,8 @@ class SAM2MaskDecoder(nn.Module):
|
|
|
345
347
|
>>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
|
|
346
348
|
>>> decoder = SAM2MaskDecoder(256, transformer)
|
|
347
349
|
>>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(
|
|
348
|
-
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False
|
|
350
|
+
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False
|
|
351
|
+
... )
|
|
349
352
|
"""
|
|
350
353
|
masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
|
|
351
354
|
image_embeddings=image_embeddings,
|
|
@@ -417,7 +417,15 @@ class SAM2Model(torch.nn.Module):
|
|
|
417
417
|
>>> point_inputs = {"point_coords": torch.rand(1, 2, 2), "point_labels": torch.tensor([[1, 0]])}
|
|
418
418
|
>>> mask_inputs = torch.rand(1, 1, 512, 512)
|
|
419
419
|
>>> results = model._forward_sam_heads(backbone_features, point_inputs, mask_inputs)
|
|
420
|
-
>>>
|
|
420
|
+
>>> (
|
|
421
|
+
... low_res_multimasks,
|
|
422
|
+
... high_res_multimasks,
|
|
423
|
+
... ious,
|
|
424
|
+
... low_res_masks,
|
|
425
|
+
... high_res_masks,
|
|
426
|
+
... obj_ptr,
|
|
427
|
+
... object_score_logits,
|
|
428
|
+
... ) = results
|
|
421
429
|
"""
|
|
422
430
|
B = backbone_features.size(0)
|
|
423
431
|
device = backbone_features.device
|
|
@@ -716,7 +716,7 @@ class BasicLayer(nn.Module):
|
|
|
716
716
|
|
|
717
717
|
Examples:
|
|
718
718
|
>>> layer = BasicLayer(dim=96, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)
|
|
719
|
-
>>> x = torch.randn(1, 56*56, 96)
|
|
719
|
+
>>> x = torch.randn(1, 56 * 56, 96)
|
|
720
720
|
>>> output = layer(x)
|
|
721
721
|
>>> print(output.shape)
|
|
722
722
|
"""
|
|
@@ -232,7 +232,6 @@ class TwoWayAttentionBlock(nn.Module):
|
|
|
232
232
|
|
|
233
233
|
def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
|
|
234
234
|
"""Applies two-way attention to process query and key embeddings in a transformer block."""
|
|
235
|
-
|
|
236
235
|
# Self attention block
|
|
237
236
|
if self.skip_first_layer_pe:
|
|
238
237
|
queries = self.self_attn(q=queries, k=queries, v=queries)
|
|
@@ -353,7 +352,6 @@ class Attention(nn.Module):
|
|
|
353
352
|
|
|
354
353
|
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
|
355
354
|
"""Applies multi-head attention to query, key, and value tensors with optional downsampling."""
|
|
356
|
-
|
|
357
355
|
# Input projections
|
|
358
356
|
q = self.q_proj(q)
|
|
359
357
|
k = self.k_proj(k)
|
|
@@ -22,7 +22,7 @@ def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num
|
|
|
22
22
|
|
|
23
23
|
Examples:
|
|
24
24
|
>>> frame_idx = 5
|
|
25
|
-
>>> cond_frame_outputs = {1:
|
|
25
|
+
>>> cond_frame_outputs = {1: "a", 3: "b", 7: "c", 9: "d"}
|
|
26
26
|
>>> max_cond_frame_num = 2
|
|
27
27
|
>>> selected, unselected = select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num)
|
|
28
28
|
>>> print(selected)
|
|
@@ -69,8 +69,8 @@ class Predictor(BasePredictor):
|
|
|
69
69
|
|
|
70
70
|
Examples:
|
|
71
71
|
>>> predictor = Predictor()
|
|
72
|
-
>>> predictor.setup_model(model_path=
|
|
73
|
-
>>> predictor.set_image(
|
|
72
|
+
>>> predictor.setup_model(model_path="sam_model.pt")
|
|
73
|
+
>>> predictor.set_image("image.jpg")
|
|
74
74
|
>>> masks, scores, boxes = predictor.generate()
|
|
75
75
|
>>> results = predictor.postprocess((masks, scores, boxes), im, orig_img)
|
|
76
76
|
"""
|
|
@@ -90,8 +90,8 @@ class Predictor(BasePredictor):
|
|
|
90
90
|
|
|
91
91
|
Examples:
|
|
92
92
|
>>> predictor = Predictor(cfg=DEFAULT_CFG)
|
|
93
|
-
>>> predictor = Predictor(overrides={
|
|
94
|
-
>>> predictor = Predictor(_callbacks={
|
|
93
|
+
>>> predictor = Predictor(overrides={"imgsz": 640})
|
|
94
|
+
>>> predictor = Predictor(_callbacks={"on_predict_start": custom_callback})
|
|
95
95
|
"""
|
|
96
96
|
if overrides is None:
|
|
97
97
|
overrides = {}
|
|
@@ -188,8 +188,8 @@ class Predictor(BasePredictor):
|
|
|
188
188
|
|
|
189
189
|
Examples:
|
|
190
190
|
>>> predictor = Predictor()
|
|
191
|
-
>>> predictor.setup_model(model_path=
|
|
192
|
-
>>> predictor.set_image(
|
|
191
|
+
>>> predictor.setup_model(model_path="sam_model.pt")
|
|
192
|
+
>>> predictor.set_image("image.jpg")
|
|
193
193
|
>>> masks, scores, logits = predictor.inference(im, bboxes=[[0, 0, 100, 100]])
|
|
194
194
|
"""
|
|
195
195
|
# Override prompts if any stored in self.prompts
|
|
@@ -475,8 +475,8 @@ class Predictor(BasePredictor):
|
|
|
475
475
|
|
|
476
476
|
Examples:
|
|
477
477
|
>>> predictor = Predictor()
|
|
478
|
-
>>> predictor.setup_source(
|
|
479
|
-
>>> predictor.setup_source(
|
|
478
|
+
>>> predictor.setup_source("path/to/images")
|
|
479
|
+
>>> predictor.setup_source("video.mp4")
|
|
480
480
|
>>> predictor.setup_source(None) # Uses default source if available
|
|
481
481
|
|
|
482
482
|
Notes:
|
|
@@ -504,8 +504,8 @@ class Predictor(BasePredictor):
|
|
|
504
504
|
|
|
505
505
|
Examples:
|
|
506
506
|
>>> predictor = Predictor()
|
|
507
|
-
>>> predictor.set_image(
|
|
508
|
-
>>> predictor.set_image(cv2.imread(
|
|
507
|
+
>>> predictor.set_image("path/to/image.jpg")
|
|
508
|
+
>>> predictor.set_image(cv2.imread("path/to/image.jpg"))
|
|
509
509
|
|
|
510
510
|
Notes:
|
|
511
511
|
- This method should be called before performing inference on a new image.
|