matrice-analytics 0.1.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. matrice_analytics/__init__.py +28 -0
  2. matrice_analytics/boundary_drawing_internal/README.md +305 -0
  3. matrice_analytics/boundary_drawing_internal/__init__.py +45 -0
  4. matrice_analytics/boundary_drawing_internal/boundary_drawing_internal.py +1207 -0
  5. matrice_analytics/boundary_drawing_internal/boundary_drawing_tool.py +429 -0
  6. matrice_analytics/boundary_drawing_internal/boundary_tool_template.html +1036 -0
  7. matrice_analytics/boundary_drawing_internal/data/.gitignore +12 -0
  8. matrice_analytics/boundary_drawing_internal/example_usage.py +206 -0
  9. matrice_analytics/boundary_drawing_internal/usage/README.md +110 -0
  10. matrice_analytics/boundary_drawing_internal/usage/boundary_drawer_launcher.py +102 -0
  11. matrice_analytics/boundary_drawing_internal/usage/simple_boundary_launcher.py +107 -0
  12. matrice_analytics/post_processing/README.md +455 -0
  13. matrice_analytics/post_processing/__init__.py +732 -0
  14. matrice_analytics/post_processing/advanced_tracker/README.md +650 -0
  15. matrice_analytics/post_processing/advanced_tracker/__init__.py +17 -0
  16. matrice_analytics/post_processing/advanced_tracker/base.py +99 -0
  17. matrice_analytics/post_processing/advanced_tracker/config.py +77 -0
  18. matrice_analytics/post_processing/advanced_tracker/kalman_filter.py +370 -0
  19. matrice_analytics/post_processing/advanced_tracker/matching.py +195 -0
  20. matrice_analytics/post_processing/advanced_tracker/strack.py +230 -0
  21. matrice_analytics/post_processing/advanced_tracker/tracker.py +367 -0
  22. matrice_analytics/post_processing/config.py +146 -0
  23. matrice_analytics/post_processing/core/__init__.py +63 -0
  24. matrice_analytics/post_processing/core/base.py +704 -0
  25. matrice_analytics/post_processing/core/config.py +3291 -0
  26. matrice_analytics/post_processing/core/config_utils.py +925 -0
  27. matrice_analytics/post_processing/face_reg/__init__.py +43 -0
  28. matrice_analytics/post_processing/face_reg/compare_similarity.py +556 -0
  29. matrice_analytics/post_processing/face_reg/embedding_manager.py +950 -0
  30. matrice_analytics/post_processing/face_reg/face_recognition.py +2234 -0
  31. matrice_analytics/post_processing/face_reg/face_recognition_client.py +606 -0
  32. matrice_analytics/post_processing/face_reg/people_activity_logging.py +321 -0
  33. matrice_analytics/post_processing/ocr/__init__.py +0 -0
  34. matrice_analytics/post_processing/ocr/easyocr_extractor.py +250 -0
  35. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/__init__.py +9 -0
  36. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/__init__.py +4 -0
  37. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/cli.py +33 -0
  38. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/dataset_stats.py +139 -0
  39. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/export.py +398 -0
  40. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/train.py +447 -0
  41. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/utils.py +129 -0
  42. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/valid.py +93 -0
  43. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/validate_dataset.py +240 -0
  44. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/visualize_augmentation.py +176 -0
  45. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/visualize_predictions.py +96 -0
  46. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/__init__.py +3 -0
  47. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/process.py +246 -0
  48. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/types.py +60 -0
  49. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/utils.py +87 -0
  50. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/__init__.py +3 -0
  51. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/config.py +82 -0
  52. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/hub.py +141 -0
  53. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/plate_recognizer.py +323 -0
  54. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/py.typed +0 -0
  55. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/__init__.py +0 -0
  56. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/__init__.py +0 -0
  57. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/augmentation.py +101 -0
  58. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/dataset.py +97 -0
  59. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/__init__.py +0 -0
  60. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/config.py +114 -0
  61. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/layers.py +553 -0
  62. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/loss.py +55 -0
  63. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/metric.py +86 -0
  64. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/model_builders.py +95 -0
  65. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/model_schema.py +395 -0
  66. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/__init__.py +0 -0
  67. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/backend_utils.py +38 -0
  68. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/utils.py +214 -0
  69. matrice_analytics/post_processing/ocr/postprocessing.py +270 -0
  70. matrice_analytics/post_processing/ocr/preprocessing.py +52 -0
  71. matrice_analytics/post_processing/post_processor.py +1175 -0
  72. matrice_analytics/post_processing/test_cases/__init__.py +1 -0
  73. matrice_analytics/post_processing/test_cases/run_tests.py +143 -0
  74. matrice_analytics/post_processing/test_cases/test_advanced_customer_service.py +841 -0
  75. matrice_analytics/post_processing/test_cases/test_basic_counting_tracking.py +523 -0
  76. matrice_analytics/post_processing/test_cases/test_comprehensive.py +531 -0
  77. matrice_analytics/post_processing/test_cases/test_config.py +852 -0
  78. matrice_analytics/post_processing/test_cases/test_customer_service.py +585 -0
  79. matrice_analytics/post_processing/test_cases/test_data_generators.py +583 -0
  80. matrice_analytics/post_processing/test_cases/test_people_counting.py +510 -0
  81. matrice_analytics/post_processing/test_cases/test_processor.py +524 -0
  82. matrice_analytics/post_processing/test_cases/test_usecases.py +165 -0
  83. matrice_analytics/post_processing/test_cases/test_utilities.py +356 -0
  84. matrice_analytics/post_processing/test_cases/test_utils.py +743 -0
  85. matrice_analytics/post_processing/usecases/Histopathological_Cancer_Detection_img.py +604 -0
  86. matrice_analytics/post_processing/usecases/__init__.py +267 -0
  87. matrice_analytics/post_processing/usecases/abandoned_object_detection.py +797 -0
  88. matrice_analytics/post_processing/usecases/advanced_customer_service.py +1601 -0
  89. matrice_analytics/post_processing/usecases/age_detection.py +842 -0
  90. matrice_analytics/post_processing/usecases/age_gender_detection.py +1085 -0
  91. matrice_analytics/post_processing/usecases/anti_spoofing_detection.py +656 -0
  92. matrice_analytics/post_processing/usecases/assembly_line_detection.py +841 -0
  93. matrice_analytics/post_processing/usecases/banana_defect_detection.py +624 -0
  94. matrice_analytics/post_processing/usecases/basic_counting_tracking.py +667 -0
  95. matrice_analytics/post_processing/usecases/blood_cancer_detection_img.py +881 -0
  96. matrice_analytics/post_processing/usecases/car_damage_detection.py +834 -0
  97. matrice_analytics/post_processing/usecases/car_part_segmentation.py +946 -0
  98. matrice_analytics/post_processing/usecases/car_service.py +1601 -0
  99. matrice_analytics/post_processing/usecases/cardiomegaly_classification.py +864 -0
  100. matrice_analytics/post_processing/usecases/cell_microscopy_segmentation.py +897 -0
  101. matrice_analytics/post_processing/usecases/chicken_pose_detection.py +648 -0
  102. matrice_analytics/post_processing/usecases/child_monitoring.py +814 -0
  103. matrice_analytics/post_processing/usecases/color/clip.py +660 -0
  104. matrice_analytics/post_processing/usecases/color/clip_processor/merges.txt +48895 -0
  105. matrice_analytics/post_processing/usecases/color/clip_processor/preprocessor_config.json +28 -0
  106. matrice_analytics/post_processing/usecases/color/clip_processor/special_tokens_map.json +30 -0
  107. matrice_analytics/post_processing/usecases/color/clip_processor/tokenizer.json +245079 -0
  108. matrice_analytics/post_processing/usecases/color/clip_processor/tokenizer_config.json +32 -0
  109. matrice_analytics/post_processing/usecases/color/clip_processor/vocab.json +1 -0
  110. matrice_analytics/post_processing/usecases/color/color_map_utils.py +70 -0
  111. matrice_analytics/post_processing/usecases/color/color_mapper.py +468 -0
  112. matrice_analytics/post_processing/usecases/color_detection.py +1936 -0
  113. matrice_analytics/post_processing/usecases/color_map_utils.py +70 -0
  114. matrice_analytics/post_processing/usecases/concrete_crack_detection.py +827 -0
  115. matrice_analytics/post_processing/usecases/crop_weed_detection.py +781 -0
  116. matrice_analytics/post_processing/usecases/customer_service.py +1008 -0
  117. matrice_analytics/post_processing/usecases/defect_detection_products.py +936 -0
  118. matrice_analytics/post_processing/usecases/distracted_driver_detection.py +822 -0
  119. matrice_analytics/post_processing/usecases/drone_traffic_monitoring.py +585 -0
  120. matrice_analytics/post_processing/usecases/drowsy_driver_detection.py +829 -0
  121. matrice_analytics/post_processing/usecases/dwell_detection.py +829 -0
  122. matrice_analytics/post_processing/usecases/emergency_vehicle_detection.py +827 -0
  123. matrice_analytics/post_processing/usecases/face_emotion.py +813 -0
  124. matrice_analytics/post_processing/usecases/face_recognition.py +827 -0
  125. matrice_analytics/post_processing/usecases/fashion_detection.py +835 -0
  126. matrice_analytics/post_processing/usecases/field_mapping.py +902 -0
  127. matrice_analytics/post_processing/usecases/fire_detection.py +1146 -0
  128. matrice_analytics/post_processing/usecases/flare_analysis.py +836 -0
  129. matrice_analytics/post_processing/usecases/flower_segmentation.py +1006 -0
  130. matrice_analytics/post_processing/usecases/gas_leak_detection.py +837 -0
  131. matrice_analytics/post_processing/usecases/gender_detection.py +832 -0
  132. matrice_analytics/post_processing/usecases/human_activity_recognition.py +871 -0
  133. matrice_analytics/post_processing/usecases/intrusion_detection.py +1672 -0
  134. matrice_analytics/post_processing/usecases/leaf.py +821 -0
  135. matrice_analytics/post_processing/usecases/leaf_disease.py +840 -0
  136. matrice_analytics/post_processing/usecases/leak_detection.py +837 -0
  137. matrice_analytics/post_processing/usecases/license_plate_detection.py +1188 -0
  138. matrice_analytics/post_processing/usecases/license_plate_monitoring.py +1781 -0
  139. matrice_analytics/post_processing/usecases/litter_monitoring.py +717 -0
  140. matrice_analytics/post_processing/usecases/mask_detection.py +869 -0
  141. matrice_analytics/post_processing/usecases/natural_disaster.py +907 -0
  142. matrice_analytics/post_processing/usecases/parking.py +787 -0
  143. matrice_analytics/post_processing/usecases/parking_space_detection.py +822 -0
  144. matrice_analytics/post_processing/usecases/pcb_defect_detection.py +888 -0
  145. matrice_analytics/post_processing/usecases/pedestrian_detection.py +808 -0
  146. matrice_analytics/post_processing/usecases/people_counting.py +706 -0
  147. matrice_analytics/post_processing/usecases/people_counting_bckp.py +1683 -0
  148. matrice_analytics/post_processing/usecases/people_tracking.py +1842 -0
  149. matrice_analytics/post_processing/usecases/pipeline_detection.py +605 -0
  150. matrice_analytics/post_processing/usecases/plaque_segmentation_img.py +874 -0
  151. matrice_analytics/post_processing/usecases/pothole_segmentation.py +915 -0
  152. matrice_analytics/post_processing/usecases/ppe_compliance.py +645 -0
  153. matrice_analytics/post_processing/usecases/price_tag_detection.py +822 -0
  154. matrice_analytics/post_processing/usecases/proximity_detection.py +1901 -0
  155. matrice_analytics/post_processing/usecases/road_lane_detection.py +623 -0
  156. matrice_analytics/post_processing/usecases/road_traffic_density.py +832 -0
  157. matrice_analytics/post_processing/usecases/road_view_segmentation.py +915 -0
  158. matrice_analytics/post_processing/usecases/shelf_inventory_detection.py +583 -0
  159. matrice_analytics/post_processing/usecases/shoplifting_detection.py +822 -0
  160. matrice_analytics/post_processing/usecases/shopping_cart_analysis.py +899 -0
  161. matrice_analytics/post_processing/usecases/skin_cancer_classification_img.py +864 -0
  162. matrice_analytics/post_processing/usecases/smoker_detection.py +833 -0
  163. matrice_analytics/post_processing/usecases/solar_panel.py +810 -0
  164. matrice_analytics/post_processing/usecases/suspicious_activity_detection.py +1030 -0
  165. matrice_analytics/post_processing/usecases/template_usecase.py +380 -0
  166. matrice_analytics/post_processing/usecases/theft_detection.py +648 -0
  167. matrice_analytics/post_processing/usecases/traffic_sign_monitoring.py +724 -0
  168. matrice_analytics/post_processing/usecases/underground_pipeline_defect_detection.py +775 -0
  169. matrice_analytics/post_processing/usecases/underwater_pollution_detection.py +842 -0
  170. matrice_analytics/post_processing/usecases/vehicle_monitoring.py +1029 -0
  171. matrice_analytics/post_processing/usecases/warehouse_object_segmentation.py +899 -0
  172. matrice_analytics/post_processing/usecases/waterbody_segmentation.py +923 -0
  173. matrice_analytics/post_processing/usecases/weapon_detection.py +771 -0
  174. matrice_analytics/post_processing/usecases/weld_defect_detection.py +615 -0
  175. matrice_analytics/post_processing/usecases/wildlife_monitoring.py +898 -0
  176. matrice_analytics/post_processing/usecases/windmill_maintenance.py +834 -0
  177. matrice_analytics/post_processing/usecases/wound_segmentation.py +856 -0
  178. matrice_analytics/post_processing/utils/__init__.py +150 -0
  179. matrice_analytics/post_processing/utils/advanced_counting_utils.py +400 -0
  180. matrice_analytics/post_processing/utils/advanced_helper_utils.py +317 -0
  181. matrice_analytics/post_processing/utils/advanced_tracking_utils.py +461 -0
  182. matrice_analytics/post_processing/utils/alerting_utils.py +213 -0
  183. matrice_analytics/post_processing/utils/category_mapping_utils.py +94 -0
  184. matrice_analytics/post_processing/utils/color_utils.py +592 -0
  185. matrice_analytics/post_processing/utils/counting_utils.py +182 -0
  186. matrice_analytics/post_processing/utils/filter_utils.py +261 -0
  187. matrice_analytics/post_processing/utils/format_utils.py +293 -0
  188. matrice_analytics/post_processing/utils/geometry_utils.py +300 -0
  189. matrice_analytics/post_processing/utils/smoothing_utils.py +358 -0
  190. matrice_analytics/post_processing/utils/tracking_utils.py +234 -0
  191. matrice_analytics/py.typed +0 -0
  192. matrice_analytics-0.1.60.dist-info/METADATA +481 -0
  193. matrice_analytics-0.1.60.dist-info/RECORD +196 -0
  194. matrice_analytics-0.1.60.dist-info/WHEEL +5 -0
  195. matrice_analytics-0.1.60.dist-info/licenses/LICENSE.txt +21 -0
  196. matrice_analytics-0.1.60.dist-info/top_level.txt +1 -0
@@ -0,0 +1,323 @@
1
+ """
2
+ ONNX inference module.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import logging
8
+ import pathlib
9
+ from collections.abc import Sequence
10
+ from typing import Literal, Union, Optional
11
+
12
+ import numpy as np
13
+ import numpy.typing as npt
14
+
15
+ try:
16
+ import onnxruntime as ort
17
+ except ModuleNotFoundError as e:
18
+ raise ModuleNotFoundError(
19
+ "ONNX Runtime is not installed. Run: "
20
+ "pip install 'fast-plate-ocr[onnx]' (or [onnx-gpu], etc.)"
21
+ ) from e
22
+ from rich.console import Console
23
+ from rich.panel import Panel
24
+ from rich.table import Table
25
+ from rich.text import Text
26
+
27
+ from fast_plate_ocr.core.process import (
28
+ postprocess_output,
29
+ preprocess_image,
30
+ read_and_resize_plate_image,
31
+ resize_image,
32
+ )
33
+ from fast_plate_ocr.core.types import BatchArray, BatchOrImgLike, ImgLike, PathLike
34
+ from fast_plate_ocr.core.utils import measure_time
35
+ from fast_plate_ocr.inference import hub
36
+ from fast_plate_ocr.inference.config import PlateOCRConfig
37
+ from fast_plate_ocr.inference.hub import OcrModel
38
+
39
+
40
+ def _frame_from(item: ImgLike, cfg: PlateOCRConfig) -> BatchArray:
41
+ """
42
+ Converts a single image-like input into a normalized (H, W, C) NumPy array ready for model
43
+ inference. It handles both file paths and in-memory images. If input is a file path, the image
44
+ is read and resized using the configuration provided. If it's a NumPy array, it is validated and
45
+ resized accordingly.
46
+ """
47
+ # If it's a path, read and resize
48
+ if isinstance(item, (str, pathlib.PurePath)):
49
+ return read_and_resize_plate_image(
50
+ item,
51
+ img_height=cfg.img_height,
52
+ img_width=cfg.img_width,
53
+ image_color_mode=cfg.image_color_mode,
54
+ keep_aspect_ratio=cfg.keep_aspect_ratio,
55
+ interpolation_method=cfg.interpolation,
56
+ padding_color=cfg.padding_color,
57
+ )
58
+
59
+ # Otherwise it must be a numpy array
60
+ if not isinstance(item, np.ndarray):
61
+ raise TypeError(f"Unsupported element type: {type(item)}")
62
+
63
+ # If it has (N, H, W, C) shape we assume it's ready for inference
64
+ if item.ndim == 4:
65
+ return item
66
+
67
+ # If it's a single frame resize accordingly
68
+ return resize_image(
69
+ item,
70
+ cfg.img_height,
71
+ cfg.img_width,
72
+ image_color_mode=cfg.image_color_mode,
73
+ keep_aspect_ratio=cfg.keep_aspect_ratio,
74
+ interpolation_method=cfg.interpolation,
75
+ padding_color=cfg.padding_color,
76
+ )
77
+
78
+
79
+ def _load_image_from_source(source: BatchOrImgLike, cfg: PlateOCRConfig) -> BatchArray:
80
+ """
81
+ Converts an image input or batch of inputs into a 4-D NumPy array (N, H, W, C).
82
+
83
+ This utility supports a wide range of input formats, including single images or batches, file
84
+ paths or NumPy arrays. It ensures the result is always a model-ready batch.
85
+
86
+ Supported input formats:
87
+ - Single path (`str` or `PathLike`) -> image is read and resized
88
+ - List or tuple of paths -> each image is read and resized
89
+ - Single 2D or 3D NumPy array -> resized and wrapped in a batch
90
+ - List or tuple of NumPy arrays -> each image is resized and batched
91
+ - Single 4D NumPy array with shape (N, H, W, C) -> returned as is
92
+
93
+ Args:
94
+ source: A single image or batch of images in path or NumPy array format.
95
+ cfg: The configuration object that defines image preprocessing parameters.
96
+
97
+ Returns:
98
+ A 4D NumPy array of shape (N, H, W, C), dtype uint8, ready for model inference.
99
+ """
100
+ if isinstance(source, np.ndarray) and source.ndim == 4:
101
+ return source
102
+
103
+ items: Sequence[ImgLike] = (
104
+ source
105
+ if isinstance(source, Sequence)
106
+ and not isinstance(source, (str, pathlib.PurePath, np.ndarray))
107
+ else [source]
108
+ )
109
+
110
+ frames: list[BatchArray] = [
111
+ frame
112
+ for item in items
113
+ for frame in (
114
+ _frame_from(item, cfg) # type: ignore[attr-defined]
115
+ if isinstance(item, np.ndarray) and item.ndim == 4
116
+ else [_frame_from(item, cfg)]
117
+ )
118
+ ]
119
+
120
+ return np.stack(frames, axis=0, dtype=np.uint8)
121
+
122
+
123
+ class LicensePlateRecognizer:
124
+ """
125
+ ONNX inference class for performing license plates OCR.
126
+ """
127
+
128
+ def __init__(
129
+ self,
130
+ hub_ocr_model: Optional[OcrModel] = None,
131
+ device: Literal["cuda", "cpu", "auto"] = "auto",
132
+ providers: Optional[Sequence[Union[str, tuple[str, dict]]]] = None,
133
+ sess_options: Optional[ort.SessionOptions] = None,
134
+ onnx_model_path: Optional[PathLike] = None,
135
+ plate_config_path: Optional[PathLike] = None,
136
+ force_download: bool = False,
137
+ ) -> None:
138
+ """
139
+ Initializes the `LicensePlateRecognizer` with the specified OCR model and inference device.
140
+
141
+ The current OCR models available from the HUB are:
142
+
143
+ - `cct-s-v1-global-model`: OCR model trained with **global** plates data. Based on Compact
144
+ Convolutional Transformer (CCT) architecture. This is the **S** variant.
145
+ - `cct-xs-v1-global-model`: OCR model trained with **global** plates data. Based on Compact
146
+ Convolutional Transformer (CCT) architecture. This is the **XS** variant.
147
+ - `argentinian-plates-cnn-model`: OCR for **Argentinian** license plates. Uses fully conv
148
+ architecture.
149
+ - `argentinian-plates-cnn-synth-model`: OCR for **Argentinian** license plates trained with
150
+ synthetic and real data. Uses fully conv architecture.
151
+ - `european-plates-mobile-vit-v2-model`: OCR for **European** license plates. Uses
152
+ MobileVIT-2 for the backbone.
153
+ - `global-plates-mobile-vit-v2-model`: OCR for **global** license plates (+65 countries).
154
+ Uses MobileVIT-2 for the backbone.
155
+
156
+ Args:
157
+ hub_ocr_model: Name of the OCR model to use from the HUB.
158
+ device: Device type for inference. Should be one of ('cpu', 'cuda', 'auto'). If
159
+ 'auto' mode, the device will be deduced from
160
+ `onnxruntime.get_available_providers()`.
161
+ providers: Optional sequence of providers in order of decreasing precedence. If not
162
+ specified, all available providers are used based on the device argument.
163
+ sess_options: Advanced session options for ONNX Runtime.
164
+ onnx_model_path: Path to ONNX model file to use (In case you want to use a custom one).
165
+ plate_config_path: Path to config file to use (In case you want to use a custom one).
166
+ force_download: Force and download the model, even if it already exists.
167
+ Returns:
168
+ None.
169
+ """
170
+ self.logger = logging.getLogger(__name__)
171
+
172
+ if providers is not None:
173
+ self.providers = providers
174
+ self.logger.info("Using custom providers: %s", providers)
175
+ else:
176
+ if device == "cuda":
177
+ self.providers = ["CUDAExecutionProvider"]
178
+ elif device == "cpu":
179
+ self.providers = ["CPUExecutionProvider"]
180
+ elif device == "auto":
181
+ self.providers = ort.get_available_providers()
182
+ else:
183
+ raise ValueError(
184
+ f"Device should be one of ('cpu', 'cuda', 'auto'). Got '{device}'."
185
+ )
186
+
187
+ self.logger.info("Using device '%s' with providers: %s", device, self.providers)
188
+
189
+ if onnx_model_path and plate_config_path:
190
+ onnx_model_path = pathlib.Path(onnx_model_path)
191
+ plate_config_path = pathlib.Path(plate_config_path)
192
+ if not onnx_model_path.exists() or not plate_config_path.exists():
193
+ raise FileNotFoundError("Missing model/config file!")
194
+ self.model_name = onnx_model_path.stem
195
+ elif hub_ocr_model:
196
+ self.model_name = hub_ocr_model
197
+ onnx_model_path, plate_config_path = hub.download_model(
198
+ model_name=hub_ocr_model, force_download=force_download
199
+ )
200
+ else:
201
+ raise ValueError(
202
+ "Either provide a model from the HUB or a custom model_path and config_path"
203
+ )
204
+
205
+ self.config = PlateOCRConfig.from_yaml(plate_config_path)
206
+ self.model = ort.InferenceSession(
207
+ onnx_model_path, providers=self.providers, sess_options=sess_options
208
+ )
209
+
210
+ def benchmark(
211
+ self,
212
+ n_iter: int = 2_500,
213
+ batch_size: int = 1,
214
+ include_processing: bool = False,
215
+ warmup: int = 250,
216
+ ) -> None:
217
+ """
218
+ Run an inference benchmark and pretty print the results.
219
+
220
+ It reports the following metrics:
221
+
222
+ * **Average latency per batch** (milliseconds)
223
+ * **Throughput** in *plates / second* (PPS), i.e., how many plates the pipeline can process
224
+ per second at the chosen ``batch_size``.
225
+
226
+ Args:
227
+ n_iter: The number of iterations to run the benchmark. This determines how many times
228
+ the inference will be executed to compute the average performance metrics.
229
+ batch_size : Batch size to use for the benchmark.
230
+ include_processing: Indicates whether the benchmark should include preprocessing and
231
+ postprocessing times in the measurement.
232
+ warmup: Number of warmup iterations to run before the benchmark.
233
+ """
234
+ x = np.random.randint(
235
+ 0,
236
+ 256,
237
+ size=(
238
+ batch_size,
239
+ self.config.img_height,
240
+ self.config.img_width,
241
+ self.config.num_channels,
242
+ ),
243
+ dtype=np.uint8,
244
+ )
245
+
246
+ # Warm-up
247
+ for _ in range(warmup):
248
+ if include_processing:
249
+ self.run(x)
250
+ else:
251
+ self.model.run(None, {"input": x})
252
+
253
+ # Timed loop
254
+ cum_time = 0.0
255
+ for _ in range(n_iter):
256
+ with measure_time() as time_taken:
257
+ if include_processing:
258
+ self.run(x)
259
+ else:
260
+ self.model.run(None, {"input": x})
261
+ cum_time += time_taken()
262
+
263
+ avg_time_ms = cum_time / n_iter if n_iter else 0.0
264
+ pps = (1_000 / avg_time_ms) * batch_size if n_iter else 0.0
265
+
266
+ console = Console()
267
+ model_info = Panel(
268
+ Text(f"Model: {self.model_name}\nProviders: {self.providers}", style="bold green"),
269
+ title="Model Information",
270
+ border_style="bright_blue",
271
+ expand=False,
272
+ )
273
+ console.print(model_info)
274
+ table = Table(title=f"Benchmark for '{self.model_name}'", border_style="bright_blue")
275
+ table.add_column("Metric", justify="center", style="cyan", no_wrap=True)
276
+ table.add_column("Value", justify="center", style="magenta")
277
+
278
+ table.add_row("Batch size", str(batch_size))
279
+ table.add_row("Warm-up iters", str(warmup))
280
+ table.add_row("Timed iterations", str(n_iter))
281
+ table.add_row("Average Time / batch (ms)", f"{avg_time_ms:.4f}")
282
+ table.add_row("Plates per Second (PPS)", f"{pps:.4f}")
283
+ console.print(table)
284
+
285
+ def run(
286
+ self,
287
+ source: Union[str, list[str], npt.NDArray, list[npt.NDArray]],
288
+ return_confidence: bool = False,
289
+ ) -> Union[tuple[list[str], npt.NDArray], list[str]]:
290
+ """
291
+ Performs OCR to recognize license plate characters from an image or a list of images.
292
+
293
+ Args:
294
+ source: One or more image inputs, which can be:
295
+
296
+ - A file path (`str` or `PathLike`) to an image.
297
+ - A list of file paths.
298
+ - A NumPy array of a single image, with shape (H, W), (H, W, 1) or (H, W, 3).
299
+ - A list of NumPy arrays, each representing an image.
300
+ - A 4D NumPy array of shape (N, H, W, C), ready for inference.
301
+
302
+ Images will be automatically resized and converted as needed based on the model's
303
+ configuration (including color mode and aspect ratio settings).
304
+
305
+ return_confidence: Whether to return confidence scores along with plate predictions.
306
+
307
+ Returns:
308
+ A list of recognized license plates (one per image). If `return_confidence` is True,
309
+ also returns a NumPy array of shape `(N, plate_slots)` containing the confidence scores
310
+ for each predicted character.
311
+ """
312
+ x = _load_image_from_source(source, self.config)
313
+ # Preprocess
314
+ x = preprocess_image(x)
315
+ # Run model
316
+ y: list[npt.NDArray] = self.model.run(None, {"input": x})
317
+ # Postprocess model output
318
+ return postprocess_output(
319
+ y[0],
320
+ self.config.max_plate_slots,
321
+ self.config.alphabet,
322
+ return_confidence=return_confidence,
323
+ )
@@ -0,0 +1,101 @@
1
+ """
2
+ Augmentations used for training the OCR model.
3
+ """
4
+
5
+ import albumentations as A
6
+ import cv2
7
+
8
+ from fast_plate_ocr.core.types import ImageColorMode
9
+
10
+ BORDER_COLOR_BLACK: tuple[int, int, int] = (0, 0, 0)
11
+
12
+
13
+ def default_train_augmentation(img_color_mode: ImageColorMode) -> A.Compose:
14
+ """
15
+ Default training augmentation pipeline.
16
+ """
17
+ if img_color_mode == "grayscale":
18
+ return A.Compose(
19
+ [
20
+ A.Affine(
21
+ translate_percent=(-0.02, 0.02),
22
+ scale=(0.75, 1.10),
23
+ rotate=(-15, 15),
24
+ border_mode=cv2.BORDER_CONSTANT,
25
+ fill=BORDER_COLOR_BLACK,
26
+ shear=(0.0, 0.0),
27
+ p=0.75,
28
+ ),
29
+ A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=1.0),
30
+ A.GaussianBlur(sigma_limit=(0.2, 0.5), p=0.25),
31
+ A.OneOf(
32
+ [
33
+ A.CoarseDropout(
34
+ num_holes_range=(1, 14),
35
+ hole_height_range=(1, 5),
36
+ hole_width_range=(1, 5),
37
+ p=0.2,
38
+ ),
39
+ A.PixelDropout(dropout_prob=0.02, p=0.2),
40
+ A.GridDropout(ratio=0.3, fill="random", p=0.2),
41
+ ],
42
+ p=0.7,
43
+ ),
44
+ ]
45
+ )
46
+ if img_color_mode == "rgb":
47
+ return A.Compose(
48
+ [
49
+ A.Affine(
50
+ translate_percent=(-0.02, 0.02),
51
+ scale=(0.75, 1.10),
52
+ rotate=(-15, 15),
53
+ border_mode=cv2.BORDER_CONSTANT,
54
+ fill=BORDER_COLOR_BLACK,
55
+ shear=(0.0, 0.0),
56
+ p=0.75,
57
+ ),
58
+ A.RandomBrightnessContrast(brightness_limit=0.10, contrast_limit=0.10, p=0.5),
59
+ A.OneOf(
60
+ [
61
+ A.HueSaturationValue(
62
+ hue_shift_limit=5, sat_shift_limit=10, val_shift_limit=10, p=0.7
63
+ ),
64
+ A.RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.3),
65
+ ],
66
+ p=0.3,
67
+ ),
68
+ A.RandomGamma(gamma_limit=(95, 105), p=0.20),
69
+ A.ToGray(p=0.05),
70
+ A.OneOf(
71
+ [
72
+ A.GaussianBlur(sigma_limit=(0.2, 0.5), p=0.5),
73
+ A.MotionBlur(blur_limit=(3, 3), p=0.5),
74
+ ],
75
+ p=0.2,
76
+ ),
77
+ A.OneOf(
78
+ [
79
+ A.GaussNoise(std_range=(0.01, 0.03), p=0.2),
80
+ A.MultiplicativeNoise(multiplier=(0.98, 1.02), p=0.1),
81
+ A.ISONoise(intensity=(0.005, 0.02), p=0.1),
82
+ A.ImageCompression(quality_range=(55, 90), p=0.1),
83
+ ],
84
+ p=0.3,
85
+ ),
86
+ A.OneOf(
87
+ [
88
+ A.CoarseDropout(
89
+ num_holes_range=(1, 14),
90
+ hole_height_range=(1, 5),
91
+ hole_width_range=(1, 5),
92
+ p=0.2,
93
+ ),
94
+ A.PixelDropout(dropout_prob=0.02, p=0.3),
95
+ A.GridDropout(ratio=0.3, fill="random", p=0.3),
96
+ ],
97
+ p=0.5,
98
+ ),
99
+ ]
100
+ )
101
+ raise ValueError(f"Unsupported img_color_mode: {img_color_mode!r}. Expected 'grayscale'/'rgb'.")
@@ -0,0 +1,97 @@
1
+ """
2
+ Dataset module.
3
+ """
4
+
5
+ import math
6
+ import os
7
+ from typing import Union
8
+ import albumentations as A
9
+ import numpy as np
10
+ import numpy.typing as npt
11
+ import pandas as pd
12
+ from keras.src.trainers.data_adapters.py_dataset_adapter import PyDataset
13
+
14
+ from fast_plate_ocr.core.process import read_and_resize_plate_image
15
+ from fast_plate_ocr.train.model.config import PlateOCRConfig
16
+ from fast_plate_ocr.train.utilities import utils
17
+
18
+
19
+ class PlateRecognitionPyDataset(PyDataset):
20
+ """
21
+ Custom PyDataset for OCR license plate recognition.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ annotations_file: Union[str, os.PathLike],
27
+ plate_config: PlateOCRConfig,
28
+ batch_size: int,
29
+ transform: Optional[A.Compose] = None,
30
+ shuffle: bool = True,
31
+ **kwargs,
32
+ ) -> None:
33
+ super().__init__(**kwargs)
34
+ # Load annotations
35
+ annotations = pd.read_csv(annotations_file, dtype={"plate_text": str})
36
+ annotations["image_path"] = (
37
+ os.path.dirname(os.path.realpath(annotations_file)) + os.sep + annotations["image_path"]
38
+ )
39
+ # Check that plate lengths do not exceed max_plate_slots.
40
+ assert (annotations["plate_text"].str.len() <= plate_config.max_plate_slots).all(), (
41
+ "Plates are longer than max_plate_slots specified param. Change the parameter."
42
+ )
43
+ # Convert the dataframe to a NumPy array
44
+ self.annotations = annotations.to_numpy()
45
+
46
+ self.plate_config = plate_config
47
+ self.transform = transform
48
+ self.batch_size = batch_size
49
+ self.shuffle = shuffle
50
+
51
+ # Shuffle once at initialization if `shuffle=True`
52
+ self._shuffle_data()
53
+
54
+ def __len__(self) -> int:
55
+ return math.ceil(len(self.annotations) / self.batch_size)
56
+
57
+ def __getitem__(self, idx: int) -> tuple[npt.NDArray, npt.NDArray]:
58
+ # Determine the idx-es of current batch
59
+ low = idx * self.batch_size
60
+ high = min(low + self.batch_size, len(self.annotations))
61
+ batch = self.annotations[low:high]
62
+
63
+ batch_x = []
64
+ batch_y = []
65
+ for image_path, plate_text in batch:
66
+ # Read and process image
67
+ x = read_and_resize_plate_image(
68
+ image_path=image_path,
69
+ img_height=self.plate_config.img_height,
70
+ img_width=self.plate_config.img_width,
71
+ image_color_mode=self.plate_config.image_color_mode,
72
+ keep_aspect_ratio=self.plate_config.keep_aspect_ratio,
73
+ interpolation_method=self.plate_config.interpolation,
74
+ padding_color=self.plate_config.padding_color,
75
+ )
76
+ # Transform target
77
+ y = utils.target_transform(
78
+ plate_text=plate_text,
79
+ max_plate_slots=self.plate_config.max_plate_slots,
80
+ alphabet=self.plate_config.alphabet,
81
+ pad_char=self.plate_config.pad_char,
82
+ )
83
+ # Apply augmentation if provided
84
+ if self.transform:
85
+ x = self.transform(image=x)["image"]
86
+ batch_x.append(x)
87
+ batch_y.append(y)
88
+
89
+ return np.array(batch_x), np.array(batch_y)
90
+
91
+ def _shuffle_data(self) -> None:
92
+ if self.shuffle:
93
+ np.random.shuffle(self.annotations)
94
+
95
+ def on_epoch_begin(self) -> None:
96
+ # Optionally shuffle the dataset at the start of each epoch
97
+ self._shuffle_data()
@@ -0,0 +1,114 @@
1
+ """
2
+ License Plate OCR config. This config file defines how license plate images and text should be
3
+ preprocessed for OCR model training and inference.
4
+ """
5
+
6
+ from pathlib import Path
7
+ from typing import Annotated, TypeAlias, Union
8
+
9
+ import annotated_types
10
+ import yaml
11
+ from pydantic import (
12
+ BaseModel,
13
+ PositiveInt,
14
+ StringConstraints,
15
+ computed_field,
16
+ model_validator,
17
+ )
18
+
19
+ from fast_plate_ocr.core.types import ImageColorMode, ImageInterpolation, PathLike
20
+
21
+ UInt8: TypeAlias = Annotated[int, annotated_types.Ge(0), annotated_types.Le(255)]
22
+ """
23
+ An integer in the range [0, 255], used for color channel values.
24
+ """
25
+
26
+
27
+ class PlateOCRConfig(BaseModel, extra="forbid", frozen=True):
28
+ """
29
+ Model License Plate OCR config.
30
+ """
31
+
32
+ max_plate_slots: PositiveInt
33
+ """
34
+ Max number of plate slots supported. This represents the number of model classification heads.
35
+ """
36
+ alphabet: str
37
+ """
38
+ All the possible character set for the model output.
39
+ """
40
+ pad_char: Annotated[str, StringConstraints(min_length=1, max_length=1)]
41
+ """
42
+ Padding character for plates which length is smaller than MAX_PLATE_SLOTS.
43
+ """
44
+ img_height: PositiveInt
45
+ """
46
+ Image height which is fed to the model.
47
+ """
48
+ img_width: PositiveInt
49
+ """
50
+ Image width which is fed to the model.
51
+ """
52
+ keep_aspect_ratio: bool = False
53
+ """
54
+ Keep aspect ratio of the input image.
55
+ """
56
+ interpolation: ImageInterpolation = "linear"
57
+ """
58
+ Interpolation method used for resizing the input image.
59
+ """
60
+ image_color_mode: ImageColorMode = "grayscale"
61
+ """
62
+ Input image color mode. Use 'grayscale' for single-channel input or 'rgb' for 3-channel input.
63
+ """
64
+ padding_color: Union[tuple[UInt8, UInt8, UInt8], UInt8] = (114, 114, 114)
65
+ """
66
+ Padding color used when keep_aspect_ratio is True. For grayscale images, this should be a single
67
+ integer and for RGB images, this must be a tuple of three integers.
68
+ """
69
+
70
+ @computed_field # type: ignore[misc]
71
+ @property
72
+ def vocabulary_size(self) -> int:
73
+ return len(self.alphabet)
74
+
75
+ @computed_field # type: ignore[misc]
76
+ @property
77
+ def pad_idx(self) -> int:
78
+ return self.alphabet.index(self.pad_char)
79
+
80
+ @computed_field # type: ignore[misc]
81
+ @property
82
+ def num_channels(self) -> int:
83
+ return 3 if self.image_color_mode == "rgb" else 1
84
+
85
+ @model_validator(mode="after")
86
+ def check_alphabet_and_pad(self) -> "PlateOCRConfig":
87
+ # `pad_char` must be in alphabet
88
+ if self.pad_char not in self.alphabet:
89
+ raise ValueError("Pad character must be present in model alphabet.")
90
+ # all chars in alphabet must be unique
91
+ if len(set(self.alphabet)) != len(self.alphabet):
92
+ raise ValueError("Alphabet must not contain duplicate characters.")
93
+ return self
94
+
95
+
96
+ def load_plate_config_from_yaml(yaml_path: PathLike) -> PlateOCRConfig:
97
+ """
98
+ Reads and parses a YAML file containing the plate configuration.
99
+
100
+ Args:
101
+ yaml_path: Path to the YAML file containing the plate config.
102
+
103
+ Returns:
104
+ PlateOCRConfig: Parsed and validated plate configuration.
105
+
106
+ Raises:
107
+ FileNotFoundError: If the YAML file does not exist.
108
+ """
109
+ if not Path(yaml_path).is_file():
110
+ raise FileNotFoundError(f"Plate config '{yaml_path}' doesn't exist.")
111
+ with open(yaml_path, encoding="utf-8") as f_in:
112
+ yaml_content = yaml.safe_load(f_in)
113
+ config = PlateOCRConfig(**yaml_content)
114
+ return config