dgenerate-ultralytics-headless 8.3.152__py3-none-any.whl → 8.3.154__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {dgenerate_ultralytics_headless-8.3.152.dist-info → dgenerate_ultralytics_headless-8.3.154.dist-info}/METADATA +1 -1
  2. {dgenerate_ultralytics_headless-8.3.152.dist-info → dgenerate_ultralytics_headless-8.3.154.dist-info}/RECORD +30 -30
  3. tests/test_python.py +1 -0
  4. ultralytics/__init__.py +1 -1
  5. ultralytics/cfg/__init__.py +2 -0
  6. ultralytics/engine/predictor.py +1 -1
  7. ultralytics/engine/validator.py +0 -6
  8. ultralytics/models/fastsam/val.py +0 -2
  9. ultralytics/models/rtdetr/val.py +28 -16
  10. ultralytics/models/yolo/classify/val.py +26 -23
  11. ultralytics/models/yolo/detect/train.py +4 -7
  12. ultralytics/models/yolo/detect/val.py +88 -90
  13. ultralytics/models/yolo/obb/val.py +52 -44
  14. ultralytics/models/yolo/pose/train.py +1 -35
  15. ultralytics/models/yolo/pose/val.py +77 -176
  16. ultralytics/models/yolo/segment/train.py +1 -41
  17. ultralytics/models/yolo/segment/val.py +64 -176
  18. ultralytics/models/yolo/yoloe/val.py +2 -1
  19. ultralytics/nn/autobackend.py +2 -2
  20. ultralytics/solutions/ai_gym.py +2 -3
  21. ultralytics/solutions/solutions.py +2 -0
  22. ultralytics/solutions/templates/similarity-search.html +31 -0
  23. ultralytics/utils/callbacks/comet.py +1 -1
  24. ultralytics/utils/metrics.py +152 -307
  25. ultralytics/utils/ops.py +4 -4
  26. ultralytics/utils/plotting.py +31 -56
  27. {dgenerate_ultralytics_headless-8.3.152.dist-info → dgenerate_ultralytics_headless-8.3.154.dist-info}/WHEEL +0 -0
  28. {dgenerate_ultralytics_headless-8.3.152.dist-info → dgenerate_ultralytics_headless-8.3.154.dist-info}/entry_points.txt +0 -0
  29. {dgenerate_ultralytics_headless-8.3.152.dist-info → dgenerate_ultralytics_headless-8.3.154.dist-info}/licenses/LICENSE +0 -0
  30. {dgenerate_ultralytics_headless-8.3.152.dist-info → dgenerate_ultralytics_headless-8.3.154.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@
2
2
 
3
3
  from multiprocessing.pool import ThreadPool
4
4
  from pathlib import Path
5
- from typing import Any, Dict, List, Optional, Tuple
5
+ from typing import Any, Dict, List, Tuple
6
6
 
7
7
  import numpy as np
8
8
  import torch
@@ -11,8 +11,7 @@ import torch.nn.functional as F
11
11
  from ultralytics.models.yolo.detect import DetectionValidator
12
12
  from ultralytics.utils import LOGGER, NUM_THREADS, ops
13
13
  from ultralytics.utils.checks import check_requirements
14
- from ultralytics.utils.metrics import SegmentMetrics, box_iou, mask_iou
15
- from ultralytics.utils.plotting import output_to_target, plot_images
14
+ from ultralytics.utils.metrics import SegmentMetrics, mask_iou
16
15
 
17
16
 
18
17
  class SegmentationValidator(DetectionValidator):
@@ -47,10 +46,9 @@ class SegmentationValidator(DetectionValidator):
47
46
  _callbacks (list, optional): List of callback functions.
48
47
  """
49
48
  super().__init__(dataloader, save_dir, args, _callbacks)
50
- self.plot_masks = None
51
49
  self.process = None
52
50
  self.args.task = "segment"
53
- self.metrics = SegmentMetrics(save_dir=self.save_dir)
51
+ self.metrics = SegmentMetrics()
54
52
 
55
53
  def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:
56
54
  """
@@ -74,12 +72,10 @@ class SegmentationValidator(DetectionValidator):
74
72
  model (torch.nn.Module): Model to validate.
75
73
  """
76
74
  super().init_metrics(model)
77
- self.plot_masks = []
78
75
  if self.args.save_json:
79
76
  check_requirements("pycocotools>=2.0.6")
80
77
  # More accurate vs faster
81
78
  self.process = ops.process_mask_native if self.args.save_json or self.args.save_txt else ops.process_mask
82
- self.stats = dict(tp_m=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
83
79
 
84
80
  def get_desc(self) -> str:
85
81
  """Return a formatted description of evaluation metrics."""
@@ -97,7 +93,7 @@ class SegmentationValidator(DetectionValidator):
97
93
  "mAP50-95)",
98
94
  )
99
95
 
100
- def postprocess(self, preds: List[torch.Tensor]) -> Tuple[List[torch.Tensor], torch.Tensor]:
96
+ def postprocess(self, preds: List[torch.Tensor]) -> List[Dict[str, torch.Tensor]]:
101
97
  """
102
98
  Post-process YOLO predictions and return output detections with proto.
103
99
 
@@ -105,12 +101,19 @@ class SegmentationValidator(DetectionValidator):
105
101
  preds (List[torch.Tensor]): Raw predictions from the model.
106
102
 
107
103
  Returns:
108
- p (List[torch.Tensor]): Processed detection predictions.
109
- proto (torch.Tensor): Prototype masks for segmentation.
104
+ List[Dict[str, torch.Tensor]]: Processed detection predictions with masks.
110
105
  """
111
- p = super().postprocess(preds[0])
112
106
  proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
113
- return p, proto
107
+ preds = super().postprocess(preds[0])
108
+ imgsz = [4 * x for x in proto.shape[2:]] # get image size from proto
109
+ for i, pred in enumerate(preds):
110
+ coefficient = pred.pop("extra")
111
+ pred["masks"] = (
112
+ self.process(proto[i], coefficient, pred["bboxes"], shape=imgsz)
113
+ if len(coefficient)
114
+ else torch.zeros((0, imgsz[0], imgsz[1]), dtype=torch.uint8, device=pred["bboxes"].device)
115
+ )
116
+ return preds
114
117
 
115
118
  def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:
116
119
  """
@@ -128,142 +131,56 @@ class SegmentationValidator(DetectionValidator):
128
131
  prepared_batch["masks"] = batch["masks"][midx]
129
132
  return prepared_batch
130
133
 
131
- def _prepare_pred(
132
- self, pred: torch.Tensor, pbatch: Dict[str, Any], proto: torch.Tensor
133
- ) -> Tuple[torch.Tensor, torch.Tensor]:
134
+ def _prepare_pred(self, pred: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
134
135
  """
135
136
  Prepare predictions for evaluation by processing bounding boxes and masks.
136
137
 
137
138
  Args:
138
- pred (torch.Tensor): Raw predictions from the model.
139
+ pred (Dict[str, torch.Tensor]): Post-processed predictions from the model.
139
140
  pbatch (Dict[str, Any]): Prepared batch information.
140
- proto (torch.Tensor): Prototype masks for segmentation.
141
141
 
142
142
  Returns:
143
- predn (torch.Tensor): Processed bounding box predictions.
144
- pred_masks (torch.Tensor): Processed mask predictions.
143
+ Dict[str, torch.Tensor]: Processed bounding box predictions.
145
144
  """
146
145
  predn = super()._prepare_pred(pred, pbatch)
147
- pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch["imgsz"])
148
- return predn, pred_masks
149
-
150
- def update_metrics(self, preds: Tuple[List[torch.Tensor], torch.Tensor], batch: Dict[str, Any]) -> None:
151
- """
152
- Update metrics with the current batch predictions and targets.
153
-
154
- Args:
155
- preds (Tuple[List[torch.Tensor], torch.Tensor]): List of predictions from the model.
156
- batch (Dict[str, Any]): Batch data containing ground truth.
157
- """
158
- for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
159
- self.seen += 1
160
- npr = len(pred)
161
- stat = dict(
162
- conf=torch.zeros(0, device=self.device),
163
- pred_cls=torch.zeros(0, device=self.device),
164
- tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
165
- tp_m=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
146
+ predn["masks"] = pred["masks"]
147
+ if self.args.save_json and len(predn["masks"]):
148
+ coco_masks = torch.as_tensor(pred["masks"], dtype=torch.uint8)
149
+ coco_masks = ops.scale_image(
150
+ coco_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
151
+ pbatch["ori_shape"],
152
+ ratio_pad=pbatch["ratio_pad"],
166
153
  )
167
- pbatch = self._prepare_batch(si, batch)
168
- cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
169
- nl = len(cls)
170
- stat["target_cls"] = cls
171
- stat["target_img"] = cls.unique()
172
- if npr == 0:
173
- if nl:
174
- for k in self.stats.keys():
175
- self.stats[k].append(stat[k])
176
- if self.args.plots:
177
- self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
178
- continue
179
-
180
- # Masks
181
- gt_masks = pbatch.pop("masks")
182
- # Predictions
183
- if self.args.single_cls:
184
- pred[:, 5] = 0
185
- predn, pred_masks = self._prepare_pred(pred, pbatch, proto)
186
- stat["conf"] = predn[:, 4]
187
- stat["pred_cls"] = predn[:, 5]
188
-
189
- # Evaluate
190
- if nl:
191
- stat["tp"] = self._process_batch(predn, bbox, cls)
192
- stat["tp_m"] = self._process_batch(
193
- predn, bbox, cls, pred_masks, gt_masks, self.args.overlap_mask, masks=True
194
- )
195
- if self.args.plots:
196
- self.confusion_matrix.process_batch(predn, bbox, cls)
197
-
198
- for k in self.stats.keys():
199
- self.stats[k].append(stat[k])
200
-
201
- pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
202
- if self.args.plots and self.batch_i < 3:
203
- self.plot_masks.append(pred_masks[:50].cpu()) # Limit plotted items for speed
204
- if pred_masks.shape[0] > 50:
205
- LOGGER.warning("Limiting validation plots to first 50 items per image for speed...")
206
-
207
- # Save
208
- if self.args.save_json:
209
- self.pred_to_json(
210
- predn,
211
- batch["im_file"][si],
212
- ops.scale_image(
213
- pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
214
- pbatch["ori_shape"],
215
- ratio_pad=batch["ratio_pad"][si],
216
- ),
217
- )
218
- if self.args.save_txt:
219
- self.save_one_txt(
220
- predn,
221
- pred_masks,
222
- self.args.save_conf,
223
- pbatch["ori_shape"],
224
- self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
225
- )
226
-
227
- def _process_batch(
228
- self,
229
- detections: torch.Tensor,
230
- gt_bboxes: torch.Tensor,
231
- gt_cls: torch.Tensor,
232
- pred_masks: Optional[torch.Tensor] = None,
233
- gt_masks: Optional[torch.Tensor] = None,
234
- overlap: Optional[bool] = False,
235
- masks: Optional[bool] = False,
236
- ) -> torch.Tensor:
154
+ predn["coco_masks"] = coco_masks
155
+ return predn
156
+
157
+ def _process_batch(self, preds: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, np.ndarray]:
237
158
  """
238
159
  Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
239
160
 
240
161
  Args:
241
- detections (torch.Tensor): Tensor of shape (N, 6) representing detected bounding boxes and
242
- associated confidence scores and class indices. Each row is of the format [x1, y1, x2, y2, conf, class].
243
- gt_bboxes (torch.Tensor): Tensor of shape (M, 4) representing ground truth bounding box coordinates.
244
- Each row is of the format [x1, y1, x2, y2].
245
- gt_cls (torch.Tensor): Tensor of shape (M,) representing ground truth class indices.
246
- pred_masks (torch.Tensor, optional): Tensor representing predicted masks, if available. The shape should
247
- match the ground truth masks.
248
- gt_masks (torch.Tensor, optional): Tensor of shape (M, H, W) representing ground truth masks, if available.
249
- overlap (bool, optional): Flag indicating if overlapping masks should be considered.
250
- masks (bool, optional): Flag indicating if the batch contains mask data.
162
+ preds (Dict[str, torch.Tensor]): Dictionary containing predictions with keys like 'cls' and 'masks'.
163
+ batch (Dict[str, Any]): Dictionary containing batch data with keys like 'cls' and 'masks'.
251
164
 
252
165
  Returns:
253
- (torch.Tensor): A correct prediction matrix of shape (N, 10), where 10 represents different IoU levels.
166
+ (Dict[str, np.ndarray]): A dictionary containing correct prediction matrices including 'tp_m' for mask IoU.
254
167
 
255
168
  Notes:
256
169
  - If `masks` is True, the function computes IoU between predicted and ground truth masks.
257
170
  - If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
258
171
 
259
172
  Examples:
260
- >>> detections = torch.tensor([[25, 30, 200, 300, 0.8, 1], [50, 60, 180, 290, 0.75, 0]])
261
- >>> gt_bboxes = torch.tensor([[24, 29, 199, 299], [55, 65, 185, 295]])
262
- >>> gt_cls = torch.tensor([1, 0])
263
- >>> correct_preds = validator._process_batch(detections, gt_bboxes, gt_cls)
173
+ >>> preds = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
174
+ >>> batch = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
175
+ >>> correct_preds = validator._process_batch(preds, batch)
264
176
  """
265
- if masks:
266
- if overlap:
177
+ tp = super()._process_batch(preds, batch)
178
+ gt_cls, gt_masks = batch["cls"], batch["masks"]
179
+ if len(gt_cls) == 0 or len(preds["cls"]) == 0:
180
+ tp_m = np.zeros((len(preds["cls"]), self.niou), dtype=bool)
181
+ else:
182
+ pred_masks = preds["masks"]
183
+ if self.args.overlap_mask:
267
184
  nl = len(gt_cls)
268
185
  index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
269
186
  gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
@@ -272,60 +189,32 @@ class SegmentationValidator(DetectionValidator):
272
189
  gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0]
273
190
  gt_masks = gt_masks.gt_(0.5)
274
191
  iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
275
- else: # boxes
276
- iou = box_iou(gt_bboxes, detections[:, :4])
277
-
278
- return self.match_predictions(detections[:, 5], gt_cls, iou)
192
+ tp_m = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
193
+ tp.update({"tp_m": tp_m}) # update tp with mask IoU
194
+ return tp
279
195
 
280
- def plot_val_samples(self, batch: Dict[str, Any], ni: int) -> None:
281
- """
282
- Plot validation samples with bounding box labels and masks.
283
-
284
- Args:
285
- batch (Dict[str, Any]): Batch containing images and annotations.
286
- ni (int): Batch index.
287
- """
288
- plot_images(
289
- batch["img"],
290
- batch["batch_idx"],
291
- batch["cls"].squeeze(-1),
292
- batch["bboxes"],
293
- masks=batch["masks"],
294
- paths=batch["im_file"],
295
- fname=self.save_dir / f"val_batch{ni}_labels.jpg",
296
- names=self.names,
297
- on_plot=self.on_plot,
298
- )
299
-
300
- def plot_predictions(self, batch: Dict[str, Any], preds: List[torch.Tensor], ni: int) -> None:
196
+ def plot_predictions(self, batch: Dict[str, Any], preds: List[Dict[str, torch.Tensor]], ni: int) -> None:
301
197
  """
302
198
  Plot batch predictions with masks and bounding boxes.
303
199
 
304
200
  Args:
305
201
  batch (Dict[str, Any]): Batch containing images and annotations.
306
- preds (List[torch.Tensor]): List of predictions from the model.
202
+ preds (List[Dict[str, torch.Tensor]]): List of predictions from the model.
307
203
  ni (int): Batch index.
308
204
  """
309
- plot_images(
310
- batch["img"],
311
- *output_to_target(preds[0], max_det=50), # not set to self.args.max_det due to slow plotting speed
312
- torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
313
- paths=batch["im_file"],
314
- fname=self.save_dir / f"val_batch{ni}_pred.jpg",
315
- names=self.names,
316
- on_plot=self.on_plot,
317
- ) # pred
318
- self.plot_masks.clear()
319
-
320
- def save_one_txt(
321
- self, predn: torch.Tensor, pred_masks: torch.Tensor, save_conf: bool, shape: Tuple[int, int], file: Path
322
- ) -> None:
205
+ for p in preds:
206
+ masks = p["masks"]
207
+ if masks.shape[0] > 50:
208
+ LOGGER.warning("Limiting validation plots to first 50 items per image for speed...")
209
+ p["masks"] = torch.as_tensor(masks[:50], dtype=torch.uint8).cpu()
210
+ super().plot_predictions(batch, preds, ni, max_det=50) # plot bboxes
211
+
212
+ def save_one_txt(self, predn: torch.Tensor, save_conf: bool, shape: Tuple[int, int], file: Path) -> None:
323
213
  """
324
214
  Save YOLO detections to a txt file in normalized coordinates in a specific format.
325
215
 
326
216
  Args:
327
217
  predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
328
- pred_masks (torch.Tensor): Predicted masks.
329
218
  save_conf (bool): Whether to save confidence scores.
330
219
  shape (Tuple[int, int]): Shape of the original image.
331
220
  file (Path): File path to save the detections.
@@ -336,18 +225,17 @@ class SegmentationValidator(DetectionValidator):
336
225
  np.zeros((shape[0], shape[1]), dtype=np.uint8),
337
226
  path=None,
338
227
  names=self.names,
339
- boxes=predn[:, :6],
340
- masks=pred_masks,
228
+ boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
229
+ masks=torch.as_tensor(predn["masks"], dtype=torch.uint8),
341
230
  ).save_txt(file, save_conf=save_conf)
342
231
 
343
- def pred_to_json(self, predn: torch.Tensor, filename: str, pred_masks: torch.Tensor) -> None:
232
+ def pred_to_json(self, predn: torch.Tensor, filename: str) -> None:
344
233
  """
345
234
  Save one JSON result for COCO evaluation.
346
235
 
347
236
  Args:
348
- predn (torch.Tensor): Predictions in the format [x1, y1, x2, y2, conf, cls].
237
+ predn (Dict[str, torch.Tensor]): Predictions containing bboxes, masks, confidence scores, and classes.
349
238
  filename (str): Image filename.
350
- pred_masks (numpy.ndarray): Predicted masks.
351
239
 
352
240
  Examples:
353
241
  >>> result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
@@ -362,18 +250,18 @@ class SegmentationValidator(DetectionValidator):
362
250
 
363
251
  stem = Path(filename).stem
364
252
  image_id = int(stem) if stem.isnumeric() else stem
365
- box = ops.xyxy2xywh(predn[:, :4]) # xywh
253
+ box = ops.xyxy2xywh(predn["bboxes"]) # xywh
366
254
  box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
367
- pred_masks = np.transpose(pred_masks, (2, 0, 1))
255
+ pred_masks = np.transpose(predn["coco_masks"], (2, 0, 1))
368
256
  with ThreadPool(NUM_THREADS) as pool:
369
257
  rles = pool.map(single_encode, pred_masks)
370
- for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())):
258
+ for i, (b, s, c) in enumerate(zip(box.tolist(), predn["conf"].tolist(), predn["cls"].tolist())):
371
259
  self.jdict.append(
372
260
  {
373
261
  "image_id": image_id,
374
- "category_id": self.class_map[int(p[5])],
262
+ "category_id": self.class_map[int(c)],
375
263
  "bbox": [round(x, 3) for x in b],
376
- "score": round(p[4], 5),
264
+ "score": round(s, 5),
377
265
  "segmentation": rles[i],
378
266
  }
379
267
  )
@@ -1,6 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from copy import deepcopy
4
+ from pathlib import Path
4
5
  from typing import Any, Dict, Optional, Union
5
6
 
6
7
  import torch
@@ -182,7 +183,7 @@ class YOLOEDetectValidator(DetectionValidator):
182
183
  assert load_vp, "Refer data is only used for visual prompt validation."
183
184
  self.device = select_device(self.args.device)
184
185
 
185
- if isinstance(model, str):
186
+ if isinstance(model, (str, Path)):
186
187
  from ultralytics.nn.tasks import attempt_load_weights
187
188
 
188
189
  model = attempt_load_weights(model, device=self.device, inplace=True)
@@ -196,9 +196,9 @@ class AutoBackend(nn.Module):
196
196
 
197
197
  # In-memory PyTorch model
198
198
  if nn_module:
199
- if fuse:
200
- weights = weights.fuse(verbose=verbose) # fuse before move to gpu
201
199
  model = weights.to(device)
200
+ if fuse:
201
+ model = model.fuse(verbose=verbose)
202
202
  if hasattr(model, "kpt_shape"):
203
203
  kpt_shape = model.kpt_shape # pose-only
204
204
  stride = max(int(model.stride.max()), 32) # model stride
@@ -76,11 +76,10 @@ class AIGym(BaseSolution):
76
76
  self.extract_tracks(im0) # Extract tracks (bounding boxes, classes, and masks)
77
77
 
78
78
  if len(self.boxes):
79
- kpt_data = self.tracks.keypoints.data.cpu() # Avoid repeated .cpu() calls
79
+ kpt_data = self.tracks.keypoints.data
80
80
 
81
81
  for i, k in enumerate(kpt_data):
82
- track_id = int(self.track_ids[i]) # get track id
83
- state = self.states[track_id] # get state details
82
+ state = self.states[self.track_ids[i]] # get state details
84
83
  # Get keypoints and estimate the angle
85
84
  state["angle"] = annotator.estimate_pose_angle(*[k[int(idx)] for idx in self.kpts])
86
85
  annotator.draw_specific_kpts(k, self.kpts, radius=self.line_width * 3)
@@ -2,6 +2,7 @@
2
2
 
3
3
  import math
4
4
  from collections import defaultdict
5
+ from functools import lru_cache
5
6
  from typing import Any, Dict, List, Optional, Tuple
6
7
 
7
8
  import cv2
@@ -423,6 +424,7 @@ class SolutionAnnotator(Annotator):
423
424
  text_y_offset = rect_y2
424
425
 
425
426
  @staticmethod
427
+ @lru_cache(maxsize=256)
426
428
  def estimate_pose_angle(a: List[float], b: List[float], c: List[float]) -> float:
427
429
  """
428
430
  Calculate the angle between three points for workout monitoring.
@@ -126,6 +126,20 @@
126
126
  }
127
127
  </style>
128
128
  </head>
129
+ <script>
130
+ function filterResults(k) {
131
+ const cards = document.querySelectorAll(".grid .card");
132
+ cards.forEach((card, idx) => {
133
+ card.style.display = idx < k ? "block" : "none";
134
+ });
135
+ const buttons = document.querySelectorAll(".topk-btn");
136
+ buttons.forEach((btn) => btn.classList.remove("active"));
137
+ event.target.classList.add("active");
138
+ }
139
+ document.addEventListener("DOMContentLoaded", () => {
140
+ filterResults(10);
141
+ });
142
+ </script>
129
143
  <body>
130
144
  <div style="text-align: center; margin-bottom: 1rem">
131
145
  <img
@@ -146,6 +160,23 @@
146
160
  required
147
161
  />
148
162
  <button type="submit">Search</button>
163
+ {% if results %}
164
+ <div class="top-k-buttons">
165
+ <button type="button" class="topk-btn" onclick="filterResults(5)">
166
+ Top 5
167
+ </button>
168
+ <button
169
+ type="button"
170
+ class="topk-btn active"
171
+ onclick="filterResults(10)"
172
+ >
173
+ Top 10
174
+ </button>
175
+ <button type="button" class="topk-btn" onclick="filterResults(30)">
176
+ Top 30
177
+ </button>
178
+ </div>
179
+ {% endif %}
149
180
  </form>
150
181
 
151
182
  <!-- Search results grid -->
@@ -457,7 +457,7 @@ def _log_plots(experiment, trainer) -> None:
457
457
  >>> _log_plots(experiment, trainer)
458
458
  """
459
459
  plot_filenames = None
460
- if isinstance(trainer.validator.metrics, SegmentMetrics) and trainer.validator.metrics.task == "segment":
460
+ if isinstance(trainer.validator.metrics, SegmentMetrics):
461
461
  plot_filenames = [
462
462
  trainer.save_dir / f"{prefix}{plots}.png"
463
463
  for plots in EVALUATION_PLOT_NAMES