kinfer 0.3.2__cp312-cp312-macosx_11_0_arm64.whl → 0.4.0__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. kinfer/__init__.py +0 -1
  2. kinfer/common/__init__.py +0 -0
  3. kinfer/common/types.py +11 -0
  4. kinfer/export/__init__.py +0 -1
  5. kinfer/export/common.py +35 -0
  6. kinfer/export/jax.py +51 -0
  7. kinfer/export/pytorch.py +42 -110
  8. kinfer/export/serialize.py +86 -0
  9. kinfer/requirements.txt +3 -4
  10. kinfer/rust/Cargo.toml +8 -6
  11. kinfer/rust/src/lib.rs +2 -11
  12. kinfer/rust/src/model.rs +271 -121
  13. kinfer/rust/src/runtime.rs +104 -0
  14. kinfer/rust_bindings/Cargo.toml +8 -1
  15. kinfer/rust_bindings/rust_bindings.pyi +35 -0
  16. kinfer/rust_bindings/src/lib.rs +310 -1
  17. kinfer/rust_bindings.cpython-312-darwin.so +0 -0
  18. kinfer/rust_bindings.pyi +29 -1
  19. kinfer-0.4.0.dist-info/METADATA +55 -0
  20. kinfer-0.4.0.dist-info/RECORD +26 -0
  21. {kinfer-0.3.2.dist-info → kinfer-0.4.0.dist-info}/WHEEL +2 -1
  22. kinfer/inference/__init__.py +0 -1
  23. kinfer/inference/python.py +0 -92
  24. kinfer/proto/__init__.py +0 -40
  25. kinfer/proto/kinfer_pb2.py +0 -103
  26. kinfer/proto/kinfer_pb2.pyi +0 -1097
  27. kinfer/requirements-dev.txt +0 -8
  28. kinfer/rust/build.rs +0 -16
  29. kinfer/rust/src/kinfer_proto.rs +0 -14
  30. kinfer/rust/src/main.rs +0 -6
  31. kinfer/rust/src/onnx_serializer.rs +0 -804
  32. kinfer/rust/src/serializer.rs +0 -221
  33. kinfer/rust/src/tests/onnx_serializer_tests.rs +0 -212
  34. kinfer/serialize/__init__.py +0 -36
  35. kinfer/serialize/base.py +0 -536
  36. kinfer/serialize/json.py +0 -399
  37. kinfer/serialize/numpy.py +0 -426
  38. kinfer/serialize/pytorch.py +0 -402
  39. kinfer/serialize/schema.py +0 -125
  40. kinfer/serialize/types.py +0 -17
  41. kinfer/serialize/utils.py +0 -177
  42. kinfer-0.3.2.dist-info/METADATA +0 -57
  43. kinfer-0.3.2.dist-info/RECORD +0 -39
  44. {kinfer-0.3.2.dist-info → kinfer-0.4.0.dist-info/licenses}/LICENSE +0 -0
  45. {kinfer-0.3.2.dist-info → kinfer-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,804 +0,0 @@
1
- use crate::serializer::{
2
- convert_position, convert_torque, convert_velocity, AudioFrameSerializer,
3
- CameraFrameSerializer, ImuSerializer, JointCommandsSerializer, JointPositionsSerializer,
4
- JointTorquesSerializer, JointVelocitiesSerializer, Serializer, StateTensorSerializer,
5
- TimestampSerializer, VectorCommandSerializer,
6
- };
7
-
8
- use ndarray::{s, Array, Array1, Array2, Array3, ArrayView, ArrayView1, ArrayView2};
9
- use ort::value::{Tensor, Value as OrtValue};
10
- use std::error::Error;
11
-
12
- // Import the re-exported types
13
- use crate::kinfer_proto::{
14
- AudioFrameSchema, AudioFrameValue, CameraFrameSchema, CameraFrameValue, DType,
15
- ImuAccelerometerValue, ImuGyroscopeValue, ImuMagnetometerValue, ImuSchema, ImuValue,
16
- JointCommandValue, JointCommandsSchema, JointCommandsValue, JointPositionUnit,
17
- JointPositionValue, JointPositionsSchema, JointPositionsValue, JointTorqueUnit,
18
- JointTorqueValue, JointTorquesSchema, JointTorquesValue, JointVelocitiesSchema,
19
- JointVelocitiesValue, JointVelocityUnit, JointVelocityValue, ProtoIO, ProtoIOSchema,
20
- ProtoValue, StateTensorSchema, StateTensorValue, TimestampSchema, TimestampValue, ValueSchema,
21
- VectorCommandSchema, VectorCommandValue,
22
- };
23
-
24
- // Import the nested types
25
- use crate::kinfer_proto::proto::value::Value as EnumValue;
26
- use crate::kinfer_proto::proto::value_schema::ValueType;
27
-
28
- pub struct OnnxSerializer {
29
- schema: ValueSchema,
30
- }
31
-
32
- impl OnnxSerializer {
33
- pub fn new(schema: ValueSchema) -> Self {
34
- Self { schema }
35
- }
36
-
37
- fn array_to_value<T, D>(&self, array: Array<T, D>) -> Result<OrtValue, Box<dyn Error>>
38
- where
39
- T: Into<f32> + Copy,
40
- D: ndarray::Dimension,
41
- {
42
- let array = array.map(|&x| x.into());
43
- Tensor::from_array(array.into_dyn())
44
- .map(|tensor| tensor.into_dyn())
45
- .map_err(|e| Box::new(e) as Box<dyn Error>)
46
- }
47
- }
48
-
49
- impl JointPositionsSerializer for OnnxSerializer {
50
- fn serialize_joint_positions(
51
- &self,
52
- schema: &JointPositionsSchema,
53
- value: JointPositionsValue,
54
- ) -> Result<OrtValue, Box<dyn Error>> {
55
- let mut array = Array1::zeros(schema.joint_names.len());
56
- for (i, name) in schema.joint_names.iter().enumerate() {
57
- if let Some(joint) = value.values.iter().find(|v| v.joint_name == *name) {
58
- let from_unit = JointPositionUnit::try_from(joint.unit)?;
59
- let to_unit = JointPositionUnit::try_from(schema.unit)?;
60
- array[i] = convert_position(joint.value, from_unit, to_unit)?;
61
- }
62
- }
63
- self.array_to_value(array)
64
- }
65
-
66
- fn deserialize_joint_positions(
67
- &self,
68
- schema: &JointPositionsSchema,
69
- value: OrtValue,
70
- ) -> Result<JointPositionsValue, Box<dyn Error>> {
71
- let tensor = value.try_extract_tensor()?;
72
- let array = tensor.view();
73
-
74
- if array.len() != schema.joint_names.len() {
75
- return Err("Array length does not match number of joints".into());
76
- }
77
-
78
- Ok(JointPositionsValue {
79
- values: schema
80
- .joint_names
81
- .iter()
82
- .enumerate()
83
- .map(|(i, name)| JointPositionValue {
84
- joint_name: name.clone(),
85
- value: array[i],
86
- unit: schema.unit,
87
- })
88
- .collect(),
89
- })
90
- }
91
- }
92
-
93
- impl JointVelocitiesSerializer for OnnxSerializer {
94
- fn serialize_joint_velocities(
95
- &self,
96
- schema: &JointVelocitiesSchema,
97
- value: JointVelocitiesValue,
98
- ) -> Result<OrtValue, Box<dyn Error>> {
99
- let mut array = Array1::zeros(schema.joint_names.len());
100
- for (i, name) in schema.joint_names.iter().enumerate() {
101
- if let Some(joint) = value.values.iter().find(|v| v.joint_name == *name) {
102
- let from_unit = JointVelocityUnit::try_from(joint.unit)?;
103
- let to_unit = JointVelocityUnit::try_from(schema.unit)?;
104
- array[i] = convert_velocity(joint.value, from_unit, to_unit)?;
105
- }
106
- }
107
- self.array_to_value(array)
108
- }
109
-
110
- fn deserialize_joint_velocities(
111
- &self,
112
- schema: &JointVelocitiesSchema,
113
- value: OrtValue,
114
- ) -> Result<JointVelocitiesValue, Box<dyn Error>> {
115
- let tensor = value.try_extract_tensor()?;
116
- let array = tensor.view();
117
-
118
- if array.len() != schema.joint_names.len() {
119
- return Err("Array length does not match number of joints".into());
120
- }
121
-
122
- Ok(JointVelocitiesValue {
123
- values: schema
124
- .joint_names
125
- .iter()
126
- .enumerate()
127
- .map(|(i, name)| JointVelocityValue {
128
- joint_name: name.clone(),
129
- value: array[i],
130
- unit: schema.unit.clone(),
131
- })
132
- .collect(),
133
- })
134
- }
135
- }
136
-
137
- impl JointTorquesSerializer for OnnxSerializer {
138
- fn serialize_joint_torques(
139
- &self,
140
- schema: &JointTorquesSchema,
141
- value: JointTorquesValue,
142
- ) -> Result<OrtValue, Box<dyn Error>> {
143
- let mut array = Array1::zeros(schema.joint_names.len());
144
- for (i, name) in schema.joint_names.iter().enumerate() {
145
- if let Some(joint) = value.values.iter().find(|v| v.joint_name == *name) {
146
- let from_unit = JointTorqueUnit::try_from(joint.unit)?;
147
- let to_unit = JointTorqueUnit::try_from(schema.unit)?;
148
- array[i] = convert_torque(joint.value, from_unit, to_unit)?;
149
- }
150
- }
151
- self.array_to_value(array)
152
- }
153
-
154
- fn deserialize_joint_torques(
155
- &self,
156
- schema: &JointTorquesSchema,
157
- value: OrtValue,
158
- ) -> Result<JointTorquesValue, Box<dyn Error>> {
159
- let tensor = value.try_extract_tensor()?;
160
- let array = tensor.view();
161
-
162
- if array.len() != schema.joint_names.len() {
163
- return Err("Array length does not match number of joints".into());
164
- }
165
-
166
- Ok(JointTorquesValue {
167
- values: schema
168
- .joint_names
169
- .iter()
170
- .enumerate()
171
- .map(|(i, name)| JointTorqueValue {
172
- joint_name: name.clone(),
173
- value: array[i],
174
- unit: schema.unit,
175
- })
176
- .collect(),
177
- })
178
- }
179
- }
180
-
181
- impl JointCommandsSerializer for OnnxSerializer {
182
- fn serialize_joint_commands(
183
- &self,
184
- schema: &JointCommandsSchema,
185
- value: JointCommandsValue,
186
- ) -> Result<OrtValue, Box<dyn Error>> {
187
- let mut array = Array2::zeros((schema.joint_names.len(), 5));
188
- for (i, name) in schema.joint_names.iter().enumerate() {
189
- if let Some(cmd) = value.values.iter().find(|v| v.joint_name == *name) {
190
- let cmd_torque_unit = JointTorqueUnit::try_from(cmd.torque_unit)?;
191
- let cmd_velocity_unit = JointVelocityUnit::try_from(cmd.velocity_unit)?;
192
- let cmd_position_unit = JointPositionUnit::try_from(cmd.position_unit)?;
193
- let schema_torque_unit = JointTorqueUnit::try_from(schema.torque_unit)?;
194
- let schema_velocity_unit = JointVelocityUnit::try_from(schema.velocity_unit)?;
195
- let schema_position_unit = JointPositionUnit::try_from(schema.position_unit)?;
196
- array[[i, 0]] = convert_torque(cmd.torque, cmd_torque_unit, schema_torque_unit)?;
197
- array[[i, 1]] =
198
- convert_velocity(cmd.velocity, cmd_velocity_unit, schema_velocity_unit)?;
199
- array[[i, 2]] =
200
- convert_position(cmd.position, cmd_position_unit, schema_position_unit)?;
201
- array[[i, 3]] = cmd.kp;
202
- array[[i, 4]] = cmd.kd;
203
- }
204
- }
205
- self.array_to_value(array)
206
- }
207
-
208
- fn deserialize_joint_commands(
209
- &self,
210
- schema: &JointCommandsSchema,
211
- value: OrtValue,
212
- ) -> Result<JointCommandsValue, Box<dyn Error>> {
213
- let tensor = value.try_extract_tensor()?;
214
- let array = tensor.view();
215
-
216
- if array.shape() != [schema.joint_names.len(), 5] {
217
- return Err("Array shape does not match expected dimensions".into());
218
- }
219
-
220
- Ok(JointCommandsValue {
221
- values: schema
222
- .joint_names
223
- .iter()
224
- .enumerate()
225
- .map(|(i, name)| JointCommandValue {
226
- joint_name: name.clone(),
227
- torque: array[[i, 0]],
228
- velocity: array[[i, 1]],
229
- position: array[[i, 2]],
230
- kp: array[[i, 3]],
231
- kd: array[[i, 4]],
232
- torque_unit: schema.torque_unit,
233
- velocity_unit: schema.velocity_unit,
234
- position_unit: schema.position_unit,
235
- })
236
- .collect(),
237
- })
238
- }
239
- }
240
-
241
- impl CameraFrameSerializer for OnnxSerializer {
242
- fn serialize_camera_frame(
243
- &self,
244
- schema: &CameraFrameSchema,
245
- value: CameraFrameValue,
246
- ) -> Result<OrtValue, Box<dyn Error>> {
247
- let bytes = value.data;
248
- let array = Array3::from_shape_vec(
249
- (
250
- schema.channels as usize,
251
- schema.height as usize,
252
- schema.width as usize,
253
- ),
254
- bytes.iter().map(|&x| x as f32 / 255.0).collect(),
255
- )?;
256
- self.array_to_value(array)
257
- }
258
-
259
- fn deserialize_camera_frame(
260
- &self,
261
- schema: &CameraFrameSchema,
262
- value: OrtValue,
263
- ) -> Result<CameraFrameValue, Box<dyn Error>> {
264
- let tensor = value.try_extract_tensor()?;
265
- let array = tensor.view();
266
-
267
- if array.shape()
268
- != [
269
- schema.channels as usize,
270
- schema.height as usize,
271
- schema.width as usize,
272
- ]
273
- {
274
- return Err("Array shape does not match expected dimensions".into());
275
- }
276
-
277
- let bytes: Vec<u8> = array
278
- .iter()
279
- .map(|&x: &f32| (x * 255.0).clamp(0.0, 255.0) as u8)
280
- .collect();
281
-
282
- Ok(CameraFrameValue { data: bytes })
283
- }
284
- }
285
-
286
- impl AudioFrameSerializer for OnnxSerializer {
287
- fn serialize_audio_frame(
288
- &self,
289
- schema: &AudioFrameSchema,
290
- value: AudioFrameValue,
291
- ) -> Result<OrtValue, Box<dyn Error>> {
292
- let array = Array2::from_shape_vec(
293
- (schema.channels as usize, schema.sample_rate as usize),
294
- parse_audio_bytes(&value.data, schema.dtype.try_into()?)?,
295
- )?;
296
- self.array_to_value(array)
297
- }
298
-
299
- fn deserialize_audio_frame(
300
- &self,
301
- schema: &AudioFrameSchema,
302
- value: OrtValue,
303
- ) -> Result<AudioFrameValue, Box<dyn Error>> {
304
- let tensor = value.try_extract_tensor()?;
305
- let array = tensor.view();
306
-
307
- if array.shape() != [schema.channels as usize, schema.sample_rate as usize] {
308
- return Err("Array shape does not match expected dimensions".into());
309
- }
310
-
311
- let array = array.into_dimensionality::<ndarray::Ix2>()?;
312
-
313
- Ok(AudioFrameValue {
314
- data: audio_array_to_bytes(array, schema.dtype.try_into()?)?,
315
- })
316
- }
317
- }
318
-
319
- impl ImuSerializer for OnnxSerializer {
320
- fn serialize_imu(
321
- &self,
322
- schema: &ImuSchema,
323
- value: ImuValue,
324
- ) -> Result<OrtValue, Box<dyn Error>> {
325
- let mut vectors = Vec::new();
326
-
327
- if schema.use_accelerometer {
328
- if let Some(acc) = &value.linear_acceleration {
329
- vectors.push([acc.x, acc.y, acc.z]);
330
- }
331
- }
332
- if schema.use_gyroscope {
333
- if let Some(gyro) = &value.angular_velocity {
334
- vectors.push([gyro.x, gyro.y, gyro.z]);
335
- }
336
- }
337
- if schema.use_magnetometer {
338
- if let Some(mag) = &value.magnetic_field {
339
- vectors.push([mag.x, mag.y, mag.z]);
340
- }
341
- }
342
-
343
- let array = Array2::from_shape_vec(
344
- (vectors.len(), 3),
345
- vectors.into_iter().flat_map(|v| v.into_iter()).collect(),
346
- )?;
347
- self.array_to_value(array)
348
- }
349
-
350
- fn deserialize_imu(
351
- &self,
352
- schema: &ImuSchema,
353
- value: OrtValue,
354
- ) -> Result<ImuValue, Box<dyn Error>> {
355
- let tensor = value.try_extract_tensor()?;
356
- let array = tensor.view();
357
- let mut result = ImuValue::default();
358
- let mut idx = 0;
359
-
360
- if schema.use_accelerometer {
361
- result.linear_acceleration = Some(ImuAccelerometerValue {
362
- x: array[[idx, 0]],
363
- y: array[[idx, 1]],
364
- z: array[[idx, 2]],
365
- });
366
- idx += 1;
367
- }
368
- if schema.use_gyroscope {
369
- result.angular_velocity = Some(ImuGyroscopeValue {
370
- x: array[[idx, 0]],
371
- y: array[[idx, 1]],
372
- z: array[[idx, 2]],
373
- });
374
- idx += 1;
375
- }
376
- if schema.use_magnetometer {
377
- result.magnetic_field = Some(ImuMagnetometerValue {
378
- x: array[[idx, 0]],
379
- y: array[[idx, 1]],
380
- z: array[[idx, 2]],
381
- });
382
- }
383
-
384
- Ok(result)
385
- }
386
- }
387
-
388
- impl TimestampSerializer for OnnxSerializer {
389
- fn serialize_timestamp(
390
- &self,
391
- schema: &TimestampSchema,
392
- value: TimestampValue,
393
- ) -> Result<OrtValue, Box<dyn Error>> {
394
- let elapsed_seconds = value.seconds - schema.start_seconds;
395
- let elapsed_nanos = value.nanos - schema.start_nanos;
396
- let total_seconds = elapsed_seconds as f32 + (elapsed_nanos as f32 / 1_000_000_000.0);
397
- self.array_to_value(Array1::from_vec(vec![total_seconds]))
398
- }
399
-
400
- fn deserialize_timestamp(
401
- &self,
402
- schema: &TimestampSchema,
403
- value: OrtValue,
404
- ) -> Result<TimestampValue, Box<dyn Error>> {
405
- let tensor = value.try_extract_tensor()?;
406
- let array = tensor.view();
407
-
408
- // Get first element using iterator
409
- let total_seconds: f32 = *array.iter().next().ok_or("Timestamp tensor is empty")?;
410
-
411
- let elapsed_seconds = total_seconds.trunc() as i64;
412
- let elapsed_nanos = ((total_seconds.fract() * 1_000_000_000.0).round()) as i32;
413
-
414
- Ok(TimestampValue {
415
- seconds: schema.start_seconds + elapsed_seconds,
416
- nanos: schema.start_nanos + elapsed_nanos,
417
- })
418
- }
419
- }
420
-
421
- impl VectorCommandSerializer for OnnxSerializer {
422
- fn serialize_vector_command(
423
- &self,
424
- schema: &VectorCommandSchema,
425
- value: VectorCommandValue,
426
- ) -> Result<OrtValue, Box<dyn Error>> {
427
- let array = Array1::from_vec(value.values);
428
- self.array_to_value(array)
429
- }
430
-
431
- fn deserialize_vector_command(
432
- &self,
433
- schema: &VectorCommandSchema,
434
- value: OrtValue,
435
- ) -> Result<VectorCommandValue, Box<dyn Error>> {
436
- let tensor = value.try_extract_tensor()?;
437
- let array = tensor.view();
438
-
439
- if array.len() != schema.dimensions as usize {
440
- return Err("Array length does not match expected dimensions".into());
441
- }
442
-
443
- Ok(VectorCommandValue {
444
- values: array.iter().copied().collect(),
445
- })
446
- }
447
- }
448
-
449
- impl StateTensorSerializer for OnnxSerializer {
450
- fn serialize_state_tensor(
451
- &self,
452
- schema: &StateTensorSchema,
453
- value: StateTensorValue,
454
- ) -> Result<OrtValue, Box<dyn Error>> {
455
- let shape: Vec<usize> = schema.shape.iter().map(|&x| x as usize).collect();
456
- let array = Array::from_shape_vec(
457
- shape,
458
- parse_tensor_bytes(&value.data, schema.dtype.try_into()?)?,
459
- )?;
460
- self.array_to_value(array)
461
- }
462
-
463
- fn deserialize_state_tensor(
464
- &self,
465
- schema: &StateTensorSchema,
466
- value: OrtValue,
467
- ) -> Result<StateTensorValue, Box<dyn Error>> {
468
- let tensor = value.try_extract_tensor()?;
469
- let array = tensor.view();
470
-
471
- let expected_shape: Vec<usize> = schema.shape.iter().map(|&x| x as usize).collect();
472
- if array.shape() != expected_shape.as_slice() {
473
- return Err("Array shape does not match expected dimensions".into());
474
- }
475
-
476
- Ok(StateTensorValue {
477
- data: tensor_array_to_bytes(array.view(), schema.dtype.try_into()?)?,
478
- })
479
- }
480
- }
481
-
482
- // Helper functions for parsing bytes
483
- fn parse_audio_bytes(bytes: &[u8], dtype: DType) -> Result<Vec<f32>, Box<dyn Error>> {
484
- match dtype {
485
- DType::Fp32 => {
486
- let mut result = Vec::with_capacity(bytes.len() / 4);
487
- for chunk in bytes.chunks_exact(4) {
488
- let value = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
489
- result.push(value);
490
- }
491
- Ok(result)
492
- }
493
- _ => Err("Unsupported audio data type".into()),
494
- }
495
- }
496
-
497
- fn audio_array_to_bytes(array: ArrayView2<f32>, dtype: DType) -> Result<Vec<u8>, Box<dyn Error>> {
498
- match dtype {
499
- DType::Fp32 => {
500
- let mut result = Vec::with_capacity(array.len() * 4);
501
- for &value in array.iter() {
502
- result.extend_from_slice(&value.to_le_bytes());
503
- }
504
- Ok(result)
505
- }
506
- _ => Err("Unsupported audio data type".into()),
507
- }
508
- }
509
-
510
- fn parse_tensor_bytes(bytes: &[u8], dtype: DType) -> Result<Vec<f32>, Box<dyn Error>> {
511
- match dtype {
512
- DType::Fp32 => {
513
- let mut result = Vec::with_capacity(bytes.len() / 4);
514
- for chunk in bytes.chunks_exact(4) {
515
- let value = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
516
- result.push(value);
517
- }
518
- Ok(result)
519
- }
520
- _ => Err("Unsupported tensor data type".into()),
521
- }
522
- }
523
-
524
- fn tensor_array_to_bytes(
525
- array: ArrayView<f32, ndarray::IxDyn>,
526
- dtype: DType,
527
- ) -> Result<Vec<u8>, Box<dyn Error>> {
528
- match dtype {
529
- DType::Fp32 => {
530
- let mut result = Vec::with_capacity(array.len() * 4);
531
- for &value in array.iter() {
532
- result.extend_from_slice(&value.to_le_bytes());
533
- }
534
- Ok(result)
535
- }
536
- _ => Err("Unsupported tensor data type".into()),
537
- }
538
- }
539
-
540
- impl Serializer for OnnxSerializer {
541
- fn serialize(
542
- &self,
543
- schema: &ValueSchema,
544
- value: ProtoValue,
545
- ) -> Result<OrtValue, Box<dyn Error>> {
546
- match schema.value_type.as_ref().ok_or("Missing value type")? {
547
- ValueType::JointPositions(ref joint_positions_schema) => match value.value {
548
- Some(EnumValue::JointPositions(values)) => {
549
- self.serialize_joint_positions(joint_positions_schema, values)
550
- }
551
- _ => Err("Unsupported value type".into()),
552
- },
553
- ValueType::JointVelocities(ref joint_velocities_schema) => match value.value {
554
- Some(EnumValue::JointVelocities(values)) => {
555
- self.serialize_joint_velocities(joint_velocities_schema, values)
556
- }
557
- _ => Err("Unsupported value type".into()),
558
- },
559
- ValueType::JointTorques(ref joint_torques_schema) => match value.value {
560
- Some(EnumValue::JointTorques(values)) => {
561
- self.serialize_joint_torques(joint_torques_schema, values)
562
- }
563
- _ => Err("Unsupported value type".into()),
564
- },
565
- ValueType::JointCommands(ref joint_commands_schema) => match value.value {
566
- Some(EnumValue::JointCommands(values)) => {
567
- self.serialize_joint_commands(joint_commands_schema, values)
568
- }
569
- _ => Err("Unsupported value type".into()),
570
- },
571
- ValueType::CameraFrame(ref camera_frame_schema) => match value.value {
572
- Some(EnumValue::CameraFrame(values)) => {
573
- self.serialize_camera_frame(camera_frame_schema, values)
574
- }
575
- _ => Err("Unsupported value type".into()),
576
- },
577
- ValueType::AudioFrame(ref audio_frame_schema) => match value.value {
578
- Some(EnumValue::AudioFrame(values)) => {
579
- self.serialize_audio_frame(audio_frame_schema, values)
580
- }
581
- _ => Err("Unsupported value type".into()),
582
- },
583
- ValueType::Imu(ref imu_schema) => match value.value {
584
- Some(EnumValue::Imu(values)) => self.serialize_imu(imu_schema, values),
585
- _ => Err("Unsupported value type".into()),
586
- },
587
- ValueType::Timestamp(ref timestamp_schema) => match value.value {
588
- Some(EnumValue::Timestamp(values)) => {
589
- self.serialize_timestamp(timestamp_schema, values)
590
- }
591
- _ => Err("Unsupported value type".into()),
592
- },
593
- ValueType::VectorCommand(ref vector_command_schema) => match value.value {
594
- Some(EnumValue::VectorCommand(values)) => {
595
- self.serialize_vector_command(vector_command_schema, values)
596
- }
597
- _ => Err("Unsupported value type".into()),
598
- },
599
- ValueType::StateTensor(ref state_tensor_schema) => match value.value {
600
- Some(EnumValue::StateTensor(values)) => {
601
- self.serialize_state_tensor(state_tensor_schema, values)
602
- }
603
- _ => Err("Unsupported value type".into()),
604
- },
605
- }
606
- }
607
-
608
- fn deserialize(
609
- &self,
610
- schema: &ValueSchema,
611
- value: OrtValue,
612
- ) -> Result<ProtoValue, Box<dyn Error>> {
613
- match schema.value_type.as_ref().ok_or("Missing value type")? {
614
- ValueType::JointPositions(ref joint_positions_schema) => {
615
- let positions = self.deserialize_joint_positions(joint_positions_schema, value)?;
616
- Ok(ProtoValue {
617
- value_name: schema.value_name.clone(),
618
- value: Some(EnumValue::JointPositions(positions)),
619
- })
620
- }
621
- ValueType::JointVelocities(ref joint_velocities_schema) => {
622
- let velocities =
623
- self.deserialize_joint_velocities(joint_velocities_schema, value)?;
624
- Ok(ProtoValue {
625
- value_name: schema.value_name.clone(),
626
- value: Some(EnumValue::JointVelocities(velocities)),
627
- })
628
- }
629
- ValueType::JointTorques(ref joint_torques_schema) => {
630
- let torques = self.deserialize_joint_torques(joint_torques_schema, value)?;
631
- Ok(ProtoValue {
632
- value_name: schema.value_name.clone(),
633
- value: Some(EnumValue::JointTorques(torques)),
634
- })
635
- }
636
- ValueType::JointCommands(ref joint_commands_schema) => {
637
- let commands = self.deserialize_joint_commands(joint_commands_schema, value)?;
638
- Ok(ProtoValue {
639
- value_name: schema.value_name.clone(),
640
- value: Some(EnumValue::JointCommands(commands)),
641
- })
642
- }
643
- ValueType::CameraFrame(ref camera_frame_schema) => {
644
- let frame = self.deserialize_camera_frame(camera_frame_schema, value)?;
645
- Ok(ProtoValue {
646
- value_name: schema.value_name.clone(),
647
- value: Some(EnumValue::CameraFrame(frame)),
648
- })
649
- }
650
- ValueType::AudioFrame(ref audio_frame_schema) => {
651
- let frame = self.deserialize_audio_frame(audio_frame_schema, value)?;
652
- Ok(ProtoValue {
653
- value_name: schema.value_name.clone(),
654
- value: Some(EnumValue::AudioFrame(frame)),
655
- })
656
- }
657
- ValueType::Imu(ref imu_schema) => {
658
- let imu = self.deserialize_imu(imu_schema, value)?;
659
- Ok(ProtoValue {
660
- value_name: schema.value_name.clone(),
661
- value: Some(EnumValue::Imu(imu)),
662
- })
663
- }
664
- ValueType::Timestamp(ref timestamp_schema) => {
665
- let timestamp = self.deserialize_timestamp(timestamp_schema, value)?;
666
- Ok(ProtoValue {
667
- value_name: schema.value_name.clone(),
668
- value: Some(EnumValue::Timestamp(timestamp)),
669
- })
670
- }
671
- ValueType::VectorCommand(ref vector_command_schema) => {
672
- let command = self.deserialize_vector_command(vector_command_schema, value)?;
673
- Ok(ProtoValue {
674
- value_name: schema.value_name.clone(),
675
- value: Some(EnumValue::VectorCommand(command)),
676
- })
677
- }
678
- ValueType::StateTensor(ref state_tensor_schema) => {
679
- let tensor = self.deserialize_state_tensor(state_tensor_schema, value)?;
680
- Ok(ProtoValue {
681
- value_name: schema.value_name.clone(),
682
- value: Some(EnumValue::StateTensor(tensor)),
683
- })
684
- }
685
- }
686
- }
687
- }
688
-
689
- fn calculate_value_size(schema: &ValueSchema) -> Result<usize, Box<dyn Error>> {
690
- match schema.value_type.as_ref().ok_or("Missing value type")? {
691
- ValueType::JointPositions(s) => Ok(s.joint_names.len()),
692
- ValueType::JointVelocities(s) => Ok(s.joint_names.len()),
693
- ValueType::JointTorques(s) => Ok(s.joint_names.len()),
694
- ValueType::JointCommands(s) => Ok(s.joint_names.len() * 5), // 5 values per joint
695
- ValueType::CameraFrame(s) => Ok((s.channels * s.height * s.width) as usize),
696
- ValueType::AudioFrame(s) => Ok((s.channels * s.sample_rate) as usize),
697
- ValueType::Imu(s) => {
698
- let mut size = 0;
699
- if s.use_accelerometer {
700
- size += 3;
701
- }
702
- if s.use_gyroscope {
703
- size += 3;
704
- }
705
- if s.use_magnetometer {
706
- size += 3;
707
- }
708
- Ok(size)
709
- }
710
- ValueType::Timestamp(_) => Ok(1),
711
- ValueType::VectorCommand(s) => Ok(s.dimensions as usize),
712
- ValueType::StateTensor(s) => Ok(s.shape.iter().product::<i32>() as usize),
713
- }
714
- }
715
-
716
- pub struct OnnxMultiSerializer {
717
- serializers: Vec<OnnxSerializer>,
718
- }
719
-
720
- impl OnnxMultiSerializer {
721
- pub fn new(schema: ProtoIOSchema) -> Self {
722
- Self {
723
- serializers: schema
724
- .values
725
- .into_iter()
726
- .map(|s| OnnxSerializer::new(s))
727
- .collect(),
728
- }
729
- }
730
-
731
- pub fn serialize_io(&self, io: ProtoIO) -> Result<OrtValue, Box<dyn Error>> {
732
- if io.values.len() != self.serializers.len() {
733
- return Err("Number of values does not match schema".into());
734
- }
735
-
736
- // Serialize each value according to its schema and concatenate the results
737
- let mut all_values: Vec<f32> = Vec::new();
738
- for (value, serializer) in io.values.iter().zip(self.serializers.iter()) {
739
- let tensor = serializer.serialize(&serializer.schema, value.clone())?;
740
- let array = tensor.try_extract_tensor::<f32>()?;
741
- let array_1d = array
742
- .as_standard_layout()
743
- .into_dimensionality::<ndarray::Ix1>()?;
744
- all_values.extend(array_1d.iter().copied());
745
- }
746
-
747
- // Convert to OrtValue
748
- Tensor::from_array(Array1::from_vec(all_values))
749
- .map(|tensor| tensor.into_dyn())
750
- .map_err(|e| Box::new(e) as Box<dyn Error>)
751
- }
752
-
753
- pub fn deserialize_io(&self, values: Vec<OrtValue>) -> Result<ProtoIO, Box<dyn Error>> {
754
- // Check if number of values matches number of serializers
755
- if values.len() != self.serializers.len() {
756
- return Err(format!(
757
- "Number of values ({}) does not match number of serializers ({})",
758
- values.len(),
759
- self.serializers.len()
760
- )
761
- .into());
762
- }
763
-
764
- // Deserialize each value using its corresponding serializer
765
- let proto_values = self
766
- .serializers
767
- .iter()
768
- .zip(values.into_iter())
769
- .map(|(serializer, value)| serializer.deserialize(&serializer.schema, value))
770
- .collect::<Result<Vec<_>, _>>()?;
771
-
772
- Ok(ProtoIO {
773
- values: proto_values,
774
- })
775
- }
776
-
777
- pub fn names(&self) -> Vec<String> {
778
- self.serializers
779
- .iter()
780
- .map(|s| s.schema.value_name.clone())
781
- .collect()
782
- }
783
-
784
- pub fn assign_names(
785
- &self,
786
- values: Vec<OrtValue>,
787
- ) -> Result<std::collections::HashMap<String, OrtValue>, Box<dyn Error>> {
788
- if values.len() != self.serializers.len() {
789
- return Err(format!(
790
- "Expected {} values, got {}",
791
- self.serializers.len(),
792
- values.len()
793
- )
794
- .into());
795
- }
796
-
797
- Ok(self
798
- .serializers
799
- .iter()
800
- .map(|s| s.schema.value_name.clone())
801
- .zip(values)
802
- .collect())
803
- }
804
- }