kinfer 0.3.2__cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl → 0.3.3__cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kinfer/__init__.py +5 -1
- kinfer/export/__init__.py +0 -1
- kinfer/export/pytorch.py +3 -3
- kinfer/inference/__init__.py +2 -1
- kinfer/inference/base.py +64 -0
- kinfer/inference/python.py +9 -35
- kinfer/rust_bindings.cpython-311-aarch64-linux-gnu.so +0 -0
- kinfer/serialize/__init__.py +28 -4
- kinfer/serialize/base.py +61 -61
- kinfer/serialize/json.py +61 -61
- kinfer/serialize/numpy.py +62 -62
- kinfer/serialize/pytorch.py +55 -55
- kinfer/serialize/schema.py +31 -31
- kinfer/serialize/types.py +4 -4
- kinfer/serialize/utils.py +58 -58
- {kinfer-0.3.2.dist-info → kinfer-0.3.3.dist-info}/METADATA +1 -1
- {kinfer-0.3.2.dist-info → kinfer-0.3.3.dist-info}/RECORD +20 -19
- {kinfer-0.3.2.dist-info → kinfer-0.3.3.dist-info}/LICENSE +0 -0
- {kinfer-0.3.2.dist-info → kinfer-0.3.3.dist-info}/WHEEL +0 -0
- {kinfer-0.3.2.dist-info → kinfer-0.3.3.dist-info}/top_level.txt +0 -0
kinfer/__init__.py
CHANGED
@@ -1,6 +1,10 @@
|
|
1
1
|
"""Defines the kinfer API."""
|
2
2
|
|
3
|
-
from . import
|
3
|
+
from . import proto as K
|
4
|
+
from .export.pytorch import export_model, get_model
|
5
|
+
from .inference.base import KModel
|
6
|
+
from .inference.python import ONNXModel
|
4
7
|
from .rust_bindings import get_version
|
8
|
+
from .serialize import get_multi_serializer, get_serializer
|
5
9
|
|
6
10
|
__version__ = get_version()
|
kinfer/export/__init__.py
CHANGED
@@ -1 +0,0 @@
|
|
1
|
-
from .pytorch import *
|
kinfer/export/pytorch.py
CHANGED
@@ -10,7 +10,7 @@ import onnxruntime as ort
|
|
10
10
|
import torch
|
11
11
|
from torch import Tensor
|
12
12
|
|
13
|
-
from kinfer import proto as
|
13
|
+
from kinfer import proto as K
|
14
14
|
from kinfer.serialize.pytorch import PyTorchMultiSerializer
|
15
15
|
from kinfer.serialize.schema import get_dummy_io
|
16
16
|
from kinfer.serialize.utils import check_names_match
|
@@ -18,7 +18,7 @@ from kinfer.serialize.utils import check_names_match
|
|
18
18
|
KINFER_METADATA_KEY = "kinfer_metadata"
|
19
19
|
|
20
20
|
|
21
|
-
def _add_metadata_to_onnx(model_proto: onnx.ModelProto, schema:
|
21
|
+
def _add_metadata_to_onnx(model_proto: onnx.ModelProto, schema: K.ModelSchema) -> onnx.ModelProto:
|
22
22
|
"""Add metadata to ONNX model.
|
23
23
|
|
24
24
|
Args:
|
@@ -37,7 +37,7 @@ def _add_metadata_to_onnx(model_proto: onnx.ModelProto, schema: P.ModelSchema) -
|
|
37
37
|
return model_proto
|
38
38
|
|
39
39
|
|
40
|
-
def export_model(model: torch.jit.ScriptModule, schema:
|
40
|
+
def export_model(model: torch.jit.ScriptModule, schema: K.ModelSchema) -> onnx.ModelProto:
|
41
41
|
"""Export PyTorch model to ONNX format with metadata.
|
42
42
|
|
43
43
|
Args:
|
kinfer/inference/__init__.py
CHANGED
@@ -1 +1,2 @@
|
|
1
|
-
from .
|
1
|
+
from .base import KModel
|
2
|
+
from .python import ONNXModel
|
kinfer/inference/base.py
ADDED
@@ -0,0 +1,64 @@
|
|
1
|
+
"""Defines the base interface for running model inference.
|
2
|
+
|
3
|
+
All kinfer models must implement this interface - the model inputs and outputs
|
4
|
+
should match the provided schema, and the `__call__` method should take the
|
5
|
+
inputs and return the outputs according to this schema.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import functools
|
9
|
+
from abc import ABC, abstractmethod
|
10
|
+
|
11
|
+
from kinfer import proto as K
|
12
|
+
|
13
|
+
|
14
|
+
class KModel(ABC):
|
15
|
+
"""Base interface for running model inference."""
|
16
|
+
|
17
|
+
@abstractmethod
|
18
|
+
def get_schema(self) -> K.ModelSchema:
|
19
|
+
"""Get the model schema."""
|
20
|
+
|
21
|
+
@abstractmethod
|
22
|
+
def __call__(self, inputs: K.IO) -> K.IO:
|
23
|
+
"""Run inference on input data.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
inputs: Input data, matching the input schema.
|
27
|
+
|
28
|
+
Returns:
|
29
|
+
Model outputs, matching the output schema.
|
30
|
+
"""
|
31
|
+
|
32
|
+
@functools.cached_property
|
33
|
+
def schema(self) -> K.ModelSchema:
|
34
|
+
return self.get_schema()
|
35
|
+
|
36
|
+
@property
|
37
|
+
def input_schema(self) -> K.IOSchema:
|
38
|
+
"""Get the input schema."""
|
39
|
+
return self.schema.input_schema
|
40
|
+
|
41
|
+
@property
|
42
|
+
def output_schema(self) -> K.IOSchema:
|
43
|
+
"""Get the output schema."""
|
44
|
+
return self.schema.output_schema
|
45
|
+
|
46
|
+
@property
|
47
|
+
def schema_input_keys(self) -> list[str]:
|
48
|
+
"""Get all value names from input schemas.
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
List of value names from input schema.
|
52
|
+
"""
|
53
|
+
input_names = [value.value_name for value in self.input_schema.values]
|
54
|
+
return input_names
|
55
|
+
|
56
|
+
@property
|
57
|
+
def schema_output_keys(self) -> list[str]:
|
58
|
+
"""Get all value names from output schemas.
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
List of value names from output schema.
|
62
|
+
"""
|
63
|
+
output_names = [value.value_name for value in self.output_schema.values]
|
64
|
+
return output_names
|
kinfer/inference/python.py
CHANGED
@@ -6,17 +6,18 @@ from pathlib import Path
|
|
6
6
|
import onnx
|
7
7
|
import onnxruntime as ort
|
8
8
|
|
9
|
-
from kinfer import proto as
|
9
|
+
from kinfer import proto as K
|
10
10
|
from kinfer.export.pytorch import KINFER_METADATA_KEY
|
11
|
+
from kinfer.inference.base import KModel
|
11
12
|
from kinfer.serialize.numpy import NumpyMultiSerializer
|
12
13
|
|
13
14
|
|
14
|
-
def _read_schema(model: onnx.ModelProto) ->
|
15
|
+
def _read_schema(model: onnx.ModelProto) -> K.ModelSchema:
|
15
16
|
for prop in model.metadata_props:
|
16
17
|
if prop.key == KINFER_METADATA_KEY:
|
17
18
|
try:
|
18
19
|
schema_bytes = base64.b64decode(prop.value)
|
19
|
-
schema =
|
20
|
+
schema = K.ModelSchema()
|
20
21
|
schema.ParseFromString(schema_bytes)
|
21
22
|
return schema
|
22
23
|
except Exception as e:
|
@@ -27,7 +28,7 @@ def _read_schema(model: onnx.ModelProto) -> P.ModelSchema:
|
|
27
28
|
raise ValueError(f"{KINFER_METADATA_KEY} not found in model metadata")
|
28
29
|
|
29
30
|
|
30
|
-
class ONNXModel:
|
31
|
+
class ONNXModel(KModel):
|
31
32
|
"""Wrapper for ONNX model inference."""
|
32
33
|
|
33
34
|
def __init__(self: "ONNXModel", model_path: str | Path) -> None:
|
@@ -47,7 +48,10 @@ class ONNXModel:
|
|
47
48
|
self._input_serializer = NumpyMultiSerializer(self._schema.input_schema)
|
48
49
|
self._output_serializer = NumpyMultiSerializer(self._schema.output_schema)
|
49
50
|
|
50
|
-
def
|
51
|
+
def get_schema(self) -> K.ModelSchema:
|
52
|
+
return self._schema
|
53
|
+
|
54
|
+
def __call__(self, inputs: K.IO) -> K.IO:
|
51
55
|
"""Run inference on input data.
|
52
56
|
|
53
57
|
Args:
|
@@ -60,33 +64,3 @@ class ONNXModel:
|
|
60
64
|
outputs_np = self.session.run(None, inputs_np)
|
61
65
|
outputs = self._output_serializer.deserialize_io(outputs_np)
|
62
66
|
return outputs
|
63
|
-
|
64
|
-
@property
|
65
|
-
def input_schema(self: "ONNXModel") -> P.IOSchema:
|
66
|
-
"""Get the input schema."""
|
67
|
-
return self._schema.input_schema
|
68
|
-
|
69
|
-
@property
|
70
|
-
def output_schema(self: "ONNXModel") -> P.IOSchema:
|
71
|
-
"""Get the output schema."""
|
72
|
-
return self._schema.output_schema
|
73
|
-
|
74
|
-
@property
|
75
|
-
def schema_input_keys(self: "ONNXModel") -> list[str]:
|
76
|
-
"""Get all value names from input schemas.
|
77
|
-
|
78
|
-
Returns:
|
79
|
-
List of value names from input schema.
|
80
|
-
"""
|
81
|
-
input_names = [value.value_name for value in self._schema.input_schema.values]
|
82
|
-
return input_names
|
83
|
-
|
84
|
-
@property
|
85
|
-
def schema_output_keys(self: "ONNXModel") -> list[str]:
|
86
|
-
"""Get all value names from output schemas.
|
87
|
-
|
88
|
-
Returns:
|
89
|
-
List of value names from output schema.
|
90
|
-
"""
|
91
|
-
output_names = [value.value_name for value in self._schema.output_schema.values]
|
92
|
-
return output_names
|
Binary file
|
kinfer/serialize/__init__.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1
1
|
"""Defines an interface for instantiating serializers."""
|
2
2
|
|
3
|
-
from typing import Literal
|
3
|
+
from typing import Literal, overload
|
4
4
|
|
5
|
-
from kinfer import proto as
|
5
|
+
from kinfer import proto as K
|
6
6
|
|
7
7
|
from .base import MultiSerializer, Serializer
|
8
8
|
from .json import JsonMultiSerializer, JsonSerializer
|
@@ -12,7 +12,19 @@ from .pytorch import PyTorchMultiSerializer, PyTorchSerializer
|
|
12
12
|
SerializerType = Literal["json", "numpy", "pytorch"]
|
13
13
|
|
14
14
|
|
15
|
-
|
15
|
+
@overload
|
16
|
+
def get_serializer(schema: K.ValueSchema, serializer_type: Literal["json"]) -> JsonSerializer: ...
|
17
|
+
|
18
|
+
|
19
|
+
@overload
|
20
|
+
def get_serializer(schema: K.ValueSchema, serializer_type: Literal["numpy"]) -> NumpySerializer: ...
|
21
|
+
|
22
|
+
|
23
|
+
@overload
|
24
|
+
def get_serializer(schema: K.ValueSchema, serializer_type: Literal["pytorch"]) -> PyTorchSerializer: ...
|
25
|
+
|
26
|
+
|
27
|
+
def get_serializer(schema: K.ValueSchema, serializer_type: SerializerType) -> Serializer:
|
16
28
|
match serializer_type:
|
17
29
|
case "json":
|
18
30
|
return JsonSerializer(schema=schema)
|
@@ -24,7 +36,19 @@ def get_serializer(schema: P.ValueSchema, serializer_type: SerializerType) -> Se
|
|
24
36
|
raise ValueError(f"Unsupported serializer type: {serializer_type}")
|
25
37
|
|
26
38
|
|
27
|
-
|
39
|
+
@overload
|
40
|
+
def get_multi_serializer(schema: K.IOSchema, serializer_type: Literal["json"]) -> JsonMultiSerializer: ...
|
41
|
+
|
42
|
+
|
43
|
+
@overload
|
44
|
+
def get_multi_serializer(schema: K.IOSchema, serializer_type: Literal["numpy"]) -> NumpyMultiSerializer: ...
|
45
|
+
|
46
|
+
|
47
|
+
@overload
|
48
|
+
def get_multi_serializer(schema: K.IOSchema, serializer_type: Literal["pytorch"]) -> PyTorchMultiSerializer: ...
|
49
|
+
|
50
|
+
|
51
|
+
def get_multi_serializer(schema: K.IOSchema, serializer_type: SerializerType) -> MultiSerializer:
|
28
52
|
match serializer_type:
|
29
53
|
case "json":
|
30
54
|
return JsonMultiSerializer(schema=schema)
|
kinfer/serialize/base.py
CHANGED
@@ -3,7 +3,7 @@
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
from typing import Generic, Literal, Sequence, TypeVar, overload
|
5
5
|
|
6
|
-
from kinfer import proto as
|
6
|
+
from kinfer import proto as K
|
7
7
|
|
8
8
|
T = TypeVar("T")
|
9
9
|
|
@@ -12,8 +12,8 @@ class JointPositionsSerializer(ABC, Generic[T]):
|
|
12
12
|
@abstractmethod
|
13
13
|
def serialize_joint_positions(
|
14
14
|
self: "JointPositionsSerializer[T]",
|
15
|
-
schema:
|
16
|
-
value:
|
15
|
+
schema: K.JointPositionsSchema,
|
16
|
+
value: K.JointPositionsValue,
|
17
17
|
) -> T:
|
18
18
|
"""Serialize a joint positions value.
|
19
19
|
|
@@ -28,9 +28,9 @@ class JointPositionsSerializer(ABC, Generic[T]):
|
|
28
28
|
@abstractmethod
|
29
29
|
def deserialize_joint_positions(
|
30
30
|
self: "JointPositionsSerializer[T]",
|
31
|
-
schema:
|
31
|
+
schema: K.JointPositionsSchema,
|
32
32
|
value: T,
|
33
|
-
) ->
|
33
|
+
) -> K.JointPositionsValue:
|
34
34
|
"""Deserialize a joint positions value.
|
35
35
|
|
36
36
|
Args:
|
@@ -47,8 +47,8 @@ class JointVelocitiesSerializer(ABC, Generic[T]):
|
|
47
47
|
@abstractmethod
|
48
48
|
def serialize_joint_velocities(
|
49
49
|
self: "JointVelocitiesSerializer[T]",
|
50
|
-
schema:
|
51
|
-
value:
|
50
|
+
schema: K.JointVelocitiesSchema,
|
51
|
+
value: K.JointVelocitiesValue,
|
52
52
|
) -> T:
|
53
53
|
"""Serialize a joint velocities value.
|
54
54
|
|
@@ -63,9 +63,9 @@ class JointVelocitiesSerializer(ABC, Generic[T]):
|
|
63
63
|
@abstractmethod
|
64
64
|
def deserialize_joint_velocities(
|
65
65
|
self: "JointVelocitiesSerializer[T]",
|
66
|
-
schema:
|
66
|
+
schema: K.JointVelocitiesSchema,
|
67
67
|
value: T,
|
68
|
-
) ->
|
68
|
+
) -> K.JointVelocitiesValue:
|
69
69
|
"""Deserialize a joint velocities value.
|
70
70
|
|
71
71
|
Args:
|
@@ -81,8 +81,8 @@ class JointTorquesSerializer(ABC, Generic[T]):
|
|
81
81
|
@abstractmethod
|
82
82
|
def serialize_joint_torques(
|
83
83
|
self: "JointTorquesSerializer[T]",
|
84
|
-
schema:
|
85
|
-
value:
|
84
|
+
schema: K.JointTorquesSchema,
|
85
|
+
value: K.JointTorquesValue,
|
86
86
|
) -> T:
|
87
87
|
"""Serialize a joint torques value.
|
88
88
|
|
@@ -97,9 +97,9 @@ class JointTorquesSerializer(ABC, Generic[T]):
|
|
97
97
|
@abstractmethod
|
98
98
|
def deserialize_joint_torques(
|
99
99
|
self: "JointTorquesSerializer[T]",
|
100
|
-
schema:
|
100
|
+
schema: K.JointTorquesSchema,
|
101
101
|
value: T,
|
102
|
-
) ->
|
102
|
+
) -> K.JointTorquesValue:
|
103
103
|
"""Deserialize a joint torques value.
|
104
104
|
|
105
105
|
Args:
|
@@ -115,8 +115,8 @@ class JointCommandsSerializer(ABC, Generic[T]):
|
|
115
115
|
@abstractmethod
|
116
116
|
def serialize_joint_commands(
|
117
117
|
self: "JointCommandsSerializer[T]",
|
118
|
-
schema:
|
119
|
-
value:
|
118
|
+
schema: K.JointCommandsSchema,
|
119
|
+
value: K.JointCommandsValue,
|
120
120
|
) -> T:
|
121
121
|
"""Serialize a joint commands value.
|
122
122
|
|
@@ -131,9 +131,9 @@ class JointCommandsSerializer(ABC, Generic[T]):
|
|
131
131
|
@abstractmethod
|
132
132
|
def deserialize_joint_commands(
|
133
133
|
self: "JointCommandsSerializer[T]",
|
134
|
-
schema:
|
134
|
+
schema: K.JointCommandsSchema,
|
135
135
|
value: T,
|
136
|
-
) ->
|
136
|
+
) -> K.JointCommandsValue:
|
137
137
|
"""Deserialize a joint commands value.
|
138
138
|
|
139
139
|
Args:
|
@@ -149,8 +149,8 @@ class CameraFrameSerializer(ABC, Generic[T]):
|
|
149
149
|
@abstractmethod
|
150
150
|
def serialize_camera_frame(
|
151
151
|
self: "CameraFrameSerializer[T]",
|
152
|
-
schema:
|
153
|
-
value:
|
152
|
+
schema: K.CameraFrameSchema,
|
153
|
+
value: K.CameraFrameValue,
|
154
154
|
) -> T:
|
155
155
|
"""Serialize a camera frame value.
|
156
156
|
|
@@ -165,9 +165,9 @@ class CameraFrameSerializer(ABC, Generic[T]):
|
|
165
165
|
@abstractmethod
|
166
166
|
def deserialize_camera_frame(
|
167
167
|
self: "CameraFrameSerializer[T]",
|
168
|
-
schema:
|
168
|
+
schema: K.CameraFrameSchema,
|
169
169
|
value: T,
|
170
|
-
) ->
|
170
|
+
) -> K.CameraFrameValue:
|
171
171
|
"""Deserialize a camera frame value.
|
172
172
|
|
173
173
|
Args:
|
@@ -183,8 +183,8 @@ class AudioFrameSerializer(ABC, Generic[T]):
|
|
183
183
|
@abstractmethod
|
184
184
|
def serialize_audio_frame(
|
185
185
|
self: "AudioFrameSerializer[T]",
|
186
|
-
schema:
|
187
|
-
value:
|
186
|
+
schema: K.AudioFrameSchema,
|
187
|
+
value: K.AudioFrameValue,
|
188
188
|
) -> T:
|
189
189
|
"""Serialize an audio frame value.
|
190
190
|
|
@@ -199,9 +199,9 @@ class AudioFrameSerializer(ABC, Generic[T]):
|
|
199
199
|
@abstractmethod
|
200
200
|
def deserialize_audio_frame(
|
201
201
|
self: "AudioFrameSerializer[T]",
|
202
|
-
schema:
|
202
|
+
schema: K.AudioFrameSchema,
|
203
203
|
value: T,
|
204
|
-
) ->
|
204
|
+
) -> K.AudioFrameValue:
|
205
205
|
"""Deserialize an audio frame value.
|
206
206
|
|
207
207
|
Args:
|
@@ -217,8 +217,8 @@ class ImuSerializer(ABC, Generic[T]):
|
|
217
217
|
@abstractmethod
|
218
218
|
def serialize_imu(
|
219
219
|
self: "ImuSerializer[T]",
|
220
|
-
schema:
|
221
|
-
value:
|
220
|
+
schema: K.ImuSchema,
|
221
|
+
value: K.ImuValue,
|
222
222
|
) -> T:
|
223
223
|
"""Serialize an IMU value.
|
224
224
|
|
@@ -233,9 +233,9 @@ class ImuSerializer(ABC, Generic[T]):
|
|
233
233
|
@abstractmethod
|
234
234
|
def deserialize_imu(
|
235
235
|
self: "ImuSerializer[T]",
|
236
|
-
schema:
|
236
|
+
schema: K.ImuSchema,
|
237
237
|
value: T,
|
238
|
-
) ->
|
238
|
+
) -> K.ImuValue:
|
239
239
|
"""Deserialize an IMU value.
|
240
240
|
|
241
241
|
Args:
|
@@ -251,8 +251,8 @@ class TimestampSerializer(ABC, Generic[T]):
|
|
251
251
|
@abstractmethod
|
252
252
|
def serialize_timestamp(
|
253
253
|
self: "TimestampSerializer[T]",
|
254
|
-
schema:
|
255
|
-
value:
|
254
|
+
schema: K.TimestampSchema,
|
255
|
+
value: K.TimestampValue,
|
256
256
|
) -> T:
|
257
257
|
"""Serialize a timestamp value.
|
258
258
|
|
@@ -267,9 +267,9 @@ class TimestampSerializer(ABC, Generic[T]):
|
|
267
267
|
@abstractmethod
|
268
268
|
def deserialize_timestamp(
|
269
269
|
self: "TimestampSerializer[T]",
|
270
|
-
schema:
|
270
|
+
schema: K.TimestampSchema,
|
271
271
|
value: T,
|
272
|
-
) ->
|
272
|
+
) -> K.TimestampValue:
|
273
273
|
"""Deserialize a timestamp value.
|
274
274
|
|
275
275
|
Args:
|
@@ -285,8 +285,8 @@ class VectorCommandSerializer(ABC, Generic[T]):
|
|
285
285
|
@abstractmethod
|
286
286
|
def serialize_vector_command(
|
287
287
|
self: "VectorCommandSerializer[T]",
|
288
|
-
schema:
|
289
|
-
value:
|
288
|
+
schema: K.VectorCommandSchema,
|
289
|
+
value: K.VectorCommandValue,
|
290
290
|
) -> T:
|
291
291
|
"""Serialize an XY command value.
|
292
292
|
|
@@ -301,9 +301,9 @@ class VectorCommandSerializer(ABC, Generic[T]):
|
|
301
301
|
@abstractmethod
|
302
302
|
def deserialize_vector_command(
|
303
303
|
self: "VectorCommandSerializer[T]",
|
304
|
-
schema:
|
304
|
+
schema: K.VectorCommandSchema,
|
305
305
|
value: T,
|
306
|
-
) ->
|
306
|
+
) -> K.VectorCommandValue:
|
307
307
|
"""Deserialize a vector command value.
|
308
308
|
|
309
309
|
Args:
|
@@ -319,8 +319,8 @@ class StateTensorSerializer(ABC, Generic[T]):
|
|
319
319
|
@abstractmethod
|
320
320
|
def serialize_state_tensor(
|
321
321
|
self: "StateTensorSerializer[T]",
|
322
|
-
schema:
|
323
|
-
value:
|
322
|
+
schema: K.StateTensorSchema,
|
323
|
+
value: K.StateTensorValue,
|
324
324
|
) -> T:
|
325
325
|
"""Serialize a state tensor value.
|
326
326
|
|
@@ -335,9 +335,9 @@ class StateTensorSerializer(ABC, Generic[T]):
|
|
335
335
|
@abstractmethod
|
336
336
|
def deserialize_state_tensor(
|
337
337
|
self: "StateTensorSerializer[T]",
|
338
|
-
schema:
|
338
|
+
schema: K.StateTensorSchema,
|
339
339
|
value: T,
|
340
|
-
) ->
|
340
|
+
) -> K.StateTensorValue:
|
341
341
|
"""Deserialize a state tensor value.
|
342
342
|
|
343
343
|
Args:
|
@@ -362,10 +362,10 @@ class Serializer(
|
|
362
362
|
StateTensorSerializer[T],
|
363
363
|
Generic[T],
|
364
364
|
):
|
365
|
-
def __init__(self: "Serializer[T]", schema:
|
365
|
+
def __init__(self: "Serializer[T]", schema: K.ValueSchema) -> None:
|
366
366
|
self.schema = schema
|
367
367
|
|
368
|
-
def serialize(self: "Serializer[T]", value:
|
368
|
+
def serialize(self: "Serializer[T]", value: K.Value) -> T:
|
369
369
|
value_type = value.WhichOneof("value")
|
370
370
|
|
371
371
|
match value_type:
|
@@ -422,75 +422,75 @@ class Serializer(
|
|
422
422
|
case _:
|
423
423
|
raise ValueError(f"Unsupported value type: {value_type}")
|
424
424
|
|
425
|
-
def deserialize(self: "Serializer[T]", value: T) ->
|
425
|
+
def deserialize(self: "Serializer[T]", value: T) -> K.Value:
|
426
426
|
value_type = self.schema.WhichOneof("value_type")
|
427
427
|
|
428
428
|
match value_type:
|
429
429
|
case "joint_positions":
|
430
|
-
return
|
430
|
+
return K.Value(
|
431
431
|
joint_positions=self.deserialize_joint_positions(
|
432
432
|
schema=self.schema.joint_positions,
|
433
433
|
value=value,
|
434
434
|
),
|
435
435
|
)
|
436
436
|
case "joint_velocities":
|
437
|
-
return
|
437
|
+
return K.Value(
|
438
438
|
joint_velocities=self.deserialize_joint_velocities(
|
439
439
|
schema=self.schema.joint_velocities,
|
440
440
|
value=value,
|
441
441
|
),
|
442
442
|
)
|
443
443
|
case "joint_torques":
|
444
|
-
return
|
444
|
+
return K.Value(
|
445
445
|
joint_torques=self.deserialize_joint_torques(
|
446
446
|
schema=self.schema.joint_torques,
|
447
447
|
value=value,
|
448
448
|
),
|
449
449
|
)
|
450
450
|
case "joint_commands":
|
451
|
-
return
|
451
|
+
return K.Value(
|
452
452
|
joint_commands=self.deserialize_joint_commands(
|
453
453
|
schema=self.schema.joint_commands,
|
454
454
|
value=value,
|
455
455
|
),
|
456
456
|
)
|
457
457
|
case "camera_frame":
|
458
|
-
return
|
458
|
+
return K.Value(
|
459
459
|
camera_frame=self.deserialize_camera_frame(
|
460
460
|
schema=self.schema.camera_frame,
|
461
461
|
value=value,
|
462
462
|
),
|
463
463
|
)
|
464
464
|
case "audio_frame":
|
465
|
-
return
|
465
|
+
return K.Value(
|
466
466
|
audio_frame=self.deserialize_audio_frame(
|
467
467
|
schema=self.schema.audio_frame,
|
468
468
|
value=value,
|
469
469
|
),
|
470
470
|
)
|
471
471
|
case "imu":
|
472
|
-
return
|
472
|
+
return K.Value(
|
473
473
|
imu=self.deserialize_imu(
|
474
474
|
schema=self.schema.imu,
|
475
475
|
value=value,
|
476
476
|
),
|
477
477
|
)
|
478
478
|
case "timestamp":
|
479
|
-
return
|
479
|
+
return K.Value(
|
480
480
|
timestamp=self.deserialize_timestamp(
|
481
481
|
schema=self.schema.timestamp,
|
482
482
|
value=value,
|
483
483
|
),
|
484
484
|
)
|
485
485
|
case "vector_command":
|
486
|
-
return
|
486
|
+
return K.Value(
|
487
487
|
vector_command=self.deserialize_vector_command(
|
488
488
|
schema=self.schema.vector_command,
|
489
489
|
value=value,
|
490
490
|
),
|
491
491
|
)
|
492
492
|
case "state_tensor":
|
493
|
-
return
|
493
|
+
return K.Value(
|
494
494
|
state_tensor=self.deserialize_state_tensor(
|
495
495
|
schema=self.schema.state_tensor,
|
496
496
|
value=value,
|
@@ -505,24 +505,24 @@ class MultiSerializer(Generic[T]):
|
|
505
505
|
self.serializers = list(serializers)
|
506
506
|
|
507
507
|
@overload
|
508
|
-
def serialize_io(self: "MultiSerializer[T]", io:
|
508
|
+
def serialize_io(self: "MultiSerializer[T]", io: K.IO, *, as_dict: Literal[True]) -> dict[str, T]: ...
|
509
509
|
|
510
510
|
@overload
|
511
|
-
def serialize_io(self: "MultiSerializer[T]", io:
|
511
|
+
def serialize_io(self: "MultiSerializer[T]", io: K.IO, *, as_dict: Literal[False] = False) -> list[T]: ...
|
512
512
|
|
513
|
-
def serialize_io(self: "MultiSerializer[T]", io:
|
514
|
-
if not isinstance(io,
|
513
|
+
def serialize_io(self: "MultiSerializer[T]", io: K.IO, *, as_dict: bool = False) -> dict[str, T] | list[T]:
|
514
|
+
if not isinstance(io, K.IO):
|
515
515
|
raise ValueError(f"Inputs must be an IO protobuf, not {type(io)}")
|
516
516
|
if as_dict:
|
517
517
|
return {s.schema.value_name: s.serialize(i) for s, i in zip(self.serializers, io.values)}
|
518
518
|
return [s.serialize(i) for s, i in zip(self.serializers, io.values)]
|
519
519
|
|
520
|
-
def deserialize_io(self: "MultiSerializer[T]", io: dict[str, T] | list[T]) ->
|
520
|
+
def deserialize_io(self: "MultiSerializer[T]", io: dict[str, T] | list[T]) -> K.IO:
|
521
521
|
if not isinstance(io, (dict, list)):
|
522
522
|
raise ValueError(f"Inputs must be a dictionary or list, not {type(io)}")
|
523
523
|
if isinstance(io, dict):
|
524
|
-
return
|
525
|
-
return
|
524
|
+
return K.IO(values=[s.deserialize(i) for s, i in zip(self.serializers, io.values())])
|
525
|
+
return K.IO(values=[s.deserialize(i) for s, i in zip(self.serializers, io)])
|
526
526
|
|
527
527
|
def assign_names(self: "MultiSerializer[T]", values: Sequence[T]) -> dict[str, T]:
|
528
528
|
if not isinstance(values, Sequence):
|