kinfer 0.4.1__tar.gz → 0.4.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {kinfer-0.4.1 → kinfer-0.4.3}/Cargo.toml +1 -1
- {kinfer-0.4.1 → kinfer-0.4.3}/PKG-INFO +3 -3
- kinfer-0.4.3/kinfer/__init__.py +16 -0
- {kinfer-0.4.1 → kinfer-0.4.3}/kinfer/export/common.py +3 -0
- {kinfer-0.4.1 → kinfer-0.4.3}/kinfer/rust/src/model.rs +11 -0
- {kinfer-0.4.1 → kinfer-0.4.3}/kinfer/rust/src/runtime.rs +7 -7
- {kinfer-0.4.1 → kinfer-0.4.3}/kinfer/rust_bindings/src/lib.rs +23 -5
- {kinfer-0.4.1 → kinfer-0.4.3}/kinfer.egg-info/PKG-INFO +3 -3
- {kinfer-0.4.1 → kinfer-0.4.3}/kinfer.egg-info/requires.txt +2 -2
- {kinfer-0.4.1 → kinfer-0.4.3}/setup.py +1 -1
- {kinfer-0.4.1 → kinfer-0.4.3}/tests/test_jax.py +14 -0
- {kinfer-0.4.1 → kinfer-0.4.3}/tests/test_pytorch.py +6 -0
- kinfer-0.4.1/kinfer/__init__.py +0 -5
- {kinfer-0.4.1 → kinfer-0.4.3}/.cargo/config.toml +0 -0
- {kinfer-0.4.1 → kinfer-0.4.3}/LICENSE +0 -0
- {kinfer-0.4.1 → kinfer-0.4.3}/MANIFEST.in +0 -0
- {kinfer-0.4.1 → kinfer-0.4.3}/README.md +0 -0
- {kinfer-0.4.1 → kinfer-0.4.3}/kinfer/common/__init__.py +0 -0
- {kinfer-0.4.1 → kinfer-0.4.3}/kinfer/common/types.py +0 -0
- {kinfer-0.4.1 → kinfer-0.4.3}/kinfer/export/__init__.py +0 -0
- {kinfer-0.4.1 → kinfer-0.4.3}/kinfer/export/jax.py +0 -0
- {kinfer-0.4.1 → kinfer-0.4.3}/kinfer/export/pytorch.py +0 -0
- {kinfer-0.4.1 → kinfer-0.4.3}/kinfer/export/serialize.py +0 -0
- {kinfer-0.4.1 → kinfer-0.4.3}/kinfer/py.typed +0 -0
- {kinfer-0.4.1 → kinfer-0.4.3}/kinfer/requirements.txt +0 -0
- {kinfer-0.4.1 → kinfer-0.4.3}/kinfer/rust/Cargo.toml +0 -0
- {kinfer-0.4.1 → kinfer-0.4.3}/kinfer/rust/src/lib.rs +0 -0
- {kinfer-0.4.1 → kinfer-0.4.3}/kinfer/rust_bindings/Cargo.toml +0 -0
- {kinfer-0.4.1 → kinfer-0.4.3}/kinfer/rust_bindings/pyproject.toml +0 -0
- {kinfer-0.4.1 → kinfer-0.4.3}/kinfer/rust_bindings/src/bin/stub_gen.rs +0 -0
- {kinfer-0.4.1 → kinfer-0.4.3}/kinfer.egg-info/SOURCES.txt +0 -0
- {kinfer-0.4.1 → kinfer-0.4.3}/kinfer.egg-info/dependency_links.txt +0 -0
- {kinfer-0.4.1 → kinfer-0.4.3}/kinfer.egg-info/not-zip-safe +0 -0
- {kinfer-0.4.1 → kinfer-0.4.3}/kinfer.egg-info/top_level.txt +0 -0
- {kinfer-0.4.1 → kinfer-0.4.3}/pyproject.toml +0 -0
- {kinfer-0.4.1 → kinfer-0.4.3}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: kinfer
|
3
|
-
Version: 0.4.
|
3
|
+
Version: 0.4.3
|
4
4
|
Summary: Tool to make it easier to run a model on a real robot
|
5
5
|
Home-page: https://github.com/kscalelabs/kinfer.git
|
6
6
|
Author: K-Scale Labs
|
@@ -21,7 +21,7 @@ Provides-Extra: pytorch
|
|
21
21
|
Requires-Dist: torch; extra == "pytorch"
|
22
22
|
Provides-Extra: jax
|
23
23
|
Requires-Dist: tensorflow; extra == "jax"
|
24
|
-
Requires-Dist: tf2onnx; extra == "jax"
|
24
|
+
Requires-Dist: tf2onnx>=1.16.0; extra == "jax"
|
25
25
|
Requires-Dist: jax; extra == "jax"
|
26
26
|
Requires-Dist: equinox; extra == "jax"
|
27
27
|
Requires-Dist: numpy<2; extra == "jax"
|
@@ -34,7 +34,7 @@ Requires-Dist: ruff; extra == "all"
|
|
34
34
|
Requires-Dist: types-tensorflow; extra == "all"
|
35
35
|
Requires-Dist: torch; extra == "all"
|
36
36
|
Requires-Dist: tensorflow; extra == "all"
|
37
|
-
Requires-Dist: tf2onnx; extra == "all"
|
37
|
+
Requires-Dist: tf2onnx>=1.16.0; extra == "all"
|
38
38
|
Requires-Dist: jax; extra == "all"
|
39
39
|
Requires-Dist: equinox; extra == "all"
|
40
40
|
Requires-Dist: numpy<2; extra == "all"
|
@@ -0,0 +1,16 @@
|
|
1
|
+
"""Defines the kinfer API."""
|
2
|
+
|
3
|
+
import os
|
4
|
+
|
5
|
+
if "ORT_DYLIB_PATH" not in os.environ:
|
6
|
+
from pathlib import Path
|
7
|
+
|
8
|
+
import onnxruntime as ort
|
9
|
+
|
10
|
+
LIB_PATH = next((Path(ort.__file__).parent / "capi").glob("libonnxruntime.*"), None)
|
11
|
+
if LIB_PATH is not None:
|
12
|
+
os.environ["ORT_DYLIB_PATH"] = LIB_PATH.resolve().as_posix()
|
13
|
+
|
14
|
+
from .rust_bindings import get_version
|
15
|
+
|
16
|
+
__version__ = get_version()
|
@@ -47,6 +47,7 @@ pub trait ModelProvider: Send + Sync {
|
|
47
47
|
async fn get_accelerometer(&self) -> Result<Array<f32, IxDyn>, ModelError>;
|
48
48
|
async fn get_gyroscope(&self) -> Result<Array<f32, IxDyn>, ModelError>;
|
49
49
|
async fn get_command(&self) -> Result<Array<f32, IxDyn>, ModelError>;
|
50
|
+
async fn get_time(&self) -> Result<Array<f32, IxDyn>, ModelError>;
|
50
51
|
async fn get_carry(&self, carry: Array<f32, IxDyn>) -> Result<Array<f32, IxDyn>, ModelError>;
|
51
52
|
async fn take_action(
|
52
53
|
&self,
|
@@ -188,6 +189,15 @@ impl ModelRunner {
|
|
188
189
|
.into());
|
189
190
|
}
|
190
191
|
}
|
192
|
+
"time" => {
|
193
|
+
if *dims != vec![1] {
|
194
|
+
return Err(format!(
|
195
|
+
"Expected shape [1] for input `{}`, got {:?}",
|
196
|
+
input.name, dims
|
197
|
+
)
|
198
|
+
.into());
|
199
|
+
}
|
200
|
+
}
|
191
201
|
"carry" => {
|
192
202
|
if dims != carry_shape {
|
193
203
|
return Err(format!(
|
@@ -269,6 +279,7 @@ impl ModelRunner {
|
|
269
279
|
"gyroscope" => futures.push(self.provider.get_gyroscope()),
|
270
280
|
"command" => futures.push(self.provider.get_command()),
|
271
281
|
"carry" => futures.push(self.provider.get_carry(carry.clone())),
|
282
|
+
"time" => futures.push(self.provider.get_time()),
|
272
283
|
_ => return Err(format!("Unknown input name: {}", name).into()),
|
273
284
|
}
|
274
285
|
}
|
@@ -3,8 +3,8 @@ use std::sync::Arc;
|
|
3
3
|
use tokio::runtime::Runtime;
|
4
4
|
|
5
5
|
use crate::model::{ModelError, ModelRunner};
|
6
|
-
use std::time::
|
7
|
-
use tokio::time::
|
6
|
+
use std::time::Duration;
|
7
|
+
use tokio::time::interval;
|
8
8
|
|
9
9
|
pub struct ModelRuntime {
|
10
10
|
model_runner: Arc<ModelRunner>,
|
@@ -61,7 +61,10 @@ impl ModelRuntime {
|
|
61
61
|
.get_joint_angles()
|
62
62
|
.await
|
63
63
|
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
64
|
-
|
64
|
+
|
65
|
+
// Wait for the first tick, since it happens immediately.
|
66
|
+
let mut interval = interval(dt);
|
67
|
+
interval.tick().await;
|
65
68
|
|
66
69
|
while running.load(Ordering::Relaxed) {
|
67
70
|
let (output, next_carry) = model_runner
|
@@ -80,10 +83,7 @@ impl ModelRuntime {
|
|
80
83
|
.take_action(interp_joint_positions * magnitude_factor)
|
81
84
|
.await
|
82
85
|
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
83
|
-
|
84
|
-
if let Some(sleep_duration) = last_time.checked_duration_since(Instant::now()) {
|
85
|
-
sleep(sleep_duration).await;
|
86
|
-
}
|
86
|
+
interval.tick().await;
|
87
87
|
}
|
88
88
|
|
89
89
|
joint_positions = output;
|
@@ -10,7 +10,7 @@ use pyo3_stub_gen::define_stub_info_gatherer;
|
|
10
10
|
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyfunction, gen_stub_pymethods};
|
11
11
|
use std::sync::Arc;
|
12
12
|
use std::sync::Mutex;
|
13
|
-
|
13
|
+
use std::time::Instant;
|
14
14
|
#[pyfunction]
|
15
15
|
#[gen_stub_pyfunction]
|
16
16
|
fn get_version() -> String {
|
@@ -73,6 +73,10 @@ impl ModelProviderABC {
|
|
73
73
|
Err(PyNotImplementedError::new_err("Must override get_command"))
|
74
74
|
}
|
75
75
|
|
76
|
+
fn get_time<'py>(&self) -> PyResult<Bound<'py, PyArray1<f32>>> {
|
77
|
+
Err(PyNotImplementedError::new_err("Must override get_time"))
|
78
|
+
}
|
79
|
+
|
76
80
|
fn take_action<'py>(
|
77
81
|
&self,
|
78
82
|
joint_names: Vec<String>,
|
@@ -92,13 +96,17 @@ impl ModelProviderABC {
|
|
92
96
|
#[derive(Clone)]
|
93
97
|
struct PyModelProvider {
|
94
98
|
obj: Arc<Py<ModelProviderABC>>,
|
99
|
+
start_time: Instant,
|
95
100
|
}
|
96
101
|
|
97
102
|
#[pymethods]
|
98
103
|
impl PyModelProvider {
|
99
104
|
#[new]
|
100
105
|
fn new(obj: Py<ModelProviderABC>) -> Self {
|
101
|
-
Self {
|
106
|
+
Self {
|
107
|
+
obj: Arc::new(obj),
|
108
|
+
start_time: Instant::now(),
|
109
|
+
}
|
102
110
|
}
|
103
111
|
}
|
104
112
|
|
@@ -182,6 +190,18 @@ impl ModelProvider for PyModelProvider {
|
|
182
190
|
Ok(args)
|
183
191
|
}
|
184
192
|
|
193
|
+
async fn get_time(&self) -> Result<Array<f32, IxDyn>, ModelError> {
|
194
|
+
let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
|
195
|
+
let obj = self.obj.clone();
|
196
|
+
let args = ();
|
197
|
+
let result = obj.call_method(py, "get_time", args, None)?;
|
198
|
+
let array = result.extract::<Vec<f32>>(py)?;
|
199
|
+
Ok(Array::from_vec(array).into_dyn())
|
200
|
+
})
|
201
|
+
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
202
|
+
Ok(args)
|
203
|
+
}
|
204
|
+
|
185
205
|
async fn get_carry(&self, carry: Array<f32, IxDyn>) -> Result<Array<f32, IxDyn>, ModelError> {
|
186
206
|
Ok(carry)
|
187
207
|
}
|
@@ -217,9 +237,7 @@ struct PyModelRunner {
|
|
217
237
|
impl PyModelRunner {
|
218
238
|
#[new]
|
219
239
|
fn new(model_path: String, provider: Py<ModelProviderABC>) -> PyResult<Self> {
|
220
|
-
let input_provider = Arc::new(PyModelProvider
|
221
|
-
obj: Arc::new(provider),
|
222
|
-
});
|
240
|
+
let input_provider = Arc::new(PyModelProvider::new(provider));
|
223
241
|
|
224
242
|
let runner = tokio::runtime::Runtime::new().unwrap().block_on(async {
|
225
243
|
ModelRunner::new(model_path, input_provider)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: kinfer
|
3
|
-
Version: 0.4.
|
3
|
+
Version: 0.4.3
|
4
4
|
Summary: Tool to make it easier to run a model on a real robot
|
5
5
|
Home-page: https://github.com/kscalelabs/kinfer.git
|
6
6
|
Author: K-Scale Labs
|
@@ -21,7 +21,7 @@ Provides-Extra: pytorch
|
|
21
21
|
Requires-Dist: torch; extra == "pytorch"
|
22
22
|
Provides-Extra: jax
|
23
23
|
Requires-Dist: tensorflow; extra == "jax"
|
24
|
-
Requires-Dist: tf2onnx; extra == "jax"
|
24
|
+
Requires-Dist: tf2onnx>=1.16.0; extra == "jax"
|
25
25
|
Requires-Dist: jax; extra == "jax"
|
26
26
|
Requires-Dist: equinox; extra == "jax"
|
27
27
|
Requires-Dist: numpy<2; extra == "jax"
|
@@ -34,7 +34,7 @@ Requires-Dist: ruff; extra == "all"
|
|
34
34
|
Requires-Dist: types-tensorflow; extra == "all"
|
35
35
|
Requires-Dist: torch; extra == "all"
|
36
36
|
Requires-Dist: tensorflow; extra == "all"
|
37
|
-
Requires-Dist: tf2onnx; extra == "all"
|
37
|
+
Requires-Dist: tf2onnx>=1.16.0; extra == "all"
|
38
38
|
Requires-Dist: jax; extra == "all"
|
39
39
|
Requires-Dist: equinox; extra == "all"
|
40
40
|
Requires-Dist: numpy<2; extra == "all"
|
@@ -18,6 +18,7 @@ logger = logging.getLogger(__name__)
|
|
18
18
|
JOINT_NAMES = ["left_arm", "right_arm", "left_leg", "right_leg"]
|
19
19
|
NUM_JOINTS = len(JOINT_NAMES)
|
20
20
|
CARRY_SIZE = 10
|
21
|
+
NUM_COMMANDS = 4
|
21
22
|
|
22
23
|
|
23
24
|
@jax.jit
|
@@ -32,6 +33,8 @@ def step_fn(
|
|
32
33
|
projected_gravity: jnp.ndarray,
|
33
34
|
accelerometer: jnp.ndarray,
|
34
35
|
gyroscope: jnp.ndarray,
|
36
|
+
command: jnp.ndarray,
|
37
|
+
time: jnp.ndarray,
|
35
38
|
carry: jnp.ndarray,
|
36
39
|
) -> tuple[jnp.ndarray, jnp.ndarray]:
|
37
40
|
output = (
|
@@ -40,6 +43,9 @@ def step_fn(
|
|
40
43
|
+ projected_gravity.mean()
|
41
44
|
+ accelerometer.mean()
|
42
45
|
+ gyroscope.mean()
|
46
|
+
+ command.mean()
|
47
|
+
+ jnp.cos(time).mean()
|
48
|
+
+ jnp.sin(time).mean()
|
43
49
|
+ carry.mean()
|
44
50
|
) * joint_angles
|
45
51
|
next_carry = carry + 1
|
@@ -54,6 +60,7 @@ def test_export(tmpdir: Path) -> None:
|
|
54
60
|
step_fn_onnx = export_fn(
|
55
61
|
model=step_fn,
|
56
62
|
num_joints=NUM_JOINTS,
|
63
|
+
num_commands=NUM_COMMANDS,
|
57
64
|
carry_shape=(CARRY_SIZE,),
|
58
65
|
)
|
59
66
|
|
@@ -61,6 +68,7 @@ def test_export(tmpdir: Path) -> None:
|
|
61
68
|
init_fn_onnx,
|
62
69
|
step_fn_onnx,
|
63
70
|
joint_names=JOINT_NAMES,
|
71
|
+
num_commands=NUM_COMMANDS,
|
64
72
|
carry_shape=(CARRY_SIZE,),
|
65
73
|
)
|
66
74
|
|
@@ -88,6 +96,12 @@ def test_export(tmpdir: Path) -> None:
|
|
88
96
|
def get_gyroscope(self) -> np.ndarray:
|
89
97
|
return np.random.randn(3)
|
90
98
|
|
99
|
+
def get_command(self) -> np.ndarray:
|
100
|
+
return np.random.randn(NUM_COMMANDS)
|
101
|
+
|
102
|
+
def get_time(self) -> np.ndarray:
|
103
|
+
return np.random.randn(1)
|
104
|
+
|
91
105
|
def take_action(self, joint_names: Sequence[str], action: np.ndarray) -> None:
|
92
106
|
assert joint_names == JOINT_NAMES
|
93
107
|
assert action.shape == (NUM_JOINTS,)
|
@@ -38,6 +38,7 @@ def step_fn(
|
|
38
38
|
accelerometer: Tensor,
|
39
39
|
gyroscope: Tensor,
|
40
40
|
command: Tensor,
|
41
|
+
time: Tensor,
|
41
42
|
carry: Tensor,
|
42
43
|
) -> tuple[Tensor, Tensor]:
|
43
44
|
output = (
|
@@ -47,6 +48,8 @@ def step_fn(
|
|
47
48
|
+ accelerometer.mean()
|
48
49
|
+ gyroscope.mean()
|
49
50
|
+ command.mean()
|
51
|
+
+ torch.cos(time).mean()
|
52
|
+
+ torch.sin(time).mean()
|
50
53
|
+ carry.mean()
|
51
54
|
) * joint_angles
|
52
55
|
next_carry = carry + 1
|
@@ -120,6 +123,9 @@ def test_export(tmpdir: Path) -> None:
|
|
120
123
|
def get_command(self) -> np.ndarray:
|
121
124
|
return np.random.randn(NUM_COMMANDS)
|
122
125
|
|
126
|
+
def get_time(self) -> np.ndarray:
|
127
|
+
return np.random.randn(1)
|
128
|
+
|
123
129
|
def take_action(self, joint_names: Sequence[str], action: np.ndarray) -> None:
|
124
130
|
assert joint_names == JOINT_NAMES
|
125
131
|
assert action.shape == (NUM_JOINTS,)
|
kinfer-0.4.1/kinfer/__init__.py
DELETED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|