matrice-analytics 0.1.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. matrice_analytics/__init__.py +28 -0
  2. matrice_analytics/boundary_drawing_internal/README.md +305 -0
  3. matrice_analytics/boundary_drawing_internal/__init__.py +45 -0
  4. matrice_analytics/boundary_drawing_internal/boundary_drawing_internal.py +1207 -0
  5. matrice_analytics/boundary_drawing_internal/boundary_drawing_tool.py +429 -0
  6. matrice_analytics/boundary_drawing_internal/boundary_tool_template.html +1036 -0
  7. matrice_analytics/boundary_drawing_internal/data/.gitignore +12 -0
  8. matrice_analytics/boundary_drawing_internal/example_usage.py +206 -0
  9. matrice_analytics/boundary_drawing_internal/usage/README.md +110 -0
  10. matrice_analytics/boundary_drawing_internal/usage/boundary_drawer_launcher.py +102 -0
  11. matrice_analytics/boundary_drawing_internal/usage/simple_boundary_launcher.py +107 -0
  12. matrice_analytics/post_processing/README.md +455 -0
  13. matrice_analytics/post_processing/__init__.py +732 -0
  14. matrice_analytics/post_processing/advanced_tracker/README.md +650 -0
  15. matrice_analytics/post_processing/advanced_tracker/__init__.py +17 -0
  16. matrice_analytics/post_processing/advanced_tracker/base.py +99 -0
  17. matrice_analytics/post_processing/advanced_tracker/config.py +77 -0
  18. matrice_analytics/post_processing/advanced_tracker/kalman_filter.py +370 -0
  19. matrice_analytics/post_processing/advanced_tracker/matching.py +195 -0
  20. matrice_analytics/post_processing/advanced_tracker/strack.py +230 -0
  21. matrice_analytics/post_processing/advanced_tracker/tracker.py +367 -0
  22. matrice_analytics/post_processing/config.py +146 -0
  23. matrice_analytics/post_processing/core/__init__.py +63 -0
  24. matrice_analytics/post_processing/core/base.py +704 -0
  25. matrice_analytics/post_processing/core/config.py +3291 -0
  26. matrice_analytics/post_processing/core/config_utils.py +925 -0
  27. matrice_analytics/post_processing/face_reg/__init__.py +43 -0
  28. matrice_analytics/post_processing/face_reg/compare_similarity.py +556 -0
  29. matrice_analytics/post_processing/face_reg/embedding_manager.py +950 -0
  30. matrice_analytics/post_processing/face_reg/face_recognition.py +2234 -0
  31. matrice_analytics/post_processing/face_reg/face_recognition_client.py +606 -0
  32. matrice_analytics/post_processing/face_reg/people_activity_logging.py +321 -0
  33. matrice_analytics/post_processing/ocr/__init__.py +0 -0
  34. matrice_analytics/post_processing/ocr/easyocr_extractor.py +250 -0
  35. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/__init__.py +9 -0
  36. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/__init__.py +4 -0
  37. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/cli.py +33 -0
  38. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/dataset_stats.py +139 -0
  39. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/export.py +398 -0
  40. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/train.py +447 -0
  41. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/utils.py +129 -0
  42. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/valid.py +93 -0
  43. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/validate_dataset.py +240 -0
  44. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/visualize_augmentation.py +176 -0
  45. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/visualize_predictions.py +96 -0
  46. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/__init__.py +3 -0
  47. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/process.py +246 -0
  48. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/types.py +60 -0
  49. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/utils.py +87 -0
  50. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/__init__.py +3 -0
  51. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/config.py +82 -0
  52. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/hub.py +141 -0
  53. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/plate_recognizer.py +323 -0
  54. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/py.typed +0 -0
  55. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/__init__.py +0 -0
  56. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/__init__.py +0 -0
  57. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/augmentation.py +101 -0
  58. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/dataset.py +97 -0
  59. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/__init__.py +0 -0
  60. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/config.py +114 -0
  61. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/layers.py +553 -0
  62. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/loss.py +55 -0
  63. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/metric.py +86 -0
  64. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/model_builders.py +95 -0
  65. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/model_schema.py +395 -0
  66. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/__init__.py +0 -0
  67. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/backend_utils.py +38 -0
  68. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/utils.py +214 -0
  69. matrice_analytics/post_processing/ocr/postprocessing.py +270 -0
  70. matrice_analytics/post_processing/ocr/preprocessing.py +52 -0
  71. matrice_analytics/post_processing/post_processor.py +1175 -0
  72. matrice_analytics/post_processing/test_cases/__init__.py +1 -0
  73. matrice_analytics/post_processing/test_cases/run_tests.py +143 -0
  74. matrice_analytics/post_processing/test_cases/test_advanced_customer_service.py +841 -0
  75. matrice_analytics/post_processing/test_cases/test_basic_counting_tracking.py +523 -0
  76. matrice_analytics/post_processing/test_cases/test_comprehensive.py +531 -0
  77. matrice_analytics/post_processing/test_cases/test_config.py +852 -0
  78. matrice_analytics/post_processing/test_cases/test_customer_service.py +585 -0
  79. matrice_analytics/post_processing/test_cases/test_data_generators.py +583 -0
  80. matrice_analytics/post_processing/test_cases/test_people_counting.py +510 -0
  81. matrice_analytics/post_processing/test_cases/test_processor.py +524 -0
  82. matrice_analytics/post_processing/test_cases/test_usecases.py +165 -0
  83. matrice_analytics/post_processing/test_cases/test_utilities.py +356 -0
  84. matrice_analytics/post_processing/test_cases/test_utils.py +743 -0
  85. matrice_analytics/post_processing/usecases/Histopathological_Cancer_Detection_img.py +604 -0
  86. matrice_analytics/post_processing/usecases/__init__.py +267 -0
  87. matrice_analytics/post_processing/usecases/abandoned_object_detection.py +797 -0
  88. matrice_analytics/post_processing/usecases/advanced_customer_service.py +1601 -0
  89. matrice_analytics/post_processing/usecases/age_detection.py +842 -0
  90. matrice_analytics/post_processing/usecases/age_gender_detection.py +1085 -0
  91. matrice_analytics/post_processing/usecases/anti_spoofing_detection.py +656 -0
  92. matrice_analytics/post_processing/usecases/assembly_line_detection.py +841 -0
  93. matrice_analytics/post_processing/usecases/banana_defect_detection.py +624 -0
  94. matrice_analytics/post_processing/usecases/basic_counting_tracking.py +667 -0
  95. matrice_analytics/post_processing/usecases/blood_cancer_detection_img.py +881 -0
  96. matrice_analytics/post_processing/usecases/car_damage_detection.py +834 -0
  97. matrice_analytics/post_processing/usecases/car_part_segmentation.py +946 -0
  98. matrice_analytics/post_processing/usecases/car_service.py +1601 -0
  99. matrice_analytics/post_processing/usecases/cardiomegaly_classification.py +864 -0
  100. matrice_analytics/post_processing/usecases/cell_microscopy_segmentation.py +897 -0
  101. matrice_analytics/post_processing/usecases/chicken_pose_detection.py +648 -0
  102. matrice_analytics/post_processing/usecases/child_monitoring.py +814 -0
  103. matrice_analytics/post_processing/usecases/color/clip.py +660 -0
  104. matrice_analytics/post_processing/usecases/color/clip_processor/merges.txt +48895 -0
  105. matrice_analytics/post_processing/usecases/color/clip_processor/preprocessor_config.json +28 -0
  106. matrice_analytics/post_processing/usecases/color/clip_processor/special_tokens_map.json +30 -0
  107. matrice_analytics/post_processing/usecases/color/clip_processor/tokenizer.json +245079 -0
  108. matrice_analytics/post_processing/usecases/color/clip_processor/tokenizer_config.json +32 -0
  109. matrice_analytics/post_processing/usecases/color/clip_processor/vocab.json +1 -0
  110. matrice_analytics/post_processing/usecases/color/color_map_utils.py +70 -0
  111. matrice_analytics/post_processing/usecases/color/color_mapper.py +468 -0
  112. matrice_analytics/post_processing/usecases/color_detection.py +1936 -0
  113. matrice_analytics/post_processing/usecases/color_map_utils.py +70 -0
  114. matrice_analytics/post_processing/usecases/concrete_crack_detection.py +827 -0
  115. matrice_analytics/post_processing/usecases/crop_weed_detection.py +781 -0
  116. matrice_analytics/post_processing/usecases/customer_service.py +1008 -0
  117. matrice_analytics/post_processing/usecases/defect_detection_products.py +936 -0
  118. matrice_analytics/post_processing/usecases/distracted_driver_detection.py +822 -0
  119. matrice_analytics/post_processing/usecases/drone_traffic_monitoring.py +585 -0
  120. matrice_analytics/post_processing/usecases/drowsy_driver_detection.py +829 -0
  121. matrice_analytics/post_processing/usecases/dwell_detection.py +829 -0
  122. matrice_analytics/post_processing/usecases/emergency_vehicle_detection.py +827 -0
  123. matrice_analytics/post_processing/usecases/face_emotion.py +813 -0
  124. matrice_analytics/post_processing/usecases/face_recognition.py +827 -0
  125. matrice_analytics/post_processing/usecases/fashion_detection.py +835 -0
  126. matrice_analytics/post_processing/usecases/field_mapping.py +902 -0
  127. matrice_analytics/post_processing/usecases/fire_detection.py +1146 -0
  128. matrice_analytics/post_processing/usecases/flare_analysis.py +836 -0
  129. matrice_analytics/post_processing/usecases/flower_segmentation.py +1006 -0
  130. matrice_analytics/post_processing/usecases/gas_leak_detection.py +837 -0
  131. matrice_analytics/post_processing/usecases/gender_detection.py +832 -0
  132. matrice_analytics/post_processing/usecases/human_activity_recognition.py +871 -0
  133. matrice_analytics/post_processing/usecases/intrusion_detection.py +1672 -0
  134. matrice_analytics/post_processing/usecases/leaf.py +821 -0
  135. matrice_analytics/post_processing/usecases/leaf_disease.py +840 -0
  136. matrice_analytics/post_processing/usecases/leak_detection.py +837 -0
  137. matrice_analytics/post_processing/usecases/license_plate_detection.py +1188 -0
  138. matrice_analytics/post_processing/usecases/license_plate_monitoring.py +1781 -0
  139. matrice_analytics/post_processing/usecases/litter_monitoring.py +717 -0
  140. matrice_analytics/post_processing/usecases/mask_detection.py +869 -0
  141. matrice_analytics/post_processing/usecases/natural_disaster.py +907 -0
  142. matrice_analytics/post_processing/usecases/parking.py +787 -0
  143. matrice_analytics/post_processing/usecases/parking_space_detection.py +822 -0
  144. matrice_analytics/post_processing/usecases/pcb_defect_detection.py +888 -0
  145. matrice_analytics/post_processing/usecases/pedestrian_detection.py +808 -0
  146. matrice_analytics/post_processing/usecases/people_counting.py +706 -0
  147. matrice_analytics/post_processing/usecases/people_counting_bckp.py +1683 -0
  148. matrice_analytics/post_processing/usecases/people_tracking.py +1842 -0
  149. matrice_analytics/post_processing/usecases/pipeline_detection.py +605 -0
  150. matrice_analytics/post_processing/usecases/plaque_segmentation_img.py +874 -0
  151. matrice_analytics/post_processing/usecases/pothole_segmentation.py +915 -0
  152. matrice_analytics/post_processing/usecases/ppe_compliance.py +645 -0
  153. matrice_analytics/post_processing/usecases/price_tag_detection.py +822 -0
  154. matrice_analytics/post_processing/usecases/proximity_detection.py +1901 -0
  155. matrice_analytics/post_processing/usecases/road_lane_detection.py +623 -0
  156. matrice_analytics/post_processing/usecases/road_traffic_density.py +832 -0
  157. matrice_analytics/post_processing/usecases/road_view_segmentation.py +915 -0
  158. matrice_analytics/post_processing/usecases/shelf_inventory_detection.py +583 -0
  159. matrice_analytics/post_processing/usecases/shoplifting_detection.py +822 -0
  160. matrice_analytics/post_processing/usecases/shopping_cart_analysis.py +899 -0
  161. matrice_analytics/post_processing/usecases/skin_cancer_classification_img.py +864 -0
  162. matrice_analytics/post_processing/usecases/smoker_detection.py +833 -0
  163. matrice_analytics/post_processing/usecases/solar_panel.py +810 -0
  164. matrice_analytics/post_processing/usecases/suspicious_activity_detection.py +1030 -0
  165. matrice_analytics/post_processing/usecases/template_usecase.py +380 -0
  166. matrice_analytics/post_processing/usecases/theft_detection.py +648 -0
  167. matrice_analytics/post_processing/usecases/traffic_sign_monitoring.py +724 -0
  168. matrice_analytics/post_processing/usecases/underground_pipeline_defect_detection.py +775 -0
  169. matrice_analytics/post_processing/usecases/underwater_pollution_detection.py +842 -0
  170. matrice_analytics/post_processing/usecases/vehicle_monitoring.py +1029 -0
  171. matrice_analytics/post_processing/usecases/warehouse_object_segmentation.py +899 -0
  172. matrice_analytics/post_processing/usecases/waterbody_segmentation.py +923 -0
  173. matrice_analytics/post_processing/usecases/weapon_detection.py +771 -0
  174. matrice_analytics/post_processing/usecases/weld_defect_detection.py +615 -0
  175. matrice_analytics/post_processing/usecases/wildlife_monitoring.py +898 -0
  176. matrice_analytics/post_processing/usecases/windmill_maintenance.py +834 -0
  177. matrice_analytics/post_processing/usecases/wound_segmentation.py +856 -0
  178. matrice_analytics/post_processing/utils/__init__.py +150 -0
  179. matrice_analytics/post_processing/utils/advanced_counting_utils.py +400 -0
  180. matrice_analytics/post_processing/utils/advanced_helper_utils.py +317 -0
  181. matrice_analytics/post_processing/utils/advanced_tracking_utils.py +461 -0
  182. matrice_analytics/post_processing/utils/alerting_utils.py +213 -0
  183. matrice_analytics/post_processing/utils/category_mapping_utils.py +94 -0
  184. matrice_analytics/post_processing/utils/color_utils.py +592 -0
  185. matrice_analytics/post_processing/utils/counting_utils.py +182 -0
  186. matrice_analytics/post_processing/utils/filter_utils.py +261 -0
  187. matrice_analytics/post_processing/utils/format_utils.py +293 -0
  188. matrice_analytics/post_processing/utils/geometry_utils.py +300 -0
  189. matrice_analytics/post_processing/utils/smoothing_utils.py +358 -0
  190. matrice_analytics/post_processing/utils/tracking_utils.py +234 -0
  191. matrice_analytics/py.typed +0 -0
  192. matrice_analytics-0.1.60.dist-info/METADATA +481 -0
  193. matrice_analytics-0.1.60.dist-info/RECORD +196 -0
  194. matrice_analytics-0.1.60.dist-info/WHEEL +5 -0
  195. matrice_analytics-0.1.60.dist-info/licenses/LICENSE.txt +21 -0
  196. matrice_analytics-0.1.60.dist-info/top_level.txt +1 -0
@@ -0,0 +1,553 @@
1
+ """
2
+ Layer blocks used in the OCR model.
3
+ """
4
+
5
+ from collections.abc import Sequence
6
+ from typing import Optional, Union
7
+ import keras
8
+ import numpy as np
9
+ from keras import ops
10
+
11
+ # pylint: disable=too-many-ancestors,abstract-method,attribute-defined-outside-init,arguments-differ
12
+ # pylint: disable=useless-parent-delegation,too-many-instance-attributes,too-many-arguments
13
+
14
+
15
+ @keras.saving.register_keras_serializable(package="fast_plate_ocr")
16
+ class AddCoords(keras.layers.Layer):
17
+ """Add coords to a tensor, modified from paper: https://arxiv.org/abs/1807.03247"""
18
+
19
+ def __init__(self, with_r=False):
20
+ super().__init__()
21
+ self.with_r = with_r
22
+
23
+ def build(self, input_shape):
24
+ # Assuming input_shape is (batch, height, width, channels)
25
+ self.x_dim = input_shape[1]
26
+ self.y_dim = input_shape[2]
27
+
28
+ def call(self, input_tensor):
29
+ """
30
+ input_tensor: (batch, x_dim, y_dim, c)
31
+ """
32
+ batch_size_tensor = ops.shape(input_tensor)[0]
33
+ xx_ones = ops.ones([batch_size_tensor, self.x_dim])
34
+ xx_ones = ops.expand_dims(xx_ones, -1)
35
+ xx_range = ops.tile(ops.expand_dims(ops.arange(self.y_dim), 0), [batch_size_tensor, 1])
36
+ xx_range = ops.expand_dims(xx_range, 1)
37
+ xx_channel = ops.matmul(xx_ones, xx_range)
38
+ xx_channel = ops.expand_dims(xx_channel, -1)
39
+ yy_ones = ops.ones([batch_size_tensor, self.y_dim])
40
+ yy_ones = ops.expand_dims(yy_ones, 1)
41
+ yy_range = ops.tile(ops.expand_dims(ops.arange(self.x_dim), 0), [batch_size_tensor, 1])
42
+
43
+ yy_range = ops.expand_dims(yy_range, -1)
44
+ yy_channel = ops.matmul(yy_range, yy_ones)
45
+ yy_channel = ops.expand_dims(yy_channel, -1)
46
+ xx_channel = ops.cast(xx_channel, "float32") / (self.x_dim - 1)
47
+ yy_channel = ops.cast(yy_channel, "float32") / (self.y_dim - 1)
48
+ xx_channel = xx_channel * 2 - 1
49
+ yy_channel = yy_channel * 2 - 1
50
+ ret = ops.concatenate([input_tensor, xx_channel, yy_channel], axis=-1)
51
+ if self.with_r:
52
+ rr = ops.sqrt(ops.square(xx_channel) + ops.square(yy_channel))
53
+ ret = ops.concatenate([ret, rr], axis=-1)
54
+ return ret
55
+
56
+
57
+ @keras.saving.register_keras_serializable(package="fast_plate_ocr")
58
+ class CoordConv2D(keras.layers.Layer):
59
+ """CoordConv2D layer as in the paper, modified from paper: https://arxiv.org/abs/1807.03247"""
60
+
61
+ def __init__(self, with_r: bool = False, **conv_kwargs):
62
+ super().__init__()
63
+ self.with_r = with_r
64
+ self.conv_kwargs = conv_kwargs.copy()
65
+ self.addcoords = AddCoords(with_r=with_r)
66
+ self.conv = keras.layers.Conv2D(**conv_kwargs)
67
+
68
+ def call(self, inputs):
69
+ x = self.addcoords(inputs)
70
+ return self.conv(x)
71
+
72
+ def get_config(self):
73
+ config = super().get_config()
74
+ config.update({"with_r": self.with_r, **self.conv_kwargs})
75
+ return config
76
+
77
+
78
+ def _build_binomial_filter(filter_size: int) -> np.ndarray:
79
+ """Builds and returns the normalized binomial filter according to `filter_size`."""
80
+ if filter_size == 1:
81
+ binomial_filter = np.array([1.0])
82
+ elif filter_size == 2:
83
+ binomial_filter = np.array([1.0, 1.0])
84
+ elif filter_size == 3:
85
+ binomial_filter = np.array([1.0, 2.0, 1.0])
86
+ elif filter_size == 4:
87
+ binomial_filter = np.array([1.0, 3.0, 3.0, 1.0])
88
+ elif filter_size == 5:
89
+ binomial_filter = np.array([1.0, 4.0, 6.0, 4.0, 1.0])
90
+ elif filter_size == 6:
91
+ binomial_filter = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0])
92
+ elif filter_size == 7:
93
+ binomial_filter = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0])
94
+ else:
95
+ raise ValueError(f"Filter size not supported, got {filter_size}")
96
+
97
+ binomial_filter = binomial_filter[:, np.newaxis] * binomial_filter[np.newaxis, :]
98
+ binomial_filter = binomial_filter / np.sum(binomial_filter)
99
+
100
+ return binomial_filter
101
+
102
+
103
+ @keras.saving.register_keras_serializable(package="fast_plate_ocr")
104
+ class MaxBlurPooling2D(keras.layers.Layer):
105
+ def __init__(self, pool_size: int = 2, filter_size: int = 3, padding: str = "same", **kwargs):
106
+ self.pool_size = pool_size
107
+ self.blur_kernel = None
108
+ self.filter_size = filter_size
109
+ self.padding = padding
110
+
111
+ super().__init__(**kwargs)
112
+
113
+ def build(self, input_shape):
114
+ binomial_filter = _build_binomial_filter(filter_size=self.filter_size)
115
+ binomial_filter = np.repeat(binomial_filter, input_shape[3])
116
+ # Maybe this should be channel first/last agnostic
117
+ binomial_filter = np.reshape(
118
+ binomial_filter, (self.filter_size, self.filter_size, input_shape[3], 1)
119
+ )
120
+ blur_init = keras.initializers.constant(binomial_filter)
121
+
122
+ self.blur_kernel = self.add_weight(
123
+ name="blur_kernel",
124
+ shape=(self.filter_size, self.filter_size, input_shape[3], 1),
125
+ initializer=blur_init,
126
+ trainable=False,
127
+ )
128
+
129
+ super().build(input_shape)
130
+
131
+ def call(self, x):
132
+ x = ops.max_pool(
133
+ x,
134
+ (self.pool_size, self.pool_size),
135
+ strides=(1, 1),
136
+ padding=self.padding,
137
+ )
138
+ x = ops.depthwise_conv(
139
+ x, self.blur_kernel, padding=self.padding, strides=(self.pool_size, self.pool_size)
140
+ )
141
+
142
+ return x
143
+
144
+ def compute_output_shape(self, input_shape):
145
+ return (
146
+ input_shape[0],
147
+ int(np.ceil(input_shape[1] / 2)),
148
+ int(np.ceil(input_shape[2] / 2)),
149
+ input_shape[3],
150
+ )
151
+
152
+ def get_config(self):
153
+ config = super().get_config()
154
+ config.update(
155
+ {
156
+ "pool_size": self.pool_size,
157
+ "filter_size": self.filter_size,
158
+ "padding": self.padding,
159
+ }
160
+ )
161
+ return config
162
+
163
+
164
+ @keras.saving.register_keras_serializable(package="fast_plate_ocr")
165
+ class SqueezeExcite(keras.layers.Layer):
166
+ """
167
+ Applies squeeze and excitation to input feature maps as seen in https://arxiv.org/abs/1709.01507
168
+
169
+ Note: this was taken from https://keras.io/examples/vision/patch_convnet.
170
+ """
171
+
172
+ def __init__(self, ratio: float = 1.0, **kwargs):
173
+ super().__init__(**kwargs)
174
+ self.ratio = ratio
175
+
176
+ def get_config(self):
177
+ config = super().get_config()
178
+ config.update({"ratio": self.ratio})
179
+ return config
180
+
181
+ def build(self, input_shape):
182
+ filters = input_shape[-1]
183
+ self.squeeze = keras.layers.GlobalAveragePooling2D(keepdims=True)
184
+ self.reduction = keras.layers.Dense(
185
+ units=int(filters // self.ratio),
186
+ activation="relu",
187
+ use_bias=False,
188
+ )
189
+ self.excite = keras.layers.Dense(units=filters, activation="sigmoid", use_bias=False)
190
+ self.multiply = keras.layers.Multiply()
191
+
192
+ def call(self, x):
193
+ shortcut = x
194
+ x = self.squeeze(x)
195
+ x = self.reduction(x)
196
+ x = self.excite(x)
197
+ x = self.multiply([shortcut, x])
198
+ return x
199
+
200
+
201
+ @keras.utils.register_keras_serializable(package="fast_plate_ocr")
202
+ class DyT(keras.layers.Layer):
203
+ """
204
+ Dynamic Tanh (DyT) is an element-wise operation as a drop-in replacement for normalization
205
+ layers in Transformers.
206
+
207
+ Paper: https://arxiv.org/abs/2503.10622.
208
+ """
209
+
210
+ def __init__(self, alpha_init_value: float = 0.5, **kwargs):
211
+ super().__init__(**kwargs)
212
+ self.alpha_init_value = alpha_init_value
213
+
214
+ def build(self, input_shape):
215
+ channels = int(input_shape[-1])
216
+
217
+ # scalar alpha
218
+ self.alpha = self.add_weight(
219
+ name="alpha",
220
+ shape=(),
221
+ initializer=keras.initializers.Constant(self.alpha_init_value),
222
+ trainable=True,
223
+ )
224
+
225
+ self.weight = self.add_weight(
226
+ name="weight",
227
+ shape=(channels,),
228
+ initializer="ones",
229
+ trainable=True,
230
+ )
231
+ self.bias = self.add_weight(
232
+ name="bias",
233
+ shape=(channels,),
234
+ initializer="zeros",
235
+ trainable=True,
236
+ )
237
+
238
+ super().build(input_shape)
239
+
240
+ def call(self, x):
241
+ x = keras.ops.tanh(self.alpha * x)
242
+ return x * self.weight + self.bias
243
+
244
+ def get_config(self):
245
+ cfg = super().get_config()
246
+ cfg.update({"alpha_init_value": self.alpha_init_value})
247
+ return cfg
248
+
249
+
250
+ def build_norm_layer(norm_type) -> keras.layers.Layer:
251
+ if norm_type == "layer_norm":
252
+ return keras.layers.LayerNormalization(epsilon=1e-5)
253
+ if norm_type == "rms_norm":
254
+ return keras.layers.RMSNormalization(epsilon=1e-5)
255
+ if norm_type == "dyt":
256
+ return DyT(alpha_init_value=0.5)
257
+ raise ValueError(f"Unknown norm_type {norm_type}")
258
+
259
+
260
+ @keras.saving.register_keras_serializable(package="fast_plate_ocr")
261
+ class PositionEmbedding(keras.layers.Layer):
262
+ def __init__(
263
+ self,
264
+ sequence_length,
265
+ initializer="glorot_uniform",
266
+ **kwargs,
267
+ ):
268
+ super().__init__(**kwargs)
269
+ if sequence_length is None:
270
+ raise ValueError("`sequence_length` must be an Integer, received `None`.")
271
+ self.sequence_length = int(sequence_length)
272
+ self.initializer = keras.initializers.get(initializer)
273
+
274
+ def build(self, input_shape):
275
+ feature_size = input_shape[-1]
276
+ self.position_embeddings = self.add_weight(
277
+ name="embeddings",
278
+ shape=[self.sequence_length, feature_size],
279
+ initializer=self.initializer,
280
+ trainable=True,
281
+ )
282
+
283
+ super().build(input_shape)
284
+
285
+ def call(self, inputs, start_index=0):
286
+ shape = keras.ops.shape(inputs)
287
+ feature_length = shape[-1]
288
+ sequence_length = shape[-2]
289
+ # trim to match the length of the input sequence, which might be less than the
290
+ # sequence_length of the layer.
291
+ position_embeddings = keras.ops.convert_to_tensor(self.position_embeddings)
292
+ position_embeddings = keras.ops.slice(
293
+ position_embeddings,
294
+ (start_index, 0),
295
+ (sequence_length, feature_length),
296
+ )
297
+ return keras.ops.broadcast_to(position_embeddings, shape)
298
+
299
+ def compute_output_shape(self, input_shape):
300
+ return input_shape
301
+
302
+ def get_config(self):
303
+ config = super().get_config()
304
+ config.update(
305
+ {
306
+ "sequence_length": self.sequence_length,
307
+ "initializer": keras.initializers.serialize(self.initializer),
308
+ }
309
+ )
310
+ return config
311
+
312
+
313
+ @keras.saving.register_keras_serializable(package="fast_plate_ocr")
314
+ class TokenReducer(keras.layers.Layer):
315
+ def __init__(self, num_tokens, projection_dim, num_heads=2, **kwargs):
316
+ super().__init__(**kwargs)
317
+ self.num_tokens = num_tokens
318
+ self.projection_dim = projection_dim
319
+ self.num_heads = num_heads
320
+ self.attn = keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim)
321
+
322
+ def build(self, input_shape):
323
+ self.query_tokens = self.add_weight(
324
+ shape=(1, self.num_tokens, self.projection_dim),
325
+ initializer="random_normal",
326
+ trainable=True,
327
+ name="query_tokens",
328
+ )
329
+ # input_shape is assumed to be (batch_size, seq_length, projection_dim)
330
+ seq_length = input_shape[1]
331
+ if seq_length is None:
332
+ raise ValueError("Input sequence length must be defined (not None).")
333
+ self.attn.build(
334
+ query_shape=(1, self.num_tokens, self.projection_dim),
335
+ value_shape=(1, seq_length, self.projection_dim),
336
+ )
337
+ super().build(input_shape)
338
+
339
+ def compute_output_shape(self, input_shape):
340
+ return input_shape[0], self.num_tokens, self.projection_dim
341
+
342
+ def call(self, inputs):
343
+ """
344
+ inputs: Tensor of shape (batch_size, seq_length, projection_dim)
345
+ returns: Tensor of shape (batch_size, num_tokens, projection_dim)
346
+ """
347
+ batch_size = keras.ops.shape(inputs)[0]
348
+ # Tile the learned query tokens for each example in the batch.
349
+ query_tokens = keras.ops.tile(self.query_tokens, [batch_size, 1, 1])
350
+ # Perform cross-attention where the queries are the learned tokens and keys/values are the
351
+ # input tokens.
352
+ reduced_tokens = self.attn(query=query_tokens, key=inputs, value=inputs)
353
+ return reduced_tokens
354
+
355
+ def get_config(self):
356
+ cfg = super().get_config()
357
+ cfg.update(
358
+ {
359
+ "num_tokens": self.num_tokens,
360
+ "projection_dim": self.projection_dim,
361
+ "num_heads": self.num_heads,
362
+ }
363
+ )
364
+ return cfg
365
+
366
+
367
+ @keras.saving.register_keras_serializable(package="fast_plate_ocr")
368
+ class StochasticDepth(keras.layers.Layer):
369
+ def __init__(self, drop_prob: float, **kwargs):
370
+ super().__init__(**kwargs)
371
+ self.drop_prob = drop_prob
372
+ self.seed_generator = keras.random.SeedGenerator(1337)
373
+
374
+ def call(self, x, training=None):
375
+ if training:
376
+ keep_prob = 1 - self.drop_prob
377
+ shape = (keras.ops.shape(x)[0],) + (1,) * (len(x.shape) - 1)
378
+ random_tensor = keep_prob + keras.random.uniform(
379
+ shape, 0, 1, seed=self.seed_generator, dtype=x.dtype
380
+ )
381
+ random_tensor = keras.ops.floor(random_tensor)
382
+ return (x / keep_prob) * random_tensor
383
+ return x
384
+
385
+ def get_config(self):
386
+ cfg = super().get_config()
387
+ cfg.update({"drop_prob": self.drop_prob})
388
+ return cfg
389
+
390
+
391
+ @keras.saving.register_keras_serializable(package="fast_plate_ocr")
392
+ class MLP(keras.layers.Layer):
393
+ def __init__(
394
+ self,
395
+ hidden_units,
396
+ dropout_rate: float = 0.1,
397
+ activation: str = "gelu",
398
+ use_bias: bool = True,
399
+ **kwargs,
400
+ ):
401
+ super().__init__(**kwargs)
402
+ self.hidden_units = list(hidden_units)
403
+ self.dropout_rate = dropout_rate
404
+ self.activation = activation
405
+ self.use_bias = use_bias
406
+
407
+ self.dense_layers = [
408
+ keras.layers.Dense(units, activation=self.activation, use_bias=self.use_bias)
409
+ for units in self.hidden_units
410
+ ]
411
+ self.dropout_layers = [keras.layers.Dropout(self.dropout_rate) for _ in self.hidden_units]
412
+
413
+ def build(self, input_shape):
414
+ super().build(input_shape)
415
+
416
+ def call(self, inputs, training=None):
417
+ x = inputs
418
+ for dense, drop in zip(self.dense_layers, self.dropout_layers, strict=True):
419
+ x = dense(x)
420
+ x = drop(x, training=training)
421
+ return x
422
+
423
+ def get_config(self):
424
+ cfg = super().get_config()
425
+ cfg.update(
426
+ {
427
+ "hidden_units": self.hidden_units,
428
+ "dropout_rate": self.dropout_rate,
429
+ "activation": self.activation,
430
+ "use_bias": self.use_bias,
431
+ }
432
+ )
433
+ return cfg
434
+
435
+
436
+ @keras.saving.register_keras_serializable(package="fast_plate_ocr")
437
+ class VocabularyProjection(keras.layers.Layer):
438
+ def __init__(self, vocabulary_size: int, dropout_rate: Optional[float] = None, **kwargs):
439
+ super().__init__(**kwargs)
440
+ self.vocabulary_size = vocabulary_size
441
+ self.dropout_rate = dropout_rate
442
+ self.dropout = (
443
+ keras.layers.Dropout(self.dropout_rate) if self.dropout_rate is not None else None
444
+ )
445
+ self.classifier = keras.layers.Dense(self.vocabulary_size, activation="softmax")
446
+
447
+ def build(self, input_shape):
448
+ super().build(input_shape)
449
+
450
+ def call(self, x, training=None):
451
+ if self.dropout is not None:
452
+ x = self.dropout(x, training=training)
453
+ return self.classifier(x)
454
+
455
+ def get_config(self):
456
+ cfg = super().get_config()
457
+ cfg.update({"vocabulary_size": self.vocabulary_size, "dropout_rate": self.dropout_rate})
458
+ return cfg
459
+
460
+
461
+ @keras.saving.register_keras_serializable(package="fast_plate_ocr")
462
+ class TransformerBlock(keras.layers.Layer):
463
+ def __init__(
464
+ self,
465
+ projection_dim: int,
466
+ num_heads: int,
467
+ mlp_units: Sequence[int],
468
+ attention_dropout: float,
469
+ mlp_dropout: float,
470
+ drop_path_rate: float,
471
+ norm_type: Optional[str] = "layer_norm",
472
+ activation: str = "gelu",
473
+ **kwargs,
474
+ ):
475
+ super().__init__(**kwargs)
476
+ self.norm_type = norm_type
477
+ self.activation = activation
478
+
479
+ self.norm1 = build_norm_layer(norm_type)
480
+ self.attn = keras.layers.MultiHeadAttention(
481
+ num_heads=num_heads, key_dim=projection_dim, dropout=attention_dropout
482
+ )
483
+ self.drop1 = StochasticDepth(drop_path_rate)
484
+ self.norm2 = build_norm_layer(norm_type)
485
+ self.mlp = MLP(hidden_units=mlp_units, dropout_rate=mlp_dropout, activation=activation)
486
+ self.drop2 = StochasticDepth(drop_path_rate)
487
+
488
+ def build(self, input_shape) -> None:
489
+ super().build(input_shape)
490
+
491
+ def call(self, x, training=None):
492
+ # 1. MHA + residual
493
+ y = self.norm1(x)
494
+ y = self.attn(y, y)
495
+ y = self.drop1(y, training=training)
496
+ x = keras.layers.Add()([x, y])
497
+
498
+ # 2. MLP + residual
499
+ y = self.norm2(x)
500
+ y = self.mlp(y, training=training)
501
+ y = self.drop2(y, training=training)
502
+ return keras.layers.Add()([x, y])
503
+
504
+ def get_config(self):
505
+ cfg = super().get_config()
506
+ cfg.update(
507
+ {
508
+ "projection_dim": self.attn.key_dim,
509
+ "num_heads": self.attn.num_heads,
510
+ "mlp_units": self.mlp.hidden_units,
511
+ "mlp_dropout": self.mlp.dropout_rate,
512
+ "attention_dropout": self.attn.dropout,
513
+ "drop_path_rate": self.drop1.drop_prob,
514
+ "norm_type": self.norm_type,
515
+ "activation": self.activation,
516
+ }
517
+ )
518
+ return cfg
519
+
520
+
521
+ @keras.saving.register_keras_serializable(package="fast_plate_ocr")
522
+ class PatchExtractor(keras.layers.Layer):
523
+ """
524
+ Extract non-overlapping patches from an image and flatten them.
525
+
526
+ Modified from https://keras.io/examples/vision/image_classification_with_vision_transformer.
527
+ """
528
+
529
+ def __init__(self, patch_size, **kwargs):
530
+ super().__init__(**kwargs)
531
+ self.patch_size = patch_size
532
+
533
+ def call(self, images):
534
+ batch_size, height, width, channels = ops.shape(images)
535
+
536
+ num_patches_h = height // self.patch_size
537
+ num_patches_w = width // self.patch_size
538
+
539
+ patches = keras.ops.image.extract_patches(images, size=self.patch_size)
540
+ patches = ops.reshape(
541
+ patches,
542
+ (
543
+ batch_size,
544
+ num_patches_h * num_patches_w,
545
+ self.patch_size * self.patch_size * channels,
546
+ ),
547
+ )
548
+ return patches
549
+
550
+ def get_config(self):
551
+ config = super().get_config()
552
+ config.update({"patch_size": self.patch_size})
553
+ return config
@@ -0,0 +1,55 @@
1
+ """
2
+ Loss functions for training license plate recognition models.
3
+ """
4
+
5
+ from keras import losses, ops
6
+
7
+
8
+ def cce_loss(vocabulary_size: int, label_smoothing: float = 0.01):
9
+ """
10
+ Categorical cross-entropy loss.
11
+ """
12
+
13
+ def cce(y_true, y_pred):
14
+ """
15
+ Computes the categorical cross-entropy loss.
16
+ """
17
+ y_true = ops.reshape(y_true, newshape=(-1, vocabulary_size))
18
+ y_pred = ops.reshape(y_pred, newshape=(-1, vocabulary_size))
19
+ return ops.mean(
20
+ losses.categorical_crossentropy(
21
+ y_true, y_pred, from_logits=False, label_smoothing=label_smoothing
22
+ )
23
+ )
24
+
25
+ return cce
26
+
27
+
28
+ def focal_cce_loss(
29
+ vocabulary_size: int,
30
+ alpha: float = 0.25,
31
+ gamma: float = 2.0,
32
+ label_smoothing: float = 0.01,
33
+ ):
34
+ """
35
+ Categorical focal cross-entropy loss.
36
+ """
37
+
38
+ def cce(y_true, y_pred):
39
+ """
40
+ Computes the focal categorical cross-entropy loss.
41
+ """
42
+ y_true = ops.reshape(y_true, newshape=(-1, vocabulary_size))
43
+ y_pred = ops.reshape(y_pred, newshape=(-1, vocabulary_size))
44
+ return ops.mean(
45
+ losses.categorical_focal_crossentropy(
46
+ y_true,
47
+ y_pred,
48
+ alpha=alpha,
49
+ gamma=gamma,
50
+ from_logits=False,
51
+ label_smoothing=label_smoothing,
52
+ )
53
+ )
54
+
55
+ return cce
@@ -0,0 +1,86 @@
1
+ """
2
+ Evaluation metrics for license plate recognition models.
3
+ """
4
+
5
+ from keras import metrics, ops
6
+
7
+
8
+ def cat_acc_metric(max_plate_slots: int, vocabulary_size: int):
9
+ """
10
+ Categorical accuracy metric.
11
+ """
12
+
13
+ def cat_acc(y_true, y_pred):
14
+ """
15
+ This is simply the CategoricalAccuracy for multi-class label problems. Example if the
16
+ correct label is ABC123 and ABC133 is predicted, it will not give a precision of 0% like
17
+ plate_acc (not completely classified correctly), but 83.3% (5/6).
18
+ """
19
+ y_true = ops.reshape(y_true, newshape=(-1, max_plate_slots, vocabulary_size))
20
+ y_pred = ops.reshape(y_pred, newshape=(-1, max_plate_slots, vocabulary_size))
21
+ return ops.mean(metrics.categorical_accuracy(y_true, y_pred))
22
+
23
+ return cat_acc
24
+
25
+
26
+ def plate_acc_metric(max_plate_slots: int, vocabulary_size: int):
27
+ """
28
+ Plate accuracy metric.
29
+ """
30
+
31
+ def plate_acc(y_true, y_pred):
32
+ """
33
+ Compute how many plates were correctly classified. For a single plate, if ground truth is
34
+ 'ABC 123', and the prediction is 'ABC 123', then this would give a score of 1. If the
35
+ prediction was ABD 123, it would score 0.
36
+ """
37
+ y_true = ops.reshape(y_true, newshape=(-1, max_plate_slots, vocabulary_size))
38
+ y_pred = ops.reshape(y_pred, newshape=(-1, max_plate_slots, vocabulary_size))
39
+ y_pred = ops.cast(y_pred, dtype="float32")
40
+ et = ops.equal(ops.argmax(y_true, axis=-1), ops.argmax(y_pred, axis=-1))
41
+ return ops.mean(ops.cast(ops.all(et, axis=-1, keepdims=False), dtype="float32"))
42
+
43
+ return plate_acc
44
+
45
+
46
+ def top_3_k_metric(vocabulary_size: int):
47
+ """
48
+ Top 3 K categorical accuracy metric.
49
+ """
50
+
51
+ def top_3_k(y_true, y_pred):
52
+ """
53
+ Calculates how often the true character is found in the 3 predictions with the highest
54
+ probability.
55
+ """
56
+ # Reshape into 2-d
57
+ y_true = ops.reshape(y_true, newshape=(-1, vocabulary_size))
58
+ y_pred = ops.reshape(y_pred, newshape=(-1, vocabulary_size))
59
+ y_pred = ops.cast(y_pred, dtype="float32")
60
+ return ops.mean(metrics.top_k_categorical_accuracy(y_true, y_pred, k=3))
61
+
62
+ return top_3_k
63
+
64
+
65
+ def plate_len_acc_metric(
66
+ max_plate_slots: int,
67
+ vocabulary_size: int,
68
+ pad_token_index: int,
69
+ ):
70
+ """
71
+ Plate-length accuracy metric.
72
+ """
73
+
74
+ def plate_len_acc(y_true, y_pred):
75
+ """
76
+ Proportion of plates whose predicted length matches the ground-truth length exactly.
77
+ """
78
+ y_true = ops.reshape(y_true, (-1, max_plate_slots, vocabulary_size))
79
+ y_pred = ops.reshape(ops.cast(y_pred, "float32"), (-1, max_plate_slots, vocabulary_size))
80
+ true_idx = ops.argmax(y_true, axis=-1)
81
+ pred_idx = ops.argmax(y_pred, axis=-1)
82
+ true_len = ops.sum(ops.cast(ops.not_equal(true_idx, pad_token_index), "int32"), axis=-1)
83
+ pred_len = ops.sum(ops.cast(ops.not_equal(pred_idx, pad_token_index), "int32"), axis=-1)
84
+ return ops.mean(ops.cast(ops.equal(true_len, pred_len), dtype="float32"))
85
+
86
+ return plate_len_acc