kinfer 0.4.3__tar.gz → 0.5.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. {kinfer-0.4.3 → kinfer-0.5.3}/Cargo.toml +1 -1
  2. {kinfer-0.4.3 → kinfer-0.5.3}/PKG-INFO +4 -1
  3. kinfer-0.5.3/kinfer/export/__init__.py +3 -0
  4. {kinfer-0.4.3 → kinfer-0.5.3}/kinfer/export/jax.py +13 -11
  5. {kinfer-0.4.3 → kinfer-0.5.3}/kinfer/export/pytorch.py +4 -14
  6. {kinfer-0.4.3 → kinfer-0.5.3}/kinfer/export/serialize.py +18 -27
  7. {kinfer-0.4.3 → kinfer-0.5.3}/kinfer/requirements.txt +5 -0
  8. {kinfer-0.4.3 → kinfer-0.5.3}/kinfer/rust/Cargo.toml +4 -0
  9. {kinfer-0.4.3 → kinfer-0.5.3}/kinfer/rust/src/lib.rs +3 -0
  10. kinfer-0.5.3/kinfer/rust/src/logger.rs +141 -0
  11. {kinfer-0.4.3 → kinfer-0.5.3}/kinfer/rust/src/model.rs +126 -104
  12. {kinfer-0.4.3 → kinfer-0.5.3}/kinfer/rust/src/runtime.rs +5 -2
  13. kinfer-0.5.3/kinfer/rust/src/types.rs +96 -0
  14. kinfer-0.5.3/kinfer/rust_bindings/src/lib.rs +431 -0
  15. kinfer-0.5.3/kinfer/scripts/plot_ndjson.py +177 -0
  16. {kinfer-0.4.3 → kinfer-0.5.3}/kinfer.egg-info/PKG-INFO +4 -1
  17. {kinfer-0.4.3 → kinfer-0.5.3}/kinfer.egg-info/SOURCES.txt +4 -3
  18. {kinfer-0.4.3 → kinfer-0.5.3}/kinfer.egg-info/requires.txt +3 -0
  19. kinfer-0.5.3/tests/test_common.py +58 -0
  20. {kinfer-0.4.3 → kinfer-0.5.3}/tests/test_jax.py +31 -41
  21. {kinfer-0.4.3 → kinfer-0.5.3}/tests/test_pytorch.py +33 -44
  22. kinfer-0.4.3/kinfer/common/__init__.py +0 -0
  23. kinfer-0.4.3/kinfer/common/types.py +0 -12
  24. kinfer-0.4.3/kinfer/export/__init__.py +0 -0
  25. kinfer-0.4.3/kinfer/export/common.py +0 -44
  26. kinfer-0.4.3/kinfer/rust_bindings/src/lib.rs +0 -360
  27. {kinfer-0.4.3 → kinfer-0.5.3}/.cargo/config.toml +0 -0
  28. {kinfer-0.4.3 → kinfer-0.5.3}/LICENSE +0 -0
  29. {kinfer-0.4.3 → kinfer-0.5.3}/MANIFEST.in +0 -0
  30. {kinfer-0.4.3 → kinfer-0.5.3}/README.md +0 -0
  31. {kinfer-0.4.3 → kinfer-0.5.3}/kinfer/__init__.py +0 -0
  32. {kinfer-0.4.3 → kinfer-0.5.3}/kinfer/py.typed +0 -0
  33. {kinfer-0.4.3 → kinfer-0.5.3}/kinfer/rust_bindings/Cargo.toml +0 -0
  34. {kinfer-0.4.3 → kinfer-0.5.3}/kinfer/rust_bindings/pyproject.toml +0 -0
  35. {kinfer-0.4.3 → kinfer-0.5.3}/kinfer/rust_bindings/src/bin/stub_gen.rs +0 -0
  36. {kinfer-0.4.3 → kinfer-0.5.3}/kinfer.egg-info/dependency_links.txt +0 -0
  37. {kinfer-0.4.3 → kinfer-0.5.3}/kinfer.egg-info/not-zip-safe +0 -0
  38. {kinfer-0.4.3 → kinfer-0.5.3}/kinfer.egg-info/top_level.txt +0 -0
  39. {kinfer-0.4.3 → kinfer-0.5.3}/pyproject.toml +0 -0
  40. {kinfer-0.4.3 → kinfer-0.5.3}/setup.cfg +0 -0
  41. {kinfer-0.4.3 → kinfer-0.5.3}/setup.py +0 -0
@@ -8,7 +8,7 @@ resolver = "2"
8
8
 
9
9
  [workspace.package]
10
10
 
11
- version = "0.4.3"
11
+ version = "0.5.3"
12
12
  authors = ["Wesley Maa <wesley@kscale.dev>", "Benjamin Bolte <ben@kscale.dev>"]
13
13
  edition = "2021"
14
14
  description = "K-Scale Inference Library"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kinfer
3
- Version: 0.4.3
3
+ Version: 0.5.3
4
4
  Summary: Tool to make it easier to run a model on a real robot
5
5
  Home-page: https://github.com/kscalelabs/kinfer.git
6
6
  Author: K-Scale Labs
@@ -10,6 +10,9 @@ License-File: LICENSE
10
10
  Requires-Dist: onnx
11
11
  Requires-Dist: onnxruntime==1.20.0
12
12
  Requires-Dist: pydantic
13
+ Requires-Dist: matplotlib
14
+ Requires-Dist: numpy
15
+ Requires-Dist: pathlib
13
16
  Provides-Extra: dev
14
17
  Requires-Dist: black; extra == "dev"
15
18
  Requires-Dist: darglint; extra == "dev"
@@ -0,0 +1,3 @@
1
+ """Defines the export API."""
2
+
3
+ from .serialize import *
@@ -10,20 +10,27 @@ from jax._src.stages import Wrapped
10
10
  from jax.experimental import jax2tf
11
11
  from onnx.onnx_pb import ModelProto
12
12
 
13
- from kinfer.export.common import get_shape
13
+ from kinfer.rust_bindings import PyInputType, PyModelMetadata
14
14
 
15
15
  logger = logging.getLogger(__name__)
16
16
 
17
17
 
18
18
  def export_fn(
19
19
  model: Wrapped,
20
+ metadata: PyModelMetadata,
20
21
  *,
21
- num_joints: int | None = None,
22
- num_commands: int | None = None,
23
- carry_shape: tuple[int, ...] | None = None,
24
22
  opset: int = 13,
25
23
  ) -> ModelProto:
26
- """Export a JAX function to ONNX."""
24
+ """Export a JAX function to ONNX.
25
+
26
+ Args:
27
+ model: The model to export.
28
+ metadata: The metadata for the model.
29
+ opset: The ONNX opset to use.
30
+
31
+ Returns:
32
+ The ONNX model as a `ModelProto`.
33
+ """
27
34
  if not isinstance(model, Wrapped):
28
35
  raise ValueError("Model must be a Wrapped function")
29
36
 
@@ -33,12 +40,7 @@ def export_fn(
33
40
  # Gets the dummy input tensors for exporting the model.
34
41
  tf_args = []
35
42
  for name in input_names:
36
- shape = get_shape(
37
- name,
38
- num_joints=num_joints,
39
- num_commands=num_commands,
40
- carry_shape=carry_shape,
41
- )
43
+ shape = PyInputType(name).get_shape(metadata)
42
44
  tf_args.append(tf.TensorSpec(shape, tf.float32, name=name))
43
45
 
44
46
  finalised_fn = finalise_fn(model)
@@ -12,23 +12,18 @@ import torch
12
12
  from onnx.onnx_pb import ModelProto
13
13
  from torch._C import FunctionSchema
14
14
 
15
- from kinfer.export.common import get_shape
15
+ from kinfer.rust_bindings import PyInputType, PyModelMetadata
16
16
 
17
17
 
18
18
  def export_fn(
19
19
  model: torch.jit.ScriptFunction,
20
- *,
21
- num_joints: int | None = None,
22
- num_commands: int | None = None,
23
- carry_shape: tuple[int, ...] | None = None,
20
+ metadata: PyModelMetadata,
24
21
  ) -> ModelProto:
25
22
  """Exports a PyTorch function to ONNX.
26
23
 
27
24
  Args:
28
25
  model: The model to export.
29
- num_joints: The number of joints in the model.
30
- num_commands: The number of commands in the model.
31
- carry_shape: The shape of the carry tensor.
26
+ metadata: The metadata for the model.
32
27
 
33
28
  Returns:
34
29
  The ONNX model as a `ModelProto`.
@@ -42,12 +37,7 @@ def export_fn(
42
37
  # Gets the dummy input tensors for exporting the model.
43
38
  args = []
44
39
  for name in input_names:
45
- shape = get_shape(
46
- name,
47
- num_joints=num_joints,
48
- num_commands=num_commands,
49
- carry_shape=carry_shape,
50
- )
40
+ shape = PyInputType(name).get_shape(metadata)
51
41
  args.append(torch.zeros(shape))
52
42
 
53
43
  buffer = io.BytesIO()
@@ -4,34 +4,30 @@ __all__ = [
4
4
  "pack",
5
5
  ]
6
6
 
7
-
8
7
  import io
8
+ import logging
9
9
  import tarfile
10
10
 
11
11
  from onnx.onnx_pb import ModelProto
12
12
 
13
- from kinfer.common.types import Metadata
14
- from kinfer.export.common import get_shape
13
+ from kinfer.rust_bindings import PyInputType, PyModelMetadata
14
+
15
+ logger = logging.getLogger(__name__)
15
16
 
16
17
 
17
18
  def pack(
18
19
  init_fn: ModelProto,
19
20
  step_fn: ModelProto,
20
- joint_names: list[str],
21
- num_commands: int | None = None,
22
- carry_shape: tuple[int, ...] | None = None,
21
+ metadata: PyModelMetadata,
23
22
  ) -> bytes:
24
23
  """Packs the initialization function and step function into a directory.
25
24
 
26
25
  Args:
27
26
  init_fn: The initialization function.
28
27
  step_fn: The step function.
29
- joint_names: The list of joint names, in the order that the model
30
- expects them to be provided.
31
- num_commands: The number of commands in the model.
32
- carry_shape: The shape of the carry tensor.
28
+ metadata: The metadata for the model.
33
29
  """
34
- num_joints = len(joint_names)
30
+ num_joints = len(metadata.joint_names) # type: ignore[attr-defined]
35
31
 
36
32
  # Checks the `init` function.
37
33
  if len(init_fn.graph.input) > 0:
@@ -40,20 +36,21 @@ def pack(
40
36
  raise ValueError(f"`init` function should have exactly 1 output! Got {len(init_fn.graph.output)}")
41
37
  init_carry = init_fn.graph.output[0]
42
38
  init_carry_shape = tuple(dim.dim_value for dim in init_carry.type.tensor_type.shape.dim)
43
- if carry_shape is not None and init_carry_shape != carry_shape:
44
- raise ValueError(f"Expected carry shape {carry_shape} for output `{init_carry.name}`, got {init_carry_shape}")
39
+
40
+ if metadata.carry_size != init_carry_shape: # type: ignore[attr-defined]
41
+ logger.warning(
42
+ "Updating carry size from %s to %s to match the `init` function",
43
+ metadata.carry_size, # type: ignore[attr-defined]
44
+ init_carry_shape,
45
+ )
46
+ metadata.carry_size = init_carry_shape # type: ignore[attr-defined]
45
47
 
46
48
  # Checks the `step` function.
47
49
  for step_input in step_fn.graph.input:
48
50
  step_input_type = step_input.type.tensor_type
49
51
  shape = tuple(dim.dim_value for dim in step_input_type.shape.dim)
50
- expected_shape = get_shape(
51
- step_input.name,
52
- num_joints=num_joints,
53
- num_commands=num_commands,
54
- carry_shape=carry_shape,
55
- )
56
- if shape != expected_shape:
52
+ expected_shape = PyInputType(step_input.name).get_shape(metadata)
53
+ if shape != tuple(expected_shape):
57
54
  raise ValueError(f"Expected shape {expected_shape} for input `{step_input.name}`, got {shape}")
58
55
 
59
56
  if len(step_fn.graph.output) != 2:
@@ -69,12 +66,6 @@ def pack(
69
66
  if output_carry_shape != init_carry_shape:
70
67
  raise ValueError(f"Expected carry shape {init_carry_shape} for output carry, got {output_carry_shape}")
71
68
 
72
- # Builds the metadata object.
73
- metadata = Metadata(
74
- joint_names=joint_names,
75
- num_commands=num_commands,
76
- )
77
-
78
69
  buffer = io.BytesIO()
79
70
 
80
71
  with tarfile.open(fileobj=buffer, mode="w:gz") as tar:
@@ -86,7 +77,7 @@ def pack(
86
77
 
87
78
  add_file_bytes("init_fn.onnx", init_fn.SerializeToString())
88
79
  add_file_bytes("step_fn.onnx", step_fn.SerializeToString())
89
- add_file_bytes("metadata.json", metadata.model_dump_json().encode("utf-8"))
80
+ add_file_bytes("metadata.json", metadata.to_json().encode("utf-8"))
90
81
 
91
82
  buffer.seek(0)
92
83
 
@@ -6,3 +6,8 @@ onnxruntime==1.20.0
6
6
 
7
7
  # Serialization
8
8
  pydantic
9
+
10
+ # Plotting
11
+ matplotlib
12
+ numpy
13
+ pathlib
@@ -17,10 +17,14 @@ crate-type = ["cdylib", "rlib"]
17
17
  [dependencies]
18
18
 
19
19
  async-trait = "0.1"
20
+ chrono = "0.4.41"
21
+ crossbeam-channel = "0.5.15"
20
22
  flate2 = "1.0"
21
23
  futures-util = "0.3.30"
24
+ log = "0.4"
22
25
  ndarray = "0.16.1"
23
26
  ort = { version = "2.0.0-rc.9", features = [ "load-dynamic" ] }
27
+ ort-sys = { version = "=2.0.0-rc.9" }
24
28
  serde = { version = "1.0", features = ["derive"] }
25
29
  serde_json = "1.0"
26
30
  tar = "0.4"
@@ -1,5 +1,8 @@
1
+ pub mod logger;
1
2
  pub mod model;
2
3
  pub mod runtime;
4
+ pub mod types;
3
5
 
4
6
  pub use model::*;
5
7
  pub use runtime::*;
8
+ pub use types::*;
@@ -0,0 +1,141 @@
1
+ use std::{
2
+ fs::OpenOptions,
3
+ io::{BufWriter, Write},
4
+ path::{Path, PathBuf},
5
+ sync::atomic::{AtomicU64, Ordering},
6
+ thread,
7
+ };
8
+
9
+ use crossbeam_channel::{bounded, Sender};
10
+ use log::{info, warn};
11
+ use serde::Serialize;
12
+
13
+ #[derive(Serialize)]
14
+ struct NdjsonStep<'a> {
15
+ step_id: u64,
16
+ t_us: u64,
17
+ joint_angles: Option<&'a [f32]>,
18
+ joint_vels: Option<&'a [f32]>,
19
+ initial_heading: Option<&'a [f32]>,
20
+ quaternion: Option<&'a [f32]>,
21
+ projected_g: Option<&'a [f32]>,
22
+ accel: Option<&'a [f32]>,
23
+ gyro: Option<&'a [f32]>,
24
+ command: Option<&'a [f32]>,
25
+ output: Option<&'a [f32]>,
26
+ }
27
+
28
+ // Channel capacity for non-blocking logging.
29
+ // ~1000 entires at 50Hz is 20 seconds of buffering.
30
+ // Warns if messages are dropped due to full buffer.
31
+ const CHANNEL_CAP: usize = 1024;
32
+
33
+ // Flush buffered writes every 100 log entries.
34
+ // At 50Hz control frequency, this flushes every 2 seconds.
35
+ const FLUSH_EVERY: u64 = 100;
36
+
37
+ pub struct StepLogger {
38
+ tx: Option<Sender<Vec<u8>>>,
39
+ worker: Option<thread::JoinHandle<()>>,
40
+ next_id: AtomicU64,
41
+ }
42
+
43
+ impl StepLogger {
44
+ pub fn new(path: impl AsRef<Path>) -> std::io::Result<Self> {
45
+ let path: PathBuf = path.as_ref().into();
46
+ if let Some(parent) = path.parent() {
47
+ std::fs::create_dir_all(parent)?;
48
+ }
49
+ info!("kinfer: logging to NDJSON: {}", path.display());
50
+
51
+ // I/O objects created here, but moved into the worker thread.
52
+ let file = OpenOptions::new().create(true).append(true).open(&path)?;
53
+ let mut bw = BufWriter::new(file);
54
+
55
+ // Bounded channel -> back-pressure capped at CHANNEL_CAP lines
56
+ let (tx, rx) = bounded::<Vec<u8>>(CHANNEL_CAP);
57
+
58
+ let worker = thread::spawn(move || {
59
+ let mut line_ctr: u64 = 0;
60
+ for msg in rx {
61
+ // drains until all senders dropped
62
+ let _ = bw.write_all(&msg);
63
+ line_ctr += 1;
64
+ if line_ctr % FLUSH_EVERY == 0 {
65
+ let _ = bw.flush();
66
+ }
67
+ }
68
+ // Final flush on graceful shutdown
69
+ let _ = bw.flush();
70
+ });
71
+
72
+ Ok(Self {
73
+ tx: Some(tx),
74
+ worker: Some(worker),
75
+ next_id: AtomicU64::new(0),
76
+ })
77
+ }
78
+
79
+ #[inline]
80
+ fn now_us() -> u128 {
81
+ std::time::SystemTime::now()
82
+ .duration_since(std::time::UNIX_EPOCH)
83
+ .unwrap()
84
+ .as_micros()
85
+ }
86
+
87
+ /// Non-blocking; drops a line if the channel is full.
88
+ #[allow(clippy::too_many_arguments)]
89
+ pub fn log_step(
90
+ &self,
91
+ joint_angles: Option<&[f32]>,
92
+ joint_vels: Option<&[f32]>,
93
+ initial_heading: Option<&[f32]>,
94
+ quaternion: Option<&[f32]>,
95
+ projected_g: Option<&[f32]>,
96
+ accel: Option<&[f32]>,
97
+ gyro: Option<&[f32]>,
98
+ command: Option<&[f32]>,
99
+ output: Option<&[f32]>,
100
+ ) {
101
+ let record = NdjsonStep {
102
+ step_id: self.next_id.fetch_add(1, Ordering::Relaxed),
103
+ t_us: Self::now_us() as u64,
104
+ joint_angles,
105
+ joint_vels,
106
+ initial_heading,
107
+ quaternion,
108
+ projected_g,
109
+ accel,
110
+ gyro,
111
+ command,
112
+ output,
113
+ };
114
+
115
+ // Serialise directly into a Vec<u8>; then push newline and send.
116
+ if let Ok(mut line) = serde_json::to_vec(&record) {
117
+ line.push(b'\n');
118
+ if let Some(tx) = &self.tx {
119
+ if tx.try_send(line).is_err() {
120
+ warn!(
121
+ "kinfer: logging buffer full, dropped message (step_id: {})",
122
+ record.step_id
123
+ );
124
+ }
125
+ }
126
+ }
127
+ }
128
+ }
129
+
130
+ /// Ensure the worker drains and flushes before program exit.
131
+ impl Drop for StepLogger {
132
+ fn drop(&mut self) {
133
+ if let Some(tx) = self.tx.take() {
134
+ drop(tx); // Drop sender to close channel
135
+ }
136
+ // Wait for worker to finish
137
+ if let Some(worker) = self.worker.take() {
138
+ let _ = worker.join();
139
+ }
140
+ }
141
+ }