kinfer 0.3.1__cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kinfer/__init__.py +6 -0
- kinfer/export/__init__.py +1 -0
- kinfer/export/pytorch.py +128 -0
- kinfer/inference/__init__.py +1 -0
- kinfer/inference/python.py +92 -0
- kinfer/proto/__init__.py +40 -0
- kinfer/proto/kinfer_pb2.py +103 -0
- kinfer/proto/kinfer_pb2.pyi +1097 -0
- kinfer/py.typed +0 -0
- kinfer/requirements-dev.txt +8 -0
- kinfer/requirements.txt +9 -0
- kinfer/rust/Cargo.toml +20 -0
- kinfer/rust/build.rs +16 -0
- kinfer/rust/src/kinfer_proto.rs +14 -0
- kinfer/rust/src/lib.rs +14 -0
- kinfer/rust/src/main.rs +6 -0
- kinfer/rust/src/model.rs +153 -0
- kinfer/rust/src/onnx_serializer.rs +804 -0
- kinfer/rust/src/serializer.rs +221 -0
- kinfer/rust/src/tests/onnx_serializer_tests.rs +212 -0
- kinfer/rust_bindings/Cargo.toml +19 -0
- kinfer/rust_bindings/pyproject.toml +7 -0
- kinfer/rust_bindings/src/bin/stub_gen.rs +7 -0
- kinfer/rust_bindings/src/lib.rs +17 -0
- kinfer/rust_bindings.cpython-311-aarch64-linux-gnu.so +0 -0
- kinfer/rust_bindings.pyi +7 -0
- kinfer/serialize/__init__.py +36 -0
- kinfer/serialize/base.py +536 -0
- kinfer/serialize/json.py +399 -0
- kinfer/serialize/numpy.py +426 -0
- kinfer/serialize/pytorch.py +402 -0
- kinfer/serialize/schema.py +125 -0
- kinfer/serialize/types.py +17 -0
- kinfer/serialize/utils.py +177 -0
- kinfer-0.3.1.dist-info/LICENSE +21 -0
- kinfer-0.3.1.dist-info/METADATA +57 -0
- kinfer-0.3.1.dist-info/RECORD +39 -0
- kinfer-0.3.1.dist-info/WHEEL +6 -0
- kinfer-0.3.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,177 @@
|
|
1
|
+
"""Utility functions for serializing and deserializing Kinfer values."""
|
2
|
+
|
3
|
+
import math
|
4
|
+
from typing import Any, Collection
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
|
9
|
+
from kinfer import proto as P
|
10
|
+
|
11
|
+
|
12
|
+
def numpy_dtype(dtype: P.DType.ValueType) -> type[np.floating] | type[np.integer]:
|
13
|
+
match dtype:
|
14
|
+
case P.DType.FP8:
|
15
|
+
raise NotImplementedError("FP8 is not supported")
|
16
|
+
case P.DType.FP16:
|
17
|
+
return np.float16
|
18
|
+
case P.DType.FP32:
|
19
|
+
return np.float32
|
20
|
+
case P.DType.FP64:
|
21
|
+
return np.float64
|
22
|
+
case P.DType.INT8:
|
23
|
+
return np.int8
|
24
|
+
case P.DType.INT16:
|
25
|
+
return np.int16
|
26
|
+
case P.DType.INT32:
|
27
|
+
return np.int32
|
28
|
+
case P.DType.INT64:
|
29
|
+
return np.int64
|
30
|
+
case P.DType.UINT8:
|
31
|
+
return np.uint8
|
32
|
+
case P.DType.UINT16:
|
33
|
+
return np.uint16
|
34
|
+
case P.DType.UINT32:
|
35
|
+
return np.uint32
|
36
|
+
case P.DType.UINT64:
|
37
|
+
return np.uint64
|
38
|
+
case _:
|
39
|
+
raise ValueError(f"Unsupported dtype: {dtype}")
|
40
|
+
|
41
|
+
|
42
|
+
def pytorch_dtype(dtype: P.DType.ValueType) -> torch.dtype:
|
43
|
+
match dtype:
|
44
|
+
case P.DType.FP8:
|
45
|
+
raise NotImplementedError("FP8 is not supported")
|
46
|
+
case P.DType.FP16:
|
47
|
+
return torch.float16
|
48
|
+
case P.DType.FP32:
|
49
|
+
return torch.float32
|
50
|
+
case P.DType.FP64:
|
51
|
+
return torch.float64
|
52
|
+
case P.DType.INT8:
|
53
|
+
return torch.int8
|
54
|
+
case P.DType.INT16:
|
55
|
+
return torch.int16
|
56
|
+
case P.DType.INT32:
|
57
|
+
return torch.int32
|
58
|
+
case P.DType.INT64:
|
59
|
+
return torch.int64
|
60
|
+
case P.DType.UINT8:
|
61
|
+
return torch.uint8
|
62
|
+
case P.DType.UINT16:
|
63
|
+
return torch.uint16
|
64
|
+
case P.DType.UINT32:
|
65
|
+
return torch.uint32
|
66
|
+
case P.DType.UINT64:
|
67
|
+
return torch.uint64
|
68
|
+
case _:
|
69
|
+
raise ValueError(f"Unsupported dtype: {dtype}")
|
70
|
+
|
71
|
+
|
72
|
+
def parse_bytes(data: bytes, dtype: P.DType.ValueType) -> np.ndarray:
|
73
|
+
return np.frombuffer(data, dtype=numpy_dtype(dtype))
|
74
|
+
|
75
|
+
|
76
|
+
def dtype_num_bytes(dtype: P.DType.ValueType) -> int:
|
77
|
+
match dtype:
|
78
|
+
case P.DType.FP8 | P.DType.INT8 | P.DType.UINT8:
|
79
|
+
return 1
|
80
|
+
case P.DType.FP16 | P.DType.INT16 | P.DType.UINT16:
|
81
|
+
return 2
|
82
|
+
case P.DType.FP32 | P.DType.INT32 | P.DType.UINT32:
|
83
|
+
return 4
|
84
|
+
case P.DType.FP64 | P.DType.INT64 | P.DType.UINT64:
|
85
|
+
return 8
|
86
|
+
case _:
|
87
|
+
raise ValueError(f"Unsupported dtype: {dtype}")
|
88
|
+
|
89
|
+
|
90
|
+
def dtype_range(dtype: P.DType.ValueType) -> tuple[int, int]:
|
91
|
+
match dtype:
|
92
|
+
case P.DType.FP8:
|
93
|
+
return -1, 1
|
94
|
+
case P.DType.FP16:
|
95
|
+
return -1, 1
|
96
|
+
case P.DType.FP32:
|
97
|
+
return -1, 1
|
98
|
+
case P.DType.FP64:
|
99
|
+
return -1, 1
|
100
|
+
case P.DType.INT8:
|
101
|
+
return -(2**7), 2**7 - 1
|
102
|
+
case P.DType.INT16:
|
103
|
+
return -(2**15), 2**15 - 1
|
104
|
+
case P.DType.INT32:
|
105
|
+
return -(2**31), 2**31 - 1
|
106
|
+
case P.DType.INT64:
|
107
|
+
return -(2**63), 2**63 - 1
|
108
|
+
case P.DType.UINT8:
|
109
|
+
return 0, 2**8 - 1
|
110
|
+
case P.DType.UINT16:
|
111
|
+
return 0, 2**16 - 1
|
112
|
+
case P.DType.UINT32:
|
113
|
+
return 0, 2**32 - 1
|
114
|
+
case P.DType.UINT64:
|
115
|
+
return 0, 2**64 - 1
|
116
|
+
case _:
|
117
|
+
raise ValueError(f"Unsupported dtype: {dtype}")
|
118
|
+
|
119
|
+
|
120
|
+
def convert_torque(
|
121
|
+
value: float,
|
122
|
+
from_unit: P.JointTorqueUnit.ValueType,
|
123
|
+
to_unit: P.JointTorqueUnit.ValueType,
|
124
|
+
) -> float:
|
125
|
+
if from_unit == to_unit:
|
126
|
+
return value
|
127
|
+
raise ValueError(f"Unsupported unit: {from_unit}")
|
128
|
+
|
129
|
+
|
130
|
+
def convert_angular_velocity(
|
131
|
+
value: float,
|
132
|
+
from_unit: P.JointVelocityUnit.ValueType,
|
133
|
+
to_unit: P.JointVelocityUnit.ValueType,
|
134
|
+
) -> float:
|
135
|
+
if from_unit == to_unit:
|
136
|
+
return value
|
137
|
+
if from_unit == P.JointVelocityUnit.DEGREES_PER_SECOND:
|
138
|
+
assert to_unit == P.JointVelocityUnit.RADIANS_PER_SECOND
|
139
|
+
return value * math.pi / 180
|
140
|
+
if from_unit == P.JointVelocityUnit.RADIANS_PER_SECOND:
|
141
|
+
assert to_unit == P.JointVelocityUnit.DEGREES_PER_SECOND
|
142
|
+
return value * 180 / math.pi
|
143
|
+
raise ValueError(f"Unsupported unit: {from_unit}")
|
144
|
+
|
145
|
+
|
146
|
+
def convert_angular_position(
|
147
|
+
value: float,
|
148
|
+
from_unit: P.JointPositionUnit.ValueType,
|
149
|
+
to_unit: P.JointPositionUnit.ValueType,
|
150
|
+
) -> float:
|
151
|
+
if from_unit == to_unit:
|
152
|
+
return value
|
153
|
+
if from_unit == P.JointPositionUnit.DEGREES:
|
154
|
+
return value * math.pi / 180
|
155
|
+
if from_unit == P.JointPositionUnit.RADIANS:
|
156
|
+
return value * 180 / math.pi
|
157
|
+
raise ValueError(f"Unsupported unit: {from_unit}")
|
158
|
+
|
159
|
+
|
160
|
+
def check_names_match(a_name: str, a: Collection[str], b_name: str, b: Collection[str]) -> None:
|
161
|
+
name_set_a = set(a)
|
162
|
+
name_set_b = set(b)
|
163
|
+
if name_set_a != name_set_b:
|
164
|
+
only_in_a = name_set_a - name_set_b
|
165
|
+
only_in_b = name_set_b - name_set_a
|
166
|
+
message = "Names must match!"
|
167
|
+
if only_in_a:
|
168
|
+
message += f" Only in {a_name}: {only_in_a}"
|
169
|
+
if only_in_b:
|
170
|
+
message += f" Only in {b_name}: {only_in_b}"
|
171
|
+
raise ValueError(message)
|
172
|
+
|
173
|
+
|
174
|
+
def as_float(value: Any) -> float: # noqa: ANN401
|
175
|
+
if not isinstance(value, (float, int)):
|
176
|
+
raise ValueError(f"Value must be a float or int: {value}")
|
177
|
+
return float(value)
|
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2023 K-Scale Labs
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|
@@ -0,0 +1,57 @@
|
|
1
|
+
Metadata-Version: 2.1
|
2
|
+
Name: kinfer
|
3
|
+
Version: 0.3.1
|
4
|
+
Summary: Tool to make it easier to run a model on a real robot
|
5
|
+
Home-page: https://github.com/kscalelabs/kinfer.git
|
6
|
+
Author: K-Scale Labs
|
7
|
+
Requires-Python: >=3.11
|
8
|
+
Description-Content-Type: text/markdown
|
9
|
+
License-File: LICENSE
|
10
|
+
Requires-Dist: torch
|
11
|
+
Requires-Dist: onnx
|
12
|
+
Requires-Dist: onnxruntime
|
13
|
+
Requires-Dist: protobuf
|
14
|
+
Provides-Extra: dev
|
15
|
+
Requires-Dist: black; extra == "dev"
|
16
|
+
Requires-Dist: darglint; extra == "dev"
|
17
|
+
Requires-Dist: mypy; extra == "dev"
|
18
|
+
Requires-Dist: mypy-protobuf; extra == "dev"
|
19
|
+
Requires-Dist: pytest; extra == "dev"
|
20
|
+
Requires-Dist: ruff; extra == "dev"
|
21
|
+
|
22
|
+
# kinfer
|
23
|
+
|
24
|
+
This package is designed to support exporting and running inference on PyTorch models.
|
25
|
+
|
26
|
+
## Installation
|
27
|
+
|
28
|
+
```bash
|
29
|
+
pip install kinfer
|
30
|
+
```
|
31
|
+
|
32
|
+
### ONNX Runtime
|
33
|
+
|
34
|
+
You can install the latest version of ONNX Runtime on Mac with:
|
35
|
+
|
36
|
+
```bash
|
37
|
+
brew install onnxruntime
|
38
|
+
```
|
39
|
+
|
40
|
+
You may need to add the binary to your DYLD_LIBRARY_PATH:
|
41
|
+
|
42
|
+
```bash
|
43
|
+
$ brew ls onnxruntime
|
44
|
+
/opt/homebrew/Cellar/onnxruntime/1.20.1/include/onnxruntime/ (11 files)
|
45
|
+
/opt/homebrew/Cellar/onnxruntime/1.20.1/lib/libonnxruntime.1.20.1.dylib # <-- This is the binary
|
46
|
+
/opt/homebrew/Cellar/onnxruntime/1.20.1/lib/cmake/ (4 files)
|
47
|
+
/opt/homebrew/Cellar/onnxruntime/1.20.1/lib/pkgconfig/libonnxruntime.pc
|
48
|
+
/opt/homebrew/Cellar/onnxruntime/1.20.1/lib/libonnxruntime.dylib
|
49
|
+
/opt/homebrew/Cellar/onnxruntime/1.20.1/sbom.spdx.json
|
50
|
+
$ export DYLD_LIBRARY_PATH=/opt/homebrew/Cellar/onnxruntime/1.20.1/lib:$DYLD_LIBRARY_PATH
|
51
|
+
```
|
52
|
+
|
53
|
+
### Considerations for Exporting PyTorch Models
|
54
|
+
|
55
|
+
Don't use common names for the inputs to your forward pass. E.g. `input`, `output`, `state`, `state_tensor`, `buffer`, etc.
|
56
|
+
|
57
|
+
This is because ONNX has internal names for the model and if there's a conflict, the inputs will have a .1, .2, etc. suffix which makes it really hard to figure out what value_name to pass into your kinfer io values.
|
@@ -0,0 +1,39 @@
|
|
1
|
+
kinfer/requirements-dev.txt,sha256=jZzaEENgXPHkAjbnwdglVCaV_GDkwCWIX2BWX0e3uLc,70
|
2
|
+
kinfer/requirements.txt,sha256=-xI0a63Os_v2mkO0Cy5bA5envglD-2c4mdSeLG4movA,91
|
3
|
+
kinfer/__init__.py,sha256=LSRl3lC7SWBk9or4fT1TYYPP92yhEeRca-3gxG8urM0,131
|
4
|
+
kinfer/rust_bindings.cpython-311-aarch64-linux-gnu.so,sha256=i9LCM4TsWf8H5MemGbrp1VFvJhjPOaFxlgxK-EXT4hk,542400
|
5
|
+
kinfer/rust_bindings.pyi,sha256=KFNalMmwkfoS8BwJfVAz5APaqlW1ryp0xpmnbfud5kc,118
|
6
|
+
kinfer/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
|
+
kinfer/serialize/utils.py,sha256=VzchOS4Wv6zhm1STmth3kGV83ExNtlE0Vio1K3kM_tE,5382
|
8
|
+
kinfer/serialize/json.py,sha256=fpacSJoyILWrPJPqJyqGJ9HpgweLSpNYokCWb7rXqks,15081
|
9
|
+
kinfer/serialize/pytorch.py,sha256=1trcjjdywLh93hLnA-6WB1q558RM7gOopdnluOV6ivk,16169
|
10
|
+
kinfer/serialize/__init__.py,sha256=zUAaf-EpWtDF3ZTneoP0-wS2TjzhFrp5-5xs_yie1ZA,1281
|
11
|
+
kinfer/serialize/schema.py,sha256=xNkY6WQ55sfmZ4QgZy4eg4hxayCEFnAiFxYvyl_ojMQ,4730
|
12
|
+
kinfer/serialize/types.py,sha256=vWc2QqiMBodu0zUx0u69su6h42J3PTpS8ttPYHaEKlA,427
|
13
|
+
kinfer/serialize/base.py,sha256=rCtfvexmhYGHBXAMxXgltMLE4zkshhqDNhaWu009y2c,16390
|
14
|
+
kinfer/serialize/numpy.py,sha256=g9u51h62-aTosNElSnAc5_8OHxCmDQaGWmeIAwSUm40,16146
|
15
|
+
kinfer/proto/kinfer_pb2.pyi,sha256=ogFIJ7p2G-ovFig6I0RfIpH4oZ4dd2MkxPznBXENkb8,34151
|
16
|
+
kinfer/proto/kinfer_pb2.py,sha256=sng4W-ZOdBgCowYK3JcoxoFOmUhPac07YDaoi6noeq0,11968
|
17
|
+
kinfer/proto/__init__.py,sha256=Zz2ApZaHIF375552y54X63XEq3XhZjyujmX_MsNjLmQ,848
|
18
|
+
kinfer/rust_bindings/pyproject.toml,sha256=jLcJuHCnQRh9HWR_R7a9qLHwj6LMBgnHyeKK_DruO1Y,135
|
19
|
+
kinfer/rust_bindings/Cargo.toml,sha256=VZ6MY1QpqUQnDC1ygeSScaf8X-m1zXye6yaODV4RE5o,387
|
20
|
+
kinfer/rust_bindings/src/lib.rs,sha256=rIq7S32Xjv83GXUW1ZNAd5Ex1HwDAYvG0quYJr6VaAQ,405
|
21
|
+
kinfer/rust_bindings/src/bin/stub_gen.rs,sha256=hhoVGnaSfazbSfj5a4x6mPicGPOgWQAfsDmiPej0B6Y,133
|
22
|
+
kinfer/inference/python.py,sha256=RN7bk-41EJjeWaokqMRFXLqw8rpPfLJsrtHmhiIy2OM,3012
|
23
|
+
kinfer/inference/__init__.py,sha256=fT52B78HnErtXSv2pvipsDVsufZlmnMzvR9Ny0UudoE,22
|
24
|
+
kinfer/rust/Cargo.toml,sha256=3dY7WVXzcNDBinkC2NQ4XtRSXSrwID8k6Y404uCi4jo,296
|
25
|
+
kinfer/rust/build.rs,sha256=VTvGhCtvEPbr-w0mRptPyiAF8rPlPnmyFVThb6NUSPc,422
|
26
|
+
kinfer/rust/src/kinfer_proto.rs,sha256=gpgy5amhHroP40aQVSSQ8xlNFMwdvTAy3dkK-V8Z1Ms,801
|
27
|
+
kinfer/rust/src/model.rs,sha256=5IWOkl_yj8Ea7yZkFug18ALwKrOtUAg17d7RphrMP1Y,5243
|
28
|
+
kinfer/rust/src/lib.rs,sha256=t5ca5ZO4Ss41wSroBa5PIJ64ifc4LuEheRqgDs2Ev6A,236
|
29
|
+
kinfer/rust/src/main.rs,sha256=JQZRmmEQ4wKDmprbGrQUpvN7r8F03Asg-bIIhemSxUA,114
|
30
|
+
kinfer/rust/src/serializer.rs,sha256=jaSnb6K8DDynCLdgBU4Nl4Ig8ZrZRdQUXjdZF1Ew7qs,6151
|
31
|
+
kinfer/rust/src/onnx_serializer.rs,sha256=Ti0iLoJi5VPkX_Yq7QPHzmtiJZPXvySUqRw8syAJFw4,29301
|
32
|
+
kinfer/rust/src/tests/onnx_serializer_tests.rs,sha256=KdrysHqhYC27lIdkNTSUJ9CTBs0PN0-8aw2Xe3pqk4E,7481
|
33
|
+
kinfer/export/pytorch.py,sha256=xz8DKT1ZBTYITLGh-ef01eEGvWOzQ-SMInRI8ol1RzM,4529
|
34
|
+
kinfer/export/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
|
35
|
+
kinfer-0.3.1.dist-info/top_level.txt,sha256=6mY_t3PYr3Dm0dpqMk80uSnArbvGfCFkxOh1QWtgDEo,7
|
36
|
+
kinfer-0.3.1.dist-info/RECORD,,
|
37
|
+
kinfer-0.3.1.dist-info/LICENSE,sha256=Qw-Z0XTwS-diSW91e_jLeBPX9zZbAatOJTBLdPHPaC0,1069
|
38
|
+
kinfer-0.3.1.dist-info/METADATA,sha256=lKtjAlL0xrRkB0OqMmpt_f3v0KygFbRiYtUgEAY1Kfw,1884
|
39
|
+
kinfer-0.3.1.dist-info/WHEEL,sha256=ZiHiI0fxbnsGhDML32hrhH3YKU2c-6yRirdNq7QKO5A,153
|
@@ -0,0 +1 @@
|
|
1
|
+
kinfer
|