kinfer 0.3.2__cp312-cp312-macosx_11_0_arm64.whl → 0.4.0__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kinfer/__init__.py +0 -1
- kinfer/common/__init__.py +0 -0
- kinfer/common/types.py +11 -0
- kinfer/export/__init__.py +0 -1
- kinfer/export/common.py +35 -0
- kinfer/export/jax.py +51 -0
- kinfer/export/pytorch.py +42 -110
- kinfer/export/serialize.py +86 -0
- kinfer/requirements.txt +3 -4
- kinfer/rust/Cargo.toml +8 -6
- kinfer/rust/src/lib.rs +2 -11
- kinfer/rust/src/model.rs +271 -121
- kinfer/rust/src/runtime.rs +104 -0
- kinfer/rust_bindings/Cargo.toml +8 -1
- kinfer/rust_bindings/rust_bindings.pyi +35 -0
- kinfer/rust_bindings/src/lib.rs +310 -1
- kinfer/rust_bindings.cpython-312-darwin.so +0 -0
- kinfer/rust_bindings.pyi +29 -1
- kinfer-0.4.0.dist-info/METADATA +55 -0
- kinfer-0.4.0.dist-info/RECORD +26 -0
- {kinfer-0.3.2.dist-info → kinfer-0.4.0.dist-info}/WHEEL +2 -1
- kinfer/inference/__init__.py +0 -1
- kinfer/inference/python.py +0 -92
- kinfer/proto/__init__.py +0 -40
- kinfer/proto/kinfer_pb2.py +0 -103
- kinfer/proto/kinfer_pb2.pyi +0 -1097
- kinfer/requirements-dev.txt +0 -8
- kinfer/rust/build.rs +0 -16
- kinfer/rust/src/kinfer_proto.rs +0 -14
- kinfer/rust/src/main.rs +0 -6
- kinfer/rust/src/onnx_serializer.rs +0 -804
- kinfer/rust/src/serializer.rs +0 -221
- kinfer/rust/src/tests/onnx_serializer_tests.rs +0 -212
- kinfer/serialize/__init__.py +0 -36
- kinfer/serialize/base.py +0 -536
- kinfer/serialize/json.py +0 -399
- kinfer/serialize/numpy.py +0 -426
- kinfer/serialize/pytorch.py +0 -402
- kinfer/serialize/schema.py +0 -125
- kinfer/serialize/types.py +0 -17
- kinfer/serialize/utils.py +0 -177
- kinfer-0.3.2.dist-info/METADATA +0 -57
- kinfer-0.3.2.dist-info/RECORD +0 -39
- {kinfer-0.3.2.dist-info → kinfer-0.4.0.dist-info/licenses}/LICENSE +0 -0
- {kinfer-0.3.2.dist-info → kinfer-0.4.0.dist-info}/top_level.txt +0 -0
kinfer/serialize/json.py
DELETED
@@ -1,399 +0,0 @@
|
|
1
|
-
"""Defines a serializer for JSON."""
|
2
|
-
|
3
|
-
import base64
|
4
|
-
from typing import Any, Mapping, Sequence
|
5
|
-
|
6
|
-
from kinfer import proto as P
|
7
|
-
from kinfer.serialize.base import (
|
8
|
-
AudioFrameSerializer,
|
9
|
-
CameraFrameSerializer,
|
10
|
-
ImuSerializer,
|
11
|
-
JointCommandsSerializer,
|
12
|
-
JointPositionsSerializer,
|
13
|
-
JointTorquesSerializer,
|
14
|
-
JointVelocitiesSerializer,
|
15
|
-
MultiSerializer,
|
16
|
-
Serializer,
|
17
|
-
StateTensorSerializer,
|
18
|
-
TimestampSerializer,
|
19
|
-
VectorCommandSerializer,
|
20
|
-
)
|
21
|
-
from kinfer.serialize.utils import (
|
22
|
-
as_float,
|
23
|
-
check_names_match,
|
24
|
-
convert_angular_position,
|
25
|
-
convert_angular_velocity,
|
26
|
-
convert_torque,
|
27
|
-
)
|
28
|
-
|
29
|
-
Prim = str | int | float
|
30
|
-
|
31
|
-
JsonValue = Mapping[
|
32
|
-
str,
|
33
|
-
Prim
|
34
|
-
| Sequence[Prim]
|
35
|
-
| Sequence[Mapping[str, Prim]]
|
36
|
-
| Mapping[str, Prim]
|
37
|
-
| Mapping[str, Sequence[Prim]]
|
38
|
-
| Mapping[str, Mapping[str, Prim]],
|
39
|
-
]
|
40
|
-
|
41
|
-
|
42
|
-
class JsonJointPositionsSerializer(JointPositionsSerializer[JsonValue]):
|
43
|
-
def serialize_joint_positions(
|
44
|
-
self: "JsonJointPositionsSerializer",
|
45
|
-
schema: P.JointPositionsSchema,
|
46
|
-
value: P.JointPositionsValue,
|
47
|
-
) -> dict[str, list[float]]:
|
48
|
-
value_map = {v.joint_name: v for v in value.values}
|
49
|
-
check_names_match("schema", schema.joint_names, "value", list(value_map.keys()))
|
50
|
-
return {
|
51
|
-
"positions": [
|
52
|
-
convert_angular_position(value_map[name].value, value_map[name].unit, schema.unit)
|
53
|
-
for name in schema.joint_names
|
54
|
-
]
|
55
|
-
}
|
56
|
-
|
57
|
-
def deserialize_joint_positions(
|
58
|
-
self: "JsonJointPositionsSerializer",
|
59
|
-
schema: P.JointPositionsSchema,
|
60
|
-
value: JsonValue,
|
61
|
-
) -> P.JointPositionsValue:
|
62
|
-
if "positions" not in value:
|
63
|
-
raise ValueError("Key 'positions' not found in value")
|
64
|
-
positions = value["positions"]
|
65
|
-
if not isinstance(positions, list):
|
66
|
-
raise ValueError("Key 'positions' must be a list")
|
67
|
-
if len(positions) != len(schema.joint_names):
|
68
|
-
raise ValueError(
|
69
|
-
f"Shape of positions must match number of joint names: {len(positions)} != {len(schema.joint_names)}"
|
70
|
-
)
|
71
|
-
return P.JointPositionsValue(
|
72
|
-
values=[
|
73
|
-
P.JointPositionValue(joint_name=name, value=as_float(positions[i]), unit=schema.unit)
|
74
|
-
for i, name in enumerate(schema.joint_names)
|
75
|
-
]
|
76
|
-
)
|
77
|
-
|
78
|
-
|
79
|
-
class JsonJointVelocitiesSerializer(JointVelocitiesSerializer[JsonValue]):
|
80
|
-
def serialize_joint_velocities(
|
81
|
-
self: "JsonJointVelocitiesSerializer",
|
82
|
-
schema: P.JointVelocitiesSchema,
|
83
|
-
value: P.JointVelocitiesValue,
|
84
|
-
) -> dict[str, list[float]]:
|
85
|
-
value_map = {v.joint_name: v for v in value.values}
|
86
|
-
check_names_match("schema", schema.joint_names, "value", list(value_map.keys()))
|
87
|
-
return {
|
88
|
-
"velocities": [
|
89
|
-
convert_angular_velocity(value_map[name].value, value_map[name].unit, schema.unit)
|
90
|
-
for name in schema.joint_names
|
91
|
-
]
|
92
|
-
}
|
93
|
-
|
94
|
-
def deserialize_joint_velocities(
|
95
|
-
self: "JsonJointVelocitiesSerializer",
|
96
|
-
schema: P.JointVelocitiesSchema,
|
97
|
-
value: JsonValue,
|
98
|
-
) -> P.JointVelocitiesValue:
|
99
|
-
if "velocities" not in value:
|
100
|
-
raise ValueError("Key 'velocities' not found in value")
|
101
|
-
velocities = value["velocities"]
|
102
|
-
if not isinstance(velocities, list):
|
103
|
-
raise ValueError("Key 'velocities' must be a list")
|
104
|
-
if len(velocities) != len(schema.joint_names):
|
105
|
-
raise ValueError(
|
106
|
-
f"Shape of velocities must match number of joint names: {len(velocities)} != {len(schema.joint_names)}"
|
107
|
-
)
|
108
|
-
return P.JointVelocitiesValue(
|
109
|
-
values=[
|
110
|
-
P.JointVelocityValue(joint_name=name, value=as_float(velocities[i]), unit=schema.unit)
|
111
|
-
for i, name in enumerate(schema.joint_names)
|
112
|
-
]
|
113
|
-
)
|
114
|
-
|
115
|
-
|
116
|
-
class JsonJointTorquesSerializer(JointTorquesSerializer[JsonValue]):
|
117
|
-
def serialize_joint_torques(
|
118
|
-
self: "JsonJointTorquesSerializer",
|
119
|
-
schema: P.JointTorquesSchema,
|
120
|
-
value: P.JointTorquesValue,
|
121
|
-
) -> dict[str, list[float]]:
|
122
|
-
value_map = {v.joint_name: v for v in value.values}
|
123
|
-
check_names_match("schema", schema.joint_names, "value", list(value_map.keys()))
|
124
|
-
return {
|
125
|
-
"torques": [
|
126
|
-
convert_torque(value_map[name].value, value_map[name].unit, schema.unit) for name in schema.joint_names
|
127
|
-
]
|
128
|
-
}
|
129
|
-
|
130
|
-
def deserialize_joint_torques(
|
131
|
-
self: "JsonJointTorquesSerializer",
|
132
|
-
schema: P.JointTorquesSchema,
|
133
|
-
value: JsonValue,
|
134
|
-
) -> P.JointTorquesValue:
|
135
|
-
if "torques" not in value:
|
136
|
-
raise ValueError("Key 'torques' not found in value")
|
137
|
-
torques = value["torques"]
|
138
|
-
if not isinstance(torques, list):
|
139
|
-
raise ValueError("Key 'torques' must be a list")
|
140
|
-
if len(torques) != len(schema.joint_names):
|
141
|
-
raise ValueError(
|
142
|
-
f"Shape of torques must match number of joint names: {len(torques)} != {len(schema.joint_names)}"
|
143
|
-
)
|
144
|
-
return P.JointTorquesValue(
|
145
|
-
values=[
|
146
|
-
P.JointTorqueValue(joint_name=name, value=as_float(torques[i]), unit=schema.unit)
|
147
|
-
for i, name in enumerate(schema.joint_names)
|
148
|
-
]
|
149
|
-
)
|
150
|
-
|
151
|
-
|
152
|
-
class JsonJointCommandsSerializer(JointCommandsSerializer[JsonValue]):
|
153
|
-
def _convert_value_to_array(
|
154
|
-
self: "JsonJointCommandsSerializer",
|
155
|
-
value: P.JointCommandValue,
|
156
|
-
schema: P.JointCommandsSchema,
|
157
|
-
) -> list[float]:
|
158
|
-
return [
|
159
|
-
convert_torque(value.torque, value.torque_unit, schema.torque_unit),
|
160
|
-
convert_angular_velocity(value.velocity, value.velocity_unit, schema.velocity_unit),
|
161
|
-
convert_angular_position(value.position, value.position_unit, schema.position_unit),
|
162
|
-
float(value.kp),
|
163
|
-
float(value.kd),
|
164
|
-
]
|
165
|
-
|
166
|
-
def _convert_array_to_value(
|
167
|
-
self: "JsonJointCommandsSerializer",
|
168
|
-
values: Any, # noqa: ANN401
|
169
|
-
schema: P.JointCommandsSchema,
|
170
|
-
name: str,
|
171
|
-
) -> P.JointCommandValue:
|
172
|
-
if not isinstance(values, list):
|
173
|
-
raise ValueError("Value must be a list")
|
174
|
-
if len(values) != 5:
|
175
|
-
raise ValueError(f"Shape of command must match number of joint commands: {len(values)} != 5")
|
176
|
-
return P.JointCommandValue(
|
177
|
-
joint_name=name,
|
178
|
-
torque=float(values[0]),
|
179
|
-
velocity=float(values[1]),
|
180
|
-
position=float(values[2]),
|
181
|
-
kp=float(values[3]),
|
182
|
-
kd=float(values[4]),
|
183
|
-
torque_unit=schema.torque_unit,
|
184
|
-
velocity_unit=schema.velocity_unit,
|
185
|
-
position_unit=schema.position_unit,
|
186
|
-
)
|
187
|
-
|
188
|
-
def serialize_joint_commands(
|
189
|
-
self: "JsonJointCommandsSerializer",
|
190
|
-
schema: P.JointCommandsSchema,
|
191
|
-
value: P.JointCommandsValue,
|
192
|
-
) -> dict[str, dict[str, list[float]]]:
|
193
|
-
value_map = {v.joint_name: v for v in value.values}
|
194
|
-
check_names_match("schema", schema.joint_names, "value", list(value_map.keys()))
|
195
|
-
return {
|
196
|
-
"commands": {name: self._convert_value_to_array(value_map[name], schema) for name in schema.joint_names}
|
197
|
-
}
|
198
|
-
|
199
|
-
def deserialize_joint_commands(
|
200
|
-
self: "JsonJointCommandsSerializer",
|
201
|
-
schema: P.JointCommandsSchema,
|
202
|
-
value: JsonValue,
|
203
|
-
) -> P.JointCommandsValue:
|
204
|
-
if "commands" not in value:
|
205
|
-
raise ValueError("Key 'commands' not found in value")
|
206
|
-
commands = value["commands"]
|
207
|
-
if not isinstance(commands, dict):
|
208
|
-
raise ValueError("Key 'commands' must be a dictionary")
|
209
|
-
check_names_match("schema", schema.joint_names, "value", list(commands.keys()))
|
210
|
-
return P.JointCommandsValue(
|
211
|
-
values=[self._convert_array_to_value(commands[name], schema, name) for name in schema.joint_names]
|
212
|
-
)
|
213
|
-
|
214
|
-
|
215
|
-
class JsonCameraFrameSerializer(CameraFrameSerializer[JsonValue]):
|
216
|
-
def serialize_camera_frame(
|
217
|
-
self: "JsonCameraFrameSerializer",
|
218
|
-
schema: P.CameraFrameSchema,
|
219
|
-
value: P.CameraFrameValue,
|
220
|
-
) -> dict[str, str]:
|
221
|
-
return {"data": base64.b64encode(value.data).decode("utf-8")}
|
222
|
-
|
223
|
-
def deserialize_camera_frame(
|
224
|
-
self: "JsonCameraFrameSerializer",
|
225
|
-
schema: P.CameraFrameSchema,
|
226
|
-
value: JsonValue,
|
227
|
-
) -> P.CameraFrameValue:
|
228
|
-
if "data" not in value:
|
229
|
-
raise ValueError("Key 'data' not found in value")
|
230
|
-
data = value["data"]
|
231
|
-
if not isinstance(data, str):
|
232
|
-
raise ValueError("Key 'data' must be a string")
|
233
|
-
return P.CameraFrameValue(data=base64.b64decode(data))
|
234
|
-
|
235
|
-
|
236
|
-
class JsonAudioFrameSerializer(AudioFrameSerializer[JsonValue]):
|
237
|
-
def serialize_audio_frame(
|
238
|
-
self: "JsonAudioFrameSerializer",
|
239
|
-
schema: P.AudioFrameSchema,
|
240
|
-
value: P.AudioFrameValue,
|
241
|
-
) -> dict[str, str]:
|
242
|
-
return {"data": base64.b64encode(value.data).decode("utf-8")}
|
243
|
-
|
244
|
-
def deserialize_audio_frame(
|
245
|
-
self: "JsonAudioFrameSerializer",
|
246
|
-
schema: P.AudioFrameSchema,
|
247
|
-
value: JsonValue,
|
248
|
-
) -> P.AudioFrameValue:
|
249
|
-
if "data" not in value:
|
250
|
-
raise ValueError("Key 'data' not found in value")
|
251
|
-
data = value["data"]
|
252
|
-
if not isinstance(data, str):
|
253
|
-
raise ValueError("Key 'data' must be a string")
|
254
|
-
return P.AudioFrameValue(data=base64.b64decode(data))
|
255
|
-
|
256
|
-
|
257
|
-
class JsonImuSerializer(ImuSerializer[JsonValue]):
|
258
|
-
def serialize_imu(
|
259
|
-
self: "JsonImuSerializer",
|
260
|
-
schema: P.ImuSchema,
|
261
|
-
value: P.ImuValue,
|
262
|
-
) -> dict[str, list[float]]:
|
263
|
-
data: dict[str, list[float]] = {}
|
264
|
-
if schema.use_accelerometer:
|
265
|
-
data["linear_acceleration"] = [
|
266
|
-
value.linear_acceleration.x,
|
267
|
-
value.linear_acceleration.y,
|
268
|
-
value.linear_acceleration.z,
|
269
|
-
]
|
270
|
-
if schema.use_gyroscope:
|
271
|
-
data["angular_velocity"] = [
|
272
|
-
value.angular_velocity.x,
|
273
|
-
value.angular_velocity.y,
|
274
|
-
value.angular_velocity.z,
|
275
|
-
]
|
276
|
-
if schema.use_magnetometer:
|
277
|
-
data["magnetic_field"] = [
|
278
|
-
value.magnetic_field.x,
|
279
|
-
value.magnetic_field.y,
|
280
|
-
value.magnetic_field.z,
|
281
|
-
]
|
282
|
-
return data
|
283
|
-
|
284
|
-
def deserialize_imu(
|
285
|
-
self: "JsonImuSerializer",
|
286
|
-
schema: P.ImuSchema,
|
287
|
-
value: JsonValue,
|
288
|
-
) -> P.ImuValue:
|
289
|
-
imu_value = P.ImuValue()
|
290
|
-
if schema.use_accelerometer:
|
291
|
-
if not isinstance(linear_acceleration := value["linear_acceleration"], list):
|
292
|
-
raise ValueError("Key 'linear_acceleration' must be a list")
|
293
|
-
x, y, z = linear_acceleration
|
294
|
-
imu_value.linear_acceleration.x = as_float(x)
|
295
|
-
imu_value.linear_acceleration.y = as_float(y)
|
296
|
-
imu_value.linear_acceleration.z = as_float(z)
|
297
|
-
if schema.use_gyroscope:
|
298
|
-
if not isinstance(angular_velocity := value["angular_velocity"], list):
|
299
|
-
raise ValueError("Key 'angular_velocity' must be a list")
|
300
|
-
x, y, z = angular_velocity
|
301
|
-
imu_value.angular_velocity.x = as_float(x)
|
302
|
-
imu_value.angular_velocity.y = as_float(y)
|
303
|
-
imu_value.angular_velocity.z = as_float(z)
|
304
|
-
if schema.use_magnetometer:
|
305
|
-
if not isinstance(magnetic_field := value["magnetic_field"], list):
|
306
|
-
raise ValueError("Key 'magnetic_field' must be a list")
|
307
|
-
x, y, z = magnetic_field
|
308
|
-
imu_value.magnetic_field.x = as_float(x)
|
309
|
-
imu_value.magnetic_field.y = as_float(y)
|
310
|
-
imu_value.magnetic_field.z = as_float(z)
|
311
|
-
return imu_value
|
312
|
-
|
313
|
-
|
314
|
-
class JsonTimestampSerializer(TimestampSerializer[JsonValue]):
|
315
|
-
def serialize_timestamp(
|
316
|
-
self: "JsonTimestampSerializer",
|
317
|
-
schema: P.TimestampSchema,
|
318
|
-
value: P.TimestampValue,
|
319
|
-
) -> dict[str, int]:
|
320
|
-
return {"seconds": value.seconds, "nanos": value.nanos}
|
321
|
-
|
322
|
-
def deserialize_timestamp(
|
323
|
-
self: "JsonTimestampSerializer",
|
324
|
-
schema: P.TimestampSchema,
|
325
|
-
value: JsonValue,
|
326
|
-
) -> P.TimestampValue:
|
327
|
-
if "seconds" not in value or "nanos" not in value:
|
328
|
-
raise ValueError("Key 'seconds' or 'nanos' not found in value")
|
329
|
-
seconds = value["seconds"]
|
330
|
-
nanos = value["nanos"]
|
331
|
-
if not isinstance(seconds, int) or not isinstance(nanos, int):
|
332
|
-
raise ValueError("Key 'seconds' and 'nanos' must be integers")
|
333
|
-
return P.TimestampValue(seconds=seconds, nanos=nanos)
|
334
|
-
|
335
|
-
|
336
|
-
class JsonVectorCommandSerializer(VectorCommandSerializer[JsonValue]):
|
337
|
-
def serialize_vector_command(
|
338
|
-
self: "JsonVectorCommandSerializer",
|
339
|
-
schema: P.VectorCommandSchema,
|
340
|
-
value: P.VectorCommandValue,
|
341
|
-
) -> dict[str, list[float]]:
|
342
|
-
return {"values": list(value.values)}
|
343
|
-
|
344
|
-
def deserialize_vector_command(
|
345
|
-
self: "JsonVectorCommandSerializer",
|
346
|
-
schema: P.VectorCommandSchema,
|
347
|
-
value: JsonValue,
|
348
|
-
) -> P.VectorCommandValue:
|
349
|
-
if "values" not in value:
|
350
|
-
raise ValueError("Key 'values' not found in value")
|
351
|
-
values = value["values"]
|
352
|
-
if not isinstance(values, list):
|
353
|
-
raise ValueError("Key 'values' must be a list")
|
354
|
-
if len(values) != schema.dimensions:
|
355
|
-
raise ValueError(f"Length of list must match number of dimensions: {len(values)} != {schema.dimensions}")
|
356
|
-
return P.VectorCommandValue(values=[as_float(v) for v in values])
|
357
|
-
|
358
|
-
|
359
|
-
class JsonStateTensorSerializer(StateTensorSerializer[JsonValue]):
|
360
|
-
def serialize_state_tensor(
|
361
|
-
self: "JsonStateTensorSerializer",
|
362
|
-
schema: P.StateTensorSchema,
|
363
|
-
value: P.StateTensorValue,
|
364
|
-
) -> dict[str, str]:
|
365
|
-
return {"data": base64.b64encode(value.data).decode("utf-8")}
|
366
|
-
|
367
|
-
def deserialize_state_tensor(
|
368
|
-
self: "JsonStateTensorSerializer",
|
369
|
-
schema: P.StateTensorSchema,
|
370
|
-
value: JsonValue,
|
371
|
-
) -> P.StateTensorValue:
|
372
|
-
if "data" not in value:
|
373
|
-
raise ValueError("Key 'data' not found in value")
|
374
|
-
data = value["data"]
|
375
|
-
if not isinstance(data, str):
|
376
|
-
raise ValueError("Key 'data' must be a string")
|
377
|
-
return P.StateTensorValue(data=base64.b64decode(data))
|
378
|
-
|
379
|
-
|
380
|
-
class JsonSerializer(
|
381
|
-
JsonJointPositionsSerializer,
|
382
|
-
JsonJointVelocitiesSerializer,
|
383
|
-
JsonJointTorquesSerializer,
|
384
|
-
JsonJointCommandsSerializer,
|
385
|
-
JsonCameraFrameSerializer,
|
386
|
-
JsonAudioFrameSerializer,
|
387
|
-
JsonImuSerializer,
|
388
|
-
JsonTimestampSerializer,
|
389
|
-
JsonVectorCommandSerializer,
|
390
|
-
JsonStateTensorSerializer,
|
391
|
-
Serializer[JsonValue],
|
392
|
-
):
|
393
|
-
def __init__(self: "JsonSerializer", schema: P.ValueSchema) -> None:
|
394
|
-
Serializer.__init__(self, schema=schema)
|
395
|
-
|
396
|
-
|
397
|
-
class JsonMultiSerializer(MultiSerializer[JsonValue]):
|
398
|
-
def __init__(self: "JsonMultiSerializer", schema: P.IOSchema) -> None:
|
399
|
-
super().__init__([JsonSerializer(schema=s) for s in schema.values])
|