matrice-analytics 0.1.60__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- matrice_analytics/__init__.py +28 -0
- matrice_analytics/boundary_drawing_internal/README.md +305 -0
- matrice_analytics/boundary_drawing_internal/__init__.py +45 -0
- matrice_analytics/boundary_drawing_internal/boundary_drawing_internal.py +1207 -0
- matrice_analytics/boundary_drawing_internal/boundary_drawing_tool.py +429 -0
- matrice_analytics/boundary_drawing_internal/boundary_tool_template.html +1036 -0
- matrice_analytics/boundary_drawing_internal/data/.gitignore +12 -0
- matrice_analytics/boundary_drawing_internal/example_usage.py +206 -0
- matrice_analytics/boundary_drawing_internal/usage/README.md +110 -0
- matrice_analytics/boundary_drawing_internal/usage/boundary_drawer_launcher.py +102 -0
- matrice_analytics/boundary_drawing_internal/usage/simple_boundary_launcher.py +107 -0
- matrice_analytics/post_processing/README.md +455 -0
- matrice_analytics/post_processing/__init__.py +732 -0
- matrice_analytics/post_processing/advanced_tracker/README.md +650 -0
- matrice_analytics/post_processing/advanced_tracker/__init__.py +17 -0
- matrice_analytics/post_processing/advanced_tracker/base.py +99 -0
- matrice_analytics/post_processing/advanced_tracker/config.py +77 -0
- matrice_analytics/post_processing/advanced_tracker/kalman_filter.py +370 -0
- matrice_analytics/post_processing/advanced_tracker/matching.py +195 -0
- matrice_analytics/post_processing/advanced_tracker/strack.py +230 -0
- matrice_analytics/post_processing/advanced_tracker/tracker.py +367 -0
- matrice_analytics/post_processing/config.py +146 -0
- matrice_analytics/post_processing/core/__init__.py +63 -0
- matrice_analytics/post_processing/core/base.py +704 -0
- matrice_analytics/post_processing/core/config.py +3291 -0
- matrice_analytics/post_processing/core/config_utils.py +925 -0
- matrice_analytics/post_processing/face_reg/__init__.py +43 -0
- matrice_analytics/post_processing/face_reg/compare_similarity.py +556 -0
- matrice_analytics/post_processing/face_reg/embedding_manager.py +950 -0
- matrice_analytics/post_processing/face_reg/face_recognition.py +2234 -0
- matrice_analytics/post_processing/face_reg/face_recognition_client.py +606 -0
- matrice_analytics/post_processing/face_reg/people_activity_logging.py +321 -0
- matrice_analytics/post_processing/ocr/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/easyocr_extractor.py +250 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/__init__.py +9 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/__init__.py +4 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/cli.py +33 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/dataset_stats.py +139 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/export.py +398 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/train.py +447 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/utils.py +129 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/valid.py +93 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/validate_dataset.py +240 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/visualize_augmentation.py +176 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/visualize_predictions.py +96 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/__init__.py +3 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/process.py +246 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/types.py +60 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/utils.py +87 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/__init__.py +3 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/config.py +82 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/hub.py +141 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/plate_recognizer.py +323 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/py.typed +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/augmentation.py +101 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/dataset.py +97 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/config.py +114 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/layers.py +553 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/loss.py +55 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/metric.py +86 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/model_builders.py +95 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/model_schema.py +395 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/backend_utils.py +38 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/utils.py +214 -0
- matrice_analytics/post_processing/ocr/postprocessing.py +270 -0
- matrice_analytics/post_processing/ocr/preprocessing.py +52 -0
- matrice_analytics/post_processing/post_processor.py +1175 -0
- matrice_analytics/post_processing/test_cases/__init__.py +1 -0
- matrice_analytics/post_processing/test_cases/run_tests.py +143 -0
- matrice_analytics/post_processing/test_cases/test_advanced_customer_service.py +841 -0
- matrice_analytics/post_processing/test_cases/test_basic_counting_tracking.py +523 -0
- matrice_analytics/post_processing/test_cases/test_comprehensive.py +531 -0
- matrice_analytics/post_processing/test_cases/test_config.py +852 -0
- matrice_analytics/post_processing/test_cases/test_customer_service.py +585 -0
- matrice_analytics/post_processing/test_cases/test_data_generators.py +583 -0
- matrice_analytics/post_processing/test_cases/test_people_counting.py +510 -0
- matrice_analytics/post_processing/test_cases/test_processor.py +524 -0
- matrice_analytics/post_processing/test_cases/test_usecases.py +165 -0
- matrice_analytics/post_processing/test_cases/test_utilities.py +356 -0
- matrice_analytics/post_processing/test_cases/test_utils.py +743 -0
- matrice_analytics/post_processing/usecases/Histopathological_Cancer_Detection_img.py +604 -0
- matrice_analytics/post_processing/usecases/__init__.py +267 -0
- matrice_analytics/post_processing/usecases/abandoned_object_detection.py +797 -0
- matrice_analytics/post_processing/usecases/advanced_customer_service.py +1601 -0
- matrice_analytics/post_processing/usecases/age_detection.py +842 -0
- matrice_analytics/post_processing/usecases/age_gender_detection.py +1085 -0
- matrice_analytics/post_processing/usecases/anti_spoofing_detection.py +656 -0
- matrice_analytics/post_processing/usecases/assembly_line_detection.py +841 -0
- matrice_analytics/post_processing/usecases/banana_defect_detection.py +624 -0
- matrice_analytics/post_processing/usecases/basic_counting_tracking.py +667 -0
- matrice_analytics/post_processing/usecases/blood_cancer_detection_img.py +881 -0
- matrice_analytics/post_processing/usecases/car_damage_detection.py +834 -0
- matrice_analytics/post_processing/usecases/car_part_segmentation.py +946 -0
- matrice_analytics/post_processing/usecases/car_service.py +1601 -0
- matrice_analytics/post_processing/usecases/cardiomegaly_classification.py +864 -0
- matrice_analytics/post_processing/usecases/cell_microscopy_segmentation.py +897 -0
- matrice_analytics/post_processing/usecases/chicken_pose_detection.py +648 -0
- matrice_analytics/post_processing/usecases/child_monitoring.py +814 -0
- matrice_analytics/post_processing/usecases/color/clip.py +660 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/merges.txt +48895 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/preprocessor_config.json +28 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/special_tokens_map.json +30 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/tokenizer.json +245079 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/tokenizer_config.json +32 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/vocab.json +1 -0
- matrice_analytics/post_processing/usecases/color/color_map_utils.py +70 -0
- matrice_analytics/post_processing/usecases/color/color_mapper.py +468 -0
- matrice_analytics/post_processing/usecases/color_detection.py +1936 -0
- matrice_analytics/post_processing/usecases/color_map_utils.py +70 -0
- matrice_analytics/post_processing/usecases/concrete_crack_detection.py +827 -0
- matrice_analytics/post_processing/usecases/crop_weed_detection.py +781 -0
- matrice_analytics/post_processing/usecases/customer_service.py +1008 -0
- matrice_analytics/post_processing/usecases/defect_detection_products.py +936 -0
- matrice_analytics/post_processing/usecases/distracted_driver_detection.py +822 -0
- matrice_analytics/post_processing/usecases/drone_traffic_monitoring.py +585 -0
- matrice_analytics/post_processing/usecases/drowsy_driver_detection.py +829 -0
- matrice_analytics/post_processing/usecases/dwell_detection.py +829 -0
- matrice_analytics/post_processing/usecases/emergency_vehicle_detection.py +827 -0
- matrice_analytics/post_processing/usecases/face_emotion.py +813 -0
- matrice_analytics/post_processing/usecases/face_recognition.py +827 -0
- matrice_analytics/post_processing/usecases/fashion_detection.py +835 -0
- matrice_analytics/post_processing/usecases/field_mapping.py +902 -0
- matrice_analytics/post_processing/usecases/fire_detection.py +1146 -0
- matrice_analytics/post_processing/usecases/flare_analysis.py +836 -0
- matrice_analytics/post_processing/usecases/flower_segmentation.py +1006 -0
- matrice_analytics/post_processing/usecases/gas_leak_detection.py +837 -0
- matrice_analytics/post_processing/usecases/gender_detection.py +832 -0
- matrice_analytics/post_processing/usecases/human_activity_recognition.py +871 -0
- matrice_analytics/post_processing/usecases/intrusion_detection.py +1672 -0
- matrice_analytics/post_processing/usecases/leaf.py +821 -0
- matrice_analytics/post_processing/usecases/leaf_disease.py +840 -0
- matrice_analytics/post_processing/usecases/leak_detection.py +837 -0
- matrice_analytics/post_processing/usecases/license_plate_detection.py +1188 -0
- matrice_analytics/post_processing/usecases/license_plate_monitoring.py +1781 -0
- matrice_analytics/post_processing/usecases/litter_monitoring.py +717 -0
- matrice_analytics/post_processing/usecases/mask_detection.py +869 -0
- matrice_analytics/post_processing/usecases/natural_disaster.py +907 -0
- matrice_analytics/post_processing/usecases/parking.py +787 -0
- matrice_analytics/post_processing/usecases/parking_space_detection.py +822 -0
- matrice_analytics/post_processing/usecases/pcb_defect_detection.py +888 -0
- matrice_analytics/post_processing/usecases/pedestrian_detection.py +808 -0
- matrice_analytics/post_processing/usecases/people_counting.py +706 -0
- matrice_analytics/post_processing/usecases/people_counting_bckp.py +1683 -0
- matrice_analytics/post_processing/usecases/people_tracking.py +1842 -0
- matrice_analytics/post_processing/usecases/pipeline_detection.py +605 -0
- matrice_analytics/post_processing/usecases/plaque_segmentation_img.py +874 -0
- matrice_analytics/post_processing/usecases/pothole_segmentation.py +915 -0
- matrice_analytics/post_processing/usecases/ppe_compliance.py +645 -0
- matrice_analytics/post_processing/usecases/price_tag_detection.py +822 -0
- matrice_analytics/post_processing/usecases/proximity_detection.py +1901 -0
- matrice_analytics/post_processing/usecases/road_lane_detection.py +623 -0
- matrice_analytics/post_processing/usecases/road_traffic_density.py +832 -0
- matrice_analytics/post_processing/usecases/road_view_segmentation.py +915 -0
- matrice_analytics/post_processing/usecases/shelf_inventory_detection.py +583 -0
- matrice_analytics/post_processing/usecases/shoplifting_detection.py +822 -0
- matrice_analytics/post_processing/usecases/shopping_cart_analysis.py +899 -0
- matrice_analytics/post_processing/usecases/skin_cancer_classification_img.py +864 -0
- matrice_analytics/post_processing/usecases/smoker_detection.py +833 -0
- matrice_analytics/post_processing/usecases/solar_panel.py +810 -0
- matrice_analytics/post_processing/usecases/suspicious_activity_detection.py +1030 -0
- matrice_analytics/post_processing/usecases/template_usecase.py +380 -0
- matrice_analytics/post_processing/usecases/theft_detection.py +648 -0
- matrice_analytics/post_processing/usecases/traffic_sign_monitoring.py +724 -0
- matrice_analytics/post_processing/usecases/underground_pipeline_defect_detection.py +775 -0
- matrice_analytics/post_processing/usecases/underwater_pollution_detection.py +842 -0
- matrice_analytics/post_processing/usecases/vehicle_monitoring.py +1029 -0
- matrice_analytics/post_processing/usecases/warehouse_object_segmentation.py +899 -0
- matrice_analytics/post_processing/usecases/waterbody_segmentation.py +923 -0
- matrice_analytics/post_processing/usecases/weapon_detection.py +771 -0
- matrice_analytics/post_processing/usecases/weld_defect_detection.py +615 -0
- matrice_analytics/post_processing/usecases/wildlife_monitoring.py +898 -0
- matrice_analytics/post_processing/usecases/windmill_maintenance.py +834 -0
- matrice_analytics/post_processing/usecases/wound_segmentation.py +856 -0
- matrice_analytics/post_processing/utils/__init__.py +150 -0
- matrice_analytics/post_processing/utils/advanced_counting_utils.py +400 -0
- matrice_analytics/post_processing/utils/advanced_helper_utils.py +317 -0
- matrice_analytics/post_processing/utils/advanced_tracking_utils.py +461 -0
- matrice_analytics/post_processing/utils/alerting_utils.py +213 -0
- matrice_analytics/post_processing/utils/category_mapping_utils.py +94 -0
- matrice_analytics/post_processing/utils/color_utils.py +592 -0
- matrice_analytics/post_processing/utils/counting_utils.py +182 -0
- matrice_analytics/post_processing/utils/filter_utils.py +261 -0
- matrice_analytics/post_processing/utils/format_utils.py +293 -0
- matrice_analytics/post_processing/utils/geometry_utils.py +300 -0
- matrice_analytics/post_processing/utils/smoothing_utils.py +358 -0
- matrice_analytics/post_processing/utils/tracking_utils.py +234 -0
- matrice_analytics/py.typed +0 -0
- matrice_analytics-0.1.60.dist-info/METADATA +481 -0
- matrice_analytics-0.1.60.dist-info/RECORD +196 -0
- matrice_analytics-0.1.60.dist-info/WHEEL +5 -0
- matrice_analytics-0.1.60.dist-info/licenses/LICENSE.txt +21 -0
- matrice_analytics-0.1.60.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,553 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Layer blocks used in the OCR model.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from collections.abc import Sequence
|
|
6
|
+
from typing import Optional, Union
|
|
7
|
+
import keras
|
|
8
|
+
import numpy as np
|
|
9
|
+
from keras import ops
|
|
10
|
+
|
|
11
|
+
# pylint: disable=too-many-ancestors,abstract-method,attribute-defined-outside-init,arguments-differ
|
|
12
|
+
# pylint: disable=useless-parent-delegation,too-many-instance-attributes,too-many-arguments
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@keras.saving.register_keras_serializable(package="fast_plate_ocr")
|
|
16
|
+
class AddCoords(keras.layers.Layer):
|
|
17
|
+
"""Add coords to a tensor, modified from paper: https://arxiv.org/abs/1807.03247"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, with_r=False):
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.with_r = with_r
|
|
22
|
+
|
|
23
|
+
def build(self, input_shape):
|
|
24
|
+
# Assuming input_shape is (batch, height, width, channels)
|
|
25
|
+
self.x_dim = input_shape[1]
|
|
26
|
+
self.y_dim = input_shape[2]
|
|
27
|
+
|
|
28
|
+
def call(self, input_tensor):
|
|
29
|
+
"""
|
|
30
|
+
input_tensor: (batch, x_dim, y_dim, c)
|
|
31
|
+
"""
|
|
32
|
+
batch_size_tensor = ops.shape(input_tensor)[0]
|
|
33
|
+
xx_ones = ops.ones([batch_size_tensor, self.x_dim])
|
|
34
|
+
xx_ones = ops.expand_dims(xx_ones, -1)
|
|
35
|
+
xx_range = ops.tile(ops.expand_dims(ops.arange(self.y_dim), 0), [batch_size_tensor, 1])
|
|
36
|
+
xx_range = ops.expand_dims(xx_range, 1)
|
|
37
|
+
xx_channel = ops.matmul(xx_ones, xx_range)
|
|
38
|
+
xx_channel = ops.expand_dims(xx_channel, -1)
|
|
39
|
+
yy_ones = ops.ones([batch_size_tensor, self.y_dim])
|
|
40
|
+
yy_ones = ops.expand_dims(yy_ones, 1)
|
|
41
|
+
yy_range = ops.tile(ops.expand_dims(ops.arange(self.x_dim), 0), [batch_size_tensor, 1])
|
|
42
|
+
|
|
43
|
+
yy_range = ops.expand_dims(yy_range, -1)
|
|
44
|
+
yy_channel = ops.matmul(yy_range, yy_ones)
|
|
45
|
+
yy_channel = ops.expand_dims(yy_channel, -1)
|
|
46
|
+
xx_channel = ops.cast(xx_channel, "float32") / (self.x_dim - 1)
|
|
47
|
+
yy_channel = ops.cast(yy_channel, "float32") / (self.y_dim - 1)
|
|
48
|
+
xx_channel = xx_channel * 2 - 1
|
|
49
|
+
yy_channel = yy_channel * 2 - 1
|
|
50
|
+
ret = ops.concatenate([input_tensor, xx_channel, yy_channel], axis=-1)
|
|
51
|
+
if self.with_r:
|
|
52
|
+
rr = ops.sqrt(ops.square(xx_channel) + ops.square(yy_channel))
|
|
53
|
+
ret = ops.concatenate([ret, rr], axis=-1)
|
|
54
|
+
return ret
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@keras.saving.register_keras_serializable(package="fast_plate_ocr")
|
|
58
|
+
class CoordConv2D(keras.layers.Layer):
|
|
59
|
+
"""CoordConv2D layer as in the paper, modified from paper: https://arxiv.org/abs/1807.03247"""
|
|
60
|
+
|
|
61
|
+
def __init__(self, with_r: bool = False, **conv_kwargs):
|
|
62
|
+
super().__init__()
|
|
63
|
+
self.with_r = with_r
|
|
64
|
+
self.conv_kwargs = conv_kwargs.copy()
|
|
65
|
+
self.addcoords = AddCoords(with_r=with_r)
|
|
66
|
+
self.conv = keras.layers.Conv2D(**conv_kwargs)
|
|
67
|
+
|
|
68
|
+
def call(self, inputs):
|
|
69
|
+
x = self.addcoords(inputs)
|
|
70
|
+
return self.conv(x)
|
|
71
|
+
|
|
72
|
+
def get_config(self):
|
|
73
|
+
config = super().get_config()
|
|
74
|
+
config.update({"with_r": self.with_r, **self.conv_kwargs})
|
|
75
|
+
return config
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _build_binomial_filter(filter_size: int) -> np.ndarray:
|
|
79
|
+
"""Builds and returns the normalized binomial filter according to `filter_size`."""
|
|
80
|
+
if filter_size == 1:
|
|
81
|
+
binomial_filter = np.array([1.0])
|
|
82
|
+
elif filter_size == 2:
|
|
83
|
+
binomial_filter = np.array([1.0, 1.0])
|
|
84
|
+
elif filter_size == 3:
|
|
85
|
+
binomial_filter = np.array([1.0, 2.0, 1.0])
|
|
86
|
+
elif filter_size == 4:
|
|
87
|
+
binomial_filter = np.array([1.0, 3.0, 3.0, 1.0])
|
|
88
|
+
elif filter_size == 5:
|
|
89
|
+
binomial_filter = np.array([1.0, 4.0, 6.0, 4.0, 1.0])
|
|
90
|
+
elif filter_size == 6:
|
|
91
|
+
binomial_filter = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0])
|
|
92
|
+
elif filter_size == 7:
|
|
93
|
+
binomial_filter = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0])
|
|
94
|
+
else:
|
|
95
|
+
raise ValueError(f"Filter size not supported, got {filter_size}")
|
|
96
|
+
|
|
97
|
+
binomial_filter = binomial_filter[:, np.newaxis] * binomial_filter[np.newaxis, :]
|
|
98
|
+
binomial_filter = binomial_filter / np.sum(binomial_filter)
|
|
99
|
+
|
|
100
|
+
return binomial_filter
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@keras.saving.register_keras_serializable(package="fast_plate_ocr")
|
|
104
|
+
class MaxBlurPooling2D(keras.layers.Layer):
|
|
105
|
+
def __init__(self, pool_size: int = 2, filter_size: int = 3, padding: str = "same", **kwargs):
|
|
106
|
+
self.pool_size = pool_size
|
|
107
|
+
self.blur_kernel = None
|
|
108
|
+
self.filter_size = filter_size
|
|
109
|
+
self.padding = padding
|
|
110
|
+
|
|
111
|
+
super().__init__(**kwargs)
|
|
112
|
+
|
|
113
|
+
def build(self, input_shape):
|
|
114
|
+
binomial_filter = _build_binomial_filter(filter_size=self.filter_size)
|
|
115
|
+
binomial_filter = np.repeat(binomial_filter, input_shape[3])
|
|
116
|
+
# Maybe this should be channel first/last agnostic
|
|
117
|
+
binomial_filter = np.reshape(
|
|
118
|
+
binomial_filter, (self.filter_size, self.filter_size, input_shape[3], 1)
|
|
119
|
+
)
|
|
120
|
+
blur_init = keras.initializers.constant(binomial_filter)
|
|
121
|
+
|
|
122
|
+
self.blur_kernel = self.add_weight(
|
|
123
|
+
name="blur_kernel",
|
|
124
|
+
shape=(self.filter_size, self.filter_size, input_shape[3], 1),
|
|
125
|
+
initializer=blur_init,
|
|
126
|
+
trainable=False,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
super().build(input_shape)
|
|
130
|
+
|
|
131
|
+
def call(self, x):
|
|
132
|
+
x = ops.max_pool(
|
|
133
|
+
x,
|
|
134
|
+
(self.pool_size, self.pool_size),
|
|
135
|
+
strides=(1, 1),
|
|
136
|
+
padding=self.padding,
|
|
137
|
+
)
|
|
138
|
+
x = ops.depthwise_conv(
|
|
139
|
+
x, self.blur_kernel, padding=self.padding, strides=(self.pool_size, self.pool_size)
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
return x
|
|
143
|
+
|
|
144
|
+
def compute_output_shape(self, input_shape):
|
|
145
|
+
return (
|
|
146
|
+
input_shape[0],
|
|
147
|
+
int(np.ceil(input_shape[1] / 2)),
|
|
148
|
+
int(np.ceil(input_shape[2] / 2)),
|
|
149
|
+
input_shape[3],
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
def get_config(self):
|
|
153
|
+
config = super().get_config()
|
|
154
|
+
config.update(
|
|
155
|
+
{
|
|
156
|
+
"pool_size": self.pool_size,
|
|
157
|
+
"filter_size": self.filter_size,
|
|
158
|
+
"padding": self.padding,
|
|
159
|
+
}
|
|
160
|
+
)
|
|
161
|
+
return config
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
@keras.saving.register_keras_serializable(package="fast_plate_ocr")
|
|
165
|
+
class SqueezeExcite(keras.layers.Layer):
|
|
166
|
+
"""
|
|
167
|
+
Applies squeeze and excitation to input feature maps as seen in https://arxiv.org/abs/1709.01507
|
|
168
|
+
|
|
169
|
+
Note: this was taken from https://keras.io/examples/vision/patch_convnet.
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
def __init__(self, ratio: float = 1.0, **kwargs):
|
|
173
|
+
super().__init__(**kwargs)
|
|
174
|
+
self.ratio = ratio
|
|
175
|
+
|
|
176
|
+
def get_config(self):
|
|
177
|
+
config = super().get_config()
|
|
178
|
+
config.update({"ratio": self.ratio})
|
|
179
|
+
return config
|
|
180
|
+
|
|
181
|
+
def build(self, input_shape):
|
|
182
|
+
filters = input_shape[-1]
|
|
183
|
+
self.squeeze = keras.layers.GlobalAveragePooling2D(keepdims=True)
|
|
184
|
+
self.reduction = keras.layers.Dense(
|
|
185
|
+
units=int(filters // self.ratio),
|
|
186
|
+
activation="relu",
|
|
187
|
+
use_bias=False,
|
|
188
|
+
)
|
|
189
|
+
self.excite = keras.layers.Dense(units=filters, activation="sigmoid", use_bias=False)
|
|
190
|
+
self.multiply = keras.layers.Multiply()
|
|
191
|
+
|
|
192
|
+
def call(self, x):
|
|
193
|
+
shortcut = x
|
|
194
|
+
x = self.squeeze(x)
|
|
195
|
+
x = self.reduction(x)
|
|
196
|
+
x = self.excite(x)
|
|
197
|
+
x = self.multiply([shortcut, x])
|
|
198
|
+
return x
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
@keras.utils.register_keras_serializable(package="fast_plate_ocr")
|
|
202
|
+
class DyT(keras.layers.Layer):
|
|
203
|
+
"""
|
|
204
|
+
Dynamic Tanh (DyT) is an element-wise operation as a drop-in replacement for normalization
|
|
205
|
+
layers in Transformers.
|
|
206
|
+
|
|
207
|
+
Paper: https://arxiv.org/abs/2503.10622.
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
def __init__(self, alpha_init_value: float = 0.5, **kwargs):
|
|
211
|
+
super().__init__(**kwargs)
|
|
212
|
+
self.alpha_init_value = alpha_init_value
|
|
213
|
+
|
|
214
|
+
def build(self, input_shape):
|
|
215
|
+
channels = int(input_shape[-1])
|
|
216
|
+
|
|
217
|
+
# scalar alpha
|
|
218
|
+
self.alpha = self.add_weight(
|
|
219
|
+
name="alpha",
|
|
220
|
+
shape=(),
|
|
221
|
+
initializer=keras.initializers.Constant(self.alpha_init_value),
|
|
222
|
+
trainable=True,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
self.weight = self.add_weight(
|
|
226
|
+
name="weight",
|
|
227
|
+
shape=(channels,),
|
|
228
|
+
initializer="ones",
|
|
229
|
+
trainable=True,
|
|
230
|
+
)
|
|
231
|
+
self.bias = self.add_weight(
|
|
232
|
+
name="bias",
|
|
233
|
+
shape=(channels,),
|
|
234
|
+
initializer="zeros",
|
|
235
|
+
trainable=True,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
super().build(input_shape)
|
|
239
|
+
|
|
240
|
+
def call(self, x):
|
|
241
|
+
x = keras.ops.tanh(self.alpha * x)
|
|
242
|
+
return x * self.weight + self.bias
|
|
243
|
+
|
|
244
|
+
def get_config(self):
|
|
245
|
+
cfg = super().get_config()
|
|
246
|
+
cfg.update({"alpha_init_value": self.alpha_init_value})
|
|
247
|
+
return cfg
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def build_norm_layer(norm_type) -> keras.layers.Layer:
|
|
251
|
+
if norm_type == "layer_norm":
|
|
252
|
+
return keras.layers.LayerNormalization(epsilon=1e-5)
|
|
253
|
+
if norm_type == "rms_norm":
|
|
254
|
+
return keras.layers.RMSNormalization(epsilon=1e-5)
|
|
255
|
+
if norm_type == "dyt":
|
|
256
|
+
return DyT(alpha_init_value=0.5)
|
|
257
|
+
raise ValueError(f"Unknown norm_type {norm_type}")
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
@keras.saving.register_keras_serializable(package="fast_plate_ocr")
|
|
261
|
+
class PositionEmbedding(keras.layers.Layer):
|
|
262
|
+
def __init__(
|
|
263
|
+
self,
|
|
264
|
+
sequence_length,
|
|
265
|
+
initializer="glorot_uniform",
|
|
266
|
+
**kwargs,
|
|
267
|
+
):
|
|
268
|
+
super().__init__(**kwargs)
|
|
269
|
+
if sequence_length is None:
|
|
270
|
+
raise ValueError("`sequence_length` must be an Integer, received `None`.")
|
|
271
|
+
self.sequence_length = int(sequence_length)
|
|
272
|
+
self.initializer = keras.initializers.get(initializer)
|
|
273
|
+
|
|
274
|
+
def build(self, input_shape):
|
|
275
|
+
feature_size = input_shape[-1]
|
|
276
|
+
self.position_embeddings = self.add_weight(
|
|
277
|
+
name="embeddings",
|
|
278
|
+
shape=[self.sequence_length, feature_size],
|
|
279
|
+
initializer=self.initializer,
|
|
280
|
+
trainable=True,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
super().build(input_shape)
|
|
284
|
+
|
|
285
|
+
def call(self, inputs, start_index=0):
|
|
286
|
+
shape = keras.ops.shape(inputs)
|
|
287
|
+
feature_length = shape[-1]
|
|
288
|
+
sequence_length = shape[-2]
|
|
289
|
+
# trim to match the length of the input sequence, which might be less than the
|
|
290
|
+
# sequence_length of the layer.
|
|
291
|
+
position_embeddings = keras.ops.convert_to_tensor(self.position_embeddings)
|
|
292
|
+
position_embeddings = keras.ops.slice(
|
|
293
|
+
position_embeddings,
|
|
294
|
+
(start_index, 0),
|
|
295
|
+
(sequence_length, feature_length),
|
|
296
|
+
)
|
|
297
|
+
return keras.ops.broadcast_to(position_embeddings, shape)
|
|
298
|
+
|
|
299
|
+
def compute_output_shape(self, input_shape):
|
|
300
|
+
return input_shape
|
|
301
|
+
|
|
302
|
+
def get_config(self):
|
|
303
|
+
config = super().get_config()
|
|
304
|
+
config.update(
|
|
305
|
+
{
|
|
306
|
+
"sequence_length": self.sequence_length,
|
|
307
|
+
"initializer": keras.initializers.serialize(self.initializer),
|
|
308
|
+
}
|
|
309
|
+
)
|
|
310
|
+
return config
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
@keras.saving.register_keras_serializable(package="fast_plate_ocr")
|
|
314
|
+
class TokenReducer(keras.layers.Layer):
|
|
315
|
+
def __init__(self, num_tokens, projection_dim, num_heads=2, **kwargs):
|
|
316
|
+
super().__init__(**kwargs)
|
|
317
|
+
self.num_tokens = num_tokens
|
|
318
|
+
self.projection_dim = projection_dim
|
|
319
|
+
self.num_heads = num_heads
|
|
320
|
+
self.attn = keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim)
|
|
321
|
+
|
|
322
|
+
def build(self, input_shape):
|
|
323
|
+
self.query_tokens = self.add_weight(
|
|
324
|
+
shape=(1, self.num_tokens, self.projection_dim),
|
|
325
|
+
initializer="random_normal",
|
|
326
|
+
trainable=True,
|
|
327
|
+
name="query_tokens",
|
|
328
|
+
)
|
|
329
|
+
# input_shape is assumed to be (batch_size, seq_length, projection_dim)
|
|
330
|
+
seq_length = input_shape[1]
|
|
331
|
+
if seq_length is None:
|
|
332
|
+
raise ValueError("Input sequence length must be defined (not None).")
|
|
333
|
+
self.attn.build(
|
|
334
|
+
query_shape=(1, self.num_tokens, self.projection_dim),
|
|
335
|
+
value_shape=(1, seq_length, self.projection_dim),
|
|
336
|
+
)
|
|
337
|
+
super().build(input_shape)
|
|
338
|
+
|
|
339
|
+
def compute_output_shape(self, input_shape):
|
|
340
|
+
return input_shape[0], self.num_tokens, self.projection_dim
|
|
341
|
+
|
|
342
|
+
def call(self, inputs):
|
|
343
|
+
"""
|
|
344
|
+
inputs: Tensor of shape (batch_size, seq_length, projection_dim)
|
|
345
|
+
returns: Tensor of shape (batch_size, num_tokens, projection_dim)
|
|
346
|
+
"""
|
|
347
|
+
batch_size = keras.ops.shape(inputs)[0]
|
|
348
|
+
# Tile the learned query tokens for each example in the batch.
|
|
349
|
+
query_tokens = keras.ops.tile(self.query_tokens, [batch_size, 1, 1])
|
|
350
|
+
# Perform cross-attention where the queries are the learned tokens and keys/values are the
|
|
351
|
+
# input tokens.
|
|
352
|
+
reduced_tokens = self.attn(query=query_tokens, key=inputs, value=inputs)
|
|
353
|
+
return reduced_tokens
|
|
354
|
+
|
|
355
|
+
def get_config(self):
|
|
356
|
+
cfg = super().get_config()
|
|
357
|
+
cfg.update(
|
|
358
|
+
{
|
|
359
|
+
"num_tokens": self.num_tokens,
|
|
360
|
+
"projection_dim": self.projection_dim,
|
|
361
|
+
"num_heads": self.num_heads,
|
|
362
|
+
}
|
|
363
|
+
)
|
|
364
|
+
return cfg
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
@keras.saving.register_keras_serializable(package="fast_plate_ocr")
|
|
368
|
+
class StochasticDepth(keras.layers.Layer):
|
|
369
|
+
def __init__(self, drop_prob: float, **kwargs):
|
|
370
|
+
super().__init__(**kwargs)
|
|
371
|
+
self.drop_prob = drop_prob
|
|
372
|
+
self.seed_generator = keras.random.SeedGenerator(1337)
|
|
373
|
+
|
|
374
|
+
def call(self, x, training=None):
|
|
375
|
+
if training:
|
|
376
|
+
keep_prob = 1 - self.drop_prob
|
|
377
|
+
shape = (keras.ops.shape(x)[0],) + (1,) * (len(x.shape) - 1)
|
|
378
|
+
random_tensor = keep_prob + keras.random.uniform(
|
|
379
|
+
shape, 0, 1, seed=self.seed_generator, dtype=x.dtype
|
|
380
|
+
)
|
|
381
|
+
random_tensor = keras.ops.floor(random_tensor)
|
|
382
|
+
return (x / keep_prob) * random_tensor
|
|
383
|
+
return x
|
|
384
|
+
|
|
385
|
+
def get_config(self):
|
|
386
|
+
cfg = super().get_config()
|
|
387
|
+
cfg.update({"drop_prob": self.drop_prob})
|
|
388
|
+
return cfg
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
@keras.saving.register_keras_serializable(package="fast_plate_ocr")
|
|
392
|
+
class MLP(keras.layers.Layer):
|
|
393
|
+
def __init__(
|
|
394
|
+
self,
|
|
395
|
+
hidden_units,
|
|
396
|
+
dropout_rate: float = 0.1,
|
|
397
|
+
activation: str = "gelu",
|
|
398
|
+
use_bias: bool = True,
|
|
399
|
+
**kwargs,
|
|
400
|
+
):
|
|
401
|
+
super().__init__(**kwargs)
|
|
402
|
+
self.hidden_units = list(hidden_units)
|
|
403
|
+
self.dropout_rate = dropout_rate
|
|
404
|
+
self.activation = activation
|
|
405
|
+
self.use_bias = use_bias
|
|
406
|
+
|
|
407
|
+
self.dense_layers = [
|
|
408
|
+
keras.layers.Dense(units, activation=self.activation, use_bias=self.use_bias)
|
|
409
|
+
for units in self.hidden_units
|
|
410
|
+
]
|
|
411
|
+
self.dropout_layers = [keras.layers.Dropout(self.dropout_rate) for _ in self.hidden_units]
|
|
412
|
+
|
|
413
|
+
def build(self, input_shape):
|
|
414
|
+
super().build(input_shape)
|
|
415
|
+
|
|
416
|
+
def call(self, inputs, training=None):
|
|
417
|
+
x = inputs
|
|
418
|
+
for dense, drop in zip(self.dense_layers, self.dropout_layers, strict=True):
|
|
419
|
+
x = dense(x)
|
|
420
|
+
x = drop(x, training=training)
|
|
421
|
+
return x
|
|
422
|
+
|
|
423
|
+
def get_config(self):
|
|
424
|
+
cfg = super().get_config()
|
|
425
|
+
cfg.update(
|
|
426
|
+
{
|
|
427
|
+
"hidden_units": self.hidden_units,
|
|
428
|
+
"dropout_rate": self.dropout_rate,
|
|
429
|
+
"activation": self.activation,
|
|
430
|
+
"use_bias": self.use_bias,
|
|
431
|
+
}
|
|
432
|
+
)
|
|
433
|
+
return cfg
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
@keras.saving.register_keras_serializable(package="fast_plate_ocr")
|
|
437
|
+
class VocabularyProjection(keras.layers.Layer):
|
|
438
|
+
def __init__(self, vocabulary_size: int, dropout_rate: Optional[float] = None, **kwargs):
|
|
439
|
+
super().__init__(**kwargs)
|
|
440
|
+
self.vocabulary_size = vocabulary_size
|
|
441
|
+
self.dropout_rate = dropout_rate
|
|
442
|
+
self.dropout = (
|
|
443
|
+
keras.layers.Dropout(self.dropout_rate) if self.dropout_rate is not None else None
|
|
444
|
+
)
|
|
445
|
+
self.classifier = keras.layers.Dense(self.vocabulary_size, activation="softmax")
|
|
446
|
+
|
|
447
|
+
def build(self, input_shape):
|
|
448
|
+
super().build(input_shape)
|
|
449
|
+
|
|
450
|
+
def call(self, x, training=None):
|
|
451
|
+
if self.dropout is not None:
|
|
452
|
+
x = self.dropout(x, training=training)
|
|
453
|
+
return self.classifier(x)
|
|
454
|
+
|
|
455
|
+
def get_config(self):
|
|
456
|
+
cfg = super().get_config()
|
|
457
|
+
cfg.update({"vocabulary_size": self.vocabulary_size, "dropout_rate": self.dropout_rate})
|
|
458
|
+
return cfg
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
@keras.saving.register_keras_serializable(package="fast_plate_ocr")
|
|
462
|
+
class TransformerBlock(keras.layers.Layer):
|
|
463
|
+
def __init__(
|
|
464
|
+
self,
|
|
465
|
+
projection_dim: int,
|
|
466
|
+
num_heads: int,
|
|
467
|
+
mlp_units: Sequence[int],
|
|
468
|
+
attention_dropout: float,
|
|
469
|
+
mlp_dropout: float,
|
|
470
|
+
drop_path_rate: float,
|
|
471
|
+
norm_type: Optional[str] = "layer_norm",
|
|
472
|
+
activation: str = "gelu",
|
|
473
|
+
**kwargs,
|
|
474
|
+
):
|
|
475
|
+
super().__init__(**kwargs)
|
|
476
|
+
self.norm_type = norm_type
|
|
477
|
+
self.activation = activation
|
|
478
|
+
|
|
479
|
+
self.norm1 = build_norm_layer(norm_type)
|
|
480
|
+
self.attn = keras.layers.MultiHeadAttention(
|
|
481
|
+
num_heads=num_heads, key_dim=projection_dim, dropout=attention_dropout
|
|
482
|
+
)
|
|
483
|
+
self.drop1 = StochasticDepth(drop_path_rate)
|
|
484
|
+
self.norm2 = build_norm_layer(norm_type)
|
|
485
|
+
self.mlp = MLP(hidden_units=mlp_units, dropout_rate=mlp_dropout, activation=activation)
|
|
486
|
+
self.drop2 = StochasticDepth(drop_path_rate)
|
|
487
|
+
|
|
488
|
+
def build(self, input_shape) -> None:
|
|
489
|
+
super().build(input_shape)
|
|
490
|
+
|
|
491
|
+
def call(self, x, training=None):
|
|
492
|
+
# 1. MHA + residual
|
|
493
|
+
y = self.norm1(x)
|
|
494
|
+
y = self.attn(y, y)
|
|
495
|
+
y = self.drop1(y, training=training)
|
|
496
|
+
x = keras.layers.Add()([x, y])
|
|
497
|
+
|
|
498
|
+
# 2. MLP + residual
|
|
499
|
+
y = self.norm2(x)
|
|
500
|
+
y = self.mlp(y, training=training)
|
|
501
|
+
y = self.drop2(y, training=training)
|
|
502
|
+
return keras.layers.Add()([x, y])
|
|
503
|
+
|
|
504
|
+
def get_config(self):
|
|
505
|
+
cfg = super().get_config()
|
|
506
|
+
cfg.update(
|
|
507
|
+
{
|
|
508
|
+
"projection_dim": self.attn.key_dim,
|
|
509
|
+
"num_heads": self.attn.num_heads,
|
|
510
|
+
"mlp_units": self.mlp.hidden_units,
|
|
511
|
+
"mlp_dropout": self.mlp.dropout_rate,
|
|
512
|
+
"attention_dropout": self.attn.dropout,
|
|
513
|
+
"drop_path_rate": self.drop1.drop_prob,
|
|
514
|
+
"norm_type": self.norm_type,
|
|
515
|
+
"activation": self.activation,
|
|
516
|
+
}
|
|
517
|
+
)
|
|
518
|
+
return cfg
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
@keras.saving.register_keras_serializable(package="fast_plate_ocr")
|
|
522
|
+
class PatchExtractor(keras.layers.Layer):
|
|
523
|
+
"""
|
|
524
|
+
Extract non-overlapping patches from an image and flatten them.
|
|
525
|
+
|
|
526
|
+
Modified from https://keras.io/examples/vision/image_classification_with_vision_transformer.
|
|
527
|
+
"""
|
|
528
|
+
|
|
529
|
+
def __init__(self, patch_size, **kwargs):
|
|
530
|
+
super().__init__(**kwargs)
|
|
531
|
+
self.patch_size = patch_size
|
|
532
|
+
|
|
533
|
+
def call(self, images):
|
|
534
|
+
batch_size, height, width, channels = ops.shape(images)
|
|
535
|
+
|
|
536
|
+
num_patches_h = height // self.patch_size
|
|
537
|
+
num_patches_w = width // self.patch_size
|
|
538
|
+
|
|
539
|
+
patches = keras.ops.image.extract_patches(images, size=self.patch_size)
|
|
540
|
+
patches = ops.reshape(
|
|
541
|
+
patches,
|
|
542
|
+
(
|
|
543
|
+
batch_size,
|
|
544
|
+
num_patches_h * num_patches_w,
|
|
545
|
+
self.patch_size * self.patch_size * channels,
|
|
546
|
+
),
|
|
547
|
+
)
|
|
548
|
+
return patches
|
|
549
|
+
|
|
550
|
+
def get_config(self):
|
|
551
|
+
config = super().get_config()
|
|
552
|
+
config.update({"patch_size": self.patch_size})
|
|
553
|
+
return config
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Loss functions for training license plate recognition models.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from keras import losses, ops
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def cce_loss(vocabulary_size: int, label_smoothing: float = 0.01):
|
|
9
|
+
"""
|
|
10
|
+
Categorical cross-entropy loss.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def cce(y_true, y_pred):
|
|
14
|
+
"""
|
|
15
|
+
Computes the categorical cross-entropy loss.
|
|
16
|
+
"""
|
|
17
|
+
y_true = ops.reshape(y_true, newshape=(-1, vocabulary_size))
|
|
18
|
+
y_pred = ops.reshape(y_pred, newshape=(-1, vocabulary_size))
|
|
19
|
+
return ops.mean(
|
|
20
|
+
losses.categorical_crossentropy(
|
|
21
|
+
y_true, y_pred, from_logits=False, label_smoothing=label_smoothing
|
|
22
|
+
)
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
return cce
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def focal_cce_loss(
|
|
29
|
+
vocabulary_size: int,
|
|
30
|
+
alpha: float = 0.25,
|
|
31
|
+
gamma: float = 2.0,
|
|
32
|
+
label_smoothing: float = 0.01,
|
|
33
|
+
):
|
|
34
|
+
"""
|
|
35
|
+
Categorical focal cross-entropy loss.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def cce(y_true, y_pred):
|
|
39
|
+
"""
|
|
40
|
+
Computes the focal categorical cross-entropy loss.
|
|
41
|
+
"""
|
|
42
|
+
y_true = ops.reshape(y_true, newshape=(-1, vocabulary_size))
|
|
43
|
+
y_pred = ops.reshape(y_pred, newshape=(-1, vocabulary_size))
|
|
44
|
+
return ops.mean(
|
|
45
|
+
losses.categorical_focal_crossentropy(
|
|
46
|
+
y_true,
|
|
47
|
+
y_pred,
|
|
48
|
+
alpha=alpha,
|
|
49
|
+
gamma=gamma,
|
|
50
|
+
from_logits=False,
|
|
51
|
+
label_smoothing=label_smoothing,
|
|
52
|
+
)
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
return cce
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Evaluation metrics for license plate recognition models.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from keras import metrics, ops
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def cat_acc_metric(max_plate_slots: int, vocabulary_size: int):
|
|
9
|
+
"""
|
|
10
|
+
Categorical accuracy metric.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def cat_acc(y_true, y_pred):
|
|
14
|
+
"""
|
|
15
|
+
This is simply the CategoricalAccuracy for multi-class label problems. Example if the
|
|
16
|
+
correct label is ABC123 and ABC133 is predicted, it will not give a precision of 0% like
|
|
17
|
+
plate_acc (not completely classified correctly), but 83.3% (5/6).
|
|
18
|
+
"""
|
|
19
|
+
y_true = ops.reshape(y_true, newshape=(-1, max_plate_slots, vocabulary_size))
|
|
20
|
+
y_pred = ops.reshape(y_pred, newshape=(-1, max_plate_slots, vocabulary_size))
|
|
21
|
+
return ops.mean(metrics.categorical_accuracy(y_true, y_pred))
|
|
22
|
+
|
|
23
|
+
return cat_acc
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def plate_acc_metric(max_plate_slots: int, vocabulary_size: int):
|
|
27
|
+
"""
|
|
28
|
+
Plate accuracy metric.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def plate_acc(y_true, y_pred):
|
|
32
|
+
"""
|
|
33
|
+
Compute how many plates were correctly classified. For a single plate, if ground truth is
|
|
34
|
+
'ABC 123', and the prediction is 'ABC 123', then this would give a score of 1. If the
|
|
35
|
+
prediction was ABD 123, it would score 0.
|
|
36
|
+
"""
|
|
37
|
+
y_true = ops.reshape(y_true, newshape=(-1, max_plate_slots, vocabulary_size))
|
|
38
|
+
y_pred = ops.reshape(y_pred, newshape=(-1, max_plate_slots, vocabulary_size))
|
|
39
|
+
y_pred = ops.cast(y_pred, dtype="float32")
|
|
40
|
+
et = ops.equal(ops.argmax(y_true, axis=-1), ops.argmax(y_pred, axis=-1))
|
|
41
|
+
return ops.mean(ops.cast(ops.all(et, axis=-1, keepdims=False), dtype="float32"))
|
|
42
|
+
|
|
43
|
+
return plate_acc
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def top_3_k_metric(vocabulary_size: int):
|
|
47
|
+
"""
|
|
48
|
+
Top 3 K categorical accuracy metric.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def top_3_k(y_true, y_pred):
|
|
52
|
+
"""
|
|
53
|
+
Calculates how often the true character is found in the 3 predictions with the highest
|
|
54
|
+
probability.
|
|
55
|
+
"""
|
|
56
|
+
# Reshape into 2-d
|
|
57
|
+
y_true = ops.reshape(y_true, newshape=(-1, vocabulary_size))
|
|
58
|
+
y_pred = ops.reshape(y_pred, newshape=(-1, vocabulary_size))
|
|
59
|
+
y_pred = ops.cast(y_pred, dtype="float32")
|
|
60
|
+
return ops.mean(metrics.top_k_categorical_accuracy(y_true, y_pred, k=3))
|
|
61
|
+
|
|
62
|
+
return top_3_k
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def plate_len_acc_metric(
|
|
66
|
+
max_plate_slots: int,
|
|
67
|
+
vocabulary_size: int,
|
|
68
|
+
pad_token_index: int,
|
|
69
|
+
):
|
|
70
|
+
"""
|
|
71
|
+
Plate-length accuracy metric.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def plate_len_acc(y_true, y_pred):
|
|
75
|
+
"""
|
|
76
|
+
Proportion of plates whose predicted length matches the ground-truth length exactly.
|
|
77
|
+
"""
|
|
78
|
+
y_true = ops.reshape(y_true, (-1, max_plate_slots, vocabulary_size))
|
|
79
|
+
y_pred = ops.reshape(ops.cast(y_pred, "float32"), (-1, max_plate_slots, vocabulary_size))
|
|
80
|
+
true_idx = ops.argmax(y_true, axis=-1)
|
|
81
|
+
pred_idx = ops.argmax(y_pred, axis=-1)
|
|
82
|
+
true_len = ops.sum(ops.cast(ops.not_equal(true_idx, pad_token_index), "int32"), axis=-1)
|
|
83
|
+
pred_len = ops.sum(ops.cast(ops.not_equal(pred_idx, pad_token_index), "int32"), axis=-1)
|
|
84
|
+
return ops.mean(ops.cast(ops.equal(true_len, pred_len), dtype="float32"))
|
|
85
|
+
|
|
86
|
+
return plate_len_acc
|