ultralytics 8.1.41__py3-none-any.whl → 8.1.43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +38 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +3 -1
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +62 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +3 -1
- ultralytics/data/augment.py +12 -9
- ultralytics/data/dataset.py +147 -142
- ultralytics/data/explorer/explorer.py +4 -6
- ultralytics/data/explorer/gui/dash.py +3 -3
- ultralytics/data/explorer/utils.py +3 -2
- ultralytics/engine/exporter.py +3 -2
- ultralytics/engine/trainer.py +11 -10
- ultralytics/models/fastsam/prompt.py +4 -1
- ultralytics/models/sam/predict.py +4 -1
- ultralytics/models/yolo/classify/train.py +2 -1
- ultralytics/solutions/heatmap.py +14 -27
- ultralytics/solutions/object_counter.py +12 -23
- ultralytics/solutions/queue_management.py +187 -0
- ultralytics/utils/__init__.py +22 -14
- ultralytics/utils/benchmarks.py +1 -2
- ultralytics/utils/callbacks/clearml.py +4 -3
- ultralytics/utils/callbacks/wb.py +5 -5
- ultralytics/utils/checks.py +6 -9
- ultralytics/utils/metrics.py +3 -3
- ultralytics/utils/ops.py +1 -1
- ultralytics/utils/plotting.py +96 -37
- ultralytics/utils/torch_utils.py +15 -7
- {ultralytics-8.1.41.dist-info → ultralytics-8.1.43.dist-info}/METADATA +2 -1
- {ultralytics-8.1.41.dist-info → ultralytics-8.1.43.dist-info}/RECORD +33 -30
- {ultralytics-8.1.41.dist-info → ultralytics-8.1.43.dist-info}/LICENSE +0 -0
- {ultralytics-8.1.41.dist-info → ultralytics-8.1.43.dist-info}/WHEEL +0 -0
- {ultralytics-8.1.41.dist-info → ultralytics-8.1.43.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.1.41.dist-info → ultralytics-8.1.43.dist-info}/top_level.txt +0 -0
ultralytics/engine/trainer.py
CHANGED
|
@@ -331,6 +331,10 @@ class BaseTrainer:
|
|
|
331
331
|
while True:
|
|
332
332
|
self.epoch = epoch
|
|
333
333
|
self.run_callbacks("on_train_epoch_start")
|
|
334
|
+
with warnings.catch_warnings():
|
|
335
|
+
warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
|
|
336
|
+
self.scheduler.step()
|
|
337
|
+
|
|
334
338
|
self.model.train()
|
|
335
339
|
if RANK != -1:
|
|
336
340
|
self.train_loader.sampler.set_epoch(epoch)
|
|
@@ -426,15 +430,12 @@ class BaseTrainer:
|
|
|
426
430
|
t = time.time()
|
|
427
431
|
self.epoch_time = t - self.epoch_time_start
|
|
428
432
|
self.epoch_time_start = t
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
self.scheduler.last_epoch = self.epoch # do not move
|
|
436
|
-
self.stop |= epoch >= self.epochs # stop if exceeded epochs
|
|
437
|
-
self.scheduler.step()
|
|
433
|
+
if self.args.time:
|
|
434
|
+
mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
|
|
435
|
+
self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
|
|
436
|
+
self._setup_scheduler()
|
|
437
|
+
self.scheduler.last_epoch = self.epoch # do not move
|
|
438
|
+
self.stop |= epoch >= self.epochs # stop if exceeded epochs
|
|
438
439
|
self.run_callbacks("on_fit_epoch_end")
|
|
439
440
|
torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors
|
|
440
441
|
|
|
@@ -463,7 +464,7 @@ class BaseTrainer:
|
|
|
463
464
|
def save_model(self):
|
|
464
465
|
"""Save model training checkpoints with additional metadata."""
|
|
465
466
|
import io
|
|
466
|
-
import pandas as pd # scope for faster
|
|
467
|
+
import pandas as pd # scope for faster 'import ultralytics'
|
|
467
468
|
|
|
468
469
|
# Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
|
|
469
470
|
buffer = io.BytesIO()
|
|
@@ -4,7 +4,6 @@ import os
|
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
|
|
6
6
|
import cv2
|
|
7
|
-
import matplotlib.pyplot as plt
|
|
8
7
|
import numpy as np
|
|
9
8
|
import torch
|
|
10
9
|
from PIL import Image
|
|
@@ -118,6 +117,8 @@ class FastSAMPrompt:
|
|
|
118
117
|
retina (bool, optional): Whether to use retina mask. Defaults to False.
|
|
119
118
|
with_contours (bool, optional): Whether to plot contours. Defaults to True.
|
|
120
119
|
"""
|
|
120
|
+
import matplotlib.pyplot as plt
|
|
121
|
+
|
|
121
122
|
pbar = TQDM(annotations, total=len(annotations))
|
|
122
123
|
for ann in pbar:
|
|
123
124
|
result_name = os.path.basename(ann.path)
|
|
@@ -202,6 +203,8 @@ class FastSAMPrompt:
|
|
|
202
203
|
target_height (int, optional): Target height for resizing. Defaults to 960.
|
|
203
204
|
target_width (int, optional): Target width for resizing. Defaults to 960.
|
|
204
205
|
"""
|
|
206
|
+
import matplotlib.pyplot as plt
|
|
207
|
+
|
|
205
208
|
n, h, w = annotation.shape # batch, height, width
|
|
206
209
|
|
|
207
210
|
areas = np.sum(annotation, axis=(1, 2))
|
|
@@ -11,7 +11,6 @@ segmentation tasks.
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import torch
|
|
13
13
|
import torch.nn.functional as F
|
|
14
|
-
import torchvision
|
|
15
14
|
|
|
16
15
|
from ultralytics.data.augment import LetterBox
|
|
17
16
|
from ultralytics.engine.predictor import BasePredictor
|
|
@@ -246,6 +245,8 @@ class Predictor(BasePredictor):
|
|
|
246
245
|
Returns:
|
|
247
246
|
(tuple): A tuple containing segmented masks, confidence scores, and bounding boxes.
|
|
248
247
|
"""
|
|
248
|
+
import torchvision # scope for faster 'import ultralytics'
|
|
249
|
+
|
|
249
250
|
self.segment_all = True
|
|
250
251
|
ih, iw = im.shape[2:]
|
|
251
252
|
crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio)
|
|
@@ -449,6 +450,8 @@ class Predictor(BasePredictor):
|
|
|
449
450
|
- new_masks (torch.Tensor): The processed masks with small regions removed. Shape is (N, H, W).
|
|
450
451
|
- keep (List[int]): The indices of the remaining masks post-NMS, which can be used to filter the boxes.
|
|
451
452
|
"""
|
|
453
|
+
import torchvision # scope for faster 'import ultralytics'
|
|
454
|
+
|
|
452
455
|
if len(masks) == 0:
|
|
453
456
|
return masks
|
|
454
457
|
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
-
import torchvision
|
|
5
4
|
|
|
6
5
|
from ultralytics.data import ClassificationDataset, build_dataloader
|
|
7
6
|
from ultralytics.engine.trainer import BaseTrainer
|
|
@@ -59,6 +58,8 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
59
58
|
|
|
60
59
|
def setup_model(self):
|
|
61
60
|
"""Load, create or download model for any task."""
|
|
61
|
+
import torchvision # scope for faster 'import ultralytics'
|
|
62
|
+
|
|
62
63
|
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
|
63
64
|
return
|
|
64
65
|
|
ultralytics/solutions/heatmap.py
CHANGED
|
@@ -30,6 +30,7 @@ class Heatmap:
|
|
|
30
30
|
self.imw = None
|
|
31
31
|
self.imh = None
|
|
32
32
|
self.im0 = None
|
|
33
|
+
self.tf = 2
|
|
33
34
|
self.view_in_counts = True
|
|
34
35
|
self.view_out_counts = True
|
|
35
36
|
|
|
@@ -56,11 +57,9 @@ class Heatmap:
|
|
|
56
57
|
self.out_counts = 0
|
|
57
58
|
self.count_ids = []
|
|
58
59
|
self.class_wise_count = {}
|
|
59
|
-
self.
|
|
60
|
-
self.
|
|
61
|
-
self.line_color = (255, 255, 255)
|
|
60
|
+
self.count_txt_color = (0, 0, 0)
|
|
61
|
+
self.count_bg_color = (255, 255, 255)
|
|
62
62
|
self.cls_txtdisplay_gap = 50
|
|
63
|
-
self.fontsize = 0.6
|
|
64
63
|
|
|
65
64
|
# Decay factor
|
|
66
65
|
self.decay_factor = 0.99
|
|
@@ -79,16 +78,14 @@ class Heatmap:
|
|
|
79
78
|
view_in_counts=True,
|
|
80
79
|
view_out_counts=True,
|
|
81
80
|
count_reg_pts=None,
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
fontsize=0.8,
|
|
85
|
-
line_color=(255, 255, 255),
|
|
81
|
+
count_txt_color=(0, 0, 0),
|
|
82
|
+
count_bg_color=(255, 255, 255),
|
|
86
83
|
count_reg_color=(255, 0, 255),
|
|
87
84
|
region_thickness=5,
|
|
88
85
|
line_dist_thresh=15,
|
|
86
|
+
line_thickness=2,
|
|
89
87
|
decay_factor=0.99,
|
|
90
88
|
shape="circle",
|
|
91
|
-
cls_txtdisplay_gap=50,
|
|
92
89
|
):
|
|
93
90
|
"""
|
|
94
91
|
Configures the heatmap colormap, width, height and display parameters.
|
|
@@ -98,22 +95,21 @@ class Heatmap:
|
|
|
98
95
|
imw (int): The width of the frame.
|
|
99
96
|
imh (int): The height of the frame.
|
|
100
97
|
classes_names (dict): Classes names
|
|
98
|
+
line_thickness (int): Line thickness for bounding boxes.
|
|
101
99
|
heatmap_alpha (float): alpha value for heatmap display
|
|
102
100
|
view_img (bool): Flag indicating frame display
|
|
103
101
|
view_in_counts (bool): Flag to control whether to display the incounts on video stream.
|
|
104
102
|
view_out_counts (bool): Flag to control whether to display the outcounts on video stream.
|
|
105
103
|
count_reg_pts (list): Object counting region points
|
|
106
|
-
count_txt_thickness (int): Text thickness for object counting display
|
|
107
104
|
count_txt_color (RGB color): count text color value
|
|
108
|
-
|
|
109
|
-
line_color (RGB color): count highlighter line color
|
|
105
|
+
count_bg_color (RGB color): count highlighter line color
|
|
110
106
|
count_reg_color (RGB color): Color of object counting region
|
|
111
107
|
region_thickness (int): Object counting Region thickness
|
|
112
108
|
line_dist_thresh (int): Euclidean Distance threshold for line counter
|
|
113
109
|
decay_factor (float): value for removing heatmap area after object passed
|
|
114
110
|
shape (str): Heatmap shape, rect or circle shape supported
|
|
115
|
-
cls_txtdisplay_gap (int): Display gap between each class count
|
|
116
111
|
"""
|
|
112
|
+
self.tf = line_thickness
|
|
117
113
|
self.names = classes_names
|
|
118
114
|
self.imw = imw
|
|
119
115
|
self.imh = imh
|
|
@@ -141,16 +137,13 @@ class Heatmap:
|
|
|
141
137
|
# Heatmap new frame
|
|
142
138
|
self.heatmap = np.zeros((int(self.imh), int(self.imw)), dtype=np.float32)
|
|
143
139
|
|
|
144
|
-
self.count_txt_thickness = count_txt_thickness
|
|
145
140
|
self.count_txt_color = count_txt_color
|
|
146
|
-
self.
|
|
147
|
-
self.line_color = line_color
|
|
141
|
+
self.count_bg_color = count_bg_color
|
|
148
142
|
self.region_color = count_reg_color
|
|
149
143
|
self.region_thickness = region_thickness
|
|
150
144
|
self.decay_factor = decay_factor
|
|
151
145
|
self.line_dist_thresh = line_dist_thresh
|
|
152
146
|
self.shape = shape
|
|
153
|
-
self.cls_txtdisplay_gap = cls_txtdisplay_gap
|
|
154
147
|
|
|
155
148
|
# shape of heatmap, if not selected
|
|
156
149
|
if self.shape not in {"circle", "rect"}:
|
|
@@ -185,7 +178,7 @@ class Heatmap:
|
|
|
185
178
|
return im0
|
|
186
179
|
self.heatmap *= self.decay_factor # decay factor
|
|
187
180
|
self.extract_results(tracks)
|
|
188
|
-
self.annotator = Annotator(self.im0, self.
|
|
181
|
+
self.annotator = Annotator(self.im0, self.tf, None)
|
|
189
182
|
|
|
190
183
|
if self.count_reg_pts is not None:
|
|
191
184
|
# Draw counting region
|
|
@@ -239,11 +232,8 @@ class Heatmap:
|
|
|
239
232
|
|
|
240
233
|
# Count objects using line
|
|
241
234
|
elif len(self.count_reg_pts) == 2:
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
if prev_position is not None and is_inside and track_id not in self.count_ids:
|
|
235
|
+
if prev_position is not None and track_id not in self.count_ids:
|
|
245
236
|
distance = Point(track_line[-1]).distance(self.counting_region)
|
|
246
|
-
|
|
247
237
|
if distance < self.line_dist_thresh and track_id not in self.count_ids:
|
|
248
238
|
self.count_ids.append(track_id)
|
|
249
239
|
|
|
@@ -293,11 +283,8 @@ class Heatmap:
|
|
|
293
283
|
if self.count_reg_pts is not None and label is not None:
|
|
294
284
|
self.annotator.display_counts(
|
|
295
285
|
counts=label,
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
txt_color=self.count_txt_color,
|
|
299
|
-
line_color=self.line_color,
|
|
300
|
-
classwise_txtgap=self.cls_txtdisplay_gap,
|
|
286
|
+
count_txt_color=self.count_txt_color,
|
|
287
|
+
count_bg_color=self.count_bg_color,
|
|
301
288
|
)
|
|
302
289
|
|
|
303
290
|
self.im0 = cv2.addWeighted(self.im0, 1 - self.heatmap_alpha, heatmap_colored, self.heatmap_alpha, 0)
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
2
|
|
|
3
3
|
from collections import defaultdict
|
|
4
|
-
|
|
5
4
|
import cv2
|
|
6
5
|
|
|
7
6
|
from ultralytics.utils.checks import check_imshow, check_requirements
|
|
@@ -47,7 +46,7 @@ class ObjectCounter:
|
|
|
47
46
|
self.class_wise_count = {}
|
|
48
47
|
self.count_txt_thickness = 0
|
|
49
48
|
self.count_txt_color = (255, 255, 255)
|
|
50
|
-
self.
|
|
49
|
+
self.count_bg_color = (255, 255, 255)
|
|
51
50
|
self.cls_txtdisplay_gap = 50
|
|
52
51
|
self.fontsize = 0.6
|
|
53
52
|
|
|
@@ -65,16 +64,14 @@ class ObjectCounter:
|
|
|
65
64
|
classes_names,
|
|
66
65
|
reg_pts,
|
|
67
66
|
count_reg_color=(255, 0, 255),
|
|
67
|
+
count_txt_color=(0, 0, 0),
|
|
68
|
+
count_bg_color=(255, 255, 255),
|
|
68
69
|
line_thickness=2,
|
|
69
70
|
track_thickness=2,
|
|
70
71
|
view_img=False,
|
|
71
72
|
view_in_counts=True,
|
|
72
73
|
view_out_counts=True,
|
|
73
74
|
draw_tracks=False,
|
|
74
|
-
count_txt_thickness=3,
|
|
75
|
-
count_txt_color=(255, 255, 255),
|
|
76
|
-
fontsize=0.8,
|
|
77
|
-
line_color=(255, 255, 255),
|
|
78
75
|
track_color=None,
|
|
79
76
|
region_thickness=5,
|
|
80
77
|
line_dist_thresh=15,
|
|
@@ -92,10 +89,8 @@ class ObjectCounter:
|
|
|
92
89
|
classes_names (dict): Classes names
|
|
93
90
|
track_thickness (int): Track thickness
|
|
94
91
|
draw_tracks (Bool): draw tracks
|
|
95
|
-
count_txt_thickness (int): Text thickness for object counting display
|
|
96
92
|
count_txt_color (RGB color): count text color value
|
|
97
|
-
|
|
98
|
-
line_color (RGB color): count highlighter line color
|
|
93
|
+
count_bg_color (RGB color): count highlighter line color
|
|
99
94
|
count_reg_color (RGB color): Color of object counting region
|
|
100
95
|
track_color (RGB color): color for tracks
|
|
101
96
|
region_thickness (int): Object counting Region thickness
|
|
@@ -125,10 +120,8 @@ class ObjectCounter:
|
|
|
125
120
|
|
|
126
121
|
self.names = classes_names
|
|
127
122
|
self.track_color = track_color
|
|
128
|
-
self.count_txt_thickness = count_txt_thickness
|
|
129
123
|
self.count_txt_color = count_txt_color
|
|
130
|
-
self.
|
|
131
|
-
self.line_color = line_color
|
|
124
|
+
self.count_bg_color = count_bg_color
|
|
132
125
|
self.region_color = count_reg_color
|
|
133
126
|
self.region_thickness = region_thickness
|
|
134
127
|
self.line_dist_thresh = line_dist_thresh
|
|
@@ -172,6 +165,9 @@ class ObjectCounter:
|
|
|
172
165
|
# Annotator Init and region drawing
|
|
173
166
|
self.annotator = Annotator(self.im0, self.tf, self.names)
|
|
174
167
|
|
|
168
|
+
# Draw region or line
|
|
169
|
+
self.annotator.draw_region(reg_pts=self.reg_pts, color=self.region_color, thickness=self.region_thickness)
|
|
170
|
+
|
|
175
171
|
if tracks[0].boxes.id is not None:
|
|
176
172
|
boxes = tracks[0].boxes.xyxy.cpu()
|
|
177
173
|
clss = tracks[0].boxes.cls.cpu().tolist()
|
|
@@ -220,17 +216,14 @@ class ObjectCounter:
|
|
|
220
216
|
|
|
221
217
|
# Count objects using line
|
|
222
218
|
elif len(self.reg_pts) == 2:
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
if prev_position is not None and is_inside and track_id not in self.count_ids:
|
|
219
|
+
if prev_position is not None and track_id not in self.count_ids:
|
|
226
220
|
distance = Point(track_line[-1]).distance(self.counting_region)
|
|
227
|
-
|
|
228
221
|
if distance < self.line_dist_thresh and track_id not in self.count_ids:
|
|
229
222
|
self.count_ids.append(track_id)
|
|
230
223
|
|
|
231
224
|
if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0:
|
|
232
225
|
self.in_counts += 1
|
|
233
|
-
self.class_wise_count[self.names[cls]]["in"] +=
|
|
226
|
+
self.class_wise_count[self.names[cls]]["in"] += 2
|
|
234
227
|
else:
|
|
235
228
|
self.out_counts += 1
|
|
236
229
|
self.class_wise_count[self.names[cls]]["out"] += 1
|
|
@@ -254,17 +247,13 @@ class ObjectCounter:
|
|
|
254
247
|
if label is not None:
|
|
255
248
|
self.annotator.display_counts(
|
|
256
249
|
counts=label,
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
txt_color=self.count_txt_color,
|
|
260
|
-
line_color=self.line_color,
|
|
261
|
-
classwise_txtgap=self.cls_txtdisplay_gap,
|
|
250
|
+
count_txt_color=self.count_txt_color,
|
|
251
|
+
count_bg_color=self.count_bg_color,
|
|
262
252
|
)
|
|
263
253
|
|
|
264
254
|
def display_frames(self):
|
|
265
255
|
"""Display frame."""
|
|
266
256
|
if self.env_check:
|
|
267
|
-
self.annotator.draw_region(reg_pts=self.reg_pts, color=self.region_color, thickness=self.region_thickness)
|
|
268
257
|
cv2.namedWindow(self.window_name)
|
|
269
258
|
if len(self.reg_pts) == 4: # only add mouse event If user drawn region
|
|
270
259
|
cv2.setMouseCallback(self.window_name, self.mouse_event_for_region, {"region_points": self.reg_pts})
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
|
+
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
|
|
5
|
+
import cv2
|
|
6
|
+
|
|
7
|
+
from ultralytics.utils.checks import check_imshow, check_requirements
|
|
8
|
+
from ultralytics.utils.plotting import Annotator, colors
|
|
9
|
+
|
|
10
|
+
check_requirements("shapely>=2.0.0")
|
|
11
|
+
|
|
12
|
+
from shapely.geometry import Point, Polygon
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class QueueManager:
|
|
16
|
+
"""A class to manage the queue management in real-time video stream based on their tracks."""
|
|
17
|
+
|
|
18
|
+
def __init__(self):
|
|
19
|
+
"""Initializes the queue manager with default values for various tracking and counting parameters."""
|
|
20
|
+
|
|
21
|
+
# Mouse events
|
|
22
|
+
self.is_drawing = False
|
|
23
|
+
self.selected_point = None
|
|
24
|
+
|
|
25
|
+
# Region & Line Information
|
|
26
|
+
self.reg_pts = [(20, 60), (20, 680), (1120, 680), (1120, 60)]
|
|
27
|
+
self.counting_region = None
|
|
28
|
+
self.region_color = (255, 0, 255)
|
|
29
|
+
self.region_thickness = 5
|
|
30
|
+
|
|
31
|
+
# Image and annotation Information
|
|
32
|
+
self.im0 = None
|
|
33
|
+
self.tf = None
|
|
34
|
+
self.view_img = False
|
|
35
|
+
self.view_queue_counts = True
|
|
36
|
+
self.fontsize = 0.6
|
|
37
|
+
|
|
38
|
+
self.names = None # Classes names
|
|
39
|
+
self.annotator = None # Annotator
|
|
40
|
+
self.window_name = "Ultralytics YOLOv8 Queue Manager"
|
|
41
|
+
|
|
42
|
+
# Object counting Information
|
|
43
|
+
self.counts = 0
|
|
44
|
+
self.count_txt_color = (255, 255, 255)
|
|
45
|
+
|
|
46
|
+
# Tracks info
|
|
47
|
+
self.track_history = defaultdict(list)
|
|
48
|
+
self.track_thickness = 2
|
|
49
|
+
self.draw_tracks = False
|
|
50
|
+
self.track_color = None
|
|
51
|
+
|
|
52
|
+
# Check if environment support imshow
|
|
53
|
+
self.env_check = check_imshow(warn=True)
|
|
54
|
+
|
|
55
|
+
def set_args(
|
|
56
|
+
self,
|
|
57
|
+
classes_names,
|
|
58
|
+
reg_pts,
|
|
59
|
+
line_thickness=2,
|
|
60
|
+
track_thickness=2,
|
|
61
|
+
view_img=False,
|
|
62
|
+
region_color=(255, 0, 255),
|
|
63
|
+
view_queue_counts=True,
|
|
64
|
+
draw_tracks=False,
|
|
65
|
+
count_txt_color=(255, 255, 255),
|
|
66
|
+
track_color=None,
|
|
67
|
+
region_thickness=5,
|
|
68
|
+
fontsize=0.7,
|
|
69
|
+
):
|
|
70
|
+
"""
|
|
71
|
+
Configures the Counter's image, bounding box line thickness, and counting region points.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
line_thickness (int): Line thickness for bounding boxes.
|
|
75
|
+
view_img (bool): Flag to control whether to display the video stream.
|
|
76
|
+
view_queue_counts (bool): Flag to control whether to display the counts on video stream.
|
|
77
|
+
reg_pts (list): Initial list of points defining the counting region.
|
|
78
|
+
classes_names (dict): Classes names
|
|
79
|
+
region_color (RGB color): Color of queue region
|
|
80
|
+
track_thickness (int): Track thickness
|
|
81
|
+
draw_tracks (Bool): draw tracks
|
|
82
|
+
count_txt_color (RGB color): count text color value
|
|
83
|
+
track_color (RGB color): color for tracks
|
|
84
|
+
region_thickness (int): Object counting Region thickness
|
|
85
|
+
fontsize (float): Text display font size
|
|
86
|
+
"""
|
|
87
|
+
self.tf = line_thickness
|
|
88
|
+
self.view_img = view_img
|
|
89
|
+
self.view_queue_counts = view_queue_counts
|
|
90
|
+
self.track_thickness = track_thickness
|
|
91
|
+
self.draw_tracks = draw_tracks
|
|
92
|
+
self.region_color = region_color
|
|
93
|
+
|
|
94
|
+
if len(reg_pts) >= 3:
|
|
95
|
+
print("Queue region initiated...")
|
|
96
|
+
self.reg_pts = reg_pts
|
|
97
|
+
self.counting_region = Polygon(self.reg_pts)
|
|
98
|
+
else:
|
|
99
|
+
print("Invalid region points provided...")
|
|
100
|
+
print("Using default region now....")
|
|
101
|
+
self.counting_region = Polygon(self.reg_pts)
|
|
102
|
+
|
|
103
|
+
self.names = classes_names
|
|
104
|
+
self.track_color = track_color
|
|
105
|
+
self.count_txt_color = count_txt_color
|
|
106
|
+
self.region_thickness = region_thickness
|
|
107
|
+
self.fontsize = fontsize
|
|
108
|
+
|
|
109
|
+
def extract_and_process_tracks(self, tracks):
|
|
110
|
+
"""Extracts and processes tracks for queue management in a video stream."""
|
|
111
|
+
|
|
112
|
+
# Annotator Init and queue region drawing
|
|
113
|
+
self.annotator = Annotator(self.im0, self.tf, self.names)
|
|
114
|
+
|
|
115
|
+
if tracks[0].boxes.id is not None:
|
|
116
|
+
boxes = tracks[0].boxes.xyxy.cpu()
|
|
117
|
+
clss = tracks[0].boxes.cls.cpu().tolist()
|
|
118
|
+
track_ids = tracks[0].boxes.id.int().cpu().tolist()
|
|
119
|
+
|
|
120
|
+
# Extract tracks
|
|
121
|
+
for box, track_id, cls in zip(boxes, track_ids, clss):
|
|
122
|
+
# Draw bounding box
|
|
123
|
+
self.annotator.box_label(box, label=f"{self.names[cls]}#{track_id}", color=colors(int(track_id), True))
|
|
124
|
+
|
|
125
|
+
# Draw Tracks
|
|
126
|
+
track_line = self.track_history[track_id]
|
|
127
|
+
track_line.append((float((box[0] + box[2]) / 2), float((box[1] + box[3]) / 2)))
|
|
128
|
+
if len(track_line) > 30:
|
|
129
|
+
track_line.pop(0)
|
|
130
|
+
|
|
131
|
+
# Draw track trails
|
|
132
|
+
if self.draw_tracks:
|
|
133
|
+
self.annotator.draw_centroid_and_tracks(
|
|
134
|
+
track_line,
|
|
135
|
+
color=self.track_color if self.track_color else colors(int(track_id), True),
|
|
136
|
+
track_thickness=self.track_thickness,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None
|
|
140
|
+
|
|
141
|
+
if len(self.reg_pts) >= 3:
|
|
142
|
+
is_inside = self.counting_region.contains(Point(track_line[-1]))
|
|
143
|
+
if prev_position is not None and is_inside:
|
|
144
|
+
self.counts += 1
|
|
145
|
+
|
|
146
|
+
label = "Queue Counts : " + str(self.counts)
|
|
147
|
+
|
|
148
|
+
if label is not None:
|
|
149
|
+
self.annotator.queue_counts_display(
|
|
150
|
+
label,
|
|
151
|
+
points=self.reg_pts,
|
|
152
|
+
region_color=self.region_color,
|
|
153
|
+
txt_color=self.count_txt_color,
|
|
154
|
+
fontsize=self.fontsize,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
self.counts = 0
|
|
158
|
+
self.display_frames()
|
|
159
|
+
|
|
160
|
+
def display_frames(self):
|
|
161
|
+
"""Display frame."""
|
|
162
|
+
if self.env_check:
|
|
163
|
+
self.annotator.draw_region(reg_pts=self.reg_pts, thickness=self.region_thickness, color=self.region_color)
|
|
164
|
+
cv2.namedWindow(self.window_name)
|
|
165
|
+
cv2.imshow(self.window_name, self.im0)
|
|
166
|
+
# Break Window
|
|
167
|
+
if cv2.waitKey(1) & 0xFF == ord("q"):
|
|
168
|
+
return
|
|
169
|
+
|
|
170
|
+
def process_queue(self, im0, tracks):
|
|
171
|
+
"""
|
|
172
|
+
Main function to start the queue management process.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
im0 (ndarray): Current frame from the video stream.
|
|
176
|
+
tracks (list): List of tracks obtained from the object tracking process.
|
|
177
|
+
"""
|
|
178
|
+
self.im0 = im0 # store image
|
|
179
|
+
self.extract_and_process_tracks(tracks) # draw region even if no objects
|
|
180
|
+
|
|
181
|
+
if self.view_img:
|
|
182
|
+
self.display_frames()
|
|
183
|
+
return self.im0
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
if __name__ == "__main__":
|
|
187
|
+
QueueManager()
|
ultralytics/utils/__init__.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
2
|
|
|
3
3
|
import contextlib
|
|
4
|
+
import importlib.metadata
|
|
4
5
|
import inspect
|
|
5
6
|
import logging.config
|
|
6
7
|
import os
|
|
@@ -42,6 +43,8 @@ TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}" if VERBOSE else None # tqdm bar form
|
|
|
42
43
|
LOGGING_NAME = "ultralytics"
|
|
43
44
|
MACOS, LINUX, WINDOWS = (platform.system() == x for x in ["Darwin", "Linux", "Windows"]) # environment booleans
|
|
44
45
|
ARM64 = platform.machine() in {"arm64", "aarch64"} # ARM64 booleans
|
|
46
|
+
PYTHON_VERSION = platform.python_version()
|
|
47
|
+
TORCHVISION_VERSION = importlib.metadata.version("torchvision") # faster than importing torchvision
|
|
45
48
|
HELP_MSG = """
|
|
46
49
|
Usage examples for running YOLOv8:
|
|
47
50
|
|
|
@@ -230,37 +233,42 @@ def plt_settings(rcparams=None, backend="Agg"):
|
|
|
230
233
|
return decorator
|
|
231
234
|
|
|
232
235
|
|
|
233
|
-
def set_logging(name=LOGGING_NAME, verbose=True):
|
|
234
|
-
"""Sets up logging for the given name with UTF-8 encoding support
|
|
236
|
+
def set_logging(name="LOGGING_NAME", verbose=True):
|
|
237
|
+
"""Sets up logging for the given name with UTF-8 encoding support, ensuring compatibility across different
|
|
238
|
+
environments.
|
|
239
|
+
"""
|
|
235
240
|
level = logging.INFO if verbose and RANK in {-1, 0} else logging.ERROR # rank in world for Multi-GPU trainings
|
|
236
241
|
|
|
237
|
-
# Configure the console (stdout) encoding to UTF-8
|
|
242
|
+
# Configure the console (stdout) encoding to UTF-8, with checks for compatibility
|
|
238
243
|
formatter = logging.Formatter("%(message)s") # Default formatter
|
|
239
|
-
if WINDOWS and sys.stdout.encoding != "utf-8":
|
|
244
|
+
if WINDOWS and hasattr(sys.stdout, "encoding") and sys.stdout.encoding != "utf-8":
|
|
245
|
+
|
|
246
|
+
class CustomFormatter(logging.Formatter):
|
|
247
|
+
def format(self, record):
|
|
248
|
+
"""Sets up logging with UTF-8 encoding and configurable verbosity."""
|
|
249
|
+
return emojis(super().format(record))
|
|
250
|
+
|
|
240
251
|
try:
|
|
252
|
+
# Attempt to reconfigure stdout to use UTF-8 encoding if possible
|
|
241
253
|
if hasattr(sys.stdout, "reconfigure"):
|
|
242
254
|
sys.stdout.reconfigure(encoding="utf-8")
|
|
255
|
+
# For environments where reconfigure is not available, wrap stdout in a TextIOWrapper
|
|
243
256
|
elif hasattr(sys.stdout, "buffer"):
|
|
244
257
|
import io
|
|
245
258
|
|
|
246
259
|
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
|
|
247
260
|
else:
|
|
248
|
-
|
|
261
|
+
formatter = CustomFormatter("%(message)s")
|
|
249
262
|
except Exception as e:
|
|
250
263
|
print(f"Creating custom formatter for non UTF-8 environments due to {e}")
|
|
264
|
+
formatter = CustomFormatter("%(message)s")
|
|
251
265
|
|
|
252
|
-
|
|
253
|
-
def format(self, record):
|
|
254
|
-
"""Sets up logging with UTF-8 encoding and configurable verbosity."""
|
|
255
|
-
return emojis(super().format(record))
|
|
256
|
-
|
|
257
|
-
formatter = CustomFormatter("%(message)s") # Use CustomFormatter to eliminate UTF-8 output as last recourse
|
|
258
|
-
|
|
259
|
-
# Create and configure the StreamHandler
|
|
266
|
+
# Create and configure the StreamHandler with the appropriate formatter and level
|
|
260
267
|
stream_handler = logging.StreamHandler(sys.stdout)
|
|
261
268
|
stream_handler.setFormatter(formatter)
|
|
262
269
|
stream_handler.setLevel(level)
|
|
263
270
|
|
|
271
|
+
# Set up the logger
|
|
264
272
|
logger = logging.getLogger(name)
|
|
265
273
|
logger.setLevel(level)
|
|
266
274
|
logger.addHandler(stream_handler)
|
|
@@ -471,7 +479,7 @@ def is_online() -> bool:
|
|
|
471
479
|
|
|
472
480
|
for host in "1.1.1.1", "8.8.8.8", "223.5.5.5": # Cloudflare, Google, AliDNS:
|
|
473
481
|
try:
|
|
474
|
-
test_connection = socket.create_connection(address=(host,
|
|
482
|
+
test_connection = socket.create_connection(address=(host, 80), timeout=2)
|
|
475
483
|
except (socket.timeout, socket.gaierror, OSError):
|
|
476
484
|
continue
|
|
477
485
|
else:
|
ultralytics/utils/benchmarks.py
CHANGED
|
@@ -7,8 +7,6 @@ try:
|
|
|
7
7
|
assert SETTINGS["clearml"] is True # verify integration is enabled
|
|
8
8
|
import clearml
|
|
9
9
|
from clearml import Task
|
|
10
|
-
from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
|
11
|
-
from clearml.binding.matplotlib_bind import PatchedMatplotlib
|
|
12
10
|
|
|
13
11
|
assert hasattr(clearml, "__version__") # verify package is not directory
|
|
14
12
|
|
|
@@ -61,8 +59,11 @@ def on_pretrain_routine_start(trainer):
|
|
|
61
59
|
"""Runs at start of pretraining routine; initializes and connects/ logs task to ClearML."""
|
|
62
60
|
try:
|
|
63
61
|
if task := Task.current_task():
|
|
64
|
-
#
|
|
62
|
+
# WARNING: make sure the automatic pytorch and matplotlib bindings are disabled!
|
|
65
63
|
# We are logging these plots and model files manually in the integration
|
|
64
|
+
from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
|
65
|
+
from clearml.binding.matplotlib_bind import PatchedMatplotlib
|
|
66
|
+
|
|
66
67
|
PatchPyTorchModelIO.update_current_task(None)
|
|
67
68
|
PatchedMatplotlib.update_current_task(None)
|
|
68
69
|
else:
|
|
@@ -9,10 +9,6 @@ try:
|
|
|
9
9
|
import wandb as wb
|
|
10
10
|
|
|
11
11
|
assert hasattr(wb, "__version__") # verify package is not directory
|
|
12
|
-
|
|
13
|
-
import numpy as np
|
|
14
|
-
import pandas as pd
|
|
15
|
-
|
|
16
12
|
_processed_plots = {}
|
|
17
13
|
|
|
18
14
|
except (ImportError, AssertionError):
|
|
@@ -38,7 +34,9 @@ def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall
|
|
|
38
34
|
Returns:
|
|
39
35
|
(wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
|
|
40
36
|
"""
|
|
41
|
-
|
|
37
|
+
import pandas # scope for faster 'import ultralytics'
|
|
38
|
+
|
|
39
|
+
df = pandas.DataFrame({"class": classes, "y": y, "x": x}).round(3)
|
|
42
40
|
fields = {"x": "x", "y": "y", "class": "class"}
|
|
43
41
|
string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
|
|
44
42
|
return wb.plot_table(
|
|
@@ -77,6 +75,8 @@ def _plot_curve(
|
|
|
77
75
|
Note:
|
|
78
76
|
The function leverages the '_custom_table' function to generate the actual visualization.
|
|
79
77
|
"""
|
|
78
|
+
import numpy as np
|
|
79
|
+
|
|
80
80
|
# Create new x
|
|
81
81
|
if names is None:
|
|
82
82
|
names = []
|