ultralytics 8.0.142__py3-none-any.whl → 8.0.144__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +2 -3
- ultralytics/engine/exporter.py +32 -3
- ultralytics/engine/model.py +55 -32
- ultralytics/engine/predictor.py +3 -1
- ultralytics/engine/results.py +3 -2
- ultralytics/hub/__init__.py +6 -6
- ultralytics/hub/auth.py +2 -2
- ultralytics/hub/session.py +5 -5
- ultralytics/hub/utils.py +1 -0
- ultralytics/models/__init__.py +4 -1
- ultralytics/models/fastsam/model.py +18 -98
- ultralytics/models/nas/model.py +14 -88
- ultralytics/models/rtdetr/model.py +16 -159
- ultralytics/models/sam/model.py +16 -25
- ultralytics/models/sam/predict.py +11 -1
- ultralytics/models/yolo/__init__.py +3 -1
- ultralytics/models/yolo/detect/train.py +1 -0
- ultralytics/models/yolo/model.py +36 -0
- ultralytics/models/yolo/segment/train.py +1 -0
- ultralytics/nn/autobackend.py +8 -8
- ultralytics/utils/callbacks/hub.py +4 -4
- ultralytics/utils/ops.py +1 -4
- {ultralytics-8.0.142.dist-info → ultralytics-8.0.144.dist-info}/METADATA +1 -1
- {ultralytics-8.0.142.dist-info → ultralytics-8.0.144.dist-info}/RECORD +28 -27
- {ultralytics-8.0.142.dist-info → ultralytics-8.0.144.dist-info}/LICENSE +0 -0
- {ultralytics-8.0.142.dist-info → ultralytics-8.0.144.dist-info}/WHEEL +0 -0
- {ultralytics-8.0.142.dist-info → ultralytics-8.0.144.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.0.142.dist-info → ultralytics-8.0.144.dist-info}/top_level.txt +0 -0
ultralytics/models/nas/model.py
CHANGED
|
@@ -13,105 +13,36 @@ from pathlib import Path
|
|
|
13
13
|
|
|
14
14
|
import torch
|
|
15
15
|
|
|
16
|
-
from ultralytics.
|
|
17
|
-
from ultralytics.engine.exporter import Exporter
|
|
18
|
-
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, ROOT, is_git_dir
|
|
19
|
-
from ultralytics.utils.checks import check_imgsz
|
|
16
|
+
from ultralytics.engine.model import Model
|
|
20
17
|
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
|
|
21
18
|
|
|
22
19
|
from .predict import NASPredictor
|
|
23
20
|
from .val import NASValidator
|
|
24
21
|
|
|
25
22
|
|
|
26
|
-
class NAS:
|
|
23
|
+
class NAS(Model):
|
|
27
24
|
|
|
28
25
|
def __init__(self, model='yolo_nas_s.pt') -> None:
|
|
26
|
+
assert Path(model).suffix != '.yaml', 'YOLO-NAS models only support pre-trained models.'
|
|
27
|
+
super().__init__(model, task='detect')
|
|
28
|
+
|
|
29
|
+
@smart_inference_mode()
|
|
30
|
+
def _load(self, weights: str, task: str):
|
|
29
31
|
# Load or create new NAS model
|
|
30
32
|
import super_gradients
|
|
31
|
-
|
|
32
|
-
self.predictor = None
|
|
33
|
-
suffix = Path(model).suffix
|
|
33
|
+
suffix = Path(weights).suffix
|
|
34
34
|
if suffix == '.pt':
|
|
35
|
-
self.
|
|
35
|
+
self.model = torch.load(weights)
|
|
36
36
|
elif suffix == '':
|
|
37
|
-
self.model = super_gradients.training.models.get(
|
|
38
|
-
self.task = 'detect'
|
|
39
|
-
self.model.args = DEFAULT_CFG_DICT # attach args to model
|
|
40
|
-
|
|
37
|
+
self.model = super_gradients.training.models.get(weights, pretrained_weights='coco')
|
|
41
38
|
# Standardize model
|
|
42
39
|
self.model.fuse = lambda verbose=True: self.model
|
|
43
40
|
self.model.stride = torch.tensor([32])
|
|
44
41
|
self.model.names = dict(enumerate(self.model._class_names))
|
|
45
42
|
self.model.is_fused = lambda: False # for info()
|
|
46
43
|
self.model.yaml = {} # for info()
|
|
47
|
-
self.model.pt_path =
|
|
44
|
+
self.model.pt_path = weights # for export()
|
|
48
45
|
self.model.task = 'detect' # for export()
|
|
49
|
-
self.info()
|
|
50
|
-
|
|
51
|
-
@smart_inference_mode()
|
|
52
|
-
def _load(self, weights: str):
|
|
53
|
-
self.model = torch.load(weights)
|
|
54
|
-
|
|
55
|
-
@smart_inference_mode()
|
|
56
|
-
def predict(self, source=None, stream=False, **kwargs):
|
|
57
|
-
"""
|
|
58
|
-
Perform prediction using the YOLO model.
|
|
59
|
-
|
|
60
|
-
Args:
|
|
61
|
-
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
|
|
62
|
-
Accepts all source types accepted by the YOLO model.
|
|
63
|
-
stream (bool): Whether to stream the predictions or not. Defaults to False.
|
|
64
|
-
**kwargs : Additional keyword arguments passed to the predictor.
|
|
65
|
-
Check the 'configuration' section in the documentation for all available options.
|
|
66
|
-
|
|
67
|
-
Returns:
|
|
68
|
-
(List[ultralytics.engine.results.Results]): The prediction results.
|
|
69
|
-
"""
|
|
70
|
-
if source is None:
|
|
71
|
-
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
|
|
72
|
-
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
|
73
|
-
overrides = dict(conf=0.25, task='detect', mode='predict')
|
|
74
|
-
overrides.update(kwargs) # prefer kwargs
|
|
75
|
-
if not self.predictor:
|
|
76
|
-
self.predictor = NASPredictor(overrides=overrides)
|
|
77
|
-
self.predictor.setup_model(model=self.model)
|
|
78
|
-
else: # only update args if predictor is already setup
|
|
79
|
-
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
|
80
|
-
return self.predictor(source, stream=stream)
|
|
81
|
-
|
|
82
|
-
def train(self, **kwargs):
|
|
83
|
-
"""Function trains models but raises an error as NAS models do not support training."""
|
|
84
|
-
raise NotImplementedError("NAS models don't support training")
|
|
85
|
-
|
|
86
|
-
def val(self, **kwargs):
|
|
87
|
-
"""Run validation given dataset."""
|
|
88
|
-
overrides = dict(task='detect', mode='val')
|
|
89
|
-
overrides.update(kwargs) # prefer kwargs
|
|
90
|
-
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
|
91
|
-
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
|
|
92
|
-
validator = NASValidator(args=args)
|
|
93
|
-
validator(model=self.model)
|
|
94
|
-
self.metrics = validator.metrics
|
|
95
|
-
return validator.metrics
|
|
96
|
-
|
|
97
|
-
@smart_inference_mode()
|
|
98
|
-
def export(self, **kwargs):
|
|
99
|
-
"""
|
|
100
|
-
Export model.
|
|
101
|
-
|
|
102
|
-
Args:
|
|
103
|
-
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
|
|
104
|
-
"""
|
|
105
|
-
overrides = dict(task='detect')
|
|
106
|
-
overrides.update(kwargs)
|
|
107
|
-
overrides['mode'] = 'export'
|
|
108
|
-
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
|
109
|
-
args.task = self.task
|
|
110
|
-
if args.imgsz == DEFAULT_CFG.imgsz:
|
|
111
|
-
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
|
112
|
-
if args.batch == DEFAULT_CFG.batch:
|
|
113
|
-
args.batch = 1 # default to 1 if not modified
|
|
114
|
-
return Exporter(overrides=args)(model=self.model)
|
|
115
46
|
|
|
116
47
|
def info(self, detailed=False, verbose=True):
|
|
117
48
|
"""
|
|
@@ -123,11 +54,6 @@ class NAS:
|
|
|
123
54
|
"""
|
|
124
55
|
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
|
|
125
56
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
return
|
|
129
|
-
|
|
130
|
-
def __getattr__(self, attr):
|
|
131
|
-
"""Raises error if object has no requested attribute."""
|
|
132
|
-
name = self.__class__.__name__
|
|
133
|
-
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
|
57
|
+
@property
|
|
58
|
+
def task_map(self):
|
|
59
|
+
return {'detect': {'predictor': NASPredictor, 'validator': NASValidator}}
|
|
@@ -2,172 +2,29 @@
|
|
|
2
2
|
"""
|
|
3
3
|
RT-DETR model interface
|
|
4
4
|
"""
|
|
5
|
-
|
|
6
|
-
from
|
|
7
|
-
|
|
8
|
-
import torch.nn as nn
|
|
9
|
-
|
|
10
|
-
from ultralytics.cfg import get_cfg
|
|
11
|
-
from ultralytics.engine.exporter import Exporter
|
|
12
|
-
from ultralytics.nn.tasks import RTDETRDetectionModel, attempt_load_one_weight, yaml_model_load
|
|
13
|
-
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, RANK, ROOT, is_git_dir
|
|
14
|
-
from ultralytics.utils.checks import check_imgsz
|
|
15
|
-
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
|
|
5
|
+
from ultralytics.engine.model import Model
|
|
6
|
+
from ultralytics.nn.tasks import RTDETRDetectionModel
|
|
16
7
|
|
|
17
8
|
from .predict import RTDETRPredictor
|
|
18
9
|
from .train import RTDETRTrainer
|
|
19
10
|
from .val import RTDETRValidator
|
|
20
11
|
|
|
21
12
|
|
|
22
|
-
class RTDETR:
|
|
13
|
+
class RTDETR(Model):
|
|
14
|
+
"""
|
|
15
|
+
RTDETR model interface.
|
|
16
|
+
"""
|
|
23
17
|
|
|
24
18
|
def __init__(self, model='rtdetr-l.pt') -> None:
|
|
25
19
|
if model and not model.endswith('.pt') and not model.endswith('.yaml'):
|
|
26
20
|
raise NotImplementedError('RT-DETR only supports creating from pt file or yaml file.')
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
cfg_dict = yaml_model_load(cfg)
|
|
38
|
-
self.cfg = cfg
|
|
39
|
-
self.task = 'detect'
|
|
40
|
-
self.model = RTDETRDetectionModel(cfg_dict, verbose=verbose) # build model
|
|
41
|
-
|
|
42
|
-
# Below added to allow export from YAMLs
|
|
43
|
-
self.model.args = DEFAULT_CFG_DICT # attach args to model
|
|
44
|
-
self.model.task = self.task
|
|
45
|
-
|
|
46
|
-
@smart_inference_mode()
|
|
47
|
-
def _load(self, weights: str):
|
|
48
|
-
self.model, self.ckpt = attempt_load_one_weight(weights)
|
|
49
|
-
self.model.args = DEFAULT_CFG_DICT # attach args to model
|
|
50
|
-
self.task = self.model.args['task']
|
|
51
|
-
|
|
52
|
-
@smart_inference_mode()
|
|
53
|
-
def load(self, weights='yolov8n.pt'):
|
|
54
|
-
"""
|
|
55
|
-
Transfers parameters with matching names and shapes from 'weights' to model.
|
|
56
|
-
"""
|
|
57
|
-
if isinstance(weights, (str, Path)):
|
|
58
|
-
weights, self.ckpt = attempt_load_one_weight(weights)
|
|
59
|
-
self.model.load(weights)
|
|
60
|
-
return self
|
|
61
|
-
|
|
62
|
-
@smart_inference_mode()
|
|
63
|
-
def predict(self, source=None, stream=False, **kwargs):
|
|
64
|
-
"""
|
|
65
|
-
Perform prediction using the YOLO model.
|
|
66
|
-
|
|
67
|
-
Args:
|
|
68
|
-
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
|
|
69
|
-
Accepts all source types accepted by the YOLO model.
|
|
70
|
-
stream (bool): Whether to stream the predictions or not. Defaults to False.
|
|
71
|
-
**kwargs : Additional keyword arguments passed to the predictor.
|
|
72
|
-
Check the 'configuration' section in the documentation for all available options.
|
|
73
|
-
|
|
74
|
-
Returns:
|
|
75
|
-
(List[ultralytics.engine.results.Results]): The prediction results.
|
|
76
|
-
"""
|
|
77
|
-
if source is None:
|
|
78
|
-
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
|
|
79
|
-
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
|
80
|
-
overrides = dict(conf=0.25, task='detect', mode='predict')
|
|
81
|
-
overrides.update(kwargs) # prefer kwargs
|
|
82
|
-
if not self.predictor:
|
|
83
|
-
self.predictor = RTDETRPredictor(overrides=overrides)
|
|
84
|
-
self.predictor.setup_model(model=self.model)
|
|
85
|
-
else: # only update args if predictor is already setup
|
|
86
|
-
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
|
87
|
-
return self.predictor(source, stream=stream)
|
|
88
|
-
|
|
89
|
-
def train(self, **kwargs):
|
|
90
|
-
"""
|
|
91
|
-
Trains the model on a given dataset.
|
|
92
|
-
|
|
93
|
-
Args:
|
|
94
|
-
**kwargs (Any): Any number of arguments representing the training configuration.
|
|
95
|
-
"""
|
|
96
|
-
overrides = dict(task='detect', mode='train')
|
|
97
|
-
overrides.update(kwargs)
|
|
98
|
-
overrides['deterministic'] = False
|
|
99
|
-
if not overrides.get('data'):
|
|
100
|
-
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
|
|
101
|
-
if overrides.get('resume'):
|
|
102
|
-
overrides['resume'] = self.ckpt_path
|
|
103
|
-
self.task = overrides.get('task') or self.task
|
|
104
|
-
self.trainer = RTDETRTrainer(overrides=overrides)
|
|
105
|
-
if not overrides.get('resume'): # manually set model only if not resuming
|
|
106
|
-
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
|
107
|
-
self.model = self.trainer.model
|
|
108
|
-
self.trainer.train()
|
|
109
|
-
# Update model and cfg after training
|
|
110
|
-
if RANK in (-1, 0):
|
|
111
|
-
self.model, _ = attempt_load_one_weight(str(self.trainer.best))
|
|
112
|
-
self.overrides = self.model.args
|
|
113
|
-
self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
|
|
114
|
-
|
|
115
|
-
def val(self, **kwargs):
|
|
116
|
-
"""Run validation given dataset."""
|
|
117
|
-
overrides = dict(task='detect', mode='val')
|
|
118
|
-
overrides.update(kwargs) # prefer kwargs
|
|
119
|
-
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
|
120
|
-
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
|
|
121
|
-
validator = RTDETRValidator(args=args)
|
|
122
|
-
validator(model=self.model)
|
|
123
|
-
self.metrics = validator.metrics
|
|
124
|
-
return validator.metrics
|
|
125
|
-
|
|
126
|
-
def info(self, verbose=True):
|
|
127
|
-
"""Get model info"""
|
|
128
|
-
return model_info(self.model, verbose=verbose)
|
|
129
|
-
|
|
130
|
-
def _check_is_pytorch_model(self):
|
|
131
|
-
"""
|
|
132
|
-
Raises TypeError is model is not a PyTorch model
|
|
133
|
-
"""
|
|
134
|
-
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == '.pt'
|
|
135
|
-
pt_module = isinstance(self.model, nn.Module)
|
|
136
|
-
if not (pt_module or pt_str):
|
|
137
|
-
raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. "
|
|
138
|
-
f'PyTorch models can be used to train, val, predict and export, i.e. '
|
|
139
|
-
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
|
|
140
|
-
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
|
|
141
|
-
|
|
142
|
-
def fuse(self):
|
|
143
|
-
"""Fuse PyTorch Conv2d and BatchNorm2d layers."""
|
|
144
|
-
self._check_is_pytorch_model()
|
|
145
|
-
self.model.fuse()
|
|
146
|
-
|
|
147
|
-
@smart_inference_mode()
|
|
148
|
-
def export(self, **kwargs):
|
|
149
|
-
"""
|
|
150
|
-
Export model.
|
|
151
|
-
|
|
152
|
-
Args:
|
|
153
|
-
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
|
|
154
|
-
"""
|
|
155
|
-
overrides = dict(task='detect')
|
|
156
|
-
overrides.update(kwargs)
|
|
157
|
-
overrides['mode'] = 'export'
|
|
158
|
-
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
|
159
|
-
args.task = self.task
|
|
160
|
-
if args.imgsz == DEFAULT_CFG.imgsz:
|
|
161
|
-
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
|
162
|
-
if args.batch == DEFAULT_CFG.batch:
|
|
163
|
-
args.batch = 1 # default to 1 if not modified
|
|
164
|
-
return Exporter(overrides=args)(model=self.model)
|
|
165
|
-
|
|
166
|
-
def __call__(self, source=None, stream=False, **kwargs):
|
|
167
|
-
"""Calls the 'predict' function with given arguments to perform object detection."""
|
|
168
|
-
return self.predict(source, stream, **kwargs)
|
|
169
|
-
|
|
170
|
-
def __getattr__(self, attr):
|
|
171
|
-
"""Raises error if object has no requested attribute."""
|
|
172
|
-
name = self.__class__.__name__
|
|
173
|
-
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
|
21
|
+
super().__init__(model=model, task='detect')
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def task_map(self):
|
|
25
|
+
return {
|
|
26
|
+
'detect': {
|
|
27
|
+
'predictor': RTDETRPredictor,
|
|
28
|
+
'validator': RTDETRValidator,
|
|
29
|
+
'trainer': RTDETRTrainer,
|
|
30
|
+
'model': RTDETRDetectionModel}}
|
ultralytics/models/sam/model.py
CHANGED
|
@@ -3,51 +3,38 @@
|
|
|
3
3
|
SAM model interface
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
-
from ultralytics.
|
|
6
|
+
from ultralytics.engine.model import Model
|
|
7
7
|
from ultralytics.utils.torch_utils import model_info
|
|
8
8
|
|
|
9
9
|
from .build import build_sam
|
|
10
10
|
from .predict import Predictor
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
class SAM:
|
|
13
|
+
class SAM(Model):
|
|
14
|
+
"""
|
|
15
|
+
SAM model interface.
|
|
16
|
+
"""
|
|
14
17
|
|
|
15
18
|
def __init__(self, model='sam_b.pt') -> None:
|
|
16
19
|
if model and not model.endswith('.pt') and not model.endswith('.pth'):
|
|
17
20
|
# Should raise AssertionError instead?
|
|
18
21
|
raise NotImplementedError('Segment anything prediction requires pre-trained checkpoint')
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
+
super().__init__(model=model, task='segment')
|
|
23
|
+
|
|
24
|
+
def _load(self, weights: str, task=None):
|
|
25
|
+
self.model = build_sam(weights)
|
|
22
26
|
|
|
23
27
|
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
|
|
24
28
|
"""Predicts and returns segmentation masks for given image or video source."""
|
|
25
29
|
overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024)
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
self.predictor.setup_model(model=self.model)
|
|
30
|
-
else: # only update args if predictor is already setup
|
|
31
|
-
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
|
32
|
-
return self.predictor(source, stream=stream, bboxes=bboxes, points=points, labels=labels)
|
|
33
|
-
|
|
34
|
-
def train(self, **kwargs):
|
|
35
|
-
"""Function trains models but raises an error as SAM models do not support training."""
|
|
36
|
-
raise NotImplementedError("SAM models don't support training")
|
|
37
|
-
|
|
38
|
-
def val(self, **kwargs):
|
|
39
|
-
"""Run validation given dataset."""
|
|
40
|
-
raise NotImplementedError("SAM models don't support validation")
|
|
30
|
+
kwargs.update(overrides)
|
|
31
|
+
prompts = dict(bboxes=bboxes, points=points, labels=labels)
|
|
32
|
+
return super().predict(source, stream, prompts=prompts, **kwargs)
|
|
41
33
|
|
|
42
34
|
def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
|
|
43
35
|
"""Calls the 'predict' function with given arguments to perform object detection."""
|
|
44
36
|
return self.predict(source, stream, bboxes, points, labels, **kwargs)
|
|
45
37
|
|
|
46
|
-
def __getattr__(self, attr):
|
|
47
|
-
"""Raises error if object has no requested attribute."""
|
|
48
|
-
name = self.__class__.__name__
|
|
49
|
-
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
|
50
|
-
|
|
51
38
|
def info(self, detailed=False, verbose=True):
|
|
52
39
|
"""
|
|
53
40
|
Logs model info.
|
|
@@ -57,3 +44,7 @@ class SAM:
|
|
|
57
44
|
verbose (bool): Controls verbosity.
|
|
58
45
|
"""
|
|
59
46
|
return model_info(self.model, detailed=detailed, verbose=verbose)
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def task_map(self):
|
|
50
|
+
return {'segment': {'predictor': Predictor}}
|
|
@@ -28,6 +28,8 @@ class Predictor(BasePredictor):
|
|
|
28
28
|
# Args for set_image
|
|
29
29
|
self.im = None
|
|
30
30
|
self.features = None
|
|
31
|
+
# Args for set_prompts
|
|
32
|
+
self.prompts = {}
|
|
31
33
|
# Args for segment everything
|
|
32
34
|
self.segment_all = False
|
|
33
35
|
|
|
@@ -92,6 +94,10 @@ class Predictor(BasePredictor):
|
|
|
92
94
|
of masks and H=W=256. These low resolution logits can be passed to
|
|
93
95
|
a subsequent iteration as mask input.
|
|
94
96
|
"""
|
|
97
|
+
# Get prompts from self.prompts first
|
|
98
|
+
bboxes = self.prompts.pop('bboxes', bboxes)
|
|
99
|
+
points = self.prompts.pop('points', points)
|
|
100
|
+
masks = self.prompts.pop('masks', masks)
|
|
95
101
|
if all(i is None for i in [bboxes, points, masks]):
|
|
96
102
|
return self.generate(im, *args, **kwargs)
|
|
97
103
|
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
|
|
@@ -288,7 +294,7 @@ class Predictor(BasePredictor):
|
|
|
288
294
|
|
|
289
295
|
def setup_model(self, model, verbose=True):
|
|
290
296
|
"""Set up YOLO model with specified thresholds and device."""
|
|
291
|
-
device = select_device(self.args.device)
|
|
297
|
+
device = select_device(self.args.device, verbose=verbose)
|
|
292
298
|
if model is None:
|
|
293
299
|
model = build_sam(self.args.model)
|
|
294
300
|
model.eval()
|
|
@@ -348,6 +354,10 @@ class Predictor(BasePredictor):
|
|
|
348
354
|
self.im = im
|
|
349
355
|
break
|
|
350
356
|
|
|
357
|
+
def set_prompts(self, prompts):
|
|
358
|
+
"""Set prompts in advance."""
|
|
359
|
+
self.prompts = prompts
|
|
360
|
+
|
|
351
361
|
def reset_image(self):
|
|
352
362
|
self.im = None
|
|
353
363
|
self.features = None
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
|
+
|
|
3
|
+
from ultralytics.engine.model import Model
|
|
4
|
+
from ultralytics.models import yolo # noqa
|
|
5
|
+
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, PoseModel, SegmentationModel
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class YOLO(Model):
|
|
9
|
+
"""
|
|
10
|
+
YOLO (You Only Look Once) object detection model.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
@property
|
|
14
|
+
def task_map(self):
|
|
15
|
+
"""Map head to model, trainer, validator, and predictor classes"""
|
|
16
|
+
return {
|
|
17
|
+
'classify': {
|
|
18
|
+
'model': ClassificationModel,
|
|
19
|
+
'trainer': yolo.classify.ClassificationTrainer,
|
|
20
|
+
'validator': yolo.classify.ClassificationValidator,
|
|
21
|
+
'predictor': yolo.classify.ClassificationPredictor, },
|
|
22
|
+
'detect': {
|
|
23
|
+
'model': DetectionModel,
|
|
24
|
+
'trainer': yolo.detect.DetectionTrainer,
|
|
25
|
+
'validator': yolo.detect.DetectionValidator,
|
|
26
|
+
'predictor': yolo.detect.DetectionPredictor, },
|
|
27
|
+
'segment': {
|
|
28
|
+
'model': SegmentationModel,
|
|
29
|
+
'trainer': yolo.segment.SegmentationTrainer,
|
|
30
|
+
'validator': yolo.segment.SegmentationValidator,
|
|
31
|
+
'predictor': yolo.segment.SegmentationPredictor, },
|
|
32
|
+
'pose': {
|
|
33
|
+
'model': PoseModel,
|
|
34
|
+
'trainer': yolo.pose.PoseTrainer,
|
|
35
|
+
'validator': yolo.pose.PoseValidator,
|
|
36
|
+
'predictor': yolo.pose.PosePredictor, }, }
|
ultralytics/nn/autobackend.py
CHANGED
|
@@ -400,21 +400,21 @@ class AutoBackend(nn.Module):
|
|
|
400
400
|
nc = y[ib].shape[1] - y[ip].shape[3] - 4 # y = (1, 160, 160, 32), (1, 116, 8400)
|
|
401
401
|
self.names = {i: f'class{i}' for i in range(nc)}
|
|
402
402
|
else: # Lite or Edge TPU
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
if
|
|
406
|
-
scale, zero_point =
|
|
407
|
-
im = (im / scale + zero_point).astype(
|
|
408
|
-
self.interpreter.set_tensor(
|
|
403
|
+
details = self.input_details[0]
|
|
404
|
+
integer = details['dtype'] in (np.int8, np.int16) # is TFLite quantized int8 or int16 model
|
|
405
|
+
if integer:
|
|
406
|
+
scale, zero_point = details['quantization']
|
|
407
|
+
im = (im / scale + zero_point).astype(details['dtype']) # de-scale
|
|
408
|
+
self.interpreter.set_tensor(details['index'], im)
|
|
409
409
|
self.interpreter.invoke()
|
|
410
410
|
y = []
|
|
411
411
|
for output in self.output_details:
|
|
412
412
|
x = self.interpreter.get_tensor(output['index'])
|
|
413
|
-
if
|
|
413
|
+
if integer:
|
|
414
414
|
scale, zero_point = output['quantization']
|
|
415
415
|
x = (x.astype(np.float32) - zero_point) * scale # re-scale
|
|
416
416
|
if x.ndim > 2: # if task is not classification
|
|
417
|
-
#
|
|
417
|
+
# Denormalize xywh with input image size
|
|
418
418
|
# xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models
|
|
419
419
|
# See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695
|
|
420
420
|
x[:, 0] *= w
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
import json
|
|
4
4
|
from time import time
|
|
5
5
|
|
|
6
|
-
from ultralytics.hub.utils import PREFIX, events
|
|
6
|
+
from ultralytics.hub.utils import HUB_WEB_ROOT, PREFIX, events
|
|
7
7
|
from ultralytics.utils import LOGGER, SETTINGS
|
|
8
8
|
from ultralytics.utils.torch_utils import model_info_for_loggers
|
|
9
9
|
|
|
@@ -13,7 +13,7 @@ def on_pretrain_routine_end(trainer):
|
|
|
13
13
|
session = getattr(trainer, 'hub_session', None)
|
|
14
14
|
if session:
|
|
15
15
|
# Start timer for upload rate limit
|
|
16
|
-
LOGGER.info(f'{PREFIX}View model at
|
|
16
|
+
LOGGER.info(f'{PREFIX}View model at {HUB_WEB_ROOT}/models/{session.model_id} 🚀')
|
|
17
17
|
session.timers = {'metrics': time(), 'ckpt': time()} # start timer on session.rate_limit
|
|
18
18
|
|
|
19
19
|
|
|
@@ -39,7 +39,7 @@ def on_model_save(trainer):
|
|
|
39
39
|
# Upload checkpoints with rate limiting
|
|
40
40
|
is_best = trainer.best_fitness == trainer.fitness
|
|
41
41
|
if time() - session.timers['ckpt'] > session.rate_limits['ckpt']:
|
|
42
|
-
LOGGER.info(f'{PREFIX}Uploading checkpoint
|
|
42
|
+
LOGGER.info(f'{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_id}')
|
|
43
43
|
session.upload_model(trainer.epoch, trainer.last, is_best)
|
|
44
44
|
session.timers['ckpt'] = time() # reset timer
|
|
45
45
|
|
|
@@ -53,7 +53,7 @@ def on_train_end(trainer):
|
|
|
53
53
|
session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics.get('metrics/mAP50-95(B)', 0), final=True)
|
|
54
54
|
session.alive = False # stop heartbeats
|
|
55
55
|
LOGGER.info(f'{PREFIX}Done ✅\n'
|
|
56
|
-
f'{PREFIX}View model at
|
|
56
|
+
f'{PREFIX}View model at {HUB_WEB_ROOT}/models/{session.model_id} 🚀')
|
|
57
57
|
|
|
58
58
|
|
|
59
59
|
def on_train_start(trainer):
|
ultralytics/utils/ops.py
CHANGED
|
@@ -321,7 +321,7 @@ def scale_image(masks, im0_shape, ratio_pad=None):
|
|
|
321
321
|
Takes a mask, and resizes it to the original image size
|
|
322
322
|
|
|
323
323
|
Args:
|
|
324
|
-
masks (
|
|
324
|
+
masks (np.ndarray): resized and padded masks/images, [h, w, num]/[h, w, 3].
|
|
325
325
|
im0_shape (tuple): the original image shape
|
|
326
326
|
ratio_pad (tuple): the ratio of the padding to the original image.
|
|
327
327
|
|
|
@@ -344,9 +344,6 @@ def scale_image(masks, im0_shape, ratio_pad=None):
|
|
|
344
344
|
if len(masks.shape) < 2:
|
|
345
345
|
raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
|
|
346
346
|
masks = masks[top:bottom, left:right]
|
|
347
|
-
# masks = masks.permute(2, 0, 1).contiguous()
|
|
348
|
-
# masks = F.interpolate(masks[None], im0_shape[:2], mode='bilinear', align_corners=False)[0]
|
|
349
|
-
# masks = masks.permute(1, 2, 0).contiguous()
|
|
350
347
|
masks = cv2.resize(masks, (im0_shape[1], im0_shape[0]))
|
|
351
348
|
if len(masks.shape) == 2:
|
|
352
349
|
masks = masks[:, :, None]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ultralytics
|
|
3
|
-
Version: 8.0.
|
|
3
|
+
Version: 8.0.144
|
|
4
4
|
Summary: Ultralytics YOLOv8 for SOTA object detection, multi-object tracking, instance segmentation, pose estimation and image classification.
|
|
5
5
|
Home-page: https://github.com/ultralytics/ultralytics
|
|
6
6
|
Author: Ultralytics
|