matrice-analytics 0.1.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. matrice_analytics/__init__.py +28 -0
  2. matrice_analytics/boundary_drawing_internal/README.md +305 -0
  3. matrice_analytics/boundary_drawing_internal/__init__.py +45 -0
  4. matrice_analytics/boundary_drawing_internal/boundary_drawing_internal.py +1207 -0
  5. matrice_analytics/boundary_drawing_internal/boundary_drawing_tool.py +429 -0
  6. matrice_analytics/boundary_drawing_internal/boundary_tool_template.html +1036 -0
  7. matrice_analytics/boundary_drawing_internal/data/.gitignore +12 -0
  8. matrice_analytics/boundary_drawing_internal/example_usage.py +206 -0
  9. matrice_analytics/boundary_drawing_internal/usage/README.md +110 -0
  10. matrice_analytics/boundary_drawing_internal/usage/boundary_drawer_launcher.py +102 -0
  11. matrice_analytics/boundary_drawing_internal/usage/simple_boundary_launcher.py +107 -0
  12. matrice_analytics/post_processing/README.md +455 -0
  13. matrice_analytics/post_processing/__init__.py +732 -0
  14. matrice_analytics/post_processing/advanced_tracker/README.md +650 -0
  15. matrice_analytics/post_processing/advanced_tracker/__init__.py +17 -0
  16. matrice_analytics/post_processing/advanced_tracker/base.py +99 -0
  17. matrice_analytics/post_processing/advanced_tracker/config.py +77 -0
  18. matrice_analytics/post_processing/advanced_tracker/kalman_filter.py +370 -0
  19. matrice_analytics/post_processing/advanced_tracker/matching.py +195 -0
  20. matrice_analytics/post_processing/advanced_tracker/strack.py +230 -0
  21. matrice_analytics/post_processing/advanced_tracker/tracker.py +367 -0
  22. matrice_analytics/post_processing/config.py +146 -0
  23. matrice_analytics/post_processing/core/__init__.py +63 -0
  24. matrice_analytics/post_processing/core/base.py +704 -0
  25. matrice_analytics/post_processing/core/config.py +3291 -0
  26. matrice_analytics/post_processing/core/config_utils.py +925 -0
  27. matrice_analytics/post_processing/face_reg/__init__.py +43 -0
  28. matrice_analytics/post_processing/face_reg/compare_similarity.py +556 -0
  29. matrice_analytics/post_processing/face_reg/embedding_manager.py +950 -0
  30. matrice_analytics/post_processing/face_reg/face_recognition.py +2234 -0
  31. matrice_analytics/post_processing/face_reg/face_recognition_client.py +606 -0
  32. matrice_analytics/post_processing/face_reg/people_activity_logging.py +321 -0
  33. matrice_analytics/post_processing/ocr/__init__.py +0 -0
  34. matrice_analytics/post_processing/ocr/easyocr_extractor.py +250 -0
  35. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/__init__.py +9 -0
  36. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/__init__.py +4 -0
  37. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/cli.py +33 -0
  38. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/dataset_stats.py +139 -0
  39. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/export.py +398 -0
  40. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/train.py +447 -0
  41. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/utils.py +129 -0
  42. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/valid.py +93 -0
  43. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/validate_dataset.py +240 -0
  44. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/visualize_augmentation.py +176 -0
  45. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/visualize_predictions.py +96 -0
  46. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/__init__.py +3 -0
  47. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/process.py +246 -0
  48. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/types.py +60 -0
  49. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/utils.py +87 -0
  50. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/__init__.py +3 -0
  51. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/config.py +82 -0
  52. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/hub.py +141 -0
  53. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/plate_recognizer.py +323 -0
  54. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/py.typed +0 -0
  55. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/__init__.py +0 -0
  56. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/__init__.py +0 -0
  57. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/augmentation.py +101 -0
  58. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/dataset.py +97 -0
  59. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/__init__.py +0 -0
  60. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/config.py +114 -0
  61. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/layers.py +553 -0
  62. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/loss.py +55 -0
  63. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/metric.py +86 -0
  64. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/model_builders.py +95 -0
  65. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/model_schema.py +395 -0
  66. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/__init__.py +0 -0
  67. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/backend_utils.py +38 -0
  68. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/utils.py +214 -0
  69. matrice_analytics/post_processing/ocr/postprocessing.py +270 -0
  70. matrice_analytics/post_processing/ocr/preprocessing.py +52 -0
  71. matrice_analytics/post_processing/post_processor.py +1175 -0
  72. matrice_analytics/post_processing/test_cases/__init__.py +1 -0
  73. matrice_analytics/post_processing/test_cases/run_tests.py +143 -0
  74. matrice_analytics/post_processing/test_cases/test_advanced_customer_service.py +841 -0
  75. matrice_analytics/post_processing/test_cases/test_basic_counting_tracking.py +523 -0
  76. matrice_analytics/post_processing/test_cases/test_comprehensive.py +531 -0
  77. matrice_analytics/post_processing/test_cases/test_config.py +852 -0
  78. matrice_analytics/post_processing/test_cases/test_customer_service.py +585 -0
  79. matrice_analytics/post_processing/test_cases/test_data_generators.py +583 -0
  80. matrice_analytics/post_processing/test_cases/test_people_counting.py +510 -0
  81. matrice_analytics/post_processing/test_cases/test_processor.py +524 -0
  82. matrice_analytics/post_processing/test_cases/test_usecases.py +165 -0
  83. matrice_analytics/post_processing/test_cases/test_utilities.py +356 -0
  84. matrice_analytics/post_processing/test_cases/test_utils.py +743 -0
  85. matrice_analytics/post_processing/usecases/Histopathological_Cancer_Detection_img.py +604 -0
  86. matrice_analytics/post_processing/usecases/__init__.py +267 -0
  87. matrice_analytics/post_processing/usecases/abandoned_object_detection.py +797 -0
  88. matrice_analytics/post_processing/usecases/advanced_customer_service.py +1601 -0
  89. matrice_analytics/post_processing/usecases/age_detection.py +842 -0
  90. matrice_analytics/post_processing/usecases/age_gender_detection.py +1085 -0
  91. matrice_analytics/post_processing/usecases/anti_spoofing_detection.py +656 -0
  92. matrice_analytics/post_processing/usecases/assembly_line_detection.py +841 -0
  93. matrice_analytics/post_processing/usecases/banana_defect_detection.py +624 -0
  94. matrice_analytics/post_processing/usecases/basic_counting_tracking.py +667 -0
  95. matrice_analytics/post_processing/usecases/blood_cancer_detection_img.py +881 -0
  96. matrice_analytics/post_processing/usecases/car_damage_detection.py +834 -0
  97. matrice_analytics/post_processing/usecases/car_part_segmentation.py +946 -0
  98. matrice_analytics/post_processing/usecases/car_service.py +1601 -0
  99. matrice_analytics/post_processing/usecases/cardiomegaly_classification.py +864 -0
  100. matrice_analytics/post_processing/usecases/cell_microscopy_segmentation.py +897 -0
  101. matrice_analytics/post_processing/usecases/chicken_pose_detection.py +648 -0
  102. matrice_analytics/post_processing/usecases/child_monitoring.py +814 -0
  103. matrice_analytics/post_processing/usecases/color/clip.py +660 -0
  104. matrice_analytics/post_processing/usecases/color/clip_processor/merges.txt +48895 -0
  105. matrice_analytics/post_processing/usecases/color/clip_processor/preprocessor_config.json +28 -0
  106. matrice_analytics/post_processing/usecases/color/clip_processor/special_tokens_map.json +30 -0
  107. matrice_analytics/post_processing/usecases/color/clip_processor/tokenizer.json +245079 -0
  108. matrice_analytics/post_processing/usecases/color/clip_processor/tokenizer_config.json +32 -0
  109. matrice_analytics/post_processing/usecases/color/clip_processor/vocab.json +1 -0
  110. matrice_analytics/post_processing/usecases/color/color_map_utils.py +70 -0
  111. matrice_analytics/post_processing/usecases/color/color_mapper.py +468 -0
  112. matrice_analytics/post_processing/usecases/color_detection.py +1936 -0
  113. matrice_analytics/post_processing/usecases/color_map_utils.py +70 -0
  114. matrice_analytics/post_processing/usecases/concrete_crack_detection.py +827 -0
  115. matrice_analytics/post_processing/usecases/crop_weed_detection.py +781 -0
  116. matrice_analytics/post_processing/usecases/customer_service.py +1008 -0
  117. matrice_analytics/post_processing/usecases/defect_detection_products.py +936 -0
  118. matrice_analytics/post_processing/usecases/distracted_driver_detection.py +822 -0
  119. matrice_analytics/post_processing/usecases/drone_traffic_monitoring.py +585 -0
  120. matrice_analytics/post_processing/usecases/drowsy_driver_detection.py +829 -0
  121. matrice_analytics/post_processing/usecases/dwell_detection.py +829 -0
  122. matrice_analytics/post_processing/usecases/emergency_vehicle_detection.py +827 -0
  123. matrice_analytics/post_processing/usecases/face_emotion.py +813 -0
  124. matrice_analytics/post_processing/usecases/face_recognition.py +827 -0
  125. matrice_analytics/post_processing/usecases/fashion_detection.py +835 -0
  126. matrice_analytics/post_processing/usecases/field_mapping.py +902 -0
  127. matrice_analytics/post_processing/usecases/fire_detection.py +1146 -0
  128. matrice_analytics/post_processing/usecases/flare_analysis.py +836 -0
  129. matrice_analytics/post_processing/usecases/flower_segmentation.py +1006 -0
  130. matrice_analytics/post_processing/usecases/gas_leak_detection.py +837 -0
  131. matrice_analytics/post_processing/usecases/gender_detection.py +832 -0
  132. matrice_analytics/post_processing/usecases/human_activity_recognition.py +871 -0
  133. matrice_analytics/post_processing/usecases/intrusion_detection.py +1672 -0
  134. matrice_analytics/post_processing/usecases/leaf.py +821 -0
  135. matrice_analytics/post_processing/usecases/leaf_disease.py +840 -0
  136. matrice_analytics/post_processing/usecases/leak_detection.py +837 -0
  137. matrice_analytics/post_processing/usecases/license_plate_detection.py +1188 -0
  138. matrice_analytics/post_processing/usecases/license_plate_monitoring.py +1781 -0
  139. matrice_analytics/post_processing/usecases/litter_monitoring.py +717 -0
  140. matrice_analytics/post_processing/usecases/mask_detection.py +869 -0
  141. matrice_analytics/post_processing/usecases/natural_disaster.py +907 -0
  142. matrice_analytics/post_processing/usecases/parking.py +787 -0
  143. matrice_analytics/post_processing/usecases/parking_space_detection.py +822 -0
  144. matrice_analytics/post_processing/usecases/pcb_defect_detection.py +888 -0
  145. matrice_analytics/post_processing/usecases/pedestrian_detection.py +808 -0
  146. matrice_analytics/post_processing/usecases/people_counting.py +706 -0
  147. matrice_analytics/post_processing/usecases/people_counting_bckp.py +1683 -0
  148. matrice_analytics/post_processing/usecases/people_tracking.py +1842 -0
  149. matrice_analytics/post_processing/usecases/pipeline_detection.py +605 -0
  150. matrice_analytics/post_processing/usecases/plaque_segmentation_img.py +874 -0
  151. matrice_analytics/post_processing/usecases/pothole_segmentation.py +915 -0
  152. matrice_analytics/post_processing/usecases/ppe_compliance.py +645 -0
  153. matrice_analytics/post_processing/usecases/price_tag_detection.py +822 -0
  154. matrice_analytics/post_processing/usecases/proximity_detection.py +1901 -0
  155. matrice_analytics/post_processing/usecases/road_lane_detection.py +623 -0
  156. matrice_analytics/post_processing/usecases/road_traffic_density.py +832 -0
  157. matrice_analytics/post_processing/usecases/road_view_segmentation.py +915 -0
  158. matrice_analytics/post_processing/usecases/shelf_inventory_detection.py +583 -0
  159. matrice_analytics/post_processing/usecases/shoplifting_detection.py +822 -0
  160. matrice_analytics/post_processing/usecases/shopping_cart_analysis.py +899 -0
  161. matrice_analytics/post_processing/usecases/skin_cancer_classification_img.py +864 -0
  162. matrice_analytics/post_processing/usecases/smoker_detection.py +833 -0
  163. matrice_analytics/post_processing/usecases/solar_panel.py +810 -0
  164. matrice_analytics/post_processing/usecases/suspicious_activity_detection.py +1030 -0
  165. matrice_analytics/post_processing/usecases/template_usecase.py +380 -0
  166. matrice_analytics/post_processing/usecases/theft_detection.py +648 -0
  167. matrice_analytics/post_processing/usecases/traffic_sign_monitoring.py +724 -0
  168. matrice_analytics/post_processing/usecases/underground_pipeline_defect_detection.py +775 -0
  169. matrice_analytics/post_processing/usecases/underwater_pollution_detection.py +842 -0
  170. matrice_analytics/post_processing/usecases/vehicle_monitoring.py +1029 -0
  171. matrice_analytics/post_processing/usecases/warehouse_object_segmentation.py +899 -0
  172. matrice_analytics/post_processing/usecases/waterbody_segmentation.py +923 -0
  173. matrice_analytics/post_processing/usecases/weapon_detection.py +771 -0
  174. matrice_analytics/post_processing/usecases/weld_defect_detection.py +615 -0
  175. matrice_analytics/post_processing/usecases/wildlife_monitoring.py +898 -0
  176. matrice_analytics/post_processing/usecases/windmill_maintenance.py +834 -0
  177. matrice_analytics/post_processing/usecases/wound_segmentation.py +856 -0
  178. matrice_analytics/post_processing/utils/__init__.py +150 -0
  179. matrice_analytics/post_processing/utils/advanced_counting_utils.py +400 -0
  180. matrice_analytics/post_processing/utils/advanced_helper_utils.py +317 -0
  181. matrice_analytics/post_processing/utils/advanced_tracking_utils.py +461 -0
  182. matrice_analytics/post_processing/utils/alerting_utils.py +213 -0
  183. matrice_analytics/post_processing/utils/category_mapping_utils.py +94 -0
  184. matrice_analytics/post_processing/utils/color_utils.py +592 -0
  185. matrice_analytics/post_processing/utils/counting_utils.py +182 -0
  186. matrice_analytics/post_processing/utils/filter_utils.py +261 -0
  187. matrice_analytics/post_processing/utils/format_utils.py +293 -0
  188. matrice_analytics/post_processing/utils/geometry_utils.py +300 -0
  189. matrice_analytics/post_processing/utils/smoothing_utils.py +358 -0
  190. matrice_analytics/post_processing/utils/tracking_utils.py +234 -0
  191. matrice_analytics/py.typed +0 -0
  192. matrice_analytics-0.1.60.dist-info/METADATA +481 -0
  193. matrice_analytics-0.1.60.dist-info/RECORD +196 -0
  194. matrice_analytics-0.1.60.dist-info/WHEEL +5 -0
  195. matrice_analytics-0.1.60.dist-info/licenses/LICENSE.txt +21 -0
  196. matrice_analytics-0.1.60.dist-info/top_level.txt +1 -0
@@ -0,0 +1,447 @@
1
+ """
2
+ Script for training the License Plate OCR models.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import json
8
+ import pathlib
9
+ import shutil
10
+ from datetime import datetime
11
+ from typing import Literal, Optional
12
+
13
+ import albumentations as A
14
+ import click
15
+ import keras
16
+ from keras.src.callbacks import (
17
+ CSVLogger,
18
+ EarlyStopping,
19
+ ModelCheckpoint,
20
+ SwapEMAWeights,
21
+ TensorBoard,
22
+ TerminateOnNaN,
23
+ )
24
+ from keras.src.optimizers import AdamW
25
+
26
+ import fast_plate_ocr.train.model.model_builders
27
+ from fast_plate_ocr.cli.utils import print_params, print_train_details
28
+ from fast_plate_ocr.train.data.augmentation import (
29
+ default_train_augmentation,
30
+ )
31
+ from fast_plate_ocr.train.data.dataset import PlateRecognitionPyDataset
32
+ from fast_plate_ocr.train.model.config import load_plate_config_from_yaml
33
+ from fast_plate_ocr.train.model.loss import cce_loss, focal_cce_loss
34
+ from fast_plate_ocr.train.model.metric import (
35
+ cat_acc_metric,
36
+ plate_acc_metric,
37
+ plate_len_acc_metric,
38
+ top_3_k_metric,
39
+ )
40
+ from fast_plate_ocr.train.model.model_schema import load_model_config_from_yaml
41
+
42
+ # ruff: noqa: PLR0913
43
+ # pylint: disable=too-many-arguments,too-many-locals
44
+
45
+
46
+ EVAL_METRICS: dict[str, Literal["max", "min", "auto"]] = {
47
+ "val_plate_acc": "max",
48
+ "val_cat_acc": "max",
49
+ "val_top_3_k_acc": "max",
50
+ "val_plate_len_acc": "max",
51
+ "val_loss": "min",
52
+ }
53
+ """Eval metric to monitor."""
54
+
55
+
56
+ @click.command(context_settings={"max_content_width": 120})
57
+ @click.option(
58
+ "--model-config-file",
59
+ required=True,
60
+ type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=pathlib.Path),
61
+ help="Path to the YAML config that describes the model architecture.",
62
+ )
63
+ @click.option(
64
+ "--plate-config-file",
65
+ required=True,
66
+ type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=pathlib.Path),
67
+ help="Path to the plate YAML config.",
68
+ )
69
+ @click.option(
70
+ "--annotations",
71
+ required=True,
72
+ type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=pathlib.Path),
73
+ help="Path pointing to the train annotations CSV file.",
74
+ )
75
+ @click.option(
76
+ "--val-annotations",
77
+ required=True,
78
+ type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=pathlib.Path),
79
+ help="Path pointing to the train validation CSV file.",
80
+ )
81
+ @click.option(
82
+ "--validation-freq",
83
+ default=1,
84
+ show_default=True,
85
+ type=int,
86
+ help="Frequency (in epochs) at which to evaluate the validation data.",
87
+ )
88
+ @click.option(
89
+ "--augmentation-path",
90
+ type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=pathlib.Path),
91
+ help="YAML file pointing to the augmentation pipeline saved with Albumentations.save(...)",
92
+ )
93
+ @click.option(
94
+ "--lr",
95
+ default=0.001,
96
+ show_default=True,
97
+ type=float,
98
+ help="Initial learning rate.",
99
+ )
100
+ @click.option(
101
+ "--final-lr-factor",
102
+ default=1e-2,
103
+ show_default=True,
104
+ type=float,
105
+ help="Final learning rate factor for the cosine decay scheduler. It's the fraction of"
106
+ " the initial learning rate that remains after decay.",
107
+ )
108
+ @click.option(
109
+ "--warmup-fraction",
110
+ default=0.05,
111
+ show_default=True,
112
+ type=float,
113
+ help="Fraction of total training steps to linearly warm up.",
114
+ )
115
+ @click.option(
116
+ "--weight-decay",
117
+ default=0.001,
118
+ show_default=True,
119
+ type=float,
120
+ help="Weight decay for the AdamW optimizer.",
121
+ )
122
+ @click.option(
123
+ "--clipnorm",
124
+ default=1.0,
125
+ show_default=True,
126
+ type=float,
127
+ help="Gradient clipping norm value for the AdamW optimizer.",
128
+ )
129
+ @click.option(
130
+ "--loss",
131
+ default="cce",
132
+ type=click.Choice(["cce", "focal_cce"], case_sensitive=False),
133
+ show_default=True,
134
+ help="Loss function to use during training.",
135
+ )
136
+ @click.option(
137
+ "--focal-alpha",
138
+ default=0.25,
139
+ show_default=True,
140
+ type=float,
141
+ help="Alpha parameter for focal loss. Applicable only when '--loss' is 'focal_cce'.",
142
+ )
143
+ @click.option(
144
+ "--focal-gamma",
145
+ default=2.0,
146
+ show_default=True,
147
+ type=float,
148
+ help="Gamma parameter for focal loss. Applicable only when '--loss' is 'focal_cce'.",
149
+ )
150
+ @click.option(
151
+ "--label-smoothing",
152
+ default=0.01,
153
+ show_default=True,
154
+ type=float,
155
+ help="Amount of label smoothing to apply.",
156
+ )
157
+ @click.option(
158
+ "--mixed-precision-policy",
159
+ default=None,
160
+ type=click.Choice(["mixed_float16", "mixed_bfloat16", "float32"]),
161
+ help=(
162
+ "Optional mixed precision policy for training. Choose one of: mixed_float16, "
163
+ "mixed_bfloat16, or float32. If not provided, Keras uses its default global policy."
164
+ ),
165
+ )
166
+ @click.option(
167
+ "--batch-size",
168
+ default=64,
169
+ show_default=True,
170
+ type=int,
171
+ help="Batch size for training.",
172
+ )
173
+ @click.option(
174
+ "--workers",
175
+ default=1,
176
+ show_default=True,
177
+ type=int,
178
+ help="Number of worker threads/processes for parallel data loading.",
179
+ )
180
+ @click.option(
181
+ "--use-multiprocessing/--no-use-multiprocessing",
182
+ default=False,
183
+ show_default=True,
184
+ help="Use multiprocessing for data loading.",
185
+ )
186
+ @click.option(
187
+ "--max-queue-size",
188
+ default=10,
189
+ show_default=True,
190
+ type=int,
191
+ help="Maximum queue size for dataset workers.",
192
+ )
193
+ @click.option(
194
+ "--output-dir",
195
+ default="./trained_models",
196
+ type=click.Path(dir_okay=True, file_okay=False, path_type=pathlib.Path),
197
+ help="Output directory where model will be saved.",
198
+ )
199
+ @click.option(
200
+ "--epochs",
201
+ default=300,
202
+ show_default=True,
203
+ type=int,
204
+ help="Number of training epochs.",
205
+ )
206
+ @click.option(
207
+ "--tensorboard",
208
+ "-t",
209
+ is_flag=True,
210
+ help="Whether to use TensorBoard visualization tool.",
211
+ )
212
+ @click.option(
213
+ "--tensorboard-dir",
214
+ "-l",
215
+ default="tensorboard_logs",
216
+ show_default=True,
217
+ type=click.Path(path_type=pathlib.Path),
218
+ help="The path of the directory where to save the TensorBoard log files.",
219
+ )
220
+ @click.option(
221
+ "--early-stopping-patience",
222
+ default=100,
223
+ show_default=True,
224
+ type=int,
225
+ help="Stop training when the early stopping metric doesn't improve for X epochs.",
226
+ )
227
+ @click.option(
228
+ "--early-stopping-metric",
229
+ default="val_plate_acc",
230
+ show_default=True,
231
+ type=click.Choice(list(EVAL_METRICS), case_sensitive=False),
232
+ help="Metric to monitor for early stopping.",
233
+ )
234
+ @click.option(
235
+ "--weights-path",
236
+ type=click.Path(exists=True, file_okay=True, path_type=pathlib.Path),
237
+ help="Path to the pretrained model weights file.",
238
+ )
239
+ @click.option(
240
+ "--use-ema/--no-use-ema",
241
+ default=True,
242
+ show_default=True,
243
+ help=(
244
+ "Whether to use exponential moving averages in the AdamW optimizer. "
245
+ "Defaults to True; use --no-use-ema to disable."
246
+ ),
247
+ )
248
+ @click.option(
249
+ "--wd-ignore",
250
+ default="bias,layer_norm",
251
+ show_default=True,
252
+ type=str,
253
+ help="Comma-separated list of variable substrings to exclude from weight decay.",
254
+ )
255
+ @click.option(
256
+ "--seed",
257
+ type=int,
258
+ help="Sets all random seeds (Python, NumPy, and backend framework, e.g. TF).",
259
+ )
260
+ @print_params(table_title="CLI Training Parameters", c1_title="Parameter", c2_title="Details")
261
+ def train(
262
+ model_config_file: pathlib.Path,
263
+ plate_config_file: pathlib.Path,
264
+ annotations: pathlib.Path,
265
+ val_annotations: pathlib.Path,
266
+ validation_freq: int,
267
+ augmentation_path: Optional[pathlib.Path],
268
+ lr: float,
269
+ final_lr_factor: float,
270
+ warmup_fraction: float,
271
+ weight_decay: float,
272
+ clipnorm: float,
273
+ loss: str,
274
+ focal_alpha: float,
275
+ focal_gamma: float,
276
+ label_smoothing: float,
277
+ mixed_precision_policy: Optional[str],
278
+ batch_size: int,
279
+ workers: int,
280
+ use_multiprocessing: bool,
281
+ max_queue_size: int,
282
+ output_dir: pathlib.Path,
283
+ epochs: int,
284
+ tensorboard: bool,
285
+ tensorboard_dir: pathlib.Path,
286
+ early_stopping_patience: int,
287
+ early_stopping_metric: str,
288
+ weights_path: Optional[pathlib.Path],
289
+ use_ema: bool,
290
+ wd_ignore: str,
291
+ seed: Optional[int],
292
+ ) -> None:
293
+ """
294
+ Train the License Plate OCR model.
295
+ """
296
+ if seed is not None:
297
+ keras.utils.set_random_seed(seed)
298
+
299
+ if mixed_precision_policy is not None:
300
+ keras.mixed_precision.set_global_policy(mixed_precision_policy)
301
+
302
+ plate_config = load_plate_config_from_yaml(plate_config_file)
303
+ model_config = load_model_config_from_yaml(model_config_file)
304
+ train_augmentation = (
305
+ A.load(augmentation_path, data_format="yaml")
306
+ if augmentation_path
307
+ else default_train_augmentation(img_color_mode=plate_config.image_color_mode)
308
+ )
309
+ print_train_details(train_augmentation, plate_config.model_dump())
310
+
311
+ train_dataset = PlateRecognitionPyDataset(
312
+ annotations_file=annotations,
313
+ transform=train_augmentation,
314
+ plate_config=plate_config,
315
+ batch_size=batch_size,
316
+ shuffle=True,
317
+ workers=workers,
318
+ use_multiprocessing=use_multiprocessing,
319
+ max_queue_size=max_queue_size,
320
+ )
321
+
322
+ val_dataset = PlateRecognitionPyDataset(
323
+ annotations_file=val_annotations,
324
+ plate_config=plate_config,
325
+ batch_size=batch_size,
326
+ shuffle=False,
327
+ workers=workers,
328
+ use_multiprocessing=use_multiprocessing,
329
+ max_queue_size=max_queue_size,
330
+ )
331
+
332
+ # Train
333
+ model = fast_plate_ocr.train.model.model_builders.build_model(model_config, plate_config)
334
+
335
+ if weights_path:
336
+ model.load_weights(weights_path, skip_mismatch=True)
337
+
338
+ total_steps = epochs * len(train_dataset)
339
+ warmup_steps = int(warmup_fraction * total_steps)
340
+
341
+ cosine_decay = keras.optimizers.schedules.CosineDecay(
342
+ initial_learning_rate=0.0 if warmup_steps > 0 else lr,
343
+ decay_steps=total_steps,
344
+ alpha=final_lr_factor,
345
+ warmup_steps=warmup_steps,
346
+ warmup_target=lr if warmup_steps > 0 else None,
347
+ )
348
+
349
+ optimizer = AdamW(cosine_decay, weight_decay=weight_decay, clipnorm=clipnorm, use_ema=use_ema)
350
+ optimizer.exclude_from_weight_decay(
351
+ var_names=[name.strip() for name in wd_ignore.split(",") if name.strip()]
352
+ )
353
+
354
+ if loss == "cce":
355
+ loss_fn = cce_loss(
356
+ vocabulary_size=plate_config.vocabulary_size, label_smoothing=label_smoothing
357
+ )
358
+ elif loss == "focal_cce":
359
+ loss_fn = focal_cce_loss(
360
+ vocabulary_size=plate_config.vocabulary_size,
361
+ alpha=focal_alpha,
362
+ gamma=focal_gamma,
363
+ label_smoothing=label_smoothing,
364
+ )
365
+ else:
366
+ raise ValueError(f"Unsupported loss type: {loss}")
367
+
368
+ model.compile(
369
+ loss=loss_fn,
370
+ jit_compile=False,
371
+ optimizer=optimizer,
372
+ metrics=[
373
+ cat_acc_metric(
374
+ max_plate_slots=plate_config.max_plate_slots,
375
+ vocabulary_size=plate_config.vocabulary_size,
376
+ ),
377
+ plate_acc_metric(
378
+ max_plate_slots=plate_config.max_plate_slots,
379
+ vocabulary_size=plate_config.vocabulary_size,
380
+ ),
381
+ top_3_k_metric(vocabulary_size=plate_config.vocabulary_size),
382
+ plate_len_acc_metric(
383
+ max_plate_slots=plate_config.max_plate_slots,
384
+ vocabulary_size=plate_config.vocabulary_size,
385
+ pad_token_index=plate_config.pad_idx,
386
+ ),
387
+ ],
388
+ )
389
+
390
+ output_dir /= datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
391
+ output_dir.mkdir(parents=True, exist_ok=True)
392
+ model_file_path = output_dir / "ckpt-epoch_{epoch:02d}-acc_{val_plate_acc:.3f}.keras"
393
+
394
+ # Save params and configs used for training
395
+ shutil.copy(model_config_file, output_dir / "model_config.yaml")
396
+ shutil.copy(plate_config_file, output_dir / "plate_config.yaml")
397
+ A.save(train_augmentation, output_dir / "train_augmentation.yaml", "yaml")
398
+ with open(output_dir / "hyper_params.json", "w", encoding="utf-8") as f_out:
399
+ json.dump(
400
+ {k: v for k, v in locals().items() if k in click.get_current_context().params},
401
+ f_out,
402
+ indent=4,
403
+ default=str,
404
+ )
405
+
406
+ callbacks = [
407
+ # Stop training when early_stopping_metric doesn't improve for X epochs
408
+ EarlyStopping(
409
+ monitor=early_stopping_metric,
410
+ patience=early_stopping_patience,
411
+ mode=EVAL_METRICS[early_stopping_metric],
412
+ restore_best_weights=False,
413
+ verbose=1,
414
+ ),
415
+ # To save model checkpoint with EMA weights, we need to place this before `ModelCheckpoint`
416
+ *([SwapEMAWeights(swap_on_epoch=True)] if use_ema else []),
417
+ # We don't use EarlyStopping restore_best_weights=True because it won't restore the best
418
+ # weights when it didn't manage to EarlyStop but finished all epochs
419
+ ModelCheckpoint(output_dir / "last.keras", save_weights_only=False, save_best_only=False),
420
+ ModelCheckpoint(
421
+ model_file_path,
422
+ monitor=early_stopping_metric,
423
+ mode=EVAL_METRICS[early_stopping_metric],
424
+ save_weights_only=False,
425
+ save_best_only=True,
426
+ verbose=1,
427
+ ),
428
+ TerminateOnNaN(),
429
+ CSVLogger(str(output_dir / "training_log.csv")),
430
+ ]
431
+
432
+ if tensorboard:
433
+ run_dir = tensorboard_dir / datetime.now().strftime("run_%Y-%m-%d_%H-%M-%S")
434
+ run_dir.mkdir(parents=True, exist_ok=True)
435
+ callbacks.append(TensorBoard(log_dir=run_dir))
436
+
437
+ model.fit(
438
+ train_dataset,
439
+ epochs=epochs,
440
+ validation_data=val_dataset,
441
+ callbacks=callbacks,
442
+ validation_freq=validation_freq,
443
+ )
444
+
445
+
446
+ if __name__ == "__main__":
447
+ train()
@@ -0,0 +1,129 @@
1
+ """
2
+ Utils used for the CLI scripts.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import importlib.util
8
+ import inspect
9
+ import pathlib
10
+ import random
11
+ from collections.abc import Callable, Sequence
12
+ from functools import wraps
13
+ from typing import Any, Optional
14
+
15
+ import albumentations as A
16
+ import numpy as np
17
+ from rich import box
18
+ from rich.console import Console
19
+ from rich.pretty import Pretty
20
+ from rich.table import Table
21
+
22
+
23
+ def print_variables_as_table(
24
+ c1_title: str, c2_title: str, title: str = "Variables Table", **kwargs: Any
25
+ ) -> None:
26
+ """
27
+ Prints variables in a formatted table using the rich library.
28
+
29
+ Args:
30
+ c1_title (str): Title of the first column.
31
+ c2_title (str): Title of the second column.
32
+ title (str): Title of the table.
33
+ **kwargs (Any): Variable names and values to be printed.
34
+ """
35
+ console = Console()
36
+ console.print("\n")
37
+ table = Table(title=title, show_header=True, header_style="bold blue", box=box.ROUNDED)
38
+ table.add_column(c1_title, min_width=20, justify="left", style="bold")
39
+ table.add_column(c2_title, min_width=60, justify="left", style="bold")
40
+
41
+ for key, value in kwargs.items():
42
+ if isinstance(value, pathlib.Path):
43
+ value = str(value) # noqa: PLW2901
44
+ table.add_row(f"[bold]{key}[/bold]", Pretty(value))
45
+
46
+ console.print(table)
47
+
48
+
49
+ def print_params(
50
+ table_title: str = "Parameters Table", c1_title: str = "Variable", c2_title: str = "Value"
51
+ ) -> Callable:
52
+ """
53
+ A decorator that prints the parameters of a function in a formatted table
54
+ using the rich library.
55
+
56
+ Args:
57
+ c1_title (str, optional): Title of the first column. Defaults to "Variable".
58
+ c2_title (str, optional): Title of the second column. Defaults to "Value".
59
+ table_title (str, optional): Title of the table. Defaults to "Parameters Table".
60
+
61
+ Returns:
62
+ Callable: The wrapped function with parameter printing functionality.
63
+ """
64
+
65
+ def decorator(func: Callable) -> Callable:
66
+ @wraps(func)
67
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
68
+ func_signature = inspect.signature(func)
69
+ bound_arguments = func_signature.bind(*args, **kwargs)
70
+ bound_arguments.apply_defaults()
71
+ params = dict(bound_arguments.arguments.items())
72
+ print_variables_as_table(c1_title, c2_title, table_title, **params)
73
+ return func(*args, **kwargs)
74
+
75
+ return wrapper
76
+
77
+ return decorator
78
+
79
+
80
+ def print_train_details(augmentation: A.Compose, config: dict[str, Any]) -> None:
81
+ console = Console()
82
+ console.print("\n")
83
+ console.print("[bold blue]Augmentation Pipeline:[/bold blue]")
84
+ console.print(Pretty(augmentation))
85
+ console.print("\n")
86
+ console.print("[bold blue]Configuration:[/bold blue]")
87
+ console.print(Pretty(config))
88
+ console.print("\n")
89
+
90
+
91
+ def requires(*modules: str, pkg_name: Optional[Sequence[str]] = None) -> Callable:
92
+ """
93
+ Decorator that checks if given modules are importable. If not, raises ModuleNotFoundError with
94
+ a hint to install the package(s).
95
+
96
+ Args:
97
+ modules (str): Names of modules to check (via importlib.util.find_spec).
98
+ pkg_name (Optional[Sequence[str]]): Names of packages to suggest installing.
99
+
100
+ Returns:
101
+ Callable: The wrapped function that checks for module availability.
102
+ """
103
+
104
+ def decorator(fn: Callable) -> Callable:
105
+ @wraps(fn)
106
+ def wrapper(*args: Any, **kwargs: Any):
107
+ missing = [m for m in modules if importlib.util.find_spec(m) is None]
108
+ if missing:
109
+ pkg_missing = " ".join(pkg_name or missing)
110
+ raise ModuleNotFoundError(
111
+ f"Cannot run `{fn.__name__}` because {missing!r} "
112
+ f"is not installed. Please install the required package(s): {pkg_missing}"
113
+ )
114
+ return fn(*args, **kwargs)
115
+
116
+ return wrapper
117
+
118
+ return decorator
119
+
120
+
121
+ def seed_everything(seed: int) -> None:
122
+ """
123
+ Seed random number generators for reproducibility.
124
+
125
+ Args:
126
+ seed (int): The seed value to set.
127
+ """
128
+ random.seed(seed)
129
+ np.random.seed(seed)
@@ -0,0 +1,93 @@
1
+ """
2
+ Script for validating trained OCR models.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import pathlib
8
+
9
+ import click
10
+
11
+ from fast_plate_ocr.train.data.dataset import PlateRecognitionPyDataset
12
+ from fast_plate_ocr.train.model.config import load_plate_config_from_yaml
13
+ from fast_plate_ocr.train.utilities.utils import load_keras_model
14
+
15
+
16
+ @click.command(context_settings={"max_content_width": 120})
17
+ @click.option(
18
+ "-m",
19
+ "--model",
20
+ "model_path",
21
+ required=True,
22
+ type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=pathlib.Path),
23
+ help="Path to the saved .keras model.",
24
+ )
25
+ @click.option(
26
+ "--plate-config-file",
27
+ required=True,
28
+ type=click.Path(exists=True, file_okay=True, path_type=pathlib.Path),
29
+ help="Path pointing to the model license plate OCR config.",
30
+ )
31
+ @click.option(
32
+ "-a",
33
+ "--annotations",
34
+ required=True,
35
+ type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=pathlib.Path),
36
+ help="Annotations file used for validation.",
37
+ )
38
+ @click.option(
39
+ "-b",
40
+ "--batch-size",
41
+ default=1,
42
+ show_default=True,
43
+ type=int,
44
+ help="Batch size.",
45
+ )
46
+ @click.option(
47
+ "--workers",
48
+ default=1,
49
+ show_default=True,
50
+ type=int,
51
+ help="Number of worker threads/processes for parallel data loading via PyDataset.",
52
+ )
53
+ @click.option(
54
+ "--use-multiprocessing/--no-use-multiprocessing",
55
+ default=False,
56
+ show_default=True,
57
+ help="Whether to use multiprocessing for data loading.",
58
+ )
59
+ @click.option(
60
+ "--max-queue-size",
61
+ default=10,
62
+ show_default=True,
63
+ type=int,
64
+ help="Maximum number of batches to prefetch for the dataset.",
65
+ )
66
+ def valid(
67
+ model_path: pathlib.Path,
68
+ plate_config_file: pathlib.Path,
69
+ annotations: pathlib.Path,
70
+ batch_size: int,
71
+ workers: int,
72
+ use_multiprocessing: bool,
73
+ max_queue_size: int,
74
+ ) -> None:
75
+ """
76
+ Validate the trained OCR model on a labeled set.
77
+ """
78
+ plate_config = load_plate_config_from_yaml(plate_config_file)
79
+ model = load_keras_model(model_path, plate_config)
80
+ val_dataset = PlateRecognitionPyDataset(
81
+ annotations_file=annotations,
82
+ plate_config=plate_config,
83
+ batch_size=batch_size,
84
+ shuffle=False,
85
+ workers=workers,
86
+ use_multiprocessing=use_multiprocessing,
87
+ max_queue_size=max_queue_size,
88
+ )
89
+ model.evaluate(val_dataset)
90
+
91
+
92
+ if __name__ == "__main__":
93
+ valid()