kinfer 0.3.3__cp311-cp311-macosx_11_0_arm64.whl → 0.4.0__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. kinfer/__init__.py +0 -5
  2. kinfer/common/__init__.py +0 -0
  3. kinfer/common/types.py +11 -0
  4. kinfer/export/common.py +35 -0
  5. kinfer/export/jax.py +51 -0
  6. kinfer/export/pytorch.py +42 -110
  7. kinfer/export/serialize.py +86 -0
  8. kinfer/requirements.txt +3 -4
  9. kinfer/rust/Cargo.toml +8 -6
  10. kinfer/rust/src/lib.rs +2 -11
  11. kinfer/rust/src/model.rs +271 -121
  12. kinfer/rust/src/runtime.rs +104 -0
  13. kinfer/rust_bindings/Cargo.toml +8 -1
  14. kinfer/rust_bindings/rust_bindings.pyi +35 -0
  15. kinfer/rust_bindings/src/lib.rs +310 -1
  16. kinfer/rust_bindings.cpython-311-darwin.so +0 -0
  17. kinfer/rust_bindings.pyi +29 -1
  18. kinfer-0.4.0.dist-info/METADATA +55 -0
  19. kinfer-0.4.0.dist-info/RECORD +26 -0
  20. {kinfer-0.3.3.dist-info → kinfer-0.4.0.dist-info}/WHEEL +2 -1
  21. kinfer/inference/__init__.py +0 -2
  22. kinfer/inference/base.py +0 -64
  23. kinfer/inference/python.py +0 -66
  24. kinfer/proto/__init__.py +0 -40
  25. kinfer/proto/kinfer_pb2.py +0 -103
  26. kinfer/proto/kinfer_pb2.pyi +0 -1097
  27. kinfer/requirements-dev.txt +0 -8
  28. kinfer/rust/build.rs +0 -16
  29. kinfer/rust/src/kinfer_proto.rs +0 -14
  30. kinfer/rust/src/main.rs +0 -6
  31. kinfer/rust/src/onnx_serializer.rs +0 -804
  32. kinfer/rust/src/serializer.rs +0 -221
  33. kinfer/rust/src/tests/onnx_serializer_tests.rs +0 -212
  34. kinfer/serialize/__init__.py +0 -60
  35. kinfer/serialize/base.py +0 -536
  36. kinfer/serialize/json.py +0 -399
  37. kinfer/serialize/numpy.py +0 -426
  38. kinfer/serialize/pytorch.py +0 -402
  39. kinfer/serialize/schema.py +0 -125
  40. kinfer/serialize/types.py +0 -17
  41. kinfer/serialize/utils.py +0 -177
  42. kinfer-0.3.3.dist-info/METADATA +0 -57
  43. kinfer-0.3.3.dist-info/RECORD +0 -40
  44. {kinfer-0.3.3.dist-info → kinfer-0.4.0.dist-info/licenses}/LICENSE +0 -0
  45. {kinfer-0.3.3.dist-info → kinfer-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,221 +0,0 @@
1
- use crate::kinfer_proto::{
2
- AudioFrameSchema, AudioFrameValue, CameraFrameSchema, CameraFrameValue, ImuSchema, ImuValue,
3
- JointCommandsSchema, JointCommandsValue, JointPositionUnit, JointPositionsSchema,
4
- JointPositionsValue, JointTorqueUnit, JointTorquesSchema, JointTorquesValue,
5
- JointVelocitiesSchema, JointVelocitiesValue, JointVelocityUnit, ProtoIO, ProtoIOSchema,
6
- ProtoValue, StateTensorSchema, StateTensorValue, TimestampSchema, TimestampValue, ValueSchema,
7
- VectorCommandSchema, VectorCommandValue,
8
- };
9
-
10
- use ort::value::Value as OrtValue;
11
- use std::error::Error;
12
-
13
- pub trait JointPositionsSerializer {
14
- fn serialize_joint_positions(
15
- &self,
16
- schema: &JointPositionsSchema,
17
- value: JointPositionsValue,
18
- ) -> Result<OrtValue, Box<dyn Error>>;
19
-
20
- fn deserialize_joint_positions(
21
- &self,
22
- schema: &JointPositionsSchema,
23
- value: OrtValue,
24
- ) -> Result<JointPositionsValue, Box<dyn Error>>;
25
- }
26
-
27
- pub trait JointVelocitiesSerializer {
28
- fn serialize_joint_velocities(
29
- &self,
30
- schema: &JointVelocitiesSchema,
31
- value: JointVelocitiesValue,
32
- ) -> Result<OrtValue, Box<dyn Error>>;
33
-
34
- fn deserialize_joint_velocities(
35
- &self,
36
- schema: &JointVelocitiesSchema,
37
- value: OrtValue,
38
- ) -> Result<JointVelocitiesValue, Box<dyn Error>>;
39
- }
40
-
41
- pub trait JointTorquesSerializer {
42
- fn serialize_joint_torques(
43
- &self,
44
- schema: &JointTorquesSchema,
45
- value: JointTorquesValue,
46
- ) -> Result<OrtValue, Box<dyn Error>>;
47
-
48
- fn deserialize_joint_torques(
49
- &self,
50
- schema: &JointTorquesSchema,
51
- value: OrtValue,
52
- ) -> Result<JointTorquesValue, Box<dyn Error>>;
53
- }
54
-
55
- pub trait JointCommandsSerializer {
56
- fn serialize_joint_commands(
57
- &self,
58
- schema: &JointCommandsSchema,
59
- value: JointCommandsValue,
60
- ) -> Result<OrtValue, Box<dyn Error>>;
61
-
62
- fn deserialize_joint_commands(
63
- &self,
64
- schema: &JointCommandsSchema,
65
- value: OrtValue,
66
- ) -> Result<JointCommandsValue, Box<dyn Error>>;
67
- }
68
-
69
- pub trait CameraFrameSerializer {
70
- fn serialize_camera_frame(
71
- &self,
72
- schema: &CameraFrameSchema,
73
- value: CameraFrameValue,
74
- ) -> Result<OrtValue, Box<dyn Error>>;
75
-
76
- fn deserialize_camera_frame(
77
- &self,
78
- schema: &CameraFrameSchema,
79
- value: OrtValue,
80
- ) -> Result<CameraFrameValue, Box<dyn Error>>;
81
- }
82
-
83
- pub trait AudioFrameSerializer {
84
- fn serialize_audio_frame(
85
- &self,
86
- schema: &AudioFrameSchema,
87
- value: AudioFrameValue,
88
- ) -> Result<OrtValue, Box<dyn Error>>;
89
-
90
- fn deserialize_audio_frame(
91
- &self,
92
- schema: &AudioFrameSchema,
93
- value: OrtValue,
94
- ) -> Result<AudioFrameValue, Box<dyn Error>>;
95
- }
96
-
97
- pub trait ImuSerializer {
98
- fn serialize_imu(
99
- &self,
100
- schema: &ImuSchema,
101
- value: ImuValue,
102
- ) -> Result<OrtValue, Box<dyn Error>>;
103
-
104
- fn deserialize_imu(
105
- &self,
106
- schema: &ImuSchema,
107
- value: OrtValue,
108
- ) -> Result<ImuValue, Box<dyn Error>>;
109
- }
110
-
111
- pub trait TimestampSerializer {
112
- fn serialize_timestamp(
113
- &self,
114
- schema: &TimestampSchema,
115
- value: TimestampValue,
116
- ) -> Result<OrtValue, Box<dyn Error>>;
117
-
118
- fn deserialize_timestamp(
119
- &self,
120
- schema: &TimestampSchema,
121
- value: OrtValue,
122
- ) -> Result<TimestampValue, Box<dyn Error>>;
123
- }
124
-
125
- pub trait VectorCommandSerializer {
126
- fn serialize_vector_command(
127
- &self,
128
- schema: &VectorCommandSchema,
129
- value: VectorCommandValue,
130
- ) -> Result<OrtValue, Box<dyn Error>>;
131
-
132
- fn deserialize_vector_command(
133
- &self,
134
- schema: &VectorCommandSchema,
135
- value: OrtValue,
136
- ) -> Result<VectorCommandValue, Box<dyn Error>>;
137
- }
138
-
139
- pub trait StateTensorSerializer {
140
- fn serialize_state_tensor(
141
- &self,
142
- schema: &StateTensorSchema,
143
- value: StateTensorValue,
144
- ) -> Result<OrtValue, Box<dyn Error>>;
145
-
146
- fn deserialize_state_tensor(
147
- &self,
148
- schema: &StateTensorSchema,
149
- value: OrtValue,
150
- ) -> Result<StateTensorValue, Box<dyn Error>>;
151
- }
152
-
153
- pub trait Serializer:
154
- JointPositionsSerializer
155
- + JointVelocitiesSerializer
156
- + JointTorquesSerializer
157
- + JointCommandsSerializer
158
- + CameraFrameSerializer
159
- + AudioFrameSerializer
160
- + ImuSerializer
161
- + TimestampSerializer
162
- + VectorCommandSerializer
163
- + StateTensorSerializer
164
- {
165
- fn serialize(
166
- &self,
167
- schema: &ValueSchema,
168
- value: ProtoValue,
169
- ) -> Result<OrtValue, Box<dyn Error>>;
170
-
171
- fn deserialize(
172
- &self,
173
- schema: &ValueSchema,
174
- value: OrtValue,
175
- ) -> Result<ProtoValue, Box<dyn Error>>;
176
- }
177
-
178
- pub fn convert_position(
179
- value: f32,
180
- from_unit: JointPositionUnit,
181
- to_unit: JointPositionUnit,
182
- ) -> Result<f32, Box<dyn Error>> {
183
- match (from_unit, to_unit) {
184
- (JointPositionUnit::Radians, JointPositionUnit::Degrees) => {
185
- Ok(value * 180.0 / std::f32::consts::PI)
186
- }
187
- (JointPositionUnit::Degrees, JointPositionUnit::Radians) => {
188
- Ok(value * std::f32::consts::PI / 180.0)
189
- }
190
- (a, b) if a == b => Ok(value),
191
- _ => Err("Unsupported position unit conversion".into()),
192
- }
193
- }
194
-
195
- pub fn convert_velocity(
196
- value: f32,
197
- from_unit: JointVelocityUnit,
198
- to_unit: JointVelocityUnit,
199
- ) -> Result<f32, Box<dyn Error>> {
200
- match (from_unit, to_unit) {
201
- (JointVelocityUnit::RadiansPerSecond, JointVelocityUnit::DegreesPerSecond) => {
202
- Ok(value * 180.0 / std::f32::consts::PI)
203
- }
204
- (JointVelocityUnit::DegreesPerSecond, JointVelocityUnit::RadiansPerSecond) => {
205
- Ok(value * std::f32::consts::PI / 180.0)
206
- }
207
- (a, b) if a == b => Ok(value),
208
- _ => Err("Unsupported velocity unit conversion".into()),
209
- }
210
- }
211
-
212
- pub fn convert_torque(
213
- value: f32,
214
- from_unit: JointTorqueUnit,
215
- to_unit: JointTorqueUnit,
216
- ) -> Result<f32, Box<dyn Error>> {
217
- match (from_unit, to_unit) {
218
- (a, b) if a == b => Ok(value),
219
- _ => Err("Unsupported torque unit conversion".into()),
220
- }
221
- }
@@ -1,212 +0,0 @@
1
- use crate::{
2
- kinfer_proto::{
3
- self as P, AudioFrameSchema, AudioFrameValue, CameraFrameSchema, CameraFrameValue, DType,
4
- ImuAccelerometerValue, ImuGyroscopeValue, ImuMagnetometerValue, ImuSchema, ImuValue,
5
- JointCommandValue, JointCommandsSchema, JointCommandsValue, JointPositionUnit,
6
- JointPositionValue, JointPositionsSchema, JointPositionsValue, JointTorqueUnit,
7
- JointTorqueValue, JointTorquesSchema, JointTorquesValue, JointVelocitiesSchema,
8
- JointVelocitiesValue, JointVelocityUnit, JointVelocityValue, ProtoValue, StateTensorSchema,
9
- StateTensorValue, TimestampSchema, TimestampValue, ValueSchema, VectorCommandSchema,
10
- VectorCommandValue,
11
- },
12
- onnx_serializer::OnnxSerializer,
13
- serializer::{
14
- AudioFrameSerializer, CameraFrameSerializer, ImuSerializer, JointCommandsSerializer,
15
- JointPositionsSerializer, JointTorquesSerializer, JointVelocitiesSerializer,
16
- StateTensorSerializer, TimestampSerializer, VectorCommandSerializer,
17
- },
18
- };
19
-
20
- use ndarray::Array;
21
- use ort::value::Value as OrtValue;
22
- use std::f32::consts::PI;
23
-
24
- #[test]
25
- fn test_serialize_joint_positions() {
26
- let joint_names = vec![
27
- "joint_1".to_string(),
28
- "joint_2".to_string(),
29
- "joint_3".to_string(),
30
- ];
31
- let schema = ValueSchema {
32
- value_name: "test".to_string(),
33
- value_type: Some(P::proto::value_schema::ValueType::JointPositions(
34
- JointPositionsSchema {
35
- unit: JointPositionUnit::Degrees as i32,
36
- joint_names: joint_names.clone(),
37
- },
38
- )),
39
- };
40
-
41
- let serializer = OnnxSerializer::new(schema.clone());
42
-
43
- // Test with matching units
44
- let value = JointPositionsValue {
45
- values: vec![
46
- JointPositionValue {
47
- joint_name: "joint_1".to_string(),
48
- value: 60.0,
49
- unit: JointPositionUnit::Degrees as i32,
50
- },
51
- JointPositionValue {
52
- joint_name: "joint_2".to_string(),
53
- value: 30.0,
54
- unit: JointPositionUnit::Degrees as i32,
55
- },
56
- JointPositionValue {
57
- joint_name: "joint_3".to_string(),
58
- value: 90.0,
59
- unit: JointPositionUnit::Degrees as i32,
60
- },
61
- ],
62
- };
63
-
64
- let result = match schema.value_type.as_ref().unwrap() {
65
- P::proto::value_schema::ValueType::JointPositions(schema) => {
66
- serializer.serialize_joint_positions(schema, value.clone())
67
- }
68
- _ => panic!("Wrong schema type"),
69
- }
70
- .unwrap();
71
-
72
- // Verify tensor shape and values
73
- let tensor = result.try_extract_tensor::<f32>().unwrap();
74
- let array = tensor.view();
75
- assert_eq!(array.shape(), &[3]);
76
- assert_eq!(array[[0]], 60.0); // joint_1
77
- assert_eq!(array[[1]], 30.0); // joint_2
78
- assert_eq!(array[[2]], 90.0); // joint_3
79
-
80
- let deserialized = match schema.value_type.as_ref().unwrap() {
81
- P::proto::value_schema::ValueType::JointPositions(schema) => {
82
- serializer.deserialize_joint_positions(schema, result)
83
- }
84
- _ => panic!("Wrong schema type"),
85
- }
86
- .unwrap();
87
-
88
- // Verify full deserialization
89
- assert_eq!(deserialized.values.len(), value.values.len());
90
- for (expected, actual) in value.values.iter().zip(deserialized.values.iter()) {
91
- assert_eq!(expected.joint_name, actual.joint_name);
92
- assert_eq!(expected.value, actual.value);
93
- assert_eq!(expected.unit, actual.unit);
94
- }
95
-
96
- // Test unit conversion
97
- let value_radians = JointPositionsValue {
98
- values: vec![JointPositionValue {
99
- joint_name: "joint_1".to_string(),
100
- value: PI / 6.0,
101
- unit: JointPositionUnit::Radians as i32,
102
- }],
103
- };
104
-
105
- let schema_radians = ValueSchema {
106
- value_name: "test".to_string(),
107
- value_type: Some(P::proto::value_schema::ValueType::JointPositions(
108
- JointPositionsSchema {
109
- unit: JointPositionUnit::Radians as i32,
110
- joint_names: vec!["joint_1".to_string()],
111
- },
112
- )),
113
- };
114
-
115
- let serializer = OnnxSerializer::new(schema_radians.clone());
116
- let result = match schema_radians.value_type.as_ref().unwrap() {
117
- P::proto::value_schema::ValueType::JointPositions(schema) => {
118
- serializer.serialize_joint_positions(schema, value_radians.clone())
119
- }
120
- _ => panic!("Wrong schema type"),
121
- }
122
- .unwrap();
123
-
124
- let tensor = result.try_extract_tensor::<f32>().unwrap();
125
- let array = tensor.view();
126
- assert!((array[[0]] - PI / 6.0).abs() < 1e-6);
127
- }
128
-
129
- #[test]
130
- fn test_serialize_joint_positions_errors() {
131
- let schema = ValueSchema {
132
- value_name: "test".to_string(),
133
- value_type: Some(P::proto::value_schema::ValueType::JointPositions(
134
- JointPositionsSchema {
135
- unit: JointPositionUnit::Degrees as i32,
136
- joint_names: vec!["joint_1".to_string(), "joint_2".to_string()],
137
- },
138
- )),
139
- };
140
-
141
- let serializer = OnnxSerializer::new(schema.clone());
142
-
143
- // Test cases that should fail:
144
-
145
- // Case 1: Wrong number of joints
146
- let value_wrong_count = JointPositionsValue {
147
- values: vec![JointPositionValue {
148
- joint_name: "joint_1".to_string(),
149
- value: 60.0,
150
- unit: JointPositionUnit::Degrees as i32,
151
- }],
152
- };
153
-
154
- let result = match schema.value_type.as_ref().unwrap() {
155
- P::proto::value_schema::ValueType::JointPositions(schema) => {
156
- serializer.serialize_joint_positions(schema, value_wrong_count)
157
- }
158
- _ => panic!("Wrong schema type"),
159
- };
160
- assert!(
161
- result.is_err(),
162
- "Should fail when joint count doesn't match"
163
- );
164
-
165
- // Case 2: Wrong joint names
166
- let value_wrong_names = JointPositionsValue {
167
- values: vec![
168
- JointPositionValue {
169
- joint_name: "wrong_joint_1".to_string(),
170
- value: 60.0,
171
- unit: JointPositionUnit::Degrees as i32,
172
- },
173
- JointPositionValue {
174
- joint_name: "wrong_joint_2".to_string(),
175
- value: 30.0,
176
- unit: JointPositionUnit::Degrees as i32,
177
- },
178
- ],
179
- };
180
-
181
- let result = match schema.value_type.as_ref().unwrap() {
182
- P::proto::value_schema::ValueType::JointPositions(schema) => {
183
- serializer.serialize_joint_positions(schema, value_wrong_names)
184
- }
185
- _ => panic!("Wrong schema type"),
186
- };
187
- assert!(result.is_err(), "Should fail when joint names don't match");
188
-
189
- // Case 3: Wrong unit type
190
- let value_wrong_unit = JointPositionsValue {
191
- values: vec![
192
- JointPositionValue {
193
- joint_name: "joint_1".to_string(),
194
- value: 60.0,
195
- unit: JointPositionUnit::Radians as i32,
196
- },
197
- JointPositionValue {
198
- joint_name: "joint_2".to_string(),
199
- value: 30.0,
200
- unit: JointPositionUnit::Radians as i32,
201
- },
202
- ],
203
- };
204
-
205
- let result = match schema.value_type.as_ref().unwrap() {
206
- P::proto::value_schema::ValueType::JointPositions(schema) => {
207
- serializer.serialize_joint_positions(schema, value_wrong_unit)
208
- }
209
- _ => panic!("Wrong schema type"),
210
- };
211
- assert!(result.is_err(), "Should fail when units don't match");
212
- }
@@ -1,60 +0,0 @@
1
- """Defines an interface for instantiating serializers."""
2
-
3
- from typing import Literal, overload
4
-
5
- from kinfer import proto as K
6
-
7
- from .base import MultiSerializer, Serializer
8
- from .json import JsonMultiSerializer, JsonSerializer
9
- from .numpy import NumpyMultiSerializer, NumpySerializer
10
- from .pytorch import PyTorchMultiSerializer, PyTorchSerializer
11
-
12
- SerializerType = Literal["json", "numpy", "pytorch"]
13
-
14
-
15
- @overload
16
- def get_serializer(schema: K.ValueSchema, serializer_type: Literal["json"]) -> JsonSerializer: ...
17
-
18
-
19
- @overload
20
- def get_serializer(schema: K.ValueSchema, serializer_type: Literal["numpy"]) -> NumpySerializer: ...
21
-
22
-
23
- @overload
24
- def get_serializer(schema: K.ValueSchema, serializer_type: Literal["pytorch"]) -> PyTorchSerializer: ...
25
-
26
-
27
- def get_serializer(schema: K.ValueSchema, serializer_type: SerializerType) -> Serializer:
28
- match serializer_type:
29
- case "json":
30
- return JsonSerializer(schema=schema)
31
- case "numpy":
32
- return NumpySerializer(schema=schema)
33
- case "pytorch":
34
- return PyTorchSerializer(schema=schema)
35
- case _:
36
- raise ValueError(f"Unsupported serializer type: {serializer_type}")
37
-
38
-
39
- @overload
40
- def get_multi_serializer(schema: K.IOSchema, serializer_type: Literal["json"]) -> JsonMultiSerializer: ...
41
-
42
-
43
- @overload
44
- def get_multi_serializer(schema: K.IOSchema, serializer_type: Literal["numpy"]) -> NumpyMultiSerializer: ...
45
-
46
-
47
- @overload
48
- def get_multi_serializer(schema: K.IOSchema, serializer_type: Literal["pytorch"]) -> PyTorchMultiSerializer: ...
49
-
50
-
51
- def get_multi_serializer(schema: K.IOSchema, serializer_type: SerializerType) -> MultiSerializer:
52
- match serializer_type:
53
- case "json":
54
- return JsonMultiSerializer(schema=schema)
55
- case "numpy":
56
- return NumpyMultiSerializer(schema=schema)
57
- case "pytorch":
58
- return PyTorchMultiSerializer(schema=schema)
59
- case _:
60
- raise ValueError(f"Unsupported serializer type: {serializer_type}")