ultralytics 8.3.97__py3-none-any.whl → 8.3.99__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_python.py +56 -0
- ultralytics/__init__.py +3 -2
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
- ultralytics/data/augment.py +101 -5
- ultralytics/data/dataset.py +165 -12
- ultralytics/engine/exporter.py +13 -13
- ultralytics/engine/trainer.py +16 -7
- ultralytics/models/__init__.py +2 -2
- ultralytics/models/nas/model.py +1 -0
- ultralytics/models/nas/predict.py +4 -24
- ultralytics/models/nas/val.py +1 -4
- ultralytics/models/yolo/__init__.py +3 -3
- ultralytics/models/yolo/detect/val.py +6 -1
- ultralytics/models/yolo/model.py +182 -3
- ultralytics/models/yolo/segment/val.py +43 -16
- ultralytics/models/yolo/yoloe/__init__.py +21 -0
- ultralytics/models/yolo/yoloe/predict.py +170 -0
- ultralytics/models/yolo/yoloe/train.py +355 -0
- ultralytics/models/yolo/yoloe/train_seg.py +141 -0
- ultralytics/models/yolo/yoloe/val.py +187 -0
- ultralytics/nn/autobackend.py +3 -2
- ultralytics/nn/modules/__init__.py +18 -1
- ultralytics/nn/modules/block.py +17 -1
- ultralytics/nn/modules/head.py +359 -22
- ultralytics/nn/tasks.py +276 -10
- ultralytics/nn/text_model.py +193 -0
- ultralytics/utils/callbacks/comet.py +3 -6
- ultralytics/utils/downloads.py +6 -2
- ultralytics/utils/instance.py +7 -2
- ultralytics/utils/loss.py +67 -6
- ultralytics/utils/plotting.py +1 -1
- ultralytics/utils/tal.py +1 -1
- {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/METADATA +69 -67
- {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/RECORD +41 -31
- {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/licenses/LICENSE +0 -0
- {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/top_level.txt +0 -0
ultralytics/data/dataset.py
CHANGED
@@ -13,7 +13,7 @@ from PIL import Image
|
|
13
13
|
from torch.utils.data import ConcatDataset
|
14
14
|
|
15
15
|
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
|
16
|
-
from ultralytics.utils.ops import resample_segments
|
16
|
+
from ultralytics.utils.ops import resample_segments, segments2boxes
|
17
17
|
from ultralytics.utils.torch_utils import TORCHVISION_0_18
|
18
18
|
|
19
19
|
from .augment import (
|
@@ -27,6 +27,7 @@ from .augment import (
|
|
27
27
|
v8_transforms,
|
28
28
|
)
|
29
29
|
from .base import BaseDataset
|
30
|
+
from .converter import merge_multi_segment
|
30
31
|
from .utils import (
|
31
32
|
HELP_URL,
|
32
33
|
LOGGER,
|
@@ -289,12 +290,15 @@ class YOLODataset(BaseDataset):
|
|
289
290
|
(dict): Collated batch with stacked tensors.
|
290
291
|
"""
|
291
292
|
new_batch = {}
|
293
|
+
batch = [dict(sorted(b.items())) for b in batch] # make sure the keys are in the same order
|
292
294
|
keys = batch[0].keys()
|
293
295
|
values = list(zip(*[list(b.values()) for b in batch]))
|
294
296
|
for i, k in enumerate(keys):
|
295
297
|
value = values[i]
|
296
|
-
if k == "img":
|
298
|
+
if k == "img" or k == "text_feats":
|
297
299
|
value = torch.stack(value, 0)
|
300
|
+
elif k == "visuals":
|
301
|
+
value = torch.nn.utils.rnn.pad_sequence(value, batch_first=True)
|
298
302
|
if k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}:
|
299
303
|
value = torch.cat(value, 0)
|
300
304
|
new_batch[k] = value
|
@@ -346,7 +350,9 @@ class YOLOMultiModalDataset(YOLODataset):
|
|
346
350
|
"""
|
347
351
|
labels = super().update_labels_info(label)
|
348
352
|
# NOTE: some categories are concatenated with its synonyms by `/`.
|
353
|
+
# NOTE: and `RandomLoadText` would randomly select one of them if there are multiple words.
|
349
354
|
labels["texts"] = [v.split("/") for _, v in self.data["names"].items()]
|
355
|
+
|
350
356
|
return labels
|
351
357
|
|
352
358
|
def build_transforms(self, hyp=None):
|
@@ -362,9 +368,46 @@ class YOLOMultiModalDataset(YOLODataset):
|
|
362
368
|
transforms = super().build_transforms(hyp)
|
363
369
|
if self.augment:
|
364
370
|
# NOTE: hard-coded the args for now.
|
365
|
-
|
371
|
+
# NOTE: this implementation is different from official yoloe,
|
372
|
+
# the strategy of selecting negative is restricted in one dataset,
|
373
|
+
# while official pre-saved neg embeddings from all datasets at once.
|
374
|
+
transform = RandomLoadText(
|
375
|
+
max_samples=min(self.data["nc"], 80),
|
376
|
+
padding=True,
|
377
|
+
padding_value=self._get_neg_texts(self.category_freq),
|
378
|
+
)
|
379
|
+
transforms.insert(-1, transform)
|
366
380
|
return transforms
|
367
381
|
|
382
|
+
@property
|
383
|
+
def category_names(self):
|
384
|
+
"""
|
385
|
+
Return category names for the dataset.
|
386
|
+
|
387
|
+
Returns:
|
388
|
+
(Tuple[str]): List of class names.
|
389
|
+
"""
|
390
|
+
names = self.data["names"].values()
|
391
|
+
return {n.strip() for name in names for n in name.split("/")} # category names
|
392
|
+
|
393
|
+
@property
|
394
|
+
def category_freq(self):
|
395
|
+
"""Return frequency of each category in the dataset."""
|
396
|
+
texts = [v.split("/") for v in self.data["names"].values()]
|
397
|
+
category_freq = defaultdict(int)
|
398
|
+
for label in self.labels:
|
399
|
+
for c in label["cls"]: # to check
|
400
|
+
text = texts[int(c)]
|
401
|
+
for t in text:
|
402
|
+
t = t.strip()
|
403
|
+
category_freq[t] += 1
|
404
|
+
return category_freq
|
405
|
+
|
406
|
+
@staticmethod
|
407
|
+
def _get_neg_texts(category_freq, threshold=100):
|
408
|
+
"""Get negative text samples based on frequency threshold."""
|
409
|
+
return [k for k, v in category_freq.items() if v >= threshold]
|
410
|
+
|
368
411
|
|
369
412
|
class GroundingDataset(YOLODataset):
|
370
413
|
"""
|
@@ -386,17 +429,17 @@ class GroundingDataset(YOLODataset):
|
|
386
429
|
>>> len(dataset) # Number of valid images with annotations
|
387
430
|
"""
|
388
431
|
|
389
|
-
def __init__(self, *args, task="detect", json_file, **kwargs):
|
432
|
+
def __init__(self, *args, task="detect", json_file="", **kwargs):
|
390
433
|
"""
|
391
434
|
Initialize a GroundingDataset for object detection.
|
392
435
|
|
393
436
|
Args:
|
394
437
|
json_file (str): Path to the JSON file containing annotations.
|
395
|
-
task (str): Must be 'detect' for GroundingDataset.
|
438
|
+
task (str): Must be 'detect' or 'segment' for GroundingDataset.
|
396
439
|
*args (Any): Additional positional arguments for the parent class.
|
397
440
|
**kwargs (Any): Additional keyword arguments for the parent class.
|
398
441
|
"""
|
399
|
-
assert task
|
442
|
+
assert task in {"detect", "segment"}, "GroundingDataset currently only supports `detect` and `segment` tasks"
|
400
443
|
self.json_file = json_file
|
401
444
|
super().__init__(*args, task=task, data={}, **kwargs)
|
402
445
|
|
@@ -412,14 +455,31 @@ class GroundingDataset(YOLODataset):
|
|
412
455
|
"""
|
413
456
|
return []
|
414
457
|
|
415
|
-
def
|
458
|
+
def verify_labels(self, labels):
|
459
|
+
"""Verify the number of instances in the dataset matches expected counts."""
|
460
|
+
instance_count = sum(label["bboxes"].shape[0] for label in labels)
|
461
|
+
if "final_mixed_train_no_coco_segm" in self.json_file:
|
462
|
+
assert instance_count == 3662344
|
463
|
+
elif "final_mixed_train_no_coco" in self.json_file:
|
464
|
+
assert instance_count == 3681235
|
465
|
+
elif "final_flickr_separateGT_train_segm" in self.json_file:
|
466
|
+
assert instance_count == 638214
|
467
|
+
elif "final_flickr_separateGT_train" in self.json_file:
|
468
|
+
assert instance_count == 640704
|
469
|
+
else:
|
470
|
+
assert False
|
471
|
+
|
472
|
+
def cache_labels(self, path=Path("./labels.cache")):
|
416
473
|
"""
|
417
474
|
Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image.
|
418
475
|
|
476
|
+
Args:
|
477
|
+
path (Path): Path where to save the cache file.
|
478
|
+
|
419
479
|
Returns:
|
420
|
-
(
|
480
|
+
(dict): Dictionary containing cached labels and related information.
|
421
481
|
"""
|
422
|
-
|
482
|
+
x = {"labels": []}
|
423
483
|
LOGGER.info("Loading annotation file...")
|
424
484
|
with open(self.json_file) as f:
|
425
485
|
annotations = json.load(f)
|
@@ -435,6 +495,7 @@ class GroundingDataset(YOLODataset):
|
|
435
495
|
continue
|
436
496
|
self.im_files.append(str(im_file))
|
437
497
|
bboxes = []
|
498
|
+
segments = []
|
438
499
|
cat2id = {}
|
439
500
|
texts = []
|
440
501
|
for ann in anns:
|
@@ -448,7 +509,10 @@ class GroundingDataset(YOLODataset):
|
|
448
509
|
continue
|
449
510
|
|
450
511
|
caption = img["caption"]
|
451
|
-
cat_name = " ".join([caption[t[0] : t[1]] for t in ann["tokens_positive"]])
|
512
|
+
cat_name = " ".join([caption[t[0] : t[1]] for t in ann["tokens_positive"]]).lower().strip()
|
513
|
+
if not cat_name:
|
514
|
+
continue
|
515
|
+
|
452
516
|
if cat_name not in cat2id:
|
453
517
|
cat2id[cat_name] = len(cat2id)
|
454
518
|
texts.append([cat_name])
|
@@ -456,18 +520,66 @@ class GroundingDataset(YOLODataset):
|
|
456
520
|
box = [cls] + box.tolist()
|
457
521
|
if box not in bboxes:
|
458
522
|
bboxes.append(box)
|
523
|
+
if ann.get("segmentation") is not None:
|
524
|
+
if len(ann["segmentation"]) == 0:
|
525
|
+
segments.append(box)
|
526
|
+
continue
|
527
|
+
elif len(ann["segmentation"]) > 1:
|
528
|
+
s = merge_multi_segment(ann["segmentation"])
|
529
|
+
s = (np.concatenate(s, axis=0) / np.array([w, h], dtype=np.float32)).reshape(-1).tolist()
|
530
|
+
else:
|
531
|
+
s = [j for i in ann["segmentation"] for j in i] # all segments concatenated
|
532
|
+
s = (
|
533
|
+
(np.array(s, dtype=np.float32).reshape(-1, 2) / np.array([w, h], dtype=np.float32))
|
534
|
+
.reshape(-1)
|
535
|
+
.tolist()
|
536
|
+
)
|
537
|
+
s = [cls] + s
|
538
|
+
segments.append(s)
|
459
539
|
lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
|
460
|
-
|
540
|
+
|
541
|
+
if segments:
|
542
|
+
classes = np.array([x[0] for x in segments], dtype=np.float32)
|
543
|
+
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in segments] # (cls, xy1...)
|
544
|
+
lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
|
545
|
+
lb = np.array(lb, dtype=np.float32)
|
546
|
+
|
547
|
+
x["labels"].append(
|
461
548
|
{
|
462
549
|
"im_file": im_file,
|
463
550
|
"shape": (h, w),
|
464
551
|
"cls": lb[:, 0:1], # n, 1
|
465
552
|
"bboxes": lb[:, 1:], # n, 4
|
553
|
+
"segments": segments,
|
466
554
|
"normalized": True,
|
467
555
|
"bbox_format": "xywh",
|
468
556
|
"texts": texts,
|
469
557
|
}
|
470
558
|
)
|
559
|
+
x["hash"] = get_hash(self.json_file)
|
560
|
+
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
561
|
+
return x
|
562
|
+
|
563
|
+
def get_labels(self):
|
564
|
+
"""
|
565
|
+
Load labels from cache or generate them from JSON file.
|
566
|
+
|
567
|
+
Returns:
|
568
|
+
(List[dict]): List of label dictionaries, each containing information about an image and its annotations.
|
569
|
+
"""
|
570
|
+
cache_path = Path(self.json_file).with_suffix(".cache")
|
571
|
+
try:
|
572
|
+
cache, _ = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
|
573
|
+
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
574
|
+
assert cache["hash"] == get_hash(self.json_file) # identical hash
|
575
|
+
except (FileNotFoundError, AssertionError, AttributeError):
|
576
|
+
cache, _ = self.cache_labels(cache_path), False # run cache ops
|
577
|
+
[cache.pop(k) for k in ("hash", "version")] # remove items
|
578
|
+
labels = cache["labels"]
|
579
|
+
# self.verify_labels(labels)
|
580
|
+
self.im_files = [str(label["im_file"]) for label in labels]
|
581
|
+
if LOCAL_RANK in {-1, 0}:
|
582
|
+
LOGGER.info(f"Load {self.json_file} from cache file {cache_path}")
|
471
583
|
return labels
|
472
584
|
|
473
585
|
def build_transforms(self, hyp=None):
|
@@ -483,9 +595,38 @@ class GroundingDataset(YOLODataset):
|
|
483
595
|
transforms = super().build_transforms(hyp)
|
484
596
|
if self.augment:
|
485
597
|
# NOTE: hard-coded the args for now.
|
486
|
-
|
598
|
+
# NOTE: this implementation is different from official yoloe,
|
599
|
+
# the strategy of selecting negative is restricted in one dataset,
|
600
|
+
# while official pre-saved neg embeddings from all datasets at once.
|
601
|
+
transform = RandomLoadText(
|
602
|
+
max_samples=80,
|
603
|
+
padding=True,
|
604
|
+
padding_value=self._get_neg_texts(self.category_freq),
|
605
|
+
)
|
606
|
+
transforms.insert(-1, transform)
|
487
607
|
return transforms
|
488
608
|
|
609
|
+
@property
|
610
|
+
def category_names(self):
|
611
|
+
"""Return unique category names from the dataset."""
|
612
|
+
return {t.strip() for label in self.labels for text in label["texts"] for t in text}
|
613
|
+
|
614
|
+
@property
|
615
|
+
def category_freq(self):
|
616
|
+
"""Return frequency of each category in the dataset."""
|
617
|
+
category_freq = defaultdict(int)
|
618
|
+
for label in self.labels:
|
619
|
+
for text in label["texts"]:
|
620
|
+
for t in text:
|
621
|
+
t = t.strip()
|
622
|
+
category_freq[t] += 1
|
623
|
+
return category_freq
|
624
|
+
|
625
|
+
@staticmethod
|
626
|
+
def _get_neg_texts(category_freq, threshold=100):
|
627
|
+
"""Get negative text samples based on frequency threshold."""
|
628
|
+
return [k for k, v in category_freq.items() if v >= threshold]
|
629
|
+
|
489
630
|
|
490
631
|
class YOLOConcatDataset(ConcatDataset):
|
491
632
|
"""
|
@@ -516,6 +657,18 @@ class YOLOConcatDataset(ConcatDataset):
|
|
516
657
|
"""
|
517
658
|
return YOLODataset.collate_fn(batch)
|
518
659
|
|
660
|
+
def close_mosaic(self, hyp):
|
661
|
+
"""
|
662
|
+
Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations.
|
663
|
+
|
664
|
+
Args:
|
665
|
+
hyp (dict): Hyperparameters for transforms.
|
666
|
+
"""
|
667
|
+
for dataset in self.datasets:
|
668
|
+
if not hasattr(dataset, "close_mosaic"):
|
669
|
+
continue
|
670
|
+
dataset.close_mosaic(hyp)
|
671
|
+
|
519
672
|
|
520
673
|
# TODO: support semantic segmentation
|
521
674
|
class SemanticDataset(BaseDataset):
|
ultralytics/engine/exporter.py
CHANGED
@@ -58,6 +58,7 @@ TensorFlow.js:
|
|
58
58
|
import gc
|
59
59
|
import json
|
60
60
|
import os
|
61
|
+
import re
|
61
62
|
import shutil
|
62
63
|
import subprocess
|
63
64
|
import time
|
@@ -326,6 +327,7 @@ class Exporter:
|
|
326
327
|
"See https://docs.ultralytics.com/models/yolo-world for details."
|
327
328
|
)
|
328
329
|
model.clip_model = None # openvino int8 export error: https://github.com/ultralytics/ultralytics/pull/18445
|
330
|
+
|
329
331
|
if self.args.int8 and not self.args.data:
|
330
332
|
self.args.data = DEFAULT_CFG.data or TASK2DATA[getattr(model, "task", "detect")] # assign default data
|
331
333
|
LOGGER.warning(
|
@@ -634,7 +636,7 @@ class Exporter:
|
|
634
636
|
# Generate calibration data for integer quantization
|
635
637
|
ignored_scope = None
|
636
638
|
if isinstance(self.model.model[-1], Detect):
|
637
|
-
# Includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
|
639
|
+
# Includes all Detect subclasses like Segment, Pose, OBB, WorldDetect, YOLOEDetect
|
638
640
|
head_module_name = ".".join(list(self.model.named_modules())[-1][0].split(".")[:2])
|
639
641
|
ignored_scope = nncf.IgnoredScope( # ignore operations
|
640
642
|
patterns=[
|
@@ -796,12 +798,12 @@ class Exporter:
|
|
796
798
|
LOGGER.warning(f"{prefix} WARNING ⚠️ 'nms=True' is only available for Detect models like 'yolo11n.pt'.")
|
797
799
|
# TODO CoreML Segment and Pose model pipelining
|
798
800
|
model = self.model
|
799
|
-
|
800
801
|
ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model
|
801
802
|
ct_model = ct.convert(
|
802
803
|
ts,
|
803
|
-
inputs=[ct.ImageType("image", shape=self.im.shape, scale=scale, bias=bias)],
|
804
|
+
inputs=[ct.ImageType("image", shape=self.im.shape, scale=scale, bias=bias)], # expects ct.TensorType
|
804
805
|
classifier_config=classifier_config,
|
806
|
+
minimum_deployment_target=ct.target.iOS16,
|
805
807
|
convert_to="neuralnetwork" if mlmodel else "mlprogram",
|
806
808
|
)
|
807
809
|
bits, mode = (8, "kmeans") if self.args.int8 else (16, "linear") if self.args.half else (32, None)
|
@@ -1231,17 +1233,15 @@ class Exporter:
|
|
1231
1233
|
|
1232
1234
|
LOGGER.info(f"\n{prefix} starting export with model_compression_toolkit {mct.__version__}...")
|
1233
1235
|
|
1236
|
+
# Install Java>=17
|
1234
1237
|
try:
|
1235
|
-
|
1236
|
-
|
1237
|
-
|
1238
|
-
|
1239
|
-
|
1240
|
-
|
1241
|
-
|
1242
|
-
if is_sudo_available():
|
1243
|
-
c.insert(0, "sudo")
|
1244
|
-
subprocess.run(c, check=True)
|
1238
|
+
java_output = subprocess.run(["java", "--version"], check=True, capture_output=True).stdout.decode()
|
1239
|
+
version_match = re.search(r"(?:openjdk|java) (\d+)", java_output)
|
1240
|
+
java_version = int(version_match.group(1)) if version_match else 0
|
1241
|
+
assert java_version >= 17, "Java version too old"
|
1242
|
+
except (FileNotFoundError, subprocess.CalledProcessError, AssertionError):
|
1243
|
+
cmd = (["sudo"] if is_sudo_available() else []) + ["apt", "install", "-y", "default-jre"]
|
1244
|
+
subprocess.run(cmd, check=True)
|
1245
1245
|
|
1246
1246
|
def representative_dataset_gen(dataloader=self.get_int8_calibration_dataloader(prefix)):
|
1247
1247
|
for batch in dataloader:
|
ultralytics/engine/trainer.py
CHANGED
@@ -249,6 +249,7 @@ class BaseTrainer:
|
|
249
249
|
)
|
250
250
|
always_freeze_names = [".dfl"] # always freeze these layers
|
251
251
|
freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names
|
252
|
+
self.freeze_layer_names = freeze_layer_names
|
252
253
|
for k, v in self.model.named_parameters():
|
253
254
|
# v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
|
254
255
|
if any(x in k for x in freeze_layer_names):
|
@@ -350,7 +351,7 @@ class BaseTrainer:
|
|
350
351
|
warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
|
351
352
|
self.scheduler.step()
|
352
353
|
|
353
|
-
self.
|
354
|
+
self._model_train()
|
354
355
|
if RANK != -1:
|
355
356
|
self.train_loader.sampler.set_epoch(epoch)
|
356
357
|
pbar = enumerate(self.train_loader)
|
@@ -381,7 +382,8 @@ class BaseTrainer:
|
|
381
382
|
# Forward
|
382
383
|
with autocast(self.amp):
|
383
384
|
batch = self.preprocess_batch(batch)
|
384
|
-
|
385
|
+
loss, self.loss_items = self.model(batch)
|
386
|
+
self.loss = loss.sum()
|
385
387
|
if RANK != -1:
|
386
388
|
self.loss *= world_size
|
387
389
|
self.tloss = (
|
@@ -496,9 +498,7 @@ class BaseTrainer:
|
|
496
498
|
memory = torch.mps.driver_allocated_memory()
|
497
499
|
if fraction:
|
498
500
|
return __import__("psutil").virtual_memory().percent / 100
|
499
|
-
elif self.device.type
|
500
|
-
pass
|
501
|
-
else:
|
501
|
+
elif self.device.type != "cpu":
|
502
502
|
memory = torch.cuda.memory_reserved()
|
503
503
|
if fraction:
|
504
504
|
total = torch.cuda.get_device_properties(self.device).total_memory
|
@@ -520,6 +520,14 @@ class BaseTrainer:
|
|
520
520
|
|
521
521
|
return pd.read_csv(self.csv).to_dict(orient="list")
|
522
522
|
|
523
|
+
def _model_train(self):
|
524
|
+
"""Set model in training mode."""
|
525
|
+
self.model.train()
|
526
|
+
# Freeze BN stat
|
527
|
+
for n, m in self.model.named_modules():
|
528
|
+
if any(filter(lambda f: f in n, self.freeze_layer_names)) and isinstance(m, nn.BatchNorm2d):
|
529
|
+
m.eval()
|
530
|
+
|
523
531
|
def save_model(self):
|
524
532
|
"""Save model training checkpoints with additional metadata."""
|
525
533
|
import io
|
@@ -720,7 +728,7 @@ class BaseTrainer:
|
|
720
728
|
|
721
729
|
# Check that resume data YAML exists, otherwise strip to force re-download of dataset
|
722
730
|
ckpt_args = attempt_load_weights(last).args
|
723
|
-
if not Path(ckpt_args["data"]).exists():
|
731
|
+
if not isinstance(ckpt_args["data"], dict) and not Path(ckpt_args["data"]).exists():
|
724
732
|
ckpt_args["data"] = self.args.data
|
725
733
|
|
726
734
|
resume = True
|
@@ -812,7 +820,8 @@ class BaseTrainer:
|
|
812
820
|
fullname = f"{module_name}.{param_name}" if module_name else param_name
|
813
821
|
if "bias" in fullname: # bias (no decay)
|
814
822
|
g[2].append(param)
|
815
|
-
elif isinstance(module, bn): # weight (no decay)
|
823
|
+
elif isinstance(module, bn) or "logit_scale" in fullname: # weight (no decay)
|
824
|
+
# ContrastiveHead and BNContrastiveHead included here with 'logit_scale'
|
816
825
|
g[1].append(param)
|
817
826
|
else: # weight (with decay)
|
818
827
|
g[0].append(param)
|
ultralytics/models/__init__.py
CHANGED
@@ -4,6 +4,6 @@ from .fastsam import FastSAM
|
|
4
4
|
from .nas import NAS
|
5
5
|
from .rtdetr import RTDETR
|
6
6
|
from .sam import SAM
|
7
|
-
from .yolo import YOLO, YOLOWorld
|
7
|
+
from .yolo import YOLO, YOLOE, YOLOWorld
|
8
8
|
|
9
|
-
__all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld" # allow simpler import
|
9
|
+
__all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld", "YOLOE" # allow simpler import
|
ultralytics/models/nas/model.py
CHANGED
@@ -81,6 +81,7 @@ class NAS(Model):
|
|
81
81
|
self.model.pt_path = weights # for export()
|
82
82
|
self.model.task = "detect" # for export()
|
83
83
|
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # for export()
|
84
|
+
self.model.eval()
|
84
85
|
|
85
86
|
def info(self, detailed: bool = False, verbose: bool = True):
|
86
87
|
"""
|
@@ -2,16 +2,15 @@
|
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from ultralytics.
|
6
|
-
from ultralytics.engine.results import Results
|
5
|
+
from ultralytics.models.yolo.detect.predict import DetectionPredictor
|
7
6
|
from ultralytics.utils import ops
|
8
7
|
|
9
8
|
|
10
|
-
class NASPredictor(
|
9
|
+
class NASPredictor(DetectionPredictor):
|
11
10
|
"""
|
12
11
|
Ultralytics YOLO NAS Predictor for object detection.
|
13
12
|
|
14
|
-
This class extends the `
|
13
|
+
This class extends the `DetectionPredictor` from Ultralytics engine and is responsible for post-processing the
|
15
14
|
raw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and
|
16
15
|
scaling the bounding boxes to fit the original image dimensions.
|
17
16
|
|
@@ -38,23 +37,4 @@ class NASPredictor(BasePredictor):
|
|
38
37
|
# Convert boxes from xyxy to xywh format and concatenate with class scores
|
39
38
|
boxes = ops.xyxy2xywh(preds_in[0][0])
|
40
39
|
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
|
41
|
-
|
42
|
-
# Apply non-maximum suppression to filter overlapping detections
|
43
|
-
preds = ops.non_max_suppression(
|
44
|
-
preds,
|
45
|
-
self.args.conf,
|
46
|
-
self.args.iou,
|
47
|
-
agnostic=self.args.agnostic_nms,
|
48
|
-
max_det=self.args.max_det,
|
49
|
-
classes=self.args.classes,
|
50
|
-
)
|
51
|
-
|
52
|
-
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
53
|
-
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
54
|
-
|
55
|
-
results = []
|
56
|
-
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
|
57
|
-
# Scale bounding boxes to match original image dimensions
|
58
|
-
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
59
|
-
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
|
60
|
-
return results
|
40
|
+
return super().postprocess(preds, img, orig_imgs)
|
ultralytics/models/nas/val.py
CHANGED
@@ -36,7 +36,4 @@ class NASValidator(DetectionValidator):
|
|
36
36
|
"""Apply Non-maximum suppression to prediction outputs."""
|
37
37
|
boxes = ops.xyxy2xywh(preds_in[0][0]) # Convert bounding box format from xyxy to xywh
|
38
38
|
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) # Concatenate boxes with scores and permute
|
39
|
-
return super().postprocess(
|
40
|
-
preds,
|
41
|
-
max_time_img=0.5,
|
42
|
-
)
|
39
|
+
return super().postprocess(preds)
|
@@ -1,7 +1,7 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
-
from ultralytics.models.yolo import classify, detect, obb, pose, segment, world
|
3
|
+
from ultralytics.models.yolo import classify, detect, obb, pose, segment, world, yoloe
|
4
4
|
|
5
|
-
from .model import YOLO, YOLOWorld
|
5
|
+
from .model import YOLO, YOLOE, YOLOWorld
|
6
6
|
|
7
|
-
__all__ = "classify", "segment", "detect", "pose", "obb", "world", "YOLO", "YOLOWorld"
|
7
|
+
__all__ = "classify", "segment", "detect", "pose", "obb", "world", "yoloe", "YOLO", "YOLOWorld", "YOLOE"
|
@@ -455,8 +455,13 @@ class DetectionValidator(BaseValidator):
|
|
455
455
|
val.print_results() # explicitly call print_results
|
456
456
|
# update mAP50-95 and mAP50
|
457
457
|
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = (
|
458
|
-
val.stats[:2] if self.is_coco else [val.results["
|
458
|
+
val.stats[:2] if self.is_coco else [val.results["AP"], val.results["AP50"]]
|
459
459
|
)
|
460
|
+
if self.is_lvis:
|
461
|
+
stats["metrics/APr(B)"] = val.results["APr"]
|
462
|
+
stats["metrics/APc(B)"] = val.results["APc"]
|
463
|
+
stats["metrics/APf(B)"] = val.results["APf"]
|
464
|
+
stats["fitness"] = val.results["AP"]
|
460
465
|
except Exception as e:
|
461
466
|
LOGGER.warning(f"{pkg} unable to run: {e}")
|
462
467
|
return stats
|