ultralytics 8.3.8__py3-none-any.whl → 8.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

tests/test_solutions.py CHANGED
@@ -14,24 +14,21 @@ WORKOUTS_SOLUTION_DEMO = "https://github.com/ultralytics/assets/releases/downloa
14
14
  def test_major_solutions():
15
15
  """Test the object counting, heatmap, speed estimation and queue management solution."""
16
16
  safe_download(url=MAJOR_SOLUTIONS_DEMO)
17
- model = YOLO("yolo11n.pt")
18
- names = model.names
19
17
  cap = cv2.VideoCapture("solutions_ci_demo.mp4")
20
18
  assert cap.isOpened(), "Error reading video file"
21
19
  region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)]
22
20
  counter = solutions.ObjectCounter(region=region_points, model="yolo11n.pt", show=False)
23
21
  heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, model="yolo11n.pt", show=False)
24
- speed = solutions.SpeedEstimator(reg_pts=region_points, names=names, view_img=False)
22
+ speed = solutions.SpeedEstimator(region=region_points, model="yolo11n.pt", show=False)
25
23
  queue = solutions.QueueManager(region=region_points, model="yolo11n.pt", show=False)
26
24
  while cap.isOpened():
27
25
  success, im0 = cap.read()
28
26
  if not success:
29
27
  break
30
28
  original_im0 = im0.copy()
31
- tracks = model.track(im0, persist=True, show=False)
32
29
  _ = counter.count(original_im0.copy())
33
30
  _ = heatmap.generate_heatmap(original_im0.copy())
34
- _ = speed.estimate_speed(original_im0.copy(), tracks)
31
+ _ = speed.estimate_speed(original_im0.copy())
35
32
  _ = queue.process_queue(original_im0.copy())
36
33
  cap.release()
37
34
  cv2.destroyAllWindows()
ultralytics/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
- __version__ = "8.3.8"
3
+ __version__ = "8.3.10"
4
4
 
5
5
  import os
6
6
 
@@ -2,15 +2,15 @@
2
2
 
3
3
  # Configuration for Ultralytics Solutions
4
4
 
5
- model: "yolo11n.pt" # The Ultralytics YOLO11 model to be used (e.g., yolo11n.pt for YOLO11 nano version)
5
+ model: "yolo11n.pt" # The Ultralytics YOLO11 model to be used (e.g., yolo11n.pt for YOLO11 nano version and yolov8n.pt for YOLOv8 nano version)
6
6
 
7
- region: # Object counting, queue or speed estimation region points
8
- line_width: 2 # Thickness of the lines used to draw regions on the image/video frames
9
- show: True # Flag to control whether to display output image or not
7
+ region: # Object counting, queue or speed estimation region points. Default region points are [(20, 400), (1080, 404), (1080, 360), (20, 360)]
8
+ line_width: 2 # Width of the annotator used to draw regions on the image/video frames + bounding boxes and tracks drawing. Default value is 2.
9
+ show: True # Flag to control whether to display output image or not, you can set this as False i.e. when deploying it on some embedded devices.
10
10
  show_in: True # Flag to display objects moving *into* the defined region
11
11
  show_out: True # Flag to display objects moving *out of* the defined region
12
- classes: # To count specific classes
13
- up_angle: 145.0 # Workouts up_angle for counts, 145.0 is default value
14
- down_angle: 90 # Workouts down_angle for counts, 90 is default value
15
- kpts: [6, 8, 10] # Keypoints for workouts monitoring
16
- colormap: # Colormap for heatmap
12
+ classes: # To count specific classes. i.e, if you want to detect, track and count the person with COCO model, you can use classes=0, Default its None
13
+ up_angle: 145.0 # Workouts up_angle for counts, 145.0 is default value. You can adjust it for different workouts, based on position of keypoints.
14
+ down_angle: 90 # Workouts down_angle for counts, 90 is default value. You can change it for different workouts, based on position of keypoints.
15
+ kpts: [6, 8, 10] # Keypoints for workouts monitoring, i.e. If you want to consider keypoints for pushups that have mostly values of [6, 8, 10].
16
+ colormap: # Colormap for heatmap, Only OPENCV supported colormaps can be used. By default COLORMAP_PARULA will be used for visualization.
@@ -18,11 +18,29 @@ from PIL import Image
18
18
  from ultralytics.data.utils import FORMATS_HELP_MSG, IMG_FORMATS, VID_FORMATS
19
19
  from ultralytics.utils import IS_COLAB, IS_KAGGLE, LOGGER, ops
20
20
  from ultralytics.utils.checks import check_requirements
21
+ from ultralytics.utils.patches import imread
21
22
 
22
23
 
23
24
  @dataclass
24
25
  class SourceTypes:
25
- """Class to represent various types of input sources for predictions."""
26
+ """
27
+ Class to represent various types of input sources for predictions.
28
+
29
+ This class uses dataclass to define boolean flags for different types of input sources that can be used for
30
+ making predictions with YOLO models.
31
+
32
+ Attributes:
33
+ stream (bool): Flag indicating if the input source is a video stream.
34
+ screenshot (bool): Flag indicating if the input source is a screenshot.
35
+ from_img (bool): Flag indicating if the input source is an image file.
36
+
37
+ Examples:
38
+ >>> source_types = SourceTypes(stream=True, screenshot=False, from_img=False)
39
+ >>> print(source_types.stream)
40
+ True
41
+ >>> print(source_types.from_img)
42
+ False
43
+ """
26
44
 
27
45
  stream: bool = False
28
46
  screenshot: bool = False
@@ -32,38 +50,47 @@ class SourceTypes:
32
50
 
33
51
  class LoadStreams:
34
52
  """
35
- Stream Loader for various types of video streams, Supports RTSP, RTMP, HTTP, and TCP streams.
53
+ Stream Loader for various types of video streams.
54
+
55
+ Supports RTSP, RTMP, HTTP, and TCP streams. This class handles the loading and processing of multiple video
56
+ streams simultaneously, making it suitable for real-time video analysis tasks.
36
57
 
37
58
  Attributes:
38
- sources (str): The source input paths or URLs for the video streams.
39
- vid_stride (int): Video frame-rate stride, defaults to 1.
40
- buffer (bool): Whether to buffer input streams, defaults to False.
59
+ sources (List[str]): The source input paths or URLs for the video streams.
60
+ vid_stride (int): Video frame-rate stride.
61
+ buffer (bool): Whether to buffer input streams.
41
62
  running (bool): Flag to indicate if the streaming thread is running.
42
63
  mode (str): Set to 'stream' indicating real-time capture.
43
- imgs (list): List of image frames for each stream.
44
- fps (list): List of FPS for each stream.
45
- frames (list): List of total frames for each stream.
46
- threads (list): List of threads for each stream.
47
- shape (list): List of shapes for each stream.
48
- caps (list): List of cv2.VideoCapture objects for each stream.
64
+ imgs (List[List[np.ndarray]]): List of image frames for each stream.
65
+ fps (List[float]): List of FPS for each stream.
66
+ frames (List[int]): List of total frames for each stream.
67
+ threads (List[Thread]): List of threads for each stream.
68
+ shape (List[Tuple[int, int, int]]): List of shapes for each stream.
69
+ caps (List[cv2.VideoCapture]): List of cv2.VideoCapture objects for each stream.
49
70
  bs (int): Batch size for processing.
50
71
 
51
72
  Methods:
52
- __init__: Initialize the stream loader.
53
73
  update: Read stream frames in daemon thread.
54
74
  close: Close stream loader and release resources.
55
75
  __iter__: Returns an iterator object for the class.
56
76
  __next__: Returns source paths, transformed, and original images for processing.
57
77
  __len__: Return the length of the sources object.
58
78
 
59
- Example:
60
- ```bash
61
- yolo predict source='rtsp://example.com/media.mp4'
62
- ```
79
+ Examples:
80
+ >>> stream_loader = LoadStreams("rtsp://example.com/stream1.mp4")
81
+ >>> for sources, imgs, _ in stream_loader:
82
+ ... # Process the images
83
+ ... pass
84
+ >>> stream_loader.close()
85
+
86
+ Notes:
87
+ - The class uses threading to efficiently load frames from multiple streams simultaneously.
88
+ - It automatically handles YouTube links, converting them to the best available stream URL.
89
+ - The class implements a buffer system to manage frame storage and retrieval.
63
90
  """
64
91
 
65
92
  def __init__(self, sources="file.streams", vid_stride=1, buffer=False):
66
- """Initialize instance variables and check for consistent input stream shapes."""
93
+ """Initialize stream loader for multiple video sources, supporting various stream types."""
67
94
  torch.backends.cudnn.benchmark = True # faster for fixed-size inference
68
95
  self.buffer = buffer # buffer input streams
69
96
  self.running = True # running flag for Thread
@@ -114,7 +141,7 @@ class LoadStreams:
114
141
  LOGGER.info("") # newline
115
142
 
116
143
  def update(self, i, cap, stream):
117
- """Read stream `i` frames in daemon thread."""
144
+ """Read stream frames in daemon thread and update image buffer."""
118
145
  n, f = 0, self.frames[i] # frame number, frame array
119
146
  while self.running and cap.isOpened() and n < (f - 1):
120
147
  if len(self.imgs[i]) < 30: # keep a <=30-image buffer
@@ -134,7 +161,7 @@ class LoadStreams:
134
161
  time.sleep(0.01) # wait until the buffer is empty
135
162
 
136
163
  def close(self):
137
- """Close stream loader and release resources."""
164
+ """Terminates stream loader, stops threads, and releases video capture resources."""
138
165
  self.running = False # stop flag for Thread
139
166
  for thread in self.threads:
140
167
  if thread.is_alive():
@@ -152,7 +179,7 @@ class LoadStreams:
152
179
  return self
153
180
 
154
181
  def __next__(self):
155
- """Returns source paths, transformed and original images for processing."""
182
+ """Returns the next batch of frames from multiple video streams for processing."""
156
183
  self.count += 1
157
184
 
158
185
  images = []
@@ -179,16 +206,16 @@ class LoadStreams:
179
206
  return self.sources, images, [""] * self.bs
180
207
 
181
208
  def __len__(self):
182
- """Return the length of the sources object."""
209
+ """Return the number of video streams in the LoadStreams object."""
183
210
  return self.bs # 1E12 frames = 32 streams at 30 FPS for 30 years
184
211
 
185
212
 
186
213
  class LoadScreenshots:
187
214
  """
188
- YOLOv8 screenshot dataloader.
215
+ Ultralytics screenshot dataloader for capturing and processing screen images.
189
216
 
190
- This class manages the loading of screenshot images for processing with YOLOv8.
191
- Suitable for use with `yolo predict source=screen`.
217
+ This class manages the loading of screenshot images for processing with YOLO. It is suitable for use with
218
+ `yolo predict source=screen`.
192
219
 
193
220
  Attributes:
194
221
  source (str): The source input indicating which screen to capture.
@@ -201,15 +228,21 @@ class LoadScreenshots:
201
228
  frame (int): Counter for captured frames.
202
229
  sct (mss.mss): Screen capture object from `mss` library.
203
230
  bs (int): Batch size, set to 1.
204
- monitor (dict): Monitor configuration details.
231
+ fps (int): Frames per second, set to 30.
232
+ monitor (Dict[str, int]): Monitor configuration details.
205
233
 
206
234
  Methods:
207
235
  __iter__: Returns an iterator object.
208
236
  __next__: Captures the next screenshot and returns it.
237
+
238
+ Examples:
239
+ >>> loader = LoadScreenshots("0 100 100 640 480") # screen 0, top-left (100,100), 640x480
240
+ >>> for source, im, im0s, vid_cap, s in loader:
241
+ ... print(f"Captured frame: {im.shape}")
209
242
  """
210
243
 
211
244
  def __init__(self, source):
212
- """Source = [screen_number left top width height] (pixels)."""
245
+ """Initialize screenshot capture with specified screen and region parameters."""
213
246
  check_requirements("mss")
214
247
  import mss # noqa
215
248
 
@@ -236,11 +269,11 @@ class LoadScreenshots:
236
269
  self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}
237
270
 
238
271
  def __iter__(self):
239
- """Returns an iterator of the object."""
272
+ """Yields the next screenshot image from the specified screen or region for processing."""
240
273
  return self
241
274
 
242
275
  def __next__(self):
243
- """Screen capture with 'mss' to get raw pixels from the screen as np array."""
276
+ """Captures and returns the next screenshot as a numpy array using the mss library."""
244
277
  im0 = np.asarray(self.sct.grab(self.monitor))[:, :, :3] # BGRA to BGR
245
278
  s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
246
279
 
@@ -250,29 +283,45 @@ class LoadScreenshots:
250
283
 
251
284
  class LoadImagesAndVideos:
252
285
  """
253
- YOLOv8 image/video dataloader.
286
+ A class for loading and processing images and videos for YOLO object detection.
254
287
 
255
- This class manages the loading and pre-processing of image and video data for YOLOv8. It supports loading from
256
- various formats, including single image files, video files, and lists of image and video paths.
288
+ This class manages the loading and pre-processing of image and video data from various sources, including
289
+ single image files, video files, and lists of image and video paths.
257
290
 
258
291
  Attributes:
259
- files (list): List of image and video file paths.
292
+ files (List[str]): List of image and video file paths.
260
293
  nf (int): Total number of files (images and videos).
261
- video_flag (list): Flags indicating whether a file is a video (True) or an image (False).
294
+ video_flag (List[bool]): Flags indicating whether a file is a video (True) or an image (False).
262
295
  mode (str): Current mode, 'image' or 'video'.
263
- vid_stride (int): Stride for video frame-rate, defaults to 1.
264
- bs (int): Batch size, set to 1 for this class.
296
+ vid_stride (int): Stride for video frame-rate.
297
+ bs (int): Batch size.
265
298
  cap (cv2.VideoCapture): Video capture object for OpenCV.
266
299
  frame (int): Frame counter for video.
267
300
  frames (int): Total number of frames in the video.
268
- count (int): Counter for iteration, initialized at 0 during `__iter__()`.
301
+ count (int): Counter for iteration, initialized at 0 during __iter__().
302
+ ni (int): Number of images.
269
303
 
270
304
  Methods:
271
- _new_video(path): Create a new cv2.VideoCapture object for a given video path.
305
+ __init__: Initialize the LoadImagesAndVideos object.
306
+ __iter__: Returns an iterator object for VideoStream or ImageFolder.
307
+ __next__: Returns the next batch of images or video frames along with their paths and metadata.
308
+ _new_video: Creates a new video capture object for the given path.
309
+ __len__: Returns the number of batches in the object.
310
+
311
+ Examples:
312
+ >>> loader = LoadImagesAndVideos("path/to/data", batch=32, vid_stride=1)
313
+ >>> for paths, imgs, info in loader:
314
+ ... # Process batch of images or video frames
315
+ ... pass
316
+
317
+ Notes:
318
+ - Supports various image formats including HEIC.
319
+ - Handles both local files and directories.
320
+ - Can read from a text file containing paths to images and videos.
272
321
  """
273
322
 
274
323
  def __init__(self, path, batch=1, vid_stride=1):
275
- """Initialize the Dataloader and raise FileNotFoundError if file not found."""
324
+ """Initialize dataloader for images and videos, supporting various input formats."""
276
325
  parent = None
277
326
  if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
278
327
  parent = Path(path).parent
@@ -316,12 +365,12 @@ class LoadImagesAndVideos:
316
365
  raise FileNotFoundError(f"No images or videos found in {p}. {FORMATS_HELP_MSG}")
317
366
 
318
367
  def __iter__(self):
319
- """Returns an iterator object for VideoStream or ImageFolder."""
368
+ """Iterates through image/video files, yielding source paths, images, and metadata."""
320
369
  self.count = 0
321
370
  return self
322
371
 
323
372
  def __next__(self):
324
- """Returns the next batch of images or video frames along with their paths and metadata."""
373
+ """Returns the next batch of images or video frames with their paths and metadata."""
325
374
  paths, imgs, info = [], [], []
326
375
  while len(imgs) < self.bs:
327
376
  if self.count >= self.nf: # end of file list
@@ -336,6 +385,7 @@ class LoadImagesAndVideos:
336
385
  if not self.cap or not self.cap.isOpened():
337
386
  self._new_video(path)
338
387
 
388
+ success = False
339
389
  for _ in range(self.vid_stride):
340
390
  success = self.cap.grab()
341
391
  if not success:
@@ -359,8 +409,19 @@ class LoadImagesAndVideos:
359
409
  if self.count < self.nf:
360
410
  self._new_video(self.files[self.count])
361
411
  else:
412
+ # Handle image files (including HEIC)
362
413
  self.mode = "image"
363
- im0 = cv2.imread(path) # BGR
414
+ if path.split(".")[-1].lower() == "heic":
415
+ # Load HEIC image using Pillow with pillow-heif
416
+ check_requirements("pillow-heif")
417
+
418
+ from pillow_heif import register_heif_opener
419
+
420
+ register_heif_opener() # Register HEIF opener with Pillow
421
+ with Image.open(path) as img:
422
+ im0 = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR) # convert image to BGR nparray
423
+ else:
424
+ im0 = imread(path) # BGR
364
425
  if im0 is None:
365
426
  LOGGER.warning(f"WARNING ⚠️ Image Read Error {path}")
366
427
  else:
@@ -374,7 +435,7 @@ class LoadImagesAndVideos:
374
435
  return paths, imgs, info
375
436
 
376
437
  def _new_video(self, path):
377
- """Creates a new video capture object for the given path."""
438
+ """Creates a new video capture object for the given path and initializes video-related attributes."""
378
439
  self.frame = 0
379
440
  self.cap = cv2.VideoCapture(path)
380
441
  self.fps = int(self.cap.get(cv2.CAP_PROP_FPS))
@@ -383,40 +444,50 @@ class LoadImagesAndVideos:
383
444
  self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
384
445
 
385
446
  def __len__(self):
386
- """Returns the number of batches in the object."""
387
- return math.ceil(self.nf / self.bs) # number of files
447
+ """Returns the number of files (images and videos) in the dataset."""
448
+ return math.ceil(self.nf / self.bs) # number of batches
388
449
 
389
450
 
390
451
  class LoadPilAndNumpy:
391
452
  """
392
453
  Load images from PIL and Numpy arrays for batch processing.
393
454
 
394
- This class is designed to manage loading and pre-processing of image data from both PIL and Numpy formats.
395
- It performs basic validation and format conversion to ensure that the images are in the required format for
396
- downstream processing.
455
+ This class manages loading and pre-processing of image data from both PIL and Numpy formats. It performs basic
456
+ validation and format conversion to ensure that the images are in the required format for downstream processing.
397
457
 
398
458
  Attributes:
399
- paths (list): List of image paths or autogenerated filenames.
400
- im0 (list): List of images stored as Numpy arrays.
401
- mode (str): Type of data being processed, defaults to 'image'.
459
+ paths (List[str]): List of image paths or autogenerated filenames.
460
+ im0 (List[np.ndarray]): List of images stored as Numpy arrays.
461
+ mode (str): Type of data being processed, set to 'image'.
402
462
  bs (int): Batch size, equivalent to the length of `im0`.
403
463
 
404
464
  Methods:
405
- _single_check(im): Validate and format a single image to a Numpy array.
465
+ _single_check: Validate and format a single image to a Numpy array.
466
+
467
+ Examples:
468
+ >>> from PIL import Image
469
+ >>> import numpy as np
470
+ >>> pil_img = Image.new("RGB", (100, 100))
471
+ >>> np_img = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
472
+ >>> loader = LoadPilAndNumpy([pil_img, np_img])
473
+ >>> paths, images, _ = next(iter(loader))
474
+ >>> print(f"Loaded {len(images)} images")
475
+ Loaded 2 images
406
476
  """
407
477
 
408
478
  def __init__(self, im0):
409
- """Initialize PIL and Numpy Dataloader."""
479
+ """Initializes a loader for PIL and Numpy images, converting inputs to a standardized format."""
410
480
  if not isinstance(im0, list):
411
481
  im0 = [im0]
412
- self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)]
482
+ # use `image{i}.jpg` when Image.filename returns an empty path.
483
+ self.paths = [getattr(im, "filename", "") or f"image{i}.jpg" for i, im in enumerate(im0)]
413
484
  self.im0 = [self._single_check(im) for im in im0]
414
485
  self.mode = "image"
415
486
  self.bs = len(self.im0)
416
487
 
417
488
  @staticmethod
418
489
  def _single_check(im):
419
- """Validate and format an image to numpy array."""
490
+ """Validate and format an image to numpy array, ensuring RGB order and contiguous memory."""
420
491
  assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}"
421
492
  if isinstance(im, Image.Image):
422
493
  if im.mode != "RGB":
@@ -426,41 +497,48 @@ class LoadPilAndNumpy:
426
497
  return im
427
498
 
428
499
  def __len__(self):
429
- """Returns the length of the 'im0' attribute."""
500
+ """Returns the length of the 'im0' attribute, representing the number of loaded images."""
430
501
  return len(self.im0)
431
502
 
432
503
  def __next__(self):
433
- """Returns batch paths, images, processed images, None, ''."""
504
+ """Returns the next batch of images, paths, and metadata for processing."""
434
505
  if self.count == 1: # loop only once as it's batch inference
435
506
  raise StopIteration
436
507
  self.count += 1
437
508
  return self.paths, self.im0, [""] * self.bs
438
509
 
439
510
  def __iter__(self):
440
- """Enables iteration for class LoadPilAndNumpy."""
511
+ """Iterates through PIL/numpy images, yielding paths, raw images, and metadata for processing."""
441
512
  self.count = 0
442
513
  return self
443
514
 
444
515
 
445
516
  class LoadTensor:
446
517
  """
447
- Load images from torch.Tensor data.
518
+ A class for loading and processing tensor data for object detection tasks.
448
519
 
449
- This class manages the loading and pre-processing of image data from PyTorch tensors for further processing.
520
+ This class handles the loading and pre-processing of image data from PyTorch tensors, preparing them for
521
+ further processing in object detection pipelines.
450
522
 
451
523
  Attributes:
452
- im0 (torch.Tensor): The input tensor containing the image(s).
524
+ im0 (torch.Tensor): The input tensor containing the image(s) with shape (B, C, H, W).
453
525
  bs (int): Batch size, inferred from the shape of `im0`.
454
- mode (str): Current mode, set to 'image'.
455
- paths (list): List of image paths or filenames.
456
- count (int): Counter for iteration, initialized at 0 during `__iter__()`.
526
+ mode (str): Current processing mode, set to 'image'.
527
+ paths (List[str]): List of image paths or auto-generated filenames.
457
528
 
458
529
  Methods:
459
- _single_check(im, stride): Validate and possibly modify the input tensor.
530
+ _single_check: Validates and formats an input tensor.
531
+
532
+ Examples:
533
+ >>> import torch
534
+ >>> tensor = torch.rand(1, 3, 640, 640)
535
+ >>> loader = LoadTensor(tensor)
536
+ >>> paths, images, info = next(iter(loader))
537
+ >>> print(f"Processed {len(images)} images")
460
538
  """
461
539
 
462
540
  def __init__(self, im0) -> None:
463
- """Initialize Tensor Dataloader."""
541
+ """Initialize LoadTensor object for processing torch.Tensor image data."""
464
542
  self.im0 = self._single_check(im0)
465
543
  self.bs = self.im0.shape[0]
466
544
  self.mode = "image"
@@ -468,7 +546,7 @@ class LoadTensor:
468
546
 
469
547
  @staticmethod
470
548
  def _single_check(im, stride=32):
471
- """Validate and format an image to torch.Tensor."""
549
+ """Validates and formats a single image tensor, ensuring correct shape and normalization."""
472
550
  s = (
473
551
  f"WARNING ⚠️ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) "
474
552
  f"divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible."
@@ -490,24 +568,24 @@ class LoadTensor:
490
568
  return im
491
569
 
492
570
  def __iter__(self):
493
- """Returns an iterator object."""
571
+ """Yields an iterator object for iterating through tensor image data."""
494
572
  self.count = 0
495
573
  return self
496
574
 
497
575
  def __next__(self):
498
- """Return next item in the iterator."""
576
+ """Yields the next batch of tensor images and metadata for processing."""
499
577
  if self.count == 1:
500
578
  raise StopIteration
501
579
  self.count += 1
502
580
  return self.paths, self.im0, [""] * self.bs
503
581
 
504
582
  def __len__(self):
505
- """Returns the batch size."""
583
+ """Returns the batch size of the tensor input."""
506
584
  return self.bs
507
585
 
508
586
 
509
587
  def autocast_list(source):
510
- """Merges a list of source of different types into a list of numpy arrays or PIL images."""
588
+ """Merges a list of sources into a list of numpy arrays or PIL images for Ultralytics prediction."""
511
589
  files = []
512
590
  for im in source:
513
591
  if isinstance(im, (str, Path)): # filename or uri
@@ -527,21 +605,24 @@ def get_best_youtube_url(url, method="pytube"):
527
605
  """
528
606
  Retrieves the URL of the best quality MP4 video stream from a given YouTube video.
529
607
 
530
- This function uses the specified method to extract the video info from YouTube. It supports the following methods:
531
- - "pytube": Uses the pytube library to fetch the video streams.
532
- - "pafy": Uses the pafy library to fetch the video streams.
533
- - "yt-dlp": Uses the yt-dlp library to fetch the video streams.
534
-
535
- The function then finds the highest quality MP4 format that has a video codec but no audio codec, and returns the
536
- URL of this video stream.
537
-
538
608
  Args:
539
609
  url (str): The URL of the YouTube video.
540
- method (str): The method to use for extracting video info. Default is "pytube". Other options are "pafy" and
541
- "yt-dlp".
610
+ method (str): The method to use for extracting video info. Options are "pytube", "pafy", and "yt-dlp".
611
+ Defaults to "pytube".
542
612
 
543
613
  Returns:
544
- (str): The URL of the best quality MP4 video stream, or None if no suitable stream is found.
614
+ (str | None): The URL of the best quality MP4 video stream, or None if no suitable stream is found.
615
+
616
+ Examples:
617
+ >>> url = "https://www.youtube.com/watch?v=dQw4w9WgXcQ"
618
+ >>> best_url = get_best_youtube_url(url)
619
+ >>> print(best_url)
620
+ https://rr4---sn-q4flrnek.googlevideo.com/videoplayback?expire=...
621
+
622
+ Notes:
623
+ - Requires additional libraries based on the chosen method: pytubefix, pafy, or yt-dlp.
624
+ - The function prioritizes streams with at least 1080p resolution when available.
625
+ - For the "yt-dlp" method, it looks for formats with video codec, no audio, and *.mp4 extension.
545
626
  """
546
627
  if method == "pytube":
547
628
  # Switched from pytube to pytubefix to resolve https://github.com/pytube/pytube/issues/1954
ultralytics/data/utils.py CHANGED
@@ -35,7 +35,7 @@ from ultralytics.utils.downloads import download, safe_download, unzip_file
35
35
  from ultralytics.utils.ops import segments2boxes
36
36
 
37
37
  HELP_URL = "See https://docs.ultralytics.com/datasets for dataset formatting guidance."
38
- IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm"} # image suffixes
38
+ IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm", "heic"} # image suffixes
39
39
  VID_FORMATS = {"asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"} # video suffixes
40
40
  PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
41
41
  FORMATS_HELP_MSG = f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
@@ -381,7 +381,7 @@ class BasePredictor:
381
381
 
382
382
  # Save images
383
383
  else:
384
- cv2.imwrite(save_path, im)
384
+ cv2.imwrite(str(Path(save_path).with_suffix(".jpg")), im) # save to JPG for best support
385
385
 
386
386
  def show(self, p=""):
387
387
  """Display an image in a window using the OpenCV imshow function."""
@@ -469,11 +469,9 @@ class BaseTrainer:
469
469
 
470
470
  if RANK in {-1, 0}:
471
471
  # Do final val with best.pt
472
- epochs = epoch - self.start_epoch + 1 # total training epochs
473
- seconds = time.time() - self.train_time_start # total training seconds
474
- LOGGER.info(f"\n{epochs} epochs completed in {seconds / 3600:.3f} hours.")
472
+ seconds = time.time() - self.train_time_start
473
+ LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
475
474
  self.final_eval()
476
- self.validator.metrics.training = {"epochs": epochs, "seconds": seconds} # add training speed
477
475
  if self.args.plots:
478
476
  self.plot_metrics()
479
477
  self.run_callbacks("on_train_end")
@@ -504,7 +502,7 @@ class BaseTrainer:
504
502
  """Read results.csv into a dict using pandas."""
505
503
  import pandas as pd # scope for faster 'import ultralytics'
506
504
 
507
- return {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}
505
+ return pd.read_csv(self.csv).to_dict(orient="list")
508
506
 
509
507
  def save_model(self):
510
508
  """Save model training checkpoints with additional metadata."""
@@ -654,10 +652,11 @@ class BaseTrainer:
654
652
  def save_metrics(self, metrics):
655
653
  """Saves training metrics to a CSV file."""
656
654
  keys, vals = list(metrics.keys()), list(metrics.values())
657
- n = len(metrics) + 1 # number of cols
658
- s = "" if self.csv.exists() else (("%23s," * n % tuple(["epoch"] + keys)).rstrip(",") + "\n") # header
655
+ n = len(metrics) + 2 # number of cols
656
+ s = "" if self.csv.exists() else (("%s," * n % tuple(["epoch", "time"] + keys)).rstrip(",") + "\n") # header
657
+ t = time.time() - self.train_time_start
659
658
  with open(self.csv, "a") as f:
660
- f.write(s + ("%23.5g," * n % tuple([self.epoch + 1] + vals)).rstrip(",") + "\n")
659
+ f.write(s + ("%.6g," * n % tuple([self.epoch + 1, t] + vals)).rstrip(",") + "\n")
661
660
 
662
661
  def plot_metrics(self):
663
662
  """Plot and display metrics visually."""
@@ -265,8 +265,8 @@ class AutoBackend(nn.Module):
265
265
  if -1 in tuple(model.get_tensor_shape(name)):
266
266
  dynamic = True
267
267
  context.set_input_shape(name, tuple(model.get_tensor_profile_shape(name, 0)[1]))
268
- if dtype == np.float16:
269
- fp16 = True
268
+ if dtype == np.float16:
269
+ fp16 = True
270
270
  else:
271
271
  output_names.append(name)
272
272
  shape = tuple(context.get_tensor_shape(name))
@@ -19,7 +19,7 @@ __all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder", "v10D
19
19
 
20
20
 
21
21
  class Detect(nn.Module):
22
- """YOLOv8 Detect head for detection models."""
22
+ """YOLO Detect head for detection models."""
23
23
 
24
24
  dynamic = False # force grid reconstruction
25
25
  export = False # export mode
@@ -30,7 +30,7 @@ class Detect(nn.Module):
30
30
  strides = torch.empty(0) # init
31
31
 
32
32
  def __init__(self, nc=80, ch=()):
33
- """Initializes the YOLOv8 detection layer with specified number of classes and channels."""
33
+ """Initializes the YOLO detection layer with specified number of classes and channels."""
34
34
  super().__init__()
35
35
  self.nc = nc # number of classes
36
36
  self.nl = len(ch) # number of detection layers
@@ -162,7 +162,7 @@ class Detect(nn.Module):
162
162
 
163
163
 
164
164
  class Segment(Detect):
165
- """YOLOv8 Segment head for segmentation models."""
165
+ """YOLO Segment head for segmentation models."""
166
166
 
167
167
  def __init__(self, nc=80, nm=32, npr=256, ch=()):
168
168
  """Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers."""
@@ -187,7 +187,7 @@ class Segment(Detect):
187
187
 
188
188
 
189
189
  class OBB(Detect):
190
- """YOLOv8 OBB detection head for detection with rotation models."""
190
+ """YOLO OBB detection head for detection with rotation models."""
191
191
 
192
192
  def __init__(self, nc=80, ne=1, ch=()):
193
193
  """Initialize OBB with number of classes `nc` and layer channels `ch`."""
@@ -217,7 +217,7 @@ class OBB(Detect):
217
217
 
218
218
 
219
219
  class Pose(Detect):
220
- """YOLOv8 Pose head for keypoints models."""
220
+ """YOLO Pose head for keypoints models."""
221
221
 
222
222
  def __init__(self, nc=80, kpt_shape=(17, 3), ch=()):
223
223
  """Initialize YOLO network with default parameters and Convolutional Layers."""
@@ -257,10 +257,10 @@ class Pose(Detect):
257
257
 
258
258
 
259
259
  class Classify(nn.Module):
260
- """YOLOv8 classification head, i.e. x(b,c1,20,20) to x(b,c2)."""
260
+ """YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2)."""
261
261
 
262
262
  def __init__(self, c1, c2, k=1, s=1, p=None, g=1):
263
- """Initializes YOLOv8 classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape."""
263
+ """Initializes YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape."""
264
264
  super().__init__()
265
265
  c_ = 1280 # efficientnet_b0 size
266
266
  self.conv = Conv(c1, c_, k, s, p, g)
@@ -277,10 +277,10 @@ class Classify(nn.Module):
277
277
 
278
278
 
279
279
  class WorldDetect(Detect):
280
- """Head for integrating YOLOv8 detection models with semantic understanding from text embeddings."""
280
+ """Head for integrating YOLO detection models with semantic understanding from text embeddings."""
281
281
 
282
282
  def __init__(self, nc=80, embed=512, with_bn=False, ch=()):
283
- """Initialize YOLOv8 detection layer with nc classes and layer channels ch."""
283
+ """Initialize YOLO detection layer with nc classes and layer channels ch."""
284
284
  super().__init__(nc, ch)
285
285
  c3 = max(ch[0], min(self.nc, 100))
286
286
  self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, embed, 1)) for x in ch)
ultralytics/nn/tasks.py CHANGED
@@ -1061,10 +1061,10 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
1061
1061
 
1062
1062
  m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
1063
1063
  t = str(m)[8:-2].replace("__main__.", "") # module type
1064
- m.np = sum(x.numel() for x in m_.parameters()) # number params
1064
+ m_.np = sum(x.numel() for x in m_.parameters()) # number params
1065
1065
  m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
1066
1066
  if verbose:
1067
- LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}") # print
1067
+ LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m_.np:10.0f} {t:<45}{str(args):<30}") # print
1068
1068
  save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
1069
1069
  layers.append(m_)
1070
1070
  if i == 0:
@@ -1,116 +1,76 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
- from collections import defaultdict
4
3
  from time import time
5
4
 
6
- import cv2
7
5
  import numpy as np
8
6
 
9
- from ultralytics.utils.checks import check_imshow
7
+ from ultralytics.solutions.solutions import BaseSolution, LineString
10
8
  from ultralytics.utils.plotting import Annotator, colors
11
9
 
12
10
 
13
- class SpeedEstimator:
11
+ class SpeedEstimator(BaseSolution):
14
12
  """A class to estimate the speed of objects in a real-time video stream based on their tracks."""
15
13
 
16
- def __init__(self, names, reg_pts=None, view_img=False, line_thickness=2, spdl_dist_thresh=10):
17
- """
18
- Initializes the SpeedEstimator with the given parameters.
19
-
20
- Args:
21
- names (dict): Dictionary of class names.
22
- reg_pts (list, optional): List of region points for speed estimation. Defaults to [(20, 400), (1260, 400)].
23
- view_img (bool, optional): Whether to display the image with annotations. Defaults to False.
24
- line_thickness (int, optional): Thickness of the lines for drawing boxes and tracks. Defaults to 2.
25
- spdl_dist_thresh (int, optional): Distance threshold for speed calculation. Defaults to 10.
26
- """
27
- # Region information
28
- self.reg_pts = reg_pts if reg_pts is not None else [(20, 400), (1260, 400)]
14
+ def __init__(self, **kwargs):
15
+ """Initializes the SpeedEstimator with the given parameters."""
16
+ super().__init__(**kwargs)
29
17
 
30
- self.names = names # Classes names
18
+ self.initialize_region() # Initialize speed region
31
19
 
32
- # Tracking information
33
- self.trk_history = defaultdict(list)
34
-
35
- self.view_img = view_img # bool for displaying inference
36
- self.tf = line_thickness # line thickness for annotator
37
20
  self.spd = {} # set for speed data
38
21
  self.trkd_ids = [] # list for already speed_estimated and tracked ID's
39
- self.spdl = spdl_dist_thresh # Speed line distance threshold
40
22
  self.trk_pt = {} # set for tracks previous time
41
23
  self.trk_pp = {} # set for tracks previous point
42
24
 
43
- # Check if the environment supports imshow
44
- self.env_check = check_imshow(warn=True)
45
-
46
- def estimate_speed(self, im0, tracks):
25
+ def estimate_speed(self, im0):
47
26
  """
48
27
  Estimates the speed of objects based on tracking data.
49
28
 
50
29
  Args:
51
- im0 (ndarray): Image.
52
- tracks (list): List of tracks obtained from the object tracking process.
53
-
54
- Returns:
55
- (ndarray): The image with annotated boxes and tracks.
30
+ im0 (ndarray): The input image that will be used for processing
31
+ Returns
32
+ im0 (ndarray): The processed image for more usage
56
33
  """
57
- if tracks[0].boxes.id is None:
58
- return im0
34
+ self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator
35
+ self.extract_tracks(im0) # Extract tracks
59
36
 
60
- boxes = tracks[0].boxes.xyxy.cpu()
61
- clss = tracks[0].boxes.cls.cpu().tolist()
62
- t_ids = tracks[0].boxes.id.int().cpu().tolist()
63
- annotator = Annotator(im0, line_width=self.tf)
64
- annotator.draw_region(reg_pts=self.reg_pts, color=(255, 0, 255), thickness=self.tf * 2)
37
+ self.annotator.draw_region(
38
+ reg_pts=self.region, color=(104, 0, 123), thickness=self.line_width * 2
39
+ ) # Draw region
65
40
 
66
- for box, t_id, cls in zip(boxes, t_ids, clss):
67
- track = self.trk_history[t_id]
68
- bbox_center = (float((box[0] + box[2]) / 2), float((box[1] + box[3]) / 2))
69
- track.append(bbox_center)
41
+ for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss):
42
+ self.store_tracking_history(track_id, box) # Store track history
70
43
 
71
- if len(track) > 30:
72
- track.pop(0)
44
+ # Check if track_id is already in self.trk_pp or trk_pt initialize if not
45
+ if track_id not in self.trk_pt:
46
+ self.trk_pt[track_id] = 0
47
+ if track_id not in self.trk_pp:
48
+ self.trk_pp[track_id] = self.track_line[-1]
73
49
 
74
- trk_pts = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
50
+ speed_label = f"{int(self.spd[track_id])} km/h" if track_id in self.spd else self.names[int(cls)]
51
+ self.annotator.box_label(box, label=speed_label, color=colors(track_id, True)) # Draw bounding box
75
52
 
76
- if t_id not in self.trk_pt:
77
- self.trk_pt[t_id] = 0
53
+ # Draw tracks of objects
54
+ self.annotator.draw_centroid_and_tracks(
55
+ self.track_line, color=colors(int(track_id), True), track_thickness=self.line_width
56
+ )
78
57
 
79
- speed_label = f"{int(self.spd[t_id])} km/h" if t_id in self.spd else self.names[int(cls)]
80
- bbox_color = colors(int(t_id), True)
81
-
82
- annotator.box_label(box, speed_label, bbox_color)
83
- cv2.polylines(im0, [trk_pts], isClosed=False, color=bbox_color, thickness=self.tf)
84
- cv2.circle(im0, (int(track[-1][0]), int(track[-1][1])), self.tf * 2, bbox_color, -1)
85
-
86
- # Calculation of object speed
87
- if not self.reg_pts[0][0] < track[-1][0] < self.reg_pts[1][0]:
88
- return
89
- if self.reg_pts[1][1] - self.spdl < track[-1][1] < self.reg_pts[1][1] + self.spdl:
90
- direction = "known"
91
- elif self.reg_pts[0][1] - self.spdl < track[-1][1] < self.reg_pts[0][1] + self.spdl:
58
+ # Calculate object speed and direction based on region intersection
59
+ if LineString([self.trk_pp[track_id], self.track_line[-1]]).intersects(self.l_s):
92
60
  direction = "known"
93
61
  else:
94
62
  direction = "unknown"
95
63
 
96
- if self.trk_pt.get(t_id) != 0 and direction != "unknown" and t_id not in self.trkd_ids:
97
- self.trkd_ids.append(t_id)
98
-
99
- time_difference = time() - self.trk_pt[t_id]
64
+ # Perform speed calculation and tracking updates if direction is valid
65
+ if direction == "known" and track_id not in self.trkd_ids:
66
+ self.trkd_ids.append(track_id)
67
+ time_difference = time() - self.trk_pt[track_id]
100
68
  if time_difference > 0:
101
- self.spd[t_id] = np.abs(track[-1][1] - self.trk_pp[t_id][1]) / time_difference
102
-
103
- self.trk_pt[t_id] = time()
104
- self.trk_pp[t_id] = track[-1]
105
-
106
- if self.view_img and self.env_check:
107
- cv2.imshow("Ultralytics Speed Estimation", im0)
108
- if cv2.waitKey(1) & 0xFF == ord("q"):
109
- return
69
+ self.spd[track_id] = np.abs(self.track_line[-1][1] - self.trk_pp[track_id][1]) / time_difference
110
70
 
111
- return im0
71
+ self.trk_pt[track_id] = time()
72
+ self.trk_pp[track_id] = self.track_line[-1]
112
73
 
74
+ self.display_output(im0) # display output with base class function
113
75
 
114
- if __name__ == "__main__":
115
- names = {0: "person", 1: "car"} # example class names
116
- speed_estimator = SpeedEstimator(names)
76
+ return im0 # return output image for more usage
@@ -67,7 +67,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
67
67
  LOGGER.info(f"{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free")
68
68
 
69
69
  # Profile batch sizes
70
- batch_sizes = [1, 2, 4, 8, 16]
70
+ batch_sizes = [1, 2, 4, 8, 16] if t < 16 else [1, 2, 4, 8, 16, 32, 64]
71
71
  try:
72
72
  img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
73
73
  results = profile(img, model, n=1, device=device)
@@ -68,9 +68,9 @@ def on_pretrain_routine_start(trainer):
68
68
  PatchedMatplotlib.update_current_task(None)
69
69
  else:
70
70
  task = Task.init(
71
- project_name=trainer.args.project or "YOLOv8",
71
+ project_name=trainer.args.project or "Ultralytics",
72
72
  task_name=trainer.args.name,
73
- tags=["YOLOv8"],
73
+ tags=["Ultralytics"],
74
74
  output_uri=True,
75
75
  reuse_last_task_id=False,
76
76
  auto_connect_frameworks={"pytorch": False, "matplotlib": False},
@@ -15,7 +15,7 @@ try:
15
15
  # Ensures certain logging functions only run for supported tasks
16
16
  COMET_SUPPORTED_TASKS = ["detect"]
17
17
 
18
- # Names of plots created by YOLOv8 that are logged to Comet
18
+ # Names of plots created by Ultralytics that are logged to Comet
19
19
  EVALUATION_PLOT_NAMES = "F1_curve", "P_curve", "R_curve", "PR_curve", "confusion_matrix"
20
20
  LABEL_PLOT_NAMES = "labels", "labels_correlogram"
21
21
 
@@ -31,8 +31,8 @@ def _get_comet_mode():
31
31
 
32
32
 
33
33
  def _get_comet_model_name():
34
- """Returns the model name for Comet from the environment variable 'COMET_MODEL_NAME' or defaults to 'YOLOv8'."""
35
- return os.getenv("COMET_MODEL_NAME", "YOLOv8")
34
+ """Returns the model name for Comet from the environment variable COMET_MODEL_NAME or defaults to 'Ultralytics'."""
35
+ return os.getenv("COMET_MODEL_NAME", "Ultralytics")
36
36
 
37
37
 
38
38
  def _get_eval_batch_logging_interval():
@@ -110,7 +110,7 @@ def _fetch_trainer_metadata(trainer):
110
110
 
111
111
  def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad):
112
112
  """
113
- YOLOv8 resizes images during training and the label values are normalized based on this resized shape.
113
+ YOLO resizes images during training and the label values are normalized based on this resized shape.
114
114
 
115
115
  This function rescales the bounding box labels to the original image shape.
116
116
  """
@@ -71,7 +71,7 @@ def on_pretrain_routine_end(trainer):
71
71
  mlflow.set_tracking_uri(uri)
72
72
 
73
73
  # Set experiment and run names
74
- experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME") or trainer.args.project or "/Shared/YOLOv8"
74
+ experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME") or trainer.args.project or "/Shared/Ultralytics"
75
75
  run_name = os.environ.get("MLFLOW_RUN") or trainer.args.name
76
76
  mlflow.set_experiment(experiment_name)
77
77
 
@@ -52,7 +52,11 @@ def on_pretrain_routine_start(trainer):
52
52
  """Callback function called before the training routine starts."""
53
53
  try:
54
54
  global run
55
- run = neptune.init_run(project=trainer.args.project or "YOLOv8", name=trainer.args.name, tags=["YOLOv8"])
55
+ run = neptune.init_run(
56
+ project=trainer.args.project or "Ultralytics",
57
+ name=trainer.args.name,
58
+ tags=["Ultralytics"],
59
+ )
56
60
  run["Configuration/Hyperparameters"] = {k: "" if v is None else v for k, v in vars(trainer.args).items()}
57
61
  except Exception as e:
58
62
  LOGGER.warning(f"WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}")
@@ -109,7 +109,7 @@ def _log_plots(plots, step):
109
109
 
110
110
  def on_pretrain_routine_start(trainer):
111
111
  """Initiate and start project if module is present."""
112
- wb.run or wb.init(project=trainer.args.project or "YOLOv8", name=trainer.args.name, config=vars(trainer.args))
112
+ wb.run or wb.init(project=trainer.args.project or "Ultralytics", name=trainer.args.name, config=vars(trainer.args))
113
113
 
114
114
 
115
115
  def on_fit_epoch_end(trainer):
@@ -238,12 +238,14 @@ def check_version(
238
238
  c = parse_version(current) # '1.2.3' -> (1, 2, 3)
239
239
  for r in required.strip(",").split(","):
240
240
  op, version = re.match(r"([^0-9]*)([\d.]+)", r).groups() # split '>=22.04' -> ('>=', '22.04')
241
+ if not op:
242
+ op = ">=" # assume >= if no op passed
241
243
  v = parse_version(version) # '1.2.3' -> (1, 2, 3)
242
244
  if op == "==" and c != v:
243
245
  result = False
244
246
  elif op == "!=" and c == v:
245
247
  result = False
246
- elif op in {">=", ""} and not (c >= v): # if no constraint passed assume '>=required'
248
+ elif op == ">=" and not (c >= v):
247
249
  result = False
248
250
  elif op == "<=" and not (c <= v):
249
251
  result = False
@@ -333,18 +335,19 @@ def check_font(font="Arial.ttf"):
333
335
  return file
334
336
 
335
337
 
336
- def check_python(minimum: str = "3.8.0", hard: bool = True) -> bool:
338
+ def check_python(minimum: str = "3.8.0", hard: bool = True, verbose: bool = True) -> bool:
337
339
  """
338
340
  Check current python version against the required minimum version.
339
341
 
340
342
  Args:
341
343
  minimum (str): Required minimum version of python.
342
344
  hard (bool, optional): If True, raise an AssertionError if the requirement is not met.
345
+ verbose (bool, optional): If True, print warning message if requirement is not met.
343
346
 
344
347
  Returns:
345
348
  (bool): Whether the installed Python version meets the minimum constraints.
346
349
  """
347
- return check_version(PYTHON_VERSION, minimum, name="Python", hard=hard)
350
+ return check_version(PYTHON_VERSION, minimum, name="Python", hard=hard, verbose=verbose)
348
351
 
349
352
 
350
353
  @TryExcept()
@@ -374,8 +377,6 @@ def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=()
374
377
  ```
375
378
  """
376
379
  prefix = colorstr("red", "bold", "requirements:")
377
- check_python() # check python version
378
- check_torchvision() # check torch-torchvision compatibility
379
380
  if isinstance(requirements, Path): # requirements.txt file
380
381
  file = requirements.resolve()
381
382
  assert file.exists(), f"{prefix} {file} not found, check failed."
@@ -770,6 +771,8 @@ def cuda_is_available() -> bool:
770
771
  return cuda_device_count() > 0
771
772
 
772
773
 
773
- # Define constants
774
+ # Run checks and define constants
775
+ check_python("3.8", hard=False, verbose=True) # check python version
776
+ check_torchvision() # check torch-torchvision compatibility
774
777
  IS_PYTHON_MINIMUM_3_10 = check_python("3.10", hard=False)
775
778
  IS_PYTHON_3_12 = PYTHON_VERSION.startswith("3.12")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ultralytics
3
- Version: 8.3.8
3
+ Version: 8.3.10
4
4
  Summary: Ultralytics YOLO for SOTA object detection, multi-object tracking, instance segmentation, pose estimation and image classification.
5
5
  Author-email: Glenn Jocher <glenn.jocher@ultralytics.com>, Jing Qiu <jing.qiu@ultralytics.com>
6
6
  Maintainer-email: Ultralytics <hello@ultralytics.com>
@@ -3,12 +3,11 @@ tests/conftest.py,sha256=9PFAiwAy6eeORGspr5dOKxVuFDVKqYg8Nn_RxSJ27UI,2919
3
3
  tests/test_cli.py,sha256=E4lMt49TGo12Lb5CgQfpk1bwyFUZuFxF0V9j_ykV7xM,4821
4
4
  tests/test_cuda.py,sha256=KoRtRLUB7KOb9IXYX4mCi295Uh_cZEEFhCyvCDGRK9s,5381
5
5
  tests/test_engine.py,sha256=dcEcJsMQh61rDSNv7l4TIAgybLpzjVwerv9JZC_KCM8,4934
6
- tests/test_explorer.py,sha256=9EeMtt4-K3-MeGnAc7NemTg3uTo-Xr6AYJlTJZJJeF8,2572
7
6
  tests/test_exports.py,sha256=fpTKEVBUGLF3WiZPNKRs-IEcIY4cfxgvgKjUNfodjww,8042
8
7
  tests/test_integrations.py,sha256=f5-QCUk1SU_-qn4mBCZwS3GN3tXEBIIXo4z2EhExbHw,6126
9
8
  tests/test_python.py,sha256=I1RRdCwLdrc3jX06huVxct8HX8ccQOmQgVpuEflRl0U,23560
10
- tests/test_solutions.py,sha256=kJzvUiOTmFVeM_90_7vwvhrREGrN2aGwcBi-F-a13NU,3126
11
- ultralytics/__init__.py,sha256=xqr6W7SxucuDpDIAK7_hCzACbw8DLZQHjSo-7vDqXBo,752
9
+ tests/test_solutions.py,sha256=dpxWGKO-aJ3Yff4KR7BQGajX9VyFdGTWEtcbmFC3WwE,3005
10
+ ultralytics/__init__.py,sha256=r6LeW7qfMLanc7g2MYd3t48Oqt6LLDRB_PJFMVyyK3E,753
12
11
  ultralytics/assets/bus.jpg,sha256=wCAZxJecGR63Od3ZRERe9Aja1Weayrb9Ug751DS_vGM,137419
13
12
  ultralytics/assets/zidane.jpg,sha256=Ftc4aeMmen1O0A3o6GCDO9FlfBslLpTAw0gnetx7bts,50427
14
13
  ultralytics/cfg/__init__.py,sha256=N-XONBXwmD3vzoE4icBXznkV8LOLmf6ak6mRdGPucvw,33146
@@ -86,7 +85,7 @@ ultralytics/cfg/models/v9/yolov9e.yaml,sha256=dhaR47WxuLOrZWDCceS4bQG00sQdrMc8FQ
86
85
  ultralytics/cfg/models/v9/yolov9m.yaml,sha256=l6CmivzNu44sRVmkQXk4-tXflbV1nWnk5MSc8su2vhs,1311
87
86
  ultralytics/cfg/models/v9/yolov9s.yaml,sha256=lPWcu-6ub1kCBD6zIDFwthYZ3RvdJfODWKy3vEQWRjo,1291
88
87
  ultralytics/cfg/models/v9/yolov9t.yaml,sha256=qL__kr6GoefpQWP4jV0jdzwTp46bdFUcqtPRnfDbkY8,1275
89
- ultralytics/cfg/solutions/default.yaml,sha256=H4pXUoA-IafiHL6NNNTyWXrlHAjMRqaItGK1U5amNE4,825
88
+ ultralytics/cfg/solutions/default.yaml,sha256=CByxINYMyoGzGKdurDk2GhYc8XOa8Z6H7CZx7uZSPpc,1532
90
89
  ultralytics/cfg/trackers/botsort.yaml,sha256=8B0xNbnG_E-9DCUpap72PWkUgBb1AjuApEn7gHiVngE,916
91
90
  ultralytics/cfg/trackers/bytetrack.yaml,sha256=8vpTZ2x9mhRXJymoJvs1G8kTXo_HxbSwHup2FQALT3A,721
92
91
  ultralytics/data/__init__.py,sha256=VGe-ATG7j35F4A4r8Jmzffjlhve4JAJPgRa5ahKTU18,616
@@ -96,9 +95,9 @@ ultralytics/data/base.py,sha256=ZCIhAyFfxXVp5fVnYD8mwbksNALJTayBKIR5FKGV7ZM,1516
96
95
  ultralytics/data/build.py,sha256=AfMmz0sHIYmwry_90tEJFRk_kz0S3SolScVXqYHiT08,7261
97
96
  ultralytics/data/converter.py,sha256=QCtrcbNz9kid8nvHfGIWt02nH1wwMKv6HI-8s927CR8,24251
98
97
  ultralytics/data/dataset.py,sha256=D556AW0ZEsW3V8c5zJiHM_prc_YfZqymIkDKPw3k9Io,22936
99
- ultralytics/data/loaders.py,sha256=JF2Z_ESK6RweavOuYWejYSGJwmqINb5hNwwCb3AAf0M,24094
98
+ ultralytics/data/loaders.py,sha256=Fr70Q9p9t7buLW_8R2_lI_nyCMG033gWSxvwy1M-a-U,28449
100
99
  ultralytics/data/split_dota.py,sha256=yOtypHoY5HvIVBKZgFXdfj2tuCLLEBnMwNfAeG94Eik,10680
101
- ultralytics/data/utils.py,sha256=BK4Z87fDHfNCd6RYVYVWdTVWc8-tCqNJ-VfeN8ZG8l0,31068
100
+ ultralytics/data/utils.py,sha256=u6OZ7InLpI1em5aEPz13ZzS9BcO37dcY9_s2btXGZYQ,31076
102
101
  ultralytics/data/explorer/__init__.py,sha256=-Y3m1ZedepOQUv_KW82zaGxvU_PSHcuwUTFqG9BhAr4,113
103
102
  ultralytics/data/explorer/explorer.py,sha256=JWmLHHhp68h2q3vx4poBou5RYoAX3R89yihR50YLDb0,18881
104
103
  ultralytics/data/explorer/utils.py,sha256=EvvukQiQUTBrsZznmMnyEX2EqTuwZo_Geyc8yfi8NIA,7085
@@ -107,9 +106,9 @@ ultralytics/data/explorer/gui/dash.py,sha256=6XOZy9NrkPEXREJPbi0EBkGgu78TAdHpdhS
107
106
  ultralytics/engine/__init__.py,sha256=mHtJuK4hwF8cuV-VHDc7tp6u6D1gHz2Z7JI8grmQDTs,42
108
107
  ultralytics/engine/exporter.py,sha256=DeHW_T_Zd3A21BLQYV1-FnS5EcmepMOy9nrussYNieU,57505
109
108
  ultralytics/engine/model.py,sha256=Vtkza7cQrxvowb0PqGFhp7eC3cXRIKj6OUaR5d9w1-U,51464
110
- ultralytics/engine/predictor.py,sha256=MgMWHUJdRcVCaVmOyvdy2Gjk_EyRHv-ar0SSGxQe8F4,17471
109
+ ultralytics/engine/predictor.py,sha256=keTelEeo23Dcbs-XvmRWAPIs4pbCNDtsMBz88WM1eK8,17534
111
110
  ultralytics/engine/results.py,sha256=8RJlN8J-_9w-mrDZm9wC-DZJTPBS7v1c_r_R173QyRM,75043
112
- ultralytics/engine/trainer.py,sha256=ZCEXUPbJG_8Hzn2mLergk3WV-41ei0LT84Tspk0le30,37147
111
+ ultralytics/engine/trainer.py,sha256=6dGOEZvMo3o97SLpKlcR5XmhWhUHh05uLYpj3jNn0jU,36981
113
112
  ultralytics/engine/tuner.py,sha256=gPqDTHH7vRB2O3YyH26m1BjVKbXxuA2XAlPRzTKFZsc,11838
114
113
  ultralytics/engine/validator.py,sha256=2C_qXI36Z9rLOpmS0YR8Qe3ka4p23YiH2w5ai7-XBwE,14811
115
114
  ultralytics/hub/__init__.py,sha256=3SKvZ5aRina3h94xMPQIB3D4maF62qFcyIqPPHRHNAc,5644
@@ -175,13 +174,13 @@ ultralytics/models/yolo/world/__init__.py,sha256=3VTH0q4NOt2EWRom15yCymvmvm0Etp2
175
174
  ultralytics/models/yolo/world/train.py,sha256=gaDrAmLJpg9qDtmL5evA5HsV2yb4RTRSfk2EDYrHdRg,3686
176
175
  ultralytics/models/yolo/world/train_world.py,sha256=IsnCEVt6DcM9lUskCKmIN-M8MM79xLpwTRqRoAHUnZ4,4857
177
176
  ultralytics/nn/__init__.py,sha256=4BPLHY89xEM_al5uK0aOmFgiML6CMGEZbezxOvTjOEs,587
178
- ultralytics/nn/autobackend.py,sha256=lyOXfZC4jgSebv52YpHlrfUNKp_kVBmIvydb9k0OKFQ,31607
179
- ultralytics/nn/tasks.py,sha256=O4i5JywqZZ2llESZ39PbojhsQcbFV5Yc1G5moiS80bM,48397
177
+ ultralytics/nn/autobackend.py,sha256=aBW_Z8XnSsD-vW7Ek873dyKX9h55XHIYwTG22M3eOIk,31599
178
+ ultralytics/nn/tasks.py,sha256=ssBZR4LY4rvaxYawXq5-yWSBAZ9oCz6BgxWYXB2YD68,48399
180
179
  ultralytics/nn/modules/__init__.py,sha256=xhW2BennT9U_VaMXVpRu-bdLgp1BXt9L8mkIUBE3idU,2625
181
180
  ultralytics/nn/modules/activation.py,sha256=chhn469wnRHEs5BMGNBYXwPYZc_7-urspTT8fnBd-xA,895
182
181
  ultralytics/nn/modules/block.py,sha256=thcIPcnGRRxDDDswywJsfzbewr9XfTrzl_UvSl-bJ3c,41832
183
182
  ultralytics/nn/modules/conv.py,sha256=vOeHZ6Z4sc6-9PrDmRGT1hFkxSBbbWkQm2jRbGGjpqQ,12705
184
- ultralytics/nn/modules/head.py,sha256=x0Y8lTKFqYC4oAN1JTJ-yQ43sIXEIp35dmC14vdtQnk,26627
183
+ ultralytics/nn/modules/head.py,sha256=WnCpQDBlMDStpEs-m-R0vcKq28OX2FEgTcmHEpRL_pA,26609
185
184
  ultralytics/nn/modules/transformer.py,sha256=tGiK8NmPfswwW1rbF21r5ILUkkZQ6Nk4s8j16vFBmps,18069
186
185
  ultralytics/nn/modules/utils.py,sha256=a88cKl2wz1nMVSEBiajtvaCbDBQIkESWOKTZ_WAJy90,3195
187
186
  ultralytics/solutions/__init__.py,sha256=6RDeXWO1QSaMgCq8YrWXaj2xvPw2sJwJL_a0dgjCvz0,648
@@ -193,7 +192,7 @@ ultralytics/solutions/object_counter.py,sha256=1Nsivk-cyGBM1G6eWe11_vdDWTdbJwaUF
193
192
  ultralytics/solutions/parking_management.py,sha256=VgYyhoSEo7fnPegIhNUqnFL0jlMEevALx0QQbzJ3vGI,9049
194
193
  ultralytics/solutions/queue_management.py,sha256=5d1RURQiqffAoET8S66gHimK0l3gKNAfuPO5U6_08jc,2716
195
194
  ultralytics/solutions/solutions.py,sha256=qWKGlwlH9858GfAdZkcu_QXbrzjTFStDvg16Eky0oyo,3541
196
- ultralytics/solutions/speed_estimation.py,sha256=c9OPGpDU9x6Dj4SobNc-sO90EZTPTGeKkW5u6C6Zj7g,4623
195
+ ultralytics/solutions/speed_estimation.py,sha256=2jLTEdnSF3Mm3Z7QJVPCUq84-7L6ELIJIR_sPFBW_cU,3164
197
196
  ultralytics/solutions/streamlit_inference.py,sha256=qA2EtwUC7ADOQ8P-zs3VPyrIoRArhcZz9CxkFbH63bw,5699
198
197
  ultralytics/trackers/__init__.py,sha256=j72IgH2dZHQArMPK4YwcV5ieIw94fYvlGdQjB9cOQKw,227
199
198
  ultralytics/trackers/basetrack.py,sha256=dXnXW3cxxd7lPm20JJCNO2voCIrQ4vhbNI1g4YEgn-Y,4423
@@ -205,9 +204,9 @@ ultralytics/trackers/utils/gmc.py,sha256=VcURuY041qGCeWUGMxHZBr10T16LtcMqyv7AmTf
205
204
  ultralytics/trackers/utils/kalman_filter.py,sha256=cH9zD3fwkuezP97H9mw8cSBN7a8hHKx_Sx1j7t3oYGs,21349
206
205
  ultralytics/trackers/utils/matching.py,sha256=3Ie1WNNRZ4_q3365F03XD7Nr9juZB_08mw4yUKC3w74,7162
207
206
  ultralytics/utils/__init__.py,sha256=du1Y1LMU0jQn_zWWnAIx9U8wn6Vh7ce-k7qMwi6y0po,48698
208
- ultralytics/utils/autobatch.py,sha256=1ZDy3vvUDKkxROHnxT3_vI4MJ52l9ap7SiuQvG4B-8k,4290
207
+ ultralytics/utils/autobatch.py,sha256=BO9MCRtrLDtrDQaxqV0BdjaYsgXf-q07Y3_VdGp4URY,4330
209
208
  ultralytics/utils/benchmarks.py,sha256=8FYp5WPzcxcDaeg8ol2sgzRBHVGYatEO7f3MrmPF6nI,25097
210
- ultralytics/utils/checks.py,sha256=7oWc91HqQdH9EHuHysxk_ZltiRrGt6eq-pUf0TkA3gU,29579
209
+ ultralytics/utils/checks.py,sha256=SsB3s1z9TtMjGelDkGZIi6B40VXmCtGw2hcOCyPikx4,29765
211
210
  ultralytics/utils/dist.py,sha256=NDFga-uKxkBX2zLxFHSene_cCiGQJoyOeCXcN9JIOIk,2358
212
211
  ultralytics/utils/downloads.py,sha256=o8RY9f0KrzWfueLs8DuJ5w8OWQ-ll4ZS9lX6MEFDi70,21977
213
212
  ultralytics/utils/errors.py,sha256=GqP_Jgj_n0paxn8OMhn3DTCgoNkB2WjUcUaqs-M6SQk,816
@@ -224,18 +223,18 @@ ultralytics/utils/triton.py,sha256=gg1finxno_tY2Ge9PMhmu7PI9wvoFZoiicdT4Bhqv3w,3
224
223
  ultralytics/utils/tuner.py,sha256=AtEtK6pOt9xVTyx864OpNRVxNdAxz5aKHzveiXwkD1A,6250
225
224
  ultralytics/utils/callbacks/__init__.py,sha256=YrWqC3BVVaTLob4iCPR6I36mUxIUOpPJW7B_LjT78Qw,214
226
225
  ultralytics/utils/callbacks/base.py,sha256=PHjQ6RITwC2dylCQTB0bdPgAsHjxVeuDb5N1NPTbHGc,5775
227
- ultralytics/utils/callbacks/clearml.py,sha256=M9Fi1OfdWqcm8uVkauuX3zJIYhNh6Tp7Jo4CfA0u0nw,5923
228
- ultralytics/utils/callbacks/comet.py,sha256=ATWjZJigLy8lJVYjlwyCha-lJ-QlMfXw-zE9PA7UxqY,13743
226
+ ultralytics/utils/callbacks/clearml.py,sha256=qbLbqzMVWAnjqg5YUM-Ue6CmGueFCvqKpHFKlw-MyVc,5933
227
+ ultralytics/utils/callbacks/comet.py,sha256=DS5w9fgo0eWfjRuIywTlKEv2LY4eOKklEq-DyoIQn7U,13754
229
228
  ultralytics/utils/callbacks/dvc.py,sha256=WIClMsuvhiiyrwRv5BsZLxjsxYNJ3Y8Vq7zN0Bthtro,5045
230
229
  ultralytics/utils/callbacks/hub.py,sha256=EPewsLigFQc9ucTX2exKSlKBiaBNhYYyGC_nR2ragJo,3997
231
- ultralytics/utils/callbacks/mlflow.py,sha256=_bUzHyPb0npne0WFlGzlGCy-X5sxGQhC_xA3dZbF08I,5391
232
- ultralytics/utils/callbacks/neptune.py,sha256=5Z3ua5YBTUS56FH8VQKQG1aaIo9fH8GEyzC5q7p4ipQ,3756
230
+ ultralytics/utils/callbacks/mlflow.py,sha256=mkl_rK0Gy02cXnQUYmzmLE5W97fMgfEb7IlgOAdnjHg,5396
231
+ ultralytics/utils/callbacks/neptune.py,sha256=IbGQfEltamUKXJt93uSLQFn8c2rYh3DMTgVE1xsnmUI,3813
233
232
  ultralytics/utils/callbacks/raytune.py,sha256=ODVYzy-CoM4Uge0zjkh3Hnh9nF2M0vhDrSenXnvcizw,705
234
233
  ultralytics/utils/callbacks/tensorboard.py,sha256=bv4fkkesdgmZv_E2MU6wuaMBwEV5iI2G53RHPyD9quw,4170
235
- ultralytics/utils/callbacks/wb.py,sha256=9-fjQIdLjr3b73DTE3rHO171KvbH1VweJ-bmbv-rqTw,6747
236
- ultralytics-8.3.8.dist-info/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
237
- ultralytics-8.3.8.dist-info/METADATA,sha256=omBKv11I1DidGjRCYOu07_VBN1yVEDB8Ccx2jNIw5Rk,34699
238
- ultralytics-8.3.8.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
239
- ultralytics-8.3.8.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
240
- ultralytics-8.3.8.dist-info/top_level.txt,sha256=XP49TwiMw4QGsvTLSYiJhz1xF_k7ev5mQ8jJXaXi45Q,12
241
- ultralytics-8.3.8.dist-info/RECORD,,
234
+ ultralytics/utils/callbacks/wb.py,sha256=upfbF8-LLXueUvulLaMDmKDhKCl_PWbNa_87PQ0L0Rc,6752
235
+ ultralytics-8.3.10.dist-info/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
236
+ ultralytics-8.3.10.dist-info/METADATA,sha256=erZGLlFck6gorIKxGLLR-ymgpHCb5WiGGa89PyM_sQs,34700
237
+ ultralytics-8.3.10.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
238
+ ultralytics-8.3.10.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
239
+ ultralytics-8.3.10.dist-info/top_level.txt,sha256=XP49TwiMw4QGsvTLSYiJhz1xF_k7ev5mQ8jJXaXi45Q,12
240
+ ultralytics-8.3.10.dist-info/RECORD,,
tests/test_explorer.py DELETED
@@ -1,66 +0,0 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
2
-
3
- import PIL
4
- import pytest
5
-
6
- from ultralytics import Explorer
7
- from ultralytics.utils import ASSETS
8
- from ultralytics.utils.torch_utils import TORCH_1_13
9
-
10
-
11
- @pytest.mark.slow
12
- @pytest.mark.skipif(not TORCH_1_13, reason="Explorer requires torch>=1.13")
13
- def test_similarity():
14
- """Test the correctness and response length of similarity calculations and SQL queries in the Explorer."""
15
- exp = Explorer(data="coco8.yaml")
16
- exp.create_embeddings_table()
17
- similar = exp.get_similar(idx=1)
18
- assert len(similar) == 4
19
- similar = exp.get_similar(img=ASSETS / "bus.jpg")
20
- assert len(similar) == 4
21
- similar = exp.get_similar(idx=[1, 2], limit=2)
22
- assert len(similar) == 2
23
- sim_idx = exp.similarity_index()
24
- assert len(sim_idx) == 4
25
- sql = exp.sql_query("WHERE labels LIKE '%zebra%'")
26
- assert len(sql) == 1
27
-
28
-
29
- @pytest.mark.slow
30
- @pytest.mark.skipif(not TORCH_1_13, reason="Explorer requires torch>=1.13")
31
- def test_det():
32
- """Test detection functionalities and verify embedding table includes bounding boxes."""
33
- exp = Explorer(data="coco8.yaml", model="yolo11n.pt")
34
- exp.create_embeddings_table(force=True)
35
- assert len(exp.table.head()["bboxes"]) > 0
36
- similar = exp.get_similar(idx=[1, 2], limit=10)
37
- assert len(similar) > 0
38
- # This is a loose test, just checks errors not correctness
39
- similar = exp.plot_similar(idx=[1, 2], limit=10)
40
- assert isinstance(similar, PIL.Image.Image)
41
-
42
-
43
- @pytest.mark.slow
44
- @pytest.mark.skipif(not TORCH_1_13, reason="Explorer requires torch>=1.13")
45
- def test_seg():
46
- """Test segmentation functionalities and ensure the embedding table includes segmentation masks."""
47
- exp = Explorer(data="coco8-seg.yaml", model="yolo11n-seg.pt")
48
- exp.create_embeddings_table(force=True)
49
- assert len(exp.table.head()["masks"]) > 0
50
- similar = exp.get_similar(idx=[1, 2], limit=10)
51
- assert len(similar) > 0
52
- similar = exp.plot_similar(idx=[1, 2], limit=10)
53
- assert isinstance(similar, PIL.Image.Image)
54
-
55
-
56
- @pytest.mark.slow
57
- @pytest.mark.skipif(not TORCH_1_13, reason="Explorer requires torch>=1.13")
58
- def test_pose():
59
- """Test pose estimation functionality and verify the embedding table includes keypoints."""
60
- exp = Explorer(data="coco8-pose.yaml", model="yolo11n-pose.pt")
61
- exp.create_embeddings_table(force=True)
62
- assert len(exp.table.head()["keypoints"]) > 0
63
- similar = exp.get_similar(idx=[1, 2], limit=10)
64
- assert len(similar) > 0
65
- similar = exp.plot_similar(idx=[1, 2], limit=10)
66
- assert isinstance(similar, PIL.Image.Image)