kinfer 0.3.1__cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,221 @@
1
+ use crate::kinfer_proto::{
2
+ AudioFrameSchema, AudioFrameValue, CameraFrameSchema, CameraFrameValue, ImuSchema, ImuValue,
3
+ JointCommandsSchema, JointCommandsValue, JointPositionUnit, JointPositionsSchema,
4
+ JointPositionsValue, JointTorqueUnit, JointTorquesSchema, JointTorquesValue,
5
+ JointVelocitiesSchema, JointVelocitiesValue, JointVelocityUnit, ProtoIO, ProtoIOSchema,
6
+ ProtoValue, StateTensorSchema, StateTensorValue, TimestampSchema, TimestampValue, ValueSchema,
7
+ VectorCommandSchema, VectorCommandValue,
8
+ };
9
+
10
+ use ort::value::Value as OrtValue;
11
+ use std::error::Error;
12
+
13
+ pub trait JointPositionsSerializer {
14
+ fn serialize_joint_positions(
15
+ &self,
16
+ schema: &JointPositionsSchema,
17
+ value: JointPositionsValue,
18
+ ) -> Result<OrtValue, Box<dyn Error>>;
19
+
20
+ fn deserialize_joint_positions(
21
+ &self,
22
+ schema: &JointPositionsSchema,
23
+ value: OrtValue,
24
+ ) -> Result<JointPositionsValue, Box<dyn Error>>;
25
+ }
26
+
27
+ pub trait JointVelocitiesSerializer {
28
+ fn serialize_joint_velocities(
29
+ &self,
30
+ schema: &JointVelocitiesSchema,
31
+ value: JointVelocitiesValue,
32
+ ) -> Result<OrtValue, Box<dyn Error>>;
33
+
34
+ fn deserialize_joint_velocities(
35
+ &self,
36
+ schema: &JointVelocitiesSchema,
37
+ value: OrtValue,
38
+ ) -> Result<JointVelocitiesValue, Box<dyn Error>>;
39
+ }
40
+
41
+ pub trait JointTorquesSerializer {
42
+ fn serialize_joint_torques(
43
+ &self,
44
+ schema: &JointTorquesSchema,
45
+ value: JointTorquesValue,
46
+ ) -> Result<OrtValue, Box<dyn Error>>;
47
+
48
+ fn deserialize_joint_torques(
49
+ &self,
50
+ schema: &JointTorquesSchema,
51
+ value: OrtValue,
52
+ ) -> Result<JointTorquesValue, Box<dyn Error>>;
53
+ }
54
+
55
+ pub trait JointCommandsSerializer {
56
+ fn serialize_joint_commands(
57
+ &self,
58
+ schema: &JointCommandsSchema,
59
+ value: JointCommandsValue,
60
+ ) -> Result<OrtValue, Box<dyn Error>>;
61
+
62
+ fn deserialize_joint_commands(
63
+ &self,
64
+ schema: &JointCommandsSchema,
65
+ value: OrtValue,
66
+ ) -> Result<JointCommandsValue, Box<dyn Error>>;
67
+ }
68
+
69
+ pub trait CameraFrameSerializer {
70
+ fn serialize_camera_frame(
71
+ &self,
72
+ schema: &CameraFrameSchema,
73
+ value: CameraFrameValue,
74
+ ) -> Result<OrtValue, Box<dyn Error>>;
75
+
76
+ fn deserialize_camera_frame(
77
+ &self,
78
+ schema: &CameraFrameSchema,
79
+ value: OrtValue,
80
+ ) -> Result<CameraFrameValue, Box<dyn Error>>;
81
+ }
82
+
83
+ pub trait AudioFrameSerializer {
84
+ fn serialize_audio_frame(
85
+ &self,
86
+ schema: &AudioFrameSchema,
87
+ value: AudioFrameValue,
88
+ ) -> Result<OrtValue, Box<dyn Error>>;
89
+
90
+ fn deserialize_audio_frame(
91
+ &self,
92
+ schema: &AudioFrameSchema,
93
+ value: OrtValue,
94
+ ) -> Result<AudioFrameValue, Box<dyn Error>>;
95
+ }
96
+
97
+ pub trait ImuSerializer {
98
+ fn serialize_imu(
99
+ &self,
100
+ schema: &ImuSchema,
101
+ value: ImuValue,
102
+ ) -> Result<OrtValue, Box<dyn Error>>;
103
+
104
+ fn deserialize_imu(
105
+ &self,
106
+ schema: &ImuSchema,
107
+ value: OrtValue,
108
+ ) -> Result<ImuValue, Box<dyn Error>>;
109
+ }
110
+
111
+ pub trait TimestampSerializer {
112
+ fn serialize_timestamp(
113
+ &self,
114
+ schema: &TimestampSchema,
115
+ value: TimestampValue,
116
+ ) -> Result<OrtValue, Box<dyn Error>>;
117
+
118
+ fn deserialize_timestamp(
119
+ &self,
120
+ schema: &TimestampSchema,
121
+ value: OrtValue,
122
+ ) -> Result<TimestampValue, Box<dyn Error>>;
123
+ }
124
+
125
+ pub trait VectorCommandSerializer {
126
+ fn serialize_vector_command(
127
+ &self,
128
+ schema: &VectorCommandSchema,
129
+ value: VectorCommandValue,
130
+ ) -> Result<OrtValue, Box<dyn Error>>;
131
+
132
+ fn deserialize_vector_command(
133
+ &self,
134
+ schema: &VectorCommandSchema,
135
+ value: OrtValue,
136
+ ) -> Result<VectorCommandValue, Box<dyn Error>>;
137
+ }
138
+
139
+ pub trait StateTensorSerializer {
140
+ fn serialize_state_tensor(
141
+ &self,
142
+ schema: &StateTensorSchema,
143
+ value: StateTensorValue,
144
+ ) -> Result<OrtValue, Box<dyn Error>>;
145
+
146
+ fn deserialize_state_tensor(
147
+ &self,
148
+ schema: &StateTensorSchema,
149
+ value: OrtValue,
150
+ ) -> Result<StateTensorValue, Box<dyn Error>>;
151
+ }
152
+
153
+ pub trait Serializer:
154
+ JointPositionsSerializer
155
+ + JointVelocitiesSerializer
156
+ + JointTorquesSerializer
157
+ + JointCommandsSerializer
158
+ + CameraFrameSerializer
159
+ + AudioFrameSerializer
160
+ + ImuSerializer
161
+ + TimestampSerializer
162
+ + VectorCommandSerializer
163
+ + StateTensorSerializer
164
+ {
165
+ fn serialize(
166
+ &self,
167
+ schema: &ValueSchema,
168
+ value: ProtoValue,
169
+ ) -> Result<OrtValue, Box<dyn Error>>;
170
+
171
+ fn deserialize(
172
+ &self,
173
+ schema: &ValueSchema,
174
+ value: OrtValue,
175
+ ) -> Result<ProtoValue, Box<dyn Error>>;
176
+ }
177
+
178
+ pub fn convert_position(
179
+ value: f32,
180
+ from_unit: JointPositionUnit,
181
+ to_unit: JointPositionUnit,
182
+ ) -> Result<f32, Box<dyn Error>> {
183
+ match (from_unit, to_unit) {
184
+ (JointPositionUnit::Radians, JointPositionUnit::Degrees) => {
185
+ Ok(value * 180.0 / std::f32::consts::PI)
186
+ }
187
+ (JointPositionUnit::Degrees, JointPositionUnit::Radians) => {
188
+ Ok(value * std::f32::consts::PI / 180.0)
189
+ }
190
+ (a, b) if a == b => Ok(value),
191
+ _ => Err("Unsupported position unit conversion".into()),
192
+ }
193
+ }
194
+
195
+ pub fn convert_velocity(
196
+ value: f32,
197
+ from_unit: JointVelocityUnit,
198
+ to_unit: JointVelocityUnit,
199
+ ) -> Result<f32, Box<dyn Error>> {
200
+ match (from_unit, to_unit) {
201
+ (JointVelocityUnit::RadiansPerSecond, JointVelocityUnit::DegreesPerSecond) => {
202
+ Ok(value * 180.0 / std::f32::consts::PI)
203
+ }
204
+ (JointVelocityUnit::DegreesPerSecond, JointVelocityUnit::RadiansPerSecond) => {
205
+ Ok(value * std::f32::consts::PI / 180.0)
206
+ }
207
+ (a, b) if a == b => Ok(value),
208
+ _ => Err("Unsupported velocity unit conversion".into()),
209
+ }
210
+ }
211
+
212
+ pub fn convert_torque(
213
+ value: f32,
214
+ from_unit: JointTorqueUnit,
215
+ to_unit: JointTorqueUnit,
216
+ ) -> Result<f32, Box<dyn Error>> {
217
+ match (from_unit, to_unit) {
218
+ (a, b) if a == b => Ok(value),
219
+ _ => Err("Unsupported torque unit conversion".into()),
220
+ }
221
+ }
@@ -0,0 +1,212 @@
1
+ use crate::{
2
+ kinfer_proto::{
3
+ self as P, AudioFrameSchema, AudioFrameValue, CameraFrameSchema, CameraFrameValue, DType,
4
+ ImuAccelerometerValue, ImuGyroscopeValue, ImuMagnetometerValue, ImuSchema, ImuValue,
5
+ JointCommandValue, JointCommandsSchema, JointCommandsValue, JointPositionUnit,
6
+ JointPositionValue, JointPositionsSchema, JointPositionsValue, JointTorqueUnit,
7
+ JointTorqueValue, JointTorquesSchema, JointTorquesValue, JointVelocitiesSchema,
8
+ JointVelocitiesValue, JointVelocityUnit, JointVelocityValue, ProtoValue, StateTensorSchema,
9
+ StateTensorValue, TimestampSchema, TimestampValue, ValueSchema, VectorCommandSchema,
10
+ VectorCommandValue,
11
+ },
12
+ onnx_serializer::OnnxSerializer,
13
+ serializer::{
14
+ AudioFrameSerializer, CameraFrameSerializer, ImuSerializer, JointCommandsSerializer,
15
+ JointPositionsSerializer, JointTorquesSerializer, JointVelocitiesSerializer,
16
+ StateTensorSerializer, TimestampSerializer, VectorCommandSerializer,
17
+ },
18
+ };
19
+
20
+ use ndarray::Array;
21
+ use ort::value::Value as OrtValue;
22
+ use std::f32::consts::PI;
23
+
24
+ #[test]
25
+ fn test_serialize_joint_positions() {
26
+ let joint_names = vec![
27
+ "joint_1".to_string(),
28
+ "joint_2".to_string(),
29
+ "joint_3".to_string(),
30
+ ];
31
+ let schema = ValueSchema {
32
+ value_name: "test".to_string(),
33
+ value_type: Some(P::proto::value_schema::ValueType::JointPositions(
34
+ JointPositionsSchema {
35
+ unit: JointPositionUnit::Degrees as i32,
36
+ joint_names: joint_names.clone(),
37
+ },
38
+ )),
39
+ };
40
+
41
+ let serializer = OnnxSerializer::new(schema.clone());
42
+
43
+ // Test with matching units
44
+ let value = JointPositionsValue {
45
+ values: vec![
46
+ JointPositionValue {
47
+ joint_name: "joint_1".to_string(),
48
+ value: 60.0,
49
+ unit: JointPositionUnit::Degrees as i32,
50
+ },
51
+ JointPositionValue {
52
+ joint_name: "joint_2".to_string(),
53
+ value: 30.0,
54
+ unit: JointPositionUnit::Degrees as i32,
55
+ },
56
+ JointPositionValue {
57
+ joint_name: "joint_3".to_string(),
58
+ value: 90.0,
59
+ unit: JointPositionUnit::Degrees as i32,
60
+ },
61
+ ],
62
+ };
63
+
64
+ let result = match schema.value_type.as_ref().unwrap() {
65
+ P::proto::value_schema::ValueType::JointPositions(schema) => {
66
+ serializer.serialize_joint_positions(schema, value.clone())
67
+ }
68
+ _ => panic!("Wrong schema type"),
69
+ }
70
+ .unwrap();
71
+
72
+ // Verify tensor shape and values
73
+ let tensor = result.try_extract_tensor::<f32>().unwrap();
74
+ let array = tensor.view();
75
+ assert_eq!(array.shape(), &[3]);
76
+ assert_eq!(array[[0]], 60.0); // joint_1
77
+ assert_eq!(array[[1]], 30.0); // joint_2
78
+ assert_eq!(array[[2]], 90.0); // joint_3
79
+
80
+ let deserialized = match schema.value_type.as_ref().unwrap() {
81
+ P::proto::value_schema::ValueType::JointPositions(schema) => {
82
+ serializer.deserialize_joint_positions(schema, result)
83
+ }
84
+ _ => panic!("Wrong schema type"),
85
+ }
86
+ .unwrap();
87
+
88
+ // Verify full deserialization
89
+ assert_eq!(deserialized.values.len(), value.values.len());
90
+ for (expected, actual) in value.values.iter().zip(deserialized.values.iter()) {
91
+ assert_eq!(expected.joint_name, actual.joint_name);
92
+ assert_eq!(expected.value, actual.value);
93
+ assert_eq!(expected.unit, actual.unit);
94
+ }
95
+
96
+ // Test unit conversion
97
+ let value_radians = JointPositionsValue {
98
+ values: vec![JointPositionValue {
99
+ joint_name: "joint_1".to_string(),
100
+ value: PI / 6.0,
101
+ unit: JointPositionUnit::Radians as i32,
102
+ }],
103
+ };
104
+
105
+ let schema_radians = ValueSchema {
106
+ value_name: "test".to_string(),
107
+ value_type: Some(P::proto::value_schema::ValueType::JointPositions(
108
+ JointPositionsSchema {
109
+ unit: JointPositionUnit::Radians as i32,
110
+ joint_names: vec!["joint_1".to_string()],
111
+ },
112
+ )),
113
+ };
114
+
115
+ let serializer = OnnxSerializer::new(schema_radians.clone());
116
+ let result = match schema_radians.value_type.as_ref().unwrap() {
117
+ P::proto::value_schema::ValueType::JointPositions(schema) => {
118
+ serializer.serialize_joint_positions(schema, value_radians.clone())
119
+ }
120
+ _ => panic!("Wrong schema type"),
121
+ }
122
+ .unwrap();
123
+
124
+ let tensor = result.try_extract_tensor::<f32>().unwrap();
125
+ let array = tensor.view();
126
+ assert!((array[[0]] - PI / 6.0).abs() < 1e-6);
127
+ }
128
+
129
+ #[test]
130
+ fn test_serialize_joint_positions_errors() {
131
+ let schema = ValueSchema {
132
+ value_name: "test".to_string(),
133
+ value_type: Some(P::proto::value_schema::ValueType::JointPositions(
134
+ JointPositionsSchema {
135
+ unit: JointPositionUnit::Degrees as i32,
136
+ joint_names: vec!["joint_1".to_string(), "joint_2".to_string()],
137
+ },
138
+ )),
139
+ };
140
+
141
+ let serializer = OnnxSerializer::new(schema.clone());
142
+
143
+ // Test cases that should fail:
144
+
145
+ // Case 1: Wrong number of joints
146
+ let value_wrong_count = JointPositionsValue {
147
+ values: vec![JointPositionValue {
148
+ joint_name: "joint_1".to_string(),
149
+ value: 60.0,
150
+ unit: JointPositionUnit::Degrees as i32,
151
+ }],
152
+ };
153
+
154
+ let result = match schema.value_type.as_ref().unwrap() {
155
+ P::proto::value_schema::ValueType::JointPositions(schema) => {
156
+ serializer.serialize_joint_positions(schema, value_wrong_count)
157
+ }
158
+ _ => panic!("Wrong schema type"),
159
+ };
160
+ assert!(
161
+ result.is_err(),
162
+ "Should fail when joint count doesn't match"
163
+ );
164
+
165
+ // Case 2: Wrong joint names
166
+ let value_wrong_names = JointPositionsValue {
167
+ values: vec![
168
+ JointPositionValue {
169
+ joint_name: "wrong_joint_1".to_string(),
170
+ value: 60.0,
171
+ unit: JointPositionUnit::Degrees as i32,
172
+ },
173
+ JointPositionValue {
174
+ joint_name: "wrong_joint_2".to_string(),
175
+ value: 30.0,
176
+ unit: JointPositionUnit::Degrees as i32,
177
+ },
178
+ ],
179
+ };
180
+
181
+ let result = match schema.value_type.as_ref().unwrap() {
182
+ P::proto::value_schema::ValueType::JointPositions(schema) => {
183
+ serializer.serialize_joint_positions(schema, value_wrong_names)
184
+ }
185
+ _ => panic!("Wrong schema type"),
186
+ };
187
+ assert!(result.is_err(), "Should fail when joint names don't match");
188
+
189
+ // Case 3: Wrong unit type
190
+ let value_wrong_unit = JointPositionsValue {
191
+ values: vec![
192
+ JointPositionValue {
193
+ joint_name: "joint_1".to_string(),
194
+ value: 60.0,
195
+ unit: JointPositionUnit::Radians as i32,
196
+ },
197
+ JointPositionValue {
198
+ joint_name: "joint_2".to_string(),
199
+ value: 30.0,
200
+ unit: JointPositionUnit::Radians as i32,
201
+ },
202
+ ],
203
+ };
204
+
205
+ let result = match schema.value_type.as_ref().unwrap() {
206
+ P::proto::value_schema::ValueType::JointPositions(schema) => {
207
+ serializer.serialize_joint_positions(schema, value_wrong_unit)
208
+ }
209
+ _ => panic!("Wrong schema type"),
210
+ };
211
+ assert!(result.is_err(), "Should fail when units don't match");
212
+ }
@@ -0,0 +1,19 @@
1
+ [package]
2
+
3
+ name = "rust_bindings"
4
+ version.workspace = true
5
+ edition.workspace = true
6
+ description.workspace = true
7
+ authors.workspace = true
8
+ repository.workspace = true
9
+ license.workspace = true
10
+ readme.workspace = true
11
+
12
+ [lib]
13
+
14
+ name = "rust_bindings"
15
+ crate-type = ["cdylib", "rlib"]
16
+
17
+ [dependencies]
18
+ pyo3 = { version = ">= 0.21.0", features = ["extension-module"] }
19
+ pyo3-stub-gen = ">= 0.6.0"
@@ -0,0 +1,7 @@
1
+ [build-system]
2
+ requires = ["maturin>=1.1,<2.0"]
3
+ build-backend = "maturin"
4
+
5
+ [project]
6
+ name = "rust_bindings"
7
+ requires-python = ">=3.11"
@@ -0,0 +1,7 @@
1
+ use pyo3_stub_gen::Result;
2
+
3
+ fn main() -> Result<()> {
4
+ let stub = rust_bindings::stub_info()?;
5
+ stub.generate()?;
6
+ Ok(())
7
+ }
@@ -0,0 +1,17 @@
1
+ use pyo3::prelude::*;
2
+ use pyo3_stub_gen::define_stub_info_gatherer;
3
+ use pyo3_stub_gen::derive::gen_stub_pyfunction;
4
+
5
+ #[pyfunction]
6
+ #[gen_stub_pyfunction]
7
+ fn get_version() -> String {
8
+ env!("CARGO_PKG_VERSION").to_string()
9
+ }
10
+
11
+ #[pymodule]
12
+ fn rust_bindings(m: &Bound<PyModule>) -> PyResult<()> {
13
+ m.add_function(wrap_pyfunction!(get_version, m)?)?;
14
+ Ok(())
15
+ }
16
+
17
+ define_stub_info_gatherer!(stub_info);
@@ -0,0 +1,7 @@
1
+ # This file is automatically generated by pyo3_stub_gen
2
+ # ruff: noqa: E501, F401
3
+
4
+
5
+ def get_version() -> str:
6
+ ...
7
+
@@ -0,0 +1,36 @@
1
+ """Defines an interface for instantiating serializers."""
2
+
3
+ from typing import Literal
4
+
5
+ from kinfer import proto as P
6
+
7
+ from .base import MultiSerializer, Serializer
8
+ from .json import JsonMultiSerializer, JsonSerializer
9
+ from .numpy import NumpyMultiSerializer, NumpySerializer
10
+ from .pytorch import PyTorchMultiSerializer, PyTorchSerializer
11
+
12
+ SerializerType = Literal["json", "numpy", "pytorch"]
13
+
14
+
15
+ def get_serializer(schema: P.ValueSchema, serializer_type: SerializerType) -> Serializer:
16
+ match serializer_type:
17
+ case "json":
18
+ return JsonSerializer(schema=schema)
19
+ case "numpy":
20
+ return NumpySerializer(schema=schema)
21
+ case "pytorch":
22
+ return PyTorchSerializer(schema=schema)
23
+ case _:
24
+ raise ValueError(f"Unsupported serializer type: {serializer_type}")
25
+
26
+
27
+ def get_multi_serializer(schema: P.IOSchema, serializer_type: SerializerType) -> MultiSerializer:
28
+ match serializer_type:
29
+ case "json":
30
+ return JsonMultiSerializer(schema=schema)
31
+ case "numpy":
32
+ return NumpyMultiSerializer(schema=schema)
33
+ case "pytorch":
34
+ return PyTorchMultiSerializer(schema=schema)
35
+ case _:
36
+ raise ValueError(f"Unsupported serializer type: {serializer_type}")