kinfer 0.3.2__cp312-cp312-macosx_11_0_arm64.whl → 0.4.0__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kinfer/__init__.py +0 -1
- kinfer/common/__init__.py +0 -0
- kinfer/common/types.py +11 -0
- kinfer/export/__init__.py +0 -1
- kinfer/export/common.py +35 -0
- kinfer/export/jax.py +51 -0
- kinfer/export/pytorch.py +42 -110
- kinfer/export/serialize.py +86 -0
- kinfer/requirements.txt +3 -4
- kinfer/rust/Cargo.toml +8 -6
- kinfer/rust/src/lib.rs +2 -11
- kinfer/rust/src/model.rs +271 -121
- kinfer/rust/src/runtime.rs +104 -0
- kinfer/rust_bindings/Cargo.toml +8 -1
- kinfer/rust_bindings/rust_bindings.pyi +35 -0
- kinfer/rust_bindings/src/lib.rs +310 -1
- kinfer/rust_bindings.cpython-312-darwin.so +0 -0
- kinfer/rust_bindings.pyi +29 -1
- kinfer-0.4.0.dist-info/METADATA +55 -0
- kinfer-0.4.0.dist-info/RECORD +26 -0
- {kinfer-0.3.2.dist-info → kinfer-0.4.0.dist-info}/WHEEL +2 -1
- kinfer/inference/__init__.py +0 -1
- kinfer/inference/python.py +0 -92
- kinfer/proto/__init__.py +0 -40
- kinfer/proto/kinfer_pb2.py +0 -103
- kinfer/proto/kinfer_pb2.pyi +0 -1097
- kinfer/requirements-dev.txt +0 -8
- kinfer/rust/build.rs +0 -16
- kinfer/rust/src/kinfer_proto.rs +0 -14
- kinfer/rust/src/main.rs +0 -6
- kinfer/rust/src/onnx_serializer.rs +0 -804
- kinfer/rust/src/serializer.rs +0 -221
- kinfer/rust/src/tests/onnx_serializer_tests.rs +0 -212
- kinfer/serialize/__init__.py +0 -36
- kinfer/serialize/base.py +0 -536
- kinfer/serialize/json.py +0 -399
- kinfer/serialize/numpy.py +0 -426
- kinfer/serialize/pytorch.py +0 -402
- kinfer/serialize/schema.py +0 -125
- kinfer/serialize/types.py +0 -17
- kinfer/serialize/utils.py +0 -177
- kinfer-0.3.2.dist-info/METADATA +0 -57
- kinfer-0.3.2.dist-info/RECORD +0 -39
- {kinfer-0.3.2.dist-info → kinfer-0.4.0.dist-info/licenses}/LICENSE +0 -0
- {kinfer-0.3.2.dist-info → kinfer-0.4.0.dist-info}/top_level.txt +0 -0
kinfer/rust_bindings/src/lib.rs
CHANGED
@@ -1,6 +1,15 @@
|
|
1
|
+
use async_trait::async_trait;
|
2
|
+
use kinfer::model::{ModelError, ModelProvider, ModelRunner};
|
3
|
+
use kinfer::runtime::ModelRuntime;
|
4
|
+
use ndarray::{Array, Ix1, IxDyn};
|
5
|
+
use numpy::{PyArray1, PyArrayDyn, PyArrayMethods};
|
6
|
+
use pyo3::exceptions::PyNotImplementedError;
|
1
7
|
use pyo3::prelude::*;
|
8
|
+
use pyo3::{pymodule, types::PyModule, Bound, PyResult, Python};
|
2
9
|
use pyo3_stub_gen::define_stub_info_gatherer;
|
3
|
-
use pyo3_stub_gen::derive::gen_stub_pyfunction;
|
10
|
+
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyfunction, gen_stub_pymethods};
|
11
|
+
use std::sync::Arc;
|
12
|
+
use std::sync::Mutex;
|
4
13
|
|
5
14
|
#[pyfunction]
|
6
15
|
#[gen_stub_pyfunction]
|
@@ -8,9 +17,309 @@ fn get_version() -> String {
|
|
8
17
|
env!("CARGO_PKG_VERSION").to_string()
|
9
18
|
}
|
10
19
|
|
20
|
+
#[pyclass(subclass)]
|
21
|
+
#[gen_stub_pyclass]
|
22
|
+
pub struct ModelProviderABC;
|
23
|
+
|
24
|
+
#[gen_stub_pymethods]
|
25
|
+
#[pymethods]
|
26
|
+
impl ModelProviderABC {
|
27
|
+
#[new]
|
28
|
+
fn new() -> Self {
|
29
|
+
ModelProviderABC
|
30
|
+
}
|
31
|
+
|
32
|
+
fn get_joint_angles<'py>(
|
33
|
+
&self,
|
34
|
+
joint_names: Vec<String>,
|
35
|
+
) -> PyResult<Bound<'py, PyArray1<f32>>> {
|
36
|
+
let n = joint_names.len();
|
37
|
+
Err(PyNotImplementedError::new_err(format!(
|
38
|
+
"Must override get_joint_angles with {} joint names",
|
39
|
+
n
|
40
|
+
)))
|
41
|
+
}
|
42
|
+
|
43
|
+
fn get_joint_angular_velocities<'py>(
|
44
|
+
&self,
|
45
|
+
joint_names: Vec<String>,
|
46
|
+
) -> PyResult<Bound<'py, PyArray1<f32>>> {
|
47
|
+
let n = joint_names.len();
|
48
|
+
Err(PyNotImplementedError::new_err(format!(
|
49
|
+
"Must override get_joint_angular_velocities with {} joint names",
|
50
|
+
n
|
51
|
+
)))
|
52
|
+
}
|
53
|
+
|
54
|
+
fn get_projected_gravity<'py>(&self) -> PyResult<Bound<'py, PyArray1<f32>>> {
|
55
|
+
Err(PyNotImplementedError::new_err(
|
56
|
+
"Must override get_projected_gravity",
|
57
|
+
))
|
58
|
+
}
|
59
|
+
|
60
|
+
fn get_accelerometer<'py>(&self) -> PyResult<Bound<'py, PyArray1<f32>>> {
|
61
|
+
Err(PyNotImplementedError::new_err(
|
62
|
+
"Must override get_accelerometer",
|
63
|
+
))
|
64
|
+
}
|
65
|
+
|
66
|
+
fn get_gyroscope<'py>(&self) -> PyResult<Bound<'py, PyArray1<f32>>> {
|
67
|
+
Err(PyNotImplementedError::new_err(
|
68
|
+
"Must override get_gyroscope",
|
69
|
+
))
|
70
|
+
}
|
71
|
+
|
72
|
+
fn take_action<'py>(
|
73
|
+
&self,
|
74
|
+
joint_names: Vec<String>,
|
75
|
+
action: Bound<'py, PyArray1<f32>>,
|
76
|
+
) -> PyResult<()> {
|
77
|
+
let n = action.len()?;
|
78
|
+
assert_eq!(joint_names.len(), n);
|
79
|
+
Err(PyNotImplementedError::new_err(format!(
|
80
|
+
"Must override take_action with {} action",
|
81
|
+
n
|
82
|
+
)))
|
83
|
+
}
|
84
|
+
}
|
85
|
+
|
86
|
+
#[gen_stub_pyclass]
|
87
|
+
#[pyclass]
|
88
|
+
#[derive(Clone)]
|
89
|
+
struct PyModelProvider {
|
90
|
+
obj: Arc<Py<ModelProviderABC>>,
|
91
|
+
}
|
92
|
+
|
93
|
+
#[pymethods]
|
94
|
+
impl PyModelProvider {
|
95
|
+
#[new]
|
96
|
+
fn new(obj: Py<ModelProviderABC>) -> Self {
|
97
|
+
Self { obj: Arc::new(obj) }
|
98
|
+
}
|
99
|
+
}
|
100
|
+
|
101
|
+
#[async_trait]
|
102
|
+
impl ModelProvider for PyModelProvider {
|
103
|
+
async fn get_joint_angles(
|
104
|
+
&self,
|
105
|
+
joint_names: &[String],
|
106
|
+
) -> Result<Array<f32, IxDyn>, ModelError> {
|
107
|
+
let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
|
108
|
+
let obj = self.obj.clone();
|
109
|
+
let args = (joint_names,);
|
110
|
+
let result = obj.call_method(py, "get_joint_angles", args, None)?;
|
111
|
+
let array = result.extract::<Vec<f32>>(py)?;
|
112
|
+
Ok(Array::from_vec(array).into_dyn())
|
113
|
+
})
|
114
|
+
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
115
|
+
Ok(args)
|
116
|
+
}
|
117
|
+
|
118
|
+
async fn get_joint_angular_velocities(
|
119
|
+
&self,
|
120
|
+
joint_names: &[String],
|
121
|
+
) -> Result<Array<f32, IxDyn>, ModelError> {
|
122
|
+
let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
|
123
|
+
let obj = self.obj.clone();
|
124
|
+
let args = (joint_names,);
|
125
|
+
let result = obj.call_method(py, "get_joint_angular_velocities", args, None)?;
|
126
|
+
let array = result.extract::<Vec<f32>>(py)?;
|
127
|
+
Ok(Array::from_vec(array).into_dyn())
|
128
|
+
})
|
129
|
+
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
130
|
+
Ok(args)
|
131
|
+
}
|
132
|
+
|
133
|
+
async fn get_projected_gravity(&self) -> Result<Array<f32, IxDyn>, ModelError> {
|
134
|
+
let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
|
135
|
+
let obj = self.obj.clone();
|
136
|
+
let args = ();
|
137
|
+
let result = obj.call_method(py, "get_projected_gravity", args, None)?;
|
138
|
+
let array = result.extract::<Vec<f32>>(py)?;
|
139
|
+
Ok(Array::from_vec(array).into_dyn())
|
140
|
+
})
|
141
|
+
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
142
|
+
Ok(args)
|
143
|
+
}
|
144
|
+
|
145
|
+
async fn get_accelerometer(&self) -> Result<Array<f32, IxDyn>, ModelError> {
|
146
|
+
let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
|
147
|
+
let obj = self.obj.clone();
|
148
|
+
let args = ();
|
149
|
+
let result = obj.call_method(py, "get_accelerometer", args, None)?;
|
150
|
+
let array = result.extract::<Vec<f32>>(py)?;
|
151
|
+
Ok(Array::from_vec(array).into_dyn())
|
152
|
+
})
|
153
|
+
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
154
|
+
Ok(args)
|
155
|
+
}
|
156
|
+
|
157
|
+
async fn get_gyroscope(&self) -> Result<Array<f32, IxDyn>, ModelError> {
|
158
|
+
let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
|
159
|
+
let obj = self.obj.clone();
|
160
|
+
let args = ();
|
161
|
+
let result = obj.call_method(py, "get_gyroscope", args, None)?;
|
162
|
+
let array = result.extract::<Vec<f32>>(py)?;
|
163
|
+
Ok(Array::from_vec(array).into_dyn())
|
164
|
+
})
|
165
|
+
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
166
|
+
Ok(args)
|
167
|
+
}
|
168
|
+
|
169
|
+
async fn get_carry(&self, carry: Array<f32, IxDyn>) -> Result<Array<f32, IxDyn>, ModelError> {
|
170
|
+
Ok(carry)
|
171
|
+
}
|
172
|
+
|
173
|
+
async fn take_action(
|
174
|
+
&self,
|
175
|
+
joint_names: Vec<String>,
|
176
|
+
action: Array<f32, IxDyn>,
|
177
|
+
) -> Result<(), ModelError> {
|
178
|
+
Python::with_gil(|py| -> PyResult<()> {
|
179
|
+
let obj = self.obj.clone();
|
180
|
+
let action_1d = action
|
181
|
+
.into_dimensionality::<Ix1>()
|
182
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
|
183
|
+
let args = (joint_names, PyArray1::from_array(py, &action_1d));
|
184
|
+
obj.call_method(py, "take_action", args, None)?;
|
185
|
+
Ok(())
|
186
|
+
})
|
187
|
+
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
188
|
+
Ok(())
|
189
|
+
}
|
190
|
+
}
|
191
|
+
|
192
|
+
#[gen_stub_pyclass]
|
193
|
+
#[pyclass]
|
194
|
+
#[derive(Clone)]
|
195
|
+
struct PyModelRunner {
|
196
|
+
runner: Arc<ModelRunner>,
|
197
|
+
}
|
198
|
+
|
199
|
+
#[gen_stub_pymethods]
|
200
|
+
#[pymethods]
|
201
|
+
impl PyModelRunner {
|
202
|
+
#[new]
|
203
|
+
fn new(model_path: String, provider: Py<ModelProviderABC>) -> PyResult<Self> {
|
204
|
+
let input_provider = Arc::new(PyModelProvider {
|
205
|
+
obj: Arc::new(provider),
|
206
|
+
});
|
207
|
+
|
208
|
+
let runner = tokio::runtime::Runtime::new().unwrap().block_on(async {
|
209
|
+
ModelRunner::new(model_path, input_provider)
|
210
|
+
.await
|
211
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
|
212
|
+
})?;
|
213
|
+
|
214
|
+
Ok(Self {
|
215
|
+
runner: Arc::new(runner),
|
216
|
+
})
|
217
|
+
}
|
218
|
+
|
219
|
+
fn init(&self) -> PyResult<Py<PyArrayDyn<f32>>> {
|
220
|
+
let runner = self.runner.clone();
|
221
|
+
let result = tokio::runtime::Runtime::new().unwrap().block_on(async {
|
222
|
+
runner
|
223
|
+
.init()
|
224
|
+
.await
|
225
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
|
226
|
+
})?;
|
227
|
+
|
228
|
+
Python::with_gil(|py| {
|
229
|
+
let array = numpy::PyArray::from_array(py, &result);
|
230
|
+
Ok(array.into())
|
231
|
+
})
|
232
|
+
}
|
233
|
+
|
234
|
+
fn step(
|
235
|
+
&self,
|
236
|
+
carry: Py<PyArrayDyn<f32>>,
|
237
|
+
) -> PyResult<(Py<PyArrayDyn<f32>>, Py<PyArrayDyn<f32>>)> {
|
238
|
+
let runner = self.runner.clone();
|
239
|
+
let carry_array = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
|
240
|
+
let carry_array = carry.bind(py);
|
241
|
+
Ok(carry_array.to_owned_array())
|
242
|
+
})?;
|
243
|
+
|
244
|
+
let result = tokio::runtime::Runtime::new().unwrap().block_on(async {
|
245
|
+
runner
|
246
|
+
.step(carry_array)
|
247
|
+
.await
|
248
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
|
249
|
+
})?;
|
250
|
+
|
251
|
+
Python::with_gil(|py| {
|
252
|
+
let (output, carry) = result;
|
253
|
+
let output_array = numpy::PyArray::from_array(py, &output);
|
254
|
+
let carry_array = numpy::PyArray::from_array(py, &carry);
|
255
|
+
Ok((output_array.into(), carry_array.into()))
|
256
|
+
})
|
257
|
+
}
|
258
|
+
|
259
|
+
fn take_action(&self, action: Py<PyArrayDyn<f32>>) -> PyResult<()> {
|
260
|
+
let runner = self.runner.clone();
|
261
|
+
let action_array = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
|
262
|
+
let action_array = action.bind(py);
|
263
|
+
Ok(action_array.to_owned_array())
|
264
|
+
})?;
|
265
|
+
|
266
|
+
tokio::runtime::Runtime::new().unwrap().block_on(async {
|
267
|
+
runner
|
268
|
+
.take_action(action_array)
|
269
|
+
.await
|
270
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
|
271
|
+
})?;
|
272
|
+
|
273
|
+
Ok(())
|
274
|
+
}
|
275
|
+
}
|
276
|
+
|
277
|
+
#[gen_stub_pyclass]
|
278
|
+
#[pyclass]
|
279
|
+
#[derive(Clone)]
|
280
|
+
struct PyModelRuntime {
|
281
|
+
runtime: Arc<Mutex<ModelRuntime>>,
|
282
|
+
}
|
283
|
+
|
284
|
+
#[gen_stub_pymethods]
|
285
|
+
#[pymethods]
|
286
|
+
impl PyModelRuntime {
|
287
|
+
#[new]
|
288
|
+
fn new(model_runner: PyModelRunner, dt: u64) -> PyResult<Self> {
|
289
|
+
Ok(Self {
|
290
|
+
runtime: Arc::new(Mutex::new(ModelRuntime::new(model_runner.runner, dt))),
|
291
|
+
})
|
292
|
+
}
|
293
|
+
|
294
|
+
fn set_slowdown_factor(&self, slowdown_factor: i32) {
|
295
|
+
let mut runtime = self.runtime.lock().unwrap();
|
296
|
+
runtime.set_slowdown_factor(slowdown_factor);
|
297
|
+
}
|
298
|
+
|
299
|
+
fn set_magnitude_factor(&self, magnitude_factor: f32) {
|
300
|
+
let mut runtime = self.runtime.lock().unwrap();
|
301
|
+
runtime.set_magnitude_factor(magnitude_factor);
|
302
|
+
}
|
303
|
+
|
304
|
+
fn start(&self) -> PyResult<()> {
|
305
|
+
let mut runtime = self.runtime.lock().unwrap();
|
306
|
+
runtime
|
307
|
+
.start()
|
308
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
|
309
|
+
}
|
310
|
+
|
311
|
+
fn stop(&self) {
|
312
|
+
let mut runtime = self.runtime.lock().unwrap();
|
313
|
+
runtime.stop();
|
314
|
+
}
|
315
|
+
}
|
316
|
+
|
11
317
|
#[pymodule]
|
12
318
|
fn rust_bindings(m: &Bound<PyModule>) -> PyResult<()> {
|
13
319
|
m.add_function(wrap_pyfunction!(get_version, m)?)?;
|
320
|
+
m.add_class::<ModelProviderABC>()?;
|
321
|
+
m.add_class::<PyModelRunner>()?;
|
322
|
+
m.add_class::<PyModelRuntime>()?;
|
14
323
|
Ok(())
|
15
324
|
}
|
16
325
|
|
Binary file
|
kinfer/rust_bindings.pyi
CHANGED
@@ -1,7 +1,35 @@
|
|
1
1
|
# This file is automatically generated by pyo3_stub_gen
|
2
2
|
# ruff: noqa: E501, F401
|
3
3
|
|
4
|
+
import builtins
|
5
|
+
import numpy
|
6
|
+
import numpy.typing
|
7
|
+
import typing
|
4
8
|
|
5
|
-
|
9
|
+
class ModelProviderABC:
|
10
|
+
def __new__(cls) -> ModelProviderABC: ...
|
11
|
+
def get_joint_angles(self, joint_names:typing.Sequence[builtins.str]) -> numpy.typing.NDArray[numpy.float32]: ...
|
12
|
+
def get_joint_angular_velocities(self, joint_names:typing.Sequence[builtins.str]) -> numpy.typing.NDArray[numpy.float32]: ...
|
13
|
+
def get_projected_gravity(self) -> numpy.typing.NDArray[numpy.float32]: ...
|
14
|
+
def get_accelerometer(self) -> numpy.typing.NDArray[numpy.float32]: ...
|
15
|
+
def get_gyroscope(self) -> numpy.typing.NDArray[numpy.float32]: ...
|
16
|
+
def take_action(self, joint_names:typing.Sequence[builtins.str], action:numpy.typing.NDArray[numpy.float32]) -> None: ...
|
17
|
+
|
18
|
+
class PyModelProvider:
|
6
19
|
...
|
7
20
|
|
21
|
+
class PyModelRunner:
|
22
|
+
def __new__(cls, model_path:builtins.str, provider:ModelProviderABC) -> PyModelRunner: ...
|
23
|
+
def init(self) -> numpy.typing.NDArray[numpy.float32]: ...
|
24
|
+
def step(self, carry:numpy.typing.NDArray[numpy.float32]) -> tuple[numpy.typing.NDArray[numpy.float32], numpy.typing.NDArray[numpy.float32]]: ...
|
25
|
+
def take_action(self, action:numpy.typing.NDArray[numpy.float32]) -> None: ...
|
26
|
+
|
27
|
+
class PyModelRuntime:
|
28
|
+
def __new__(cls, model_runner:PyModelRunner, dt:builtins.int) -> PyModelRuntime: ...
|
29
|
+
def set_slowdown_factor(self, slowdown_factor:builtins.int) -> None: ...
|
30
|
+
def set_magnitude_factor(self, magnitude_factor:builtins.float) -> None: ...
|
31
|
+
def start(self) -> None: ...
|
32
|
+
def stop(self) -> None: ...
|
33
|
+
|
34
|
+
def get_version() -> builtins.str: ...
|
35
|
+
|
@@ -0,0 +1,55 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: kinfer
|
3
|
+
Version: 0.4.0
|
4
|
+
Summary: Tool to make it easier to run a model on a real robot
|
5
|
+
Home-page: https://github.com/kscalelabs/kinfer.git
|
6
|
+
Author: K-Scale Labs
|
7
|
+
Requires-Python: >=3.11
|
8
|
+
Description-Content-Type: text/markdown
|
9
|
+
License-File: LICENSE
|
10
|
+
Requires-Dist: onnx
|
11
|
+
Requires-Dist: onnxruntime==1.20.0
|
12
|
+
Requires-Dist: pydantic
|
13
|
+
Provides-Extra: dev
|
14
|
+
Requires-Dist: black; extra == "dev"
|
15
|
+
Requires-Dist: darglint; extra == "dev"
|
16
|
+
Requires-Dist: mypy; extra == "dev"
|
17
|
+
Requires-Dist: pytest; extra == "dev"
|
18
|
+
Requires-Dist: ruff; extra == "dev"
|
19
|
+
Requires-Dist: types-tensorflow; extra == "dev"
|
20
|
+
Provides-Extra: pytorch
|
21
|
+
Requires-Dist: torch; extra == "pytorch"
|
22
|
+
Provides-Extra: jax
|
23
|
+
Requires-Dist: tensorflow; extra == "jax"
|
24
|
+
Requires-Dist: tf2onnx; extra == "jax"
|
25
|
+
Requires-Dist: jax; extra == "jax"
|
26
|
+
Requires-Dist: equinox; extra == "jax"
|
27
|
+
Requires-Dist: numpy<2; extra == "jax"
|
28
|
+
Provides-Extra: all
|
29
|
+
Requires-Dist: black; extra == "all"
|
30
|
+
Requires-Dist: darglint; extra == "all"
|
31
|
+
Requires-Dist: mypy; extra == "all"
|
32
|
+
Requires-Dist: pytest; extra == "all"
|
33
|
+
Requires-Dist: ruff; extra == "all"
|
34
|
+
Requires-Dist: types-tensorflow; extra == "all"
|
35
|
+
Requires-Dist: torch; extra == "all"
|
36
|
+
Requires-Dist: tensorflow; extra == "all"
|
37
|
+
Requires-Dist: tf2onnx; extra == "all"
|
38
|
+
Requires-Dist: jax; extra == "all"
|
39
|
+
Requires-Dist: equinox; extra == "all"
|
40
|
+
Requires-Dist: numpy<2; extra == "all"
|
41
|
+
Dynamic: author
|
42
|
+
Dynamic: description
|
43
|
+
Dynamic: description-content-type
|
44
|
+
Dynamic: home-page
|
45
|
+
Dynamic: license-file
|
46
|
+
Dynamic: provides-extra
|
47
|
+
Dynamic: requires-dist
|
48
|
+
Dynamic: requires-python
|
49
|
+
Dynamic: summary
|
50
|
+
|
51
|
+
# kinfer
|
52
|
+
|
53
|
+
This package is designed to support running real-time robotics models.
|
54
|
+
|
55
|
+
For more information, see the documentation [here](https://docs.kscale.dev/docs/k-infer).
|
@@ -0,0 +1,26 @@
|
|
1
|
+
kinfer/requirements.txt,sha256=j08HO4ptA5afuj99j8FlAP2qla5Zf4_OiEBtgAmF7Jg,90
|
2
|
+
kinfer/__init__.py,sha256=YbtJIepEE4pbjYbdgFCV-rBP90AnQpBfDaprkflBmEE,99
|
3
|
+
kinfer/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
+
kinfer/rust_bindings.pyi,sha256=oQKlW_bRw_c7fgVMbqiscPJ57pli25DkTCA_odsMOSU,1639
|
5
|
+
kinfer/rust_bindings.cpython-312-darwin.so,sha256=2qlbCxfS5Dsxdw077wpnPjt0WQhjm5CRtcA4S7aTG8A,1912656
|
6
|
+
kinfer/rust/Cargo.toml,sha256=0SLLbjtoODLrSMcyPdg22Ora-JNdaCqbMKb1c5OQblU,404
|
7
|
+
kinfer/rust/src/runtime.rs,sha256=eDflFSIa4IxOmG3SInwrHrYc02UzEANZMkK-liA7pVA,3407
|
8
|
+
kinfer/rust/src/lib.rs,sha256=Z3dWdhKhqhWqPJee6vWde4ptqquOYW6W9wB0wyaKCyk,71
|
9
|
+
kinfer/rust/src/model.rs,sha256=pODQzD-roXpCJ7-kzaE6uULgqHgsHKb3oQjv-uKtCYk,10668
|
10
|
+
kinfer/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
+
kinfer/common/types.py,sha256=Yf3c8Tui7vWguXMmxxa_QynwydCPu9RxrJS8JJMDuKg,147
|
12
|
+
kinfer/export/serialize.py,sha256=pOxXdcWMoaNfnAUhgzHUTb7GtZvn0MYtKskIarHzAiY,3007
|
13
|
+
kinfer/export/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
|
+
kinfer/export/jax.py,sha256=nbZB7Bxed4MQmfBZrH_Ml6PnkHQoyq5jdOT01DxiZDA,1337
|
15
|
+
kinfer/export/pytorch.py,sha256=twGpcX6OaJgiiwQ0teo_4rrCFTrCqu-ULH2UEHl8auc,1477
|
16
|
+
kinfer/export/common.py,sha256=ZiOAtehb2_ATg7t3tWzxbHUkGhtJ8vqjlaaw94-Iux4,1021
|
17
|
+
kinfer/rust_bindings/Cargo.toml,sha256=i1RGB9VNd9Q4FJ6gGwjZJQYo8DBBvpVWf3GJ95EfVgM,637
|
18
|
+
kinfer/rust_bindings/pyproject.toml,sha256=jLcJuHCnQRh9HWR_R7a9qLHwj6LMBgnHyeKK_DruO1Y,135
|
19
|
+
kinfer/rust_bindings/rust_bindings.pyi,sha256=oQKlW_bRw_c7fgVMbqiscPJ57pli25DkTCA_odsMOSU,1639
|
20
|
+
kinfer/rust_bindings/src/lib.rs,sha256=lKYr6Imu0Z-XxDKCguJLfrhIuai6S2bEhRAe__9sFnc,10196
|
21
|
+
kinfer/rust_bindings/src/bin/stub_gen.rs,sha256=hhoVGnaSfazbSfj5a4x6mPicGPOgWQAfsDmiPej0B6Y,133
|
22
|
+
kinfer-0.4.0.dist-info/RECORD,,
|
23
|
+
kinfer-0.4.0.dist-info/WHEEL,sha256=mP9bWt4ASeNWfyg7GBBbGbsOVFgblaN5WklJcvrSjIE,136
|
24
|
+
kinfer-0.4.0.dist-info/top_level.txt,sha256=6mY_t3PYr3Dm0dpqMk80uSnArbvGfCFkxOh1QWtgDEo,7
|
25
|
+
kinfer-0.4.0.dist-info/METADATA,sha256=2gGRmUwgmAoYK8kRG5qTtteIMrdaUYW7Edvpk_32yas,1745
|
26
|
+
kinfer-0.4.0.dist-info/licenses/LICENSE,sha256=Qw-Z0XTwS-diSW91e_jLeBPX9zZbAatOJTBLdPHPaC0,1069
|
kinfer/inference/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1
|
-
from .python import *
|
kinfer/inference/python.py
DELETED
@@ -1,92 +0,0 @@
|
|
1
|
-
"""ONNX model inference utilities for Python."""
|
2
|
-
|
3
|
-
import base64
|
4
|
-
from pathlib import Path
|
5
|
-
|
6
|
-
import onnx
|
7
|
-
import onnxruntime as ort
|
8
|
-
|
9
|
-
from kinfer import proto as P
|
10
|
-
from kinfer.export.pytorch import KINFER_METADATA_KEY
|
11
|
-
from kinfer.serialize.numpy import NumpyMultiSerializer
|
12
|
-
|
13
|
-
|
14
|
-
def _read_schema(model: onnx.ModelProto) -> P.ModelSchema:
|
15
|
-
for prop in model.metadata_props:
|
16
|
-
if prop.key == KINFER_METADATA_KEY:
|
17
|
-
try:
|
18
|
-
schema_bytes = base64.b64decode(prop.value)
|
19
|
-
schema = P.ModelSchema()
|
20
|
-
schema.ParseFromString(schema_bytes)
|
21
|
-
return schema
|
22
|
-
except Exception as e:
|
23
|
-
raise ValueError("Failed to parse kinfer_metadata value") from e
|
24
|
-
else:
|
25
|
-
raise ValueError(f"Found arbitrary metadata key {prop.key}")
|
26
|
-
|
27
|
-
raise ValueError(f"{KINFER_METADATA_KEY} not found in model metadata")
|
28
|
-
|
29
|
-
|
30
|
-
class ONNXModel:
|
31
|
-
"""Wrapper for ONNX model inference."""
|
32
|
-
|
33
|
-
def __init__(self: "ONNXModel", model_path: str | Path) -> None:
|
34
|
-
"""Initialize ONNX model.
|
35
|
-
|
36
|
-
Args:
|
37
|
-
model_path: Path to ONNX model file
|
38
|
-
"""
|
39
|
-
self.model_path = model_path
|
40
|
-
|
41
|
-
# Load model and create inference session
|
42
|
-
self.model = onnx.load(model_path)
|
43
|
-
self.session = ort.InferenceSession(model_path)
|
44
|
-
self._schema = _read_schema(self.model)
|
45
|
-
|
46
|
-
# Create serializers for input and output.
|
47
|
-
self._input_serializer = NumpyMultiSerializer(self._schema.input_schema)
|
48
|
-
self._output_serializer = NumpyMultiSerializer(self._schema.output_schema)
|
49
|
-
|
50
|
-
def __call__(self: "ONNXModel", inputs: P.IO) -> P.IO:
|
51
|
-
"""Run inference on input data.
|
52
|
-
|
53
|
-
Args:
|
54
|
-
inputs: Input data, matching the input schema.
|
55
|
-
|
56
|
-
Returns:
|
57
|
-
Model outputs, matching the output schema.
|
58
|
-
"""
|
59
|
-
inputs_np = self._input_serializer.serialize_io(inputs, as_dict=True)
|
60
|
-
outputs_np = self.session.run(None, inputs_np)
|
61
|
-
outputs = self._output_serializer.deserialize_io(outputs_np)
|
62
|
-
return outputs
|
63
|
-
|
64
|
-
@property
|
65
|
-
def input_schema(self: "ONNXModel") -> P.IOSchema:
|
66
|
-
"""Get the input schema."""
|
67
|
-
return self._schema.input_schema
|
68
|
-
|
69
|
-
@property
|
70
|
-
def output_schema(self: "ONNXModel") -> P.IOSchema:
|
71
|
-
"""Get the output schema."""
|
72
|
-
return self._schema.output_schema
|
73
|
-
|
74
|
-
@property
|
75
|
-
def schema_input_keys(self: "ONNXModel") -> list[str]:
|
76
|
-
"""Get all value names from input schemas.
|
77
|
-
|
78
|
-
Returns:
|
79
|
-
List of value names from input schema.
|
80
|
-
"""
|
81
|
-
input_names = [value.value_name for value in self._schema.input_schema.values]
|
82
|
-
return input_names
|
83
|
-
|
84
|
-
@property
|
85
|
-
def schema_output_keys(self: "ONNXModel") -> list[str]:
|
86
|
-
"""Get all value names from output schemas.
|
87
|
-
|
88
|
-
Returns:
|
89
|
-
List of value names from output schema.
|
90
|
-
"""
|
91
|
-
output_names = [value.value_name for value in self._schema.output_schema.values]
|
92
|
-
return output_names
|
kinfer/proto/__init__.py
DELETED
@@ -1,40 +0,0 @@
|
|
1
|
-
"""Defines helper types for the protocol buffers."""
|
2
|
-
|
3
|
-
from .kinfer_pb2 import (
|
4
|
-
IO,
|
5
|
-
AudioFrameSchema,
|
6
|
-
AudioFrameValue,
|
7
|
-
CameraFrameSchema,
|
8
|
-
CameraFrameValue,
|
9
|
-
DType,
|
10
|
-
ImuAccelerometerValue,
|
11
|
-
ImuGyroscopeValue,
|
12
|
-
ImuMagnetometerValue,
|
13
|
-
ImuSchema,
|
14
|
-
ImuValue,
|
15
|
-
IOSchema,
|
16
|
-
JointCommandsSchema,
|
17
|
-
JointCommandsValue,
|
18
|
-
JointCommandValue,
|
19
|
-
JointPositionsSchema,
|
20
|
-
JointPositionsValue,
|
21
|
-
JointPositionUnit,
|
22
|
-
JointPositionValue,
|
23
|
-
JointTorquesSchema,
|
24
|
-
JointTorquesValue,
|
25
|
-
JointTorqueUnit,
|
26
|
-
JointTorqueValue,
|
27
|
-
JointVelocitiesSchema,
|
28
|
-
JointVelocitiesValue,
|
29
|
-
JointVelocityUnit,
|
30
|
-
JointVelocityValue,
|
31
|
-
ModelSchema,
|
32
|
-
StateTensorSchema,
|
33
|
-
StateTensorValue,
|
34
|
-
TimestampSchema,
|
35
|
-
TimestampValue,
|
36
|
-
Value,
|
37
|
-
ValueSchema,
|
38
|
-
VectorCommandSchema,
|
39
|
-
VectorCommandValue,
|
40
|
-
)
|