matrice-analytics 0.1.60__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- matrice_analytics/__init__.py +28 -0
- matrice_analytics/boundary_drawing_internal/README.md +305 -0
- matrice_analytics/boundary_drawing_internal/__init__.py +45 -0
- matrice_analytics/boundary_drawing_internal/boundary_drawing_internal.py +1207 -0
- matrice_analytics/boundary_drawing_internal/boundary_drawing_tool.py +429 -0
- matrice_analytics/boundary_drawing_internal/boundary_tool_template.html +1036 -0
- matrice_analytics/boundary_drawing_internal/data/.gitignore +12 -0
- matrice_analytics/boundary_drawing_internal/example_usage.py +206 -0
- matrice_analytics/boundary_drawing_internal/usage/README.md +110 -0
- matrice_analytics/boundary_drawing_internal/usage/boundary_drawer_launcher.py +102 -0
- matrice_analytics/boundary_drawing_internal/usage/simple_boundary_launcher.py +107 -0
- matrice_analytics/post_processing/README.md +455 -0
- matrice_analytics/post_processing/__init__.py +732 -0
- matrice_analytics/post_processing/advanced_tracker/README.md +650 -0
- matrice_analytics/post_processing/advanced_tracker/__init__.py +17 -0
- matrice_analytics/post_processing/advanced_tracker/base.py +99 -0
- matrice_analytics/post_processing/advanced_tracker/config.py +77 -0
- matrice_analytics/post_processing/advanced_tracker/kalman_filter.py +370 -0
- matrice_analytics/post_processing/advanced_tracker/matching.py +195 -0
- matrice_analytics/post_processing/advanced_tracker/strack.py +230 -0
- matrice_analytics/post_processing/advanced_tracker/tracker.py +367 -0
- matrice_analytics/post_processing/config.py +146 -0
- matrice_analytics/post_processing/core/__init__.py +63 -0
- matrice_analytics/post_processing/core/base.py +704 -0
- matrice_analytics/post_processing/core/config.py +3291 -0
- matrice_analytics/post_processing/core/config_utils.py +925 -0
- matrice_analytics/post_processing/face_reg/__init__.py +43 -0
- matrice_analytics/post_processing/face_reg/compare_similarity.py +556 -0
- matrice_analytics/post_processing/face_reg/embedding_manager.py +950 -0
- matrice_analytics/post_processing/face_reg/face_recognition.py +2234 -0
- matrice_analytics/post_processing/face_reg/face_recognition_client.py +606 -0
- matrice_analytics/post_processing/face_reg/people_activity_logging.py +321 -0
- matrice_analytics/post_processing/ocr/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/easyocr_extractor.py +250 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/__init__.py +9 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/__init__.py +4 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/cli.py +33 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/dataset_stats.py +139 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/export.py +398 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/train.py +447 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/utils.py +129 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/valid.py +93 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/validate_dataset.py +240 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/visualize_augmentation.py +176 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/visualize_predictions.py +96 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/__init__.py +3 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/process.py +246 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/types.py +60 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/utils.py +87 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/__init__.py +3 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/config.py +82 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/hub.py +141 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/plate_recognizer.py +323 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/py.typed +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/augmentation.py +101 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/dataset.py +97 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/config.py +114 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/layers.py +553 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/loss.py +55 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/metric.py +86 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/model_builders.py +95 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/model_schema.py +395 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/backend_utils.py +38 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/utils.py +214 -0
- matrice_analytics/post_processing/ocr/postprocessing.py +270 -0
- matrice_analytics/post_processing/ocr/preprocessing.py +52 -0
- matrice_analytics/post_processing/post_processor.py +1175 -0
- matrice_analytics/post_processing/test_cases/__init__.py +1 -0
- matrice_analytics/post_processing/test_cases/run_tests.py +143 -0
- matrice_analytics/post_processing/test_cases/test_advanced_customer_service.py +841 -0
- matrice_analytics/post_processing/test_cases/test_basic_counting_tracking.py +523 -0
- matrice_analytics/post_processing/test_cases/test_comprehensive.py +531 -0
- matrice_analytics/post_processing/test_cases/test_config.py +852 -0
- matrice_analytics/post_processing/test_cases/test_customer_service.py +585 -0
- matrice_analytics/post_processing/test_cases/test_data_generators.py +583 -0
- matrice_analytics/post_processing/test_cases/test_people_counting.py +510 -0
- matrice_analytics/post_processing/test_cases/test_processor.py +524 -0
- matrice_analytics/post_processing/test_cases/test_usecases.py +165 -0
- matrice_analytics/post_processing/test_cases/test_utilities.py +356 -0
- matrice_analytics/post_processing/test_cases/test_utils.py +743 -0
- matrice_analytics/post_processing/usecases/Histopathological_Cancer_Detection_img.py +604 -0
- matrice_analytics/post_processing/usecases/__init__.py +267 -0
- matrice_analytics/post_processing/usecases/abandoned_object_detection.py +797 -0
- matrice_analytics/post_processing/usecases/advanced_customer_service.py +1601 -0
- matrice_analytics/post_processing/usecases/age_detection.py +842 -0
- matrice_analytics/post_processing/usecases/age_gender_detection.py +1085 -0
- matrice_analytics/post_processing/usecases/anti_spoofing_detection.py +656 -0
- matrice_analytics/post_processing/usecases/assembly_line_detection.py +841 -0
- matrice_analytics/post_processing/usecases/banana_defect_detection.py +624 -0
- matrice_analytics/post_processing/usecases/basic_counting_tracking.py +667 -0
- matrice_analytics/post_processing/usecases/blood_cancer_detection_img.py +881 -0
- matrice_analytics/post_processing/usecases/car_damage_detection.py +834 -0
- matrice_analytics/post_processing/usecases/car_part_segmentation.py +946 -0
- matrice_analytics/post_processing/usecases/car_service.py +1601 -0
- matrice_analytics/post_processing/usecases/cardiomegaly_classification.py +864 -0
- matrice_analytics/post_processing/usecases/cell_microscopy_segmentation.py +897 -0
- matrice_analytics/post_processing/usecases/chicken_pose_detection.py +648 -0
- matrice_analytics/post_processing/usecases/child_monitoring.py +814 -0
- matrice_analytics/post_processing/usecases/color/clip.py +660 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/merges.txt +48895 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/preprocessor_config.json +28 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/special_tokens_map.json +30 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/tokenizer.json +245079 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/tokenizer_config.json +32 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/vocab.json +1 -0
- matrice_analytics/post_processing/usecases/color/color_map_utils.py +70 -0
- matrice_analytics/post_processing/usecases/color/color_mapper.py +468 -0
- matrice_analytics/post_processing/usecases/color_detection.py +1936 -0
- matrice_analytics/post_processing/usecases/color_map_utils.py +70 -0
- matrice_analytics/post_processing/usecases/concrete_crack_detection.py +827 -0
- matrice_analytics/post_processing/usecases/crop_weed_detection.py +781 -0
- matrice_analytics/post_processing/usecases/customer_service.py +1008 -0
- matrice_analytics/post_processing/usecases/defect_detection_products.py +936 -0
- matrice_analytics/post_processing/usecases/distracted_driver_detection.py +822 -0
- matrice_analytics/post_processing/usecases/drone_traffic_monitoring.py +585 -0
- matrice_analytics/post_processing/usecases/drowsy_driver_detection.py +829 -0
- matrice_analytics/post_processing/usecases/dwell_detection.py +829 -0
- matrice_analytics/post_processing/usecases/emergency_vehicle_detection.py +827 -0
- matrice_analytics/post_processing/usecases/face_emotion.py +813 -0
- matrice_analytics/post_processing/usecases/face_recognition.py +827 -0
- matrice_analytics/post_processing/usecases/fashion_detection.py +835 -0
- matrice_analytics/post_processing/usecases/field_mapping.py +902 -0
- matrice_analytics/post_processing/usecases/fire_detection.py +1146 -0
- matrice_analytics/post_processing/usecases/flare_analysis.py +836 -0
- matrice_analytics/post_processing/usecases/flower_segmentation.py +1006 -0
- matrice_analytics/post_processing/usecases/gas_leak_detection.py +837 -0
- matrice_analytics/post_processing/usecases/gender_detection.py +832 -0
- matrice_analytics/post_processing/usecases/human_activity_recognition.py +871 -0
- matrice_analytics/post_processing/usecases/intrusion_detection.py +1672 -0
- matrice_analytics/post_processing/usecases/leaf.py +821 -0
- matrice_analytics/post_processing/usecases/leaf_disease.py +840 -0
- matrice_analytics/post_processing/usecases/leak_detection.py +837 -0
- matrice_analytics/post_processing/usecases/license_plate_detection.py +1188 -0
- matrice_analytics/post_processing/usecases/license_plate_monitoring.py +1781 -0
- matrice_analytics/post_processing/usecases/litter_monitoring.py +717 -0
- matrice_analytics/post_processing/usecases/mask_detection.py +869 -0
- matrice_analytics/post_processing/usecases/natural_disaster.py +907 -0
- matrice_analytics/post_processing/usecases/parking.py +787 -0
- matrice_analytics/post_processing/usecases/parking_space_detection.py +822 -0
- matrice_analytics/post_processing/usecases/pcb_defect_detection.py +888 -0
- matrice_analytics/post_processing/usecases/pedestrian_detection.py +808 -0
- matrice_analytics/post_processing/usecases/people_counting.py +706 -0
- matrice_analytics/post_processing/usecases/people_counting_bckp.py +1683 -0
- matrice_analytics/post_processing/usecases/people_tracking.py +1842 -0
- matrice_analytics/post_processing/usecases/pipeline_detection.py +605 -0
- matrice_analytics/post_processing/usecases/plaque_segmentation_img.py +874 -0
- matrice_analytics/post_processing/usecases/pothole_segmentation.py +915 -0
- matrice_analytics/post_processing/usecases/ppe_compliance.py +645 -0
- matrice_analytics/post_processing/usecases/price_tag_detection.py +822 -0
- matrice_analytics/post_processing/usecases/proximity_detection.py +1901 -0
- matrice_analytics/post_processing/usecases/road_lane_detection.py +623 -0
- matrice_analytics/post_processing/usecases/road_traffic_density.py +832 -0
- matrice_analytics/post_processing/usecases/road_view_segmentation.py +915 -0
- matrice_analytics/post_processing/usecases/shelf_inventory_detection.py +583 -0
- matrice_analytics/post_processing/usecases/shoplifting_detection.py +822 -0
- matrice_analytics/post_processing/usecases/shopping_cart_analysis.py +899 -0
- matrice_analytics/post_processing/usecases/skin_cancer_classification_img.py +864 -0
- matrice_analytics/post_processing/usecases/smoker_detection.py +833 -0
- matrice_analytics/post_processing/usecases/solar_panel.py +810 -0
- matrice_analytics/post_processing/usecases/suspicious_activity_detection.py +1030 -0
- matrice_analytics/post_processing/usecases/template_usecase.py +380 -0
- matrice_analytics/post_processing/usecases/theft_detection.py +648 -0
- matrice_analytics/post_processing/usecases/traffic_sign_monitoring.py +724 -0
- matrice_analytics/post_processing/usecases/underground_pipeline_defect_detection.py +775 -0
- matrice_analytics/post_processing/usecases/underwater_pollution_detection.py +842 -0
- matrice_analytics/post_processing/usecases/vehicle_monitoring.py +1029 -0
- matrice_analytics/post_processing/usecases/warehouse_object_segmentation.py +899 -0
- matrice_analytics/post_processing/usecases/waterbody_segmentation.py +923 -0
- matrice_analytics/post_processing/usecases/weapon_detection.py +771 -0
- matrice_analytics/post_processing/usecases/weld_defect_detection.py +615 -0
- matrice_analytics/post_processing/usecases/wildlife_monitoring.py +898 -0
- matrice_analytics/post_processing/usecases/windmill_maintenance.py +834 -0
- matrice_analytics/post_processing/usecases/wound_segmentation.py +856 -0
- matrice_analytics/post_processing/utils/__init__.py +150 -0
- matrice_analytics/post_processing/utils/advanced_counting_utils.py +400 -0
- matrice_analytics/post_processing/utils/advanced_helper_utils.py +317 -0
- matrice_analytics/post_processing/utils/advanced_tracking_utils.py +461 -0
- matrice_analytics/post_processing/utils/alerting_utils.py +213 -0
- matrice_analytics/post_processing/utils/category_mapping_utils.py +94 -0
- matrice_analytics/post_processing/utils/color_utils.py +592 -0
- matrice_analytics/post_processing/utils/counting_utils.py +182 -0
- matrice_analytics/post_processing/utils/filter_utils.py +261 -0
- matrice_analytics/post_processing/utils/format_utils.py +293 -0
- matrice_analytics/post_processing/utils/geometry_utils.py +300 -0
- matrice_analytics/post_processing/utils/smoothing_utils.py +358 -0
- matrice_analytics/post_processing/utils/tracking_utils.py +234 -0
- matrice_analytics/py.typed +0 -0
- matrice_analytics-0.1.60.dist-info/METADATA +481 -0
- matrice_analytics-0.1.60.dist-info/RECORD +196 -0
- matrice_analytics-0.1.60.dist-info/WHEEL +5 -0
- matrice_analytics-0.1.60.dist-info/licenses/LICENSE.txt +21 -0
- matrice_analytics-0.1.60.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,447 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Script for training the License Plate OCR models.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import pathlib
|
|
9
|
+
import shutil
|
|
10
|
+
from datetime import datetime
|
|
11
|
+
from typing import Literal, Optional
|
|
12
|
+
|
|
13
|
+
import albumentations as A
|
|
14
|
+
import click
|
|
15
|
+
import keras
|
|
16
|
+
from keras.src.callbacks import (
|
|
17
|
+
CSVLogger,
|
|
18
|
+
EarlyStopping,
|
|
19
|
+
ModelCheckpoint,
|
|
20
|
+
SwapEMAWeights,
|
|
21
|
+
TensorBoard,
|
|
22
|
+
TerminateOnNaN,
|
|
23
|
+
)
|
|
24
|
+
from keras.src.optimizers import AdamW
|
|
25
|
+
|
|
26
|
+
import fast_plate_ocr.train.model.model_builders
|
|
27
|
+
from fast_plate_ocr.cli.utils import print_params, print_train_details
|
|
28
|
+
from fast_plate_ocr.train.data.augmentation import (
|
|
29
|
+
default_train_augmentation,
|
|
30
|
+
)
|
|
31
|
+
from fast_plate_ocr.train.data.dataset import PlateRecognitionPyDataset
|
|
32
|
+
from fast_plate_ocr.train.model.config import load_plate_config_from_yaml
|
|
33
|
+
from fast_plate_ocr.train.model.loss import cce_loss, focal_cce_loss
|
|
34
|
+
from fast_plate_ocr.train.model.metric import (
|
|
35
|
+
cat_acc_metric,
|
|
36
|
+
plate_acc_metric,
|
|
37
|
+
plate_len_acc_metric,
|
|
38
|
+
top_3_k_metric,
|
|
39
|
+
)
|
|
40
|
+
from fast_plate_ocr.train.model.model_schema import load_model_config_from_yaml
|
|
41
|
+
|
|
42
|
+
# ruff: noqa: PLR0913
|
|
43
|
+
# pylint: disable=too-many-arguments,too-many-locals
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
EVAL_METRICS: dict[str, Literal["max", "min", "auto"]] = {
|
|
47
|
+
"val_plate_acc": "max",
|
|
48
|
+
"val_cat_acc": "max",
|
|
49
|
+
"val_top_3_k_acc": "max",
|
|
50
|
+
"val_plate_len_acc": "max",
|
|
51
|
+
"val_loss": "min",
|
|
52
|
+
}
|
|
53
|
+
"""Eval metric to monitor."""
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@click.command(context_settings={"max_content_width": 120})
|
|
57
|
+
@click.option(
|
|
58
|
+
"--model-config-file",
|
|
59
|
+
required=True,
|
|
60
|
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=pathlib.Path),
|
|
61
|
+
help="Path to the YAML config that describes the model architecture.",
|
|
62
|
+
)
|
|
63
|
+
@click.option(
|
|
64
|
+
"--plate-config-file",
|
|
65
|
+
required=True,
|
|
66
|
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=pathlib.Path),
|
|
67
|
+
help="Path to the plate YAML config.",
|
|
68
|
+
)
|
|
69
|
+
@click.option(
|
|
70
|
+
"--annotations",
|
|
71
|
+
required=True,
|
|
72
|
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=pathlib.Path),
|
|
73
|
+
help="Path pointing to the train annotations CSV file.",
|
|
74
|
+
)
|
|
75
|
+
@click.option(
|
|
76
|
+
"--val-annotations",
|
|
77
|
+
required=True,
|
|
78
|
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=pathlib.Path),
|
|
79
|
+
help="Path pointing to the train validation CSV file.",
|
|
80
|
+
)
|
|
81
|
+
@click.option(
|
|
82
|
+
"--validation-freq",
|
|
83
|
+
default=1,
|
|
84
|
+
show_default=True,
|
|
85
|
+
type=int,
|
|
86
|
+
help="Frequency (in epochs) at which to evaluate the validation data.",
|
|
87
|
+
)
|
|
88
|
+
@click.option(
|
|
89
|
+
"--augmentation-path",
|
|
90
|
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=pathlib.Path),
|
|
91
|
+
help="YAML file pointing to the augmentation pipeline saved with Albumentations.save(...)",
|
|
92
|
+
)
|
|
93
|
+
@click.option(
|
|
94
|
+
"--lr",
|
|
95
|
+
default=0.001,
|
|
96
|
+
show_default=True,
|
|
97
|
+
type=float,
|
|
98
|
+
help="Initial learning rate.",
|
|
99
|
+
)
|
|
100
|
+
@click.option(
|
|
101
|
+
"--final-lr-factor",
|
|
102
|
+
default=1e-2,
|
|
103
|
+
show_default=True,
|
|
104
|
+
type=float,
|
|
105
|
+
help="Final learning rate factor for the cosine decay scheduler. It's the fraction of"
|
|
106
|
+
" the initial learning rate that remains after decay.",
|
|
107
|
+
)
|
|
108
|
+
@click.option(
|
|
109
|
+
"--warmup-fraction",
|
|
110
|
+
default=0.05,
|
|
111
|
+
show_default=True,
|
|
112
|
+
type=float,
|
|
113
|
+
help="Fraction of total training steps to linearly warm up.",
|
|
114
|
+
)
|
|
115
|
+
@click.option(
|
|
116
|
+
"--weight-decay",
|
|
117
|
+
default=0.001,
|
|
118
|
+
show_default=True,
|
|
119
|
+
type=float,
|
|
120
|
+
help="Weight decay for the AdamW optimizer.",
|
|
121
|
+
)
|
|
122
|
+
@click.option(
|
|
123
|
+
"--clipnorm",
|
|
124
|
+
default=1.0,
|
|
125
|
+
show_default=True,
|
|
126
|
+
type=float,
|
|
127
|
+
help="Gradient clipping norm value for the AdamW optimizer.",
|
|
128
|
+
)
|
|
129
|
+
@click.option(
|
|
130
|
+
"--loss",
|
|
131
|
+
default="cce",
|
|
132
|
+
type=click.Choice(["cce", "focal_cce"], case_sensitive=False),
|
|
133
|
+
show_default=True,
|
|
134
|
+
help="Loss function to use during training.",
|
|
135
|
+
)
|
|
136
|
+
@click.option(
|
|
137
|
+
"--focal-alpha",
|
|
138
|
+
default=0.25,
|
|
139
|
+
show_default=True,
|
|
140
|
+
type=float,
|
|
141
|
+
help="Alpha parameter for focal loss. Applicable only when '--loss' is 'focal_cce'.",
|
|
142
|
+
)
|
|
143
|
+
@click.option(
|
|
144
|
+
"--focal-gamma",
|
|
145
|
+
default=2.0,
|
|
146
|
+
show_default=True,
|
|
147
|
+
type=float,
|
|
148
|
+
help="Gamma parameter for focal loss. Applicable only when '--loss' is 'focal_cce'.",
|
|
149
|
+
)
|
|
150
|
+
@click.option(
|
|
151
|
+
"--label-smoothing",
|
|
152
|
+
default=0.01,
|
|
153
|
+
show_default=True,
|
|
154
|
+
type=float,
|
|
155
|
+
help="Amount of label smoothing to apply.",
|
|
156
|
+
)
|
|
157
|
+
@click.option(
|
|
158
|
+
"--mixed-precision-policy",
|
|
159
|
+
default=None,
|
|
160
|
+
type=click.Choice(["mixed_float16", "mixed_bfloat16", "float32"]),
|
|
161
|
+
help=(
|
|
162
|
+
"Optional mixed precision policy for training. Choose one of: mixed_float16, "
|
|
163
|
+
"mixed_bfloat16, or float32. If not provided, Keras uses its default global policy."
|
|
164
|
+
),
|
|
165
|
+
)
|
|
166
|
+
@click.option(
|
|
167
|
+
"--batch-size",
|
|
168
|
+
default=64,
|
|
169
|
+
show_default=True,
|
|
170
|
+
type=int,
|
|
171
|
+
help="Batch size for training.",
|
|
172
|
+
)
|
|
173
|
+
@click.option(
|
|
174
|
+
"--workers",
|
|
175
|
+
default=1,
|
|
176
|
+
show_default=True,
|
|
177
|
+
type=int,
|
|
178
|
+
help="Number of worker threads/processes for parallel data loading.",
|
|
179
|
+
)
|
|
180
|
+
@click.option(
|
|
181
|
+
"--use-multiprocessing/--no-use-multiprocessing",
|
|
182
|
+
default=False,
|
|
183
|
+
show_default=True,
|
|
184
|
+
help="Use multiprocessing for data loading.",
|
|
185
|
+
)
|
|
186
|
+
@click.option(
|
|
187
|
+
"--max-queue-size",
|
|
188
|
+
default=10,
|
|
189
|
+
show_default=True,
|
|
190
|
+
type=int,
|
|
191
|
+
help="Maximum queue size for dataset workers.",
|
|
192
|
+
)
|
|
193
|
+
@click.option(
|
|
194
|
+
"--output-dir",
|
|
195
|
+
default="./trained_models",
|
|
196
|
+
type=click.Path(dir_okay=True, file_okay=False, path_type=pathlib.Path),
|
|
197
|
+
help="Output directory where model will be saved.",
|
|
198
|
+
)
|
|
199
|
+
@click.option(
|
|
200
|
+
"--epochs",
|
|
201
|
+
default=300,
|
|
202
|
+
show_default=True,
|
|
203
|
+
type=int,
|
|
204
|
+
help="Number of training epochs.",
|
|
205
|
+
)
|
|
206
|
+
@click.option(
|
|
207
|
+
"--tensorboard",
|
|
208
|
+
"-t",
|
|
209
|
+
is_flag=True,
|
|
210
|
+
help="Whether to use TensorBoard visualization tool.",
|
|
211
|
+
)
|
|
212
|
+
@click.option(
|
|
213
|
+
"--tensorboard-dir",
|
|
214
|
+
"-l",
|
|
215
|
+
default="tensorboard_logs",
|
|
216
|
+
show_default=True,
|
|
217
|
+
type=click.Path(path_type=pathlib.Path),
|
|
218
|
+
help="The path of the directory where to save the TensorBoard log files.",
|
|
219
|
+
)
|
|
220
|
+
@click.option(
|
|
221
|
+
"--early-stopping-patience",
|
|
222
|
+
default=100,
|
|
223
|
+
show_default=True,
|
|
224
|
+
type=int,
|
|
225
|
+
help="Stop training when the early stopping metric doesn't improve for X epochs.",
|
|
226
|
+
)
|
|
227
|
+
@click.option(
|
|
228
|
+
"--early-stopping-metric",
|
|
229
|
+
default="val_plate_acc",
|
|
230
|
+
show_default=True,
|
|
231
|
+
type=click.Choice(list(EVAL_METRICS), case_sensitive=False),
|
|
232
|
+
help="Metric to monitor for early stopping.",
|
|
233
|
+
)
|
|
234
|
+
@click.option(
|
|
235
|
+
"--weights-path",
|
|
236
|
+
type=click.Path(exists=True, file_okay=True, path_type=pathlib.Path),
|
|
237
|
+
help="Path to the pretrained model weights file.",
|
|
238
|
+
)
|
|
239
|
+
@click.option(
|
|
240
|
+
"--use-ema/--no-use-ema",
|
|
241
|
+
default=True,
|
|
242
|
+
show_default=True,
|
|
243
|
+
help=(
|
|
244
|
+
"Whether to use exponential moving averages in the AdamW optimizer. "
|
|
245
|
+
"Defaults to True; use --no-use-ema to disable."
|
|
246
|
+
),
|
|
247
|
+
)
|
|
248
|
+
@click.option(
|
|
249
|
+
"--wd-ignore",
|
|
250
|
+
default="bias,layer_norm",
|
|
251
|
+
show_default=True,
|
|
252
|
+
type=str,
|
|
253
|
+
help="Comma-separated list of variable substrings to exclude from weight decay.",
|
|
254
|
+
)
|
|
255
|
+
@click.option(
|
|
256
|
+
"--seed",
|
|
257
|
+
type=int,
|
|
258
|
+
help="Sets all random seeds (Python, NumPy, and backend framework, e.g. TF).",
|
|
259
|
+
)
|
|
260
|
+
@print_params(table_title="CLI Training Parameters", c1_title="Parameter", c2_title="Details")
|
|
261
|
+
def train(
|
|
262
|
+
model_config_file: pathlib.Path,
|
|
263
|
+
plate_config_file: pathlib.Path,
|
|
264
|
+
annotations: pathlib.Path,
|
|
265
|
+
val_annotations: pathlib.Path,
|
|
266
|
+
validation_freq: int,
|
|
267
|
+
augmentation_path: Optional[pathlib.Path],
|
|
268
|
+
lr: float,
|
|
269
|
+
final_lr_factor: float,
|
|
270
|
+
warmup_fraction: float,
|
|
271
|
+
weight_decay: float,
|
|
272
|
+
clipnorm: float,
|
|
273
|
+
loss: str,
|
|
274
|
+
focal_alpha: float,
|
|
275
|
+
focal_gamma: float,
|
|
276
|
+
label_smoothing: float,
|
|
277
|
+
mixed_precision_policy: Optional[str],
|
|
278
|
+
batch_size: int,
|
|
279
|
+
workers: int,
|
|
280
|
+
use_multiprocessing: bool,
|
|
281
|
+
max_queue_size: int,
|
|
282
|
+
output_dir: pathlib.Path,
|
|
283
|
+
epochs: int,
|
|
284
|
+
tensorboard: bool,
|
|
285
|
+
tensorboard_dir: pathlib.Path,
|
|
286
|
+
early_stopping_patience: int,
|
|
287
|
+
early_stopping_metric: str,
|
|
288
|
+
weights_path: Optional[pathlib.Path],
|
|
289
|
+
use_ema: bool,
|
|
290
|
+
wd_ignore: str,
|
|
291
|
+
seed: Optional[int],
|
|
292
|
+
) -> None:
|
|
293
|
+
"""
|
|
294
|
+
Train the License Plate OCR model.
|
|
295
|
+
"""
|
|
296
|
+
if seed is not None:
|
|
297
|
+
keras.utils.set_random_seed(seed)
|
|
298
|
+
|
|
299
|
+
if mixed_precision_policy is not None:
|
|
300
|
+
keras.mixed_precision.set_global_policy(mixed_precision_policy)
|
|
301
|
+
|
|
302
|
+
plate_config = load_plate_config_from_yaml(plate_config_file)
|
|
303
|
+
model_config = load_model_config_from_yaml(model_config_file)
|
|
304
|
+
train_augmentation = (
|
|
305
|
+
A.load(augmentation_path, data_format="yaml")
|
|
306
|
+
if augmentation_path
|
|
307
|
+
else default_train_augmentation(img_color_mode=plate_config.image_color_mode)
|
|
308
|
+
)
|
|
309
|
+
print_train_details(train_augmentation, plate_config.model_dump())
|
|
310
|
+
|
|
311
|
+
train_dataset = PlateRecognitionPyDataset(
|
|
312
|
+
annotations_file=annotations,
|
|
313
|
+
transform=train_augmentation,
|
|
314
|
+
plate_config=plate_config,
|
|
315
|
+
batch_size=batch_size,
|
|
316
|
+
shuffle=True,
|
|
317
|
+
workers=workers,
|
|
318
|
+
use_multiprocessing=use_multiprocessing,
|
|
319
|
+
max_queue_size=max_queue_size,
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
val_dataset = PlateRecognitionPyDataset(
|
|
323
|
+
annotations_file=val_annotations,
|
|
324
|
+
plate_config=plate_config,
|
|
325
|
+
batch_size=batch_size,
|
|
326
|
+
shuffle=False,
|
|
327
|
+
workers=workers,
|
|
328
|
+
use_multiprocessing=use_multiprocessing,
|
|
329
|
+
max_queue_size=max_queue_size,
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
# Train
|
|
333
|
+
model = fast_plate_ocr.train.model.model_builders.build_model(model_config, plate_config)
|
|
334
|
+
|
|
335
|
+
if weights_path:
|
|
336
|
+
model.load_weights(weights_path, skip_mismatch=True)
|
|
337
|
+
|
|
338
|
+
total_steps = epochs * len(train_dataset)
|
|
339
|
+
warmup_steps = int(warmup_fraction * total_steps)
|
|
340
|
+
|
|
341
|
+
cosine_decay = keras.optimizers.schedules.CosineDecay(
|
|
342
|
+
initial_learning_rate=0.0 if warmup_steps > 0 else lr,
|
|
343
|
+
decay_steps=total_steps,
|
|
344
|
+
alpha=final_lr_factor,
|
|
345
|
+
warmup_steps=warmup_steps,
|
|
346
|
+
warmup_target=lr if warmup_steps > 0 else None,
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
optimizer = AdamW(cosine_decay, weight_decay=weight_decay, clipnorm=clipnorm, use_ema=use_ema)
|
|
350
|
+
optimizer.exclude_from_weight_decay(
|
|
351
|
+
var_names=[name.strip() for name in wd_ignore.split(",") if name.strip()]
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
if loss == "cce":
|
|
355
|
+
loss_fn = cce_loss(
|
|
356
|
+
vocabulary_size=plate_config.vocabulary_size, label_smoothing=label_smoothing
|
|
357
|
+
)
|
|
358
|
+
elif loss == "focal_cce":
|
|
359
|
+
loss_fn = focal_cce_loss(
|
|
360
|
+
vocabulary_size=plate_config.vocabulary_size,
|
|
361
|
+
alpha=focal_alpha,
|
|
362
|
+
gamma=focal_gamma,
|
|
363
|
+
label_smoothing=label_smoothing,
|
|
364
|
+
)
|
|
365
|
+
else:
|
|
366
|
+
raise ValueError(f"Unsupported loss type: {loss}")
|
|
367
|
+
|
|
368
|
+
model.compile(
|
|
369
|
+
loss=loss_fn,
|
|
370
|
+
jit_compile=False,
|
|
371
|
+
optimizer=optimizer,
|
|
372
|
+
metrics=[
|
|
373
|
+
cat_acc_metric(
|
|
374
|
+
max_plate_slots=plate_config.max_plate_slots,
|
|
375
|
+
vocabulary_size=plate_config.vocabulary_size,
|
|
376
|
+
),
|
|
377
|
+
plate_acc_metric(
|
|
378
|
+
max_plate_slots=plate_config.max_plate_slots,
|
|
379
|
+
vocabulary_size=plate_config.vocabulary_size,
|
|
380
|
+
),
|
|
381
|
+
top_3_k_metric(vocabulary_size=plate_config.vocabulary_size),
|
|
382
|
+
plate_len_acc_metric(
|
|
383
|
+
max_plate_slots=plate_config.max_plate_slots,
|
|
384
|
+
vocabulary_size=plate_config.vocabulary_size,
|
|
385
|
+
pad_token_index=plate_config.pad_idx,
|
|
386
|
+
),
|
|
387
|
+
],
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
output_dir /= datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
391
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
392
|
+
model_file_path = output_dir / "ckpt-epoch_{epoch:02d}-acc_{val_plate_acc:.3f}.keras"
|
|
393
|
+
|
|
394
|
+
# Save params and configs used for training
|
|
395
|
+
shutil.copy(model_config_file, output_dir / "model_config.yaml")
|
|
396
|
+
shutil.copy(plate_config_file, output_dir / "plate_config.yaml")
|
|
397
|
+
A.save(train_augmentation, output_dir / "train_augmentation.yaml", "yaml")
|
|
398
|
+
with open(output_dir / "hyper_params.json", "w", encoding="utf-8") as f_out:
|
|
399
|
+
json.dump(
|
|
400
|
+
{k: v for k, v in locals().items() if k in click.get_current_context().params},
|
|
401
|
+
f_out,
|
|
402
|
+
indent=4,
|
|
403
|
+
default=str,
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
callbacks = [
|
|
407
|
+
# Stop training when early_stopping_metric doesn't improve for X epochs
|
|
408
|
+
EarlyStopping(
|
|
409
|
+
monitor=early_stopping_metric,
|
|
410
|
+
patience=early_stopping_patience,
|
|
411
|
+
mode=EVAL_METRICS[early_stopping_metric],
|
|
412
|
+
restore_best_weights=False,
|
|
413
|
+
verbose=1,
|
|
414
|
+
),
|
|
415
|
+
# To save model checkpoint with EMA weights, we need to place this before `ModelCheckpoint`
|
|
416
|
+
*([SwapEMAWeights(swap_on_epoch=True)] if use_ema else []),
|
|
417
|
+
# We don't use EarlyStopping restore_best_weights=True because it won't restore the best
|
|
418
|
+
# weights when it didn't manage to EarlyStop but finished all epochs
|
|
419
|
+
ModelCheckpoint(output_dir / "last.keras", save_weights_only=False, save_best_only=False),
|
|
420
|
+
ModelCheckpoint(
|
|
421
|
+
model_file_path,
|
|
422
|
+
monitor=early_stopping_metric,
|
|
423
|
+
mode=EVAL_METRICS[early_stopping_metric],
|
|
424
|
+
save_weights_only=False,
|
|
425
|
+
save_best_only=True,
|
|
426
|
+
verbose=1,
|
|
427
|
+
),
|
|
428
|
+
TerminateOnNaN(),
|
|
429
|
+
CSVLogger(str(output_dir / "training_log.csv")),
|
|
430
|
+
]
|
|
431
|
+
|
|
432
|
+
if tensorboard:
|
|
433
|
+
run_dir = tensorboard_dir / datetime.now().strftime("run_%Y-%m-%d_%H-%M-%S")
|
|
434
|
+
run_dir.mkdir(parents=True, exist_ok=True)
|
|
435
|
+
callbacks.append(TensorBoard(log_dir=run_dir))
|
|
436
|
+
|
|
437
|
+
model.fit(
|
|
438
|
+
train_dataset,
|
|
439
|
+
epochs=epochs,
|
|
440
|
+
validation_data=val_dataset,
|
|
441
|
+
callbacks=callbacks,
|
|
442
|
+
validation_freq=validation_freq,
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
if __name__ == "__main__":
|
|
447
|
+
train()
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utils used for the CLI scripts.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import importlib.util
|
|
8
|
+
import inspect
|
|
9
|
+
import pathlib
|
|
10
|
+
import random
|
|
11
|
+
from collections.abc import Callable, Sequence
|
|
12
|
+
from functools import wraps
|
|
13
|
+
from typing import Any, Optional
|
|
14
|
+
|
|
15
|
+
import albumentations as A
|
|
16
|
+
import numpy as np
|
|
17
|
+
from rich import box
|
|
18
|
+
from rich.console import Console
|
|
19
|
+
from rich.pretty import Pretty
|
|
20
|
+
from rich.table import Table
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def print_variables_as_table(
|
|
24
|
+
c1_title: str, c2_title: str, title: str = "Variables Table", **kwargs: Any
|
|
25
|
+
) -> None:
|
|
26
|
+
"""
|
|
27
|
+
Prints variables in a formatted table using the rich library.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
c1_title (str): Title of the first column.
|
|
31
|
+
c2_title (str): Title of the second column.
|
|
32
|
+
title (str): Title of the table.
|
|
33
|
+
**kwargs (Any): Variable names and values to be printed.
|
|
34
|
+
"""
|
|
35
|
+
console = Console()
|
|
36
|
+
console.print("\n")
|
|
37
|
+
table = Table(title=title, show_header=True, header_style="bold blue", box=box.ROUNDED)
|
|
38
|
+
table.add_column(c1_title, min_width=20, justify="left", style="bold")
|
|
39
|
+
table.add_column(c2_title, min_width=60, justify="left", style="bold")
|
|
40
|
+
|
|
41
|
+
for key, value in kwargs.items():
|
|
42
|
+
if isinstance(value, pathlib.Path):
|
|
43
|
+
value = str(value) # noqa: PLW2901
|
|
44
|
+
table.add_row(f"[bold]{key}[/bold]", Pretty(value))
|
|
45
|
+
|
|
46
|
+
console.print(table)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def print_params(
|
|
50
|
+
table_title: str = "Parameters Table", c1_title: str = "Variable", c2_title: str = "Value"
|
|
51
|
+
) -> Callable:
|
|
52
|
+
"""
|
|
53
|
+
A decorator that prints the parameters of a function in a formatted table
|
|
54
|
+
using the rich library.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
c1_title (str, optional): Title of the first column. Defaults to "Variable".
|
|
58
|
+
c2_title (str, optional): Title of the second column. Defaults to "Value".
|
|
59
|
+
table_title (str, optional): Title of the table. Defaults to "Parameters Table".
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
Callable: The wrapped function with parameter printing functionality.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def decorator(func: Callable) -> Callable:
|
|
66
|
+
@wraps(func)
|
|
67
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
68
|
+
func_signature = inspect.signature(func)
|
|
69
|
+
bound_arguments = func_signature.bind(*args, **kwargs)
|
|
70
|
+
bound_arguments.apply_defaults()
|
|
71
|
+
params = dict(bound_arguments.arguments.items())
|
|
72
|
+
print_variables_as_table(c1_title, c2_title, table_title, **params)
|
|
73
|
+
return func(*args, **kwargs)
|
|
74
|
+
|
|
75
|
+
return wrapper
|
|
76
|
+
|
|
77
|
+
return decorator
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def print_train_details(augmentation: A.Compose, config: dict[str, Any]) -> None:
|
|
81
|
+
console = Console()
|
|
82
|
+
console.print("\n")
|
|
83
|
+
console.print("[bold blue]Augmentation Pipeline:[/bold blue]")
|
|
84
|
+
console.print(Pretty(augmentation))
|
|
85
|
+
console.print("\n")
|
|
86
|
+
console.print("[bold blue]Configuration:[/bold blue]")
|
|
87
|
+
console.print(Pretty(config))
|
|
88
|
+
console.print("\n")
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def requires(*modules: str, pkg_name: Optional[Sequence[str]] = None) -> Callable:
|
|
92
|
+
"""
|
|
93
|
+
Decorator that checks if given modules are importable. If not, raises ModuleNotFoundError with
|
|
94
|
+
a hint to install the package(s).
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
modules (str): Names of modules to check (via importlib.util.find_spec).
|
|
98
|
+
pkg_name (Optional[Sequence[str]]): Names of packages to suggest installing.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
Callable: The wrapped function that checks for module availability.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
def decorator(fn: Callable) -> Callable:
|
|
105
|
+
@wraps(fn)
|
|
106
|
+
def wrapper(*args: Any, **kwargs: Any):
|
|
107
|
+
missing = [m for m in modules if importlib.util.find_spec(m) is None]
|
|
108
|
+
if missing:
|
|
109
|
+
pkg_missing = " ".join(pkg_name or missing)
|
|
110
|
+
raise ModuleNotFoundError(
|
|
111
|
+
f"Cannot run `{fn.__name__}` because {missing!r} "
|
|
112
|
+
f"is not installed. Please install the required package(s): {pkg_missing}"
|
|
113
|
+
)
|
|
114
|
+
return fn(*args, **kwargs)
|
|
115
|
+
|
|
116
|
+
return wrapper
|
|
117
|
+
|
|
118
|
+
return decorator
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def seed_everything(seed: int) -> None:
|
|
122
|
+
"""
|
|
123
|
+
Seed random number generators for reproducibility.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
seed (int): The seed value to set.
|
|
127
|
+
"""
|
|
128
|
+
random.seed(seed)
|
|
129
|
+
np.random.seed(seed)
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Script for validating trained OCR models.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import pathlib
|
|
8
|
+
|
|
9
|
+
import click
|
|
10
|
+
|
|
11
|
+
from fast_plate_ocr.train.data.dataset import PlateRecognitionPyDataset
|
|
12
|
+
from fast_plate_ocr.train.model.config import load_plate_config_from_yaml
|
|
13
|
+
from fast_plate_ocr.train.utilities.utils import load_keras_model
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@click.command(context_settings={"max_content_width": 120})
|
|
17
|
+
@click.option(
|
|
18
|
+
"-m",
|
|
19
|
+
"--model",
|
|
20
|
+
"model_path",
|
|
21
|
+
required=True,
|
|
22
|
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=pathlib.Path),
|
|
23
|
+
help="Path to the saved .keras model.",
|
|
24
|
+
)
|
|
25
|
+
@click.option(
|
|
26
|
+
"--plate-config-file",
|
|
27
|
+
required=True,
|
|
28
|
+
type=click.Path(exists=True, file_okay=True, path_type=pathlib.Path),
|
|
29
|
+
help="Path pointing to the model license plate OCR config.",
|
|
30
|
+
)
|
|
31
|
+
@click.option(
|
|
32
|
+
"-a",
|
|
33
|
+
"--annotations",
|
|
34
|
+
required=True,
|
|
35
|
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=pathlib.Path),
|
|
36
|
+
help="Annotations file used for validation.",
|
|
37
|
+
)
|
|
38
|
+
@click.option(
|
|
39
|
+
"-b",
|
|
40
|
+
"--batch-size",
|
|
41
|
+
default=1,
|
|
42
|
+
show_default=True,
|
|
43
|
+
type=int,
|
|
44
|
+
help="Batch size.",
|
|
45
|
+
)
|
|
46
|
+
@click.option(
|
|
47
|
+
"--workers",
|
|
48
|
+
default=1,
|
|
49
|
+
show_default=True,
|
|
50
|
+
type=int,
|
|
51
|
+
help="Number of worker threads/processes for parallel data loading via PyDataset.",
|
|
52
|
+
)
|
|
53
|
+
@click.option(
|
|
54
|
+
"--use-multiprocessing/--no-use-multiprocessing",
|
|
55
|
+
default=False,
|
|
56
|
+
show_default=True,
|
|
57
|
+
help="Whether to use multiprocessing for data loading.",
|
|
58
|
+
)
|
|
59
|
+
@click.option(
|
|
60
|
+
"--max-queue-size",
|
|
61
|
+
default=10,
|
|
62
|
+
show_default=True,
|
|
63
|
+
type=int,
|
|
64
|
+
help="Maximum number of batches to prefetch for the dataset.",
|
|
65
|
+
)
|
|
66
|
+
def valid(
|
|
67
|
+
model_path: pathlib.Path,
|
|
68
|
+
plate_config_file: pathlib.Path,
|
|
69
|
+
annotations: pathlib.Path,
|
|
70
|
+
batch_size: int,
|
|
71
|
+
workers: int,
|
|
72
|
+
use_multiprocessing: bool,
|
|
73
|
+
max_queue_size: int,
|
|
74
|
+
) -> None:
|
|
75
|
+
"""
|
|
76
|
+
Validate the trained OCR model on a labeled set.
|
|
77
|
+
"""
|
|
78
|
+
plate_config = load_plate_config_from_yaml(plate_config_file)
|
|
79
|
+
model = load_keras_model(model_path, plate_config)
|
|
80
|
+
val_dataset = PlateRecognitionPyDataset(
|
|
81
|
+
annotations_file=annotations,
|
|
82
|
+
plate_config=plate_config,
|
|
83
|
+
batch_size=batch_size,
|
|
84
|
+
shuffle=False,
|
|
85
|
+
workers=workers,
|
|
86
|
+
use_multiprocessing=use_multiprocessing,
|
|
87
|
+
max_queue_size=max_queue_size,
|
|
88
|
+
)
|
|
89
|
+
model.evaluate(val_dataset)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
if __name__ == "__main__":
|
|
93
|
+
valid()
|