kinfer 0.3.2__cp312-cp312-macosx_11_0_arm64.whl → 0.3.3__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kinfer/__init__.py +5 -1
- kinfer/export/__init__.py +0 -1
- kinfer/export/pytorch.py +3 -3
- kinfer/inference/__init__.py +2 -1
- kinfer/inference/base.py +64 -0
- kinfer/inference/python.py +9 -35
- kinfer/rust_bindings.cpython-312-darwin.so +0 -0
- kinfer/serialize/__init__.py +28 -4
- kinfer/serialize/base.py +61 -61
- kinfer/serialize/json.py +61 -61
- kinfer/serialize/numpy.py +62 -62
- kinfer/serialize/pytorch.py +55 -55
- kinfer/serialize/schema.py +31 -31
- kinfer/serialize/types.py +4 -4
- kinfer/serialize/utils.py +58 -58
- {kinfer-0.3.2.dist-info → kinfer-0.3.3.dist-info}/METADATA +1 -1
- {kinfer-0.3.2.dist-info → kinfer-0.3.3.dist-info}/RECORD +20 -19
- {kinfer-0.3.2.dist-info → kinfer-0.3.3.dist-info}/LICENSE +0 -0
- {kinfer-0.3.2.dist-info → kinfer-0.3.3.dist-info}/WHEEL +0 -0
- {kinfer-0.3.2.dist-info → kinfer-0.3.3.dist-info}/top_level.txt +0 -0
kinfer/serialize/json.py
CHANGED
@@ -3,7 +3,7 @@
|
|
3
3
|
import base64
|
4
4
|
from typing import Any, Mapping, Sequence
|
5
5
|
|
6
|
-
from kinfer import proto as
|
6
|
+
from kinfer import proto as K
|
7
7
|
from kinfer.serialize.base import (
|
8
8
|
AudioFrameSerializer,
|
9
9
|
CameraFrameSerializer,
|
@@ -42,8 +42,8 @@ JsonValue = Mapping[
|
|
42
42
|
class JsonJointPositionsSerializer(JointPositionsSerializer[JsonValue]):
|
43
43
|
def serialize_joint_positions(
|
44
44
|
self: "JsonJointPositionsSerializer",
|
45
|
-
schema:
|
46
|
-
value:
|
45
|
+
schema: K.JointPositionsSchema,
|
46
|
+
value: K.JointPositionsValue,
|
47
47
|
) -> dict[str, list[float]]:
|
48
48
|
value_map = {v.joint_name: v for v in value.values}
|
49
49
|
check_names_match("schema", schema.joint_names, "value", list(value_map.keys()))
|
@@ -56,9 +56,9 @@ class JsonJointPositionsSerializer(JointPositionsSerializer[JsonValue]):
|
|
56
56
|
|
57
57
|
def deserialize_joint_positions(
|
58
58
|
self: "JsonJointPositionsSerializer",
|
59
|
-
schema:
|
59
|
+
schema: K.JointPositionsSchema,
|
60
60
|
value: JsonValue,
|
61
|
-
) ->
|
61
|
+
) -> K.JointPositionsValue:
|
62
62
|
if "positions" not in value:
|
63
63
|
raise ValueError("Key 'positions' not found in value")
|
64
64
|
positions = value["positions"]
|
@@ -68,9 +68,9 @@ class JsonJointPositionsSerializer(JointPositionsSerializer[JsonValue]):
|
|
68
68
|
raise ValueError(
|
69
69
|
f"Shape of positions must match number of joint names: {len(positions)} != {len(schema.joint_names)}"
|
70
70
|
)
|
71
|
-
return
|
71
|
+
return K.JointPositionsValue(
|
72
72
|
values=[
|
73
|
-
|
73
|
+
K.JointPositionValue(joint_name=name, value=as_float(positions[i]), unit=schema.unit)
|
74
74
|
for i, name in enumerate(schema.joint_names)
|
75
75
|
]
|
76
76
|
)
|
@@ -79,8 +79,8 @@ class JsonJointPositionsSerializer(JointPositionsSerializer[JsonValue]):
|
|
79
79
|
class JsonJointVelocitiesSerializer(JointVelocitiesSerializer[JsonValue]):
|
80
80
|
def serialize_joint_velocities(
|
81
81
|
self: "JsonJointVelocitiesSerializer",
|
82
|
-
schema:
|
83
|
-
value:
|
82
|
+
schema: K.JointVelocitiesSchema,
|
83
|
+
value: K.JointVelocitiesValue,
|
84
84
|
) -> dict[str, list[float]]:
|
85
85
|
value_map = {v.joint_name: v for v in value.values}
|
86
86
|
check_names_match("schema", schema.joint_names, "value", list(value_map.keys()))
|
@@ -93,9 +93,9 @@ class JsonJointVelocitiesSerializer(JointVelocitiesSerializer[JsonValue]):
|
|
93
93
|
|
94
94
|
def deserialize_joint_velocities(
|
95
95
|
self: "JsonJointVelocitiesSerializer",
|
96
|
-
schema:
|
96
|
+
schema: K.JointVelocitiesSchema,
|
97
97
|
value: JsonValue,
|
98
|
-
) ->
|
98
|
+
) -> K.JointVelocitiesValue:
|
99
99
|
if "velocities" not in value:
|
100
100
|
raise ValueError("Key 'velocities' not found in value")
|
101
101
|
velocities = value["velocities"]
|
@@ -105,9 +105,9 @@ class JsonJointVelocitiesSerializer(JointVelocitiesSerializer[JsonValue]):
|
|
105
105
|
raise ValueError(
|
106
106
|
f"Shape of velocities must match number of joint names: {len(velocities)} != {len(schema.joint_names)}"
|
107
107
|
)
|
108
|
-
return
|
108
|
+
return K.JointVelocitiesValue(
|
109
109
|
values=[
|
110
|
-
|
110
|
+
K.JointVelocityValue(joint_name=name, value=as_float(velocities[i]), unit=schema.unit)
|
111
111
|
for i, name in enumerate(schema.joint_names)
|
112
112
|
]
|
113
113
|
)
|
@@ -116,8 +116,8 @@ class JsonJointVelocitiesSerializer(JointVelocitiesSerializer[JsonValue]):
|
|
116
116
|
class JsonJointTorquesSerializer(JointTorquesSerializer[JsonValue]):
|
117
117
|
def serialize_joint_torques(
|
118
118
|
self: "JsonJointTorquesSerializer",
|
119
|
-
schema:
|
120
|
-
value:
|
119
|
+
schema: K.JointTorquesSchema,
|
120
|
+
value: K.JointTorquesValue,
|
121
121
|
) -> dict[str, list[float]]:
|
122
122
|
value_map = {v.joint_name: v for v in value.values}
|
123
123
|
check_names_match("schema", schema.joint_names, "value", list(value_map.keys()))
|
@@ -129,9 +129,9 @@ class JsonJointTorquesSerializer(JointTorquesSerializer[JsonValue]):
|
|
129
129
|
|
130
130
|
def deserialize_joint_torques(
|
131
131
|
self: "JsonJointTorquesSerializer",
|
132
|
-
schema:
|
132
|
+
schema: K.JointTorquesSchema,
|
133
133
|
value: JsonValue,
|
134
|
-
) ->
|
134
|
+
) -> K.JointTorquesValue:
|
135
135
|
if "torques" not in value:
|
136
136
|
raise ValueError("Key 'torques' not found in value")
|
137
137
|
torques = value["torques"]
|
@@ -141,9 +141,9 @@ class JsonJointTorquesSerializer(JointTorquesSerializer[JsonValue]):
|
|
141
141
|
raise ValueError(
|
142
142
|
f"Shape of torques must match number of joint names: {len(torques)} != {len(schema.joint_names)}"
|
143
143
|
)
|
144
|
-
return
|
144
|
+
return K.JointTorquesValue(
|
145
145
|
values=[
|
146
|
-
|
146
|
+
K.JointTorqueValue(joint_name=name, value=as_float(torques[i]), unit=schema.unit)
|
147
147
|
for i, name in enumerate(schema.joint_names)
|
148
148
|
]
|
149
149
|
)
|
@@ -152,8 +152,8 @@ class JsonJointTorquesSerializer(JointTorquesSerializer[JsonValue]):
|
|
152
152
|
class JsonJointCommandsSerializer(JointCommandsSerializer[JsonValue]):
|
153
153
|
def _convert_value_to_array(
|
154
154
|
self: "JsonJointCommandsSerializer",
|
155
|
-
value:
|
156
|
-
schema:
|
155
|
+
value: K.JointCommandValue,
|
156
|
+
schema: K.JointCommandsSchema,
|
157
157
|
) -> list[float]:
|
158
158
|
return [
|
159
159
|
convert_torque(value.torque, value.torque_unit, schema.torque_unit),
|
@@ -166,14 +166,14 @@ class JsonJointCommandsSerializer(JointCommandsSerializer[JsonValue]):
|
|
166
166
|
def _convert_array_to_value(
|
167
167
|
self: "JsonJointCommandsSerializer",
|
168
168
|
values: Any, # noqa: ANN401
|
169
|
-
schema:
|
169
|
+
schema: K.JointCommandsSchema,
|
170
170
|
name: str,
|
171
|
-
) ->
|
171
|
+
) -> K.JointCommandValue:
|
172
172
|
if not isinstance(values, list):
|
173
173
|
raise ValueError("Value must be a list")
|
174
174
|
if len(values) != 5:
|
175
175
|
raise ValueError(f"Shape of command must match number of joint commands: {len(values)} != 5")
|
176
|
-
return
|
176
|
+
return K.JointCommandValue(
|
177
177
|
joint_name=name,
|
178
178
|
torque=float(values[0]),
|
179
179
|
velocity=float(values[1]),
|
@@ -187,8 +187,8 @@ class JsonJointCommandsSerializer(JointCommandsSerializer[JsonValue]):
|
|
187
187
|
|
188
188
|
def serialize_joint_commands(
|
189
189
|
self: "JsonJointCommandsSerializer",
|
190
|
-
schema:
|
191
|
-
value:
|
190
|
+
schema: K.JointCommandsSchema,
|
191
|
+
value: K.JointCommandsValue,
|
192
192
|
) -> dict[str, dict[str, list[float]]]:
|
193
193
|
value_map = {v.joint_name: v for v in value.values}
|
194
194
|
check_names_match("schema", schema.joint_names, "value", list(value_map.keys()))
|
@@ -198,16 +198,16 @@ class JsonJointCommandsSerializer(JointCommandsSerializer[JsonValue]):
|
|
198
198
|
|
199
199
|
def deserialize_joint_commands(
|
200
200
|
self: "JsonJointCommandsSerializer",
|
201
|
-
schema:
|
201
|
+
schema: K.JointCommandsSchema,
|
202
202
|
value: JsonValue,
|
203
|
-
) ->
|
203
|
+
) -> K.JointCommandsValue:
|
204
204
|
if "commands" not in value:
|
205
205
|
raise ValueError("Key 'commands' not found in value")
|
206
206
|
commands = value["commands"]
|
207
207
|
if not isinstance(commands, dict):
|
208
208
|
raise ValueError("Key 'commands' must be a dictionary")
|
209
209
|
check_names_match("schema", schema.joint_names, "value", list(commands.keys()))
|
210
|
-
return
|
210
|
+
return K.JointCommandsValue(
|
211
211
|
values=[self._convert_array_to_value(commands[name], schema, name) for name in schema.joint_names]
|
212
212
|
)
|
213
213
|
|
@@ -215,50 +215,50 @@ class JsonJointCommandsSerializer(JointCommandsSerializer[JsonValue]):
|
|
215
215
|
class JsonCameraFrameSerializer(CameraFrameSerializer[JsonValue]):
|
216
216
|
def serialize_camera_frame(
|
217
217
|
self: "JsonCameraFrameSerializer",
|
218
|
-
schema:
|
219
|
-
value:
|
218
|
+
schema: K.CameraFrameSchema,
|
219
|
+
value: K.CameraFrameValue,
|
220
220
|
) -> dict[str, str]:
|
221
221
|
return {"data": base64.b64encode(value.data).decode("utf-8")}
|
222
222
|
|
223
223
|
def deserialize_camera_frame(
|
224
224
|
self: "JsonCameraFrameSerializer",
|
225
|
-
schema:
|
225
|
+
schema: K.CameraFrameSchema,
|
226
226
|
value: JsonValue,
|
227
|
-
) ->
|
227
|
+
) -> K.CameraFrameValue:
|
228
228
|
if "data" not in value:
|
229
229
|
raise ValueError("Key 'data' not found in value")
|
230
230
|
data = value["data"]
|
231
231
|
if not isinstance(data, str):
|
232
232
|
raise ValueError("Key 'data' must be a string")
|
233
|
-
return
|
233
|
+
return K.CameraFrameValue(data=base64.b64decode(data))
|
234
234
|
|
235
235
|
|
236
236
|
class JsonAudioFrameSerializer(AudioFrameSerializer[JsonValue]):
|
237
237
|
def serialize_audio_frame(
|
238
238
|
self: "JsonAudioFrameSerializer",
|
239
|
-
schema:
|
240
|
-
value:
|
239
|
+
schema: K.AudioFrameSchema,
|
240
|
+
value: K.AudioFrameValue,
|
241
241
|
) -> dict[str, str]:
|
242
242
|
return {"data": base64.b64encode(value.data).decode("utf-8")}
|
243
243
|
|
244
244
|
def deserialize_audio_frame(
|
245
245
|
self: "JsonAudioFrameSerializer",
|
246
|
-
schema:
|
246
|
+
schema: K.AudioFrameSchema,
|
247
247
|
value: JsonValue,
|
248
|
-
) ->
|
248
|
+
) -> K.AudioFrameValue:
|
249
249
|
if "data" not in value:
|
250
250
|
raise ValueError("Key 'data' not found in value")
|
251
251
|
data = value["data"]
|
252
252
|
if not isinstance(data, str):
|
253
253
|
raise ValueError("Key 'data' must be a string")
|
254
|
-
return
|
254
|
+
return K.AudioFrameValue(data=base64.b64decode(data))
|
255
255
|
|
256
256
|
|
257
257
|
class JsonImuSerializer(ImuSerializer[JsonValue]):
|
258
258
|
def serialize_imu(
|
259
259
|
self: "JsonImuSerializer",
|
260
|
-
schema:
|
261
|
-
value:
|
260
|
+
schema: K.ImuSchema,
|
261
|
+
value: K.ImuValue,
|
262
262
|
) -> dict[str, list[float]]:
|
263
263
|
data: dict[str, list[float]] = {}
|
264
264
|
if schema.use_accelerometer:
|
@@ -283,10 +283,10 @@ class JsonImuSerializer(ImuSerializer[JsonValue]):
|
|
283
283
|
|
284
284
|
def deserialize_imu(
|
285
285
|
self: "JsonImuSerializer",
|
286
|
-
schema:
|
286
|
+
schema: K.ImuSchema,
|
287
287
|
value: JsonValue,
|
288
|
-
) ->
|
289
|
-
imu_value =
|
288
|
+
) -> K.ImuValue:
|
289
|
+
imu_value = K.ImuValue()
|
290
290
|
if schema.use_accelerometer:
|
291
291
|
if not isinstance(linear_acceleration := value["linear_acceleration"], list):
|
292
292
|
raise ValueError("Key 'linear_acceleration' must be a list")
|
@@ -314,38 +314,38 @@ class JsonImuSerializer(ImuSerializer[JsonValue]):
|
|
314
314
|
class JsonTimestampSerializer(TimestampSerializer[JsonValue]):
|
315
315
|
def serialize_timestamp(
|
316
316
|
self: "JsonTimestampSerializer",
|
317
|
-
schema:
|
318
|
-
value:
|
317
|
+
schema: K.TimestampSchema,
|
318
|
+
value: K.TimestampValue,
|
319
319
|
) -> dict[str, int]:
|
320
320
|
return {"seconds": value.seconds, "nanos": value.nanos}
|
321
321
|
|
322
322
|
def deserialize_timestamp(
|
323
323
|
self: "JsonTimestampSerializer",
|
324
|
-
schema:
|
324
|
+
schema: K.TimestampSchema,
|
325
325
|
value: JsonValue,
|
326
|
-
) ->
|
326
|
+
) -> K.TimestampValue:
|
327
327
|
if "seconds" not in value or "nanos" not in value:
|
328
328
|
raise ValueError("Key 'seconds' or 'nanos' not found in value")
|
329
329
|
seconds = value["seconds"]
|
330
330
|
nanos = value["nanos"]
|
331
331
|
if not isinstance(seconds, int) or not isinstance(nanos, int):
|
332
332
|
raise ValueError("Key 'seconds' and 'nanos' must be integers")
|
333
|
-
return
|
333
|
+
return K.TimestampValue(seconds=seconds, nanos=nanos)
|
334
334
|
|
335
335
|
|
336
336
|
class JsonVectorCommandSerializer(VectorCommandSerializer[JsonValue]):
|
337
337
|
def serialize_vector_command(
|
338
338
|
self: "JsonVectorCommandSerializer",
|
339
|
-
schema:
|
340
|
-
value:
|
339
|
+
schema: K.VectorCommandSchema,
|
340
|
+
value: K.VectorCommandValue,
|
341
341
|
) -> dict[str, list[float]]:
|
342
342
|
return {"values": list(value.values)}
|
343
343
|
|
344
344
|
def deserialize_vector_command(
|
345
345
|
self: "JsonVectorCommandSerializer",
|
346
|
-
schema:
|
346
|
+
schema: K.VectorCommandSchema,
|
347
347
|
value: JsonValue,
|
348
|
-
) ->
|
348
|
+
) -> K.VectorCommandValue:
|
349
349
|
if "values" not in value:
|
350
350
|
raise ValueError("Key 'values' not found in value")
|
351
351
|
values = value["values"]
|
@@ -353,28 +353,28 @@ class JsonVectorCommandSerializer(VectorCommandSerializer[JsonValue]):
|
|
353
353
|
raise ValueError("Key 'values' must be a list")
|
354
354
|
if len(values) != schema.dimensions:
|
355
355
|
raise ValueError(f"Length of list must match number of dimensions: {len(values)} != {schema.dimensions}")
|
356
|
-
return
|
356
|
+
return K.VectorCommandValue(values=[as_float(v) for v in values])
|
357
357
|
|
358
358
|
|
359
359
|
class JsonStateTensorSerializer(StateTensorSerializer[JsonValue]):
|
360
360
|
def serialize_state_tensor(
|
361
361
|
self: "JsonStateTensorSerializer",
|
362
|
-
schema:
|
363
|
-
value:
|
362
|
+
schema: K.StateTensorSchema,
|
363
|
+
value: K.StateTensorValue,
|
364
364
|
) -> dict[str, str]:
|
365
365
|
return {"data": base64.b64encode(value.data).decode("utf-8")}
|
366
366
|
|
367
367
|
def deserialize_state_tensor(
|
368
368
|
self: "JsonStateTensorSerializer",
|
369
|
-
schema:
|
369
|
+
schema: K.StateTensorSchema,
|
370
370
|
value: JsonValue,
|
371
|
-
) ->
|
371
|
+
) -> K.StateTensorValue:
|
372
372
|
if "data" not in value:
|
373
373
|
raise ValueError("Key 'data' not found in value")
|
374
374
|
data = value["data"]
|
375
375
|
if not isinstance(data, str):
|
376
376
|
raise ValueError("Key 'data' must be a string")
|
377
|
-
return
|
377
|
+
return K.StateTensorValue(data=base64.b64decode(data))
|
378
378
|
|
379
379
|
|
380
380
|
class JsonSerializer(
|
@@ -390,10 +390,10 @@ class JsonSerializer(
|
|
390
390
|
JsonStateTensorSerializer,
|
391
391
|
Serializer[JsonValue],
|
392
392
|
):
|
393
|
-
def __init__(self: "JsonSerializer", schema:
|
393
|
+
def __init__(self: "JsonSerializer", schema: K.ValueSchema) -> None:
|
394
394
|
Serializer.__init__(self, schema=schema)
|
395
395
|
|
396
396
|
|
397
397
|
class JsonMultiSerializer(MultiSerializer[JsonValue]):
|
398
|
-
def __init__(self: "JsonMultiSerializer", schema:
|
398
|
+
def __init__(self: "JsonMultiSerializer", schema: K.IOSchema) -> None:
|
399
399
|
super().__init__([JsonSerializer(schema=s) for s in schema.values])
|