kinfer 0.3.3__cp312-cp312-macosx_11_0_arm64.whl → 0.4.1__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kinfer/__init__.py +0 -5
- kinfer/common/__init__.py +0 -0
- kinfer/common/types.py +12 -0
- kinfer/export/common.py +41 -0
- kinfer/export/jax.py +53 -0
- kinfer/export/pytorch.py +45 -110
- kinfer/export/serialize.py +93 -0
- kinfer/requirements.txt +3 -4
- kinfer/rust/Cargo.toml +20 -8
- kinfer/rust/src/lib.rs +2 -11
- kinfer/rust/src/model.rs +286 -121
- kinfer/rust/src/runtime.rs +104 -0
- kinfer/rust_bindings/Cargo.toml +8 -1
- kinfer/rust_bindings/rust_bindings.pyi +36 -0
- kinfer/rust_bindings/src/lib.rs +326 -1
- kinfer/rust_bindings.cpython-312-darwin.so +0 -0
- kinfer/rust_bindings.pyi +30 -1
- kinfer-0.4.1.dist-info/METADATA +55 -0
- kinfer-0.4.1.dist-info/RECORD +26 -0
- {kinfer-0.3.3.dist-info → kinfer-0.4.1.dist-info}/WHEEL +2 -1
- kinfer/inference/__init__.py +0 -2
- kinfer/inference/base.py +0 -64
- kinfer/inference/python.py +0 -66
- kinfer/proto/__init__.py +0 -40
- kinfer/proto/kinfer_pb2.py +0 -103
- kinfer/proto/kinfer_pb2.pyi +0 -1097
- kinfer/requirements-dev.txt +0 -8
- kinfer/rust/build.rs +0 -16
- kinfer/rust/src/kinfer_proto.rs +0 -14
- kinfer/rust/src/main.rs +0 -6
- kinfer/rust/src/onnx_serializer.rs +0 -804
- kinfer/rust/src/serializer.rs +0 -221
- kinfer/rust/src/tests/onnx_serializer_tests.rs +0 -212
- kinfer/serialize/__init__.py +0 -60
- kinfer/serialize/base.py +0 -536
- kinfer/serialize/json.py +0 -399
- kinfer/serialize/numpy.py +0 -426
- kinfer/serialize/pytorch.py +0 -402
- kinfer/serialize/schema.py +0 -125
- kinfer/serialize/types.py +0 -17
- kinfer/serialize/utils.py +0 -177
- kinfer-0.3.3.dist-info/METADATA +0 -57
- kinfer-0.3.3.dist-info/RECORD +0 -40
- {kinfer-0.3.3.dist-info → kinfer-0.4.1.dist-info/licenses}/LICENSE +0 -0
- {kinfer-0.3.3.dist-info → kinfer-0.4.1.dist-info}/top_level.txt +0 -0
kinfer/rust/src/serializer.rs
DELETED
@@ -1,221 +0,0 @@
|
|
1
|
-
use crate::kinfer_proto::{
|
2
|
-
AudioFrameSchema, AudioFrameValue, CameraFrameSchema, CameraFrameValue, ImuSchema, ImuValue,
|
3
|
-
JointCommandsSchema, JointCommandsValue, JointPositionUnit, JointPositionsSchema,
|
4
|
-
JointPositionsValue, JointTorqueUnit, JointTorquesSchema, JointTorquesValue,
|
5
|
-
JointVelocitiesSchema, JointVelocitiesValue, JointVelocityUnit, ProtoIO, ProtoIOSchema,
|
6
|
-
ProtoValue, StateTensorSchema, StateTensorValue, TimestampSchema, TimestampValue, ValueSchema,
|
7
|
-
VectorCommandSchema, VectorCommandValue,
|
8
|
-
};
|
9
|
-
|
10
|
-
use ort::value::Value as OrtValue;
|
11
|
-
use std::error::Error;
|
12
|
-
|
13
|
-
pub trait JointPositionsSerializer {
|
14
|
-
fn serialize_joint_positions(
|
15
|
-
&self,
|
16
|
-
schema: &JointPositionsSchema,
|
17
|
-
value: JointPositionsValue,
|
18
|
-
) -> Result<OrtValue, Box<dyn Error>>;
|
19
|
-
|
20
|
-
fn deserialize_joint_positions(
|
21
|
-
&self,
|
22
|
-
schema: &JointPositionsSchema,
|
23
|
-
value: OrtValue,
|
24
|
-
) -> Result<JointPositionsValue, Box<dyn Error>>;
|
25
|
-
}
|
26
|
-
|
27
|
-
pub trait JointVelocitiesSerializer {
|
28
|
-
fn serialize_joint_velocities(
|
29
|
-
&self,
|
30
|
-
schema: &JointVelocitiesSchema,
|
31
|
-
value: JointVelocitiesValue,
|
32
|
-
) -> Result<OrtValue, Box<dyn Error>>;
|
33
|
-
|
34
|
-
fn deserialize_joint_velocities(
|
35
|
-
&self,
|
36
|
-
schema: &JointVelocitiesSchema,
|
37
|
-
value: OrtValue,
|
38
|
-
) -> Result<JointVelocitiesValue, Box<dyn Error>>;
|
39
|
-
}
|
40
|
-
|
41
|
-
pub trait JointTorquesSerializer {
|
42
|
-
fn serialize_joint_torques(
|
43
|
-
&self,
|
44
|
-
schema: &JointTorquesSchema,
|
45
|
-
value: JointTorquesValue,
|
46
|
-
) -> Result<OrtValue, Box<dyn Error>>;
|
47
|
-
|
48
|
-
fn deserialize_joint_torques(
|
49
|
-
&self,
|
50
|
-
schema: &JointTorquesSchema,
|
51
|
-
value: OrtValue,
|
52
|
-
) -> Result<JointTorquesValue, Box<dyn Error>>;
|
53
|
-
}
|
54
|
-
|
55
|
-
pub trait JointCommandsSerializer {
|
56
|
-
fn serialize_joint_commands(
|
57
|
-
&self,
|
58
|
-
schema: &JointCommandsSchema,
|
59
|
-
value: JointCommandsValue,
|
60
|
-
) -> Result<OrtValue, Box<dyn Error>>;
|
61
|
-
|
62
|
-
fn deserialize_joint_commands(
|
63
|
-
&self,
|
64
|
-
schema: &JointCommandsSchema,
|
65
|
-
value: OrtValue,
|
66
|
-
) -> Result<JointCommandsValue, Box<dyn Error>>;
|
67
|
-
}
|
68
|
-
|
69
|
-
pub trait CameraFrameSerializer {
|
70
|
-
fn serialize_camera_frame(
|
71
|
-
&self,
|
72
|
-
schema: &CameraFrameSchema,
|
73
|
-
value: CameraFrameValue,
|
74
|
-
) -> Result<OrtValue, Box<dyn Error>>;
|
75
|
-
|
76
|
-
fn deserialize_camera_frame(
|
77
|
-
&self,
|
78
|
-
schema: &CameraFrameSchema,
|
79
|
-
value: OrtValue,
|
80
|
-
) -> Result<CameraFrameValue, Box<dyn Error>>;
|
81
|
-
}
|
82
|
-
|
83
|
-
pub trait AudioFrameSerializer {
|
84
|
-
fn serialize_audio_frame(
|
85
|
-
&self,
|
86
|
-
schema: &AudioFrameSchema,
|
87
|
-
value: AudioFrameValue,
|
88
|
-
) -> Result<OrtValue, Box<dyn Error>>;
|
89
|
-
|
90
|
-
fn deserialize_audio_frame(
|
91
|
-
&self,
|
92
|
-
schema: &AudioFrameSchema,
|
93
|
-
value: OrtValue,
|
94
|
-
) -> Result<AudioFrameValue, Box<dyn Error>>;
|
95
|
-
}
|
96
|
-
|
97
|
-
pub trait ImuSerializer {
|
98
|
-
fn serialize_imu(
|
99
|
-
&self,
|
100
|
-
schema: &ImuSchema,
|
101
|
-
value: ImuValue,
|
102
|
-
) -> Result<OrtValue, Box<dyn Error>>;
|
103
|
-
|
104
|
-
fn deserialize_imu(
|
105
|
-
&self,
|
106
|
-
schema: &ImuSchema,
|
107
|
-
value: OrtValue,
|
108
|
-
) -> Result<ImuValue, Box<dyn Error>>;
|
109
|
-
}
|
110
|
-
|
111
|
-
pub trait TimestampSerializer {
|
112
|
-
fn serialize_timestamp(
|
113
|
-
&self,
|
114
|
-
schema: &TimestampSchema,
|
115
|
-
value: TimestampValue,
|
116
|
-
) -> Result<OrtValue, Box<dyn Error>>;
|
117
|
-
|
118
|
-
fn deserialize_timestamp(
|
119
|
-
&self,
|
120
|
-
schema: &TimestampSchema,
|
121
|
-
value: OrtValue,
|
122
|
-
) -> Result<TimestampValue, Box<dyn Error>>;
|
123
|
-
}
|
124
|
-
|
125
|
-
pub trait VectorCommandSerializer {
|
126
|
-
fn serialize_vector_command(
|
127
|
-
&self,
|
128
|
-
schema: &VectorCommandSchema,
|
129
|
-
value: VectorCommandValue,
|
130
|
-
) -> Result<OrtValue, Box<dyn Error>>;
|
131
|
-
|
132
|
-
fn deserialize_vector_command(
|
133
|
-
&self,
|
134
|
-
schema: &VectorCommandSchema,
|
135
|
-
value: OrtValue,
|
136
|
-
) -> Result<VectorCommandValue, Box<dyn Error>>;
|
137
|
-
}
|
138
|
-
|
139
|
-
pub trait StateTensorSerializer {
|
140
|
-
fn serialize_state_tensor(
|
141
|
-
&self,
|
142
|
-
schema: &StateTensorSchema,
|
143
|
-
value: StateTensorValue,
|
144
|
-
) -> Result<OrtValue, Box<dyn Error>>;
|
145
|
-
|
146
|
-
fn deserialize_state_tensor(
|
147
|
-
&self,
|
148
|
-
schema: &StateTensorSchema,
|
149
|
-
value: OrtValue,
|
150
|
-
) -> Result<StateTensorValue, Box<dyn Error>>;
|
151
|
-
}
|
152
|
-
|
153
|
-
pub trait Serializer:
|
154
|
-
JointPositionsSerializer
|
155
|
-
+ JointVelocitiesSerializer
|
156
|
-
+ JointTorquesSerializer
|
157
|
-
+ JointCommandsSerializer
|
158
|
-
+ CameraFrameSerializer
|
159
|
-
+ AudioFrameSerializer
|
160
|
-
+ ImuSerializer
|
161
|
-
+ TimestampSerializer
|
162
|
-
+ VectorCommandSerializer
|
163
|
-
+ StateTensorSerializer
|
164
|
-
{
|
165
|
-
fn serialize(
|
166
|
-
&self,
|
167
|
-
schema: &ValueSchema,
|
168
|
-
value: ProtoValue,
|
169
|
-
) -> Result<OrtValue, Box<dyn Error>>;
|
170
|
-
|
171
|
-
fn deserialize(
|
172
|
-
&self,
|
173
|
-
schema: &ValueSchema,
|
174
|
-
value: OrtValue,
|
175
|
-
) -> Result<ProtoValue, Box<dyn Error>>;
|
176
|
-
}
|
177
|
-
|
178
|
-
pub fn convert_position(
|
179
|
-
value: f32,
|
180
|
-
from_unit: JointPositionUnit,
|
181
|
-
to_unit: JointPositionUnit,
|
182
|
-
) -> Result<f32, Box<dyn Error>> {
|
183
|
-
match (from_unit, to_unit) {
|
184
|
-
(JointPositionUnit::Radians, JointPositionUnit::Degrees) => {
|
185
|
-
Ok(value * 180.0 / std::f32::consts::PI)
|
186
|
-
}
|
187
|
-
(JointPositionUnit::Degrees, JointPositionUnit::Radians) => {
|
188
|
-
Ok(value * std::f32::consts::PI / 180.0)
|
189
|
-
}
|
190
|
-
(a, b) if a == b => Ok(value),
|
191
|
-
_ => Err("Unsupported position unit conversion".into()),
|
192
|
-
}
|
193
|
-
}
|
194
|
-
|
195
|
-
pub fn convert_velocity(
|
196
|
-
value: f32,
|
197
|
-
from_unit: JointVelocityUnit,
|
198
|
-
to_unit: JointVelocityUnit,
|
199
|
-
) -> Result<f32, Box<dyn Error>> {
|
200
|
-
match (from_unit, to_unit) {
|
201
|
-
(JointVelocityUnit::RadiansPerSecond, JointVelocityUnit::DegreesPerSecond) => {
|
202
|
-
Ok(value * 180.0 / std::f32::consts::PI)
|
203
|
-
}
|
204
|
-
(JointVelocityUnit::DegreesPerSecond, JointVelocityUnit::RadiansPerSecond) => {
|
205
|
-
Ok(value * std::f32::consts::PI / 180.0)
|
206
|
-
}
|
207
|
-
(a, b) if a == b => Ok(value),
|
208
|
-
_ => Err("Unsupported velocity unit conversion".into()),
|
209
|
-
}
|
210
|
-
}
|
211
|
-
|
212
|
-
pub fn convert_torque(
|
213
|
-
value: f32,
|
214
|
-
from_unit: JointTorqueUnit,
|
215
|
-
to_unit: JointTorqueUnit,
|
216
|
-
) -> Result<f32, Box<dyn Error>> {
|
217
|
-
match (from_unit, to_unit) {
|
218
|
-
(a, b) if a == b => Ok(value),
|
219
|
-
_ => Err("Unsupported torque unit conversion".into()),
|
220
|
-
}
|
221
|
-
}
|
@@ -1,212 +0,0 @@
|
|
1
|
-
use crate::{
|
2
|
-
kinfer_proto::{
|
3
|
-
self as P, AudioFrameSchema, AudioFrameValue, CameraFrameSchema, CameraFrameValue, DType,
|
4
|
-
ImuAccelerometerValue, ImuGyroscopeValue, ImuMagnetometerValue, ImuSchema, ImuValue,
|
5
|
-
JointCommandValue, JointCommandsSchema, JointCommandsValue, JointPositionUnit,
|
6
|
-
JointPositionValue, JointPositionsSchema, JointPositionsValue, JointTorqueUnit,
|
7
|
-
JointTorqueValue, JointTorquesSchema, JointTorquesValue, JointVelocitiesSchema,
|
8
|
-
JointVelocitiesValue, JointVelocityUnit, JointVelocityValue, ProtoValue, StateTensorSchema,
|
9
|
-
StateTensorValue, TimestampSchema, TimestampValue, ValueSchema, VectorCommandSchema,
|
10
|
-
VectorCommandValue,
|
11
|
-
},
|
12
|
-
onnx_serializer::OnnxSerializer,
|
13
|
-
serializer::{
|
14
|
-
AudioFrameSerializer, CameraFrameSerializer, ImuSerializer, JointCommandsSerializer,
|
15
|
-
JointPositionsSerializer, JointTorquesSerializer, JointVelocitiesSerializer,
|
16
|
-
StateTensorSerializer, TimestampSerializer, VectorCommandSerializer,
|
17
|
-
},
|
18
|
-
};
|
19
|
-
|
20
|
-
use ndarray::Array;
|
21
|
-
use ort::value::Value as OrtValue;
|
22
|
-
use std::f32::consts::PI;
|
23
|
-
|
24
|
-
#[test]
|
25
|
-
fn test_serialize_joint_positions() {
|
26
|
-
let joint_names = vec![
|
27
|
-
"joint_1".to_string(),
|
28
|
-
"joint_2".to_string(),
|
29
|
-
"joint_3".to_string(),
|
30
|
-
];
|
31
|
-
let schema = ValueSchema {
|
32
|
-
value_name: "test".to_string(),
|
33
|
-
value_type: Some(P::proto::value_schema::ValueType::JointPositions(
|
34
|
-
JointPositionsSchema {
|
35
|
-
unit: JointPositionUnit::Degrees as i32,
|
36
|
-
joint_names: joint_names.clone(),
|
37
|
-
},
|
38
|
-
)),
|
39
|
-
};
|
40
|
-
|
41
|
-
let serializer = OnnxSerializer::new(schema.clone());
|
42
|
-
|
43
|
-
// Test with matching units
|
44
|
-
let value = JointPositionsValue {
|
45
|
-
values: vec![
|
46
|
-
JointPositionValue {
|
47
|
-
joint_name: "joint_1".to_string(),
|
48
|
-
value: 60.0,
|
49
|
-
unit: JointPositionUnit::Degrees as i32,
|
50
|
-
},
|
51
|
-
JointPositionValue {
|
52
|
-
joint_name: "joint_2".to_string(),
|
53
|
-
value: 30.0,
|
54
|
-
unit: JointPositionUnit::Degrees as i32,
|
55
|
-
},
|
56
|
-
JointPositionValue {
|
57
|
-
joint_name: "joint_3".to_string(),
|
58
|
-
value: 90.0,
|
59
|
-
unit: JointPositionUnit::Degrees as i32,
|
60
|
-
},
|
61
|
-
],
|
62
|
-
};
|
63
|
-
|
64
|
-
let result = match schema.value_type.as_ref().unwrap() {
|
65
|
-
P::proto::value_schema::ValueType::JointPositions(schema) => {
|
66
|
-
serializer.serialize_joint_positions(schema, value.clone())
|
67
|
-
}
|
68
|
-
_ => panic!("Wrong schema type"),
|
69
|
-
}
|
70
|
-
.unwrap();
|
71
|
-
|
72
|
-
// Verify tensor shape and values
|
73
|
-
let tensor = result.try_extract_tensor::<f32>().unwrap();
|
74
|
-
let array = tensor.view();
|
75
|
-
assert_eq!(array.shape(), &[3]);
|
76
|
-
assert_eq!(array[[0]], 60.0); // joint_1
|
77
|
-
assert_eq!(array[[1]], 30.0); // joint_2
|
78
|
-
assert_eq!(array[[2]], 90.0); // joint_3
|
79
|
-
|
80
|
-
let deserialized = match schema.value_type.as_ref().unwrap() {
|
81
|
-
P::proto::value_schema::ValueType::JointPositions(schema) => {
|
82
|
-
serializer.deserialize_joint_positions(schema, result)
|
83
|
-
}
|
84
|
-
_ => panic!("Wrong schema type"),
|
85
|
-
}
|
86
|
-
.unwrap();
|
87
|
-
|
88
|
-
// Verify full deserialization
|
89
|
-
assert_eq!(deserialized.values.len(), value.values.len());
|
90
|
-
for (expected, actual) in value.values.iter().zip(deserialized.values.iter()) {
|
91
|
-
assert_eq!(expected.joint_name, actual.joint_name);
|
92
|
-
assert_eq!(expected.value, actual.value);
|
93
|
-
assert_eq!(expected.unit, actual.unit);
|
94
|
-
}
|
95
|
-
|
96
|
-
// Test unit conversion
|
97
|
-
let value_radians = JointPositionsValue {
|
98
|
-
values: vec![JointPositionValue {
|
99
|
-
joint_name: "joint_1".to_string(),
|
100
|
-
value: PI / 6.0,
|
101
|
-
unit: JointPositionUnit::Radians as i32,
|
102
|
-
}],
|
103
|
-
};
|
104
|
-
|
105
|
-
let schema_radians = ValueSchema {
|
106
|
-
value_name: "test".to_string(),
|
107
|
-
value_type: Some(P::proto::value_schema::ValueType::JointPositions(
|
108
|
-
JointPositionsSchema {
|
109
|
-
unit: JointPositionUnit::Radians as i32,
|
110
|
-
joint_names: vec!["joint_1".to_string()],
|
111
|
-
},
|
112
|
-
)),
|
113
|
-
};
|
114
|
-
|
115
|
-
let serializer = OnnxSerializer::new(schema_radians.clone());
|
116
|
-
let result = match schema_radians.value_type.as_ref().unwrap() {
|
117
|
-
P::proto::value_schema::ValueType::JointPositions(schema) => {
|
118
|
-
serializer.serialize_joint_positions(schema, value_radians.clone())
|
119
|
-
}
|
120
|
-
_ => panic!("Wrong schema type"),
|
121
|
-
}
|
122
|
-
.unwrap();
|
123
|
-
|
124
|
-
let tensor = result.try_extract_tensor::<f32>().unwrap();
|
125
|
-
let array = tensor.view();
|
126
|
-
assert!((array[[0]] - PI / 6.0).abs() < 1e-6);
|
127
|
-
}
|
128
|
-
|
129
|
-
#[test]
|
130
|
-
fn test_serialize_joint_positions_errors() {
|
131
|
-
let schema = ValueSchema {
|
132
|
-
value_name: "test".to_string(),
|
133
|
-
value_type: Some(P::proto::value_schema::ValueType::JointPositions(
|
134
|
-
JointPositionsSchema {
|
135
|
-
unit: JointPositionUnit::Degrees as i32,
|
136
|
-
joint_names: vec!["joint_1".to_string(), "joint_2".to_string()],
|
137
|
-
},
|
138
|
-
)),
|
139
|
-
};
|
140
|
-
|
141
|
-
let serializer = OnnxSerializer::new(schema.clone());
|
142
|
-
|
143
|
-
// Test cases that should fail:
|
144
|
-
|
145
|
-
// Case 1: Wrong number of joints
|
146
|
-
let value_wrong_count = JointPositionsValue {
|
147
|
-
values: vec![JointPositionValue {
|
148
|
-
joint_name: "joint_1".to_string(),
|
149
|
-
value: 60.0,
|
150
|
-
unit: JointPositionUnit::Degrees as i32,
|
151
|
-
}],
|
152
|
-
};
|
153
|
-
|
154
|
-
let result = match schema.value_type.as_ref().unwrap() {
|
155
|
-
P::proto::value_schema::ValueType::JointPositions(schema) => {
|
156
|
-
serializer.serialize_joint_positions(schema, value_wrong_count)
|
157
|
-
}
|
158
|
-
_ => panic!("Wrong schema type"),
|
159
|
-
};
|
160
|
-
assert!(
|
161
|
-
result.is_err(),
|
162
|
-
"Should fail when joint count doesn't match"
|
163
|
-
);
|
164
|
-
|
165
|
-
// Case 2: Wrong joint names
|
166
|
-
let value_wrong_names = JointPositionsValue {
|
167
|
-
values: vec![
|
168
|
-
JointPositionValue {
|
169
|
-
joint_name: "wrong_joint_1".to_string(),
|
170
|
-
value: 60.0,
|
171
|
-
unit: JointPositionUnit::Degrees as i32,
|
172
|
-
},
|
173
|
-
JointPositionValue {
|
174
|
-
joint_name: "wrong_joint_2".to_string(),
|
175
|
-
value: 30.0,
|
176
|
-
unit: JointPositionUnit::Degrees as i32,
|
177
|
-
},
|
178
|
-
],
|
179
|
-
};
|
180
|
-
|
181
|
-
let result = match schema.value_type.as_ref().unwrap() {
|
182
|
-
P::proto::value_schema::ValueType::JointPositions(schema) => {
|
183
|
-
serializer.serialize_joint_positions(schema, value_wrong_names)
|
184
|
-
}
|
185
|
-
_ => panic!("Wrong schema type"),
|
186
|
-
};
|
187
|
-
assert!(result.is_err(), "Should fail when joint names don't match");
|
188
|
-
|
189
|
-
// Case 3: Wrong unit type
|
190
|
-
let value_wrong_unit = JointPositionsValue {
|
191
|
-
values: vec![
|
192
|
-
JointPositionValue {
|
193
|
-
joint_name: "joint_1".to_string(),
|
194
|
-
value: 60.0,
|
195
|
-
unit: JointPositionUnit::Radians as i32,
|
196
|
-
},
|
197
|
-
JointPositionValue {
|
198
|
-
joint_name: "joint_2".to_string(),
|
199
|
-
value: 30.0,
|
200
|
-
unit: JointPositionUnit::Radians as i32,
|
201
|
-
},
|
202
|
-
],
|
203
|
-
};
|
204
|
-
|
205
|
-
let result = match schema.value_type.as_ref().unwrap() {
|
206
|
-
P::proto::value_schema::ValueType::JointPositions(schema) => {
|
207
|
-
serializer.serialize_joint_positions(schema, value_wrong_unit)
|
208
|
-
}
|
209
|
-
_ => panic!("Wrong schema type"),
|
210
|
-
};
|
211
|
-
assert!(result.is_err(), "Should fail when units don't match");
|
212
|
-
}
|
kinfer/serialize/__init__.py
DELETED
@@ -1,60 +0,0 @@
|
|
1
|
-
"""Defines an interface for instantiating serializers."""
|
2
|
-
|
3
|
-
from typing import Literal, overload
|
4
|
-
|
5
|
-
from kinfer import proto as K
|
6
|
-
|
7
|
-
from .base import MultiSerializer, Serializer
|
8
|
-
from .json import JsonMultiSerializer, JsonSerializer
|
9
|
-
from .numpy import NumpyMultiSerializer, NumpySerializer
|
10
|
-
from .pytorch import PyTorchMultiSerializer, PyTorchSerializer
|
11
|
-
|
12
|
-
SerializerType = Literal["json", "numpy", "pytorch"]
|
13
|
-
|
14
|
-
|
15
|
-
@overload
|
16
|
-
def get_serializer(schema: K.ValueSchema, serializer_type: Literal["json"]) -> JsonSerializer: ...
|
17
|
-
|
18
|
-
|
19
|
-
@overload
|
20
|
-
def get_serializer(schema: K.ValueSchema, serializer_type: Literal["numpy"]) -> NumpySerializer: ...
|
21
|
-
|
22
|
-
|
23
|
-
@overload
|
24
|
-
def get_serializer(schema: K.ValueSchema, serializer_type: Literal["pytorch"]) -> PyTorchSerializer: ...
|
25
|
-
|
26
|
-
|
27
|
-
def get_serializer(schema: K.ValueSchema, serializer_type: SerializerType) -> Serializer:
|
28
|
-
match serializer_type:
|
29
|
-
case "json":
|
30
|
-
return JsonSerializer(schema=schema)
|
31
|
-
case "numpy":
|
32
|
-
return NumpySerializer(schema=schema)
|
33
|
-
case "pytorch":
|
34
|
-
return PyTorchSerializer(schema=schema)
|
35
|
-
case _:
|
36
|
-
raise ValueError(f"Unsupported serializer type: {serializer_type}")
|
37
|
-
|
38
|
-
|
39
|
-
@overload
|
40
|
-
def get_multi_serializer(schema: K.IOSchema, serializer_type: Literal["json"]) -> JsonMultiSerializer: ...
|
41
|
-
|
42
|
-
|
43
|
-
@overload
|
44
|
-
def get_multi_serializer(schema: K.IOSchema, serializer_type: Literal["numpy"]) -> NumpyMultiSerializer: ...
|
45
|
-
|
46
|
-
|
47
|
-
@overload
|
48
|
-
def get_multi_serializer(schema: K.IOSchema, serializer_type: Literal["pytorch"]) -> PyTorchMultiSerializer: ...
|
49
|
-
|
50
|
-
|
51
|
-
def get_multi_serializer(schema: K.IOSchema, serializer_type: SerializerType) -> MultiSerializer:
|
52
|
-
match serializer_type:
|
53
|
-
case "json":
|
54
|
-
return JsonMultiSerializer(schema=schema)
|
55
|
-
case "numpy":
|
56
|
-
return NumpyMultiSerializer(schema=schema)
|
57
|
-
case "pytorch":
|
58
|
-
return PyTorchMultiSerializer(schema=schema)
|
59
|
-
case _:
|
60
|
-
raise ValueError(f"Unsupported serializer type: {serializer_type}")
|