kinfer 0.3.3__cp311-cp311-macosx_11_0_arm64.whl → 0.4.0__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. kinfer/__init__.py +0 -5
  2. kinfer/common/__init__.py +0 -0
  3. kinfer/common/types.py +11 -0
  4. kinfer/export/common.py +35 -0
  5. kinfer/export/jax.py +51 -0
  6. kinfer/export/pytorch.py +42 -110
  7. kinfer/export/serialize.py +86 -0
  8. kinfer/requirements.txt +3 -4
  9. kinfer/rust/Cargo.toml +8 -6
  10. kinfer/rust/src/lib.rs +2 -11
  11. kinfer/rust/src/model.rs +271 -121
  12. kinfer/rust/src/runtime.rs +104 -0
  13. kinfer/rust_bindings/Cargo.toml +8 -1
  14. kinfer/rust_bindings/rust_bindings.pyi +35 -0
  15. kinfer/rust_bindings/src/lib.rs +310 -1
  16. kinfer/rust_bindings.cpython-311-darwin.so +0 -0
  17. kinfer/rust_bindings.pyi +29 -1
  18. kinfer-0.4.0.dist-info/METADATA +55 -0
  19. kinfer-0.4.0.dist-info/RECORD +26 -0
  20. {kinfer-0.3.3.dist-info → kinfer-0.4.0.dist-info}/WHEEL +2 -1
  21. kinfer/inference/__init__.py +0 -2
  22. kinfer/inference/base.py +0 -64
  23. kinfer/inference/python.py +0 -66
  24. kinfer/proto/__init__.py +0 -40
  25. kinfer/proto/kinfer_pb2.py +0 -103
  26. kinfer/proto/kinfer_pb2.pyi +0 -1097
  27. kinfer/requirements-dev.txt +0 -8
  28. kinfer/rust/build.rs +0 -16
  29. kinfer/rust/src/kinfer_proto.rs +0 -14
  30. kinfer/rust/src/main.rs +0 -6
  31. kinfer/rust/src/onnx_serializer.rs +0 -804
  32. kinfer/rust/src/serializer.rs +0 -221
  33. kinfer/rust/src/tests/onnx_serializer_tests.rs +0 -212
  34. kinfer/serialize/__init__.py +0 -60
  35. kinfer/serialize/base.py +0 -536
  36. kinfer/serialize/json.py +0 -399
  37. kinfer/serialize/numpy.py +0 -426
  38. kinfer/serialize/pytorch.py +0 -402
  39. kinfer/serialize/schema.py +0 -125
  40. kinfer/serialize/types.py +0 -17
  41. kinfer/serialize/utils.py +0 -177
  42. kinfer-0.3.3.dist-info/METADATA +0 -57
  43. kinfer-0.3.3.dist-info/RECORD +0 -40
  44. {kinfer-0.3.3.dist-info → kinfer-0.4.0.dist-info/licenses}/LICENSE +0 -0
  45. {kinfer-0.3.3.dist-info → kinfer-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,15 @@
1
+ use async_trait::async_trait;
2
+ use kinfer::model::{ModelError, ModelProvider, ModelRunner};
3
+ use kinfer::runtime::ModelRuntime;
4
+ use ndarray::{Array, Ix1, IxDyn};
5
+ use numpy::{PyArray1, PyArrayDyn, PyArrayMethods};
6
+ use pyo3::exceptions::PyNotImplementedError;
1
7
  use pyo3::prelude::*;
8
+ use pyo3::{pymodule, types::PyModule, Bound, PyResult, Python};
2
9
  use pyo3_stub_gen::define_stub_info_gatherer;
3
- use pyo3_stub_gen::derive::gen_stub_pyfunction;
10
+ use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyfunction, gen_stub_pymethods};
11
+ use std::sync::Arc;
12
+ use std::sync::Mutex;
4
13
 
5
14
  #[pyfunction]
6
15
  #[gen_stub_pyfunction]
@@ -8,9 +17,309 @@ fn get_version() -> String {
8
17
  env!("CARGO_PKG_VERSION").to_string()
9
18
  }
10
19
 
20
+ #[pyclass(subclass)]
21
+ #[gen_stub_pyclass]
22
+ pub struct ModelProviderABC;
23
+
24
+ #[gen_stub_pymethods]
25
+ #[pymethods]
26
+ impl ModelProviderABC {
27
+ #[new]
28
+ fn new() -> Self {
29
+ ModelProviderABC
30
+ }
31
+
32
+ fn get_joint_angles<'py>(
33
+ &self,
34
+ joint_names: Vec<String>,
35
+ ) -> PyResult<Bound<'py, PyArray1<f32>>> {
36
+ let n = joint_names.len();
37
+ Err(PyNotImplementedError::new_err(format!(
38
+ "Must override get_joint_angles with {} joint names",
39
+ n
40
+ )))
41
+ }
42
+
43
+ fn get_joint_angular_velocities<'py>(
44
+ &self,
45
+ joint_names: Vec<String>,
46
+ ) -> PyResult<Bound<'py, PyArray1<f32>>> {
47
+ let n = joint_names.len();
48
+ Err(PyNotImplementedError::new_err(format!(
49
+ "Must override get_joint_angular_velocities with {} joint names",
50
+ n
51
+ )))
52
+ }
53
+
54
+ fn get_projected_gravity<'py>(&self) -> PyResult<Bound<'py, PyArray1<f32>>> {
55
+ Err(PyNotImplementedError::new_err(
56
+ "Must override get_projected_gravity",
57
+ ))
58
+ }
59
+
60
+ fn get_accelerometer<'py>(&self) -> PyResult<Bound<'py, PyArray1<f32>>> {
61
+ Err(PyNotImplementedError::new_err(
62
+ "Must override get_accelerometer",
63
+ ))
64
+ }
65
+
66
+ fn get_gyroscope<'py>(&self) -> PyResult<Bound<'py, PyArray1<f32>>> {
67
+ Err(PyNotImplementedError::new_err(
68
+ "Must override get_gyroscope",
69
+ ))
70
+ }
71
+
72
+ fn take_action<'py>(
73
+ &self,
74
+ joint_names: Vec<String>,
75
+ action: Bound<'py, PyArray1<f32>>,
76
+ ) -> PyResult<()> {
77
+ let n = action.len()?;
78
+ assert_eq!(joint_names.len(), n);
79
+ Err(PyNotImplementedError::new_err(format!(
80
+ "Must override take_action with {} action",
81
+ n
82
+ )))
83
+ }
84
+ }
85
+
86
+ #[gen_stub_pyclass]
87
+ #[pyclass]
88
+ #[derive(Clone)]
89
+ struct PyModelProvider {
90
+ obj: Arc<Py<ModelProviderABC>>,
91
+ }
92
+
93
+ #[pymethods]
94
+ impl PyModelProvider {
95
+ #[new]
96
+ fn new(obj: Py<ModelProviderABC>) -> Self {
97
+ Self { obj: Arc::new(obj) }
98
+ }
99
+ }
100
+
101
+ #[async_trait]
102
+ impl ModelProvider for PyModelProvider {
103
+ async fn get_joint_angles(
104
+ &self,
105
+ joint_names: &[String],
106
+ ) -> Result<Array<f32, IxDyn>, ModelError> {
107
+ let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
108
+ let obj = self.obj.clone();
109
+ let args = (joint_names,);
110
+ let result = obj.call_method(py, "get_joint_angles", args, None)?;
111
+ let array = result.extract::<Vec<f32>>(py)?;
112
+ Ok(Array::from_vec(array).into_dyn())
113
+ })
114
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
115
+ Ok(args)
116
+ }
117
+
118
+ async fn get_joint_angular_velocities(
119
+ &self,
120
+ joint_names: &[String],
121
+ ) -> Result<Array<f32, IxDyn>, ModelError> {
122
+ let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
123
+ let obj = self.obj.clone();
124
+ let args = (joint_names,);
125
+ let result = obj.call_method(py, "get_joint_angular_velocities", args, None)?;
126
+ let array = result.extract::<Vec<f32>>(py)?;
127
+ Ok(Array::from_vec(array).into_dyn())
128
+ })
129
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
130
+ Ok(args)
131
+ }
132
+
133
+ async fn get_projected_gravity(&self) -> Result<Array<f32, IxDyn>, ModelError> {
134
+ let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
135
+ let obj = self.obj.clone();
136
+ let args = ();
137
+ let result = obj.call_method(py, "get_projected_gravity", args, None)?;
138
+ let array = result.extract::<Vec<f32>>(py)?;
139
+ Ok(Array::from_vec(array).into_dyn())
140
+ })
141
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
142
+ Ok(args)
143
+ }
144
+
145
+ async fn get_accelerometer(&self) -> Result<Array<f32, IxDyn>, ModelError> {
146
+ let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
147
+ let obj = self.obj.clone();
148
+ let args = ();
149
+ let result = obj.call_method(py, "get_accelerometer", args, None)?;
150
+ let array = result.extract::<Vec<f32>>(py)?;
151
+ Ok(Array::from_vec(array).into_dyn())
152
+ })
153
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
154
+ Ok(args)
155
+ }
156
+
157
+ async fn get_gyroscope(&self) -> Result<Array<f32, IxDyn>, ModelError> {
158
+ let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
159
+ let obj = self.obj.clone();
160
+ let args = ();
161
+ let result = obj.call_method(py, "get_gyroscope", args, None)?;
162
+ let array = result.extract::<Vec<f32>>(py)?;
163
+ Ok(Array::from_vec(array).into_dyn())
164
+ })
165
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
166
+ Ok(args)
167
+ }
168
+
169
+ async fn get_carry(&self, carry: Array<f32, IxDyn>) -> Result<Array<f32, IxDyn>, ModelError> {
170
+ Ok(carry)
171
+ }
172
+
173
+ async fn take_action(
174
+ &self,
175
+ joint_names: Vec<String>,
176
+ action: Array<f32, IxDyn>,
177
+ ) -> Result<(), ModelError> {
178
+ Python::with_gil(|py| -> PyResult<()> {
179
+ let obj = self.obj.clone();
180
+ let action_1d = action
181
+ .into_dimensionality::<Ix1>()
182
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
183
+ let args = (joint_names, PyArray1::from_array(py, &action_1d));
184
+ obj.call_method(py, "take_action", args, None)?;
185
+ Ok(())
186
+ })
187
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
188
+ Ok(())
189
+ }
190
+ }
191
+
192
+ #[gen_stub_pyclass]
193
+ #[pyclass]
194
+ #[derive(Clone)]
195
+ struct PyModelRunner {
196
+ runner: Arc<ModelRunner>,
197
+ }
198
+
199
+ #[gen_stub_pymethods]
200
+ #[pymethods]
201
+ impl PyModelRunner {
202
+ #[new]
203
+ fn new(model_path: String, provider: Py<ModelProviderABC>) -> PyResult<Self> {
204
+ let input_provider = Arc::new(PyModelProvider {
205
+ obj: Arc::new(provider),
206
+ });
207
+
208
+ let runner = tokio::runtime::Runtime::new().unwrap().block_on(async {
209
+ ModelRunner::new(model_path, input_provider)
210
+ .await
211
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
212
+ })?;
213
+
214
+ Ok(Self {
215
+ runner: Arc::new(runner),
216
+ })
217
+ }
218
+
219
+ fn init(&self) -> PyResult<Py<PyArrayDyn<f32>>> {
220
+ let runner = self.runner.clone();
221
+ let result = tokio::runtime::Runtime::new().unwrap().block_on(async {
222
+ runner
223
+ .init()
224
+ .await
225
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
226
+ })?;
227
+
228
+ Python::with_gil(|py| {
229
+ let array = numpy::PyArray::from_array(py, &result);
230
+ Ok(array.into())
231
+ })
232
+ }
233
+
234
+ fn step(
235
+ &self,
236
+ carry: Py<PyArrayDyn<f32>>,
237
+ ) -> PyResult<(Py<PyArrayDyn<f32>>, Py<PyArrayDyn<f32>>)> {
238
+ let runner = self.runner.clone();
239
+ let carry_array = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
240
+ let carry_array = carry.bind(py);
241
+ Ok(carry_array.to_owned_array())
242
+ })?;
243
+
244
+ let result = tokio::runtime::Runtime::new().unwrap().block_on(async {
245
+ runner
246
+ .step(carry_array)
247
+ .await
248
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
249
+ })?;
250
+
251
+ Python::with_gil(|py| {
252
+ let (output, carry) = result;
253
+ let output_array = numpy::PyArray::from_array(py, &output);
254
+ let carry_array = numpy::PyArray::from_array(py, &carry);
255
+ Ok((output_array.into(), carry_array.into()))
256
+ })
257
+ }
258
+
259
+ fn take_action(&self, action: Py<PyArrayDyn<f32>>) -> PyResult<()> {
260
+ let runner = self.runner.clone();
261
+ let action_array = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
262
+ let action_array = action.bind(py);
263
+ Ok(action_array.to_owned_array())
264
+ })?;
265
+
266
+ tokio::runtime::Runtime::new().unwrap().block_on(async {
267
+ runner
268
+ .take_action(action_array)
269
+ .await
270
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
271
+ })?;
272
+
273
+ Ok(())
274
+ }
275
+ }
276
+
277
+ #[gen_stub_pyclass]
278
+ #[pyclass]
279
+ #[derive(Clone)]
280
+ struct PyModelRuntime {
281
+ runtime: Arc<Mutex<ModelRuntime>>,
282
+ }
283
+
284
+ #[gen_stub_pymethods]
285
+ #[pymethods]
286
+ impl PyModelRuntime {
287
+ #[new]
288
+ fn new(model_runner: PyModelRunner, dt: u64) -> PyResult<Self> {
289
+ Ok(Self {
290
+ runtime: Arc::new(Mutex::new(ModelRuntime::new(model_runner.runner, dt))),
291
+ })
292
+ }
293
+
294
+ fn set_slowdown_factor(&self, slowdown_factor: i32) {
295
+ let mut runtime = self.runtime.lock().unwrap();
296
+ runtime.set_slowdown_factor(slowdown_factor);
297
+ }
298
+
299
+ fn set_magnitude_factor(&self, magnitude_factor: f32) {
300
+ let mut runtime = self.runtime.lock().unwrap();
301
+ runtime.set_magnitude_factor(magnitude_factor);
302
+ }
303
+
304
+ fn start(&self) -> PyResult<()> {
305
+ let mut runtime = self.runtime.lock().unwrap();
306
+ runtime
307
+ .start()
308
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
309
+ }
310
+
311
+ fn stop(&self) {
312
+ let mut runtime = self.runtime.lock().unwrap();
313
+ runtime.stop();
314
+ }
315
+ }
316
+
11
317
  #[pymodule]
12
318
  fn rust_bindings(m: &Bound<PyModule>) -> PyResult<()> {
13
319
  m.add_function(wrap_pyfunction!(get_version, m)?)?;
320
+ m.add_class::<ModelProviderABC>()?;
321
+ m.add_class::<PyModelRunner>()?;
322
+ m.add_class::<PyModelRuntime>()?;
14
323
  Ok(())
15
324
  }
16
325
 
Binary file
kinfer/rust_bindings.pyi CHANGED
@@ -1,7 +1,35 @@
1
1
  # This file is automatically generated by pyo3_stub_gen
2
2
  # ruff: noqa: E501, F401
3
3
 
4
+ import builtins
5
+ import numpy
6
+ import numpy.typing
7
+ import typing
4
8
 
5
- def get_version() -> str:
9
+ class ModelProviderABC:
10
+ def __new__(cls) -> ModelProviderABC: ...
11
+ def get_joint_angles(self, joint_names:typing.Sequence[builtins.str]) -> numpy.typing.NDArray[numpy.float32]: ...
12
+ def get_joint_angular_velocities(self, joint_names:typing.Sequence[builtins.str]) -> numpy.typing.NDArray[numpy.float32]: ...
13
+ def get_projected_gravity(self) -> numpy.typing.NDArray[numpy.float32]: ...
14
+ def get_accelerometer(self) -> numpy.typing.NDArray[numpy.float32]: ...
15
+ def get_gyroscope(self) -> numpy.typing.NDArray[numpy.float32]: ...
16
+ def take_action(self, joint_names:typing.Sequence[builtins.str], action:numpy.typing.NDArray[numpy.float32]) -> None: ...
17
+
18
+ class PyModelProvider:
6
19
  ...
7
20
 
21
+ class PyModelRunner:
22
+ def __new__(cls, model_path:builtins.str, provider:ModelProviderABC) -> PyModelRunner: ...
23
+ def init(self) -> numpy.typing.NDArray[numpy.float32]: ...
24
+ def step(self, carry:numpy.typing.NDArray[numpy.float32]) -> tuple[numpy.typing.NDArray[numpy.float32], numpy.typing.NDArray[numpy.float32]]: ...
25
+ def take_action(self, action:numpy.typing.NDArray[numpy.float32]) -> None: ...
26
+
27
+ class PyModelRuntime:
28
+ def __new__(cls, model_runner:PyModelRunner, dt:builtins.int) -> PyModelRuntime: ...
29
+ def set_slowdown_factor(self, slowdown_factor:builtins.int) -> None: ...
30
+ def set_magnitude_factor(self, magnitude_factor:builtins.float) -> None: ...
31
+ def start(self) -> None: ...
32
+ def stop(self) -> None: ...
33
+
34
+ def get_version() -> builtins.str: ...
35
+
@@ -0,0 +1,55 @@
1
+ Metadata-Version: 2.4
2
+ Name: kinfer
3
+ Version: 0.4.0
4
+ Summary: Tool to make it easier to run a model on a real robot
5
+ Home-page: https://github.com/kscalelabs/kinfer.git
6
+ Author: K-Scale Labs
7
+ Requires-Python: >=3.11
8
+ Description-Content-Type: text/markdown
9
+ License-File: LICENSE
10
+ Requires-Dist: onnx
11
+ Requires-Dist: onnxruntime==1.20.0
12
+ Requires-Dist: pydantic
13
+ Provides-Extra: dev
14
+ Requires-Dist: black; extra == "dev"
15
+ Requires-Dist: darglint; extra == "dev"
16
+ Requires-Dist: mypy; extra == "dev"
17
+ Requires-Dist: pytest; extra == "dev"
18
+ Requires-Dist: ruff; extra == "dev"
19
+ Requires-Dist: types-tensorflow; extra == "dev"
20
+ Provides-Extra: pytorch
21
+ Requires-Dist: torch; extra == "pytorch"
22
+ Provides-Extra: jax
23
+ Requires-Dist: tensorflow; extra == "jax"
24
+ Requires-Dist: tf2onnx; extra == "jax"
25
+ Requires-Dist: jax; extra == "jax"
26
+ Requires-Dist: equinox; extra == "jax"
27
+ Requires-Dist: numpy<2; extra == "jax"
28
+ Provides-Extra: all
29
+ Requires-Dist: black; extra == "all"
30
+ Requires-Dist: darglint; extra == "all"
31
+ Requires-Dist: mypy; extra == "all"
32
+ Requires-Dist: pytest; extra == "all"
33
+ Requires-Dist: ruff; extra == "all"
34
+ Requires-Dist: types-tensorflow; extra == "all"
35
+ Requires-Dist: torch; extra == "all"
36
+ Requires-Dist: tensorflow; extra == "all"
37
+ Requires-Dist: tf2onnx; extra == "all"
38
+ Requires-Dist: jax; extra == "all"
39
+ Requires-Dist: equinox; extra == "all"
40
+ Requires-Dist: numpy<2; extra == "all"
41
+ Dynamic: author
42
+ Dynamic: description
43
+ Dynamic: description-content-type
44
+ Dynamic: home-page
45
+ Dynamic: license-file
46
+ Dynamic: provides-extra
47
+ Dynamic: requires-dist
48
+ Dynamic: requires-python
49
+ Dynamic: summary
50
+
51
+ # kinfer
52
+
53
+ This package is designed to support running real-time robotics models.
54
+
55
+ For more information, see the documentation [here](https://docs.kscale.dev/docs/k-infer).
@@ -0,0 +1,26 @@
1
+ kinfer/rust_bindings.cpython-311-darwin.so,sha256=lwkKH5R0rx9hsMdjtsNcSWwdFM93Bz2Eg_91Hpcx7vg,1931696
2
+ kinfer/requirements.txt,sha256=j08HO4ptA5afuj99j8FlAP2qla5Zf4_OiEBtgAmF7Jg,90
3
+ kinfer/__init__.py,sha256=YbtJIepEE4pbjYbdgFCV-rBP90AnQpBfDaprkflBmEE,99
4
+ kinfer/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ kinfer/rust_bindings.pyi,sha256=oQKlW_bRw_c7fgVMbqiscPJ57pli25DkTCA_odsMOSU,1639
6
+ kinfer/rust/Cargo.toml,sha256=0SLLbjtoODLrSMcyPdg22Ora-JNdaCqbMKb1c5OQblU,404
7
+ kinfer/rust/src/runtime.rs,sha256=eDflFSIa4IxOmG3SInwrHrYc02UzEANZMkK-liA7pVA,3407
8
+ kinfer/rust/src/lib.rs,sha256=Z3dWdhKhqhWqPJee6vWde4ptqquOYW6W9wB0wyaKCyk,71
9
+ kinfer/rust/src/model.rs,sha256=pODQzD-roXpCJ7-kzaE6uULgqHgsHKb3oQjv-uKtCYk,10668
10
+ kinfer/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ kinfer/common/types.py,sha256=Yf3c8Tui7vWguXMmxxa_QynwydCPu9RxrJS8JJMDuKg,147
12
+ kinfer/export/serialize.py,sha256=pOxXdcWMoaNfnAUhgzHUTb7GtZvn0MYtKskIarHzAiY,3007
13
+ kinfer/export/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
+ kinfer/export/jax.py,sha256=nbZB7Bxed4MQmfBZrH_Ml6PnkHQoyq5jdOT01DxiZDA,1337
15
+ kinfer/export/pytorch.py,sha256=twGpcX6OaJgiiwQ0teo_4rrCFTrCqu-ULH2UEHl8auc,1477
16
+ kinfer/export/common.py,sha256=ZiOAtehb2_ATg7t3tWzxbHUkGhtJ8vqjlaaw94-Iux4,1021
17
+ kinfer/rust_bindings/Cargo.toml,sha256=i1RGB9VNd9Q4FJ6gGwjZJQYo8DBBvpVWf3GJ95EfVgM,637
18
+ kinfer/rust_bindings/pyproject.toml,sha256=jLcJuHCnQRh9HWR_R7a9qLHwj6LMBgnHyeKK_DruO1Y,135
19
+ kinfer/rust_bindings/rust_bindings.pyi,sha256=oQKlW_bRw_c7fgVMbqiscPJ57pli25DkTCA_odsMOSU,1639
20
+ kinfer/rust_bindings/src/lib.rs,sha256=lKYr6Imu0Z-XxDKCguJLfrhIuai6S2bEhRAe__9sFnc,10196
21
+ kinfer/rust_bindings/src/bin/stub_gen.rs,sha256=hhoVGnaSfazbSfj5a4x6mPicGPOgWQAfsDmiPej0B6Y,133
22
+ kinfer-0.4.0.dist-info/RECORD,,
23
+ kinfer-0.4.0.dist-info/WHEEL,sha256=rtKwvZSAzWV03G3Ircwq_TRBlj2DQz4ocNXl0bd9DbU,136
24
+ kinfer-0.4.0.dist-info/top_level.txt,sha256=6mY_t3PYr3Dm0dpqMk80uSnArbvGfCFkxOh1QWtgDEo,7
25
+ kinfer-0.4.0.dist-info/METADATA,sha256=2gGRmUwgmAoYK8kRG5qTtteIMrdaUYW7Edvpk_32yas,1745
26
+ kinfer-0.4.0.dist-info/licenses/LICENSE,sha256=Qw-Z0XTwS-diSW91e_jLeBPX9zZbAatOJTBLdPHPaC0,1069
@@ -1,5 +1,6 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.6.0)
2
+ Generator: setuptools (80.4.0)
3
3
  Root-Is-Purelib: false
4
4
  Tag: cp311-cp311-macosx_11_0_arm64
5
+ Generator: delocate 0.13.0
5
6
 
@@ -1,2 +0,0 @@
1
- from .base import KModel
2
- from .python import ONNXModel
kinfer/inference/base.py DELETED
@@ -1,64 +0,0 @@
1
- """Defines the base interface for running model inference.
2
-
3
- All kinfer models must implement this interface - the model inputs and outputs
4
- should match the provided schema, and the `__call__` method should take the
5
- inputs and return the outputs according to this schema.
6
- """
7
-
8
- import functools
9
- from abc import ABC, abstractmethod
10
-
11
- from kinfer import proto as K
12
-
13
-
14
- class KModel(ABC):
15
- """Base interface for running model inference."""
16
-
17
- @abstractmethod
18
- def get_schema(self) -> K.ModelSchema:
19
- """Get the model schema."""
20
-
21
- @abstractmethod
22
- def __call__(self, inputs: K.IO) -> K.IO:
23
- """Run inference on input data.
24
-
25
- Args:
26
- inputs: Input data, matching the input schema.
27
-
28
- Returns:
29
- Model outputs, matching the output schema.
30
- """
31
-
32
- @functools.cached_property
33
- def schema(self) -> K.ModelSchema:
34
- return self.get_schema()
35
-
36
- @property
37
- def input_schema(self) -> K.IOSchema:
38
- """Get the input schema."""
39
- return self.schema.input_schema
40
-
41
- @property
42
- def output_schema(self) -> K.IOSchema:
43
- """Get the output schema."""
44
- return self.schema.output_schema
45
-
46
- @property
47
- def schema_input_keys(self) -> list[str]:
48
- """Get all value names from input schemas.
49
-
50
- Returns:
51
- List of value names from input schema.
52
- """
53
- input_names = [value.value_name for value in self.input_schema.values]
54
- return input_names
55
-
56
- @property
57
- def schema_output_keys(self) -> list[str]:
58
- """Get all value names from output schemas.
59
-
60
- Returns:
61
- List of value names from output schema.
62
- """
63
- output_names = [value.value_name for value in self.output_schema.values]
64
- return output_names
@@ -1,66 +0,0 @@
1
- """ONNX model inference utilities for Python."""
2
-
3
- import base64
4
- from pathlib import Path
5
-
6
- import onnx
7
- import onnxruntime as ort
8
-
9
- from kinfer import proto as K
10
- from kinfer.export.pytorch import KINFER_METADATA_KEY
11
- from kinfer.inference.base import KModel
12
- from kinfer.serialize.numpy import NumpyMultiSerializer
13
-
14
-
15
- def _read_schema(model: onnx.ModelProto) -> K.ModelSchema:
16
- for prop in model.metadata_props:
17
- if prop.key == KINFER_METADATA_KEY:
18
- try:
19
- schema_bytes = base64.b64decode(prop.value)
20
- schema = K.ModelSchema()
21
- schema.ParseFromString(schema_bytes)
22
- return schema
23
- except Exception as e:
24
- raise ValueError("Failed to parse kinfer_metadata value") from e
25
- else:
26
- raise ValueError(f"Found arbitrary metadata key {prop.key}")
27
-
28
- raise ValueError(f"{KINFER_METADATA_KEY} not found in model metadata")
29
-
30
-
31
- class ONNXModel(KModel):
32
- """Wrapper for ONNX model inference."""
33
-
34
- def __init__(self: "ONNXModel", model_path: str | Path) -> None:
35
- """Initialize ONNX model.
36
-
37
- Args:
38
- model_path: Path to ONNX model file
39
- """
40
- self.model_path = model_path
41
-
42
- # Load model and create inference session
43
- self.model = onnx.load(model_path)
44
- self.session = ort.InferenceSession(model_path)
45
- self._schema = _read_schema(self.model)
46
-
47
- # Create serializers for input and output.
48
- self._input_serializer = NumpyMultiSerializer(self._schema.input_schema)
49
- self._output_serializer = NumpyMultiSerializer(self._schema.output_schema)
50
-
51
- def get_schema(self) -> K.ModelSchema:
52
- return self._schema
53
-
54
- def __call__(self, inputs: K.IO) -> K.IO:
55
- """Run inference on input data.
56
-
57
- Args:
58
- inputs: Input data, matching the input schema.
59
-
60
- Returns:
61
- Model outputs, matching the output schema.
62
- """
63
- inputs_np = self._input_serializer.serialize_io(inputs, as_dict=True)
64
- outputs_np = self.session.run(None, inputs_np)
65
- outputs = self._output_serializer.deserialize_io(outputs_np)
66
- return outputs
kinfer/proto/__init__.py DELETED
@@ -1,40 +0,0 @@
1
- """Defines helper types for the protocol buffers."""
2
-
3
- from .kinfer_pb2 import (
4
- IO,
5
- AudioFrameSchema,
6
- AudioFrameValue,
7
- CameraFrameSchema,
8
- CameraFrameValue,
9
- DType,
10
- ImuAccelerometerValue,
11
- ImuGyroscopeValue,
12
- ImuMagnetometerValue,
13
- ImuSchema,
14
- ImuValue,
15
- IOSchema,
16
- JointCommandsSchema,
17
- JointCommandsValue,
18
- JointCommandValue,
19
- JointPositionsSchema,
20
- JointPositionsValue,
21
- JointPositionUnit,
22
- JointPositionValue,
23
- JointTorquesSchema,
24
- JointTorquesValue,
25
- JointTorqueUnit,
26
- JointTorqueValue,
27
- JointVelocitiesSchema,
28
- JointVelocitiesValue,
29
- JointVelocityUnit,
30
- JointVelocityValue,
31
- ModelSchema,
32
- StateTensorSchema,
33
- StateTensorValue,
34
- TimestampSchema,
35
- TimestampValue,
36
- Value,
37
- ValueSchema,
38
- VectorCommandSchema,
39
- VectorCommandValue,
40
- )