oriented-det 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- export/__init__.py +9 -0
- export/ort_runtime.py +67 -0
- export/postprocess.py +151 -0
- export/scripts/__init__.py +6 -0
- export/scripts/build_faster_rcnn_savedmodel.py +104 -0
- export/scripts/export_onnx.py +210 -0
- export/scripts/onnx_to_savedmodel.py +35 -0
- export/scripts/predict_savedmodel.py +94 -0
- export/scripts/save_predictions_tf.py +447 -0
- export/scripts/to_tflite.py +32 -0
- export/tests/test_export_onnx_optional.py +82 -0
- export/tests/test_export_wrappers.py +105 -0
- export/tests/test_faster_rcnn_export_parity.py +201 -0
- export/tests/test_ort_runtime.py +41 -0
- export/tf_serving_model.py +96 -0
- export/val_dataset.py +116 -0
- export/wrappers.py +161 -0
- oriented_det/__init__.py +77 -0
- oriented_det/cli/__init__.py +92 -0
- oriented_det/cli/train.py +18 -0
- oriented_det/configs/_base_/augmentation.json +21 -0
- oriented_det/configs/_base_/datasets/dota_le90.json +21 -0
- oriented_det/configs/_base_/fp16.json +5 -0
- oriented_det/configs/_base_/models/oriented_rcnn_r50.json +26 -0
- oriented_det/configs/_base_/models/rotated_faster_rcnn_r50.json +53 -0
- oriented_det/configs/_base_/models/rotated_retinanet_r50.json +30 -0
- oriented_det/configs/_base_/preprocessing.json +8 -0
- oriented_det/configs/_base_/schedules/1x.json +33 -0
- oriented_det/configs/config.schema.json +404 -0
- oriented_det/configs/oriented_rcnn/dota_le90_1x.json +121 -0
- oriented_det/configs/oriented_rcnn/dota_le90_3x.json +11 -0
- oriented_det/configs/rotated_faster_rcnn/dota_le90_1x.json +119 -0
- oriented_det/configs/rotated_faster_rcnn/dota_le90_3x.json +9 -0
- oriented_det/configs/rotated_retinanet/dota_le90_1x.json +107 -0
- oriented_det/configs/rotated_retinanet/dota_le90_3x.json +30 -0
- oriented_det/data/__init__.py +89 -0
- oriented_det/data/airbus_playground.py +611 -0
- oriented_det/data/dota.py +741 -0
- oriented_det/data/dota_classes.py +26 -0
- oriented_det/data/evaluation.py +648 -0
- oriented_det/data/flips.py +115 -0
- oriented_det/data/preprocessing.py +335 -0
- oriented_det/data/tiling.py +399 -0
- oriented_det/data/transforms.py +377 -0
- oriented_det/geometry/__init__.py +8 -0
- oriented_det/geometry/poly.py +127 -0
- oriented_det/geometry/qbox.py +63 -0
- oriented_det/geometry/rbox.py +250 -0
- oriented_det/geometry/transforms.py +266 -0
- oriented_det/models/__init__.py +36 -0
- oriented_det/models/backbones/__init__.py +11 -0
- oriented_det/models/backbones/resnet_fpn.py +79 -0
- oriented_det/models/backbones/utils.py +81 -0
- oriented_det/models/bbox_coder.py +355 -0
- oriented_det/models/faster_rcnn_inference.py +494 -0
- oriented_det/models/horizontal_roi_coder.py +155 -0
- oriented_det/models/oriented_rcnn.py +1256 -0
- oriented_det/models/oriented_roi.py +1664 -0
- oriented_det/models/oriented_rpn.py +2104 -0
- oriented_det/models/rotated_retinanet.py +1030 -0
- oriented_det/models/utils.py +590 -0
- oriented_det/ops/__init__.py +58 -0
- oriented_det/ops/gpu_ops.py +1109 -0
- oriented_det/ops/iou.py +172 -0
- oriented_det/ops/kfiou.py +275 -0
- oriented_det/ops/nms.py +202 -0
- oriented_det/ops/probiou.py +165 -0
- oriented_det/ops/rotated_ops.py +122 -0
- oriented_det/ops/utils.py +257 -0
- oriented_det/pretrained/__init__.py +23 -0
- oriented_det/pretrained/hub.py +249 -0
- oriented_det/pretrained/manifest.json +46 -0
- oriented_det/runtime/__init__.py +29 -0
- oriented_det/runtime/checkpoint.py +274 -0
- oriented_det/runtime/collate.py +348 -0
- oriented_det/runtime/inference.py +1286 -0
- oriented_det/train/__init__.py +102 -0
- oriented_det/train/config.py +872 -0
- oriented_det/train/engine.py +2933 -0
- oriented_det/train/grouped_ce.py +139 -0
- oriented_det/train/piecewise_schedule.py +33 -0
- oriented_det/train/profiler.py +287 -0
- oriented_det/train/utils.py +999 -0
- oriented_det/utils/__init__.py +31 -0
- oriented_det/utils/config.py +376 -0
- oriented_det/utils/device.py +35 -0
- oriented_det/utils/logging.py +163 -0
- oriented_det/utils/progress.py +62 -0
- oriented_det/utils/viz.py +181 -0
- oriented_det-0.1.0.dist-info/METADATA +313 -0
- oriented_det-0.1.0.dist-info/RECORD +115 -0
- oriented_det-0.1.0.dist-info/WHEEL +5 -0
- oriented_det-0.1.0.dist-info/entry_points.txt +2 -0
- oriented_det-0.1.0.dist-info/licenses/LICENSE +202 -0
- oriented_det-0.1.0.dist-info/top_level.txt +3 -0
- tools/__init__.py +1 -0
- tools/app.py +1284 -0
- tools/dataset_stats.py +389 -0
- tools/dota_labels_to_comma.py +132 -0
- tools/free_gpu.py +142 -0
- tools/generate_airbus_playground_csv.py +154 -0
- tools/image_demo.py +200 -0
- tools/lr_finder.py +771 -0
- tools/measure_sampled_riou_error.py +483 -0
- tools/playground_to_dota.py +290 -0
- tools/pretrained_download.py +49 -0
- tools/preview_augmentation.py +510 -0
- tools/publish_checkpoint.py +96 -0
- tools/save_predictions.py +2350 -0
- tools/sync_vendored_configs.py +94 -0
- tools/tile_dota.py +447 -0
- tools/train.py +2324 -0
- tools/train_example.py +244 -0
- tools/train_multi_gpu.py +367 -0
- tools/visualize_boxes.py +183 -0
|
@@ -0,0 +1,447 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Run validation inference with the TF/Keras export bundle; write predictions.json for metrics."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import json
|
|
8
|
+
import os
|
|
9
|
+
import sys
|
|
10
|
+
import time
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any, Dict, List, Optional
|
|
14
|
+
|
|
15
|
+
_REPO_ROOT = Path(__file__).resolve().parents[2]
|
|
16
|
+
if str(_REPO_ROOT) not in sys.path:
|
|
17
|
+
sys.path.insert(0, str(_REPO_ROOT))
|
|
18
|
+
|
|
19
|
+
import cv2
|
|
20
|
+
import numpy as np
|
|
21
|
+
from PIL import Image
|
|
22
|
+
from tqdm import tqdm
|
|
23
|
+
|
|
24
|
+
from export.ort_runtime import configure_ort_device, get_ort_device
|
|
25
|
+
from export.tf_serving_model import load_keras_detect_model
|
|
26
|
+
from export.val_dataset import collect_split_images
|
|
27
|
+
from oriented_det.data import Detection, GroundTruth
|
|
28
|
+
from oriented_det.geometry import RBox, normalize_le90
|
|
29
|
+
from oriented_det.train.config import (
|
|
30
|
+
TrainingExperimentConfig,
|
|
31
|
+
effective_eval_metric_thresholds,
|
|
32
|
+
get_preprocessing_params,
|
|
33
|
+
resolve_inference_sliding_window_overlap_pixels,
|
|
34
|
+
)
|
|
35
|
+
from oriented_det.utils import tqdm_progress_stream
|
|
36
|
+
from oriented_det.runtime.inference import _preprocess_image_tensor_training_style, get_model_size
|
|
37
|
+
from oriented_det.runtime.checkpoint import load_model_from_checkpoint
|
|
38
|
+
from tools.save_predictions import (
|
|
39
|
+
_annotations_to_ground_truths,
|
|
40
|
+
_resolve_metrics_margin_pixels,
|
|
41
|
+
_rbox_centroid_in_tile_interior,
|
|
42
|
+
load_dota_annotations,
|
|
43
|
+
load_gt_as_ground_truths,
|
|
44
|
+
rbox_to_array,
|
|
45
|
+
run_diagnostics_pipeline,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _load_gt_entries(
|
|
50
|
+
img_path: Path,
|
|
51
|
+
label_dir: Optional[Path],
|
|
52
|
+
gt_by_image_path: Optional[Dict[Path, list]],
|
|
53
|
+
class_map: Dict[str, int],
|
|
54
|
+
) -> tuple[int, list, list]:
|
|
55
|
+
if gt_by_image_path is not None:
|
|
56
|
+
gt_list = gt_by_image_path.get(img_path, [])
|
|
57
|
+
num_gt = len(gt_list)
|
|
58
|
+
gt_entries = [
|
|
59
|
+
{
|
|
60
|
+
"bbox": rbox_to_array(gt.rbox).tolist(),
|
|
61
|
+
"class_name": gt.class_name,
|
|
62
|
+
"class_id": int(gt.class_id),
|
|
63
|
+
"difficult": int(getattr(gt, "difficult", 0)),
|
|
64
|
+
}
|
|
65
|
+
for gt in gt_list
|
|
66
|
+
]
|
|
67
|
+
return num_gt, gt_entries, gt_list
|
|
68
|
+
|
|
69
|
+
txt_path = (label_dir / f"{img_path.stem}.txt") if label_dir is not None else None
|
|
70
|
+
try:
|
|
71
|
+
if txt_path and txt_path.exists():
|
|
72
|
+
gt_rboxes, gt_class_names = load_dota_annotations(str(txt_path))
|
|
73
|
+
else:
|
|
74
|
+
gt_rboxes = np.array([]).reshape(0, 5)
|
|
75
|
+
gt_class_names = []
|
|
76
|
+
num_gt = len(gt_rboxes)
|
|
77
|
+
gt_entries = [
|
|
78
|
+
{
|
|
79
|
+
"bbox": gt_rboxes[i].tolist(),
|
|
80
|
+
"class_name": gt_class_names[i] if i < len(gt_class_names) else "unknown",
|
|
81
|
+
"class_id": int(class_map.get(gt_class_names[i], -1)) if i < len(gt_class_names) else -1,
|
|
82
|
+
"difficult": 0,
|
|
83
|
+
}
|
|
84
|
+
for i in range(len(gt_rboxes))
|
|
85
|
+
]
|
|
86
|
+
gt_list = load_gt_as_ground_truths(txt_path, class_map) if txt_path and txt_path.exists() else []
|
|
87
|
+
except Exception as exc:
|
|
88
|
+
print(f"Warning: Could not load GT for {img_path.name}: {exc}")
|
|
89
|
+
num_gt = 0
|
|
90
|
+
gt_entries = []
|
|
91
|
+
gt_list = []
|
|
92
|
+
return num_gt, gt_entries, gt_list
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def infer_keras_on_image(
|
|
96
|
+
keras_model,
|
|
97
|
+
pil_image: Image.Image,
|
|
98
|
+
preprocessing: dict,
|
|
99
|
+
) -> List[Dict[str, Any]]:
|
|
100
|
+
"""Run Keras detect bundle on one image; return list of {rbox, score, label}."""
|
|
101
|
+
import tensorflow as tf
|
|
102
|
+
|
|
103
|
+
image_width, image_height = pil_image.size
|
|
104
|
+
slice_h, slice_w = get_model_size(preprocessing)
|
|
105
|
+
if image_height > slice_h or image_width > slice_w:
|
|
106
|
+
raise NotImplementedError(
|
|
107
|
+
f"Image {image_width}x{image_height} exceeds model canvas {slice_w}x{slice_h}. "
|
|
108
|
+
"TF export preds does not implement sliding-window tiling; use pre-tiled val images."
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
tensor = _preprocess_image_tensor_training_style(pil_image, preprocessing)
|
|
112
|
+
batch = tf.constant(tensor.unsqueeze(0).numpy(), dtype=tf.float32)
|
|
113
|
+
detections_t, num_t = keras_model(batch, training=False)
|
|
114
|
+
n = int(num_t.numpy().reshape(-1)[0])
|
|
115
|
+
if n <= 0:
|
|
116
|
+
return []
|
|
117
|
+
|
|
118
|
+
det = detections_t.numpy()[:n]
|
|
119
|
+
from oriented_det.data.preprocessing import build_spatial_meta_from_dims, remap_detections_to_original
|
|
120
|
+
|
|
121
|
+
mode = preprocessing.get("resize_mode", "fixed")
|
|
122
|
+
ts = preprocessing.get("target_size", (slice_h, slice_w))
|
|
123
|
+
meta = build_spatial_meta_from_dims(mode, image_width, image_height, ts)
|
|
124
|
+
|
|
125
|
+
model_dets: List[Dict[str, Any]] = []
|
|
126
|
+
for row in det:
|
|
127
|
+
cx, cy, w, h, ang, score, label = [float(x) for x in row]
|
|
128
|
+
rbox = normalize_le90(RBox(cx=cx, cy=cy, width=w, height=h, angle=ang))
|
|
129
|
+
model_dets.append({"rbox": rbox, "score": score, "label": int(label)})
|
|
130
|
+
return remap_detections_to_original(model_dets, meta)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def run_tf_inference_and_save(
|
|
134
|
+
*,
|
|
135
|
+
config_path: Path,
|
|
136
|
+
detect_dir: Path,
|
|
137
|
+
output_dir: Optional[Path] = None,
|
|
138
|
+
data_root: Optional[Path] = None,
|
|
139
|
+
data_split: str = "val",
|
|
140
|
+
val_dir: Optional[Path] = None,
|
|
141
|
+
run_diagnostics: bool = True,
|
|
142
|
+
reference_checkpoint: Optional[Path] = None,
|
|
143
|
+
ort_device: Optional[str] = None,
|
|
144
|
+
) -> Dict[str, Any]:
|
|
145
|
+
ort_providers = configure_ort_device(ort_device)
|
|
146
|
+
config = TrainingExperimentConfig.load(config_path)
|
|
147
|
+
class_names = list(config.class_names or [])
|
|
148
|
+
preprocessing = get_preprocessing_params(config)
|
|
149
|
+
|
|
150
|
+
if not data_root:
|
|
151
|
+
if getattr(config, "dataset", None) and getattr(config.dataset, "data_root", None):
|
|
152
|
+
data_root = Path(config.dataset.data_root)
|
|
153
|
+
else:
|
|
154
|
+
raise ValueError("data_root required (CLI or config.dataset.data_root).")
|
|
155
|
+
data_root = Path(data_root)
|
|
156
|
+
|
|
157
|
+
keras_path = detect_dir / "keras_model.keras"
|
|
158
|
+
if not keras_path.is_file():
|
|
159
|
+
raise FileNotFoundError(f"Missing Keras bundle: {keras_path} (run make export-tf first).")
|
|
160
|
+
keras_model = load_keras_detect_model(keras_path)
|
|
161
|
+
|
|
162
|
+
if output_dir is None:
|
|
163
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
164
|
+
output_dir = detect_dir.parent / "predictions" / timestamp
|
|
165
|
+
output_dir = Path(output_dir)
|
|
166
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
167
|
+
|
|
168
|
+
cfg_thr_sc, cfg_thr_pc, cfg_thr_iou = effective_eval_metric_thresholds(config)
|
|
169
|
+
score_threshold = cfg_thr_sc
|
|
170
|
+
per_cls_thr = cfg_thr_pc
|
|
171
|
+
nms_threshold = float(
|
|
172
|
+
getattr(getattr(config, "production", None), "final_nms_iou_threshold", None)
|
|
173
|
+
or getattr(getattr(config, "model", None), "final_nms_iou_threshold", 0.1)
|
|
174
|
+
)
|
|
175
|
+
nms_class_agnostic = bool(
|
|
176
|
+
getattr(getattr(config, "production", None), "nms_class_agnostic", False)
|
|
177
|
+
or getattr(getattr(config, "model", None), "nms_class_agnostic", False)
|
|
178
|
+
)
|
|
179
|
+
iou_threshold = float(cfg_thr_iou)
|
|
180
|
+
overlap_pixels = resolve_inference_sliding_window_overlap_pixels(config)
|
|
181
|
+
resolved_metrics_margin_px = _resolve_metrics_margin_pixels(
|
|
182
|
+
margin_pixels=getattr(getattr(config, "production", None), "ignore_margin_pixels", None),
|
|
183
|
+
overlap_ratio=None,
|
|
184
|
+
overlap_pixels=overlap_pixels,
|
|
185
|
+
preprocessing=preprocessing,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
split_images, label_dir, dataset_format = collect_split_images(
|
|
189
|
+
config, data_root, data_split=data_split, val_dir=val_dir
|
|
190
|
+
)
|
|
191
|
+
print(
|
|
192
|
+
f"TF export inference: {len(split_images)} {data_split} images → {output_dir} "
|
|
193
|
+
f"(ort_device={get_ort_device()}, providers={ort_providers})"
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
class_map = {name: i for i, name in enumerate(class_names)} if class_names else {}
|
|
197
|
+
gt_by_image_path = None
|
|
198
|
+
if dataset_format == "airbus_playground":
|
|
199
|
+
from oriented_det.data.airbus_playground import AirbusPlaygroundCSVDataset
|
|
200
|
+
|
|
201
|
+
ds_config = config.dataset
|
|
202
|
+
airbus_dataset = AirbusPlaygroundCSVDataset(
|
|
203
|
+
data_root=data_root,
|
|
204
|
+
split=data_split,
|
|
205
|
+
annotations_file=ds_config.annotations_file,
|
|
206
|
+
split_file=ds_config.split_file,
|
|
207
|
+
val_split_id=getattr(ds_config, "val_split_id", 0),
|
|
208
|
+
difficult_strategy=ds_config.difficult_strategy,
|
|
209
|
+
allowed_classes=getattr(ds_config, "allowed_classes", None),
|
|
210
|
+
ignore_labels=getattr(ds_config, "ignore_labels", None) or [],
|
|
211
|
+
map_labels=getattr(ds_config, "map_labels", None) or {},
|
|
212
|
+
)
|
|
213
|
+
gt_by_image_path = {}
|
|
214
|
+
for idx in range(len(airbus_dataset)):
|
|
215
|
+
sample = airbus_dataset[idx]
|
|
216
|
+
gt_by_image_path[Path(sample.image_path)] = _annotations_to_ground_truths(
|
|
217
|
+
list(sample.annotations), class_map
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
results: List[Dict[str, Any]] = []
|
|
221
|
+
all_detections: Dict[str, list] = {}
|
|
222
|
+
all_ground_truths: Dict[str, list] = {}
|
|
223
|
+
all_scores: List[float] = []
|
|
224
|
+
image_name_by_id: Dict[str, str] = {}
|
|
225
|
+
|
|
226
|
+
t0 = time.perf_counter()
|
|
227
|
+
for img_path in tqdm(
|
|
228
|
+
split_images,
|
|
229
|
+
desc=f"TF export {data_split}",
|
|
230
|
+
file=tqdm_progress_stream(),
|
|
231
|
+
):
|
|
232
|
+
img_name = img_path.name
|
|
233
|
+
image_id = img_path.stem
|
|
234
|
+
image_name_by_id[image_id] = img_name
|
|
235
|
+
|
|
236
|
+
img_bgr = cv2.imread(str(img_path))
|
|
237
|
+
if img_bgr is None:
|
|
238
|
+
print(f"Warning: skip unreadable {img_path}")
|
|
239
|
+
continue
|
|
240
|
+
img_h, img_w = img_bgr.shape[:2]
|
|
241
|
+
pil = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
|
|
242
|
+
|
|
243
|
+
num_gt, gt_entries, gt_list_raw = _load_gt_entries(
|
|
244
|
+
img_path, label_dir, gt_by_image_path, class_map
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
try:
|
|
248
|
+
det_dicts = infer_keras_on_image(keras_model, pil, preprocessing)
|
|
249
|
+
except Exception as exc:
|
|
250
|
+
print(f"Warning: inference failed for {img_name}: {exc}")
|
|
251
|
+
det_dicts = []
|
|
252
|
+
|
|
253
|
+
pred_scores = [d["score"] for d in det_dicts]
|
|
254
|
+
pred_labels = [d["label"] for d in det_dicts]
|
|
255
|
+
rboxes = [d["rbox"] for d in det_dicts]
|
|
256
|
+
all_scores.extend(pred_scores)
|
|
257
|
+
|
|
258
|
+
if run_diagnostics:
|
|
259
|
+
dets_m = [
|
|
260
|
+
Detection(
|
|
261
|
+
rbox=d["rbox"],
|
|
262
|
+
score=d["score"],
|
|
263
|
+
class_id=d["label"],
|
|
264
|
+
class_name=class_names[d["label"] - 1]
|
|
265
|
+
if class_names and 1 <= d["label"] <= len(class_names)
|
|
266
|
+
else f"class_{d['label']}",
|
|
267
|
+
image_id=image_id,
|
|
268
|
+
)
|
|
269
|
+
for d in det_dicts
|
|
270
|
+
]
|
|
271
|
+
gt_m = [
|
|
272
|
+
GroundTruth(
|
|
273
|
+
rbox=gt.rbox,
|
|
274
|
+
class_id=gt.class_id,
|
|
275
|
+
class_name=gt.class_name,
|
|
276
|
+
difficult=gt.difficult,
|
|
277
|
+
image_id=image_id,
|
|
278
|
+
)
|
|
279
|
+
for gt in gt_list_raw
|
|
280
|
+
]
|
|
281
|
+
if resolved_metrics_margin_px > 0 and img_w > 0 and img_h > 0:
|
|
282
|
+
dets_m = [
|
|
283
|
+
d
|
|
284
|
+
for d in dets_m
|
|
285
|
+
if _rbox_centroid_in_tile_interior(d.rbox, img_w, img_h, resolved_metrics_margin_px)
|
|
286
|
+
]
|
|
287
|
+
gt_m = [
|
|
288
|
+
g
|
|
289
|
+
for g in gt_m
|
|
290
|
+
if _rbox_centroid_in_tile_interior(g.rbox, img_w, img_h, resolved_metrics_margin_px)
|
|
291
|
+
]
|
|
292
|
+
all_detections[image_id] = dets_m
|
|
293
|
+
all_ground_truths[image_id] = gt_m
|
|
294
|
+
|
|
295
|
+
pred_boxes_array = [rbox_to_array(rb).tolist() for rb in rboxes]
|
|
296
|
+
results.append(
|
|
297
|
+
{
|
|
298
|
+
"image_name": img_name,
|
|
299
|
+
"image_path": os.path.relpath(img_path, data_root),
|
|
300
|
+
"image_width": int(img_w),
|
|
301
|
+
"image_height": int(img_h),
|
|
302
|
+
"resize_mode": preprocessing.get("resize_mode", "fixed"),
|
|
303
|
+
"target_size": preprocessing.get("target_size", [1024, 1024]),
|
|
304
|
+
"num_gt": int(num_gt),
|
|
305
|
+
"num_pred": int(len(rboxes)),
|
|
306
|
+
"predictions": [
|
|
307
|
+
{
|
|
308
|
+
"bbox": pred_boxes_array[i],
|
|
309
|
+
"score": float(pred_scores[i]),
|
|
310
|
+
"label": int(pred_labels[i]),
|
|
311
|
+
"class_name": (
|
|
312
|
+
class_names[pred_labels[i] - 1]
|
|
313
|
+
if class_names and 1 <= pred_labels[i] <= len(class_names)
|
|
314
|
+
else f"class_{pred_labels[i]}"
|
|
315
|
+
),
|
|
316
|
+
}
|
|
317
|
+
for i in range(len(rboxes))
|
|
318
|
+
],
|
|
319
|
+
"ground_truths": gt_entries,
|
|
320
|
+
"stats": {"inference_backend": "tensorflow_keras_export"},
|
|
321
|
+
}
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
t_infer = time.perf_counter() - t0
|
|
325
|
+
experiment_dir = str(config_path.parent)
|
|
326
|
+
checkpoint_ref = str(reference_checkpoint or config_path)
|
|
327
|
+
|
|
328
|
+
diagnostics = None
|
|
329
|
+
analysis = None
|
|
330
|
+
if run_diagnostics:
|
|
331
|
+
diagnostics, analysis = run_diagnostics_pipeline(
|
|
332
|
+
experiment_dir=experiment_dir,
|
|
333
|
+
checkpoint_path=checkpoint_ref,
|
|
334
|
+
config_path=str(config_path),
|
|
335
|
+
data_root=data_root,
|
|
336
|
+
data_split=data_split,
|
|
337
|
+
class_names=class_names,
|
|
338
|
+
score_threshold=float(score_threshold),
|
|
339
|
+
per_cls_thr=per_cls_thr,
|
|
340
|
+
nms_class_agnostic=nms_class_agnostic,
|
|
341
|
+
iou_threshold=iou_threshold,
|
|
342
|
+
pr_iou_threshold=None,
|
|
343
|
+
pr_threshold_min=0.0,
|
|
344
|
+
pr_threshold_max=1.0,
|
|
345
|
+
pr_threshold_step=0.1,
|
|
346
|
+
per_class_threshold_analysis=False,
|
|
347
|
+
resolved_metrics_margin_px=int(resolved_metrics_margin_px),
|
|
348
|
+
all_detections=all_detections,
|
|
349
|
+
all_ground_truths=all_ground_truths,
|
|
350
|
+
results=results,
|
|
351
|
+
all_scores=all_scores,
|
|
352
|
+
image_name_by_id=image_name_by_id,
|
|
353
|
+
sliding_window_positions_total=None,
|
|
354
|
+
window_batch_effective=None,
|
|
355
|
+
t_infer_sec=float(t_infer),
|
|
356
|
+
device="tensorflow",
|
|
357
|
+
output_dir=str(output_dir),
|
|
358
|
+
tile_metrics_csv=None,
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
meta_path = detect_dir / "export_meta.json"
|
|
362
|
+
export_meta = {}
|
|
363
|
+
if meta_path.is_file():
|
|
364
|
+
export_meta = json.loads(meta_path.read_text(encoding="utf-8"))
|
|
365
|
+
|
|
366
|
+
metadata: Dict[str, Any] = {
|
|
367
|
+
"timestamp": datetime.now().isoformat(),
|
|
368
|
+
"inference_backend": "tensorflow_keras_export",
|
|
369
|
+
"detect_bundle": str(detect_dir.resolve()),
|
|
370
|
+
"keras_model": str(keras_path.resolve()),
|
|
371
|
+
"export_meta_core_backend": export_meta.get("core_backend"),
|
|
372
|
+
"experiment_dir": experiment_dir,
|
|
373
|
+
"checkpoint": checkpoint_ref,
|
|
374
|
+
"config_file": str(config_path),
|
|
375
|
+
"pytorch_reference_checkpoint": str(reference_checkpoint) if reference_checkpoint else None,
|
|
376
|
+
"data_root": str(data_root),
|
|
377
|
+
"data_split": data_split,
|
|
378
|
+
"device": "tensorflow",
|
|
379
|
+
"ort_device": get_ort_device(),
|
|
380
|
+
"ort_providers": ort_providers,
|
|
381
|
+
"class_names": class_names,
|
|
382
|
+
"score_threshold": score_threshold,
|
|
383
|
+
"per_class_score_threshold": per_cls_thr,
|
|
384
|
+
"nms_class_agnostic": nms_class_agnostic,
|
|
385
|
+
"total_images": len(results),
|
|
386
|
+
"total_predictions": sum(r["num_pred"] for r in results),
|
|
387
|
+
"total_ground_truth": sum(r["num_gt"] for r in results),
|
|
388
|
+
"inference_loop_seconds": float(t_infer),
|
|
389
|
+
"bbox_coordinate_space": "image_pixels",
|
|
390
|
+
"metrics_margin_pixels": int(resolved_metrics_margin_px),
|
|
391
|
+
"preprocess_note": "Resize+ToTensor+Normalize (same as PyTorch make preds)",
|
|
392
|
+
}
|
|
393
|
+
if diagnostics is not None:
|
|
394
|
+
metadata["diagnostics"] = diagnostics
|
|
395
|
+
if analysis is not None:
|
|
396
|
+
metadata["best_threshold_f2"] = analysis.get("best_threshold", {})
|
|
397
|
+
metadata["analysis_file"] = analysis.get("artifacts", {}).get(
|
|
398
|
+
"analysis_json", f"analysis_iou{iou_threshold:.2f}.json"
|
|
399
|
+
)
|
|
400
|
+
metadata["pr_iou_threshold"] = analysis.get("iou_threshold", iou_threshold)
|
|
401
|
+
|
|
402
|
+
json_path = output_dir / "predictions.json"
|
|
403
|
+
with open(json_path, "w", encoding="utf-8") as f:
|
|
404
|
+
json.dump({"metadata": metadata, "results": results}, f, indent=2)
|
|
405
|
+
|
|
406
|
+
print(f"Wrote {json_path}")
|
|
407
|
+
return {"output_dir": str(output_dir), "metadata": metadata}
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def main() -> None:
|
|
411
|
+
p = argparse.ArgumentParser(description="Val inference via TF/Keras export bundle.")
|
|
412
|
+
p.add_argument("--config", type=Path, required=True)
|
|
413
|
+
p.add_argument("--detect-dir", type=Path, required=True, help="Directory with keras_model.keras")
|
|
414
|
+
p.add_argument("--output-dir", type=Path, default=None, help="Default: export/artifacts/predictions/<ts>")
|
|
415
|
+
p.add_argument("--data-root", type=Path, default=None)
|
|
416
|
+
p.add_argument("--data-split", default="val", choices=("train", "val", "test"))
|
|
417
|
+
p.add_argument("--val-dir", type=Path, default=None)
|
|
418
|
+
p.add_argument(
|
|
419
|
+
"--reference-checkpoint",
|
|
420
|
+
type=Path,
|
|
421
|
+
default=None,
|
|
422
|
+
help="PyTorch .pth path recorded in metadata for comparison (default: deploy weights path only in meta).",
|
|
423
|
+
)
|
|
424
|
+
p.add_argument("--no-diagnostics", action="store_true", help="Skip mAP/PR (inference-only JSON).")
|
|
425
|
+
p.add_argument(
|
|
426
|
+
"--ort-device",
|
|
427
|
+
default=None,
|
|
428
|
+
choices=("cpu", "cuda", "auto"),
|
|
429
|
+
help="ONNX Runtime EP for the exported ONNX core (default: cpu or ORIENTED_DET_ORT_DEVICE).",
|
|
430
|
+
)
|
|
431
|
+
args = p.parse_args()
|
|
432
|
+
|
|
433
|
+
run_tf_inference_and_save(
|
|
434
|
+
config_path=args.config,
|
|
435
|
+
detect_dir=args.detect_dir,
|
|
436
|
+
output_dir=args.output_dir,
|
|
437
|
+
data_root=args.data_root,
|
|
438
|
+
data_split=args.data_split,
|
|
439
|
+
val_dir=args.val_dir,
|
|
440
|
+
run_diagnostics=not args.no_diagnostics,
|
|
441
|
+
reference_checkpoint=args.reference_checkpoint,
|
|
442
|
+
ort_device=args.ort_device,
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
if __name__ == "__main__":
|
|
447
|
+
main()
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Optional: convert SavedModel to TFLite (float32)."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import sys
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def main() -> None:
|
|
12
|
+
p = argparse.ArgumentParser(description="SavedModel → TFLite float32.")
|
|
13
|
+
p.add_argument("--saved-model", type=Path, required=True)
|
|
14
|
+
p.add_argument("--output", type=Path, required=True, help="Output .tflite path")
|
|
15
|
+
args = p.parse_args()
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
import tensorflow as tf
|
|
19
|
+
except ImportError as e:
|
|
20
|
+
print("Install TensorFlow: pip install -r export/requirements-export.txt", file=sys.stderr)
|
|
21
|
+
raise SystemExit(1) from e
|
|
22
|
+
|
|
23
|
+
converter = tf.lite.TFLiteConverter.from_saved_model(str(args.saved_model))
|
|
24
|
+
converter.optimizations = []
|
|
25
|
+
tflite_model = converter.convert()
|
|
26
|
+
args.output.parent.mkdir(parents=True, exist_ok=True)
|
|
27
|
+
args.output.write_bytes(tflite_model)
|
|
28
|
+
print(f"Wrote {args.output} ({len(tflite_model)} bytes)")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
if __name__ == "__main__":
|
|
32
|
+
main()
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""Optional ONNX smoke (each test skips if deps missing)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import io
|
|
6
|
+
import sys
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
import pytest
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
_REPO_ROOT = Path(__file__).resolve().parents[2]
|
|
13
|
+
_EXPORT_DIR = Path(__file__).resolve().parents[1]
|
|
14
|
+
if str(_REPO_ROOT) not in sys.path:
|
|
15
|
+
sys.path.insert(0, str(_REPO_ROOT))
|
|
16
|
+
if str(_EXPORT_DIR) not in sys.path:
|
|
17
|
+
sys.path.insert(0, str(_EXPORT_DIR))
|
|
18
|
+
|
|
19
|
+
import wrappers as _wrappers # noqa: E402
|
|
20
|
+
from oriented_det.models.rotated_retinanet import RotatedRetinaNet # noqa: E402
|
|
21
|
+
|
|
22
|
+
BackboneExportWrapper = _wrappers.BackboneExportWrapper
|
|
23
|
+
RetinaNetBackboneHeadExportWrapper = _wrappers.RetinaNetBackboneHeadExportWrapper
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def test_export_backbone_onnx_roundtrip_ort() -> None:
|
|
27
|
+
pytest.importorskip("onnx")
|
|
28
|
+
pytest.importorskip("onnxruntime")
|
|
29
|
+
import onnx
|
|
30
|
+
import onnxruntime as ort
|
|
31
|
+
|
|
32
|
+
model = RotatedRetinaNet(num_classes=2, backbone_name="resnet18", pretrained_backbone=False)
|
|
33
|
+
model.eval()
|
|
34
|
+
wrap = BackboneExportWrapper(model.backbone)
|
|
35
|
+
x = torch.randn(1, 3, 128, 128, dtype=torch.float32)
|
|
36
|
+
with torch.no_grad():
|
|
37
|
+
ref = wrap(x)
|
|
38
|
+
buf = io.BytesIO()
|
|
39
|
+
torch.onnx.export(
|
|
40
|
+
wrap,
|
|
41
|
+
x,
|
|
42
|
+
buf,
|
|
43
|
+
input_names=["images"],
|
|
44
|
+
output_names=[f"fpn_{i}" for i in range(len(ref))],
|
|
45
|
+
opset_version=17,
|
|
46
|
+
do_constant_folding=True,
|
|
47
|
+
)
|
|
48
|
+
buf.seek(0)
|
|
49
|
+
onnx_model = onnx.load(buf)
|
|
50
|
+
onnx.checker.check_model(onnx_model)
|
|
51
|
+
sess = ort.InferenceSession(buf.getvalue(), providers=["CPUExecutionProvider"])
|
|
52
|
+
out = sess.run(None, {"images": x.numpy()})
|
|
53
|
+
assert len(out) == len(ref)
|
|
54
|
+
for a, b in zip(out, ref):
|
|
55
|
+
assert a.shape == tuple(b.shape)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def test_export_retinanet_heads_onnx_checker() -> None:
|
|
59
|
+
pytest.importorskip("onnx")
|
|
60
|
+
import onnx
|
|
61
|
+
|
|
62
|
+
model = RotatedRetinaNet(num_classes=2, backbone_name="resnet18", pretrained_backbone=False)
|
|
63
|
+
model.eval()
|
|
64
|
+
wrap = RetinaNetBackboneHeadExportWrapper(model)
|
|
65
|
+
x = torch.randn(1, 3, 128, 128, dtype=torch.float32)
|
|
66
|
+
with torch.no_grad():
|
|
67
|
+
ref = wrap(x)
|
|
68
|
+
names = []
|
|
69
|
+
for i in range(len(ref) // 2):
|
|
70
|
+
names.extend([f"level{i}_cls", f"level{i}_bbox"])
|
|
71
|
+
buf = io.BytesIO()
|
|
72
|
+
torch.onnx.export(
|
|
73
|
+
wrap,
|
|
74
|
+
x,
|
|
75
|
+
buf,
|
|
76
|
+
input_names=["images"],
|
|
77
|
+
output_names=names,
|
|
78
|
+
opset_version=17,
|
|
79
|
+
do_constant_folding=True,
|
|
80
|
+
)
|
|
81
|
+
buf.seek(0)
|
|
82
|
+
onnx.checker.check_model(onnx.load(buf))
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
"""Tests for export wrappers (no ONNX/TF required)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import sys
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import pytest
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
_REPO_ROOT = Path(__file__).resolve().parents[2]
|
|
12
|
+
_EXPORT_DIR = Path(__file__).resolve().parents[1]
|
|
13
|
+
if str(_REPO_ROOT) not in sys.path:
|
|
14
|
+
sys.path.insert(0, str(_REPO_ROOT))
|
|
15
|
+
if str(_EXPORT_DIR) not in sys.path:
|
|
16
|
+
sys.path.insert(0, str(_EXPORT_DIR))
|
|
17
|
+
|
|
18
|
+
import wrappers as _wrappers # noqa: E402
|
|
19
|
+
from oriented_det import RotatedFasterRCNN # noqa: E402
|
|
20
|
+
from oriented_det.models.rotated_retinanet import RotatedRetinaNet # noqa: E402
|
|
21
|
+
|
|
22
|
+
BackboneExportWrapper = _wrappers.BackboneExportWrapper
|
|
23
|
+
RetinaNetBackboneHeadExportWrapper = _wrappers.RetinaNetBackboneHeadExportWrapper
|
|
24
|
+
RotatedFasterRCNNPreNmsExportWrapper = _wrappers.RotatedFasterRCNNPreNmsExportWrapper
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@pytest.fixture
|
|
28
|
+
def tiny_retinanet() -> RotatedRetinaNet:
|
|
29
|
+
return RotatedRetinaNet(
|
|
30
|
+
num_classes=3,
|
|
31
|
+
backbone_name="resnet18",
|
|
32
|
+
pretrained_backbone=False,
|
|
33
|
+
trainable_layers=5,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def test_backbone_wrapper_shapes(tiny_retinanet: RotatedRetinaNet) -> None:
|
|
38
|
+
w = BackboneExportWrapper(tiny_retinanet.backbone)
|
|
39
|
+
w.eval()
|
|
40
|
+
x = torch.randn(2, 3, 128, 128, dtype=torch.float32)
|
|
41
|
+
with torch.no_grad():
|
|
42
|
+
outs = w(x)
|
|
43
|
+
assert isinstance(outs, tuple)
|
|
44
|
+
assert len(outs) >= 1
|
|
45
|
+
for t in outs:
|
|
46
|
+
assert t.dim() == 4
|
|
47
|
+
assert t.shape[0] == 2
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def test_retinanet_heads_wrapper_shapes(tiny_retinanet: RotatedRetinaNet) -> None:
|
|
51
|
+
w = RetinaNetBackboneHeadExportWrapper(tiny_retinanet)
|
|
52
|
+
w.eval()
|
|
53
|
+
x = torch.randn(1, 3, 128, 128, dtype=torch.float32)
|
|
54
|
+
with torch.no_grad():
|
|
55
|
+
outs = w(x)
|
|
56
|
+
assert len(outs) % 2 == 0
|
|
57
|
+
for i in range(0, len(outs), 2):
|
|
58
|
+
cls_t, box_t = outs[i], outs[i + 1]
|
|
59
|
+
assert cls_t.shape[0] == 1 and box_t.shape[0] == 1
|
|
60
|
+
assert cls_t.shape[1] > 0 and box_t.shape[1] % 5 == 0
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def test_retinanet_heads_batch_unbind_matches_stack(tiny_retinanet: RotatedRetinaNet) -> None:
|
|
64
|
+
"""B>1 path uses torch.unbind inside extract_backbone_features input list."""
|
|
65
|
+
w = RetinaNetBackboneHeadExportWrapper(tiny_retinanet)
|
|
66
|
+
w.eval()
|
|
67
|
+
x2 = torch.randn(2, 3, 128, 128, dtype=torch.float32)
|
|
68
|
+
with torch.no_grad():
|
|
69
|
+
a = w(x2)
|
|
70
|
+
x1 = torch.stack([x2[0], x2[1]], dim=0)
|
|
71
|
+
assert torch.allclose(x1, x2)
|
|
72
|
+
with torch.no_grad():
|
|
73
|
+
b0 = w(x2[0:1])
|
|
74
|
+
b1 = w(x2[1:2])
|
|
75
|
+
assert len(a) == len(b0) == len(b1)
|
|
76
|
+
for i in range(len(a)):
|
|
77
|
+
assert a[i].shape[0] == 2
|
|
78
|
+
assert torch.allclose(a[i][0], b0[i][0])
|
|
79
|
+
assert torch.allclose(a[i][1], b1[i][0])
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@pytest.fixture
|
|
83
|
+
def tiny_faster_rcnn() -> RotatedFasterRCNN:
|
|
84
|
+
return RotatedFasterRCNN(
|
|
85
|
+
num_classes=3,
|
|
86
|
+
backbone_name="resnet18",
|
|
87
|
+
pretrained_backbone=False,
|
|
88
|
+
trainable_layers=5,
|
|
89
|
+
rpn_post_nms_top_n=64,
|
|
90
|
+
rpn_pre_nms_top_n=64,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def test_faster_rcnn_pre_nms_wrapper_shapes(tiny_faster_rcnn: RotatedFasterRCNN) -> None:
|
|
95
|
+
h, w = 128, 128
|
|
96
|
+
wwrap = RotatedFasterRCNNPreNmsExportWrapper(tiny_faster_rcnn, height=h, width=w, max_candidates=32)
|
|
97
|
+
wwrap.eval()
|
|
98
|
+
x = torch.randn(1, 3, h, w, dtype=torch.float32)
|
|
99
|
+
with torch.no_grad():
|
|
100
|
+
boxes, scores, labels, count = wwrap(x)
|
|
101
|
+
assert boxes.shape == (32, 5)
|
|
102
|
+
assert scores.shape == (32,)
|
|
103
|
+
assert labels.shape == (32,)
|
|
104
|
+
assert count.ndim == 0
|
|
105
|
+
assert int(count.item()) <= 32
|