oriented-det 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. export/__init__.py +9 -0
  2. export/ort_runtime.py +67 -0
  3. export/postprocess.py +151 -0
  4. export/scripts/__init__.py +6 -0
  5. export/scripts/build_faster_rcnn_savedmodel.py +104 -0
  6. export/scripts/export_onnx.py +210 -0
  7. export/scripts/onnx_to_savedmodel.py +35 -0
  8. export/scripts/predict_savedmodel.py +94 -0
  9. export/scripts/save_predictions_tf.py +447 -0
  10. export/scripts/to_tflite.py +32 -0
  11. export/tests/test_export_onnx_optional.py +82 -0
  12. export/tests/test_export_wrappers.py +105 -0
  13. export/tests/test_faster_rcnn_export_parity.py +201 -0
  14. export/tests/test_ort_runtime.py +41 -0
  15. export/tf_serving_model.py +96 -0
  16. export/val_dataset.py +116 -0
  17. export/wrappers.py +161 -0
  18. oriented_det/__init__.py +77 -0
  19. oriented_det/cli/__init__.py +92 -0
  20. oriented_det/cli/train.py +18 -0
  21. oriented_det/configs/_base_/augmentation.json +21 -0
  22. oriented_det/configs/_base_/datasets/dota_le90.json +21 -0
  23. oriented_det/configs/_base_/fp16.json +5 -0
  24. oriented_det/configs/_base_/models/oriented_rcnn_r50.json +26 -0
  25. oriented_det/configs/_base_/models/rotated_faster_rcnn_r50.json +53 -0
  26. oriented_det/configs/_base_/models/rotated_retinanet_r50.json +30 -0
  27. oriented_det/configs/_base_/preprocessing.json +8 -0
  28. oriented_det/configs/_base_/schedules/1x.json +33 -0
  29. oriented_det/configs/config.schema.json +404 -0
  30. oriented_det/configs/oriented_rcnn/dota_le90_1x.json +121 -0
  31. oriented_det/configs/oriented_rcnn/dota_le90_3x.json +11 -0
  32. oriented_det/configs/rotated_faster_rcnn/dota_le90_1x.json +119 -0
  33. oriented_det/configs/rotated_faster_rcnn/dota_le90_3x.json +9 -0
  34. oriented_det/configs/rotated_retinanet/dota_le90_1x.json +107 -0
  35. oriented_det/configs/rotated_retinanet/dota_le90_3x.json +30 -0
  36. oriented_det/data/__init__.py +89 -0
  37. oriented_det/data/airbus_playground.py +611 -0
  38. oriented_det/data/dota.py +741 -0
  39. oriented_det/data/dota_classes.py +26 -0
  40. oriented_det/data/evaluation.py +648 -0
  41. oriented_det/data/flips.py +115 -0
  42. oriented_det/data/preprocessing.py +335 -0
  43. oriented_det/data/tiling.py +399 -0
  44. oriented_det/data/transforms.py +377 -0
  45. oriented_det/geometry/__init__.py +8 -0
  46. oriented_det/geometry/poly.py +127 -0
  47. oriented_det/geometry/qbox.py +63 -0
  48. oriented_det/geometry/rbox.py +250 -0
  49. oriented_det/geometry/transforms.py +266 -0
  50. oriented_det/models/__init__.py +36 -0
  51. oriented_det/models/backbones/__init__.py +11 -0
  52. oriented_det/models/backbones/resnet_fpn.py +79 -0
  53. oriented_det/models/backbones/utils.py +81 -0
  54. oriented_det/models/bbox_coder.py +355 -0
  55. oriented_det/models/faster_rcnn_inference.py +494 -0
  56. oriented_det/models/horizontal_roi_coder.py +155 -0
  57. oriented_det/models/oriented_rcnn.py +1256 -0
  58. oriented_det/models/oriented_roi.py +1664 -0
  59. oriented_det/models/oriented_rpn.py +2104 -0
  60. oriented_det/models/rotated_retinanet.py +1030 -0
  61. oriented_det/models/utils.py +590 -0
  62. oriented_det/ops/__init__.py +58 -0
  63. oriented_det/ops/gpu_ops.py +1109 -0
  64. oriented_det/ops/iou.py +172 -0
  65. oriented_det/ops/kfiou.py +275 -0
  66. oriented_det/ops/nms.py +202 -0
  67. oriented_det/ops/probiou.py +165 -0
  68. oriented_det/ops/rotated_ops.py +122 -0
  69. oriented_det/ops/utils.py +257 -0
  70. oriented_det/pretrained/__init__.py +23 -0
  71. oriented_det/pretrained/hub.py +249 -0
  72. oriented_det/pretrained/manifest.json +46 -0
  73. oriented_det/runtime/__init__.py +29 -0
  74. oriented_det/runtime/checkpoint.py +274 -0
  75. oriented_det/runtime/collate.py +348 -0
  76. oriented_det/runtime/inference.py +1286 -0
  77. oriented_det/train/__init__.py +102 -0
  78. oriented_det/train/config.py +872 -0
  79. oriented_det/train/engine.py +2933 -0
  80. oriented_det/train/grouped_ce.py +139 -0
  81. oriented_det/train/piecewise_schedule.py +33 -0
  82. oriented_det/train/profiler.py +287 -0
  83. oriented_det/train/utils.py +999 -0
  84. oriented_det/utils/__init__.py +31 -0
  85. oriented_det/utils/config.py +376 -0
  86. oriented_det/utils/device.py +35 -0
  87. oriented_det/utils/logging.py +163 -0
  88. oriented_det/utils/progress.py +62 -0
  89. oriented_det/utils/viz.py +181 -0
  90. oriented_det-0.1.0.dist-info/METADATA +313 -0
  91. oriented_det-0.1.0.dist-info/RECORD +115 -0
  92. oriented_det-0.1.0.dist-info/WHEEL +5 -0
  93. oriented_det-0.1.0.dist-info/entry_points.txt +2 -0
  94. oriented_det-0.1.0.dist-info/licenses/LICENSE +202 -0
  95. oriented_det-0.1.0.dist-info/top_level.txt +3 -0
  96. tools/__init__.py +1 -0
  97. tools/app.py +1284 -0
  98. tools/dataset_stats.py +389 -0
  99. tools/dota_labels_to_comma.py +132 -0
  100. tools/free_gpu.py +142 -0
  101. tools/generate_airbus_playground_csv.py +154 -0
  102. tools/image_demo.py +200 -0
  103. tools/lr_finder.py +771 -0
  104. tools/measure_sampled_riou_error.py +483 -0
  105. tools/playground_to_dota.py +290 -0
  106. tools/pretrained_download.py +49 -0
  107. tools/preview_augmentation.py +510 -0
  108. tools/publish_checkpoint.py +96 -0
  109. tools/save_predictions.py +2350 -0
  110. tools/sync_vendored_configs.py +94 -0
  111. tools/tile_dota.py +447 -0
  112. tools/train.py +2324 -0
  113. tools/train_example.py +244 -0
  114. tools/train_multi_gpu.py +367 -0
  115. tools/visualize_boxes.py +183 -0
export/__init__.py ADDED
@@ -0,0 +1,9 @@
1
+ """Export utilities for converting OrientedDet models to other runtimes.
2
+
3
+ This package is used by the `odet export-onnx` CLI command.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ __all__ = []
9
+
export/ort_runtime.py ADDED
@@ -0,0 +1,67 @@
1
+ """ONNX Runtime device selection and session cache for TF export inference."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from typing import Dict, List, Optional, Tuple
7
+
8
+ _ORT_DEVICE_OVERRIDE: Optional[str] = None
9
+ _SESSION_CACHE: Dict[Tuple[str, Tuple[str, ...]], object] = {}
10
+
11
+
12
+ def set_ort_device(device: Optional[str]) -> None:
13
+ """Set runtime ORT device for this process (``cpu``, ``cuda``, ``auto``)."""
14
+ global _ORT_DEVICE_OVERRIDE
15
+ _ORT_DEVICE_OVERRIDE = device.lower().strip() if device else None
16
+
17
+
18
+ def get_ort_device() -> str:
19
+ """Resolved ORT device string (override, env, or default ``cpu``)."""
20
+ if _ORT_DEVICE_OVERRIDE:
21
+ return _ORT_DEVICE_OVERRIDE
22
+ return (os.environ.get("ORIENTED_DET_ORT_DEVICE") or "cpu").lower().strip()
23
+
24
+
25
+ def ort_providers_for_device(device: Optional[str] = None) -> List[str]:
26
+ """Map device string to ONNX Runtime ``providers`` list."""
27
+ import onnxruntime as ort
28
+
29
+ d = (device if device is not None else get_ort_device()).lower().strip()
30
+ if d in ("cuda", "gpu"):
31
+ if "CUDAExecutionProvider" not in ort.get_available_providers():
32
+ raise RuntimeError(
33
+ "ORT device=cuda requested but CUDAExecutionProvider is not available. "
34
+ "Install onnxruntime-gpu matching your CUDA driver."
35
+ )
36
+ return ["CUDAExecutionProvider", "CPUExecutionProvider"]
37
+ if d == "auto":
38
+ if "CUDAExecutionProvider" in ort.get_available_providers():
39
+ return ["CUDAExecutionProvider", "CPUExecutionProvider"]
40
+ return ["CPUExecutionProvider"]
41
+ if d != "cpu":
42
+ raise ValueError(f"Unknown ORT device {device!r}; use cpu, cuda, or auto.")
43
+ return ["CPUExecutionProvider"]
44
+
45
+
46
+ def configure_ort_device(device: Optional[str]) -> List[str]:
47
+ """Apply device override and return the provider list that will be used."""
48
+ if device is not None:
49
+ set_ort_device(device)
50
+ providers = ort_providers_for_device()
51
+ return providers
52
+
53
+
54
+ def clear_ort_session_cache() -> None:
55
+ """Drop cached ORT sessions (tests)."""
56
+ _SESSION_CACHE.clear()
57
+
58
+
59
+ def get_ort_session(onnx_path: str, device: Optional[str] = None):
60
+ """Return a cached ``onnxruntime.InferenceSession`` for ``onnx_path``."""
61
+ import onnxruntime as ort
62
+
63
+ providers = ort_providers_for_device(device)
64
+ key = (str(onnx_path), tuple(providers))
65
+ if key not in _SESSION_CACHE:
66
+ _SESSION_CACHE[key] = ort.InferenceSession(onnx_path, providers=providers)
67
+ return _SESSION_CACHE[key]
export/postprocess.py ADDED
@@ -0,0 +1,151 @@
1
+ """Post-NMS detection finalization for TF export (exact CPU rotated NMS + score filter)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ import numpy as np
8
+
9
+ try:
10
+ import torch
11
+ except Exception: # pragma: no cover
12
+ torch = None # type: ignore
13
+
14
+ from export.ort_runtime import get_ort_device, get_ort_session
15
+ from oriented_det.models.faster_rcnn_inference import PreNmsDetections, apply_final_rotated_nms
16
+ from oriented_det.train.utils import effective_score_threshold_for_class_name
17
+
18
+
19
+ class _NmsConfigView:
20
+ """Minimal attribute bag for :func:`apply_final_rotated_nms`."""
21
+
22
+ def __init__(
23
+ self,
24
+ *,
25
+ nms_class_agnostic: bool,
26
+ final_nms_iou_threshold: float,
27
+ max_detections_per_image: Optional[int],
28
+ final_nms_use_cpu: bool,
29
+ ) -> None:
30
+ self.nms_class_agnostic = nms_class_agnostic
31
+ self.final_nms_iou_threshold = final_nms_iou_threshold
32
+ self.max_detections_per_image = max_detections_per_image
33
+ self.final_nms_use_cpu = final_nms_use_cpu
34
+
35
+
36
+ def finalize_detections_numpy(
37
+ pre_nms_boxes: np.ndarray,
38
+ pre_nms_scores: np.ndarray,
39
+ pre_nms_labels: np.ndarray,
40
+ pre_nms_count: int,
41
+ *,
42
+ nms_class_agnostic: bool,
43
+ final_nms_iou_threshold: float,
44
+ max_detections_per_image: Optional[int],
45
+ final_nms_use_cpu: bool,
46
+ score_threshold: float,
47
+ per_class_score_threshold: Optional[Dict[str, float]],
48
+ class_id_to_name: Dict[int, str],
49
+ max_output_slots: int,
50
+ ) -> Tuple[np.ndarray, int]:
51
+ """Run rotated NMS + production score filter; return padded ``[max_output_slots, 7]``."""
52
+ if torch is None:
53
+ raise RuntimeError("torch is required for finalize_detections_numpy")
54
+
55
+ out = np.zeros((max_output_slots, 7), dtype=np.float32)
56
+ n = int(pre_nms_count)
57
+ if n <= 0:
58
+ return out, 0
59
+
60
+ boxes_t = torch.from_numpy(np.asarray(pre_nms_boxes[:n], dtype=np.float32))
61
+ scores_t = torch.from_numpy(np.asarray(pre_nms_scores[:n], dtype=np.float32))
62
+ labels_t = torch.from_numpy(np.asarray(pre_nms_labels[:n], dtype=np.int64))
63
+
64
+ nms_view = _NmsConfigView(
65
+ nms_class_agnostic=nms_class_agnostic,
66
+ final_nms_iou_threshold=final_nms_iou_threshold,
67
+ max_detections_per_image=max_detections_per_image,
68
+ final_nms_use_cpu=final_nms_use_cpu,
69
+ )
70
+ final = apply_final_rotated_nms(nms_view, PreNmsDetections(boxes_t, scores_t, labels_t))
71
+
72
+ if final.boxes.numel() == 0:
73
+ return out, 0
74
+
75
+ keep_mask = torch.ones(final.scores.shape[0], dtype=torch.bool)
76
+ if per_class_score_threshold or score_threshold is not None:
77
+ keep_list = []
78
+ for i in range(int(final.scores.shape[0])):
79
+ lid = int(final.labels[i].item())
80
+ cname = class_id_to_name.get(lid, f"class_{lid}")
81
+ thr = effective_score_threshold_for_class_name(
82
+ cname, score_threshold, per_class_score_threshold
83
+ )
84
+ keep_list.append(float(final.scores[i].item()) >= thr)
85
+ keep_mask = torch.tensor(keep_list, dtype=torch.bool)
86
+
87
+ final_boxes = final.boxes[keep_mask]
88
+ final_scores = final.scores[keep_mask]
89
+ final_labels = final.labels[keep_mask]
90
+
91
+ m = min(int(final_boxes.shape[0]), max_output_slots)
92
+ if m == 0:
93
+ return out, 0
94
+
95
+ det = np.column_stack(
96
+ [
97
+ final_boxes[:m].cpu().numpy(),
98
+ final_scores[:m].cpu().numpy().reshape(-1, 1),
99
+ final_labels[:m].cpu().numpy().astype(np.float32).reshape(-1, 1),
100
+ ]
101
+ )
102
+ out[:m] = det
103
+ return out, m
104
+
105
+
106
+ def build_class_id_to_name(class_names: List[str]) -> Dict[int, str]:
107
+ """Map 1-based foreground label id to class name."""
108
+ return {i + 1: name for i, name in enumerate(class_names)}
109
+
110
+
111
+ def ort_pre_nms_to_detections(
112
+ images: "np.ndarray",
113
+ onnx_path: str,
114
+ ort_output_names: List[str],
115
+ finalize_kwargs: Dict[str, Any],
116
+ ) -> Tuple["np.ndarray", int]:
117
+ """ONNX Runtime forward + finalize (for TF ``numpy_function`` / Keras bundle)."""
118
+ import numpy as np
119
+
120
+ img = np.asarray(images, dtype=np.float32)
121
+ if img.ndim == 3:
122
+ img = img[np.newaxis, ...]
123
+ sess = get_ort_session(onnx_path)
124
+ input_name = sess.get_inputs()[0].name
125
+ outs = sess.run(ort_output_names, {input_name: img})
126
+ name_to_val = dict(zip(ort_output_names, outs))
127
+ detections, num = finalize_detections_numpy(
128
+ name_to_val["pre_nms_boxes"],
129
+ name_to_val["pre_nms_scores"],
130
+ name_to_val["pre_nms_labels"],
131
+ int(np.asarray(name_to_val["pre_nms_count"]).reshape(-1)[0]),
132
+ **finalize_kwargs,
133
+ )
134
+ return detections, int(num)
135
+
136
+
137
+ def meta_to_finalize_kwargs(meta: Dict[str, Any]) -> Dict[str, Any]:
138
+ """Extract finalize_detections_numpy kwargs from export meta JSON."""
139
+ prod = meta.get("production") or {}
140
+ class_names: List[str] = list(meta.get("class_names") or [])
141
+ max_det = int(prod.get("max_detections_per_image") or meta.get("max_detections_per_image") or 3000)
142
+ return {
143
+ "nms_class_agnostic": bool(prod.get("nms_class_agnostic", False)),
144
+ "final_nms_iou_threshold": float(prod.get("final_nms_iou_threshold", 0.1)),
145
+ "max_detections_per_image": max_det,
146
+ "final_nms_use_cpu": bool(prod.get("final_nms_use_cpu", True)),
147
+ "score_threshold": float(prod.get("score_threshold", 0.05)),
148
+ "per_class_score_threshold": prod.get("per_class_score_threshold"),
149
+ "class_id_to_name": build_class_id_to_name(class_names),
150
+ "max_output_slots": max_det,
151
+ }
@@ -0,0 +1,6 @@
1
+ """Export scripts used by the `odet` CLI."""
2
+
3
+ from __future__ import annotations
4
+
5
+ __all__ = []
6
+
@@ -0,0 +1,104 @@
1
+ #!/usr/bin/env python3
2
+ """Build e2e Faster R-CNN SavedModel (Keras + ORT core + exact rotated NMS)."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import json
8
+ import sys
9
+ from pathlib import Path
10
+ from typing import Optional
11
+
12
+ _REPO_ROOT = Path(__file__).resolve().parents[2]
13
+ if str(_REPO_ROOT) not in sys.path:
14
+ sys.path.insert(0, str(_REPO_ROOT))
15
+
16
+ from export.postprocess import meta_to_finalize_kwargs # noqa: E402
17
+ from export.tf_serving_model import FasterRCNNDetectLayer, save_keras_detect_bundle # noqa: E402
18
+
19
+
20
+ def _load_meta(meta_path: Path) -> dict:
21
+ return json.loads(meta_path.read_text(encoding="utf-8"))
22
+
23
+
24
+ def _resolve_onnx_path(meta: dict, meta_path: Path, onnx_path: Optional[Path]) -> Path:
25
+ if onnx_path is not None and onnx_path.is_file():
26
+ return onnx_path
27
+ if meta.get("onnx_path"):
28
+ p = Path(meta["onnx_path"])
29
+ if p.is_file():
30
+ return p
31
+ candidate = meta_path.with_suffix("").with_suffix(".onnx")
32
+ if candidate.is_file():
33
+ return candidate
34
+ raise FileNotFoundError(
35
+ "ONNX file not found; pass --onnx or ensure meta onnx_path / sibling .onnx exists."
36
+ )
37
+
38
+
39
+ def build_savedmodel(
40
+ output_path: Path,
41
+ meta_path: Path,
42
+ *,
43
+ onnx_path: Optional[Path] = None,
44
+ ) -> None:
45
+ import tensorflow as tf
46
+
47
+ meta = _load_meta(meta_path)
48
+ finalize_kwargs = meta_to_finalize_kwargs(meta)
49
+ onnx_file = _resolve_onnx_path(meta, meta_path, onnx_path)
50
+
51
+ h = int(meta["input"]["shape"][2])
52
+ w = int(meta["input"]["shape"][3])
53
+ inputs = tf.keras.Input(shape=(3, h, w), batch_size=1, name="images")
54
+ layer = FasterRCNNDetectLayer(
55
+ onnx_path=str(onnx_file.resolve()),
56
+ ort_output_names=list(meta.get("output_names") or []),
57
+ finalize_kwargs=finalize_kwargs,
58
+ max_output_slots=int(finalize_kwargs["max_output_slots"]),
59
+ )
60
+ layer_out = layer(inputs)
61
+ model = tf.keras.Model(
62
+ inputs=inputs,
63
+ outputs=[layer_out["detections"], layer_out["num_detections"]],
64
+ )
65
+
66
+ full_meta = dict(meta)
67
+ full_meta["core_backend"] = "onnxruntime_keras"
68
+ full_meta["onnx_path"] = str(onnx_file)
69
+ full_meta["savedmodel_outputs"] = {
70
+ "detections": {
71
+ "shape": [finalize_kwargs["max_output_slots"], 7],
72
+ "layout": "cx,cy,w,h,angle,score,label",
73
+ },
74
+ "num_detections": {"dtype": "int32"},
75
+ }
76
+
77
+ keras_path = save_keras_detect_bundle(model, output_path, full_meta)
78
+ print(f"Wrote detect bundle: {output_path} (keras: {keras_path.name}, core: onnxruntime)")
79
+
80
+
81
+ def main() -> None:
82
+ p = argparse.ArgumentParser(description="Build e2e Faster R-CNN SavedModel.")
83
+ p.add_argument(
84
+ "--tf-core",
85
+ type=Path,
86
+ default=None,
87
+ help="Ignored (kept for CLI compatibility with README pipeline).",
88
+ )
89
+ p.add_argument("--meta", type=Path, required=True, help="*.export_meta.json from export_onnx.py.")
90
+ p.add_argument("--onnx", type=Path, default=None, help="ONNX file (default: from meta / sibling).")
91
+ p.add_argument("--output", type=Path, required=True, help="Output SavedModel directory.")
92
+ args = p.parse_args()
93
+
94
+ try:
95
+ import tensorflow as tf # noqa: F401
96
+ except ImportError as e:
97
+ print("Install TensorFlow: pip install -r export/requirements-export.txt", file=sys.stderr)
98
+ raise SystemExit(1) from e
99
+
100
+ build_savedmodel(args.output, args.meta, onnx_path=args.onnx)
101
+
102
+
103
+ if __name__ == "__main__":
104
+ main()
@@ -0,0 +1,210 @@
1
+ #!/usr/bin/env python3
2
+ """Export a tensor-only subgraph from oriented-det to ONNX.
3
+
4
+ See export/README.md and export/contract.json for supported modes.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import argparse
10
+ import json
11
+ import sys
12
+ from dataclasses import asdict
13
+ from pathlib import Path
14
+ from typing import Any
15
+
16
+ import torch
17
+
18
+ _REPO_ROOT = Path(__file__).resolve().parents[2]
19
+ _EXPORT_DIR = _REPO_ROOT / "export"
20
+ if str(_REPO_ROOT) not in sys.path:
21
+ sys.path.insert(0, str(_REPO_ROOT))
22
+ if str(_EXPORT_DIR) not in sys.path:
23
+ sys.path.insert(0, str(_EXPORT_DIR))
24
+
25
+ import wrappers as _wrappers # noqa: E402
26
+ from oriented_det.models.oriented_rcnn import RotatedFasterRCNN # noqa: E402
27
+ from oriented_det.models.rotated_retinanet import RotatedRetinaNet # noqa: E402
28
+ from oriented_det.runtime.checkpoint import load_model_from_checkpoint # noqa: E402
29
+
30
+ BackboneExportWrapper = _wrappers.BackboneExportWrapper
31
+ RetinaNetBackboneHeadExportWrapper = _wrappers.RetinaNetBackboneHeadExportWrapper
32
+ RotatedFasterRCNNPreNmsExportWrapper = _wrappers.RotatedFasterRCNNPreNmsExportWrapper
33
+
34
+ FASTER_RCNN_PRE_NMS_OUTPUTS = (
35
+ "pre_nms_boxes",
36
+ "pre_nms_scores",
37
+ "pre_nms_labels",
38
+ "pre_nms_count",
39
+ )
40
+
41
+
42
+ def _build_wrapper(
43
+ model: torch.nn.Module,
44
+ mode: str,
45
+ height: int,
46
+ width: int,
47
+ ) -> torch.nn.Module:
48
+ mt = type(model).__name__
49
+ if mode == "backbone":
50
+ if not hasattr(model, "backbone"):
51
+ raise ValueError(f"Model {mt} has no 'backbone' attribute.")
52
+ return BackboneExportWrapper(model.backbone)
53
+ if mode == "retinanet_heads":
54
+ if not isinstance(model, RotatedRetinaNet):
55
+ raise ValueError(
56
+ f"retinanet_heads mode requires RotatedRetinaNet, got {mt}. "
57
+ "Use --mode backbone for two-stage models."
58
+ )
59
+ return RetinaNetBackboneHeadExportWrapper(model)
60
+ if mode == "faster_rcnn_pre_nms":
61
+ if not isinstance(model, RotatedFasterRCNN):
62
+ raise ValueError(
63
+ f"faster_rcnn_pre_nms requires RotatedFasterRCNN, got {mt}."
64
+ )
65
+ return RotatedFasterRCNNPreNmsExportWrapper(model, height=height, width=width)
66
+ raise ValueError(f"Unknown mode: {mode}")
67
+
68
+
69
+ def _output_names_for_wrapper(
70
+ wrapper: torch.nn.Module,
71
+ mode: str,
72
+ device: torch.device,
73
+ height: int,
74
+ width: int,
75
+ ) -> list[str]:
76
+ if mode == "faster_rcnn_pre_nms":
77
+ return list(FASTER_RCNN_PRE_NMS_OUTPUTS)
78
+
79
+ dummy = torch.zeros(1, 3, 128, 128, dtype=torch.float32, device=device)
80
+ with torch.no_grad():
81
+ out = wrapper(dummy)
82
+ names: list[str] = []
83
+ if mode == "backbone":
84
+ for i in range(len(out)):
85
+ names.append(f"fpn_level_{i}")
86
+ return names
87
+ if mode == "retinanet_heads":
88
+ for i in range(len(out) // 2):
89
+ names.append(f"level{i}_cls_logits")
90
+ names.append(f"level{i}_bbox_pred")
91
+ return names
92
+ return [f"out_{i}" for i in range(len(out))]
93
+
94
+
95
+ def _production_dict(config: Any) -> dict | None:
96
+ prod = getattr(config, "production", None)
97
+ if prod is None:
98
+ return None
99
+ if hasattr(prod, "__dataclass_fields__"):
100
+ return asdict(prod)
101
+ if isinstance(prod, dict):
102
+ return prod
103
+ return None
104
+
105
+
106
+ def _validate_onnx_ort(onnx_path: Path, dummy: torch.Tensor, output_names: list[str]) -> None:
107
+ try:
108
+ import onnx
109
+ import onnxruntime as ort
110
+ except ImportError:
111
+ print("Skipping ORT validation (install onnx and onnxruntime).")
112
+ return
113
+
114
+ onnx.checker.check_model(onnx.load(str(onnx_path)))
115
+ sess = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"])
116
+ feeds = {"images": dummy.detach().cpu().numpy()}
117
+ outs = sess.run(output_names, feeds)
118
+ print(f"ORT smoke OK: {len(outs)} outputs")
119
+ for name, arr in zip(output_names, outs):
120
+ print(f" {name}: shape={arr.shape} dtype={arr.dtype}")
121
+
122
+
123
+ def main() -> None:
124
+ p = argparse.ArgumentParser(description="Export oriented-det subgraph to ONNX.")
125
+ p.add_argument("--config", type=Path, required=True, help="Training JSON config path.")
126
+ p.add_argument("--checkpoint", type=Path, required=True, help="Weights .pth path.")
127
+ p.add_argument("--output", type=Path, required=True, help="Output .onnx file path.")
128
+ p.add_argument("--height", type=int, default=1024)
129
+ p.add_argument("--width", type=int, default=1024)
130
+ p.add_argument(
131
+ "--mode",
132
+ choices=("backbone", "retinanet_heads", "faster_rcnn_pre_nms"),
133
+ default="backbone",
134
+ help="backbone | retinanet_heads | faster_rcnn_pre_nms (Rotated Faster R-CNN detect).",
135
+ )
136
+ p.add_argument("--opset", type=int, default=17)
137
+ p.add_argument(
138
+ "--dynamic-batch",
139
+ action="store_true",
140
+ help="Allow dynamic batch size (N) on input images; H and W stay fixed.",
141
+ )
142
+ p.add_argument(
143
+ "--skip-ort",
144
+ action="store_true",
145
+ help="Skip onnxruntime validation after export.",
146
+ )
147
+ p.add_argument("--device", default="cpu", help="Device to trace on (cpu recommended for reproducibility).")
148
+ args = p.parse_args()
149
+
150
+ model, config, class_names = load_model_from_checkpoint(
151
+ str(args.checkpoint),
152
+ str(args.config),
153
+ device=args.device,
154
+ )
155
+ h, w = int(args.height), int(args.width)
156
+ wrapper = _build_wrapper(model, args.mode, h, w)
157
+ wrapper.eval()
158
+ wrapper.to(args.device)
159
+
160
+ dev = torch.device(args.device)
161
+ out_names = _output_names_for_wrapper(wrapper, args.mode, dev, h, w)
162
+ dummy = torch.zeros(1, 3, h, w, dtype=torch.float32, device=dev)
163
+
164
+ dynamic_axes = None
165
+ if args.dynamic_batch:
166
+ dynamic_axes = {"images": {0: "batch"}}
167
+ for name in out_names:
168
+ dynamic_axes[name] = {0: "batch"}
169
+
170
+ args.output.parent.mkdir(parents=True, exist_ok=True)
171
+ torch.onnx.export(
172
+ wrapper,
173
+ dummy,
174
+ str(args.output),
175
+ input_names=["images"],
176
+ output_names=out_names,
177
+ dynamic_axes=dynamic_axes,
178
+ opset_version=int(args.opset),
179
+ do_constant_folding=True,
180
+ )
181
+
182
+ meta: dict[str, Any] = {
183
+ "mode": args.mode,
184
+ "input": {"name": "images", "shape": [1, 3, h, w], "dtype": "float32"},
185
+ "output_names": out_names,
186
+ "opset": args.opset,
187
+ "dynamic_batch": args.dynamic_batch,
188
+ "config": str(args.config),
189
+ "checkpoint": str(args.checkpoint),
190
+ "class_names": class_names or [],
191
+ "num_classes": int(getattr(config, "num_classes", 0) or 0),
192
+ "production": _production_dict(config),
193
+ }
194
+ if args.mode == "faster_rcnn_pre_nms" and isinstance(wrapper, RotatedFasterRCNNPreNmsExportWrapper):
195
+ meta["max_pre_nms_candidates"] = wrapper.max_candidates
196
+ prod = meta.get("production") or {}
197
+ meta["max_detections_per_image"] = prod.get("max_detections_per_image")
198
+
199
+ meta["onnx_path"] = str(args.output.resolve())
200
+ meta_path = args.output.with_suffix(".export_meta.json")
201
+ meta_path.write_text(json.dumps(meta, indent=2), encoding="utf-8")
202
+ print(f"Wrote ONNX: {args.output}")
203
+ print(f"Wrote meta: {meta_path}")
204
+
205
+ if not args.skip_ort:
206
+ _validate_onnx_ort(args.output, dummy, out_names)
207
+
208
+
209
+ if __name__ == "__main__":
210
+ main()
@@ -0,0 +1,35 @@
1
+ #!/usr/bin/env python3
2
+ """Convert ONNX to TensorFlow SavedModel using onnx2tf."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import shutil
8
+ import subprocess
9
+ import sys
10
+ from pathlib import Path
11
+
12
+
13
+ def _find_onnx2tf() -> list[str]:
14
+ exe = shutil.which("onnx2tf")
15
+ if exe:
16
+ return [exe]
17
+ return [sys.executable, "-m", "onnx2tf"]
18
+
19
+
20
+ def main() -> None:
21
+ p = argparse.ArgumentParser(description="ONNX → SavedModel via onnx2tf.")
22
+ p.add_argument("--onnx", type=Path, required=True)
23
+ p.add_argument("--output", type=Path, required=True, help="Output directory for SavedModel.")
24
+ args = p.parse_args()
25
+
26
+ # tf_converter can fail on some ops (e.g. ScatterND); flatbuffer_direct still useful for TFLite.
27
+ # build_faster_rcnn_savedmodel.py falls back to ONNX Runtime when no saved_model.pb is found.
28
+ cmd = _find_onnx2tf() + ["-i", str(args.onnx), "-o", str(args.output), "-osd"]
29
+ print("Running:", " ".join(cmd))
30
+ subprocess.run(cmd, check=True)
31
+ print(f"SavedModel directory: {args.output}")
32
+
33
+
34
+ if __name__ == "__main__":
35
+ main()
@@ -0,0 +1,94 @@
1
+ #!/usr/bin/env python3
2
+ """Smoke-test an export detect bundle (keras_model.keras) or legacy TF SavedModel."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import json
8
+ import sys
9
+ from pathlib import Path
10
+
11
+ _REPO_ROOT = Path(__file__).resolve().parents[2]
12
+ if str(_REPO_ROOT) not in sys.path:
13
+ sys.path.insert(0, str(_REPO_ROOT))
14
+
15
+
16
+ def _run_keras_bundle(bundle_dir: Path, height: int, width: int, ort_device: str | None) -> None:
17
+ import tensorflow as tf
18
+
19
+ from export.ort_runtime import configure_ort_device, get_ort_device
20
+ from export.tf_serving_model import load_keras_detect_model
21
+
22
+ providers = configure_ort_device(ort_device)
23
+ print(f" ort_device: {get_ort_device()} providers: {providers}")
24
+
25
+ keras_path = bundle_dir / "keras_model.keras"
26
+ if not keras_path.is_file():
27
+ raise FileNotFoundError(f"Missing {keras_path}")
28
+ model = load_keras_detect_model(keras_path)
29
+ x = tf.zeros([1, 3, height, width], dtype=tf.float32)
30
+ detections, num_detections = model(x, training=False)
31
+ print(f" detections: shape={detections.shape} dtype={detections.dtype}")
32
+ print(f" num_detections: {int(num_detections.numpy())}")
33
+ meta_path = bundle_dir / "export_meta.json"
34
+ if meta_path.is_file():
35
+ meta = json.loads(meta_path.read_text(encoding="utf-8"))
36
+ print(f" core_backend: {meta.get('core_backend', 'unknown')}")
37
+
38
+
39
+ def _run_saved_model(sm_dir: Path, height: int, width: int) -> None:
40
+ import tensorflow as tf
41
+
42
+ sm = tf.saved_model.load(str(sm_dir))
43
+ sigs = list(getattr(sm, "signatures", {}).keys())
44
+ print("Signatures:", sigs)
45
+ name = "serving_default" if "serving_default" in sigs else sigs[0]
46
+ fn = sm.signatures[name]
47
+
48
+ kwargs = {}
49
+ _pos, kw = fn.structured_input_signature
50
+ if kw:
51
+ for key, spec in kw.items():
52
+ if spec.dtype.is_floating and spec.shape.rank == 4:
53
+ shape = [1, 3, height, width]
54
+ kwargs[key] = tf.zeros(shape, dtype=spec.dtype)
55
+ if not kwargs:
56
+ raise SystemExit("Could not infer image input from SavedModel signature.")
57
+ out = fn(**kwargs)
58
+ for k, v in out.items():
59
+ print(f" {k}: shape={v.shape} dtype={v.dtype}")
60
+
61
+
62
+ def main() -> None:
63
+ p = argparse.ArgumentParser(description="Smoke-test export detect bundle or SavedModel.")
64
+ p.add_argument(
65
+ "--saved-model",
66
+ type=Path,
67
+ required=True,
68
+ help="Directory with keras_model.keras (+ export_meta.json) or TF SavedModel.",
69
+ )
70
+ p.add_argument("--height", type=int, default=1024)
71
+ p.add_argument("--width", type=int, default=1024)
72
+ p.add_argument(
73
+ "--ort-device",
74
+ default=None,
75
+ choices=("cpu", "cuda", "auto"),
76
+ help="ONNX Runtime EP for keras bundle (default: cpu or ORIENTED_DET_ORT_DEVICE).",
77
+ )
78
+ args = p.parse_args()
79
+
80
+ try:
81
+ import tensorflow as tf # noqa: F401
82
+ except ImportError as e:
83
+ print("Install TensorFlow: pip install -r export/requirements-export.txt", file=sys.stderr)
84
+ raise SystemExit(1) from e
85
+
86
+ bundle = args.saved_model
87
+ if (bundle / "keras_model.keras").is_file():
88
+ _run_keras_bundle(bundle, args.height, args.width, args.ort_device)
89
+ else:
90
+ _run_saved_model(bundle, args.height, args.width)
91
+
92
+
93
+ if __name__ == "__main__":
94
+ main()