oriented-det 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- export/__init__.py +9 -0
- export/ort_runtime.py +67 -0
- export/postprocess.py +151 -0
- export/scripts/__init__.py +6 -0
- export/scripts/build_faster_rcnn_savedmodel.py +104 -0
- export/scripts/export_onnx.py +210 -0
- export/scripts/onnx_to_savedmodel.py +35 -0
- export/scripts/predict_savedmodel.py +94 -0
- export/scripts/save_predictions_tf.py +447 -0
- export/scripts/to_tflite.py +32 -0
- export/tests/test_export_onnx_optional.py +82 -0
- export/tests/test_export_wrappers.py +105 -0
- export/tests/test_faster_rcnn_export_parity.py +201 -0
- export/tests/test_ort_runtime.py +41 -0
- export/tf_serving_model.py +96 -0
- export/val_dataset.py +116 -0
- export/wrappers.py +161 -0
- oriented_det/__init__.py +77 -0
- oriented_det/cli/__init__.py +92 -0
- oriented_det/cli/train.py +18 -0
- oriented_det/configs/_base_/augmentation.json +21 -0
- oriented_det/configs/_base_/datasets/dota_le90.json +21 -0
- oriented_det/configs/_base_/fp16.json +5 -0
- oriented_det/configs/_base_/models/oriented_rcnn_r50.json +26 -0
- oriented_det/configs/_base_/models/rotated_faster_rcnn_r50.json +53 -0
- oriented_det/configs/_base_/models/rotated_retinanet_r50.json +30 -0
- oriented_det/configs/_base_/preprocessing.json +8 -0
- oriented_det/configs/_base_/schedules/1x.json +33 -0
- oriented_det/configs/config.schema.json +404 -0
- oriented_det/configs/oriented_rcnn/dota_le90_1x.json +121 -0
- oriented_det/configs/oriented_rcnn/dota_le90_3x.json +11 -0
- oriented_det/configs/rotated_faster_rcnn/dota_le90_1x.json +119 -0
- oriented_det/configs/rotated_faster_rcnn/dota_le90_3x.json +9 -0
- oriented_det/configs/rotated_retinanet/dota_le90_1x.json +107 -0
- oriented_det/configs/rotated_retinanet/dota_le90_3x.json +30 -0
- oriented_det/data/__init__.py +89 -0
- oriented_det/data/airbus_playground.py +611 -0
- oriented_det/data/dota.py +741 -0
- oriented_det/data/dota_classes.py +26 -0
- oriented_det/data/evaluation.py +648 -0
- oriented_det/data/flips.py +115 -0
- oriented_det/data/preprocessing.py +335 -0
- oriented_det/data/tiling.py +399 -0
- oriented_det/data/transforms.py +377 -0
- oriented_det/geometry/__init__.py +8 -0
- oriented_det/geometry/poly.py +127 -0
- oriented_det/geometry/qbox.py +63 -0
- oriented_det/geometry/rbox.py +250 -0
- oriented_det/geometry/transforms.py +266 -0
- oriented_det/models/__init__.py +36 -0
- oriented_det/models/backbones/__init__.py +11 -0
- oriented_det/models/backbones/resnet_fpn.py +79 -0
- oriented_det/models/backbones/utils.py +81 -0
- oriented_det/models/bbox_coder.py +355 -0
- oriented_det/models/faster_rcnn_inference.py +494 -0
- oriented_det/models/horizontal_roi_coder.py +155 -0
- oriented_det/models/oriented_rcnn.py +1256 -0
- oriented_det/models/oriented_roi.py +1664 -0
- oriented_det/models/oriented_rpn.py +2104 -0
- oriented_det/models/rotated_retinanet.py +1030 -0
- oriented_det/models/utils.py +590 -0
- oriented_det/ops/__init__.py +58 -0
- oriented_det/ops/gpu_ops.py +1109 -0
- oriented_det/ops/iou.py +172 -0
- oriented_det/ops/kfiou.py +275 -0
- oriented_det/ops/nms.py +202 -0
- oriented_det/ops/probiou.py +165 -0
- oriented_det/ops/rotated_ops.py +122 -0
- oriented_det/ops/utils.py +257 -0
- oriented_det/pretrained/__init__.py +23 -0
- oriented_det/pretrained/hub.py +249 -0
- oriented_det/pretrained/manifest.json +46 -0
- oriented_det/runtime/__init__.py +29 -0
- oriented_det/runtime/checkpoint.py +274 -0
- oriented_det/runtime/collate.py +348 -0
- oriented_det/runtime/inference.py +1286 -0
- oriented_det/train/__init__.py +102 -0
- oriented_det/train/config.py +872 -0
- oriented_det/train/engine.py +2933 -0
- oriented_det/train/grouped_ce.py +139 -0
- oriented_det/train/piecewise_schedule.py +33 -0
- oriented_det/train/profiler.py +287 -0
- oriented_det/train/utils.py +999 -0
- oriented_det/utils/__init__.py +31 -0
- oriented_det/utils/config.py +376 -0
- oriented_det/utils/device.py +35 -0
- oriented_det/utils/logging.py +163 -0
- oriented_det/utils/progress.py +62 -0
- oriented_det/utils/viz.py +181 -0
- oriented_det-0.1.0.dist-info/METADATA +313 -0
- oriented_det-0.1.0.dist-info/RECORD +115 -0
- oriented_det-0.1.0.dist-info/WHEEL +5 -0
- oriented_det-0.1.0.dist-info/entry_points.txt +2 -0
- oriented_det-0.1.0.dist-info/licenses/LICENSE +202 -0
- oriented_det-0.1.0.dist-info/top_level.txt +3 -0
- tools/__init__.py +1 -0
- tools/app.py +1284 -0
- tools/dataset_stats.py +389 -0
- tools/dota_labels_to_comma.py +132 -0
- tools/free_gpu.py +142 -0
- tools/generate_airbus_playground_csv.py +154 -0
- tools/image_demo.py +200 -0
- tools/lr_finder.py +771 -0
- tools/measure_sampled_riou_error.py +483 -0
- tools/playground_to_dota.py +290 -0
- tools/pretrained_download.py +49 -0
- tools/preview_augmentation.py +510 -0
- tools/publish_checkpoint.py +96 -0
- tools/save_predictions.py +2350 -0
- tools/sync_vendored_configs.py +94 -0
- tools/tile_dota.py +447 -0
- tools/train.py +2324 -0
- tools/train_example.py +244 -0
- tools/train_multi_gpu.py +367 -0
- tools/visualize_boxes.py +183 -0
export/__init__.py
ADDED
export/ort_runtime.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""ONNX Runtime device selection and session cache for TF export inference."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import Dict, List, Optional, Tuple
|
|
7
|
+
|
|
8
|
+
_ORT_DEVICE_OVERRIDE: Optional[str] = None
|
|
9
|
+
_SESSION_CACHE: Dict[Tuple[str, Tuple[str, ...]], object] = {}
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def set_ort_device(device: Optional[str]) -> None:
|
|
13
|
+
"""Set runtime ORT device for this process (``cpu``, ``cuda``, ``auto``)."""
|
|
14
|
+
global _ORT_DEVICE_OVERRIDE
|
|
15
|
+
_ORT_DEVICE_OVERRIDE = device.lower().strip() if device else None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_ort_device() -> str:
|
|
19
|
+
"""Resolved ORT device string (override, env, or default ``cpu``)."""
|
|
20
|
+
if _ORT_DEVICE_OVERRIDE:
|
|
21
|
+
return _ORT_DEVICE_OVERRIDE
|
|
22
|
+
return (os.environ.get("ORIENTED_DET_ORT_DEVICE") or "cpu").lower().strip()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def ort_providers_for_device(device: Optional[str] = None) -> List[str]:
|
|
26
|
+
"""Map device string to ONNX Runtime ``providers`` list."""
|
|
27
|
+
import onnxruntime as ort
|
|
28
|
+
|
|
29
|
+
d = (device if device is not None else get_ort_device()).lower().strip()
|
|
30
|
+
if d in ("cuda", "gpu"):
|
|
31
|
+
if "CUDAExecutionProvider" not in ort.get_available_providers():
|
|
32
|
+
raise RuntimeError(
|
|
33
|
+
"ORT device=cuda requested but CUDAExecutionProvider is not available. "
|
|
34
|
+
"Install onnxruntime-gpu matching your CUDA driver."
|
|
35
|
+
)
|
|
36
|
+
return ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
|
37
|
+
if d == "auto":
|
|
38
|
+
if "CUDAExecutionProvider" in ort.get_available_providers():
|
|
39
|
+
return ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
|
40
|
+
return ["CPUExecutionProvider"]
|
|
41
|
+
if d != "cpu":
|
|
42
|
+
raise ValueError(f"Unknown ORT device {device!r}; use cpu, cuda, or auto.")
|
|
43
|
+
return ["CPUExecutionProvider"]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def configure_ort_device(device: Optional[str]) -> List[str]:
|
|
47
|
+
"""Apply device override and return the provider list that will be used."""
|
|
48
|
+
if device is not None:
|
|
49
|
+
set_ort_device(device)
|
|
50
|
+
providers = ort_providers_for_device()
|
|
51
|
+
return providers
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def clear_ort_session_cache() -> None:
|
|
55
|
+
"""Drop cached ORT sessions (tests)."""
|
|
56
|
+
_SESSION_CACHE.clear()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def get_ort_session(onnx_path: str, device: Optional[str] = None):
|
|
60
|
+
"""Return a cached ``onnxruntime.InferenceSession`` for ``onnx_path``."""
|
|
61
|
+
import onnxruntime as ort
|
|
62
|
+
|
|
63
|
+
providers = ort_providers_for_device(device)
|
|
64
|
+
key = (str(onnx_path), tuple(providers))
|
|
65
|
+
if key not in _SESSION_CACHE:
|
|
66
|
+
_SESSION_CACHE[key] = ort.InferenceSession(onnx_path, providers=providers)
|
|
67
|
+
return _SESSION_CACHE[key]
|
export/postprocess.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
"""Post-NMS detection finalization for TF export (exact CPU rotated NMS + score filter)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import torch
|
|
11
|
+
except Exception: # pragma: no cover
|
|
12
|
+
torch = None # type: ignore
|
|
13
|
+
|
|
14
|
+
from export.ort_runtime import get_ort_device, get_ort_session
|
|
15
|
+
from oriented_det.models.faster_rcnn_inference import PreNmsDetections, apply_final_rotated_nms
|
|
16
|
+
from oriented_det.train.utils import effective_score_threshold_for_class_name
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class _NmsConfigView:
|
|
20
|
+
"""Minimal attribute bag for :func:`apply_final_rotated_nms`."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
*,
|
|
25
|
+
nms_class_agnostic: bool,
|
|
26
|
+
final_nms_iou_threshold: float,
|
|
27
|
+
max_detections_per_image: Optional[int],
|
|
28
|
+
final_nms_use_cpu: bool,
|
|
29
|
+
) -> None:
|
|
30
|
+
self.nms_class_agnostic = nms_class_agnostic
|
|
31
|
+
self.final_nms_iou_threshold = final_nms_iou_threshold
|
|
32
|
+
self.max_detections_per_image = max_detections_per_image
|
|
33
|
+
self.final_nms_use_cpu = final_nms_use_cpu
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def finalize_detections_numpy(
|
|
37
|
+
pre_nms_boxes: np.ndarray,
|
|
38
|
+
pre_nms_scores: np.ndarray,
|
|
39
|
+
pre_nms_labels: np.ndarray,
|
|
40
|
+
pre_nms_count: int,
|
|
41
|
+
*,
|
|
42
|
+
nms_class_agnostic: bool,
|
|
43
|
+
final_nms_iou_threshold: float,
|
|
44
|
+
max_detections_per_image: Optional[int],
|
|
45
|
+
final_nms_use_cpu: bool,
|
|
46
|
+
score_threshold: float,
|
|
47
|
+
per_class_score_threshold: Optional[Dict[str, float]],
|
|
48
|
+
class_id_to_name: Dict[int, str],
|
|
49
|
+
max_output_slots: int,
|
|
50
|
+
) -> Tuple[np.ndarray, int]:
|
|
51
|
+
"""Run rotated NMS + production score filter; return padded ``[max_output_slots, 7]``."""
|
|
52
|
+
if torch is None:
|
|
53
|
+
raise RuntimeError("torch is required for finalize_detections_numpy")
|
|
54
|
+
|
|
55
|
+
out = np.zeros((max_output_slots, 7), dtype=np.float32)
|
|
56
|
+
n = int(pre_nms_count)
|
|
57
|
+
if n <= 0:
|
|
58
|
+
return out, 0
|
|
59
|
+
|
|
60
|
+
boxes_t = torch.from_numpy(np.asarray(pre_nms_boxes[:n], dtype=np.float32))
|
|
61
|
+
scores_t = torch.from_numpy(np.asarray(pre_nms_scores[:n], dtype=np.float32))
|
|
62
|
+
labels_t = torch.from_numpy(np.asarray(pre_nms_labels[:n], dtype=np.int64))
|
|
63
|
+
|
|
64
|
+
nms_view = _NmsConfigView(
|
|
65
|
+
nms_class_agnostic=nms_class_agnostic,
|
|
66
|
+
final_nms_iou_threshold=final_nms_iou_threshold,
|
|
67
|
+
max_detections_per_image=max_detections_per_image,
|
|
68
|
+
final_nms_use_cpu=final_nms_use_cpu,
|
|
69
|
+
)
|
|
70
|
+
final = apply_final_rotated_nms(nms_view, PreNmsDetections(boxes_t, scores_t, labels_t))
|
|
71
|
+
|
|
72
|
+
if final.boxes.numel() == 0:
|
|
73
|
+
return out, 0
|
|
74
|
+
|
|
75
|
+
keep_mask = torch.ones(final.scores.shape[0], dtype=torch.bool)
|
|
76
|
+
if per_class_score_threshold or score_threshold is not None:
|
|
77
|
+
keep_list = []
|
|
78
|
+
for i in range(int(final.scores.shape[0])):
|
|
79
|
+
lid = int(final.labels[i].item())
|
|
80
|
+
cname = class_id_to_name.get(lid, f"class_{lid}")
|
|
81
|
+
thr = effective_score_threshold_for_class_name(
|
|
82
|
+
cname, score_threshold, per_class_score_threshold
|
|
83
|
+
)
|
|
84
|
+
keep_list.append(float(final.scores[i].item()) >= thr)
|
|
85
|
+
keep_mask = torch.tensor(keep_list, dtype=torch.bool)
|
|
86
|
+
|
|
87
|
+
final_boxes = final.boxes[keep_mask]
|
|
88
|
+
final_scores = final.scores[keep_mask]
|
|
89
|
+
final_labels = final.labels[keep_mask]
|
|
90
|
+
|
|
91
|
+
m = min(int(final_boxes.shape[0]), max_output_slots)
|
|
92
|
+
if m == 0:
|
|
93
|
+
return out, 0
|
|
94
|
+
|
|
95
|
+
det = np.column_stack(
|
|
96
|
+
[
|
|
97
|
+
final_boxes[:m].cpu().numpy(),
|
|
98
|
+
final_scores[:m].cpu().numpy().reshape(-1, 1),
|
|
99
|
+
final_labels[:m].cpu().numpy().astype(np.float32).reshape(-1, 1),
|
|
100
|
+
]
|
|
101
|
+
)
|
|
102
|
+
out[:m] = det
|
|
103
|
+
return out, m
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def build_class_id_to_name(class_names: List[str]) -> Dict[int, str]:
|
|
107
|
+
"""Map 1-based foreground label id to class name."""
|
|
108
|
+
return {i + 1: name for i, name in enumerate(class_names)}
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def ort_pre_nms_to_detections(
|
|
112
|
+
images: "np.ndarray",
|
|
113
|
+
onnx_path: str,
|
|
114
|
+
ort_output_names: List[str],
|
|
115
|
+
finalize_kwargs: Dict[str, Any],
|
|
116
|
+
) -> Tuple["np.ndarray", int]:
|
|
117
|
+
"""ONNX Runtime forward + finalize (for TF ``numpy_function`` / Keras bundle)."""
|
|
118
|
+
import numpy as np
|
|
119
|
+
|
|
120
|
+
img = np.asarray(images, dtype=np.float32)
|
|
121
|
+
if img.ndim == 3:
|
|
122
|
+
img = img[np.newaxis, ...]
|
|
123
|
+
sess = get_ort_session(onnx_path)
|
|
124
|
+
input_name = sess.get_inputs()[0].name
|
|
125
|
+
outs = sess.run(ort_output_names, {input_name: img})
|
|
126
|
+
name_to_val = dict(zip(ort_output_names, outs))
|
|
127
|
+
detections, num = finalize_detections_numpy(
|
|
128
|
+
name_to_val["pre_nms_boxes"],
|
|
129
|
+
name_to_val["pre_nms_scores"],
|
|
130
|
+
name_to_val["pre_nms_labels"],
|
|
131
|
+
int(np.asarray(name_to_val["pre_nms_count"]).reshape(-1)[0]),
|
|
132
|
+
**finalize_kwargs,
|
|
133
|
+
)
|
|
134
|
+
return detections, int(num)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def meta_to_finalize_kwargs(meta: Dict[str, Any]) -> Dict[str, Any]:
|
|
138
|
+
"""Extract finalize_detections_numpy kwargs from export meta JSON."""
|
|
139
|
+
prod = meta.get("production") or {}
|
|
140
|
+
class_names: List[str] = list(meta.get("class_names") or [])
|
|
141
|
+
max_det = int(prod.get("max_detections_per_image") or meta.get("max_detections_per_image") or 3000)
|
|
142
|
+
return {
|
|
143
|
+
"nms_class_agnostic": bool(prod.get("nms_class_agnostic", False)),
|
|
144
|
+
"final_nms_iou_threshold": float(prod.get("final_nms_iou_threshold", 0.1)),
|
|
145
|
+
"max_detections_per_image": max_det,
|
|
146
|
+
"final_nms_use_cpu": bool(prod.get("final_nms_use_cpu", True)),
|
|
147
|
+
"score_threshold": float(prod.get("score_threshold", 0.05)),
|
|
148
|
+
"per_class_score_threshold": prod.get("per_class_score_threshold"),
|
|
149
|
+
"class_id_to_name": build_class_id_to_name(class_names),
|
|
150
|
+
"max_output_slots": max_det,
|
|
151
|
+
}
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Build e2e Faster R-CNN SavedModel (Keras + ORT core + exact rotated NMS)."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import json
|
|
8
|
+
import sys
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Optional
|
|
11
|
+
|
|
12
|
+
_REPO_ROOT = Path(__file__).resolve().parents[2]
|
|
13
|
+
if str(_REPO_ROOT) not in sys.path:
|
|
14
|
+
sys.path.insert(0, str(_REPO_ROOT))
|
|
15
|
+
|
|
16
|
+
from export.postprocess import meta_to_finalize_kwargs # noqa: E402
|
|
17
|
+
from export.tf_serving_model import FasterRCNNDetectLayer, save_keras_detect_bundle # noqa: E402
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _load_meta(meta_path: Path) -> dict:
|
|
21
|
+
return json.loads(meta_path.read_text(encoding="utf-8"))
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _resolve_onnx_path(meta: dict, meta_path: Path, onnx_path: Optional[Path]) -> Path:
|
|
25
|
+
if onnx_path is not None and onnx_path.is_file():
|
|
26
|
+
return onnx_path
|
|
27
|
+
if meta.get("onnx_path"):
|
|
28
|
+
p = Path(meta["onnx_path"])
|
|
29
|
+
if p.is_file():
|
|
30
|
+
return p
|
|
31
|
+
candidate = meta_path.with_suffix("").with_suffix(".onnx")
|
|
32
|
+
if candidate.is_file():
|
|
33
|
+
return candidate
|
|
34
|
+
raise FileNotFoundError(
|
|
35
|
+
"ONNX file not found; pass --onnx or ensure meta onnx_path / sibling .onnx exists."
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def build_savedmodel(
|
|
40
|
+
output_path: Path,
|
|
41
|
+
meta_path: Path,
|
|
42
|
+
*,
|
|
43
|
+
onnx_path: Optional[Path] = None,
|
|
44
|
+
) -> None:
|
|
45
|
+
import tensorflow as tf
|
|
46
|
+
|
|
47
|
+
meta = _load_meta(meta_path)
|
|
48
|
+
finalize_kwargs = meta_to_finalize_kwargs(meta)
|
|
49
|
+
onnx_file = _resolve_onnx_path(meta, meta_path, onnx_path)
|
|
50
|
+
|
|
51
|
+
h = int(meta["input"]["shape"][2])
|
|
52
|
+
w = int(meta["input"]["shape"][3])
|
|
53
|
+
inputs = tf.keras.Input(shape=(3, h, w), batch_size=1, name="images")
|
|
54
|
+
layer = FasterRCNNDetectLayer(
|
|
55
|
+
onnx_path=str(onnx_file.resolve()),
|
|
56
|
+
ort_output_names=list(meta.get("output_names") or []),
|
|
57
|
+
finalize_kwargs=finalize_kwargs,
|
|
58
|
+
max_output_slots=int(finalize_kwargs["max_output_slots"]),
|
|
59
|
+
)
|
|
60
|
+
layer_out = layer(inputs)
|
|
61
|
+
model = tf.keras.Model(
|
|
62
|
+
inputs=inputs,
|
|
63
|
+
outputs=[layer_out["detections"], layer_out["num_detections"]],
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
full_meta = dict(meta)
|
|
67
|
+
full_meta["core_backend"] = "onnxruntime_keras"
|
|
68
|
+
full_meta["onnx_path"] = str(onnx_file)
|
|
69
|
+
full_meta["savedmodel_outputs"] = {
|
|
70
|
+
"detections": {
|
|
71
|
+
"shape": [finalize_kwargs["max_output_slots"], 7],
|
|
72
|
+
"layout": "cx,cy,w,h,angle,score,label",
|
|
73
|
+
},
|
|
74
|
+
"num_detections": {"dtype": "int32"},
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
keras_path = save_keras_detect_bundle(model, output_path, full_meta)
|
|
78
|
+
print(f"Wrote detect bundle: {output_path} (keras: {keras_path.name}, core: onnxruntime)")
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def main() -> None:
|
|
82
|
+
p = argparse.ArgumentParser(description="Build e2e Faster R-CNN SavedModel.")
|
|
83
|
+
p.add_argument(
|
|
84
|
+
"--tf-core",
|
|
85
|
+
type=Path,
|
|
86
|
+
default=None,
|
|
87
|
+
help="Ignored (kept for CLI compatibility with README pipeline).",
|
|
88
|
+
)
|
|
89
|
+
p.add_argument("--meta", type=Path, required=True, help="*.export_meta.json from export_onnx.py.")
|
|
90
|
+
p.add_argument("--onnx", type=Path, default=None, help="ONNX file (default: from meta / sibling).")
|
|
91
|
+
p.add_argument("--output", type=Path, required=True, help="Output SavedModel directory.")
|
|
92
|
+
args = p.parse_args()
|
|
93
|
+
|
|
94
|
+
try:
|
|
95
|
+
import tensorflow as tf # noqa: F401
|
|
96
|
+
except ImportError as e:
|
|
97
|
+
print("Install TensorFlow: pip install -r export/requirements-export.txt", file=sys.stderr)
|
|
98
|
+
raise SystemExit(1) from e
|
|
99
|
+
|
|
100
|
+
build_savedmodel(args.output, args.meta, onnx_path=args.onnx)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
if __name__ == "__main__":
|
|
104
|
+
main()
|
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Export a tensor-only subgraph from oriented-det to ONNX.
|
|
3
|
+
|
|
4
|
+
See export/README.md and export/contract.json for supported modes.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import argparse
|
|
10
|
+
import json
|
|
11
|
+
import sys
|
|
12
|
+
from dataclasses import asdict
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
_REPO_ROOT = Path(__file__).resolve().parents[2]
|
|
19
|
+
_EXPORT_DIR = _REPO_ROOT / "export"
|
|
20
|
+
if str(_REPO_ROOT) not in sys.path:
|
|
21
|
+
sys.path.insert(0, str(_REPO_ROOT))
|
|
22
|
+
if str(_EXPORT_DIR) not in sys.path:
|
|
23
|
+
sys.path.insert(0, str(_EXPORT_DIR))
|
|
24
|
+
|
|
25
|
+
import wrappers as _wrappers # noqa: E402
|
|
26
|
+
from oriented_det.models.oriented_rcnn import RotatedFasterRCNN # noqa: E402
|
|
27
|
+
from oriented_det.models.rotated_retinanet import RotatedRetinaNet # noqa: E402
|
|
28
|
+
from oriented_det.runtime.checkpoint import load_model_from_checkpoint # noqa: E402
|
|
29
|
+
|
|
30
|
+
BackboneExportWrapper = _wrappers.BackboneExportWrapper
|
|
31
|
+
RetinaNetBackboneHeadExportWrapper = _wrappers.RetinaNetBackboneHeadExportWrapper
|
|
32
|
+
RotatedFasterRCNNPreNmsExportWrapper = _wrappers.RotatedFasterRCNNPreNmsExportWrapper
|
|
33
|
+
|
|
34
|
+
FASTER_RCNN_PRE_NMS_OUTPUTS = (
|
|
35
|
+
"pre_nms_boxes",
|
|
36
|
+
"pre_nms_scores",
|
|
37
|
+
"pre_nms_labels",
|
|
38
|
+
"pre_nms_count",
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _build_wrapper(
|
|
43
|
+
model: torch.nn.Module,
|
|
44
|
+
mode: str,
|
|
45
|
+
height: int,
|
|
46
|
+
width: int,
|
|
47
|
+
) -> torch.nn.Module:
|
|
48
|
+
mt = type(model).__name__
|
|
49
|
+
if mode == "backbone":
|
|
50
|
+
if not hasattr(model, "backbone"):
|
|
51
|
+
raise ValueError(f"Model {mt} has no 'backbone' attribute.")
|
|
52
|
+
return BackboneExportWrapper(model.backbone)
|
|
53
|
+
if mode == "retinanet_heads":
|
|
54
|
+
if not isinstance(model, RotatedRetinaNet):
|
|
55
|
+
raise ValueError(
|
|
56
|
+
f"retinanet_heads mode requires RotatedRetinaNet, got {mt}. "
|
|
57
|
+
"Use --mode backbone for two-stage models."
|
|
58
|
+
)
|
|
59
|
+
return RetinaNetBackboneHeadExportWrapper(model)
|
|
60
|
+
if mode == "faster_rcnn_pre_nms":
|
|
61
|
+
if not isinstance(model, RotatedFasterRCNN):
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"faster_rcnn_pre_nms requires RotatedFasterRCNN, got {mt}."
|
|
64
|
+
)
|
|
65
|
+
return RotatedFasterRCNNPreNmsExportWrapper(model, height=height, width=width)
|
|
66
|
+
raise ValueError(f"Unknown mode: {mode}")
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _output_names_for_wrapper(
|
|
70
|
+
wrapper: torch.nn.Module,
|
|
71
|
+
mode: str,
|
|
72
|
+
device: torch.device,
|
|
73
|
+
height: int,
|
|
74
|
+
width: int,
|
|
75
|
+
) -> list[str]:
|
|
76
|
+
if mode == "faster_rcnn_pre_nms":
|
|
77
|
+
return list(FASTER_RCNN_PRE_NMS_OUTPUTS)
|
|
78
|
+
|
|
79
|
+
dummy = torch.zeros(1, 3, 128, 128, dtype=torch.float32, device=device)
|
|
80
|
+
with torch.no_grad():
|
|
81
|
+
out = wrapper(dummy)
|
|
82
|
+
names: list[str] = []
|
|
83
|
+
if mode == "backbone":
|
|
84
|
+
for i in range(len(out)):
|
|
85
|
+
names.append(f"fpn_level_{i}")
|
|
86
|
+
return names
|
|
87
|
+
if mode == "retinanet_heads":
|
|
88
|
+
for i in range(len(out) // 2):
|
|
89
|
+
names.append(f"level{i}_cls_logits")
|
|
90
|
+
names.append(f"level{i}_bbox_pred")
|
|
91
|
+
return names
|
|
92
|
+
return [f"out_{i}" for i in range(len(out))]
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _production_dict(config: Any) -> dict | None:
|
|
96
|
+
prod = getattr(config, "production", None)
|
|
97
|
+
if prod is None:
|
|
98
|
+
return None
|
|
99
|
+
if hasattr(prod, "__dataclass_fields__"):
|
|
100
|
+
return asdict(prod)
|
|
101
|
+
if isinstance(prod, dict):
|
|
102
|
+
return prod
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _validate_onnx_ort(onnx_path: Path, dummy: torch.Tensor, output_names: list[str]) -> None:
|
|
107
|
+
try:
|
|
108
|
+
import onnx
|
|
109
|
+
import onnxruntime as ort
|
|
110
|
+
except ImportError:
|
|
111
|
+
print("Skipping ORT validation (install onnx and onnxruntime).")
|
|
112
|
+
return
|
|
113
|
+
|
|
114
|
+
onnx.checker.check_model(onnx.load(str(onnx_path)))
|
|
115
|
+
sess = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"])
|
|
116
|
+
feeds = {"images": dummy.detach().cpu().numpy()}
|
|
117
|
+
outs = sess.run(output_names, feeds)
|
|
118
|
+
print(f"ORT smoke OK: {len(outs)} outputs")
|
|
119
|
+
for name, arr in zip(output_names, outs):
|
|
120
|
+
print(f" {name}: shape={arr.shape} dtype={arr.dtype}")
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def main() -> None:
|
|
124
|
+
p = argparse.ArgumentParser(description="Export oriented-det subgraph to ONNX.")
|
|
125
|
+
p.add_argument("--config", type=Path, required=True, help="Training JSON config path.")
|
|
126
|
+
p.add_argument("--checkpoint", type=Path, required=True, help="Weights .pth path.")
|
|
127
|
+
p.add_argument("--output", type=Path, required=True, help="Output .onnx file path.")
|
|
128
|
+
p.add_argument("--height", type=int, default=1024)
|
|
129
|
+
p.add_argument("--width", type=int, default=1024)
|
|
130
|
+
p.add_argument(
|
|
131
|
+
"--mode",
|
|
132
|
+
choices=("backbone", "retinanet_heads", "faster_rcnn_pre_nms"),
|
|
133
|
+
default="backbone",
|
|
134
|
+
help="backbone | retinanet_heads | faster_rcnn_pre_nms (Rotated Faster R-CNN detect).",
|
|
135
|
+
)
|
|
136
|
+
p.add_argument("--opset", type=int, default=17)
|
|
137
|
+
p.add_argument(
|
|
138
|
+
"--dynamic-batch",
|
|
139
|
+
action="store_true",
|
|
140
|
+
help="Allow dynamic batch size (N) on input images; H and W stay fixed.",
|
|
141
|
+
)
|
|
142
|
+
p.add_argument(
|
|
143
|
+
"--skip-ort",
|
|
144
|
+
action="store_true",
|
|
145
|
+
help="Skip onnxruntime validation after export.",
|
|
146
|
+
)
|
|
147
|
+
p.add_argument("--device", default="cpu", help="Device to trace on (cpu recommended for reproducibility).")
|
|
148
|
+
args = p.parse_args()
|
|
149
|
+
|
|
150
|
+
model, config, class_names = load_model_from_checkpoint(
|
|
151
|
+
str(args.checkpoint),
|
|
152
|
+
str(args.config),
|
|
153
|
+
device=args.device,
|
|
154
|
+
)
|
|
155
|
+
h, w = int(args.height), int(args.width)
|
|
156
|
+
wrapper = _build_wrapper(model, args.mode, h, w)
|
|
157
|
+
wrapper.eval()
|
|
158
|
+
wrapper.to(args.device)
|
|
159
|
+
|
|
160
|
+
dev = torch.device(args.device)
|
|
161
|
+
out_names = _output_names_for_wrapper(wrapper, args.mode, dev, h, w)
|
|
162
|
+
dummy = torch.zeros(1, 3, h, w, dtype=torch.float32, device=dev)
|
|
163
|
+
|
|
164
|
+
dynamic_axes = None
|
|
165
|
+
if args.dynamic_batch:
|
|
166
|
+
dynamic_axes = {"images": {0: "batch"}}
|
|
167
|
+
for name in out_names:
|
|
168
|
+
dynamic_axes[name] = {0: "batch"}
|
|
169
|
+
|
|
170
|
+
args.output.parent.mkdir(parents=True, exist_ok=True)
|
|
171
|
+
torch.onnx.export(
|
|
172
|
+
wrapper,
|
|
173
|
+
dummy,
|
|
174
|
+
str(args.output),
|
|
175
|
+
input_names=["images"],
|
|
176
|
+
output_names=out_names,
|
|
177
|
+
dynamic_axes=dynamic_axes,
|
|
178
|
+
opset_version=int(args.opset),
|
|
179
|
+
do_constant_folding=True,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
meta: dict[str, Any] = {
|
|
183
|
+
"mode": args.mode,
|
|
184
|
+
"input": {"name": "images", "shape": [1, 3, h, w], "dtype": "float32"},
|
|
185
|
+
"output_names": out_names,
|
|
186
|
+
"opset": args.opset,
|
|
187
|
+
"dynamic_batch": args.dynamic_batch,
|
|
188
|
+
"config": str(args.config),
|
|
189
|
+
"checkpoint": str(args.checkpoint),
|
|
190
|
+
"class_names": class_names or [],
|
|
191
|
+
"num_classes": int(getattr(config, "num_classes", 0) or 0),
|
|
192
|
+
"production": _production_dict(config),
|
|
193
|
+
}
|
|
194
|
+
if args.mode == "faster_rcnn_pre_nms" and isinstance(wrapper, RotatedFasterRCNNPreNmsExportWrapper):
|
|
195
|
+
meta["max_pre_nms_candidates"] = wrapper.max_candidates
|
|
196
|
+
prod = meta.get("production") or {}
|
|
197
|
+
meta["max_detections_per_image"] = prod.get("max_detections_per_image")
|
|
198
|
+
|
|
199
|
+
meta["onnx_path"] = str(args.output.resolve())
|
|
200
|
+
meta_path = args.output.with_suffix(".export_meta.json")
|
|
201
|
+
meta_path.write_text(json.dumps(meta, indent=2), encoding="utf-8")
|
|
202
|
+
print(f"Wrote ONNX: {args.output}")
|
|
203
|
+
print(f"Wrote meta: {meta_path}")
|
|
204
|
+
|
|
205
|
+
if not args.skip_ort:
|
|
206
|
+
_validate_onnx_ort(args.output, dummy, out_names)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
if __name__ == "__main__":
|
|
210
|
+
main()
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Convert ONNX to TensorFlow SavedModel using onnx2tf."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import shutil
|
|
8
|
+
import subprocess
|
|
9
|
+
import sys
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _find_onnx2tf() -> list[str]:
|
|
14
|
+
exe = shutil.which("onnx2tf")
|
|
15
|
+
if exe:
|
|
16
|
+
return [exe]
|
|
17
|
+
return [sys.executable, "-m", "onnx2tf"]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def main() -> None:
|
|
21
|
+
p = argparse.ArgumentParser(description="ONNX → SavedModel via onnx2tf.")
|
|
22
|
+
p.add_argument("--onnx", type=Path, required=True)
|
|
23
|
+
p.add_argument("--output", type=Path, required=True, help="Output directory for SavedModel.")
|
|
24
|
+
args = p.parse_args()
|
|
25
|
+
|
|
26
|
+
# tf_converter can fail on some ops (e.g. ScatterND); flatbuffer_direct still useful for TFLite.
|
|
27
|
+
# build_faster_rcnn_savedmodel.py falls back to ONNX Runtime when no saved_model.pb is found.
|
|
28
|
+
cmd = _find_onnx2tf() + ["-i", str(args.onnx), "-o", str(args.output), "-osd"]
|
|
29
|
+
print("Running:", " ".join(cmd))
|
|
30
|
+
subprocess.run(cmd, check=True)
|
|
31
|
+
print(f"SavedModel directory: {args.output}")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
if __name__ == "__main__":
|
|
35
|
+
main()
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Smoke-test an export detect bundle (keras_model.keras) or legacy TF SavedModel."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import json
|
|
8
|
+
import sys
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
_REPO_ROOT = Path(__file__).resolve().parents[2]
|
|
12
|
+
if str(_REPO_ROOT) not in sys.path:
|
|
13
|
+
sys.path.insert(0, str(_REPO_ROOT))
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _run_keras_bundle(bundle_dir: Path, height: int, width: int, ort_device: str | None) -> None:
|
|
17
|
+
import tensorflow as tf
|
|
18
|
+
|
|
19
|
+
from export.ort_runtime import configure_ort_device, get_ort_device
|
|
20
|
+
from export.tf_serving_model import load_keras_detect_model
|
|
21
|
+
|
|
22
|
+
providers = configure_ort_device(ort_device)
|
|
23
|
+
print(f" ort_device: {get_ort_device()} providers: {providers}")
|
|
24
|
+
|
|
25
|
+
keras_path = bundle_dir / "keras_model.keras"
|
|
26
|
+
if not keras_path.is_file():
|
|
27
|
+
raise FileNotFoundError(f"Missing {keras_path}")
|
|
28
|
+
model = load_keras_detect_model(keras_path)
|
|
29
|
+
x = tf.zeros([1, 3, height, width], dtype=tf.float32)
|
|
30
|
+
detections, num_detections = model(x, training=False)
|
|
31
|
+
print(f" detections: shape={detections.shape} dtype={detections.dtype}")
|
|
32
|
+
print(f" num_detections: {int(num_detections.numpy())}")
|
|
33
|
+
meta_path = bundle_dir / "export_meta.json"
|
|
34
|
+
if meta_path.is_file():
|
|
35
|
+
meta = json.loads(meta_path.read_text(encoding="utf-8"))
|
|
36
|
+
print(f" core_backend: {meta.get('core_backend', 'unknown')}")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _run_saved_model(sm_dir: Path, height: int, width: int) -> None:
|
|
40
|
+
import tensorflow as tf
|
|
41
|
+
|
|
42
|
+
sm = tf.saved_model.load(str(sm_dir))
|
|
43
|
+
sigs = list(getattr(sm, "signatures", {}).keys())
|
|
44
|
+
print("Signatures:", sigs)
|
|
45
|
+
name = "serving_default" if "serving_default" in sigs else sigs[0]
|
|
46
|
+
fn = sm.signatures[name]
|
|
47
|
+
|
|
48
|
+
kwargs = {}
|
|
49
|
+
_pos, kw = fn.structured_input_signature
|
|
50
|
+
if kw:
|
|
51
|
+
for key, spec in kw.items():
|
|
52
|
+
if spec.dtype.is_floating and spec.shape.rank == 4:
|
|
53
|
+
shape = [1, 3, height, width]
|
|
54
|
+
kwargs[key] = tf.zeros(shape, dtype=spec.dtype)
|
|
55
|
+
if not kwargs:
|
|
56
|
+
raise SystemExit("Could not infer image input from SavedModel signature.")
|
|
57
|
+
out = fn(**kwargs)
|
|
58
|
+
for k, v in out.items():
|
|
59
|
+
print(f" {k}: shape={v.shape} dtype={v.dtype}")
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def main() -> None:
|
|
63
|
+
p = argparse.ArgumentParser(description="Smoke-test export detect bundle or SavedModel.")
|
|
64
|
+
p.add_argument(
|
|
65
|
+
"--saved-model",
|
|
66
|
+
type=Path,
|
|
67
|
+
required=True,
|
|
68
|
+
help="Directory with keras_model.keras (+ export_meta.json) or TF SavedModel.",
|
|
69
|
+
)
|
|
70
|
+
p.add_argument("--height", type=int, default=1024)
|
|
71
|
+
p.add_argument("--width", type=int, default=1024)
|
|
72
|
+
p.add_argument(
|
|
73
|
+
"--ort-device",
|
|
74
|
+
default=None,
|
|
75
|
+
choices=("cpu", "cuda", "auto"),
|
|
76
|
+
help="ONNX Runtime EP for keras bundle (default: cpu or ORIENTED_DET_ORT_DEVICE).",
|
|
77
|
+
)
|
|
78
|
+
args = p.parse_args()
|
|
79
|
+
|
|
80
|
+
try:
|
|
81
|
+
import tensorflow as tf # noqa: F401
|
|
82
|
+
except ImportError as e:
|
|
83
|
+
print("Install TensorFlow: pip install -r export/requirements-export.txt", file=sys.stderr)
|
|
84
|
+
raise SystemExit(1) from e
|
|
85
|
+
|
|
86
|
+
bundle = args.saved_model
|
|
87
|
+
if (bundle / "keras_model.keras").is_file():
|
|
88
|
+
_run_keras_bundle(bundle, args.height, args.width, args.ort_device)
|
|
89
|
+
else:
|
|
90
|
+
_run_saved_model(bundle, args.height, args.width)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
if __name__ == "__main__":
|
|
94
|
+
main()
|