attr-eomt 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
eomt/api.py ADDED
@@ -0,0 +1,253 @@
1
+ """High-level :class:`EoMT` interface — init from a checkpoint or a size, then
2
+ ``train`` / ``val`` / ``predict``.
3
+
4
+ from eomt import EoMT
5
+
6
+ model = EoMT("l") # fresh large model (DINOv2 backbone)
7
+ model.train(data="coco", epochs=50) # COCO auto-downloads if missing
8
+
9
+ model = EoMT("runs/train/eomt-l") # reload a run (everything auto-detected)
10
+ model.val(data="coco")
11
+ results = model.predict("images/", plot=True)
12
+
13
+ The class is a thin orchestration layer over the existing engine functions; the
14
+ trainable network itself is :class:`~eomt.model.EoMTModel`, reachable via
15
+ ``model.model``.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ from pathlib import Path
21
+
22
+ import torch
23
+
24
+ from .config import SIZES
25
+ from .data import CocoValImages, load_data_config
26
+ from .engine import evaluate as _evaluate
27
+ from .engine import evaluate_detection as _evaluate_detection
28
+ from .engine import predict as _predict
29
+ from .engine import train as _train
30
+ from .model import build_model, load_dinov2_backbone
31
+ from .serialization import is_hf_ref, load_model, save_checkpoint, wrap_checkpoint
32
+
33
+ #: Named dataset aliases that resolve to a bundled YAML config.
34
+ _DATASET_ALIASES = {"coco": "configs/coco.yaml"}
35
+
36
+
37
+ def _looks_like_checkpoint(spec: str | Path) -> bool:
38
+ """True if ``spec`` points at a checkpoint file, a run/weights folder, or a Hub ref."""
39
+ if is_hf_ref(spec):
40
+ return True
41
+ p = Path(spec)
42
+ if p.suffix == ".pt":
43
+ return True
44
+ if p.is_file():
45
+ return True
46
+ if p.is_dir():
47
+ # A run folder if it (or its weights/ subdir) holds a best/last checkpoint.
48
+ return any((p / sub).is_file() for sub in ("best.pt", "last.pt", "weights/best.pt", "weights/last.pt"))
49
+ return False
50
+
51
+
52
+ def _resolve_data(data: str | Path) -> dict:
53
+ """Resolve a dataset spec (alias like ``"coco"`` or a YAML path) to absolute paths."""
54
+ yaml_path = _DATASET_ALIASES.get(str(data), str(data))
55
+ if not Path(yaml_path).is_file():
56
+ raise FileNotFoundError(
57
+ f"dataset config not found: {yaml_path!r}. Pass a dataset YAML path or one "
58
+ f"of {sorted(_DATASET_ALIASES)} (run from the repo root for the bundled configs)."
59
+ )
60
+ return load_data_config(yaml_path)
61
+
62
+
63
+ class EoMT:
64
+ """An EoMT instance-segmentation model with train/val/predict.
65
+
66
+ Construct from either a model size (``"s"`` / ``"b"`` / ``"l"`` — a fresh model
67
+ with a pretrained DINOv2 backbone and a randomly initialized head) or a
68
+ checkpoint (a ``.pt`` file or a run/weights folder — size, classes, image size
69
+ and any secondary heads are all auto-detected from the checkpoint).
70
+ """
71
+
72
+ def __init__(self, model: str | Path = "l", *, device: str = "auto", pretrained: bool = True, **build_kwargs):
73
+ self.device = device
74
+ self._ckpt: str | None = None
75
+ self._pretrained = pretrained
76
+ self._build_kwargs = build_kwargs
77
+
78
+ if isinstance(model, (str, Path)) and _looks_like_checkpoint(model):
79
+ self._model = load_model(model, device=device)
80
+ self._ckpt = str(model)
81
+ self.size = self._model.size
82
+ elif str(model) in SIZES:
83
+ self.size = str(model)
84
+ self._model = None # built lazily (avoids a wasted DINOv2 load before train())
85
+ else:
86
+ raise ValueError(
87
+ f"{model!r} is neither a known size {tuple(SIZES)} nor an existing checkpoint."
88
+ )
89
+
90
+ # ----------------------------------------------------------- huggingface
91
+ @classmethod
92
+ def from_pretrained(
93
+ cls,
94
+ repo_id: str,
95
+ *,
96
+ filename: str = "model.pt",
97
+ revision: str | None = None,
98
+ device: str = "auto",
99
+ ) -> "EoMT":
100
+ """Load a model from a Hugging Face Hub repo (downloaded once, then cached).
101
+
102
+ model = EoMT.from_pretrained("imagra93/eomt-l-coco")
103
+
104
+ Equivalent to ``EoMT("hf://<repo_id>/<filename>")``. ``revision`` pins a
105
+ branch, tag or commit; ``filename`` selects the checkpoint inside the repo.
106
+ """
107
+ ref = f"hf://{repo_id}/{filename}"
108
+ self = cls.__new__(cls)
109
+ self.device = device
110
+ self._pretrained = True
111
+ self._build_kwargs = {}
112
+ # Resolve via the Hub (cached) and load, recording the ref for repr.
113
+ from .serialization import download_from_hub
114
+ local = download_from_hub(ref, revision=revision)
115
+ self._model = load_model(local, device=device)
116
+ self._ckpt = ref
117
+ self.size = self._model.size
118
+ return self
119
+
120
+ def push_to_hub(
121
+ self,
122
+ repo_id: str,
123
+ *,
124
+ filename: str = "model.pt",
125
+ private: bool = True,
126
+ commit_message: str = "Upload EoMT checkpoint",
127
+ ) -> str:
128
+ """Upload this model's weights to a Hugging Face Hub repo (created if absent).
129
+
130
+ Writes a self-describing checkpoint (same format as :meth:`save`) and
131
+ uploads it as ``filename``. Returns the repo URL. Requires a Hub token
132
+ (``huggingface-cli login`` or ``HF_TOKEN``). Defaults to a **private** repo.
133
+ """
134
+ try:
135
+ from huggingface_hub import HfApi
136
+ except ImportError as e: # pragma: no cover
137
+ raise ImportError("push_to_hub needs 'huggingface_hub' (pip install huggingface_hub).") from e
138
+
139
+ import tempfile
140
+
141
+ api = HfApi()
142
+ api.create_repo(repo_id, repo_type="model", private=private, exist_ok=True)
143
+ with tempfile.TemporaryDirectory() as tmp:
144
+ local = Path(tmp) / filename
145
+ self.save(local)
146
+ api.upload_file(
147
+ path_or_fileobj=str(local),
148
+ path_in_repo=filename,
149
+ repo_id=repo_id,
150
+ repo_type="model",
151
+ commit_message=commit_message,
152
+ )
153
+ return f"https://huggingface.co/{repo_id}"
154
+
155
+ # ------------------------------------------------------------------ model
156
+ @property
157
+ def model(self):
158
+ """The underlying trainable :class:`~eomt.model.EoMTModel` (built on demand)."""
159
+ if self._model is None:
160
+ self._model = build_model(self.size, **self._build_kwargs)
161
+ if self._pretrained:
162
+ load_dinov2_backbone(self._model)
163
+ dev = "cuda" if (self.device in ("", "auto") and torch.cuda.is_available()) else \
164
+ ("cpu" if self.device in ("", "auto") else self.device)
165
+ self._model = self._model.to(dev).eval()
166
+ return self._model
167
+
168
+ # ------------------------------------------------------------------ train
169
+ def train(self, data: str | Path = "coco", *, resume: bool = False, **hp) -> dict:
170
+ """Train on a COCO-format dataset.
171
+
172
+ ``data`` is a dataset YAML path or a known alias (``"coco"`` auto-downloads).
173
+ Extra keyword args (``epochs``, ``batch``, ``lr0``, ``aux_w``, …) are passed
174
+ straight through to the training engine. For a checkpoint-initialized model,
175
+ ``resume=True`` continues the original run; otherwise the checkpoint warm-starts
176
+ a fresh run (fine-tune). Returns the engine's result dict and reloads the best
177
+ weights into this object.
178
+ """
179
+ cfg = _resolve_data(data)
180
+ if not (cfg["train_images"] and cfg["train_json"]):
181
+ raise ValueError(f"dataset {data!r} has no train split (train_images/train_json).")
182
+
183
+ if self._ckpt is not None:
184
+ hp["resume" if resume else "init_weights"] = self._ckpt
185
+
186
+ result = _train(
187
+ train_images=cfg["train_images"],
188
+ train_json=cfg["train_json"],
189
+ val_images=cfg["val_images"],
190
+ val_json=cfg["val_json"],
191
+ size=self.size,
192
+ device=self.device,
193
+ **hp,
194
+ )
195
+ best = result.get("best") or result.get("last")
196
+ if best:
197
+ self._model = load_model(best, device=self.device)
198
+ self._ckpt = best
199
+ return result
200
+
201
+ # -------------------------------------------------------------------- val
202
+ def val(self, data: str | Path = "coco", *, batch: int = 4, workers: int = 4,
203
+ conf_thres: float = 0.0, max_det: int = 100, letterbox: bool | None = None, **kw) -> dict:
204
+ """Evaluate on a dataset's val split, returning COCO mAP metrics.
205
+
206
+ Segmentation models report ``segm/*`` (+ ``bbox/*``); detection
207
+ (``family="detect"``) models report ``bbox/*`` only.
208
+ """
209
+ cfg = _resolve_data(data)
210
+ if not (cfg["val_images"] and cfg["val_json"]):
211
+ raise ValueError(f"dataset {data!r} has no val split (val_images/val_json).")
212
+
213
+ model = self.model
214
+ dev = next(model.parameters()).device
215
+ lb = letterbox if letterbox is not None else bool(getattr(model, "preprocess_letterbox", False))
216
+ val_ds = CocoValImages(cfg["val_images"], cfg["val_json"], imgsz=int(model.image_size), letterbox=lb)
217
+ eval_fn = _evaluate_detection if getattr(model, "family", "instance") == "detect" else _evaluate
218
+ return eval_fn(
219
+ model, val_ds, device=dev, batch_size=batch, num_workers=workers,
220
+ conf_thres=conf_thres, max_det=max_det, **kw,
221
+ )
222
+
223
+ # ---------------------------------------------------------------- predict
224
+ def predict(self, source: str | Path, *, plot: bool = False, save: str | None = "runs/predict",
225
+ conf_thres: float = 0.3, max_det: int = 100, mask_thresh: float = 0.5, **kw) -> list[dict]:
226
+ """Run inference on an image or a directory.
227
+
228
+ Returns one result dict per image (``boxes`` / ``scores`` / ``classes`` /
229
+ ``masks`` and, for models with secondary heads, ``aux``). With ``plot=True``
230
+ each image is rendered with masks/boxes/labels and saved under ``save``.
231
+ """
232
+ return _predict(
233
+ self.model, str(source), plot=plot, save=save,
234
+ conf_thres=conf_thres, max_det=max_det, mask_thresh=mask_thresh, **kw,
235
+ )
236
+
237
+ # ------------------------------------------------------------------- save
238
+ def save(self, path: str | Path) -> None:
239
+ """Write a self-describing checkpoint (reloadable with ``EoMT(path)``)."""
240
+ m = self.model
241
+ ckpt = wrap_checkpoint(
242
+ m.state_dict(),
243
+ size=m.size, nc=m.nc, imgsz=m.image_size, names=m.names,
244
+ task=getattr(m, "family", "instance"),
245
+ aux_heads=m.aux_specs, aux_head_arch=m.aux_head_arch,
246
+ letterbox=bool(getattr(m, "preprocess_letterbox", False)),
247
+ loss_weights=m.loss_weights, num_upscale_blocks=m.num_upscale_blocks,
248
+ )
249
+ save_checkpoint(ckpt, path)
250
+
251
+ def __repr__(self) -> str:
252
+ src = f"ckpt={self._ckpt!r}" if self._ckpt else f"size={self.size!r}"
253
+ return f"EoMT({src})"
eomt/aux_cls.py ADDED
@@ -0,0 +1,225 @@
1
+ """Secondary per-instance classification heads ("attributes") for EoMT.
2
+
3
+ The primary task (instance segmentation over ``nc`` classes) is unchanged. Each
4
+ aux head predicts an extra attribute per *detected instance* — typology,
5
+ laterality, severity, … — read from the matched query's embedding.
6
+
7
+ The supervision reuses EoMT's own Hungarian matcher
8
+ (``model.eomt.criterion.matcher``) so each attribute is trained on the **same**
9
+ query→GT assignment the detection loss used. Several specs ⇒ several independent
10
+ heads, summed (optionally weighted) into one scalar.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import torch
16
+ import torch.nn.functional as F # noqa: N812
17
+
18
+
19
+ @torch.no_grad()
20
+ def match_queries(model, out: dict, geom_labels, class_labels):
21
+ """Hungarian query→GT indices from EoMT's matcher (list of ``(src, tgt)``).
22
+
23
+ ``geom_labels`` is the GT geometry for the model's family — instance masks for the
24
+ ``"instance"`` family, normalized ``cxcywh`` boxes for ``"detect"``. The right
25
+ matcher (mask-cost or box-cost) is selected by what the model emitted in ``out``.
26
+ """
27
+ geom = out["pred_boxes"] if "pred_boxes" in out else out["masks_queries_logits"]
28
+ return model.eomt.criterion.matcher(
29
+ geom, out["class_queries_logits"], geom_labels, class_labels
30
+ )
31
+
32
+
33
+ def _gather_matched(out: dict, indices):
34
+ """Return ``(matched_feats [N, hidden], [(b, tgt_idx)...])`` for matched queries."""
35
+ q = out["query_embed"] # [B, Q, hidden]
36
+ feats, batch_tgt = [], []
37
+ for b, (src, tgt) in enumerate(indices):
38
+ if src.numel() == 0:
39
+ continue
40
+ feats.append(q[b, src])
41
+ batch_tgt.append((b, tgt))
42
+ if not feats:
43
+ return None, batch_tgt
44
+ return torch.cat(feats), batch_tgt
45
+
46
+
47
+ @torch.no_grad()
48
+ def gate_indices(
49
+ out: dict,
50
+ indices,
51
+ geom_labels,
52
+ class_labels,
53
+ *,
54
+ iou_thr: float = 0.5,
55
+ require_class: bool = False,
56
+ ):
57
+ """Keep only well-localized (and optionally correctly-classified) matched pairs.
58
+
59
+ The Hungarian matcher assigns *every* GT a query, even one whose prediction barely
60
+ overlaps it (common early in training). For attribute supervision we want queries
61
+ that actually localize the instance, so we drop matched ``(src, tgt)`` pairs whose
62
+ predicted↔GT IoU is ``< iou_thr`` (mask IoU for the instance family, box IoU for
63
+ detect). With ``require_class`` we also drop pairs whose predicted primary class ≠
64
+ the GT class. Returns the same ``[(src, tgt), ...]`` structure with each pair
65
+ filtered (possibly to empty).
66
+
67
+ A no-op (returns ``indices`` unchanged) when ``iou_thr <= 0`` and not
68
+ ``require_class`` — i.e. the pre-gate behaviour.
69
+ """
70
+ if iou_thr <= 0 and not require_class:
71
+ return indices
72
+ is_detect = "pred_boxes" in out
73
+ cls = out["class_queries_logits"] # [B, Q, C+1]
74
+ dev = cls.device
75
+ gated = []
76
+ for b, (src, tgt) in enumerate(indices):
77
+ src = src.to(dev)
78
+ tgt = tgt.to(dev)
79
+ if src.numel() == 0:
80
+ gated.append((src, tgt))
81
+ continue
82
+ keep = torch.ones(src.numel(), dtype=torch.bool, device=dev)
83
+ if iou_thr > 0:
84
+ keep &= _localization_iou(out, b, src, tgt, geom_labels, is_detect, dev) >= iou_thr
85
+ if require_class:
86
+ pred_cls = cls[b, src].argmax(-1)
87
+ keep &= pred_cls == class_labels[b][tgt].to(dev)
88
+ gated.append((src[keep], tgt[keep]))
89
+ return gated
90
+
91
+
92
+ def _localization_iou(out, b, src, tgt, geom_labels, is_detect, dev):
93
+ """Per-pair predicted↔GT IoU for the matched queries (box IoU or mask IoU)."""
94
+ if is_detect:
95
+ from .box_loss import box_cxcywh_to_xyxy, box_iou
96
+
97
+ pred = box_cxcywh_to_xyxy(out["pred_boxes"][b, src].to(dev)) # [n, 4]
98
+ gt = box_cxcywh_to_xyxy(geom_labels[b][tgt].to(dev).float()) # [n, 4]
99
+ return box_iou(pred, gt)[0].diagonal() # per-matched-pair IoU
100
+ masks = out["masks_queries_logits"] # [B, Q, h, w]
101
+ pm = masks[b, src].sigmoid() > 0.5 # [n, h, w]
102
+ gm = geom_labels[b][tgt].to(dev).float() # [n, H, W] in {0, 1}
103
+ if gm.shape[-2:] != pm.shape[-2:]:
104
+ gm = F.interpolate(gm.unsqueeze(1), size=pm.shape[-2:], mode="nearest").squeeze(1)
105
+ gm = gm > 0.5
106
+ inter = (pm & gm).flatten(1).sum(1).float()
107
+ union = (pm | gm).flatten(1).sum(1).float().clamp(min=1.0)
108
+ return inter / union
109
+
110
+
111
+ def aux_loss(
112
+ model,
113
+ out: dict,
114
+ mask_labels,
115
+ class_labels,
116
+ aux_labels: dict[str, list[torch.Tensor]],
117
+ weights: dict[str, float] | None = None,
118
+ *,
119
+ indices=None,
120
+ class_weights: dict[str, torch.Tensor] | None = None,
121
+ ignore_index: int = -100,
122
+ ):
123
+ """Weighted sum of per-attribute CE over matched queries.
124
+
125
+ ``aux_labels`` maps head name → per-image label tensors (line-aligned with
126
+ ``class_labels``). Labels equal to ``ignore_index`` (missing / out-of-vocab
127
+ attributes) are skipped — they contribute no loss instead of being trained as
128
+ class 0. ``weights`` scales each head; ``class_weights`` optionally re-weights
129
+ classes within a head (imbalance). Pass ``indices`` to reuse a matching already
130
+ computed this step (avoids re-running the Hungarian matcher). Returns
131
+ ``(total_loss, {name: per_head_loss})``.
132
+ """
133
+ if indices is None:
134
+ indices = match_queries(model, out, mask_labels, class_labels)
135
+ matched, batch_tgt = _gather_matched(out, indices)
136
+ if matched is None: # no query matched in this batch — keep the graph alive
137
+ zero = out["query_embed"].sum() * 0.0
138
+ return zero, {name: zero.detach() for name in model.aux_heads}
139
+
140
+ weights = weights or {}
141
+ class_weights = class_weights or {}
142
+ total: torch.Tensor | None = None
143
+ per_head: dict[str, torch.Tensor] = {}
144
+ for name, head in model.aux_heads.items():
145
+ logits = head(matched) # [N, ns]
146
+ gt = torch.cat([aux_labels[name][b][tgt] for (b, tgt) in batch_tgt]).to(logits.device)
147
+ cw = class_weights.get(name)
148
+ if cw is not None:
149
+ cw = cw.to(logits.device, logits.dtype)
150
+ if (gt != ignore_index).any():
151
+ loss = F.cross_entropy(logits, gt, weight=cw, ignore_index=ignore_index)
152
+ else: # every matched label ignored — keep the graph alive with a zero
153
+ loss = logits.sum() * 0.0
154
+ per_head[name] = loss
155
+ w = float(weights.get(name, 1.0))
156
+ total = w * loss if total is None else total + w * loss
157
+ return total, per_head
158
+
159
+
160
+ @torch.no_grad()
161
+ def aux_accuracy(
162
+ model,
163
+ out: dict,
164
+ mask_labels,
165
+ class_labels,
166
+ aux_labels: dict[str, list[torch.Tensor]],
167
+ *,
168
+ indices=None,
169
+ ignore_index: int = -100,
170
+ ) -> dict[str, tuple[int, int]]:
171
+ """Top-1 ``{name: (correct, total)}`` on matched queries (the ``typ_acc`` analogue).
172
+
173
+ Ignored labels (``ignore_index``) are excluded from both correct and total.
174
+ Pass ``indices`` to reuse a matching already computed this step.
175
+ """
176
+ if indices is None:
177
+ indices = match_queries(model, out, mask_labels, class_labels)
178
+ matched, batch_tgt = _gather_matched(out, indices)
179
+ if matched is None:
180
+ return {name: (0, 0) for name in model.aux_heads}
181
+ res: dict[str, tuple[int, int]] = {}
182
+ for name, head in model.aux_heads.items():
183
+ pred = head(matched).argmax(-1)
184
+ gt = torch.cat([aux_labels[name][b][tgt] for (b, tgt) in batch_tgt]).to(pred.device)
185
+ valid = gt != ignore_index
186
+ res[name] = (int(((pred == gt) & valid).sum()), int(valid.sum()))
187
+ return res
188
+
189
+
190
+ @torch.no_grad()
191
+ def aux_accuracy_by_primary(
192
+ model,
193
+ out: dict,
194
+ mask_labels,
195
+ class_labels,
196
+ aux_labels: dict[str, list[torch.Tensor]],
197
+ *,
198
+ indices=None,
199
+ ignore_index: int = -100,
200
+ ) -> dict[str, dict[int, tuple[int, int]]]:
201
+ """Per-head accuracy bucketed by GT **primary** class.
202
+
203
+ Returns ``{head: {primary_class_id: (correct, total)}}`` over matched queries —
204
+ a diagnostic for *which primary classes* the attribute is (in)accurate on. Pass
205
+ ``indices`` (e.g. IoU-gated, but not class-gated, so weak primary classes still
206
+ appear) to control the population. Ignored labels are excluded.
207
+ """
208
+ if indices is None:
209
+ indices = match_queries(model, out, mask_labels, class_labels)
210
+ matched, batch_tgt = _gather_matched(out, indices)
211
+ res: dict[str, dict[int, tuple[int, int]]] = {name: {} for name in model.aux_heads}
212
+ if matched is None:
213
+ return res
214
+ prim = torch.cat([class_labels[b][tgt] for (b, tgt) in batch_tgt]).to(matched.device)
215
+ for name, head in model.aux_heads.items():
216
+ pred = head(matched).argmax(-1)
217
+ gt = torch.cat([aux_labels[name][b][tgt] for (b, tgt) in batch_tgt]).to(pred.device)
218
+ valid = gt != ignore_index
219
+ correct = (pred == gt) & valid
220
+ d: dict[int, tuple[int, int]] = {}
221
+ for c in prim[valid].unique().tolist():
222
+ sel = (prim == c) & valid
223
+ d[int(c)] = (int(correct[sel].sum()), int(sel.sum()))
224
+ res[name] = d
225
+ return res
eomt/box_loss.py ADDED
@@ -0,0 +1,196 @@
1
+ """Pure-PyTorch DETR-style detection loss: Hungarian matcher + L1/GIoU + class CE.
2
+
3
+ The detection counterpart to :mod:`eomt.loss` (the mask-classification stack). It
4
+ keeps the **same** query→GT assignment machinery — a Hungarian matcher exposed as
5
+ ``DetectionLoss.matcher`` with a ``(geometry, class_queries_logits, geom_labels,
6
+ class_labels)`` signature and the same list-of-``(src, tgt)`` return — so
7
+ :func:`eomt.aux_cls.match_queries` can drive attribute supervision identically for
8
+ the box head. The only difference from :mod:`eomt.loss` is the geometry term: per-query
9
+ mask BCE + dice is replaced by per-query box **L1 + GIoU** (DETR's box loss).
10
+
11
+ All boxes here are normalized ``cxcywh`` in ``[0, 1]`` (the box head emits sigmoid
12
+ ``cxcywh``; the detection dataset stores GT the same way), so no image size is needed.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import torch
18
+ import torch.nn.functional as F # noqa: N812
19
+ from scipy.optimize import linear_sum_assignment
20
+ from torch import Tensor, nn
21
+
22
+
23
+ def box_cxcywh_to_xyxy(boxes: Tensor) -> Tensor:
24
+ """``(..., 4)`` boxes from center form ``(cx, cy, w, h)`` to corner form ``(x1, y1, x2, y2)``."""
25
+ cx, cy, w, h = boxes.unbind(-1)
26
+ return torch.stack([cx - 0.5 * w, cy - 0.5 * h, cx + 0.5 * w, cy + 0.5 * h], dim=-1)
27
+
28
+
29
+ def _box_area(boxes: Tensor) -> Tensor:
30
+ return (boxes[:, 2] - boxes[:, 0]).clamp(min=0) * (boxes[:, 3] - boxes[:, 1]).clamp(min=0)
31
+
32
+
33
+ def box_iou(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Tensor]:
34
+ """Pairwise IoU and union between two sets of ``xyxy`` boxes ``([N,4], [M,4])``."""
35
+ area1 = _box_area(boxes1)
36
+ area2 = _box_area(boxes2)
37
+ lt = torch.max(boxes1[:, None, :2], boxes2[None, :, :2])
38
+ rb = torch.min(boxes1[:, None, 2:], boxes2[None, :, 2:])
39
+ wh = (rb - lt).clamp(min=0)
40
+ inter = wh[..., 0] * wh[..., 1]
41
+ union = area1[:, None] + area2[None, :] - inter
42
+ iou = inter / union.clamp(min=1e-6)
43
+ return iou, union
44
+
45
+
46
+ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
47
+ """Pairwise GIoU between two sets of ``xyxy`` boxes (DETR's enclosing-box variant)."""
48
+ iou, union = box_iou(boxes1, boxes2)
49
+ lt = torch.min(boxes1[:, None, :2], boxes2[None, :, :2])
50
+ rb = torch.max(boxes1[:, None, 2:], boxes2[None, :, 2:])
51
+ wh = (rb - lt).clamp(min=0)
52
+ area = wh[..., 0] * wh[..., 1]
53
+ return iou - (area - union) / area.clamp(min=1e-6)
54
+
55
+
56
+ class DetectionHungarianMatcher(nn.Module):
57
+ """1-to-1 assignment between queries and GT boxes via class + L1 + GIoU cost."""
58
+
59
+ def __init__(self, cost_class: float = 1.0, cost_bbox: float = 1.0, cost_giou: float = 1.0):
60
+ super().__init__()
61
+ if cost_class == 0 and cost_bbox == 0 and cost_giou == 0:
62
+ raise ValueError("All costs can't be 0")
63
+ self.cost_class = cost_class
64
+ self.cost_bbox = cost_bbox
65
+ self.cost_giou = cost_giou
66
+
67
+ @torch.no_grad()
68
+ def forward(
69
+ self,
70
+ pred_boxes: Tensor,
71
+ class_queries_logits: Tensor,
72
+ box_labels: list[Tensor],
73
+ class_labels: list[Tensor],
74
+ ) -> list[tuple[Tensor, Tensor]]:
75
+ indices: list = []
76
+ batch_size = pred_boxes.shape[0]
77
+ for i in range(batch_size):
78
+ # Cost math in fp32: cdist/GIoU are not implemented for fp16 on CUDA, and
79
+ # the matcher is called both inside autocast (training) and outside it
80
+ # (eval/aux), where the tensors stay half.
81
+ tgt_boxes = box_labels[i].to(device=pred_boxes.device, dtype=torch.float32)
82
+ tgt_cls = class_labels[i]
83
+ if tgt_boxes.numel() == 0:
84
+ indices.append(
85
+ (
86
+ torch.as_tensor([], dtype=torch.int64),
87
+ torch.as_tensor([], dtype=torch.int64),
88
+ )
89
+ )
90
+ continue
91
+ pred_probs = class_queries_logits[i].float().softmax(-1) # [Q, C+1]
92
+ out_boxes = pred_boxes[i].float() # [Q, 4]
93
+
94
+ cost_class = -pred_probs[:, tgt_cls] # [Q, num_tgt]
95
+ cost_bbox = torch.cdist(out_boxes, tgt_boxes, p=1) # [Q, num_tgt]
96
+ cost_giou = -generalized_box_iou(
97
+ box_cxcywh_to_xyxy(out_boxes), box_cxcywh_to_xyxy(tgt_boxes)
98
+ )
99
+ cost_matrix = (
100
+ self.cost_bbox * cost_bbox
101
+ + self.cost_class * cost_class
102
+ + self.cost_giou * cost_giou
103
+ )
104
+ cost_matrix = torch.minimum(cost_matrix, torch.tensor(1e10))
105
+ cost_matrix = torch.maximum(cost_matrix, torch.tensor(-1e10))
106
+ cost_matrix = torch.nan_to_num(cost_matrix, 0)
107
+ indices.append(linear_sum_assignment(cost_matrix.cpu()))
108
+
109
+ return [
110
+ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
111
+ for i, j in indices
112
+ ]
113
+
114
+
115
+ class DetectionLoss(nn.Module):
116
+ """DETR-style detection loss (class CE + box L1 + GIoU), matching EoMTLoss's API.
117
+
118
+ ``weight_dict`` carries ``{"loss_cross_entropy", "loss_bbox", "loss_giou"}``; the
119
+ matcher reuses ``class_weight`` (from the config) and the L1/GIoU weights from
120
+ ``weight_dict``. ``forward`` returns the three *unweighted* loss terms — the caller
121
+ (``EoMTEncoder.get_loss_dict``) applies ``weight_dict``, exactly as for the mask loss.
122
+ """
123
+
124
+ def __init__(self, config, weight_dict: dict[str, float]):
125
+ super().__init__()
126
+ self.num_labels = config.num_labels
127
+ self.weight_dict = weight_dict
128
+
129
+ self.eos_coef = config.no_object_weight
130
+ empty_weight = torch.ones(self.num_labels + 1)
131
+ empty_weight[-1] = self.eos_coef
132
+ self.register_buffer("empty_weight", empty_weight)
133
+
134
+ self.matcher = DetectionHungarianMatcher(
135
+ cost_class=config.class_weight,
136
+ cost_bbox=weight_dict["loss_bbox"],
137
+ cost_giou=weight_dict["loss_giou"],
138
+ )
139
+
140
+ def _get_predictions_permutation_indices(self, indices):
141
+ batch_indices = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
142
+ predictions_indices = torch.cat([src for (src, _) in indices])
143
+ return batch_indices, predictions_indices
144
+
145
+ def loss_labels(self, class_queries_logits, class_labels, indices) -> dict[str, Tensor]:
146
+ pred_logits = class_queries_logits
147
+ batch_size, num_queries, _ = pred_logits.shape
148
+ criterion = nn.CrossEntropyLoss(weight=self.empty_weight)
149
+ idx = self._get_predictions_permutation_indices(indices)
150
+ target_classes_o = torch.cat(
151
+ [target[j] for target, (_, j) in zip(class_labels, indices)]
152
+ ).to(pred_logits.device)
153
+ target_classes = torch.full(
154
+ (batch_size, num_queries), fill_value=self.num_labels, dtype=torch.int64,
155
+ device=pred_logits.device,
156
+ )
157
+ target_classes[idx] = target_classes_o
158
+ loss_ce = criterion(pred_logits.transpose(1, 2), target_classes)
159
+ return {"loss_cross_entropy": loss_ce}
160
+
161
+ def loss_boxes(self, pred_boxes, box_labels, indices, num_boxes) -> dict[str, Tensor]:
162
+ idx = self._get_predictions_permutation_indices(indices)
163
+ src_boxes = pred_boxes[idx] # [N, 4] cxcywh
164
+ tgt_boxes = torch.cat(
165
+ [t[j] for t, (_, j) in zip(box_labels, indices)]
166
+ ).to(src_boxes) # [N, 4] cxcywh
167
+ if src_boxes.numel() == 0: # no matched query this batch — keep graph alive
168
+ zero = pred_boxes.sum() * 0.0
169
+ return {"loss_bbox": zero, "loss_giou": zero}
170
+
171
+ loss_bbox = F.l1_loss(src_boxes, tgt_boxes, reduction="sum") / num_boxes
172
+ giou = generalized_box_iou(
173
+ box_cxcywh_to_xyxy(src_boxes), box_cxcywh_to_xyxy(tgt_boxes)
174
+ ).diagonal()
175
+ loss_giou = (1 - giou).sum() / num_boxes
176
+ return {"loss_bbox": loss_bbox, "loss_giou": loss_giou}
177
+
178
+ def get_num_boxes(self, class_labels, device) -> Tensor:
179
+ num = sum(len(c) for c in class_labels)
180
+ return torch.clamp(torch.as_tensor(num, dtype=torch.float, device=device), min=1)
181
+
182
+ def forward(
183
+ self,
184
+ masks_queries_logits: Tensor, # unused; kept for signature symmetry with EoMTLoss
185
+ class_queries_logits: Tensor,
186
+ mask_labels: list[Tensor], # actually box_labels (cxcywh) for the detect family
187
+ class_labels: list[Tensor],
188
+ auxiliary_predictions=None,
189
+ ) -> dict[str, Tensor]:
190
+ pred_boxes, box_labels = masks_queries_logits, mask_labels
191
+ indices = self.matcher(pred_boxes, class_queries_logits, box_labels, class_labels)
192
+ num_boxes = self.get_num_boxes(class_labels, device=class_queries_logits.device)
193
+ return {
194
+ **self.loss_boxes(pred_boxes, box_labels, indices, num_boxes),
195
+ **self.loss_labels(class_queries_logits, class_labels, indices),
196
+ }