matrice-analytics 0.1.60__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- matrice_analytics/__init__.py +28 -0
- matrice_analytics/boundary_drawing_internal/README.md +305 -0
- matrice_analytics/boundary_drawing_internal/__init__.py +45 -0
- matrice_analytics/boundary_drawing_internal/boundary_drawing_internal.py +1207 -0
- matrice_analytics/boundary_drawing_internal/boundary_drawing_tool.py +429 -0
- matrice_analytics/boundary_drawing_internal/boundary_tool_template.html +1036 -0
- matrice_analytics/boundary_drawing_internal/data/.gitignore +12 -0
- matrice_analytics/boundary_drawing_internal/example_usage.py +206 -0
- matrice_analytics/boundary_drawing_internal/usage/README.md +110 -0
- matrice_analytics/boundary_drawing_internal/usage/boundary_drawer_launcher.py +102 -0
- matrice_analytics/boundary_drawing_internal/usage/simple_boundary_launcher.py +107 -0
- matrice_analytics/post_processing/README.md +455 -0
- matrice_analytics/post_processing/__init__.py +732 -0
- matrice_analytics/post_processing/advanced_tracker/README.md +650 -0
- matrice_analytics/post_processing/advanced_tracker/__init__.py +17 -0
- matrice_analytics/post_processing/advanced_tracker/base.py +99 -0
- matrice_analytics/post_processing/advanced_tracker/config.py +77 -0
- matrice_analytics/post_processing/advanced_tracker/kalman_filter.py +370 -0
- matrice_analytics/post_processing/advanced_tracker/matching.py +195 -0
- matrice_analytics/post_processing/advanced_tracker/strack.py +230 -0
- matrice_analytics/post_processing/advanced_tracker/tracker.py +367 -0
- matrice_analytics/post_processing/config.py +146 -0
- matrice_analytics/post_processing/core/__init__.py +63 -0
- matrice_analytics/post_processing/core/base.py +704 -0
- matrice_analytics/post_processing/core/config.py +3291 -0
- matrice_analytics/post_processing/core/config_utils.py +925 -0
- matrice_analytics/post_processing/face_reg/__init__.py +43 -0
- matrice_analytics/post_processing/face_reg/compare_similarity.py +556 -0
- matrice_analytics/post_processing/face_reg/embedding_manager.py +950 -0
- matrice_analytics/post_processing/face_reg/face_recognition.py +2234 -0
- matrice_analytics/post_processing/face_reg/face_recognition_client.py +606 -0
- matrice_analytics/post_processing/face_reg/people_activity_logging.py +321 -0
- matrice_analytics/post_processing/ocr/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/easyocr_extractor.py +250 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/__init__.py +9 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/__init__.py +4 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/cli.py +33 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/dataset_stats.py +139 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/export.py +398 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/train.py +447 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/utils.py +129 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/valid.py +93 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/validate_dataset.py +240 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/visualize_augmentation.py +176 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/visualize_predictions.py +96 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/__init__.py +3 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/process.py +246 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/types.py +60 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/utils.py +87 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/__init__.py +3 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/config.py +82 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/hub.py +141 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/plate_recognizer.py +323 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/py.typed +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/augmentation.py +101 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/dataset.py +97 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/config.py +114 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/layers.py +553 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/loss.py +55 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/metric.py +86 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/model_builders.py +95 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/model_schema.py +395 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/backend_utils.py +38 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/utils.py +214 -0
- matrice_analytics/post_processing/ocr/postprocessing.py +270 -0
- matrice_analytics/post_processing/ocr/preprocessing.py +52 -0
- matrice_analytics/post_processing/post_processor.py +1175 -0
- matrice_analytics/post_processing/test_cases/__init__.py +1 -0
- matrice_analytics/post_processing/test_cases/run_tests.py +143 -0
- matrice_analytics/post_processing/test_cases/test_advanced_customer_service.py +841 -0
- matrice_analytics/post_processing/test_cases/test_basic_counting_tracking.py +523 -0
- matrice_analytics/post_processing/test_cases/test_comprehensive.py +531 -0
- matrice_analytics/post_processing/test_cases/test_config.py +852 -0
- matrice_analytics/post_processing/test_cases/test_customer_service.py +585 -0
- matrice_analytics/post_processing/test_cases/test_data_generators.py +583 -0
- matrice_analytics/post_processing/test_cases/test_people_counting.py +510 -0
- matrice_analytics/post_processing/test_cases/test_processor.py +524 -0
- matrice_analytics/post_processing/test_cases/test_usecases.py +165 -0
- matrice_analytics/post_processing/test_cases/test_utilities.py +356 -0
- matrice_analytics/post_processing/test_cases/test_utils.py +743 -0
- matrice_analytics/post_processing/usecases/Histopathological_Cancer_Detection_img.py +604 -0
- matrice_analytics/post_processing/usecases/__init__.py +267 -0
- matrice_analytics/post_processing/usecases/abandoned_object_detection.py +797 -0
- matrice_analytics/post_processing/usecases/advanced_customer_service.py +1601 -0
- matrice_analytics/post_processing/usecases/age_detection.py +842 -0
- matrice_analytics/post_processing/usecases/age_gender_detection.py +1085 -0
- matrice_analytics/post_processing/usecases/anti_spoofing_detection.py +656 -0
- matrice_analytics/post_processing/usecases/assembly_line_detection.py +841 -0
- matrice_analytics/post_processing/usecases/banana_defect_detection.py +624 -0
- matrice_analytics/post_processing/usecases/basic_counting_tracking.py +667 -0
- matrice_analytics/post_processing/usecases/blood_cancer_detection_img.py +881 -0
- matrice_analytics/post_processing/usecases/car_damage_detection.py +834 -0
- matrice_analytics/post_processing/usecases/car_part_segmentation.py +946 -0
- matrice_analytics/post_processing/usecases/car_service.py +1601 -0
- matrice_analytics/post_processing/usecases/cardiomegaly_classification.py +864 -0
- matrice_analytics/post_processing/usecases/cell_microscopy_segmentation.py +897 -0
- matrice_analytics/post_processing/usecases/chicken_pose_detection.py +648 -0
- matrice_analytics/post_processing/usecases/child_monitoring.py +814 -0
- matrice_analytics/post_processing/usecases/color/clip.py +660 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/merges.txt +48895 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/preprocessor_config.json +28 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/special_tokens_map.json +30 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/tokenizer.json +245079 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/tokenizer_config.json +32 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/vocab.json +1 -0
- matrice_analytics/post_processing/usecases/color/color_map_utils.py +70 -0
- matrice_analytics/post_processing/usecases/color/color_mapper.py +468 -0
- matrice_analytics/post_processing/usecases/color_detection.py +1936 -0
- matrice_analytics/post_processing/usecases/color_map_utils.py +70 -0
- matrice_analytics/post_processing/usecases/concrete_crack_detection.py +827 -0
- matrice_analytics/post_processing/usecases/crop_weed_detection.py +781 -0
- matrice_analytics/post_processing/usecases/customer_service.py +1008 -0
- matrice_analytics/post_processing/usecases/defect_detection_products.py +936 -0
- matrice_analytics/post_processing/usecases/distracted_driver_detection.py +822 -0
- matrice_analytics/post_processing/usecases/drone_traffic_monitoring.py +585 -0
- matrice_analytics/post_processing/usecases/drowsy_driver_detection.py +829 -0
- matrice_analytics/post_processing/usecases/dwell_detection.py +829 -0
- matrice_analytics/post_processing/usecases/emergency_vehicle_detection.py +827 -0
- matrice_analytics/post_processing/usecases/face_emotion.py +813 -0
- matrice_analytics/post_processing/usecases/face_recognition.py +827 -0
- matrice_analytics/post_processing/usecases/fashion_detection.py +835 -0
- matrice_analytics/post_processing/usecases/field_mapping.py +902 -0
- matrice_analytics/post_processing/usecases/fire_detection.py +1146 -0
- matrice_analytics/post_processing/usecases/flare_analysis.py +836 -0
- matrice_analytics/post_processing/usecases/flower_segmentation.py +1006 -0
- matrice_analytics/post_processing/usecases/gas_leak_detection.py +837 -0
- matrice_analytics/post_processing/usecases/gender_detection.py +832 -0
- matrice_analytics/post_processing/usecases/human_activity_recognition.py +871 -0
- matrice_analytics/post_processing/usecases/intrusion_detection.py +1672 -0
- matrice_analytics/post_processing/usecases/leaf.py +821 -0
- matrice_analytics/post_processing/usecases/leaf_disease.py +840 -0
- matrice_analytics/post_processing/usecases/leak_detection.py +837 -0
- matrice_analytics/post_processing/usecases/license_plate_detection.py +1188 -0
- matrice_analytics/post_processing/usecases/license_plate_monitoring.py +1781 -0
- matrice_analytics/post_processing/usecases/litter_monitoring.py +717 -0
- matrice_analytics/post_processing/usecases/mask_detection.py +869 -0
- matrice_analytics/post_processing/usecases/natural_disaster.py +907 -0
- matrice_analytics/post_processing/usecases/parking.py +787 -0
- matrice_analytics/post_processing/usecases/parking_space_detection.py +822 -0
- matrice_analytics/post_processing/usecases/pcb_defect_detection.py +888 -0
- matrice_analytics/post_processing/usecases/pedestrian_detection.py +808 -0
- matrice_analytics/post_processing/usecases/people_counting.py +706 -0
- matrice_analytics/post_processing/usecases/people_counting_bckp.py +1683 -0
- matrice_analytics/post_processing/usecases/people_tracking.py +1842 -0
- matrice_analytics/post_processing/usecases/pipeline_detection.py +605 -0
- matrice_analytics/post_processing/usecases/plaque_segmentation_img.py +874 -0
- matrice_analytics/post_processing/usecases/pothole_segmentation.py +915 -0
- matrice_analytics/post_processing/usecases/ppe_compliance.py +645 -0
- matrice_analytics/post_processing/usecases/price_tag_detection.py +822 -0
- matrice_analytics/post_processing/usecases/proximity_detection.py +1901 -0
- matrice_analytics/post_processing/usecases/road_lane_detection.py +623 -0
- matrice_analytics/post_processing/usecases/road_traffic_density.py +832 -0
- matrice_analytics/post_processing/usecases/road_view_segmentation.py +915 -0
- matrice_analytics/post_processing/usecases/shelf_inventory_detection.py +583 -0
- matrice_analytics/post_processing/usecases/shoplifting_detection.py +822 -0
- matrice_analytics/post_processing/usecases/shopping_cart_analysis.py +899 -0
- matrice_analytics/post_processing/usecases/skin_cancer_classification_img.py +864 -0
- matrice_analytics/post_processing/usecases/smoker_detection.py +833 -0
- matrice_analytics/post_processing/usecases/solar_panel.py +810 -0
- matrice_analytics/post_processing/usecases/suspicious_activity_detection.py +1030 -0
- matrice_analytics/post_processing/usecases/template_usecase.py +380 -0
- matrice_analytics/post_processing/usecases/theft_detection.py +648 -0
- matrice_analytics/post_processing/usecases/traffic_sign_monitoring.py +724 -0
- matrice_analytics/post_processing/usecases/underground_pipeline_defect_detection.py +775 -0
- matrice_analytics/post_processing/usecases/underwater_pollution_detection.py +842 -0
- matrice_analytics/post_processing/usecases/vehicle_monitoring.py +1029 -0
- matrice_analytics/post_processing/usecases/warehouse_object_segmentation.py +899 -0
- matrice_analytics/post_processing/usecases/waterbody_segmentation.py +923 -0
- matrice_analytics/post_processing/usecases/weapon_detection.py +771 -0
- matrice_analytics/post_processing/usecases/weld_defect_detection.py +615 -0
- matrice_analytics/post_processing/usecases/wildlife_monitoring.py +898 -0
- matrice_analytics/post_processing/usecases/windmill_maintenance.py +834 -0
- matrice_analytics/post_processing/usecases/wound_segmentation.py +856 -0
- matrice_analytics/post_processing/utils/__init__.py +150 -0
- matrice_analytics/post_processing/utils/advanced_counting_utils.py +400 -0
- matrice_analytics/post_processing/utils/advanced_helper_utils.py +317 -0
- matrice_analytics/post_processing/utils/advanced_tracking_utils.py +461 -0
- matrice_analytics/post_processing/utils/alerting_utils.py +213 -0
- matrice_analytics/post_processing/utils/category_mapping_utils.py +94 -0
- matrice_analytics/post_processing/utils/color_utils.py +592 -0
- matrice_analytics/post_processing/utils/counting_utils.py +182 -0
- matrice_analytics/post_processing/utils/filter_utils.py +261 -0
- matrice_analytics/post_processing/utils/format_utils.py +293 -0
- matrice_analytics/post_processing/utils/geometry_utils.py +300 -0
- matrice_analytics/post_processing/utils/smoothing_utils.py +358 -0
- matrice_analytics/post_processing/utils/tracking_utils.py +234 -0
- matrice_analytics/py.typed +0 -0
- matrice_analytics-0.1.60.dist-info/METADATA +481 -0
- matrice_analytics-0.1.60.dist-info/RECORD +196 -0
- matrice_analytics-0.1.60.dist-info/WHEEL +5 -0
- matrice_analytics-0.1.60.dist-info/licenses/LICENSE.txt +21 -0
- matrice_analytics-0.1.60.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ONNX inference module.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import pathlib
|
|
9
|
+
from collections.abc import Sequence
|
|
10
|
+
from typing import Literal, Union, Optional
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import numpy.typing as npt
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
import onnxruntime as ort
|
|
17
|
+
except ModuleNotFoundError as e:
|
|
18
|
+
raise ModuleNotFoundError(
|
|
19
|
+
"ONNX Runtime is not installed. Run: "
|
|
20
|
+
"pip install 'fast-plate-ocr[onnx]' (or [onnx-gpu], etc.)"
|
|
21
|
+
) from e
|
|
22
|
+
from rich.console import Console
|
|
23
|
+
from rich.panel import Panel
|
|
24
|
+
from rich.table import Table
|
|
25
|
+
from rich.text import Text
|
|
26
|
+
|
|
27
|
+
from fast_plate_ocr.core.process import (
|
|
28
|
+
postprocess_output,
|
|
29
|
+
preprocess_image,
|
|
30
|
+
read_and_resize_plate_image,
|
|
31
|
+
resize_image,
|
|
32
|
+
)
|
|
33
|
+
from fast_plate_ocr.core.types import BatchArray, BatchOrImgLike, ImgLike, PathLike
|
|
34
|
+
from fast_plate_ocr.core.utils import measure_time
|
|
35
|
+
from fast_plate_ocr.inference import hub
|
|
36
|
+
from fast_plate_ocr.inference.config import PlateOCRConfig
|
|
37
|
+
from fast_plate_ocr.inference.hub import OcrModel
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _frame_from(item: ImgLike, cfg: PlateOCRConfig) -> BatchArray:
|
|
41
|
+
"""
|
|
42
|
+
Converts a single image-like input into a normalized (H, W, C) NumPy array ready for model
|
|
43
|
+
inference. It handles both file paths and in-memory images. If input is a file path, the image
|
|
44
|
+
is read and resized using the configuration provided. If it's a NumPy array, it is validated and
|
|
45
|
+
resized accordingly.
|
|
46
|
+
"""
|
|
47
|
+
# If it's a path, read and resize
|
|
48
|
+
if isinstance(item, (str, pathlib.PurePath)):
|
|
49
|
+
return read_and_resize_plate_image(
|
|
50
|
+
item,
|
|
51
|
+
img_height=cfg.img_height,
|
|
52
|
+
img_width=cfg.img_width,
|
|
53
|
+
image_color_mode=cfg.image_color_mode,
|
|
54
|
+
keep_aspect_ratio=cfg.keep_aspect_ratio,
|
|
55
|
+
interpolation_method=cfg.interpolation,
|
|
56
|
+
padding_color=cfg.padding_color,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# Otherwise it must be a numpy array
|
|
60
|
+
if not isinstance(item, np.ndarray):
|
|
61
|
+
raise TypeError(f"Unsupported element type: {type(item)}")
|
|
62
|
+
|
|
63
|
+
# If it has (N, H, W, C) shape we assume it's ready for inference
|
|
64
|
+
if item.ndim == 4:
|
|
65
|
+
return item
|
|
66
|
+
|
|
67
|
+
# If it's a single frame resize accordingly
|
|
68
|
+
return resize_image(
|
|
69
|
+
item,
|
|
70
|
+
cfg.img_height,
|
|
71
|
+
cfg.img_width,
|
|
72
|
+
image_color_mode=cfg.image_color_mode,
|
|
73
|
+
keep_aspect_ratio=cfg.keep_aspect_ratio,
|
|
74
|
+
interpolation_method=cfg.interpolation,
|
|
75
|
+
padding_color=cfg.padding_color,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _load_image_from_source(source: BatchOrImgLike, cfg: PlateOCRConfig) -> BatchArray:
|
|
80
|
+
"""
|
|
81
|
+
Converts an image input or batch of inputs into a 4-D NumPy array (N, H, W, C).
|
|
82
|
+
|
|
83
|
+
This utility supports a wide range of input formats, including single images or batches, file
|
|
84
|
+
paths or NumPy arrays. It ensures the result is always a model-ready batch.
|
|
85
|
+
|
|
86
|
+
Supported input formats:
|
|
87
|
+
- Single path (`str` or `PathLike`) -> image is read and resized
|
|
88
|
+
- List or tuple of paths -> each image is read and resized
|
|
89
|
+
- Single 2D or 3D NumPy array -> resized and wrapped in a batch
|
|
90
|
+
- List or tuple of NumPy arrays -> each image is resized and batched
|
|
91
|
+
- Single 4D NumPy array with shape (N, H, W, C) -> returned as is
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
source: A single image or batch of images in path or NumPy array format.
|
|
95
|
+
cfg: The configuration object that defines image preprocessing parameters.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
A 4D NumPy array of shape (N, H, W, C), dtype uint8, ready for model inference.
|
|
99
|
+
"""
|
|
100
|
+
if isinstance(source, np.ndarray) and source.ndim == 4:
|
|
101
|
+
return source
|
|
102
|
+
|
|
103
|
+
items: Sequence[ImgLike] = (
|
|
104
|
+
source
|
|
105
|
+
if isinstance(source, Sequence)
|
|
106
|
+
and not isinstance(source, (str, pathlib.PurePath, np.ndarray))
|
|
107
|
+
else [source]
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
frames: list[BatchArray] = [
|
|
111
|
+
frame
|
|
112
|
+
for item in items
|
|
113
|
+
for frame in (
|
|
114
|
+
_frame_from(item, cfg) # type: ignore[attr-defined]
|
|
115
|
+
if isinstance(item, np.ndarray) and item.ndim == 4
|
|
116
|
+
else [_frame_from(item, cfg)]
|
|
117
|
+
)
|
|
118
|
+
]
|
|
119
|
+
|
|
120
|
+
return np.stack(frames, axis=0, dtype=np.uint8)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class LicensePlateRecognizer:
|
|
124
|
+
"""
|
|
125
|
+
ONNX inference class for performing license plates OCR.
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
hub_ocr_model: Optional[OcrModel] = None,
|
|
131
|
+
device: Literal["cuda", "cpu", "auto"] = "auto",
|
|
132
|
+
providers: Optional[Sequence[Union[str, tuple[str, dict]]]] = None,
|
|
133
|
+
sess_options: Optional[ort.SessionOptions] = None,
|
|
134
|
+
onnx_model_path: Optional[PathLike] = None,
|
|
135
|
+
plate_config_path: Optional[PathLike] = None,
|
|
136
|
+
force_download: bool = False,
|
|
137
|
+
) -> None:
|
|
138
|
+
"""
|
|
139
|
+
Initializes the `LicensePlateRecognizer` with the specified OCR model and inference device.
|
|
140
|
+
|
|
141
|
+
The current OCR models available from the HUB are:
|
|
142
|
+
|
|
143
|
+
- `cct-s-v1-global-model`: OCR model trained with **global** plates data. Based on Compact
|
|
144
|
+
Convolutional Transformer (CCT) architecture. This is the **S** variant.
|
|
145
|
+
- `cct-xs-v1-global-model`: OCR model trained with **global** plates data. Based on Compact
|
|
146
|
+
Convolutional Transformer (CCT) architecture. This is the **XS** variant.
|
|
147
|
+
- `argentinian-plates-cnn-model`: OCR for **Argentinian** license plates. Uses fully conv
|
|
148
|
+
architecture.
|
|
149
|
+
- `argentinian-plates-cnn-synth-model`: OCR for **Argentinian** license plates trained with
|
|
150
|
+
synthetic and real data. Uses fully conv architecture.
|
|
151
|
+
- `european-plates-mobile-vit-v2-model`: OCR for **European** license plates. Uses
|
|
152
|
+
MobileVIT-2 for the backbone.
|
|
153
|
+
- `global-plates-mobile-vit-v2-model`: OCR for **global** license plates (+65 countries).
|
|
154
|
+
Uses MobileVIT-2 for the backbone.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
hub_ocr_model: Name of the OCR model to use from the HUB.
|
|
158
|
+
device: Device type for inference. Should be one of ('cpu', 'cuda', 'auto'). If
|
|
159
|
+
'auto' mode, the device will be deduced from
|
|
160
|
+
`onnxruntime.get_available_providers()`.
|
|
161
|
+
providers: Optional sequence of providers in order of decreasing precedence. If not
|
|
162
|
+
specified, all available providers are used based on the device argument.
|
|
163
|
+
sess_options: Advanced session options for ONNX Runtime.
|
|
164
|
+
onnx_model_path: Path to ONNX model file to use (In case you want to use a custom one).
|
|
165
|
+
plate_config_path: Path to config file to use (In case you want to use a custom one).
|
|
166
|
+
force_download: Force and download the model, even if it already exists.
|
|
167
|
+
Returns:
|
|
168
|
+
None.
|
|
169
|
+
"""
|
|
170
|
+
self.logger = logging.getLogger(__name__)
|
|
171
|
+
|
|
172
|
+
if providers is not None:
|
|
173
|
+
self.providers = providers
|
|
174
|
+
self.logger.info("Using custom providers: %s", providers)
|
|
175
|
+
else:
|
|
176
|
+
if device == "cuda":
|
|
177
|
+
self.providers = ["CUDAExecutionProvider"]
|
|
178
|
+
elif device == "cpu":
|
|
179
|
+
self.providers = ["CPUExecutionProvider"]
|
|
180
|
+
elif device == "auto":
|
|
181
|
+
self.providers = ort.get_available_providers()
|
|
182
|
+
else:
|
|
183
|
+
raise ValueError(
|
|
184
|
+
f"Device should be one of ('cpu', 'cuda', 'auto'). Got '{device}'."
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
self.logger.info("Using device '%s' with providers: %s", device, self.providers)
|
|
188
|
+
|
|
189
|
+
if onnx_model_path and plate_config_path:
|
|
190
|
+
onnx_model_path = pathlib.Path(onnx_model_path)
|
|
191
|
+
plate_config_path = pathlib.Path(plate_config_path)
|
|
192
|
+
if not onnx_model_path.exists() or not plate_config_path.exists():
|
|
193
|
+
raise FileNotFoundError("Missing model/config file!")
|
|
194
|
+
self.model_name = onnx_model_path.stem
|
|
195
|
+
elif hub_ocr_model:
|
|
196
|
+
self.model_name = hub_ocr_model
|
|
197
|
+
onnx_model_path, plate_config_path = hub.download_model(
|
|
198
|
+
model_name=hub_ocr_model, force_download=force_download
|
|
199
|
+
)
|
|
200
|
+
else:
|
|
201
|
+
raise ValueError(
|
|
202
|
+
"Either provide a model from the HUB or a custom model_path and config_path"
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
self.config = PlateOCRConfig.from_yaml(plate_config_path)
|
|
206
|
+
self.model = ort.InferenceSession(
|
|
207
|
+
onnx_model_path, providers=self.providers, sess_options=sess_options
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
def benchmark(
|
|
211
|
+
self,
|
|
212
|
+
n_iter: int = 2_500,
|
|
213
|
+
batch_size: int = 1,
|
|
214
|
+
include_processing: bool = False,
|
|
215
|
+
warmup: int = 250,
|
|
216
|
+
) -> None:
|
|
217
|
+
"""
|
|
218
|
+
Run an inference benchmark and pretty print the results.
|
|
219
|
+
|
|
220
|
+
It reports the following metrics:
|
|
221
|
+
|
|
222
|
+
* **Average latency per batch** (milliseconds)
|
|
223
|
+
* **Throughput** in *plates / second* (PPS), i.e., how many plates the pipeline can process
|
|
224
|
+
per second at the chosen ``batch_size``.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
n_iter: The number of iterations to run the benchmark. This determines how many times
|
|
228
|
+
the inference will be executed to compute the average performance metrics.
|
|
229
|
+
batch_size : Batch size to use for the benchmark.
|
|
230
|
+
include_processing: Indicates whether the benchmark should include preprocessing and
|
|
231
|
+
postprocessing times in the measurement.
|
|
232
|
+
warmup: Number of warmup iterations to run before the benchmark.
|
|
233
|
+
"""
|
|
234
|
+
x = np.random.randint(
|
|
235
|
+
0,
|
|
236
|
+
256,
|
|
237
|
+
size=(
|
|
238
|
+
batch_size,
|
|
239
|
+
self.config.img_height,
|
|
240
|
+
self.config.img_width,
|
|
241
|
+
self.config.num_channels,
|
|
242
|
+
),
|
|
243
|
+
dtype=np.uint8,
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
# Warm-up
|
|
247
|
+
for _ in range(warmup):
|
|
248
|
+
if include_processing:
|
|
249
|
+
self.run(x)
|
|
250
|
+
else:
|
|
251
|
+
self.model.run(None, {"input": x})
|
|
252
|
+
|
|
253
|
+
# Timed loop
|
|
254
|
+
cum_time = 0.0
|
|
255
|
+
for _ in range(n_iter):
|
|
256
|
+
with measure_time() as time_taken:
|
|
257
|
+
if include_processing:
|
|
258
|
+
self.run(x)
|
|
259
|
+
else:
|
|
260
|
+
self.model.run(None, {"input": x})
|
|
261
|
+
cum_time += time_taken()
|
|
262
|
+
|
|
263
|
+
avg_time_ms = cum_time / n_iter if n_iter else 0.0
|
|
264
|
+
pps = (1_000 / avg_time_ms) * batch_size if n_iter else 0.0
|
|
265
|
+
|
|
266
|
+
console = Console()
|
|
267
|
+
model_info = Panel(
|
|
268
|
+
Text(f"Model: {self.model_name}\nProviders: {self.providers}", style="bold green"),
|
|
269
|
+
title="Model Information",
|
|
270
|
+
border_style="bright_blue",
|
|
271
|
+
expand=False,
|
|
272
|
+
)
|
|
273
|
+
console.print(model_info)
|
|
274
|
+
table = Table(title=f"Benchmark for '{self.model_name}'", border_style="bright_blue")
|
|
275
|
+
table.add_column("Metric", justify="center", style="cyan", no_wrap=True)
|
|
276
|
+
table.add_column("Value", justify="center", style="magenta")
|
|
277
|
+
|
|
278
|
+
table.add_row("Batch size", str(batch_size))
|
|
279
|
+
table.add_row("Warm-up iters", str(warmup))
|
|
280
|
+
table.add_row("Timed iterations", str(n_iter))
|
|
281
|
+
table.add_row("Average Time / batch (ms)", f"{avg_time_ms:.4f}")
|
|
282
|
+
table.add_row("Plates per Second (PPS)", f"{pps:.4f}")
|
|
283
|
+
console.print(table)
|
|
284
|
+
|
|
285
|
+
def run(
|
|
286
|
+
self,
|
|
287
|
+
source: Union[str, list[str], npt.NDArray, list[npt.NDArray]],
|
|
288
|
+
return_confidence: bool = False,
|
|
289
|
+
) -> Union[tuple[list[str], npt.NDArray], list[str]]:
|
|
290
|
+
"""
|
|
291
|
+
Performs OCR to recognize license plate characters from an image or a list of images.
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
source: One or more image inputs, which can be:
|
|
295
|
+
|
|
296
|
+
- A file path (`str` or `PathLike`) to an image.
|
|
297
|
+
- A list of file paths.
|
|
298
|
+
- A NumPy array of a single image, with shape (H, W), (H, W, 1) or (H, W, 3).
|
|
299
|
+
- A list of NumPy arrays, each representing an image.
|
|
300
|
+
- A 4D NumPy array of shape (N, H, W, C), ready for inference.
|
|
301
|
+
|
|
302
|
+
Images will be automatically resized and converted as needed based on the model's
|
|
303
|
+
configuration (including color mode and aspect ratio settings).
|
|
304
|
+
|
|
305
|
+
return_confidence: Whether to return confidence scores along with plate predictions.
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
A list of recognized license plates (one per image). If `return_confidence` is True,
|
|
309
|
+
also returns a NumPy array of shape `(N, plate_slots)` containing the confidence scores
|
|
310
|
+
for each predicted character.
|
|
311
|
+
"""
|
|
312
|
+
x = _load_image_from_source(source, self.config)
|
|
313
|
+
# Preprocess
|
|
314
|
+
x = preprocess_image(x)
|
|
315
|
+
# Run model
|
|
316
|
+
y: list[npt.NDArray] = self.model.run(None, {"input": x})
|
|
317
|
+
# Postprocess model output
|
|
318
|
+
return postprocess_output(
|
|
319
|
+
y[0],
|
|
320
|
+
self.config.max_plate_slots,
|
|
321
|
+
self.config.alphabet,
|
|
322
|
+
return_confidence=return_confidence,
|
|
323
|
+
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Augmentations used for training the OCR model.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import albumentations as A
|
|
6
|
+
import cv2
|
|
7
|
+
|
|
8
|
+
from fast_plate_ocr.core.types import ImageColorMode
|
|
9
|
+
|
|
10
|
+
BORDER_COLOR_BLACK: tuple[int, int, int] = (0, 0, 0)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def default_train_augmentation(img_color_mode: ImageColorMode) -> A.Compose:
|
|
14
|
+
"""
|
|
15
|
+
Default training augmentation pipeline.
|
|
16
|
+
"""
|
|
17
|
+
if img_color_mode == "grayscale":
|
|
18
|
+
return A.Compose(
|
|
19
|
+
[
|
|
20
|
+
A.Affine(
|
|
21
|
+
translate_percent=(-0.02, 0.02),
|
|
22
|
+
scale=(0.75, 1.10),
|
|
23
|
+
rotate=(-15, 15),
|
|
24
|
+
border_mode=cv2.BORDER_CONSTANT,
|
|
25
|
+
fill=BORDER_COLOR_BLACK,
|
|
26
|
+
shear=(0.0, 0.0),
|
|
27
|
+
p=0.75,
|
|
28
|
+
),
|
|
29
|
+
A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=1.0),
|
|
30
|
+
A.GaussianBlur(sigma_limit=(0.2, 0.5), p=0.25),
|
|
31
|
+
A.OneOf(
|
|
32
|
+
[
|
|
33
|
+
A.CoarseDropout(
|
|
34
|
+
num_holes_range=(1, 14),
|
|
35
|
+
hole_height_range=(1, 5),
|
|
36
|
+
hole_width_range=(1, 5),
|
|
37
|
+
p=0.2,
|
|
38
|
+
),
|
|
39
|
+
A.PixelDropout(dropout_prob=0.02, p=0.2),
|
|
40
|
+
A.GridDropout(ratio=0.3, fill="random", p=0.2),
|
|
41
|
+
],
|
|
42
|
+
p=0.7,
|
|
43
|
+
),
|
|
44
|
+
]
|
|
45
|
+
)
|
|
46
|
+
if img_color_mode == "rgb":
|
|
47
|
+
return A.Compose(
|
|
48
|
+
[
|
|
49
|
+
A.Affine(
|
|
50
|
+
translate_percent=(-0.02, 0.02),
|
|
51
|
+
scale=(0.75, 1.10),
|
|
52
|
+
rotate=(-15, 15),
|
|
53
|
+
border_mode=cv2.BORDER_CONSTANT,
|
|
54
|
+
fill=BORDER_COLOR_BLACK,
|
|
55
|
+
shear=(0.0, 0.0),
|
|
56
|
+
p=0.75,
|
|
57
|
+
),
|
|
58
|
+
A.RandomBrightnessContrast(brightness_limit=0.10, contrast_limit=0.10, p=0.5),
|
|
59
|
+
A.OneOf(
|
|
60
|
+
[
|
|
61
|
+
A.HueSaturationValue(
|
|
62
|
+
hue_shift_limit=5, sat_shift_limit=10, val_shift_limit=10, p=0.7
|
|
63
|
+
),
|
|
64
|
+
A.RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.3),
|
|
65
|
+
],
|
|
66
|
+
p=0.3,
|
|
67
|
+
),
|
|
68
|
+
A.RandomGamma(gamma_limit=(95, 105), p=0.20),
|
|
69
|
+
A.ToGray(p=0.05),
|
|
70
|
+
A.OneOf(
|
|
71
|
+
[
|
|
72
|
+
A.GaussianBlur(sigma_limit=(0.2, 0.5), p=0.5),
|
|
73
|
+
A.MotionBlur(blur_limit=(3, 3), p=0.5),
|
|
74
|
+
],
|
|
75
|
+
p=0.2,
|
|
76
|
+
),
|
|
77
|
+
A.OneOf(
|
|
78
|
+
[
|
|
79
|
+
A.GaussNoise(std_range=(0.01, 0.03), p=0.2),
|
|
80
|
+
A.MultiplicativeNoise(multiplier=(0.98, 1.02), p=0.1),
|
|
81
|
+
A.ISONoise(intensity=(0.005, 0.02), p=0.1),
|
|
82
|
+
A.ImageCompression(quality_range=(55, 90), p=0.1),
|
|
83
|
+
],
|
|
84
|
+
p=0.3,
|
|
85
|
+
),
|
|
86
|
+
A.OneOf(
|
|
87
|
+
[
|
|
88
|
+
A.CoarseDropout(
|
|
89
|
+
num_holes_range=(1, 14),
|
|
90
|
+
hole_height_range=(1, 5),
|
|
91
|
+
hole_width_range=(1, 5),
|
|
92
|
+
p=0.2,
|
|
93
|
+
),
|
|
94
|
+
A.PixelDropout(dropout_prob=0.02, p=0.3),
|
|
95
|
+
A.GridDropout(ratio=0.3, fill="random", p=0.3),
|
|
96
|
+
],
|
|
97
|
+
p=0.5,
|
|
98
|
+
),
|
|
99
|
+
]
|
|
100
|
+
)
|
|
101
|
+
raise ValueError(f"Unsupported img_color_mode: {img_color_mode!r}. Expected 'grayscale'/'rgb'.")
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Dataset module.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
import os
|
|
7
|
+
from typing import Union
|
|
8
|
+
import albumentations as A
|
|
9
|
+
import numpy as np
|
|
10
|
+
import numpy.typing as npt
|
|
11
|
+
import pandas as pd
|
|
12
|
+
from keras.src.trainers.data_adapters.py_dataset_adapter import PyDataset
|
|
13
|
+
|
|
14
|
+
from fast_plate_ocr.core.process import read_and_resize_plate_image
|
|
15
|
+
from fast_plate_ocr.train.model.config import PlateOCRConfig
|
|
16
|
+
from fast_plate_ocr.train.utilities import utils
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class PlateRecognitionPyDataset(PyDataset):
|
|
20
|
+
"""
|
|
21
|
+
Custom PyDataset for OCR license plate recognition.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
annotations_file: Union[str, os.PathLike],
|
|
27
|
+
plate_config: PlateOCRConfig,
|
|
28
|
+
batch_size: int,
|
|
29
|
+
transform: Optional[A.Compose] = None,
|
|
30
|
+
shuffle: bool = True,
|
|
31
|
+
**kwargs,
|
|
32
|
+
) -> None:
|
|
33
|
+
super().__init__(**kwargs)
|
|
34
|
+
# Load annotations
|
|
35
|
+
annotations = pd.read_csv(annotations_file, dtype={"plate_text": str})
|
|
36
|
+
annotations["image_path"] = (
|
|
37
|
+
os.path.dirname(os.path.realpath(annotations_file)) + os.sep + annotations["image_path"]
|
|
38
|
+
)
|
|
39
|
+
# Check that plate lengths do not exceed max_plate_slots.
|
|
40
|
+
assert (annotations["plate_text"].str.len() <= plate_config.max_plate_slots).all(), (
|
|
41
|
+
"Plates are longer than max_plate_slots specified param. Change the parameter."
|
|
42
|
+
)
|
|
43
|
+
# Convert the dataframe to a NumPy array
|
|
44
|
+
self.annotations = annotations.to_numpy()
|
|
45
|
+
|
|
46
|
+
self.plate_config = plate_config
|
|
47
|
+
self.transform = transform
|
|
48
|
+
self.batch_size = batch_size
|
|
49
|
+
self.shuffle = shuffle
|
|
50
|
+
|
|
51
|
+
# Shuffle once at initialization if `shuffle=True`
|
|
52
|
+
self._shuffle_data()
|
|
53
|
+
|
|
54
|
+
def __len__(self) -> int:
|
|
55
|
+
return math.ceil(len(self.annotations) / self.batch_size)
|
|
56
|
+
|
|
57
|
+
def __getitem__(self, idx: int) -> tuple[npt.NDArray, npt.NDArray]:
|
|
58
|
+
# Determine the idx-es of current batch
|
|
59
|
+
low = idx * self.batch_size
|
|
60
|
+
high = min(low + self.batch_size, len(self.annotations))
|
|
61
|
+
batch = self.annotations[low:high]
|
|
62
|
+
|
|
63
|
+
batch_x = []
|
|
64
|
+
batch_y = []
|
|
65
|
+
for image_path, plate_text in batch:
|
|
66
|
+
# Read and process image
|
|
67
|
+
x = read_and_resize_plate_image(
|
|
68
|
+
image_path=image_path,
|
|
69
|
+
img_height=self.plate_config.img_height,
|
|
70
|
+
img_width=self.plate_config.img_width,
|
|
71
|
+
image_color_mode=self.plate_config.image_color_mode,
|
|
72
|
+
keep_aspect_ratio=self.plate_config.keep_aspect_ratio,
|
|
73
|
+
interpolation_method=self.plate_config.interpolation,
|
|
74
|
+
padding_color=self.plate_config.padding_color,
|
|
75
|
+
)
|
|
76
|
+
# Transform target
|
|
77
|
+
y = utils.target_transform(
|
|
78
|
+
plate_text=plate_text,
|
|
79
|
+
max_plate_slots=self.plate_config.max_plate_slots,
|
|
80
|
+
alphabet=self.plate_config.alphabet,
|
|
81
|
+
pad_char=self.plate_config.pad_char,
|
|
82
|
+
)
|
|
83
|
+
# Apply augmentation if provided
|
|
84
|
+
if self.transform:
|
|
85
|
+
x = self.transform(image=x)["image"]
|
|
86
|
+
batch_x.append(x)
|
|
87
|
+
batch_y.append(y)
|
|
88
|
+
|
|
89
|
+
return np.array(batch_x), np.array(batch_y)
|
|
90
|
+
|
|
91
|
+
def _shuffle_data(self) -> None:
|
|
92
|
+
if self.shuffle:
|
|
93
|
+
np.random.shuffle(self.annotations)
|
|
94
|
+
|
|
95
|
+
def on_epoch_begin(self) -> None:
|
|
96
|
+
# Optionally shuffle the dataset at the start of each epoch
|
|
97
|
+
self._shuffle_data()
|
|
File without changes
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
"""
|
|
2
|
+
License Plate OCR config. This config file defines how license plate images and text should be
|
|
3
|
+
preprocessed for OCR model training and inference.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Annotated, TypeAlias, Union
|
|
8
|
+
|
|
9
|
+
import annotated_types
|
|
10
|
+
import yaml
|
|
11
|
+
from pydantic import (
|
|
12
|
+
BaseModel,
|
|
13
|
+
PositiveInt,
|
|
14
|
+
StringConstraints,
|
|
15
|
+
computed_field,
|
|
16
|
+
model_validator,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
from fast_plate_ocr.core.types import ImageColorMode, ImageInterpolation, PathLike
|
|
20
|
+
|
|
21
|
+
UInt8: TypeAlias = Annotated[int, annotated_types.Ge(0), annotated_types.Le(255)]
|
|
22
|
+
"""
|
|
23
|
+
An integer in the range [0, 255], used for color channel values.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class PlateOCRConfig(BaseModel, extra="forbid", frozen=True):
|
|
28
|
+
"""
|
|
29
|
+
Model License Plate OCR config.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
max_plate_slots: PositiveInt
|
|
33
|
+
"""
|
|
34
|
+
Max number of plate slots supported. This represents the number of model classification heads.
|
|
35
|
+
"""
|
|
36
|
+
alphabet: str
|
|
37
|
+
"""
|
|
38
|
+
All the possible character set for the model output.
|
|
39
|
+
"""
|
|
40
|
+
pad_char: Annotated[str, StringConstraints(min_length=1, max_length=1)]
|
|
41
|
+
"""
|
|
42
|
+
Padding character for plates which length is smaller than MAX_PLATE_SLOTS.
|
|
43
|
+
"""
|
|
44
|
+
img_height: PositiveInt
|
|
45
|
+
"""
|
|
46
|
+
Image height which is fed to the model.
|
|
47
|
+
"""
|
|
48
|
+
img_width: PositiveInt
|
|
49
|
+
"""
|
|
50
|
+
Image width which is fed to the model.
|
|
51
|
+
"""
|
|
52
|
+
keep_aspect_ratio: bool = False
|
|
53
|
+
"""
|
|
54
|
+
Keep aspect ratio of the input image.
|
|
55
|
+
"""
|
|
56
|
+
interpolation: ImageInterpolation = "linear"
|
|
57
|
+
"""
|
|
58
|
+
Interpolation method used for resizing the input image.
|
|
59
|
+
"""
|
|
60
|
+
image_color_mode: ImageColorMode = "grayscale"
|
|
61
|
+
"""
|
|
62
|
+
Input image color mode. Use 'grayscale' for single-channel input or 'rgb' for 3-channel input.
|
|
63
|
+
"""
|
|
64
|
+
padding_color: Union[tuple[UInt8, UInt8, UInt8], UInt8] = (114, 114, 114)
|
|
65
|
+
"""
|
|
66
|
+
Padding color used when keep_aspect_ratio is True. For grayscale images, this should be a single
|
|
67
|
+
integer and for RGB images, this must be a tuple of three integers.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
@computed_field # type: ignore[misc]
|
|
71
|
+
@property
|
|
72
|
+
def vocabulary_size(self) -> int:
|
|
73
|
+
return len(self.alphabet)
|
|
74
|
+
|
|
75
|
+
@computed_field # type: ignore[misc]
|
|
76
|
+
@property
|
|
77
|
+
def pad_idx(self) -> int:
|
|
78
|
+
return self.alphabet.index(self.pad_char)
|
|
79
|
+
|
|
80
|
+
@computed_field # type: ignore[misc]
|
|
81
|
+
@property
|
|
82
|
+
def num_channels(self) -> int:
|
|
83
|
+
return 3 if self.image_color_mode == "rgb" else 1
|
|
84
|
+
|
|
85
|
+
@model_validator(mode="after")
|
|
86
|
+
def check_alphabet_and_pad(self) -> "PlateOCRConfig":
|
|
87
|
+
# `pad_char` must be in alphabet
|
|
88
|
+
if self.pad_char not in self.alphabet:
|
|
89
|
+
raise ValueError("Pad character must be present in model alphabet.")
|
|
90
|
+
# all chars in alphabet must be unique
|
|
91
|
+
if len(set(self.alphabet)) != len(self.alphabet):
|
|
92
|
+
raise ValueError("Alphabet must not contain duplicate characters.")
|
|
93
|
+
return self
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def load_plate_config_from_yaml(yaml_path: PathLike) -> PlateOCRConfig:
|
|
97
|
+
"""
|
|
98
|
+
Reads and parses a YAML file containing the plate configuration.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
yaml_path: Path to the YAML file containing the plate config.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
PlateOCRConfig: Parsed and validated plate configuration.
|
|
105
|
+
|
|
106
|
+
Raises:
|
|
107
|
+
FileNotFoundError: If the YAML file does not exist.
|
|
108
|
+
"""
|
|
109
|
+
if not Path(yaml_path).is_file():
|
|
110
|
+
raise FileNotFoundError(f"Plate config '{yaml_path}' doesn't exist.")
|
|
111
|
+
with open(yaml_path, encoding="utf-8") as f_in:
|
|
112
|
+
yaml_content = yaml.safe_load(f_in)
|
|
113
|
+
config = PlateOCRConfig(**yaml_content)
|
|
114
|
+
return config
|