kinfer 0.3.3__cp312-cp312-macosx_11_0_arm64.whl → 0.4.1__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. kinfer/__init__.py +0 -5
  2. kinfer/common/__init__.py +0 -0
  3. kinfer/common/types.py +12 -0
  4. kinfer/export/common.py +41 -0
  5. kinfer/export/jax.py +53 -0
  6. kinfer/export/pytorch.py +45 -110
  7. kinfer/export/serialize.py +93 -0
  8. kinfer/requirements.txt +3 -4
  9. kinfer/rust/Cargo.toml +20 -8
  10. kinfer/rust/src/lib.rs +2 -11
  11. kinfer/rust/src/model.rs +286 -121
  12. kinfer/rust/src/runtime.rs +104 -0
  13. kinfer/rust_bindings/Cargo.toml +8 -1
  14. kinfer/rust_bindings/rust_bindings.pyi +36 -0
  15. kinfer/rust_bindings/src/lib.rs +326 -1
  16. kinfer/rust_bindings.cpython-312-darwin.so +0 -0
  17. kinfer/rust_bindings.pyi +30 -1
  18. kinfer-0.4.1.dist-info/METADATA +55 -0
  19. kinfer-0.4.1.dist-info/RECORD +26 -0
  20. {kinfer-0.3.3.dist-info → kinfer-0.4.1.dist-info}/WHEEL +2 -1
  21. kinfer/inference/__init__.py +0 -2
  22. kinfer/inference/base.py +0 -64
  23. kinfer/inference/python.py +0 -66
  24. kinfer/proto/__init__.py +0 -40
  25. kinfer/proto/kinfer_pb2.py +0 -103
  26. kinfer/proto/kinfer_pb2.pyi +0 -1097
  27. kinfer/requirements-dev.txt +0 -8
  28. kinfer/rust/build.rs +0 -16
  29. kinfer/rust/src/kinfer_proto.rs +0 -14
  30. kinfer/rust/src/main.rs +0 -6
  31. kinfer/rust/src/onnx_serializer.rs +0 -804
  32. kinfer/rust/src/serializer.rs +0 -221
  33. kinfer/rust/src/tests/onnx_serializer_tests.rs +0 -212
  34. kinfer/serialize/__init__.py +0 -60
  35. kinfer/serialize/base.py +0 -536
  36. kinfer/serialize/json.py +0 -399
  37. kinfer/serialize/numpy.py +0 -426
  38. kinfer/serialize/pytorch.py +0 -402
  39. kinfer/serialize/schema.py +0 -125
  40. kinfer/serialize/types.py +0 -17
  41. kinfer/serialize/utils.py +0 -177
  42. kinfer-0.3.3.dist-info/METADATA +0 -57
  43. kinfer-0.3.3.dist-info/RECORD +0 -40
  44. {kinfer-0.3.3.dist-info → kinfer-0.4.1.dist-info/licenses}/LICENSE +0 -0
  45. {kinfer-0.3.3.dist-info → kinfer-0.4.1.dist-info}/top_level.txt +0 -0
@@ -1,402 +0,0 @@
1
- """Defines a serializer for PyTorch tensors."""
2
-
3
- from typing import cast
4
-
5
- import numpy as np
6
- import torch
7
- from torch import Tensor
8
-
9
- from kinfer import proto as K
10
- from kinfer.serialize.base import (
11
- AudioFrameSerializer,
12
- CameraFrameSerializer,
13
- ImuSerializer,
14
- JointCommandsSerializer,
15
- JointPositionsSerializer,
16
- JointTorquesSerializer,
17
- JointVelocitiesSerializer,
18
- MultiSerializer,
19
- Serializer,
20
- StateTensorSerializer,
21
- TimestampSerializer,
22
- VectorCommandSerializer,
23
- )
24
- from kinfer.serialize.utils import (
25
- check_names_match,
26
- convert_angular_position,
27
- convert_angular_velocity,
28
- convert_torque,
29
- dtype_num_bytes,
30
- dtype_range,
31
- numpy_dtype,
32
- parse_bytes,
33
- pytorch_dtype,
34
- )
35
-
36
-
37
- class PyTorchBaseSerializer:
38
- def __init__(
39
- self: "PyTorchBaseSerializer",
40
- device: str | torch.device | None = None,
41
- dtype: torch.dtype | None = None,
42
- ) -> None:
43
- self.device = device
44
- self.dtype = dtype
45
-
46
-
47
- class PyTorchJointPositionsSerializer(PyTorchBaseSerializer, JointPositionsSerializer[Tensor]):
48
- def serialize_joint_positions(
49
- self: "PyTorchJointPositionsSerializer",
50
- schema: K.JointPositionsSchema,
51
- value: K.JointPositionsValue,
52
- ) -> Tensor:
53
- value_map = {v.joint_name: v for v in value.values}
54
- check_names_match("schema", schema.joint_names, "value", list(value_map.keys()))
55
- tensor = torch.tensor(
56
- [
57
- convert_angular_position(value_map[name].value, value_map[name].unit, schema.unit)
58
- for name in schema.joint_names
59
- ],
60
- dtype=self.dtype,
61
- device=self.device,
62
- )
63
- return tensor
64
-
65
- def deserialize_joint_positions(
66
- self: "PyTorchJointPositionsSerializer",
67
- schema: K.JointPositionsSchema,
68
- value: Tensor,
69
- ) -> K.JointPositionsValue:
70
- if value.shape != (len(schema.joint_names),):
71
- raise ValueError(
72
- f"Shape of tensor must match number of joint names: {value.shape} != {len(schema.joint_names)}"
73
- )
74
- value_list = cast(list[float], value.detach().cpu().numpy().astype(float).tolist())
75
- return K.JointPositionsValue(
76
- values=[
77
- K.JointPositionValue(joint_name=name, value=value_list[i], unit=schema.unit)
78
- for i, name in enumerate(schema.joint_names)
79
- ]
80
- )
81
-
82
-
83
- class PyTorchJointVelocitiesSerializer(PyTorchBaseSerializer, JointVelocitiesSerializer[Tensor]):
84
- def serialize_joint_velocities(
85
- self: "PyTorchJointVelocitiesSerializer",
86
- schema: K.JointVelocitiesSchema,
87
- value: K.JointVelocitiesValue,
88
- ) -> Tensor:
89
- value_map = {v.joint_name: v for v in value.values}
90
- check_names_match("schema", schema.joint_names, "value", list(value_map.keys()))
91
- tensor = torch.tensor(
92
- [
93
- convert_angular_velocity(value_map[name].value, value_map[name].unit, schema.unit)
94
- for name in schema.joint_names
95
- ],
96
- dtype=self.dtype,
97
- device=self.device,
98
- )
99
- return tensor
100
-
101
- def deserialize_joint_velocities(
102
- self: "PyTorchJointVelocitiesSerializer",
103
- schema: K.JointVelocitiesSchema,
104
- value: Tensor,
105
- ) -> K.JointVelocitiesValue:
106
- if value.shape != (len(schema.joint_names),):
107
- raise ValueError(
108
- f"Shape of tensor must match number of joint names: {value.shape} != {len(schema.joint_names)}"
109
- )
110
- value_list = cast(list[float], value.detach().cpu().numpy().astype(float).tolist())
111
- return K.JointVelocitiesValue(
112
- values=[
113
- K.JointVelocityValue(joint_name=name, value=value_list[i], unit=schema.unit)
114
- for i, name in enumerate(schema.joint_names)
115
- ]
116
- )
117
-
118
-
119
- class PyTorchJointTorquesSerializer(PyTorchBaseSerializer, JointTorquesSerializer[Tensor]):
120
- def serialize_joint_torques(
121
- self: "PyTorchJointTorquesSerializer",
122
- schema: K.JointTorquesSchema,
123
- value: K.JointTorquesValue,
124
- ) -> Tensor:
125
- value_map = {v.joint_name: v for v in value.values}
126
- check_names_match("schema", schema.joint_names, "value", list(value_map.keys()))
127
- tensor = torch.tensor(
128
- [convert_torque(value_map[name].value, value_map[name].unit, schema.unit) for name in schema.joint_names],
129
- dtype=self.dtype,
130
- device=self.device,
131
- )
132
- return tensor
133
-
134
- def deserialize_joint_torques(
135
- self: "PyTorchJointTorquesSerializer",
136
- schema: K.JointTorquesSchema,
137
- value: Tensor,
138
- ) -> K.JointTorquesValue:
139
- if value.shape != (len(schema.joint_names),):
140
- raise ValueError(
141
- f"Shape of tensor must match number of joint names: {value.shape} != {len(schema.joint_names)}"
142
- )
143
- value_list = cast(list[float], value.detach().cpu().numpy().astype(float).tolist())
144
- return K.JointTorquesValue(
145
- values=[
146
- K.JointTorqueValue(joint_name=name, value=value_list[i], unit=schema.unit)
147
- for i, name in enumerate(schema.joint_names)
148
- ]
149
- )
150
-
151
-
152
- class PyTorchJointCommandsSerializer(PyTorchBaseSerializer, JointCommandsSerializer[Tensor]):
153
- def _convert_value_to_tensor(
154
- self: "PyTorchJointCommandsSerializer",
155
- value: K.JointCommandValue,
156
- schema: K.JointCommandsSchema,
157
- ) -> Tensor:
158
- return torch.tensor(
159
- [
160
- convert_torque(value.torque, value.torque_unit, schema.torque_unit),
161
- convert_angular_velocity(value.velocity, value.velocity_unit, schema.velocity_unit),
162
- convert_angular_position(value.position, value.position_unit, schema.position_unit),
163
- value.kp,
164
- value.kd,
165
- ],
166
- dtype=self.dtype,
167
- device=self.device,
168
- )
169
-
170
- def _convert_tensor_to_value(
171
- self: "PyTorchJointCommandsSerializer",
172
- values: list[float],
173
- schema: K.JointCommandsSchema,
174
- name: str,
175
- ) -> K.JointCommandValue:
176
- if len(values) != 5:
177
- raise ValueError(f"Shape of tensor must match number of joint commands: {len(values)} != 5")
178
- return K.JointCommandValue(
179
- joint_name=name,
180
- torque=values[0],
181
- velocity=values[1],
182
- position=values[2],
183
- kp=values[3],
184
- kd=values[4],
185
- torque_unit=schema.torque_unit,
186
- velocity_unit=schema.velocity_unit,
187
- position_unit=schema.position_unit,
188
- )
189
-
190
- def serialize_joint_commands(
191
- self: "PyTorchJointCommandsSerializer",
192
- schema: K.JointCommandsSchema,
193
- value: K.JointCommandsValue,
194
- ) -> Tensor:
195
- value_map = {v.joint_name: v for v in value.values}
196
- check_names_match("schema", schema.joint_names, "value", list(value_map.keys()))
197
- tensor = torch.stack(
198
- [self._convert_value_to_tensor(value_map[name], schema) for name in schema.joint_names],
199
- dim=0,
200
- )
201
- return tensor
202
-
203
- def deserialize_joint_commands(
204
- self: "PyTorchJointCommandsSerializer",
205
- schema: K.JointCommandsSchema,
206
- value: Tensor,
207
- ) -> K.JointCommandsValue:
208
- if value.shape != (len(schema.joint_names), 5):
209
- raise ValueError(
210
- "Shape of tensor must match number of joint names and commands: "
211
- f"{value.shape} != ({len(schema.joint_names)}, 5)"
212
- )
213
- value_list = cast(list[list[float]], value.detach().cpu().numpy().astype(float).tolist())
214
- return K.JointCommandsValue(
215
- values=[
216
- self._convert_tensor_to_value(value_list[i], schema, name) for i, name in enumerate(schema.joint_names)
217
- ]
218
- )
219
-
220
-
221
- class PyTorchCameraFrameSerializer(PyTorchBaseSerializer, CameraFrameSerializer[Tensor]):
222
- def serialize_camera_frame(
223
- self: "PyTorchCameraFrameSerializer", schema: K.CameraFrameSchema, value: K.CameraFrameValue
224
- ) -> Tensor:
225
- np_arr = parse_bytes(value.data, K.DType.UINT8)
226
- tensor = torch.from_numpy(np_arr).to(self.device, self.dtype) / 255.0
227
- if tensor.numel() != schema.channels * schema.height * schema.width:
228
- raise ValueError(
229
- "Length of data must match number of channels, height, and width: "
230
- f"{tensor.numel()} != {schema.channels} * {schema.height} * {schema.width}"
231
- )
232
- tensor = tensor.view(schema.channels, schema.height, schema.width)
233
- return tensor
234
-
235
- def deserialize_camera_frame(
236
- self: "PyTorchCameraFrameSerializer", schema: K.CameraFrameSchema, value: Tensor
237
- ) -> K.CameraFrameValue:
238
- np_arr = (value * 255.0).detach().cpu().flatten().numpy().astype(np.uint8)
239
- return K.CameraFrameValue(data=np_arr.tobytes())
240
-
241
-
242
- class PyTorchAudioFrameSerializer(PyTorchBaseSerializer, AudioFrameSerializer[Tensor]):
243
- def serialize_audio_frame(
244
- self: "PyTorchAudioFrameSerializer", schema: K.AudioFrameSchema, value: K.AudioFrameValue
245
- ) -> Tensor:
246
- value_bytes = value.data
247
- if len(value_bytes) != schema.channels * schema.sample_rate * dtype_num_bytes(schema.dtype):
248
- raise ValueError(
249
- "Length of data must match number of channels, sample rate, and dtype: "
250
- f"{len(value_bytes)} != {schema.channels} * {schema.sample_rate} * {dtype_num_bytes(schema.dtype)}"
251
- )
252
- _, max_value = dtype_range(schema.dtype)
253
- np_arr = parse_bytes(value_bytes, schema.dtype)
254
- tensor = torch.from_numpy(np_arr).to(self.device, self.dtype)
255
- tensor = tensor.view(schema.channels, -1)
256
- tensor = tensor / max_value
257
- return tensor
258
-
259
- def deserialize_audio_frame(
260
- self: "PyTorchAudioFrameSerializer", schema: K.AudioFrameSchema, value: Tensor
261
- ) -> K.AudioFrameValue:
262
- _, max_value = dtype_range(schema.dtype)
263
- np_arr = (value * max_value).detach().cpu().flatten().numpy().astype(numpy_dtype(schema.dtype))
264
- return K.AudioFrameValue(data=np_arr.tobytes())
265
-
266
-
267
- class PyTorchImuSerializer(PyTorchBaseSerializer, ImuSerializer[Tensor]):
268
- def serialize_imu(self: "PyTorchImuSerializer", schema: K.ImuSchema, value: K.ImuValue) -> Tensor:
269
- vectors: list[Tensor] = []
270
- if schema.use_accelerometer:
271
- vectors.append(
272
- torch.tensor(
273
- [value.linear_acceleration.x, value.linear_acceleration.y, value.linear_acceleration.z],
274
- dtype=self.dtype,
275
- device=self.device,
276
- )
277
- )
278
- if schema.use_gyroscope:
279
- vectors.append(
280
- torch.tensor(
281
- [value.angular_velocity.x, value.angular_velocity.y, value.angular_velocity.z],
282
- dtype=self.dtype,
283
- device=self.device,
284
- )
285
- )
286
- if schema.use_magnetometer:
287
- vectors.append(
288
- torch.tensor(
289
- [value.magnetic_field.x, value.magnetic_field.y, value.magnetic_field.z],
290
- dtype=self.dtype,
291
- device=self.device,
292
- )
293
- )
294
- if not vectors:
295
- raise ValueError("IMU has nothing to serialize")
296
- return torch.stack(vectors, dim=0)
297
-
298
- def deserialize_imu(self: "PyTorchImuSerializer", schema: K.ImuSchema, value: Tensor) -> K.ImuValue:
299
- vectors = value.tolist()
300
- imu_value = K.ImuValue()
301
- if schema.use_accelerometer:
302
- (x, y, z), vectors = vectors[0], vectors[1:]
303
- imu_value.linear_acceleration.x = x
304
- imu_value.linear_acceleration.y = y
305
- imu_value.linear_acceleration.z = z
306
- if schema.use_gyroscope:
307
- (x, y, z), vectors = vectors[0], vectors[1:]
308
- imu_value.angular_velocity.x = x
309
- imu_value.angular_velocity.y = y
310
- imu_value.angular_velocity.z = z
311
- if schema.use_magnetometer:
312
- (x, y, z), vectors = vectors[0], vectors[1:]
313
- imu_value.magnetic_field.x = x
314
- imu_value.magnetic_field.y = y
315
- imu_value.magnetic_field.z = z
316
- return imu_value
317
-
318
-
319
- class PyTorchTimestampSerializer(PyTorchBaseSerializer, TimestampSerializer[Tensor]):
320
- def serialize_timestamp(
321
- self: "PyTorchTimestampSerializer", schema: K.TimestampSchema, value: K.TimestampValue
322
- ) -> Tensor:
323
- elapsed_seconds = value.seconds - schema.start_seconds
324
- elapsed_nanos = value.nanos - schema.start_nanos
325
- if elapsed_nanos < 0:
326
- elapsed_seconds -= 1
327
- elapsed_nanos += 1_000_000_000
328
- total_elapsed_seconds = elapsed_seconds + elapsed_nanos / 1_000_000_000
329
- return torch.tensor([total_elapsed_seconds], dtype=self.dtype, device=self.device, requires_grad=False)
330
-
331
- def deserialize_timestamp(
332
- self: "PyTorchTimestampSerializer", schema: K.TimestampSchema, value: Tensor
333
- ) -> K.TimestampValue:
334
- total_elapsed_seconds = value.item()
335
- elapsed_seconds = int(total_elapsed_seconds)
336
- elapsed_nanos = int((total_elapsed_seconds - elapsed_seconds) * 1_000_000_000)
337
- return K.TimestampValue(seconds=elapsed_seconds, nanos=elapsed_nanos)
338
-
339
-
340
- class PyTorchVectorCommandSerializer(PyTorchBaseSerializer, VectorCommandSerializer[Tensor]):
341
- def serialize_vector_command(
342
- self: "PyTorchVectorCommandSerializer", schema: K.VectorCommandSchema, value: K.VectorCommandValue
343
- ) -> Tensor:
344
- return torch.tensor(value.values, dtype=self.dtype, device=self.device)
345
-
346
- def deserialize_vector_command(
347
- self: "PyTorchVectorCommandSerializer", schema: K.VectorCommandSchema, value: Tensor
348
- ) -> K.VectorCommandValue:
349
- if value.shape != (schema.dimensions,):
350
- raise ValueError(f"Shape of tensor must match number of dimensions: {value.shape} != {schema.dimensions}")
351
- values = cast(list[float], value.tolist())
352
- return K.VectorCommandValue(values=values)
353
-
354
-
355
- class PyTorchStateTensorSerializer(PyTorchBaseSerializer, StateTensorSerializer[Tensor]):
356
- def serialize_state_tensor(
357
- self: "PyTorchStateTensorSerializer", schema: K.StateTensorSchema, value: K.StateTensorValue
358
- ) -> Tensor:
359
- value_bytes = value.data
360
- if len(value_bytes) != np.prod(schema.shape) * dtype_num_bytes(schema.dtype):
361
- raise ValueError(
362
- "Length of data must match number of elements: "
363
- f"{len(value_bytes)} != {np.prod(schema.shape)} * {dtype_num_bytes(schema.dtype)}"
364
- )
365
- np_arr = parse_bytes(value_bytes, schema.dtype)
366
- tensor = torch.from_numpy(np_arr).to(self.device, pytorch_dtype(schema.dtype))
367
- tensor = tensor.view(tuple(schema.shape))
368
- return tensor
369
-
370
- def deserialize_state_tensor(
371
- self: "PyTorchStateTensorSerializer", schema: K.StateTensorSchema, value: Tensor
372
- ) -> K.StateTensorValue:
373
- return K.StateTensorValue(data=value.detach().cpu().flatten().numpy().tobytes())
374
-
375
-
376
- class PyTorchSerializer(
377
- PyTorchJointPositionsSerializer,
378
- PyTorchJointVelocitiesSerializer,
379
- PyTorchJointTorquesSerializer,
380
- PyTorchJointCommandsSerializer,
381
- PyTorchCameraFrameSerializer,
382
- PyTorchAudioFrameSerializer,
383
- PyTorchImuSerializer,
384
- PyTorchTimestampSerializer,
385
- PyTorchVectorCommandSerializer,
386
- PyTorchStateTensorSerializer,
387
- Serializer[Tensor],
388
- ):
389
- def __init__(
390
- self: "PyTorchSerializer",
391
- schema: K.ValueSchema,
392
- *,
393
- device: str | torch.device | None = None,
394
- dtype: torch.dtype | None = None,
395
- ) -> None:
396
- PyTorchBaseSerializer.__init__(self, device=device, dtype=dtype)
397
- Serializer.__init__(self, schema=schema)
398
-
399
-
400
- class PyTorchMultiSerializer(MultiSerializer[Tensor]):
401
- def __init__(self: "PyTorchMultiSerializer", schema: K.IOSchema) -> None:
402
- super().__init__([PyTorchSerializer(schema=s) for s in schema.values])
@@ -1,125 +0,0 @@
1
- """Defines utility functions for the schema."""
2
-
3
- import numpy as np
4
-
5
- from kinfer import proto as K
6
- from kinfer.serialize.utils import dtype_num_bytes
7
-
8
-
9
- def get_dummy_value(value_schema: K.ValueSchema) -> K.Value:
10
- value_type = value_schema.WhichOneof("value_type")
11
-
12
- match value_type:
13
- case "joint_positions":
14
- return K.Value(
15
- joint_positions=K.JointPositionsValue(
16
- values=[
17
- K.JointPositionValue(
18
- joint_name=joint_name,
19
- value=0.0,
20
- unit=value_schema.joint_positions.unit,
21
- )
22
- for joint_name in value_schema.joint_positions.joint_names
23
- ]
24
- ),
25
- )
26
- case "joint_velocities":
27
- return K.Value(
28
- joint_velocities=K.JointVelocitiesValue(
29
- values=[
30
- K.JointVelocityValue(
31
- joint_name=joint_name,
32
- value=0.0,
33
- unit=value_schema.joint_velocities.unit,
34
- )
35
- for joint_name in value_schema.joint_velocities.joint_names
36
- ]
37
- ),
38
- )
39
- case "joint_torques":
40
- return K.Value(
41
- joint_torques=K.JointTorquesValue(
42
- values=[
43
- K.JointTorqueValue(
44
- joint_name=joint_name,
45
- value=0.0,
46
- unit=value_schema.joint_torques.unit,
47
- )
48
- for joint_name in value_schema.joint_torques.joint_names
49
- ]
50
- ),
51
- )
52
- case "joint_commands":
53
- return K.Value(
54
- joint_commands=K.JointCommandsValue(
55
- values=[
56
- K.JointCommandValue(
57
- joint_name=joint_name,
58
- torque=0.0,
59
- velocity=0.0,
60
- position=0.0,
61
- kp=0.0,
62
- kd=0.0,
63
- torque_unit=value_schema.joint_commands.torque_unit,
64
- velocity_unit=value_schema.joint_commands.velocity_unit,
65
- position_unit=value_schema.joint_commands.position_unit,
66
- )
67
- for joint_name in value_schema.joint_commands.joint_names
68
- ]
69
- ),
70
- )
71
- case "camera_frame":
72
- return K.Value(
73
- camera_frame=K.CameraFrameValue(
74
- data=b"\x00"
75
- * (
76
- value_schema.camera_frame.width
77
- * value_schema.camera_frame.height
78
- * value_schema.camera_frame.channels
79
- )
80
- ),
81
- )
82
- case "audio_frame":
83
- return K.Value(
84
- audio_frame=K.AudioFrameValue(
85
- data=b"\x00"
86
- * (
87
- value_schema.audio_frame.channels
88
- * value_schema.audio_frame.sample_rate
89
- * dtype_num_bytes(value_schema.audio_frame.dtype)
90
- )
91
- ),
92
- )
93
- case "imu":
94
- return K.Value(
95
- imu=K.ImuValue(
96
- linear_acceleration=K.ImuAccelerometerValue(x=0.0, y=0.0, z=0.0),
97
- angular_velocity=K.ImuGyroscopeValue(x=0.0, y=0.0, z=0.0),
98
- magnetic_field=K.ImuMagnetometerValue(x=0.0, y=0.0, z=0.0),
99
- ),
100
- )
101
- case "timestamp":
102
- return K.Value(
103
- timestamp=K.TimestampValue(seconds=1728000000, nanos=0),
104
- )
105
- case "vector_command":
106
- return K.Value(
107
- vector_command=K.VectorCommandValue(values=[0.0] * value_schema.vector_command.dimensions),
108
- )
109
- case "state_tensor":
110
- return K.Value(
111
- state_tensor=K.StateTensorValue(
112
- data=b"\x00"
113
- * np.prod(value_schema.state_tensor.shape)
114
- * dtype_num_bytes(value_schema.state_tensor.dtype)
115
- ),
116
- )
117
- case _:
118
- raise ValueError(f"Invalid value type: {value_type}")
119
-
120
-
121
- def get_dummy_io(schema: K.IOSchema) -> K.IO:
122
- io_value = K.IO()
123
- for value_schema in schema.values:
124
- io_value.values.append(get_dummy_value(value_schema))
125
- return io_value
kinfer/serialize/types.py DELETED
@@ -1,17 +0,0 @@
1
- """Type conversion utilities for serializers."""
2
-
3
- from typing import Type, TypeVar, cast
4
-
5
- from kinfer import proto as K
6
-
7
- T = TypeVar("T")
8
-
9
-
10
- def to_value_type(enum_value: T) -> K.Value:
11
- """Convert an enum value to ValueType."""
12
- return cast(K.Value, enum_value)
13
-
14
-
15
- def from_value_type(value_type: K.Value, enum_class: Type[T]) -> T:
16
- """Convert a ValueType to the specified enum type."""
17
- return cast(T, value_type)