kinfer 0.3.3__tar.gz → 0.4.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. {kinfer-0.3.3 → kinfer-0.4.2}/.cargo/config.toml +3 -1
  2. {kinfer-0.3.3 → kinfer-0.4.2}/Cargo.toml +1 -1
  3. kinfer-0.4.2/PKG-INFO +55 -0
  4. kinfer-0.4.2/README.md +5 -0
  5. kinfer-0.4.2/kinfer/__init__.py +16 -0
  6. kinfer-0.4.2/kinfer/common/types.py +12 -0
  7. kinfer-0.4.2/kinfer/export/common.py +41 -0
  8. kinfer-0.4.2/kinfer/export/jax.py +53 -0
  9. kinfer-0.4.2/kinfer/export/pytorch.py +63 -0
  10. kinfer-0.4.2/kinfer/export/serialize.py +93 -0
  11. kinfer-0.4.2/kinfer/py.typed +0 -0
  12. kinfer-0.4.2/kinfer/requirements.txt +8 -0
  13. kinfer-0.4.2/kinfer/rust/Cargo.toml +32 -0
  14. kinfer-0.4.2/kinfer/rust/src/lib.rs +5 -0
  15. kinfer-0.4.2/kinfer/rust/src/model.rs +318 -0
  16. kinfer-0.4.2/kinfer/rust/src/runtime.rs +104 -0
  17. {kinfer-0.3.3 → kinfer-0.4.2}/kinfer/rust_bindings/Cargo.toml +8 -1
  18. kinfer-0.4.2/kinfer/rust_bindings/src/lib.rs +342 -0
  19. kinfer-0.4.2/kinfer.egg-info/PKG-INFO +55 -0
  20. kinfer-0.4.2/kinfer.egg-info/SOURCES.txt +33 -0
  21. kinfer-0.4.2/kinfer.egg-info/requires.txt +35 -0
  22. {kinfer-0.3.3 → kinfer-0.4.2}/pyproject.toml +1 -0
  23. {kinfer-0.3.3 → kinfer-0.4.2}/setup.py +27 -19
  24. kinfer-0.4.2/tests/test_jax.py +114 -0
  25. kinfer-0.4.2/tests/test_pytorch.py +153 -0
  26. kinfer-0.3.3/PKG-INFO +0 -57
  27. kinfer-0.3.3/README.md +0 -36
  28. kinfer-0.3.3/kinfer/__init__.py +0 -10
  29. kinfer-0.3.3/kinfer/export/pytorch.py +0 -128
  30. kinfer-0.3.3/kinfer/inference/__init__.py +0 -2
  31. kinfer-0.3.3/kinfer/inference/base.py +0 -64
  32. kinfer-0.3.3/kinfer/inference/python.py +0 -66
  33. kinfer-0.3.3/kinfer/proto/__init__.py +0 -40
  34. kinfer-0.3.3/kinfer/proto/kinfer_pb2.py +0 -103
  35. kinfer-0.3.3/kinfer/proto/kinfer_pb2.pyi +0 -1097
  36. kinfer-0.3.3/kinfer/requirements-dev.txt +0 -8
  37. kinfer-0.3.3/kinfer/requirements.txt +0 -9
  38. kinfer-0.3.3/kinfer/rust/Cargo.toml +0 -20
  39. kinfer-0.3.3/kinfer/rust/build.rs +0 -16
  40. kinfer-0.3.3/kinfer/rust/src/kinfer_proto.rs +0 -14
  41. kinfer-0.3.3/kinfer/rust/src/lib.rs +0 -14
  42. kinfer-0.3.3/kinfer/rust/src/main.rs +0 -6
  43. kinfer-0.3.3/kinfer/rust/src/model.rs +0 -153
  44. kinfer-0.3.3/kinfer/rust/src/onnx_serializer.rs +0 -804
  45. kinfer-0.3.3/kinfer/rust/src/serializer.rs +0 -221
  46. kinfer-0.3.3/kinfer/rust/src/tests/onnx_serializer_tests.rs +0 -212
  47. kinfer-0.3.3/kinfer/rust_bindings/src/lib.rs +0 -17
  48. kinfer-0.3.3/kinfer/rust_bindings.pyi +0 -7
  49. kinfer-0.3.3/kinfer/serialize/__init__.py +0 -60
  50. kinfer-0.3.3/kinfer/serialize/base.py +0 -536
  51. kinfer-0.3.3/kinfer/serialize/json.py +0 -399
  52. kinfer-0.3.3/kinfer/serialize/numpy.py +0 -426
  53. kinfer-0.3.3/kinfer/serialize/pytorch.py +0 -402
  54. kinfer-0.3.3/kinfer/serialize/schema.py +0 -125
  55. kinfer-0.3.3/kinfer/serialize/types.py +0 -17
  56. kinfer-0.3.3/kinfer/serialize/utils.py +0 -177
  57. kinfer-0.3.3/kinfer.egg-info/PKG-INFO +0 -57
  58. kinfer-0.3.3/kinfer.egg-info/SOURCES.txt +0 -49
  59. kinfer-0.3.3/kinfer.egg-info/requires.txt +0 -12
  60. kinfer-0.3.3/tests/test_infer.py +0 -101
  61. kinfer-0.3.3/tests/test_schema.py +0 -229
  62. {kinfer-0.3.3 → kinfer-0.4.2}/LICENSE +0 -0
  63. {kinfer-0.3.3 → kinfer-0.4.2}/MANIFEST.in +0 -0
  64. {kinfer-0.3.3/kinfer/export → kinfer-0.4.2/kinfer/common}/__init__.py +0 -0
  65. /kinfer-0.3.3/kinfer/py.typed → /kinfer-0.4.2/kinfer/export/__init__.py +0 -0
  66. {kinfer-0.3.3 → kinfer-0.4.2}/kinfer/rust_bindings/pyproject.toml +0 -0
  67. {kinfer-0.3.3 → kinfer-0.4.2}/kinfer/rust_bindings/src/bin/stub_gen.rs +0 -0
  68. {kinfer-0.3.3 → kinfer-0.4.2}/kinfer.egg-info/dependency_links.txt +0 -0
  69. {kinfer-0.3.3 → kinfer-0.4.2}/kinfer.egg-info/not-zip-safe +0 -0
  70. {kinfer-0.3.3 → kinfer-0.4.2}/kinfer.egg-info/top_level.txt +0 -0
  71. {kinfer-0.3.3 → kinfer-0.4.2}/setup.cfg +0 -0
@@ -1,5 +1,7 @@
1
- [target.aarch64-apple-darwin]
1
+ [target.x86_64-unknown-linux-gnu]
2
+ rustflags = [ "-Clink-args=-Wl,-rpath,\\$ORIGIN" ]
2
3
 
4
+ [target.aarch64-apple-darwin]
3
5
  rustflags = [
4
6
  "-C", "link-arg=-undefined",
5
7
  "-C", "link-arg=dynamic_lookup",
@@ -8,7 +8,7 @@ resolver = "2"
8
8
 
9
9
  [workspace.package]
10
10
 
11
- version = "0.3.3"
11
+ version = "0.4.2"
12
12
  authors = ["Wesley Maa <wesley@kscale.dev>", "Benjamin Bolte <ben@kscale.dev>"]
13
13
  edition = "2021"
14
14
  description = "K-Scale Inference Library"
kinfer-0.4.2/PKG-INFO ADDED
@@ -0,0 +1,55 @@
1
+ Metadata-Version: 2.4
2
+ Name: kinfer
3
+ Version: 0.4.2
4
+ Summary: Tool to make it easier to run a model on a real robot
5
+ Home-page: https://github.com/kscalelabs/kinfer.git
6
+ Author: K-Scale Labs
7
+ Requires-Python: >=3.11
8
+ Description-Content-Type: text/markdown
9
+ License-File: LICENSE
10
+ Requires-Dist: onnx
11
+ Requires-Dist: onnxruntime==1.20.0
12
+ Requires-Dist: pydantic
13
+ Provides-Extra: dev
14
+ Requires-Dist: black; extra == "dev"
15
+ Requires-Dist: darglint; extra == "dev"
16
+ Requires-Dist: mypy; extra == "dev"
17
+ Requires-Dist: pytest; extra == "dev"
18
+ Requires-Dist: ruff; extra == "dev"
19
+ Requires-Dist: types-tensorflow; extra == "dev"
20
+ Provides-Extra: pytorch
21
+ Requires-Dist: torch; extra == "pytorch"
22
+ Provides-Extra: jax
23
+ Requires-Dist: tensorflow; extra == "jax"
24
+ Requires-Dist: tf2onnx; extra == "jax"
25
+ Requires-Dist: jax; extra == "jax"
26
+ Requires-Dist: equinox; extra == "jax"
27
+ Requires-Dist: numpy<2; extra == "jax"
28
+ Provides-Extra: all
29
+ Requires-Dist: black; extra == "all"
30
+ Requires-Dist: darglint; extra == "all"
31
+ Requires-Dist: mypy; extra == "all"
32
+ Requires-Dist: pytest; extra == "all"
33
+ Requires-Dist: ruff; extra == "all"
34
+ Requires-Dist: types-tensorflow; extra == "all"
35
+ Requires-Dist: torch; extra == "all"
36
+ Requires-Dist: tensorflow; extra == "all"
37
+ Requires-Dist: tf2onnx; extra == "all"
38
+ Requires-Dist: jax; extra == "all"
39
+ Requires-Dist: equinox; extra == "all"
40
+ Requires-Dist: numpy<2; extra == "all"
41
+ Dynamic: author
42
+ Dynamic: description
43
+ Dynamic: description-content-type
44
+ Dynamic: home-page
45
+ Dynamic: license-file
46
+ Dynamic: provides-extra
47
+ Dynamic: requires-dist
48
+ Dynamic: requires-python
49
+ Dynamic: summary
50
+
51
+ # kinfer
52
+
53
+ This package is designed to support running real-time robotics models.
54
+
55
+ For more information, see the documentation [here](https://docs.kscale.dev/docs/k-infer).
kinfer-0.4.2/README.md ADDED
@@ -0,0 +1,5 @@
1
+ # kinfer
2
+
3
+ This package is designed to support running real-time robotics models.
4
+
5
+ For more information, see the documentation [here](https://docs.kscale.dev/docs/k-infer).
@@ -0,0 +1,16 @@
1
+ """Defines the kinfer API."""
2
+
3
+ import os
4
+
5
+ if "ORT_DYLIB_PATH" not in os.environ:
6
+ from pathlib import Path
7
+
8
+ import onnxruntime as ort
9
+
10
+ LIB_PATH = next((Path(ort.__file__).parent / "capi").glob("libonnxruntime.*"), None)
11
+ if LIB_PATH is not None:
12
+ os.environ["ORT_DYLIB_PATH"] = LIB_PATH.resolve().as_posix()
13
+
14
+ from .rust_bindings import get_version
15
+
16
+ __version__ = get_version()
@@ -0,0 +1,12 @@
1
+ """Defines common types."""
2
+
3
+ __all__ = [
4
+ "Metadata",
5
+ ]
6
+
7
+ from pydantic import BaseModel
8
+
9
+
10
+ class Metadata(BaseModel):
11
+ joint_names: list[str]
12
+ num_commands: int | None
@@ -0,0 +1,41 @@
1
+ """Defines common utilities for exporting models."""
2
+
3
+
4
+ def get_shape(
5
+ name: str,
6
+ num_joints: int | None = None,
7
+ num_commands: int | None = None,
8
+ carry_shape: tuple[int, ...] | None = None,
9
+ ) -> tuple[int, ...]:
10
+ match name:
11
+ case "joint_angles":
12
+ if num_joints is None:
13
+ raise ValueError("`num_joints` must be provided when using `joint_angles`")
14
+ return (num_joints,)
15
+
16
+ case "joint_angular_velocities":
17
+ if num_joints is None:
18
+ raise ValueError("`num_joints` must be provided when using `joint_angular_velocities`")
19
+ return (num_joints,)
20
+
21
+ case "projected_gravity":
22
+ return (3,)
23
+
24
+ case "accelerometer":
25
+ return (3,)
26
+
27
+ case "gyroscope":
28
+ return (3,)
29
+
30
+ case "command":
31
+ if num_commands is None:
32
+ raise ValueError("`num_commands` must be provided when using `command`")
33
+ return (num_commands,)
34
+
35
+ case "carry":
36
+ if carry_shape is None:
37
+ raise ValueError("`carry_shape` must be provided for `carry`")
38
+ return carry_shape
39
+
40
+ case _:
41
+ raise ValueError(f"Unknown tensor name: {name}")
@@ -0,0 +1,53 @@
1
+ """Jax model export utilities."""
2
+
3
+ import inspect
4
+ import logging
5
+
6
+ import tensorflow as tf
7
+ import tf2onnx
8
+ from equinox.internal._finalise_jaxpr import finalise_fn
9
+ from jax._src.stages import Wrapped
10
+ from jax.experimental import jax2tf
11
+ from onnx.onnx_pb import ModelProto
12
+
13
+ from kinfer.export.common import get_shape
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def export_fn(
19
+ model: Wrapped,
20
+ *,
21
+ num_joints: int | None = None,
22
+ num_commands: int | None = None,
23
+ carry_shape: tuple[int, ...] | None = None,
24
+ opset: int = 13,
25
+ ) -> ModelProto:
26
+ """Export a JAX function to ONNX."""
27
+ if not isinstance(model, Wrapped):
28
+ raise ValueError("Model must be a Wrapped function")
29
+
30
+ params = inspect.signature(model).parameters
31
+ input_names = list(params.keys())
32
+
33
+ # Gets the dummy input tensors for exporting the model.
34
+ tf_args = []
35
+ for name in input_names:
36
+ shape = get_shape(
37
+ name,
38
+ num_joints=num_joints,
39
+ num_commands=num_commands,
40
+ carry_shape=carry_shape,
41
+ )
42
+ tf_args.append(tf.TensorSpec(shape, tf.float32, name=name))
43
+
44
+ finalised_fn = finalise_fn(model)
45
+ tf_fn = tf.function(jax2tf.convert(finalised_fn, enable_xla=False))
46
+
47
+ model_proto, _ = tf2onnx.convert.from_function(
48
+ tf_fn,
49
+ input_signature=tf_args,
50
+ opset=opset,
51
+ large_model=False,
52
+ )
53
+ return model_proto
@@ -0,0 +1,63 @@
1
+ """PyTorch model export utilities."""
2
+
3
+ __all__ = [
4
+ "export_fn",
5
+ ]
6
+
7
+ import io
8
+ from typing import cast
9
+
10
+ import onnx
11
+ import torch
12
+ from onnx.onnx_pb import ModelProto
13
+ from torch._C import FunctionSchema
14
+
15
+ from kinfer.export.common import get_shape
16
+
17
+
18
+ def export_fn(
19
+ model: torch.jit.ScriptFunction,
20
+ *,
21
+ num_joints: int | None = None,
22
+ num_commands: int | None = None,
23
+ carry_shape: tuple[int, ...] | None = None,
24
+ ) -> ModelProto:
25
+ """Exports a PyTorch function to ONNX.
26
+
27
+ Args:
28
+ model: The model to export.
29
+ num_joints: The number of joints in the model.
30
+ num_commands: The number of commands in the model.
31
+ carry_shape: The shape of the carry tensor.
32
+
33
+ Returns:
34
+ The ONNX model as a `ModelProto`.
35
+ """
36
+ if not isinstance(model, torch.jit.ScriptFunction):
37
+ raise ValueError("Model must be a torch.jit.ScriptFunction")
38
+
39
+ schema = cast(FunctionSchema, model.schema)
40
+ input_names = [arg.name for arg in schema.arguments]
41
+
42
+ # Gets the dummy input tensors for exporting the model.
43
+ args = []
44
+ for name in input_names:
45
+ shape = get_shape(
46
+ name,
47
+ num_joints=num_joints,
48
+ num_commands=num_commands,
49
+ carry_shape=carry_shape,
50
+ )
51
+ args.append(torch.zeros(shape))
52
+
53
+ buffer = io.BytesIO()
54
+ torch.onnx.export(
55
+ model=model,
56
+ f=buffer, # type: ignore[arg-type]
57
+ args=tuple(args),
58
+ input_names=input_names,
59
+ external_data=False,
60
+ )
61
+ buffer.seek(0)
62
+ model_bytes = buffer.read()
63
+ return onnx.load_from_string(model_bytes)
@@ -0,0 +1,93 @@
1
+ """Functions for serializing and deserializing models."""
2
+
3
+ __all__ = [
4
+ "pack",
5
+ ]
6
+
7
+
8
+ import io
9
+ import tarfile
10
+
11
+ from onnx.onnx_pb import ModelProto
12
+
13
+ from kinfer.common.types import Metadata
14
+ from kinfer.export.common import get_shape
15
+
16
+
17
+ def pack(
18
+ init_fn: ModelProto,
19
+ step_fn: ModelProto,
20
+ joint_names: list[str],
21
+ num_commands: int | None = None,
22
+ carry_shape: tuple[int, ...] | None = None,
23
+ ) -> bytes:
24
+ """Packs the initialization function and step function into a directory.
25
+
26
+ Args:
27
+ init_fn: The initialization function.
28
+ step_fn: The step function.
29
+ joint_names: The list of joint names, in the order that the model
30
+ expects them to be provided.
31
+ num_commands: The number of commands in the model.
32
+ carry_shape: The shape of the carry tensor.
33
+ """
34
+ num_joints = len(joint_names)
35
+
36
+ # Checks the `init` function.
37
+ if len(init_fn.graph.input) > 0:
38
+ raise ValueError(f"`init` function should not have any inputs! Got {len(init_fn.graph.input)}")
39
+ if len(init_fn.graph.output) != 1:
40
+ raise ValueError(f"`init` function should have exactly 1 output! Got {len(init_fn.graph.output)}")
41
+ init_carry = init_fn.graph.output[0]
42
+ init_carry_shape = tuple(dim.dim_value for dim in init_carry.type.tensor_type.shape.dim)
43
+ if carry_shape is not None and init_carry_shape != carry_shape:
44
+ raise ValueError(f"Expected carry shape {carry_shape} for output `{init_carry.name}`, got {init_carry_shape}")
45
+
46
+ # Checks the `step` function.
47
+ for step_input in step_fn.graph.input:
48
+ step_input_type = step_input.type.tensor_type
49
+ shape = tuple(dim.dim_value for dim in step_input_type.shape.dim)
50
+ expected_shape = get_shape(
51
+ step_input.name,
52
+ num_joints=num_joints,
53
+ num_commands=num_commands,
54
+ carry_shape=carry_shape,
55
+ )
56
+ if shape != expected_shape:
57
+ raise ValueError(f"Expected shape {expected_shape} for input `{step_input.name}`, got {shape}")
58
+
59
+ if len(step_fn.graph.output) != 2:
60
+ raise ValueError(f"Step function must have exactly 2 outputs, got {len(step_fn.graph.output)}")
61
+
62
+ output_actions = step_fn.graph.output[0]
63
+ actions_shape = tuple(dim.dim_value for dim in output_actions.type.tensor_type.shape.dim)
64
+ if actions_shape != (num_joints,):
65
+ raise ValueError(f"Expected output shape {num_joints} for output `{output_actions.name}`, got {actions_shape}")
66
+
67
+ output_carry = step_fn.graph.output[1]
68
+ output_carry_shape = tuple(dim.dim_value for dim in output_carry.type.tensor_type.shape.dim)
69
+ if output_carry_shape != init_carry_shape:
70
+ raise ValueError(f"Expected carry shape {init_carry_shape} for output carry, got {output_carry_shape}")
71
+
72
+ # Builds the metadata object.
73
+ metadata = Metadata(
74
+ joint_names=joint_names,
75
+ num_commands=num_commands,
76
+ )
77
+
78
+ buffer = io.BytesIO()
79
+
80
+ with tarfile.open(fileobj=buffer, mode="w:gz") as tar:
81
+
82
+ def add_file_bytes(name: str, data: bytes) -> None: # noqa: ANN401
83
+ info = tarfile.TarInfo(name=name)
84
+ info.size = len(data)
85
+ tar.addfile(info, io.BytesIO(data))
86
+
87
+ add_file_bytes("init_fn.onnx", init_fn.SerializeToString())
88
+ add_file_bytes("step_fn.onnx", step_fn.SerializeToString())
89
+ add_file_bytes("metadata.json", metadata.model_dump_json().encode("utf-8"))
90
+
91
+ buffer.seek(0)
92
+
93
+ return buffer.read()
File without changes
@@ -0,0 +1,8 @@
1
+ # requirements.txt
2
+
3
+ # Machine Learning
4
+ onnx
5
+ onnxruntime==1.20.0
6
+
7
+ # Serialization
8
+ pydantic
@@ -0,0 +1,32 @@
1
+ [package]
2
+
3
+ name = "kinfer"
4
+ version.workspace = true
5
+ edition.workspace = true
6
+ description.workspace = true
7
+ authors.workspace = true
8
+ repository.workspace = true
9
+ license.workspace = true
10
+ readme.workspace = true
11
+
12
+ [lib]
13
+
14
+ name = "kinfer"
15
+ crate-type = ["cdylib", "rlib"]
16
+
17
+ [dependencies]
18
+
19
+ async-trait = "0.1"
20
+ flate2 = "1.0"
21
+ futures-util = "0.3.30"
22
+ ndarray = "0.16.1"
23
+ ort = { version = "2.0.0-rc.9", features = [ "load-dynamic" ] }
24
+ serde = { version = "1.0", features = ["derive"] }
25
+ serde_json = "1.0"
26
+ tar = "0.4"
27
+ thiserror = "1.0"
28
+ tokio = { version = "1.0", features = ["full"] }
29
+
30
+ [dev-dependencies]
31
+
32
+ rand = "0.8"
@@ -0,0 +1,5 @@
1
+ pub mod model;
2
+ pub mod runtime;
3
+
4
+ pub use model::*;
5
+ pub use runtime::*;