ultralytics 8.0.194__py3-none-any.whl → 8.0.196__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (84) hide show
  1. ultralytics/__init__.py +1 -1
  2. ultralytics/cfg/__init__.py +5 -6
  3. ultralytics/data/augment.py +234 -29
  4. ultralytics/data/base.py +2 -1
  5. ultralytics/data/build.py +9 -3
  6. ultralytics/data/converter.py +5 -2
  7. ultralytics/data/dataset.py +16 -2
  8. ultralytics/data/loaders.py +111 -7
  9. ultralytics/data/utils.py +3 -3
  10. ultralytics/engine/exporter.py +1 -3
  11. ultralytics/engine/model.py +16 -9
  12. ultralytics/engine/predictor.py +10 -6
  13. ultralytics/engine/results.py +18 -8
  14. ultralytics/engine/trainer.py +19 -31
  15. ultralytics/engine/tuner.py +20 -20
  16. ultralytics/engine/validator.py +3 -4
  17. ultralytics/hub/__init__.py +2 -2
  18. ultralytics/hub/auth.py +18 -3
  19. ultralytics/hub/session.py +1 -0
  20. ultralytics/hub/utils.py +1 -3
  21. ultralytics/models/fastsam/model.py +2 -1
  22. ultralytics/models/fastsam/predict.py +10 -7
  23. ultralytics/models/fastsam/prompt.py +15 -1
  24. ultralytics/models/nas/model.py +3 -1
  25. ultralytics/models/rtdetr/model.py +4 -6
  26. ultralytics/models/rtdetr/predict.py +2 -1
  27. ultralytics/models/rtdetr/train.py +2 -1
  28. ultralytics/models/rtdetr/val.py +1 -0
  29. ultralytics/models/sam/amg.py +12 -6
  30. ultralytics/models/sam/model.py +5 -6
  31. ultralytics/models/sam/modules/decoders.py +5 -1
  32. ultralytics/models/sam/modules/encoders.py +15 -12
  33. ultralytics/models/sam/modules/tiny_encoder.py +38 -2
  34. ultralytics/models/sam/modules/transformer.py +2 -4
  35. ultralytics/models/sam/predict.py +8 -4
  36. ultralytics/models/utils/loss.py +35 -8
  37. ultralytics/models/utils/ops.py +14 -18
  38. ultralytics/models/yolo/classify/predict.py +1 -0
  39. ultralytics/models/yolo/classify/train.py +4 -2
  40. ultralytics/models/yolo/classify/val.py +1 -0
  41. ultralytics/models/yolo/detect/train.py +4 -3
  42. ultralytics/models/yolo/model.py +2 -4
  43. ultralytics/models/yolo/pose/predict.py +1 -0
  44. ultralytics/models/yolo/segment/predict.py +2 -0
  45. ultralytics/models/yolo/segment/val.py +1 -1
  46. ultralytics/nn/autobackend.py +54 -43
  47. ultralytics/nn/modules/__init__.py +13 -9
  48. ultralytics/nn/modules/block.py +11 -5
  49. ultralytics/nn/modules/conv.py +16 -7
  50. ultralytics/nn/modules/head.py +6 -3
  51. ultralytics/nn/modules/transformer.py +47 -15
  52. ultralytics/nn/modules/utils.py +6 -4
  53. ultralytics/nn/tasks.py +61 -21
  54. ultralytics/trackers/bot_sort.py +53 -6
  55. ultralytics/trackers/byte_tracker.py +71 -15
  56. ultralytics/trackers/track.py +0 -1
  57. ultralytics/trackers/utils/gmc.py +23 -0
  58. ultralytics/trackers/utils/kalman_filter.py +6 -6
  59. ultralytics/utils/__init__.py +32 -19
  60. ultralytics/utils/autobatch.py +1 -3
  61. ultralytics/utils/benchmarks.py +14 -1
  62. ultralytics/utils/callbacks/base.py +1 -3
  63. ultralytics/utils/callbacks/comet.py +11 -3
  64. ultralytics/utils/callbacks/dvc.py +9 -0
  65. ultralytics/utils/callbacks/neptune.py +5 -6
  66. ultralytics/utils/callbacks/wb.py +1 -0
  67. ultralytics/utils/checks.py +13 -9
  68. ultralytics/utils/dist.py +2 -1
  69. ultralytics/utils/downloads.py +7 -3
  70. ultralytics/utils/files.py +3 -3
  71. ultralytics/utils/instance.py +12 -3
  72. ultralytics/utils/loss.py +97 -22
  73. ultralytics/utils/metrics.py +35 -34
  74. ultralytics/utils/ops.py +10 -9
  75. ultralytics/utils/patches.py +9 -7
  76. ultralytics/utils/plotting.py +4 -3
  77. ultralytics/utils/torch_utils.py +8 -6
  78. ultralytics/utils/triton.py +87 -0
  79. {ultralytics-8.0.194.dist-info → ultralytics-8.0.196.dist-info}/METADATA +1 -1
  80. {ultralytics-8.0.194.dist-info → ultralytics-8.0.196.dist-info}/RECORD +84 -83
  81. {ultralytics-8.0.194.dist-info → ultralytics-8.0.196.dist-info}/LICENSE +0 -0
  82. {ultralytics-8.0.194.dist-info → ultralytics-8.0.196.dist-info}/WHEEL +0 -0
  83. {ultralytics-8.0.194.dist-info → ultralytics-8.0.196.dist-info}/entry_points.txt +0 -0
  84. {ultralytics-8.0.194.dist-info → ultralytics-8.0.196.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,7 @@ from ultralytics.utils.checks import check_requirements
22
22
 
23
23
  @dataclass
24
24
  class SourceTypes:
25
+ """Class to represent various types of input sources for predictions."""
25
26
  webcam: bool = False
26
27
  screenshot: bool = False
27
28
  from_img: bool = False
@@ -29,7 +30,34 @@ class SourceTypes:
29
30
 
30
31
 
31
32
  class LoadStreams:
32
- """Stream Loader, i.e. `yolo predict source='rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP, TCP streams`."""
33
+ """
34
+ Stream Loader for various types of video streams.
35
+
36
+ Suitable for use with `yolo predict source='rtsp://example.com/media.mp4'`, supports RTSP, RTMP, HTTP, and TCP streams.
37
+
38
+ Attributes:
39
+ sources (str): The source input paths or URLs for the video streams.
40
+ imgsz (int): The image size for processing, defaults to 640.
41
+ vid_stride (int): Video frame-rate stride, defaults to 1.
42
+ buffer (bool): Whether to buffer input streams, defaults to False.
43
+ running (bool): Flag to indicate if the streaming thread is running.
44
+ mode (str): Set to 'stream' indicating real-time capture.
45
+ imgs (list): List of image frames for each stream.
46
+ fps (list): List of FPS for each stream.
47
+ frames (list): List of total frames for each stream.
48
+ threads (list): List of threads for each stream.
49
+ shape (list): List of shapes for each stream.
50
+ caps (list): List of cv2.VideoCapture objects for each stream.
51
+ bs (int): Batch size for processing.
52
+
53
+ Methods:
54
+ __init__: Initialize the stream loader.
55
+ update: Read stream frames in daemon thread.
56
+ close: Close stream loader and release resources.
57
+ __iter__: Returns an iterator object for the class.
58
+ __next__: Returns source paths, transformed, and original images for processing.
59
+ __len__: Return the length of the sources object.
60
+ """
33
61
 
34
62
  def __init__(self, sources='file.streams', imgsz=640, vid_stride=1, buffer=False):
35
63
  """Initialize instance variables and check for consistent input stream shapes."""
@@ -149,10 +177,33 @@ class LoadStreams:
149
177
 
150
178
 
151
179
  class LoadScreenshots:
152
- """YOLOv8 screenshot dataloader, i.e. `yolo predict source=screen`."""
180
+ """
181
+ YOLOv8 screenshot dataloader.
182
+
183
+ This class manages the loading of screenshot images for processing with YOLOv8.
184
+ Suitable for use with `yolo predict source=screen`.
185
+
186
+ Attributes:
187
+ source (str): The source input indicating which screen to capture.
188
+ imgsz (int): The image size for processing, defaults to 640.
189
+ screen (int): The screen number to capture.
190
+ left (int): The left coordinate for screen capture area.
191
+ top (int): The top coordinate for screen capture area.
192
+ width (int): The width of the screen capture area.
193
+ height (int): The height of the screen capture area.
194
+ mode (str): Set to 'stream' indicating real-time capture.
195
+ frame (int): Counter for captured frames.
196
+ sct (mss.mss): Screen capture object from `mss` library.
197
+ bs (int): Batch size, set to 1.
198
+ monitor (dict): Monitor configuration details.
199
+
200
+ Methods:
201
+ __iter__: Returns an iterator object.
202
+ __next__: Captures the next screenshot and returns it.
203
+ """
153
204
 
154
205
  def __init__(self, source, imgsz=640):
155
- """source = [screen_number left top width height] (pixels)."""
206
+ """Source = [screen_number left top width height] (pixels)."""
156
207
  check_requirements('mss')
157
208
  import mss # noqa
158
209
 
@@ -192,7 +243,28 @@ class LoadScreenshots:
192
243
 
193
244
 
194
245
  class LoadImages:
195
- """YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4`."""
246
+ """
247
+ YOLOv8 image/video dataloader.
248
+
249
+ This class manages the loading and pre-processing of image and video data for YOLOv8. It supports loading from
250
+ various formats, including single image files, video files, and lists of image and video paths.
251
+
252
+ Attributes:
253
+ imgsz (int): Image size, defaults to 640.
254
+ files (list): List of image and video file paths.
255
+ nf (int): Total number of files (images and videos).
256
+ video_flag (list): Flags indicating whether a file is a video (True) or an image (False).
257
+ mode (str): Current mode, 'image' or 'video'.
258
+ vid_stride (int): Stride for video frame-rate, defaults to 1.
259
+ bs (int): Batch size, set to 1 for this class.
260
+ cap (cv2.VideoCapture): Video capture object for OpenCV.
261
+ frame (int): Frame counter for video.
262
+ frames (int): Total number of frames in the video.
263
+ count (int): Counter for iteration, initialized at 0 during `__iter__()`.
264
+
265
+ Methods:
266
+ _new_video(path): Create a new cv2.VideoCapture object for a given video path.
267
+ """
196
268
 
197
269
  def __init__(self, path, imgsz=640, vid_stride=1):
198
270
  """Initialize the Dataloader and raise FileNotFoundError if file not found."""
@@ -285,6 +357,24 @@ class LoadImages:
285
357
 
286
358
 
287
359
  class LoadPilAndNumpy:
360
+ """
361
+ Load images from PIL and Numpy arrays for batch processing.
362
+
363
+ This class is designed to manage loading and pre-processing of image data from both PIL and Numpy formats.
364
+ It performs basic validation and format conversion to ensure that the images are in the required format for
365
+ downstream processing.
366
+
367
+ Attributes:
368
+ paths (list): List of image paths or autogenerated filenames.
369
+ im0 (list): List of images stored as Numpy arrays.
370
+ imgsz (int): Image size, defaults to 640.
371
+ mode (str): Type of data being processed, defaults to 'image'.
372
+ bs (int): Batch size, equivalent to the length of `im0`.
373
+ count (int): Counter for iteration, initialized at 0 during `__iter__()`.
374
+
375
+ Methods:
376
+ _single_check(im): Validate and format a single image to a Numpy array.
377
+ """
288
378
 
289
379
  def __init__(self, im0, imgsz=640):
290
380
  """Initialize PIL and Numpy Dataloader."""
@@ -326,8 +416,24 @@ class LoadPilAndNumpy:
326
416
 
327
417
 
328
418
  class LoadTensor:
419
+ """
420
+ Load images from torch.Tensor data.
421
+
422
+ This class manages the loading and pre-processing of image data from PyTorch tensors for further processing.
423
+
424
+ Attributes:
425
+ im0 (torch.Tensor): The input tensor containing the image(s).
426
+ bs (int): Batch size, inferred from the shape of `im0`.
427
+ mode (str): Current mode, set to 'image'.
428
+ paths (list): List of image paths or filenames.
429
+ count (int): Counter for iteration, initialized at 0 during `__iter__()`.
430
+
431
+ Methods:
432
+ _single_check(im, stride): Validate and possibly modify the input tensor.
433
+ """
329
434
 
330
435
  def __init__(self, im0) -> None:
436
+ """Initialize Tensor Dataloader."""
331
437
  self.im0 = self._single_check(im0)
332
438
  self.bs = self.im0.shape[0]
333
439
  self.mode = 'image'
@@ -370,9 +476,7 @@ class LoadTensor:
370
476
 
371
477
 
372
478
  def autocast_list(source):
373
- """
374
- Merges a list of source of different types into a list of numpy arrays or PIL images
375
- """
479
+ """Merges a list of source of different types into a list of numpy arrays or PIL images."""
376
480
  files = []
377
481
  for im in source:
378
482
  if isinstance(im, (str, Path)): # filename or uri
ultralytics/data/utils.py CHANGED
@@ -547,9 +547,9 @@ class HUBDatasetStats:
547
547
 
548
548
  def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
549
549
  """
550
- Compresses a single image file to reduced size while preserving its aspect ratio and quality using either the
551
- Python Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it will
552
- not be resized.
550
+ Compresses a single image file to reduced size while preserving its aspect ratio and quality using either the Python
551
+ Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it will not be
552
+ resized.
553
553
 
554
554
  Args:
555
555
  f (str): The path to the input image file.
@@ -986,9 +986,7 @@ class Exporter:
986
986
  return model
987
987
 
988
988
  def add_callback(self, event: str, callback):
989
- """
990
- Appends the given callback.
991
- """
989
+ """Appends the given callback."""
992
990
  self.callbacks[event].append(callback)
993
991
 
994
992
  def run_callbacks(self, event: str):
@@ -81,6 +81,12 @@ class Model(nn.Module):
81
81
  self.session = HUBTrainingSession(model)
82
82
  model = self.session.model_file
83
83
 
84
+ # Check if Triton Server model
85
+ elif self.is_triton_model(model):
86
+ self.model = model
87
+ self.task = task
88
+ return
89
+
84
90
  # Load or create new YOLO model
85
91
  suffix = Path(model).suffix
86
92
  if not suffix and Path(model).stem in GITHUB_ASSETS_STEMS:
@@ -94,6 +100,13 @@ class Model(nn.Module):
94
100
  """Calls the 'predict' function with given arguments to perform object detection."""
95
101
  return self.predict(source, stream, **kwargs)
96
102
 
103
+ @staticmethod
104
+ def is_triton_model(model):
105
+ """Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>"""
106
+ from urllib.parse import urlsplit
107
+ url = urlsplit(model)
108
+ return url.netloc and url.path and url.scheme in {'http', 'grfc'}
109
+
97
110
  @staticmethod
98
111
  def is_hub_model(model):
99
112
  """Check if the provided model is a HUB model."""
@@ -146,9 +159,7 @@ class Model(nn.Module):
146
159
  self.overrides['task'] = self.task
147
160
 
148
161
  def _check_is_pytorch_model(self):
149
- """
150
- Raises TypeError is model is not a PyTorch model
151
- """
162
+ """Raises TypeError is model is not a PyTorch model."""
152
163
  pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == '.pt'
153
164
  pt_module = isinstance(self.model, nn.Module)
154
165
  if not (pt_module or pt_str):
@@ -160,9 +171,7 @@ class Model(nn.Module):
160
171
  f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'")
161
172
 
162
173
  def reset_weights(self):
163
- """
164
- Resets the model modules parameters to randomly initialized values, losing all training information.
165
- """
174
+ """Resets the model modules parameters to randomly initialized values, losing all training information."""
166
175
  self._check_is_pytorch_model()
167
176
  for m in self.model.modules():
168
177
  if hasattr(m, 'reset_parameters'):
@@ -172,9 +181,7 @@ class Model(nn.Module):
172
181
  return self
173
182
 
174
183
  def load(self, weights='yolov8n.pt'):
175
- """
176
- Transfers parameters with matching names and shapes from 'weights' to model.
177
- """
184
+ """Transfers parameters with matching names and shapes from 'weights' to model."""
178
185
  self._check_is_pytorch_model()
179
186
  if isinstance(weights, (str, Path)):
180
187
  weights, self.ckpt = attempt_load_one_weight(weights)
@@ -58,7 +58,7 @@ Example:
58
58
 
59
59
  class BasePredictor:
60
60
  """
61
- BasePredictor
61
+ BasePredictor.
62
62
 
63
63
  A base class for creating predictors.
64
64
 
@@ -109,7 +109,8 @@ class BasePredictor:
109
109
  callbacks.add_integration_callbacks(self)
110
110
 
111
111
  def preprocess(self, im):
112
- """Prepares input image before inference.
112
+ """
113
+ Prepares input image before inference.
113
114
 
114
115
  Args:
115
116
  im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
@@ -128,6 +129,7 @@ class BasePredictor:
128
129
  return im
129
130
 
130
131
  def inference(self, im, *args, **kwargs):
132
+ """Runs inference on a given image using the specified model and arguments."""
131
133
  visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem,
132
134
  mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
133
135
  return self.model(im, augment=self.args.augment, visualize=visualize)
@@ -194,7 +196,11 @@ class BasePredictor:
194
196
  return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one
195
197
 
196
198
  def predict_cli(self, source=None, model=None):
197
- """Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode."""
199
+ """
200
+ Method used for CLI prediction.
201
+
202
+ It uses always generator as outputs as not required by CLI mode.
203
+ """
198
204
  gen = self.stream_inference(source, model)
199
205
  for _ in gen: # running CLI inference without accumulating any outputs (do not modify)
200
206
  pass
@@ -352,7 +358,5 @@ class BasePredictor:
352
358
  callback(self)
353
359
 
354
360
  def add_callback(self, event: str, func):
355
- """
356
- Add callback
357
- """
361
+ """Add callback."""
358
362
  self.callbacks[event].append(func)
@@ -1,6 +1,6 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
  """
3
- Ultralytics Results, Boxes and Masks classes for handling inference results
3
+ Ultralytics Results, Boxes and Masks classes for handling inference results.
4
4
 
5
5
  Usage: See https://docs.ultralytics.com/modes/predict/
6
6
  """
@@ -19,12 +19,11 @@ from ultralytics.utils.torch_utils import smart_inference_mode
19
19
 
20
20
 
21
21
  class BaseTensor(SimpleClass):
22
- """
23
- Base tensor class with additional methods for easy manipulation and device handling.
24
- """
22
+ """Base tensor class with additional methods for easy manipulation and device handling."""
25
23
 
26
24
  def __init__(self, data, orig_shape) -> None:
27
- """Initialize BaseTensor with data and original shape.
25
+ """
26
+ Initialize BaseTensor with data and original shape.
28
27
 
29
28
  Args:
30
29
  data (torch.Tensor | np.ndarray): Predictions, such as bboxes, masks and keypoints.
@@ -126,6 +125,18 @@ class Results(SimpleClass):
126
125
  self.probs = probs
127
126
 
128
127
  def _apply(self, fn, *args, **kwargs):
128
+ """
129
+ Applies a function to all non-empty attributes and returns a new Results object with modified attributes. This
130
+ function is internally called by methods like .to(), .cuda(), .cpu(), etc.
131
+
132
+ Args:
133
+ fn (str): The name of the function to apply.
134
+ *args: Variable length argument list to pass to the function.
135
+ **kwargs: Arbitrary keyword arguments to pass to the function.
136
+
137
+ Returns:
138
+ Results: A new Results object with attributes modified by the applied function.
139
+ """
129
140
  r = self.new()
130
141
  for k in self._keys:
131
142
  v = getattr(self, k)
@@ -250,9 +261,7 @@ class Results(SimpleClass):
250
261
  return annotator.result()
251
262
 
252
263
  def verbose(self):
253
- """
254
- Return log string for each task.
255
- """
264
+ """Return log string for each task."""
256
265
  log_string = ''
257
266
  probs = self.probs
258
267
  boxes = self.boxes
@@ -537,6 +546,7 @@ class Probs(BaseTensor):
537
546
  """
538
547
 
539
548
  def __init__(self, probs, orig_shape=None) -> None:
549
+ """Initialize the Probs class with classification probabilities and optional original shape of the image."""
540
550
  super().__init__(probs, orig_shape)
541
551
 
542
552
  @property
@@ -1,6 +1,6 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
  """
3
- Train a model on a dataset
3
+ Train a model on a dataset.
4
4
 
5
5
  Usage:
6
6
  $ yolo mode=train model=yolov8n.pt data=coco128.yaml imgsz=640 epochs=100 batch=16
@@ -37,7 +37,7 @@ from ultralytics.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel,
37
37
 
38
38
  class BaseTrainer:
39
39
  """
40
- BaseTrainer
40
+ BaseTrainer.
41
41
 
42
42
  A base class for creating trainers.
43
43
 
@@ -143,15 +143,11 @@ class BaseTrainer:
143
143
  callbacks.add_integration_callbacks(self)
144
144
 
145
145
  def add_callback(self, event: str, callback):
146
- """
147
- Appends the given callback.
148
- """
146
+ """Appends the given callback."""
149
147
  self.callbacks[event].append(callback)
150
148
 
151
149
  def set_callback(self, event: str, callback):
152
- """
153
- Overrides the existing callbacks with the given callback.
154
- """
150
+ """Overrides the existing callbacks with the given callback."""
155
151
  self.callbacks[event] = [callback]
156
152
 
157
153
  def run_callbacks(self, event: str):
@@ -207,9 +203,7 @@ class BaseTrainer:
207
203
  world_size=world_size)
208
204
 
209
205
  def _setup_train(self, world_size):
210
- """
211
- Builds dataloaders and optimizer on correct rank process.
212
- """
206
+ """Builds dataloaders and optimizer on correct rank process."""
213
207
 
214
208
  # Model
215
209
  self.run_callbacks('on_pretrain_routine_start')
@@ -450,14 +444,14 @@ class BaseTrainer:
450
444
  @staticmethod
451
445
  def get_dataset(data):
452
446
  """
453
- Get train, val path from data dict if it exists. Returns None if data format is not recognized.
447
+ Get train, val path from data dict if it exists.
448
+
449
+ Returns None if data format is not recognized.
454
450
  """
455
451
  return data['train'], data.get('val') or data.get('test')
456
452
 
457
453
  def setup_model(self):
458
- """
459
- load/create/download model for any task.
460
- """
454
+ """Load/create/download model for any task."""
461
455
  if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
462
456
  return
463
457
 
@@ -482,14 +476,14 @@ class BaseTrainer:
482
476
  self.ema.update(self.model)
483
477
 
484
478
  def preprocess_batch(self, batch):
485
- """
486
- Allows custom preprocessing model inputs and ground truths depending on task type.
487
- """
479
+ """Allows custom preprocessing model inputs and ground truths depending on task type."""
488
480
  return batch
489
481
 
490
482
  def validate(self):
491
483
  """
492
- Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
484
+ Runs validation on test set using self.validator.
485
+
486
+ The returned dict is expected to contain "fitness" key.
493
487
  """
494
488
  metrics = self.validator(self)
495
489
  fitness = metrics.pop('fitness', -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
@@ -506,26 +500,20 @@ class BaseTrainer:
506
500
  raise NotImplementedError('get_validator function not implemented in trainer')
507
501
 
508
502
  def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
509
- """
510
- Returns dataloader derived from torch.data.Dataloader.
511
- """
503
+ """Returns dataloader derived from torch.data.Dataloader."""
512
504
  raise NotImplementedError('get_dataloader function not implemented in trainer')
513
505
 
514
506
  def build_dataset(self, img_path, mode='train', batch=None):
515
- """Build dataset"""
507
+ """Build dataset."""
516
508
  raise NotImplementedError('build_dataset function not implemented in trainer')
517
509
 
518
510
  def label_loss_items(self, loss_items=None, prefix='train'):
519
- """
520
- Returns a loss dict with labelled training loss items tensor
521
- """
511
+ """Returns a loss dict with labelled training loss items tensor."""
522
512
  # Not needed for classification but necessary for segmentation & detection
523
513
  return {'loss': loss_items} if loss_items is not None else ['loss']
524
514
 
525
515
  def set_model_attributes(self):
526
- """
527
- To set or update model parameters before training.
528
- """
516
+ """To set or update model parameters before training."""
529
517
  self.model.names = self.data['names']
530
518
 
531
519
  def build_targets(self, preds, targets):
@@ -632,8 +620,8 @@ class BaseTrainer:
632
620
 
633
621
  def build_optimizer(self, model, name='auto', lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
634
622
  """
635
- Constructs an optimizer for the given model, based on the specified optimizer name, learning rate,
636
- momentum, weight decay, and number of iterations.
623
+ Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum,
624
+ weight decay, and number of iterations.
637
625
 
638
626
  Args:
639
627
  model (torch.nn.Module): The model for which to build an optimizer.
@@ -31,32 +31,32 @@ from ultralytics.utils.plotting import plot_tune_results
31
31
 
32
32
  class Tuner:
33
33
  """
34
- Class responsible for hyperparameter tuning of YOLO models.
34
+ Class responsible for hyperparameter tuning of YOLO models.
35
35
 
36
- The class evolves YOLO model hyperparameters over a given number of iterations
37
- by mutating them according to the search space and retraining the model to evaluate their performance.
36
+ The class evolves YOLO model hyperparameters over a given number of iterations
37
+ by mutating them according to the search space and retraining the model to evaluate their performance.
38
38
 
39
- Attributes:
40
- space (dict): Hyperparameter search space containing bounds and scaling factors for mutation.
41
- tune_dir (Path): Directory where evolution logs and results will be saved.
42
- tune_csv (Path): Path to the CSV file where evolution logs are saved.
39
+ Attributes:
40
+ space (dict): Hyperparameter search space containing bounds and scaling factors for mutation.
41
+ tune_dir (Path): Directory where evolution logs and results will be saved.
42
+ tune_csv (Path): Path to the CSV file where evolution logs are saved.
43
43
 
44
- Methods:
45
- _mutate(hyp: dict) -> dict:
46
- Mutates the given hyperparameters within the bounds specified in `self.space`.
44
+ Methods:
45
+ _mutate(hyp: dict) -> dict:
46
+ Mutates the given hyperparameters within the bounds specified in `self.space`.
47
47
 
48
- __call__():
49
- Executes the hyperparameter evolution across multiple iterations.
48
+ __call__():
49
+ Executes the hyperparameter evolution across multiple iterations.
50
50
 
51
- Example:
52
- Tune hyperparameters for YOLOv8n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
53
- ```python
54
- from ultralytics import YOLO
51
+ Example:
52
+ Tune hyperparameters for YOLOv8n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
53
+ ```python
54
+ from ultralytics import YOLO
55
55
 
56
- model = YOLO('yolov8n.pt')
57
- model.tune(data='coco8.yaml', epochs=10, iterations=300, optimizer='AdamW', plots=False, save=False, val=False)
58
- ```
59
- """
56
+ model = YOLO('yolov8n.pt')
57
+ model.tune(data='coco8.yaml', epochs=10, iterations=300, optimizer='AdamW', plots=False, save=False, val=False)
58
+ ```
59
+ """
60
60
 
61
61
  def __init__(self, args=DEFAULT_CFG, _callbacks=None):
62
62
  """
@@ -36,7 +36,7 @@ from ultralytics.utils.torch_utils import de_parallel, select_device, smart_infe
36
36
 
37
37
  class BaseValidator:
38
38
  """
39
- BaseValidator
39
+ BaseValidator.
40
40
 
41
41
  A base class for creating validators.
42
42
 
@@ -102,8 +102,7 @@ class BaseValidator:
102
102
 
103
103
  @smart_inference_mode()
104
104
  def __call__(self, trainer=None, model=None):
105
- """
106
- Supports validation of a pre-trained model if passed or a model being trained if trainer is passed (trainer
105
+ """Supports validation of a pre-trained model if passed or a model being trained if trainer is passed (trainer
107
106
  gets priority).
108
107
  """
109
108
  self.training = trainer is not None
@@ -260,7 +259,7 @@ class BaseValidator:
260
259
  raise NotImplementedError('get_dataloader function not implemented for this validator')
261
260
 
262
261
  def build_dataset(self, img_path):
263
- """Build dataset"""
262
+ """Build dataset."""
264
263
  raise NotImplementedError('build_dataset function not implemented in validator')
265
264
 
266
265
  def preprocess(self, batch):
@@ -80,8 +80,8 @@ def get_export(model_id='', format='torchscript'):
80
80
 
81
81
  def check_dataset(path='', task='detect'):
82
82
  """
83
- Function for error-checking HUB dataset Zip file before upload. It checks a dataset for errors before it is
84
- uploaded to the HUB. Usage examples are given below.
83
+ Function for error-checking HUB dataset Zip file before upload. It checks a dataset for errors before it is uploaded
84
+ to the HUB. Usage examples are given below.
85
85
 
86
86
  Args:
87
87
  path (str, optional): Path to data.zip (with data.yaml inside data.zip). Defaults to ''.
ultralytics/hub/auth.py CHANGED
@@ -9,6 +9,19 @@ API_KEY_URL = f'{HUB_WEB_ROOT}/settings?tab=api+keys'
9
9
 
10
10
 
11
11
  class Auth:
12
+ """
13
+ Manages authentication processes including API key handling, cookie-based authentication, and header generation.
14
+
15
+ The class supports different methods of authentication:
16
+ 1. Directly using an API key.
17
+ 2. Authenticating using browser cookies (specifically in Google Colab).
18
+ 3. Prompting the user to enter an API key.
19
+
20
+ Attributes:
21
+ id_token (str or bool): Token used for identity verification, initialized as False.
22
+ api_key (str or bool): API key for authentication, initialized as False.
23
+ model_key (bool): Placeholder for model key, initialized as False.
24
+ """
12
25
  id_token = api_key = model_key = False
13
26
 
14
27
  def __init__(self, api_key='', verbose=False):
@@ -54,7 +67,9 @@ class Auth:
54
67
 
55
68
  def request_api_key(self, max_attempts=3):
56
69
  """
57
- Prompt the user to input their API key. Returns the model ID.
70
+ Prompt the user to input their API key.
71
+
72
+ Returns the model ID.
58
73
  """
59
74
  import getpass
60
75
  for attempts in range(max_attempts):
@@ -86,8 +101,8 @@ class Auth:
86
101
 
87
102
  def auth_with_cookies(self) -> bool:
88
103
  """
89
- Attempt to fetch authentication via cookies and set id_token.
90
- User must be logged in to HUB and running in a supported browser.
104
+ Attempt to fetch authentication via cookies and set id_token. User must be logged in to HUB and running in a
105
+ supported browser.
91
106
 
92
107
  Returns:
93
108
  bool: True if authentication is successful, False otherwise.
@@ -84,6 +84,7 @@ class HUBTrainingSession:
84
84
  def _handle_signal(self, signum, frame):
85
85
  """
86
86
  Handle kill signals and prevent heartbeats from being sent on Colab after termination.
87
+
87
88
  This method does not use frame, it is included as it is passed by signal.
88
89
  """
89
90
  if self.alive is True:
ultralytics/hub/utils.py CHANGED
@@ -161,9 +161,7 @@ class Events:
161
161
  url = 'https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw'
162
162
 
163
163
  def __init__(self):
164
- """
165
- Initializes the Events object with default values for events, rate_limit, and metadata.
166
- """
164
+ """Initializes the Events object with default values for events, rate_limit, and metadata."""
167
165
  self.events = [] # events list
168
166
  self.rate_limit = 60.0 # rate limit (seconds)
169
167
  self.t = 0.0 # rate limit timer (seconds)
@@ -22,7 +22,7 @@ class FastSAM(Model):
22
22
  """
23
23
 
24
24
  def __init__(self, model='FastSAM-x.pt'):
25
- """Call the __init__ method of the parent class (YOLO) with the updated default model"""
25
+ """Call the __init__ method of the parent class (YOLO) with the updated default model."""
26
26
  if str(model) == 'FastSAM.pt':
27
27
  model = 'FastSAM-x.pt'
28
28
  assert Path(model).suffix not in ('.yaml', '.yml'), 'FastSAM models only support pre-trained models.'
@@ -30,4 +30,5 @@ class FastSAM(Model):
30
30
 
31
31
  @property
32
32
  def task_map(self):
33
+ """Returns a dictionary mapping segment task to corresponding predictor and validator classes."""
33
34
  return {'segment': {'predictor': FastSAMPredictor, 'validator': FastSAMValidator}}