kinfer 0.0.1__tar.gz → 0.0.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {kinfer-0.0.1 → kinfer-0.0.2}/LICENSE +1 -1
- {kinfer-0.0.1/kinfer.egg-info → kinfer-0.0.2}/PKG-INFO +6 -3
- kinfer-0.0.2/kinfer/__init__.py +1 -0
- kinfer-0.0.2/kinfer/export/pytorch.py +213 -0
- kinfer-0.0.2/kinfer/inference/__init__.py +0 -0
- kinfer-0.0.2/kinfer/inference/python.py +91 -0
- kinfer-0.0.2/kinfer/py.typed +0 -0
- kinfer-0.0.2/kinfer/requirements.txt +5 -0
- {kinfer-0.0.1 → kinfer-0.0.2/kinfer.egg-info}/PKG-INFO +6 -3
- {kinfer-0.0.1 → kinfer-0.0.2}/kinfer.egg-info/SOURCES.txt +6 -1
- {kinfer-0.0.1 → kinfer-0.0.2}/kinfer.egg-info/requires.txt +3 -0
- {kinfer-0.0.1 → kinfer-0.0.2}/pyproject.toml +1 -1
- {kinfer-0.0.1 → kinfer-0.0.2}/setup.py +5 -4
- kinfer-0.0.2/tests/test_infer.py +193 -0
- kinfer-0.0.1/kinfer/__init__.py +0 -1
- kinfer-0.0.1/kinfer/requirements.txt +0 -1
- {kinfer-0.0.1 → kinfer-0.0.2}/MANIFEST.in +0 -0
- {kinfer-0.0.1 → kinfer-0.0.2}/README.md +0 -0
- /kinfer-0.0.1/kinfer/py.typed → /kinfer-0.0.2/kinfer/export/__init__.py +0 -0
- {kinfer-0.0.1 → kinfer-0.0.2}/kinfer/requirements-dev.txt +0 -0
- {kinfer-0.0.1 → kinfer-0.0.2}/kinfer.egg-info/dependency_links.txt +0 -0
- {kinfer-0.0.1 → kinfer-0.0.2}/kinfer.egg-info/top_level.txt +0 -0
- {kinfer-0.0.1 → kinfer-0.0.2}/setup.cfg +0 -0
- {kinfer-0.0.1 → kinfer-0.0.2}/tests/test_dummy.py +0 -0
@@ -1,12 +1,15 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: kinfer
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.2
|
4
4
|
Summary: The kinfer project
|
5
5
|
Home-page: https://github.com/kscalelabs/kinfer.git
|
6
|
-
Author:
|
7
|
-
Requires-Python: >=3.
|
6
|
+
Author: K-Scale Labs
|
7
|
+
Requires-Python: >=3.8
|
8
8
|
Description-Content-Type: text/markdown
|
9
9
|
License-File: LICENSE
|
10
|
+
Requires-Dist: torch
|
11
|
+
Requires-Dist: onnx
|
12
|
+
Requires-Dist: onnxruntime
|
10
13
|
Provides-Extra: dev
|
11
14
|
Requires-Dist: black; extra == "dev"
|
12
15
|
Requires-Dist: darglint; extra == "dev"
|
@@ -0,0 +1 @@
|
|
1
|
+
__version__ = "0.0.2"
|
@@ -0,0 +1,213 @@
|
|
1
|
+
"""PyTorch model export utilities."""
|
2
|
+
|
3
|
+
import inspect
|
4
|
+
import sys
|
5
|
+
from dataclasses import fields, is_dataclass
|
6
|
+
from io import BytesIO
|
7
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
8
|
+
|
9
|
+
import onnx
|
10
|
+
import onnxruntime as ort # type: ignore[import-untyped]
|
11
|
+
import torch
|
12
|
+
from torch import nn
|
13
|
+
|
14
|
+
|
15
|
+
def get_model_info(model: nn.Module) -> Dict[str, Any]:
|
16
|
+
"""Extract model information including input parameters and their types.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
model: PyTorch model to analyze
|
20
|
+
|
21
|
+
Returns:
|
22
|
+
Dictionary containing model information
|
23
|
+
"""
|
24
|
+
# Get model's forward method signature
|
25
|
+
signature = inspect.signature(model.forward)
|
26
|
+
|
27
|
+
# Extract parameter information
|
28
|
+
params_info = {}
|
29
|
+
for name, param in signature.parameters.items():
|
30
|
+
if name == "self":
|
31
|
+
continue
|
32
|
+
params_info[name] = {
|
33
|
+
"annotation": str(param.annotation),
|
34
|
+
"default": None if param.default is param.empty else str(param.default),
|
35
|
+
}
|
36
|
+
|
37
|
+
return {
|
38
|
+
"input_params": params_info,
|
39
|
+
"num_parameters": sum(p.numel() for p in model.parameters()),
|
40
|
+
}
|
41
|
+
|
42
|
+
|
43
|
+
def add_metadata_to_onnx(
|
44
|
+
model_proto: onnx.ModelProto, metadata: Dict[str, Any], config: Optional[object] = None
|
45
|
+
) -> onnx.ModelProto:
|
46
|
+
"""Add metadata to ONNX model.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
model_proto: ONNX model prototype
|
50
|
+
metadata: Dictionary of metadata to add
|
51
|
+
config: Optional configuration dataclass to add to metadata
|
52
|
+
|
53
|
+
Returns:
|
54
|
+
ONNX model with added metadata
|
55
|
+
"""
|
56
|
+
# Add model metadata
|
57
|
+
for key, value in metadata.items():
|
58
|
+
meta = model_proto.metadata_props.add()
|
59
|
+
meta.key = key
|
60
|
+
meta.value = str(value)
|
61
|
+
|
62
|
+
# Add configuration if provided
|
63
|
+
if config is not None and is_dataclass(config):
|
64
|
+
for field in fields(config):
|
65
|
+
value = getattr(config, field.name)
|
66
|
+
meta = model_proto.metadata_props.add()
|
67
|
+
meta.key = field.name
|
68
|
+
meta.value = str(value)
|
69
|
+
|
70
|
+
return model_proto
|
71
|
+
|
72
|
+
|
73
|
+
def infer_input_shapes(model: nn.Module) -> Union[torch.Size, List[torch.Size]]:
|
74
|
+
"""Infer input shapes from model architecture.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
model: PyTorch model to analyze
|
78
|
+
|
79
|
+
Returns:
|
80
|
+
Single input shape or list of input shapes
|
81
|
+
"""
|
82
|
+
# Check if model is Sequential or has Sequential as first child
|
83
|
+
if isinstance(model, nn.Sequential):
|
84
|
+
first_layer = model[0]
|
85
|
+
else:
|
86
|
+
# Get the first immediate child
|
87
|
+
children = list(model.children())
|
88
|
+
first_layer = children[0] if children else None
|
89
|
+
|
90
|
+
# Unwrap if the first child is Sequential
|
91
|
+
if isinstance(first_layer, nn.Sequential):
|
92
|
+
first_layer = first_layer[0]
|
93
|
+
# Check if first layer is a type we can infer from
|
94
|
+
if not isinstance(first_layer, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
95
|
+
raise ValueError("First layer must be Linear or Conv layer to infer input shape")
|
96
|
+
|
97
|
+
# Get input dimensions
|
98
|
+
if isinstance(first_layer, nn.Linear):
|
99
|
+
return torch.Size([1, first_layer.in_features])
|
100
|
+
elif isinstance(first_layer, nn.Conv1d):
|
101
|
+
raise ValueError("Cannot infer sequence length for Conv1d layer. Please provide input_tensors explicitly.")
|
102
|
+
elif isinstance(first_layer, nn.Conv2d):
|
103
|
+
raise ValueError("Cannot infer image dimensions for Conv2d layer. Please provide input_tensors explicitly.")
|
104
|
+
elif isinstance(first_layer, nn.Conv3d):
|
105
|
+
raise ValueError("Cannot infer volume dimensions for Conv3d layer. Please provide input_tensors explicitly.")
|
106
|
+
|
107
|
+
raise ValueError("Could not infer input shape from model architecture")
|
108
|
+
|
109
|
+
|
110
|
+
def create_example_inputs(model: nn.Module) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
111
|
+
"""Create example input tensors based on model's forward signature and architecture.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
model: PyTorch model to analyze
|
115
|
+
|
116
|
+
Returns:
|
117
|
+
Single tensor or tuple of tensors matching the model's expected input
|
118
|
+
"""
|
119
|
+
signature = inspect.signature(model.forward)
|
120
|
+
params = [p for p in signature.parameters.items() if p[0] != "self"]
|
121
|
+
|
122
|
+
# If single parameter (besides self), try to infer shape
|
123
|
+
if len(params) == 1:
|
124
|
+
shape = infer_input_shapes(model)
|
125
|
+
return torch.randn(*shape) if isinstance(shape, torch.Size) else torch.randn(*shape[0])
|
126
|
+
|
127
|
+
# For multiple parameters, try to infer from parameter annotations
|
128
|
+
input_tensors = []
|
129
|
+
for name, param in params:
|
130
|
+
# Try to get shape from annotation
|
131
|
+
if hasattr(param.annotation, "__origin__") and param.annotation.__origin__ is torch.Tensor:
|
132
|
+
# If annotation includes size information (e.g., Tensor[batch_size, channels, height, width])
|
133
|
+
if hasattr(param.annotation, "__args__"):
|
134
|
+
shape = param.annotation.__args__
|
135
|
+
input_tensors.append(torch.randn(*shape) if isinstance(shape, torch.Size) else torch.randn(*shape[0]))
|
136
|
+
else:
|
137
|
+
# Default to a vector if no size info
|
138
|
+
input_tensors.append(torch.randn(1, 32))
|
139
|
+
else:
|
140
|
+
# Default fallback
|
141
|
+
input_tensors.append(torch.randn(1, 32))
|
142
|
+
|
143
|
+
return tuple(input_tensors)
|
144
|
+
|
145
|
+
|
146
|
+
def export_to_onnx(
|
147
|
+
model: nn.Module,
|
148
|
+
input_tensors: Optional[Union[torch.Tensor, Tuple[torch.Tensor, ...]]] = None,
|
149
|
+
config: Optional[object] = None,
|
150
|
+
save_path: Optional[str] = None,
|
151
|
+
) -> ort.InferenceSession:
|
152
|
+
"""Export PyTorch model to ONNX format with metadata.
|
153
|
+
|
154
|
+
Args:
|
155
|
+
model: PyTorch model to export
|
156
|
+
input_tensors: Optional example input tensors for model tracing. If None, will attempt to infer.
|
157
|
+
config: Optional configuration dataclass to add to metadata
|
158
|
+
save_path: Optional path to save the ONNX model
|
159
|
+
|
160
|
+
Returns:
|
161
|
+
ONNX inference session
|
162
|
+
"""
|
163
|
+
# Get model information
|
164
|
+
model_info = get_model_info(model)
|
165
|
+
|
166
|
+
# Create example inputs if not provided
|
167
|
+
if input_tensors is None:
|
168
|
+
try:
|
169
|
+
input_tensors = create_example_inputs(model)
|
170
|
+
model_info["inferred_input_shapes"] = str(
|
171
|
+
input_tensors.shape if isinstance(input_tensors, torch.Tensor) else [t.shape for t in input_tensors]
|
172
|
+
)
|
173
|
+
except ValueError as e:
|
174
|
+
raise ValueError(
|
175
|
+
f"Could not automatically infer input shapes. Please provide input_tensors. Error: {str(e)}"
|
176
|
+
)
|
177
|
+
|
178
|
+
# Convert model to JIT if not already
|
179
|
+
if not isinstance(model, torch.jit.ScriptModule):
|
180
|
+
model = torch.jit.script(model)
|
181
|
+
|
182
|
+
# Export model to buffer
|
183
|
+
buffer = BytesIO()
|
184
|
+
if TYPE_CHECKING and sys.version_info >= (3, 11):
|
185
|
+
torch.onnx.export(
|
186
|
+
model,
|
187
|
+
(input_tensors,) if isinstance(input_tensors, torch.Tensor) else input_tensors,
|
188
|
+
buffer, # type: ignore[arg-type]
|
189
|
+
)
|
190
|
+
else:
|
191
|
+
torch.onnx.export(
|
192
|
+
model,
|
193
|
+
(input_tensors,) if isinstance(input_tensors, torch.Tensor) else input_tensors,
|
194
|
+
buffer,
|
195
|
+
)
|
196
|
+
buffer.seek(0)
|
197
|
+
|
198
|
+
# Load as ONNX model
|
199
|
+
model_proto = onnx.load_model(buffer)
|
200
|
+
|
201
|
+
# Add metadata
|
202
|
+
model_proto = add_metadata_to_onnx(model_proto, model_info, config)
|
203
|
+
|
204
|
+
# Save if path provided
|
205
|
+
if save_path:
|
206
|
+
onnx.save_model(model_proto, save_path)
|
207
|
+
|
208
|
+
# Convert to inference session
|
209
|
+
buffer = BytesIO()
|
210
|
+
onnx.save_model(model_proto, buffer)
|
211
|
+
buffer.seek(0)
|
212
|
+
|
213
|
+
return ort.InferenceSession(buffer.read())
|
File without changes
|
@@ -0,0 +1,91 @@
|
|
1
|
+
"""ONNX model inference utilities for Python."""
|
2
|
+
|
3
|
+
from typing import Any, Dict, List, Union
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import onnx
|
7
|
+
import onnxruntime as ort # type: ignore[import-untyped]
|
8
|
+
|
9
|
+
|
10
|
+
class ONNXModel:
|
11
|
+
"""Wrapper for ONNX model inference."""
|
12
|
+
|
13
|
+
def __init__(
|
14
|
+
self,
|
15
|
+
model_path: str,
|
16
|
+
) -> None:
|
17
|
+
"""Initialize ONNX model.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
model_path: Path to ONNX model file
|
21
|
+
config: Optional inference configuration
|
22
|
+
"""
|
23
|
+
self.model_path = model_path
|
24
|
+
|
25
|
+
# Load model and create inference session
|
26
|
+
self.model = onnx.load(model_path)
|
27
|
+
self.session = ort.InferenceSession(
|
28
|
+
model_path,
|
29
|
+
)
|
30
|
+
|
31
|
+
# Extract metadata
|
32
|
+
self.metadata = {prop.key: prop.value for prop in self.model.metadata_props}
|
33
|
+
|
34
|
+
# Get input and output details
|
35
|
+
self.input_details = [{"name": x.name, "shape": x.shape, "type": x.type} for x in self.session.get_inputs()]
|
36
|
+
self.output_details = [{"name": x.name, "shape": x.shape, "type": x.type} for x in self.session.get_outputs()]
|
37
|
+
|
38
|
+
def __call__(
|
39
|
+
self, inputs: Union[np.ndarray, Dict[str, np.ndarray], List[np.ndarray]]
|
40
|
+
) -> Union[np.ndarray, Dict[str, np.ndarray], List[np.ndarray]]:
|
41
|
+
"""Run inference on input data.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
inputs: Input data as numpy array, dictionary of arrays, or list of arrays
|
45
|
+
|
46
|
+
Returns:
|
47
|
+
Model outputs in the same format as inputs
|
48
|
+
"""
|
49
|
+
# Convert single array to dict
|
50
|
+
if isinstance(inputs, np.ndarray):
|
51
|
+
input_dict = {self.input_details[0]["name"]: inputs}
|
52
|
+
# Convert list to dict
|
53
|
+
elif isinstance(inputs, list):
|
54
|
+
input_dict = {detail["name"]: arr for detail, arr in zip(self.input_details, inputs)}
|
55
|
+
else:
|
56
|
+
input_dict = inputs
|
57
|
+
|
58
|
+
# Run inference - pass None to output_names param to get all outputs
|
59
|
+
outputs = self.session.run(None, input_dict)
|
60
|
+
|
61
|
+
# Convert output format to match input
|
62
|
+
if isinstance(inputs, np.ndarray):
|
63
|
+
return outputs[0]
|
64
|
+
elif isinstance(inputs, list):
|
65
|
+
return outputs
|
66
|
+
else:
|
67
|
+
return {detail["name"]: arr for detail, arr in zip(self.output_details, outputs)}
|
68
|
+
|
69
|
+
def get_metadata(self) -> Dict[str, str]:
|
70
|
+
"""Get model metadata.
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
Dictionary of metadata key-value pairs
|
74
|
+
"""
|
75
|
+
return self.metadata
|
76
|
+
|
77
|
+
def get_input_details(self) -> List[Dict[str, Any]]:
|
78
|
+
"""Get input tensor details.
|
79
|
+
|
80
|
+
Returns:
|
81
|
+
List of dictionaries containing input tensor information
|
82
|
+
"""
|
83
|
+
return self.input_details
|
84
|
+
|
85
|
+
def get_output_details(self) -> List[Dict[str, Any]]:
|
86
|
+
"""Get output tensor details.
|
87
|
+
|
88
|
+
Returns:
|
89
|
+
List of dictionaries containing output tensor information
|
90
|
+
"""
|
91
|
+
return self.output_details
|
File without changes
|
@@ -1,12 +1,15 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: kinfer
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.2
|
4
4
|
Summary: The kinfer project
|
5
5
|
Home-page: https://github.com/kscalelabs/kinfer.git
|
6
|
-
Author:
|
7
|
-
Requires-Python: >=3.
|
6
|
+
Author: K-Scale Labs
|
7
|
+
Requires-Python: >=3.8
|
8
8
|
Description-Content-Type: text/markdown
|
9
9
|
License-File: LICENSE
|
10
|
+
Requires-Dist: torch
|
11
|
+
Requires-Dist: onnx
|
12
|
+
Requires-Dist: onnxruntime
|
10
13
|
Provides-Extra: dev
|
11
14
|
Requires-Dist: black; extra == "dev"
|
12
15
|
Requires-Dist: darglint; extra == "dev"
|
@@ -12,4 +12,9 @@ kinfer.egg-info/SOURCES.txt
|
|
12
12
|
kinfer.egg-info/dependency_links.txt
|
13
13
|
kinfer.egg-info/requires.txt
|
14
14
|
kinfer.egg-info/top_level.txt
|
15
|
-
|
15
|
+
kinfer/export/__init__.py
|
16
|
+
kinfer/export/pytorch.py
|
17
|
+
kinfer/inference/__init__.py
|
18
|
+
kinfer/inference/python.py
|
19
|
+
tests/test_dummy.py
|
20
|
+
tests/test_infer.py
|
@@ -46,7 +46,7 @@ target-version = "py310"
|
|
46
46
|
|
47
47
|
[tool.ruff.lint]
|
48
48
|
|
49
|
-
select = ["ANN", "D", "E", "F", "I", "N", "PGH", "PLC", "PLE", "PLR", "PLW", "W"]
|
49
|
+
select = ["ANN", "D", "E", "F", "G", "I", "N", "PGH", "PLC", "PLE", "PLR", "PLW", "W"]
|
50
50
|
|
51
51
|
ignore = [
|
52
52
|
"ANN101", "ANN102",
|
@@ -3,6 +3,7 @@
|
|
3
3
|
"""Setup script for the project."""
|
4
4
|
|
5
5
|
import re
|
6
|
+
from typing import List
|
6
7
|
|
7
8
|
from setuptools import setup
|
8
9
|
|
@@ -11,11 +12,11 @@ with open("README.md", "r", encoding="utf-8") as f:
|
|
11
12
|
|
12
13
|
|
13
14
|
with open("kinfer/requirements.txt", "r", encoding="utf-8") as f:
|
14
|
-
requirements:
|
15
|
+
requirements: List[str] = f.read().splitlines()
|
15
16
|
|
16
17
|
|
17
18
|
with open("kinfer/requirements-dev.txt", "r", encoding="utf-8") as f:
|
18
|
-
requirements_dev:
|
19
|
+
requirements_dev: List[str] = f.read().splitlines()
|
19
20
|
|
20
21
|
|
21
22
|
with open("kinfer/__init__.py", "r", encoding="utf-8") as fh:
|
@@ -28,11 +29,11 @@ setup(
|
|
28
29
|
name="kinfer",
|
29
30
|
version=version,
|
30
31
|
description="The kinfer project",
|
31
|
-
author="
|
32
|
+
author="K-Scale Labs",
|
32
33
|
url="https://github.com/kscalelabs/kinfer.git",
|
33
34
|
long_description=long_description,
|
34
35
|
long_description_content_type="text/markdown",
|
35
|
-
python_requires=">=3.
|
36
|
+
python_requires=">=3.8",
|
36
37
|
install_requires=requirements,
|
37
38
|
tests_require=requirements_dev,
|
38
39
|
extras_require={"dev": requirements_dev},
|
@@ -0,0 +1,193 @@
|
|
1
|
+
"""Tests for model inference functionality."""
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from pathlib import Path
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import pytest
|
8
|
+
import torch
|
9
|
+
|
10
|
+
from kinfer.export.pytorch import export_to_onnx
|
11
|
+
from kinfer.inference.python import ONNXModel
|
12
|
+
|
13
|
+
|
14
|
+
@dataclass
|
15
|
+
class ModelConfig:
|
16
|
+
hidden_size: int = 64
|
17
|
+
num_layers: int = 2
|
18
|
+
|
19
|
+
|
20
|
+
class SimpleModel(torch.nn.Module):
|
21
|
+
"""A simple neural network model for demonstration."""
|
22
|
+
|
23
|
+
def __init__(self, config: ModelConfig) -> None:
|
24
|
+
super().__init__()
|
25
|
+
layers = []
|
26
|
+
in_features = 10
|
27
|
+
|
28
|
+
for _ in range(config.num_layers):
|
29
|
+
layers.extend([torch.nn.Linear(in_features, config.hidden_size), torch.nn.ReLU()])
|
30
|
+
in_features = config.hidden_size
|
31
|
+
|
32
|
+
layers.append(torch.nn.Linear(config.hidden_size, 1))
|
33
|
+
self.net = torch.nn.Sequential(*layers)
|
34
|
+
|
35
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
36
|
+
return self.net(x)
|
37
|
+
|
38
|
+
|
39
|
+
@pytest.fixture
|
40
|
+
def model_path(tmp_path: Path) -> str:
|
41
|
+
"""Create and export a test model."""
|
42
|
+
# Create and export model
|
43
|
+
config = ModelConfig()
|
44
|
+
model = SimpleModel(config)
|
45
|
+
|
46
|
+
save_path = str(tmp_path / "test_model.onnx")
|
47
|
+
export_to_onnx(model=model, input_tensors=torch.randn(1, 10), config=config, save_path=save_path)
|
48
|
+
|
49
|
+
return save_path
|
50
|
+
|
51
|
+
|
52
|
+
def test_model_loading(model_path: str) -> None:
|
53
|
+
"""Test basic model loading functionality."""
|
54
|
+
# Test with default config
|
55
|
+
model = ONNXModel(model_path)
|
56
|
+
assert model is not None
|
57
|
+
|
58
|
+
model = ONNXModel(model_path)
|
59
|
+
assert model is not None
|
60
|
+
|
61
|
+
|
62
|
+
def test_model_metadata(model_path: str) -> None:
|
63
|
+
"""Test model metadata extraction."""
|
64
|
+
model = ONNXModel(model_path)
|
65
|
+
metadata = model.get_metadata()
|
66
|
+
|
67
|
+
# Check if config parameters are in metadata
|
68
|
+
assert "hidden_size" in metadata
|
69
|
+
assert "num_layers" in metadata
|
70
|
+
assert metadata["hidden_size"] == "64"
|
71
|
+
assert metadata["num_layers"] == "2"
|
72
|
+
|
73
|
+
|
74
|
+
def test_model_inference(model_path: str) -> None:
|
75
|
+
"""Test model inference with different input formats."""
|
76
|
+
model = ONNXModel(model_path)
|
77
|
+
|
78
|
+
# Test with numpy array
|
79
|
+
input_data = np.random.randn(1, 10).astype(np.float32)
|
80
|
+
output = model(input_data)
|
81
|
+
assert isinstance(output, np.ndarray)
|
82
|
+
assert output.shape == (1, 1)
|
83
|
+
|
84
|
+
# Test with dictionary input
|
85
|
+
input_name = model.get_input_details()[0]["name"]
|
86
|
+
output = model({input_name: input_data})
|
87
|
+
assert isinstance(output, dict)
|
88
|
+
|
89
|
+
# Test with list input
|
90
|
+
output = model([input_data])
|
91
|
+
assert isinstance(output, list)
|
92
|
+
|
93
|
+
|
94
|
+
def test_model_details(model_path: str) -> None:
|
95
|
+
"""Test input/output detail extraction."""
|
96
|
+
model = ONNXModel(model_path)
|
97
|
+
|
98
|
+
# Check input details
|
99
|
+
input_details = model.get_input_details()
|
100
|
+
assert len(input_details) == 1
|
101
|
+
assert input_details[0]["shape"] == [1, 10]
|
102
|
+
|
103
|
+
# Check output details
|
104
|
+
output_details = model.get_output_details()
|
105
|
+
assert len(output_details) == 1
|
106
|
+
assert output_details[0]["shape"] == [1, 1]
|
107
|
+
|
108
|
+
|
109
|
+
def test_comprehensive_model_workflow(tmp_path: Path) -> None:
|
110
|
+
"""Test complete model workflow including export, loading and inference."""
|
111
|
+
# Create and export model
|
112
|
+
config = ModelConfig(hidden_size=64, num_layers=2)
|
113
|
+
model = SimpleModel(config)
|
114
|
+
input_tensor = torch.randn(1, 10)
|
115
|
+
|
116
|
+
save_path = str(tmp_path / "test_model.onnx")
|
117
|
+
export_to_onnx(model=model, input_tensors=input_tensor, config=config, save_path=save_path)
|
118
|
+
|
119
|
+
# Load model for inference
|
120
|
+
onnx_model = ONNXModel(save_path)
|
121
|
+
|
122
|
+
# Test metadata
|
123
|
+
metadata = onnx_model.get_metadata()
|
124
|
+
assert "hidden_size" in metadata
|
125
|
+
assert "num_layers" in metadata
|
126
|
+
assert metadata["hidden_size"] == "64"
|
127
|
+
assert metadata["num_layers"] == "2"
|
128
|
+
|
129
|
+
# Test input/output details
|
130
|
+
input_details = onnx_model.get_input_details()
|
131
|
+
assert len(input_details) == 1
|
132
|
+
assert input_details[0]["shape"] == [1, 10]
|
133
|
+
|
134
|
+
output_details = onnx_model.get_output_details()
|
135
|
+
assert len(output_details) == 1
|
136
|
+
assert output_details[0]["shape"] == [1, 1]
|
137
|
+
|
138
|
+
# Test inference with different input methods
|
139
|
+
input_data = np.random.randn(1, 10).astype(np.float32)
|
140
|
+
|
141
|
+
# Method 1: Direct numpy array input
|
142
|
+
output1 = onnx_model(input_data)
|
143
|
+
assert isinstance(output1, np.ndarray)
|
144
|
+
assert output1.shape == (1, 1)
|
145
|
+
|
146
|
+
# Method 2: Dictionary input
|
147
|
+
input_name = onnx_model.get_input_details()[0]["name"]
|
148
|
+
output2 = onnx_model({input_name: input_data})
|
149
|
+
assert isinstance(output2, dict)
|
150
|
+
assert len(output2) == 1
|
151
|
+
assert list(output2.values())[0].shape == (1, 1)
|
152
|
+
|
153
|
+
# Method 3: List input
|
154
|
+
output3 = onnx_model([input_data])
|
155
|
+
assert isinstance(output3, list)
|
156
|
+
assert len(output3) == 1
|
157
|
+
assert output3[0].shape == (1, 1)
|
158
|
+
|
159
|
+
|
160
|
+
def test_export_with_given_input(tmp_path: Path) -> None:
|
161
|
+
"""Test model export with explicitly provided input tensor."""
|
162
|
+
config = ModelConfig()
|
163
|
+
model = SimpleModel(config)
|
164
|
+
|
165
|
+
# Create specific input tensor
|
166
|
+
input_tensor = torch.randn(1, 10)
|
167
|
+
|
168
|
+
save_path = str(tmp_path / "explicit_input_model.onnx")
|
169
|
+
session = export_to_onnx(model=model, input_tensors=input_tensor, config=config, save_path=save_path)
|
170
|
+
|
171
|
+
# Verify input shape matches what we provided
|
172
|
+
inputs = session.get_inputs()
|
173
|
+
assert len(inputs) == 1
|
174
|
+
assert inputs[0].shape == [1, 10]
|
175
|
+
|
176
|
+
|
177
|
+
def test_export_with_inferred_input(tmp_path: Path) -> None:
|
178
|
+
"""Test model export with automatically inferred input tensor."""
|
179
|
+
config = ModelConfig()
|
180
|
+
model = SimpleModel(config)
|
181
|
+
|
182
|
+
save_path = str(tmp_path / "inferred_input_model.onnx")
|
183
|
+
session = export_to_onnx(
|
184
|
+
model=model,
|
185
|
+
input_tensors=None,
|
186
|
+
config=config,
|
187
|
+
save_path=save_path, # Let it infer the input
|
188
|
+
)
|
189
|
+
|
190
|
+
# Verify input shape was correctly inferred
|
191
|
+
inputs = session.get_inputs()
|
192
|
+
assert len(inputs) == 1
|
193
|
+
assert inputs[0].shape == [1, 10] # Should match the in_features=10 from SimpleModel
|
kinfer-0.0.1/kinfer/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1
|
-
__version__ = "0.0.1"
|
@@ -1 +0,0 @@
|
|
1
|
-
# requirements.txt
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|