kinfer 0.3.3__cp312-cp312-macosx_11_0_arm64.whl → 0.4.1__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kinfer/__init__.py +0 -5
- kinfer/common/__init__.py +0 -0
- kinfer/common/types.py +12 -0
- kinfer/export/common.py +41 -0
- kinfer/export/jax.py +53 -0
- kinfer/export/pytorch.py +45 -110
- kinfer/export/serialize.py +93 -0
- kinfer/requirements.txt +3 -4
- kinfer/rust/Cargo.toml +20 -8
- kinfer/rust/src/lib.rs +2 -11
- kinfer/rust/src/model.rs +286 -121
- kinfer/rust/src/runtime.rs +104 -0
- kinfer/rust_bindings/Cargo.toml +8 -1
- kinfer/rust_bindings/rust_bindings.pyi +36 -0
- kinfer/rust_bindings/src/lib.rs +326 -1
- kinfer/rust_bindings.cpython-312-darwin.so +0 -0
- kinfer/rust_bindings.pyi +30 -1
- kinfer-0.4.1.dist-info/METADATA +55 -0
- kinfer-0.4.1.dist-info/RECORD +26 -0
- {kinfer-0.3.3.dist-info → kinfer-0.4.1.dist-info}/WHEEL +2 -1
- kinfer/inference/__init__.py +0 -2
- kinfer/inference/base.py +0 -64
- kinfer/inference/python.py +0 -66
- kinfer/proto/__init__.py +0 -40
- kinfer/proto/kinfer_pb2.py +0 -103
- kinfer/proto/kinfer_pb2.pyi +0 -1097
- kinfer/requirements-dev.txt +0 -8
- kinfer/rust/build.rs +0 -16
- kinfer/rust/src/kinfer_proto.rs +0 -14
- kinfer/rust/src/main.rs +0 -6
- kinfer/rust/src/onnx_serializer.rs +0 -804
- kinfer/rust/src/serializer.rs +0 -221
- kinfer/rust/src/tests/onnx_serializer_tests.rs +0 -212
- kinfer/serialize/__init__.py +0 -60
- kinfer/serialize/base.py +0 -536
- kinfer/serialize/json.py +0 -399
- kinfer/serialize/numpy.py +0 -426
- kinfer/serialize/pytorch.py +0 -402
- kinfer/serialize/schema.py +0 -125
- kinfer/serialize/types.py +0 -17
- kinfer/serialize/utils.py +0 -177
- kinfer-0.3.3.dist-info/METADATA +0 -57
- kinfer-0.3.3.dist-info/RECORD +0 -40
- {kinfer-0.3.3.dist-info → kinfer-0.4.1.dist-info/licenses}/LICENSE +0 -0
- {kinfer-0.3.3.dist-info → kinfer-0.4.1.dist-info}/top_level.txt +0 -0
kinfer/serialize/pytorch.py
DELETED
@@ -1,402 +0,0 @@
|
|
1
|
-
"""Defines a serializer for PyTorch tensors."""
|
2
|
-
|
3
|
-
from typing import cast
|
4
|
-
|
5
|
-
import numpy as np
|
6
|
-
import torch
|
7
|
-
from torch import Tensor
|
8
|
-
|
9
|
-
from kinfer import proto as K
|
10
|
-
from kinfer.serialize.base import (
|
11
|
-
AudioFrameSerializer,
|
12
|
-
CameraFrameSerializer,
|
13
|
-
ImuSerializer,
|
14
|
-
JointCommandsSerializer,
|
15
|
-
JointPositionsSerializer,
|
16
|
-
JointTorquesSerializer,
|
17
|
-
JointVelocitiesSerializer,
|
18
|
-
MultiSerializer,
|
19
|
-
Serializer,
|
20
|
-
StateTensorSerializer,
|
21
|
-
TimestampSerializer,
|
22
|
-
VectorCommandSerializer,
|
23
|
-
)
|
24
|
-
from kinfer.serialize.utils import (
|
25
|
-
check_names_match,
|
26
|
-
convert_angular_position,
|
27
|
-
convert_angular_velocity,
|
28
|
-
convert_torque,
|
29
|
-
dtype_num_bytes,
|
30
|
-
dtype_range,
|
31
|
-
numpy_dtype,
|
32
|
-
parse_bytes,
|
33
|
-
pytorch_dtype,
|
34
|
-
)
|
35
|
-
|
36
|
-
|
37
|
-
class PyTorchBaseSerializer:
|
38
|
-
def __init__(
|
39
|
-
self: "PyTorchBaseSerializer",
|
40
|
-
device: str | torch.device | None = None,
|
41
|
-
dtype: torch.dtype | None = None,
|
42
|
-
) -> None:
|
43
|
-
self.device = device
|
44
|
-
self.dtype = dtype
|
45
|
-
|
46
|
-
|
47
|
-
class PyTorchJointPositionsSerializer(PyTorchBaseSerializer, JointPositionsSerializer[Tensor]):
|
48
|
-
def serialize_joint_positions(
|
49
|
-
self: "PyTorchJointPositionsSerializer",
|
50
|
-
schema: K.JointPositionsSchema,
|
51
|
-
value: K.JointPositionsValue,
|
52
|
-
) -> Tensor:
|
53
|
-
value_map = {v.joint_name: v for v in value.values}
|
54
|
-
check_names_match("schema", schema.joint_names, "value", list(value_map.keys()))
|
55
|
-
tensor = torch.tensor(
|
56
|
-
[
|
57
|
-
convert_angular_position(value_map[name].value, value_map[name].unit, schema.unit)
|
58
|
-
for name in schema.joint_names
|
59
|
-
],
|
60
|
-
dtype=self.dtype,
|
61
|
-
device=self.device,
|
62
|
-
)
|
63
|
-
return tensor
|
64
|
-
|
65
|
-
def deserialize_joint_positions(
|
66
|
-
self: "PyTorchJointPositionsSerializer",
|
67
|
-
schema: K.JointPositionsSchema,
|
68
|
-
value: Tensor,
|
69
|
-
) -> K.JointPositionsValue:
|
70
|
-
if value.shape != (len(schema.joint_names),):
|
71
|
-
raise ValueError(
|
72
|
-
f"Shape of tensor must match number of joint names: {value.shape} != {len(schema.joint_names)}"
|
73
|
-
)
|
74
|
-
value_list = cast(list[float], value.detach().cpu().numpy().astype(float).tolist())
|
75
|
-
return K.JointPositionsValue(
|
76
|
-
values=[
|
77
|
-
K.JointPositionValue(joint_name=name, value=value_list[i], unit=schema.unit)
|
78
|
-
for i, name in enumerate(schema.joint_names)
|
79
|
-
]
|
80
|
-
)
|
81
|
-
|
82
|
-
|
83
|
-
class PyTorchJointVelocitiesSerializer(PyTorchBaseSerializer, JointVelocitiesSerializer[Tensor]):
|
84
|
-
def serialize_joint_velocities(
|
85
|
-
self: "PyTorchJointVelocitiesSerializer",
|
86
|
-
schema: K.JointVelocitiesSchema,
|
87
|
-
value: K.JointVelocitiesValue,
|
88
|
-
) -> Tensor:
|
89
|
-
value_map = {v.joint_name: v for v in value.values}
|
90
|
-
check_names_match("schema", schema.joint_names, "value", list(value_map.keys()))
|
91
|
-
tensor = torch.tensor(
|
92
|
-
[
|
93
|
-
convert_angular_velocity(value_map[name].value, value_map[name].unit, schema.unit)
|
94
|
-
for name in schema.joint_names
|
95
|
-
],
|
96
|
-
dtype=self.dtype,
|
97
|
-
device=self.device,
|
98
|
-
)
|
99
|
-
return tensor
|
100
|
-
|
101
|
-
def deserialize_joint_velocities(
|
102
|
-
self: "PyTorchJointVelocitiesSerializer",
|
103
|
-
schema: K.JointVelocitiesSchema,
|
104
|
-
value: Tensor,
|
105
|
-
) -> K.JointVelocitiesValue:
|
106
|
-
if value.shape != (len(schema.joint_names),):
|
107
|
-
raise ValueError(
|
108
|
-
f"Shape of tensor must match number of joint names: {value.shape} != {len(schema.joint_names)}"
|
109
|
-
)
|
110
|
-
value_list = cast(list[float], value.detach().cpu().numpy().astype(float).tolist())
|
111
|
-
return K.JointVelocitiesValue(
|
112
|
-
values=[
|
113
|
-
K.JointVelocityValue(joint_name=name, value=value_list[i], unit=schema.unit)
|
114
|
-
for i, name in enumerate(schema.joint_names)
|
115
|
-
]
|
116
|
-
)
|
117
|
-
|
118
|
-
|
119
|
-
class PyTorchJointTorquesSerializer(PyTorchBaseSerializer, JointTorquesSerializer[Tensor]):
|
120
|
-
def serialize_joint_torques(
|
121
|
-
self: "PyTorchJointTorquesSerializer",
|
122
|
-
schema: K.JointTorquesSchema,
|
123
|
-
value: K.JointTorquesValue,
|
124
|
-
) -> Tensor:
|
125
|
-
value_map = {v.joint_name: v for v in value.values}
|
126
|
-
check_names_match("schema", schema.joint_names, "value", list(value_map.keys()))
|
127
|
-
tensor = torch.tensor(
|
128
|
-
[convert_torque(value_map[name].value, value_map[name].unit, schema.unit) for name in schema.joint_names],
|
129
|
-
dtype=self.dtype,
|
130
|
-
device=self.device,
|
131
|
-
)
|
132
|
-
return tensor
|
133
|
-
|
134
|
-
def deserialize_joint_torques(
|
135
|
-
self: "PyTorchJointTorquesSerializer",
|
136
|
-
schema: K.JointTorquesSchema,
|
137
|
-
value: Tensor,
|
138
|
-
) -> K.JointTorquesValue:
|
139
|
-
if value.shape != (len(schema.joint_names),):
|
140
|
-
raise ValueError(
|
141
|
-
f"Shape of tensor must match number of joint names: {value.shape} != {len(schema.joint_names)}"
|
142
|
-
)
|
143
|
-
value_list = cast(list[float], value.detach().cpu().numpy().astype(float).tolist())
|
144
|
-
return K.JointTorquesValue(
|
145
|
-
values=[
|
146
|
-
K.JointTorqueValue(joint_name=name, value=value_list[i], unit=schema.unit)
|
147
|
-
for i, name in enumerate(schema.joint_names)
|
148
|
-
]
|
149
|
-
)
|
150
|
-
|
151
|
-
|
152
|
-
class PyTorchJointCommandsSerializer(PyTorchBaseSerializer, JointCommandsSerializer[Tensor]):
|
153
|
-
def _convert_value_to_tensor(
|
154
|
-
self: "PyTorchJointCommandsSerializer",
|
155
|
-
value: K.JointCommandValue,
|
156
|
-
schema: K.JointCommandsSchema,
|
157
|
-
) -> Tensor:
|
158
|
-
return torch.tensor(
|
159
|
-
[
|
160
|
-
convert_torque(value.torque, value.torque_unit, schema.torque_unit),
|
161
|
-
convert_angular_velocity(value.velocity, value.velocity_unit, schema.velocity_unit),
|
162
|
-
convert_angular_position(value.position, value.position_unit, schema.position_unit),
|
163
|
-
value.kp,
|
164
|
-
value.kd,
|
165
|
-
],
|
166
|
-
dtype=self.dtype,
|
167
|
-
device=self.device,
|
168
|
-
)
|
169
|
-
|
170
|
-
def _convert_tensor_to_value(
|
171
|
-
self: "PyTorchJointCommandsSerializer",
|
172
|
-
values: list[float],
|
173
|
-
schema: K.JointCommandsSchema,
|
174
|
-
name: str,
|
175
|
-
) -> K.JointCommandValue:
|
176
|
-
if len(values) != 5:
|
177
|
-
raise ValueError(f"Shape of tensor must match number of joint commands: {len(values)} != 5")
|
178
|
-
return K.JointCommandValue(
|
179
|
-
joint_name=name,
|
180
|
-
torque=values[0],
|
181
|
-
velocity=values[1],
|
182
|
-
position=values[2],
|
183
|
-
kp=values[3],
|
184
|
-
kd=values[4],
|
185
|
-
torque_unit=schema.torque_unit,
|
186
|
-
velocity_unit=schema.velocity_unit,
|
187
|
-
position_unit=schema.position_unit,
|
188
|
-
)
|
189
|
-
|
190
|
-
def serialize_joint_commands(
|
191
|
-
self: "PyTorchJointCommandsSerializer",
|
192
|
-
schema: K.JointCommandsSchema,
|
193
|
-
value: K.JointCommandsValue,
|
194
|
-
) -> Tensor:
|
195
|
-
value_map = {v.joint_name: v for v in value.values}
|
196
|
-
check_names_match("schema", schema.joint_names, "value", list(value_map.keys()))
|
197
|
-
tensor = torch.stack(
|
198
|
-
[self._convert_value_to_tensor(value_map[name], schema) for name in schema.joint_names],
|
199
|
-
dim=0,
|
200
|
-
)
|
201
|
-
return tensor
|
202
|
-
|
203
|
-
def deserialize_joint_commands(
|
204
|
-
self: "PyTorchJointCommandsSerializer",
|
205
|
-
schema: K.JointCommandsSchema,
|
206
|
-
value: Tensor,
|
207
|
-
) -> K.JointCommandsValue:
|
208
|
-
if value.shape != (len(schema.joint_names), 5):
|
209
|
-
raise ValueError(
|
210
|
-
"Shape of tensor must match number of joint names and commands: "
|
211
|
-
f"{value.shape} != ({len(schema.joint_names)}, 5)"
|
212
|
-
)
|
213
|
-
value_list = cast(list[list[float]], value.detach().cpu().numpy().astype(float).tolist())
|
214
|
-
return K.JointCommandsValue(
|
215
|
-
values=[
|
216
|
-
self._convert_tensor_to_value(value_list[i], schema, name) for i, name in enumerate(schema.joint_names)
|
217
|
-
]
|
218
|
-
)
|
219
|
-
|
220
|
-
|
221
|
-
class PyTorchCameraFrameSerializer(PyTorchBaseSerializer, CameraFrameSerializer[Tensor]):
|
222
|
-
def serialize_camera_frame(
|
223
|
-
self: "PyTorchCameraFrameSerializer", schema: K.CameraFrameSchema, value: K.CameraFrameValue
|
224
|
-
) -> Tensor:
|
225
|
-
np_arr = parse_bytes(value.data, K.DType.UINT8)
|
226
|
-
tensor = torch.from_numpy(np_arr).to(self.device, self.dtype) / 255.0
|
227
|
-
if tensor.numel() != schema.channels * schema.height * schema.width:
|
228
|
-
raise ValueError(
|
229
|
-
"Length of data must match number of channels, height, and width: "
|
230
|
-
f"{tensor.numel()} != {schema.channels} * {schema.height} * {schema.width}"
|
231
|
-
)
|
232
|
-
tensor = tensor.view(schema.channels, schema.height, schema.width)
|
233
|
-
return tensor
|
234
|
-
|
235
|
-
def deserialize_camera_frame(
|
236
|
-
self: "PyTorchCameraFrameSerializer", schema: K.CameraFrameSchema, value: Tensor
|
237
|
-
) -> K.CameraFrameValue:
|
238
|
-
np_arr = (value * 255.0).detach().cpu().flatten().numpy().astype(np.uint8)
|
239
|
-
return K.CameraFrameValue(data=np_arr.tobytes())
|
240
|
-
|
241
|
-
|
242
|
-
class PyTorchAudioFrameSerializer(PyTorchBaseSerializer, AudioFrameSerializer[Tensor]):
|
243
|
-
def serialize_audio_frame(
|
244
|
-
self: "PyTorchAudioFrameSerializer", schema: K.AudioFrameSchema, value: K.AudioFrameValue
|
245
|
-
) -> Tensor:
|
246
|
-
value_bytes = value.data
|
247
|
-
if len(value_bytes) != schema.channels * schema.sample_rate * dtype_num_bytes(schema.dtype):
|
248
|
-
raise ValueError(
|
249
|
-
"Length of data must match number of channels, sample rate, and dtype: "
|
250
|
-
f"{len(value_bytes)} != {schema.channels} * {schema.sample_rate} * {dtype_num_bytes(schema.dtype)}"
|
251
|
-
)
|
252
|
-
_, max_value = dtype_range(schema.dtype)
|
253
|
-
np_arr = parse_bytes(value_bytes, schema.dtype)
|
254
|
-
tensor = torch.from_numpy(np_arr).to(self.device, self.dtype)
|
255
|
-
tensor = tensor.view(schema.channels, -1)
|
256
|
-
tensor = tensor / max_value
|
257
|
-
return tensor
|
258
|
-
|
259
|
-
def deserialize_audio_frame(
|
260
|
-
self: "PyTorchAudioFrameSerializer", schema: K.AudioFrameSchema, value: Tensor
|
261
|
-
) -> K.AudioFrameValue:
|
262
|
-
_, max_value = dtype_range(schema.dtype)
|
263
|
-
np_arr = (value * max_value).detach().cpu().flatten().numpy().astype(numpy_dtype(schema.dtype))
|
264
|
-
return K.AudioFrameValue(data=np_arr.tobytes())
|
265
|
-
|
266
|
-
|
267
|
-
class PyTorchImuSerializer(PyTorchBaseSerializer, ImuSerializer[Tensor]):
|
268
|
-
def serialize_imu(self: "PyTorchImuSerializer", schema: K.ImuSchema, value: K.ImuValue) -> Tensor:
|
269
|
-
vectors: list[Tensor] = []
|
270
|
-
if schema.use_accelerometer:
|
271
|
-
vectors.append(
|
272
|
-
torch.tensor(
|
273
|
-
[value.linear_acceleration.x, value.linear_acceleration.y, value.linear_acceleration.z],
|
274
|
-
dtype=self.dtype,
|
275
|
-
device=self.device,
|
276
|
-
)
|
277
|
-
)
|
278
|
-
if schema.use_gyroscope:
|
279
|
-
vectors.append(
|
280
|
-
torch.tensor(
|
281
|
-
[value.angular_velocity.x, value.angular_velocity.y, value.angular_velocity.z],
|
282
|
-
dtype=self.dtype,
|
283
|
-
device=self.device,
|
284
|
-
)
|
285
|
-
)
|
286
|
-
if schema.use_magnetometer:
|
287
|
-
vectors.append(
|
288
|
-
torch.tensor(
|
289
|
-
[value.magnetic_field.x, value.magnetic_field.y, value.magnetic_field.z],
|
290
|
-
dtype=self.dtype,
|
291
|
-
device=self.device,
|
292
|
-
)
|
293
|
-
)
|
294
|
-
if not vectors:
|
295
|
-
raise ValueError("IMU has nothing to serialize")
|
296
|
-
return torch.stack(vectors, dim=0)
|
297
|
-
|
298
|
-
def deserialize_imu(self: "PyTorchImuSerializer", schema: K.ImuSchema, value: Tensor) -> K.ImuValue:
|
299
|
-
vectors = value.tolist()
|
300
|
-
imu_value = K.ImuValue()
|
301
|
-
if schema.use_accelerometer:
|
302
|
-
(x, y, z), vectors = vectors[0], vectors[1:]
|
303
|
-
imu_value.linear_acceleration.x = x
|
304
|
-
imu_value.linear_acceleration.y = y
|
305
|
-
imu_value.linear_acceleration.z = z
|
306
|
-
if schema.use_gyroscope:
|
307
|
-
(x, y, z), vectors = vectors[0], vectors[1:]
|
308
|
-
imu_value.angular_velocity.x = x
|
309
|
-
imu_value.angular_velocity.y = y
|
310
|
-
imu_value.angular_velocity.z = z
|
311
|
-
if schema.use_magnetometer:
|
312
|
-
(x, y, z), vectors = vectors[0], vectors[1:]
|
313
|
-
imu_value.magnetic_field.x = x
|
314
|
-
imu_value.magnetic_field.y = y
|
315
|
-
imu_value.magnetic_field.z = z
|
316
|
-
return imu_value
|
317
|
-
|
318
|
-
|
319
|
-
class PyTorchTimestampSerializer(PyTorchBaseSerializer, TimestampSerializer[Tensor]):
|
320
|
-
def serialize_timestamp(
|
321
|
-
self: "PyTorchTimestampSerializer", schema: K.TimestampSchema, value: K.TimestampValue
|
322
|
-
) -> Tensor:
|
323
|
-
elapsed_seconds = value.seconds - schema.start_seconds
|
324
|
-
elapsed_nanos = value.nanos - schema.start_nanos
|
325
|
-
if elapsed_nanos < 0:
|
326
|
-
elapsed_seconds -= 1
|
327
|
-
elapsed_nanos += 1_000_000_000
|
328
|
-
total_elapsed_seconds = elapsed_seconds + elapsed_nanos / 1_000_000_000
|
329
|
-
return torch.tensor([total_elapsed_seconds], dtype=self.dtype, device=self.device, requires_grad=False)
|
330
|
-
|
331
|
-
def deserialize_timestamp(
|
332
|
-
self: "PyTorchTimestampSerializer", schema: K.TimestampSchema, value: Tensor
|
333
|
-
) -> K.TimestampValue:
|
334
|
-
total_elapsed_seconds = value.item()
|
335
|
-
elapsed_seconds = int(total_elapsed_seconds)
|
336
|
-
elapsed_nanos = int((total_elapsed_seconds - elapsed_seconds) * 1_000_000_000)
|
337
|
-
return K.TimestampValue(seconds=elapsed_seconds, nanos=elapsed_nanos)
|
338
|
-
|
339
|
-
|
340
|
-
class PyTorchVectorCommandSerializer(PyTorchBaseSerializer, VectorCommandSerializer[Tensor]):
|
341
|
-
def serialize_vector_command(
|
342
|
-
self: "PyTorchVectorCommandSerializer", schema: K.VectorCommandSchema, value: K.VectorCommandValue
|
343
|
-
) -> Tensor:
|
344
|
-
return torch.tensor(value.values, dtype=self.dtype, device=self.device)
|
345
|
-
|
346
|
-
def deserialize_vector_command(
|
347
|
-
self: "PyTorchVectorCommandSerializer", schema: K.VectorCommandSchema, value: Tensor
|
348
|
-
) -> K.VectorCommandValue:
|
349
|
-
if value.shape != (schema.dimensions,):
|
350
|
-
raise ValueError(f"Shape of tensor must match number of dimensions: {value.shape} != {schema.dimensions}")
|
351
|
-
values = cast(list[float], value.tolist())
|
352
|
-
return K.VectorCommandValue(values=values)
|
353
|
-
|
354
|
-
|
355
|
-
class PyTorchStateTensorSerializer(PyTorchBaseSerializer, StateTensorSerializer[Tensor]):
|
356
|
-
def serialize_state_tensor(
|
357
|
-
self: "PyTorchStateTensorSerializer", schema: K.StateTensorSchema, value: K.StateTensorValue
|
358
|
-
) -> Tensor:
|
359
|
-
value_bytes = value.data
|
360
|
-
if len(value_bytes) != np.prod(schema.shape) * dtype_num_bytes(schema.dtype):
|
361
|
-
raise ValueError(
|
362
|
-
"Length of data must match number of elements: "
|
363
|
-
f"{len(value_bytes)} != {np.prod(schema.shape)} * {dtype_num_bytes(schema.dtype)}"
|
364
|
-
)
|
365
|
-
np_arr = parse_bytes(value_bytes, schema.dtype)
|
366
|
-
tensor = torch.from_numpy(np_arr).to(self.device, pytorch_dtype(schema.dtype))
|
367
|
-
tensor = tensor.view(tuple(schema.shape))
|
368
|
-
return tensor
|
369
|
-
|
370
|
-
def deserialize_state_tensor(
|
371
|
-
self: "PyTorchStateTensorSerializer", schema: K.StateTensorSchema, value: Tensor
|
372
|
-
) -> K.StateTensorValue:
|
373
|
-
return K.StateTensorValue(data=value.detach().cpu().flatten().numpy().tobytes())
|
374
|
-
|
375
|
-
|
376
|
-
class PyTorchSerializer(
|
377
|
-
PyTorchJointPositionsSerializer,
|
378
|
-
PyTorchJointVelocitiesSerializer,
|
379
|
-
PyTorchJointTorquesSerializer,
|
380
|
-
PyTorchJointCommandsSerializer,
|
381
|
-
PyTorchCameraFrameSerializer,
|
382
|
-
PyTorchAudioFrameSerializer,
|
383
|
-
PyTorchImuSerializer,
|
384
|
-
PyTorchTimestampSerializer,
|
385
|
-
PyTorchVectorCommandSerializer,
|
386
|
-
PyTorchStateTensorSerializer,
|
387
|
-
Serializer[Tensor],
|
388
|
-
):
|
389
|
-
def __init__(
|
390
|
-
self: "PyTorchSerializer",
|
391
|
-
schema: K.ValueSchema,
|
392
|
-
*,
|
393
|
-
device: str | torch.device | None = None,
|
394
|
-
dtype: torch.dtype | None = None,
|
395
|
-
) -> None:
|
396
|
-
PyTorchBaseSerializer.__init__(self, device=device, dtype=dtype)
|
397
|
-
Serializer.__init__(self, schema=schema)
|
398
|
-
|
399
|
-
|
400
|
-
class PyTorchMultiSerializer(MultiSerializer[Tensor]):
|
401
|
-
def __init__(self: "PyTorchMultiSerializer", schema: K.IOSchema) -> None:
|
402
|
-
super().__init__([PyTorchSerializer(schema=s) for s in schema.values])
|
kinfer/serialize/schema.py
DELETED
@@ -1,125 +0,0 @@
|
|
1
|
-
"""Defines utility functions for the schema."""
|
2
|
-
|
3
|
-
import numpy as np
|
4
|
-
|
5
|
-
from kinfer import proto as K
|
6
|
-
from kinfer.serialize.utils import dtype_num_bytes
|
7
|
-
|
8
|
-
|
9
|
-
def get_dummy_value(value_schema: K.ValueSchema) -> K.Value:
|
10
|
-
value_type = value_schema.WhichOneof("value_type")
|
11
|
-
|
12
|
-
match value_type:
|
13
|
-
case "joint_positions":
|
14
|
-
return K.Value(
|
15
|
-
joint_positions=K.JointPositionsValue(
|
16
|
-
values=[
|
17
|
-
K.JointPositionValue(
|
18
|
-
joint_name=joint_name,
|
19
|
-
value=0.0,
|
20
|
-
unit=value_schema.joint_positions.unit,
|
21
|
-
)
|
22
|
-
for joint_name in value_schema.joint_positions.joint_names
|
23
|
-
]
|
24
|
-
),
|
25
|
-
)
|
26
|
-
case "joint_velocities":
|
27
|
-
return K.Value(
|
28
|
-
joint_velocities=K.JointVelocitiesValue(
|
29
|
-
values=[
|
30
|
-
K.JointVelocityValue(
|
31
|
-
joint_name=joint_name,
|
32
|
-
value=0.0,
|
33
|
-
unit=value_schema.joint_velocities.unit,
|
34
|
-
)
|
35
|
-
for joint_name in value_schema.joint_velocities.joint_names
|
36
|
-
]
|
37
|
-
),
|
38
|
-
)
|
39
|
-
case "joint_torques":
|
40
|
-
return K.Value(
|
41
|
-
joint_torques=K.JointTorquesValue(
|
42
|
-
values=[
|
43
|
-
K.JointTorqueValue(
|
44
|
-
joint_name=joint_name,
|
45
|
-
value=0.0,
|
46
|
-
unit=value_schema.joint_torques.unit,
|
47
|
-
)
|
48
|
-
for joint_name in value_schema.joint_torques.joint_names
|
49
|
-
]
|
50
|
-
),
|
51
|
-
)
|
52
|
-
case "joint_commands":
|
53
|
-
return K.Value(
|
54
|
-
joint_commands=K.JointCommandsValue(
|
55
|
-
values=[
|
56
|
-
K.JointCommandValue(
|
57
|
-
joint_name=joint_name,
|
58
|
-
torque=0.0,
|
59
|
-
velocity=0.0,
|
60
|
-
position=0.0,
|
61
|
-
kp=0.0,
|
62
|
-
kd=0.0,
|
63
|
-
torque_unit=value_schema.joint_commands.torque_unit,
|
64
|
-
velocity_unit=value_schema.joint_commands.velocity_unit,
|
65
|
-
position_unit=value_schema.joint_commands.position_unit,
|
66
|
-
)
|
67
|
-
for joint_name in value_schema.joint_commands.joint_names
|
68
|
-
]
|
69
|
-
),
|
70
|
-
)
|
71
|
-
case "camera_frame":
|
72
|
-
return K.Value(
|
73
|
-
camera_frame=K.CameraFrameValue(
|
74
|
-
data=b"\x00"
|
75
|
-
* (
|
76
|
-
value_schema.camera_frame.width
|
77
|
-
* value_schema.camera_frame.height
|
78
|
-
* value_schema.camera_frame.channels
|
79
|
-
)
|
80
|
-
),
|
81
|
-
)
|
82
|
-
case "audio_frame":
|
83
|
-
return K.Value(
|
84
|
-
audio_frame=K.AudioFrameValue(
|
85
|
-
data=b"\x00"
|
86
|
-
* (
|
87
|
-
value_schema.audio_frame.channels
|
88
|
-
* value_schema.audio_frame.sample_rate
|
89
|
-
* dtype_num_bytes(value_schema.audio_frame.dtype)
|
90
|
-
)
|
91
|
-
),
|
92
|
-
)
|
93
|
-
case "imu":
|
94
|
-
return K.Value(
|
95
|
-
imu=K.ImuValue(
|
96
|
-
linear_acceleration=K.ImuAccelerometerValue(x=0.0, y=0.0, z=0.0),
|
97
|
-
angular_velocity=K.ImuGyroscopeValue(x=0.0, y=0.0, z=0.0),
|
98
|
-
magnetic_field=K.ImuMagnetometerValue(x=0.0, y=0.0, z=0.0),
|
99
|
-
),
|
100
|
-
)
|
101
|
-
case "timestamp":
|
102
|
-
return K.Value(
|
103
|
-
timestamp=K.TimestampValue(seconds=1728000000, nanos=0),
|
104
|
-
)
|
105
|
-
case "vector_command":
|
106
|
-
return K.Value(
|
107
|
-
vector_command=K.VectorCommandValue(values=[0.0] * value_schema.vector_command.dimensions),
|
108
|
-
)
|
109
|
-
case "state_tensor":
|
110
|
-
return K.Value(
|
111
|
-
state_tensor=K.StateTensorValue(
|
112
|
-
data=b"\x00"
|
113
|
-
* np.prod(value_schema.state_tensor.shape)
|
114
|
-
* dtype_num_bytes(value_schema.state_tensor.dtype)
|
115
|
-
),
|
116
|
-
)
|
117
|
-
case _:
|
118
|
-
raise ValueError(f"Invalid value type: {value_type}")
|
119
|
-
|
120
|
-
|
121
|
-
def get_dummy_io(schema: K.IOSchema) -> K.IO:
|
122
|
-
io_value = K.IO()
|
123
|
-
for value_schema in schema.values:
|
124
|
-
io_value.values.append(get_dummy_value(value_schema))
|
125
|
-
return io_value
|
kinfer/serialize/types.py
DELETED
@@ -1,17 +0,0 @@
|
|
1
|
-
"""Type conversion utilities for serializers."""
|
2
|
-
|
3
|
-
from typing import Type, TypeVar, cast
|
4
|
-
|
5
|
-
from kinfer import proto as K
|
6
|
-
|
7
|
-
T = TypeVar("T")
|
8
|
-
|
9
|
-
|
10
|
-
def to_value_type(enum_value: T) -> K.Value:
|
11
|
-
"""Convert an enum value to ValueType."""
|
12
|
-
return cast(K.Value, enum_value)
|
13
|
-
|
14
|
-
|
15
|
-
def from_value_type(value_type: K.Value, enum_class: Type[T]) -> T:
|
16
|
-
"""Convert a ValueType to the specified enum type."""
|
17
|
-
return cast(T, value_type)
|