matrice-analytics 0.1.60__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- matrice_analytics/__init__.py +28 -0
- matrice_analytics/boundary_drawing_internal/README.md +305 -0
- matrice_analytics/boundary_drawing_internal/__init__.py +45 -0
- matrice_analytics/boundary_drawing_internal/boundary_drawing_internal.py +1207 -0
- matrice_analytics/boundary_drawing_internal/boundary_drawing_tool.py +429 -0
- matrice_analytics/boundary_drawing_internal/boundary_tool_template.html +1036 -0
- matrice_analytics/boundary_drawing_internal/data/.gitignore +12 -0
- matrice_analytics/boundary_drawing_internal/example_usage.py +206 -0
- matrice_analytics/boundary_drawing_internal/usage/README.md +110 -0
- matrice_analytics/boundary_drawing_internal/usage/boundary_drawer_launcher.py +102 -0
- matrice_analytics/boundary_drawing_internal/usage/simple_boundary_launcher.py +107 -0
- matrice_analytics/post_processing/README.md +455 -0
- matrice_analytics/post_processing/__init__.py +732 -0
- matrice_analytics/post_processing/advanced_tracker/README.md +650 -0
- matrice_analytics/post_processing/advanced_tracker/__init__.py +17 -0
- matrice_analytics/post_processing/advanced_tracker/base.py +99 -0
- matrice_analytics/post_processing/advanced_tracker/config.py +77 -0
- matrice_analytics/post_processing/advanced_tracker/kalman_filter.py +370 -0
- matrice_analytics/post_processing/advanced_tracker/matching.py +195 -0
- matrice_analytics/post_processing/advanced_tracker/strack.py +230 -0
- matrice_analytics/post_processing/advanced_tracker/tracker.py +367 -0
- matrice_analytics/post_processing/config.py +146 -0
- matrice_analytics/post_processing/core/__init__.py +63 -0
- matrice_analytics/post_processing/core/base.py +704 -0
- matrice_analytics/post_processing/core/config.py +3291 -0
- matrice_analytics/post_processing/core/config_utils.py +925 -0
- matrice_analytics/post_processing/face_reg/__init__.py +43 -0
- matrice_analytics/post_processing/face_reg/compare_similarity.py +556 -0
- matrice_analytics/post_processing/face_reg/embedding_manager.py +950 -0
- matrice_analytics/post_processing/face_reg/face_recognition.py +2234 -0
- matrice_analytics/post_processing/face_reg/face_recognition_client.py +606 -0
- matrice_analytics/post_processing/face_reg/people_activity_logging.py +321 -0
- matrice_analytics/post_processing/ocr/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/easyocr_extractor.py +250 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/__init__.py +9 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/__init__.py +4 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/cli.py +33 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/dataset_stats.py +139 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/export.py +398 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/train.py +447 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/utils.py +129 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/valid.py +93 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/validate_dataset.py +240 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/visualize_augmentation.py +176 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/visualize_predictions.py +96 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/__init__.py +3 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/process.py +246 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/types.py +60 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/utils.py +87 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/__init__.py +3 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/config.py +82 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/hub.py +141 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/plate_recognizer.py +323 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/py.typed +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/augmentation.py +101 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/dataset.py +97 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/config.py +114 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/layers.py +553 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/loss.py +55 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/metric.py +86 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/model_builders.py +95 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/model_schema.py +395 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/backend_utils.py +38 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/utils.py +214 -0
- matrice_analytics/post_processing/ocr/postprocessing.py +270 -0
- matrice_analytics/post_processing/ocr/preprocessing.py +52 -0
- matrice_analytics/post_processing/post_processor.py +1175 -0
- matrice_analytics/post_processing/test_cases/__init__.py +1 -0
- matrice_analytics/post_processing/test_cases/run_tests.py +143 -0
- matrice_analytics/post_processing/test_cases/test_advanced_customer_service.py +841 -0
- matrice_analytics/post_processing/test_cases/test_basic_counting_tracking.py +523 -0
- matrice_analytics/post_processing/test_cases/test_comprehensive.py +531 -0
- matrice_analytics/post_processing/test_cases/test_config.py +852 -0
- matrice_analytics/post_processing/test_cases/test_customer_service.py +585 -0
- matrice_analytics/post_processing/test_cases/test_data_generators.py +583 -0
- matrice_analytics/post_processing/test_cases/test_people_counting.py +510 -0
- matrice_analytics/post_processing/test_cases/test_processor.py +524 -0
- matrice_analytics/post_processing/test_cases/test_usecases.py +165 -0
- matrice_analytics/post_processing/test_cases/test_utilities.py +356 -0
- matrice_analytics/post_processing/test_cases/test_utils.py +743 -0
- matrice_analytics/post_processing/usecases/Histopathological_Cancer_Detection_img.py +604 -0
- matrice_analytics/post_processing/usecases/__init__.py +267 -0
- matrice_analytics/post_processing/usecases/abandoned_object_detection.py +797 -0
- matrice_analytics/post_processing/usecases/advanced_customer_service.py +1601 -0
- matrice_analytics/post_processing/usecases/age_detection.py +842 -0
- matrice_analytics/post_processing/usecases/age_gender_detection.py +1085 -0
- matrice_analytics/post_processing/usecases/anti_spoofing_detection.py +656 -0
- matrice_analytics/post_processing/usecases/assembly_line_detection.py +841 -0
- matrice_analytics/post_processing/usecases/banana_defect_detection.py +624 -0
- matrice_analytics/post_processing/usecases/basic_counting_tracking.py +667 -0
- matrice_analytics/post_processing/usecases/blood_cancer_detection_img.py +881 -0
- matrice_analytics/post_processing/usecases/car_damage_detection.py +834 -0
- matrice_analytics/post_processing/usecases/car_part_segmentation.py +946 -0
- matrice_analytics/post_processing/usecases/car_service.py +1601 -0
- matrice_analytics/post_processing/usecases/cardiomegaly_classification.py +864 -0
- matrice_analytics/post_processing/usecases/cell_microscopy_segmentation.py +897 -0
- matrice_analytics/post_processing/usecases/chicken_pose_detection.py +648 -0
- matrice_analytics/post_processing/usecases/child_monitoring.py +814 -0
- matrice_analytics/post_processing/usecases/color/clip.py +660 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/merges.txt +48895 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/preprocessor_config.json +28 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/special_tokens_map.json +30 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/tokenizer.json +245079 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/tokenizer_config.json +32 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/vocab.json +1 -0
- matrice_analytics/post_processing/usecases/color/color_map_utils.py +70 -0
- matrice_analytics/post_processing/usecases/color/color_mapper.py +468 -0
- matrice_analytics/post_processing/usecases/color_detection.py +1936 -0
- matrice_analytics/post_processing/usecases/color_map_utils.py +70 -0
- matrice_analytics/post_processing/usecases/concrete_crack_detection.py +827 -0
- matrice_analytics/post_processing/usecases/crop_weed_detection.py +781 -0
- matrice_analytics/post_processing/usecases/customer_service.py +1008 -0
- matrice_analytics/post_processing/usecases/defect_detection_products.py +936 -0
- matrice_analytics/post_processing/usecases/distracted_driver_detection.py +822 -0
- matrice_analytics/post_processing/usecases/drone_traffic_monitoring.py +585 -0
- matrice_analytics/post_processing/usecases/drowsy_driver_detection.py +829 -0
- matrice_analytics/post_processing/usecases/dwell_detection.py +829 -0
- matrice_analytics/post_processing/usecases/emergency_vehicle_detection.py +827 -0
- matrice_analytics/post_processing/usecases/face_emotion.py +813 -0
- matrice_analytics/post_processing/usecases/face_recognition.py +827 -0
- matrice_analytics/post_processing/usecases/fashion_detection.py +835 -0
- matrice_analytics/post_processing/usecases/field_mapping.py +902 -0
- matrice_analytics/post_processing/usecases/fire_detection.py +1146 -0
- matrice_analytics/post_processing/usecases/flare_analysis.py +836 -0
- matrice_analytics/post_processing/usecases/flower_segmentation.py +1006 -0
- matrice_analytics/post_processing/usecases/gas_leak_detection.py +837 -0
- matrice_analytics/post_processing/usecases/gender_detection.py +832 -0
- matrice_analytics/post_processing/usecases/human_activity_recognition.py +871 -0
- matrice_analytics/post_processing/usecases/intrusion_detection.py +1672 -0
- matrice_analytics/post_processing/usecases/leaf.py +821 -0
- matrice_analytics/post_processing/usecases/leaf_disease.py +840 -0
- matrice_analytics/post_processing/usecases/leak_detection.py +837 -0
- matrice_analytics/post_processing/usecases/license_plate_detection.py +1188 -0
- matrice_analytics/post_processing/usecases/license_plate_monitoring.py +1781 -0
- matrice_analytics/post_processing/usecases/litter_monitoring.py +717 -0
- matrice_analytics/post_processing/usecases/mask_detection.py +869 -0
- matrice_analytics/post_processing/usecases/natural_disaster.py +907 -0
- matrice_analytics/post_processing/usecases/parking.py +787 -0
- matrice_analytics/post_processing/usecases/parking_space_detection.py +822 -0
- matrice_analytics/post_processing/usecases/pcb_defect_detection.py +888 -0
- matrice_analytics/post_processing/usecases/pedestrian_detection.py +808 -0
- matrice_analytics/post_processing/usecases/people_counting.py +706 -0
- matrice_analytics/post_processing/usecases/people_counting_bckp.py +1683 -0
- matrice_analytics/post_processing/usecases/people_tracking.py +1842 -0
- matrice_analytics/post_processing/usecases/pipeline_detection.py +605 -0
- matrice_analytics/post_processing/usecases/plaque_segmentation_img.py +874 -0
- matrice_analytics/post_processing/usecases/pothole_segmentation.py +915 -0
- matrice_analytics/post_processing/usecases/ppe_compliance.py +645 -0
- matrice_analytics/post_processing/usecases/price_tag_detection.py +822 -0
- matrice_analytics/post_processing/usecases/proximity_detection.py +1901 -0
- matrice_analytics/post_processing/usecases/road_lane_detection.py +623 -0
- matrice_analytics/post_processing/usecases/road_traffic_density.py +832 -0
- matrice_analytics/post_processing/usecases/road_view_segmentation.py +915 -0
- matrice_analytics/post_processing/usecases/shelf_inventory_detection.py +583 -0
- matrice_analytics/post_processing/usecases/shoplifting_detection.py +822 -0
- matrice_analytics/post_processing/usecases/shopping_cart_analysis.py +899 -0
- matrice_analytics/post_processing/usecases/skin_cancer_classification_img.py +864 -0
- matrice_analytics/post_processing/usecases/smoker_detection.py +833 -0
- matrice_analytics/post_processing/usecases/solar_panel.py +810 -0
- matrice_analytics/post_processing/usecases/suspicious_activity_detection.py +1030 -0
- matrice_analytics/post_processing/usecases/template_usecase.py +380 -0
- matrice_analytics/post_processing/usecases/theft_detection.py +648 -0
- matrice_analytics/post_processing/usecases/traffic_sign_monitoring.py +724 -0
- matrice_analytics/post_processing/usecases/underground_pipeline_defect_detection.py +775 -0
- matrice_analytics/post_processing/usecases/underwater_pollution_detection.py +842 -0
- matrice_analytics/post_processing/usecases/vehicle_monitoring.py +1029 -0
- matrice_analytics/post_processing/usecases/warehouse_object_segmentation.py +899 -0
- matrice_analytics/post_processing/usecases/waterbody_segmentation.py +923 -0
- matrice_analytics/post_processing/usecases/weapon_detection.py +771 -0
- matrice_analytics/post_processing/usecases/weld_defect_detection.py +615 -0
- matrice_analytics/post_processing/usecases/wildlife_monitoring.py +898 -0
- matrice_analytics/post_processing/usecases/windmill_maintenance.py +834 -0
- matrice_analytics/post_processing/usecases/wound_segmentation.py +856 -0
- matrice_analytics/post_processing/utils/__init__.py +150 -0
- matrice_analytics/post_processing/utils/advanced_counting_utils.py +400 -0
- matrice_analytics/post_processing/utils/advanced_helper_utils.py +317 -0
- matrice_analytics/post_processing/utils/advanced_tracking_utils.py +461 -0
- matrice_analytics/post_processing/utils/alerting_utils.py +213 -0
- matrice_analytics/post_processing/utils/category_mapping_utils.py +94 -0
- matrice_analytics/post_processing/utils/color_utils.py +592 -0
- matrice_analytics/post_processing/utils/counting_utils.py +182 -0
- matrice_analytics/post_processing/utils/filter_utils.py +261 -0
- matrice_analytics/post_processing/utils/format_utils.py +293 -0
- matrice_analytics/post_processing/utils/geometry_utils.py +300 -0
- matrice_analytics/post_processing/utils/smoothing_utils.py +358 -0
- matrice_analytics/post_processing/utils/tracking_utils.py +234 -0
- matrice_analytics/py.typed +0 -0
- matrice_analytics-0.1.60.dist-info/METADATA +481 -0
- matrice_analytics-0.1.60.dist-info/RECORD +196 -0
- matrice_analytics-0.1.60.dist-info/WHEEL +5 -0
- matrice_analytics-0.1.60.dist-info/licenses/LICENSE.txt +21 -0
- matrice_analytics-0.1.60.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Display statistics for a `fast-plate-ocr` dataset.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from collections import Counter
|
|
8
|
+
from collections.abc import Sequence
|
|
9
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
import click
|
|
13
|
+
import pandas as pd
|
|
14
|
+
from PIL import Image, UnidentifiedImageError
|
|
15
|
+
from rich import box
|
|
16
|
+
from rich.console import Console, Group
|
|
17
|
+
from rich.markup import escape
|
|
18
|
+
from rich.panel import Panel
|
|
19
|
+
from rich.table import Table
|
|
20
|
+
from typing import Optional
|
|
21
|
+
|
|
22
|
+
from fast_plate_ocr.train.model.config import load_plate_config_from_yaml
|
|
23
|
+
|
|
24
|
+
# pylint: disable=too-many-locals
|
|
25
|
+
|
|
26
|
+
console = Console()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _header_shape(path: Path) -> tuple[bool, Optional[tuple[int, int]]]:
|
|
30
|
+
try:
|
|
31
|
+
with Image.open(path) as im:
|
|
32
|
+
im.verify()
|
|
33
|
+
w, h = im.size
|
|
34
|
+
return True, (h, w)
|
|
35
|
+
except (UnidentifiedImageError, OSError):
|
|
36
|
+
return False, None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _compact_table(title: str, values: Sequence[float]) -> Table:
|
|
40
|
+
s = pd.Series(values, dtype="float64")
|
|
41
|
+
desc = s.describe(percentiles=[0.05, 0.5, 0.95])
|
|
42
|
+
metrics = ["count", "mean", "std", "min", "max", "5%", "50%", "95%"]
|
|
43
|
+
tbl = Table(title=title, box=box.MINIMAL_DOUBLE_HEAD, pad_edge=False, expand=False)
|
|
44
|
+
for m in metrics:
|
|
45
|
+
tbl.add_column(m, justify="right", style="bold")
|
|
46
|
+
tbl.add_row(*[f"{desc[m]:.2f}" if pd.notna(desc[m]) else "-" for m in metrics])
|
|
47
|
+
return tbl
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@click.command(context_settings={"max_content_width": 120})
|
|
51
|
+
@click.option(
|
|
52
|
+
"--annotations",
|
|
53
|
+
"-a",
|
|
54
|
+
required=True,
|
|
55
|
+
type=click.Path(exists=True, dir_okay=False, file_okay=True, path_type=Path),
|
|
56
|
+
help="CSV with image_path and plate_text columns.",
|
|
57
|
+
)
|
|
58
|
+
@click.option(
|
|
59
|
+
"--plate-config-file",
|
|
60
|
+
"-c",
|
|
61
|
+
required=True,
|
|
62
|
+
type=click.Path(exists=True, dir_okay=False, file_okay=True, path_type=Path),
|
|
63
|
+
help="YAML config so we know alphabet/pad char.",
|
|
64
|
+
)
|
|
65
|
+
@click.option(
|
|
66
|
+
"--top-chars",
|
|
67
|
+
default=10,
|
|
68
|
+
show_default=True,
|
|
69
|
+
type=int,
|
|
70
|
+
help="Show N most frequent characters.",
|
|
71
|
+
)
|
|
72
|
+
@click.option(
|
|
73
|
+
"--workers",
|
|
74
|
+
default=8,
|
|
75
|
+
show_default=True,
|
|
76
|
+
type=int,
|
|
77
|
+
help="Parallel header reads (0 disables threading).",
|
|
78
|
+
)
|
|
79
|
+
def dataset_stats(annotations: Path, plate_config_file: Path, top_chars: int, workers: int) -> None:
|
|
80
|
+
"""
|
|
81
|
+
Display statistics for a `fast-plate-ocr` dataset.
|
|
82
|
+
"""
|
|
83
|
+
plate_config = load_plate_config_from_yaml(plate_config_file)
|
|
84
|
+
|
|
85
|
+
df_annots = pd.read_csv(annotations)
|
|
86
|
+
root = annotations.parent
|
|
87
|
+
df_annots["image_path"] = df_annots["image_path"].apply(lambda p: str((root / p).resolve()))
|
|
88
|
+
|
|
89
|
+
# Plate lengths and char frequencies
|
|
90
|
+
plate_lengths = df_annots["plate_text"].str.len().tolist()
|
|
91
|
+
char_counter: Counter[str] = Counter("".join(df_annots["plate_text"].tolist()))
|
|
92
|
+
|
|
93
|
+
# File extension counts
|
|
94
|
+
ext_counter = Counter(df_annots["image_path"].apply(lambda p: Path(p).suffix.lower()))
|
|
95
|
+
|
|
96
|
+
# Image header dimensions
|
|
97
|
+
paths = [Path(p) for p in df_annots["image_path"].tolist()]
|
|
98
|
+
if workers > 1:
|
|
99
|
+
with ThreadPoolExecutor(max_workers=workers) as ex:
|
|
100
|
+
dims = list(ex.map(_header_shape, paths))
|
|
101
|
+
else:
|
|
102
|
+
dims = [_header_shape(p) for p in paths]
|
|
103
|
+
|
|
104
|
+
valid_dims = [dims_pair for ok, dims_pair in dims if ok and dims_pair is not None]
|
|
105
|
+
|
|
106
|
+
heights = [h for h, _ in valid_dims]
|
|
107
|
+
widths = [w for _, w in valid_dims]
|
|
108
|
+
aspects = [w / h for h, w in valid_dims if h > 0]
|
|
109
|
+
|
|
110
|
+
# Build tables
|
|
111
|
+
tbl_len = _compact_table("Plate Lengths", plate_lengths)
|
|
112
|
+
tbl_h = _compact_table("Image Height", heights)
|
|
113
|
+
tbl_w = _compact_table("Image Width", widths)
|
|
114
|
+
tbl_ar = _compact_table("Aspect Ratio", aspects)
|
|
115
|
+
|
|
116
|
+
# Extension table
|
|
117
|
+
tbl_ext = Table(title="Extensions", box=box.MINIMAL_DOUBLE_HEAD, pad_edge=False)
|
|
118
|
+
tbl_ext.add_column("Ext", style="bold", justify="left")
|
|
119
|
+
tbl_ext.add_column("Count", justify="right")
|
|
120
|
+
for ext, cnt in ext_counter.most_common():
|
|
121
|
+
tbl_ext.add_row(ext or "<none>", str(cnt))
|
|
122
|
+
|
|
123
|
+
# Character freq table
|
|
124
|
+
tbl_char = Table(title=f"Top {top_chars} Chars", box=box.MINIMAL_DOUBLE_HEAD, pad_edge=False)
|
|
125
|
+
tbl_char.add_column("Char", style="bold")
|
|
126
|
+
tbl_char.add_column("Count", justify="right")
|
|
127
|
+
for ch, cnt in char_counter.most_common(top_chars):
|
|
128
|
+
if ch == plate_config.pad_char:
|
|
129
|
+
continue
|
|
130
|
+
tbl_char.add_row(escape(ch), str(cnt))
|
|
131
|
+
|
|
132
|
+
group = Group(tbl_len, tbl_h, tbl_w, tbl_ar, tbl_ext, tbl_char)
|
|
133
|
+
console.print(
|
|
134
|
+
Panel.fit(group, title="Dataset Statistics", border_style="green", box=box.SQUARE)
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
if __name__ == "__main__":
|
|
139
|
+
dataset_stats()
|
|
@@ -0,0 +1,398 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Script for exporting the trained Keras models to other formats.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import pathlib
|
|
9
|
+
import shutil
|
|
10
|
+
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
|
11
|
+
|
|
12
|
+
import click
|
|
13
|
+
import keras
|
|
14
|
+
import numpy as np
|
|
15
|
+
from numpy.typing import DTypeLike
|
|
16
|
+
|
|
17
|
+
from fast_plate_ocr.cli.utils import requires
|
|
18
|
+
from fast_plate_ocr.core.types import TensorDataFormat
|
|
19
|
+
from fast_plate_ocr.core.utils import log_time_taken
|
|
20
|
+
from fast_plate_ocr.train.model.config import (
|
|
21
|
+
PlateOCRConfig,
|
|
22
|
+
load_plate_config_from_yaml,
|
|
23
|
+
)
|
|
24
|
+
from fast_plate_ocr.train.utilities.utils import load_keras_model
|
|
25
|
+
from typing import Optional
|
|
26
|
+
|
|
27
|
+
# ruff: noqa: PLC0415
|
|
28
|
+
# pylint: disable=too-many-arguments,too-many-locals,import-outside-toplevel
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _dummy_input(b: int, h: int, w: int, n_c: int, dtype: DTypeLike = np.uint8) -> np.ndarray:
|
|
32
|
+
"""Random tensor in [0, 255] shaped (b, h, w, 1)."""
|
|
33
|
+
return np.random.randint(0, 256, size=(b, h, w, n_c)).astype(dtype)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _validate_prediction(
|
|
37
|
+
keras_model: keras.Model,
|
|
38
|
+
exported_predict,
|
|
39
|
+
x: np.ndarray,
|
|
40
|
+
target: str,
|
|
41
|
+
rtol: float = 1e-4,
|
|
42
|
+
atol: float = 1e-4,
|
|
43
|
+
) -> None:
|
|
44
|
+
"""Compare Keras and exported backend on a single forward pass."""
|
|
45
|
+
keras_out = keras_model.predict(x, verbose=0)
|
|
46
|
+
exported_out = exported_predict(x)
|
|
47
|
+
if not np.allclose(keras_out, exported_out, rtol=rtol, atol=atol):
|
|
48
|
+
logging.warning("%s output deviates from Keras beyond tolerance.", target.upper())
|
|
49
|
+
else:
|
|
50
|
+
logging.info("%s output matches Keras ✔", target.upper())
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _make_output_path(
|
|
54
|
+
model_path: pathlib.Path, new_ext: str, save_dir: Optional[pathlib.Path] = None
|
|
55
|
+
) -> pathlib.Path:
|
|
56
|
+
"""
|
|
57
|
+
Build an output filename next to the model or inside --save-dir.
|
|
58
|
+
|
|
59
|
+
Note: If the file already exists we delete it.
|
|
60
|
+
|
|
61
|
+
:param model_path: Path to the model file.
|
|
62
|
+
:param save_dir: Directory to save the exported model.
|
|
63
|
+
:param new_ext: Extension to append to the model filename.
|
|
64
|
+
:return: Path to the output file.
|
|
65
|
+
"""
|
|
66
|
+
out_file = model_path.with_suffix(new_ext)
|
|
67
|
+
if save_dir is not None:
|
|
68
|
+
out_file = save_dir / out_file.name
|
|
69
|
+
|
|
70
|
+
if out_file.exists():
|
|
71
|
+
logging.info("Overwriting existing %s", out_file)
|
|
72
|
+
if out_file.is_dir():
|
|
73
|
+
shutil.rmtree(out_file)
|
|
74
|
+
else:
|
|
75
|
+
out_file.unlink()
|
|
76
|
+
|
|
77
|
+
return out_file
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _prepare_model_for_onnx_export(
|
|
81
|
+
model: keras.Model,
|
|
82
|
+
plate_config: PlateOCRConfig,
|
|
83
|
+
dynamic_batch: bool,
|
|
84
|
+
input_dtype: str,
|
|
85
|
+
data_format: TensorDataFormat,
|
|
86
|
+
):
|
|
87
|
+
"""
|
|
88
|
+
Prepare a Keras model for ONNX export by adjusting input layout if needed.
|
|
89
|
+
|
|
90
|
+
The model is only wrapped when 'channels_first' (NxCxHxW) format is requested, by inserting a
|
|
91
|
+
Permute layer to convert NxCxHxW to NxHxWxC (the model's expected input).
|
|
92
|
+
"""
|
|
93
|
+
if data_format == "channels_first":
|
|
94
|
+
# NxCxHxW -> NxHxWxC
|
|
95
|
+
inp_shape = (
|
|
96
|
+
plate_config.num_channels,
|
|
97
|
+
plate_config.img_height,
|
|
98
|
+
plate_config.img_width,
|
|
99
|
+
)
|
|
100
|
+
x_in = keras.Input(shape=inp_shape, dtype=input_dtype, name="input_nchw")
|
|
101
|
+
x_out = model(keras.layers.Permute((2, 3, 1))(x_in))
|
|
102
|
+
export_model = keras.Model(x_in, x_out, name=f"{model.name}_nchw")
|
|
103
|
+
else:
|
|
104
|
+
# Default is channels last (NxHxWxC), keep the original graph
|
|
105
|
+
inp_shape = (
|
|
106
|
+
plate_config.img_height,
|
|
107
|
+
plate_config.img_width,
|
|
108
|
+
plate_config.num_channels,
|
|
109
|
+
)
|
|
110
|
+
export_model = model
|
|
111
|
+
|
|
112
|
+
batch_dim = None if dynamic_batch else 1
|
|
113
|
+
spec_shape = (batch_dim, *inp_shape)
|
|
114
|
+
dummy_input = np.random.randint(0, 256, size=(1, *inp_shape)).astype(input_dtype)
|
|
115
|
+
return export_model, spec_shape, dummy_input
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@requires("onnx", "onnxruntime", "onnxslim")
|
|
119
|
+
def export_onnx(
|
|
120
|
+
model: keras.Model,
|
|
121
|
+
plate_config: PlateOCRConfig,
|
|
122
|
+
out_file: pathlib.Path,
|
|
123
|
+
simplify: bool,
|
|
124
|
+
dynamic_batch: bool,
|
|
125
|
+
skip_validation: bool = False,
|
|
126
|
+
onnx_input_dtype: str = "uint8",
|
|
127
|
+
onnx_data_format: TensorDataFormat = "channels_last",
|
|
128
|
+
) -> None:
|
|
129
|
+
import onnxruntime as rt
|
|
130
|
+
|
|
131
|
+
export_model, spec_shape, dummy_input = _prepare_model_for_onnx_export(
|
|
132
|
+
model, plate_config, dynamic_batch, onnx_input_dtype, onnx_data_format
|
|
133
|
+
)
|
|
134
|
+
spec = [keras.InputSpec(name="input", shape=spec_shape, dtype=onnx_input_dtype)]
|
|
135
|
+
|
|
136
|
+
with NamedTemporaryFile(suffix=".onnx") as tmp:
|
|
137
|
+
export_model.export(tmp.name, format="onnx", verbose=False, input_signature=spec)
|
|
138
|
+
|
|
139
|
+
if simplify:
|
|
140
|
+
import onnx
|
|
141
|
+
import onnxslim
|
|
142
|
+
|
|
143
|
+
logging.info("Simplifying ONNX ...")
|
|
144
|
+
model_simp = onnxslim.slim(onnx.load(tmp.name))
|
|
145
|
+
onnx.save(model_simp, out_file)
|
|
146
|
+
else:
|
|
147
|
+
shutil.copy(tmp.name, out_file)
|
|
148
|
+
|
|
149
|
+
# Load the newly converted ONNX model
|
|
150
|
+
sess = rt.InferenceSession(out_file)
|
|
151
|
+
input_name = sess.get_inputs()[0].name
|
|
152
|
+
output_names = [o.name for o in sess.get_outputs()]
|
|
153
|
+
|
|
154
|
+
def _predict(x: np.ndarray):
|
|
155
|
+
return sess.run(output_names, {input_name: x})[0]
|
|
156
|
+
|
|
157
|
+
if skip_validation:
|
|
158
|
+
logging.info("Skipping ONNX validation.")
|
|
159
|
+
else:
|
|
160
|
+
_validate_prediction(export_model, _predict, dummy_input, "ONNX")
|
|
161
|
+
|
|
162
|
+
with log_time_taken("ONNX inference time"):
|
|
163
|
+
_predict(dummy_input)
|
|
164
|
+
|
|
165
|
+
logging.info("Saved ONNX model to %s", out_file)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
@requires("tensorflow")
|
|
169
|
+
def export_tflite(
|
|
170
|
+
model: keras.Model,
|
|
171
|
+
plate_config: PlateOCRConfig,
|
|
172
|
+
out_file: pathlib.Path,
|
|
173
|
+
skip_validation: bool = False,
|
|
174
|
+
) -> None:
|
|
175
|
+
import tensorflow as tf
|
|
176
|
+
|
|
177
|
+
with TemporaryDirectory() as tmp_dir:
|
|
178
|
+
model.export(tmp_dir, format="tf_saved_model")
|
|
179
|
+
|
|
180
|
+
converter = tf.lite.TFLiteConverter.from_saved_model(tmp_dir)
|
|
181
|
+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
|
182
|
+
|
|
183
|
+
tflite_bytes = converter.convert()
|
|
184
|
+
out_file.write_bytes(tflite_bytes)
|
|
185
|
+
|
|
186
|
+
if skip_validation:
|
|
187
|
+
logging.info("Skipping TFLite validation.")
|
|
188
|
+
logging.info("Saved TFLite model to %s", out_file)
|
|
189
|
+
return
|
|
190
|
+
|
|
191
|
+
class _TFLiteRunner:
|
|
192
|
+
def __init__(self, path):
|
|
193
|
+
self.interp = tf.lite.Interpreter(str(path))
|
|
194
|
+
self.interp.allocate_tensors()
|
|
195
|
+
self.inp = self.interp.get_input_details()[0]["index"]
|
|
196
|
+
self.out = self.interp.get_output_details()[0]["index"]
|
|
197
|
+
|
|
198
|
+
def __call__(self, x: np.ndarray):
|
|
199
|
+
self.interp.set_tensor(self.inp, x)
|
|
200
|
+
self.interp.invoke()
|
|
201
|
+
return self.interp.get_tensor(self.out)
|
|
202
|
+
|
|
203
|
+
tfl_runner = _TFLiteRunner(out_file)
|
|
204
|
+
_validate_prediction(
|
|
205
|
+
model,
|
|
206
|
+
tfl_runner,
|
|
207
|
+
_dummy_input(
|
|
208
|
+
1,
|
|
209
|
+
plate_config.img_height,
|
|
210
|
+
plate_config.img_width,
|
|
211
|
+
plate_config.num_channels,
|
|
212
|
+
np.float32,
|
|
213
|
+
),
|
|
214
|
+
"TFLite",
|
|
215
|
+
atol=5e-3,
|
|
216
|
+
rtol=5e-3,
|
|
217
|
+
)
|
|
218
|
+
logging.info("Saved TFLite model to %s", out_file)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
@requires("coremltools", "tensorflow")
|
|
222
|
+
def export_coreml(
|
|
223
|
+
model: keras.Model,
|
|
224
|
+
plate_config: PlateOCRConfig,
|
|
225
|
+
out_file: pathlib.Path,
|
|
226
|
+
skip_validation: bool = False,
|
|
227
|
+
) -> None:
|
|
228
|
+
import coremltools as ct
|
|
229
|
+
import tensorflow as tf
|
|
230
|
+
|
|
231
|
+
with TemporaryDirectory() as tmp_dir:
|
|
232
|
+
model.export(tmp_dir, format="tf_saved_model")
|
|
233
|
+
loaded = tf.saved_model.load(tmp_dir)
|
|
234
|
+
func = loaded.signatures["serving_default"]
|
|
235
|
+
|
|
236
|
+
ct_inputs = [
|
|
237
|
+
ct.TensorType(
|
|
238
|
+
shape=(
|
|
239
|
+
1,
|
|
240
|
+
plate_config.img_height,
|
|
241
|
+
plate_config.img_width,
|
|
242
|
+
plate_config.num_channels,
|
|
243
|
+
),
|
|
244
|
+
dtype=np.float32,
|
|
245
|
+
)
|
|
246
|
+
]
|
|
247
|
+
mlmodel = ct.convert(
|
|
248
|
+
[func],
|
|
249
|
+
source="tensorflow",
|
|
250
|
+
convert_to="mlprogram",
|
|
251
|
+
inputs=ct_inputs,
|
|
252
|
+
)
|
|
253
|
+
mlmodel.save(str(out_file))
|
|
254
|
+
|
|
255
|
+
if skip_validation:
|
|
256
|
+
logging.info("Skipping CoreML validation.")
|
|
257
|
+
return
|
|
258
|
+
|
|
259
|
+
mlmodel = ct.models.MLModel(str(out_file))
|
|
260
|
+
|
|
261
|
+
spec = mlmodel.get_spec()
|
|
262
|
+
input_name = spec.description.input[0].name
|
|
263
|
+
output_name = spec.description.output[0].name
|
|
264
|
+
|
|
265
|
+
def _predict(x: np.ndarray):
|
|
266
|
+
return mlmodel.predict({input_name: x})[output_name]
|
|
267
|
+
|
|
268
|
+
_validate_prediction(
|
|
269
|
+
model,
|
|
270
|
+
_predict,
|
|
271
|
+
_dummy_input(
|
|
272
|
+
1,
|
|
273
|
+
plate_config.img_height,
|
|
274
|
+
plate_config.img_width,
|
|
275
|
+
plate_config.num_channels,
|
|
276
|
+
np.float32,
|
|
277
|
+
),
|
|
278
|
+
"CoreML",
|
|
279
|
+
)
|
|
280
|
+
logging.info("Saved CoreML model to %s", out_file)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
@click.command(context_settings={"max_content_width": 120})
|
|
284
|
+
@click.option(
|
|
285
|
+
"-m",
|
|
286
|
+
"--model",
|
|
287
|
+
"model_path",
|
|
288
|
+
required=True,
|
|
289
|
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=pathlib.Path),
|
|
290
|
+
help="Path to the saved .keras model.",
|
|
291
|
+
)
|
|
292
|
+
@click.option(
|
|
293
|
+
"-f",
|
|
294
|
+
"--format",
|
|
295
|
+
"export_format",
|
|
296
|
+
type=click.Choice(["onnx", "tflite", "coreml"], case_sensitive=False),
|
|
297
|
+
default="onnx",
|
|
298
|
+
show_default=True,
|
|
299
|
+
help="Target export format.",
|
|
300
|
+
)
|
|
301
|
+
@click.option(
|
|
302
|
+
"--simplify/--no-simplify",
|
|
303
|
+
default=True,
|
|
304
|
+
show_default=True,
|
|
305
|
+
help="Simplify ONNX model using onnxslim (only applies when format is ONNX).",
|
|
306
|
+
)
|
|
307
|
+
@click.option(
|
|
308
|
+
"--plate-config-file",
|
|
309
|
+
required=True,
|
|
310
|
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=pathlib.Path),
|
|
311
|
+
help="Path to the model OCR config YAML.",
|
|
312
|
+
)
|
|
313
|
+
@click.option(
|
|
314
|
+
"--save-dir",
|
|
315
|
+
required=False,
|
|
316
|
+
type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=pathlib.Path),
|
|
317
|
+
help="Directory to save the exported model. Defaults to model's directory.",
|
|
318
|
+
)
|
|
319
|
+
@click.option(
|
|
320
|
+
"--dynamic-batch/--no-dynamic-batch",
|
|
321
|
+
default=True,
|
|
322
|
+
show_default=True,
|
|
323
|
+
help="Enable dynamic batch size (only applies to ONNX format).",
|
|
324
|
+
)
|
|
325
|
+
@click.option(
|
|
326
|
+
"--skip-validation/--no-skip-validation",
|
|
327
|
+
default=False,
|
|
328
|
+
show_default=True,
|
|
329
|
+
help="Skip the post-export inference validation step.",
|
|
330
|
+
)
|
|
331
|
+
@click.option(
|
|
332
|
+
"--onnx-input-dtype",
|
|
333
|
+
type=click.Choice(["uint8", "float32"], case_sensitive=False),
|
|
334
|
+
default="uint8",
|
|
335
|
+
show_default=True,
|
|
336
|
+
help="Data type of the ONNX model input.",
|
|
337
|
+
)
|
|
338
|
+
@click.option(
|
|
339
|
+
"--onnx-data-format",
|
|
340
|
+
type=click.Choice(["channels_last", "channels_first"], case_sensitive=False),
|
|
341
|
+
default="channels_last",
|
|
342
|
+
show_default=True,
|
|
343
|
+
help=(
|
|
344
|
+
"Data format of the input tensor. It can be either "
|
|
345
|
+
"'channels_last' (NHWC) or 'channels_first' (NCHW)."
|
|
346
|
+
),
|
|
347
|
+
)
|
|
348
|
+
def export( # noqa: PLR0913
|
|
349
|
+
model_path: pathlib.Path,
|
|
350
|
+
export_format: str,
|
|
351
|
+
simplify: bool,
|
|
352
|
+
plate_config_file: pathlib.Path,
|
|
353
|
+
save_dir: pathlib.Path,
|
|
354
|
+
dynamic_batch: bool,
|
|
355
|
+
skip_validation: bool,
|
|
356
|
+
onnx_input_dtype: str,
|
|
357
|
+
onnx_data_format: TensorDataFormat,
|
|
358
|
+
) -> None:
|
|
359
|
+
"""
|
|
360
|
+
Export Keras models to other formats.
|
|
361
|
+
"""
|
|
362
|
+
|
|
363
|
+
plate_config = load_plate_config_from_yaml(plate_config_file)
|
|
364
|
+
model = load_keras_model(model_path, plate_config)
|
|
365
|
+
|
|
366
|
+
if export_format == "onnx":
|
|
367
|
+
out_file = _make_output_path(model_path, ".onnx", save_dir)
|
|
368
|
+
export_onnx(
|
|
369
|
+
model=model,
|
|
370
|
+
plate_config=plate_config,
|
|
371
|
+
out_file=out_file,
|
|
372
|
+
simplify=simplify,
|
|
373
|
+
dynamic_batch=dynamic_batch,
|
|
374
|
+
skip_validation=skip_validation,
|
|
375
|
+
onnx_input_dtype=onnx_input_dtype,
|
|
376
|
+
onnx_data_format=onnx_data_format,
|
|
377
|
+
)
|
|
378
|
+
elif export_format == "tflite":
|
|
379
|
+
out_file = _make_output_path(model_path, ".tflite", save_dir)
|
|
380
|
+
# TFLite doesn't seem to support dynamic batch size
|
|
381
|
+
# See: https://ai.google.dev/edge/litert/inference#run-inference
|
|
382
|
+
export_tflite(
|
|
383
|
+
model=model,
|
|
384
|
+
plate_config=plate_config,
|
|
385
|
+
out_file=out_file,
|
|
386
|
+
)
|
|
387
|
+
elif export_format == "coreml":
|
|
388
|
+
out_file = _make_output_path(model_path, ".mlpackage", save_dir)
|
|
389
|
+
export_coreml(
|
|
390
|
+
model=model,
|
|
391
|
+
plate_config=plate_config,
|
|
392
|
+
out_file=out_file,
|
|
393
|
+
skip_validation=skip_validation,
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
if __name__ == "__main__":
|
|
398
|
+
export()
|