kinfer 0.5.4__cp311-cp311-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kinfer/__init__.py +16 -0
- kinfer/export/__init__.py +3 -0
- kinfer/export/jax.py +55 -0
- kinfer/export/pytorch.py +53 -0
- kinfer/export/serialize.py +84 -0
- kinfer/py.typed +0 -0
- kinfer/requirements.txt +13 -0
- kinfer/rust/Cargo.toml +36 -0
- kinfer/rust/src/lib.rs +8 -0
- kinfer/rust/src/logger.rs +141 -0
- kinfer/rust/src/model.rs +354 -0
- kinfer/rust/src/runtime.rs +107 -0
- kinfer/rust/src/types.rs +96 -0
- kinfer/rust_bindings/Cargo.toml +26 -0
- kinfer/rust_bindings/pyproject.toml +7 -0
- kinfer/rust_bindings/rust_bindings.pyi +46 -0
- kinfer/rust_bindings/src/bin/stub_gen.rs +7 -0
- kinfer/rust_bindings/src/lib.rs +486 -0
- kinfer/rust_bindings.cpython-311-x86_64-linux-gnu.so +0 -0
- kinfer/rust_bindings.pyi +46 -0
- kinfer/scripts/plot_ndjson.py +177 -0
- kinfer-0.5.4.dist-info/METADATA +63 -0
- kinfer-0.5.4.dist-info/RECORD +26 -0
- kinfer-0.5.4.dist-info/WHEEL +5 -0
- kinfer-0.5.4.dist-info/licenses/LICENSE +21 -0
- kinfer-0.5.4.dist-info/top_level.txt +1 -0
kinfer/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
1
|
+
"""Defines the kinfer API."""
|
2
|
+
|
3
|
+
import os
|
4
|
+
|
5
|
+
if "ORT_DYLIB_PATH" not in os.environ:
|
6
|
+
from pathlib import Path
|
7
|
+
|
8
|
+
import onnxruntime as ort
|
9
|
+
|
10
|
+
LIB_PATH = next((Path(ort.__file__).parent / "capi").glob("libonnxruntime.*"), None)
|
11
|
+
if LIB_PATH is not None:
|
12
|
+
os.environ["ORT_DYLIB_PATH"] = LIB_PATH.resolve().as_posix()
|
13
|
+
|
14
|
+
from .rust_bindings import get_version
|
15
|
+
|
16
|
+
__version__ = get_version()
|
kinfer/export/jax.py
ADDED
@@ -0,0 +1,55 @@
|
|
1
|
+
"""Jax model export utilities."""
|
2
|
+
|
3
|
+
import inspect
|
4
|
+
import logging
|
5
|
+
|
6
|
+
import tensorflow as tf
|
7
|
+
import tf2onnx
|
8
|
+
from equinox.internal._finalise_jaxpr import finalise_fn
|
9
|
+
from jax._src.stages import Wrapped
|
10
|
+
from jax.experimental import jax2tf
|
11
|
+
from onnx.onnx_pb import ModelProto
|
12
|
+
|
13
|
+
from kinfer.rust_bindings import PyInputType, PyModelMetadata
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
def export_fn(
|
19
|
+
model: Wrapped,
|
20
|
+
metadata: PyModelMetadata,
|
21
|
+
*,
|
22
|
+
opset: int = 13,
|
23
|
+
) -> ModelProto:
|
24
|
+
"""Export a JAX function to ONNX.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
model: The model to export.
|
28
|
+
metadata: The metadata for the model.
|
29
|
+
opset: The ONNX opset to use.
|
30
|
+
|
31
|
+
Returns:
|
32
|
+
The ONNX model as a `ModelProto`.
|
33
|
+
"""
|
34
|
+
if not isinstance(model, Wrapped):
|
35
|
+
raise ValueError("Model must be a Wrapped function")
|
36
|
+
|
37
|
+
params = inspect.signature(model).parameters
|
38
|
+
input_names = list(params.keys())
|
39
|
+
|
40
|
+
# Gets the dummy input tensors for exporting the model.
|
41
|
+
tf_args = []
|
42
|
+
for name in input_names:
|
43
|
+
shape = PyInputType(name).get_shape(metadata)
|
44
|
+
tf_args.append(tf.TensorSpec(shape, tf.float32, name=name))
|
45
|
+
|
46
|
+
finalised_fn = finalise_fn(model)
|
47
|
+
tf_fn = tf.function(jax2tf.convert(finalised_fn, enable_xla=False))
|
48
|
+
|
49
|
+
model_proto, _ = tf2onnx.convert.from_function(
|
50
|
+
tf_fn,
|
51
|
+
input_signature=tf_args,
|
52
|
+
opset=opset,
|
53
|
+
large_model=False,
|
54
|
+
)
|
55
|
+
return model_proto
|
kinfer/export/pytorch.py
ADDED
@@ -0,0 +1,53 @@
|
|
1
|
+
"""PyTorch model export utilities."""
|
2
|
+
|
3
|
+
__all__ = [
|
4
|
+
"export_fn",
|
5
|
+
]
|
6
|
+
|
7
|
+
import io
|
8
|
+
from typing import cast
|
9
|
+
|
10
|
+
import onnx
|
11
|
+
import torch
|
12
|
+
from onnx.onnx_pb import ModelProto
|
13
|
+
from torch._C import FunctionSchema
|
14
|
+
|
15
|
+
from kinfer.rust_bindings import PyInputType, PyModelMetadata
|
16
|
+
|
17
|
+
|
18
|
+
def export_fn(
|
19
|
+
model: torch.jit.ScriptFunction,
|
20
|
+
metadata: PyModelMetadata,
|
21
|
+
) -> ModelProto:
|
22
|
+
"""Exports a PyTorch function to ONNX.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
model: The model to export.
|
26
|
+
metadata: The metadata for the model.
|
27
|
+
|
28
|
+
Returns:
|
29
|
+
The ONNX model as a `ModelProto`.
|
30
|
+
"""
|
31
|
+
if not isinstance(model, torch.jit.ScriptFunction):
|
32
|
+
raise ValueError("Model must be a torch.jit.ScriptFunction")
|
33
|
+
|
34
|
+
schema = cast(FunctionSchema, model.schema)
|
35
|
+
input_names = [arg.name for arg in schema.arguments]
|
36
|
+
|
37
|
+
# Gets the dummy input tensors for exporting the model.
|
38
|
+
args = []
|
39
|
+
for name in input_names:
|
40
|
+
shape = PyInputType(name).get_shape(metadata)
|
41
|
+
args.append(torch.zeros(shape))
|
42
|
+
|
43
|
+
buffer = io.BytesIO()
|
44
|
+
torch.onnx.export(
|
45
|
+
model=model,
|
46
|
+
f=buffer, # type: ignore[arg-type]
|
47
|
+
args=tuple(args),
|
48
|
+
input_names=input_names,
|
49
|
+
external_data=False,
|
50
|
+
)
|
51
|
+
buffer.seek(0)
|
52
|
+
model_bytes = buffer.read()
|
53
|
+
return onnx.load_from_string(model_bytes)
|
@@ -0,0 +1,84 @@
|
|
1
|
+
"""Functions for serializing and deserializing models."""
|
2
|
+
|
3
|
+
__all__ = [
|
4
|
+
"pack",
|
5
|
+
]
|
6
|
+
|
7
|
+
import io
|
8
|
+
import logging
|
9
|
+
import tarfile
|
10
|
+
|
11
|
+
from onnx.onnx_pb import ModelProto
|
12
|
+
|
13
|
+
from kinfer.rust_bindings import PyInputType, PyModelMetadata
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
def pack(
|
19
|
+
init_fn: ModelProto,
|
20
|
+
step_fn: ModelProto,
|
21
|
+
metadata: PyModelMetadata,
|
22
|
+
) -> bytes:
|
23
|
+
"""Packs the initialization function and step function into a directory.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
init_fn: The initialization function.
|
27
|
+
step_fn: The step function.
|
28
|
+
metadata: The metadata for the model.
|
29
|
+
"""
|
30
|
+
num_joints = len(metadata.joint_names) # type: ignore[attr-defined]
|
31
|
+
|
32
|
+
# Checks the `init` function.
|
33
|
+
if len(init_fn.graph.input) > 0:
|
34
|
+
raise ValueError(f"`init` function should not have any inputs! Got {len(init_fn.graph.input)}")
|
35
|
+
if len(init_fn.graph.output) != 1:
|
36
|
+
raise ValueError(f"`init` function should have exactly 1 output! Got {len(init_fn.graph.output)}")
|
37
|
+
init_carry = init_fn.graph.output[0]
|
38
|
+
init_carry_shape = tuple(dim.dim_value for dim in init_carry.type.tensor_type.shape.dim)
|
39
|
+
|
40
|
+
if metadata.carry_size != init_carry_shape: # type: ignore[attr-defined]
|
41
|
+
logger.warning(
|
42
|
+
"Updating carry size from %s to %s to match the `init` function",
|
43
|
+
metadata.carry_size, # type: ignore[attr-defined]
|
44
|
+
init_carry_shape,
|
45
|
+
)
|
46
|
+
metadata.carry_size = init_carry_shape # type: ignore[attr-defined]
|
47
|
+
|
48
|
+
# Checks the `step` function.
|
49
|
+
for step_input in step_fn.graph.input:
|
50
|
+
step_input_type = step_input.type.tensor_type
|
51
|
+
shape = tuple(dim.dim_value for dim in step_input_type.shape.dim)
|
52
|
+
expected_shape = PyInputType(step_input.name).get_shape(metadata)
|
53
|
+
if shape != tuple(expected_shape):
|
54
|
+
raise ValueError(f"Expected shape {expected_shape} for input `{step_input.name}`, got {shape}")
|
55
|
+
|
56
|
+
if len(step_fn.graph.output) != 2:
|
57
|
+
raise ValueError(f"Step function must have exactly 2 outputs, got {len(step_fn.graph.output)}")
|
58
|
+
|
59
|
+
output_actions = step_fn.graph.output[0]
|
60
|
+
actions_shape = tuple(dim.dim_value for dim in output_actions.type.tensor_type.shape.dim)
|
61
|
+
if actions_shape != (num_joints,):
|
62
|
+
raise ValueError(f"Expected output shape {num_joints} for output `{output_actions.name}`, got {actions_shape}")
|
63
|
+
|
64
|
+
output_carry = step_fn.graph.output[1]
|
65
|
+
output_carry_shape = tuple(dim.dim_value for dim in output_carry.type.tensor_type.shape.dim)
|
66
|
+
if output_carry_shape != init_carry_shape:
|
67
|
+
raise ValueError(f"Expected carry shape {init_carry_shape} for output carry, got {output_carry_shape}")
|
68
|
+
|
69
|
+
buffer = io.BytesIO()
|
70
|
+
|
71
|
+
with tarfile.open(fileobj=buffer, mode="w:gz") as tar:
|
72
|
+
|
73
|
+
def add_file_bytes(name: str, data: bytes) -> None: # noqa: ANN401
|
74
|
+
info = tarfile.TarInfo(name=name)
|
75
|
+
info.size = len(data)
|
76
|
+
tar.addfile(info, io.BytesIO(data))
|
77
|
+
|
78
|
+
add_file_bytes("init_fn.onnx", init_fn.SerializeToString())
|
79
|
+
add_file_bytes("step_fn.onnx", step_fn.SerializeToString())
|
80
|
+
add_file_bytes("metadata.json", metadata.to_json().encode("utf-8"))
|
81
|
+
|
82
|
+
buffer.seek(0)
|
83
|
+
|
84
|
+
return buffer.read()
|
kinfer/py.typed
ADDED
File without changes
|
kinfer/requirements.txt
ADDED
kinfer/rust/Cargo.toml
ADDED
@@ -0,0 +1,36 @@
|
|
1
|
+
[package]
|
2
|
+
|
3
|
+
name = "kinfer"
|
4
|
+
version.workspace = true
|
5
|
+
edition.workspace = true
|
6
|
+
description.workspace = true
|
7
|
+
authors.workspace = true
|
8
|
+
repository.workspace = true
|
9
|
+
license.workspace = true
|
10
|
+
readme.workspace = true
|
11
|
+
|
12
|
+
[lib]
|
13
|
+
|
14
|
+
name = "kinfer"
|
15
|
+
crate-type = ["cdylib", "rlib"]
|
16
|
+
|
17
|
+
[dependencies]
|
18
|
+
|
19
|
+
async-trait = "0.1"
|
20
|
+
chrono = "0.4.41"
|
21
|
+
crossbeam-channel = "0.5.15"
|
22
|
+
flate2 = "1.0"
|
23
|
+
futures-util = "0.3.30"
|
24
|
+
log = "0.4"
|
25
|
+
ndarray = "0.16.1"
|
26
|
+
ort = { version = "2.0.0-rc.9", features = [ "load-dynamic" ] }
|
27
|
+
ort-sys = { version = "=2.0.0-rc.9" }
|
28
|
+
serde = { version = "1.0", features = ["derive"] }
|
29
|
+
serde_json = "1.0"
|
30
|
+
tar = "0.4"
|
31
|
+
thiserror = "1.0"
|
32
|
+
tokio = { version = "1.0", features = ["full"] }
|
33
|
+
|
34
|
+
[dev-dependencies]
|
35
|
+
|
36
|
+
rand = "0.8"
|
kinfer/rust/src/lib.rs
ADDED
@@ -0,0 +1,141 @@
|
|
1
|
+
use std::{
|
2
|
+
fs::OpenOptions,
|
3
|
+
io::{BufWriter, Write},
|
4
|
+
path::{Path, PathBuf},
|
5
|
+
sync::atomic::{AtomicU64, Ordering},
|
6
|
+
thread,
|
7
|
+
};
|
8
|
+
|
9
|
+
use crossbeam_channel::{bounded, Sender};
|
10
|
+
use log::{info, warn};
|
11
|
+
use serde::Serialize;
|
12
|
+
|
13
|
+
#[derive(Serialize)]
|
14
|
+
struct NdjsonStep<'a> {
|
15
|
+
step_id: u64,
|
16
|
+
t_us: u64,
|
17
|
+
joint_angles: Option<&'a [f32]>,
|
18
|
+
joint_vels: Option<&'a [f32]>,
|
19
|
+
initial_heading: Option<&'a [f32]>,
|
20
|
+
quaternion: Option<&'a [f32]>,
|
21
|
+
projected_g: Option<&'a [f32]>,
|
22
|
+
accel: Option<&'a [f32]>,
|
23
|
+
gyro: Option<&'a [f32]>,
|
24
|
+
command: Option<&'a [f32]>,
|
25
|
+
output: Option<&'a [f32]>,
|
26
|
+
}
|
27
|
+
|
28
|
+
// Channel capacity for non-blocking logging.
|
29
|
+
// ~1000 entires at 50Hz is 20 seconds of buffering.
|
30
|
+
// Warns if messages are dropped due to full buffer.
|
31
|
+
const CHANNEL_CAP: usize = 1024;
|
32
|
+
|
33
|
+
// Flush buffered writes every 100 log entries.
|
34
|
+
// At 50Hz control frequency, this flushes every 2 seconds.
|
35
|
+
const FLUSH_EVERY: u64 = 100;
|
36
|
+
|
37
|
+
pub struct StepLogger {
|
38
|
+
tx: Option<Sender<Vec<u8>>>,
|
39
|
+
worker: Option<thread::JoinHandle<()>>,
|
40
|
+
next_id: AtomicU64,
|
41
|
+
}
|
42
|
+
|
43
|
+
impl StepLogger {
|
44
|
+
pub fn new(path: impl AsRef<Path>) -> std::io::Result<Self> {
|
45
|
+
let path: PathBuf = path.as_ref().into();
|
46
|
+
if let Some(parent) = path.parent() {
|
47
|
+
std::fs::create_dir_all(parent)?;
|
48
|
+
}
|
49
|
+
info!("kinfer: logging to NDJSON: {}", path.display());
|
50
|
+
|
51
|
+
// I/O objects created here, but moved into the worker thread.
|
52
|
+
let file = OpenOptions::new().create(true).append(true).open(&path)?;
|
53
|
+
let mut bw = BufWriter::new(file);
|
54
|
+
|
55
|
+
// Bounded channel -> back-pressure capped at CHANNEL_CAP lines
|
56
|
+
let (tx, rx) = bounded::<Vec<u8>>(CHANNEL_CAP);
|
57
|
+
|
58
|
+
let worker = thread::spawn(move || {
|
59
|
+
let mut line_ctr: u64 = 0;
|
60
|
+
for msg in rx {
|
61
|
+
// drains until all senders dropped
|
62
|
+
let _ = bw.write_all(&msg);
|
63
|
+
line_ctr += 1;
|
64
|
+
if line_ctr % FLUSH_EVERY == 0 {
|
65
|
+
let _ = bw.flush();
|
66
|
+
}
|
67
|
+
}
|
68
|
+
// Final flush on graceful shutdown
|
69
|
+
let _ = bw.flush();
|
70
|
+
});
|
71
|
+
|
72
|
+
Ok(Self {
|
73
|
+
tx: Some(tx),
|
74
|
+
worker: Some(worker),
|
75
|
+
next_id: AtomicU64::new(0),
|
76
|
+
})
|
77
|
+
}
|
78
|
+
|
79
|
+
#[inline]
|
80
|
+
fn now_us() -> u128 {
|
81
|
+
std::time::SystemTime::now()
|
82
|
+
.duration_since(std::time::UNIX_EPOCH)
|
83
|
+
.unwrap()
|
84
|
+
.as_micros()
|
85
|
+
}
|
86
|
+
|
87
|
+
/// Non-blocking; drops a line if the channel is full.
|
88
|
+
#[allow(clippy::too_many_arguments)]
|
89
|
+
pub fn log_step(
|
90
|
+
&self,
|
91
|
+
joint_angles: Option<&[f32]>,
|
92
|
+
joint_vels: Option<&[f32]>,
|
93
|
+
initial_heading: Option<&[f32]>,
|
94
|
+
quaternion: Option<&[f32]>,
|
95
|
+
projected_g: Option<&[f32]>,
|
96
|
+
accel: Option<&[f32]>,
|
97
|
+
gyro: Option<&[f32]>,
|
98
|
+
command: Option<&[f32]>,
|
99
|
+
output: Option<&[f32]>,
|
100
|
+
) {
|
101
|
+
let record = NdjsonStep {
|
102
|
+
step_id: self.next_id.fetch_add(1, Ordering::Relaxed),
|
103
|
+
t_us: Self::now_us() as u64,
|
104
|
+
joint_angles,
|
105
|
+
joint_vels,
|
106
|
+
initial_heading,
|
107
|
+
quaternion,
|
108
|
+
projected_g,
|
109
|
+
accel,
|
110
|
+
gyro,
|
111
|
+
command,
|
112
|
+
output,
|
113
|
+
};
|
114
|
+
|
115
|
+
// Serialise directly into a Vec<u8>; then push newline and send.
|
116
|
+
if let Ok(mut line) = serde_json::to_vec(&record) {
|
117
|
+
line.push(b'\n');
|
118
|
+
if let Some(tx) = &self.tx {
|
119
|
+
if tx.try_send(line).is_err() {
|
120
|
+
warn!(
|
121
|
+
"kinfer: logging buffer full, dropped message (step_id: {})",
|
122
|
+
record.step_id
|
123
|
+
);
|
124
|
+
}
|
125
|
+
}
|
126
|
+
}
|
127
|
+
}
|
128
|
+
}
|
129
|
+
|
130
|
+
/// Ensure the worker drains and flushes before program exit.
|
131
|
+
impl Drop for StepLogger {
|
132
|
+
fn drop(&mut self) {
|
133
|
+
if let Some(tx) = self.tx.take() {
|
134
|
+
drop(tx); // Drop sender to close channel
|
135
|
+
}
|
136
|
+
// Wait for worker to finish
|
137
|
+
if let Some(worker) = self.worker.take() {
|
138
|
+
let _ = worker.join();
|
139
|
+
}
|
140
|
+
}
|
141
|
+
}
|