kinfer 0.3.2__tar.gz → 0.4.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {kinfer-0.3.2 → kinfer-0.4.1}/.cargo/config.toml +3 -1
- {kinfer-0.3.2 → kinfer-0.4.1}/Cargo.toml +1 -1
- kinfer-0.4.1/PKG-INFO +55 -0
- kinfer-0.4.1/README.md +5 -0
- {kinfer-0.3.2 → kinfer-0.4.1}/kinfer/__init__.py +0 -1
- kinfer-0.4.1/kinfer/common/types.py +12 -0
- kinfer-0.4.1/kinfer/export/__init__.py +0 -0
- kinfer-0.4.1/kinfer/export/common.py +41 -0
- kinfer-0.4.1/kinfer/export/jax.py +53 -0
- kinfer-0.4.1/kinfer/export/pytorch.py +63 -0
- kinfer-0.4.1/kinfer/export/serialize.py +93 -0
- kinfer-0.4.1/kinfer/py.typed +0 -0
- kinfer-0.4.1/kinfer/requirements.txt +8 -0
- kinfer-0.4.1/kinfer/rust/Cargo.toml +32 -0
- kinfer-0.4.1/kinfer/rust/src/lib.rs +5 -0
- kinfer-0.4.1/kinfer/rust/src/model.rs +318 -0
- kinfer-0.4.1/kinfer/rust/src/runtime.rs +104 -0
- {kinfer-0.3.2 → kinfer-0.4.1}/kinfer/rust_bindings/Cargo.toml +8 -1
- kinfer-0.4.1/kinfer/rust_bindings/src/lib.rs +342 -0
- kinfer-0.4.1/kinfer.egg-info/PKG-INFO +55 -0
- kinfer-0.4.1/kinfer.egg-info/SOURCES.txt +33 -0
- kinfer-0.4.1/kinfer.egg-info/requires.txt +35 -0
- {kinfer-0.3.2 → kinfer-0.4.1}/pyproject.toml +1 -0
- {kinfer-0.3.2 → kinfer-0.4.1}/setup.py +27 -19
- kinfer-0.4.1/tests/test_jax.py +114 -0
- kinfer-0.4.1/tests/test_pytorch.py +153 -0
- kinfer-0.3.2/PKG-INFO +0 -57
- kinfer-0.3.2/README.md +0 -36
- kinfer-0.3.2/kinfer/export/__init__.py +0 -1
- kinfer-0.3.2/kinfer/export/pytorch.py +0 -128
- kinfer-0.3.2/kinfer/inference/__init__.py +0 -1
- kinfer-0.3.2/kinfer/inference/python.py +0 -92
- kinfer-0.3.2/kinfer/proto/__init__.py +0 -40
- kinfer-0.3.2/kinfer/proto/kinfer_pb2.py +0 -103
- kinfer-0.3.2/kinfer/proto/kinfer_pb2.pyi +0 -1097
- kinfer-0.3.2/kinfer/requirements-dev.txt +0 -8
- kinfer-0.3.2/kinfer/requirements.txt +0 -9
- kinfer-0.3.2/kinfer/rust/Cargo.toml +0 -20
- kinfer-0.3.2/kinfer/rust/build.rs +0 -16
- kinfer-0.3.2/kinfer/rust/src/kinfer_proto.rs +0 -14
- kinfer-0.3.2/kinfer/rust/src/lib.rs +0 -14
- kinfer-0.3.2/kinfer/rust/src/main.rs +0 -6
- kinfer-0.3.2/kinfer/rust/src/model.rs +0 -153
- kinfer-0.3.2/kinfer/rust/src/onnx_serializer.rs +0 -804
- kinfer-0.3.2/kinfer/rust/src/serializer.rs +0 -221
- kinfer-0.3.2/kinfer/rust/src/tests/onnx_serializer_tests.rs +0 -212
- kinfer-0.3.2/kinfer/rust_bindings/src/lib.rs +0 -17
- kinfer-0.3.2/kinfer/rust_bindings.pyi +0 -7
- kinfer-0.3.2/kinfer/serialize/__init__.py +0 -36
- kinfer-0.3.2/kinfer/serialize/base.py +0 -536
- kinfer-0.3.2/kinfer/serialize/json.py +0 -399
- kinfer-0.3.2/kinfer/serialize/numpy.py +0 -426
- kinfer-0.3.2/kinfer/serialize/pytorch.py +0 -402
- kinfer-0.3.2/kinfer/serialize/schema.py +0 -125
- kinfer-0.3.2/kinfer/serialize/types.py +0 -17
- kinfer-0.3.2/kinfer/serialize/utils.py +0 -177
- kinfer-0.3.2/kinfer.egg-info/PKG-INFO +0 -57
- kinfer-0.3.2/kinfer.egg-info/SOURCES.txt +0 -48
- kinfer-0.3.2/kinfer.egg-info/requires.txt +0 -12
- kinfer-0.3.2/tests/test_infer.py +0 -101
- kinfer-0.3.2/tests/test_schema.py +0 -229
- {kinfer-0.3.2 → kinfer-0.4.1}/LICENSE +0 -0
- {kinfer-0.3.2 → kinfer-0.4.1}/MANIFEST.in +0 -0
- /kinfer-0.3.2/kinfer/py.typed → /kinfer-0.4.1/kinfer/common/__init__.py +0 -0
- {kinfer-0.3.2 → kinfer-0.4.1}/kinfer/rust_bindings/pyproject.toml +0 -0
- {kinfer-0.3.2 → kinfer-0.4.1}/kinfer/rust_bindings/src/bin/stub_gen.rs +0 -0
- {kinfer-0.3.2 → kinfer-0.4.1}/kinfer.egg-info/dependency_links.txt +0 -0
- {kinfer-0.3.2 → kinfer-0.4.1}/kinfer.egg-info/not-zip-safe +0 -0
- {kinfer-0.3.2 → kinfer-0.4.1}/kinfer.egg-info/top_level.txt +0 -0
- {kinfer-0.3.2 → kinfer-0.4.1}/setup.cfg +0 -0
kinfer-0.4.1/PKG-INFO
ADDED
@@ -0,0 +1,55 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: kinfer
|
3
|
+
Version: 0.4.1
|
4
|
+
Summary: Tool to make it easier to run a model on a real robot
|
5
|
+
Home-page: https://github.com/kscalelabs/kinfer.git
|
6
|
+
Author: K-Scale Labs
|
7
|
+
Requires-Python: >=3.11
|
8
|
+
Description-Content-Type: text/markdown
|
9
|
+
License-File: LICENSE
|
10
|
+
Requires-Dist: onnx
|
11
|
+
Requires-Dist: onnxruntime==1.20.0
|
12
|
+
Requires-Dist: pydantic
|
13
|
+
Provides-Extra: dev
|
14
|
+
Requires-Dist: black; extra == "dev"
|
15
|
+
Requires-Dist: darglint; extra == "dev"
|
16
|
+
Requires-Dist: mypy; extra == "dev"
|
17
|
+
Requires-Dist: pytest; extra == "dev"
|
18
|
+
Requires-Dist: ruff; extra == "dev"
|
19
|
+
Requires-Dist: types-tensorflow; extra == "dev"
|
20
|
+
Provides-Extra: pytorch
|
21
|
+
Requires-Dist: torch; extra == "pytorch"
|
22
|
+
Provides-Extra: jax
|
23
|
+
Requires-Dist: tensorflow; extra == "jax"
|
24
|
+
Requires-Dist: tf2onnx; extra == "jax"
|
25
|
+
Requires-Dist: jax; extra == "jax"
|
26
|
+
Requires-Dist: equinox; extra == "jax"
|
27
|
+
Requires-Dist: numpy<2; extra == "jax"
|
28
|
+
Provides-Extra: all
|
29
|
+
Requires-Dist: black; extra == "all"
|
30
|
+
Requires-Dist: darglint; extra == "all"
|
31
|
+
Requires-Dist: mypy; extra == "all"
|
32
|
+
Requires-Dist: pytest; extra == "all"
|
33
|
+
Requires-Dist: ruff; extra == "all"
|
34
|
+
Requires-Dist: types-tensorflow; extra == "all"
|
35
|
+
Requires-Dist: torch; extra == "all"
|
36
|
+
Requires-Dist: tensorflow; extra == "all"
|
37
|
+
Requires-Dist: tf2onnx; extra == "all"
|
38
|
+
Requires-Dist: jax; extra == "all"
|
39
|
+
Requires-Dist: equinox; extra == "all"
|
40
|
+
Requires-Dist: numpy<2; extra == "all"
|
41
|
+
Dynamic: author
|
42
|
+
Dynamic: description
|
43
|
+
Dynamic: description-content-type
|
44
|
+
Dynamic: home-page
|
45
|
+
Dynamic: license-file
|
46
|
+
Dynamic: provides-extra
|
47
|
+
Dynamic: requires-dist
|
48
|
+
Dynamic: requires-python
|
49
|
+
Dynamic: summary
|
50
|
+
|
51
|
+
# kinfer
|
52
|
+
|
53
|
+
This package is designed to support running real-time robotics models.
|
54
|
+
|
55
|
+
For more information, see the documentation [here](https://docs.kscale.dev/docs/k-infer).
|
kinfer-0.4.1/README.md
ADDED
File without changes
|
@@ -0,0 +1,41 @@
|
|
1
|
+
"""Defines common utilities for exporting models."""
|
2
|
+
|
3
|
+
|
4
|
+
def get_shape(
|
5
|
+
name: str,
|
6
|
+
num_joints: int | None = None,
|
7
|
+
num_commands: int | None = None,
|
8
|
+
carry_shape: tuple[int, ...] | None = None,
|
9
|
+
) -> tuple[int, ...]:
|
10
|
+
match name:
|
11
|
+
case "joint_angles":
|
12
|
+
if num_joints is None:
|
13
|
+
raise ValueError("`num_joints` must be provided when using `joint_angles`")
|
14
|
+
return (num_joints,)
|
15
|
+
|
16
|
+
case "joint_angular_velocities":
|
17
|
+
if num_joints is None:
|
18
|
+
raise ValueError("`num_joints` must be provided when using `joint_angular_velocities`")
|
19
|
+
return (num_joints,)
|
20
|
+
|
21
|
+
case "projected_gravity":
|
22
|
+
return (3,)
|
23
|
+
|
24
|
+
case "accelerometer":
|
25
|
+
return (3,)
|
26
|
+
|
27
|
+
case "gyroscope":
|
28
|
+
return (3,)
|
29
|
+
|
30
|
+
case "command":
|
31
|
+
if num_commands is None:
|
32
|
+
raise ValueError("`num_commands` must be provided when using `command`")
|
33
|
+
return (num_commands,)
|
34
|
+
|
35
|
+
case "carry":
|
36
|
+
if carry_shape is None:
|
37
|
+
raise ValueError("`carry_shape` must be provided for `carry`")
|
38
|
+
return carry_shape
|
39
|
+
|
40
|
+
case _:
|
41
|
+
raise ValueError(f"Unknown tensor name: {name}")
|
@@ -0,0 +1,53 @@
|
|
1
|
+
"""Jax model export utilities."""
|
2
|
+
|
3
|
+
import inspect
|
4
|
+
import logging
|
5
|
+
|
6
|
+
import tensorflow as tf
|
7
|
+
import tf2onnx
|
8
|
+
from equinox.internal._finalise_jaxpr import finalise_fn
|
9
|
+
from jax._src.stages import Wrapped
|
10
|
+
from jax.experimental import jax2tf
|
11
|
+
from onnx.onnx_pb import ModelProto
|
12
|
+
|
13
|
+
from kinfer.export.common import get_shape
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
def export_fn(
|
19
|
+
model: Wrapped,
|
20
|
+
*,
|
21
|
+
num_joints: int | None = None,
|
22
|
+
num_commands: int | None = None,
|
23
|
+
carry_shape: tuple[int, ...] | None = None,
|
24
|
+
opset: int = 13,
|
25
|
+
) -> ModelProto:
|
26
|
+
"""Export a JAX function to ONNX."""
|
27
|
+
if not isinstance(model, Wrapped):
|
28
|
+
raise ValueError("Model must be a Wrapped function")
|
29
|
+
|
30
|
+
params = inspect.signature(model).parameters
|
31
|
+
input_names = list(params.keys())
|
32
|
+
|
33
|
+
# Gets the dummy input tensors for exporting the model.
|
34
|
+
tf_args = []
|
35
|
+
for name in input_names:
|
36
|
+
shape = get_shape(
|
37
|
+
name,
|
38
|
+
num_joints=num_joints,
|
39
|
+
num_commands=num_commands,
|
40
|
+
carry_shape=carry_shape,
|
41
|
+
)
|
42
|
+
tf_args.append(tf.TensorSpec(shape, tf.float32, name=name))
|
43
|
+
|
44
|
+
finalised_fn = finalise_fn(model)
|
45
|
+
tf_fn = tf.function(jax2tf.convert(finalised_fn, enable_xla=False))
|
46
|
+
|
47
|
+
model_proto, _ = tf2onnx.convert.from_function(
|
48
|
+
tf_fn,
|
49
|
+
input_signature=tf_args,
|
50
|
+
opset=opset,
|
51
|
+
large_model=False,
|
52
|
+
)
|
53
|
+
return model_proto
|
@@ -0,0 +1,63 @@
|
|
1
|
+
"""PyTorch model export utilities."""
|
2
|
+
|
3
|
+
__all__ = [
|
4
|
+
"export_fn",
|
5
|
+
]
|
6
|
+
|
7
|
+
import io
|
8
|
+
from typing import cast
|
9
|
+
|
10
|
+
import onnx
|
11
|
+
import torch
|
12
|
+
from onnx.onnx_pb import ModelProto
|
13
|
+
from torch._C import FunctionSchema
|
14
|
+
|
15
|
+
from kinfer.export.common import get_shape
|
16
|
+
|
17
|
+
|
18
|
+
def export_fn(
|
19
|
+
model: torch.jit.ScriptFunction,
|
20
|
+
*,
|
21
|
+
num_joints: int | None = None,
|
22
|
+
num_commands: int | None = None,
|
23
|
+
carry_shape: tuple[int, ...] | None = None,
|
24
|
+
) -> ModelProto:
|
25
|
+
"""Exports a PyTorch function to ONNX.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
model: The model to export.
|
29
|
+
num_joints: The number of joints in the model.
|
30
|
+
num_commands: The number of commands in the model.
|
31
|
+
carry_shape: The shape of the carry tensor.
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
The ONNX model as a `ModelProto`.
|
35
|
+
"""
|
36
|
+
if not isinstance(model, torch.jit.ScriptFunction):
|
37
|
+
raise ValueError("Model must be a torch.jit.ScriptFunction")
|
38
|
+
|
39
|
+
schema = cast(FunctionSchema, model.schema)
|
40
|
+
input_names = [arg.name for arg in schema.arguments]
|
41
|
+
|
42
|
+
# Gets the dummy input tensors for exporting the model.
|
43
|
+
args = []
|
44
|
+
for name in input_names:
|
45
|
+
shape = get_shape(
|
46
|
+
name,
|
47
|
+
num_joints=num_joints,
|
48
|
+
num_commands=num_commands,
|
49
|
+
carry_shape=carry_shape,
|
50
|
+
)
|
51
|
+
args.append(torch.zeros(shape))
|
52
|
+
|
53
|
+
buffer = io.BytesIO()
|
54
|
+
torch.onnx.export(
|
55
|
+
model=model,
|
56
|
+
f=buffer, # type: ignore[arg-type]
|
57
|
+
args=tuple(args),
|
58
|
+
input_names=input_names,
|
59
|
+
external_data=False,
|
60
|
+
)
|
61
|
+
buffer.seek(0)
|
62
|
+
model_bytes = buffer.read()
|
63
|
+
return onnx.load_from_string(model_bytes)
|
@@ -0,0 +1,93 @@
|
|
1
|
+
"""Functions for serializing and deserializing models."""
|
2
|
+
|
3
|
+
__all__ = [
|
4
|
+
"pack",
|
5
|
+
]
|
6
|
+
|
7
|
+
|
8
|
+
import io
|
9
|
+
import tarfile
|
10
|
+
|
11
|
+
from onnx.onnx_pb import ModelProto
|
12
|
+
|
13
|
+
from kinfer.common.types import Metadata
|
14
|
+
from kinfer.export.common import get_shape
|
15
|
+
|
16
|
+
|
17
|
+
def pack(
|
18
|
+
init_fn: ModelProto,
|
19
|
+
step_fn: ModelProto,
|
20
|
+
joint_names: list[str],
|
21
|
+
num_commands: int | None = None,
|
22
|
+
carry_shape: tuple[int, ...] | None = None,
|
23
|
+
) -> bytes:
|
24
|
+
"""Packs the initialization function and step function into a directory.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
init_fn: The initialization function.
|
28
|
+
step_fn: The step function.
|
29
|
+
joint_names: The list of joint names, in the order that the model
|
30
|
+
expects them to be provided.
|
31
|
+
num_commands: The number of commands in the model.
|
32
|
+
carry_shape: The shape of the carry tensor.
|
33
|
+
"""
|
34
|
+
num_joints = len(joint_names)
|
35
|
+
|
36
|
+
# Checks the `init` function.
|
37
|
+
if len(init_fn.graph.input) > 0:
|
38
|
+
raise ValueError(f"`init` function should not have any inputs! Got {len(init_fn.graph.input)}")
|
39
|
+
if len(init_fn.graph.output) != 1:
|
40
|
+
raise ValueError(f"`init` function should have exactly 1 output! Got {len(init_fn.graph.output)}")
|
41
|
+
init_carry = init_fn.graph.output[0]
|
42
|
+
init_carry_shape = tuple(dim.dim_value for dim in init_carry.type.tensor_type.shape.dim)
|
43
|
+
if carry_shape is not None and init_carry_shape != carry_shape:
|
44
|
+
raise ValueError(f"Expected carry shape {carry_shape} for output `{init_carry.name}`, got {init_carry_shape}")
|
45
|
+
|
46
|
+
# Checks the `step` function.
|
47
|
+
for step_input in step_fn.graph.input:
|
48
|
+
step_input_type = step_input.type.tensor_type
|
49
|
+
shape = tuple(dim.dim_value for dim in step_input_type.shape.dim)
|
50
|
+
expected_shape = get_shape(
|
51
|
+
step_input.name,
|
52
|
+
num_joints=num_joints,
|
53
|
+
num_commands=num_commands,
|
54
|
+
carry_shape=carry_shape,
|
55
|
+
)
|
56
|
+
if shape != expected_shape:
|
57
|
+
raise ValueError(f"Expected shape {expected_shape} for input `{step_input.name}`, got {shape}")
|
58
|
+
|
59
|
+
if len(step_fn.graph.output) != 2:
|
60
|
+
raise ValueError(f"Step function must have exactly 2 outputs, got {len(step_fn.graph.output)}")
|
61
|
+
|
62
|
+
output_actions = step_fn.graph.output[0]
|
63
|
+
actions_shape = tuple(dim.dim_value for dim in output_actions.type.tensor_type.shape.dim)
|
64
|
+
if actions_shape != (num_joints,):
|
65
|
+
raise ValueError(f"Expected output shape {num_joints} for output `{output_actions.name}`, got {actions_shape}")
|
66
|
+
|
67
|
+
output_carry = step_fn.graph.output[1]
|
68
|
+
output_carry_shape = tuple(dim.dim_value for dim in output_carry.type.tensor_type.shape.dim)
|
69
|
+
if output_carry_shape != init_carry_shape:
|
70
|
+
raise ValueError(f"Expected carry shape {init_carry_shape} for output carry, got {output_carry_shape}")
|
71
|
+
|
72
|
+
# Builds the metadata object.
|
73
|
+
metadata = Metadata(
|
74
|
+
joint_names=joint_names,
|
75
|
+
num_commands=num_commands,
|
76
|
+
)
|
77
|
+
|
78
|
+
buffer = io.BytesIO()
|
79
|
+
|
80
|
+
with tarfile.open(fileobj=buffer, mode="w:gz") as tar:
|
81
|
+
|
82
|
+
def add_file_bytes(name: str, data: bytes) -> None: # noqa: ANN401
|
83
|
+
info = tarfile.TarInfo(name=name)
|
84
|
+
info.size = len(data)
|
85
|
+
tar.addfile(info, io.BytesIO(data))
|
86
|
+
|
87
|
+
add_file_bytes("init_fn.onnx", init_fn.SerializeToString())
|
88
|
+
add_file_bytes("step_fn.onnx", step_fn.SerializeToString())
|
89
|
+
add_file_bytes("metadata.json", metadata.model_dump_json().encode("utf-8"))
|
90
|
+
|
91
|
+
buffer.seek(0)
|
92
|
+
|
93
|
+
return buffer.read()
|
File without changes
|
@@ -0,0 +1,32 @@
|
|
1
|
+
[package]
|
2
|
+
|
3
|
+
name = "kinfer"
|
4
|
+
version.workspace = true
|
5
|
+
edition.workspace = true
|
6
|
+
description.workspace = true
|
7
|
+
authors.workspace = true
|
8
|
+
repository.workspace = true
|
9
|
+
license.workspace = true
|
10
|
+
readme.workspace = true
|
11
|
+
|
12
|
+
[lib]
|
13
|
+
|
14
|
+
name = "kinfer"
|
15
|
+
crate-type = ["cdylib", "rlib"]
|
16
|
+
|
17
|
+
[dependencies]
|
18
|
+
|
19
|
+
async-trait = "0.1"
|
20
|
+
flate2 = "1.0"
|
21
|
+
futures-util = "0.3.30"
|
22
|
+
ndarray = "0.16.1"
|
23
|
+
ort = { version = "2.0.0-rc.9", features = [ "load-dynamic" ] }
|
24
|
+
serde = { version = "1.0", features = ["derive"] }
|
25
|
+
serde_json = "1.0"
|
26
|
+
tar = "0.4"
|
27
|
+
thiserror = "1.0"
|
28
|
+
tokio = { version = "1.0", features = ["full"] }
|
29
|
+
|
30
|
+
[dev-dependencies]
|
31
|
+
|
32
|
+
rand = "0.8"
|