attr-eomt 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- attr_eomt-0.1.0.dist-info/METADATA +307 -0
- attr_eomt-0.1.0.dist-info/RECORD +21 -0
- attr_eomt-0.1.0.dist-info/WHEEL +4 -0
- attr_eomt-0.1.0.dist-info/licenses/LICENSE +201 -0
- eomt/__init__.py +55 -0
- eomt/api.py +253 -0
- eomt/aux_cls.py +225 -0
- eomt/box_loss.py +196 -0
- eomt/config.py +230 -0
- eomt/ema.py +67 -0
- eomt/engine/__init__.py +7 -0
- eomt/engine/predict.py +117 -0
- eomt/engine/train.py +847 -0
- eomt/engine/validate.py +391 -0
- eomt/loss.py +279 -0
- eomt/model.py +786 -0
- eomt/plotting.py +157 -0
- eomt/postprocess.py +253 -0
- eomt/preprocess.py +55 -0
- eomt/serialization.py +308 -0
- eomt/visualize.py +98 -0
eomt/api.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
"""High-level :class:`EoMT` interface — init from a checkpoint or a size, then
|
|
2
|
+
``train`` / ``val`` / ``predict``.
|
|
3
|
+
|
|
4
|
+
from eomt import EoMT
|
|
5
|
+
|
|
6
|
+
model = EoMT("l") # fresh large model (DINOv2 backbone)
|
|
7
|
+
model.train(data="coco", epochs=50) # COCO auto-downloads if missing
|
|
8
|
+
|
|
9
|
+
model = EoMT("runs/train/eomt-l") # reload a run (everything auto-detected)
|
|
10
|
+
model.val(data="coco")
|
|
11
|
+
results = model.predict("images/", plot=True)
|
|
12
|
+
|
|
13
|
+
The class is a thin orchestration layer over the existing engine functions; the
|
|
14
|
+
trainable network itself is :class:`~eomt.model.EoMTModel`, reachable via
|
|
15
|
+
``model.model``.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
|
|
24
|
+
from .config import SIZES
|
|
25
|
+
from .data import CocoValImages, load_data_config
|
|
26
|
+
from .engine import evaluate as _evaluate
|
|
27
|
+
from .engine import evaluate_detection as _evaluate_detection
|
|
28
|
+
from .engine import predict as _predict
|
|
29
|
+
from .engine import train as _train
|
|
30
|
+
from .model import build_model, load_dinov2_backbone
|
|
31
|
+
from .serialization import is_hf_ref, load_model, save_checkpoint, wrap_checkpoint
|
|
32
|
+
|
|
33
|
+
#: Named dataset aliases that resolve to a bundled YAML config.
|
|
34
|
+
_DATASET_ALIASES = {"coco": "configs/coco.yaml"}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _looks_like_checkpoint(spec: str | Path) -> bool:
|
|
38
|
+
"""True if ``spec`` points at a checkpoint file, a run/weights folder, or a Hub ref."""
|
|
39
|
+
if is_hf_ref(spec):
|
|
40
|
+
return True
|
|
41
|
+
p = Path(spec)
|
|
42
|
+
if p.suffix == ".pt":
|
|
43
|
+
return True
|
|
44
|
+
if p.is_file():
|
|
45
|
+
return True
|
|
46
|
+
if p.is_dir():
|
|
47
|
+
# A run folder if it (or its weights/ subdir) holds a best/last checkpoint.
|
|
48
|
+
return any((p / sub).is_file() for sub in ("best.pt", "last.pt", "weights/best.pt", "weights/last.pt"))
|
|
49
|
+
return False
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _resolve_data(data: str | Path) -> dict:
|
|
53
|
+
"""Resolve a dataset spec (alias like ``"coco"`` or a YAML path) to absolute paths."""
|
|
54
|
+
yaml_path = _DATASET_ALIASES.get(str(data), str(data))
|
|
55
|
+
if not Path(yaml_path).is_file():
|
|
56
|
+
raise FileNotFoundError(
|
|
57
|
+
f"dataset config not found: {yaml_path!r}. Pass a dataset YAML path or one "
|
|
58
|
+
f"of {sorted(_DATASET_ALIASES)} (run from the repo root for the bundled configs)."
|
|
59
|
+
)
|
|
60
|
+
return load_data_config(yaml_path)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class EoMT:
|
|
64
|
+
"""An EoMT instance-segmentation model with train/val/predict.
|
|
65
|
+
|
|
66
|
+
Construct from either a model size (``"s"`` / ``"b"`` / ``"l"`` — a fresh model
|
|
67
|
+
with a pretrained DINOv2 backbone and a randomly initialized head) or a
|
|
68
|
+
checkpoint (a ``.pt`` file or a run/weights folder — size, classes, image size
|
|
69
|
+
and any secondary heads are all auto-detected from the checkpoint).
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(self, model: str | Path = "l", *, device: str = "auto", pretrained: bool = True, **build_kwargs):
|
|
73
|
+
self.device = device
|
|
74
|
+
self._ckpt: str | None = None
|
|
75
|
+
self._pretrained = pretrained
|
|
76
|
+
self._build_kwargs = build_kwargs
|
|
77
|
+
|
|
78
|
+
if isinstance(model, (str, Path)) and _looks_like_checkpoint(model):
|
|
79
|
+
self._model = load_model(model, device=device)
|
|
80
|
+
self._ckpt = str(model)
|
|
81
|
+
self.size = self._model.size
|
|
82
|
+
elif str(model) in SIZES:
|
|
83
|
+
self.size = str(model)
|
|
84
|
+
self._model = None # built lazily (avoids a wasted DINOv2 load before train())
|
|
85
|
+
else:
|
|
86
|
+
raise ValueError(
|
|
87
|
+
f"{model!r} is neither a known size {tuple(SIZES)} nor an existing checkpoint."
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# ----------------------------------------------------------- huggingface
|
|
91
|
+
@classmethod
|
|
92
|
+
def from_pretrained(
|
|
93
|
+
cls,
|
|
94
|
+
repo_id: str,
|
|
95
|
+
*,
|
|
96
|
+
filename: str = "model.pt",
|
|
97
|
+
revision: str | None = None,
|
|
98
|
+
device: str = "auto",
|
|
99
|
+
) -> "EoMT":
|
|
100
|
+
"""Load a model from a Hugging Face Hub repo (downloaded once, then cached).
|
|
101
|
+
|
|
102
|
+
model = EoMT.from_pretrained("imagra93/eomt-l-coco")
|
|
103
|
+
|
|
104
|
+
Equivalent to ``EoMT("hf://<repo_id>/<filename>")``. ``revision`` pins a
|
|
105
|
+
branch, tag or commit; ``filename`` selects the checkpoint inside the repo.
|
|
106
|
+
"""
|
|
107
|
+
ref = f"hf://{repo_id}/{filename}"
|
|
108
|
+
self = cls.__new__(cls)
|
|
109
|
+
self.device = device
|
|
110
|
+
self._pretrained = True
|
|
111
|
+
self._build_kwargs = {}
|
|
112
|
+
# Resolve via the Hub (cached) and load, recording the ref for repr.
|
|
113
|
+
from .serialization import download_from_hub
|
|
114
|
+
local = download_from_hub(ref, revision=revision)
|
|
115
|
+
self._model = load_model(local, device=device)
|
|
116
|
+
self._ckpt = ref
|
|
117
|
+
self.size = self._model.size
|
|
118
|
+
return self
|
|
119
|
+
|
|
120
|
+
def push_to_hub(
|
|
121
|
+
self,
|
|
122
|
+
repo_id: str,
|
|
123
|
+
*,
|
|
124
|
+
filename: str = "model.pt",
|
|
125
|
+
private: bool = True,
|
|
126
|
+
commit_message: str = "Upload EoMT checkpoint",
|
|
127
|
+
) -> str:
|
|
128
|
+
"""Upload this model's weights to a Hugging Face Hub repo (created if absent).
|
|
129
|
+
|
|
130
|
+
Writes a self-describing checkpoint (same format as :meth:`save`) and
|
|
131
|
+
uploads it as ``filename``. Returns the repo URL. Requires a Hub token
|
|
132
|
+
(``huggingface-cli login`` or ``HF_TOKEN``). Defaults to a **private** repo.
|
|
133
|
+
"""
|
|
134
|
+
try:
|
|
135
|
+
from huggingface_hub import HfApi
|
|
136
|
+
except ImportError as e: # pragma: no cover
|
|
137
|
+
raise ImportError("push_to_hub needs 'huggingface_hub' (pip install huggingface_hub).") from e
|
|
138
|
+
|
|
139
|
+
import tempfile
|
|
140
|
+
|
|
141
|
+
api = HfApi()
|
|
142
|
+
api.create_repo(repo_id, repo_type="model", private=private, exist_ok=True)
|
|
143
|
+
with tempfile.TemporaryDirectory() as tmp:
|
|
144
|
+
local = Path(tmp) / filename
|
|
145
|
+
self.save(local)
|
|
146
|
+
api.upload_file(
|
|
147
|
+
path_or_fileobj=str(local),
|
|
148
|
+
path_in_repo=filename,
|
|
149
|
+
repo_id=repo_id,
|
|
150
|
+
repo_type="model",
|
|
151
|
+
commit_message=commit_message,
|
|
152
|
+
)
|
|
153
|
+
return f"https://huggingface.co/{repo_id}"
|
|
154
|
+
|
|
155
|
+
# ------------------------------------------------------------------ model
|
|
156
|
+
@property
|
|
157
|
+
def model(self):
|
|
158
|
+
"""The underlying trainable :class:`~eomt.model.EoMTModel` (built on demand)."""
|
|
159
|
+
if self._model is None:
|
|
160
|
+
self._model = build_model(self.size, **self._build_kwargs)
|
|
161
|
+
if self._pretrained:
|
|
162
|
+
load_dinov2_backbone(self._model)
|
|
163
|
+
dev = "cuda" if (self.device in ("", "auto") and torch.cuda.is_available()) else \
|
|
164
|
+
("cpu" if self.device in ("", "auto") else self.device)
|
|
165
|
+
self._model = self._model.to(dev).eval()
|
|
166
|
+
return self._model
|
|
167
|
+
|
|
168
|
+
# ------------------------------------------------------------------ train
|
|
169
|
+
def train(self, data: str | Path = "coco", *, resume: bool = False, **hp) -> dict:
|
|
170
|
+
"""Train on a COCO-format dataset.
|
|
171
|
+
|
|
172
|
+
``data`` is a dataset YAML path or a known alias (``"coco"`` auto-downloads).
|
|
173
|
+
Extra keyword args (``epochs``, ``batch``, ``lr0``, ``aux_w``, …) are passed
|
|
174
|
+
straight through to the training engine. For a checkpoint-initialized model,
|
|
175
|
+
``resume=True`` continues the original run; otherwise the checkpoint warm-starts
|
|
176
|
+
a fresh run (fine-tune). Returns the engine's result dict and reloads the best
|
|
177
|
+
weights into this object.
|
|
178
|
+
"""
|
|
179
|
+
cfg = _resolve_data(data)
|
|
180
|
+
if not (cfg["train_images"] and cfg["train_json"]):
|
|
181
|
+
raise ValueError(f"dataset {data!r} has no train split (train_images/train_json).")
|
|
182
|
+
|
|
183
|
+
if self._ckpt is not None:
|
|
184
|
+
hp["resume" if resume else "init_weights"] = self._ckpt
|
|
185
|
+
|
|
186
|
+
result = _train(
|
|
187
|
+
train_images=cfg["train_images"],
|
|
188
|
+
train_json=cfg["train_json"],
|
|
189
|
+
val_images=cfg["val_images"],
|
|
190
|
+
val_json=cfg["val_json"],
|
|
191
|
+
size=self.size,
|
|
192
|
+
device=self.device,
|
|
193
|
+
**hp,
|
|
194
|
+
)
|
|
195
|
+
best = result.get("best") or result.get("last")
|
|
196
|
+
if best:
|
|
197
|
+
self._model = load_model(best, device=self.device)
|
|
198
|
+
self._ckpt = best
|
|
199
|
+
return result
|
|
200
|
+
|
|
201
|
+
# -------------------------------------------------------------------- val
|
|
202
|
+
def val(self, data: str | Path = "coco", *, batch: int = 4, workers: int = 4,
|
|
203
|
+
conf_thres: float = 0.0, max_det: int = 100, letterbox: bool | None = None, **kw) -> dict:
|
|
204
|
+
"""Evaluate on a dataset's val split, returning COCO mAP metrics.
|
|
205
|
+
|
|
206
|
+
Segmentation models report ``segm/*`` (+ ``bbox/*``); detection
|
|
207
|
+
(``family="detect"``) models report ``bbox/*`` only.
|
|
208
|
+
"""
|
|
209
|
+
cfg = _resolve_data(data)
|
|
210
|
+
if not (cfg["val_images"] and cfg["val_json"]):
|
|
211
|
+
raise ValueError(f"dataset {data!r} has no val split (val_images/val_json).")
|
|
212
|
+
|
|
213
|
+
model = self.model
|
|
214
|
+
dev = next(model.parameters()).device
|
|
215
|
+
lb = letterbox if letterbox is not None else bool(getattr(model, "preprocess_letterbox", False))
|
|
216
|
+
val_ds = CocoValImages(cfg["val_images"], cfg["val_json"], imgsz=int(model.image_size), letterbox=lb)
|
|
217
|
+
eval_fn = _evaluate_detection if getattr(model, "family", "instance") == "detect" else _evaluate
|
|
218
|
+
return eval_fn(
|
|
219
|
+
model, val_ds, device=dev, batch_size=batch, num_workers=workers,
|
|
220
|
+
conf_thres=conf_thres, max_det=max_det, **kw,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
# ---------------------------------------------------------------- predict
|
|
224
|
+
def predict(self, source: str | Path, *, plot: bool = False, save: str | None = "runs/predict",
|
|
225
|
+
conf_thres: float = 0.3, max_det: int = 100, mask_thresh: float = 0.5, **kw) -> list[dict]:
|
|
226
|
+
"""Run inference on an image or a directory.
|
|
227
|
+
|
|
228
|
+
Returns one result dict per image (``boxes`` / ``scores`` / ``classes`` /
|
|
229
|
+
``masks`` and, for models with secondary heads, ``aux``). With ``plot=True``
|
|
230
|
+
each image is rendered with masks/boxes/labels and saved under ``save``.
|
|
231
|
+
"""
|
|
232
|
+
return _predict(
|
|
233
|
+
self.model, str(source), plot=plot, save=save,
|
|
234
|
+
conf_thres=conf_thres, max_det=max_det, mask_thresh=mask_thresh, **kw,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
# ------------------------------------------------------------------- save
|
|
238
|
+
def save(self, path: str | Path) -> None:
|
|
239
|
+
"""Write a self-describing checkpoint (reloadable with ``EoMT(path)``)."""
|
|
240
|
+
m = self.model
|
|
241
|
+
ckpt = wrap_checkpoint(
|
|
242
|
+
m.state_dict(),
|
|
243
|
+
size=m.size, nc=m.nc, imgsz=m.image_size, names=m.names,
|
|
244
|
+
task=getattr(m, "family", "instance"),
|
|
245
|
+
aux_heads=m.aux_specs, aux_head_arch=m.aux_head_arch,
|
|
246
|
+
letterbox=bool(getattr(m, "preprocess_letterbox", False)),
|
|
247
|
+
loss_weights=m.loss_weights, num_upscale_blocks=m.num_upscale_blocks,
|
|
248
|
+
)
|
|
249
|
+
save_checkpoint(ckpt, path)
|
|
250
|
+
|
|
251
|
+
def __repr__(self) -> str:
|
|
252
|
+
src = f"ckpt={self._ckpt!r}" if self._ckpt else f"size={self.size!r}"
|
|
253
|
+
return f"EoMT({src})"
|
eomt/aux_cls.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
"""Secondary per-instance classification heads ("attributes") for EoMT.
|
|
2
|
+
|
|
3
|
+
The primary task (instance segmentation over ``nc`` classes) is unchanged. Each
|
|
4
|
+
aux head predicts an extra attribute per *detected instance* — typology,
|
|
5
|
+
laterality, severity, … — read from the matched query's embedding.
|
|
6
|
+
|
|
7
|
+
The supervision reuses EoMT's own Hungarian matcher
|
|
8
|
+
(``model.eomt.criterion.matcher``) so each attribute is trained on the **same**
|
|
9
|
+
query→GT assignment the detection loss used. Several specs ⇒ several independent
|
|
10
|
+
heads, summed (optionally weighted) into one scalar.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn.functional as F # noqa: N812
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@torch.no_grad()
|
|
20
|
+
def match_queries(model, out: dict, geom_labels, class_labels):
|
|
21
|
+
"""Hungarian query→GT indices from EoMT's matcher (list of ``(src, tgt)``).
|
|
22
|
+
|
|
23
|
+
``geom_labels`` is the GT geometry for the model's family — instance masks for the
|
|
24
|
+
``"instance"`` family, normalized ``cxcywh`` boxes for ``"detect"``. The right
|
|
25
|
+
matcher (mask-cost or box-cost) is selected by what the model emitted in ``out``.
|
|
26
|
+
"""
|
|
27
|
+
geom = out["pred_boxes"] if "pred_boxes" in out else out["masks_queries_logits"]
|
|
28
|
+
return model.eomt.criterion.matcher(
|
|
29
|
+
geom, out["class_queries_logits"], geom_labels, class_labels
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _gather_matched(out: dict, indices):
|
|
34
|
+
"""Return ``(matched_feats [N, hidden], [(b, tgt_idx)...])`` for matched queries."""
|
|
35
|
+
q = out["query_embed"] # [B, Q, hidden]
|
|
36
|
+
feats, batch_tgt = [], []
|
|
37
|
+
for b, (src, tgt) in enumerate(indices):
|
|
38
|
+
if src.numel() == 0:
|
|
39
|
+
continue
|
|
40
|
+
feats.append(q[b, src])
|
|
41
|
+
batch_tgt.append((b, tgt))
|
|
42
|
+
if not feats:
|
|
43
|
+
return None, batch_tgt
|
|
44
|
+
return torch.cat(feats), batch_tgt
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@torch.no_grad()
|
|
48
|
+
def gate_indices(
|
|
49
|
+
out: dict,
|
|
50
|
+
indices,
|
|
51
|
+
geom_labels,
|
|
52
|
+
class_labels,
|
|
53
|
+
*,
|
|
54
|
+
iou_thr: float = 0.5,
|
|
55
|
+
require_class: bool = False,
|
|
56
|
+
):
|
|
57
|
+
"""Keep only well-localized (and optionally correctly-classified) matched pairs.
|
|
58
|
+
|
|
59
|
+
The Hungarian matcher assigns *every* GT a query, even one whose prediction barely
|
|
60
|
+
overlaps it (common early in training). For attribute supervision we want queries
|
|
61
|
+
that actually localize the instance, so we drop matched ``(src, tgt)`` pairs whose
|
|
62
|
+
predicted↔GT IoU is ``< iou_thr`` (mask IoU for the instance family, box IoU for
|
|
63
|
+
detect). With ``require_class`` we also drop pairs whose predicted primary class ≠
|
|
64
|
+
the GT class. Returns the same ``[(src, tgt), ...]`` structure with each pair
|
|
65
|
+
filtered (possibly to empty).
|
|
66
|
+
|
|
67
|
+
A no-op (returns ``indices`` unchanged) when ``iou_thr <= 0`` and not
|
|
68
|
+
``require_class`` — i.e. the pre-gate behaviour.
|
|
69
|
+
"""
|
|
70
|
+
if iou_thr <= 0 and not require_class:
|
|
71
|
+
return indices
|
|
72
|
+
is_detect = "pred_boxes" in out
|
|
73
|
+
cls = out["class_queries_logits"] # [B, Q, C+1]
|
|
74
|
+
dev = cls.device
|
|
75
|
+
gated = []
|
|
76
|
+
for b, (src, tgt) in enumerate(indices):
|
|
77
|
+
src = src.to(dev)
|
|
78
|
+
tgt = tgt.to(dev)
|
|
79
|
+
if src.numel() == 0:
|
|
80
|
+
gated.append((src, tgt))
|
|
81
|
+
continue
|
|
82
|
+
keep = torch.ones(src.numel(), dtype=torch.bool, device=dev)
|
|
83
|
+
if iou_thr > 0:
|
|
84
|
+
keep &= _localization_iou(out, b, src, tgt, geom_labels, is_detect, dev) >= iou_thr
|
|
85
|
+
if require_class:
|
|
86
|
+
pred_cls = cls[b, src].argmax(-1)
|
|
87
|
+
keep &= pred_cls == class_labels[b][tgt].to(dev)
|
|
88
|
+
gated.append((src[keep], tgt[keep]))
|
|
89
|
+
return gated
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _localization_iou(out, b, src, tgt, geom_labels, is_detect, dev):
|
|
93
|
+
"""Per-pair predicted↔GT IoU for the matched queries (box IoU or mask IoU)."""
|
|
94
|
+
if is_detect:
|
|
95
|
+
from .box_loss import box_cxcywh_to_xyxy, box_iou
|
|
96
|
+
|
|
97
|
+
pred = box_cxcywh_to_xyxy(out["pred_boxes"][b, src].to(dev)) # [n, 4]
|
|
98
|
+
gt = box_cxcywh_to_xyxy(geom_labels[b][tgt].to(dev).float()) # [n, 4]
|
|
99
|
+
return box_iou(pred, gt)[0].diagonal() # per-matched-pair IoU
|
|
100
|
+
masks = out["masks_queries_logits"] # [B, Q, h, w]
|
|
101
|
+
pm = masks[b, src].sigmoid() > 0.5 # [n, h, w]
|
|
102
|
+
gm = geom_labels[b][tgt].to(dev).float() # [n, H, W] in {0, 1}
|
|
103
|
+
if gm.shape[-2:] != pm.shape[-2:]:
|
|
104
|
+
gm = F.interpolate(gm.unsqueeze(1), size=pm.shape[-2:], mode="nearest").squeeze(1)
|
|
105
|
+
gm = gm > 0.5
|
|
106
|
+
inter = (pm & gm).flatten(1).sum(1).float()
|
|
107
|
+
union = (pm | gm).flatten(1).sum(1).float().clamp(min=1.0)
|
|
108
|
+
return inter / union
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def aux_loss(
|
|
112
|
+
model,
|
|
113
|
+
out: dict,
|
|
114
|
+
mask_labels,
|
|
115
|
+
class_labels,
|
|
116
|
+
aux_labels: dict[str, list[torch.Tensor]],
|
|
117
|
+
weights: dict[str, float] | None = None,
|
|
118
|
+
*,
|
|
119
|
+
indices=None,
|
|
120
|
+
class_weights: dict[str, torch.Tensor] | None = None,
|
|
121
|
+
ignore_index: int = -100,
|
|
122
|
+
):
|
|
123
|
+
"""Weighted sum of per-attribute CE over matched queries.
|
|
124
|
+
|
|
125
|
+
``aux_labels`` maps head name → per-image label tensors (line-aligned with
|
|
126
|
+
``class_labels``). Labels equal to ``ignore_index`` (missing / out-of-vocab
|
|
127
|
+
attributes) are skipped — they contribute no loss instead of being trained as
|
|
128
|
+
class 0. ``weights`` scales each head; ``class_weights`` optionally re-weights
|
|
129
|
+
classes within a head (imbalance). Pass ``indices`` to reuse a matching already
|
|
130
|
+
computed this step (avoids re-running the Hungarian matcher). Returns
|
|
131
|
+
``(total_loss, {name: per_head_loss})``.
|
|
132
|
+
"""
|
|
133
|
+
if indices is None:
|
|
134
|
+
indices = match_queries(model, out, mask_labels, class_labels)
|
|
135
|
+
matched, batch_tgt = _gather_matched(out, indices)
|
|
136
|
+
if matched is None: # no query matched in this batch — keep the graph alive
|
|
137
|
+
zero = out["query_embed"].sum() * 0.0
|
|
138
|
+
return zero, {name: zero.detach() for name in model.aux_heads}
|
|
139
|
+
|
|
140
|
+
weights = weights or {}
|
|
141
|
+
class_weights = class_weights or {}
|
|
142
|
+
total: torch.Tensor | None = None
|
|
143
|
+
per_head: dict[str, torch.Tensor] = {}
|
|
144
|
+
for name, head in model.aux_heads.items():
|
|
145
|
+
logits = head(matched) # [N, ns]
|
|
146
|
+
gt = torch.cat([aux_labels[name][b][tgt] for (b, tgt) in batch_tgt]).to(logits.device)
|
|
147
|
+
cw = class_weights.get(name)
|
|
148
|
+
if cw is not None:
|
|
149
|
+
cw = cw.to(logits.device, logits.dtype)
|
|
150
|
+
if (gt != ignore_index).any():
|
|
151
|
+
loss = F.cross_entropy(logits, gt, weight=cw, ignore_index=ignore_index)
|
|
152
|
+
else: # every matched label ignored — keep the graph alive with a zero
|
|
153
|
+
loss = logits.sum() * 0.0
|
|
154
|
+
per_head[name] = loss
|
|
155
|
+
w = float(weights.get(name, 1.0))
|
|
156
|
+
total = w * loss if total is None else total + w * loss
|
|
157
|
+
return total, per_head
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
@torch.no_grad()
|
|
161
|
+
def aux_accuracy(
|
|
162
|
+
model,
|
|
163
|
+
out: dict,
|
|
164
|
+
mask_labels,
|
|
165
|
+
class_labels,
|
|
166
|
+
aux_labels: dict[str, list[torch.Tensor]],
|
|
167
|
+
*,
|
|
168
|
+
indices=None,
|
|
169
|
+
ignore_index: int = -100,
|
|
170
|
+
) -> dict[str, tuple[int, int]]:
|
|
171
|
+
"""Top-1 ``{name: (correct, total)}`` on matched queries (the ``typ_acc`` analogue).
|
|
172
|
+
|
|
173
|
+
Ignored labels (``ignore_index``) are excluded from both correct and total.
|
|
174
|
+
Pass ``indices`` to reuse a matching already computed this step.
|
|
175
|
+
"""
|
|
176
|
+
if indices is None:
|
|
177
|
+
indices = match_queries(model, out, mask_labels, class_labels)
|
|
178
|
+
matched, batch_tgt = _gather_matched(out, indices)
|
|
179
|
+
if matched is None:
|
|
180
|
+
return {name: (0, 0) for name in model.aux_heads}
|
|
181
|
+
res: dict[str, tuple[int, int]] = {}
|
|
182
|
+
for name, head in model.aux_heads.items():
|
|
183
|
+
pred = head(matched).argmax(-1)
|
|
184
|
+
gt = torch.cat([aux_labels[name][b][tgt] for (b, tgt) in batch_tgt]).to(pred.device)
|
|
185
|
+
valid = gt != ignore_index
|
|
186
|
+
res[name] = (int(((pred == gt) & valid).sum()), int(valid.sum()))
|
|
187
|
+
return res
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
@torch.no_grad()
|
|
191
|
+
def aux_accuracy_by_primary(
|
|
192
|
+
model,
|
|
193
|
+
out: dict,
|
|
194
|
+
mask_labels,
|
|
195
|
+
class_labels,
|
|
196
|
+
aux_labels: dict[str, list[torch.Tensor]],
|
|
197
|
+
*,
|
|
198
|
+
indices=None,
|
|
199
|
+
ignore_index: int = -100,
|
|
200
|
+
) -> dict[str, dict[int, tuple[int, int]]]:
|
|
201
|
+
"""Per-head accuracy bucketed by GT **primary** class.
|
|
202
|
+
|
|
203
|
+
Returns ``{head: {primary_class_id: (correct, total)}}`` over matched queries —
|
|
204
|
+
a diagnostic for *which primary classes* the attribute is (in)accurate on. Pass
|
|
205
|
+
``indices`` (e.g. IoU-gated, but not class-gated, so weak primary classes still
|
|
206
|
+
appear) to control the population. Ignored labels are excluded.
|
|
207
|
+
"""
|
|
208
|
+
if indices is None:
|
|
209
|
+
indices = match_queries(model, out, mask_labels, class_labels)
|
|
210
|
+
matched, batch_tgt = _gather_matched(out, indices)
|
|
211
|
+
res: dict[str, dict[int, tuple[int, int]]] = {name: {} for name in model.aux_heads}
|
|
212
|
+
if matched is None:
|
|
213
|
+
return res
|
|
214
|
+
prim = torch.cat([class_labels[b][tgt] for (b, tgt) in batch_tgt]).to(matched.device)
|
|
215
|
+
for name, head in model.aux_heads.items():
|
|
216
|
+
pred = head(matched).argmax(-1)
|
|
217
|
+
gt = torch.cat([aux_labels[name][b][tgt] for (b, tgt) in batch_tgt]).to(pred.device)
|
|
218
|
+
valid = gt != ignore_index
|
|
219
|
+
correct = (pred == gt) & valid
|
|
220
|
+
d: dict[int, tuple[int, int]] = {}
|
|
221
|
+
for c in prim[valid].unique().tolist():
|
|
222
|
+
sel = (prim == c) & valid
|
|
223
|
+
d[int(c)] = (int(correct[sel].sum()), int(sel.sum()))
|
|
224
|
+
res[name] = d
|
|
225
|
+
return res
|
eomt/box_loss.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
"""Pure-PyTorch DETR-style detection loss: Hungarian matcher + L1/GIoU + class CE.
|
|
2
|
+
|
|
3
|
+
The detection counterpart to :mod:`eomt.loss` (the mask-classification stack). It
|
|
4
|
+
keeps the **same** query→GT assignment machinery — a Hungarian matcher exposed as
|
|
5
|
+
``DetectionLoss.matcher`` with a ``(geometry, class_queries_logits, geom_labels,
|
|
6
|
+
class_labels)`` signature and the same list-of-``(src, tgt)`` return — so
|
|
7
|
+
:func:`eomt.aux_cls.match_queries` can drive attribute supervision identically for
|
|
8
|
+
the box head. The only difference from :mod:`eomt.loss` is the geometry term: per-query
|
|
9
|
+
mask BCE + dice is replaced by per-query box **L1 + GIoU** (DETR's box loss).
|
|
10
|
+
|
|
11
|
+
All boxes here are normalized ``cxcywh`` in ``[0, 1]`` (the box head emits sigmoid
|
|
12
|
+
``cxcywh``; the detection dataset stores GT the same way), so no image size is needed.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
import torch.nn.functional as F # noqa: N812
|
|
19
|
+
from scipy.optimize import linear_sum_assignment
|
|
20
|
+
from torch import Tensor, nn
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def box_cxcywh_to_xyxy(boxes: Tensor) -> Tensor:
|
|
24
|
+
"""``(..., 4)`` boxes from center form ``(cx, cy, w, h)`` to corner form ``(x1, y1, x2, y2)``."""
|
|
25
|
+
cx, cy, w, h = boxes.unbind(-1)
|
|
26
|
+
return torch.stack([cx - 0.5 * w, cy - 0.5 * h, cx + 0.5 * w, cy + 0.5 * h], dim=-1)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _box_area(boxes: Tensor) -> Tensor:
|
|
30
|
+
return (boxes[:, 2] - boxes[:, 0]).clamp(min=0) * (boxes[:, 3] - boxes[:, 1]).clamp(min=0)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def box_iou(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Tensor]:
|
|
34
|
+
"""Pairwise IoU and union between two sets of ``xyxy`` boxes ``([N,4], [M,4])``."""
|
|
35
|
+
area1 = _box_area(boxes1)
|
|
36
|
+
area2 = _box_area(boxes2)
|
|
37
|
+
lt = torch.max(boxes1[:, None, :2], boxes2[None, :, :2])
|
|
38
|
+
rb = torch.min(boxes1[:, None, 2:], boxes2[None, :, 2:])
|
|
39
|
+
wh = (rb - lt).clamp(min=0)
|
|
40
|
+
inter = wh[..., 0] * wh[..., 1]
|
|
41
|
+
union = area1[:, None] + area2[None, :] - inter
|
|
42
|
+
iou = inter / union.clamp(min=1e-6)
|
|
43
|
+
return iou, union
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
|
|
47
|
+
"""Pairwise GIoU between two sets of ``xyxy`` boxes (DETR's enclosing-box variant)."""
|
|
48
|
+
iou, union = box_iou(boxes1, boxes2)
|
|
49
|
+
lt = torch.min(boxes1[:, None, :2], boxes2[None, :, :2])
|
|
50
|
+
rb = torch.max(boxes1[:, None, 2:], boxes2[None, :, 2:])
|
|
51
|
+
wh = (rb - lt).clamp(min=0)
|
|
52
|
+
area = wh[..., 0] * wh[..., 1]
|
|
53
|
+
return iou - (area - union) / area.clamp(min=1e-6)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class DetectionHungarianMatcher(nn.Module):
|
|
57
|
+
"""1-to-1 assignment between queries and GT boxes via class + L1 + GIoU cost."""
|
|
58
|
+
|
|
59
|
+
def __init__(self, cost_class: float = 1.0, cost_bbox: float = 1.0, cost_giou: float = 1.0):
|
|
60
|
+
super().__init__()
|
|
61
|
+
if cost_class == 0 and cost_bbox == 0 and cost_giou == 0:
|
|
62
|
+
raise ValueError("All costs can't be 0")
|
|
63
|
+
self.cost_class = cost_class
|
|
64
|
+
self.cost_bbox = cost_bbox
|
|
65
|
+
self.cost_giou = cost_giou
|
|
66
|
+
|
|
67
|
+
@torch.no_grad()
|
|
68
|
+
def forward(
|
|
69
|
+
self,
|
|
70
|
+
pred_boxes: Tensor,
|
|
71
|
+
class_queries_logits: Tensor,
|
|
72
|
+
box_labels: list[Tensor],
|
|
73
|
+
class_labels: list[Tensor],
|
|
74
|
+
) -> list[tuple[Tensor, Tensor]]:
|
|
75
|
+
indices: list = []
|
|
76
|
+
batch_size = pred_boxes.shape[0]
|
|
77
|
+
for i in range(batch_size):
|
|
78
|
+
# Cost math in fp32: cdist/GIoU are not implemented for fp16 on CUDA, and
|
|
79
|
+
# the matcher is called both inside autocast (training) and outside it
|
|
80
|
+
# (eval/aux), where the tensors stay half.
|
|
81
|
+
tgt_boxes = box_labels[i].to(device=pred_boxes.device, dtype=torch.float32)
|
|
82
|
+
tgt_cls = class_labels[i]
|
|
83
|
+
if tgt_boxes.numel() == 0:
|
|
84
|
+
indices.append(
|
|
85
|
+
(
|
|
86
|
+
torch.as_tensor([], dtype=torch.int64),
|
|
87
|
+
torch.as_tensor([], dtype=torch.int64),
|
|
88
|
+
)
|
|
89
|
+
)
|
|
90
|
+
continue
|
|
91
|
+
pred_probs = class_queries_logits[i].float().softmax(-1) # [Q, C+1]
|
|
92
|
+
out_boxes = pred_boxes[i].float() # [Q, 4]
|
|
93
|
+
|
|
94
|
+
cost_class = -pred_probs[:, tgt_cls] # [Q, num_tgt]
|
|
95
|
+
cost_bbox = torch.cdist(out_boxes, tgt_boxes, p=1) # [Q, num_tgt]
|
|
96
|
+
cost_giou = -generalized_box_iou(
|
|
97
|
+
box_cxcywh_to_xyxy(out_boxes), box_cxcywh_to_xyxy(tgt_boxes)
|
|
98
|
+
)
|
|
99
|
+
cost_matrix = (
|
|
100
|
+
self.cost_bbox * cost_bbox
|
|
101
|
+
+ self.cost_class * cost_class
|
|
102
|
+
+ self.cost_giou * cost_giou
|
|
103
|
+
)
|
|
104
|
+
cost_matrix = torch.minimum(cost_matrix, torch.tensor(1e10))
|
|
105
|
+
cost_matrix = torch.maximum(cost_matrix, torch.tensor(-1e10))
|
|
106
|
+
cost_matrix = torch.nan_to_num(cost_matrix, 0)
|
|
107
|
+
indices.append(linear_sum_assignment(cost_matrix.cpu()))
|
|
108
|
+
|
|
109
|
+
return [
|
|
110
|
+
(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
|
|
111
|
+
for i, j in indices
|
|
112
|
+
]
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class DetectionLoss(nn.Module):
|
|
116
|
+
"""DETR-style detection loss (class CE + box L1 + GIoU), matching EoMTLoss's API.
|
|
117
|
+
|
|
118
|
+
``weight_dict`` carries ``{"loss_cross_entropy", "loss_bbox", "loss_giou"}``; the
|
|
119
|
+
matcher reuses ``class_weight`` (from the config) and the L1/GIoU weights from
|
|
120
|
+
``weight_dict``. ``forward`` returns the three *unweighted* loss terms — the caller
|
|
121
|
+
(``EoMTEncoder.get_loss_dict``) applies ``weight_dict``, exactly as for the mask loss.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
def __init__(self, config, weight_dict: dict[str, float]):
|
|
125
|
+
super().__init__()
|
|
126
|
+
self.num_labels = config.num_labels
|
|
127
|
+
self.weight_dict = weight_dict
|
|
128
|
+
|
|
129
|
+
self.eos_coef = config.no_object_weight
|
|
130
|
+
empty_weight = torch.ones(self.num_labels + 1)
|
|
131
|
+
empty_weight[-1] = self.eos_coef
|
|
132
|
+
self.register_buffer("empty_weight", empty_weight)
|
|
133
|
+
|
|
134
|
+
self.matcher = DetectionHungarianMatcher(
|
|
135
|
+
cost_class=config.class_weight,
|
|
136
|
+
cost_bbox=weight_dict["loss_bbox"],
|
|
137
|
+
cost_giou=weight_dict["loss_giou"],
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
def _get_predictions_permutation_indices(self, indices):
|
|
141
|
+
batch_indices = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
|
142
|
+
predictions_indices = torch.cat([src for (src, _) in indices])
|
|
143
|
+
return batch_indices, predictions_indices
|
|
144
|
+
|
|
145
|
+
def loss_labels(self, class_queries_logits, class_labels, indices) -> dict[str, Tensor]:
|
|
146
|
+
pred_logits = class_queries_logits
|
|
147
|
+
batch_size, num_queries, _ = pred_logits.shape
|
|
148
|
+
criterion = nn.CrossEntropyLoss(weight=self.empty_weight)
|
|
149
|
+
idx = self._get_predictions_permutation_indices(indices)
|
|
150
|
+
target_classes_o = torch.cat(
|
|
151
|
+
[target[j] for target, (_, j) in zip(class_labels, indices)]
|
|
152
|
+
).to(pred_logits.device)
|
|
153
|
+
target_classes = torch.full(
|
|
154
|
+
(batch_size, num_queries), fill_value=self.num_labels, dtype=torch.int64,
|
|
155
|
+
device=pred_logits.device,
|
|
156
|
+
)
|
|
157
|
+
target_classes[idx] = target_classes_o
|
|
158
|
+
loss_ce = criterion(pred_logits.transpose(1, 2), target_classes)
|
|
159
|
+
return {"loss_cross_entropy": loss_ce}
|
|
160
|
+
|
|
161
|
+
def loss_boxes(self, pred_boxes, box_labels, indices, num_boxes) -> dict[str, Tensor]:
|
|
162
|
+
idx = self._get_predictions_permutation_indices(indices)
|
|
163
|
+
src_boxes = pred_boxes[idx] # [N, 4] cxcywh
|
|
164
|
+
tgt_boxes = torch.cat(
|
|
165
|
+
[t[j] for t, (_, j) in zip(box_labels, indices)]
|
|
166
|
+
).to(src_boxes) # [N, 4] cxcywh
|
|
167
|
+
if src_boxes.numel() == 0: # no matched query this batch — keep graph alive
|
|
168
|
+
zero = pred_boxes.sum() * 0.0
|
|
169
|
+
return {"loss_bbox": zero, "loss_giou": zero}
|
|
170
|
+
|
|
171
|
+
loss_bbox = F.l1_loss(src_boxes, tgt_boxes, reduction="sum") / num_boxes
|
|
172
|
+
giou = generalized_box_iou(
|
|
173
|
+
box_cxcywh_to_xyxy(src_boxes), box_cxcywh_to_xyxy(tgt_boxes)
|
|
174
|
+
).diagonal()
|
|
175
|
+
loss_giou = (1 - giou).sum() / num_boxes
|
|
176
|
+
return {"loss_bbox": loss_bbox, "loss_giou": loss_giou}
|
|
177
|
+
|
|
178
|
+
def get_num_boxes(self, class_labels, device) -> Tensor:
|
|
179
|
+
num = sum(len(c) for c in class_labels)
|
|
180
|
+
return torch.clamp(torch.as_tensor(num, dtype=torch.float, device=device), min=1)
|
|
181
|
+
|
|
182
|
+
def forward(
|
|
183
|
+
self,
|
|
184
|
+
masks_queries_logits: Tensor, # unused; kept for signature symmetry with EoMTLoss
|
|
185
|
+
class_queries_logits: Tensor,
|
|
186
|
+
mask_labels: list[Tensor], # actually box_labels (cxcywh) for the detect family
|
|
187
|
+
class_labels: list[Tensor],
|
|
188
|
+
auxiliary_predictions=None,
|
|
189
|
+
) -> dict[str, Tensor]:
|
|
190
|
+
pred_boxes, box_labels = masks_queries_logits, mask_labels
|
|
191
|
+
indices = self.matcher(pred_boxes, class_queries_logits, box_labels, class_labels)
|
|
192
|
+
num_boxes = self.get_num_boxes(class_labels, device=class_queries_logits.device)
|
|
193
|
+
return {
|
|
194
|
+
**self.loss_boxes(pred_boxes, box_labels, indices, num_boxes),
|
|
195
|
+
**self.loss_labels(class_queries_logits, class_labels, indices),
|
|
196
|
+
}
|