kinfer 0.4.0__cp312-cp312-macosx_11_0_arm64.whl → 0.4.2__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kinfer/__init__.py CHANGED
@@ -1,5 +1,16 @@
1
1
  """Defines the kinfer API."""
2
2
 
3
+ import os
4
+
5
+ if "ORT_DYLIB_PATH" not in os.environ:
6
+ from pathlib import Path
7
+
8
+ import onnxruntime as ort
9
+
10
+ LIB_PATH = next((Path(ort.__file__).parent / "capi").glob("libonnxruntime.*"), None)
11
+ if LIB_PATH is not None:
12
+ os.environ["ORT_DYLIB_PATH"] = LIB_PATH.resolve().as_posix()
13
+
3
14
  from .rust_bindings import get_version
4
15
 
5
16
  __version__ = get_version()
kinfer/common/types.py CHANGED
@@ -9,3 +9,4 @@ from pydantic import BaseModel
9
9
 
10
10
  class Metadata(BaseModel):
11
11
  joint_names: list[str]
12
+ num_commands: int | None
kinfer/export/common.py CHANGED
@@ -4,6 +4,7 @@
4
4
  def get_shape(
5
5
  name: str,
6
6
  num_joints: int | None = None,
7
+ num_commands: int | None = None,
7
8
  carry_shape: tuple[int, ...] | None = None,
8
9
  ) -> tuple[int, ...]:
9
10
  match name:
@@ -26,6 +27,11 @@ def get_shape(
26
27
  case "gyroscope":
27
28
  return (3,)
28
29
 
30
+ case "command":
31
+ if num_commands is None:
32
+ raise ValueError("`num_commands` must be provided when using `command`")
33
+ return (num_commands,)
34
+
29
35
  case "carry":
30
36
  if carry_shape is None:
31
37
  raise ValueError("`carry_shape` must be provided for `carry`")
kinfer/export/jax.py CHANGED
@@ -19,6 +19,7 @@ def export_fn(
19
19
  model: Wrapped,
20
20
  *,
21
21
  num_joints: int | None = None,
22
+ num_commands: int | None = None,
22
23
  carry_shape: tuple[int, ...] | None = None,
23
24
  opset: int = 13,
24
25
  ) -> ModelProto:
@@ -35,6 +36,7 @@ def export_fn(
35
36
  shape = get_shape(
36
37
  name,
37
38
  num_joints=num_joints,
39
+ num_commands=num_commands,
38
40
  carry_shape=carry_shape,
39
41
  )
40
42
  tf_args.append(tf.TensorSpec(shape, tf.float32, name=name))
kinfer/export/pytorch.py CHANGED
@@ -19,6 +19,7 @@ def export_fn(
19
19
  model: torch.jit.ScriptFunction,
20
20
  *,
21
21
  num_joints: int | None = None,
22
+ num_commands: int | None = None,
22
23
  carry_shape: tuple[int, ...] | None = None,
23
24
  ) -> ModelProto:
24
25
  """Exports a PyTorch function to ONNX.
@@ -26,6 +27,7 @@ def export_fn(
26
27
  Args:
27
28
  model: The model to export.
28
29
  num_joints: The number of joints in the model.
30
+ num_commands: The number of commands in the model.
29
31
  carry_shape: The shape of the carry tensor.
30
32
 
31
33
  Returns:
@@ -43,6 +45,7 @@ def export_fn(
43
45
  shape = get_shape(
44
46
  name,
45
47
  num_joints=num_joints,
48
+ num_commands=num_commands,
46
49
  carry_shape=carry_shape,
47
50
  )
48
51
  args.append(torch.zeros(shape))
@@ -18,7 +18,8 @@ def pack(
18
18
  init_fn: ModelProto,
19
19
  step_fn: ModelProto,
20
20
  joint_names: list[str],
21
- carry_shape: tuple[int, ...],
21
+ num_commands: int | None = None,
22
+ carry_shape: tuple[int, ...] | None = None,
22
23
  ) -> bytes:
23
24
  """Packs the initialization function and step function into a directory.
24
25
 
@@ -27,8 +28,8 @@ def pack(
27
28
  step_fn: The step function.
28
29
  joint_names: The list of joint names, in the order that the model
29
30
  expects them to be provided.
31
+ num_commands: The number of commands in the model.
30
32
  carry_shape: The shape of the carry tensor.
31
- root_dir: The root directory of the model.
32
33
  """
33
34
  num_joints = len(joint_names)
34
35
 
@@ -37,6 +38,10 @@ def pack(
37
38
  raise ValueError(f"`init` function should not have any inputs! Got {len(init_fn.graph.input)}")
38
39
  if len(init_fn.graph.output) != 1:
39
40
  raise ValueError(f"`init` function should have exactly 1 output! Got {len(init_fn.graph.output)}")
41
+ init_carry = init_fn.graph.output[0]
42
+ init_carry_shape = tuple(dim.dim_value for dim in init_carry.type.tensor_type.shape.dim)
43
+ if carry_shape is not None and init_carry_shape != carry_shape:
44
+ raise ValueError(f"Expected carry shape {carry_shape} for output `{init_carry.name}`, got {init_carry_shape}")
40
45
 
41
46
  # Checks the `step` function.
42
47
  for step_input in step_fn.graph.input:
@@ -45,6 +50,7 @@ def pack(
45
50
  expected_shape = get_shape(
46
51
  step_input.name,
47
52
  num_joints=num_joints,
53
+ num_commands=num_commands,
48
54
  carry_shape=carry_shape,
49
55
  )
50
56
  if shape != expected_shape:
@@ -53,19 +59,20 @@ def pack(
53
59
  if len(step_fn.graph.output) != 2:
54
60
  raise ValueError(f"Step function must have exactly 2 outputs, got {len(step_fn.graph.output)}")
55
61
 
56
- model_output = step_fn.graph.output[0]
57
- output_shape = tuple(dim.dim_value for dim in model_output.type.tensor_type.shape.dim)
58
- if output_shape != (num_joints,):
59
- raise ValueError(f"Expected output shape {num_joints} for output `{model_output.name}`, got {output_shape}")
62
+ output_actions = step_fn.graph.output[0]
63
+ actions_shape = tuple(dim.dim_value for dim in output_actions.type.tensor_type.shape.dim)
64
+ if actions_shape != (num_joints,):
65
+ raise ValueError(f"Expected output shape {num_joints} for output `{output_actions.name}`, got {actions_shape}")
60
66
 
61
- model_carry = step_fn.graph.output[1]
62
- output_carry_shape = tuple(dim.dim_value for dim in model_carry.type.tensor_type.shape.dim)
63
- if output_carry_shape != carry_shape:
64
- raise ValueError(f"Expected carry shape {carry_shape} for output `{model_carry.name}`, got {carry_shape}")
67
+ output_carry = step_fn.graph.output[1]
68
+ output_carry_shape = tuple(dim.dim_value for dim in output_carry.type.tensor_type.shape.dim)
69
+ if output_carry_shape != init_carry_shape:
70
+ raise ValueError(f"Expected carry shape {init_carry_shape} for output carry, got {output_carry_shape}")
65
71
 
66
72
  # Builds the metadata object.
67
73
  metadata = Metadata(
68
74
  joint_names=joint_names,
75
+ num_commands=num_commands,
69
76
  )
70
77
 
71
78
  buffer = io.BytesIO()
kinfer/rust/Cargo.toml CHANGED
@@ -1,8 +1,18 @@
1
1
  [package]
2
2
 
3
3
  name = "kinfer"
4
- version = "0.1.0"
5
- edition = "2021"
4
+ version.workspace = true
5
+ edition.workspace = true
6
+ description.workspace = true
7
+ authors.workspace = true
8
+ repository.workspace = true
9
+ license.workspace = true
10
+ readme.workspace = true
11
+
12
+ [lib]
13
+
14
+ name = "kinfer"
15
+ crate-type = ["cdylib", "rlib"]
6
16
 
7
17
  [dependencies]
8
18
 
kinfer/rust/src/model.rs CHANGED
@@ -16,6 +16,7 @@ use tokio::io::AsyncReadExt;
16
16
  #[derive(Debug, Deserialize)]
17
17
  struct ModelMetadata {
18
18
  joint_names: Vec<String>,
19
+ num_commands: Option<usize>,
19
20
  }
20
21
 
21
22
  impl ModelMetadata {
@@ -45,6 +46,7 @@ pub trait ModelProvider: Send + Sync {
45
46
  async fn get_projected_gravity(&self) -> Result<Array<f32, IxDyn>, ModelError>;
46
47
  async fn get_accelerometer(&self) -> Result<Array<f32, IxDyn>, ModelError>;
47
48
  async fn get_gyroscope(&self) -> Result<Array<f32, IxDyn>, ModelError>;
49
+ async fn get_command(&self) -> Result<Array<f32, IxDyn>, ModelError>;
48
50
  async fn get_carry(&self, carry: Array<f32, IxDyn>) -> Result<Array<f32, IxDyn>, ModelError>;
49
51
  async fn take_action(
50
52
  &self,
@@ -134,7 +136,7 @@ impl ModelRunner {
134
136
  .to_vec();
135
137
 
136
138
  // Validate step_fn inputs and outputs
137
- Self::validate_step_fn(&step_session, metadata.joint_names.len(), &carry_shape)?;
139
+ Self::validate_step_fn(&step_session, &metadata, &carry_shape)?;
138
140
 
139
141
  Ok(Self {
140
142
  init_session,
@@ -146,7 +148,7 @@ impl ModelRunner {
146
148
 
147
149
  fn validate_step_fn(
148
150
  session: &Session,
149
- num_joints: usize,
151
+ metadata: &ModelMetadata,
150
152
  carry_shape: &[i64],
151
153
  ) -> Result<(), Box<dyn std::error::Error>> {
152
154
  // Validate inputs
@@ -158,6 +160,7 @@ impl ModelRunner {
158
160
 
159
161
  match input.name.as_str() {
160
162
  "joint_angles" | "joint_angular_velocities" => {
163
+ let num_joints = metadata.joint_names.len();
161
164
  if *dims != vec![num_joints as i64] {
162
165
  return Err(format!(
163
166
  "Expected shape [{num_joints}] for input `{}`, got {:?}",
@@ -175,6 +178,16 @@ impl ModelRunner {
175
178
  .into());
176
179
  }
177
180
  }
181
+ "command" => {
182
+ let num_commands = metadata.num_commands.ok_or("num_commands is not set")?;
183
+ if *dims != vec![num_commands as i64] {
184
+ return Err(format!(
185
+ "Expected shape [{num_commands}] for input `{}`, got {:?}",
186
+ input.name, dims
187
+ )
188
+ .into());
189
+ }
190
+ }
178
191
  "carry" => {
179
192
  if dims != carry_shape {
180
193
  return Err(format!(
@@ -197,6 +210,7 @@ impl ModelRunner {
197
210
  .output_type
198
211
  .tensor_dimensions()
199
212
  .ok_or("Missing tensor type")?;
213
+ let num_joints = metadata.joint_names.len();
200
214
  if *output_shape != vec![num_joints as i64] {
201
215
  return Err(format!(
202
216
  "Expected output shape [{num_joints}], got {:?}",
@@ -253,6 +267,7 @@ impl ModelRunner {
253
267
  "projected_gravity" => futures.push(self.provider.get_projected_gravity()),
254
268
  "accelerometer" => futures.push(self.provider.get_accelerometer()),
255
269
  "gyroscope" => futures.push(self.provider.get_gyroscope()),
270
+ "command" => futures.push(self.provider.get_command()),
256
271
  "carry" => futures.push(self.provider.get_carry(carry.clone())),
257
272
  _ => return Err(format!("Unknown input name: {}", name).into()),
258
273
  }
@@ -13,6 +13,7 @@ class ModelProviderABC:
13
13
  def get_projected_gravity(self) -> numpy.typing.NDArray[numpy.float32]: ...
14
14
  def get_accelerometer(self) -> numpy.typing.NDArray[numpy.float32]: ...
15
15
  def get_gyroscope(self) -> numpy.typing.NDArray[numpy.float32]: ...
16
+ def get_command(self) -> numpy.typing.NDArray[numpy.float32]: ...
16
17
  def take_action(self, joint_names:typing.Sequence[builtins.str], action:numpy.typing.NDArray[numpy.float32]) -> None: ...
17
18
 
18
19
  class PyModelProvider:
@@ -69,6 +69,10 @@ impl ModelProviderABC {
69
69
  ))
70
70
  }
71
71
 
72
+ fn get_command<'py>(&self) -> PyResult<Bound<'py, PyArray1<f32>>> {
73
+ Err(PyNotImplementedError::new_err("Must override get_command"))
74
+ }
75
+
72
76
  fn take_action<'py>(
73
77
  &self,
74
78
  joint_names: Vec<String>,
@@ -166,6 +170,18 @@ impl ModelProvider for PyModelProvider {
166
170
  Ok(args)
167
171
  }
168
172
 
173
+ async fn get_command(&self) -> Result<Array<f32, IxDyn>, ModelError> {
174
+ let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
175
+ let obj = self.obj.clone();
176
+ let args = ();
177
+ let result = obj.call_method(py, "get_command", args, None)?;
178
+ let array = result.extract::<Vec<f32>>(py)?;
179
+ Ok(Array::from_vec(array).into_dyn())
180
+ })
181
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
182
+ Ok(args)
183
+ }
184
+
169
185
  async fn get_carry(&self, carry: Array<f32, IxDyn>) -> Result<Array<f32, IxDyn>, ModelError> {
170
186
  Ok(carry)
171
187
  }
Binary file
kinfer/rust_bindings.pyi CHANGED
@@ -13,6 +13,7 @@ class ModelProviderABC:
13
13
  def get_projected_gravity(self) -> numpy.typing.NDArray[numpy.float32]: ...
14
14
  def get_accelerometer(self) -> numpy.typing.NDArray[numpy.float32]: ...
15
15
  def get_gyroscope(self) -> numpy.typing.NDArray[numpy.float32]: ...
16
+ def get_command(self) -> numpy.typing.NDArray[numpy.float32]: ...
16
17
  def take_action(self, joint_names:typing.Sequence[builtins.str], action:numpy.typing.NDArray[numpy.float32]) -> None: ...
17
18
 
18
19
  class PyModelProvider:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kinfer
3
- Version: 0.4.0
3
+ Version: 0.4.2
4
4
  Summary: Tool to make it easier to run a model on a real robot
5
5
  Home-page: https://github.com/kscalelabs/kinfer.git
6
6
  Author: K-Scale Labs
@@ -0,0 +1,26 @@
1
+ kinfer/requirements.txt,sha256=j08HO4ptA5afuj99j8FlAP2qla5Zf4_OiEBtgAmF7Jg,90
2
+ kinfer/__init__.py,sha256=i5da6ND827Cgn8PFKzDCmEBk14ptQasLa_9fdof4Y9c,398
3
+ kinfer/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ kinfer/rust_bindings.pyi,sha256=9u0_TuHwV9xQH3O5OKbYfUZ_Uj6bsUwEOW0bfZujRMQ,1709
5
+ kinfer/rust_bindings.cpython-312-darwin.so,sha256=efrQ9tXPm6oqyEDDRBE5nciU9qdUKjq53USCt2-g4Qg,1930976
6
+ kinfer/rust/Cargo.toml,sha256=uxjvuM1Sm82UgCdHW3R6PbbiJjXvfE4VCLvjGPqp2mY,606
7
+ kinfer/rust/src/runtime.rs,sha256=eDflFSIa4IxOmG3SInwrHrYc02UzEANZMkK-liA7pVA,3407
8
+ kinfer/rust/src/lib.rs,sha256=Z3dWdhKhqhWqPJee6vWde4ptqquOYW6W9wB0wyaKCyk,71
9
+ kinfer/rust/src/model.rs,sha256=NqxWjSDivkN2mLAGCqytjfYHlx8dpP4zqoatx7pRWg4,11419
10
+ kinfer/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ kinfer/common/types.py,sha256=jiuCfoTEumOJm4VZ03VoFVIL5PMJ1tCG7b2Kr66qJQA,176
12
+ kinfer/export/serialize.py,sha256=7rcSrXawVnJinxeJp26wovUwoR04BnJ5CZAV9vwKUUo,3474
13
+ kinfer/export/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
+ kinfer/export/jax.py,sha256=gZWWUJKb3MPEtfNYeKfPnypeK0elJwGuDYfpCW0GqNY,1413
15
+ kinfer/export/pytorch.py,sha256=rwHGOTcd08IeL5-HZ_7JVUMPMHleiHI8-TkShAoRrYM,1612
16
+ kinfer/export/common.py,sha256=QWcLXY-jzZCbrVPWN7gFGLE5RbnZHLhuuvoH9agEvwk,1244
17
+ kinfer/rust_bindings/Cargo.toml,sha256=i1RGB9VNd9Q4FJ6gGwjZJQYo8DBBvpVWf3GJ95EfVgM,637
18
+ kinfer/rust_bindings/pyproject.toml,sha256=jLcJuHCnQRh9HWR_R7a9qLHwj6LMBgnHyeKK_DruO1Y,135
19
+ kinfer/rust_bindings/rust_bindings.pyi,sha256=9u0_TuHwV9xQH3O5OKbYfUZ_Uj6bsUwEOW0bfZujRMQ,1709
20
+ kinfer/rust_bindings/src/lib.rs,sha256=ydTTyx68d1fZDzc6kjFpU9ePDQa9IBD8Gf2rPy-kutI,10840
21
+ kinfer/rust_bindings/src/bin/stub_gen.rs,sha256=hhoVGnaSfazbSfj5a4x6mPicGPOgWQAfsDmiPej0B6Y,133
22
+ kinfer-0.4.2.dist-info/RECORD,,
23
+ kinfer-0.4.2.dist-info/WHEEL,sha256=mP9bWt4ASeNWfyg7GBBbGbsOVFgblaN5WklJcvrSjIE,136
24
+ kinfer-0.4.2.dist-info/top_level.txt,sha256=6mY_t3PYr3Dm0dpqMk80uSnArbvGfCFkxOh1QWtgDEo,7
25
+ kinfer-0.4.2.dist-info/METADATA,sha256=utM93j1dUATyhQF8h2ittLqKCOwCAt_M-pnJ68dODh4,1745
26
+ kinfer-0.4.2.dist-info/licenses/LICENSE,sha256=Qw-Z0XTwS-diSW91e_jLeBPX9zZbAatOJTBLdPHPaC0,1069
@@ -1,26 +0,0 @@
1
- kinfer/requirements.txt,sha256=j08HO4ptA5afuj99j8FlAP2qla5Zf4_OiEBtgAmF7Jg,90
2
- kinfer/__init__.py,sha256=YbtJIepEE4pbjYbdgFCV-rBP90AnQpBfDaprkflBmEE,99
3
- kinfer/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- kinfer/rust_bindings.pyi,sha256=oQKlW_bRw_c7fgVMbqiscPJ57pli25DkTCA_odsMOSU,1639
5
- kinfer/rust_bindings.cpython-312-darwin.so,sha256=2qlbCxfS5Dsxdw077wpnPjt0WQhjm5CRtcA4S7aTG8A,1912656
6
- kinfer/rust/Cargo.toml,sha256=0SLLbjtoODLrSMcyPdg22Ora-JNdaCqbMKb1c5OQblU,404
7
- kinfer/rust/src/runtime.rs,sha256=eDflFSIa4IxOmG3SInwrHrYc02UzEANZMkK-liA7pVA,3407
8
- kinfer/rust/src/lib.rs,sha256=Z3dWdhKhqhWqPJee6vWde4ptqquOYW6W9wB0wyaKCyk,71
9
- kinfer/rust/src/model.rs,sha256=pODQzD-roXpCJ7-kzaE6uULgqHgsHKb3oQjv-uKtCYk,10668
10
- kinfer/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- kinfer/common/types.py,sha256=Yf3c8Tui7vWguXMmxxa_QynwydCPu9RxrJS8JJMDuKg,147
12
- kinfer/export/serialize.py,sha256=pOxXdcWMoaNfnAUhgzHUTb7GtZvn0MYtKskIarHzAiY,3007
13
- kinfer/export/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- kinfer/export/jax.py,sha256=nbZB7Bxed4MQmfBZrH_Ml6PnkHQoyq5jdOT01DxiZDA,1337
15
- kinfer/export/pytorch.py,sha256=twGpcX6OaJgiiwQ0teo_4rrCFTrCqu-ULH2UEHl8auc,1477
16
- kinfer/export/common.py,sha256=ZiOAtehb2_ATg7t3tWzxbHUkGhtJ8vqjlaaw94-Iux4,1021
17
- kinfer/rust_bindings/Cargo.toml,sha256=i1RGB9VNd9Q4FJ6gGwjZJQYo8DBBvpVWf3GJ95EfVgM,637
18
- kinfer/rust_bindings/pyproject.toml,sha256=jLcJuHCnQRh9HWR_R7a9qLHwj6LMBgnHyeKK_DruO1Y,135
19
- kinfer/rust_bindings/rust_bindings.pyi,sha256=oQKlW_bRw_c7fgVMbqiscPJ57pli25DkTCA_odsMOSU,1639
20
- kinfer/rust_bindings/src/lib.rs,sha256=lKYr6Imu0Z-XxDKCguJLfrhIuai6S2bEhRAe__9sFnc,10196
21
- kinfer/rust_bindings/src/bin/stub_gen.rs,sha256=hhoVGnaSfazbSfj5a4x6mPicGPOgWQAfsDmiPej0B6Y,133
22
- kinfer-0.4.0.dist-info/RECORD,,
23
- kinfer-0.4.0.dist-info/WHEEL,sha256=mP9bWt4ASeNWfyg7GBBbGbsOVFgblaN5WklJcvrSjIE,136
24
- kinfer-0.4.0.dist-info/top_level.txt,sha256=6mY_t3PYr3Dm0dpqMk80uSnArbvGfCFkxOh1QWtgDEo,7
25
- kinfer-0.4.0.dist-info/METADATA,sha256=2gGRmUwgmAoYK8kRG5qTtteIMrdaUYW7Edvpk_32yas,1745
26
- kinfer-0.4.0.dist-info/licenses/LICENSE,sha256=Qw-Z0XTwS-diSW91e_jLeBPX9zZbAatOJTBLdPHPaC0,1069
File without changes