kinfer 0.3.2__cp312-cp312-macosx_11_0_arm64.whl → 0.3.3__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kinfer/serialize/numpy.py CHANGED
@@ -4,7 +4,7 @@ from typing import cast
4
4
 
5
5
  import numpy as np
6
6
 
7
- from kinfer import proto as P
7
+ from kinfer import proto as K
8
8
  from kinfer.serialize.base import (
9
9
  AudioFrameSerializer,
10
10
  CameraFrameSerializer,
@@ -40,8 +40,8 @@ class NumpyBaseSerializer:
40
40
  class NumpyJointPositionsSerializer(NumpyBaseSerializer, JointPositionsSerializer[np.ndarray]):
41
41
  def serialize_joint_positions(
42
42
  self: "NumpyJointPositionsSerializer",
43
- schema: P.JointPositionsSchema,
44
- value: P.JointPositionsValue,
43
+ schema: K.JointPositionsSchema,
44
+ value: K.JointPositionsValue,
45
45
  ) -> np.ndarray:
46
46
  value_map = {v.joint_name: v for v in value.values}
47
47
  check_names_match("schema", schema.joint_names, "value", list(value_map.keys()))
@@ -57,17 +57,17 @@ class NumpyJointPositionsSerializer(NumpyBaseSerializer, JointPositionsSerialize
57
57
 
58
58
  def deserialize_joint_positions(
59
59
  self: "NumpyJointPositionsSerializer",
60
- schema: P.JointPositionsSchema,
60
+ schema: K.JointPositionsSchema,
61
61
  value: np.ndarray,
62
- ) -> P.JointPositionsValue:
62
+ ) -> K.JointPositionsValue:
63
63
  if value.shape != (len(schema.joint_names),):
64
64
  raise ValueError(
65
65
  f"Shape of array must match number of joint names: {value.shape} != {len(schema.joint_names)}"
66
66
  )
67
67
  value_list = cast(list[float], value.astype(float).tolist())
68
- return P.JointPositionsValue(
68
+ return K.JointPositionsValue(
69
69
  values=[
70
- P.JointPositionValue(
70
+ K.JointPositionValue(
71
71
  joint_name=name,
72
72
  value=float(value_list[i]),
73
73
  unit=schema.unit,
@@ -80,8 +80,8 @@ class NumpyJointPositionsSerializer(NumpyBaseSerializer, JointPositionsSerialize
80
80
  class NumpyJointVelocitiesSerializer(NumpyBaseSerializer, JointVelocitiesSerializer[np.ndarray]):
81
81
  def serialize_joint_velocities(
82
82
  self: "NumpyJointVelocitiesSerializer",
83
- schema: P.JointVelocitiesSchema,
84
- value: P.JointVelocitiesValue,
83
+ schema: K.JointVelocitiesSchema,
84
+ value: K.JointVelocitiesValue,
85
85
  ) -> np.ndarray:
86
86
  value_map = {v.joint_name: v for v in value.values}
87
87
  check_names_match("schema", schema.joint_names, "value", list(value_map.keys()))
@@ -96,17 +96,17 @@ class NumpyJointVelocitiesSerializer(NumpyBaseSerializer, JointVelocitiesSeriali
96
96
 
97
97
  def deserialize_joint_velocities(
98
98
  self: "NumpyJointVelocitiesSerializer",
99
- schema: P.JointVelocitiesSchema,
99
+ schema: K.JointVelocitiesSchema,
100
100
  value: np.ndarray,
101
- ) -> P.JointVelocitiesValue:
101
+ ) -> K.JointVelocitiesValue:
102
102
  if value.shape != (len(schema.joint_names),):
103
103
  raise ValueError(
104
104
  f"Shape of array must match number of joint names: {value.shape} != {len(schema.joint_names)}"
105
105
  )
106
106
  value_list = cast(list[float], value.astype(float).tolist())
107
- return P.JointVelocitiesValue(
107
+ return K.JointVelocitiesValue(
108
108
  values=[
109
- P.JointVelocityValue(joint_name=name, value=value_list[i], unit=schema.unit)
109
+ K.JointVelocityValue(joint_name=name, value=value_list[i], unit=schema.unit)
110
110
  for i, name in enumerate(schema.joint_names)
111
111
  ]
112
112
  )
@@ -115,8 +115,8 @@ class NumpyJointVelocitiesSerializer(NumpyBaseSerializer, JointVelocitiesSeriali
115
115
  class NumpyJointTorquesSerializer(NumpyBaseSerializer, JointTorquesSerializer[np.ndarray]):
116
116
  def serialize_joint_torques(
117
117
  self: "NumpyJointTorquesSerializer",
118
- schema: P.JointTorquesSchema,
119
- value: P.JointTorquesValue,
118
+ schema: K.JointTorquesSchema,
119
+ value: K.JointTorquesValue,
120
120
  ) -> np.ndarray:
121
121
  value_map = {v.joint_name: v for v in value.values}
122
122
  check_names_match("schema", schema.joint_names, "value", list(value_map.keys()))
@@ -128,17 +128,17 @@ class NumpyJointTorquesSerializer(NumpyBaseSerializer, JointTorquesSerializer[np
128
128
 
129
129
  def deserialize_joint_torques(
130
130
  self: "NumpyJointTorquesSerializer",
131
- schema: P.JointTorquesSchema,
131
+ schema: K.JointTorquesSchema,
132
132
  value: np.ndarray,
133
- ) -> P.JointTorquesValue:
133
+ ) -> K.JointTorquesValue:
134
134
  if value.shape != (len(schema.joint_names),):
135
135
  raise ValueError(
136
136
  f"Shape of array must match number of joint names: {value.shape} != {len(schema.joint_names)}"
137
137
  )
138
138
  value_list = cast(list[float], value.astype(float).tolist())
139
- return P.JointTorquesValue(
139
+ return K.JointTorquesValue(
140
140
  values=[
141
- P.JointTorqueValue(joint_name=name, value=float(value_list[i]), unit=schema.unit)
141
+ K.JointTorqueValue(joint_name=name, value=float(value_list[i]), unit=schema.unit)
142
142
  for i, name in enumerate(schema.joint_names)
143
143
  ]
144
144
  )
@@ -147,8 +147,8 @@ class NumpyJointTorquesSerializer(NumpyBaseSerializer, JointTorquesSerializer[np
147
147
  class NumpyJointCommandsSerializer(NumpyBaseSerializer, JointCommandsSerializer[np.ndarray]):
148
148
  def _convert_value_to_array(
149
149
  self: "NumpyJointCommandsSerializer",
150
- value: P.JointCommandValue,
151
- schema: P.JointCommandsSchema,
150
+ value: K.JointCommandValue,
151
+ schema: K.JointCommandsSchema,
152
152
  ) -> np.ndarray:
153
153
  return np.array(
154
154
  [
@@ -164,12 +164,12 @@ class NumpyJointCommandsSerializer(NumpyBaseSerializer, JointCommandsSerializer[
164
164
  def _convert_array_to_value(
165
165
  self: "NumpyJointCommandsSerializer",
166
166
  values: list[float],
167
- schema: P.JointCommandsSchema,
167
+ schema: K.JointCommandsSchema,
168
168
  name: str,
169
- ) -> P.JointCommandValue:
169
+ ) -> K.JointCommandValue:
170
170
  if len(values) != 5:
171
171
  raise ValueError(f"Shape of array must match number of joint commands: {len(values)} != 5")
172
- return P.JointCommandValue(
172
+ return K.JointCommandValue(
173
173
  joint_name=name,
174
174
  torque=values[0],
175
175
  velocity=values[1],
@@ -183,8 +183,8 @@ class NumpyJointCommandsSerializer(NumpyBaseSerializer, JointCommandsSerializer[
183
183
 
184
184
  def serialize_joint_commands(
185
185
  self: "NumpyJointCommandsSerializer",
186
- schema: P.JointCommandsSchema,
187
- value: P.JointCommandsValue,
186
+ schema: K.JointCommandsSchema,
187
+ value: K.JointCommandsValue,
188
188
  ) -> np.ndarray:
189
189
  value_map = {v.joint_name: v for v in value.values}
190
190
  check_names_match("schema", schema.joint_names, "value", list(value_map.keys()))
@@ -196,16 +196,16 @@ class NumpyJointCommandsSerializer(NumpyBaseSerializer, JointCommandsSerializer[
196
196
 
197
197
  def deserialize_joint_commands(
198
198
  self: "NumpyJointCommandsSerializer",
199
- schema: P.JointCommandsSchema,
199
+ schema: K.JointCommandsSchema,
200
200
  value: np.ndarray,
201
- ) -> P.JointCommandsValue:
201
+ ) -> K.JointCommandsValue:
202
202
  if value.shape != (len(schema.joint_names), 5):
203
203
  raise ValueError(
204
204
  "Shape of array must match number of joint names and commands: "
205
205
  f"{value.shape} != ({len(schema.joint_names)}, 5)"
206
206
  )
207
207
  value_list = cast(list[list[float]], value.astype(float).tolist())
208
- return P.JointCommandsValue(
208
+ return K.JointCommandsValue(
209
209
  values=[
210
210
  self._convert_array_to_value(value_list[i], schema, name) for i, name in enumerate(schema.joint_names)
211
211
  ]
@@ -215,10 +215,10 @@ class NumpyJointCommandsSerializer(NumpyBaseSerializer, JointCommandsSerializer[
215
215
  class NumpyCameraFrameSerializer(NumpyBaseSerializer, CameraFrameSerializer[np.ndarray]):
216
216
  def serialize_camera_frame(
217
217
  self: "NumpyCameraFrameSerializer",
218
- schema: P.CameraFrameSchema,
219
- value: P.CameraFrameValue,
218
+ schema: K.CameraFrameSchema,
219
+ value: K.CameraFrameValue,
220
220
  ) -> np.ndarray:
221
- np_arr = parse_bytes(value.data, P.DType.UINT8)
221
+ np_arr = parse_bytes(value.data, K.DType.UINT8)
222
222
  array = np_arr.astype(self.dtype) / 255.0
223
223
  if array.size != schema.channels * schema.height * schema.width:
224
224
  raise ValueError(
@@ -230,18 +230,18 @@ class NumpyCameraFrameSerializer(NumpyBaseSerializer, CameraFrameSerializer[np.n
230
230
 
231
231
  def deserialize_camera_frame(
232
232
  self: "NumpyCameraFrameSerializer",
233
- schema: P.CameraFrameSchema,
233
+ schema: K.CameraFrameSchema,
234
234
  value: np.ndarray,
235
- ) -> P.CameraFrameValue:
235
+ ) -> K.CameraFrameValue:
236
236
  np_arr = (value * 255.0).flatten().astype(np.uint8)
237
- return P.CameraFrameValue(data=np_arr.tobytes())
237
+ return K.CameraFrameValue(data=np_arr.tobytes())
238
238
 
239
239
 
240
240
  class NumpyAudioFrameSerializer(NumpyBaseSerializer, AudioFrameSerializer[np.ndarray]):
241
241
  def serialize_audio_frame(
242
242
  self: "NumpyAudioFrameSerializer",
243
- schema: P.AudioFrameSchema,
244
- value: P.AudioFrameValue,
243
+ schema: K.AudioFrameSchema,
244
+ value: K.AudioFrameValue,
245
245
  ) -> np.ndarray:
246
246
  value_bytes = value.data
247
247
  if len(value_bytes) != schema.channels * schema.sample_rate * dtype_num_bytes(schema.dtype):
@@ -258,19 +258,19 @@ class NumpyAudioFrameSerializer(NumpyBaseSerializer, AudioFrameSerializer[np.nda
258
258
 
259
259
  def deserialize_audio_frame(
260
260
  self: "NumpyAudioFrameSerializer",
261
- schema: P.AudioFrameSchema,
261
+ schema: K.AudioFrameSchema,
262
262
  value: np.ndarray,
263
- ) -> P.AudioFrameValue:
263
+ ) -> K.AudioFrameValue:
264
264
  _, max_value = dtype_range(schema.dtype)
265
265
  np_arr = (value * max_value).flatten().astype(numpy_dtype(schema.dtype))
266
- return P.AudioFrameValue(data=np_arr.tobytes())
266
+ return K.AudioFrameValue(data=np_arr.tobytes())
267
267
 
268
268
 
269
269
  class NumpyImuSerializer(NumpyBaseSerializer, ImuSerializer[np.ndarray]):
270
270
  def serialize_imu(
271
271
  self: "NumpyImuSerializer",
272
- schema: P.ImuSchema,
273
- value: P.ImuValue,
272
+ schema: K.ImuSchema,
273
+ value: K.ImuValue,
274
274
  ) -> np.ndarray:
275
275
  vectors = []
276
276
  if schema.use_accelerometer:
@@ -300,16 +300,16 @@ class NumpyImuSerializer(NumpyBaseSerializer, ImuSerializer[np.ndarray]):
300
300
 
301
301
  def deserialize_imu(
302
302
  self: "NumpyImuSerializer",
303
- schema: P.ImuSchema,
303
+ schema: K.ImuSchema,
304
304
  value: np.ndarray,
305
- ) -> P.ImuValue:
305
+ ) -> K.ImuValue:
306
306
  num_vectors = sum([schema.use_accelerometer, schema.use_gyroscope, schema.use_magnetometer])
307
307
  if value.shape != (num_vectors, 3):
308
308
  raise ValueError(
309
309
  f"Shape of array must match number of vectors and components: {value.shape} != ({num_vectors}, 3)"
310
310
  )
311
311
  vectors = cast(list[list[float]], value.astype(float).tolist())
312
- imu_value = P.ImuValue()
312
+ imu_value = K.ImuValue()
313
313
  if schema.use_accelerometer:
314
314
  x, y, z = vectors.pop(0)
315
315
  imu_value.linear_acceleration.x = as_float(x)
@@ -331,8 +331,8 @@ class NumpyImuSerializer(NumpyBaseSerializer, ImuSerializer[np.ndarray]):
331
331
  class NumpyTimestampSerializer(NumpyBaseSerializer, TimestampSerializer[np.ndarray]):
332
332
  def serialize_timestamp(
333
333
  self: "NumpyTimestampSerializer",
334
- schema: P.TimestampSchema,
335
- value: P.TimestampValue,
334
+ schema: K.TimestampSchema,
335
+ value: K.TimestampValue,
336
336
  ) -> np.ndarray:
337
337
  elapsed_seconds = value.seconds - schema.start_seconds
338
338
  elapsed_nanos = value.nanos - schema.start_nanos
@@ -344,39 +344,39 @@ class NumpyTimestampSerializer(NumpyBaseSerializer, TimestampSerializer[np.ndarr
344
344
 
345
345
  def deserialize_timestamp(
346
346
  self: "NumpyTimestampSerializer",
347
- schema: P.TimestampSchema,
347
+ schema: K.TimestampSchema,
348
348
  value: np.ndarray,
349
- ) -> P.TimestampValue:
349
+ ) -> K.TimestampValue:
350
350
  total_elapsed_seconds = float(value.item())
351
351
  elapsed_seconds = int(total_elapsed_seconds)
352
352
  elapsed_nanos = int((total_elapsed_seconds - elapsed_seconds) * 1_000_000_000)
353
- return P.TimestampValue(seconds=elapsed_seconds, nanos=elapsed_nanos)
353
+ return K.TimestampValue(seconds=elapsed_seconds, nanos=elapsed_nanos)
354
354
 
355
355
 
356
356
  class NumpyVectorCommandSerializer(NumpyBaseSerializer, VectorCommandSerializer[np.ndarray]):
357
357
  def serialize_vector_command(
358
358
  self: "NumpyVectorCommandSerializer",
359
- schema: P.VectorCommandSchema,
360
- value: P.VectorCommandValue,
359
+ schema: K.VectorCommandSchema,
360
+ value: K.VectorCommandValue,
361
361
  ) -> np.ndarray:
362
362
  return np.array(value.values, dtype=self.dtype)
363
363
 
364
364
  def deserialize_vector_command(
365
365
  self: "NumpyVectorCommandSerializer",
366
- schema: P.VectorCommandSchema,
366
+ schema: K.VectorCommandSchema,
367
367
  value: np.ndarray,
368
- ) -> P.VectorCommandValue:
368
+ ) -> K.VectorCommandValue:
369
369
  if value.shape != (schema.dimensions,):
370
370
  raise ValueError(f"Shape of array must match number of dimensions: {value.shape} != {schema.dimensions}")
371
371
  values = cast(list[float], value.astype(float).tolist())
372
- return P.VectorCommandValue(values=values)
372
+ return K.VectorCommandValue(values=values)
373
373
 
374
374
 
375
375
  class NumpyStateTensorSerializer(NumpyBaseSerializer, StateTensorSerializer[np.ndarray]):
376
376
  def serialize_state_tensor(
377
377
  self: "NumpyStateTensorSerializer",
378
- schema: P.StateTensorSchema,
379
- value: P.StateTensorValue,
378
+ schema: K.StateTensorSchema,
379
+ value: K.StateTensorValue,
380
380
  ) -> np.ndarray:
381
381
  value_bytes = value.data
382
382
  if len(value_bytes) != np.prod(schema.shape) * dtype_num_bytes(schema.dtype):
@@ -391,11 +391,11 @@ class NumpyStateTensorSerializer(NumpyBaseSerializer, StateTensorSerializer[np.n
391
391
 
392
392
  def deserialize_state_tensor(
393
393
  self: "NumpyStateTensorSerializer",
394
- schema: P.StateTensorSchema,
394
+ schema: K.StateTensorSchema,
395
395
  value: np.ndarray,
396
- ) -> P.StateTensorValue:
396
+ ) -> K.StateTensorValue:
397
397
  contiguous_value = np.ascontiguousarray(value)
398
- return P.StateTensorValue(data=contiguous_value.flatten().tobytes())
398
+ return K.StateTensorValue(data=contiguous_value.flatten().tobytes())
399
399
 
400
400
 
401
401
  class NumpySerializer(
@@ -413,7 +413,7 @@ class NumpySerializer(
413
413
  ):
414
414
  def __init__(
415
415
  self: "NumpySerializer",
416
- schema: P.ValueSchema,
416
+ schema: K.ValueSchema,
417
417
  *,
418
418
  dtype: np.dtype | None = None,
419
419
  ) -> None:
@@ -422,5 +422,5 @@ class NumpySerializer(
422
422
 
423
423
 
424
424
  class NumpyMultiSerializer(MultiSerializer[np.ndarray]):
425
- def __init__(self: "NumpyMultiSerializer", schema: P.IOSchema) -> None:
425
+ def __init__(self: "NumpyMultiSerializer", schema: K.IOSchema) -> None:
426
426
  super().__init__([NumpySerializer(schema=s) for s in schema.values])
@@ -6,7 +6,7 @@ import numpy as np
6
6
  import torch
7
7
  from torch import Tensor
8
8
 
9
- from kinfer import proto as P
9
+ from kinfer import proto as K
10
10
  from kinfer.serialize.base import (
11
11
  AudioFrameSerializer,
12
12
  CameraFrameSerializer,
@@ -47,8 +47,8 @@ class PyTorchBaseSerializer:
47
47
  class PyTorchJointPositionsSerializer(PyTorchBaseSerializer, JointPositionsSerializer[Tensor]):
48
48
  def serialize_joint_positions(
49
49
  self: "PyTorchJointPositionsSerializer",
50
- schema: P.JointPositionsSchema,
51
- value: P.JointPositionsValue,
50
+ schema: K.JointPositionsSchema,
51
+ value: K.JointPositionsValue,
52
52
  ) -> Tensor:
53
53
  value_map = {v.joint_name: v for v in value.values}
54
54
  check_names_match("schema", schema.joint_names, "value", list(value_map.keys()))
@@ -64,17 +64,17 @@ class PyTorchJointPositionsSerializer(PyTorchBaseSerializer, JointPositionsSeria
64
64
 
65
65
  def deserialize_joint_positions(
66
66
  self: "PyTorchJointPositionsSerializer",
67
- schema: P.JointPositionsSchema,
67
+ schema: K.JointPositionsSchema,
68
68
  value: Tensor,
69
- ) -> P.JointPositionsValue:
69
+ ) -> K.JointPositionsValue:
70
70
  if value.shape != (len(schema.joint_names),):
71
71
  raise ValueError(
72
72
  f"Shape of tensor must match number of joint names: {value.shape} != {len(schema.joint_names)}"
73
73
  )
74
74
  value_list = cast(list[float], value.detach().cpu().numpy().astype(float).tolist())
75
- return P.JointPositionsValue(
75
+ return K.JointPositionsValue(
76
76
  values=[
77
- P.JointPositionValue(joint_name=name, value=value_list[i], unit=schema.unit)
77
+ K.JointPositionValue(joint_name=name, value=value_list[i], unit=schema.unit)
78
78
  for i, name in enumerate(schema.joint_names)
79
79
  ]
80
80
  )
@@ -83,8 +83,8 @@ class PyTorchJointPositionsSerializer(PyTorchBaseSerializer, JointPositionsSeria
83
83
  class PyTorchJointVelocitiesSerializer(PyTorchBaseSerializer, JointVelocitiesSerializer[Tensor]):
84
84
  def serialize_joint_velocities(
85
85
  self: "PyTorchJointVelocitiesSerializer",
86
- schema: P.JointVelocitiesSchema,
87
- value: P.JointVelocitiesValue,
86
+ schema: K.JointVelocitiesSchema,
87
+ value: K.JointVelocitiesValue,
88
88
  ) -> Tensor:
89
89
  value_map = {v.joint_name: v for v in value.values}
90
90
  check_names_match("schema", schema.joint_names, "value", list(value_map.keys()))
@@ -100,17 +100,17 @@ class PyTorchJointVelocitiesSerializer(PyTorchBaseSerializer, JointVelocitiesSer
100
100
 
101
101
  def deserialize_joint_velocities(
102
102
  self: "PyTorchJointVelocitiesSerializer",
103
- schema: P.JointVelocitiesSchema,
103
+ schema: K.JointVelocitiesSchema,
104
104
  value: Tensor,
105
- ) -> P.JointVelocitiesValue:
105
+ ) -> K.JointVelocitiesValue:
106
106
  if value.shape != (len(schema.joint_names),):
107
107
  raise ValueError(
108
108
  f"Shape of tensor must match number of joint names: {value.shape} != {len(schema.joint_names)}"
109
109
  )
110
110
  value_list = cast(list[float], value.detach().cpu().numpy().astype(float).tolist())
111
- return P.JointVelocitiesValue(
111
+ return K.JointVelocitiesValue(
112
112
  values=[
113
- P.JointVelocityValue(joint_name=name, value=value_list[i], unit=schema.unit)
113
+ K.JointVelocityValue(joint_name=name, value=value_list[i], unit=schema.unit)
114
114
  for i, name in enumerate(schema.joint_names)
115
115
  ]
116
116
  )
@@ -119,8 +119,8 @@ class PyTorchJointVelocitiesSerializer(PyTorchBaseSerializer, JointVelocitiesSer
119
119
  class PyTorchJointTorquesSerializer(PyTorchBaseSerializer, JointTorquesSerializer[Tensor]):
120
120
  def serialize_joint_torques(
121
121
  self: "PyTorchJointTorquesSerializer",
122
- schema: P.JointTorquesSchema,
123
- value: P.JointTorquesValue,
122
+ schema: K.JointTorquesSchema,
123
+ value: K.JointTorquesValue,
124
124
  ) -> Tensor:
125
125
  value_map = {v.joint_name: v for v in value.values}
126
126
  check_names_match("schema", schema.joint_names, "value", list(value_map.keys()))
@@ -133,17 +133,17 @@ class PyTorchJointTorquesSerializer(PyTorchBaseSerializer, JointTorquesSerialize
133
133
 
134
134
  def deserialize_joint_torques(
135
135
  self: "PyTorchJointTorquesSerializer",
136
- schema: P.JointTorquesSchema,
136
+ schema: K.JointTorquesSchema,
137
137
  value: Tensor,
138
- ) -> P.JointTorquesValue:
138
+ ) -> K.JointTorquesValue:
139
139
  if value.shape != (len(schema.joint_names),):
140
140
  raise ValueError(
141
141
  f"Shape of tensor must match number of joint names: {value.shape} != {len(schema.joint_names)}"
142
142
  )
143
143
  value_list = cast(list[float], value.detach().cpu().numpy().astype(float).tolist())
144
- return P.JointTorquesValue(
144
+ return K.JointTorquesValue(
145
145
  values=[
146
- P.JointTorqueValue(joint_name=name, value=value_list[i], unit=schema.unit)
146
+ K.JointTorqueValue(joint_name=name, value=value_list[i], unit=schema.unit)
147
147
  for i, name in enumerate(schema.joint_names)
148
148
  ]
149
149
  )
@@ -152,8 +152,8 @@ class PyTorchJointTorquesSerializer(PyTorchBaseSerializer, JointTorquesSerialize
152
152
  class PyTorchJointCommandsSerializer(PyTorchBaseSerializer, JointCommandsSerializer[Tensor]):
153
153
  def _convert_value_to_tensor(
154
154
  self: "PyTorchJointCommandsSerializer",
155
- value: P.JointCommandValue,
156
- schema: P.JointCommandsSchema,
155
+ value: K.JointCommandValue,
156
+ schema: K.JointCommandsSchema,
157
157
  ) -> Tensor:
158
158
  return torch.tensor(
159
159
  [
@@ -170,12 +170,12 @@ class PyTorchJointCommandsSerializer(PyTorchBaseSerializer, JointCommandsSeriali
170
170
  def _convert_tensor_to_value(
171
171
  self: "PyTorchJointCommandsSerializer",
172
172
  values: list[float],
173
- schema: P.JointCommandsSchema,
173
+ schema: K.JointCommandsSchema,
174
174
  name: str,
175
- ) -> P.JointCommandValue:
175
+ ) -> K.JointCommandValue:
176
176
  if len(values) != 5:
177
177
  raise ValueError(f"Shape of tensor must match number of joint commands: {len(values)} != 5")
178
- return P.JointCommandValue(
178
+ return K.JointCommandValue(
179
179
  joint_name=name,
180
180
  torque=values[0],
181
181
  velocity=values[1],
@@ -189,8 +189,8 @@ class PyTorchJointCommandsSerializer(PyTorchBaseSerializer, JointCommandsSeriali
189
189
 
190
190
  def serialize_joint_commands(
191
191
  self: "PyTorchJointCommandsSerializer",
192
- schema: P.JointCommandsSchema,
193
- value: P.JointCommandsValue,
192
+ schema: K.JointCommandsSchema,
193
+ value: K.JointCommandsValue,
194
194
  ) -> Tensor:
195
195
  value_map = {v.joint_name: v for v in value.values}
196
196
  check_names_match("schema", schema.joint_names, "value", list(value_map.keys()))
@@ -202,16 +202,16 @@ class PyTorchJointCommandsSerializer(PyTorchBaseSerializer, JointCommandsSeriali
202
202
 
203
203
  def deserialize_joint_commands(
204
204
  self: "PyTorchJointCommandsSerializer",
205
- schema: P.JointCommandsSchema,
205
+ schema: K.JointCommandsSchema,
206
206
  value: Tensor,
207
- ) -> P.JointCommandsValue:
207
+ ) -> K.JointCommandsValue:
208
208
  if value.shape != (len(schema.joint_names), 5):
209
209
  raise ValueError(
210
210
  "Shape of tensor must match number of joint names and commands: "
211
211
  f"{value.shape} != ({len(schema.joint_names)}, 5)"
212
212
  )
213
213
  value_list = cast(list[list[float]], value.detach().cpu().numpy().astype(float).tolist())
214
- return P.JointCommandsValue(
214
+ return K.JointCommandsValue(
215
215
  values=[
216
216
  self._convert_tensor_to_value(value_list[i], schema, name) for i, name in enumerate(schema.joint_names)
217
217
  ]
@@ -220,9 +220,9 @@ class PyTorchJointCommandsSerializer(PyTorchBaseSerializer, JointCommandsSeriali
220
220
 
221
221
  class PyTorchCameraFrameSerializer(PyTorchBaseSerializer, CameraFrameSerializer[Tensor]):
222
222
  def serialize_camera_frame(
223
- self: "PyTorchCameraFrameSerializer", schema: P.CameraFrameSchema, value: P.CameraFrameValue
223
+ self: "PyTorchCameraFrameSerializer", schema: K.CameraFrameSchema, value: K.CameraFrameValue
224
224
  ) -> Tensor:
225
- np_arr = parse_bytes(value.data, P.DType.UINT8)
225
+ np_arr = parse_bytes(value.data, K.DType.UINT8)
226
226
  tensor = torch.from_numpy(np_arr).to(self.device, self.dtype) / 255.0
227
227
  if tensor.numel() != schema.channels * schema.height * schema.width:
228
228
  raise ValueError(
@@ -233,15 +233,15 @@ class PyTorchCameraFrameSerializer(PyTorchBaseSerializer, CameraFrameSerializer[
233
233
  return tensor
234
234
 
235
235
  def deserialize_camera_frame(
236
- self: "PyTorchCameraFrameSerializer", schema: P.CameraFrameSchema, value: Tensor
237
- ) -> P.CameraFrameValue:
236
+ self: "PyTorchCameraFrameSerializer", schema: K.CameraFrameSchema, value: Tensor
237
+ ) -> K.CameraFrameValue:
238
238
  np_arr = (value * 255.0).detach().cpu().flatten().numpy().astype(np.uint8)
239
- return P.CameraFrameValue(data=np_arr.tobytes())
239
+ return K.CameraFrameValue(data=np_arr.tobytes())
240
240
 
241
241
 
242
242
  class PyTorchAudioFrameSerializer(PyTorchBaseSerializer, AudioFrameSerializer[Tensor]):
243
243
  def serialize_audio_frame(
244
- self: "PyTorchAudioFrameSerializer", schema: P.AudioFrameSchema, value: P.AudioFrameValue
244
+ self: "PyTorchAudioFrameSerializer", schema: K.AudioFrameSchema, value: K.AudioFrameValue
245
245
  ) -> Tensor:
246
246
  value_bytes = value.data
247
247
  if len(value_bytes) != schema.channels * schema.sample_rate * dtype_num_bytes(schema.dtype):
@@ -257,15 +257,15 @@ class PyTorchAudioFrameSerializer(PyTorchBaseSerializer, AudioFrameSerializer[Te
257
257
  return tensor
258
258
 
259
259
  def deserialize_audio_frame(
260
- self: "PyTorchAudioFrameSerializer", schema: P.AudioFrameSchema, value: Tensor
261
- ) -> P.AudioFrameValue:
260
+ self: "PyTorchAudioFrameSerializer", schema: K.AudioFrameSchema, value: Tensor
261
+ ) -> K.AudioFrameValue:
262
262
  _, max_value = dtype_range(schema.dtype)
263
263
  np_arr = (value * max_value).detach().cpu().flatten().numpy().astype(numpy_dtype(schema.dtype))
264
- return P.AudioFrameValue(data=np_arr.tobytes())
264
+ return K.AudioFrameValue(data=np_arr.tobytes())
265
265
 
266
266
 
267
267
  class PyTorchImuSerializer(PyTorchBaseSerializer, ImuSerializer[Tensor]):
268
- def serialize_imu(self: "PyTorchImuSerializer", schema: P.ImuSchema, value: P.ImuValue) -> Tensor:
268
+ def serialize_imu(self: "PyTorchImuSerializer", schema: K.ImuSchema, value: K.ImuValue) -> Tensor:
269
269
  vectors: list[Tensor] = []
270
270
  if schema.use_accelerometer:
271
271
  vectors.append(
@@ -295,9 +295,9 @@ class PyTorchImuSerializer(PyTorchBaseSerializer, ImuSerializer[Tensor]):
295
295
  raise ValueError("IMU has nothing to serialize")
296
296
  return torch.stack(vectors, dim=0)
297
297
 
298
- def deserialize_imu(self: "PyTorchImuSerializer", schema: P.ImuSchema, value: Tensor) -> P.ImuValue:
298
+ def deserialize_imu(self: "PyTorchImuSerializer", schema: K.ImuSchema, value: Tensor) -> K.ImuValue:
299
299
  vectors = value.tolist()
300
- imu_value = P.ImuValue()
300
+ imu_value = K.ImuValue()
301
301
  if schema.use_accelerometer:
302
302
  (x, y, z), vectors = vectors[0], vectors[1:]
303
303
  imu_value.linear_acceleration.x = x
@@ -318,7 +318,7 @@ class PyTorchImuSerializer(PyTorchBaseSerializer, ImuSerializer[Tensor]):
318
318
 
319
319
  class PyTorchTimestampSerializer(PyTorchBaseSerializer, TimestampSerializer[Tensor]):
320
320
  def serialize_timestamp(
321
- self: "PyTorchTimestampSerializer", schema: P.TimestampSchema, value: P.TimestampValue
321
+ self: "PyTorchTimestampSerializer", schema: K.TimestampSchema, value: K.TimestampValue
322
322
  ) -> Tensor:
323
323
  elapsed_seconds = value.seconds - schema.start_seconds
324
324
  elapsed_nanos = value.nanos - schema.start_nanos
@@ -329,32 +329,32 @@ class PyTorchTimestampSerializer(PyTorchBaseSerializer, TimestampSerializer[Tens
329
329
  return torch.tensor([total_elapsed_seconds], dtype=self.dtype, device=self.device, requires_grad=False)
330
330
 
331
331
  def deserialize_timestamp(
332
- self: "PyTorchTimestampSerializer", schema: P.TimestampSchema, value: Tensor
333
- ) -> P.TimestampValue:
332
+ self: "PyTorchTimestampSerializer", schema: K.TimestampSchema, value: Tensor
333
+ ) -> K.TimestampValue:
334
334
  total_elapsed_seconds = value.item()
335
335
  elapsed_seconds = int(total_elapsed_seconds)
336
336
  elapsed_nanos = int((total_elapsed_seconds - elapsed_seconds) * 1_000_000_000)
337
- return P.TimestampValue(seconds=elapsed_seconds, nanos=elapsed_nanos)
337
+ return K.TimestampValue(seconds=elapsed_seconds, nanos=elapsed_nanos)
338
338
 
339
339
 
340
340
  class PyTorchVectorCommandSerializer(PyTorchBaseSerializer, VectorCommandSerializer[Tensor]):
341
341
  def serialize_vector_command(
342
- self: "PyTorchVectorCommandSerializer", schema: P.VectorCommandSchema, value: P.VectorCommandValue
342
+ self: "PyTorchVectorCommandSerializer", schema: K.VectorCommandSchema, value: K.VectorCommandValue
343
343
  ) -> Tensor:
344
344
  return torch.tensor(value.values, dtype=self.dtype, device=self.device)
345
345
 
346
346
  def deserialize_vector_command(
347
- self: "PyTorchVectorCommandSerializer", schema: P.VectorCommandSchema, value: Tensor
348
- ) -> P.VectorCommandValue:
347
+ self: "PyTorchVectorCommandSerializer", schema: K.VectorCommandSchema, value: Tensor
348
+ ) -> K.VectorCommandValue:
349
349
  if value.shape != (schema.dimensions,):
350
350
  raise ValueError(f"Shape of tensor must match number of dimensions: {value.shape} != {schema.dimensions}")
351
351
  values = cast(list[float], value.tolist())
352
- return P.VectorCommandValue(values=values)
352
+ return K.VectorCommandValue(values=values)
353
353
 
354
354
 
355
355
  class PyTorchStateTensorSerializer(PyTorchBaseSerializer, StateTensorSerializer[Tensor]):
356
356
  def serialize_state_tensor(
357
- self: "PyTorchStateTensorSerializer", schema: P.StateTensorSchema, value: P.StateTensorValue
357
+ self: "PyTorchStateTensorSerializer", schema: K.StateTensorSchema, value: K.StateTensorValue
358
358
  ) -> Tensor:
359
359
  value_bytes = value.data
360
360
  if len(value_bytes) != np.prod(schema.shape) * dtype_num_bytes(schema.dtype):
@@ -368,9 +368,9 @@ class PyTorchStateTensorSerializer(PyTorchBaseSerializer, StateTensorSerializer[
368
368
  return tensor
369
369
 
370
370
  def deserialize_state_tensor(
371
- self: "PyTorchStateTensorSerializer", schema: P.StateTensorSchema, value: Tensor
372
- ) -> P.StateTensorValue:
373
- return P.StateTensorValue(data=value.detach().cpu().flatten().numpy().tobytes())
371
+ self: "PyTorchStateTensorSerializer", schema: K.StateTensorSchema, value: Tensor
372
+ ) -> K.StateTensorValue:
373
+ return K.StateTensorValue(data=value.detach().cpu().flatten().numpy().tobytes())
374
374
 
375
375
 
376
376
  class PyTorchSerializer(
@@ -388,7 +388,7 @@ class PyTorchSerializer(
388
388
  ):
389
389
  def __init__(
390
390
  self: "PyTorchSerializer",
391
- schema: P.ValueSchema,
391
+ schema: K.ValueSchema,
392
392
  *,
393
393
  device: str | torch.device | None = None,
394
394
  dtype: torch.dtype | None = None,
@@ -398,5 +398,5 @@ class PyTorchSerializer(
398
398
 
399
399
 
400
400
  class PyTorchMultiSerializer(MultiSerializer[Tensor]):
401
- def __init__(self: "PyTorchMultiSerializer", schema: P.IOSchema) -> None:
401
+ def __init__(self: "PyTorchMultiSerializer", schema: K.IOSchema) -> None:
402
402
  super().__init__([PyTorchSerializer(schema=s) for s in schema.values])