aissemble-inference-core 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aissemble_inference_core/__init__.py +19 -0
- aissemble_inference_core/client/__init__.py +23 -0
- aissemble_inference_core/client/builder/__init__.py +36 -0
- aissemble_inference_core/client/builder/inference_builder.py +178 -0
- aissemble_inference_core/client/builder/object_detection_builder.py +190 -0
- aissemble_inference_core/client/builder/raw_inference_builder.py +95 -0
- aissemble_inference_core/client/builder/summarization_builder.py +213 -0
- aissemble_inference_core/client/inference_client.py +158 -0
- aissemble_inference_core/client/oip_adapter.py +211 -0
- aissemble_inference_core/client/predictor.py +75 -0
- aissemble_inference_core/client/registry.py +201 -0
- aissemble_inference_core/client/results/__init__.py +29 -0
- aissemble_inference_core/client/results/object_detection_result.py +155 -0
- aissemble_inference_core/client/results/summarization_result.py +78 -0
- aissemble_inference_core/client/translator.py +57 -0
- aissemble_inference_core/client/translators/__init__.py +34 -0
- aissemble_inference_core/client/translators/_image_utils.py +75 -0
- aissemble_inference_core/client/translators/_tensor_utils.py +89 -0
- aissemble_inference_core/client/translators/object_detection_translator.py +212 -0
- aissemble_inference_core/client/translators/summarization_translator.py +147 -0
- aissemble_inference_core/client/translators/tensorflow_object_detection_translator.py +231 -0
- aissemble_inference_core-1.5.0.dist-info/METADATA +71 -0
- aissemble_inference_core-1.5.0.dist-info/RECORD +26 -0
- aissemble_inference_core-1.5.0.dist-info/WHEEL +4 -0
- aissemble_inference_core-1.5.0.dist-info/entry_points.txt +5 -0
- aissemble_inference_core-1.5.0.dist-info/licenses/LICENSE.txt +201 -0
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
###
|
|
2
|
+
# #%L
|
|
3
|
+
# aiSSEMBLE::Open Inference Protocol::Core
|
|
4
|
+
# %%
|
|
5
|
+
# Copyright (C) 2024 Booz Allen Hamilton Inc.
|
|
6
|
+
# %%
|
|
7
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
8
|
+
# you may not use this file except in compliance with the License.
|
|
9
|
+
# You may obtain a copy of the License at
|
|
10
|
+
#
|
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
12
|
+
#
|
|
13
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
14
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
15
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
16
|
+
# See the License for the specific language governing permissions and
|
|
17
|
+
# limitations under the License.
|
|
18
|
+
# #L%
|
|
19
|
+
###
|
|
20
|
+
"""Shared utilities for tensor validation and extraction."""
|
|
21
|
+
|
|
22
|
+
from typing import Any
|
|
23
|
+
|
|
24
|
+
from aissemble_inference_core.client.oip_adapter import TensorData
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def validate_required_tensors(
|
|
28
|
+
outputs: dict[str, TensorData], required: list[str], format_name: str
|
|
29
|
+
) -> None:
|
|
30
|
+
"""Validate that all required tensors are present in outputs.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
outputs: Dictionary mapping tensor names to TensorData
|
|
34
|
+
required: List of required tensor names
|
|
35
|
+
format_name: Name of the format for error messages (e.g., "TensorFlow")
|
|
36
|
+
|
|
37
|
+
Raises:
|
|
38
|
+
ValueError: If any required tensors are missing
|
|
39
|
+
"""
|
|
40
|
+
missing = [name for name in required if name not in outputs]
|
|
41
|
+
if missing:
|
|
42
|
+
raise ValueError(
|
|
43
|
+
f"{format_name} format requires tensors {missing} but server returned: "
|
|
44
|
+
f"{list(outputs.keys())}"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def validate_tensor_lengths_match(
|
|
49
|
+
*tensors: list[Any], names: list[str] | None = None
|
|
50
|
+
) -> None:
|
|
51
|
+
"""Validate that all tensors have the same length.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
*tensors: Variable number of tensor data lists
|
|
55
|
+
names: Optional names for error messages
|
|
56
|
+
|
|
57
|
+
Raises:
|
|
58
|
+
ValueError: If tensor lengths don't match
|
|
59
|
+
"""
|
|
60
|
+
lengths = [len(t) for t in tensors]
|
|
61
|
+
if len(set(lengths)) > 1:
|
|
62
|
+
if names:
|
|
63
|
+
details = ", ".join(f"{len(t)} {name}" for t, name in zip(tensors, names))
|
|
64
|
+
raise ValueError(f"Tensor length mismatch: {details}. Expected all equal.")
|
|
65
|
+
else:
|
|
66
|
+
raise ValueError(f"Tensor length mismatch: {lengths}. Expected all equal.")
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def unbatch_tensor(tensor: TensorData, num_items: int | None = None) -> list[Any]:
|
|
70
|
+
"""Extract data from batched tensor [1, N, ...] -> [N, ...] and apply limit.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
tensor: TensorData with potentially batched data
|
|
74
|
+
num_items: Optional limit on number of items to extract
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Unbatched list of items
|
|
78
|
+
"""
|
|
79
|
+
data = tensor.data
|
|
80
|
+
|
|
81
|
+
# Unbatch if [1, N] or [1, N, M] shape
|
|
82
|
+
if isinstance(data, list) and data and isinstance(data[0], list):
|
|
83
|
+
data = data[0]
|
|
84
|
+
|
|
85
|
+
# Apply limit if specified
|
|
86
|
+
if num_items is not None:
|
|
87
|
+
data = data[:num_items]
|
|
88
|
+
|
|
89
|
+
return data
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
###
|
|
2
|
+
# #%L
|
|
3
|
+
# aiSSEMBLE::Open Inference Protocol::Core
|
|
4
|
+
# %%
|
|
5
|
+
# Copyright (C) 2024 Booz Allen Hamilton Inc.
|
|
6
|
+
# %%
|
|
7
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
8
|
+
# you may not use this file except in compliance with the License.
|
|
9
|
+
# You may obtain a copy of the License at
|
|
10
|
+
#
|
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
12
|
+
#
|
|
13
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
14
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
15
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
16
|
+
# See the License for the specific language governing permissions and
|
|
17
|
+
# limitations under the License.
|
|
18
|
+
# #L%
|
|
19
|
+
###
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
from aissemble_inference_core.client.oip_adapter import (
|
|
23
|
+
OipRequest,
|
|
24
|
+
OipResponse,
|
|
25
|
+
TensorData,
|
|
26
|
+
)
|
|
27
|
+
from aissemble_inference_core.client.results import (
|
|
28
|
+
BoundingBox,
|
|
29
|
+
Detection,
|
|
30
|
+
ObjectDetectionResult,
|
|
31
|
+
)
|
|
32
|
+
from aissemble_inference_core.client.translator import Translator
|
|
33
|
+
from aissemble_inference_core.client.translators._image_utils import (
|
|
34
|
+
encode_image_for_oip,
|
|
35
|
+
)
|
|
36
|
+
from aissemble_inference_core.client.translators._tensor_utils import (
|
|
37
|
+
validate_required_tensors,
|
|
38
|
+
validate_tensor_lengths_match,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class DefaultObjectDetectionTranslator(Translator[Any, ObjectDetectionResult]):
|
|
43
|
+
"""A reasonable default translator for object detection payloads.
|
|
44
|
+
|
|
45
|
+
This translator handles:
|
|
46
|
+
- Auto-encoding images to base64 bytes for OIP compatibility
|
|
47
|
+
- Converting various image formats (PIL, numpy, file paths)
|
|
48
|
+
- Parsing bounding box outputs in common formats
|
|
49
|
+
- Creating ObjectDetectionResult with proper metadata
|
|
50
|
+
|
|
51
|
+
The translator expects the model to output:
|
|
52
|
+
- bboxes: [N, 4] tensor with coordinates (x1, y1, x2, y2)
|
|
53
|
+
- labels: [N] tensor with class labels
|
|
54
|
+
- scores: [N] tensor with confidence scores
|
|
55
|
+
|
|
56
|
+
Thread Safety:
|
|
57
|
+
This translator is thread-safe and stateless. Image dimensions are stored in
|
|
58
|
+
request parameters and retrieved from response parameters.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
input_name: str = "image",
|
|
64
|
+
bbox_output_name: str = "bboxes",
|
|
65
|
+
label_output_name: str = "labels",
|
|
66
|
+
score_output_name: str = "scores",
|
|
67
|
+
):
|
|
68
|
+
"""Initialize the translator with configurable tensor names.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
input_name: Name of the input tensor (default: "image")
|
|
72
|
+
bbox_output_name: Name of bounding box output tensor (default: "bboxes")
|
|
73
|
+
label_output_name: Name of label output tensor (default: "labels")
|
|
74
|
+
score_output_name: Name of score output tensor (default: "scores")
|
|
75
|
+
"""
|
|
76
|
+
self.input_name = input_name
|
|
77
|
+
self.bbox_output_name = bbox_output_name
|
|
78
|
+
self.label_output_name = label_output_name
|
|
79
|
+
self.score_output_name = score_output_name
|
|
80
|
+
|
|
81
|
+
def preprocess(self, input_data: Any) -> OipRequest: # noqa: A003
|
|
82
|
+
"""Preprocess image input into an OipRequest.
|
|
83
|
+
|
|
84
|
+
Accepts various input formats:
|
|
85
|
+
- PIL Image
|
|
86
|
+
- numpy array
|
|
87
|
+
- file path (string)
|
|
88
|
+
- bytes
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
input_data: Image data in supported format
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
OipRequest with encoded image tensor and dimensions in parameters
|
|
95
|
+
"""
|
|
96
|
+
image_bytes, width, height = encode_image_for_oip(input_data)
|
|
97
|
+
|
|
98
|
+
tensor = TensorData(
|
|
99
|
+
name=self.input_name,
|
|
100
|
+
shape=[1, len(image_bytes)],
|
|
101
|
+
datatype="BYTES",
|
|
102
|
+
data=[[image_bytes]],
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# Store image dimensions in request parameters for stateless operation
|
|
106
|
+
return OipRequest(
|
|
107
|
+
inputs=[tensor],
|
|
108
|
+
parameters={"_image_width": width, "_image_height": height},
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
def postprocess(self, response: OipResponse) -> ObjectDetectionResult:
|
|
112
|
+
"""Postprocess OipResponse into ObjectDetectionResult.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
response: OIP response containing detection outputs and dimensions
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
ObjectDetectionResult with parsed detections
|
|
119
|
+
|
|
120
|
+
Raises:
|
|
121
|
+
ValueError: If required tensors are missing or dimensions not available
|
|
122
|
+
"""
|
|
123
|
+
outputs = {out.name: out for out in response.outputs}
|
|
124
|
+
|
|
125
|
+
# Validate required tensors are present
|
|
126
|
+
required = [
|
|
127
|
+
self.bbox_output_name,
|
|
128
|
+
self.label_output_name,
|
|
129
|
+
self.score_output_name,
|
|
130
|
+
]
|
|
131
|
+
validate_required_tensors(outputs, required, "Default object detection")
|
|
132
|
+
|
|
133
|
+
# Get image dimensions from response parameters
|
|
134
|
+
params = response.parameters or {}
|
|
135
|
+
image_width = params.get("_image_width")
|
|
136
|
+
image_height = params.get("_image_height")
|
|
137
|
+
|
|
138
|
+
if image_width is None or image_height is None:
|
|
139
|
+
raise ValueError(
|
|
140
|
+
"Image dimensions not found in response parameters. "
|
|
141
|
+
"Ensure the OIP adapter preserves request parameters in the response."
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
bbox_tensor = outputs[self.bbox_output_name]
|
|
145
|
+
bboxes = self._extract_bboxes(bbox_tensor)
|
|
146
|
+
labels = self._extract_tensor_data(outputs[self.label_output_name])
|
|
147
|
+
scores = self._extract_tensor_data(outputs[self.score_output_name])
|
|
148
|
+
|
|
149
|
+
# Validate all tensors have same length
|
|
150
|
+
validate_tensor_lengths_match(
|
|
151
|
+
bboxes, labels, scores, names=["bboxes", "labels", "scores"]
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
detections = []
|
|
155
|
+
for bbox, label, score in zip(bboxes, labels, scores):
|
|
156
|
+
detection = Detection(
|
|
157
|
+
bbox=BoundingBox(
|
|
158
|
+
x1=float(bbox[0]),
|
|
159
|
+
y1=float(bbox[1]),
|
|
160
|
+
x2=float(bbox[2]),
|
|
161
|
+
y2=float(bbox[3]),
|
|
162
|
+
),
|
|
163
|
+
label=str(label),
|
|
164
|
+
confidence=float(score),
|
|
165
|
+
)
|
|
166
|
+
detections.append(detection)
|
|
167
|
+
|
|
168
|
+
return ObjectDetectionResult(
|
|
169
|
+
detections=detections,
|
|
170
|
+
image_width=image_width,
|
|
171
|
+
image_height=image_height,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
def _extract_bboxes(self, tensor: TensorData) -> list[list[float]]:
|
|
175
|
+
"""Extract bounding boxes from tensor, reshaping if necessary.
|
|
176
|
+
|
|
177
|
+
OIP may flatten [N, 4] tensor data into a flat list. This method
|
|
178
|
+
reconstructs the bounding box structure based on the tensor shape.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
tensor: TensorData with bounding box data
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
List of [x1, y1, x2, y2] coordinate lists
|
|
185
|
+
"""
|
|
186
|
+
data = self._extract_tensor_data(tensor)
|
|
187
|
+
if not data:
|
|
188
|
+
return []
|
|
189
|
+
|
|
190
|
+
if isinstance(data[0], list):
|
|
191
|
+
return data
|
|
192
|
+
|
|
193
|
+
shape = tensor.shape
|
|
194
|
+
if len(shape) == 2 and shape[1] == 4:
|
|
195
|
+
num_boxes = shape[0]
|
|
196
|
+
return [data[i * 4 : (i + 1) * 4] for i in range(num_boxes)]
|
|
197
|
+
|
|
198
|
+
return data
|
|
199
|
+
|
|
200
|
+
def _extract_tensor_data(self, tensor: TensorData) -> list[Any]:
|
|
201
|
+
"""Extract the actual data from a tensor, flattening if necessary.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
tensor: TensorData object
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
Flattened list of tensor values
|
|
208
|
+
"""
|
|
209
|
+
data = tensor.data
|
|
210
|
+
if isinstance(data, list) and len(data) > 0 and isinstance(data[0], list):
|
|
211
|
+
return data[0]
|
|
212
|
+
return data
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
###
|
|
2
|
+
# #%L
|
|
3
|
+
# aiSSEMBLE::Open Inference Protocol::Core
|
|
4
|
+
# %%
|
|
5
|
+
# Copyright (C) 2024 Booz Allen Hamilton Inc.
|
|
6
|
+
# %%
|
|
7
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
8
|
+
# you may not use this file except in compliance with the License.
|
|
9
|
+
# You may obtain a copy of the License at
|
|
10
|
+
#
|
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
12
|
+
#
|
|
13
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
14
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
15
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
16
|
+
# See the License for the specific language governing permissions and
|
|
17
|
+
# limitations under the License.
|
|
18
|
+
# #L%
|
|
19
|
+
###
|
|
20
|
+
"""Translator for text summarization tasks."""
|
|
21
|
+
|
|
22
|
+
from aissemble_inference_core.client.oip_adapter import (
|
|
23
|
+
OipRequest,
|
|
24
|
+
OipResponse,
|
|
25
|
+
TensorData,
|
|
26
|
+
)
|
|
27
|
+
from aissemble_inference_core.client.results.summarization_result import (
|
|
28
|
+
SummarizationResult,
|
|
29
|
+
)
|
|
30
|
+
from aissemble_inference_core.client.translator import Translator
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class DefaultSummarizationTranslator(Translator[str, SummarizationResult]):
|
|
34
|
+
"""Default translator for text summarization inference.
|
|
35
|
+
|
|
36
|
+
This translator converts input text to OIP format and parses OIP responses
|
|
37
|
+
into SummarizationResult objects. It is the only component aware of the
|
|
38
|
+
OIP JSON schema and tensor details for summarization.
|
|
39
|
+
|
|
40
|
+
The translator is configurable to work with different model implementations
|
|
41
|
+
that may use different tensor names or parameter conventions.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
input_name: str = "text",
|
|
47
|
+
output_name: str = "summary",
|
|
48
|
+
):
|
|
49
|
+
"""Initialize the summarization translator.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
input_name: Name of the input tensor (default: "text")
|
|
53
|
+
output_name: Name of the output tensor (default: "summary")
|
|
54
|
+
"""
|
|
55
|
+
self.input_name = input_name
|
|
56
|
+
self.output_name = output_name
|
|
57
|
+
self._original_text: str = ""
|
|
58
|
+
|
|
59
|
+
def preprocess(self, text: str) -> OipRequest:
|
|
60
|
+
"""Convert input text to OIP request format.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
text: The text to summarize
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
OipRequest with text as BYTES tensor
|
|
67
|
+
|
|
68
|
+
Raises:
|
|
69
|
+
ValueError: If text is empty or invalid
|
|
70
|
+
"""
|
|
71
|
+
if not text or not isinstance(text, str):
|
|
72
|
+
raise ValueError("Input text must be a non-empty string")
|
|
73
|
+
|
|
74
|
+
self._original_text = text
|
|
75
|
+
|
|
76
|
+
# Encode text as UTF-8 bytes
|
|
77
|
+
text_bytes = text.encode("utf-8")
|
|
78
|
+
text_str = text_bytes.decode("utf-8")
|
|
79
|
+
|
|
80
|
+
tensor = TensorData(
|
|
81
|
+
name=self.input_name,
|
|
82
|
+
shape=[1],
|
|
83
|
+
datatype="BYTES",
|
|
84
|
+
data=[text_str],
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
return OipRequest(inputs=[tensor])
|
|
88
|
+
|
|
89
|
+
def postprocess(self, response: OipResponse) -> SummarizationResult:
|
|
90
|
+
"""Convert OIP response to SummarizationResult.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
response: OIP response containing summary tensor
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
SummarizationResult with extracted summary and metadata
|
|
97
|
+
|
|
98
|
+
Raises:
|
|
99
|
+
ValueError: If response format is invalid
|
|
100
|
+
"""
|
|
101
|
+
outputs = {out.name: out for out in response.outputs}
|
|
102
|
+
|
|
103
|
+
if self.output_name not in outputs:
|
|
104
|
+
raise ValueError(
|
|
105
|
+
f"Expected output tensor '{self.output_name}' not found in response. "
|
|
106
|
+
f"Available: {list(outputs.keys())}"
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
summary_tensor = outputs[self.output_name]
|
|
110
|
+
summary_data = self._extract_text(summary_tensor)
|
|
111
|
+
|
|
112
|
+
if not summary_data:
|
|
113
|
+
raise ValueError("Summary output is empty")
|
|
114
|
+
|
|
115
|
+
return SummarizationResult(
|
|
116
|
+
summary=summary_data,
|
|
117
|
+
original_length=len(self._original_text),
|
|
118
|
+
summary_length=len(summary_data),
|
|
119
|
+
model_name=response.model_name,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
def _extract_text(self, tensor: TensorData) -> str:
|
|
123
|
+
"""Extract text string from tensor data.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
tensor: TensorData containing text
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
Extracted text string
|
|
130
|
+
"""
|
|
131
|
+
if not tensor.data:
|
|
132
|
+
return ""
|
|
133
|
+
|
|
134
|
+
# Handle nested list structure [[text]] or flat [text]
|
|
135
|
+
data = tensor.data
|
|
136
|
+
if isinstance(data, list) and len(data) > 0:
|
|
137
|
+
if isinstance(data[0], list) and len(data[0]) > 0:
|
|
138
|
+
text = data[0][0]
|
|
139
|
+
else:
|
|
140
|
+
text = data[0]
|
|
141
|
+
|
|
142
|
+
if isinstance(text, str):
|
|
143
|
+
return text
|
|
144
|
+
elif isinstance(text, bytes):
|
|
145
|
+
return text.decode("utf-8")
|
|
146
|
+
|
|
147
|
+
return str(data)
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
###
|
|
2
|
+
# #%L
|
|
3
|
+
# aiSSEMBLE::Open Inference Protocol::Core
|
|
4
|
+
# %%
|
|
5
|
+
# Copyright (C) 2024 Booz Allen Hamilton Inc.
|
|
6
|
+
# %%
|
|
7
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
8
|
+
# you may not use this file except in compliance with the License.
|
|
9
|
+
# You may obtain a copy of the License at
|
|
10
|
+
#
|
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
12
|
+
#
|
|
13
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
14
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
15
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
16
|
+
# See the License for the specific language governing permissions and
|
|
17
|
+
# limitations under the License.
|
|
18
|
+
# #L%
|
|
19
|
+
###
|
|
20
|
+
"""TensorFlow Object Detection API translator for OIP object detection.
|
|
21
|
+
|
|
22
|
+
This translator handles the TensorFlow Object Detection API / TensorFlow Serving
|
|
23
|
+
tensor format, commonly used with models exported from TensorFlow Model Zoo.
|
|
24
|
+
|
|
25
|
+
Key format characteristics:
|
|
26
|
+
- Normalized coordinates [0, 1] instead of pixel coordinates
|
|
27
|
+
- TensorFlow Serving-style tensor names (detection_boxes, detection_classes, etc.)
|
|
28
|
+
- Batched shape conventions [1, N, 4] instead of [N, 4]
|
|
29
|
+
- Integer class IDs requiring mapping to label names
|
|
30
|
+
- Coordinate order: (ymin, xmin, ymax, xmax) instead of (x1, y1, x2, y2)
|
|
31
|
+
|
|
32
|
+
Despite these differences from the default YOLO-style format, end users use the
|
|
33
|
+
SAME client API - the translator handles all tensor-level complexity invisibly.
|
|
34
|
+
|
|
35
|
+
Thread Safety:
|
|
36
|
+
This translator is thread-safe and stateless. Image dimensions are stored in
|
|
37
|
+
request parameters and retrieved from response parameters.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
import logging
|
|
41
|
+
from typing import Any
|
|
42
|
+
|
|
43
|
+
from aissemble_inference_core.client.oip_adapter import (
|
|
44
|
+
OipRequest,
|
|
45
|
+
OipResponse,
|
|
46
|
+
TensorData,
|
|
47
|
+
)
|
|
48
|
+
from aissemble_inference_core.client.results import (
|
|
49
|
+
BoundingBox,
|
|
50
|
+
Detection,
|
|
51
|
+
ObjectDetectionResult,
|
|
52
|
+
)
|
|
53
|
+
from aissemble_inference_core.client.translator import Translator
|
|
54
|
+
from aissemble_inference_core.client.translators._image_utils import (
|
|
55
|
+
encode_image_for_oip,
|
|
56
|
+
)
|
|
57
|
+
from aissemble_inference_core.client.translators._tensor_utils import (
|
|
58
|
+
unbatch_tensor,
|
|
59
|
+
validate_required_tensors,
|
|
60
|
+
validate_tensor_lengths_match,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
logger = logging.getLogger(__name__)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class TensorFlowObjectDetectionTranslator(Translator[Any, ObjectDetectionResult]):
|
|
67
|
+
"""Translator for TensorFlow Object Detection API / TensorFlow Serving format.
|
|
68
|
+
|
|
69
|
+
This translator handles the standard TensorFlow Object Detection API output format,
|
|
70
|
+
commonly used by models exported from TensorFlow Model Zoo and served via TensorFlow
|
|
71
|
+
Serving, KServe, or other TF-compatible inference backends.
|
|
72
|
+
|
|
73
|
+
Expected output tensors:
|
|
74
|
+
- detection_boxes: [1, N, 4] with NORMALIZED coordinates (ymin, xmin, ymax, xmax)
|
|
75
|
+
- detection_classes: [1, N] with integer class IDs
|
|
76
|
+
- detection_scores: [1, N] with confidence scores
|
|
77
|
+
- num_detections: [1] with count of valid detections
|
|
78
|
+
|
|
79
|
+
Key differences from default YOLO-style format:
|
|
80
|
+
1. Normalized coords [0,1] vs pixel coords
|
|
81
|
+
2. Different coord order: (ymin, xmin, ymax, xmax) vs (x1, y1, x2, y2)
|
|
82
|
+
3. Batched shape [1, N, 4] vs unbatched [N, 4]
|
|
83
|
+
4. Integer class IDs vs string labels
|
|
84
|
+
5. Extra num_detections tensor
|
|
85
|
+
|
|
86
|
+
Despite these differences, users interact with the same ObjectDetectionResult
|
|
87
|
+
type, demonstrating complete tensor abstraction.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
class_names: dict[int, str] | None = None,
|
|
93
|
+
):
|
|
94
|
+
"""Initialize translator with class ID to name mapping.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
class_names: Optional mapping of class IDs to human-readable names.
|
|
98
|
+
If None, uses class IDs as labels (e.g., "class_17").
|
|
99
|
+
"""
|
|
100
|
+
self._class_names = class_names or {}
|
|
101
|
+
|
|
102
|
+
def preprocess(self, input_data: Any) -> OipRequest:
|
|
103
|
+
"""Preprocess image input into OipRequest.
|
|
104
|
+
|
|
105
|
+
Accepts PIL Image, numpy array, file path, or bytes.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
input_data: Image data in supported format
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
OipRequest with encoded image tensor and dimensions in parameters
|
|
112
|
+
"""
|
|
113
|
+
image_bytes, width, height = encode_image_for_oip(input_data)
|
|
114
|
+
|
|
115
|
+
tensor = TensorData(
|
|
116
|
+
name="image",
|
|
117
|
+
shape=[1, len(image_bytes)],
|
|
118
|
+
datatype="BYTES",
|
|
119
|
+
data=[[image_bytes]],
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# Store image dimensions in request parameters for stateless operation
|
|
123
|
+
return OipRequest(
|
|
124
|
+
inputs=[tensor],
|
|
125
|
+
parameters={"_image_width": width, "_image_height": height},
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
def postprocess(self, response: OipResponse) -> ObjectDetectionResult:
|
|
129
|
+
"""Postprocess OipResponse into ObjectDetectionResult.
|
|
130
|
+
|
|
131
|
+
Handles the TensorFlow Object Detection API tensor format:
|
|
132
|
+
- Validates required tensors are present
|
|
133
|
+
- Extracts batch dimension [1, N, ...] -> [N, ...]
|
|
134
|
+
- Converts normalized coords to pixel coords
|
|
135
|
+
- Reorders coords from (ymin, xmin, ymax, xmax) to (x1, y1, x2, y2)
|
|
136
|
+
- Maps class IDs to names
|
|
137
|
+
- Respects num_detections count
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
response: OIP response with TensorFlow format tensors and dimensions
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
ObjectDetectionResult - same as default translator!
|
|
144
|
+
|
|
145
|
+
Raises:
|
|
146
|
+
ValueError: If required tensors are missing or malformed
|
|
147
|
+
"""
|
|
148
|
+
outputs = {out.name: out for out in response.outputs}
|
|
149
|
+
|
|
150
|
+
# Validate required tensors are present
|
|
151
|
+
required = ["detection_boxes", "detection_classes", "detection_scores"]
|
|
152
|
+
validate_required_tensors(outputs, required, "TensorFlow Object Detection API")
|
|
153
|
+
|
|
154
|
+
# Get image dimensions from response parameters
|
|
155
|
+
params = response.parameters or {}
|
|
156
|
+
image_width = params.get("_image_width")
|
|
157
|
+
image_height = params.get("_image_height")
|
|
158
|
+
|
|
159
|
+
if image_width is None or image_height is None:
|
|
160
|
+
raise ValueError(
|
|
161
|
+
"Image dimensions not found in response parameters. "
|
|
162
|
+
"Ensure the OIP adapter preserves request parameters in the response."
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
# Extract num_detections if present, otherwise use all detections
|
|
166
|
+
num_detections = self._get_num_detections(outputs)
|
|
167
|
+
|
|
168
|
+
# Extract and unbatch tensors [1, N] -> [N]
|
|
169
|
+
boxes = unbatch_tensor(outputs["detection_boxes"], num_detections)
|
|
170
|
+
class_ids = [
|
|
171
|
+
int(cid)
|
|
172
|
+
for cid in unbatch_tensor(outputs["detection_classes"], num_detections)
|
|
173
|
+
]
|
|
174
|
+
scores = [
|
|
175
|
+
float(s)
|
|
176
|
+
for s in unbatch_tensor(outputs["detection_scores"], num_detections)
|
|
177
|
+
]
|
|
178
|
+
|
|
179
|
+
# Validate all tensors have same length
|
|
180
|
+
validate_tensor_lengths_match(
|
|
181
|
+
boxes, class_ids, scores, names=["boxes", "classes", "scores"]
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
detections = []
|
|
185
|
+
for box, class_id, score in zip(boxes, class_ids, scores):
|
|
186
|
+
# Validate normalized coordinate range
|
|
187
|
+
ymin_norm, xmin_norm, ymax_norm, xmax_norm = box
|
|
188
|
+
if not all(0 <= coord <= 1 for coord in box):
|
|
189
|
+
logger.warning(
|
|
190
|
+
f"Normalized coordinates outside [0,1] range: {box}. "
|
|
191
|
+
f"Server may be misconfigured or model may be wrong type."
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
# Convert normalized coords to pixel coords and reorder
|
|
195
|
+
x1 = xmin_norm * image_width
|
|
196
|
+
y1 = ymin_norm * image_height
|
|
197
|
+
x2 = xmax_norm * image_width
|
|
198
|
+
y2 = ymax_norm * image_height
|
|
199
|
+
|
|
200
|
+
# Map class ID to name
|
|
201
|
+
label = self._class_names.get(class_id, f"class_{class_id}")
|
|
202
|
+
|
|
203
|
+
detection = Detection(
|
|
204
|
+
bbox=BoundingBox(x1=x1, y1=y1, x2=x2, y2=y2),
|
|
205
|
+
label=label,
|
|
206
|
+
confidence=score,
|
|
207
|
+
)
|
|
208
|
+
detections.append(detection)
|
|
209
|
+
|
|
210
|
+
return ObjectDetectionResult(
|
|
211
|
+
detections=detections,
|
|
212
|
+
image_width=image_width,
|
|
213
|
+
image_height=image_height,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
def _get_num_detections(self, outputs: dict[str, TensorData]) -> int | None:
|
|
217
|
+
"""Extract number of valid detections from num_detections tensor if present.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
outputs: Dictionary of output tensors
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
Number of detections or None if tensor not present
|
|
224
|
+
"""
|
|
225
|
+
if "num_detections" not in outputs:
|
|
226
|
+
return None
|
|
227
|
+
|
|
228
|
+
num_det_data = outputs["num_detections"].data
|
|
229
|
+
if isinstance(num_det_data, list):
|
|
230
|
+
return int(num_det_data[0]) if num_det_data else None
|
|
231
|
+
return int(num_det_data)
|