matrice-analytics 0.1.60__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- matrice_analytics/__init__.py +28 -0
- matrice_analytics/boundary_drawing_internal/README.md +305 -0
- matrice_analytics/boundary_drawing_internal/__init__.py +45 -0
- matrice_analytics/boundary_drawing_internal/boundary_drawing_internal.py +1207 -0
- matrice_analytics/boundary_drawing_internal/boundary_drawing_tool.py +429 -0
- matrice_analytics/boundary_drawing_internal/boundary_tool_template.html +1036 -0
- matrice_analytics/boundary_drawing_internal/data/.gitignore +12 -0
- matrice_analytics/boundary_drawing_internal/example_usage.py +206 -0
- matrice_analytics/boundary_drawing_internal/usage/README.md +110 -0
- matrice_analytics/boundary_drawing_internal/usage/boundary_drawer_launcher.py +102 -0
- matrice_analytics/boundary_drawing_internal/usage/simple_boundary_launcher.py +107 -0
- matrice_analytics/post_processing/README.md +455 -0
- matrice_analytics/post_processing/__init__.py +732 -0
- matrice_analytics/post_processing/advanced_tracker/README.md +650 -0
- matrice_analytics/post_processing/advanced_tracker/__init__.py +17 -0
- matrice_analytics/post_processing/advanced_tracker/base.py +99 -0
- matrice_analytics/post_processing/advanced_tracker/config.py +77 -0
- matrice_analytics/post_processing/advanced_tracker/kalman_filter.py +370 -0
- matrice_analytics/post_processing/advanced_tracker/matching.py +195 -0
- matrice_analytics/post_processing/advanced_tracker/strack.py +230 -0
- matrice_analytics/post_processing/advanced_tracker/tracker.py +367 -0
- matrice_analytics/post_processing/config.py +146 -0
- matrice_analytics/post_processing/core/__init__.py +63 -0
- matrice_analytics/post_processing/core/base.py +704 -0
- matrice_analytics/post_processing/core/config.py +3291 -0
- matrice_analytics/post_processing/core/config_utils.py +925 -0
- matrice_analytics/post_processing/face_reg/__init__.py +43 -0
- matrice_analytics/post_processing/face_reg/compare_similarity.py +556 -0
- matrice_analytics/post_processing/face_reg/embedding_manager.py +950 -0
- matrice_analytics/post_processing/face_reg/face_recognition.py +2234 -0
- matrice_analytics/post_processing/face_reg/face_recognition_client.py +606 -0
- matrice_analytics/post_processing/face_reg/people_activity_logging.py +321 -0
- matrice_analytics/post_processing/ocr/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/easyocr_extractor.py +250 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/__init__.py +9 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/__init__.py +4 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/cli.py +33 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/dataset_stats.py +139 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/export.py +398 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/train.py +447 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/utils.py +129 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/valid.py +93 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/validate_dataset.py +240 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/visualize_augmentation.py +176 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/visualize_predictions.py +96 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/__init__.py +3 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/process.py +246 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/types.py +60 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/utils.py +87 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/__init__.py +3 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/config.py +82 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/hub.py +141 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/plate_recognizer.py +323 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/py.typed +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/augmentation.py +101 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/dataset.py +97 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/config.py +114 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/layers.py +553 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/loss.py +55 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/metric.py +86 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/model_builders.py +95 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/model_schema.py +395 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/backend_utils.py +38 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/utils.py +214 -0
- matrice_analytics/post_processing/ocr/postprocessing.py +270 -0
- matrice_analytics/post_processing/ocr/preprocessing.py +52 -0
- matrice_analytics/post_processing/post_processor.py +1175 -0
- matrice_analytics/post_processing/test_cases/__init__.py +1 -0
- matrice_analytics/post_processing/test_cases/run_tests.py +143 -0
- matrice_analytics/post_processing/test_cases/test_advanced_customer_service.py +841 -0
- matrice_analytics/post_processing/test_cases/test_basic_counting_tracking.py +523 -0
- matrice_analytics/post_processing/test_cases/test_comprehensive.py +531 -0
- matrice_analytics/post_processing/test_cases/test_config.py +852 -0
- matrice_analytics/post_processing/test_cases/test_customer_service.py +585 -0
- matrice_analytics/post_processing/test_cases/test_data_generators.py +583 -0
- matrice_analytics/post_processing/test_cases/test_people_counting.py +510 -0
- matrice_analytics/post_processing/test_cases/test_processor.py +524 -0
- matrice_analytics/post_processing/test_cases/test_usecases.py +165 -0
- matrice_analytics/post_processing/test_cases/test_utilities.py +356 -0
- matrice_analytics/post_processing/test_cases/test_utils.py +743 -0
- matrice_analytics/post_processing/usecases/Histopathological_Cancer_Detection_img.py +604 -0
- matrice_analytics/post_processing/usecases/__init__.py +267 -0
- matrice_analytics/post_processing/usecases/abandoned_object_detection.py +797 -0
- matrice_analytics/post_processing/usecases/advanced_customer_service.py +1601 -0
- matrice_analytics/post_processing/usecases/age_detection.py +842 -0
- matrice_analytics/post_processing/usecases/age_gender_detection.py +1085 -0
- matrice_analytics/post_processing/usecases/anti_spoofing_detection.py +656 -0
- matrice_analytics/post_processing/usecases/assembly_line_detection.py +841 -0
- matrice_analytics/post_processing/usecases/banana_defect_detection.py +624 -0
- matrice_analytics/post_processing/usecases/basic_counting_tracking.py +667 -0
- matrice_analytics/post_processing/usecases/blood_cancer_detection_img.py +881 -0
- matrice_analytics/post_processing/usecases/car_damage_detection.py +834 -0
- matrice_analytics/post_processing/usecases/car_part_segmentation.py +946 -0
- matrice_analytics/post_processing/usecases/car_service.py +1601 -0
- matrice_analytics/post_processing/usecases/cardiomegaly_classification.py +864 -0
- matrice_analytics/post_processing/usecases/cell_microscopy_segmentation.py +897 -0
- matrice_analytics/post_processing/usecases/chicken_pose_detection.py +648 -0
- matrice_analytics/post_processing/usecases/child_monitoring.py +814 -0
- matrice_analytics/post_processing/usecases/color/clip.py +660 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/merges.txt +48895 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/preprocessor_config.json +28 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/special_tokens_map.json +30 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/tokenizer.json +245079 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/tokenizer_config.json +32 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/vocab.json +1 -0
- matrice_analytics/post_processing/usecases/color/color_map_utils.py +70 -0
- matrice_analytics/post_processing/usecases/color/color_mapper.py +468 -0
- matrice_analytics/post_processing/usecases/color_detection.py +1936 -0
- matrice_analytics/post_processing/usecases/color_map_utils.py +70 -0
- matrice_analytics/post_processing/usecases/concrete_crack_detection.py +827 -0
- matrice_analytics/post_processing/usecases/crop_weed_detection.py +781 -0
- matrice_analytics/post_processing/usecases/customer_service.py +1008 -0
- matrice_analytics/post_processing/usecases/defect_detection_products.py +936 -0
- matrice_analytics/post_processing/usecases/distracted_driver_detection.py +822 -0
- matrice_analytics/post_processing/usecases/drone_traffic_monitoring.py +585 -0
- matrice_analytics/post_processing/usecases/drowsy_driver_detection.py +829 -0
- matrice_analytics/post_processing/usecases/dwell_detection.py +829 -0
- matrice_analytics/post_processing/usecases/emergency_vehicle_detection.py +827 -0
- matrice_analytics/post_processing/usecases/face_emotion.py +813 -0
- matrice_analytics/post_processing/usecases/face_recognition.py +827 -0
- matrice_analytics/post_processing/usecases/fashion_detection.py +835 -0
- matrice_analytics/post_processing/usecases/field_mapping.py +902 -0
- matrice_analytics/post_processing/usecases/fire_detection.py +1146 -0
- matrice_analytics/post_processing/usecases/flare_analysis.py +836 -0
- matrice_analytics/post_processing/usecases/flower_segmentation.py +1006 -0
- matrice_analytics/post_processing/usecases/gas_leak_detection.py +837 -0
- matrice_analytics/post_processing/usecases/gender_detection.py +832 -0
- matrice_analytics/post_processing/usecases/human_activity_recognition.py +871 -0
- matrice_analytics/post_processing/usecases/intrusion_detection.py +1672 -0
- matrice_analytics/post_processing/usecases/leaf.py +821 -0
- matrice_analytics/post_processing/usecases/leaf_disease.py +840 -0
- matrice_analytics/post_processing/usecases/leak_detection.py +837 -0
- matrice_analytics/post_processing/usecases/license_plate_detection.py +1188 -0
- matrice_analytics/post_processing/usecases/license_plate_monitoring.py +1781 -0
- matrice_analytics/post_processing/usecases/litter_monitoring.py +717 -0
- matrice_analytics/post_processing/usecases/mask_detection.py +869 -0
- matrice_analytics/post_processing/usecases/natural_disaster.py +907 -0
- matrice_analytics/post_processing/usecases/parking.py +787 -0
- matrice_analytics/post_processing/usecases/parking_space_detection.py +822 -0
- matrice_analytics/post_processing/usecases/pcb_defect_detection.py +888 -0
- matrice_analytics/post_processing/usecases/pedestrian_detection.py +808 -0
- matrice_analytics/post_processing/usecases/people_counting.py +706 -0
- matrice_analytics/post_processing/usecases/people_counting_bckp.py +1683 -0
- matrice_analytics/post_processing/usecases/people_tracking.py +1842 -0
- matrice_analytics/post_processing/usecases/pipeline_detection.py +605 -0
- matrice_analytics/post_processing/usecases/plaque_segmentation_img.py +874 -0
- matrice_analytics/post_processing/usecases/pothole_segmentation.py +915 -0
- matrice_analytics/post_processing/usecases/ppe_compliance.py +645 -0
- matrice_analytics/post_processing/usecases/price_tag_detection.py +822 -0
- matrice_analytics/post_processing/usecases/proximity_detection.py +1901 -0
- matrice_analytics/post_processing/usecases/road_lane_detection.py +623 -0
- matrice_analytics/post_processing/usecases/road_traffic_density.py +832 -0
- matrice_analytics/post_processing/usecases/road_view_segmentation.py +915 -0
- matrice_analytics/post_processing/usecases/shelf_inventory_detection.py +583 -0
- matrice_analytics/post_processing/usecases/shoplifting_detection.py +822 -0
- matrice_analytics/post_processing/usecases/shopping_cart_analysis.py +899 -0
- matrice_analytics/post_processing/usecases/skin_cancer_classification_img.py +864 -0
- matrice_analytics/post_processing/usecases/smoker_detection.py +833 -0
- matrice_analytics/post_processing/usecases/solar_panel.py +810 -0
- matrice_analytics/post_processing/usecases/suspicious_activity_detection.py +1030 -0
- matrice_analytics/post_processing/usecases/template_usecase.py +380 -0
- matrice_analytics/post_processing/usecases/theft_detection.py +648 -0
- matrice_analytics/post_processing/usecases/traffic_sign_monitoring.py +724 -0
- matrice_analytics/post_processing/usecases/underground_pipeline_defect_detection.py +775 -0
- matrice_analytics/post_processing/usecases/underwater_pollution_detection.py +842 -0
- matrice_analytics/post_processing/usecases/vehicle_monitoring.py +1029 -0
- matrice_analytics/post_processing/usecases/warehouse_object_segmentation.py +899 -0
- matrice_analytics/post_processing/usecases/waterbody_segmentation.py +923 -0
- matrice_analytics/post_processing/usecases/weapon_detection.py +771 -0
- matrice_analytics/post_processing/usecases/weld_defect_detection.py +615 -0
- matrice_analytics/post_processing/usecases/wildlife_monitoring.py +898 -0
- matrice_analytics/post_processing/usecases/windmill_maintenance.py +834 -0
- matrice_analytics/post_processing/usecases/wound_segmentation.py +856 -0
- matrice_analytics/post_processing/utils/__init__.py +150 -0
- matrice_analytics/post_processing/utils/advanced_counting_utils.py +400 -0
- matrice_analytics/post_processing/utils/advanced_helper_utils.py +317 -0
- matrice_analytics/post_processing/utils/advanced_tracking_utils.py +461 -0
- matrice_analytics/post_processing/utils/alerting_utils.py +213 -0
- matrice_analytics/post_processing/utils/category_mapping_utils.py +94 -0
- matrice_analytics/post_processing/utils/color_utils.py +592 -0
- matrice_analytics/post_processing/utils/counting_utils.py +182 -0
- matrice_analytics/post_processing/utils/filter_utils.py +261 -0
- matrice_analytics/post_processing/utils/format_utils.py +293 -0
- matrice_analytics/post_processing/utils/geometry_utils.py +300 -0
- matrice_analytics/post_processing/utils/smoothing_utils.py +358 -0
- matrice_analytics/post_processing/utils/tracking_utils.py +234 -0
- matrice_analytics/py.typed +0 -0
- matrice_analytics-0.1.60.dist-info/METADATA +481 -0
- matrice_analytics-0.1.60.dist-info/RECORD +196 -0
- matrice_analytics-0.1.60.dist-info/WHEEL +5 -0
- matrice_analytics-0.1.60.dist-info/licenses/LICENSE.txt +21 -0
- matrice_analytics-0.1.60.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utility functions for processing model input/output.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
from typing import Union
|
|
9
|
+
import cv2
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
from fast_plate_ocr.core.types import (
|
|
13
|
+
ImageColorMode,
|
|
14
|
+
ImageInterpolation,
|
|
15
|
+
PaddingColor,
|
|
16
|
+
PathLike,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
INTERPOLATION_MAP: dict[ImageInterpolation, int] = {
|
|
20
|
+
"nearest": cv2.INTER_NEAREST,
|
|
21
|
+
"linear": cv2.INTER_LINEAR,
|
|
22
|
+
"cubic": cv2.INTER_CUBIC,
|
|
23
|
+
"area": cv2.INTER_AREA,
|
|
24
|
+
"lanczos4": cv2.INTER_LANCZOS4,
|
|
25
|
+
}
|
|
26
|
+
"""Mapping from interpolation method name to OpenCV constant."""
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def read_plate_image(
|
|
30
|
+
image_path: PathLike,
|
|
31
|
+
image_color_mode: ImageColorMode = "grayscale",
|
|
32
|
+
) -> np.ndarray:
|
|
33
|
+
"""
|
|
34
|
+
Reads an image from disk in the requested colour mode.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
image_path: Path to the image file.
|
|
38
|
+
image_color_mode: ``"grayscale"`` for single-channel or ``"rgb"`` for three-channel
|
|
39
|
+
colour. Defaults to ``"grayscale"``.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
The image as a NumPy array.
|
|
43
|
+
Grayscale images have shape ``(H, W)``, RGB images have shape ``(H, W, 3)``.
|
|
44
|
+
|
|
45
|
+
Raises:
|
|
46
|
+
FileNotFoundError: If the image file does not exist.
|
|
47
|
+
ValueError: If the image cannot be decoded.
|
|
48
|
+
"""
|
|
49
|
+
image_path = str(image_path)
|
|
50
|
+
|
|
51
|
+
if not os.path.exists(image_path):
|
|
52
|
+
raise FileNotFoundError(f"Image not found: {image_path}")
|
|
53
|
+
|
|
54
|
+
if image_color_mode == "rgb":
|
|
55
|
+
raw = cv2.imread(image_path, cv2.IMREAD_COLOR)
|
|
56
|
+
if raw is None:
|
|
57
|
+
raise ValueError(f"Failed to decode image: {image_path}")
|
|
58
|
+
img = cv2.cvtColor(raw, cv2.COLOR_BGR2RGB)
|
|
59
|
+
else:
|
|
60
|
+
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
|
|
61
|
+
if img is None:
|
|
62
|
+
raise ValueError(f"Failed to decode image: {image_path}")
|
|
63
|
+
|
|
64
|
+
return img
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def resize_image(
|
|
68
|
+
img: np.ndarray,
|
|
69
|
+
img_height: int,
|
|
70
|
+
img_width: int,
|
|
71
|
+
image_color_mode: ImageColorMode = "grayscale",
|
|
72
|
+
keep_aspect_ratio: bool = False,
|
|
73
|
+
interpolation_method: ImageInterpolation = "linear",
|
|
74
|
+
padding_color: PaddingColor = (114, 114, 114),
|
|
75
|
+
) -> np.ndarray:
|
|
76
|
+
"""
|
|
77
|
+
Resizes an in-memory image with optional aspect-ratio preservation and padding.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
img: Input image.
|
|
81
|
+
img_height: Target image height.
|
|
82
|
+
img_width: Target image width.
|
|
83
|
+
image_color_mode: Output colour mode, ``"grayscale"`` or ``"rgb"``.
|
|
84
|
+
keep_aspect_ratio: If ``True``, maintain the original aspect ratio using letter-box
|
|
85
|
+
padding. Defaults to ``False``.
|
|
86
|
+
interpolation_method: Interpolation method used for resizing. Defaults to ``"linear"``.
|
|
87
|
+
padding_color: Padding colour (scalar for grayscale, tuple for RGB). Defaults to
|
|
88
|
+
``(114, 114, 114)``.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
The resized image with shape ``(H, W, C)`` (a channel axis is added for grayscale).
|
|
92
|
+
|
|
93
|
+
Raises:
|
|
94
|
+
ValueError: If ``padding_color`` length is not 3 for RGB output.
|
|
95
|
+
"""
|
|
96
|
+
# pylint: disable=too-many-locals
|
|
97
|
+
|
|
98
|
+
interpolation = INTERPOLATION_MAP[interpolation_method]
|
|
99
|
+
|
|
100
|
+
if not keep_aspect_ratio:
|
|
101
|
+
img = cv2.resize(img, (img_width, img_height), interpolation=interpolation)
|
|
102
|
+
|
|
103
|
+
else:
|
|
104
|
+
orig_h, orig_w = img.shape[:2]
|
|
105
|
+
# Scale ratio (new / old) - choose the limiting dimension
|
|
106
|
+
r = min(img_height / orig_h, img_width / orig_w)
|
|
107
|
+
# Compute the size of the resized (unpadded) image
|
|
108
|
+
new_unpad_w, new_unpad_h = round(orig_w * r), round(orig_h * r)
|
|
109
|
+
# Resize if necessary
|
|
110
|
+
if (orig_w, orig_h) != (new_unpad_w, new_unpad_h):
|
|
111
|
+
img = cv2.resize(img, (new_unpad_w, new_unpad_h), interpolation=interpolation)
|
|
112
|
+
# Padding on each side
|
|
113
|
+
dw, dh = (img_width - new_unpad_w) / 2, (img_height - new_unpad_h) / 2
|
|
114
|
+
top, bottom, left, right = (
|
|
115
|
+
round(dh - 0.1),
|
|
116
|
+
round(dh + 0.1),
|
|
117
|
+
round(dw - 0.1),
|
|
118
|
+
round(dw + 0.1),
|
|
119
|
+
)
|
|
120
|
+
border_color: PaddingColor
|
|
121
|
+
# Ensure padding colour matches channel count
|
|
122
|
+
if image_color_mode == "grayscale":
|
|
123
|
+
if isinstance(padding_color, tuple):
|
|
124
|
+
border_color = int(padding_color[0])
|
|
125
|
+
else:
|
|
126
|
+
border_color = int(padding_color)
|
|
127
|
+
elif image_color_mode == "rgb":
|
|
128
|
+
if isinstance(padding_color, tuple):
|
|
129
|
+
if len(padding_color) != 3:
|
|
130
|
+
raise ValueError("padding_color must be length-3 for RGB images")
|
|
131
|
+
border_color = tuple(int(c) for c in padding_color) # type: ignore[assignment]
|
|
132
|
+
else:
|
|
133
|
+
border_color = (int(padding_color),) * 3
|
|
134
|
+
img = cv2.copyMakeBorder(
|
|
135
|
+
img,
|
|
136
|
+
top,
|
|
137
|
+
bottom,
|
|
138
|
+
left,
|
|
139
|
+
right,
|
|
140
|
+
borderType=cv2.BORDER_CONSTANT,
|
|
141
|
+
value=border_color, # type: ignore[arg-type]
|
|
142
|
+
)
|
|
143
|
+
# Add channel axis for gray so output is HxWxC
|
|
144
|
+
if image_color_mode == "grayscale" and img.ndim == 2:
|
|
145
|
+
img = np.expand_dims(img, axis=-1)
|
|
146
|
+
|
|
147
|
+
return img
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def read_and_resize_plate_image(
|
|
151
|
+
image_path: PathLike,
|
|
152
|
+
img_height: int,
|
|
153
|
+
img_width: int,
|
|
154
|
+
image_color_mode: ImageColorMode = "grayscale",
|
|
155
|
+
keep_aspect_ratio: bool = False,
|
|
156
|
+
interpolation_method: ImageInterpolation = "linear",
|
|
157
|
+
padding_color: PaddingColor = (114, 114, 114),
|
|
158
|
+
) -> np.ndarray:
|
|
159
|
+
"""
|
|
160
|
+
Reads an image from disk and resizes it for model input.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
image_path: Path to the image.
|
|
164
|
+
img_height: Desired output height.
|
|
165
|
+
img_width: Desired output width.
|
|
166
|
+
image_color_mode: ``"grayscale"`` or ``"rgb"``. Defaults to ``"grayscale"``.
|
|
167
|
+
keep_aspect_ratio: Whether to preserve aspect ratio via letter-boxing. Defaults to
|
|
168
|
+
``False``.
|
|
169
|
+
interpolation_method: Interpolation method to use. Defaults to ``"linear"``.
|
|
170
|
+
padding_color: Colour used for padding when aspect ratio is preserved. Defaults to
|
|
171
|
+
``(114, 114, 114)``.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
The resized (and possibly padded) image with shape ``(H, W, C)``.
|
|
175
|
+
"""
|
|
176
|
+
img = read_plate_image(image_path, image_color_mode=image_color_mode)
|
|
177
|
+
return resize_image(
|
|
178
|
+
img,
|
|
179
|
+
img_height,
|
|
180
|
+
img_width,
|
|
181
|
+
image_color_mode=image_color_mode,
|
|
182
|
+
keep_aspect_ratio=keep_aspect_ratio,
|
|
183
|
+
interpolation_method=interpolation_method,
|
|
184
|
+
padding_color=padding_color,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def preprocess_image(images: np.ndarray) -> np.ndarray:
|
|
189
|
+
"""
|
|
190
|
+
Converts image data to the format expected by the model.
|
|
191
|
+
|
|
192
|
+
The model itself handles pixel-value normalisation, so this function only ensures the
|
|
193
|
+
batch-dimension and dtype are correct.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
images: Image or batch of images with shape ``(H, W, C)`` or ``(N, H, W, C)``.
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
A NumPy array with shape ``(N, H, W, C)`` and dtype ``uint8``.
|
|
200
|
+
|
|
201
|
+
Raises:
|
|
202
|
+
ValueError: If the input does not have 3 or 4 dimensions.
|
|
203
|
+
"""
|
|
204
|
+
# single sample (H, W, C)
|
|
205
|
+
if images.ndim == 3:
|
|
206
|
+
images = np.expand_dims(images, axis=0)
|
|
207
|
+
|
|
208
|
+
if images.ndim != 4:
|
|
209
|
+
raise ValueError("Expected input of shape (N, H, W, C).")
|
|
210
|
+
|
|
211
|
+
return images.astype(np.uint8)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def postprocess_output(
|
|
215
|
+
model_output: np.ndarray,
|
|
216
|
+
max_plate_slots: int,
|
|
217
|
+
model_alphabet: str,
|
|
218
|
+
return_confidence: bool = False,
|
|
219
|
+
) -> tuple[list[str], Union[np.ndarray, list[str]]]:
|
|
220
|
+
"""
|
|
221
|
+
Decodes model predictions into licence-plate strings.
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
model_output: Raw output tensor from the model.
|
|
225
|
+
max_plate_slots: Maximum number of character positions.
|
|
226
|
+
model_alphabet: Alphabet used by the model.
|
|
227
|
+
return_confidence: If ``True``, also return per-character confidence scores.
|
|
228
|
+
Defaults to ``False``.
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
If ``return_confidence`` is ``False``: a list of decoded plate strings.
|
|
232
|
+
If ``True``: a two-tuple ``(plates, probs)`` where
|
|
233
|
+
|
|
234
|
+
* ``plates`` is the list of decoded strings, and
|
|
235
|
+
* ``probs`` is an array of shape ``(N, max_plate_slots)`` with the corresponding
|
|
236
|
+
confidence scores.
|
|
237
|
+
"""
|
|
238
|
+
predictions = model_output.reshape((-1, max_plate_slots, len(model_alphabet)))
|
|
239
|
+
prediction_indices = np.argmax(predictions, axis=-1)
|
|
240
|
+
alphabet_array = np.array(list(model_alphabet))
|
|
241
|
+
plate_chars = alphabet_array[prediction_indices]
|
|
242
|
+
plates: list[str] = np.apply_along_axis("".join, 1, plate_chars).tolist()
|
|
243
|
+
if return_confidence:
|
|
244
|
+
probs = np.max(predictions, axis=-1)
|
|
245
|
+
return plates, probs
|
|
246
|
+
return plates
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Common custom types used across the lib.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
from collections.abc import Sequence
|
|
9
|
+
from typing import Literal, Tuple, Union
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
from numpy import typing as npt
|
|
13
|
+
|
|
14
|
+
ImageInterpolation = Literal["nearest", "linear", "cubic", "area", "lanczos4"]
|
|
15
|
+
"""Interpolation method used for resizing the input image."""
|
|
16
|
+
ImageColorMode = Literal["grayscale", "rgb"]
|
|
17
|
+
"""
|
|
18
|
+
Input image color mode. Use ``grayscale`` for single-channel input or ``rgb`` for 3-channel input.
|
|
19
|
+
"""
|
|
20
|
+
PaddingColor = Union[Tuple[int, int, int], int]
|
|
21
|
+
"""Padding colour for letterboxing (only used when keeping image aspect ratio)."""
|
|
22
|
+
PathLike = Union[str, os.PathLike]
|
|
23
|
+
"""Path-like objects."""
|
|
24
|
+
ImgLike = Union[PathLike, npt.NDArray[np.uint8]]
|
|
25
|
+
"""Image-like objects, including paths to image files and NumPy arrays of images."""
|
|
26
|
+
BatchOrImgLike = Union[ImgLike, Sequence[ImgLike]]
|
|
27
|
+
"""
|
|
28
|
+
Image-like objects, including paths to image files and NumPy arrays of images, or a batch of images.
|
|
29
|
+
"""
|
|
30
|
+
BatchArray = npt.NDArray[np.uint8]
|
|
31
|
+
"""Numpy array of images, representing a batch of images."""
|
|
32
|
+
TensorDataFormat = Literal["channels_last", "channels_first"]
|
|
33
|
+
"""
|
|
34
|
+
Data format of the input tensor. It can be either ``channels_last`` or ``channels_first``.
|
|
35
|
+
``channels_last`` corresponds to inputs with shape ``(batch, height, width, channels)``, while
|
|
36
|
+
``channels_first`` corresponds to inputs with shape ``(batch, channels, height, width)``.
|
|
37
|
+
"""
|
|
38
|
+
KerasDtypes = Literal[
|
|
39
|
+
"float16",
|
|
40
|
+
"float32",
|
|
41
|
+
"float64",
|
|
42
|
+
"uint8",
|
|
43
|
+
"uint16",
|
|
44
|
+
"uint32",
|
|
45
|
+
"uint64",
|
|
46
|
+
"int8",
|
|
47
|
+
"int16",
|
|
48
|
+
"int32",
|
|
49
|
+
"int64",
|
|
50
|
+
"bfloat16",
|
|
51
|
+
"bool",
|
|
52
|
+
"string",
|
|
53
|
+
"float8_e4m3fn",
|
|
54
|
+
"float8_e5m2",
|
|
55
|
+
"complex64",
|
|
56
|
+
"complex128",
|
|
57
|
+
]
|
|
58
|
+
"""
|
|
59
|
+
Keras data types supported by the library.
|
|
60
|
+
"""
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Common utilities used across the package.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
import time
|
|
10
|
+
from collections.abc import Callable, Iterator
|
|
11
|
+
from contextlib import contextmanager
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import IO, Any, Optional, Union
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@contextmanager
|
|
17
|
+
def log_time_taken(process_name: str) -> Iterator[None]:
|
|
18
|
+
"""
|
|
19
|
+
A concise context manager to time code snippets and log the result.
|
|
20
|
+
|
|
21
|
+
Usage:
|
|
22
|
+
```python
|
|
23
|
+
with log_time_taken("process_name"):
|
|
24
|
+
# Code snippet to be timed
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
process_name: Name of the process being timed.
|
|
29
|
+
"""
|
|
30
|
+
time_start: float = time.perf_counter()
|
|
31
|
+
try:
|
|
32
|
+
yield
|
|
33
|
+
finally:
|
|
34
|
+
time_end: float = time.perf_counter()
|
|
35
|
+
time_elapsed: float = time_end - time_start
|
|
36
|
+
logger = logging.getLogger(__name__)
|
|
37
|
+
logger.info("Computation time of '%s' = %.3fms", process_name, 1_000 * time_elapsed)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@contextmanager
|
|
41
|
+
def measure_time() -> Iterator[Callable[[], float]]:
|
|
42
|
+
"""
|
|
43
|
+
A context manager for measuring execution time (in milliseconds) within its code block.
|
|
44
|
+
|
|
45
|
+
Usage:
|
|
46
|
+
```python
|
|
47
|
+
with measure_time() as timer:
|
|
48
|
+
# Code snippet to be timed
|
|
49
|
+
print(f"Code took: {timer()} ms")
|
|
50
|
+
```
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
A function that returns the elapsed time in milliseconds.
|
|
54
|
+
"""
|
|
55
|
+
start_time = end_time = time.perf_counter()
|
|
56
|
+
yield lambda: (end_time - start_time) * 1_000
|
|
57
|
+
end_time = time.perf_counter()
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@contextmanager
|
|
61
|
+
def safe_write(
|
|
62
|
+
file: Union[str, os.PathLike[str]],
|
|
63
|
+
mode: str = "wb",
|
|
64
|
+
encoding: Optional[str] = None,
|
|
65
|
+
**kwargs: Any,
|
|
66
|
+
) -> Iterator[IO]:
|
|
67
|
+
"""
|
|
68
|
+
Context manager for safe file writing.
|
|
69
|
+
|
|
70
|
+
Opens the specified file for writing and yields a file object.
|
|
71
|
+
If an exception occurs during writing, the file is removed before raising the exception.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
file: Path to the file to write.
|
|
75
|
+
mode: File open mode (e.g. ``"wb"``, ``"w"``, etc.). Defaults to ``"wb"``.
|
|
76
|
+
encoding: Encoding to use (for text modes). Ignored in binary mode.
|
|
77
|
+
**kwargs: Additional arguments passed to ``open()``.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
A writable file object.
|
|
81
|
+
"""
|
|
82
|
+
try:
|
|
83
|
+
with open(file, mode, encoding=encoding, **kwargs) as f:
|
|
84
|
+
yield f
|
|
85
|
+
except Exception as e:
|
|
86
|
+
Path(file).unlink(missing_ok=True)
|
|
87
|
+
raise e
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model config reading/parsing for doing inference.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import Union
|
|
9
|
+
import yaml
|
|
10
|
+
|
|
11
|
+
from fast_plate_ocr.core.types import ImageColorMode, ImageInterpolation, PathLike
|
|
12
|
+
|
|
13
|
+
# pylint: disable=duplicate-code
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(frozen=True)
|
|
17
|
+
class PlateOCRConfig: # pylint: disable=too-many-instance-attributes
|
|
18
|
+
"""
|
|
19
|
+
Plate OCR Config used for inference.
|
|
20
|
+
|
|
21
|
+
This dataclass is used to read and parse the config file used for training the OCR model.
|
|
22
|
+
We prefer to keep the inference package with minimal dependencies and avoid using Pydantic here.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
max_plate_slots: int
|
|
26
|
+
"""
|
|
27
|
+
Max number of plate slots supported. This represents the number of model classification heads.
|
|
28
|
+
"""
|
|
29
|
+
alphabet: str
|
|
30
|
+
"""
|
|
31
|
+
All the possible character set for the model output.
|
|
32
|
+
"""
|
|
33
|
+
pad_char: str
|
|
34
|
+
"""
|
|
35
|
+
Padding character for plates which length is smaller than MAX_PLATE_SLOTS.
|
|
36
|
+
"""
|
|
37
|
+
img_height: int
|
|
38
|
+
"""
|
|
39
|
+
Image height which is fed to the model.
|
|
40
|
+
"""
|
|
41
|
+
img_width: int
|
|
42
|
+
"""
|
|
43
|
+
Image width which is fed to the model.
|
|
44
|
+
"""
|
|
45
|
+
keep_aspect_ratio: bool = False
|
|
46
|
+
"""
|
|
47
|
+
Keep aspect ratio of the input image.
|
|
48
|
+
"""
|
|
49
|
+
interpolation: ImageInterpolation = "linear"
|
|
50
|
+
"""
|
|
51
|
+
Interpolation method used for resizing the input image.
|
|
52
|
+
"""
|
|
53
|
+
image_color_mode: ImageColorMode = "grayscale"
|
|
54
|
+
"""
|
|
55
|
+
Input image color mode. Use 'grayscale' for single-channel input or 'rgb' for 3-channel input.
|
|
56
|
+
"""
|
|
57
|
+
padding_color: Union[tuple[int, int, int], int] = (114, 114, 114)
|
|
58
|
+
"""
|
|
59
|
+
Padding color used when keep_aspect_ratio is True. For grayscale images, this should be a single
|
|
60
|
+
integer and for RGB images, this must be a tuple of three integers.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def vocabulary_size(self) -> int:
|
|
65
|
+
return len(self.alphabet)
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def pad_idx(self) -> int:
|
|
69
|
+
return self.alphabet.index(self.pad_char)
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def num_channels(self) -> int:
|
|
73
|
+
return 3 if self.image_color_mode == "rgb" else 1
|
|
74
|
+
|
|
75
|
+
@classmethod
|
|
76
|
+
def from_yaml(cls, path: PathLike) -> "PlateOCRConfig":
|
|
77
|
+
"""
|
|
78
|
+
Read and parse a yaml containing the Plate OCR config.
|
|
79
|
+
"""
|
|
80
|
+
with open(path, encoding="utf-8") as f_in:
|
|
81
|
+
data = yaml.safe_load(f_in)
|
|
82
|
+
return cls(**data)
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utilities function used for doing inference with the OCR models.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import pathlib
|
|
9
|
+
import shutil
|
|
10
|
+
import urllib.request
|
|
11
|
+
from http import HTTPStatus
|
|
12
|
+
from typing import Literal, Tuple, Dict, Optional
|
|
13
|
+
|
|
14
|
+
from tqdm.asyncio import tqdm
|
|
15
|
+
|
|
16
|
+
from fast_plate_ocr.core.utils import safe_write
|
|
17
|
+
|
|
18
|
+
BASE_URL: str = "https://github.com/ankandrew/cnn-ocr-lp/releases/download"
|
|
19
|
+
OcrModel = Literal[
|
|
20
|
+
"cct-s-v1-global-model",
|
|
21
|
+
"cct-xs-v1-global-model",
|
|
22
|
+
"cct-s-relu-v1-global-model",
|
|
23
|
+
"cct-xs-relu-v1-global-model",
|
|
24
|
+
"argentinian-plates-cnn-model",
|
|
25
|
+
"argentinian-plates-cnn-synth-model",
|
|
26
|
+
"european-plates-mobile-vit-v2-model",
|
|
27
|
+
"global-plates-mobile-vit-v2-model",
|
|
28
|
+
]
|
|
29
|
+
"""Available OCR models for doing inference."""
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
AVAILABLE_ONNX_MODELS: Dict[OcrModel, Tuple[str, str]] = {
|
|
33
|
+
"cct-s-v1-global-model": (
|
|
34
|
+
f"{BASE_URL}/arg-plates/cct_s_v1_global.onnx",
|
|
35
|
+
f"{BASE_URL}/arg-plates/cct_s_v1_global_plate_config.yaml",
|
|
36
|
+
),
|
|
37
|
+
"cct-xs-v1-global-model": (
|
|
38
|
+
f"{BASE_URL}/arg-plates/cct_xs_v1_global.onnx",
|
|
39
|
+
f"{BASE_URL}/arg-plates/cct_xs_v1_global_plate_config.yaml",
|
|
40
|
+
),
|
|
41
|
+
"argentinian-plates-cnn-model": (
|
|
42
|
+
f"{BASE_URL}/arg-plates/arg_cnn_ocr.onnx",
|
|
43
|
+
f"{BASE_URL}/arg-plates/arg_cnn_ocr_config.yaml",
|
|
44
|
+
),
|
|
45
|
+
"argentinian-plates-cnn-synth-model": (
|
|
46
|
+
f"{BASE_URL}/arg-plates/arg_cnn_ocr_synth.onnx",
|
|
47
|
+
f"{BASE_URL}/arg-plates/arg_cnn_ocr_config.yaml",
|
|
48
|
+
),
|
|
49
|
+
"european-plates-mobile-vit-v2-model": (
|
|
50
|
+
f"{BASE_URL}/arg-plates/european_mobile_vit_v2_ocr.onnx",
|
|
51
|
+
f"{BASE_URL}/arg-plates/european_mobile_vit_v2_ocr_config.yaml",
|
|
52
|
+
),
|
|
53
|
+
"global-plates-mobile-vit-v2-model": (
|
|
54
|
+
f"{BASE_URL}/arg-plates/global_mobile_vit_v2_ocr.onnx",
|
|
55
|
+
f"{BASE_URL}/arg-plates/global_mobile_vit_v2_ocr_config.yaml",
|
|
56
|
+
),
|
|
57
|
+
"cct-s-relu-v1-global-model": (
|
|
58
|
+
f"{BASE_URL}/arg-plates/cct_s_relu_v1_global.onnx",
|
|
59
|
+
f"{BASE_URL}/arg-plates/cct_s_relu_v1_global_plate_config.yaml",
|
|
60
|
+
),
|
|
61
|
+
"cct-xs-relu-v1-global-model": (
|
|
62
|
+
f"{BASE_URL}/arg-plates/cct_xs_relu_v1_global.onnx",
|
|
63
|
+
f"{BASE_URL}/arg-plates/cct_xs_relu_v1_global_plate_config.yaml",
|
|
64
|
+
),
|
|
65
|
+
}
|
|
66
|
+
"""Dictionary of available OCR models and their URLs."""
|
|
67
|
+
|
|
68
|
+
MODEL_CACHE_DIR: pathlib.Path = pathlib.Path.home() / ".cache" / "fast-plate-ocr"
|
|
69
|
+
"""Default location where models will be stored."""
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _download_with_progress(url: str, filename: pathlib.Path) -> None:
|
|
73
|
+
"""
|
|
74
|
+
Download utility function with progress bar.
|
|
75
|
+
|
|
76
|
+
:param url: URL of the model to download.
|
|
77
|
+
:param filename: Where to save the OCR model.
|
|
78
|
+
"""
|
|
79
|
+
with urllib.request.urlopen(url) as response, safe_write(filename, mode="wb") as out_file:
|
|
80
|
+
if response.getcode() != HTTPStatus.OK:
|
|
81
|
+
raise ValueError(f"Failed to download file from {url}. Status code: {response.status}")
|
|
82
|
+
|
|
83
|
+
file_size = int(response.headers.get("Content-Length", 0))
|
|
84
|
+
desc = f"Downloading {filename.name}"
|
|
85
|
+
|
|
86
|
+
with tqdm.wrapattr(out_file, "write", total=file_size, desc=desc) as f_out:
|
|
87
|
+
shutil.copyfileobj(response, f_out)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def download_model(
|
|
91
|
+
model_name: OcrModel,
|
|
92
|
+
save_directory: Optional[pathlib.Path] = None,
|
|
93
|
+
force_download: bool = False,
|
|
94
|
+
) -> tuple[pathlib.Path, pathlib.Path]:
|
|
95
|
+
"""
|
|
96
|
+
Download an OCR model and the config to a given directory.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
model_name: Which model to download.
|
|
100
|
+
save_directory: Directory to save the OCR model. It should point to a folder.
|
|
101
|
+
If not supplied, this will point to '~/.cache/<model_name>'.
|
|
102
|
+
force_download: Force and download the model if it already exists in
|
|
103
|
+
`save_directory`.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
A tuple consisting of (model_downloaded_path, config_downloaded_path).
|
|
107
|
+
"""
|
|
108
|
+
if model_name not in AVAILABLE_ONNX_MODELS:
|
|
109
|
+
available_models = ", ".join(AVAILABLE_ONNX_MODELS.keys())
|
|
110
|
+
raise ValueError(f"Unknown model {model_name}. Use one of [{available_models}]")
|
|
111
|
+
|
|
112
|
+
if save_directory is None:
|
|
113
|
+
save_directory = MODEL_CACHE_DIR / model_name
|
|
114
|
+
elif save_directory.is_file():
|
|
115
|
+
raise ValueError(f"Expected a directory, but got {save_directory}")
|
|
116
|
+
|
|
117
|
+
save_directory.mkdir(parents=True, exist_ok=True)
|
|
118
|
+
|
|
119
|
+
model_url, plate_config_url = AVAILABLE_ONNX_MODELS[model_name]
|
|
120
|
+
model_filename = save_directory / model_url.split("/")[-1]
|
|
121
|
+
plate_config_filename = save_directory / plate_config_url.split("/")[-1]
|
|
122
|
+
|
|
123
|
+
if not force_download and model_filename.is_file() and plate_config_filename.is_file():
|
|
124
|
+
logging.info(
|
|
125
|
+
"Skipping download of '%s' model, already exists at %s",
|
|
126
|
+
model_name,
|
|
127
|
+
save_directory,
|
|
128
|
+
)
|
|
129
|
+
return model_filename, plate_config_filename
|
|
130
|
+
|
|
131
|
+
# Download the model if not present or if we want to force the download
|
|
132
|
+
if force_download or not model_filename.is_file():
|
|
133
|
+
logging.info("Downloading model to %s", model_filename)
|
|
134
|
+
_download_with_progress(url=model_url, filename=model_filename)
|
|
135
|
+
|
|
136
|
+
# Same for the config
|
|
137
|
+
if force_download or not plate_config_filename.is_file():
|
|
138
|
+
logging.info("Downloading config to %s", plate_config_filename)
|
|
139
|
+
_download_with_progress(url=plate_config_url, filename=plate_config_filename)
|
|
140
|
+
|
|
141
|
+
return model_filename, plate_config_filename
|