oriented-det 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. export/__init__.py +9 -0
  2. export/ort_runtime.py +67 -0
  3. export/postprocess.py +151 -0
  4. export/scripts/__init__.py +6 -0
  5. export/scripts/build_faster_rcnn_savedmodel.py +104 -0
  6. export/scripts/export_onnx.py +210 -0
  7. export/scripts/onnx_to_savedmodel.py +35 -0
  8. export/scripts/predict_savedmodel.py +94 -0
  9. export/scripts/save_predictions_tf.py +447 -0
  10. export/scripts/to_tflite.py +32 -0
  11. export/tests/test_export_onnx_optional.py +82 -0
  12. export/tests/test_export_wrappers.py +105 -0
  13. export/tests/test_faster_rcnn_export_parity.py +201 -0
  14. export/tests/test_ort_runtime.py +41 -0
  15. export/tf_serving_model.py +96 -0
  16. export/val_dataset.py +116 -0
  17. export/wrappers.py +161 -0
  18. oriented_det/__init__.py +77 -0
  19. oriented_det/cli/__init__.py +92 -0
  20. oriented_det/cli/train.py +18 -0
  21. oriented_det/configs/_base_/augmentation.json +21 -0
  22. oriented_det/configs/_base_/datasets/dota_le90.json +21 -0
  23. oriented_det/configs/_base_/fp16.json +5 -0
  24. oriented_det/configs/_base_/models/oriented_rcnn_r50.json +26 -0
  25. oriented_det/configs/_base_/models/rotated_faster_rcnn_r50.json +53 -0
  26. oriented_det/configs/_base_/models/rotated_retinanet_r50.json +30 -0
  27. oriented_det/configs/_base_/preprocessing.json +8 -0
  28. oriented_det/configs/_base_/schedules/1x.json +33 -0
  29. oriented_det/configs/config.schema.json +404 -0
  30. oriented_det/configs/oriented_rcnn/dota_le90_1x.json +121 -0
  31. oriented_det/configs/oriented_rcnn/dota_le90_3x.json +11 -0
  32. oriented_det/configs/rotated_faster_rcnn/dota_le90_1x.json +119 -0
  33. oriented_det/configs/rotated_faster_rcnn/dota_le90_3x.json +9 -0
  34. oriented_det/configs/rotated_retinanet/dota_le90_1x.json +107 -0
  35. oriented_det/configs/rotated_retinanet/dota_le90_3x.json +30 -0
  36. oriented_det/data/__init__.py +89 -0
  37. oriented_det/data/airbus_playground.py +611 -0
  38. oriented_det/data/dota.py +741 -0
  39. oriented_det/data/dota_classes.py +26 -0
  40. oriented_det/data/evaluation.py +648 -0
  41. oriented_det/data/flips.py +115 -0
  42. oriented_det/data/preprocessing.py +335 -0
  43. oriented_det/data/tiling.py +399 -0
  44. oriented_det/data/transforms.py +377 -0
  45. oriented_det/geometry/__init__.py +8 -0
  46. oriented_det/geometry/poly.py +127 -0
  47. oriented_det/geometry/qbox.py +63 -0
  48. oriented_det/geometry/rbox.py +250 -0
  49. oriented_det/geometry/transforms.py +266 -0
  50. oriented_det/models/__init__.py +36 -0
  51. oriented_det/models/backbones/__init__.py +11 -0
  52. oriented_det/models/backbones/resnet_fpn.py +79 -0
  53. oriented_det/models/backbones/utils.py +81 -0
  54. oriented_det/models/bbox_coder.py +355 -0
  55. oriented_det/models/faster_rcnn_inference.py +494 -0
  56. oriented_det/models/horizontal_roi_coder.py +155 -0
  57. oriented_det/models/oriented_rcnn.py +1256 -0
  58. oriented_det/models/oriented_roi.py +1664 -0
  59. oriented_det/models/oriented_rpn.py +2104 -0
  60. oriented_det/models/rotated_retinanet.py +1030 -0
  61. oriented_det/models/utils.py +590 -0
  62. oriented_det/ops/__init__.py +58 -0
  63. oriented_det/ops/gpu_ops.py +1109 -0
  64. oriented_det/ops/iou.py +172 -0
  65. oriented_det/ops/kfiou.py +275 -0
  66. oriented_det/ops/nms.py +202 -0
  67. oriented_det/ops/probiou.py +165 -0
  68. oriented_det/ops/rotated_ops.py +122 -0
  69. oriented_det/ops/utils.py +257 -0
  70. oriented_det/pretrained/__init__.py +23 -0
  71. oriented_det/pretrained/hub.py +249 -0
  72. oriented_det/pretrained/manifest.json +46 -0
  73. oriented_det/runtime/__init__.py +29 -0
  74. oriented_det/runtime/checkpoint.py +274 -0
  75. oriented_det/runtime/collate.py +348 -0
  76. oriented_det/runtime/inference.py +1286 -0
  77. oriented_det/train/__init__.py +102 -0
  78. oriented_det/train/config.py +872 -0
  79. oriented_det/train/engine.py +2933 -0
  80. oriented_det/train/grouped_ce.py +139 -0
  81. oriented_det/train/piecewise_schedule.py +33 -0
  82. oriented_det/train/profiler.py +287 -0
  83. oriented_det/train/utils.py +999 -0
  84. oriented_det/utils/__init__.py +31 -0
  85. oriented_det/utils/config.py +376 -0
  86. oriented_det/utils/device.py +35 -0
  87. oriented_det/utils/logging.py +163 -0
  88. oriented_det/utils/progress.py +62 -0
  89. oriented_det/utils/viz.py +181 -0
  90. oriented_det-0.1.0.dist-info/METADATA +313 -0
  91. oriented_det-0.1.0.dist-info/RECORD +115 -0
  92. oriented_det-0.1.0.dist-info/WHEEL +5 -0
  93. oriented_det-0.1.0.dist-info/entry_points.txt +2 -0
  94. oriented_det-0.1.0.dist-info/licenses/LICENSE +202 -0
  95. oriented_det-0.1.0.dist-info/top_level.txt +3 -0
  96. tools/__init__.py +1 -0
  97. tools/app.py +1284 -0
  98. tools/dataset_stats.py +389 -0
  99. tools/dota_labels_to_comma.py +132 -0
  100. tools/free_gpu.py +142 -0
  101. tools/generate_airbus_playground_csv.py +154 -0
  102. tools/image_demo.py +200 -0
  103. tools/lr_finder.py +771 -0
  104. tools/measure_sampled_riou_error.py +483 -0
  105. tools/playground_to_dota.py +290 -0
  106. tools/pretrained_download.py +49 -0
  107. tools/preview_augmentation.py +510 -0
  108. tools/publish_checkpoint.py +96 -0
  109. tools/save_predictions.py +2350 -0
  110. tools/sync_vendored_configs.py +94 -0
  111. tools/tile_dota.py +447 -0
  112. tools/train.py +2324 -0
  113. tools/train_example.py +244 -0
  114. tools/train_multi_gpu.py +367 -0
  115. tools/visualize_boxes.py +183 -0
@@ -0,0 +1,447 @@
1
+ #!/usr/bin/env python3
2
+ """Run validation inference with the TF/Keras export bundle; write predictions.json for metrics."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import json
8
+ import os
9
+ import sys
10
+ import time
11
+ from datetime import datetime
12
+ from pathlib import Path
13
+ from typing import Any, Dict, List, Optional
14
+
15
+ _REPO_ROOT = Path(__file__).resolve().parents[2]
16
+ if str(_REPO_ROOT) not in sys.path:
17
+ sys.path.insert(0, str(_REPO_ROOT))
18
+
19
+ import cv2
20
+ import numpy as np
21
+ from PIL import Image
22
+ from tqdm import tqdm
23
+
24
+ from export.ort_runtime import configure_ort_device, get_ort_device
25
+ from export.tf_serving_model import load_keras_detect_model
26
+ from export.val_dataset import collect_split_images
27
+ from oriented_det.data import Detection, GroundTruth
28
+ from oriented_det.geometry import RBox, normalize_le90
29
+ from oriented_det.train.config import (
30
+ TrainingExperimentConfig,
31
+ effective_eval_metric_thresholds,
32
+ get_preprocessing_params,
33
+ resolve_inference_sliding_window_overlap_pixels,
34
+ )
35
+ from oriented_det.utils import tqdm_progress_stream
36
+ from oriented_det.runtime.inference import _preprocess_image_tensor_training_style, get_model_size
37
+ from oriented_det.runtime.checkpoint import load_model_from_checkpoint
38
+ from tools.save_predictions import (
39
+ _annotations_to_ground_truths,
40
+ _resolve_metrics_margin_pixels,
41
+ _rbox_centroid_in_tile_interior,
42
+ load_dota_annotations,
43
+ load_gt_as_ground_truths,
44
+ rbox_to_array,
45
+ run_diagnostics_pipeline,
46
+ )
47
+
48
+
49
+ def _load_gt_entries(
50
+ img_path: Path,
51
+ label_dir: Optional[Path],
52
+ gt_by_image_path: Optional[Dict[Path, list]],
53
+ class_map: Dict[str, int],
54
+ ) -> tuple[int, list, list]:
55
+ if gt_by_image_path is not None:
56
+ gt_list = gt_by_image_path.get(img_path, [])
57
+ num_gt = len(gt_list)
58
+ gt_entries = [
59
+ {
60
+ "bbox": rbox_to_array(gt.rbox).tolist(),
61
+ "class_name": gt.class_name,
62
+ "class_id": int(gt.class_id),
63
+ "difficult": int(getattr(gt, "difficult", 0)),
64
+ }
65
+ for gt in gt_list
66
+ ]
67
+ return num_gt, gt_entries, gt_list
68
+
69
+ txt_path = (label_dir / f"{img_path.stem}.txt") if label_dir is not None else None
70
+ try:
71
+ if txt_path and txt_path.exists():
72
+ gt_rboxes, gt_class_names = load_dota_annotations(str(txt_path))
73
+ else:
74
+ gt_rboxes = np.array([]).reshape(0, 5)
75
+ gt_class_names = []
76
+ num_gt = len(gt_rboxes)
77
+ gt_entries = [
78
+ {
79
+ "bbox": gt_rboxes[i].tolist(),
80
+ "class_name": gt_class_names[i] if i < len(gt_class_names) else "unknown",
81
+ "class_id": int(class_map.get(gt_class_names[i], -1)) if i < len(gt_class_names) else -1,
82
+ "difficult": 0,
83
+ }
84
+ for i in range(len(gt_rboxes))
85
+ ]
86
+ gt_list = load_gt_as_ground_truths(txt_path, class_map) if txt_path and txt_path.exists() else []
87
+ except Exception as exc:
88
+ print(f"Warning: Could not load GT for {img_path.name}: {exc}")
89
+ num_gt = 0
90
+ gt_entries = []
91
+ gt_list = []
92
+ return num_gt, gt_entries, gt_list
93
+
94
+
95
+ def infer_keras_on_image(
96
+ keras_model,
97
+ pil_image: Image.Image,
98
+ preprocessing: dict,
99
+ ) -> List[Dict[str, Any]]:
100
+ """Run Keras detect bundle on one image; return list of {rbox, score, label}."""
101
+ import tensorflow as tf
102
+
103
+ image_width, image_height = pil_image.size
104
+ slice_h, slice_w = get_model_size(preprocessing)
105
+ if image_height > slice_h or image_width > slice_w:
106
+ raise NotImplementedError(
107
+ f"Image {image_width}x{image_height} exceeds model canvas {slice_w}x{slice_h}. "
108
+ "TF export preds does not implement sliding-window tiling; use pre-tiled val images."
109
+ )
110
+
111
+ tensor = _preprocess_image_tensor_training_style(pil_image, preprocessing)
112
+ batch = tf.constant(tensor.unsqueeze(0).numpy(), dtype=tf.float32)
113
+ detections_t, num_t = keras_model(batch, training=False)
114
+ n = int(num_t.numpy().reshape(-1)[0])
115
+ if n <= 0:
116
+ return []
117
+
118
+ det = detections_t.numpy()[:n]
119
+ from oriented_det.data.preprocessing import build_spatial_meta_from_dims, remap_detections_to_original
120
+
121
+ mode = preprocessing.get("resize_mode", "fixed")
122
+ ts = preprocessing.get("target_size", (slice_h, slice_w))
123
+ meta = build_spatial_meta_from_dims(mode, image_width, image_height, ts)
124
+
125
+ model_dets: List[Dict[str, Any]] = []
126
+ for row in det:
127
+ cx, cy, w, h, ang, score, label = [float(x) for x in row]
128
+ rbox = normalize_le90(RBox(cx=cx, cy=cy, width=w, height=h, angle=ang))
129
+ model_dets.append({"rbox": rbox, "score": score, "label": int(label)})
130
+ return remap_detections_to_original(model_dets, meta)
131
+
132
+
133
+ def run_tf_inference_and_save(
134
+ *,
135
+ config_path: Path,
136
+ detect_dir: Path,
137
+ output_dir: Optional[Path] = None,
138
+ data_root: Optional[Path] = None,
139
+ data_split: str = "val",
140
+ val_dir: Optional[Path] = None,
141
+ run_diagnostics: bool = True,
142
+ reference_checkpoint: Optional[Path] = None,
143
+ ort_device: Optional[str] = None,
144
+ ) -> Dict[str, Any]:
145
+ ort_providers = configure_ort_device(ort_device)
146
+ config = TrainingExperimentConfig.load(config_path)
147
+ class_names = list(config.class_names or [])
148
+ preprocessing = get_preprocessing_params(config)
149
+
150
+ if not data_root:
151
+ if getattr(config, "dataset", None) and getattr(config.dataset, "data_root", None):
152
+ data_root = Path(config.dataset.data_root)
153
+ else:
154
+ raise ValueError("data_root required (CLI or config.dataset.data_root).")
155
+ data_root = Path(data_root)
156
+
157
+ keras_path = detect_dir / "keras_model.keras"
158
+ if not keras_path.is_file():
159
+ raise FileNotFoundError(f"Missing Keras bundle: {keras_path} (run make export-tf first).")
160
+ keras_model = load_keras_detect_model(keras_path)
161
+
162
+ if output_dir is None:
163
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
164
+ output_dir = detect_dir.parent / "predictions" / timestamp
165
+ output_dir = Path(output_dir)
166
+ output_dir.mkdir(parents=True, exist_ok=True)
167
+
168
+ cfg_thr_sc, cfg_thr_pc, cfg_thr_iou = effective_eval_metric_thresholds(config)
169
+ score_threshold = cfg_thr_sc
170
+ per_cls_thr = cfg_thr_pc
171
+ nms_threshold = float(
172
+ getattr(getattr(config, "production", None), "final_nms_iou_threshold", None)
173
+ or getattr(getattr(config, "model", None), "final_nms_iou_threshold", 0.1)
174
+ )
175
+ nms_class_agnostic = bool(
176
+ getattr(getattr(config, "production", None), "nms_class_agnostic", False)
177
+ or getattr(getattr(config, "model", None), "nms_class_agnostic", False)
178
+ )
179
+ iou_threshold = float(cfg_thr_iou)
180
+ overlap_pixels = resolve_inference_sliding_window_overlap_pixels(config)
181
+ resolved_metrics_margin_px = _resolve_metrics_margin_pixels(
182
+ margin_pixels=getattr(getattr(config, "production", None), "ignore_margin_pixels", None),
183
+ overlap_ratio=None,
184
+ overlap_pixels=overlap_pixels,
185
+ preprocessing=preprocessing,
186
+ )
187
+
188
+ split_images, label_dir, dataset_format = collect_split_images(
189
+ config, data_root, data_split=data_split, val_dir=val_dir
190
+ )
191
+ print(
192
+ f"TF export inference: {len(split_images)} {data_split} images → {output_dir} "
193
+ f"(ort_device={get_ort_device()}, providers={ort_providers})"
194
+ )
195
+
196
+ class_map = {name: i for i, name in enumerate(class_names)} if class_names else {}
197
+ gt_by_image_path = None
198
+ if dataset_format == "airbus_playground":
199
+ from oriented_det.data.airbus_playground import AirbusPlaygroundCSVDataset
200
+
201
+ ds_config = config.dataset
202
+ airbus_dataset = AirbusPlaygroundCSVDataset(
203
+ data_root=data_root,
204
+ split=data_split,
205
+ annotations_file=ds_config.annotations_file,
206
+ split_file=ds_config.split_file,
207
+ val_split_id=getattr(ds_config, "val_split_id", 0),
208
+ difficult_strategy=ds_config.difficult_strategy,
209
+ allowed_classes=getattr(ds_config, "allowed_classes", None),
210
+ ignore_labels=getattr(ds_config, "ignore_labels", None) or [],
211
+ map_labels=getattr(ds_config, "map_labels", None) or {},
212
+ )
213
+ gt_by_image_path = {}
214
+ for idx in range(len(airbus_dataset)):
215
+ sample = airbus_dataset[idx]
216
+ gt_by_image_path[Path(sample.image_path)] = _annotations_to_ground_truths(
217
+ list(sample.annotations), class_map
218
+ )
219
+
220
+ results: List[Dict[str, Any]] = []
221
+ all_detections: Dict[str, list] = {}
222
+ all_ground_truths: Dict[str, list] = {}
223
+ all_scores: List[float] = []
224
+ image_name_by_id: Dict[str, str] = {}
225
+
226
+ t0 = time.perf_counter()
227
+ for img_path in tqdm(
228
+ split_images,
229
+ desc=f"TF export {data_split}",
230
+ file=tqdm_progress_stream(),
231
+ ):
232
+ img_name = img_path.name
233
+ image_id = img_path.stem
234
+ image_name_by_id[image_id] = img_name
235
+
236
+ img_bgr = cv2.imread(str(img_path))
237
+ if img_bgr is None:
238
+ print(f"Warning: skip unreadable {img_path}")
239
+ continue
240
+ img_h, img_w = img_bgr.shape[:2]
241
+ pil = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
242
+
243
+ num_gt, gt_entries, gt_list_raw = _load_gt_entries(
244
+ img_path, label_dir, gt_by_image_path, class_map
245
+ )
246
+
247
+ try:
248
+ det_dicts = infer_keras_on_image(keras_model, pil, preprocessing)
249
+ except Exception as exc:
250
+ print(f"Warning: inference failed for {img_name}: {exc}")
251
+ det_dicts = []
252
+
253
+ pred_scores = [d["score"] for d in det_dicts]
254
+ pred_labels = [d["label"] for d in det_dicts]
255
+ rboxes = [d["rbox"] for d in det_dicts]
256
+ all_scores.extend(pred_scores)
257
+
258
+ if run_diagnostics:
259
+ dets_m = [
260
+ Detection(
261
+ rbox=d["rbox"],
262
+ score=d["score"],
263
+ class_id=d["label"],
264
+ class_name=class_names[d["label"] - 1]
265
+ if class_names and 1 <= d["label"] <= len(class_names)
266
+ else f"class_{d['label']}",
267
+ image_id=image_id,
268
+ )
269
+ for d in det_dicts
270
+ ]
271
+ gt_m = [
272
+ GroundTruth(
273
+ rbox=gt.rbox,
274
+ class_id=gt.class_id,
275
+ class_name=gt.class_name,
276
+ difficult=gt.difficult,
277
+ image_id=image_id,
278
+ )
279
+ for gt in gt_list_raw
280
+ ]
281
+ if resolved_metrics_margin_px > 0 and img_w > 0 and img_h > 0:
282
+ dets_m = [
283
+ d
284
+ for d in dets_m
285
+ if _rbox_centroid_in_tile_interior(d.rbox, img_w, img_h, resolved_metrics_margin_px)
286
+ ]
287
+ gt_m = [
288
+ g
289
+ for g in gt_m
290
+ if _rbox_centroid_in_tile_interior(g.rbox, img_w, img_h, resolved_metrics_margin_px)
291
+ ]
292
+ all_detections[image_id] = dets_m
293
+ all_ground_truths[image_id] = gt_m
294
+
295
+ pred_boxes_array = [rbox_to_array(rb).tolist() for rb in rboxes]
296
+ results.append(
297
+ {
298
+ "image_name": img_name,
299
+ "image_path": os.path.relpath(img_path, data_root),
300
+ "image_width": int(img_w),
301
+ "image_height": int(img_h),
302
+ "resize_mode": preprocessing.get("resize_mode", "fixed"),
303
+ "target_size": preprocessing.get("target_size", [1024, 1024]),
304
+ "num_gt": int(num_gt),
305
+ "num_pred": int(len(rboxes)),
306
+ "predictions": [
307
+ {
308
+ "bbox": pred_boxes_array[i],
309
+ "score": float(pred_scores[i]),
310
+ "label": int(pred_labels[i]),
311
+ "class_name": (
312
+ class_names[pred_labels[i] - 1]
313
+ if class_names and 1 <= pred_labels[i] <= len(class_names)
314
+ else f"class_{pred_labels[i]}"
315
+ ),
316
+ }
317
+ for i in range(len(rboxes))
318
+ ],
319
+ "ground_truths": gt_entries,
320
+ "stats": {"inference_backend": "tensorflow_keras_export"},
321
+ }
322
+ )
323
+
324
+ t_infer = time.perf_counter() - t0
325
+ experiment_dir = str(config_path.parent)
326
+ checkpoint_ref = str(reference_checkpoint or config_path)
327
+
328
+ diagnostics = None
329
+ analysis = None
330
+ if run_diagnostics:
331
+ diagnostics, analysis = run_diagnostics_pipeline(
332
+ experiment_dir=experiment_dir,
333
+ checkpoint_path=checkpoint_ref,
334
+ config_path=str(config_path),
335
+ data_root=data_root,
336
+ data_split=data_split,
337
+ class_names=class_names,
338
+ score_threshold=float(score_threshold),
339
+ per_cls_thr=per_cls_thr,
340
+ nms_class_agnostic=nms_class_agnostic,
341
+ iou_threshold=iou_threshold,
342
+ pr_iou_threshold=None,
343
+ pr_threshold_min=0.0,
344
+ pr_threshold_max=1.0,
345
+ pr_threshold_step=0.1,
346
+ per_class_threshold_analysis=False,
347
+ resolved_metrics_margin_px=int(resolved_metrics_margin_px),
348
+ all_detections=all_detections,
349
+ all_ground_truths=all_ground_truths,
350
+ results=results,
351
+ all_scores=all_scores,
352
+ image_name_by_id=image_name_by_id,
353
+ sliding_window_positions_total=None,
354
+ window_batch_effective=None,
355
+ t_infer_sec=float(t_infer),
356
+ device="tensorflow",
357
+ output_dir=str(output_dir),
358
+ tile_metrics_csv=None,
359
+ )
360
+
361
+ meta_path = detect_dir / "export_meta.json"
362
+ export_meta = {}
363
+ if meta_path.is_file():
364
+ export_meta = json.loads(meta_path.read_text(encoding="utf-8"))
365
+
366
+ metadata: Dict[str, Any] = {
367
+ "timestamp": datetime.now().isoformat(),
368
+ "inference_backend": "tensorflow_keras_export",
369
+ "detect_bundle": str(detect_dir.resolve()),
370
+ "keras_model": str(keras_path.resolve()),
371
+ "export_meta_core_backend": export_meta.get("core_backend"),
372
+ "experiment_dir": experiment_dir,
373
+ "checkpoint": checkpoint_ref,
374
+ "config_file": str(config_path),
375
+ "pytorch_reference_checkpoint": str(reference_checkpoint) if reference_checkpoint else None,
376
+ "data_root": str(data_root),
377
+ "data_split": data_split,
378
+ "device": "tensorflow",
379
+ "ort_device": get_ort_device(),
380
+ "ort_providers": ort_providers,
381
+ "class_names": class_names,
382
+ "score_threshold": score_threshold,
383
+ "per_class_score_threshold": per_cls_thr,
384
+ "nms_class_agnostic": nms_class_agnostic,
385
+ "total_images": len(results),
386
+ "total_predictions": sum(r["num_pred"] for r in results),
387
+ "total_ground_truth": sum(r["num_gt"] for r in results),
388
+ "inference_loop_seconds": float(t_infer),
389
+ "bbox_coordinate_space": "image_pixels",
390
+ "metrics_margin_pixels": int(resolved_metrics_margin_px),
391
+ "preprocess_note": "Resize+ToTensor+Normalize (same as PyTorch make preds)",
392
+ }
393
+ if diagnostics is not None:
394
+ metadata["diagnostics"] = diagnostics
395
+ if analysis is not None:
396
+ metadata["best_threshold_f2"] = analysis.get("best_threshold", {})
397
+ metadata["analysis_file"] = analysis.get("artifacts", {}).get(
398
+ "analysis_json", f"analysis_iou{iou_threshold:.2f}.json"
399
+ )
400
+ metadata["pr_iou_threshold"] = analysis.get("iou_threshold", iou_threshold)
401
+
402
+ json_path = output_dir / "predictions.json"
403
+ with open(json_path, "w", encoding="utf-8") as f:
404
+ json.dump({"metadata": metadata, "results": results}, f, indent=2)
405
+
406
+ print(f"Wrote {json_path}")
407
+ return {"output_dir": str(output_dir), "metadata": metadata}
408
+
409
+
410
+ def main() -> None:
411
+ p = argparse.ArgumentParser(description="Val inference via TF/Keras export bundle.")
412
+ p.add_argument("--config", type=Path, required=True)
413
+ p.add_argument("--detect-dir", type=Path, required=True, help="Directory with keras_model.keras")
414
+ p.add_argument("--output-dir", type=Path, default=None, help="Default: export/artifacts/predictions/<ts>")
415
+ p.add_argument("--data-root", type=Path, default=None)
416
+ p.add_argument("--data-split", default="val", choices=("train", "val", "test"))
417
+ p.add_argument("--val-dir", type=Path, default=None)
418
+ p.add_argument(
419
+ "--reference-checkpoint",
420
+ type=Path,
421
+ default=None,
422
+ help="PyTorch .pth path recorded in metadata for comparison (default: deploy weights path only in meta).",
423
+ )
424
+ p.add_argument("--no-diagnostics", action="store_true", help="Skip mAP/PR (inference-only JSON).")
425
+ p.add_argument(
426
+ "--ort-device",
427
+ default=None,
428
+ choices=("cpu", "cuda", "auto"),
429
+ help="ONNX Runtime EP for the exported ONNX core (default: cpu or ORIENTED_DET_ORT_DEVICE).",
430
+ )
431
+ args = p.parse_args()
432
+
433
+ run_tf_inference_and_save(
434
+ config_path=args.config,
435
+ detect_dir=args.detect_dir,
436
+ output_dir=args.output_dir,
437
+ data_root=args.data_root,
438
+ data_split=args.data_split,
439
+ val_dir=args.val_dir,
440
+ run_diagnostics=not args.no_diagnostics,
441
+ reference_checkpoint=args.reference_checkpoint,
442
+ ort_device=args.ort_device,
443
+ )
444
+
445
+
446
+ if __name__ == "__main__":
447
+ main()
@@ -0,0 +1,32 @@
1
+ #!/usr/bin/env python3
2
+ """Optional: convert SavedModel to TFLite (float32)."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import sys
8
+ from pathlib import Path
9
+
10
+
11
+ def main() -> None:
12
+ p = argparse.ArgumentParser(description="SavedModel → TFLite float32.")
13
+ p.add_argument("--saved-model", type=Path, required=True)
14
+ p.add_argument("--output", type=Path, required=True, help="Output .tflite path")
15
+ args = p.parse_args()
16
+
17
+ try:
18
+ import tensorflow as tf
19
+ except ImportError as e:
20
+ print("Install TensorFlow: pip install -r export/requirements-export.txt", file=sys.stderr)
21
+ raise SystemExit(1) from e
22
+
23
+ converter = tf.lite.TFLiteConverter.from_saved_model(str(args.saved_model))
24
+ converter.optimizations = []
25
+ tflite_model = converter.convert()
26
+ args.output.parent.mkdir(parents=True, exist_ok=True)
27
+ args.output.write_bytes(tflite_model)
28
+ print(f"Wrote {args.output} ({len(tflite_model)} bytes)")
29
+
30
+
31
+ if __name__ == "__main__":
32
+ main()
@@ -0,0 +1,82 @@
1
+ """Optional ONNX smoke (each test skips if deps missing)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import io
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ import pytest
10
+ import torch
11
+
12
+ _REPO_ROOT = Path(__file__).resolve().parents[2]
13
+ _EXPORT_DIR = Path(__file__).resolve().parents[1]
14
+ if str(_REPO_ROOT) not in sys.path:
15
+ sys.path.insert(0, str(_REPO_ROOT))
16
+ if str(_EXPORT_DIR) not in sys.path:
17
+ sys.path.insert(0, str(_EXPORT_DIR))
18
+
19
+ import wrappers as _wrappers # noqa: E402
20
+ from oriented_det.models.rotated_retinanet import RotatedRetinaNet # noqa: E402
21
+
22
+ BackboneExportWrapper = _wrappers.BackboneExportWrapper
23
+ RetinaNetBackboneHeadExportWrapper = _wrappers.RetinaNetBackboneHeadExportWrapper
24
+
25
+
26
+ def test_export_backbone_onnx_roundtrip_ort() -> None:
27
+ pytest.importorskip("onnx")
28
+ pytest.importorskip("onnxruntime")
29
+ import onnx
30
+ import onnxruntime as ort
31
+
32
+ model = RotatedRetinaNet(num_classes=2, backbone_name="resnet18", pretrained_backbone=False)
33
+ model.eval()
34
+ wrap = BackboneExportWrapper(model.backbone)
35
+ x = torch.randn(1, 3, 128, 128, dtype=torch.float32)
36
+ with torch.no_grad():
37
+ ref = wrap(x)
38
+ buf = io.BytesIO()
39
+ torch.onnx.export(
40
+ wrap,
41
+ x,
42
+ buf,
43
+ input_names=["images"],
44
+ output_names=[f"fpn_{i}" for i in range(len(ref))],
45
+ opset_version=17,
46
+ do_constant_folding=True,
47
+ )
48
+ buf.seek(0)
49
+ onnx_model = onnx.load(buf)
50
+ onnx.checker.check_model(onnx_model)
51
+ sess = ort.InferenceSession(buf.getvalue(), providers=["CPUExecutionProvider"])
52
+ out = sess.run(None, {"images": x.numpy()})
53
+ assert len(out) == len(ref)
54
+ for a, b in zip(out, ref):
55
+ assert a.shape == tuple(b.shape)
56
+
57
+
58
+ def test_export_retinanet_heads_onnx_checker() -> None:
59
+ pytest.importorskip("onnx")
60
+ import onnx
61
+
62
+ model = RotatedRetinaNet(num_classes=2, backbone_name="resnet18", pretrained_backbone=False)
63
+ model.eval()
64
+ wrap = RetinaNetBackboneHeadExportWrapper(model)
65
+ x = torch.randn(1, 3, 128, 128, dtype=torch.float32)
66
+ with torch.no_grad():
67
+ ref = wrap(x)
68
+ names = []
69
+ for i in range(len(ref) // 2):
70
+ names.extend([f"level{i}_cls", f"level{i}_bbox"])
71
+ buf = io.BytesIO()
72
+ torch.onnx.export(
73
+ wrap,
74
+ x,
75
+ buf,
76
+ input_names=["images"],
77
+ output_names=names,
78
+ opset_version=17,
79
+ do_constant_folding=True,
80
+ )
81
+ buf.seek(0)
82
+ onnx.checker.check_model(onnx.load(buf))
@@ -0,0 +1,105 @@
1
+ """Tests for export wrappers (no ONNX/TF required)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ import pytest
9
+ import torch
10
+
11
+ _REPO_ROOT = Path(__file__).resolve().parents[2]
12
+ _EXPORT_DIR = Path(__file__).resolve().parents[1]
13
+ if str(_REPO_ROOT) not in sys.path:
14
+ sys.path.insert(0, str(_REPO_ROOT))
15
+ if str(_EXPORT_DIR) not in sys.path:
16
+ sys.path.insert(0, str(_EXPORT_DIR))
17
+
18
+ import wrappers as _wrappers # noqa: E402
19
+ from oriented_det import RotatedFasterRCNN # noqa: E402
20
+ from oriented_det.models.rotated_retinanet import RotatedRetinaNet # noqa: E402
21
+
22
+ BackboneExportWrapper = _wrappers.BackboneExportWrapper
23
+ RetinaNetBackboneHeadExportWrapper = _wrappers.RetinaNetBackboneHeadExportWrapper
24
+ RotatedFasterRCNNPreNmsExportWrapper = _wrappers.RotatedFasterRCNNPreNmsExportWrapper
25
+
26
+
27
+ @pytest.fixture
28
+ def tiny_retinanet() -> RotatedRetinaNet:
29
+ return RotatedRetinaNet(
30
+ num_classes=3,
31
+ backbone_name="resnet18",
32
+ pretrained_backbone=False,
33
+ trainable_layers=5,
34
+ )
35
+
36
+
37
+ def test_backbone_wrapper_shapes(tiny_retinanet: RotatedRetinaNet) -> None:
38
+ w = BackboneExportWrapper(tiny_retinanet.backbone)
39
+ w.eval()
40
+ x = torch.randn(2, 3, 128, 128, dtype=torch.float32)
41
+ with torch.no_grad():
42
+ outs = w(x)
43
+ assert isinstance(outs, tuple)
44
+ assert len(outs) >= 1
45
+ for t in outs:
46
+ assert t.dim() == 4
47
+ assert t.shape[0] == 2
48
+
49
+
50
+ def test_retinanet_heads_wrapper_shapes(tiny_retinanet: RotatedRetinaNet) -> None:
51
+ w = RetinaNetBackboneHeadExportWrapper(tiny_retinanet)
52
+ w.eval()
53
+ x = torch.randn(1, 3, 128, 128, dtype=torch.float32)
54
+ with torch.no_grad():
55
+ outs = w(x)
56
+ assert len(outs) % 2 == 0
57
+ for i in range(0, len(outs), 2):
58
+ cls_t, box_t = outs[i], outs[i + 1]
59
+ assert cls_t.shape[0] == 1 and box_t.shape[0] == 1
60
+ assert cls_t.shape[1] > 0 and box_t.shape[1] % 5 == 0
61
+
62
+
63
+ def test_retinanet_heads_batch_unbind_matches_stack(tiny_retinanet: RotatedRetinaNet) -> None:
64
+ """B>1 path uses torch.unbind inside extract_backbone_features input list."""
65
+ w = RetinaNetBackboneHeadExportWrapper(tiny_retinanet)
66
+ w.eval()
67
+ x2 = torch.randn(2, 3, 128, 128, dtype=torch.float32)
68
+ with torch.no_grad():
69
+ a = w(x2)
70
+ x1 = torch.stack([x2[0], x2[1]], dim=0)
71
+ assert torch.allclose(x1, x2)
72
+ with torch.no_grad():
73
+ b0 = w(x2[0:1])
74
+ b1 = w(x2[1:2])
75
+ assert len(a) == len(b0) == len(b1)
76
+ for i in range(len(a)):
77
+ assert a[i].shape[0] == 2
78
+ assert torch.allclose(a[i][0], b0[i][0])
79
+ assert torch.allclose(a[i][1], b1[i][0])
80
+
81
+
82
+ @pytest.fixture
83
+ def tiny_faster_rcnn() -> RotatedFasterRCNN:
84
+ return RotatedFasterRCNN(
85
+ num_classes=3,
86
+ backbone_name="resnet18",
87
+ pretrained_backbone=False,
88
+ trainable_layers=5,
89
+ rpn_post_nms_top_n=64,
90
+ rpn_pre_nms_top_n=64,
91
+ )
92
+
93
+
94
+ def test_faster_rcnn_pre_nms_wrapper_shapes(tiny_faster_rcnn: RotatedFasterRCNN) -> None:
95
+ h, w = 128, 128
96
+ wwrap = RotatedFasterRCNNPreNmsExportWrapper(tiny_faster_rcnn, height=h, width=w, max_candidates=32)
97
+ wwrap.eval()
98
+ x = torch.randn(1, 3, h, w, dtype=torch.float32)
99
+ with torch.no_grad():
100
+ boxes, scores, labels, count = wwrap(x)
101
+ assert boxes.shape == (32, 5)
102
+ assert scores.shape == (32,)
103
+ assert labels.shape == (32,)
104
+ assert count.ndim == 0
105
+ assert int(count.item()) <= 32