kinfer 0.0.1__tar.gz → 0.0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  MIT License
2
2
 
3
- Copyright (c) 2023 WT-MM
3
+ Copyright (c) 2023 K-Scale Labs
4
4
 
5
5
  Permission is hereby granted, free of charge, to any person obtaining a copy
6
6
  of this software and associated documentation files (the "Software"), to deal
@@ -1,12 +1,15 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: kinfer
3
- Version: 0.0.1
3
+ Version: 0.0.2
4
4
  Summary: The kinfer project
5
5
  Home-page: https://github.com/kscalelabs/kinfer.git
6
- Author: WT-MM
7
- Requires-Python: >=3.11
6
+ Author: K-Scale Labs
7
+ Requires-Python: >=3.8
8
8
  Description-Content-Type: text/markdown
9
9
  License-File: LICENSE
10
+ Requires-Dist: torch
11
+ Requires-Dist: onnx
12
+ Requires-Dist: onnxruntime
10
13
  Provides-Extra: dev
11
14
  Requires-Dist: black; extra == "dev"
12
15
  Requires-Dist: darglint; extra == "dev"
@@ -0,0 +1 @@
1
+ __version__ = "0.0.2"
@@ -0,0 +1,213 @@
1
+ """PyTorch model export utilities."""
2
+
3
+ import inspect
4
+ import sys
5
+ from dataclasses import fields, is_dataclass
6
+ from io import BytesIO
7
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
8
+
9
+ import onnx
10
+ import onnxruntime as ort # type: ignore[import-untyped]
11
+ import torch
12
+ from torch import nn
13
+
14
+
15
+ def get_model_info(model: nn.Module) -> Dict[str, Any]:
16
+ """Extract model information including input parameters and their types.
17
+
18
+ Args:
19
+ model: PyTorch model to analyze
20
+
21
+ Returns:
22
+ Dictionary containing model information
23
+ """
24
+ # Get model's forward method signature
25
+ signature = inspect.signature(model.forward)
26
+
27
+ # Extract parameter information
28
+ params_info = {}
29
+ for name, param in signature.parameters.items():
30
+ if name == "self":
31
+ continue
32
+ params_info[name] = {
33
+ "annotation": str(param.annotation),
34
+ "default": None if param.default is param.empty else str(param.default),
35
+ }
36
+
37
+ return {
38
+ "input_params": params_info,
39
+ "num_parameters": sum(p.numel() for p in model.parameters()),
40
+ }
41
+
42
+
43
+ def add_metadata_to_onnx(
44
+ model_proto: onnx.ModelProto, metadata: Dict[str, Any], config: Optional[object] = None
45
+ ) -> onnx.ModelProto:
46
+ """Add metadata to ONNX model.
47
+
48
+ Args:
49
+ model_proto: ONNX model prototype
50
+ metadata: Dictionary of metadata to add
51
+ config: Optional configuration dataclass to add to metadata
52
+
53
+ Returns:
54
+ ONNX model with added metadata
55
+ """
56
+ # Add model metadata
57
+ for key, value in metadata.items():
58
+ meta = model_proto.metadata_props.add()
59
+ meta.key = key
60
+ meta.value = str(value)
61
+
62
+ # Add configuration if provided
63
+ if config is not None and is_dataclass(config):
64
+ for field in fields(config):
65
+ value = getattr(config, field.name)
66
+ meta = model_proto.metadata_props.add()
67
+ meta.key = field.name
68
+ meta.value = str(value)
69
+
70
+ return model_proto
71
+
72
+
73
+ def infer_input_shapes(model: nn.Module) -> Union[torch.Size, List[torch.Size]]:
74
+ """Infer input shapes from model architecture.
75
+
76
+ Args:
77
+ model: PyTorch model to analyze
78
+
79
+ Returns:
80
+ Single input shape or list of input shapes
81
+ """
82
+ # Check if model is Sequential or has Sequential as first child
83
+ if isinstance(model, nn.Sequential):
84
+ first_layer = model[0]
85
+ else:
86
+ # Get the first immediate child
87
+ children = list(model.children())
88
+ first_layer = children[0] if children else None
89
+
90
+ # Unwrap if the first child is Sequential
91
+ if isinstance(first_layer, nn.Sequential):
92
+ first_layer = first_layer[0]
93
+ # Check if first layer is a type we can infer from
94
+ if not isinstance(first_layer, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
95
+ raise ValueError("First layer must be Linear or Conv layer to infer input shape")
96
+
97
+ # Get input dimensions
98
+ if isinstance(first_layer, nn.Linear):
99
+ return torch.Size([1, first_layer.in_features])
100
+ elif isinstance(first_layer, nn.Conv1d):
101
+ raise ValueError("Cannot infer sequence length for Conv1d layer. Please provide input_tensors explicitly.")
102
+ elif isinstance(first_layer, nn.Conv2d):
103
+ raise ValueError("Cannot infer image dimensions for Conv2d layer. Please provide input_tensors explicitly.")
104
+ elif isinstance(first_layer, nn.Conv3d):
105
+ raise ValueError("Cannot infer volume dimensions for Conv3d layer. Please provide input_tensors explicitly.")
106
+
107
+ raise ValueError("Could not infer input shape from model architecture")
108
+
109
+
110
+ def create_example_inputs(model: nn.Module) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
111
+ """Create example input tensors based on model's forward signature and architecture.
112
+
113
+ Args:
114
+ model: PyTorch model to analyze
115
+
116
+ Returns:
117
+ Single tensor or tuple of tensors matching the model's expected input
118
+ """
119
+ signature = inspect.signature(model.forward)
120
+ params = [p for p in signature.parameters.items() if p[0] != "self"]
121
+
122
+ # If single parameter (besides self), try to infer shape
123
+ if len(params) == 1:
124
+ shape = infer_input_shapes(model)
125
+ return torch.randn(*shape) if isinstance(shape, torch.Size) else torch.randn(*shape[0])
126
+
127
+ # For multiple parameters, try to infer from parameter annotations
128
+ input_tensors = []
129
+ for name, param in params:
130
+ # Try to get shape from annotation
131
+ if hasattr(param.annotation, "__origin__") and param.annotation.__origin__ is torch.Tensor:
132
+ # If annotation includes size information (e.g., Tensor[batch_size, channels, height, width])
133
+ if hasattr(param.annotation, "__args__"):
134
+ shape = param.annotation.__args__
135
+ input_tensors.append(torch.randn(*shape) if isinstance(shape, torch.Size) else torch.randn(*shape[0]))
136
+ else:
137
+ # Default to a vector if no size info
138
+ input_tensors.append(torch.randn(1, 32))
139
+ else:
140
+ # Default fallback
141
+ input_tensors.append(torch.randn(1, 32))
142
+
143
+ return tuple(input_tensors)
144
+
145
+
146
+ def export_to_onnx(
147
+ model: nn.Module,
148
+ input_tensors: Optional[Union[torch.Tensor, Tuple[torch.Tensor, ...]]] = None,
149
+ config: Optional[object] = None,
150
+ save_path: Optional[str] = None,
151
+ ) -> ort.InferenceSession:
152
+ """Export PyTorch model to ONNX format with metadata.
153
+
154
+ Args:
155
+ model: PyTorch model to export
156
+ input_tensors: Optional example input tensors for model tracing. If None, will attempt to infer.
157
+ config: Optional configuration dataclass to add to metadata
158
+ save_path: Optional path to save the ONNX model
159
+
160
+ Returns:
161
+ ONNX inference session
162
+ """
163
+ # Get model information
164
+ model_info = get_model_info(model)
165
+
166
+ # Create example inputs if not provided
167
+ if input_tensors is None:
168
+ try:
169
+ input_tensors = create_example_inputs(model)
170
+ model_info["inferred_input_shapes"] = str(
171
+ input_tensors.shape if isinstance(input_tensors, torch.Tensor) else [t.shape for t in input_tensors]
172
+ )
173
+ except ValueError as e:
174
+ raise ValueError(
175
+ f"Could not automatically infer input shapes. Please provide input_tensors. Error: {str(e)}"
176
+ )
177
+
178
+ # Convert model to JIT if not already
179
+ if not isinstance(model, torch.jit.ScriptModule):
180
+ model = torch.jit.script(model)
181
+
182
+ # Export model to buffer
183
+ buffer = BytesIO()
184
+ if TYPE_CHECKING and sys.version_info >= (3, 11):
185
+ torch.onnx.export(
186
+ model,
187
+ (input_tensors,) if isinstance(input_tensors, torch.Tensor) else input_tensors,
188
+ buffer, # type: ignore[arg-type]
189
+ )
190
+ else:
191
+ torch.onnx.export(
192
+ model,
193
+ (input_tensors,) if isinstance(input_tensors, torch.Tensor) else input_tensors,
194
+ buffer,
195
+ )
196
+ buffer.seek(0)
197
+
198
+ # Load as ONNX model
199
+ model_proto = onnx.load_model(buffer)
200
+
201
+ # Add metadata
202
+ model_proto = add_metadata_to_onnx(model_proto, model_info, config)
203
+
204
+ # Save if path provided
205
+ if save_path:
206
+ onnx.save_model(model_proto, save_path)
207
+
208
+ # Convert to inference session
209
+ buffer = BytesIO()
210
+ onnx.save_model(model_proto, buffer)
211
+ buffer.seek(0)
212
+
213
+ return ort.InferenceSession(buffer.read())
File without changes
@@ -0,0 +1,91 @@
1
+ """ONNX model inference utilities for Python."""
2
+
3
+ from typing import Any, Dict, List, Union
4
+
5
+ import numpy as np
6
+ import onnx
7
+ import onnxruntime as ort # type: ignore[import-untyped]
8
+
9
+
10
+ class ONNXModel:
11
+ """Wrapper for ONNX model inference."""
12
+
13
+ def __init__(
14
+ self,
15
+ model_path: str,
16
+ ) -> None:
17
+ """Initialize ONNX model.
18
+
19
+ Args:
20
+ model_path: Path to ONNX model file
21
+ config: Optional inference configuration
22
+ """
23
+ self.model_path = model_path
24
+
25
+ # Load model and create inference session
26
+ self.model = onnx.load(model_path)
27
+ self.session = ort.InferenceSession(
28
+ model_path,
29
+ )
30
+
31
+ # Extract metadata
32
+ self.metadata = {prop.key: prop.value for prop in self.model.metadata_props}
33
+
34
+ # Get input and output details
35
+ self.input_details = [{"name": x.name, "shape": x.shape, "type": x.type} for x in self.session.get_inputs()]
36
+ self.output_details = [{"name": x.name, "shape": x.shape, "type": x.type} for x in self.session.get_outputs()]
37
+
38
+ def __call__(
39
+ self, inputs: Union[np.ndarray, Dict[str, np.ndarray], List[np.ndarray]]
40
+ ) -> Union[np.ndarray, Dict[str, np.ndarray], List[np.ndarray]]:
41
+ """Run inference on input data.
42
+
43
+ Args:
44
+ inputs: Input data as numpy array, dictionary of arrays, or list of arrays
45
+
46
+ Returns:
47
+ Model outputs in the same format as inputs
48
+ """
49
+ # Convert single array to dict
50
+ if isinstance(inputs, np.ndarray):
51
+ input_dict = {self.input_details[0]["name"]: inputs}
52
+ # Convert list to dict
53
+ elif isinstance(inputs, list):
54
+ input_dict = {detail["name"]: arr for detail, arr in zip(self.input_details, inputs)}
55
+ else:
56
+ input_dict = inputs
57
+
58
+ # Run inference - pass None to output_names param to get all outputs
59
+ outputs = self.session.run(None, input_dict)
60
+
61
+ # Convert output format to match input
62
+ if isinstance(inputs, np.ndarray):
63
+ return outputs[0]
64
+ elif isinstance(inputs, list):
65
+ return outputs
66
+ else:
67
+ return {detail["name"]: arr for detail, arr in zip(self.output_details, outputs)}
68
+
69
+ def get_metadata(self) -> Dict[str, str]:
70
+ """Get model metadata.
71
+
72
+ Returns:
73
+ Dictionary of metadata key-value pairs
74
+ """
75
+ return self.metadata
76
+
77
+ def get_input_details(self) -> List[Dict[str, Any]]:
78
+ """Get input tensor details.
79
+
80
+ Returns:
81
+ List of dictionaries containing input tensor information
82
+ """
83
+ return self.input_details
84
+
85
+ def get_output_details(self) -> List[Dict[str, Any]]:
86
+ """Get output tensor details.
87
+
88
+ Returns:
89
+ List of dictionaries containing output tensor information
90
+ """
91
+ return self.output_details
File without changes
@@ -0,0 +1,5 @@
1
+ # requirements.txt
2
+
3
+ torch
4
+ onnx
5
+ onnxruntime
@@ -1,12 +1,15 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: kinfer
3
- Version: 0.0.1
3
+ Version: 0.0.2
4
4
  Summary: The kinfer project
5
5
  Home-page: https://github.com/kscalelabs/kinfer.git
6
- Author: WT-MM
7
- Requires-Python: >=3.11
6
+ Author: K-Scale Labs
7
+ Requires-Python: >=3.8
8
8
  Description-Content-Type: text/markdown
9
9
  License-File: LICENSE
10
+ Requires-Dist: torch
11
+ Requires-Dist: onnx
12
+ Requires-Dist: onnxruntime
10
13
  Provides-Extra: dev
11
14
  Requires-Dist: black; extra == "dev"
12
15
  Requires-Dist: darglint; extra == "dev"
@@ -12,4 +12,9 @@ kinfer.egg-info/SOURCES.txt
12
12
  kinfer.egg-info/dependency_links.txt
13
13
  kinfer.egg-info/requires.txt
14
14
  kinfer.egg-info/top_level.txt
15
- tests/test_dummy.py
15
+ kinfer/export/__init__.py
16
+ kinfer/export/pytorch.py
17
+ kinfer/inference/__init__.py
18
+ kinfer/inference/python.py
19
+ tests/test_dummy.py
20
+ tests/test_infer.py
@@ -1,3 +1,6 @@
1
+ torch
2
+ onnx
3
+ onnxruntime
1
4
 
2
5
  [dev]
3
6
  black
@@ -46,7 +46,7 @@ target-version = "py310"
46
46
 
47
47
  [tool.ruff.lint]
48
48
 
49
- select = ["ANN", "D", "E", "F", "I", "N", "PGH", "PLC", "PLE", "PLR", "PLW", "W"]
49
+ select = ["ANN", "D", "E", "F", "G", "I", "N", "PGH", "PLC", "PLE", "PLR", "PLW", "W"]
50
50
 
51
51
  ignore = [
52
52
  "ANN101", "ANN102",
@@ -3,6 +3,7 @@
3
3
  """Setup script for the project."""
4
4
 
5
5
  import re
6
+ from typing import List
6
7
 
7
8
  from setuptools import setup
8
9
 
@@ -11,11 +12,11 @@ with open("README.md", "r", encoding="utf-8") as f:
11
12
 
12
13
 
13
14
  with open("kinfer/requirements.txt", "r", encoding="utf-8") as f:
14
- requirements: list[str] = f.read().splitlines()
15
+ requirements: List[str] = f.read().splitlines()
15
16
 
16
17
 
17
18
  with open("kinfer/requirements-dev.txt", "r", encoding="utf-8") as f:
18
- requirements_dev: list[str] = f.read().splitlines()
19
+ requirements_dev: List[str] = f.read().splitlines()
19
20
 
20
21
 
21
22
  with open("kinfer/__init__.py", "r", encoding="utf-8") as fh:
@@ -28,11 +29,11 @@ setup(
28
29
  name="kinfer",
29
30
  version=version,
30
31
  description="The kinfer project",
31
- author="WT-MM",
32
+ author="K-Scale Labs",
32
33
  url="https://github.com/kscalelabs/kinfer.git",
33
34
  long_description=long_description,
34
35
  long_description_content_type="text/markdown",
35
- python_requires=">=3.11",
36
+ python_requires=">=3.8",
36
37
  install_requires=requirements,
37
38
  tests_require=requirements_dev,
38
39
  extras_require={"dev": requirements_dev},
@@ -0,0 +1,193 @@
1
+ """Tests for model inference functionality."""
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import pytest
8
+ import torch
9
+
10
+ from kinfer.export.pytorch import export_to_onnx
11
+ from kinfer.inference.python import ONNXModel
12
+
13
+
14
+ @dataclass
15
+ class ModelConfig:
16
+ hidden_size: int = 64
17
+ num_layers: int = 2
18
+
19
+
20
+ class SimpleModel(torch.nn.Module):
21
+ """A simple neural network model for demonstration."""
22
+
23
+ def __init__(self, config: ModelConfig) -> None:
24
+ super().__init__()
25
+ layers = []
26
+ in_features = 10
27
+
28
+ for _ in range(config.num_layers):
29
+ layers.extend([torch.nn.Linear(in_features, config.hidden_size), torch.nn.ReLU()])
30
+ in_features = config.hidden_size
31
+
32
+ layers.append(torch.nn.Linear(config.hidden_size, 1))
33
+ self.net = torch.nn.Sequential(*layers)
34
+
35
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
36
+ return self.net(x)
37
+
38
+
39
+ @pytest.fixture
40
+ def model_path(tmp_path: Path) -> str:
41
+ """Create and export a test model."""
42
+ # Create and export model
43
+ config = ModelConfig()
44
+ model = SimpleModel(config)
45
+
46
+ save_path = str(tmp_path / "test_model.onnx")
47
+ export_to_onnx(model=model, input_tensors=torch.randn(1, 10), config=config, save_path=save_path)
48
+
49
+ return save_path
50
+
51
+
52
+ def test_model_loading(model_path: str) -> None:
53
+ """Test basic model loading functionality."""
54
+ # Test with default config
55
+ model = ONNXModel(model_path)
56
+ assert model is not None
57
+
58
+ model = ONNXModel(model_path)
59
+ assert model is not None
60
+
61
+
62
+ def test_model_metadata(model_path: str) -> None:
63
+ """Test model metadata extraction."""
64
+ model = ONNXModel(model_path)
65
+ metadata = model.get_metadata()
66
+
67
+ # Check if config parameters are in metadata
68
+ assert "hidden_size" in metadata
69
+ assert "num_layers" in metadata
70
+ assert metadata["hidden_size"] == "64"
71
+ assert metadata["num_layers"] == "2"
72
+
73
+
74
+ def test_model_inference(model_path: str) -> None:
75
+ """Test model inference with different input formats."""
76
+ model = ONNXModel(model_path)
77
+
78
+ # Test with numpy array
79
+ input_data = np.random.randn(1, 10).astype(np.float32)
80
+ output = model(input_data)
81
+ assert isinstance(output, np.ndarray)
82
+ assert output.shape == (1, 1)
83
+
84
+ # Test with dictionary input
85
+ input_name = model.get_input_details()[0]["name"]
86
+ output = model({input_name: input_data})
87
+ assert isinstance(output, dict)
88
+
89
+ # Test with list input
90
+ output = model([input_data])
91
+ assert isinstance(output, list)
92
+
93
+
94
+ def test_model_details(model_path: str) -> None:
95
+ """Test input/output detail extraction."""
96
+ model = ONNXModel(model_path)
97
+
98
+ # Check input details
99
+ input_details = model.get_input_details()
100
+ assert len(input_details) == 1
101
+ assert input_details[0]["shape"] == [1, 10]
102
+
103
+ # Check output details
104
+ output_details = model.get_output_details()
105
+ assert len(output_details) == 1
106
+ assert output_details[0]["shape"] == [1, 1]
107
+
108
+
109
+ def test_comprehensive_model_workflow(tmp_path: Path) -> None:
110
+ """Test complete model workflow including export, loading and inference."""
111
+ # Create and export model
112
+ config = ModelConfig(hidden_size=64, num_layers=2)
113
+ model = SimpleModel(config)
114
+ input_tensor = torch.randn(1, 10)
115
+
116
+ save_path = str(tmp_path / "test_model.onnx")
117
+ export_to_onnx(model=model, input_tensors=input_tensor, config=config, save_path=save_path)
118
+
119
+ # Load model for inference
120
+ onnx_model = ONNXModel(save_path)
121
+
122
+ # Test metadata
123
+ metadata = onnx_model.get_metadata()
124
+ assert "hidden_size" in metadata
125
+ assert "num_layers" in metadata
126
+ assert metadata["hidden_size"] == "64"
127
+ assert metadata["num_layers"] == "2"
128
+
129
+ # Test input/output details
130
+ input_details = onnx_model.get_input_details()
131
+ assert len(input_details) == 1
132
+ assert input_details[0]["shape"] == [1, 10]
133
+
134
+ output_details = onnx_model.get_output_details()
135
+ assert len(output_details) == 1
136
+ assert output_details[0]["shape"] == [1, 1]
137
+
138
+ # Test inference with different input methods
139
+ input_data = np.random.randn(1, 10).astype(np.float32)
140
+
141
+ # Method 1: Direct numpy array input
142
+ output1 = onnx_model(input_data)
143
+ assert isinstance(output1, np.ndarray)
144
+ assert output1.shape == (1, 1)
145
+
146
+ # Method 2: Dictionary input
147
+ input_name = onnx_model.get_input_details()[0]["name"]
148
+ output2 = onnx_model({input_name: input_data})
149
+ assert isinstance(output2, dict)
150
+ assert len(output2) == 1
151
+ assert list(output2.values())[0].shape == (1, 1)
152
+
153
+ # Method 3: List input
154
+ output3 = onnx_model([input_data])
155
+ assert isinstance(output3, list)
156
+ assert len(output3) == 1
157
+ assert output3[0].shape == (1, 1)
158
+
159
+
160
+ def test_export_with_given_input(tmp_path: Path) -> None:
161
+ """Test model export with explicitly provided input tensor."""
162
+ config = ModelConfig()
163
+ model = SimpleModel(config)
164
+
165
+ # Create specific input tensor
166
+ input_tensor = torch.randn(1, 10)
167
+
168
+ save_path = str(tmp_path / "explicit_input_model.onnx")
169
+ session = export_to_onnx(model=model, input_tensors=input_tensor, config=config, save_path=save_path)
170
+
171
+ # Verify input shape matches what we provided
172
+ inputs = session.get_inputs()
173
+ assert len(inputs) == 1
174
+ assert inputs[0].shape == [1, 10]
175
+
176
+
177
+ def test_export_with_inferred_input(tmp_path: Path) -> None:
178
+ """Test model export with automatically inferred input tensor."""
179
+ config = ModelConfig()
180
+ model = SimpleModel(config)
181
+
182
+ save_path = str(tmp_path / "inferred_input_model.onnx")
183
+ session = export_to_onnx(
184
+ model=model,
185
+ input_tensors=None,
186
+ config=config,
187
+ save_path=save_path, # Let it infer the input
188
+ )
189
+
190
+ # Verify input shape was correctly inferred
191
+ inputs = session.get_inputs()
192
+ assert len(inputs) == 1
193
+ assert inputs[0].shape == [1, 10] # Should match the in_features=10 from SimpleModel
@@ -1 +0,0 @@
1
- __version__ = "0.0.1"
@@ -1 +0,0 @@
1
- # requirements.txt
File without changes
File without changes
File without changes
File without changes