kinfer 0.3.3__cp311-cp311-macosx_11_0_arm64.whl → 0.4.0__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. kinfer/__init__.py +0 -5
  2. kinfer/common/__init__.py +0 -0
  3. kinfer/common/types.py +11 -0
  4. kinfer/export/common.py +35 -0
  5. kinfer/export/jax.py +51 -0
  6. kinfer/export/pytorch.py +42 -110
  7. kinfer/export/serialize.py +86 -0
  8. kinfer/requirements.txt +3 -4
  9. kinfer/rust/Cargo.toml +8 -6
  10. kinfer/rust/src/lib.rs +2 -11
  11. kinfer/rust/src/model.rs +271 -121
  12. kinfer/rust/src/runtime.rs +104 -0
  13. kinfer/rust_bindings/Cargo.toml +8 -1
  14. kinfer/rust_bindings/rust_bindings.pyi +35 -0
  15. kinfer/rust_bindings/src/lib.rs +310 -1
  16. kinfer/rust_bindings.cpython-311-darwin.so +0 -0
  17. kinfer/rust_bindings.pyi +29 -1
  18. kinfer-0.4.0.dist-info/METADATA +55 -0
  19. kinfer-0.4.0.dist-info/RECORD +26 -0
  20. {kinfer-0.3.3.dist-info → kinfer-0.4.0.dist-info}/WHEEL +2 -1
  21. kinfer/inference/__init__.py +0 -2
  22. kinfer/inference/base.py +0 -64
  23. kinfer/inference/python.py +0 -66
  24. kinfer/proto/__init__.py +0 -40
  25. kinfer/proto/kinfer_pb2.py +0 -103
  26. kinfer/proto/kinfer_pb2.pyi +0 -1097
  27. kinfer/requirements-dev.txt +0 -8
  28. kinfer/rust/build.rs +0 -16
  29. kinfer/rust/src/kinfer_proto.rs +0 -14
  30. kinfer/rust/src/main.rs +0 -6
  31. kinfer/rust/src/onnx_serializer.rs +0 -804
  32. kinfer/rust/src/serializer.rs +0 -221
  33. kinfer/rust/src/tests/onnx_serializer_tests.rs +0 -212
  34. kinfer/serialize/__init__.py +0 -60
  35. kinfer/serialize/base.py +0 -536
  36. kinfer/serialize/json.py +0 -399
  37. kinfer/serialize/numpy.py +0 -426
  38. kinfer/serialize/pytorch.py +0 -402
  39. kinfer/serialize/schema.py +0 -125
  40. kinfer/serialize/types.py +0 -17
  41. kinfer/serialize/utils.py +0 -177
  42. kinfer-0.3.3.dist-info/METADATA +0 -57
  43. kinfer-0.3.3.dist-info/RECORD +0 -40
  44. {kinfer-0.3.3.dist-info → kinfer-0.4.0.dist-info/licenses}/LICENSE +0 -0
  45. {kinfer-0.3.3.dist-info → kinfer-0.4.0.dist-info}/top_level.txt +0 -0
kinfer/__init__.py CHANGED
@@ -1,10 +1,5 @@
1
1
  """Defines the kinfer API."""
2
2
 
3
- from . import proto as K
4
- from .export.pytorch import export_model, get_model
5
- from .inference.base import KModel
6
- from .inference.python import ONNXModel
7
3
  from .rust_bindings import get_version
8
- from .serialize import get_multi_serializer, get_serializer
9
4
 
10
5
  __version__ = get_version()
File without changes
kinfer/common/types.py ADDED
@@ -0,0 +1,11 @@
1
+ """Defines common types."""
2
+
3
+ __all__ = [
4
+ "Metadata",
5
+ ]
6
+
7
+ from pydantic import BaseModel
8
+
9
+
10
+ class Metadata(BaseModel):
11
+ joint_names: list[str]
@@ -0,0 +1,35 @@
1
+ """Defines common utilities for exporting models."""
2
+
3
+
4
+ def get_shape(
5
+ name: str,
6
+ num_joints: int | None = None,
7
+ carry_shape: tuple[int, ...] | None = None,
8
+ ) -> tuple[int, ...]:
9
+ match name:
10
+ case "joint_angles":
11
+ if num_joints is None:
12
+ raise ValueError("`num_joints` must be provided when using `joint_angles`")
13
+ return (num_joints,)
14
+
15
+ case "joint_angular_velocities":
16
+ if num_joints is None:
17
+ raise ValueError("`num_joints` must be provided when using `joint_angular_velocities`")
18
+ return (num_joints,)
19
+
20
+ case "projected_gravity":
21
+ return (3,)
22
+
23
+ case "accelerometer":
24
+ return (3,)
25
+
26
+ case "gyroscope":
27
+ return (3,)
28
+
29
+ case "carry":
30
+ if carry_shape is None:
31
+ raise ValueError("`carry_shape` must be provided for `carry`")
32
+ return carry_shape
33
+
34
+ case _:
35
+ raise ValueError(f"Unknown tensor name: {name}")
kinfer/export/jax.py ADDED
@@ -0,0 +1,51 @@
1
+ """Jax model export utilities."""
2
+
3
+ import inspect
4
+ import logging
5
+
6
+ import tensorflow as tf
7
+ import tf2onnx
8
+ from equinox.internal._finalise_jaxpr import finalise_fn
9
+ from jax._src.stages import Wrapped
10
+ from jax.experimental import jax2tf
11
+ from onnx.onnx_pb import ModelProto
12
+
13
+ from kinfer.export.common import get_shape
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def export_fn(
19
+ model: Wrapped,
20
+ *,
21
+ num_joints: int | None = None,
22
+ carry_shape: tuple[int, ...] | None = None,
23
+ opset: int = 13,
24
+ ) -> ModelProto:
25
+ """Export a JAX function to ONNX."""
26
+ if not isinstance(model, Wrapped):
27
+ raise ValueError("Model must be a Wrapped function")
28
+
29
+ params = inspect.signature(model).parameters
30
+ input_names = list(params.keys())
31
+
32
+ # Gets the dummy input tensors for exporting the model.
33
+ tf_args = []
34
+ for name in input_names:
35
+ shape = get_shape(
36
+ name,
37
+ num_joints=num_joints,
38
+ carry_shape=carry_shape,
39
+ )
40
+ tf_args.append(tf.TensorSpec(shape, tf.float32, name=name))
41
+
42
+ finalised_fn = finalise_fn(model)
43
+ tf_fn = tf.function(jax2tf.convert(finalised_fn, enable_xla=False))
44
+
45
+ model_proto, _ = tf2onnx.convert.from_function(
46
+ tf_fn,
47
+ input_signature=tf_args,
48
+ opset=opset,
49
+ large_model=False,
50
+ )
51
+ return model_proto
kinfer/export/pytorch.py CHANGED
@@ -1,128 +1,60 @@
1
1
  """PyTorch model export utilities."""
2
2
 
3
- import base64
4
- import inspect
5
- from io import BytesIO
6
- from typing import Sequence
3
+ __all__ = [
4
+ "export_fn",
5
+ ]
6
+
7
+ import io
8
+ from typing import cast
7
9
 
8
10
  import onnx
9
- import onnxruntime as ort
10
11
  import torch
11
- from torch import Tensor
12
-
13
- from kinfer import proto as K
14
- from kinfer.serialize.pytorch import PyTorchMultiSerializer
15
- from kinfer.serialize.schema import get_dummy_io
16
- from kinfer.serialize.utils import check_names_match
12
+ from onnx.onnx_pb import ModelProto
13
+ from torch._C import FunctionSchema
17
14
 
18
- KINFER_METADATA_KEY = "kinfer_metadata"
15
+ from kinfer.export.common import get_shape
19
16
 
20
17
 
21
- def _add_metadata_to_onnx(model_proto: onnx.ModelProto, schema: K.ModelSchema) -> onnx.ModelProto:
22
- """Add metadata to ONNX model.
18
+ def export_fn(
19
+ model: torch.jit.ScriptFunction,
20
+ *,
21
+ num_joints: int | None = None,
22
+ carry_shape: tuple[int, ...] | None = None,
23
+ ) -> ModelProto:
24
+ """Exports a PyTorch function to ONNX.
23
25
 
24
26
  Args:
25
- model_proto: ONNX model prototype
26
- schema: Model schema to use for model export.
27
+ model: The model to export.
28
+ num_joints: The number of joints in the model.
29
+ carry_shape: The shape of the carry tensor.
27
30
 
28
31
  Returns:
29
- ONNX model with added metadata
32
+ The ONNX model as a `ModelProto`.
30
33
  """
31
- schema_bytes = schema.SerializeToString()
32
-
33
- meta = model_proto.metadata_props.add()
34
- meta.key = KINFER_METADATA_KEY
35
- meta.value = base64.b64encode(schema_bytes).decode("utf-8")
36
-
37
- return model_proto
38
-
39
-
40
- def export_model(model: torch.jit.ScriptModule, schema: K.ModelSchema) -> onnx.ModelProto:
41
- """Export PyTorch model to ONNX format with metadata.
42
-
43
- Args:
44
- model: PyTorch model to export.
45
- schema: Model schema to use for model export.
46
-
47
- Returns:
48
- ONNX inference session
49
- """
50
- # Matches each input name to the input values.
51
- signature = inspect.signature(model.forward)
52
- model_input_names = [
53
- p.name for p in signature.parameters.values() if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
54
- ]
55
- if len(model_input_names) != len(schema.input_schema.values):
56
- raise ValueError(f"Expected {len(model_input_names)} inputs, but schema has {len(schema.input_schema.values)}")
57
- input_schema_names = [i.value_name for i in schema.input_schema.values]
58
- output_schema_names = [o.value_name for o in schema.output_schema.values]
59
-
60
- if model_input_names != input_schema_names:
61
- raise ValueError(f"Expected input names {model_input_names} to match schema names {input_schema_names}")
62
-
63
- input_serializer = PyTorchMultiSerializer(schema.input_schema)
64
- output_serializer = PyTorchMultiSerializer(schema.output_schema)
65
-
66
- input_dummy_values = get_dummy_io(schema.input_schema)
67
- input_tensors = input_serializer.serialize_io(input_dummy_values, as_dict=True)
68
-
69
- check_names_match("model_input_names", model_input_names, "input_schema", list(input_tensors.keys()))
70
- input_tensor_list = [input_tensors[name] for name in model_input_names]
71
-
72
- # Attempts to run the model with the dummy inputs.
73
- try:
74
- pred_output_tensors = model(*input_tensor_list)
75
- except Exception as e:
76
- signature = inspect.signature(model.forward)
77
- model_input_names = [
78
- p.name for p in signature.parameters.values() if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
79
- ]
80
-
81
- raise ValueError(
82
- f"Failed to run model with dummy inputs; input names are {model_input_names} while "
83
- f"input schema is {schema.input_schema}"
84
- ) from e
85
-
86
- # Attempts to parse the output tensors using the output schema.
87
- if isinstance(pred_output_tensors, Tensor):
88
- pred_output_tensors = (pred_output_tensors,)
89
- if isinstance(pred_output_tensors, Sequence):
90
- pred_output_tensors = output_serializer.assign_names(pred_output_tensors)
91
- if not isinstance(pred_output_tensors, dict):
92
- raise ValueError("Output tensors could not be converted to dictionary")
93
- try:
94
- pred_output_tensors = output_serializer.deserialize_io(pred_output_tensors)
95
- except Exception as e:
96
- raise ValueError("Failed to parse output tensors using output schema; are you sure it is correct?") from e
97
-
98
- # Export model to buffer
99
- buffer = BytesIO()
34
+ if not isinstance(model, torch.jit.ScriptFunction):
35
+ raise ValueError("Model must be a torch.jit.ScriptFunction")
36
+
37
+ schema = cast(FunctionSchema, model.schema)
38
+ input_names = [arg.name for arg in schema.arguments]
39
+
40
+ # Gets the dummy input tensors for exporting the model.
41
+ args = []
42
+ for name in input_names:
43
+ shape = get_shape(
44
+ name,
45
+ num_joints=num_joints,
46
+ carry_shape=carry_shape,
47
+ )
48
+ args.append(torch.zeros(shape))
49
+
50
+ buffer = io.BytesIO()
100
51
  torch.onnx.export(
101
52
  model=model,
102
53
  f=buffer, # type: ignore[arg-type]
103
- kwargs=input_tensors,
104
- input_names=input_schema_names,
105
- output_names=output_schema_names,
54
+ args=tuple(args),
55
+ input_names=input_names,
56
+ external_data=False,
106
57
  )
107
58
  buffer.seek(0)
108
-
109
- # Loads the model from the buffer and adds metadata.
110
- model_proto = onnx.load_model(buffer)
111
- model_proto = _add_metadata_to_onnx(model_proto, schema)
112
-
113
- return model_proto
114
-
115
-
116
- def get_model(model_proto: onnx.ModelProto) -> ort.InferenceSession:
117
- """Converts a model proto to an inference session.
118
-
119
- Args:
120
- model_proto: ONNX model proto to convert to inference session.
121
-
122
- Returns:
123
- ONNX inference session
124
- """
125
- buffer = BytesIO()
126
- onnx.save_model(model_proto, buffer)
127
- buffer.seek(0)
128
- return ort.InferenceSession(buffer.read())
59
+ model_bytes = buffer.read()
60
+ return onnx.load_from_string(model_bytes)
@@ -0,0 +1,86 @@
1
+ """Functions for serializing and deserializing models."""
2
+
3
+ __all__ = [
4
+ "pack",
5
+ ]
6
+
7
+
8
+ import io
9
+ import tarfile
10
+
11
+ from onnx.onnx_pb import ModelProto
12
+
13
+ from kinfer.common.types import Metadata
14
+ from kinfer.export.common import get_shape
15
+
16
+
17
+ def pack(
18
+ init_fn: ModelProto,
19
+ step_fn: ModelProto,
20
+ joint_names: list[str],
21
+ carry_shape: tuple[int, ...],
22
+ ) -> bytes:
23
+ """Packs the initialization function and step function into a directory.
24
+
25
+ Args:
26
+ init_fn: The initialization function.
27
+ step_fn: The step function.
28
+ joint_names: The list of joint names, in the order that the model
29
+ expects them to be provided.
30
+ carry_shape: The shape of the carry tensor.
31
+ root_dir: The root directory of the model.
32
+ """
33
+ num_joints = len(joint_names)
34
+
35
+ # Checks the `init` function.
36
+ if len(init_fn.graph.input) > 0:
37
+ raise ValueError(f"`init` function should not have any inputs! Got {len(init_fn.graph.input)}")
38
+ if len(init_fn.graph.output) != 1:
39
+ raise ValueError(f"`init` function should have exactly 1 output! Got {len(init_fn.graph.output)}")
40
+
41
+ # Checks the `step` function.
42
+ for step_input in step_fn.graph.input:
43
+ step_input_type = step_input.type.tensor_type
44
+ shape = tuple(dim.dim_value for dim in step_input_type.shape.dim)
45
+ expected_shape = get_shape(
46
+ step_input.name,
47
+ num_joints=num_joints,
48
+ carry_shape=carry_shape,
49
+ )
50
+ if shape != expected_shape:
51
+ raise ValueError(f"Expected shape {expected_shape} for input `{step_input.name}`, got {shape}")
52
+
53
+ if len(step_fn.graph.output) != 2:
54
+ raise ValueError(f"Step function must have exactly 2 outputs, got {len(step_fn.graph.output)}")
55
+
56
+ model_output = step_fn.graph.output[0]
57
+ output_shape = tuple(dim.dim_value for dim in model_output.type.tensor_type.shape.dim)
58
+ if output_shape != (num_joints,):
59
+ raise ValueError(f"Expected output shape {num_joints} for output `{model_output.name}`, got {output_shape}")
60
+
61
+ model_carry = step_fn.graph.output[1]
62
+ output_carry_shape = tuple(dim.dim_value for dim in model_carry.type.tensor_type.shape.dim)
63
+ if output_carry_shape != carry_shape:
64
+ raise ValueError(f"Expected carry shape {carry_shape} for output `{model_carry.name}`, got {carry_shape}")
65
+
66
+ # Builds the metadata object.
67
+ metadata = Metadata(
68
+ joint_names=joint_names,
69
+ )
70
+
71
+ buffer = io.BytesIO()
72
+
73
+ with tarfile.open(fileobj=buffer, mode="w:gz") as tar:
74
+
75
+ def add_file_bytes(name: str, data: bytes) -> None: # noqa: ANN401
76
+ info = tarfile.TarInfo(name=name)
77
+ info.size = len(data)
78
+ tar.addfile(info, io.BytesIO(data))
79
+
80
+ add_file_bytes("init_fn.onnx", init_fn.SerializeToString())
81
+ add_file_bytes("step_fn.onnx", step_fn.SerializeToString())
82
+ add_file_bytes("metadata.json", metadata.model_dump_json().encode("utf-8"))
83
+
84
+ buffer.seek(0)
85
+
86
+ return buffer.read()
kinfer/requirements.txt CHANGED
@@ -1,9 +1,8 @@
1
1
  # requirements.txt
2
2
 
3
3
  # Machine Learning
4
- torch
5
4
  onnx
6
- onnxruntime
5
+ onnxruntime==1.20.0
7
6
 
8
- # Protocol Buffers
9
- protobuf
7
+ # Serialization
8
+ pydantic
kinfer/rust/Cargo.toml CHANGED
@@ -6,15 +6,17 @@ edition = "2021"
6
6
 
7
7
  [dependencies]
8
8
 
9
+ async-trait = "0.1"
10
+ flate2 = "1.0"
9
11
  futures-util = "0.3.30"
10
12
  ndarray = "0.16.1"
11
- ort = { version = "2.0.0-rc.6", features = [ "load-dynamic" ] }
12
- prost = "0.12"
13
+ ort = { version = "2.0.0-rc.9", features = [ "load-dynamic" ] }
14
+ serde = { version = "1.0", features = ["derive"] }
15
+ serde_json = "1.0"
16
+ tar = "0.4"
13
17
  thiserror = "1.0"
18
+ tokio = { version = "1.0", features = ["full"] }
14
19
 
15
20
  [dev-dependencies]
16
- rand = "0.8"
17
-
18
- [build-dependencies]
19
21
 
20
- prost-build = "0.12"
22
+ rand = "0.8"
kinfer/rust/src/lib.rs CHANGED
@@ -1,14 +1,5 @@
1
- pub mod kinfer_proto;
2
1
  pub mod model;
3
- pub mod onnx_serializer;
4
- pub mod serializer;
2
+ pub mod runtime;
5
3
 
6
- pub use kinfer_proto::*;
7
4
  pub use model::*;
8
- pub use onnx_serializer::*;
9
- pub use serializer::*;
10
-
11
- #[cfg(test)]
12
- mod tests {
13
- mod onnx_serializer_tests;
14
- }
5
+ pub use runtime::*;