kinfer 0.4.2__cp311-cp311-macosx_11_0_arm64.whl → 0.4.3__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kinfer/export/common.py CHANGED
@@ -37,5 +37,8 @@ def get_shape(
37
37
  raise ValueError("`carry_shape` must be provided for `carry`")
38
38
  return carry_shape
39
39
 
40
+ case "time":
41
+ return (1,)
42
+
40
43
  case _:
41
44
  raise ValueError(f"Unknown tensor name: {name}")
kinfer/rust/src/model.rs CHANGED
@@ -47,6 +47,7 @@ pub trait ModelProvider: Send + Sync {
47
47
  async fn get_accelerometer(&self) -> Result<Array<f32, IxDyn>, ModelError>;
48
48
  async fn get_gyroscope(&self) -> Result<Array<f32, IxDyn>, ModelError>;
49
49
  async fn get_command(&self) -> Result<Array<f32, IxDyn>, ModelError>;
50
+ async fn get_time(&self) -> Result<Array<f32, IxDyn>, ModelError>;
50
51
  async fn get_carry(&self, carry: Array<f32, IxDyn>) -> Result<Array<f32, IxDyn>, ModelError>;
51
52
  async fn take_action(
52
53
  &self,
@@ -188,6 +189,15 @@ impl ModelRunner {
188
189
  .into());
189
190
  }
190
191
  }
192
+ "time" => {
193
+ if *dims != vec![1] {
194
+ return Err(format!(
195
+ "Expected shape [1] for input `{}`, got {:?}",
196
+ input.name, dims
197
+ )
198
+ .into());
199
+ }
200
+ }
191
201
  "carry" => {
192
202
  if dims != carry_shape {
193
203
  return Err(format!(
@@ -269,6 +279,7 @@ impl ModelRunner {
269
279
  "gyroscope" => futures.push(self.provider.get_gyroscope()),
270
280
  "command" => futures.push(self.provider.get_command()),
271
281
  "carry" => futures.push(self.provider.get_carry(carry.clone())),
282
+ "time" => futures.push(self.provider.get_time()),
272
283
  _ => return Err(format!("Unknown input name: {}", name).into()),
273
284
  }
274
285
  }
@@ -3,8 +3,8 @@ use std::sync::Arc;
3
3
  use tokio::runtime::Runtime;
4
4
 
5
5
  use crate::model::{ModelError, ModelRunner};
6
- use std::time::{Duration, Instant};
7
- use tokio::time::sleep;
6
+ use std::time::Duration;
7
+ use tokio::time::interval;
8
8
 
9
9
  pub struct ModelRuntime {
10
10
  model_runner: Arc<ModelRunner>,
@@ -61,7 +61,10 @@ impl ModelRuntime {
61
61
  .get_joint_angles()
62
62
  .await
63
63
  .map_err(|e| ModelError::Provider(e.to_string()))?;
64
- let mut last_time = Instant::now();
64
+
65
+ // Wait for the first tick, since it happens immediately.
66
+ let mut interval = interval(dt);
67
+ interval.tick().await;
65
68
 
66
69
  while running.load(Ordering::Relaxed) {
67
70
  let (output, next_carry) = model_runner
@@ -80,10 +83,7 @@ impl ModelRuntime {
80
83
  .take_action(interp_joint_positions * magnitude_factor)
81
84
  .await
82
85
  .map_err(|e| ModelError::Provider(e.to_string()))?;
83
- last_time = last_time + dt;
84
- if let Some(sleep_duration) = last_time.checked_duration_since(Instant::now()) {
85
- sleep(sleep_duration).await;
86
- }
86
+ interval.tick().await;
87
87
  }
88
88
 
89
89
  joint_positions = output;
@@ -14,6 +14,7 @@ class ModelProviderABC:
14
14
  def get_accelerometer(self) -> numpy.typing.NDArray[numpy.float32]: ...
15
15
  def get_gyroscope(self) -> numpy.typing.NDArray[numpy.float32]: ...
16
16
  def get_command(self) -> numpy.typing.NDArray[numpy.float32]: ...
17
+ def get_time(self) -> numpy.typing.NDArray[numpy.float32]: ...
17
18
  def take_action(self, joint_names:typing.Sequence[builtins.str], action:numpy.typing.NDArray[numpy.float32]) -> None: ...
18
19
 
19
20
  class PyModelProvider:
@@ -10,7 +10,7 @@ use pyo3_stub_gen::define_stub_info_gatherer;
10
10
  use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyfunction, gen_stub_pymethods};
11
11
  use std::sync::Arc;
12
12
  use std::sync::Mutex;
13
-
13
+ use std::time::Instant;
14
14
  #[pyfunction]
15
15
  #[gen_stub_pyfunction]
16
16
  fn get_version() -> String {
@@ -73,6 +73,10 @@ impl ModelProviderABC {
73
73
  Err(PyNotImplementedError::new_err("Must override get_command"))
74
74
  }
75
75
 
76
+ fn get_time<'py>(&self) -> PyResult<Bound<'py, PyArray1<f32>>> {
77
+ Err(PyNotImplementedError::new_err("Must override get_time"))
78
+ }
79
+
76
80
  fn take_action<'py>(
77
81
  &self,
78
82
  joint_names: Vec<String>,
@@ -92,13 +96,17 @@ impl ModelProviderABC {
92
96
  #[derive(Clone)]
93
97
  struct PyModelProvider {
94
98
  obj: Arc<Py<ModelProviderABC>>,
99
+ start_time: Instant,
95
100
  }
96
101
 
97
102
  #[pymethods]
98
103
  impl PyModelProvider {
99
104
  #[new]
100
105
  fn new(obj: Py<ModelProviderABC>) -> Self {
101
- Self { obj: Arc::new(obj) }
106
+ Self {
107
+ obj: Arc::new(obj),
108
+ start_time: Instant::now(),
109
+ }
102
110
  }
103
111
  }
104
112
 
@@ -182,6 +190,18 @@ impl ModelProvider for PyModelProvider {
182
190
  Ok(args)
183
191
  }
184
192
 
193
+ async fn get_time(&self) -> Result<Array<f32, IxDyn>, ModelError> {
194
+ let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
195
+ let obj = self.obj.clone();
196
+ let args = ();
197
+ let result = obj.call_method(py, "get_time", args, None)?;
198
+ let array = result.extract::<Vec<f32>>(py)?;
199
+ Ok(Array::from_vec(array).into_dyn())
200
+ })
201
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
202
+ Ok(args)
203
+ }
204
+
185
205
  async fn get_carry(&self, carry: Array<f32, IxDyn>) -> Result<Array<f32, IxDyn>, ModelError> {
186
206
  Ok(carry)
187
207
  }
@@ -217,9 +237,7 @@ struct PyModelRunner {
217
237
  impl PyModelRunner {
218
238
  #[new]
219
239
  fn new(model_path: String, provider: Py<ModelProviderABC>) -> PyResult<Self> {
220
- let input_provider = Arc::new(PyModelProvider {
221
- obj: Arc::new(provider),
222
- });
240
+ let input_provider = Arc::new(PyModelProvider::new(provider));
223
241
 
224
242
  let runner = tokio::runtime::Runtime::new().unwrap().block_on(async {
225
243
  ModelRunner::new(model_path, input_provider)
Binary file
kinfer/rust_bindings.pyi CHANGED
@@ -14,6 +14,7 @@ class ModelProviderABC:
14
14
  def get_accelerometer(self) -> numpy.typing.NDArray[numpy.float32]: ...
15
15
  def get_gyroscope(self) -> numpy.typing.NDArray[numpy.float32]: ...
16
16
  def get_command(self) -> numpy.typing.NDArray[numpy.float32]: ...
17
+ def get_time(self) -> numpy.typing.NDArray[numpy.float32]: ...
17
18
  def take_action(self, joint_names:typing.Sequence[builtins.str], action:numpy.typing.NDArray[numpy.float32]) -> None: ...
18
19
 
19
20
  class PyModelProvider:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kinfer
3
- Version: 0.4.2
3
+ Version: 0.4.3
4
4
  Summary: Tool to make it easier to run a model on a real robot
5
5
  Home-page: https://github.com/kscalelabs/kinfer.git
6
6
  Author: K-Scale Labs
@@ -21,7 +21,7 @@ Provides-Extra: pytorch
21
21
  Requires-Dist: torch; extra == "pytorch"
22
22
  Provides-Extra: jax
23
23
  Requires-Dist: tensorflow; extra == "jax"
24
- Requires-Dist: tf2onnx; extra == "jax"
24
+ Requires-Dist: tf2onnx>=1.16.0; extra == "jax"
25
25
  Requires-Dist: jax; extra == "jax"
26
26
  Requires-Dist: equinox; extra == "jax"
27
27
  Requires-Dist: numpy<2; extra == "jax"
@@ -34,7 +34,7 @@ Requires-Dist: ruff; extra == "all"
34
34
  Requires-Dist: types-tensorflow; extra == "all"
35
35
  Requires-Dist: torch; extra == "all"
36
36
  Requires-Dist: tensorflow; extra == "all"
37
- Requires-Dist: tf2onnx; extra == "all"
37
+ Requires-Dist: tf2onnx>=1.16.0; extra == "all"
38
38
  Requires-Dist: jax; extra == "all"
39
39
  Requires-Dist: equinox; extra == "all"
40
40
  Requires-Dist: numpy<2; extra == "all"
@@ -1,26 +1,26 @@
1
- kinfer/rust_bindings.cpython-311-darwin.so,sha256=fHPKAukypHHVpNxF5OaE_l8IdxAliPUkGXVEhrO7yh8,1933008
1
+ kinfer/rust_bindings.cpython-311-darwin.so,sha256=ZUWCReCTDKhzqBteg3k4nJ-15ubYFMCSBYQdADqcvd0,1953904
2
2
  kinfer/requirements.txt,sha256=j08HO4ptA5afuj99j8FlAP2qla5Zf4_OiEBtgAmF7Jg,90
3
3
  kinfer/__init__.py,sha256=i5da6ND827Cgn8PFKzDCmEBk14ptQasLa_9fdof4Y9c,398
4
4
  kinfer/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- kinfer/rust_bindings.pyi,sha256=9u0_TuHwV9xQH3O5OKbYfUZ_Uj6bsUwEOW0bfZujRMQ,1709
5
+ kinfer/rust_bindings.pyi,sha256=lyBw-t_9ONYwVSaUTeeYDWpBbhKod1wf9Dvc4MgnC1g,1776
6
6
  kinfer/rust/Cargo.toml,sha256=uxjvuM1Sm82UgCdHW3R6PbbiJjXvfE4VCLvjGPqp2mY,606
7
- kinfer/rust/src/runtime.rs,sha256=eDflFSIa4IxOmG3SInwrHrYc02UzEANZMkK-liA7pVA,3407
7
+ kinfer/rust/src/runtime.rs,sha256=l9wG0XbZmpHSqho3SfGuW54P7YXs8eXZ365kOj7UqUE,3321
8
8
  kinfer/rust/src/lib.rs,sha256=Z3dWdhKhqhWqPJee6vWde4ptqquOYW6W9wB0wyaKCyk,71
9
- kinfer/rust/src/model.rs,sha256=NqxWjSDivkN2mLAGCqytjfYHlx8dpP4zqoatx7pRWg4,11419
9
+ kinfer/rust/src/model.rs,sha256=zY_XmkCZOq8VhVE_LpgTxKDrV-FceY8tnkkqS1tFa8c,11890
10
10
  kinfer/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  kinfer/common/types.py,sha256=jiuCfoTEumOJm4VZ03VoFVIL5PMJ1tCG7b2Kr66qJQA,176
12
12
  kinfer/export/serialize.py,sha256=7rcSrXawVnJinxeJp26wovUwoR04BnJ5CZAV9vwKUUo,3474
13
13
  kinfer/export/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  kinfer/export/jax.py,sha256=gZWWUJKb3MPEtfNYeKfPnypeK0elJwGuDYfpCW0GqNY,1413
15
15
  kinfer/export/pytorch.py,sha256=rwHGOTcd08IeL5-HZ_7JVUMPMHleiHI8-TkShAoRrYM,1612
16
- kinfer/export/common.py,sha256=QWcLXY-jzZCbrVPWN7gFGLE5RbnZHLhuuvoH9agEvwk,1244
16
+ kinfer/export/common.py,sha256=AYN_gwmxmSlZ7BuswGUKYcY8uVUCh9UyXnWw-k8RVys,1290
17
17
  kinfer/rust_bindings/Cargo.toml,sha256=i1RGB9VNd9Q4FJ6gGwjZJQYo8DBBvpVWf3GJ95EfVgM,637
18
18
  kinfer/rust_bindings/pyproject.toml,sha256=jLcJuHCnQRh9HWR_R7a9qLHwj6LMBgnHyeKK_DruO1Y,135
19
- kinfer/rust_bindings/rust_bindings.pyi,sha256=9u0_TuHwV9xQH3O5OKbYfUZ_Uj6bsUwEOW0bfZujRMQ,1709
20
- kinfer/rust_bindings/src/lib.rs,sha256=ydTTyx68d1fZDzc6kjFpU9ePDQa9IBD8Gf2rPy-kutI,10840
19
+ kinfer/rust_bindings/rust_bindings.pyi,sha256=lyBw-t_9ONYwVSaUTeeYDWpBbhKod1wf9Dvc4MgnC1g,1776
20
+ kinfer/rust_bindings/src/lib.rs,sha256=pjsRej2BxTsecG9Y7GdBpW6GGR-pgbBk3yrKqEEQ3K0,11547
21
21
  kinfer/rust_bindings/src/bin/stub_gen.rs,sha256=hhoVGnaSfazbSfj5a4x6mPicGPOgWQAfsDmiPej0B6Y,133
22
- kinfer-0.4.2.dist-info/RECORD,,
23
- kinfer-0.4.2.dist-info/WHEEL,sha256=rtKwvZSAzWV03G3Ircwq_TRBlj2DQz4ocNXl0bd9DbU,136
24
- kinfer-0.4.2.dist-info/top_level.txt,sha256=6mY_t3PYr3Dm0dpqMk80uSnArbvGfCFkxOh1QWtgDEo,7
25
- kinfer-0.4.2.dist-info/METADATA,sha256=utM93j1dUATyhQF8h2ittLqKCOwCAt_M-pnJ68dODh4,1745
26
- kinfer-0.4.2.dist-info/licenses/LICENSE,sha256=Qw-Z0XTwS-diSW91e_jLeBPX9zZbAatOJTBLdPHPaC0,1069
22
+ kinfer-0.4.3.dist-info/RECORD,,
23
+ kinfer-0.4.3.dist-info/WHEEL,sha256=3lrG374qykB8NIZqJ4ApcNfFLoAGB9tcfdD4_UMfO40,136
24
+ kinfer-0.4.3.dist-info/top_level.txt,sha256=6mY_t3PYr3Dm0dpqMk80uSnArbvGfCFkxOh1QWtgDEo,7
25
+ kinfer-0.4.3.dist-info/METADATA,sha256=DZKaMDT-DToGkw_Ge_nSYBoo3ZAzxatlv6sR0x9kv-A,1761
26
+ kinfer-0.4.3.dist-info/licenses/LICENSE,sha256=Qw-Z0XTwS-diSW91e_jLeBPX9zZbAatOJTBLdPHPaC0,1069
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.4.0)
2
+ Generator: setuptools (80.7.1)
3
3
  Root-Is-Purelib: false
4
4
  Tag: cp311-cp311-macosx_11_0_arm64
5
5
  Generator: delocate 0.13.0