kinfer 0.3.3__cp312-cp312-macosx_11_0_arm64.whl → 0.4.0__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kinfer/__init__.py +0 -5
- kinfer/common/__init__.py +0 -0
- kinfer/common/types.py +11 -0
- kinfer/export/common.py +35 -0
- kinfer/export/jax.py +51 -0
- kinfer/export/pytorch.py +42 -110
- kinfer/export/serialize.py +86 -0
- kinfer/requirements.txt +3 -4
- kinfer/rust/Cargo.toml +8 -6
- kinfer/rust/src/lib.rs +2 -11
- kinfer/rust/src/model.rs +271 -121
- kinfer/rust/src/runtime.rs +104 -0
- kinfer/rust_bindings/Cargo.toml +8 -1
- kinfer/rust_bindings/rust_bindings.pyi +35 -0
- kinfer/rust_bindings/src/lib.rs +310 -1
- kinfer/rust_bindings.cpython-312-darwin.so +0 -0
- kinfer/rust_bindings.pyi +29 -1
- kinfer-0.4.0.dist-info/METADATA +55 -0
- kinfer-0.4.0.dist-info/RECORD +26 -0
- {kinfer-0.3.3.dist-info → kinfer-0.4.0.dist-info}/WHEEL +2 -1
- kinfer/inference/__init__.py +0 -2
- kinfer/inference/base.py +0 -64
- kinfer/inference/python.py +0 -66
- kinfer/proto/__init__.py +0 -40
- kinfer/proto/kinfer_pb2.py +0 -103
- kinfer/proto/kinfer_pb2.pyi +0 -1097
- kinfer/requirements-dev.txt +0 -8
- kinfer/rust/build.rs +0 -16
- kinfer/rust/src/kinfer_proto.rs +0 -14
- kinfer/rust/src/main.rs +0 -6
- kinfer/rust/src/onnx_serializer.rs +0 -804
- kinfer/rust/src/serializer.rs +0 -221
- kinfer/rust/src/tests/onnx_serializer_tests.rs +0 -212
- kinfer/serialize/__init__.py +0 -60
- kinfer/serialize/base.py +0 -536
- kinfer/serialize/json.py +0 -399
- kinfer/serialize/numpy.py +0 -426
- kinfer/serialize/pytorch.py +0 -402
- kinfer/serialize/schema.py +0 -125
- kinfer/serialize/types.py +0 -17
- kinfer/serialize/utils.py +0 -177
- kinfer-0.3.3.dist-info/METADATA +0 -57
- kinfer-0.3.3.dist-info/RECORD +0 -40
- {kinfer-0.3.3.dist-info → kinfer-0.4.0.dist-info/licenses}/LICENSE +0 -0
- {kinfer-0.3.3.dist-info → kinfer-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,804 +0,0 @@
|
|
1
|
-
use crate::serializer::{
|
2
|
-
convert_position, convert_torque, convert_velocity, AudioFrameSerializer,
|
3
|
-
CameraFrameSerializer, ImuSerializer, JointCommandsSerializer, JointPositionsSerializer,
|
4
|
-
JointTorquesSerializer, JointVelocitiesSerializer, Serializer, StateTensorSerializer,
|
5
|
-
TimestampSerializer, VectorCommandSerializer,
|
6
|
-
};
|
7
|
-
|
8
|
-
use ndarray::{s, Array, Array1, Array2, Array3, ArrayView, ArrayView1, ArrayView2};
|
9
|
-
use ort::value::{Tensor, Value as OrtValue};
|
10
|
-
use std::error::Error;
|
11
|
-
|
12
|
-
// Import the re-exported types
|
13
|
-
use crate::kinfer_proto::{
|
14
|
-
AudioFrameSchema, AudioFrameValue, CameraFrameSchema, CameraFrameValue, DType,
|
15
|
-
ImuAccelerometerValue, ImuGyroscopeValue, ImuMagnetometerValue, ImuSchema, ImuValue,
|
16
|
-
JointCommandValue, JointCommandsSchema, JointCommandsValue, JointPositionUnit,
|
17
|
-
JointPositionValue, JointPositionsSchema, JointPositionsValue, JointTorqueUnit,
|
18
|
-
JointTorqueValue, JointTorquesSchema, JointTorquesValue, JointVelocitiesSchema,
|
19
|
-
JointVelocitiesValue, JointVelocityUnit, JointVelocityValue, ProtoIO, ProtoIOSchema,
|
20
|
-
ProtoValue, StateTensorSchema, StateTensorValue, TimestampSchema, TimestampValue, ValueSchema,
|
21
|
-
VectorCommandSchema, VectorCommandValue,
|
22
|
-
};
|
23
|
-
|
24
|
-
// Import the nested types
|
25
|
-
use crate::kinfer_proto::proto::value::Value as EnumValue;
|
26
|
-
use crate::kinfer_proto::proto::value_schema::ValueType;
|
27
|
-
|
28
|
-
pub struct OnnxSerializer {
|
29
|
-
schema: ValueSchema,
|
30
|
-
}
|
31
|
-
|
32
|
-
impl OnnxSerializer {
|
33
|
-
pub fn new(schema: ValueSchema) -> Self {
|
34
|
-
Self { schema }
|
35
|
-
}
|
36
|
-
|
37
|
-
fn array_to_value<T, D>(&self, array: Array<T, D>) -> Result<OrtValue, Box<dyn Error>>
|
38
|
-
where
|
39
|
-
T: Into<f32> + Copy,
|
40
|
-
D: ndarray::Dimension,
|
41
|
-
{
|
42
|
-
let array = array.map(|&x| x.into());
|
43
|
-
Tensor::from_array(array.into_dyn())
|
44
|
-
.map(|tensor| tensor.into_dyn())
|
45
|
-
.map_err(|e| Box::new(e) as Box<dyn Error>)
|
46
|
-
}
|
47
|
-
}
|
48
|
-
|
49
|
-
impl JointPositionsSerializer for OnnxSerializer {
|
50
|
-
fn serialize_joint_positions(
|
51
|
-
&self,
|
52
|
-
schema: &JointPositionsSchema,
|
53
|
-
value: JointPositionsValue,
|
54
|
-
) -> Result<OrtValue, Box<dyn Error>> {
|
55
|
-
let mut array = Array1::zeros(schema.joint_names.len());
|
56
|
-
for (i, name) in schema.joint_names.iter().enumerate() {
|
57
|
-
if let Some(joint) = value.values.iter().find(|v| v.joint_name == *name) {
|
58
|
-
let from_unit = JointPositionUnit::try_from(joint.unit)?;
|
59
|
-
let to_unit = JointPositionUnit::try_from(schema.unit)?;
|
60
|
-
array[i] = convert_position(joint.value, from_unit, to_unit)?;
|
61
|
-
}
|
62
|
-
}
|
63
|
-
self.array_to_value(array)
|
64
|
-
}
|
65
|
-
|
66
|
-
fn deserialize_joint_positions(
|
67
|
-
&self,
|
68
|
-
schema: &JointPositionsSchema,
|
69
|
-
value: OrtValue,
|
70
|
-
) -> Result<JointPositionsValue, Box<dyn Error>> {
|
71
|
-
let tensor = value.try_extract_tensor()?;
|
72
|
-
let array = tensor.view();
|
73
|
-
|
74
|
-
if array.len() != schema.joint_names.len() {
|
75
|
-
return Err("Array length does not match number of joints".into());
|
76
|
-
}
|
77
|
-
|
78
|
-
Ok(JointPositionsValue {
|
79
|
-
values: schema
|
80
|
-
.joint_names
|
81
|
-
.iter()
|
82
|
-
.enumerate()
|
83
|
-
.map(|(i, name)| JointPositionValue {
|
84
|
-
joint_name: name.clone(),
|
85
|
-
value: array[i],
|
86
|
-
unit: schema.unit,
|
87
|
-
})
|
88
|
-
.collect(),
|
89
|
-
})
|
90
|
-
}
|
91
|
-
}
|
92
|
-
|
93
|
-
impl JointVelocitiesSerializer for OnnxSerializer {
|
94
|
-
fn serialize_joint_velocities(
|
95
|
-
&self,
|
96
|
-
schema: &JointVelocitiesSchema,
|
97
|
-
value: JointVelocitiesValue,
|
98
|
-
) -> Result<OrtValue, Box<dyn Error>> {
|
99
|
-
let mut array = Array1::zeros(schema.joint_names.len());
|
100
|
-
for (i, name) in schema.joint_names.iter().enumerate() {
|
101
|
-
if let Some(joint) = value.values.iter().find(|v| v.joint_name == *name) {
|
102
|
-
let from_unit = JointVelocityUnit::try_from(joint.unit)?;
|
103
|
-
let to_unit = JointVelocityUnit::try_from(schema.unit)?;
|
104
|
-
array[i] = convert_velocity(joint.value, from_unit, to_unit)?;
|
105
|
-
}
|
106
|
-
}
|
107
|
-
self.array_to_value(array)
|
108
|
-
}
|
109
|
-
|
110
|
-
fn deserialize_joint_velocities(
|
111
|
-
&self,
|
112
|
-
schema: &JointVelocitiesSchema,
|
113
|
-
value: OrtValue,
|
114
|
-
) -> Result<JointVelocitiesValue, Box<dyn Error>> {
|
115
|
-
let tensor = value.try_extract_tensor()?;
|
116
|
-
let array = tensor.view();
|
117
|
-
|
118
|
-
if array.len() != schema.joint_names.len() {
|
119
|
-
return Err("Array length does not match number of joints".into());
|
120
|
-
}
|
121
|
-
|
122
|
-
Ok(JointVelocitiesValue {
|
123
|
-
values: schema
|
124
|
-
.joint_names
|
125
|
-
.iter()
|
126
|
-
.enumerate()
|
127
|
-
.map(|(i, name)| JointVelocityValue {
|
128
|
-
joint_name: name.clone(),
|
129
|
-
value: array[i],
|
130
|
-
unit: schema.unit.clone(),
|
131
|
-
})
|
132
|
-
.collect(),
|
133
|
-
})
|
134
|
-
}
|
135
|
-
}
|
136
|
-
|
137
|
-
impl JointTorquesSerializer for OnnxSerializer {
|
138
|
-
fn serialize_joint_torques(
|
139
|
-
&self,
|
140
|
-
schema: &JointTorquesSchema,
|
141
|
-
value: JointTorquesValue,
|
142
|
-
) -> Result<OrtValue, Box<dyn Error>> {
|
143
|
-
let mut array = Array1::zeros(schema.joint_names.len());
|
144
|
-
for (i, name) in schema.joint_names.iter().enumerate() {
|
145
|
-
if let Some(joint) = value.values.iter().find(|v| v.joint_name == *name) {
|
146
|
-
let from_unit = JointTorqueUnit::try_from(joint.unit)?;
|
147
|
-
let to_unit = JointTorqueUnit::try_from(schema.unit)?;
|
148
|
-
array[i] = convert_torque(joint.value, from_unit, to_unit)?;
|
149
|
-
}
|
150
|
-
}
|
151
|
-
self.array_to_value(array)
|
152
|
-
}
|
153
|
-
|
154
|
-
fn deserialize_joint_torques(
|
155
|
-
&self,
|
156
|
-
schema: &JointTorquesSchema,
|
157
|
-
value: OrtValue,
|
158
|
-
) -> Result<JointTorquesValue, Box<dyn Error>> {
|
159
|
-
let tensor = value.try_extract_tensor()?;
|
160
|
-
let array = tensor.view();
|
161
|
-
|
162
|
-
if array.len() != schema.joint_names.len() {
|
163
|
-
return Err("Array length does not match number of joints".into());
|
164
|
-
}
|
165
|
-
|
166
|
-
Ok(JointTorquesValue {
|
167
|
-
values: schema
|
168
|
-
.joint_names
|
169
|
-
.iter()
|
170
|
-
.enumerate()
|
171
|
-
.map(|(i, name)| JointTorqueValue {
|
172
|
-
joint_name: name.clone(),
|
173
|
-
value: array[i],
|
174
|
-
unit: schema.unit,
|
175
|
-
})
|
176
|
-
.collect(),
|
177
|
-
})
|
178
|
-
}
|
179
|
-
}
|
180
|
-
|
181
|
-
impl JointCommandsSerializer for OnnxSerializer {
|
182
|
-
fn serialize_joint_commands(
|
183
|
-
&self,
|
184
|
-
schema: &JointCommandsSchema,
|
185
|
-
value: JointCommandsValue,
|
186
|
-
) -> Result<OrtValue, Box<dyn Error>> {
|
187
|
-
let mut array = Array2::zeros((schema.joint_names.len(), 5));
|
188
|
-
for (i, name) in schema.joint_names.iter().enumerate() {
|
189
|
-
if let Some(cmd) = value.values.iter().find(|v| v.joint_name == *name) {
|
190
|
-
let cmd_torque_unit = JointTorqueUnit::try_from(cmd.torque_unit)?;
|
191
|
-
let cmd_velocity_unit = JointVelocityUnit::try_from(cmd.velocity_unit)?;
|
192
|
-
let cmd_position_unit = JointPositionUnit::try_from(cmd.position_unit)?;
|
193
|
-
let schema_torque_unit = JointTorqueUnit::try_from(schema.torque_unit)?;
|
194
|
-
let schema_velocity_unit = JointVelocityUnit::try_from(schema.velocity_unit)?;
|
195
|
-
let schema_position_unit = JointPositionUnit::try_from(schema.position_unit)?;
|
196
|
-
array[[i, 0]] = convert_torque(cmd.torque, cmd_torque_unit, schema_torque_unit)?;
|
197
|
-
array[[i, 1]] =
|
198
|
-
convert_velocity(cmd.velocity, cmd_velocity_unit, schema_velocity_unit)?;
|
199
|
-
array[[i, 2]] =
|
200
|
-
convert_position(cmd.position, cmd_position_unit, schema_position_unit)?;
|
201
|
-
array[[i, 3]] = cmd.kp;
|
202
|
-
array[[i, 4]] = cmd.kd;
|
203
|
-
}
|
204
|
-
}
|
205
|
-
self.array_to_value(array)
|
206
|
-
}
|
207
|
-
|
208
|
-
fn deserialize_joint_commands(
|
209
|
-
&self,
|
210
|
-
schema: &JointCommandsSchema,
|
211
|
-
value: OrtValue,
|
212
|
-
) -> Result<JointCommandsValue, Box<dyn Error>> {
|
213
|
-
let tensor = value.try_extract_tensor()?;
|
214
|
-
let array = tensor.view();
|
215
|
-
|
216
|
-
if array.shape() != [schema.joint_names.len(), 5] {
|
217
|
-
return Err("Array shape does not match expected dimensions".into());
|
218
|
-
}
|
219
|
-
|
220
|
-
Ok(JointCommandsValue {
|
221
|
-
values: schema
|
222
|
-
.joint_names
|
223
|
-
.iter()
|
224
|
-
.enumerate()
|
225
|
-
.map(|(i, name)| JointCommandValue {
|
226
|
-
joint_name: name.clone(),
|
227
|
-
torque: array[[i, 0]],
|
228
|
-
velocity: array[[i, 1]],
|
229
|
-
position: array[[i, 2]],
|
230
|
-
kp: array[[i, 3]],
|
231
|
-
kd: array[[i, 4]],
|
232
|
-
torque_unit: schema.torque_unit,
|
233
|
-
velocity_unit: schema.velocity_unit,
|
234
|
-
position_unit: schema.position_unit,
|
235
|
-
})
|
236
|
-
.collect(),
|
237
|
-
})
|
238
|
-
}
|
239
|
-
}
|
240
|
-
|
241
|
-
impl CameraFrameSerializer for OnnxSerializer {
|
242
|
-
fn serialize_camera_frame(
|
243
|
-
&self,
|
244
|
-
schema: &CameraFrameSchema,
|
245
|
-
value: CameraFrameValue,
|
246
|
-
) -> Result<OrtValue, Box<dyn Error>> {
|
247
|
-
let bytes = value.data;
|
248
|
-
let array = Array3::from_shape_vec(
|
249
|
-
(
|
250
|
-
schema.channels as usize,
|
251
|
-
schema.height as usize,
|
252
|
-
schema.width as usize,
|
253
|
-
),
|
254
|
-
bytes.iter().map(|&x| x as f32 / 255.0).collect(),
|
255
|
-
)?;
|
256
|
-
self.array_to_value(array)
|
257
|
-
}
|
258
|
-
|
259
|
-
fn deserialize_camera_frame(
|
260
|
-
&self,
|
261
|
-
schema: &CameraFrameSchema,
|
262
|
-
value: OrtValue,
|
263
|
-
) -> Result<CameraFrameValue, Box<dyn Error>> {
|
264
|
-
let tensor = value.try_extract_tensor()?;
|
265
|
-
let array = tensor.view();
|
266
|
-
|
267
|
-
if array.shape()
|
268
|
-
!= [
|
269
|
-
schema.channels as usize,
|
270
|
-
schema.height as usize,
|
271
|
-
schema.width as usize,
|
272
|
-
]
|
273
|
-
{
|
274
|
-
return Err("Array shape does not match expected dimensions".into());
|
275
|
-
}
|
276
|
-
|
277
|
-
let bytes: Vec<u8> = array
|
278
|
-
.iter()
|
279
|
-
.map(|&x: &f32| (x * 255.0).clamp(0.0, 255.0) as u8)
|
280
|
-
.collect();
|
281
|
-
|
282
|
-
Ok(CameraFrameValue { data: bytes })
|
283
|
-
}
|
284
|
-
}
|
285
|
-
|
286
|
-
impl AudioFrameSerializer for OnnxSerializer {
|
287
|
-
fn serialize_audio_frame(
|
288
|
-
&self,
|
289
|
-
schema: &AudioFrameSchema,
|
290
|
-
value: AudioFrameValue,
|
291
|
-
) -> Result<OrtValue, Box<dyn Error>> {
|
292
|
-
let array = Array2::from_shape_vec(
|
293
|
-
(schema.channels as usize, schema.sample_rate as usize),
|
294
|
-
parse_audio_bytes(&value.data, schema.dtype.try_into()?)?,
|
295
|
-
)?;
|
296
|
-
self.array_to_value(array)
|
297
|
-
}
|
298
|
-
|
299
|
-
fn deserialize_audio_frame(
|
300
|
-
&self,
|
301
|
-
schema: &AudioFrameSchema,
|
302
|
-
value: OrtValue,
|
303
|
-
) -> Result<AudioFrameValue, Box<dyn Error>> {
|
304
|
-
let tensor = value.try_extract_tensor()?;
|
305
|
-
let array = tensor.view();
|
306
|
-
|
307
|
-
if array.shape() != [schema.channels as usize, schema.sample_rate as usize] {
|
308
|
-
return Err("Array shape does not match expected dimensions".into());
|
309
|
-
}
|
310
|
-
|
311
|
-
let array = array.into_dimensionality::<ndarray::Ix2>()?;
|
312
|
-
|
313
|
-
Ok(AudioFrameValue {
|
314
|
-
data: audio_array_to_bytes(array, schema.dtype.try_into()?)?,
|
315
|
-
})
|
316
|
-
}
|
317
|
-
}
|
318
|
-
|
319
|
-
impl ImuSerializer for OnnxSerializer {
|
320
|
-
fn serialize_imu(
|
321
|
-
&self,
|
322
|
-
schema: &ImuSchema,
|
323
|
-
value: ImuValue,
|
324
|
-
) -> Result<OrtValue, Box<dyn Error>> {
|
325
|
-
let mut vectors = Vec::new();
|
326
|
-
|
327
|
-
if schema.use_accelerometer {
|
328
|
-
if let Some(acc) = &value.linear_acceleration {
|
329
|
-
vectors.push([acc.x, acc.y, acc.z]);
|
330
|
-
}
|
331
|
-
}
|
332
|
-
if schema.use_gyroscope {
|
333
|
-
if let Some(gyro) = &value.angular_velocity {
|
334
|
-
vectors.push([gyro.x, gyro.y, gyro.z]);
|
335
|
-
}
|
336
|
-
}
|
337
|
-
if schema.use_magnetometer {
|
338
|
-
if let Some(mag) = &value.magnetic_field {
|
339
|
-
vectors.push([mag.x, mag.y, mag.z]);
|
340
|
-
}
|
341
|
-
}
|
342
|
-
|
343
|
-
let array = Array2::from_shape_vec(
|
344
|
-
(vectors.len(), 3),
|
345
|
-
vectors.into_iter().flat_map(|v| v.into_iter()).collect(),
|
346
|
-
)?;
|
347
|
-
self.array_to_value(array)
|
348
|
-
}
|
349
|
-
|
350
|
-
fn deserialize_imu(
|
351
|
-
&self,
|
352
|
-
schema: &ImuSchema,
|
353
|
-
value: OrtValue,
|
354
|
-
) -> Result<ImuValue, Box<dyn Error>> {
|
355
|
-
let tensor = value.try_extract_tensor()?;
|
356
|
-
let array = tensor.view();
|
357
|
-
let mut result = ImuValue::default();
|
358
|
-
let mut idx = 0;
|
359
|
-
|
360
|
-
if schema.use_accelerometer {
|
361
|
-
result.linear_acceleration = Some(ImuAccelerometerValue {
|
362
|
-
x: array[[idx, 0]],
|
363
|
-
y: array[[idx, 1]],
|
364
|
-
z: array[[idx, 2]],
|
365
|
-
});
|
366
|
-
idx += 1;
|
367
|
-
}
|
368
|
-
if schema.use_gyroscope {
|
369
|
-
result.angular_velocity = Some(ImuGyroscopeValue {
|
370
|
-
x: array[[idx, 0]],
|
371
|
-
y: array[[idx, 1]],
|
372
|
-
z: array[[idx, 2]],
|
373
|
-
});
|
374
|
-
idx += 1;
|
375
|
-
}
|
376
|
-
if schema.use_magnetometer {
|
377
|
-
result.magnetic_field = Some(ImuMagnetometerValue {
|
378
|
-
x: array[[idx, 0]],
|
379
|
-
y: array[[idx, 1]],
|
380
|
-
z: array[[idx, 2]],
|
381
|
-
});
|
382
|
-
}
|
383
|
-
|
384
|
-
Ok(result)
|
385
|
-
}
|
386
|
-
}
|
387
|
-
|
388
|
-
impl TimestampSerializer for OnnxSerializer {
|
389
|
-
fn serialize_timestamp(
|
390
|
-
&self,
|
391
|
-
schema: &TimestampSchema,
|
392
|
-
value: TimestampValue,
|
393
|
-
) -> Result<OrtValue, Box<dyn Error>> {
|
394
|
-
let elapsed_seconds = value.seconds - schema.start_seconds;
|
395
|
-
let elapsed_nanos = value.nanos - schema.start_nanos;
|
396
|
-
let total_seconds = elapsed_seconds as f32 + (elapsed_nanos as f32 / 1_000_000_000.0);
|
397
|
-
self.array_to_value(Array1::from_vec(vec![total_seconds]))
|
398
|
-
}
|
399
|
-
|
400
|
-
fn deserialize_timestamp(
|
401
|
-
&self,
|
402
|
-
schema: &TimestampSchema,
|
403
|
-
value: OrtValue,
|
404
|
-
) -> Result<TimestampValue, Box<dyn Error>> {
|
405
|
-
let tensor = value.try_extract_tensor()?;
|
406
|
-
let array = tensor.view();
|
407
|
-
|
408
|
-
// Get first element using iterator
|
409
|
-
let total_seconds: f32 = *array.iter().next().ok_or("Timestamp tensor is empty")?;
|
410
|
-
|
411
|
-
let elapsed_seconds = total_seconds.trunc() as i64;
|
412
|
-
let elapsed_nanos = ((total_seconds.fract() * 1_000_000_000.0).round()) as i32;
|
413
|
-
|
414
|
-
Ok(TimestampValue {
|
415
|
-
seconds: schema.start_seconds + elapsed_seconds,
|
416
|
-
nanos: schema.start_nanos + elapsed_nanos,
|
417
|
-
})
|
418
|
-
}
|
419
|
-
}
|
420
|
-
|
421
|
-
impl VectorCommandSerializer for OnnxSerializer {
|
422
|
-
fn serialize_vector_command(
|
423
|
-
&self,
|
424
|
-
schema: &VectorCommandSchema,
|
425
|
-
value: VectorCommandValue,
|
426
|
-
) -> Result<OrtValue, Box<dyn Error>> {
|
427
|
-
let array = Array1::from_vec(value.values);
|
428
|
-
self.array_to_value(array)
|
429
|
-
}
|
430
|
-
|
431
|
-
fn deserialize_vector_command(
|
432
|
-
&self,
|
433
|
-
schema: &VectorCommandSchema,
|
434
|
-
value: OrtValue,
|
435
|
-
) -> Result<VectorCommandValue, Box<dyn Error>> {
|
436
|
-
let tensor = value.try_extract_tensor()?;
|
437
|
-
let array = tensor.view();
|
438
|
-
|
439
|
-
if array.len() != schema.dimensions as usize {
|
440
|
-
return Err("Array length does not match expected dimensions".into());
|
441
|
-
}
|
442
|
-
|
443
|
-
Ok(VectorCommandValue {
|
444
|
-
values: array.iter().copied().collect(),
|
445
|
-
})
|
446
|
-
}
|
447
|
-
}
|
448
|
-
|
449
|
-
impl StateTensorSerializer for OnnxSerializer {
|
450
|
-
fn serialize_state_tensor(
|
451
|
-
&self,
|
452
|
-
schema: &StateTensorSchema,
|
453
|
-
value: StateTensorValue,
|
454
|
-
) -> Result<OrtValue, Box<dyn Error>> {
|
455
|
-
let shape: Vec<usize> = schema.shape.iter().map(|&x| x as usize).collect();
|
456
|
-
let array = Array::from_shape_vec(
|
457
|
-
shape,
|
458
|
-
parse_tensor_bytes(&value.data, schema.dtype.try_into()?)?,
|
459
|
-
)?;
|
460
|
-
self.array_to_value(array)
|
461
|
-
}
|
462
|
-
|
463
|
-
fn deserialize_state_tensor(
|
464
|
-
&self,
|
465
|
-
schema: &StateTensorSchema,
|
466
|
-
value: OrtValue,
|
467
|
-
) -> Result<StateTensorValue, Box<dyn Error>> {
|
468
|
-
let tensor = value.try_extract_tensor()?;
|
469
|
-
let array = tensor.view();
|
470
|
-
|
471
|
-
let expected_shape: Vec<usize> = schema.shape.iter().map(|&x| x as usize).collect();
|
472
|
-
if array.shape() != expected_shape.as_slice() {
|
473
|
-
return Err("Array shape does not match expected dimensions".into());
|
474
|
-
}
|
475
|
-
|
476
|
-
Ok(StateTensorValue {
|
477
|
-
data: tensor_array_to_bytes(array.view(), schema.dtype.try_into()?)?,
|
478
|
-
})
|
479
|
-
}
|
480
|
-
}
|
481
|
-
|
482
|
-
// Helper functions for parsing bytes
|
483
|
-
fn parse_audio_bytes(bytes: &[u8], dtype: DType) -> Result<Vec<f32>, Box<dyn Error>> {
|
484
|
-
match dtype {
|
485
|
-
DType::Fp32 => {
|
486
|
-
let mut result = Vec::with_capacity(bytes.len() / 4);
|
487
|
-
for chunk in bytes.chunks_exact(4) {
|
488
|
-
let value = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
|
489
|
-
result.push(value);
|
490
|
-
}
|
491
|
-
Ok(result)
|
492
|
-
}
|
493
|
-
_ => Err("Unsupported audio data type".into()),
|
494
|
-
}
|
495
|
-
}
|
496
|
-
|
497
|
-
fn audio_array_to_bytes(array: ArrayView2<f32>, dtype: DType) -> Result<Vec<u8>, Box<dyn Error>> {
|
498
|
-
match dtype {
|
499
|
-
DType::Fp32 => {
|
500
|
-
let mut result = Vec::with_capacity(array.len() * 4);
|
501
|
-
for &value in array.iter() {
|
502
|
-
result.extend_from_slice(&value.to_le_bytes());
|
503
|
-
}
|
504
|
-
Ok(result)
|
505
|
-
}
|
506
|
-
_ => Err("Unsupported audio data type".into()),
|
507
|
-
}
|
508
|
-
}
|
509
|
-
|
510
|
-
fn parse_tensor_bytes(bytes: &[u8], dtype: DType) -> Result<Vec<f32>, Box<dyn Error>> {
|
511
|
-
match dtype {
|
512
|
-
DType::Fp32 => {
|
513
|
-
let mut result = Vec::with_capacity(bytes.len() / 4);
|
514
|
-
for chunk in bytes.chunks_exact(4) {
|
515
|
-
let value = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
|
516
|
-
result.push(value);
|
517
|
-
}
|
518
|
-
Ok(result)
|
519
|
-
}
|
520
|
-
_ => Err("Unsupported tensor data type".into()),
|
521
|
-
}
|
522
|
-
}
|
523
|
-
|
524
|
-
fn tensor_array_to_bytes(
|
525
|
-
array: ArrayView<f32, ndarray::IxDyn>,
|
526
|
-
dtype: DType,
|
527
|
-
) -> Result<Vec<u8>, Box<dyn Error>> {
|
528
|
-
match dtype {
|
529
|
-
DType::Fp32 => {
|
530
|
-
let mut result = Vec::with_capacity(array.len() * 4);
|
531
|
-
for &value in array.iter() {
|
532
|
-
result.extend_from_slice(&value.to_le_bytes());
|
533
|
-
}
|
534
|
-
Ok(result)
|
535
|
-
}
|
536
|
-
_ => Err("Unsupported tensor data type".into()),
|
537
|
-
}
|
538
|
-
}
|
539
|
-
|
540
|
-
impl Serializer for OnnxSerializer {
|
541
|
-
fn serialize(
|
542
|
-
&self,
|
543
|
-
schema: &ValueSchema,
|
544
|
-
value: ProtoValue,
|
545
|
-
) -> Result<OrtValue, Box<dyn Error>> {
|
546
|
-
match schema.value_type.as_ref().ok_or("Missing value type")? {
|
547
|
-
ValueType::JointPositions(ref joint_positions_schema) => match value.value {
|
548
|
-
Some(EnumValue::JointPositions(values)) => {
|
549
|
-
self.serialize_joint_positions(joint_positions_schema, values)
|
550
|
-
}
|
551
|
-
_ => Err("Unsupported value type".into()),
|
552
|
-
},
|
553
|
-
ValueType::JointVelocities(ref joint_velocities_schema) => match value.value {
|
554
|
-
Some(EnumValue::JointVelocities(values)) => {
|
555
|
-
self.serialize_joint_velocities(joint_velocities_schema, values)
|
556
|
-
}
|
557
|
-
_ => Err("Unsupported value type".into()),
|
558
|
-
},
|
559
|
-
ValueType::JointTorques(ref joint_torques_schema) => match value.value {
|
560
|
-
Some(EnumValue::JointTorques(values)) => {
|
561
|
-
self.serialize_joint_torques(joint_torques_schema, values)
|
562
|
-
}
|
563
|
-
_ => Err("Unsupported value type".into()),
|
564
|
-
},
|
565
|
-
ValueType::JointCommands(ref joint_commands_schema) => match value.value {
|
566
|
-
Some(EnumValue::JointCommands(values)) => {
|
567
|
-
self.serialize_joint_commands(joint_commands_schema, values)
|
568
|
-
}
|
569
|
-
_ => Err("Unsupported value type".into()),
|
570
|
-
},
|
571
|
-
ValueType::CameraFrame(ref camera_frame_schema) => match value.value {
|
572
|
-
Some(EnumValue::CameraFrame(values)) => {
|
573
|
-
self.serialize_camera_frame(camera_frame_schema, values)
|
574
|
-
}
|
575
|
-
_ => Err("Unsupported value type".into()),
|
576
|
-
},
|
577
|
-
ValueType::AudioFrame(ref audio_frame_schema) => match value.value {
|
578
|
-
Some(EnumValue::AudioFrame(values)) => {
|
579
|
-
self.serialize_audio_frame(audio_frame_schema, values)
|
580
|
-
}
|
581
|
-
_ => Err("Unsupported value type".into()),
|
582
|
-
},
|
583
|
-
ValueType::Imu(ref imu_schema) => match value.value {
|
584
|
-
Some(EnumValue::Imu(values)) => self.serialize_imu(imu_schema, values),
|
585
|
-
_ => Err("Unsupported value type".into()),
|
586
|
-
},
|
587
|
-
ValueType::Timestamp(ref timestamp_schema) => match value.value {
|
588
|
-
Some(EnumValue::Timestamp(values)) => {
|
589
|
-
self.serialize_timestamp(timestamp_schema, values)
|
590
|
-
}
|
591
|
-
_ => Err("Unsupported value type".into()),
|
592
|
-
},
|
593
|
-
ValueType::VectorCommand(ref vector_command_schema) => match value.value {
|
594
|
-
Some(EnumValue::VectorCommand(values)) => {
|
595
|
-
self.serialize_vector_command(vector_command_schema, values)
|
596
|
-
}
|
597
|
-
_ => Err("Unsupported value type".into()),
|
598
|
-
},
|
599
|
-
ValueType::StateTensor(ref state_tensor_schema) => match value.value {
|
600
|
-
Some(EnumValue::StateTensor(values)) => {
|
601
|
-
self.serialize_state_tensor(state_tensor_schema, values)
|
602
|
-
}
|
603
|
-
_ => Err("Unsupported value type".into()),
|
604
|
-
},
|
605
|
-
}
|
606
|
-
}
|
607
|
-
|
608
|
-
fn deserialize(
|
609
|
-
&self,
|
610
|
-
schema: &ValueSchema,
|
611
|
-
value: OrtValue,
|
612
|
-
) -> Result<ProtoValue, Box<dyn Error>> {
|
613
|
-
match schema.value_type.as_ref().ok_or("Missing value type")? {
|
614
|
-
ValueType::JointPositions(ref joint_positions_schema) => {
|
615
|
-
let positions = self.deserialize_joint_positions(joint_positions_schema, value)?;
|
616
|
-
Ok(ProtoValue {
|
617
|
-
value_name: schema.value_name.clone(),
|
618
|
-
value: Some(EnumValue::JointPositions(positions)),
|
619
|
-
})
|
620
|
-
}
|
621
|
-
ValueType::JointVelocities(ref joint_velocities_schema) => {
|
622
|
-
let velocities =
|
623
|
-
self.deserialize_joint_velocities(joint_velocities_schema, value)?;
|
624
|
-
Ok(ProtoValue {
|
625
|
-
value_name: schema.value_name.clone(),
|
626
|
-
value: Some(EnumValue::JointVelocities(velocities)),
|
627
|
-
})
|
628
|
-
}
|
629
|
-
ValueType::JointTorques(ref joint_torques_schema) => {
|
630
|
-
let torques = self.deserialize_joint_torques(joint_torques_schema, value)?;
|
631
|
-
Ok(ProtoValue {
|
632
|
-
value_name: schema.value_name.clone(),
|
633
|
-
value: Some(EnumValue::JointTorques(torques)),
|
634
|
-
})
|
635
|
-
}
|
636
|
-
ValueType::JointCommands(ref joint_commands_schema) => {
|
637
|
-
let commands = self.deserialize_joint_commands(joint_commands_schema, value)?;
|
638
|
-
Ok(ProtoValue {
|
639
|
-
value_name: schema.value_name.clone(),
|
640
|
-
value: Some(EnumValue::JointCommands(commands)),
|
641
|
-
})
|
642
|
-
}
|
643
|
-
ValueType::CameraFrame(ref camera_frame_schema) => {
|
644
|
-
let frame = self.deserialize_camera_frame(camera_frame_schema, value)?;
|
645
|
-
Ok(ProtoValue {
|
646
|
-
value_name: schema.value_name.clone(),
|
647
|
-
value: Some(EnumValue::CameraFrame(frame)),
|
648
|
-
})
|
649
|
-
}
|
650
|
-
ValueType::AudioFrame(ref audio_frame_schema) => {
|
651
|
-
let frame = self.deserialize_audio_frame(audio_frame_schema, value)?;
|
652
|
-
Ok(ProtoValue {
|
653
|
-
value_name: schema.value_name.clone(),
|
654
|
-
value: Some(EnumValue::AudioFrame(frame)),
|
655
|
-
})
|
656
|
-
}
|
657
|
-
ValueType::Imu(ref imu_schema) => {
|
658
|
-
let imu = self.deserialize_imu(imu_schema, value)?;
|
659
|
-
Ok(ProtoValue {
|
660
|
-
value_name: schema.value_name.clone(),
|
661
|
-
value: Some(EnumValue::Imu(imu)),
|
662
|
-
})
|
663
|
-
}
|
664
|
-
ValueType::Timestamp(ref timestamp_schema) => {
|
665
|
-
let timestamp = self.deserialize_timestamp(timestamp_schema, value)?;
|
666
|
-
Ok(ProtoValue {
|
667
|
-
value_name: schema.value_name.clone(),
|
668
|
-
value: Some(EnumValue::Timestamp(timestamp)),
|
669
|
-
})
|
670
|
-
}
|
671
|
-
ValueType::VectorCommand(ref vector_command_schema) => {
|
672
|
-
let command = self.deserialize_vector_command(vector_command_schema, value)?;
|
673
|
-
Ok(ProtoValue {
|
674
|
-
value_name: schema.value_name.clone(),
|
675
|
-
value: Some(EnumValue::VectorCommand(command)),
|
676
|
-
})
|
677
|
-
}
|
678
|
-
ValueType::StateTensor(ref state_tensor_schema) => {
|
679
|
-
let tensor = self.deserialize_state_tensor(state_tensor_schema, value)?;
|
680
|
-
Ok(ProtoValue {
|
681
|
-
value_name: schema.value_name.clone(),
|
682
|
-
value: Some(EnumValue::StateTensor(tensor)),
|
683
|
-
})
|
684
|
-
}
|
685
|
-
}
|
686
|
-
}
|
687
|
-
}
|
688
|
-
|
689
|
-
fn calculate_value_size(schema: &ValueSchema) -> Result<usize, Box<dyn Error>> {
|
690
|
-
match schema.value_type.as_ref().ok_or("Missing value type")? {
|
691
|
-
ValueType::JointPositions(s) => Ok(s.joint_names.len()),
|
692
|
-
ValueType::JointVelocities(s) => Ok(s.joint_names.len()),
|
693
|
-
ValueType::JointTorques(s) => Ok(s.joint_names.len()),
|
694
|
-
ValueType::JointCommands(s) => Ok(s.joint_names.len() * 5), // 5 values per joint
|
695
|
-
ValueType::CameraFrame(s) => Ok((s.channels * s.height * s.width) as usize),
|
696
|
-
ValueType::AudioFrame(s) => Ok((s.channels * s.sample_rate) as usize),
|
697
|
-
ValueType::Imu(s) => {
|
698
|
-
let mut size = 0;
|
699
|
-
if s.use_accelerometer {
|
700
|
-
size += 3;
|
701
|
-
}
|
702
|
-
if s.use_gyroscope {
|
703
|
-
size += 3;
|
704
|
-
}
|
705
|
-
if s.use_magnetometer {
|
706
|
-
size += 3;
|
707
|
-
}
|
708
|
-
Ok(size)
|
709
|
-
}
|
710
|
-
ValueType::Timestamp(_) => Ok(1),
|
711
|
-
ValueType::VectorCommand(s) => Ok(s.dimensions as usize),
|
712
|
-
ValueType::StateTensor(s) => Ok(s.shape.iter().product::<i32>() as usize),
|
713
|
-
}
|
714
|
-
}
|
715
|
-
|
716
|
-
pub struct OnnxMultiSerializer {
|
717
|
-
serializers: Vec<OnnxSerializer>,
|
718
|
-
}
|
719
|
-
|
720
|
-
impl OnnxMultiSerializer {
|
721
|
-
pub fn new(schema: ProtoIOSchema) -> Self {
|
722
|
-
Self {
|
723
|
-
serializers: schema
|
724
|
-
.values
|
725
|
-
.into_iter()
|
726
|
-
.map(|s| OnnxSerializer::new(s))
|
727
|
-
.collect(),
|
728
|
-
}
|
729
|
-
}
|
730
|
-
|
731
|
-
pub fn serialize_io(&self, io: ProtoIO) -> Result<OrtValue, Box<dyn Error>> {
|
732
|
-
if io.values.len() != self.serializers.len() {
|
733
|
-
return Err("Number of values does not match schema".into());
|
734
|
-
}
|
735
|
-
|
736
|
-
// Serialize each value according to its schema and concatenate the results
|
737
|
-
let mut all_values: Vec<f32> = Vec::new();
|
738
|
-
for (value, serializer) in io.values.iter().zip(self.serializers.iter()) {
|
739
|
-
let tensor = serializer.serialize(&serializer.schema, value.clone())?;
|
740
|
-
let array = tensor.try_extract_tensor::<f32>()?;
|
741
|
-
let array_1d = array
|
742
|
-
.as_standard_layout()
|
743
|
-
.into_dimensionality::<ndarray::Ix1>()?;
|
744
|
-
all_values.extend(array_1d.iter().copied());
|
745
|
-
}
|
746
|
-
|
747
|
-
// Convert to OrtValue
|
748
|
-
Tensor::from_array(Array1::from_vec(all_values))
|
749
|
-
.map(|tensor| tensor.into_dyn())
|
750
|
-
.map_err(|e| Box::new(e) as Box<dyn Error>)
|
751
|
-
}
|
752
|
-
|
753
|
-
pub fn deserialize_io(&self, values: Vec<OrtValue>) -> Result<ProtoIO, Box<dyn Error>> {
|
754
|
-
// Check if number of values matches number of serializers
|
755
|
-
if values.len() != self.serializers.len() {
|
756
|
-
return Err(format!(
|
757
|
-
"Number of values ({}) does not match number of serializers ({})",
|
758
|
-
values.len(),
|
759
|
-
self.serializers.len()
|
760
|
-
)
|
761
|
-
.into());
|
762
|
-
}
|
763
|
-
|
764
|
-
// Deserialize each value using its corresponding serializer
|
765
|
-
let proto_values = self
|
766
|
-
.serializers
|
767
|
-
.iter()
|
768
|
-
.zip(values.into_iter())
|
769
|
-
.map(|(serializer, value)| serializer.deserialize(&serializer.schema, value))
|
770
|
-
.collect::<Result<Vec<_>, _>>()?;
|
771
|
-
|
772
|
-
Ok(ProtoIO {
|
773
|
-
values: proto_values,
|
774
|
-
})
|
775
|
-
}
|
776
|
-
|
777
|
-
pub fn names(&self) -> Vec<String> {
|
778
|
-
self.serializers
|
779
|
-
.iter()
|
780
|
-
.map(|s| s.schema.value_name.clone())
|
781
|
-
.collect()
|
782
|
-
}
|
783
|
-
|
784
|
-
pub fn assign_names(
|
785
|
-
&self,
|
786
|
-
values: Vec<OrtValue>,
|
787
|
-
) -> Result<std::collections::HashMap<String, OrtValue>, Box<dyn Error>> {
|
788
|
-
if values.len() != self.serializers.len() {
|
789
|
-
return Err(format!(
|
790
|
-
"Expected {} values, got {}",
|
791
|
-
self.serializers.len(),
|
792
|
-
values.len()
|
793
|
-
)
|
794
|
-
.into());
|
795
|
-
}
|
796
|
-
|
797
|
-
Ok(self
|
798
|
-
.serializers
|
799
|
-
.iter()
|
800
|
-
.map(|s| s.schema.value_name.clone())
|
801
|
-
.zip(values)
|
802
|
-
.collect())
|
803
|
-
}
|
804
|
-
}
|