kinfer 0.3.3__cp312-cp312-macosx_11_0_arm64.whl → 0.4.0__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kinfer/__init__.py +0 -5
- kinfer/common/__init__.py +0 -0
- kinfer/common/types.py +11 -0
- kinfer/export/common.py +35 -0
- kinfer/export/jax.py +51 -0
- kinfer/export/pytorch.py +42 -110
- kinfer/export/serialize.py +86 -0
- kinfer/requirements.txt +3 -4
- kinfer/rust/Cargo.toml +8 -6
- kinfer/rust/src/lib.rs +2 -11
- kinfer/rust/src/model.rs +271 -121
- kinfer/rust/src/runtime.rs +104 -0
- kinfer/rust_bindings/Cargo.toml +8 -1
- kinfer/rust_bindings/rust_bindings.pyi +35 -0
- kinfer/rust_bindings/src/lib.rs +310 -1
- kinfer/rust_bindings.cpython-312-darwin.so +0 -0
- kinfer/rust_bindings.pyi +29 -1
- kinfer-0.4.0.dist-info/METADATA +55 -0
- kinfer-0.4.0.dist-info/RECORD +26 -0
- {kinfer-0.3.3.dist-info → kinfer-0.4.0.dist-info}/WHEEL +2 -1
- kinfer/inference/__init__.py +0 -2
- kinfer/inference/base.py +0 -64
- kinfer/inference/python.py +0 -66
- kinfer/proto/__init__.py +0 -40
- kinfer/proto/kinfer_pb2.py +0 -103
- kinfer/proto/kinfer_pb2.pyi +0 -1097
- kinfer/requirements-dev.txt +0 -8
- kinfer/rust/build.rs +0 -16
- kinfer/rust/src/kinfer_proto.rs +0 -14
- kinfer/rust/src/main.rs +0 -6
- kinfer/rust/src/onnx_serializer.rs +0 -804
- kinfer/rust/src/serializer.rs +0 -221
- kinfer/rust/src/tests/onnx_serializer_tests.rs +0 -212
- kinfer/serialize/__init__.py +0 -60
- kinfer/serialize/base.py +0 -536
- kinfer/serialize/json.py +0 -399
- kinfer/serialize/numpy.py +0 -426
- kinfer/serialize/pytorch.py +0 -402
- kinfer/serialize/schema.py +0 -125
- kinfer/serialize/types.py +0 -17
- kinfer/serialize/utils.py +0 -177
- kinfer-0.3.3.dist-info/METADATA +0 -57
- kinfer-0.3.3.dist-info/RECORD +0 -40
- {kinfer-0.3.3.dist-info → kinfer-0.4.0.dist-info/licenses}/LICENSE +0 -0
- {kinfer-0.3.3.dist-info → kinfer-0.4.0.dist-info}/top_level.txt +0 -0
kinfer/__init__.py
CHANGED
@@ -1,10 +1,5 @@
|
|
1
1
|
"""Defines the kinfer API."""
|
2
2
|
|
3
|
-
from . import proto as K
|
4
|
-
from .export.pytorch import export_model, get_model
|
5
|
-
from .inference.base import KModel
|
6
|
-
from .inference.python import ONNXModel
|
7
3
|
from .rust_bindings import get_version
|
8
|
-
from .serialize import get_multi_serializer, get_serializer
|
9
4
|
|
10
5
|
__version__ = get_version()
|
File without changes
|
kinfer/common/types.py
ADDED
kinfer/export/common.py
ADDED
@@ -0,0 +1,35 @@
|
|
1
|
+
"""Defines common utilities for exporting models."""
|
2
|
+
|
3
|
+
|
4
|
+
def get_shape(
|
5
|
+
name: str,
|
6
|
+
num_joints: int | None = None,
|
7
|
+
carry_shape: tuple[int, ...] | None = None,
|
8
|
+
) -> tuple[int, ...]:
|
9
|
+
match name:
|
10
|
+
case "joint_angles":
|
11
|
+
if num_joints is None:
|
12
|
+
raise ValueError("`num_joints` must be provided when using `joint_angles`")
|
13
|
+
return (num_joints,)
|
14
|
+
|
15
|
+
case "joint_angular_velocities":
|
16
|
+
if num_joints is None:
|
17
|
+
raise ValueError("`num_joints` must be provided when using `joint_angular_velocities`")
|
18
|
+
return (num_joints,)
|
19
|
+
|
20
|
+
case "projected_gravity":
|
21
|
+
return (3,)
|
22
|
+
|
23
|
+
case "accelerometer":
|
24
|
+
return (3,)
|
25
|
+
|
26
|
+
case "gyroscope":
|
27
|
+
return (3,)
|
28
|
+
|
29
|
+
case "carry":
|
30
|
+
if carry_shape is None:
|
31
|
+
raise ValueError("`carry_shape` must be provided for `carry`")
|
32
|
+
return carry_shape
|
33
|
+
|
34
|
+
case _:
|
35
|
+
raise ValueError(f"Unknown tensor name: {name}")
|
kinfer/export/jax.py
ADDED
@@ -0,0 +1,51 @@
|
|
1
|
+
"""Jax model export utilities."""
|
2
|
+
|
3
|
+
import inspect
|
4
|
+
import logging
|
5
|
+
|
6
|
+
import tensorflow as tf
|
7
|
+
import tf2onnx
|
8
|
+
from equinox.internal._finalise_jaxpr import finalise_fn
|
9
|
+
from jax._src.stages import Wrapped
|
10
|
+
from jax.experimental import jax2tf
|
11
|
+
from onnx.onnx_pb import ModelProto
|
12
|
+
|
13
|
+
from kinfer.export.common import get_shape
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
def export_fn(
|
19
|
+
model: Wrapped,
|
20
|
+
*,
|
21
|
+
num_joints: int | None = None,
|
22
|
+
carry_shape: tuple[int, ...] | None = None,
|
23
|
+
opset: int = 13,
|
24
|
+
) -> ModelProto:
|
25
|
+
"""Export a JAX function to ONNX."""
|
26
|
+
if not isinstance(model, Wrapped):
|
27
|
+
raise ValueError("Model must be a Wrapped function")
|
28
|
+
|
29
|
+
params = inspect.signature(model).parameters
|
30
|
+
input_names = list(params.keys())
|
31
|
+
|
32
|
+
# Gets the dummy input tensors for exporting the model.
|
33
|
+
tf_args = []
|
34
|
+
for name in input_names:
|
35
|
+
shape = get_shape(
|
36
|
+
name,
|
37
|
+
num_joints=num_joints,
|
38
|
+
carry_shape=carry_shape,
|
39
|
+
)
|
40
|
+
tf_args.append(tf.TensorSpec(shape, tf.float32, name=name))
|
41
|
+
|
42
|
+
finalised_fn = finalise_fn(model)
|
43
|
+
tf_fn = tf.function(jax2tf.convert(finalised_fn, enable_xla=False))
|
44
|
+
|
45
|
+
model_proto, _ = tf2onnx.convert.from_function(
|
46
|
+
tf_fn,
|
47
|
+
input_signature=tf_args,
|
48
|
+
opset=opset,
|
49
|
+
large_model=False,
|
50
|
+
)
|
51
|
+
return model_proto
|
kinfer/export/pytorch.py
CHANGED
@@ -1,128 +1,60 @@
|
|
1
1
|
"""PyTorch model export utilities."""
|
2
2
|
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
3
|
+
__all__ = [
|
4
|
+
"export_fn",
|
5
|
+
]
|
6
|
+
|
7
|
+
import io
|
8
|
+
from typing import cast
|
7
9
|
|
8
10
|
import onnx
|
9
|
-
import onnxruntime as ort
|
10
11
|
import torch
|
11
|
-
from
|
12
|
-
|
13
|
-
from kinfer import proto as K
|
14
|
-
from kinfer.serialize.pytorch import PyTorchMultiSerializer
|
15
|
-
from kinfer.serialize.schema import get_dummy_io
|
16
|
-
from kinfer.serialize.utils import check_names_match
|
12
|
+
from onnx.onnx_pb import ModelProto
|
13
|
+
from torch._C import FunctionSchema
|
17
14
|
|
18
|
-
|
15
|
+
from kinfer.export.common import get_shape
|
19
16
|
|
20
17
|
|
21
|
-
def
|
22
|
-
|
18
|
+
def export_fn(
|
19
|
+
model: torch.jit.ScriptFunction,
|
20
|
+
*,
|
21
|
+
num_joints: int | None = None,
|
22
|
+
carry_shape: tuple[int, ...] | None = None,
|
23
|
+
) -> ModelProto:
|
24
|
+
"""Exports a PyTorch function to ONNX.
|
23
25
|
|
24
26
|
Args:
|
25
|
-
|
26
|
-
|
27
|
+
model: The model to export.
|
28
|
+
num_joints: The number of joints in the model.
|
29
|
+
carry_shape: The shape of the carry tensor.
|
27
30
|
|
28
31
|
Returns:
|
29
|
-
ONNX model
|
32
|
+
The ONNX model as a `ModelProto`.
|
30
33
|
"""
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
ONNX inference session
|
49
|
-
"""
|
50
|
-
# Matches each input name to the input values.
|
51
|
-
signature = inspect.signature(model.forward)
|
52
|
-
model_input_names = [
|
53
|
-
p.name for p in signature.parameters.values() if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
54
|
-
]
|
55
|
-
if len(model_input_names) != len(schema.input_schema.values):
|
56
|
-
raise ValueError(f"Expected {len(model_input_names)} inputs, but schema has {len(schema.input_schema.values)}")
|
57
|
-
input_schema_names = [i.value_name for i in schema.input_schema.values]
|
58
|
-
output_schema_names = [o.value_name for o in schema.output_schema.values]
|
59
|
-
|
60
|
-
if model_input_names != input_schema_names:
|
61
|
-
raise ValueError(f"Expected input names {model_input_names} to match schema names {input_schema_names}")
|
62
|
-
|
63
|
-
input_serializer = PyTorchMultiSerializer(schema.input_schema)
|
64
|
-
output_serializer = PyTorchMultiSerializer(schema.output_schema)
|
65
|
-
|
66
|
-
input_dummy_values = get_dummy_io(schema.input_schema)
|
67
|
-
input_tensors = input_serializer.serialize_io(input_dummy_values, as_dict=True)
|
68
|
-
|
69
|
-
check_names_match("model_input_names", model_input_names, "input_schema", list(input_tensors.keys()))
|
70
|
-
input_tensor_list = [input_tensors[name] for name in model_input_names]
|
71
|
-
|
72
|
-
# Attempts to run the model with the dummy inputs.
|
73
|
-
try:
|
74
|
-
pred_output_tensors = model(*input_tensor_list)
|
75
|
-
except Exception as e:
|
76
|
-
signature = inspect.signature(model.forward)
|
77
|
-
model_input_names = [
|
78
|
-
p.name for p in signature.parameters.values() if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
79
|
-
]
|
80
|
-
|
81
|
-
raise ValueError(
|
82
|
-
f"Failed to run model with dummy inputs; input names are {model_input_names} while "
|
83
|
-
f"input schema is {schema.input_schema}"
|
84
|
-
) from e
|
85
|
-
|
86
|
-
# Attempts to parse the output tensors using the output schema.
|
87
|
-
if isinstance(pred_output_tensors, Tensor):
|
88
|
-
pred_output_tensors = (pred_output_tensors,)
|
89
|
-
if isinstance(pred_output_tensors, Sequence):
|
90
|
-
pred_output_tensors = output_serializer.assign_names(pred_output_tensors)
|
91
|
-
if not isinstance(pred_output_tensors, dict):
|
92
|
-
raise ValueError("Output tensors could not be converted to dictionary")
|
93
|
-
try:
|
94
|
-
pred_output_tensors = output_serializer.deserialize_io(pred_output_tensors)
|
95
|
-
except Exception as e:
|
96
|
-
raise ValueError("Failed to parse output tensors using output schema; are you sure it is correct?") from e
|
97
|
-
|
98
|
-
# Export model to buffer
|
99
|
-
buffer = BytesIO()
|
34
|
+
if not isinstance(model, torch.jit.ScriptFunction):
|
35
|
+
raise ValueError("Model must be a torch.jit.ScriptFunction")
|
36
|
+
|
37
|
+
schema = cast(FunctionSchema, model.schema)
|
38
|
+
input_names = [arg.name for arg in schema.arguments]
|
39
|
+
|
40
|
+
# Gets the dummy input tensors for exporting the model.
|
41
|
+
args = []
|
42
|
+
for name in input_names:
|
43
|
+
shape = get_shape(
|
44
|
+
name,
|
45
|
+
num_joints=num_joints,
|
46
|
+
carry_shape=carry_shape,
|
47
|
+
)
|
48
|
+
args.append(torch.zeros(shape))
|
49
|
+
|
50
|
+
buffer = io.BytesIO()
|
100
51
|
torch.onnx.export(
|
101
52
|
model=model,
|
102
53
|
f=buffer, # type: ignore[arg-type]
|
103
|
-
|
104
|
-
input_names=
|
105
|
-
|
54
|
+
args=tuple(args),
|
55
|
+
input_names=input_names,
|
56
|
+
external_data=False,
|
106
57
|
)
|
107
58
|
buffer.seek(0)
|
108
|
-
|
109
|
-
|
110
|
-
model_proto = onnx.load_model(buffer)
|
111
|
-
model_proto = _add_metadata_to_onnx(model_proto, schema)
|
112
|
-
|
113
|
-
return model_proto
|
114
|
-
|
115
|
-
|
116
|
-
def get_model(model_proto: onnx.ModelProto) -> ort.InferenceSession:
|
117
|
-
"""Converts a model proto to an inference session.
|
118
|
-
|
119
|
-
Args:
|
120
|
-
model_proto: ONNX model proto to convert to inference session.
|
121
|
-
|
122
|
-
Returns:
|
123
|
-
ONNX inference session
|
124
|
-
"""
|
125
|
-
buffer = BytesIO()
|
126
|
-
onnx.save_model(model_proto, buffer)
|
127
|
-
buffer.seek(0)
|
128
|
-
return ort.InferenceSession(buffer.read())
|
59
|
+
model_bytes = buffer.read()
|
60
|
+
return onnx.load_from_string(model_bytes)
|
@@ -0,0 +1,86 @@
|
|
1
|
+
"""Functions for serializing and deserializing models."""
|
2
|
+
|
3
|
+
__all__ = [
|
4
|
+
"pack",
|
5
|
+
]
|
6
|
+
|
7
|
+
|
8
|
+
import io
|
9
|
+
import tarfile
|
10
|
+
|
11
|
+
from onnx.onnx_pb import ModelProto
|
12
|
+
|
13
|
+
from kinfer.common.types import Metadata
|
14
|
+
from kinfer.export.common import get_shape
|
15
|
+
|
16
|
+
|
17
|
+
def pack(
|
18
|
+
init_fn: ModelProto,
|
19
|
+
step_fn: ModelProto,
|
20
|
+
joint_names: list[str],
|
21
|
+
carry_shape: tuple[int, ...],
|
22
|
+
) -> bytes:
|
23
|
+
"""Packs the initialization function and step function into a directory.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
init_fn: The initialization function.
|
27
|
+
step_fn: The step function.
|
28
|
+
joint_names: The list of joint names, in the order that the model
|
29
|
+
expects them to be provided.
|
30
|
+
carry_shape: The shape of the carry tensor.
|
31
|
+
root_dir: The root directory of the model.
|
32
|
+
"""
|
33
|
+
num_joints = len(joint_names)
|
34
|
+
|
35
|
+
# Checks the `init` function.
|
36
|
+
if len(init_fn.graph.input) > 0:
|
37
|
+
raise ValueError(f"`init` function should not have any inputs! Got {len(init_fn.graph.input)}")
|
38
|
+
if len(init_fn.graph.output) != 1:
|
39
|
+
raise ValueError(f"`init` function should have exactly 1 output! Got {len(init_fn.graph.output)}")
|
40
|
+
|
41
|
+
# Checks the `step` function.
|
42
|
+
for step_input in step_fn.graph.input:
|
43
|
+
step_input_type = step_input.type.tensor_type
|
44
|
+
shape = tuple(dim.dim_value for dim in step_input_type.shape.dim)
|
45
|
+
expected_shape = get_shape(
|
46
|
+
step_input.name,
|
47
|
+
num_joints=num_joints,
|
48
|
+
carry_shape=carry_shape,
|
49
|
+
)
|
50
|
+
if shape != expected_shape:
|
51
|
+
raise ValueError(f"Expected shape {expected_shape} for input `{step_input.name}`, got {shape}")
|
52
|
+
|
53
|
+
if len(step_fn.graph.output) != 2:
|
54
|
+
raise ValueError(f"Step function must have exactly 2 outputs, got {len(step_fn.graph.output)}")
|
55
|
+
|
56
|
+
model_output = step_fn.graph.output[0]
|
57
|
+
output_shape = tuple(dim.dim_value for dim in model_output.type.tensor_type.shape.dim)
|
58
|
+
if output_shape != (num_joints,):
|
59
|
+
raise ValueError(f"Expected output shape {num_joints} for output `{model_output.name}`, got {output_shape}")
|
60
|
+
|
61
|
+
model_carry = step_fn.graph.output[1]
|
62
|
+
output_carry_shape = tuple(dim.dim_value for dim in model_carry.type.tensor_type.shape.dim)
|
63
|
+
if output_carry_shape != carry_shape:
|
64
|
+
raise ValueError(f"Expected carry shape {carry_shape} for output `{model_carry.name}`, got {carry_shape}")
|
65
|
+
|
66
|
+
# Builds the metadata object.
|
67
|
+
metadata = Metadata(
|
68
|
+
joint_names=joint_names,
|
69
|
+
)
|
70
|
+
|
71
|
+
buffer = io.BytesIO()
|
72
|
+
|
73
|
+
with tarfile.open(fileobj=buffer, mode="w:gz") as tar:
|
74
|
+
|
75
|
+
def add_file_bytes(name: str, data: bytes) -> None: # noqa: ANN401
|
76
|
+
info = tarfile.TarInfo(name=name)
|
77
|
+
info.size = len(data)
|
78
|
+
tar.addfile(info, io.BytesIO(data))
|
79
|
+
|
80
|
+
add_file_bytes("init_fn.onnx", init_fn.SerializeToString())
|
81
|
+
add_file_bytes("step_fn.onnx", step_fn.SerializeToString())
|
82
|
+
add_file_bytes("metadata.json", metadata.model_dump_json().encode("utf-8"))
|
83
|
+
|
84
|
+
buffer.seek(0)
|
85
|
+
|
86
|
+
return buffer.read()
|
kinfer/requirements.txt
CHANGED
kinfer/rust/Cargo.toml
CHANGED
@@ -6,15 +6,17 @@ edition = "2021"
|
|
6
6
|
|
7
7
|
[dependencies]
|
8
8
|
|
9
|
+
async-trait = "0.1"
|
10
|
+
flate2 = "1.0"
|
9
11
|
futures-util = "0.3.30"
|
10
12
|
ndarray = "0.16.1"
|
11
|
-
ort = { version = "2.0.0-rc.
|
12
|
-
|
13
|
+
ort = { version = "2.0.0-rc.9", features = [ "load-dynamic" ] }
|
14
|
+
serde = { version = "1.0", features = ["derive"] }
|
15
|
+
serde_json = "1.0"
|
16
|
+
tar = "0.4"
|
13
17
|
thiserror = "1.0"
|
18
|
+
tokio = { version = "1.0", features = ["full"] }
|
14
19
|
|
15
20
|
[dev-dependencies]
|
16
|
-
rand = "0.8"
|
17
|
-
|
18
|
-
[build-dependencies]
|
19
21
|
|
20
|
-
|
22
|
+
rand = "0.8"
|
kinfer/rust/src/lib.rs
CHANGED
@@ -1,14 +1,5 @@
|
|
1
|
-
pub mod kinfer_proto;
|
2
1
|
pub mod model;
|
3
|
-
pub mod
|
4
|
-
pub mod serializer;
|
2
|
+
pub mod runtime;
|
5
3
|
|
6
|
-
pub use kinfer_proto::*;
|
7
4
|
pub use model::*;
|
8
|
-
pub use
|
9
|
-
pub use serializer::*;
|
10
|
-
|
11
|
-
#[cfg(test)]
|
12
|
-
mod tests {
|
13
|
-
mod onnx_serializer_tests;
|
14
|
-
}
|
5
|
+
pub use runtime::*;
|