ultralytics 8.3.153__py3-none-any.whl → 8.3.154__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_python.py +1 -0
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +2 -0
- ultralytics/engine/predictor.py +1 -1
- ultralytics/engine/validator.py +0 -6
- ultralytics/models/fastsam/val.py +0 -2
- ultralytics/models/rtdetr/val.py +28 -16
- ultralytics/models/yolo/classify/val.py +26 -23
- ultralytics/models/yolo/detect/train.py +4 -7
- ultralytics/models/yolo/detect/val.py +88 -90
- ultralytics/models/yolo/obb/val.py +52 -44
- ultralytics/models/yolo/pose/train.py +1 -35
- ultralytics/models/yolo/pose/val.py +77 -176
- ultralytics/models/yolo/segment/train.py +1 -41
- ultralytics/models/yolo/segment/val.py +64 -176
- ultralytics/models/yolo/yoloe/val.py +2 -1
- ultralytics/nn/autobackend.py +2 -2
- ultralytics/solutions/ai_gym.py +2 -3
- ultralytics/solutions/solutions.py +2 -0
- ultralytics/solutions/templates/similarity-search.html +31 -0
- ultralytics/utils/callbacks/comet.py +1 -1
- ultralytics/utils/metrics.py +146 -317
- ultralytics/utils/ops.py +4 -4
- ultralytics/utils/plotting.py +31 -56
- {ultralytics-8.3.153.dist-info → ultralytics-8.3.154.dist-info}/METADATA +1 -1
- {ultralytics-8.3.153.dist-info → ultralytics-8.3.154.dist-info}/RECORD +30 -30
- {ultralytics-8.3.153.dist-info → ultralytics-8.3.154.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.153.dist-info → ultralytics-8.3.154.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.153.dist-info → ultralytics-8.3.154.dist-info}/licenses/LICENSE +0 -0
- {ultralytics-8.3.153.dist-info → ultralytics-8.3.154.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
from multiprocessing.pool import ThreadPool
|
4
4
|
from pathlib import Path
|
5
|
-
from typing import Any, Dict, List,
|
5
|
+
from typing import Any, Dict, List, Tuple
|
6
6
|
|
7
7
|
import numpy as np
|
8
8
|
import torch
|
@@ -11,8 +11,7 @@ import torch.nn.functional as F
|
|
11
11
|
from ultralytics.models.yolo.detect import DetectionValidator
|
12
12
|
from ultralytics.utils import LOGGER, NUM_THREADS, ops
|
13
13
|
from ultralytics.utils.checks import check_requirements
|
14
|
-
from ultralytics.utils.metrics import SegmentMetrics,
|
15
|
-
from ultralytics.utils.plotting import output_to_target, plot_images
|
14
|
+
from ultralytics.utils.metrics import SegmentMetrics, mask_iou
|
16
15
|
|
17
16
|
|
18
17
|
class SegmentationValidator(DetectionValidator):
|
@@ -47,10 +46,9 @@ class SegmentationValidator(DetectionValidator):
|
|
47
46
|
_callbacks (list, optional): List of callback functions.
|
48
47
|
"""
|
49
48
|
super().__init__(dataloader, save_dir, args, _callbacks)
|
50
|
-
self.plot_masks = None
|
51
49
|
self.process = None
|
52
50
|
self.args.task = "segment"
|
53
|
-
self.metrics = SegmentMetrics(
|
51
|
+
self.metrics = SegmentMetrics()
|
54
52
|
|
55
53
|
def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:
|
56
54
|
"""
|
@@ -74,12 +72,10 @@ class SegmentationValidator(DetectionValidator):
|
|
74
72
|
model (torch.nn.Module): Model to validate.
|
75
73
|
"""
|
76
74
|
super().init_metrics(model)
|
77
|
-
self.plot_masks = []
|
78
75
|
if self.args.save_json:
|
79
76
|
check_requirements("pycocotools>=2.0.6")
|
80
77
|
# More accurate vs faster
|
81
78
|
self.process = ops.process_mask_native if self.args.save_json or self.args.save_txt else ops.process_mask
|
82
|
-
self.stats = dict(tp_m=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
|
83
79
|
|
84
80
|
def get_desc(self) -> str:
|
85
81
|
"""Return a formatted description of evaluation metrics."""
|
@@ -97,7 +93,7 @@ class SegmentationValidator(DetectionValidator):
|
|
97
93
|
"mAP50-95)",
|
98
94
|
)
|
99
95
|
|
100
|
-
def postprocess(self, preds: List[torch.Tensor]) ->
|
96
|
+
def postprocess(self, preds: List[torch.Tensor]) -> List[Dict[str, torch.Tensor]]:
|
101
97
|
"""
|
102
98
|
Post-process YOLO predictions and return output detections with proto.
|
103
99
|
|
@@ -105,12 +101,19 @@ class SegmentationValidator(DetectionValidator):
|
|
105
101
|
preds (List[torch.Tensor]): Raw predictions from the model.
|
106
102
|
|
107
103
|
Returns:
|
108
|
-
|
109
|
-
proto (torch.Tensor): Prototype masks for segmentation.
|
104
|
+
List[Dict[str, torch.Tensor]]: Processed detection predictions with masks.
|
110
105
|
"""
|
111
|
-
p = super().postprocess(preds[0])
|
112
106
|
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
|
113
|
-
|
107
|
+
preds = super().postprocess(preds[0])
|
108
|
+
imgsz = [4 * x for x in proto.shape[2:]] # get image size from proto
|
109
|
+
for i, pred in enumerate(preds):
|
110
|
+
coefficient = pred.pop("extra")
|
111
|
+
pred["masks"] = (
|
112
|
+
self.process(proto[i], coefficient, pred["bboxes"], shape=imgsz)
|
113
|
+
if len(coefficient)
|
114
|
+
else torch.zeros((0, imgsz[0], imgsz[1]), dtype=torch.uint8, device=pred["bboxes"].device)
|
115
|
+
)
|
116
|
+
return preds
|
114
117
|
|
115
118
|
def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:
|
116
119
|
"""
|
@@ -128,142 +131,56 @@ class SegmentationValidator(DetectionValidator):
|
|
128
131
|
prepared_batch["masks"] = batch["masks"][midx]
|
129
132
|
return prepared_batch
|
130
133
|
|
131
|
-
def _prepare_pred(
|
132
|
-
self, pred: torch.Tensor, pbatch: Dict[str, Any], proto: torch.Tensor
|
133
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
134
|
+
def _prepare_pred(self, pred: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
|
134
135
|
"""
|
135
136
|
Prepare predictions for evaluation by processing bounding boxes and masks.
|
136
137
|
|
137
138
|
Args:
|
138
|
-
pred (torch.Tensor):
|
139
|
+
pred (Dict[str, torch.Tensor]): Post-processed predictions from the model.
|
139
140
|
pbatch (Dict[str, Any]): Prepared batch information.
|
140
|
-
proto (torch.Tensor): Prototype masks for segmentation.
|
141
141
|
|
142
142
|
Returns:
|
143
|
-
|
144
|
-
pred_masks (torch.Tensor): Processed mask predictions.
|
143
|
+
Dict[str, torch.Tensor]: Processed bounding box predictions.
|
145
144
|
"""
|
146
145
|
predn = super()._prepare_pred(pred, pbatch)
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
Args:
|
155
|
-
preds (Tuple[List[torch.Tensor], torch.Tensor]): List of predictions from the model.
|
156
|
-
batch (Dict[str, Any]): Batch data containing ground truth.
|
157
|
-
"""
|
158
|
-
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
|
159
|
-
self.seen += 1
|
160
|
-
npr = len(pred)
|
161
|
-
stat = dict(
|
162
|
-
conf=torch.zeros(0, device=self.device),
|
163
|
-
pred_cls=torch.zeros(0, device=self.device),
|
164
|
-
tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
|
165
|
-
tp_m=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
|
146
|
+
predn["masks"] = pred["masks"]
|
147
|
+
if self.args.save_json and len(predn["masks"]):
|
148
|
+
coco_masks = torch.as_tensor(pred["masks"], dtype=torch.uint8)
|
149
|
+
coco_masks = ops.scale_image(
|
150
|
+
coco_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
|
151
|
+
pbatch["ori_shape"],
|
152
|
+
ratio_pad=pbatch["ratio_pad"],
|
166
153
|
)
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
stat["target_img"] = cls.unique()
|
172
|
-
if npr == 0:
|
173
|
-
if nl:
|
174
|
-
for k in self.stats.keys():
|
175
|
-
self.stats[k].append(stat[k])
|
176
|
-
if self.args.plots:
|
177
|
-
self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
|
178
|
-
continue
|
179
|
-
|
180
|
-
# Masks
|
181
|
-
gt_masks = pbatch.pop("masks")
|
182
|
-
# Predictions
|
183
|
-
if self.args.single_cls:
|
184
|
-
pred[:, 5] = 0
|
185
|
-
predn, pred_masks = self._prepare_pred(pred, pbatch, proto)
|
186
|
-
stat["conf"] = predn[:, 4]
|
187
|
-
stat["pred_cls"] = predn[:, 5]
|
188
|
-
|
189
|
-
# Evaluate
|
190
|
-
if nl:
|
191
|
-
stat["tp"] = self._process_batch(predn, bbox, cls)
|
192
|
-
stat["tp_m"] = self._process_batch(
|
193
|
-
predn, bbox, cls, pred_masks, gt_masks, self.args.overlap_mask, masks=True
|
194
|
-
)
|
195
|
-
if self.args.plots:
|
196
|
-
self.confusion_matrix.process_batch(predn, bbox, cls)
|
197
|
-
|
198
|
-
for k in self.stats.keys():
|
199
|
-
self.stats[k].append(stat[k])
|
200
|
-
|
201
|
-
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
|
202
|
-
if self.args.plots and self.batch_i < 3:
|
203
|
-
self.plot_masks.append(pred_masks[:50].cpu()) # Limit plotted items for speed
|
204
|
-
if pred_masks.shape[0] > 50:
|
205
|
-
LOGGER.warning("Limiting validation plots to first 50 items per image for speed...")
|
206
|
-
|
207
|
-
# Save
|
208
|
-
if self.args.save_json:
|
209
|
-
self.pred_to_json(
|
210
|
-
predn,
|
211
|
-
batch["im_file"][si],
|
212
|
-
ops.scale_image(
|
213
|
-
pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
|
214
|
-
pbatch["ori_shape"],
|
215
|
-
ratio_pad=batch["ratio_pad"][si],
|
216
|
-
),
|
217
|
-
)
|
218
|
-
if self.args.save_txt:
|
219
|
-
self.save_one_txt(
|
220
|
-
predn,
|
221
|
-
pred_masks,
|
222
|
-
self.args.save_conf,
|
223
|
-
pbatch["ori_shape"],
|
224
|
-
self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
|
225
|
-
)
|
226
|
-
|
227
|
-
def _process_batch(
|
228
|
-
self,
|
229
|
-
detections: torch.Tensor,
|
230
|
-
gt_bboxes: torch.Tensor,
|
231
|
-
gt_cls: torch.Tensor,
|
232
|
-
pred_masks: Optional[torch.Tensor] = None,
|
233
|
-
gt_masks: Optional[torch.Tensor] = None,
|
234
|
-
overlap: Optional[bool] = False,
|
235
|
-
masks: Optional[bool] = False,
|
236
|
-
) -> torch.Tensor:
|
154
|
+
predn["coco_masks"] = coco_masks
|
155
|
+
return predn
|
156
|
+
|
157
|
+
def _process_batch(self, preds: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, np.ndarray]:
|
237
158
|
"""
|
238
159
|
Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
|
239
160
|
|
240
161
|
Args:
|
241
|
-
|
242
|
-
|
243
|
-
gt_bboxes (torch.Tensor): Tensor of shape (M, 4) representing ground truth bounding box coordinates.
|
244
|
-
Each row is of the format [x1, y1, x2, y2].
|
245
|
-
gt_cls (torch.Tensor): Tensor of shape (M,) representing ground truth class indices.
|
246
|
-
pred_masks (torch.Tensor, optional): Tensor representing predicted masks, if available. The shape should
|
247
|
-
match the ground truth masks.
|
248
|
-
gt_masks (torch.Tensor, optional): Tensor of shape (M, H, W) representing ground truth masks, if available.
|
249
|
-
overlap (bool, optional): Flag indicating if overlapping masks should be considered.
|
250
|
-
masks (bool, optional): Flag indicating if the batch contains mask data.
|
162
|
+
preds (Dict[str, torch.Tensor]): Dictionary containing predictions with keys like 'cls' and 'masks'.
|
163
|
+
batch (Dict[str, Any]): Dictionary containing batch data with keys like 'cls' and 'masks'.
|
251
164
|
|
252
165
|
Returns:
|
253
|
-
(
|
166
|
+
(Dict[str, np.ndarray]): A dictionary containing correct prediction matrices including 'tp_m' for mask IoU.
|
254
167
|
|
255
168
|
Notes:
|
256
169
|
- If `masks` is True, the function computes IoU between predicted and ground truth masks.
|
257
170
|
- If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
|
258
171
|
|
259
172
|
Examples:
|
260
|
-
>>>
|
261
|
-
>>>
|
262
|
-
>>>
|
263
|
-
>>> correct_preds = validator._process_batch(detections, gt_bboxes, gt_cls)
|
173
|
+
>>> preds = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
|
174
|
+
>>> batch = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
|
175
|
+
>>> correct_preds = validator._process_batch(preds, batch)
|
264
176
|
"""
|
265
|
-
|
266
|
-
|
177
|
+
tp = super()._process_batch(preds, batch)
|
178
|
+
gt_cls, gt_masks = batch["cls"], batch["masks"]
|
179
|
+
if len(gt_cls) == 0 or len(preds["cls"]) == 0:
|
180
|
+
tp_m = np.zeros((len(preds["cls"]), self.niou), dtype=bool)
|
181
|
+
else:
|
182
|
+
pred_masks = preds["masks"]
|
183
|
+
if self.args.overlap_mask:
|
267
184
|
nl = len(gt_cls)
|
268
185
|
index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
|
269
186
|
gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
|
@@ -272,60 +189,32 @@ class SegmentationValidator(DetectionValidator):
|
|
272
189
|
gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0]
|
273
190
|
gt_masks = gt_masks.gt_(0.5)
|
274
191
|
iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
192
|
+
tp_m = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
|
193
|
+
tp.update({"tp_m": tp_m}) # update tp with mask IoU
|
194
|
+
return tp
|
279
195
|
|
280
|
-
def
|
281
|
-
"""
|
282
|
-
Plot validation samples with bounding box labels and masks.
|
283
|
-
|
284
|
-
Args:
|
285
|
-
batch (Dict[str, Any]): Batch containing images and annotations.
|
286
|
-
ni (int): Batch index.
|
287
|
-
"""
|
288
|
-
plot_images(
|
289
|
-
batch["img"],
|
290
|
-
batch["batch_idx"],
|
291
|
-
batch["cls"].squeeze(-1),
|
292
|
-
batch["bboxes"],
|
293
|
-
masks=batch["masks"],
|
294
|
-
paths=batch["im_file"],
|
295
|
-
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
296
|
-
names=self.names,
|
297
|
-
on_plot=self.on_plot,
|
298
|
-
)
|
299
|
-
|
300
|
-
def plot_predictions(self, batch: Dict[str, Any], preds: List[torch.Tensor], ni: int) -> None:
|
196
|
+
def plot_predictions(self, batch: Dict[str, Any], preds: List[Dict[str, torch.Tensor]], ni: int) -> None:
|
301
197
|
"""
|
302
198
|
Plot batch predictions with masks and bounding boxes.
|
303
199
|
|
304
200
|
Args:
|
305
201
|
batch (Dict[str, Any]): Batch containing images and annotations.
|
306
|
-
preds (List[torch.Tensor]): List of predictions from the model.
|
202
|
+
preds (List[Dict[str, torch.Tensor]]): List of predictions from the model.
|
307
203
|
ni (int): Batch index.
|
308
204
|
"""
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
) # pred
|
318
|
-
self.plot_masks.clear()
|
319
|
-
|
320
|
-
def save_one_txt(
|
321
|
-
self, predn: torch.Tensor, pred_masks: torch.Tensor, save_conf: bool, shape: Tuple[int, int], file: Path
|
322
|
-
) -> None:
|
205
|
+
for p in preds:
|
206
|
+
masks = p["masks"]
|
207
|
+
if masks.shape[0] > 50:
|
208
|
+
LOGGER.warning("Limiting validation plots to first 50 items per image for speed...")
|
209
|
+
p["masks"] = torch.as_tensor(masks[:50], dtype=torch.uint8).cpu()
|
210
|
+
super().plot_predictions(batch, preds, ni, max_det=50) # plot bboxes
|
211
|
+
|
212
|
+
def save_one_txt(self, predn: torch.Tensor, save_conf: bool, shape: Tuple[int, int], file: Path) -> None:
|
323
213
|
"""
|
324
214
|
Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
325
215
|
|
326
216
|
Args:
|
327
217
|
predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
|
328
|
-
pred_masks (torch.Tensor): Predicted masks.
|
329
218
|
save_conf (bool): Whether to save confidence scores.
|
330
219
|
shape (Tuple[int, int]): Shape of the original image.
|
331
220
|
file (Path): File path to save the detections.
|
@@ -336,18 +225,17 @@ class SegmentationValidator(DetectionValidator):
|
|
336
225
|
np.zeros((shape[0], shape[1]), dtype=np.uint8),
|
337
226
|
path=None,
|
338
227
|
names=self.names,
|
339
|
-
boxes=predn[
|
340
|
-
masks=
|
228
|
+
boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
|
229
|
+
masks=torch.as_tensor(predn["masks"], dtype=torch.uint8),
|
341
230
|
).save_txt(file, save_conf=save_conf)
|
342
231
|
|
343
|
-
def pred_to_json(self, predn: torch.Tensor, filename: str
|
232
|
+
def pred_to_json(self, predn: torch.Tensor, filename: str) -> None:
|
344
233
|
"""
|
345
234
|
Save one JSON result for COCO evaluation.
|
346
235
|
|
347
236
|
Args:
|
348
|
-
predn (torch.Tensor): Predictions
|
237
|
+
predn (Dict[str, torch.Tensor]): Predictions containing bboxes, masks, confidence scores, and classes.
|
349
238
|
filename (str): Image filename.
|
350
|
-
pred_masks (numpy.ndarray): Predicted masks.
|
351
239
|
|
352
240
|
Examples:
|
353
241
|
>>> result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
|
@@ -362,18 +250,18 @@ class SegmentationValidator(DetectionValidator):
|
|
362
250
|
|
363
251
|
stem = Path(filename).stem
|
364
252
|
image_id = int(stem) if stem.isnumeric() else stem
|
365
|
-
box = ops.xyxy2xywh(predn[
|
253
|
+
box = ops.xyxy2xywh(predn["bboxes"]) # xywh
|
366
254
|
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
|
367
|
-
pred_masks = np.transpose(
|
255
|
+
pred_masks = np.transpose(predn["coco_masks"], (2, 0, 1))
|
368
256
|
with ThreadPool(NUM_THREADS) as pool:
|
369
257
|
rles = pool.map(single_encode, pred_masks)
|
370
|
-
for i, (
|
258
|
+
for i, (b, s, c) in enumerate(zip(box.tolist(), predn["conf"].tolist(), predn["cls"].tolist())):
|
371
259
|
self.jdict.append(
|
372
260
|
{
|
373
261
|
"image_id": image_id,
|
374
|
-
"category_id": self.class_map[int(
|
262
|
+
"category_id": self.class_map[int(c)],
|
375
263
|
"bbox": [round(x, 3) for x in b],
|
376
|
-
"score": round(
|
264
|
+
"score": round(s, 5),
|
377
265
|
"segmentation": rles[i],
|
378
266
|
}
|
379
267
|
)
|
@@ -1,6 +1,7 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
from copy import deepcopy
|
4
|
+
from pathlib import Path
|
4
5
|
from typing import Any, Dict, Optional, Union
|
5
6
|
|
6
7
|
import torch
|
@@ -182,7 +183,7 @@ class YOLOEDetectValidator(DetectionValidator):
|
|
182
183
|
assert load_vp, "Refer data is only used for visual prompt validation."
|
183
184
|
self.device = select_device(self.args.device)
|
184
185
|
|
185
|
-
if isinstance(model, str):
|
186
|
+
if isinstance(model, (str, Path)):
|
186
187
|
from ultralytics.nn.tasks import attempt_load_weights
|
187
188
|
|
188
189
|
model = attempt_load_weights(model, device=self.device, inplace=True)
|
ultralytics/nn/autobackend.py
CHANGED
@@ -196,9 +196,9 @@ class AutoBackend(nn.Module):
|
|
196
196
|
|
197
197
|
# In-memory PyTorch model
|
198
198
|
if nn_module:
|
199
|
-
if fuse:
|
200
|
-
weights = weights.fuse(verbose=verbose) # fuse before move to gpu
|
201
199
|
model = weights.to(device)
|
200
|
+
if fuse:
|
201
|
+
model = model.fuse(verbose=verbose)
|
202
202
|
if hasattr(model, "kpt_shape"):
|
203
203
|
kpt_shape = model.kpt_shape # pose-only
|
204
204
|
stride = max(int(model.stride.max()), 32) # model stride
|
ultralytics/solutions/ai_gym.py
CHANGED
@@ -76,11 +76,10 @@ class AIGym(BaseSolution):
|
|
76
76
|
self.extract_tracks(im0) # Extract tracks (bounding boxes, classes, and masks)
|
77
77
|
|
78
78
|
if len(self.boxes):
|
79
|
-
kpt_data = self.tracks.keypoints.data
|
79
|
+
kpt_data = self.tracks.keypoints.data
|
80
80
|
|
81
81
|
for i, k in enumerate(kpt_data):
|
82
|
-
|
83
|
-
state = self.states[track_id] # get state details
|
82
|
+
state = self.states[self.track_ids[i]] # get state details
|
84
83
|
# Get keypoints and estimate the angle
|
85
84
|
state["angle"] = annotator.estimate_pose_angle(*[k[int(idx)] for idx in self.kpts])
|
86
85
|
annotator.draw_specific_kpts(k, self.kpts, radius=self.line_width * 3)
|
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
import math
|
4
4
|
from collections import defaultdict
|
5
|
+
from functools import lru_cache
|
5
6
|
from typing import Any, Dict, List, Optional, Tuple
|
6
7
|
|
7
8
|
import cv2
|
@@ -423,6 +424,7 @@ class SolutionAnnotator(Annotator):
|
|
423
424
|
text_y_offset = rect_y2
|
424
425
|
|
425
426
|
@staticmethod
|
427
|
+
@lru_cache(maxsize=256)
|
426
428
|
def estimate_pose_angle(a: List[float], b: List[float], c: List[float]) -> float:
|
427
429
|
"""
|
428
430
|
Calculate the angle between three points for workout monitoring.
|
@@ -126,6 +126,20 @@
|
|
126
126
|
}
|
127
127
|
</style>
|
128
128
|
</head>
|
129
|
+
<script>
|
130
|
+
function filterResults(k) {
|
131
|
+
const cards = document.querySelectorAll(".grid .card");
|
132
|
+
cards.forEach((card, idx) => {
|
133
|
+
card.style.display = idx < k ? "block" : "none";
|
134
|
+
});
|
135
|
+
const buttons = document.querySelectorAll(".topk-btn");
|
136
|
+
buttons.forEach((btn) => btn.classList.remove("active"));
|
137
|
+
event.target.classList.add("active");
|
138
|
+
}
|
139
|
+
document.addEventListener("DOMContentLoaded", () => {
|
140
|
+
filterResults(10);
|
141
|
+
});
|
142
|
+
</script>
|
129
143
|
<body>
|
130
144
|
<div style="text-align: center; margin-bottom: 1rem">
|
131
145
|
<img
|
@@ -146,6 +160,23 @@
|
|
146
160
|
required
|
147
161
|
/>
|
148
162
|
<button type="submit">Search</button>
|
163
|
+
{% if results %}
|
164
|
+
<div class="top-k-buttons">
|
165
|
+
<button type="button" class="topk-btn" onclick="filterResults(5)">
|
166
|
+
Top 5
|
167
|
+
</button>
|
168
|
+
<button
|
169
|
+
type="button"
|
170
|
+
class="topk-btn active"
|
171
|
+
onclick="filterResults(10)"
|
172
|
+
>
|
173
|
+
Top 10
|
174
|
+
</button>
|
175
|
+
<button type="button" class="topk-btn" onclick="filterResults(30)">
|
176
|
+
Top 30
|
177
|
+
</button>
|
178
|
+
</div>
|
179
|
+
{% endif %}
|
149
180
|
</form>
|
150
181
|
|
151
182
|
<!-- Search results grid -->
|
@@ -457,7 +457,7 @@ def _log_plots(experiment, trainer) -> None:
|
|
457
457
|
>>> _log_plots(experiment, trainer)
|
458
458
|
"""
|
459
459
|
plot_filenames = None
|
460
|
-
if isinstance(trainer.validator.metrics, SegmentMetrics)
|
460
|
+
if isinstance(trainer.validator.metrics, SegmentMetrics):
|
461
461
|
plot_filenames = [
|
462
462
|
trainer.save_dir / f"{prefix}{plots}.png"
|
463
463
|
for plots in EVALUATION_PLOT_NAMES
|