kinfer 0.3.2__cp312-cp312-macosx_11_0_arm64.whl → 0.4.0__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. kinfer/__init__.py +0 -1
  2. kinfer/common/__init__.py +0 -0
  3. kinfer/common/types.py +11 -0
  4. kinfer/export/__init__.py +0 -1
  5. kinfer/export/common.py +35 -0
  6. kinfer/export/jax.py +51 -0
  7. kinfer/export/pytorch.py +42 -110
  8. kinfer/export/serialize.py +86 -0
  9. kinfer/requirements.txt +3 -4
  10. kinfer/rust/Cargo.toml +8 -6
  11. kinfer/rust/src/lib.rs +2 -11
  12. kinfer/rust/src/model.rs +271 -121
  13. kinfer/rust/src/runtime.rs +104 -0
  14. kinfer/rust_bindings/Cargo.toml +8 -1
  15. kinfer/rust_bindings/rust_bindings.pyi +35 -0
  16. kinfer/rust_bindings/src/lib.rs +310 -1
  17. kinfer/rust_bindings.cpython-312-darwin.so +0 -0
  18. kinfer/rust_bindings.pyi +29 -1
  19. kinfer-0.4.0.dist-info/METADATA +55 -0
  20. kinfer-0.4.0.dist-info/RECORD +26 -0
  21. {kinfer-0.3.2.dist-info → kinfer-0.4.0.dist-info}/WHEEL +2 -1
  22. kinfer/inference/__init__.py +0 -1
  23. kinfer/inference/python.py +0 -92
  24. kinfer/proto/__init__.py +0 -40
  25. kinfer/proto/kinfer_pb2.py +0 -103
  26. kinfer/proto/kinfer_pb2.pyi +0 -1097
  27. kinfer/requirements-dev.txt +0 -8
  28. kinfer/rust/build.rs +0 -16
  29. kinfer/rust/src/kinfer_proto.rs +0 -14
  30. kinfer/rust/src/main.rs +0 -6
  31. kinfer/rust/src/onnx_serializer.rs +0 -804
  32. kinfer/rust/src/serializer.rs +0 -221
  33. kinfer/rust/src/tests/onnx_serializer_tests.rs +0 -212
  34. kinfer/serialize/__init__.py +0 -36
  35. kinfer/serialize/base.py +0 -536
  36. kinfer/serialize/json.py +0 -399
  37. kinfer/serialize/numpy.py +0 -426
  38. kinfer/serialize/pytorch.py +0 -402
  39. kinfer/serialize/schema.py +0 -125
  40. kinfer/serialize/types.py +0 -17
  41. kinfer/serialize/utils.py +0 -177
  42. kinfer-0.3.2.dist-info/METADATA +0 -57
  43. kinfer-0.3.2.dist-info/RECORD +0 -39
  44. {kinfer-0.3.2.dist-info → kinfer-0.4.0.dist-info/licenses}/LICENSE +0 -0
  45. {kinfer-0.3.2.dist-info → kinfer-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,15 @@
1
+ use async_trait::async_trait;
2
+ use kinfer::model::{ModelError, ModelProvider, ModelRunner};
3
+ use kinfer::runtime::ModelRuntime;
4
+ use ndarray::{Array, Ix1, IxDyn};
5
+ use numpy::{PyArray1, PyArrayDyn, PyArrayMethods};
6
+ use pyo3::exceptions::PyNotImplementedError;
1
7
  use pyo3::prelude::*;
8
+ use pyo3::{pymodule, types::PyModule, Bound, PyResult, Python};
2
9
  use pyo3_stub_gen::define_stub_info_gatherer;
3
- use pyo3_stub_gen::derive::gen_stub_pyfunction;
10
+ use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyfunction, gen_stub_pymethods};
11
+ use std::sync::Arc;
12
+ use std::sync::Mutex;
4
13
 
5
14
  #[pyfunction]
6
15
  #[gen_stub_pyfunction]
@@ -8,9 +17,309 @@ fn get_version() -> String {
8
17
  env!("CARGO_PKG_VERSION").to_string()
9
18
  }
10
19
 
20
+ #[pyclass(subclass)]
21
+ #[gen_stub_pyclass]
22
+ pub struct ModelProviderABC;
23
+
24
+ #[gen_stub_pymethods]
25
+ #[pymethods]
26
+ impl ModelProviderABC {
27
+ #[new]
28
+ fn new() -> Self {
29
+ ModelProviderABC
30
+ }
31
+
32
+ fn get_joint_angles<'py>(
33
+ &self,
34
+ joint_names: Vec<String>,
35
+ ) -> PyResult<Bound<'py, PyArray1<f32>>> {
36
+ let n = joint_names.len();
37
+ Err(PyNotImplementedError::new_err(format!(
38
+ "Must override get_joint_angles with {} joint names",
39
+ n
40
+ )))
41
+ }
42
+
43
+ fn get_joint_angular_velocities<'py>(
44
+ &self,
45
+ joint_names: Vec<String>,
46
+ ) -> PyResult<Bound<'py, PyArray1<f32>>> {
47
+ let n = joint_names.len();
48
+ Err(PyNotImplementedError::new_err(format!(
49
+ "Must override get_joint_angular_velocities with {} joint names",
50
+ n
51
+ )))
52
+ }
53
+
54
+ fn get_projected_gravity<'py>(&self) -> PyResult<Bound<'py, PyArray1<f32>>> {
55
+ Err(PyNotImplementedError::new_err(
56
+ "Must override get_projected_gravity",
57
+ ))
58
+ }
59
+
60
+ fn get_accelerometer<'py>(&self) -> PyResult<Bound<'py, PyArray1<f32>>> {
61
+ Err(PyNotImplementedError::new_err(
62
+ "Must override get_accelerometer",
63
+ ))
64
+ }
65
+
66
+ fn get_gyroscope<'py>(&self) -> PyResult<Bound<'py, PyArray1<f32>>> {
67
+ Err(PyNotImplementedError::new_err(
68
+ "Must override get_gyroscope",
69
+ ))
70
+ }
71
+
72
+ fn take_action<'py>(
73
+ &self,
74
+ joint_names: Vec<String>,
75
+ action: Bound<'py, PyArray1<f32>>,
76
+ ) -> PyResult<()> {
77
+ let n = action.len()?;
78
+ assert_eq!(joint_names.len(), n);
79
+ Err(PyNotImplementedError::new_err(format!(
80
+ "Must override take_action with {} action",
81
+ n
82
+ )))
83
+ }
84
+ }
85
+
86
+ #[gen_stub_pyclass]
87
+ #[pyclass]
88
+ #[derive(Clone)]
89
+ struct PyModelProvider {
90
+ obj: Arc<Py<ModelProviderABC>>,
91
+ }
92
+
93
+ #[pymethods]
94
+ impl PyModelProvider {
95
+ #[new]
96
+ fn new(obj: Py<ModelProviderABC>) -> Self {
97
+ Self { obj: Arc::new(obj) }
98
+ }
99
+ }
100
+
101
+ #[async_trait]
102
+ impl ModelProvider for PyModelProvider {
103
+ async fn get_joint_angles(
104
+ &self,
105
+ joint_names: &[String],
106
+ ) -> Result<Array<f32, IxDyn>, ModelError> {
107
+ let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
108
+ let obj = self.obj.clone();
109
+ let args = (joint_names,);
110
+ let result = obj.call_method(py, "get_joint_angles", args, None)?;
111
+ let array = result.extract::<Vec<f32>>(py)?;
112
+ Ok(Array::from_vec(array).into_dyn())
113
+ })
114
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
115
+ Ok(args)
116
+ }
117
+
118
+ async fn get_joint_angular_velocities(
119
+ &self,
120
+ joint_names: &[String],
121
+ ) -> Result<Array<f32, IxDyn>, ModelError> {
122
+ let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
123
+ let obj = self.obj.clone();
124
+ let args = (joint_names,);
125
+ let result = obj.call_method(py, "get_joint_angular_velocities", args, None)?;
126
+ let array = result.extract::<Vec<f32>>(py)?;
127
+ Ok(Array::from_vec(array).into_dyn())
128
+ })
129
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
130
+ Ok(args)
131
+ }
132
+
133
+ async fn get_projected_gravity(&self) -> Result<Array<f32, IxDyn>, ModelError> {
134
+ let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
135
+ let obj = self.obj.clone();
136
+ let args = ();
137
+ let result = obj.call_method(py, "get_projected_gravity", args, None)?;
138
+ let array = result.extract::<Vec<f32>>(py)?;
139
+ Ok(Array::from_vec(array).into_dyn())
140
+ })
141
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
142
+ Ok(args)
143
+ }
144
+
145
+ async fn get_accelerometer(&self) -> Result<Array<f32, IxDyn>, ModelError> {
146
+ let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
147
+ let obj = self.obj.clone();
148
+ let args = ();
149
+ let result = obj.call_method(py, "get_accelerometer", args, None)?;
150
+ let array = result.extract::<Vec<f32>>(py)?;
151
+ Ok(Array::from_vec(array).into_dyn())
152
+ })
153
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
154
+ Ok(args)
155
+ }
156
+
157
+ async fn get_gyroscope(&self) -> Result<Array<f32, IxDyn>, ModelError> {
158
+ let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
159
+ let obj = self.obj.clone();
160
+ let args = ();
161
+ let result = obj.call_method(py, "get_gyroscope", args, None)?;
162
+ let array = result.extract::<Vec<f32>>(py)?;
163
+ Ok(Array::from_vec(array).into_dyn())
164
+ })
165
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
166
+ Ok(args)
167
+ }
168
+
169
+ async fn get_carry(&self, carry: Array<f32, IxDyn>) -> Result<Array<f32, IxDyn>, ModelError> {
170
+ Ok(carry)
171
+ }
172
+
173
+ async fn take_action(
174
+ &self,
175
+ joint_names: Vec<String>,
176
+ action: Array<f32, IxDyn>,
177
+ ) -> Result<(), ModelError> {
178
+ Python::with_gil(|py| -> PyResult<()> {
179
+ let obj = self.obj.clone();
180
+ let action_1d = action
181
+ .into_dimensionality::<Ix1>()
182
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
183
+ let args = (joint_names, PyArray1::from_array(py, &action_1d));
184
+ obj.call_method(py, "take_action", args, None)?;
185
+ Ok(())
186
+ })
187
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
188
+ Ok(())
189
+ }
190
+ }
191
+
192
+ #[gen_stub_pyclass]
193
+ #[pyclass]
194
+ #[derive(Clone)]
195
+ struct PyModelRunner {
196
+ runner: Arc<ModelRunner>,
197
+ }
198
+
199
+ #[gen_stub_pymethods]
200
+ #[pymethods]
201
+ impl PyModelRunner {
202
+ #[new]
203
+ fn new(model_path: String, provider: Py<ModelProviderABC>) -> PyResult<Self> {
204
+ let input_provider = Arc::new(PyModelProvider {
205
+ obj: Arc::new(provider),
206
+ });
207
+
208
+ let runner = tokio::runtime::Runtime::new().unwrap().block_on(async {
209
+ ModelRunner::new(model_path, input_provider)
210
+ .await
211
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
212
+ })?;
213
+
214
+ Ok(Self {
215
+ runner: Arc::new(runner),
216
+ })
217
+ }
218
+
219
+ fn init(&self) -> PyResult<Py<PyArrayDyn<f32>>> {
220
+ let runner = self.runner.clone();
221
+ let result = tokio::runtime::Runtime::new().unwrap().block_on(async {
222
+ runner
223
+ .init()
224
+ .await
225
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
226
+ })?;
227
+
228
+ Python::with_gil(|py| {
229
+ let array = numpy::PyArray::from_array(py, &result);
230
+ Ok(array.into())
231
+ })
232
+ }
233
+
234
+ fn step(
235
+ &self,
236
+ carry: Py<PyArrayDyn<f32>>,
237
+ ) -> PyResult<(Py<PyArrayDyn<f32>>, Py<PyArrayDyn<f32>>)> {
238
+ let runner = self.runner.clone();
239
+ let carry_array = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
240
+ let carry_array = carry.bind(py);
241
+ Ok(carry_array.to_owned_array())
242
+ })?;
243
+
244
+ let result = tokio::runtime::Runtime::new().unwrap().block_on(async {
245
+ runner
246
+ .step(carry_array)
247
+ .await
248
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
249
+ })?;
250
+
251
+ Python::with_gil(|py| {
252
+ let (output, carry) = result;
253
+ let output_array = numpy::PyArray::from_array(py, &output);
254
+ let carry_array = numpy::PyArray::from_array(py, &carry);
255
+ Ok((output_array.into(), carry_array.into()))
256
+ })
257
+ }
258
+
259
+ fn take_action(&self, action: Py<PyArrayDyn<f32>>) -> PyResult<()> {
260
+ let runner = self.runner.clone();
261
+ let action_array = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
262
+ let action_array = action.bind(py);
263
+ Ok(action_array.to_owned_array())
264
+ })?;
265
+
266
+ tokio::runtime::Runtime::new().unwrap().block_on(async {
267
+ runner
268
+ .take_action(action_array)
269
+ .await
270
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
271
+ })?;
272
+
273
+ Ok(())
274
+ }
275
+ }
276
+
277
+ #[gen_stub_pyclass]
278
+ #[pyclass]
279
+ #[derive(Clone)]
280
+ struct PyModelRuntime {
281
+ runtime: Arc<Mutex<ModelRuntime>>,
282
+ }
283
+
284
+ #[gen_stub_pymethods]
285
+ #[pymethods]
286
+ impl PyModelRuntime {
287
+ #[new]
288
+ fn new(model_runner: PyModelRunner, dt: u64) -> PyResult<Self> {
289
+ Ok(Self {
290
+ runtime: Arc::new(Mutex::new(ModelRuntime::new(model_runner.runner, dt))),
291
+ })
292
+ }
293
+
294
+ fn set_slowdown_factor(&self, slowdown_factor: i32) {
295
+ let mut runtime = self.runtime.lock().unwrap();
296
+ runtime.set_slowdown_factor(slowdown_factor);
297
+ }
298
+
299
+ fn set_magnitude_factor(&self, magnitude_factor: f32) {
300
+ let mut runtime = self.runtime.lock().unwrap();
301
+ runtime.set_magnitude_factor(magnitude_factor);
302
+ }
303
+
304
+ fn start(&self) -> PyResult<()> {
305
+ let mut runtime = self.runtime.lock().unwrap();
306
+ runtime
307
+ .start()
308
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
309
+ }
310
+
311
+ fn stop(&self) {
312
+ let mut runtime = self.runtime.lock().unwrap();
313
+ runtime.stop();
314
+ }
315
+ }
316
+
11
317
  #[pymodule]
12
318
  fn rust_bindings(m: &Bound<PyModule>) -> PyResult<()> {
13
319
  m.add_function(wrap_pyfunction!(get_version, m)?)?;
320
+ m.add_class::<ModelProviderABC>()?;
321
+ m.add_class::<PyModelRunner>()?;
322
+ m.add_class::<PyModelRuntime>()?;
14
323
  Ok(())
15
324
  }
16
325
 
Binary file
kinfer/rust_bindings.pyi CHANGED
@@ -1,7 +1,35 @@
1
1
  # This file is automatically generated by pyo3_stub_gen
2
2
  # ruff: noqa: E501, F401
3
3
 
4
+ import builtins
5
+ import numpy
6
+ import numpy.typing
7
+ import typing
4
8
 
5
- def get_version() -> str:
9
+ class ModelProviderABC:
10
+ def __new__(cls) -> ModelProviderABC: ...
11
+ def get_joint_angles(self, joint_names:typing.Sequence[builtins.str]) -> numpy.typing.NDArray[numpy.float32]: ...
12
+ def get_joint_angular_velocities(self, joint_names:typing.Sequence[builtins.str]) -> numpy.typing.NDArray[numpy.float32]: ...
13
+ def get_projected_gravity(self) -> numpy.typing.NDArray[numpy.float32]: ...
14
+ def get_accelerometer(self) -> numpy.typing.NDArray[numpy.float32]: ...
15
+ def get_gyroscope(self) -> numpy.typing.NDArray[numpy.float32]: ...
16
+ def take_action(self, joint_names:typing.Sequence[builtins.str], action:numpy.typing.NDArray[numpy.float32]) -> None: ...
17
+
18
+ class PyModelProvider:
6
19
  ...
7
20
 
21
+ class PyModelRunner:
22
+ def __new__(cls, model_path:builtins.str, provider:ModelProviderABC) -> PyModelRunner: ...
23
+ def init(self) -> numpy.typing.NDArray[numpy.float32]: ...
24
+ def step(self, carry:numpy.typing.NDArray[numpy.float32]) -> tuple[numpy.typing.NDArray[numpy.float32], numpy.typing.NDArray[numpy.float32]]: ...
25
+ def take_action(self, action:numpy.typing.NDArray[numpy.float32]) -> None: ...
26
+
27
+ class PyModelRuntime:
28
+ def __new__(cls, model_runner:PyModelRunner, dt:builtins.int) -> PyModelRuntime: ...
29
+ def set_slowdown_factor(self, slowdown_factor:builtins.int) -> None: ...
30
+ def set_magnitude_factor(self, magnitude_factor:builtins.float) -> None: ...
31
+ def start(self) -> None: ...
32
+ def stop(self) -> None: ...
33
+
34
+ def get_version() -> builtins.str: ...
35
+
@@ -0,0 +1,55 @@
1
+ Metadata-Version: 2.4
2
+ Name: kinfer
3
+ Version: 0.4.0
4
+ Summary: Tool to make it easier to run a model on a real robot
5
+ Home-page: https://github.com/kscalelabs/kinfer.git
6
+ Author: K-Scale Labs
7
+ Requires-Python: >=3.11
8
+ Description-Content-Type: text/markdown
9
+ License-File: LICENSE
10
+ Requires-Dist: onnx
11
+ Requires-Dist: onnxruntime==1.20.0
12
+ Requires-Dist: pydantic
13
+ Provides-Extra: dev
14
+ Requires-Dist: black; extra == "dev"
15
+ Requires-Dist: darglint; extra == "dev"
16
+ Requires-Dist: mypy; extra == "dev"
17
+ Requires-Dist: pytest; extra == "dev"
18
+ Requires-Dist: ruff; extra == "dev"
19
+ Requires-Dist: types-tensorflow; extra == "dev"
20
+ Provides-Extra: pytorch
21
+ Requires-Dist: torch; extra == "pytorch"
22
+ Provides-Extra: jax
23
+ Requires-Dist: tensorflow; extra == "jax"
24
+ Requires-Dist: tf2onnx; extra == "jax"
25
+ Requires-Dist: jax; extra == "jax"
26
+ Requires-Dist: equinox; extra == "jax"
27
+ Requires-Dist: numpy<2; extra == "jax"
28
+ Provides-Extra: all
29
+ Requires-Dist: black; extra == "all"
30
+ Requires-Dist: darglint; extra == "all"
31
+ Requires-Dist: mypy; extra == "all"
32
+ Requires-Dist: pytest; extra == "all"
33
+ Requires-Dist: ruff; extra == "all"
34
+ Requires-Dist: types-tensorflow; extra == "all"
35
+ Requires-Dist: torch; extra == "all"
36
+ Requires-Dist: tensorflow; extra == "all"
37
+ Requires-Dist: tf2onnx; extra == "all"
38
+ Requires-Dist: jax; extra == "all"
39
+ Requires-Dist: equinox; extra == "all"
40
+ Requires-Dist: numpy<2; extra == "all"
41
+ Dynamic: author
42
+ Dynamic: description
43
+ Dynamic: description-content-type
44
+ Dynamic: home-page
45
+ Dynamic: license-file
46
+ Dynamic: provides-extra
47
+ Dynamic: requires-dist
48
+ Dynamic: requires-python
49
+ Dynamic: summary
50
+
51
+ # kinfer
52
+
53
+ This package is designed to support running real-time robotics models.
54
+
55
+ For more information, see the documentation [here](https://docs.kscale.dev/docs/k-infer).
@@ -0,0 +1,26 @@
1
+ kinfer/requirements.txt,sha256=j08HO4ptA5afuj99j8FlAP2qla5Zf4_OiEBtgAmF7Jg,90
2
+ kinfer/__init__.py,sha256=YbtJIepEE4pbjYbdgFCV-rBP90AnQpBfDaprkflBmEE,99
3
+ kinfer/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ kinfer/rust_bindings.pyi,sha256=oQKlW_bRw_c7fgVMbqiscPJ57pli25DkTCA_odsMOSU,1639
5
+ kinfer/rust_bindings.cpython-312-darwin.so,sha256=2qlbCxfS5Dsxdw077wpnPjt0WQhjm5CRtcA4S7aTG8A,1912656
6
+ kinfer/rust/Cargo.toml,sha256=0SLLbjtoODLrSMcyPdg22Ora-JNdaCqbMKb1c5OQblU,404
7
+ kinfer/rust/src/runtime.rs,sha256=eDflFSIa4IxOmG3SInwrHrYc02UzEANZMkK-liA7pVA,3407
8
+ kinfer/rust/src/lib.rs,sha256=Z3dWdhKhqhWqPJee6vWde4ptqquOYW6W9wB0wyaKCyk,71
9
+ kinfer/rust/src/model.rs,sha256=pODQzD-roXpCJ7-kzaE6uULgqHgsHKb3oQjv-uKtCYk,10668
10
+ kinfer/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ kinfer/common/types.py,sha256=Yf3c8Tui7vWguXMmxxa_QynwydCPu9RxrJS8JJMDuKg,147
12
+ kinfer/export/serialize.py,sha256=pOxXdcWMoaNfnAUhgzHUTb7GtZvn0MYtKskIarHzAiY,3007
13
+ kinfer/export/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
+ kinfer/export/jax.py,sha256=nbZB7Bxed4MQmfBZrH_Ml6PnkHQoyq5jdOT01DxiZDA,1337
15
+ kinfer/export/pytorch.py,sha256=twGpcX6OaJgiiwQ0teo_4rrCFTrCqu-ULH2UEHl8auc,1477
16
+ kinfer/export/common.py,sha256=ZiOAtehb2_ATg7t3tWzxbHUkGhtJ8vqjlaaw94-Iux4,1021
17
+ kinfer/rust_bindings/Cargo.toml,sha256=i1RGB9VNd9Q4FJ6gGwjZJQYo8DBBvpVWf3GJ95EfVgM,637
18
+ kinfer/rust_bindings/pyproject.toml,sha256=jLcJuHCnQRh9HWR_R7a9qLHwj6LMBgnHyeKK_DruO1Y,135
19
+ kinfer/rust_bindings/rust_bindings.pyi,sha256=oQKlW_bRw_c7fgVMbqiscPJ57pli25DkTCA_odsMOSU,1639
20
+ kinfer/rust_bindings/src/lib.rs,sha256=lKYr6Imu0Z-XxDKCguJLfrhIuai6S2bEhRAe__9sFnc,10196
21
+ kinfer/rust_bindings/src/bin/stub_gen.rs,sha256=hhoVGnaSfazbSfj5a4x6mPicGPOgWQAfsDmiPej0B6Y,133
22
+ kinfer-0.4.0.dist-info/RECORD,,
23
+ kinfer-0.4.0.dist-info/WHEEL,sha256=mP9bWt4ASeNWfyg7GBBbGbsOVFgblaN5WklJcvrSjIE,136
24
+ kinfer-0.4.0.dist-info/top_level.txt,sha256=6mY_t3PYr3Dm0dpqMk80uSnArbvGfCFkxOh1QWtgDEo,7
25
+ kinfer-0.4.0.dist-info/METADATA,sha256=2gGRmUwgmAoYK8kRG5qTtteIMrdaUYW7Edvpk_32yas,1745
26
+ kinfer-0.4.0.dist-info/licenses/LICENSE,sha256=Qw-Z0XTwS-diSW91e_jLeBPX9zZbAatOJTBLdPHPaC0,1069
@@ -1,5 +1,6 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.6.0)
2
+ Generator: setuptools (80.4.0)
3
3
  Root-Is-Purelib: false
4
4
  Tag: cp312-cp312-macosx_11_0_arm64
5
+ Generator: delocate 0.13.0
5
6
 
@@ -1 +0,0 @@
1
- from .python import *
@@ -1,92 +0,0 @@
1
- """ONNX model inference utilities for Python."""
2
-
3
- import base64
4
- from pathlib import Path
5
-
6
- import onnx
7
- import onnxruntime as ort
8
-
9
- from kinfer import proto as P
10
- from kinfer.export.pytorch import KINFER_METADATA_KEY
11
- from kinfer.serialize.numpy import NumpyMultiSerializer
12
-
13
-
14
- def _read_schema(model: onnx.ModelProto) -> P.ModelSchema:
15
- for prop in model.metadata_props:
16
- if prop.key == KINFER_METADATA_KEY:
17
- try:
18
- schema_bytes = base64.b64decode(prop.value)
19
- schema = P.ModelSchema()
20
- schema.ParseFromString(schema_bytes)
21
- return schema
22
- except Exception as e:
23
- raise ValueError("Failed to parse kinfer_metadata value") from e
24
- else:
25
- raise ValueError(f"Found arbitrary metadata key {prop.key}")
26
-
27
- raise ValueError(f"{KINFER_METADATA_KEY} not found in model metadata")
28
-
29
-
30
- class ONNXModel:
31
- """Wrapper for ONNX model inference."""
32
-
33
- def __init__(self: "ONNXModel", model_path: str | Path) -> None:
34
- """Initialize ONNX model.
35
-
36
- Args:
37
- model_path: Path to ONNX model file
38
- """
39
- self.model_path = model_path
40
-
41
- # Load model and create inference session
42
- self.model = onnx.load(model_path)
43
- self.session = ort.InferenceSession(model_path)
44
- self._schema = _read_schema(self.model)
45
-
46
- # Create serializers for input and output.
47
- self._input_serializer = NumpyMultiSerializer(self._schema.input_schema)
48
- self._output_serializer = NumpyMultiSerializer(self._schema.output_schema)
49
-
50
- def __call__(self: "ONNXModel", inputs: P.IO) -> P.IO:
51
- """Run inference on input data.
52
-
53
- Args:
54
- inputs: Input data, matching the input schema.
55
-
56
- Returns:
57
- Model outputs, matching the output schema.
58
- """
59
- inputs_np = self._input_serializer.serialize_io(inputs, as_dict=True)
60
- outputs_np = self.session.run(None, inputs_np)
61
- outputs = self._output_serializer.deserialize_io(outputs_np)
62
- return outputs
63
-
64
- @property
65
- def input_schema(self: "ONNXModel") -> P.IOSchema:
66
- """Get the input schema."""
67
- return self._schema.input_schema
68
-
69
- @property
70
- def output_schema(self: "ONNXModel") -> P.IOSchema:
71
- """Get the output schema."""
72
- return self._schema.output_schema
73
-
74
- @property
75
- def schema_input_keys(self: "ONNXModel") -> list[str]:
76
- """Get all value names from input schemas.
77
-
78
- Returns:
79
- List of value names from input schema.
80
- """
81
- input_names = [value.value_name for value in self._schema.input_schema.values]
82
- return input_names
83
-
84
- @property
85
- def schema_output_keys(self: "ONNXModel") -> list[str]:
86
- """Get all value names from output schemas.
87
-
88
- Returns:
89
- List of value names from output schema.
90
- """
91
- output_names = [value.value_name for value in self._schema.output_schema.values]
92
- return output_names
kinfer/proto/__init__.py DELETED
@@ -1,40 +0,0 @@
1
- """Defines helper types for the protocol buffers."""
2
-
3
- from .kinfer_pb2 import (
4
- IO,
5
- AudioFrameSchema,
6
- AudioFrameValue,
7
- CameraFrameSchema,
8
- CameraFrameValue,
9
- DType,
10
- ImuAccelerometerValue,
11
- ImuGyroscopeValue,
12
- ImuMagnetometerValue,
13
- ImuSchema,
14
- ImuValue,
15
- IOSchema,
16
- JointCommandsSchema,
17
- JointCommandsValue,
18
- JointCommandValue,
19
- JointPositionsSchema,
20
- JointPositionsValue,
21
- JointPositionUnit,
22
- JointPositionValue,
23
- JointTorquesSchema,
24
- JointTorquesValue,
25
- JointTorqueUnit,
26
- JointTorqueValue,
27
- JointVelocitiesSchema,
28
- JointVelocitiesValue,
29
- JointVelocityUnit,
30
- JointVelocityValue,
31
- ModelSchema,
32
- StateTensorSchema,
33
- StateTensorValue,
34
- TimestampSchema,
35
- TimestampValue,
36
- Value,
37
- ValueSchema,
38
- VectorCommandSchema,
39
- VectorCommandValue,
40
- )