kinfer 0.4.2__tar.gz → 0.4.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {kinfer-0.4.2 → kinfer-0.4.3}/Cargo.toml +1 -1
  2. {kinfer-0.4.2 → kinfer-0.4.3}/PKG-INFO +3 -3
  3. {kinfer-0.4.2 → kinfer-0.4.3}/kinfer/export/common.py +3 -0
  4. {kinfer-0.4.2 → kinfer-0.4.3}/kinfer/rust/src/model.rs +11 -0
  5. {kinfer-0.4.2 → kinfer-0.4.3}/kinfer/rust/src/runtime.rs +7 -7
  6. {kinfer-0.4.2 → kinfer-0.4.3}/kinfer/rust_bindings/src/lib.rs +23 -5
  7. {kinfer-0.4.2 → kinfer-0.4.3}/kinfer.egg-info/PKG-INFO +3 -3
  8. {kinfer-0.4.2 → kinfer-0.4.3}/kinfer.egg-info/requires.txt +2 -2
  9. {kinfer-0.4.2 → kinfer-0.4.3}/setup.py +1 -1
  10. {kinfer-0.4.2 → kinfer-0.4.3}/tests/test_jax.py +14 -0
  11. {kinfer-0.4.2 → kinfer-0.4.3}/tests/test_pytorch.py +6 -0
  12. {kinfer-0.4.2 → kinfer-0.4.3}/.cargo/config.toml +0 -0
  13. {kinfer-0.4.2 → kinfer-0.4.3}/LICENSE +0 -0
  14. {kinfer-0.4.2 → kinfer-0.4.3}/MANIFEST.in +0 -0
  15. {kinfer-0.4.2 → kinfer-0.4.3}/README.md +0 -0
  16. {kinfer-0.4.2 → kinfer-0.4.3}/kinfer/__init__.py +0 -0
  17. {kinfer-0.4.2 → kinfer-0.4.3}/kinfer/common/__init__.py +0 -0
  18. {kinfer-0.4.2 → kinfer-0.4.3}/kinfer/common/types.py +0 -0
  19. {kinfer-0.4.2 → kinfer-0.4.3}/kinfer/export/__init__.py +0 -0
  20. {kinfer-0.4.2 → kinfer-0.4.3}/kinfer/export/jax.py +0 -0
  21. {kinfer-0.4.2 → kinfer-0.4.3}/kinfer/export/pytorch.py +0 -0
  22. {kinfer-0.4.2 → kinfer-0.4.3}/kinfer/export/serialize.py +0 -0
  23. {kinfer-0.4.2 → kinfer-0.4.3}/kinfer/py.typed +0 -0
  24. {kinfer-0.4.2 → kinfer-0.4.3}/kinfer/requirements.txt +0 -0
  25. {kinfer-0.4.2 → kinfer-0.4.3}/kinfer/rust/Cargo.toml +0 -0
  26. {kinfer-0.4.2 → kinfer-0.4.3}/kinfer/rust/src/lib.rs +0 -0
  27. {kinfer-0.4.2 → kinfer-0.4.3}/kinfer/rust_bindings/Cargo.toml +0 -0
  28. {kinfer-0.4.2 → kinfer-0.4.3}/kinfer/rust_bindings/pyproject.toml +0 -0
  29. {kinfer-0.4.2 → kinfer-0.4.3}/kinfer/rust_bindings/src/bin/stub_gen.rs +0 -0
  30. {kinfer-0.4.2 → kinfer-0.4.3}/kinfer.egg-info/SOURCES.txt +0 -0
  31. {kinfer-0.4.2 → kinfer-0.4.3}/kinfer.egg-info/dependency_links.txt +0 -0
  32. {kinfer-0.4.2 → kinfer-0.4.3}/kinfer.egg-info/not-zip-safe +0 -0
  33. {kinfer-0.4.2 → kinfer-0.4.3}/kinfer.egg-info/top_level.txt +0 -0
  34. {kinfer-0.4.2 → kinfer-0.4.3}/pyproject.toml +0 -0
  35. {kinfer-0.4.2 → kinfer-0.4.3}/setup.cfg +0 -0
@@ -8,7 +8,7 @@ resolver = "2"
8
8
 
9
9
  [workspace.package]
10
10
 
11
- version = "0.4.2"
11
+ version = "0.4.3"
12
12
  authors = ["Wesley Maa <wesley@kscale.dev>", "Benjamin Bolte <ben@kscale.dev>"]
13
13
  edition = "2021"
14
14
  description = "K-Scale Inference Library"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kinfer
3
- Version: 0.4.2
3
+ Version: 0.4.3
4
4
  Summary: Tool to make it easier to run a model on a real robot
5
5
  Home-page: https://github.com/kscalelabs/kinfer.git
6
6
  Author: K-Scale Labs
@@ -21,7 +21,7 @@ Provides-Extra: pytorch
21
21
  Requires-Dist: torch; extra == "pytorch"
22
22
  Provides-Extra: jax
23
23
  Requires-Dist: tensorflow; extra == "jax"
24
- Requires-Dist: tf2onnx; extra == "jax"
24
+ Requires-Dist: tf2onnx>=1.16.0; extra == "jax"
25
25
  Requires-Dist: jax; extra == "jax"
26
26
  Requires-Dist: equinox; extra == "jax"
27
27
  Requires-Dist: numpy<2; extra == "jax"
@@ -34,7 +34,7 @@ Requires-Dist: ruff; extra == "all"
34
34
  Requires-Dist: types-tensorflow; extra == "all"
35
35
  Requires-Dist: torch; extra == "all"
36
36
  Requires-Dist: tensorflow; extra == "all"
37
- Requires-Dist: tf2onnx; extra == "all"
37
+ Requires-Dist: tf2onnx>=1.16.0; extra == "all"
38
38
  Requires-Dist: jax; extra == "all"
39
39
  Requires-Dist: equinox; extra == "all"
40
40
  Requires-Dist: numpy<2; extra == "all"
@@ -37,5 +37,8 @@ def get_shape(
37
37
  raise ValueError("`carry_shape` must be provided for `carry`")
38
38
  return carry_shape
39
39
 
40
+ case "time":
41
+ return (1,)
42
+
40
43
  case _:
41
44
  raise ValueError(f"Unknown tensor name: {name}")
@@ -47,6 +47,7 @@ pub trait ModelProvider: Send + Sync {
47
47
  async fn get_accelerometer(&self) -> Result<Array<f32, IxDyn>, ModelError>;
48
48
  async fn get_gyroscope(&self) -> Result<Array<f32, IxDyn>, ModelError>;
49
49
  async fn get_command(&self) -> Result<Array<f32, IxDyn>, ModelError>;
50
+ async fn get_time(&self) -> Result<Array<f32, IxDyn>, ModelError>;
50
51
  async fn get_carry(&self, carry: Array<f32, IxDyn>) -> Result<Array<f32, IxDyn>, ModelError>;
51
52
  async fn take_action(
52
53
  &self,
@@ -188,6 +189,15 @@ impl ModelRunner {
188
189
  .into());
189
190
  }
190
191
  }
192
+ "time" => {
193
+ if *dims != vec![1] {
194
+ return Err(format!(
195
+ "Expected shape [1] for input `{}`, got {:?}",
196
+ input.name, dims
197
+ )
198
+ .into());
199
+ }
200
+ }
191
201
  "carry" => {
192
202
  if dims != carry_shape {
193
203
  return Err(format!(
@@ -269,6 +279,7 @@ impl ModelRunner {
269
279
  "gyroscope" => futures.push(self.provider.get_gyroscope()),
270
280
  "command" => futures.push(self.provider.get_command()),
271
281
  "carry" => futures.push(self.provider.get_carry(carry.clone())),
282
+ "time" => futures.push(self.provider.get_time()),
272
283
  _ => return Err(format!("Unknown input name: {}", name).into()),
273
284
  }
274
285
  }
@@ -3,8 +3,8 @@ use std::sync::Arc;
3
3
  use tokio::runtime::Runtime;
4
4
 
5
5
  use crate::model::{ModelError, ModelRunner};
6
- use std::time::{Duration, Instant};
7
- use tokio::time::sleep;
6
+ use std::time::Duration;
7
+ use tokio::time::interval;
8
8
 
9
9
  pub struct ModelRuntime {
10
10
  model_runner: Arc<ModelRunner>,
@@ -61,7 +61,10 @@ impl ModelRuntime {
61
61
  .get_joint_angles()
62
62
  .await
63
63
  .map_err(|e| ModelError::Provider(e.to_string()))?;
64
- let mut last_time = Instant::now();
64
+
65
+ // Wait for the first tick, since it happens immediately.
66
+ let mut interval = interval(dt);
67
+ interval.tick().await;
65
68
 
66
69
  while running.load(Ordering::Relaxed) {
67
70
  let (output, next_carry) = model_runner
@@ -80,10 +83,7 @@ impl ModelRuntime {
80
83
  .take_action(interp_joint_positions * magnitude_factor)
81
84
  .await
82
85
  .map_err(|e| ModelError::Provider(e.to_string()))?;
83
- last_time = last_time + dt;
84
- if let Some(sleep_duration) = last_time.checked_duration_since(Instant::now()) {
85
- sleep(sleep_duration).await;
86
- }
86
+ interval.tick().await;
87
87
  }
88
88
 
89
89
  joint_positions = output;
@@ -10,7 +10,7 @@ use pyo3_stub_gen::define_stub_info_gatherer;
10
10
  use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyfunction, gen_stub_pymethods};
11
11
  use std::sync::Arc;
12
12
  use std::sync::Mutex;
13
-
13
+ use std::time::Instant;
14
14
  #[pyfunction]
15
15
  #[gen_stub_pyfunction]
16
16
  fn get_version() -> String {
@@ -73,6 +73,10 @@ impl ModelProviderABC {
73
73
  Err(PyNotImplementedError::new_err("Must override get_command"))
74
74
  }
75
75
 
76
+ fn get_time<'py>(&self) -> PyResult<Bound<'py, PyArray1<f32>>> {
77
+ Err(PyNotImplementedError::new_err("Must override get_time"))
78
+ }
79
+
76
80
  fn take_action<'py>(
77
81
  &self,
78
82
  joint_names: Vec<String>,
@@ -92,13 +96,17 @@ impl ModelProviderABC {
92
96
  #[derive(Clone)]
93
97
  struct PyModelProvider {
94
98
  obj: Arc<Py<ModelProviderABC>>,
99
+ start_time: Instant,
95
100
  }
96
101
 
97
102
  #[pymethods]
98
103
  impl PyModelProvider {
99
104
  #[new]
100
105
  fn new(obj: Py<ModelProviderABC>) -> Self {
101
- Self { obj: Arc::new(obj) }
106
+ Self {
107
+ obj: Arc::new(obj),
108
+ start_time: Instant::now(),
109
+ }
102
110
  }
103
111
  }
104
112
 
@@ -182,6 +190,18 @@ impl ModelProvider for PyModelProvider {
182
190
  Ok(args)
183
191
  }
184
192
 
193
+ async fn get_time(&self) -> Result<Array<f32, IxDyn>, ModelError> {
194
+ let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
195
+ let obj = self.obj.clone();
196
+ let args = ();
197
+ let result = obj.call_method(py, "get_time", args, None)?;
198
+ let array = result.extract::<Vec<f32>>(py)?;
199
+ Ok(Array::from_vec(array).into_dyn())
200
+ })
201
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
202
+ Ok(args)
203
+ }
204
+
185
205
  async fn get_carry(&self, carry: Array<f32, IxDyn>) -> Result<Array<f32, IxDyn>, ModelError> {
186
206
  Ok(carry)
187
207
  }
@@ -217,9 +237,7 @@ struct PyModelRunner {
217
237
  impl PyModelRunner {
218
238
  #[new]
219
239
  fn new(model_path: String, provider: Py<ModelProviderABC>) -> PyResult<Self> {
220
- let input_provider = Arc::new(PyModelProvider {
221
- obj: Arc::new(provider),
222
- });
240
+ let input_provider = Arc::new(PyModelProvider::new(provider));
223
241
 
224
242
  let runner = tokio::runtime::Runtime::new().unwrap().block_on(async {
225
243
  ModelRunner::new(model_path, input_provider)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kinfer
3
- Version: 0.4.2
3
+ Version: 0.4.3
4
4
  Summary: Tool to make it easier to run a model on a real robot
5
5
  Home-page: https://github.com/kscalelabs/kinfer.git
6
6
  Author: K-Scale Labs
@@ -21,7 +21,7 @@ Provides-Extra: pytorch
21
21
  Requires-Dist: torch; extra == "pytorch"
22
22
  Provides-Extra: jax
23
23
  Requires-Dist: tensorflow; extra == "jax"
24
- Requires-Dist: tf2onnx; extra == "jax"
24
+ Requires-Dist: tf2onnx>=1.16.0; extra == "jax"
25
25
  Requires-Dist: jax; extra == "jax"
26
26
  Requires-Dist: equinox; extra == "jax"
27
27
  Requires-Dist: numpy<2; extra == "jax"
@@ -34,7 +34,7 @@ Requires-Dist: ruff; extra == "all"
34
34
  Requires-Dist: types-tensorflow; extra == "all"
35
35
  Requires-Dist: torch; extra == "all"
36
36
  Requires-Dist: tensorflow; extra == "all"
37
- Requires-Dist: tf2onnx; extra == "all"
37
+ Requires-Dist: tf2onnx>=1.16.0; extra == "all"
38
38
  Requires-Dist: jax; extra == "all"
39
39
  Requires-Dist: equinox; extra == "all"
40
40
  Requires-Dist: numpy<2; extra == "all"
@@ -11,7 +11,7 @@ ruff
11
11
  types-tensorflow
12
12
  torch
13
13
  tensorflow
14
- tf2onnx
14
+ tf2onnx>=1.16.0
15
15
  jax
16
16
  equinox
17
17
  numpy<2
@@ -26,7 +26,7 @@ types-tensorflow
26
26
 
27
27
  [jax]
28
28
  tensorflow
29
- tf2onnx
29
+ tf2onnx>=1.16.0
30
30
  jax
31
31
  equinox
32
32
  numpy<2
@@ -39,7 +39,7 @@ requirements_pytorch = [
39
39
 
40
40
  requirements_jax = [
41
41
  "tensorflow",
42
- "tf2onnx",
42
+ "tf2onnx>=1.16.0",
43
43
  "jax",
44
44
  "equinox",
45
45
  "numpy<2",
@@ -18,6 +18,7 @@ logger = logging.getLogger(__name__)
18
18
  JOINT_NAMES = ["left_arm", "right_arm", "left_leg", "right_leg"]
19
19
  NUM_JOINTS = len(JOINT_NAMES)
20
20
  CARRY_SIZE = 10
21
+ NUM_COMMANDS = 4
21
22
 
22
23
 
23
24
  @jax.jit
@@ -32,6 +33,8 @@ def step_fn(
32
33
  projected_gravity: jnp.ndarray,
33
34
  accelerometer: jnp.ndarray,
34
35
  gyroscope: jnp.ndarray,
36
+ command: jnp.ndarray,
37
+ time: jnp.ndarray,
35
38
  carry: jnp.ndarray,
36
39
  ) -> tuple[jnp.ndarray, jnp.ndarray]:
37
40
  output = (
@@ -40,6 +43,9 @@ def step_fn(
40
43
  + projected_gravity.mean()
41
44
  + accelerometer.mean()
42
45
  + gyroscope.mean()
46
+ + command.mean()
47
+ + jnp.cos(time).mean()
48
+ + jnp.sin(time).mean()
43
49
  + carry.mean()
44
50
  ) * joint_angles
45
51
  next_carry = carry + 1
@@ -54,6 +60,7 @@ def test_export(tmpdir: Path) -> None:
54
60
  step_fn_onnx = export_fn(
55
61
  model=step_fn,
56
62
  num_joints=NUM_JOINTS,
63
+ num_commands=NUM_COMMANDS,
57
64
  carry_shape=(CARRY_SIZE,),
58
65
  )
59
66
 
@@ -61,6 +68,7 @@ def test_export(tmpdir: Path) -> None:
61
68
  init_fn_onnx,
62
69
  step_fn_onnx,
63
70
  joint_names=JOINT_NAMES,
71
+ num_commands=NUM_COMMANDS,
64
72
  carry_shape=(CARRY_SIZE,),
65
73
  )
66
74
 
@@ -88,6 +96,12 @@ def test_export(tmpdir: Path) -> None:
88
96
  def get_gyroscope(self) -> np.ndarray:
89
97
  return np.random.randn(3)
90
98
 
99
+ def get_command(self) -> np.ndarray:
100
+ return np.random.randn(NUM_COMMANDS)
101
+
102
+ def get_time(self) -> np.ndarray:
103
+ return np.random.randn(1)
104
+
91
105
  def take_action(self, joint_names: Sequence[str], action: np.ndarray) -> None:
92
106
  assert joint_names == JOINT_NAMES
93
107
  assert action.shape == (NUM_JOINTS,)
@@ -38,6 +38,7 @@ def step_fn(
38
38
  accelerometer: Tensor,
39
39
  gyroscope: Tensor,
40
40
  command: Tensor,
41
+ time: Tensor,
41
42
  carry: Tensor,
42
43
  ) -> tuple[Tensor, Tensor]:
43
44
  output = (
@@ -47,6 +48,8 @@ def step_fn(
47
48
  + accelerometer.mean()
48
49
  + gyroscope.mean()
49
50
  + command.mean()
51
+ + torch.cos(time).mean()
52
+ + torch.sin(time).mean()
50
53
  + carry.mean()
51
54
  ) * joint_angles
52
55
  next_carry = carry + 1
@@ -120,6 +123,9 @@ def test_export(tmpdir: Path) -> None:
120
123
  def get_command(self) -> np.ndarray:
121
124
  return np.random.randn(NUM_COMMANDS)
122
125
 
126
+ def get_time(self) -> np.ndarray:
127
+ return np.random.randn(1)
128
+
123
129
  def take_action(self, joint_names: Sequence[str], action: np.ndarray) -> None:
124
130
  assert joint_names == JOINT_NAMES
125
131
  assert action.shape == (NUM_JOINTS,)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes