ultralytics 8.0.142__py3-none-any.whl → 8.0.144__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

ultralytics/__init__.py CHANGED
@@ -1,10 +1,9 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
- __version__ = '8.0.142'
3
+ __version__ = '8.0.144'
4
4
 
5
- from ultralytics.engine.model import YOLO
6
5
  from ultralytics.hub import start
7
- from ultralytics.models import RTDETR, SAM
6
+ from ultralytics.models import RTDETR, SAM, YOLO
8
7
  from ultralytics.models.fastsam import FastSAM
9
8
  from ultralytics.models.nas import NAS
10
9
  from ultralytics.utils import SETTINGS as settings
@@ -593,14 +593,43 @@ class Exporter:
593
593
  f_onnx, _ = self.export_onnx()
594
594
 
595
595
  # Export to TF
596
- int8 = '-oiqt -qt per-tensor' if self.args.int8 else ''
597
- cmd = f'onnx2tf -i "{f_onnx}" -o "{f}" -nuo --non_verbose {int8}'
598
- LOGGER.info(f"\n{prefix} running '{cmd}'")
596
+ tmp_file = f / 'tmp_tflite_int8_calibration_images.npy' # int8 calibration images file
597
+ if self.args.int8:
598
+ if self.args.data:
599
+ import numpy as np
600
+
601
+ from ultralytics.data.dataset import YOLODataset
602
+ from ultralytics.data.utils import check_det_dataset
603
+
604
+ # Generate calibration data for integer quantization
605
+ LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
606
+ dataset = YOLODataset(check_det_dataset(self.args.data)['val'], imgsz=self.imgsz[0], augment=False)
607
+ images = []
608
+ n_images = 100 # maximum number of images
609
+ for n, batch in enumerate(dataset):
610
+ if n >= n_images:
611
+ break
612
+ im = batch['img'].permute(1, 2, 0)[None] # list to nparray, CHW to BHWC,
613
+ images.append(im)
614
+ f.mkdir()
615
+ images = torch.cat(images, 0).float()
616
+ # mean = images.view(-1, 3).mean(0) # imagenet mean [123.675, 116.28, 103.53]
617
+ # std = images.view(-1, 3).std(0) # imagenet std [58.395, 57.12, 57.375]
618
+ np.save(str(tmp_file), images.numpy()) # BHWC
619
+ int8 = f'-oiqt -qt per-tensor -cind images "{tmp_file}" "[[[[0, 0, 0]]]]" "[[[[255, 255, 255]]]]"'
620
+ else:
621
+ int8 = '-oiqt -qt per-tensor'
622
+ else:
623
+ int8 = ''
624
+
625
+ cmd = f'onnx2tf -i "{f_onnx}" -o "{f}" -nuo --non_verbose {int8}'.strip()
626
+ LOGGER.info(f"{prefix} running '{cmd}'")
599
627
  subprocess.run(cmd, shell=True)
600
628
  yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml
601
629
 
602
630
  # Remove/rename TFLite models
603
631
  if self.args.int8:
632
+ tmp_file.unlink(missing_ok=True)
604
633
  for file in f.rglob('*_dynamic_range_quant.tflite'):
605
634
  file.rename(file.with_name(file.stem.replace('_dynamic_range_quant', '_int8') + file.suffix))
606
635
  for file in f.rglob('*_integer_quant_with_int16_act.tflite'):
@@ -1,36 +1,24 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
+ import inspect
3
4
  import sys
4
5
  from pathlib import Path
5
6
  from typing import Union
6
7
 
7
8
  from ultralytics.cfg import get_cfg
8
9
  from ultralytics.engine.exporter import Exporter
9
- from ultralytics.models import yolo # noqa
10
- from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, PoseModel, SegmentationModel,
11
- attempt_load_one_weight, guess_model_task, nn, yaml_model_load)
10
+ from ultralytics.hub.utils import HUB_WEB_ROOT
11
+ from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
12
12
  from ultralytics.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
13
13
  is_git_dir, yaml_load)
14
14
  from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
15
15
  from ultralytics.utils.downloads import GITHUB_ASSET_STEMS
16
16
  from ultralytics.utils.torch_utils import smart_inference_mode
17
17
 
18
- # Map head to model, trainer, validator, and predictor classes
19
- TASK_MAP = {
20
- 'classify': [
21
- ClassificationModel, yolo.classify.ClassificationTrainer, yolo.classify.ClassificationValidator,
22
- yolo.classify.ClassificationPredictor],
23
- 'detect':
24
- [DetectionModel, yolo.detect.DetectionTrainer, yolo.detect.DetectionValidator, yolo.detect.DetectionPredictor],
25
- 'segment': [
26
- SegmentationModel, yolo.segment.SegmentationTrainer, yolo.segment.SegmentationValidator,
27
- yolo.segment.SegmentationPredictor],
28
- 'pose': [PoseModel, yolo.pose.PoseTrainer, yolo.pose.PoseValidator, yolo.pose.PosePredictor]}
29
18
 
30
-
31
- class YOLO:
19
+ class Model:
32
20
  """
33
- YOLO (You Only Look Once) object detection model.
21
+ A base model class to unify apis for all the models.
34
22
 
35
23
  Args:
36
24
  model (str, Path): Path to the model file to load or create.
@@ -81,13 +69,13 @@ class YOLO:
81
69
  self.predictor = None # reuse predictor
82
70
  self.model = None # model object
83
71
  self.trainer = None # trainer object
84
- self.task = None # task type
85
72
  self.ckpt = None # if loaded from *.pt
86
73
  self.cfg = None # if loaded from *.yaml
87
74
  self.ckpt_path = None
88
75
  self.overrides = {} # overrides for trainer object
89
76
  self.metrics = None # validation/training metrics
90
77
  self.session = None # HUB session
78
+ self.task = task # task type
91
79
  model = str(model).strip() # strip spaces
92
80
 
93
81
  # Check if Ultralytics HUB model from https://hub.ultralytics.com
@@ -109,32 +97,29 @@ class YOLO:
109
97
  """Calls the 'predict' function with given arguments to perform object detection."""
110
98
  return self.predict(source, stream, **kwargs)
111
99
 
112
- def __getattr__(self, attr):
113
- """Raises error if object has no requested attribute."""
114
- name = self.__class__.__name__
115
- raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
116
-
117
100
  @staticmethod
118
101
  def is_hub_model(model):
119
102
  """Check if the provided model is a HUB model."""
120
103
  return any((
121
- model.startswith('https://hub.ultralytics.com/models/'), # i.e. https://hub.ultralytics.com/models/MODEL_ID
104
+ model.startswith(f'{HUB_WEB_ROOT}/models/'), # i.e. https://hub.ultralytics.com/models/MODEL_ID
122
105
  [len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID
123
106
  len(model) == 20 and not Path(model).exists() and all(x not in model for x in './\\'))) # MODELID
124
107
 
125
- def _new(self, cfg: str, task=None, verbose=True):
108
+ def _new(self, cfg: str, task=None, model=None, verbose=True):
126
109
  """
127
110
  Initializes a new model and infers the task type from the model definitions.
128
111
 
129
112
  Args:
130
113
  cfg (str): model configuration file
131
114
  task (str | None): model task
115
+ model (BaseModel): Customized model.
132
116
  verbose (bool): display model info on load
133
117
  """
134
118
  cfg_dict = yaml_model_load(cfg)
135
119
  self.cfg = cfg
136
120
  self.task = task or guess_model_task(cfg_dict)
137
- self.model = TASK_MAP[self.task][0](cfg_dict, verbose=verbose and RANK == -1) # build model
121
+ model = model or self.smart_load('model')
122
+ self.model = model(cfg_dict, verbose=verbose and RANK == -1) # build model
138
123
  self.overrides['model'] = self.cfg
139
124
 
140
125
  # Below added to allow export from yamls
@@ -217,7 +202,7 @@ class YOLO:
217
202
  self.model.fuse()
218
203
 
219
204
  @smart_inference_mode()
220
- def predict(self, source=None, stream=False, **kwargs):
205
+ def predict(self, source=None, stream=False, predictor=None, **kwargs):
221
206
  """
222
207
  Perform prediction using the YOLO model.
223
208
 
@@ -225,6 +210,7 @@ class YOLO:
225
210
  source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
226
211
  Accepts all source types accepted by the YOLO model.
227
212
  stream (bool): Whether to stream the predictions or not. Defaults to False.
213
+ predictor (BasePredictor): Customized predictor.
228
214
  **kwargs : Additional keyword arguments passed to the predictor.
229
215
  Check the 'configuration' section in the documentation for all available options.
230
216
 
@@ -236,6 +222,8 @@ class YOLO:
236
222
  LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
237
223
  is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any(
238
224
  x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track'))
225
+ # Check prompts for SAM/FastSAM
226
+ prompts = kwargs.pop('prompts', None)
239
227
  overrides = self.overrides.copy()
240
228
  overrides['conf'] = 0.25
241
229
  overrides.update(kwargs) # prefer kwargs
@@ -245,12 +233,16 @@ class YOLO:
245
233
  overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
246
234
  if not self.predictor:
247
235
  self.task = overrides.get('task') or self.task
248
- self.predictor = TASK_MAP[self.task][3](overrides=overrides, _callbacks=self.callbacks)
236
+ predictor = predictor or self.smart_load('predictor')
237
+ self.predictor = predictor(overrides=overrides, _callbacks=self.callbacks)
249
238
  self.predictor.setup_model(model=self.model, verbose=is_cli)
250
239
  else: # only update args if predictor is already setup
251
240
  self.predictor.args = get_cfg(self.predictor.args, overrides)
252
241
  if 'project' in overrides or 'name' in overrides:
253
242
  self.predictor.save_dir = self.predictor.get_save_dir()
243
+ # Set prompts for SAM/FastSAM
244
+ if len and hasattr(self.predictor, 'set_prompts'):
245
+ self.predictor.set_prompts(prompts)
254
246
  return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
255
247
 
256
248
  def track(self, source=None, stream=False, persist=False, **kwargs):
@@ -277,12 +269,13 @@ class YOLO:
277
269
  return self.predict(source=source, stream=stream, **kwargs)
278
270
 
279
271
  @smart_inference_mode()
280
- def val(self, data=None, **kwargs):
272
+ def val(self, data=None, validator=None, **kwargs):
281
273
  """
282
274
  Validate a model on a given dataset.
283
275
 
284
276
  Args:
285
277
  data (str): The dataset to validate on. Accepts all formats accepted by yolo
278
+ validator (BaseValidator): Customized validator.
286
279
  **kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
287
280
  """
288
281
  overrides = self.overrides.copy()
@@ -295,11 +288,12 @@ class YOLO:
295
288
  self.task = args.task
296
289
  else:
297
290
  args.task = self.task
291
+ validator = validator or self.smart_load('validator')
298
292
  if args.imgsz == DEFAULT_CFG.imgsz and not isinstance(self.model, (str, Path)):
299
293
  args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
300
294
  args.imgsz = check_imgsz(args.imgsz, max_dim=1)
301
295
 
302
- validator = TASK_MAP[self.task][2](args=args, _callbacks=self.callbacks)
296
+ validator = validator(args=args, _callbacks=self.callbacks)
303
297
  validator(model=self.model)
304
298
  self.metrics = validator.metrics
305
299
 
@@ -343,15 +337,18 @@ class YOLO:
343
337
  overrides['imgsz'] = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
344
338
  if 'batch' not in kwargs:
345
339
  overrides['batch'] = 1 # default to 1 if not modified
340
+ if 'data' not in kwargs:
341
+ overrides['data'] = None # default to None if not modified (avoid int8 calibration with coco.yaml)
346
342
  args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
347
343
  args.task = self.task
348
344
  return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
349
345
 
350
- def train(self, **kwargs):
346
+ def train(self, trainer=None, **kwargs):
351
347
  """
352
348
  Trains the model on a given dataset.
353
349
 
354
350
  Args:
351
+ trainer (BaseTrainer, optional): Customized trainer.
355
352
  **kwargs (Any): Any number of arguments representing the training configuration.
356
353
  """
357
354
  self._check_is_pytorch_model()
@@ -371,7 +368,8 @@ class YOLO:
371
368
  if overrides.get('resume'):
372
369
  overrides['resume'] = self.ckpt_path
373
370
  self.task = overrides.get('task') or self.task
374
- self.trainer = TASK_MAP[self.task][1](overrides=overrides, _callbacks=self.callbacks)
371
+ trainer = trainer or self.smart_load('trainer')
372
+ self.trainer = trainer(overrides=overrides, _callbacks=self.callbacks)
375
373
  if not overrides.get('resume'): # manually set model only if not resuming
376
374
  self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
377
375
  self.model = self.trainer.model
@@ -440,3 +438,28 @@ class YOLO:
440
438
  """Reset all registered callbacks."""
441
439
  for event in callbacks.default_callbacks.keys():
442
440
  self.callbacks[event] = [callbacks.default_callbacks[event][0]]
441
+
442
+ def __getattr__(self, attr):
443
+ """Raises error if object has no requested attribute."""
444
+ name = self.__class__.__name__
445
+ raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
446
+
447
+ def smart_load(self, key):
448
+ """Load model/trainer/validator/predictor."""
449
+ try:
450
+ return self.task_map[self.task][key]
451
+ except Exception:
452
+ name = self.__class__.__name__
453
+ mode = inspect.stack()[1][3] # get the function name.
454
+ raise NotImplementedError(
455
+ f'WARNING ⚠️ `{name}` model does not support `{mode}` mode for `{self.task}` task yet.')
456
+
457
+ @property
458
+ def task_map(self):
459
+ """
460
+ Map head to model, trainer, validator, and predictor classes.
461
+
462
+ Returns:
463
+ task_map (dict): The map of model task to mode classes.
464
+ """
465
+ raise NotImplementedError('Please provide task map for your model!')
@@ -105,6 +105,7 @@ class BasePredictor:
105
105
  self.results = None
106
106
  self.transforms = None
107
107
  self.callbacks = _callbacks or callbacks.get_default_callbacks()
108
+ self.txt_path = None
108
109
  callbacks.add_integration_callbacks(self)
109
110
 
110
111
  def get_save_dir(self):
@@ -178,7 +179,8 @@ class BasePredictor:
178
179
  if self.args.save_txt:
179
180
  result.save_txt(f'{self.txt_path}.txt', save_conf=self.args.save_conf)
180
181
  if self.args.save_crop:
181
- result.save_crop(save_dir=self.save_dir / 'crops', file_name=self.data_path.stem)
182
+ result.save_crop(save_dir=self.save_dir / 'crops',
183
+ file_name=self.data_path.stem + ('' if self.dataset.mode == 'image' else f'_{frame}'))
182
184
 
183
185
  return log_string
184
186
 
@@ -76,7 +76,6 @@ class Results(SimpleClass):
76
76
  probs (torch.tensor, optional): A 1D tensor of probabilities of each class for classification task.
77
77
  keypoints (List[List[float]], optional): A list of detected keypoints for each object.
78
78
 
79
-
80
79
  Attributes:
81
80
  orig_img (numpy.ndarray): The original image as a numpy array.
82
81
  orig_shape (tuple): The original image shape in (height, width) format.
@@ -172,6 +171,7 @@ class Results(SimpleClass):
172
171
  pil=False,
173
172
  img=None,
174
173
  im_gpu=None,
174
+ kpt_radius=5,
175
175
  kpt_line=True,
176
176
  labels=True,
177
177
  boxes=True,
@@ -190,6 +190,7 @@ class Results(SimpleClass):
190
190
  pil (bool): Whether to return the image as a PIL Image.
191
191
  img (numpy.ndarray): Plot to another image. if not, plot to original image.
192
192
  im_gpu (torch.Tensor): Normalized image in gpu with shape (1, 3, 640, 640), for faster mask plotting.
193
+ kpt_radius (int, optional): Radius of the drawn keypoints. Default is 5.
193
194
  kpt_line (bool): Whether to draw lines connecting keypoints.
194
195
  labels (bool): Whether to plot the label of bounding boxes.
195
196
  boxes (bool): Whether to plot the bounding boxes.
@@ -251,7 +252,7 @@ class Results(SimpleClass):
251
252
  # Plot Pose results
252
253
  if self.keypoints is not None:
253
254
  for k in reversed(self.keypoints.data):
254
- annotator.kpts(k, self.orig_shape, kpt_line=kpt_line)
255
+ annotator.kpts(k, self.orig_shape, radius=kpt_radius, kpt_line=kpt_line)
255
256
 
256
257
  return annotator.result()
257
258
 
@@ -4,7 +4,7 @@ import requests
4
4
 
5
5
  from ultralytics.data.utils import HUBDatasetStats
6
6
  from ultralytics.hub.auth import Auth
7
- from ultralytics.hub.utils import PREFIX
7
+ from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX
8
8
  from ultralytics.utils import LOGGER, SETTINGS, USER_CONFIG_DIR, yaml_save
9
9
 
10
10
 
@@ -50,13 +50,13 @@ WARNING ⚠️ ultralytics.start() is deprecated after 8.0.60. Updated usage to
50
50
  from ultralytics import YOLO, hub
51
51
 
52
52
  hub.login('{api_key}')
53
- model = YOLO('https://hub.ultralytics.com/models/{model_id}')
53
+ model = YOLO('{HUB_WEB_ROOT}/models/{model_id}')
54
54
  model.train()""")
55
55
 
56
56
 
57
57
  def reset_model(model_id=''):
58
58
  """Reset a trained model to an untrained state."""
59
- r = requests.post('https://api.ultralytics.com/model-reset', json={'apiKey': Auth().api_key, 'modelId': model_id})
59
+ r = requests.post(f'{HUB_API_ROOT}/model-reset', json={'apiKey': Auth().api_key, 'modelId': model_id})
60
60
  if r.status_code == 200:
61
61
  LOGGER.info(f'{PREFIX}Model reset successfully')
62
62
  return
@@ -72,7 +72,7 @@ def export_fmts_hub():
72
72
  def export_model(model_id='', format='torchscript'):
73
73
  """Export a model to all formats."""
74
74
  assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
75
- r = requests.post(f'https://api.ultralytics.com/v1/models/{model_id}/export',
75
+ r = requests.post(f'{HUB_API_ROOT}/v1/models/{model_id}/export',
76
76
  json={'format': format},
77
77
  headers={'x-api-key': Auth().api_key})
78
78
  assert r.status_code == 200, f'{PREFIX}{format} export failure {r.status_code} {r.reason}'
@@ -82,7 +82,7 @@ def export_model(model_id='', format='torchscript'):
82
82
  def get_export(model_id='', format='torchscript'):
83
83
  """Get an exported model dictionary with download URL."""
84
84
  assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
85
- r = requests.post('https://api.ultralytics.com/get-export',
85
+ r = requests.post(f'{HUB_API_ROOT}/get-export',
86
86
  json={
87
87
  'apiKey': Auth().api_key,
88
88
  'modelId': model_id,
@@ -110,7 +110,7 @@ def check_dataset(path='', task='detect'):
110
110
  ```
111
111
  """
112
112
  HUBDatasetStats(path=path, task=task).get_json()
113
- LOGGER.info('Checks completed correctly ✅. Upload this dataset to https://hub.ultralytics.com/datasets/.')
113
+ LOGGER.info(f'Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.')
114
114
 
115
115
 
116
116
  if __name__ == '__main__':
ultralytics/hub/auth.py CHANGED
@@ -2,10 +2,10 @@
2
2
 
3
3
  import requests
4
4
 
5
- from ultralytics.hub.utils import HUB_API_ROOT, PREFIX, request_with_credentials
5
+ from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, request_with_credentials
6
6
  from ultralytics.utils import LOGGER, SETTINGS, emojis, is_colab
7
7
 
8
- API_KEY_URL = 'https://hub.ultralytics.com/settings?tab=api+keys'
8
+ API_KEY_URL = f'{HUB_WEB_ROOT}/settings?tab=api+keys'
9
9
 
10
10
 
11
11
  class Auth:
@@ -6,7 +6,7 @@ from time import sleep
6
6
 
7
7
  import requests
8
8
 
9
- from ultralytics.hub.utils import HUB_API_ROOT, PREFIX, smart_request
9
+ from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, smart_request
10
10
  from ultralytics.utils import LOGGER, __version__, checks, emojis, is_colab, threaded
11
11
  from ultralytics.utils.errors import HUBModelError
12
12
 
@@ -49,21 +49,21 @@ class HUBTrainingSession:
49
49
  from ultralytics.hub.auth import Auth
50
50
 
51
51
  # Parse input
52
- if url.startswith('https://hub.ultralytics.com/models/'):
53
- url = url.split('https://hub.ultralytics.com/models/')[-1]
52
+ if url.startswith(f'{HUB_WEB_ROOT}/models/'):
53
+ url = url.split(f'{HUB_WEB_ROOT}/models/')[-1]
54
54
  if [len(x) for x in url.split('_')] == [42, 20]:
55
55
  key, model_id = url.split('_')
56
56
  elif len(url) == 20:
57
57
  key, model_id = '', url
58
58
  else:
59
59
  raise HUBModelError(f"model='{url}' not found. Check format is correct, i.e. "
60
- f"model='https://hub.ultralytics.com/models/MODEL_ID' and try again.")
60
+ f"model='{HUB_WEB_ROOT}/models/MODEL_ID' and try again.")
61
61
 
62
62
  # Authorize
63
63
  auth = Auth(key)
64
64
  self.agent_id = None # identifies which instance is communicating with server
65
65
  self.model_id = model_id
66
- self.model_url = f'https://hub.ultralytics.com/models/{model_id}'
66
+ self.model_url = f'{HUB_WEB_ROOT}/models/{model_id}'
67
67
  self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}'
68
68
  self.auth_header = auth.get_auth_header()
69
69
  self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds)
ultralytics/hub/utils.py CHANGED
@@ -18,6 +18,7 @@ from ultralytics.utils.downloads import GITHUB_ASSET_NAMES
18
18
  PREFIX = colorstr('Ultralytics HUB: ')
19
19
  HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
20
20
  HUB_API_ROOT = os.environ.get('ULTRALYTICS_HUB_API', 'https://api.ultralytics.com')
21
+ HUB_WEB_ROOT = os.environ.get('ULTRALYTICS_HUB_WEB', 'https://hub.ultralytics.com')
21
22
 
22
23
 
23
24
  def request_with_credentials(url: str) -> any:
@@ -1,4 +1,7 @@
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
1
3
  from .rtdetr import RTDETR
2
4
  from .sam import SAM
5
+ from .yolo import YOLO
3
6
 
4
- __all__ = 'RTDETR', 'SAM' # allow simpler import
7
+ __all__ = 'YOLO', 'RTDETR', 'SAM' # allow simpler import
@@ -1,111 +1,31 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
- """
3
- FastSAM model interface.
4
2
 
5
- Usage - Predict:
6
- from ultralytics import FastSAM
3
+ from pathlib import Path
7
4
 
8
- model = FastSAM('last.pt')
9
- results = model.predict('ultralytics/assets/bus.jpg')
10
- """
11
-
12
- from ultralytics.cfg import get_cfg
13
- from ultralytics.engine.exporter import Exporter
14
- from ultralytics.engine.model import YOLO
15
- from ultralytics.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir
16
- from ultralytics.utils.checks import check_imgsz
17
- from ultralytics.utils.torch_utils import model_info, smart_inference_mode
5
+ from ultralytics.engine.model import Model
18
6
 
19
7
  from .predict import FastSAMPredictor
8
+ from .val import FastSAMValidator
9
+
20
10
 
11
+ class FastSAM(Model):
12
+ """
13
+ FastSAM model interface.
21
14
 
22
- class FastSAM(YOLO):
15
+ Usage - Predict:
16
+ from ultralytics import FastSAM
17
+
18
+ model = FastSAM('last.pt')
19
+ results = model.predict('ultralytics/assets/bus.jpg')
20
+ """
23
21
 
24
22
  def __init__(self, model='FastSAM-x.pt'):
25
23
  """Call the __init__ method of the parent class (YOLO) with the updated default model"""
26
24
  if model == 'FastSAM.pt':
27
25
  model = 'FastSAM-x.pt'
28
- super().__init__(model=model)
29
- # any additional initialization code for FastSAM
30
-
31
- @smart_inference_mode()
32
- def predict(self, source=None, stream=False, **kwargs):
33
- """
34
- Perform prediction using the YOLO model.
35
-
36
- Args:
37
- source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
38
- Accepts all source types accepted by the YOLO model.
39
- stream (bool): Whether to stream the predictions or not. Defaults to False.
40
- **kwargs : Additional keyword arguments passed to the predictor.
41
- Check the 'configuration' section in the documentation for all available options.
42
-
43
- Returns:
44
- (List[ultralytics.engine.results.Results]): The prediction results.
45
- """
46
- if source is None:
47
- source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
48
- LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
49
- overrides = self.overrides.copy()
50
- overrides['conf'] = 0.25
51
- overrides.update(kwargs) # prefer kwargs
52
- overrides['mode'] = kwargs.get('mode', 'predict')
53
- assert overrides['mode'] in ['track', 'predict']
54
- overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
55
- self.predictor = FastSAMPredictor(overrides=overrides)
56
- self.predictor.setup_model(model=self.model, verbose=False)
57
-
58
- return self.predictor(source, stream=stream)
59
-
60
- def train(self, **kwargs):
61
- """Function trains models but raises an error as FastSAM models do not support training."""
62
- raise NotImplementedError("FastSAM models don't support training")
63
-
64
- def val(self, **kwargs):
65
- """Run validation given dataset."""
66
- overrides = dict(task='segment', mode='val')
67
- overrides.update(kwargs) # prefer kwargs
68
- args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
69
- args.imgsz = check_imgsz(args.imgsz, max_dim=1)
70
- validator = FastSAM(args=args)
71
- validator(model=self.model)
72
- self.metrics = validator.metrics
73
- return validator.metrics
74
-
75
- @smart_inference_mode()
76
- def export(self, **kwargs):
77
- """
78
- Export model.
79
-
80
- Args:
81
- **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
82
- """
83
- overrides = dict(task='detect')
84
- overrides.update(kwargs)
85
- overrides['mode'] = 'export'
86
- args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
87
- args.task = self.task
88
- if args.imgsz == DEFAULT_CFG.imgsz:
89
- args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
90
- if args.batch == DEFAULT_CFG.batch:
91
- args.batch = 1 # default to 1 if not modified
92
- return Exporter(overrides=args)(model=self.model)
93
-
94
- def info(self, detailed=False, verbose=True):
95
- """
96
- Logs model info.
97
-
98
- Args:
99
- detailed (bool): Show detailed information about model.
100
- verbose (bool): Controls verbosity.
101
- """
102
- return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
103
-
104
- def __call__(self, source=None, stream=False, **kwargs):
105
- """Calls the 'predict' function with given arguments to perform object detection."""
106
- return self.predict(source, stream, **kwargs)
26
+ assert Path(model).suffix != '.yaml', 'FastSAM models only support pre-trained models.'
27
+ super().__init__(model=model, task='segment')
107
28
 
108
- def __getattr__(self, attr):
109
- """Raises error if object has no requested attribute."""
110
- name = self.__class__.__name__
111
- raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
29
+ @property
30
+ def task_map(self):
31
+ return {'segment': {'predictor': FastSAMPredictor, 'validator': FastSAMValidator}}