ultralytics 8.3.137__py3-none-any.whl → 8.3.139__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. tests/test_python.py +6 -1
  2. tests/test_solutions.py +183 -8
  3. ultralytics/__init__.py +1 -1
  4. ultralytics/cfg/__init__.py +1 -1
  5. ultralytics/data/base.py +1 -1
  6. ultralytics/data/build.py +4 -3
  7. ultralytics/data/loaders.py +2 -2
  8. ultralytics/engine/exporter.py +5 -5
  9. ultralytics/engine/model.py +2 -2
  10. ultralytics/engine/predictor.py +3 -10
  11. ultralytics/engine/results.py +2 -209
  12. ultralytics/engine/trainer.py +1 -1
  13. ultralytics/engine/validator.py +1 -1
  14. ultralytics/hub/auth.py +2 -2
  15. ultralytics/hub/utils.py +8 -3
  16. ultralytics/models/yolo/classify/predict.py +11 -0
  17. ultralytics/models/yolo/obb/val.py +1 -1
  18. ultralytics/models/yolo/world/train.py +1 -1
  19. ultralytics/models/yolo/yoloe/val.py +3 -3
  20. ultralytics/solutions/similarity_search.py +3 -6
  21. ultralytics/solutions/streamlit_inference.py +1 -1
  22. ultralytics/utils/__init__.py +159 -1
  23. ultralytics/utils/callbacks/hub.py +5 -4
  24. ultralytics/utils/checks.py +25 -18
  25. ultralytics/utils/downloads.py +7 -5
  26. ultralytics/utils/export.py +1 -1
  27. ultralytics/utils/metrics.py +90 -5
  28. ultralytics/utils/plotting.py +1 -1
  29. ultralytics/utils/torch_utils.py +3 -0
  30. ultralytics/utils/triton.py +1 -1
  31. {ultralytics-8.3.137.dist-info → ultralytics-8.3.139.dist-info}/METADATA +1 -1
  32. {ultralytics-8.3.137.dist-info → ultralytics-8.3.139.dist-info}/RECORD +36 -36
  33. {ultralytics-8.3.137.dist-info → ultralytics-8.3.139.dist-info}/WHEEL +0 -0
  34. {ultralytics-8.3.137.dist-info → ultralytics-8.3.139.dist-info}/entry_points.txt +0 -0
  35. {ultralytics-8.3.137.dist-info → ultralytics-8.3.139.dist-info}/licenses/LICENSE +0 -0
  36. {ultralytics-8.3.137.dist-info → ultralytics-8.3.139.dist-info}/top_level.txt +0 -0
tests/test_python.py CHANGED
@@ -198,7 +198,12 @@ def test_track_stream():
198
198
 
199
199
  def test_val():
200
200
  """Test the validation mode of the YOLO model."""
201
- YOLO(MODEL).val(data="coco8.yaml", imgsz=32)
201
+ metrics = YOLO(MODEL).val(data="coco8.yaml", imgsz=32)
202
+ metrics.to_df()
203
+ metrics.to_csv()
204
+ metrics.to_xml()
205
+ metrics.to_html()
206
+ metrics.to_json()
202
207
 
203
208
 
204
209
  def test_train_scratch():
tests/test_solutions.py CHANGED
@@ -3,7 +3,11 @@
3
3
  # Tests Ultralytics Solutions: https://docs.ultralytics.com/solutions/,
4
4
  # including every solution excluding DistanceCalculation and Security Alarm System.
5
5
 
6
+ import os
7
+ from unittest.mock import patch
8
+
6
9
  import cv2
10
+ import numpy as np
7
11
  import pytest
8
12
 
9
13
  from tests import MODEL, TMP
@@ -19,7 +23,10 @@ POSE_VIDEO = "solution_ci_pose_demo.mp4" # only for workouts monitoring solutio
19
23
  PARKING_VIDEO = "solution_ci_parking_demo.mp4" # only for parking management solution
20
24
  PARKING_AREAS_JSON = "solution_ci_parking_areas.json" # only for parking management solution
21
25
  PARKING_MODEL = "solutions_ci_parking_model.pt" # only for parking management solution
26
+ VERTICAL_VIDEO = "solution_vertical_demo.mp4" # only for vertical line counting
22
27
  REGION = [(10, 200), (540, 200), (540, 180), (10, 180)] # for object counting, speed estimation and queue management
28
+ HORIZONTAL_LINE = [(10, 200), (540, 200)] # for object counting
29
+ VERTICAL_LINE = [(320, 0), (320, 400)] # for object counting
23
30
 
24
31
  # Test configs for each solution : (name, class, needs_frame_count, video, kwargs)
25
32
  SOLUTIONS = [
@@ -30,6 +37,27 @@ SOLUTIONS = [
30
37
  DEMO_VIDEO,
31
38
  {"region": REGION, "model": MODEL, "show": SHOW},
32
39
  ),
40
+ (
41
+ "ObjectCounter",
42
+ solutions.ObjectCounter,
43
+ False,
44
+ DEMO_VIDEO,
45
+ {"region": HORIZONTAL_LINE, "model": MODEL, "show": SHOW},
46
+ ),
47
+ (
48
+ "ObjectCounterVertical",
49
+ solutions.ObjectCounter,
50
+ False,
51
+ DEMO_VIDEO,
52
+ {"region": VERTICAL_LINE, "model": MODEL, "show": SHOW},
53
+ ),
54
+ (
55
+ "ObjectCounterwithOBB",
56
+ solutions.ObjectCounter,
57
+ False,
58
+ DEMO_VIDEO,
59
+ {"region": REGION, "model": "yolo11n-obb.pt", "show": SHOW},
60
+ ),
33
61
  (
34
62
  "Heatmap",
35
63
  solutions.Heatmap,
@@ -63,28 +91,28 @@ SOLUTIONS = [
63
91
  solutions.Analytics,
64
92
  True,
65
93
  DEMO_VIDEO,
66
- {"analytics_type": "line", "model": MODEL, "show": SHOW},
94
+ {"analytics_type": "line", "model": MODEL, "show": SHOW, "figsize": (6.4, 3.2)},
67
95
  ),
68
96
  (
69
97
  "PieAnalytics",
70
98
  solutions.Analytics,
71
99
  True,
72
100
  DEMO_VIDEO,
73
- {"analytics_type": "pie", "model": MODEL, "show": SHOW},
101
+ {"analytics_type": "pie", "model": MODEL, "show": SHOW, "figsize": (6.4, 3.2)},
74
102
  ),
75
103
  (
76
104
  "BarAnalytics",
77
105
  solutions.Analytics,
78
106
  True,
79
107
  DEMO_VIDEO,
80
- {"analytics_type": "bar", "model": MODEL, "show": SHOW},
108
+ {"analytics_type": "bar", "model": MODEL, "show": SHOW, "figsize": (6.4, 3.2)},
81
109
  ),
82
110
  (
83
111
  "AreaAnalytics",
84
112
  solutions.Analytics,
85
113
  True,
86
114
  DEMO_VIDEO,
87
- {"analytics_type": "area", "model": MODEL, "show": SHOW},
115
+ {"analytics_type": "area", "model": MODEL, "show": SHOW, "figsize": (6.4, 3.2)},
88
116
  ),
89
117
  ("TrackZone", solutions.TrackZone, False, DEMO_VIDEO, {"region": REGION, "model": MODEL, "show": SHOW}),
90
118
  (
@@ -99,7 +127,7 @@ SOLUTIONS = [
99
127
  solutions.ObjectBlurrer,
100
128
  False,
101
129
  DEMO_VIDEO,
102
- {"blur_ratio": 0.5, "model": MODEL, "show": SHOW},
130
+ {"blur_ratio": 0.02, "model": MODEL, "show": SHOW},
103
131
  ),
104
132
  (
105
133
  "InstanceSegmentation",
@@ -160,7 +188,10 @@ def process_video(solution, video_path, needs_frame_count=False):
160
188
  def test_solution(name, solution_class, needs_frame_count, video, kwargs):
161
189
  """Test individual Ultralytics solution."""
162
190
  if video:
163
- safe_download(url=f"{ASSETS_URL}/{video}", dir=TMP)
191
+ if name != "ObjectCounterVertical":
192
+ safe_download(url=f"{ASSETS_URL}/{video}", dir=TMP)
193
+ else:
194
+ safe_download(url=f"{ASSETS_URL}/{VERTICAL_VIDEO}", dir=TMP)
164
195
  if name == "ParkingManager":
165
196
  safe_download(url=f"{ASSETS_URL}/{PARKING_AREAS_JSON}", dir=TMP)
166
197
  safe_download(url=f"{ASSETS_URL}/{PARKING_MODEL}", dir=TMP)
@@ -169,6 +200,7 @@ def test_solution(name, solution_class, needs_frame_count, video, kwargs):
169
200
  solution_class(**kwargs).inference() # requires interactive GUI environment
170
201
  return
171
202
 
203
+ video = VERTICAL_VIDEO if name == "ObjectCounterVertical" else video
172
204
  process_video(
173
205
  solution=solution_class(**kwargs),
174
206
  video_path=str(TMP / video),
@@ -181,7 +213,150 @@ def test_solution(name, solution_class, needs_frame_count, video, kwargs):
181
213
  @pytest.mark.skipif(IS_RASPBERRYPI, reason="Disabled due to slow performance on Raspberry Pi.")
182
214
  def test_similarity_search():
183
215
  """Test similarity search solution."""
184
- from ultralytics import solutions
185
-
186
216
  searcher = solutions.VisualAISearch()
187
217
  _ = searcher("a dog sitting on a bench") # Returns the results in format "- img name | similarity score"
218
+
219
+
220
+ def test_left_click_selection():
221
+ """Test distance calculation left click."""
222
+ dc = solutions.DistanceCalculation()
223
+ dc.boxes, dc.track_ids = [[10, 10, 50, 50]], [1]
224
+ dc.mouse_event_for_distance(cv2.EVENT_LBUTTONDOWN, 30, 30, None, None)
225
+ assert 1 in dc.selected_boxes
226
+
227
+
228
+ def test_right_click_reset():
229
+ """Test distance calculation right click."""
230
+ dc = solutions.DistanceCalculation()
231
+ dc.selected_boxes, dc.left_mouse_count = {1: [10, 10, 50, 50]}, 1
232
+ dc.mouse_event_for_distance(cv2.EVENT_RBUTTONDOWN, 0, 0, None, None)
233
+ assert dc.selected_boxes == {}
234
+ assert dc.left_mouse_count == 0
235
+
236
+
237
+ def test_parking_json_none():
238
+ """Test that ParkingManagement skips or errors cleanly when no JSON is provided."""
239
+ im0 = np.zeros((640, 480, 3), dtype=np.uint8)
240
+ try:
241
+ parkingmanager = solutions.ParkingManagement(json_path=None)
242
+ parkingmanager(im0)
243
+ except ValueError:
244
+ pytest.skip("Skipping test due to missing JSON.")
245
+
246
+
247
+ def test_analytics_graph_not_supported():
248
+ """Test that unsupported analytics type raises ModuleNotFoundError."""
249
+ try:
250
+ analytics = solutions.Analytics(analytics_type="test") # 'test' is unsupported
251
+ analytics.process(im0=None, frame_number=0)
252
+ assert False, "Expected ModuleNotFoundError for unsupported chart type"
253
+ except ModuleNotFoundError as e:
254
+ assert "test chart is not supported" in str(e)
255
+
256
+
257
+ def test_area_chart_padding():
258
+ """Test area chart graph update with dynamic class padding logic."""
259
+ analytics = solutions.Analytics(analytics_type="area")
260
+ analytics.update_graph(frame_number=1, count_dict={"car": 2}, plot="area")
261
+ plot_im = analytics.update_graph(frame_number=2, count_dict={"car": 3, "person": 1}, plot="area")
262
+ assert plot_im is not None
263
+
264
+
265
+ def test_config_update_method_with_invalid_argument():
266
+ """Test that update() raises ValueError for invalid config keys."""
267
+ obj = solutions.config.SolutionConfig()
268
+ try:
269
+ obj.update(invalid_key=123)
270
+ assert False, "Expected ValueError for invalid update argument"
271
+ except ValueError as e:
272
+ assert "❌ invalid_key is not a valid solution argument" in str(e)
273
+
274
+
275
+ def test_plot_with_no_masks():
276
+ """Test that instance segmentation handles cases with no masks."""
277
+ im0 = np.zeros((640, 480, 3), dtype=np.uint8)
278
+ isegment = solutions.InstanceSegmentation(model="yolo11n-seg.pt")
279
+ results = isegment(im0)
280
+ assert results.plot_im is not None
281
+
282
+
283
+ def test_streamlit_handle_video_upload_creates_file():
284
+ """Test Streamlit video upload logic saves file correctly."""
285
+ import io
286
+
287
+ fake_file = io.BytesIO(b"fake video content")
288
+ fake_file.read = fake_file.getvalue
289
+ if fake_file is not None:
290
+ g = io.BytesIO(fake_file.read())
291
+ with open("ultralytics.mp4", "wb") as out:
292
+ out.write(g.read())
293
+ output_path = "ultralytics.mp4"
294
+ else:
295
+ output_path = None
296
+ assert output_path == "ultralytics.mp4"
297
+ assert os.path.exists("ultralytics.mp4")
298
+ with open("ultralytics.mp4", "rb") as f:
299
+ assert f.read() == b"fake video content"
300
+ os.remove("ultralytics.mp4")
301
+
302
+
303
+ @pytest.mark.skipif(IS_RASPBERRYPI, reason="Disabled due to slow performance on Raspberry Pi.")
304
+ def test_similarity_search_app_init():
305
+ """Test SearchApp initializes with required attributes."""
306
+ app = solutions.SearchApp(device="cpu")
307
+ assert hasattr(app, "searcher")
308
+ assert hasattr(app, "run")
309
+
310
+
311
+ @pytest.mark.skipif(IS_RASPBERRYPI, reason="Disabled due to slow performance on Raspberry Pi.")
312
+ def test_similarity_search_complete(tmp_path):
313
+ """Test VisualAISearch end-to-end with sample image and query."""
314
+ from PIL import Image
315
+
316
+ image_dir = tmp_path / "images"
317
+ os.makedirs(image_dir, exist_ok=True)
318
+ for i in range(2):
319
+ img = Image.fromarray(np.uint8(np.random.rand(224, 224, 3) * 255))
320
+ img.save(image_dir / f"test_image_{i}.jpg")
321
+ searcher = solutions.VisualAISearch(data=str(image_dir))
322
+ results = searcher("a red and white object")
323
+ assert results
324
+
325
+
326
+ def test_distance_calculation_process_method():
327
+ """Test DistanceCalculation.process() computes distance between selected boxes."""
328
+ from ultralytics.solutions.solutions import SolutionResults
329
+
330
+ dc = solutions.DistanceCalculation()
331
+ dc.boxes, dc.track_ids, dc.clss, dc.confs = (
332
+ [[100, 100, 200, 200], [300, 300, 400, 400]],
333
+ [1, 2],
334
+ [0, 0],
335
+ [0.9, 0.95],
336
+ )
337
+ dc.selected_boxes = {1: dc.boxes[0], 2: dc.boxes[1]}
338
+ frame = np.zeros((480, 640, 3), dtype=np.uint8)
339
+ with patch.object(dc, "extract_tracks"), patch.object(dc, "display_output"), patch("cv2.setMouseCallback"):
340
+ result = dc.process(frame)
341
+ assert isinstance(result, SolutionResults)
342
+ assert result.total_tracks == 2
343
+ assert result.pixels_distance > 0
344
+
345
+
346
+ def test_object_crop_with_show_True():
347
+ """Test ObjectCropper init with show=True to cover display warning."""
348
+ solutions.ObjectCropper(show=True)
349
+
350
+
351
+ def test_display_output_method():
352
+ """Test that display_output triggers imshow, waitKey, and destroyAllWindows when enabled."""
353
+ counter = solutions.ObjectCounter(show=True)
354
+ counter.env_check = True
355
+ frame = np.zeros((100, 100, 3), dtype=np.uint8)
356
+ with patch("cv2.imshow") as mock_imshow, patch("cv2.waitKey", return_value=ord("q")) as mock_wait, patch(
357
+ "cv2.destroyAllWindows"
358
+ ) as mock_destroy:
359
+ counter.display_output(frame)
360
+ mock_imshow.assert_called_once()
361
+ mock_wait.assert_called_once()
362
+ mock_destroy.assert_called_once()
ultralytics/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- __version__ = "8.3.137"
3
+ __version__ = "8.3.139"
4
4
 
5
5
  import os
6
6
 
@@ -311,7 +311,7 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
311
311
  if k in cfg and isinstance(cfg[k], (int, float)):
312
312
  cfg[k] = str(cfg[k])
313
313
  if cfg.get("name") == "model": # assign model to 'name' arg
314
- cfg["name"] = str(cfg.get("model", "")).split(".")[0]
314
+ cfg["name"] = str(cfg.get("model", "")).partition(".")[0]
315
315
  LOGGER.warning(f"'name=model' automatically updated to 'name={cfg['name']}'.")
316
316
 
317
317
  # Type and Value checks
ultralytics/data/base.py CHANGED
@@ -170,7 +170,7 @@ class BaseDataset(Dataset):
170
170
  # F += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
171
171
  else:
172
172
  raise FileNotFoundError(f"{self.prefix}{p} does not exist")
173
- im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
173
+ im_files = sorted(x.replace("/", os.sep) for x in f if x.rpartition(".")[-1].lower() in IMG_FORMATS)
174
174
  # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
175
175
  assert im_files, f"{self.prefix}No images found in {img_path}. {FORMATS_HELP_MSG}"
176
176
  except Exception as e:
ultralytics/data/build.py CHANGED
@@ -200,10 +200,11 @@ def check_source(source):
200
200
  webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False
201
201
  if isinstance(source, (str, int, Path)): # int for local usb camera
202
202
  source = str(source)
203
- is_file = Path(source).suffix[1:] in (IMG_FORMATS | VID_FORMATS)
204
- is_url = source.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://"))
203
+ source_lower = source.lower()
204
+ is_file = source_lower.rpartition(".")[-1] in (IMG_FORMATS | VID_FORMATS)
205
+ is_url = source_lower.startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://"))
205
206
  webcam = source.isnumeric() or source.endswith(".streams") or (is_url and not is_file)
206
- screenshot = source.lower() == "screen"
207
+ screenshot = source_lower == "screen"
207
208
  if is_url and is_file:
208
209
  source = check_file(source) # download
209
210
  elif isinstance(source, LOADERS):
@@ -353,7 +353,7 @@ class LoadImagesAndVideos:
353
353
  # Define files as images or videos
354
354
  images, videos = [], []
355
355
  for f in files:
356
- suffix = f.split(".")[-1].lower() # Get file extension without the dot and lowercase
356
+ suffix = f.rpartition(".")[-1].lower() # Get file extension without the dot and lowercase
357
357
  if suffix in IMG_FORMATS:
358
358
  images.append(f)
359
359
  elif suffix in VID_FORMATS:
@@ -427,7 +427,7 @@ class LoadImagesAndVideos:
427
427
  else:
428
428
  # Handle image files (including HEIC)
429
429
  self.mode = "image"
430
- if path.split(".")[-1].lower() == "heic":
430
+ if path.rpartition(".")[-1].lower() == "heic":
431
431
  # Load HEIC image using Pillow with pillow-heif
432
432
  check_requirements("pillow-heif")
433
433
 
@@ -244,7 +244,6 @@ class Exporter:
244
244
 
245
245
  def __call__(self, model=None) -> str:
246
246
  """Return list of exported files/dirs after running callbacks."""
247
- self.run_callbacks("on_export_start")
248
247
  t = time.time()
249
248
  fmt = self.args.format.lower() # to lowercase
250
249
  if fmt in {"tensorrt", "trt"}: # 'engine' aliases
@@ -277,7 +276,7 @@ class Exporter:
277
276
  LOGGER.warning("TensorRT requires GPU export, automatically assigning device=0")
278
277
  self.args.device = "0"
279
278
  if fmt == "engine" and "dla" in str(self.args.device): # convert int/list to str first
280
- dla = self.args.device.split(":")[-1]
279
+ dla = self.args.device.rsplit(":", 1)[-1]
281
280
  self.args.device = "0" # update device to "0"
282
281
  assert dla in {"0", "1"}, f"Expected self.args.device='dla:0' or 'dla:1, but got {self.args.device}."
283
282
  if imx and self.args.device is None and torch.cuda.is_available():
@@ -450,7 +449,7 @@ class Exporter:
450
449
  f"\n{colorstr('PyTorch:')} starting from '{file}' with input shape {tuple(im.shape)} BCHW and "
451
450
  f"output shape(s) {self.output_shape} ({file_size(file):.1f} MB)"
452
451
  )
453
-
452
+ self.run_callbacks("on_export_start")
454
453
  # Exports
455
454
  f = [""] * len(fmts) # exported filenames
456
455
  if jit or ncnn: # TorchScript
@@ -789,7 +788,7 @@ class Exporter:
789
788
  subprocess.run(cmd, check=True)
790
789
 
791
790
  # Remove debug files
792
- pnnx_files = [x.split("=")[-1] for x in pnnx_args]
791
+ pnnx_files = [x.rsplit("=", 1)[-1] for x in pnnx_args]
793
792
  for f_debug in ("debug.bin", "debug.param", "debug2.bin", "debug2.param", *pnnx_files):
794
793
  Path(f_debug).unlink(missing_ok=True)
795
794
 
@@ -982,6 +981,7 @@ class Exporter:
982
981
  custom_input_op_name_np_data_path=np_data,
983
982
  enable_batchmatmul_unfold=True, # fix lower no. of detected objects on GPU delegate
984
983
  output_signaturedefs=True, # fix error with Attention block group convolution
984
+ disable_group_convolution=self.args.format == "tfjs", # fix TF.js error with group convolution
985
985
  optimization_for_gpu_delegate=True,
986
986
  )
987
987
  YAML.save(f / "metadata.yaml", self.metadata) # add metadata.yaml
@@ -1048,7 +1048,7 @@ class Exporter:
1048
1048
  "sudo apt-get install edgetpu-compiler",
1049
1049
  ):
1050
1050
  subprocess.run(c if is_sudo_available() else c.replace("sudo ", ""), shell=True, check=True)
1051
- ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
1051
+ ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().rsplit(maxsplit=1)[-1]
1052
1052
 
1053
1053
  LOGGER.info(f"\n{prefix} starting export with Edge TPU compiler {ver}...")
1054
1054
  f = str(tflite_model).replace(".tflite", "_edgetpu.tflite") # Edge TPU model
@@ -288,7 +288,7 @@ class Model(torch.nn.Module):
288
288
  weights = checks.check_file(weights, download_dir=SETTINGS["weights_dir"]) # download and return local file
289
289
  weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolo11n -> yolo11n.pt
290
290
 
291
- if Path(weights).suffix == ".pt":
291
+ if str(weights).rpartition(".")[-1] == "pt":
292
292
  self.model, self.ckpt = attempt_load_one_weight(weights)
293
293
  self.task = self.model.args["task"]
294
294
  self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
@@ -319,7 +319,7 @@ class Model(torch.nn.Module):
319
319
  >>> model = Model("yolo11n.onnx")
320
320
  >>> model._check_is_pytorch_model() # Raises TypeError
321
321
  """
322
- pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt"
322
+ pt_str = isinstance(self.model, (str, Path)) and str(self.model).rpartition(".")[-1] == "pt"
323
323
  pt_module = isinstance(self.model, torch.nn.Module)
324
324
  if not (pt_module or pt_str):
325
325
  raise TypeError(
@@ -43,7 +43,7 @@ import torch
43
43
 
44
44
  from ultralytics.cfg import get_cfg, get_save_dir
45
45
  from ultralytics.data import load_inference_source
46
- from ultralytics.data.augment import LetterBox, classify_transforms
46
+ from ultralytics.data.augment import LetterBox
47
47
  from ultralytics.nn.autobackend import AutoBackend
48
48
  from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops
49
49
  from ultralytics.utils.checks import check_imgsz, check_imshow
@@ -247,15 +247,6 @@ class BasePredictor:
247
247
  Source for inference.
248
248
  """
249
249
  self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
250
- self.transforms = (
251
- getattr(
252
- self.model.model,
253
- "transforms",
254
- classify_transforms(self.imgsz[0]),
255
- )
256
- if self.args.task == "classify"
257
- else None
258
- )
259
250
  self.dataset = load_inference_source(
260
251
  source=source,
261
252
  batch=self.args.batch,
@@ -395,6 +386,8 @@ class BasePredictor:
395
386
 
396
387
  self.device = self.model.device # update device
397
388
  self.args.half = self.model.fp16 # update half
389
+ if hasattr(self.model, "imgsz"):
390
+ self.args.imgsz = self.model.imgsz # reuse imgsz from export metadata
398
391
  self.model.eval()
399
392
 
400
393
  def write_results(self, i, p, im, s):