aissemble-inference-core 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aissemble_inference_core/__init__.py +19 -0
- aissemble_inference_core/client/__init__.py +23 -0
- aissemble_inference_core/client/builder/__init__.py +36 -0
- aissemble_inference_core/client/builder/inference_builder.py +178 -0
- aissemble_inference_core/client/builder/object_detection_builder.py +190 -0
- aissemble_inference_core/client/builder/raw_inference_builder.py +95 -0
- aissemble_inference_core/client/builder/summarization_builder.py +213 -0
- aissemble_inference_core/client/inference_client.py +158 -0
- aissemble_inference_core/client/oip_adapter.py +211 -0
- aissemble_inference_core/client/predictor.py +75 -0
- aissemble_inference_core/client/registry.py +201 -0
- aissemble_inference_core/client/results/__init__.py +29 -0
- aissemble_inference_core/client/results/object_detection_result.py +155 -0
- aissemble_inference_core/client/results/summarization_result.py +78 -0
- aissemble_inference_core/client/translator.py +57 -0
- aissemble_inference_core/client/translators/__init__.py +34 -0
- aissemble_inference_core/client/translators/_image_utils.py +75 -0
- aissemble_inference_core/client/translators/_tensor_utils.py +89 -0
- aissemble_inference_core/client/translators/object_detection_translator.py +212 -0
- aissemble_inference_core/client/translators/summarization_translator.py +147 -0
- aissemble_inference_core/client/translators/tensorflow_object_detection_translator.py +231 -0
- aissemble_inference_core-1.5.0.dist-info/METADATA +71 -0
- aissemble_inference_core-1.5.0.dist-info/RECORD +26 -0
- aissemble_inference_core-1.5.0.dist-info/WHEEL +4 -0
- aissemble_inference_core-1.5.0.dist-info/entry_points.txt +5 -0
- aissemble_inference_core-1.5.0.dist-info/licenses/LICENSE.txt +201 -0
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
###
|
|
2
|
+
# #%L
|
|
3
|
+
# aiSSEMBLE::Open Inference Protocol::Core
|
|
4
|
+
# %%
|
|
5
|
+
# Copyright (C) 2024 Booz Allen Hamilton Inc.
|
|
6
|
+
# %%
|
|
7
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
8
|
+
# you may not use this file except in compliance with the License.
|
|
9
|
+
# You may obtain a copy of the License at
|
|
10
|
+
#
|
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
12
|
+
#
|
|
13
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
14
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
15
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
16
|
+
# See the License for the specific language governing permissions and
|
|
17
|
+
# limitations under the License.
|
|
18
|
+
# #L%
|
|
19
|
+
###
|
|
20
|
+
"""Builder for text summarization inference."""
|
|
21
|
+
|
|
22
|
+
from typing import Any, Generator
|
|
23
|
+
|
|
24
|
+
from aissemble_inference_core.client.builder.inference_builder import InferenceBuilder
|
|
25
|
+
from aissemble_inference_core.client.predictor import Predictor
|
|
26
|
+
from aissemble_inference_core.client.results.summarization_result import (
|
|
27
|
+
SummarizationResult,
|
|
28
|
+
)
|
|
29
|
+
from aissemble_inference_core.client.translators.summarization_translator import (
|
|
30
|
+
DefaultSummarizationTranslator,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class SummarizationPredictor(Predictor[str, SummarizationResult]):
|
|
35
|
+
"""Concrete predictor implementation for text summarization."""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
adapter: Any,
|
|
40
|
+
translator: DefaultSummarizationTranslator,
|
|
41
|
+
):
|
|
42
|
+
"""Initialize the summarization predictor.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
adapter: OIP adapter for communication
|
|
46
|
+
translator: Translator for preprocessing/postprocessing
|
|
47
|
+
"""
|
|
48
|
+
self.adapter = adapter
|
|
49
|
+
self.translator = translator
|
|
50
|
+
|
|
51
|
+
def predict(self, text: str) -> SummarizationResult: # noqa: A003
|
|
52
|
+
"""Perform summarization on the input text.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
text: Text to summarize
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
SummarizationResult with generated summary
|
|
59
|
+
"""
|
|
60
|
+
request = self.translator.preprocess(text)
|
|
61
|
+
response = self.adapter.infer(request)
|
|
62
|
+
return self.translator.postprocess(response)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class SummarizationBuilder(InferenceBuilder):
|
|
66
|
+
"""Task-specific builder for text summarization inference.
|
|
67
|
+
|
|
68
|
+
Provides a fluent, natural API for text summarization:
|
|
69
|
+
builder.text("long text...").max_length(100).run()
|
|
70
|
+
|
|
71
|
+
Example:
|
|
72
|
+
client = InferenceClient(adapter, endpoint)
|
|
73
|
+
result = (client.summarize("bart-large")
|
|
74
|
+
.text("Very long article text here...")
|
|
75
|
+
.max_length(150)
|
|
76
|
+
.min_length(50)
|
|
77
|
+
.run())
|
|
78
|
+
print(result.summary)
|
|
79
|
+
print(f"Compressed {result.compression_ratio:.1f}x")
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(self):
|
|
83
|
+
"""Initialize the summarization builder."""
|
|
84
|
+
super().__init__()
|
|
85
|
+
self._text_input: str | None = None
|
|
86
|
+
self._max_length: int | None = None
|
|
87
|
+
self._min_length: int | None = None
|
|
88
|
+
|
|
89
|
+
def text(self, text: str) -> "SummarizationBuilder":
|
|
90
|
+
"""Set the input text to summarize.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
text: The text to summarize
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
Self for method chaining
|
|
97
|
+
|
|
98
|
+
Raises:
|
|
99
|
+
ValueError: If text is empty or not a string
|
|
100
|
+
"""
|
|
101
|
+
if not text or not isinstance(text, str):
|
|
102
|
+
raise ValueError("Text must be a non-empty string")
|
|
103
|
+
self._text_input = text
|
|
104
|
+
return self
|
|
105
|
+
|
|
106
|
+
def max_length(self, length: int) -> "SummarizationBuilder":
|
|
107
|
+
"""Set maximum summary length.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
length: Maximum length in tokens or characters (model-dependent)
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
Self for method chaining
|
|
114
|
+
|
|
115
|
+
Raises:
|
|
116
|
+
ValueError: If length is not positive
|
|
117
|
+
"""
|
|
118
|
+
if length <= 0:
|
|
119
|
+
raise ValueError("Max length must be positive")
|
|
120
|
+
self._max_length = length
|
|
121
|
+
return self
|
|
122
|
+
|
|
123
|
+
def min_length(self, length: int) -> "SummarizationBuilder":
|
|
124
|
+
"""Set minimum summary length.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
length: Minimum length in tokens or characters (model-dependent)
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Self for method chaining
|
|
131
|
+
|
|
132
|
+
Raises:
|
|
133
|
+
ValueError: If length is not positive
|
|
134
|
+
"""
|
|
135
|
+
if length <= 0:
|
|
136
|
+
raise ValueError("Min length must be positive")
|
|
137
|
+
self._min_length = length
|
|
138
|
+
return self
|
|
139
|
+
|
|
140
|
+
def run(self) -> SummarizationResult:
|
|
141
|
+
"""Execute the text summarization inference.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
SummarizationResult with generated summary
|
|
145
|
+
|
|
146
|
+
Raises:
|
|
147
|
+
ValueError: If required inputs are missing or invalid
|
|
148
|
+
"""
|
|
149
|
+
if self._text_input is None:
|
|
150
|
+
raise ValueError("Text input is required. Call .text() first.")
|
|
151
|
+
|
|
152
|
+
if self._min_length and self._max_length:
|
|
153
|
+
if self._min_length > self._max_length:
|
|
154
|
+
raise ValueError(
|
|
155
|
+
f"Min length ({self._min_length}) cannot exceed "
|
|
156
|
+
f"max length ({self._max_length})"
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
# Add length parameters to request if specified
|
|
160
|
+
if self._max_length is not None:
|
|
161
|
+
self._parameters["max_length"] = self._max_length
|
|
162
|
+
if self._min_length is not None:
|
|
163
|
+
self._parameters["min_length"] = self._min_length
|
|
164
|
+
|
|
165
|
+
predictor = self.build_predictor()
|
|
166
|
+
return predictor.predict(self._text_input)
|
|
167
|
+
|
|
168
|
+
def build_predictor(self) -> Predictor[str, SummarizationResult]:
|
|
169
|
+
"""Build the predictor for summarization.
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
SummarizationPredictor instance
|
|
173
|
+
|
|
174
|
+
Raises:
|
|
175
|
+
ValueError: If adapter is not set
|
|
176
|
+
"""
|
|
177
|
+
if self.oip_adapter is None:
|
|
178
|
+
raise ValueError("OipAdapter is required. Call .with_adapter() first.")
|
|
179
|
+
|
|
180
|
+
translator = (
|
|
181
|
+
self.translator
|
|
182
|
+
if isinstance(self.translator, DefaultSummarizationTranslator)
|
|
183
|
+
else DefaultSummarizationTranslator()
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
return SummarizationPredictor(
|
|
187
|
+
adapter=self.oip_adapter,
|
|
188
|
+
translator=translator,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
def __iter__(self) -> Generator[SummarizationResult, None, None]:
|
|
192
|
+
"""Streaming iteration for summarization.
|
|
193
|
+
|
|
194
|
+
Not currently implemented for summarization tasks.
|
|
195
|
+
|
|
196
|
+
Raises:
|
|
197
|
+
NotImplementedError: Streaming is not yet supported
|
|
198
|
+
"""
|
|
199
|
+
raise NotImplementedError(
|
|
200
|
+
"Streaming is not yet supported for summarization tasks"
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
def __next__(self) -> SummarizationResult:
|
|
204
|
+
"""Return the next streaming result.
|
|
205
|
+
|
|
206
|
+
Not currently implemented for summarization tasks.
|
|
207
|
+
|
|
208
|
+
Raises:
|
|
209
|
+
NotImplementedError: Streaming is not yet supported
|
|
210
|
+
"""
|
|
211
|
+
raise NotImplementedError(
|
|
212
|
+
"Streaming is not yet supported for summarization tasks"
|
|
213
|
+
)
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
###
|
|
2
|
+
# #%L
|
|
3
|
+
# aiSSEMBLE::Open Inference Protocol::Core
|
|
4
|
+
# %%
|
|
5
|
+
# Copyright (C) 2024 Booz Allen Hamilton Inc.
|
|
6
|
+
# %%
|
|
7
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
8
|
+
# you may not use this file except in compliance with the License.
|
|
9
|
+
# You may obtain a copy of the License at
|
|
10
|
+
#
|
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
12
|
+
#
|
|
13
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
14
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
15
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
16
|
+
# See the License for the specific language governing permissions and
|
|
17
|
+
# limitations under the License.
|
|
18
|
+
# #L%
|
|
19
|
+
###
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
from typing import Dict, List
|
|
23
|
+
|
|
24
|
+
from aissemble_inference_core.client.builder.object_detection_builder import (
|
|
25
|
+
ObjectDetectionBuilder,
|
|
26
|
+
)
|
|
27
|
+
from aissemble_inference_core.client.builder.raw_inference_builder import (
|
|
28
|
+
RawInferenceBuilder,
|
|
29
|
+
)
|
|
30
|
+
from aissemble_inference_core.client.builder.summarization_builder import (
|
|
31
|
+
SummarizationBuilder,
|
|
32
|
+
)
|
|
33
|
+
from aissemble_inference_core.client.oip_adapter import OipAdapter
|
|
34
|
+
from aissemble_inference_core.client.registry import ModuleRegistry
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class InferenceClient:
|
|
38
|
+
"""A facade of the entire client library. Supports one-line construction with zero configuration ceremony. It wraps
|
|
39
|
+
controls invocation of OIP-compliant endpoints.
|
|
40
|
+
|
|
41
|
+
ALL task-specific entry points (e.g., detect_object, summarize) are accessible from here.
|
|
42
|
+
|
|
43
|
+
The client supports dynamic module discovery via Python entry points. Installed
|
|
44
|
+
OIP modules (e.g., aissemble-inference-yolo) automatically register their builders,
|
|
45
|
+
translators, and runtimes.
|
|
46
|
+
|
|
47
|
+
Example:
|
|
48
|
+
from aissemble_inference_core.client import InferenceClient
|
|
49
|
+
|
|
50
|
+
client = InferenceClient(adapter, endpoint)
|
|
51
|
+
|
|
52
|
+
# List available modules
|
|
53
|
+
print(client.list_available_modules())
|
|
54
|
+
|
|
55
|
+
# Use object detection (uses discovered or built-in builder)
|
|
56
|
+
result = client.detect_object("yolov8").image("photo.jpg").run()
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self, adapter: OipAdapter, endpoint: str):
|
|
60
|
+
"""Initializes the InferenceClient with the given adapter and endpoint.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
adapter: The OIP adapter to use for inference.
|
|
64
|
+
endpoint: The endpoint URL for the inference service.
|
|
65
|
+
"""
|
|
66
|
+
self.adapter = adapter
|
|
67
|
+
self.endpoint = endpoint
|
|
68
|
+
self._registry = ModuleRegistry.instance()
|
|
69
|
+
|
|
70
|
+
def raw(self, model_name: str) -> RawInferenceBuilder:
|
|
71
|
+
"""Creates a builder for raw inference.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
model_name: The name of the model to use.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
A RawInferenceBuilder instance.
|
|
78
|
+
"""
|
|
79
|
+
# TODO: Implement raw inference builder creation
|
|
80
|
+
pass
|
|
81
|
+
|
|
82
|
+
def detect_object(self, model_name: str | None = None) -> ObjectDetectionBuilder:
|
|
83
|
+
"""Creates a builder for object detection inference.
|
|
84
|
+
|
|
85
|
+
If a model-specific builder is registered (e.g., from aissemble-inference-yolo),
|
|
86
|
+
it will be used. Otherwise, falls back to the default ObjectDetectionBuilder.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
model_name: Optional name of the model to use
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
An ObjectDetectionBuilder instance configured for this client
|
|
93
|
+
|
|
94
|
+
Example:
|
|
95
|
+
result = client.detect_object("yolov8").image("photo.jpg").confidence(0.6).run()
|
|
96
|
+
"""
|
|
97
|
+
builder = ObjectDetectionBuilder()
|
|
98
|
+
builder = builder.with_adapter(self.adapter)
|
|
99
|
+
if model_name:
|
|
100
|
+
builder = builder.with_model(model_name)
|
|
101
|
+
return builder
|
|
102
|
+
|
|
103
|
+
def summarize(self, model_name: str | None = None) -> SummarizationBuilder:
|
|
104
|
+
"""Creates a builder for text summarization inference.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
model_name: Optional name of the summarization model to use
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
A SummarizationBuilder instance configured for this client
|
|
111
|
+
|
|
112
|
+
Example:
|
|
113
|
+
result = client.summarize("bart-large") \
|
|
114
|
+
.text("Long article text here...") \
|
|
115
|
+
.max_length(100) \
|
|
116
|
+
.run()
|
|
117
|
+
print(result.summary)
|
|
118
|
+
"""
|
|
119
|
+
builder = SummarizationBuilder()
|
|
120
|
+
builder = builder.with_adapter(self.adapter)
|
|
121
|
+
if model_name:
|
|
122
|
+
builder = builder.with_model(model_name)
|
|
123
|
+
return builder
|
|
124
|
+
|
|
125
|
+
def list_available_modules(self) -> Dict[str, List[str]]:
|
|
126
|
+
"""List all discovered OIP modules.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
Dictionary mapping category to list of available module names.
|
|
130
|
+
Categories: runtimes, translators, builders
|
|
131
|
+
|
|
132
|
+
Example:
|
|
133
|
+
modules = client.list_available_modules()
|
|
134
|
+
# {'runtimes': ['yolo'], 'translators': ['yolo', 'object_detection'], ...}
|
|
135
|
+
"""
|
|
136
|
+
return self._registry.list_available()
|
|
137
|
+
|
|
138
|
+
def get_translator(self, name: str):
|
|
139
|
+
"""Get a translator class by name from the registry.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
name: The registered name of the translator
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
The translator class, or None if not found
|
|
146
|
+
"""
|
|
147
|
+
return self._registry.get_translator(name)
|
|
148
|
+
|
|
149
|
+
def get_runtime(self, name: str):
|
|
150
|
+
"""Get a runtime class by name from the registry.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
name: The registered name of the runtime
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
The runtime class, or None if not found
|
|
157
|
+
"""
|
|
158
|
+
return self._registry.get_runtime(name)
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
###
|
|
2
|
+
# #%L
|
|
3
|
+
# aiSSEMBLE::Open Inference Protocol::Core
|
|
4
|
+
# %%
|
|
5
|
+
# Copyright (C) 2024 Booz Allen Hamilton Inc.
|
|
6
|
+
# %%
|
|
7
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
8
|
+
# you may not use this file except in compliance with the License.
|
|
9
|
+
# You may obtain a copy of the License at
|
|
10
|
+
#
|
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
12
|
+
#
|
|
13
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
14
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
15
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
16
|
+
# See the License for the specific language governing permissions and
|
|
17
|
+
# limitations under the License.
|
|
18
|
+
# #L%
|
|
19
|
+
###
|
|
20
|
+
from dataclasses import dataclass, field
|
|
21
|
+
from typing import Any
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class TensorData:
|
|
26
|
+
"""Represents tensor data in OIP format.
|
|
27
|
+
|
|
28
|
+
Attributes:
|
|
29
|
+
name: Name of the tensor input/output
|
|
30
|
+
shape: Shape of the tensor as a list of integers
|
|
31
|
+
datatype: Data type string (e.g., "FP32", "UINT8", "BYTES")
|
|
32
|
+
data: The actual tensor data (nested list structure)
|
|
33
|
+
parameters: Optional parameters for this tensor
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
name: str
|
|
37
|
+
shape: list[int]
|
|
38
|
+
datatype: str
|
|
39
|
+
data: list[Any]
|
|
40
|
+
parameters: dict[str, Any] | None = None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class OipRequest:
|
|
45
|
+
"""Represents an OIP inference request compliant with the Open Inference Protocol specification.
|
|
46
|
+
|
|
47
|
+
Attributes:
|
|
48
|
+
inputs: List of input tensors (required)
|
|
49
|
+
id: Optional request identifier
|
|
50
|
+
parameters: Optional inference parameters
|
|
51
|
+
outputs: Optional list of requested output names
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
inputs: list[TensorData]
|
|
55
|
+
id: str | None = None # noqa: A003
|
|
56
|
+
parameters: dict[str, Any] = field(default_factory=dict)
|
|
57
|
+
outputs: list[dict[str, Any]] | None = None
|
|
58
|
+
|
|
59
|
+
def to_dict(self) -> dict[str, Any]:
|
|
60
|
+
"""Convert the request to a dictionary for JSON serialization."""
|
|
61
|
+
result: dict[str, Any] = {
|
|
62
|
+
"inputs": [
|
|
63
|
+
{
|
|
64
|
+
"name": inp.name,
|
|
65
|
+
"shape": inp.shape,
|
|
66
|
+
"datatype": inp.datatype,
|
|
67
|
+
"data": inp.data,
|
|
68
|
+
**({"parameters": inp.parameters} if inp.parameters else {}),
|
|
69
|
+
}
|
|
70
|
+
for inp in self.inputs
|
|
71
|
+
]
|
|
72
|
+
}
|
|
73
|
+
if self.id:
|
|
74
|
+
result["id"] = self.id
|
|
75
|
+
if self.parameters:
|
|
76
|
+
result["parameters"] = self.parameters
|
|
77
|
+
if self.outputs:
|
|
78
|
+
result["outputs"] = self.outputs
|
|
79
|
+
return result
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@dataclass
|
|
83
|
+
class OipResponse:
|
|
84
|
+
"""Represents an OIP inference response compliant with the Open Inference Protocol specification.
|
|
85
|
+
|
|
86
|
+
Attributes:
|
|
87
|
+
model_name: Name of the model that produced this response
|
|
88
|
+
outputs: List of output tensors
|
|
89
|
+
model_version: Optional version of the model
|
|
90
|
+
id: Optional request identifier echo
|
|
91
|
+
parameters: Optional response parameters
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
model_name: str
|
|
95
|
+
outputs: list[TensorData]
|
|
96
|
+
model_version: str | None = None
|
|
97
|
+
id: str | None = None # noqa: A003
|
|
98
|
+
parameters: dict[str, Any] | None = None
|
|
99
|
+
|
|
100
|
+
@classmethod
|
|
101
|
+
def from_dict(cls, data: dict[str, Any]) -> "OipResponse":
|
|
102
|
+
"""Create an OipResponse from a dictionary (JSON deserialization)."""
|
|
103
|
+
outputs = [
|
|
104
|
+
TensorData(
|
|
105
|
+
name=out["name"],
|
|
106
|
+
shape=out["shape"],
|
|
107
|
+
datatype=out["datatype"],
|
|
108
|
+
data=out["data"],
|
|
109
|
+
parameters=out.get("parameters"),
|
|
110
|
+
)
|
|
111
|
+
for out in data["outputs"]
|
|
112
|
+
]
|
|
113
|
+
return cls(
|
|
114
|
+
model_name=data["model_name"],
|
|
115
|
+
outputs=outputs,
|
|
116
|
+
model_version=data.get("model_version"),
|
|
117
|
+
id=data.get("id"),
|
|
118
|
+
parameters=data.get("parameters"),
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@dataclass
|
|
123
|
+
class OipHealthStatus:
|
|
124
|
+
"""Represents the health status of an OIP endpoint.
|
|
125
|
+
|
|
126
|
+
Attributes:
|
|
127
|
+
isLive: Whether the endpoint is alive
|
|
128
|
+
isReady: Whether the endpoint is ready to serve requests
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
isLive: bool
|
|
132
|
+
isReady: bool
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class OipAdapter:
|
|
136
|
+
"""This is the sole class that interacts with OIP endpoints. It is stateless and can be mocked for client testing.
|
|
137
|
+
Implements appropriate backoff, authentication, and metrics capturing.
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
def infer(self, request: OipRequest) -> OipResponse:
|
|
141
|
+
"""Performs inference using the provided OIP request.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
request: The OipRequest object containing inference parameters.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
The OipResponse object.
|
|
148
|
+
"""
|
|
149
|
+
raise NotImplementedError
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class HttpOipAdapter(OipAdapter):
|
|
153
|
+
"""HTTP-based OIP adapter for communicating with OIP-compliant inference servers.
|
|
154
|
+
|
|
155
|
+
This adapter sends inference requests to an OIP endpoint using HTTP POST,
|
|
156
|
+
compatible with MLServer and other OIP-compliant runtimes.
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
def __init__(self, base_url: str, model_name: str, timeout: float = 30.0):
|
|
160
|
+
"""Initialize the HTTP adapter.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
base_url: Base URL of the inference server (e.g., "http://localhost:8080")
|
|
164
|
+
model_name: Name of the model to invoke
|
|
165
|
+
timeout: Request timeout in seconds (default: 30.0)
|
|
166
|
+
"""
|
|
167
|
+
self.base_url = base_url.rstrip("/")
|
|
168
|
+
self.model_name = model_name
|
|
169
|
+
self.timeout = timeout
|
|
170
|
+
|
|
171
|
+
@property
|
|
172
|
+
def inference_url(self) -> str:
|
|
173
|
+
"""Get the full inference URL for this model."""
|
|
174
|
+
return f"{self.base_url}/v2/models/{self.model_name}/infer"
|
|
175
|
+
|
|
176
|
+
def infer(self, request: OipRequest) -> OipResponse:
|
|
177
|
+
"""Perform inference using HTTP POST to the inference server.
|
|
178
|
+
|
|
179
|
+
Request parameters are preserved in the response to support stateless translators.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
request: The OipRequest containing inference inputs
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
OipResponse with model outputs and merged request parameters
|
|
186
|
+
|
|
187
|
+
Raises:
|
|
188
|
+
requests.HTTPError: If the server returns an error status
|
|
189
|
+
requests.Timeout: If the request times out
|
|
190
|
+
"""
|
|
191
|
+
import requests
|
|
192
|
+
|
|
193
|
+
response = requests.post(
|
|
194
|
+
self.inference_url,
|
|
195
|
+
json=request.to_dict(),
|
|
196
|
+
headers={"Content-Type": "application/json"},
|
|
197
|
+
timeout=self.timeout,
|
|
198
|
+
)
|
|
199
|
+
response.raise_for_status()
|
|
200
|
+
|
|
201
|
+
oip_response = OipResponse.from_dict(response.json())
|
|
202
|
+
|
|
203
|
+
# Merge request parameters into response for stateless translator support
|
|
204
|
+
# Server parameters take precedence over request parameters
|
|
205
|
+
if request.parameters:
|
|
206
|
+
merged_params = dict(request.parameters)
|
|
207
|
+
if oip_response.parameters:
|
|
208
|
+
merged_params.update(oip_response.parameters)
|
|
209
|
+
oip_response.parameters = merged_params
|
|
210
|
+
|
|
211
|
+
return oip_response
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
###
|
|
2
|
+
# #%L
|
|
3
|
+
# aiSSEMBLE::Open Inference Protocol::Core
|
|
4
|
+
# %%
|
|
5
|
+
# Copyright (C) 2024 Booz Allen Hamilton Inc.
|
|
6
|
+
# %%
|
|
7
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
8
|
+
# you may not use this file except in compliance with the License.
|
|
9
|
+
# You may obtain a copy of the License at
|
|
10
|
+
#
|
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
12
|
+
#
|
|
13
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
14
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
15
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
16
|
+
# See the License for the specific language governing permissions and
|
|
17
|
+
# limitations under the License.
|
|
18
|
+
# #L%
|
|
19
|
+
###
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
from typing import Generic, TypeVar
|
|
23
|
+
|
|
24
|
+
# Input type variable - represents the type of data fed into the model for inference
|
|
25
|
+
T = TypeVar("T")
|
|
26
|
+
|
|
27
|
+
# Output type variable - represents the type of prediction/result returned by the model
|
|
28
|
+
R = TypeVar("R")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class Predictor(Generic[T, R]):
|
|
32
|
+
"""
|
|
33
|
+
An interface representing a ready-to-use inference wrapper around a specific model endpoint.
|
|
34
|
+
|
|
35
|
+
Inspired by the DJL (Deep Java Library) ``Predictor`` abstraction, this interface defines a
|
|
36
|
+
lightweight, reusable component that encapsulates the invocation of a model while managing
|
|
37
|
+
underlying resources such as connection pools, retries, timeouts, and authentication.
|
|
38
|
+
|
|
39
|
+
Implementations are expected to be thread-safe where appropriate and should maintain internal
|
|
40
|
+
state (e.g., HTTP clients, gRPC channels, or cached credentials) for efficient repeated use.
|
|
41
|
+
Typically, one ``Predictor`` instance is created per model endpoint/version and reused across
|
|
42
|
+
multiple inference calls.
|
|
43
|
+
|
|
44
|
+
Example usage::
|
|
45
|
+
|
|
46
|
+
predictor: Predictor[Image, Classification] = model.new_predictor()
|
|
47
|
+
result = predictor.predict(my_image)
|
|
48
|
+
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def predict(self, input: T) -> R: # noqa: A003
|
|
52
|
+
"""
|
|
53
|
+
Perform inference on the given input and return the model's prediction.
|
|
54
|
+
|
|
55
|
+
Parameters
|
|
56
|
+
----------
|
|
57
|
+
input : T
|
|
58
|
+
The input data in the format expected by the wrapped model. The exact type depends
|
|
59
|
+
on the model (e.g., ``numpy.ndarray``, ``PIL.Image.Image``, ``dict[str, Any]``, etc.).
|
|
60
|
+
|
|
61
|
+
Returns
|
|
62
|
+
-------
|
|
63
|
+
R
|
|
64
|
+
The prediction result. The concrete return type is determined by the model's output
|
|
65
|
+
schema (e.g., ``Classification``, ``List[float]``, ``dict[str, Any]``, etc.).
|
|
66
|
+
|
|
67
|
+
Raises
|
|
68
|
+
------
|
|
69
|
+
InferenceError
|
|
70
|
+
If the inference request fails after retries are exhausted or due to an unrecoverable
|
|
71
|
+
error (e.g., malformed input, model endpoint unreachable, timeout).
|
|
72
|
+
ValueError
|
|
73
|
+
If the input does not conform to the expected schema or constraints.
|
|
74
|
+
"""
|
|
75
|
+
raise NotImplementedError("Subclasses must implement predict()")
|