edgefirst-validator 4.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. deepview/modelpack/utils/argmax.py +16 -0
  2. edgefirst/validator/__init__.py +1 -0
  3. edgefirst/validator/__main__.py +375 -0
  4. edgefirst/validator/datasets/__init__.py +118 -0
  5. edgefirst/validator/datasets/cache.py +296 -0
  6. edgefirst/validator/datasets/core.py +250 -0
  7. edgefirst/validator/datasets/darknet.py +446 -0
  8. edgefirst/validator/datasets/database.py +1067 -0
  9. edgefirst/validator/datasets/instance/__init__.py +4 -0
  10. edgefirst/validator/datasets/instance/core.py +222 -0
  11. edgefirst/validator/datasets/instance/detection.py +145 -0
  12. edgefirst/validator/datasets/instance/multitask.py +80 -0
  13. edgefirst/validator/datasets/instance/segmentation.py +120 -0
  14. edgefirst/validator/datasets/utils/fetch.py +682 -0
  15. edgefirst/validator/datasets/utils/readers.py +425 -0
  16. edgefirst/validator/datasets/utils/transformations.py +1695 -0
  17. edgefirst/validator/evaluators/__init__.py +17 -0
  18. edgefirst/validator/evaluators/callbacks/__init__.py +3 -0
  19. edgefirst/validator/evaluators/callbacks/core.py +192 -0
  20. edgefirst/validator/evaluators/callbacks/plots.py +900 -0
  21. edgefirst/validator/evaluators/callbacks/studio.py +234 -0
  22. edgefirst/validator/evaluators/core.py +257 -0
  23. edgefirst/validator/evaluators/detection.py +749 -0
  24. edgefirst/validator/evaluators/multitask.py +270 -0
  25. edgefirst/validator/evaluators/parameters/__init__.py +53 -0
  26. edgefirst/validator/evaluators/parameters/core.py +554 -0
  27. edgefirst/validator/evaluators/parameters/dataset.py +239 -0
  28. edgefirst/validator/evaluators/parameters/model.py +338 -0
  29. edgefirst/validator/evaluators/parameters/validation.py +528 -0
  30. edgefirst/validator/evaluators/segmentation.py +729 -0
  31. edgefirst/validator/evaluators/utils/__init__.py +3 -0
  32. edgefirst/validator/evaluators/utils/classify.py +292 -0
  33. edgefirst/validator/evaluators/utils/match.py +262 -0
  34. edgefirst/validator/evaluators/utils/timer.py +132 -0
  35. edgefirst/validator/metrics/__init__.py +9 -0
  36. edgefirst/validator/metrics/data/__init__.py +7 -0
  37. edgefirst/validator/metrics/data/label.py +668 -0
  38. edgefirst/validator/metrics/data/metrics.py +759 -0
  39. edgefirst/validator/metrics/data/plots.py +476 -0
  40. edgefirst/validator/metrics/data/stats.py +507 -0
  41. edgefirst/validator/metrics/detection.py +595 -0
  42. edgefirst/validator/metrics/segmentation.py +173 -0
  43. edgefirst/validator/metrics/utils/math.py +717 -0
  44. edgefirst/validator/publishers/__init__.py +3 -0
  45. edgefirst/validator/publishers/console.py +147 -0
  46. edgefirst/validator/publishers/studio.py +128 -0
  47. edgefirst/validator/publishers/tensorboard.py +119 -0
  48. edgefirst/validator/publishers/utils/logger.py +111 -0
  49. edgefirst/validator/publishers/utils/table.py +403 -0
  50. edgefirst/validator/runners/__init__.py +8 -0
  51. edgefirst/validator/runners/core.py +727 -0
  52. edgefirst/validator/runners/deepviewrt.py +177 -0
  53. edgefirst/validator/runners/hailo.py +263 -0
  54. edgefirst/validator/runners/keras.py +150 -0
  55. edgefirst/validator/runners/kinara.py +265 -0
  56. edgefirst/validator/runners/offline.py +228 -0
  57. edgefirst/validator/runners/onnx.py +241 -0
  58. edgefirst/validator/runners/processing/decode.py +320 -0
  59. edgefirst/validator/runners/processing/dvapi.py +4192 -0
  60. edgefirst/validator/runners/processing/nms.py +637 -0
  61. edgefirst/validator/runners/processing/outputs.py +507 -0
  62. edgefirst/validator/runners/tensorrt.py +321 -0
  63. edgefirst/validator/runners/tflite.py +221 -0
  64. edgefirst/validator/validate.py +843 -0
  65. edgefirst/validator/visualize/__init__.py +3 -0
  66. edgefirst/validator/visualize/detection.py +623 -0
  67. edgefirst/validator/visualize/segmentation.py +281 -0
  68. edgefirst/validator/visualize/utils/plots.py +635 -0
  69. edgefirst_validator-4.2.1.dist-info/METADATA +111 -0
  70. edgefirst_validator-4.2.1.dist-info/RECORD +73 -0
  71. edgefirst_validator-4.2.1.dist-info/WHEEL +5 -0
  72. edgefirst_validator-4.2.1.dist-info/entry_points.txt +2 -0
  73. edgefirst_validator-4.2.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,177 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from time import monotonic_ns as clock_now
5
+ from typing import TYPE_CHECKING, List, Any
6
+
7
+ import numpy as np
8
+
9
+ from edgefirst.validator.publishers.utils.logger import logger
10
+ from edgefirst.validator.runners.processing.outputs import Outputs
11
+ from edgefirst.validator.runners.core import Runner
12
+
13
+ if TYPE_CHECKING:
14
+ from edgefirst.validator.evaluators import ModelParameters, TimerContext
15
+
16
+
17
+ class DeepViewRTRunner(Runner):
18
+ """
19
+ Loads and runs DeepViewRT models using the VAAL API.
20
+
21
+ Parameters
22
+ ----------
23
+ model: List[str]
24
+ This is typically the path to the model backbone and decoder.
25
+ parameters: ModelParameters
26
+ These are the model parameters set from the command line.
27
+ metadata: dict
28
+ The model metadata which contains information for decoding
29
+ the model outputs.
30
+ timer: TimerContext
31
+ A timer object for handling validation timings for the model.
32
+
33
+ Raises
34
+ ------
35
+ ImportError
36
+ Raised if the deepview.vaal library is not found.
37
+ EnvironmentError
38
+ Raised if VAAL Context is not found.
39
+ FileNotFoundError
40
+ Raised if the path to the model does not exist.
41
+ NotImplementedError
42
+ Some methods have not been implemented yet.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ model: List[str],
48
+ parameters: ModelParameters,
49
+ metadata: dict,
50
+ timer: TimerContext
51
+ ):
52
+ super(DeepViewRTRunner, self).__init__(model, parameters, timer=timer)
53
+
54
+ try:
55
+ import deepview.vaal as vaal # type: ignore
56
+ except ImportError:
57
+ raise ImportError(
58
+ "VAAL library is needed to run DeepViewRT models.")
59
+
60
+ try:
61
+ self.ctx = vaal.Context(self.parameters.engine)
62
+ except AttributeError:
63
+ raise EnvironmentError(
64
+ 'Did not find Vaal Context. Try setting the environment \
65
+ variable VAAL_LIBRARY to the VAAL library.')
66
+
67
+ # Change because VAAL automatically uses CPU if NPU is unavailable.
68
+ self.parameters.engine = self.ctx.device
69
+
70
+ if self.parameters.max_detections is not None:
71
+ self.ctx['max_detection'] = self.parameters.max_detections
72
+ self.ctx['score_threshold'] = self.parameters.score_threshold
73
+ self.ctx['iou_threshold'] = self.parameters.iou_threshold
74
+
75
+ if (self.parameters.nms in ['standard', 'fast', 'matrix']):
76
+ self.ctx['nms_type'] = self.parameters.nms
77
+
78
+ if self.parameters.common.norm == 'raw':
79
+ self.ctx['proc'] = vaal.ImageProc.RAW
80
+ elif self.parameters.common.norm == 'signed':
81
+ self.ctx['proc'] = vaal.ImageProc.SIGNED_NORM
82
+ elif self.parameters.common.norm == 'unsigned':
83
+ self.ctx['proc'] = vaal.ImageProc.UNSIGNED_NORM
84
+ elif self.parameters.common.norm == 'whitening':
85
+ self.ctx['proc'] = vaal.ImageProc.WHITENING
86
+ elif self.parameters.common.norm == 'imagenet':
87
+ self.ctx['proc'] = vaal.ImageProc.IMAGENET
88
+ else:
89
+ logger(f"Unsupported normalization method: {self.parameters.common.norm}",
90
+ code="ERROR")
91
+
92
+ if not os.path.exists(model):
93
+ raise FileNotFoundError(
94
+ "The model '{}' does not exist.".format(model))
95
+
96
+ self.ctx.load_model(model)
97
+
98
+ if len(self.ctx.labels) > 0:
99
+ self.parameters.labels = self.ctx.labels
100
+
101
+ outputs = [{"quantization": [out.scales, out.zeros],
102
+ "shape": out.shape} for out in self.ctx.outputs]
103
+ # Parse the model output details in the metadata.
104
+ self.outputs = Outputs(
105
+ metadata=metadata,
106
+ parameters=parameters,
107
+ outputs=outputs
108
+ )
109
+
110
+ if self.parameters.warmup > 0:
111
+ self.warmup()
112
+
113
+ def warmup(self):
114
+ """
115
+ Run model warmup.
116
+ """
117
+ logger("Running model warmup...", code="INFO")
118
+
119
+ times = []
120
+ for _ in range(self.parameters.warmup):
121
+ start = clock_now()
122
+ self.ctx.run_model()
123
+ stop = clock_now() - start
124
+ times.append(stop * 1e-6)
125
+
126
+ outputs = []
127
+ for x in self.ctx.outputs:
128
+ outputs.append(x.array())
129
+
130
+ # Warmup output postprocessing.
131
+ if len(outputs):
132
+ self.postprocessing(outputs)
133
+
134
+ message = "model warmup took %f ms (%f ms avg)" % (np.sum(times),
135
+ np.average(times))
136
+ logger(message, code="INFO")
137
+ self.timer.reset()
138
+
139
+ def run_single_instance(self, image: str) -> Any:
140
+ """
141
+ Run two stage DeepViewRT inference
142
+ on a single image and record the timings.
143
+
144
+ Parameters
145
+ ----------
146
+ image: str
147
+ The path to the image. This is used to match the
148
+ annotation to be read.
149
+
150
+ Returns
151
+ -------
152
+ Any
153
+ This could either return detection outputs after NMS.
154
+ np.ndarray
155
+ The prediction bounding boxes.. [[box1], [box2], ...].
156
+ np.ndarray
157
+ The prediction labels.. [cl1, cl2, ...].
158
+ np.ndarray
159
+ The prediction confidence scores.. [score, score, ...]
160
+ normalized between 0 and 1.
161
+ This could also return segmentation masks.
162
+ np.ndarray
163
+ """
164
+ # Preprocessing
165
+ with self.timer.time("input"):
166
+ self.ctx.load_image(image)
167
+
168
+ # Inference
169
+ with self.timer.time("inference"):
170
+ self.ctx.run_model()
171
+
172
+ outputs = []
173
+ for x in self.ctx.outputs:
174
+ outputs.append(x.array())
175
+
176
+ # Postprocessing
177
+ return self.postprocessing(outputs)
@@ -0,0 +1,263 @@
1
+ """
2
+ This implementation is currently deprecated!
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from timeit import timeit
8
+ from time import monotonic_ns as clock_now
9
+ from typing import TYPE_CHECKING, Tuple, Union
10
+
11
+ import numpy as np
12
+
13
+ from edgefirst.validator.publishers.utils.logger import logger
14
+ from edgefirst.validator.datasets.utils.transformations import (resize, pad)
15
+ from edgefirst.validator.runners.core import Runner
16
+
17
+ if TYPE_CHECKING:
18
+ from edgefirst.validator.evaluators import Parameters
19
+
20
+
21
+ class HailoRunner(Runner):
22
+ """
23
+ Runs Hailo models.
24
+
25
+ Parameters
26
+ ----------
27
+ model: str
28
+ The path to the model or the loaded Hailo model.
29
+
30
+ parameters: Parameters
31
+ These are the model parameters set from the command line.
32
+
33
+ labels: list
34
+ Unique string labels.
35
+
36
+ Raises
37
+ ------
38
+ ImportError
39
+ Raised if hailo_platform library is not intalled.
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ model,
45
+ parameters: Parameters,
46
+ labels: list = None
47
+ ):
48
+ super(HailoRunner, self).__init__(model, parameters, labels)
49
+
50
+ try:
51
+ from hailo_platform import ( # type: ignore
52
+ HEF, Device, VDevice, HailoStreamInterface, ConfigureParams,
53
+ InputVStreamParams, OutputVStreamParams, FormatType,
54
+ InferVStreams)
55
+ except ImportError:
56
+ raise ImportError(
57
+ "hailo_platform library is needed to run Hailo models.")
58
+
59
+ if isinstance(model, str):
60
+ model = self.validate_model_path(model)
61
+ devices = Device.scan()
62
+ self.hef = HEF(model)
63
+ self.target = VDevice(device_ids=devices)
64
+ configure_params = ConfigureParams.create_from_hef(
65
+ self.hef, interface=HailoStreamInterface.PCIe)
66
+ self.network_group = self.target.configure(
67
+ self.hef, configure_params)[0]
68
+ self.network_group_params = self.network_group.create_params()
69
+
70
+ self.input_vstreams_params = InputVStreamParams.make_from_network_group(
71
+ self.network_group, quantized=False, format_type=FormatType.FLOAT32)
72
+ self.output_vstreams_params = OutputVStreamParams.make_from_network_group(
73
+ self.network_group, quantized=False, format_type=FormatType.FLOAT32)
74
+ else:
75
+ raise ValueError("Only string filepaths are supported")
76
+
77
+ self.parameters.engine = "hailo"
78
+
79
+ if self.parameters.warmup > 0:
80
+ input_vstream_info = self.hef.get_input_vstream_infos()[0]
81
+ with InferVStreams(self.network_group, self.input_vstreams_params, self.output_vstreams_params) as infer_pipeline:
82
+ # Produce a sample image of zeros.
83
+ input_type = "float32" if "float" in self.get_input_type() else "uint32"
84
+ height, width = self.get_input_shape()[1:3]
85
+ image = np.expand_dims(np.zeros((height, width, 3)), 0).astype(
86
+ np.dtype(input_type))
87
+ input_data = {input_vstream_info.name: np.expand_dims(
88
+ image, axis=0).astype(np.float32)}
89
+
90
+ with self.network_group.activate(self.network_group_params):
91
+ logger("Loading model and warmup...", code="INFO")
92
+ t = timeit(lambda: infer_pipeline.infer(input_data),
93
+ number=self.parameters.warmup)
94
+ logger("model warmup took %f seconds (%f ms avg)" %
95
+ (t, t * 1000 / self.parameters.warmup), code="INFO")
96
+
97
+ def run_single_instance(
98
+ self,
99
+ image: Union[str, np.ndarray]
100
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
101
+ """
102
+ Produce Hailo inference on one image and records the timings.
103
+
104
+ Parameters
105
+ ----------
106
+ image: str or np.ndarray
107
+ The path to the image or NumPy array image.
108
+
109
+ Returns
110
+ -------
111
+ nmsed_boxes: np.ndarray
112
+ The prediction bounding boxes.. [[box1], [box2], ...].
113
+
114
+ nmsed_classes: np.ndarray
115
+ The prediction labels.. [cl1, cl2, ...].
116
+
117
+ nmsed_scores: np.ndarray
118
+ The prediction confidence scores.. [score, score, ...]
119
+ normalized between 0 and 1.
120
+ """
121
+ try:
122
+ from hailo_platform import InferVStreams # type: ignore
123
+ except ImportError:
124
+ raise ImportError(
125
+ "hailo_platform library is needed to run Hailo models.")
126
+ input_vstream_info = self.hef.get_input_vstream_infos()[0]
127
+ output_name = self.hef.get_output_vstream_infos()[0].name
128
+
129
+ with InferVStreams(
130
+ self.network_group,
131
+ self.input_vstreams_params,
132
+ self.output_vstreams_params
133
+ ) as infer_pipeline:
134
+ infer_pipeline.set_nms_iou_threshold(self.parameters.iou_threshold)
135
+ infer_pipeline.set_nms_score_threshold(
136
+ self.parameters.score_threshold)
137
+
138
+ """Inference"""
139
+ start = clock_now()
140
+ input_data = {
141
+ input_vstream_info.name: image
142
+ }
143
+ with self.network_group.activate(self.network_group_params):
144
+ raw_detections = infer_pipeline.infer(input_data)
145
+ infer_ns = clock_now() - start
146
+ self.backbone_timings.append(infer_ns * 1e-6)
147
+
148
+ """Postprocessing"""
149
+ # An output with 7 columns refers to batch_id, xmin, ymin, xmax, ymax, cls, score.
150
+ # Otherwise it is batch_size, number of boxes, number of classes
151
+ # which needs external NMS.
152
+ # Decoder and box timings are measured in this function.
153
+ nmsed_boxes, nmsed_classes, nmsed_scores = self.postprocessing(
154
+ raw_detections[output_name][0])
155
+
156
+ return nmsed_boxes, nmsed_classes, nmsed_scores
157
+
158
+ def postprocessing(
159
+ self, output: list) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
160
+ """
161
+ Retrieves the boxes, scores and labels.
162
+
163
+ Parameters
164
+ ----------
165
+ outputs:
166
+ This contains bounding boxes, scores, labels in the format.
167
+ [[xmin, ymin, xmax, ymax, confidence, label], [...], ...].
168
+
169
+ Returns
170
+ -------
171
+ nmsed_boxes: np.ndarray
172
+ The prediction bounding boxes.. [[box1], [box2], ...].
173
+
174
+ nmsed_classes: np.ndarray
175
+ The prediction labels.. [cl1, cl2, ...].
176
+
177
+ nmsed_scores: np.ndarray
178
+ The prediction confidence scores.. [score, score, ...]
179
+ normalized between 0 and 1.
180
+ """
181
+ boxes, classes, scores = list(), list(), list()
182
+ num_detections = 0
183
+
184
+ for i, detection in enumerate(output):
185
+ if len(detection) == 0:
186
+ continue
187
+ for j in range(len(detection)):
188
+ bbox = np.array(detection)[j][:4]
189
+ score = np.array(detection)[j][4]
190
+ xyxy_bbox = np.asarray([bbox[1], bbox[0], bbox[3], bbox[2]])
191
+ boxes.append(xyxy_bbox)
192
+ scores.append(score)
193
+ classes.append(i)
194
+ num_detections = num_detections + 1
195
+
196
+ nmsed_classes = np.asarray(classes)
197
+ if self.parameters.label_offset != 0:
198
+ nmsed_classes += self.parameters.label_offset
199
+ return np.asarray(boxes), nmsed_classes, np.asarray(scores)
200
+
201
+ def get_input_type(self) -> str:
202
+ """
203
+ This returns the input type of the model.
204
+
205
+ Returns
206
+ -------
207
+ type: str
208
+ The input type of the model.
209
+ """
210
+ inputs = self.hef.get_input_vstream_infos()
211
+ base_format = str(inputs[0].format.type)
212
+ return base_format[base_format.find('.') + 1:].lower()
213
+
214
+ def get_input_shape(self) -> np.ndarray:
215
+ """
216
+ Grabs the model input shape.
217
+
218
+ Returns
219
+ -------
220
+ shape: np.ndarray
221
+ The model input shape.
222
+ (batch size, channels, height, width).
223
+ """
224
+ inputs = self.hef.get_input_vstream_infos()
225
+ return tuple([1] + list(inputs[0].shape))
226
+
227
+ def get_output_type(self) -> str:
228
+ """
229
+ This returns the output type of the model.
230
+
231
+ Returns
232
+ -------
233
+ type: str
234
+ The output type of the model.
235
+ """
236
+ outputs = self.hef.get_output_vstream_infos()
237
+ base_format = str(outputs[0].format.type)
238
+ return base_format[base_format.find('.') + 1:].lower()
239
+
240
+ def get_output_shape(self) -> np.ndarray:
241
+ """
242
+ Grabs the model output shape.
243
+
244
+ Returns
245
+ --------
246
+ shape: np.ndarray
247
+ The model output shape.
248
+ (batch size, boxes, classes).
249
+ """
250
+ outputs = self.hef.get_output_vstream_infos()
251
+ return outputs[0].shape
252
+
253
+ def get_image_shape(self) -> tuple:
254
+ """
255
+ Returns the input image shape passed to the model.
256
+
257
+ Returns
258
+ -------
259
+ shape: tuple
260
+ The shape as (height, width).
261
+ """
262
+ _, _, height, width = self.get_input_shape()
263
+ return (height, width)
@@ -0,0 +1,150 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import TYPE_CHECKING, Any
5
+
6
+ import numpy as np
7
+ from edgefirst.validator.runners.core import Runner
8
+
9
+ if TYPE_CHECKING:
10
+ from edgefirst.validator.evaluators import ModelParameters, TimerContext
11
+
12
+
13
+ class KerasRunner(Runner):
14
+ """
15
+ Loads and runs the Keras (.h5, .keras) models using the TensorFlow library.
16
+
17
+ Parameters
18
+ ----------
19
+ model: str or tf.keras.Model
20
+ The path to the model or the loaded keras model.
21
+ parameters: ModelParameters
22
+ These are the model parameters set from the command line.
23
+ metadata: dict
24
+ The model metadata which contains information for decoding
25
+ the model outputs.
26
+ timer: TimerContext
27
+ A timer object for handling validation timings for the model.
28
+
29
+ Raises
30
+ ------
31
+ ImportError
32
+ Raised if the TensorFlow library is not installed.
33
+ FileNotFoundError
34
+ Raised if the path to the model does not exist.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ model: Any,
40
+ parameters: ModelParameters,
41
+ metadata: dict,
42
+ timer: TimerContext
43
+ ):
44
+ super(KerasRunner, self).__init__(model, parameters, timer=timer)
45
+
46
+ # Load Argmax dependency needed for keras
47
+ try:
48
+ from deepview.modelpack.utils.argmax import Argmax
49
+ except ImportError:
50
+ pass
51
+
52
+ try:
53
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
54
+ import tensorflow as tf # type: ignore
55
+ except ImportError:
56
+ raise ImportError(
57
+ "TensorFlow is needed to run keras models.")
58
+
59
+ if isinstance(model, str):
60
+ if not os.path.exists(model):
61
+ raise FileNotFoundError(
62
+ "The model '{}' does not exist.".format(model))
63
+
64
+ if os.path.exists(os.path.join(model, "saved_model.pb")):
65
+ self.model = tf.saved_model.load(model)
66
+ self.input = self.model.signatures["serving_default"].inputs
67
+ outputs = self.model.signatures["serving_default"].outputs
68
+ else:
69
+ self.model = tf.keras.models.load_model(model, compile=False)
70
+ outputs = self.model.output
71
+ self.input = self.model.input
72
+ self.init_decoder(metadata=metadata, outputs=outputs)
73
+
74
+ if self.parameters.warmup > 0:
75
+ self.warmup()
76
+
77
+ def run_single_instance(self, image: np.ndarray) -> Any:
78
+ """
79
+ Run Keras inference on a single image and record the timings.
80
+
81
+ Parameters
82
+ ----------
83
+ image: np.ndarray
84
+ The input image after being preprocessed.
85
+ Typically this is an RGB image array.
86
+
87
+ Returns
88
+ -------
89
+ Any
90
+ This could either return detection outputs after NMS.
91
+ np.ndarray
92
+ The prediction bounding boxes.. [[box1], [box2], ...].
93
+ np.ndarray
94
+ The prediction labels.. [cl1, cl2, ...].
95
+ np.ndarray
96
+ The prediction confidence scores.. [score, score, ...]
97
+ normalized between 0 and 1.
98
+ This could also return segmentation masks.
99
+ np.ndarray
100
+ """
101
+ # Inference
102
+ with self.timer.time("inference"):
103
+ outputs = self.model(image)
104
+
105
+ # Postprocessing
106
+ return self.postprocessing(outputs)
107
+
108
+ def get_input_type(self) -> np.dtype:
109
+ """
110
+ This returns the input type of the model with shape
111
+ (batch size, channels, height, width) or
112
+ (batch size, height, width, channels).
113
+
114
+ Returns
115
+ -------
116
+ np.dtype
117
+ The input type of the model.
118
+ """
119
+ try:
120
+ try:
121
+ return self.model.input.dtype.as_numpy_dtype
122
+ except AttributeError:
123
+ return np.dtype(self.model.input.dtype)
124
+ except AttributeError:
125
+ for input in self.input:
126
+ shape = input.shape
127
+ if len(shape) == 4:
128
+ if shape[1] == 3 or shape[-1] == [3]:
129
+ return input.dtype
130
+ return self.input[0].dtype
131
+
132
+ def get_input_shape(self) -> np.ndarray:
133
+ """
134
+ Grabs the model input shape.
135
+
136
+ Returns
137
+ -------
138
+ np.ndarray
139
+ The model input shape (batch size, channels, height, width) or
140
+ (batch size, height, width, channels).
141
+ """
142
+ try:
143
+ return self.model.input.shape
144
+ except AttributeError:
145
+ for input in self.input:
146
+ shape = input.shape
147
+ if len(shape) == 4:
148
+ if shape[1] == 3 or shape[-1] == [3]:
149
+ return shape
150
+ return self.input[0].shape