ultralytics 8.0.237__py3-none-any.whl → 8.0.239__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +2 -2
- ultralytics/cfg/__init__.py +241 -138
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +34 -0
- ultralytics/data/__init__.py +9 -2
- ultralytics/data/annotator.py +4 -4
- ultralytics/data/augment.py +186 -169
- ultralytics/data/base.py +54 -48
- ultralytics/data/build.py +34 -23
- ultralytics/data/converter.py +242 -70
- ultralytics/data/dataset.py +117 -95
- ultralytics/data/explorer/__init__.py +5 -0
- ultralytics/data/explorer/explorer.py +170 -97
- ultralytics/data/explorer/gui/__init__.py +1 -0
- ultralytics/data/explorer/gui/dash.py +146 -76
- ultralytics/data/explorer/utils.py +87 -25
- ultralytics/data/loaders.py +75 -62
- ultralytics/data/split_dota.py +44 -36
- ultralytics/data/utils.py +160 -142
- ultralytics/engine/exporter.py +348 -292
- ultralytics/engine/model.py +102 -66
- ultralytics/engine/predictor.py +74 -55
- ultralytics/engine/results.py +63 -40
- ultralytics/engine/trainer.py +192 -144
- ultralytics/engine/tuner.py +66 -59
- ultralytics/engine/validator.py +31 -26
- ultralytics/hub/__init__.py +54 -31
- ultralytics/hub/auth.py +28 -25
- ultralytics/hub/session.py +282 -133
- ultralytics/hub/utils.py +64 -42
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +6 -6
- ultralytics/models/fastsam/predict.py +3 -2
- ultralytics/models/fastsam/prompt.py +55 -48
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +9 -8
- ultralytics/models/nas/predict.py +8 -6
- ultralytics/models/nas/val.py +11 -9
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +11 -9
- ultralytics/models/rtdetr/train.py +18 -16
- ultralytics/models/rtdetr/val.py +25 -19
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +13 -14
- ultralytics/models/sam/build.py +44 -42
- ultralytics/models/sam/model.py +6 -6
- ultralytics/models/sam/modules/decoders.py +6 -4
- ultralytics/models/sam/modules/encoders.py +37 -35
- ultralytics/models/sam/modules/sam.py +5 -4
- ultralytics/models/sam/modules/tiny_encoder.py +95 -73
- ultralytics/models/sam/modules/transformer.py +3 -2
- ultralytics/models/sam/predict.py +39 -27
- ultralytics/models/utils/loss.py +99 -95
- ultralytics/models/utils/ops.py +34 -31
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +8 -6
- ultralytics/models/yolo/classify/train.py +37 -31
- ultralytics/models/yolo/classify/val.py +26 -24
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +8 -6
- ultralytics/models/yolo/detect/train.py +47 -37
- ultralytics/models/yolo/detect/val.py +100 -82
- ultralytics/models/yolo/model.py +31 -25
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +13 -12
- ultralytics/models/yolo/obb/train.py +3 -3
- ultralytics/models/yolo/obb/val.py +80 -58
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +17 -12
- ultralytics/models/yolo/pose/train.py +28 -25
- ultralytics/models/yolo/pose/val.py +91 -64
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +10 -8
- ultralytics/models/yolo/segment/train.py +16 -15
- ultralytics/models/yolo/segment/val.py +90 -68
- ultralytics/nn/__init__.py +26 -6
- ultralytics/nn/autobackend.py +144 -112
- ultralytics/nn/modules/__init__.py +96 -13
- ultralytics/nn/modules/block.py +28 -7
- ultralytics/nn/modules/conv.py +41 -23
- ultralytics/nn/modules/head.py +67 -59
- ultralytics/nn/modules/transformer.py +49 -32
- ultralytics/nn/modules/utils.py +20 -15
- ultralytics/nn/tasks.py +215 -141
- ultralytics/solutions/ai_gym.py +59 -47
- ultralytics/solutions/distance_calculation.py +22 -15
- ultralytics/solutions/heatmap.py +76 -54
- ultralytics/solutions/object_counter.py +46 -39
- ultralytics/solutions/speed_estimation.py +13 -16
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +1 -0
- ultralytics/trackers/bot_sort.py +2 -1
- ultralytics/trackers/byte_tracker.py +10 -7
- ultralytics/trackers/track.py +7 -7
- ultralytics/trackers/utils/gmc.py +25 -25
- ultralytics/trackers/utils/kalman_filter.py +85 -42
- ultralytics/trackers/utils/matching.py +8 -7
- ultralytics/utils/__init__.py +173 -151
- ultralytics/utils/autobatch.py +10 -10
- ultralytics/utils/benchmarks.py +76 -86
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +29 -29
- ultralytics/utils/callbacks/clearml.py +51 -43
- ultralytics/utils/callbacks/comet.py +81 -66
- ultralytics/utils/callbacks/dvc.py +33 -26
- ultralytics/utils/callbacks/hub.py +44 -26
- ultralytics/utils/callbacks/mlflow.py +31 -24
- ultralytics/utils/callbacks/neptune.py +35 -25
- ultralytics/utils/callbacks/raytune.py +9 -4
- ultralytics/utils/callbacks/tensorboard.py +16 -11
- ultralytics/utils/callbacks/wb.py +39 -33
- ultralytics/utils/checks.py +189 -141
- ultralytics/utils/dist.py +15 -12
- ultralytics/utils/downloads.py +112 -96
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +11 -11
- ultralytics/utils/instance.py +22 -22
- ultralytics/utils/loss.py +117 -67
- ultralytics/utils/metrics.py +224 -158
- ultralytics/utils/ops.py +39 -29
- ultralytics/utils/patches.py +3 -3
- ultralytics/utils/plotting.py +217 -120
- ultralytics/utils/tal.py +19 -13
- ultralytics/utils/torch_utils.py +138 -109
- ultralytics/utils/triton.py +12 -10
- ultralytics/utils/tuner.py +49 -47
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/METADATA +5 -4
- ultralytics-8.0.239.dist-info/RECORD +188 -0
- ultralytics-8.0.237.dist-info/RECORD +0 -187
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
ultralytics/data/base.py
CHANGED
|
@@ -47,20 +47,22 @@ class BaseDataset(Dataset):
|
|
|
47
47
|
transforms (callable): Image transformation function.
|
|
48
48
|
"""
|
|
49
49
|
|
|
50
|
-
def __init__(
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
img_path,
|
|
53
|
+
imgsz=640,
|
|
54
|
+
cache=False,
|
|
55
|
+
augment=True,
|
|
56
|
+
hyp=DEFAULT_CFG,
|
|
57
|
+
prefix="",
|
|
58
|
+
rect=False,
|
|
59
|
+
batch_size=16,
|
|
60
|
+
stride=32,
|
|
61
|
+
pad=0.5,
|
|
62
|
+
single_cls=False,
|
|
63
|
+
classes=None,
|
|
64
|
+
fraction=1.0,
|
|
65
|
+
):
|
|
64
66
|
"""Initialize BaseDataset with given configuration and options."""
|
|
65
67
|
super().__init__()
|
|
66
68
|
self.img_path = img_path
|
|
@@ -86,10 +88,10 @@ class BaseDataset(Dataset):
|
|
|
86
88
|
self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0
|
|
87
89
|
|
|
88
90
|
# Cache images
|
|
89
|
-
if cache ==
|
|
91
|
+
if cache == "ram" and not self.check_cache_ram():
|
|
90
92
|
cache = False
|
|
91
93
|
self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni
|
|
92
|
-
self.npy_files = [Path(f).with_suffix(
|
|
94
|
+
self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files]
|
|
93
95
|
if cache:
|
|
94
96
|
self.cache_images(cache)
|
|
95
97
|
|
|
@@ -103,23 +105,23 @@ class BaseDataset(Dataset):
|
|
|
103
105
|
for p in img_path if isinstance(img_path, list) else [img_path]:
|
|
104
106
|
p = Path(p) # os-agnostic
|
|
105
107
|
if p.is_dir(): # dir
|
|
106
|
-
f += glob.glob(str(p /
|
|
108
|
+
f += glob.glob(str(p / "**" / "*.*"), recursive=True)
|
|
107
109
|
# F = list(p.rglob('*.*')) # pathlib
|
|
108
110
|
elif p.is_file(): # file
|
|
109
111
|
with open(p) as t:
|
|
110
112
|
t = t.read().strip().splitlines()
|
|
111
113
|
parent = str(p.parent) + os.sep
|
|
112
|
-
f += [x.replace(
|
|
114
|
+
f += [x.replace("./", parent) if x.startswith("./") else x for x in t] # local to global path
|
|
113
115
|
# F += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
|
|
114
116
|
else:
|
|
115
|
-
raise FileNotFoundError(f
|
|
116
|
-
im_files = sorted(x.replace(
|
|
117
|
+
raise FileNotFoundError(f"{self.prefix}{p} does not exist")
|
|
118
|
+
im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
|
|
117
119
|
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
|
|
118
|
-
assert im_files, f
|
|
120
|
+
assert im_files, f"{self.prefix}No images found in {img_path}"
|
|
119
121
|
except Exception as e:
|
|
120
|
-
raise FileNotFoundError(f
|
|
122
|
+
raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e
|
|
121
123
|
if self.fraction < 1:
|
|
122
|
-
im_files = im_files[:round(len(im_files) * self.fraction)]
|
|
124
|
+
im_files = im_files[: round(len(im_files) * self.fraction)]
|
|
123
125
|
return im_files
|
|
124
126
|
|
|
125
127
|
def update_labels(self, include_class: Optional[list]):
|
|
@@ -127,19 +129,19 @@ class BaseDataset(Dataset):
|
|
|
127
129
|
include_class_array = np.array(include_class).reshape(1, -1)
|
|
128
130
|
for i in range(len(self.labels)):
|
|
129
131
|
if include_class is not None:
|
|
130
|
-
cls = self.labels[i][
|
|
131
|
-
bboxes = self.labels[i][
|
|
132
|
-
segments = self.labels[i][
|
|
133
|
-
keypoints = self.labels[i][
|
|
132
|
+
cls = self.labels[i]["cls"]
|
|
133
|
+
bboxes = self.labels[i]["bboxes"]
|
|
134
|
+
segments = self.labels[i]["segments"]
|
|
135
|
+
keypoints = self.labels[i]["keypoints"]
|
|
134
136
|
j = (cls == include_class_array).any(1)
|
|
135
|
-
self.labels[i][
|
|
136
|
-
self.labels[i][
|
|
137
|
+
self.labels[i]["cls"] = cls[j]
|
|
138
|
+
self.labels[i]["bboxes"] = bboxes[j]
|
|
137
139
|
if segments:
|
|
138
|
-
self.labels[i][
|
|
140
|
+
self.labels[i]["segments"] = [segments[si] for si, idx in enumerate(j) if idx]
|
|
139
141
|
if keypoints is not None:
|
|
140
|
-
self.labels[i][
|
|
142
|
+
self.labels[i]["keypoints"] = keypoints[j]
|
|
141
143
|
if self.single_cls:
|
|
142
|
-
self.labels[i][
|
|
144
|
+
self.labels[i]["cls"][:, 0] = 0
|
|
143
145
|
|
|
144
146
|
def load_image(self, i, rect_mode=True):
|
|
145
147
|
"""Loads 1 image from dataset index 'i', returns (im, resized hw)."""
|
|
@@ -149,13 +151,13 @@ class BaseDataset(Dataset):
|
|
|
149
151
|
try:
|
|
150
152
|
im = np.load(fn)
|
|
151
153
|
except Exception as e:
|
|
152
|
-
LOGGER.warning(f
|
|
154
|
+
LOGGER.warning(f"{self.prefix}WARNING ⚠️ Removing corrupt *.npy image file {fn} due to: {e}")
|
|
153
155
|
Path(fn).unlink(missing_ok=True)
|
|
154
156
|
im = cv2.imread(f) # BGR
|
|
155
157
|
else: # read image
|
|
156
158
|
im = cv2.imread(f) # BGR
|
|
157
159
|
if im is None:
|
|
158
|
-
raise FileNotFoundError(f
|
|
160
|
+
raise FileNotFoundError(f"Image Not Found {f}")
|
|
159
161
|
|
|
160
162
|
h0, w0 = im.shape[:2] # orig hw
|
|
161
163
|
if rect_mode: # resize long side to imgsz while maintaining aspect ratio
|
|
@@ -181,17 +183,17 @@ class BaseDataset(Dataset):
|
|
|
181
183
|
def cache_images(self, cache):
|
|
182
184
|
"""Cache images to memory or disk."""
|
|
183
185
|
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
|
|
184
|
-
fcn = self.cache_images_to_disk if cache ==
|
|
186
|
+
fcn = self.cache_images_to_disk if cache == "disk" else self.load_image
|
|
185
187
|
with ThreadPool(NUM_THREADS) as pool:
|
|
186
188
|
results = pool.imap(fcn, range(self.ni))
|
|
187
189
|
pbar = TQDM(enumerate(results), total=self.ni, disable=LOCAL_RANK > 0)
|
|
188
190
|
for i, x in pbar:
|
|
189
|
-
if cache ==
|
|
191
|
+
if cache == "disk":
|
|
190
192
|
b += self.npy_files[i].stat().st_size
|
|
191
193
|
else: # 'ram'
|
|
192
194
|
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
|
|
193
195
|
b += self.ims[i].nbytes
|
|
194
|
-
pbar.desc = f
|
|
196
|
+
pbar.desc = f"{self.prefix}Caching images ({b / gb:.1f}GB {cache})"
|
|
195
197
|
pbar.close()
|
|
196
198
|
|
|
197
199
|
def cache_images_to_disk(self, i):
|
|
@@ -207,15 +209,17 @@ class BaseDataset(Dataset):
|
|
|
207
209
|
for _ in range(n):
|
|
208
210
|
im = cv2.imread(random.choice(self.im_files)) # sample image
|
|
209
211
|
ratio = self.imgsz / max(im.shape[0], im.shape[1]) # max(h, w) # ratio
|
|
210
|
-
b += im.nbytes * ratio
|
|
212
|
+
b += im.nbytes * ratio**2
|
|
211
213
|
mem_required = b * self.ni / n * (1 + safety_margin) # GB required to cache dataset into RAM
|
|
212
214
|
mem = psutil.virtual_memory()
|
|
213
215
|
cache = mem_required < mem.available # to cache or not to cache, that is the question
|
|
214
216
|
if not cache:
|
|
215
|
-
LOGGER.info(
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
217
|
+
LOGGER.info(
|
|
218
|
+
f'{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images '
|
|
219
|
+
f'with {int(safety_margin * 100)}% safety margin but only '
|
|
220
|
+
f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, '
|
|
221
|
+
f"{'caching images ✅' if cache else 'not caching images ⚠️'}"
|
|
222
|
+
)
|
|
219
223
|
return cache
|
|
220
224
|
|
|
221
225
|
def set_rectangle(self):
|
|
@@ -223,7 +227,7 @@ class BaseDataset(Dataset):
|
|
|
223
227
|
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
|
|
224
228
|
nb = bi[-1] + 1 # number of batches
|
|
225
229
|
|
|
226
|
-
s = np.array([x.pop(
|
|
230
|
+
s = np.array([x.pop("shape") for x in self.labels]) # hw
|
|
227
231
|
ar = s[:, 0] / s[:, 1] # aspect ratio
|
|
228
232
|
irect = ar.argsort()
|
|
229
233
|
self.im_files = [self.im_files[i] for i in irect]
|
|
@@ -250,12 +254,14 @@ class BaseDataset(Dataset):
|
|
|
250
254
|
def get_image_and_label(self, index):
|
|
251
255
|
"""Get and return label information from the dataset."""
|
|
252
256
|
label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948
|
|
253
|
-
label.pop(
|
|
254
|
-
label[
|
|
255
|
-
label[
|
|
256
|
-
|
|
257
|
+
label.pop("shape", None) # shape is for rect, remove it
|
|
258
|
+
label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index)
|
|
259
|
+
label["ratio_pad"] = (
|
|
260
|
+
label["resized_shape"][0] / label["ori_shape"][0],
|
|
261
|
+
label["resized_shape"][1] / label["ori_shape"][1],
|
|
262
|
+
) # for evaluation
|
|
257
263
|
if self.rect:
|
|
258
|
-
label[
|
|
264
|
+
label["rect_shape"] = self.batch_shapes[self.batch[index]]
|
|
259
265
|
return self.update_labels_info(label)
|
|
260
266
|
|
|
261
267
|
def __len__(self):
|
ultralytics/data/build.py
CHANGED
|
@@ -9,8 +9,16 @@ import torch
|
|
|
9
9
|
from PIL import Image
|
|
10
10
|
from torch.utils.data import dataloader, distributed
|
|
11
11
|
|
|
12
|
-
from ultralytics.data.loaders import (
|
|
13
|
-
|
|
12
|
+
from ultralytics.data.loaders import (
|
|
13
|
+
LOADERS,
|
|
14
|
+
LoadImages,
|
|
15
|
+
LoadPilAndNumpy,
|
|
16
|
+
LoadScreenshots,
|
|
17
|
+
LoadStreams,
|
|
18
|
+
LoadTensor,
|
|
19
|
+
SourceTypes,
|
|
20
|
+
autocast_list,
|
|
21
|
+
)
|
|
14
22
|
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
|
|
15
23
|
from ultralytics.utils import RANK, colorstr
|
|
16
24
|
from ultralytics.utils.checks import check_file
|
|
@@ -29,7 +37,7 @@ class InfiniteDataLoader(dataloader.DataLoader):
|
|
|
29
37
|
def __init__(self, *args, **kwargs):
|
|
30
38
|
"""Dataloader that infinitely recycles workers, inherits from DataLoader."""
|
|
31
39
|
super().__init__(*args, **kwargs)
|
|
32
|
-
object.__setattr__(self,
|
|
40
|
+
object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))
|
|
33
41
|
self.iterator = super().__iter__()
|
|
34
42
|
|
|
35
43
|
def __len__(self):
|
|
@@ -70,29 +78,30 @@ class _RepeatSampler:
|
|
|
70
78
|
|
|
71
79
|
def seed_worker(worker_id): # noqa
|
|
72
80
|
"""Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader."""
|
|
73
|
-
worker_seed = torch.initial_seed() % 2
|
|
81
|
+
worker_seed = torch.initial_seed() % 2**32
|
|
74
82
|
np.random.seed(worker_seed)
|
|
75
83
|
random.seed(worker_seed)
|
|
76
84
|
|
|
77
85
|
|
|
78
|
-
def build_yolo_dataset(cfg, img_path, batch, data, mode=
|
|
86
|
+
def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32):
|
|
79
87
|
"""Build YOLO Dataset."""
|
|
80
88
|
return YOLODataset(
|
|
81
89
|
img_path=img_path,
|
|
82
90
|
imgsz=cfg.imgsz,
|
|
83
91
|
batch_size=batch,
|
|
84
|
-
augment=mode ==
|
|
92
|
+
augment=mode == "train", # augmentation
|
|
85
93
|
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
|
|
86
94
|
rect=cfg.rect or rect, # rectangular batches
|
|
87
95
|
cache=cfg.cache or None,
|
|
88
96
|
single_cls=cfg.single_cls or False,
|
|
89
97
|
stride=int(stride),
|
|
90
|
-
pad=0.0 if mode ==
|
|
91
|
-
prefix=colorstr(f
|
|
98
|
+
pad=0.0 if mode == "train" else 0.5,
|
|
99
|
+
prefix=colorstr(f"{mode}: "),
|
|
92
100
|
task=cfg.task,
|
|
93
101
|
classes=cfg.classes,
|
|
94
102
|
data=data,
|
|
95
|
-
fraction=cfg.fraction if mode ==
|
|
103
|
+
fraction=cfg.fraction if mode == "train" else 1.0,
|
|
104
|
+
)
|
|
96
105
|
|
|
97
106
|
|
|
98
107
|
def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
|
|
@@ -103,15 +112,17 @@ def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
|
|
|
103
112
|
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
|
104
113
|
generator = torch.Generator()
|
|
105
114
|
generator.manual_seed(6148914691236517205 + RANK)
|
|
106
|
-
return InfiniteDataLoader(
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
+
return InfiniteDataLoader(
|
|
116
|
+
dataset=dataset,
|
|
117
|
+
batch_size=batch,
|
|
118
|
+
shuffle=shuffle and sampler is None,
|
|
119
|
+
num_workers=nw,
|
|
120
|
+
sampler=sampler,
|
|
121
|
+
pin_memory=PIN_MEMORY,
|
|
122
|
+
collate_fn=getattr(dataset, "collate_fn", None),
|
|
123
|
+
worker_init_fn=seed_worker,
|
|
124
|
+
generator=generator,
|
|
125
|
+
)
|
|
115
126
|
|
|
116
127
|
|
|
117
128
|
def check_source(source):
|
|
@@ -120,9 +131,9 @@ def check_source(source):
|
|
|
120
131
|
if isinstance(source, (str, int, Path)): # int for local usb camera
|
|
121
132
|
source = str(source)
|
|
122
133
|
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
|
123
|
-
is_url = source.lower().startswith((
|
|
124
|
-
webcam = source.isnumeric() or source.endswith(
|
|
125
|
-
screenshot = source.lower() ==
|
|
134
|
+
is_url = source.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://"))
|
|
135
|
+
webcam = source.isnumeric() or source.endswith(".streams") or (is_url and not is_file)
|
|
136
|
+
screenshot = source.lower() == "screen"
|
|
126
137
|
if is_url and is_file:
|
|
127
138
|
source = check_file(source) # download
|
|
128
139
|
elif isinstance(source, LOADERS):
|
|
@@ -135,7 +146,7 @@ def check_source(source):
|
|
|
135
146
|
elif isinstance(source, torch.Tensor):
|
|
136
147
|
tensor = True
|
|
137
148
|
else:
|
|
138
|
-
raise TypeError(
|
|
149
|
+
raise TypeError("Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict")
|
|
139
150
|
|
|
140
151
|
return source, webcam, screenshot, from_img, in_memory, tensor
|
|
141
152
|
|
|
@@ -171,6 +182,6 @@ def load_inference_source(source=None, imgsz=640, vid_stride=1, buffer=False):
|
|
|
171
182
|
dataset = LoadImages(source, imgsz=imgsz, vid_stride=vid_stride)
|
|
172
183
|
|
|
173
184
|
# Attach source types to the dataset
|
|
174
|
-
setattr(dataset,
|
|
185
|
+
setattr(dataset, "source_type", source_type)
|
|
175
186
|
|
|
176
187
|
return dataset
|