matrice-analytics 0.1.3__py3-none-any.whl → 0.1.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of matrice-analytics might be problematic. Click here for more details.
- matrice_analytics/post_processing/advanced_tracker/matching.py +3 -3
- matrice_analytics/post_processing/advanced_tracker/strack.py +1 -1
- matrice_analytics/post_processing/config.py +4 -0
- matrice_analytics/post_processing/core/config.py +115 -12
- matrice_analytics/post_processing/face_reg/compare_similarity.py +5 -5
- matrice_analytics/post_processing/face_reg/embedding_manager.py +109 -8
- matrice_analytics/post_processing/face_reg/face_recognition.py +157 -61
- matrice_analytics/post_processing/face_reg/face_recognition_client.py +339 -88
- matrice_analytics/post_processing/face_reg/people_activity_logging.py +67 -29
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/__init__.py +9 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/__init__.py +4 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/cli.py +33 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/dataset_stats.py +139 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/export.py +398 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/train.py +447 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/utils.py +129 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/valid.py +93 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/validate_dataset.py +240 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/visualize_augmentation.py +176 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/visualize_predictions.py +96 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/__init__.py +3 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/process.py +246 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/types.py +60 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/utils.py +87 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/__init__.py +3 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/config.py +82 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/hub.py +141 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/plate_recognizer.py +323 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/py.typed +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/augmentation.py +101 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/dataset.py +97 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/config.py +114 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/layers.py +553 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/loss.py +55 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/metric.py +86 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/model_builders.py +95 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/model_schema.py +395 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/backend_utils.py +38 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/utils.py +214 -0
- matrice_analytics/post_processing/ocr/postprocessing.py +0 -1
- matrice_analytics/post_processing/post_processor.py +32 -11
- matrice_analytics/post_processing/usecases/color/clip.py +42 -8
- matrice_analytics/post_processing/usecases/color/color_mapper.py +2 -2
- matrice_analytics/post_processing/usecases/color_detection.py +50 -129
- matrice_analytics/post_processing/usecases/drone_traffic_monitoring.py +41 -386
- matrice_analytics/post_processing/usecases/flare_analysis.py +1 -56
- matrice_analytics/post_processing/usecases/license_plate_detection.py +476 -202
- matrice_analytics/post_processing/usecases/license_plate_monitoring.py +351 -26
- matrice_analytics/post_processing/usecases/people_counting.py +408 -1431
- matrice_analytics/post_processing/usecases/people_counting_bckp.py +1683 -0
- matrice_analytics/post_processing/usecases/vehicle_monitoring.py +39 -10
- matrice_analytics/post_processing/utils/__init__.py +8 -8
- {matrice_analytics-0.1.3.dist-info → matrice_analytics-0.1.32.dist-info}/METADATA +1 -1
- {matrice_analytics-0.1.3.dist-info → matrice_analytics-0.1.32.dist-info}/RECORD +61 -26
- {matrice_analytics-0.1.3.dist-info → matrice_analytics-0.1.32.dist-info}/WHEEL +0 -0
- {matrice_analytics-0.1.3.dist-info → matrice_analytics-0.1.32.dist-info}/licenses/LICENSE.txt +0 -0
- {matrice_analytics-0.1.3.dist-info → matrice_analytics-0.1.32.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,398 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Script for exporting the trained Keras models to other formats.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import pathlib
|
|
9
|
+
import shutil
|
|
10
|
+
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
|
11
|
+
|
|
12
|
+
import click
|
|
13
|
+
import keras
|
|
14
|
+
import numpy as np
|
|
15
|
+
from numpy.typing import DTypeLike
|
|
16
|
+
|
|
17
|
+
from fast_plate_ocr.cli.utils import requires
|
|
18
|
+
from fast_plate_ocr.core.types import TensorDataFormat
|
|
19
|
+
from fast_plate_ocr.core.utils import log_time_taken
|
|
20
|
+
from fast_plate_ocr.train.model.config import (
|
|
21
|
+
PlateOCRConfig,
|
|
22
|
+
load_plate_config_from_yaml,
|
|
23
|
+
)
|
|
24
|
+
from fast_plate_ocr.train.utilities.utils import load_keras_model
|
|
25
|
+
from typing import Optional
|
|
26
|
+
|
|
27
|
+
# ruff: noqa: PLC0415
|
|
28
|
+
# pylint: disable=too-many-arguments,too-many-locals,import-outside-toplevel
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _dummy_input(b: int, h: int, w: int, n_c: int, dtype: DTypeLike = np.uint8) -> np.ndarray:
|
|
32
|
+
"""Random tensor in [0, 255] shaped (b, h, w, 1)."""
|
|
33
|
+
return np.random.randint(0, 256, size=(b, h, w, n_c)).astype(dtype)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _validate_prediction(
|
|
37
|
+
keras_model: keras.Model,
|
|
38
|
+
exported_predict,
|
|
39
|
+
x: np.ndarray,
|
|
40
|
+
target: str,
|
|
41
|
+
rtol: float = 1e-4,
|
|
42
|
+
atol: float = 1e-4,
|
|
43
|
+
) -> None:
|
|
44
|
+
"""Compare Keras and exported backend on a single forward pass."""
|
|
45
|
+
keras_out = keras_model.predict(x, verbose=0)
|
|
46
|
+
exported_out = exported_predict(x)
|
|
47
|
+
if not np.allclose(keras_out, exported_out, rtol=rtol, atol=atol):
|
|
48
|
+
logging.warning("%s output deviates from Keras beyond tolerance.", target.upper())
|
|
49
|
+
else:
|
|
50
|
+
logging.info("%s output matches Keras ✔", target.upper())
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _make_output_path(
|
|
54
|
+
model_path: pathlib.Path, new_ext: str, save_dir: Optional[pathlib.Path] = None
|
|
55
|
+
) -> pathlib.Path:
|
|
56
|
+
"""
|
|
57
|
+
Build an output filename next to the model or inside --save-dir.
|
|
58
|
+
|
|
59
|
+
Note: If the file already exists we delete it.
|
|
60
|
+
|
|
61
|
+
:param model_path: Path to the model file.
|
|
62
|
+
:param save_dir: Directory to save the exported model.
|
|
63
|
+
:param new_ext: Extension to append to the model filename.
|
|
64
|
+
:return: Path to the output file.
|
|
65
|
+
"""
|
|
66
|
+
out_file = model_path.with_suffix(new_ext)
|
|
67
|
+
if save_dir is not None:
|
|
68
|
+
out_file = save_dir / out_file.name
|
|
69
|
+
|
|
70
|
+
if out_file.exists():
|
|
71
|
+
logging.info("Overwriting existing %s", out_file)
|
|
72
|
+
if out_file.is_dir():
|
|
73
|
+
shutil.rmtree(out_file)
|
|
74
|
+
else:
|
|
75
|
+
out_file.unlink()
|
|
76
|
+
|
|
77
|
+
return out_file
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _prepare_model_for_onnx_export(
|
|
81
|
+
model: keras.Model,
|
|
82
|
+
plate_config: PlateOCRConfig,
|
|
83
|
+
dynamic_batch: bool,
|
|
84
|
+
input_dtype: str,
|
|
85
|
+
data_format: TensorDataFormat,
|
|
86
|
+
):
|
|
87
|
+
"""
|
|
88
|
+
Prepare a Keras model for ONNX export by adjusting input layout if needed.
|
|
89
|
+
|
|
90
|
+
The model is only wrapped when 'channels_first' (NxCxHxW) format is requested, by inserting a
|
|
91
|
+
Permute layer to convert NxCxHxW to NxHxWxC (the model's expected input).
|
|
92
|
+
"""
|
|
93
|
+
if data_format == "channels_first":
|
|
94
|
+
# NxCxHxW -> NxHxWxC
|
|
95
|
+
inp_shape = (
|
|
96
|
+
plate_config.num_channels,
|
|
97
|
+
plate_config.img_height,
|
|
98
|
+
plate_config.img_width,
|
|
99
|
+
)
|
|
100
|
+
x_in = keras.Input(shape=inp_shape, dtype=input_dtype, name="input_nchw")
|
|
101
|
+
x_out = model(keras.layers.Permute((2, 3, 1))(x_in))
|
|
102
|
+
export_model = keras.Model(x_in, x_out, name=f"{model.name}_nchw")
|
|
103
|
+
else:
|
|
104
|
+
# Default is channels last (NxHxWxC), keep the original graph
|
|
105
|
+
inp_shape = (
|
|
106
|
+
plate_config.img_height,
|
|
107
|
+
plate_config.img_width,
|
|
108
|
+
plate_config.num_channels,
|
|
109
|
+
)
|
|
110
|
+
export_model = model
|
|
111
|
+
|
|
112
|
+
batch_dim = None if dynamic_batch else 1
|
|
113
|
+
spec_shape = (batch_dim, *inp_shape)
|
|
114
|
+
dummy_input = np.random.randint(0, 256, size=(1, *inp_shape)).astype(input_dtype)
|
|
115
|
+
return export_model, spec_shape, dummy_input
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@requires("onnx", "onnxruntime", "onnxslim")
|
|
119
|
+
def export_onnx(
|
|
120
|
+
model: keras.Model,
|
|
121
|
+
plate_config: PlateOCRConfig,
|
|
122
|
+
out_file: pathlib.Path,
|
|
123
|
+
simplify: bool,
|
|
124
|
+
dynamic_batch: bool,
|
|
125
|
+
skip_validation: bool = False,
|
|
126
|
+
onnx_input_dtype: str = "uint8",
|
|
127
|
+
onnx_data_format: TensorDataFormat = "channels_last",
|
|
128
|
+
) -> None:
|
|
129
|
+
import onnxruntime as rt
|
|
130
|
+
|
|
131
|
+
export_model, spec_shape, dummy_input = _prepare_model_for_onnx_export(
|
|
132
|
+
model, plate_config, dynamic_batch, onnx_input_dtype, onnx_data_format
|
|
133
|
+
)
|
|
134
|
+
spec = [keras.InputSpec(name="input", shape=spec_shape, dtype=onnx_input_dtype)]
|
|
135
|
+
|
|
136
|
+
with NamedTemporaryFile(suffix=".onnx") as tmp:
|
|
137
|
+
export_model.export(tmp.name, format="onnx", verbose=False, input_signature=spec)
|
|
138
|
+
|
|
139
|
+
if simplify:
|
|
140
|
+
import onnx
|
|
141
|
+
import onnxslim
|
|
142
|
+
|
|
143
|
+
logging.info("Simplifying ONNX ...")
|
|
144
|
+
model_simp = onnxslim.slim(onnx.load(tmp.name))
|
|
145
|
+
onnx.save(model_simp, out_file)
|
|
146
|
+
else:
|
|
147
|
+
shutil.copy(tmp.name, out_file)
|
|
148
|
+
|
|
149
|
+
# Load the newly converted ONNX model
|
|
150
|
+
sess = rt.InferenceSession(out_file)
|
|
151
|
+
input_name = sess.get_inputs()[0].name
|
|
152
|
+
output_names = [o.name for o in sess.get_outputs()]
|
|
153
|
+
|
|
154
|
+
def _predict(x: np.ndarray):
|
|
155
|
+
return sess.run(output_names, {input_name: x})[0]
|
|
156
|
+
|
|
157
|
+
if skip_validation:
|
|
158
|
+
logging.info("Skipping ONNX validation.")
|
|
159
|
+
else:
|
|
160
|
+
_validate_prediction(export_model, _predict, dummy_input, "ONNX")
|
|
161
|
+
|
|
162
|
+
with log_time_taken("ONNX inference time"):
|
|
163
|
+
_predict(dummy_input)
|
|
164
|
+
|
|
165
|
+
logging.info("Saved ONNX model to %s", out_file)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
@requires("tensorflow")
|
|
169
|
+
def export_tflite(
|
|
170
|
+
model: keras.Model,
|
|
171
|
+
plate_config: PlateOCRConfig,
|
|
172
|
+
out_file: pathlib.Path,
|
|
173
|
+
skip_validation: bool = False,
|
|
174
|
+
) -> None:
|
|
175
|
+
import tensorflow as tf
|
|
176
|
+
|
|
177
|
+
with TemporaryDirectory() as tmp_dir:
|
|
178
|
+
model.export(tmp_dir, format="tf_saved_model")
|
|
179
|
+
|
|
180
|
+
converter = tf.lite.TFLiteConverter.from_saved_model(tmp_dir)
|
|
181
|
+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
|
182
|
+
|
|
183
|
+
tflite_bytes = converter.convert()
|
|
184
|
+
out_file.write_bytes(tflite_bytes)
|
|
185
|
+
|
|
186
|
+
if skip_validation:
|
|
187
|
+
logging.info("Skipping TFLite validation.")
|
|
188
|
+
logging.info("Saved TFLite model to %s", out_file)
|
|
189
|
+
return
|
|
190
|
+
|
|
191
|
+
class _TFLiteRunner:
|
|
192
|
+
def __init__(self, path):
|
|
193
|
+
self.interp = tf.lite.Interpreter(str(path))
|
|
194
|
+
self.interp.allocate_tensors()
|
|
195
|
+
self.inp = self.interp.get_input_details()[0]["index"]
|
|
196
|
+
self.out = self.interp.get_output_details()[0]["index"]
|
|
197
|
+
|
|
198
|
+
def __call__(self, x: np.ndarray):
|
|
199
|
+
self.interp.set_tensor(self.inp, x)
|
|
200
|
+
self.interp.invoke()
|
|
201
|
+
return self.interp.get_tensor(self.out)
|
|
202
|
+
|
|
203
|
+
tfl_runner = _TFLiteRunner(out_file)
|
|
204
|
+
_validate_prediction(
|
|
205
|
+
model,
|
|
206
|
+
tfl_runner,
|
|
207
|
+
_dummy_input(
|
|
208
|
+
1,
|
|
209
|
+
plate_config.img_height,
|
|
210
|
+
plate_config.img_width,
|
|
211
|
+
plate_config.num_channels,
|
|
212
|
+
np.float32,
|
|
213
|
+
),
|
|
214
|
+
"TFLite",
|
|
215
|
+
atol=5e-3,
|
|
216
|
+
rtol=5e-3,
|
|
217
|
+
)
|
|
218
|
+
logging.info("Saved TFLite model to %s", out_file)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
@requires("coremltools", "tensorflow")
|
|
222
|
+
def export_coreml(
|
|
223
|
+
model: keras.Model,
|
|
224
|
+
plate_config: PlateOCRConfig,
|
|
225
|
+
out_file: pathlib.Path,
|
|
226
|
+
skip_validation: bool = False,
|
|
227
|
+
) -> None:
|
|
228
|
+
import coremltools as ct
|
|
229
|
+
import tensorflow as tf
|
|
230
|
+
|
|
231
|
+
with TemporaryDirectory() as tmp_dir:
|
|
232
|
+
model.export(tmp_dir, format="tf_saved_model")
|
|
233
|
+
loaded = tf.saved_model.load(tmp_dir)
|
|
234
|
+
func = loaded.signatures["serving_default"]
|
|
235
|
+
|
|
236
|
+
ct_inputs = [
|
|
237
|
+
ct.TensorType(
|
|
238
|
+
shape=(
|
|
239
|
+
1,
|
|
240
|
+
plate_config.img_height,
|
|
241
|
+
plate_config.img_width,
|
|
242
|
+
plate_config.num_channels,
|
|
243
|
+
),
|
|
244
|
+
dtype=np.float32,
|
|
245
|
+
)
|
|
246
|
+
]
|
|
247
|
+
mlmodel = ct.convert(
|
|
248
|
+
[func],
|
|
249
|
+
source="tensorflow",
|
|
250
|
+
convert_to="mlprogram",
|
|
251
|
+
inputs=ct_inputs,
|
|
252
|
+
)
|
|
253
|
+
mlmodel.save(str(out_file))
|
|
254
|
+
|
|
255
|
+
if skip_validation:
|
|
256
|
+
logging.info("Skipping CoreML validation.")
|
|
257
|
+
return
|
|
258
|
+
|
|
259
|
+
mlmodel = ct.models.MLModel(str(out_file))
|
|
260
|
+
|
|
261
|
+
spec = mlmodel.get_spec()
|
|
262
|
+
input_name = spec.description.input[0].name
|
|
263
|
+
output_name = spec.description.output[0].name
|
|
264
|
+
|
|
265
|
+
def _predict(x: np.ndarray):
|
|
266
|
+
return mlmodel.predict({input_name: x})[output_name]
|
|
267
|
+
|
|
268
|
+
_validate_prediction(
|
|
269
|
+
model,
|
|
270
|
+
_predict,
|
|
271
|
+
_dummy_input(
|
|
272
|
+
1,
|
|
273
|
+
plate_config.img_height,
|
|
274
|
+
plate_config.img_width,
|
|
275
|
+
plate_config.num_channels,
|
|
276
|
+
np.float32,
|
|
277
|
+
),
|
|
278
|
+
"CoreML",
|
|
279
|
+
)
|
|
280
|
+
logging.info("Saved CoreML model to %s", out_file)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
@click.command(context_settings={"max_content_width": 120})
|
|
284
|
+
@click.option(
|
|
285
|
+
"-m",
|
|
286
|
+
"--model",
|
|
287
|
+
"model_path",
|
|
288
|
+
required=True,
|
|
289
|
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=pathlib.Path),
|
|
290
|
+
help="Path to the saved .keras model.",
|
|
291
|
+
)
|
|
292
|
+
@click.option(
|
|
293
|
+
"-f",
|
|
294
|
+
"--format",
|
|
295
|
+
"export_format",
|
|
296
|
+
type=click.Choice(["onnx", "tflite", "coreml"], case_sensitive=False),
|
|
297
|
+
default="onnx",
|
|
298
|
+
show_default=True,
|
|
299
|
+
help="Target export format.",
|
|
300
|
+
)
|
|
301
|
+
@click.option(
|
|
302
|
+
"--simplify/--no-simplify",
|
|
303
|
+
default=True,
|
|
304
|
+
show_default=True,
|
|
305
|
+
help="Simplify ONNX model using onnxslim (only applies when format is ONNX).",
|
|
306
|
+
)
|
|
307
|
+
@click.option(
|
|
308
|
+
"--plate-config-file",
|
|
309
|
+
required=True,
|
|
310
|
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=pathlib.Path),
|
|
311
|
+
help="Path to the model OCR config YAML.",
|
|
312
|
+
)
|
|
313
|
+
@click.option(
|
|
314
|
+
"--save-dir",
|
|
315
|
+
required=False,
|
|
316
|
+
type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=pathlib.Path),
|
|
317
|
+
help="Directory to save the exported model. Defaults to model's directory.",
|
|
318
|
+
)
|
|
319
|
+
@click.option(
|
|
320
|
+
"--dynamic-batch/--no-dynamic-batch",
|
|
321
|
+
default=True,
|
|
322
|
+
show_default=True,
|
|
323
|
+
help="Enable dynamic batch size (only applies to ONNX format).",
|
|
324
|
+
)
|
|
325
|
+
@click.option(
|
|
326
|
+
"--skip-validation/--no-skip-validation",
|
|
327
|
+
default=False,
|
|
328
|
+
show_default=True,
|
|
329
|
+
help="Skip the post-export inference validation step.",
|
|
330
|
+
)
|
|
331
|
+
@click.option(
|
|
332
|
+
"--onnx-input-dtype",
|
|
333
|
+
type=click.Choice(["uint8", "float32"], case_sensitive=False),
|
|
334
|
+
default="uint8",
|
|
335
|
+
show_default=True,
|
|
336
|
+
help="Data type of the ONNX model input.",
|
|
337
|
+
)
|
|
338
|
+
@click.option(
|
|
339
|
+
"--onnx-data-format",
|
|
340
|
+
type=click.Choice(["channels_last", "channels_first"], case_sensitive=False),
|
|
341
|
+
default="channels_last",
|
|
342
|
+
show_default=True,
|
|
343
|
+
help=(
|
|
344
|
+
"Data format of the input tensor. It can be either "
|
|
345
|
+
"'channels_last' (NHWC) or 'channels_first' (NCHW)."
|
|
346
|
+
),
|
|
347
|
+
)
|
|
348
|
+
def export( # noqa: PLR0913
|
|
349
|
+
model_path: pathlib.Path,
|
|
350
|
+
export_format: str,
|
|
351
|
+
simplify: bool,
|
|
352
|
+
plate_config_file: pathlib.Path,
|
|
353
|
+
save_dir: pathlib.Path,
|
|
354
|
+
dynamic_batch: bool,
|
|
355
|
+
skip_validation: bool,
|
|
356
|
+
onnx_input_dtype: str,
|
|
357
|
+
onnx_data_format: TensorDataFormat,
|
|
358
|
+
) -> None:
|
|
359
|
+
"""
|
|
360
|
+
Export Keras models to other formats.
|
|
361
|
+
"""
|
|
362
|
+
|
|
363
|
+
plate_config = load_plate_config_from_yaml(plate_config_file)
|
|
364
|
+
model = load_keras_model(model_path, plate_config)
|
|
365
|
+
|
|
366
|
+
if export_format == "onnx":
|
|
367
|
+
out_file = _make_output_path(model_path, ".onnx", save_dir)
|
|
368
|
+
export_onnx(
|
|
369
|
+
model=model,
|
|
370
|
+
plate_config=plate_config,
|
|
371
|
+
out_file=out_file,
|
|
372
|
+
simplify=simplify,
|
|
373
|
+
dynamic_batch=dynamic_batch,
|
|
374
|
+
skip_validation=skip_validation,
|
|
375
|
+
onnx_input_dtype=onnx_input_dtype,
|
|
376
|
+
onnx_data_format=onnx_data_format,
|
|
377
|
+
)
|
|
378
|
+
elif export_format == "tflite":
|
|
379
|
+
out_file = _make_output_path(model_path, ".tflite", save_dir)
|
|
380
|
+
# TFLite doesn't seem to support dynamic batch size
|
|
381
|
+
# See: https://ai.google.dev/edge/litert/inference#run-inference
|
|
382
|
+
export_tflite(
|
|
383
|
+
model=model,
|
|
384
|
+
plate_config=plate_config,
|
|
385
|
+
out_file=out_file,
|
|
386
|
+
)
|
|
387
|
+
elif export_format == "coreml":
|
|
388
|
+
out_file = _make_output_path(model_path, ".mlpackage", save_dir)
|
|
389
|
+
export_coreml(
|
|
390
|
+
model=model,
|
|
391
|
+
plate_config=plate_config,
|
|
392
|
+
out_file=out_file,
|
|
393
|
+
skip_validation=skip_validation,
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
if __name__ == "__main__":
|
|
398
|
+
export()
|