kinfer 0.3.2__cp312-cp312-macosx_11_0_arm64.whl → 0.4.0__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. kinfer/__init__.py +0 -1
  2. kinfer/common/__init__.py +0 -0
  3. kinfer/common/types.py +11 -0
  4. kinfer/export/__init__.py +0 -1
  5. kinfer/export/common.py +35 -0
  6. kinfer/export/jax.py +51 -0
  7. kinfer/export/pytorch.py +42 -110
  8. kinfer/export/serialize.py +86 -0
  9. kinfer/requirements.txt +3 -4
  10. kinfer/rust/Cargo.toml +8 -6
  11. kinfer/rust/src/lib.rs +2 -11
  12. kinfer/rust/src/model.rs +271 -121
  13. kinfer/rust/src/runtime.rs +104 -0
  14. kinfer/rust_bindings/Cargo.toml +8 -1
  15. kinfer/rust_bindings/rust_bindings.pyi +35 -0
  16. kinfer/rust_bindings/src/lib.rs +310 -1
  17. kinfer/rust_bindings.cpython-312-darwin.so +0 -0
  18. kinfer/rust_bindings.pyi +29 -1
  19. kinfer-0.4.0.dist-info/METADATA +55 -0
  20. kinfer-0.4.0.dist-info/RECORD +26 -0
  21. {kinfer-0.3.2.dist-info → kinfer-0.4.0.dist-info}/WHEEL +2 -1
  22. kinfer/inference/__init__.py +0 -1
  23. kinfer/inference/python.py +0 -92
  24. kinfer/proto/__init__.py +0 -40
  25. kinfer/proto/kinfer_pb2.py +0 -103
  26. kinfer/proto/kinfer_pb2.pyi +0 -1097
  27. kinfer/requirements-dev.txt +0 -8
  28. kinfer/rust/build.rs +0 -16
  29. kinfer/rust/src/kinfer_proto.rs +0 -14
  30. kinfer/rust/src/main.rs +0 -6
  31. kinfer/rust/src/onnx_serializer.rs +0 -804
  32. kinfer/rust/src/serializer.rs +0 -221
  33. kinfer/rust/src/tests/onnx_serializer_tests.rs +0 -212
  34. kinfer/serialize/__init__.py +0 -36
  35. kinfer/serialize/base.py +0 -536
  36. kinfer/serialize/json.py +0 -399
  37. kinfer/serialize/numpy.py +0 -426
  38. kinfer/serialize/pytorch.py +0 -402
  39. kinfer/serialize/schema.py +0 -125
  40. kinfer/serialize/types.py +0 -17
  41. kinfer/serialize/utils.py +0 -177
  42. kinfer-0.3.2.dist-info/METADATA +0 -57
  43. kinfer-0.3.2.dist-info/RECORD +0 -39
  44. {kinfer-0.3.2.dist-info → kinfer-0.4.0.dist-info/licenses}/LICENSE +0 -0
  45. {kinfer-0.3.2.dist-info → kinfer-0.4.0.dist-info}/top_level.txt +0 -0
kinfer/serialize/json.py DELETED
@@ -1,399 +0,0 @@
1
- """Defines a serializer for JSON."""
2
-
3
- import base64
4
- from typing import Any, Mapping, Sequence
5
-
6
- from kinfer import proto as P
7
- from kinfer.serialize.base import (
8
- AudioFrameSerializer,
9
- CameraFrameSerializer,
10
- ImuSerializer,
11
- JointCommandsSerializer,
12
- JointPositionsSerializer,
13
- JointTorquesSerializer,
14
- JointVelocitiesSerializer,
15
- MultiSerializer,
16
- Serializer,
17
- StateTensorSerializer,
18
- TimestampSerializer,
19
- VectorCommandSerializer,
20
- )
21
- from kinfer.serialize.utils import (
22
- as_float,
23
- check_names_match,
24
- convert_angular_position,
25
- convert_angular_velocity,
26
- convert_torque,
27
- )
28
-
29
- Prim = str | int | float
30
-
31
- JsonValue = Mapping[
32
- str,
33
- Prim
34
- | Sequence[Prim]
35
- | Sequence[Mapping[str, Prim]]
36
- | Mapping[str, Prim]
37
- | Mapping[str, Sequence[Prim]]
38
- | Mapping[str, Mapping[str, Prim]],
39
- ]
40
-
41
-
42
- class JsonJointPositionsSerializer(JointPositionsSerializer[JsonValue]):
43
- def serialize_joint_positions(
44
- self: "JsonJointPositionsSerializer",
45
- schema: P.JointPositionsSchema,
46
- value: P.JointPositionsValue,
47
- ) -> dict[str, list[float]]:
48
- value_map = {v.joint_name: v for v in value.values}
49
- check_names_match("schema", schema.joint_names, "value", list(value_map.keys()))
50
- return {
51
- "positions": [
52
- convert_angular_position(value_map[name].value, value_map[name].unit, schema.unit)
53
- for name in schema.joint_names
54
- ]
55
- }
56
-
57
- def deserialize_joint_positions(
58
- self: "JsonJointPositionsSerializer",
59
- schema: P.JointPositionsSchema,
60
- value: JsonValue,
61
- ) -> P.JointPositionsValue:
62
- if "positions" not in value:
63
- raise ValueError("Key 'positions' not found in value")
64
- positions = value["positions"]
65
- if not isinstance(positions, list):
66
- raise ValueError("Key 'positions' must be a list")
67
- if len(positions) != len(schema.joint_names):
68
- raise ValueError(
69
- f"Shape of positions must match number of joint names: {len(positions)} != {len(schema.joint_names)}"
70
- )
71
- return P.JointPositionsValue(
72
- values=[
73
- P.JointPositionValue(joint_name=name, value=as_float(positions[i]), unit=schema.unit)
74
- for i, name in enumerate(schema.joint_names)
75
- ]
76
- )
77
-
78
-
79
- class JsonJointVelocitiesSerializer(JointVelocitiesSerializer[JsonValue]):
80
- def serialize_joint_velocities(
81
- self: "JsonJointVelocitiesSerializer",
82
- schema: P.JointVelocitiesSchema,
83
- value: P.JointVelocitiesValue,
84
- ) -> dict[str, list[float]]:
85
- value_map = {v.joint_name: v for v in value.values}
86
- check_names_match("schema", schema.joint_names, "value", list(value_map.keys()))
87
- return {
88
- "velocities": [
89
- convert_angular_velocity(value_map[name].value, value_map[name].unit, schema.unit)
90
- for name in schema.joint_names
91
- ]
92
- }
93
-
94
- def deserialize_joint_velocities(
95
- self: "JsonJointVelocitiesSerializer",
96
- schema: P.JointVelocitiesSchema,
97
- value: JsonValue,
98
- ) -> P.JointVelocitiesValue:
99
- if "velocities" not in value:
100
- raise ValueError("Key 'velocities' not found in value")
101
- velocities = value["velocities"]
102
- if not isinstance(velocities, list):
103
- raise ValueError("Key 'velocities' must be a list")
104
- if len(velocities) != len(schema.joint_names):
105
- raise ValueError(
106
- f"Shape of velocities must match number of joint names: {len(velocities)} != {len(schema.joint_names)}"
107
- )
108
- return P.JointVelocitiesValue(
109
- values=[
110
- P.JointVelocityValue(joint_name=name, value=as_float(velocities[i]), unit=schema.unit)
111
- for i, name in enumerate(schema.joint_names)
112
- ]
113
- )
114
-
115
-
116
- class JsonJointTorquesSerializer(JointTorquesSerializer[JsonValue]):
117
- def serialize_joint_torques(
118
- self: "JsonJointTorquesSerializer",
119
- schema: P.JointTorquesSchema,
120
- value: P.JointTorquesValue,
121
- ) -> dict[str, list[float]]:
122
- value_map = {v.joint_name: v for v in value.values}
123
- check_names_match("schema", schema.joint_names, "value", list(value_map.keys()))
124
- return {
125
- "torques": [
126
- convert_torque(value_map[name].value, value_map[name].unit, schema.unit) for name in schema.joint_names
127
- ]
128
- }
129
-
130
- def deserialize_joint_torques(
131
- self: "JsonJointTorquesSerializer",
132
- schema: P.JointTorquesSchema,
133
- value: JsonValue,
134
- ) -> P.JointTorquesValue:
135
- if "torques" not in value:
136
- raise ValueError("Key 'torques' not found in value")
137
- torques = value["torques"]
138
- if not isinstance(torques, list):
139
- raise ValueError("Key 'torques' must be a list")
140
- if len(torques) != len(schema.joint_names):
141
- raise ValueError(
142
- f"Shape of torques must match number of joint names: {len(torques)} != {len(schema.joint_names)}"
143
- )
144
- return P.JointTorquesValue(
145
- values=[
146
- P.JointTorqueValue(joint_name=name, value=as_float(torques[i]), unit=schema.unit)
147
- for i, name in enumerate(schema.joint_names)
148
- ]
149
- )
150
-
151
-
152
- class JsonJointCommandsSerializer(JointCommandsSerializer[JsonValue]):
153
- def _convert_value_to_array(
154
- self: "JsonJointCommandsSerializer",
155
- value: P.JointCommandValue,
156
- schema: P.JointCommandsSchema,
157
- ) -> list[float]:
158
- return [
159
- convert_torque(value.torque, value.torque_unit, schema.torque_unit),
160
- convert_angular_velocity(value.velocity, value.velocity_unit, schema.velocity_unit),
161
- convert_angular_position(value.position, value.position_unit, schema.position_unit),
162
- float(value.kp),
163
- float(value.kd),
164
- ]
165
-
166
- def _convert_array_to_value(
167
- self: "JsonJointCommandsSerializer",
168
- values: Any, # noqa: ANN401
169
- schema: P.JointCommandsSchema,
170
- name: str,
171
- ) -> P.JointCommandValue:
172
- if not isinstance(values, list):
173
- raise ValueError("Value must be a list")
174
- if len(values) != 5:
175
- raise ValueError(f"Shape of command must match number of joint commands: {len(values)} != 5")
176
- return P.JointCommandValue(
177
- joint_name=name,
178
- torque=float(values[0]),
179
- velocity=float(values[1]),
180
- position=float(values[2]),
181
- kp=float(values[3]),
182
- kd=float(values[4]),
183
- torque_unit=schema.torque_unit,
184
- velocity_unit=schema.velocity_unit,
185
- position_unit=schema.position_unit,
186
- )
187
-
188
- def serialize_joint_commands(
189
- self: "JsonJointCommandsSerializer",
190
- schema: P.JointCommandsSchema,
191
- value: P.JointCommandsValue,
192
- ) -> dict[str, dict[str, list[float]]]:
193
- value_map = {v.joint_name: v for v in value.values}
194
- check_names_match("schema", schema.joint_names, "value", list(value_map.keys()))
195
- return {
196
- "commands": {name: self._convert_value_to_array(value_map[name], schema) for name in schema.joint_names}
197
- }
198
-
199
- def deserialize_joint_commands(
200
- self: "JsonJointCommandsSerializer",
201
- schema: P.JointCommandsSchema,
202
- value: JsonValue,
203
- ) -> P.JointCommandsValue:
204
- if "commands" not in value:
205
- raise ValueError("Key 'commands' not found in value")
206
- commands = value["commands"]
207
- if not isinstance(commands, dict):
208
- raise ValueError("Key 'commands' must be a dictionary")
209
- check_names_match("schema", schema.joint_names, "value", list(commands.keys()))
210
- return P.JointCommandsValue(
211
- values=[self._convert_array_to_value(commands[name], schema, name) for name in schema.joint_names]
212
- )
213
-
214
-
215
- class JsonCameraFrameSerializer(CameraFrameSerializer[JsonValue]):
216
- def serialize_camera_frame(
217
- self: "JsonCameraFrameSerializer",
218
- schema: P.CameraFrameSchema,
219
- value: P.CameraFrameValue,
220
- ) -> dict[str, str]:
221
- return {"data": base64.b64encode(value.data).decode("utf-8")}
222
-
223
- def deserialize_camera_frame(
224
- self: "JsonCameraFrameSerializer",
225
- schema: P.CameraFrameSchema,
226
- value: JsonValue,
227
- ) -> P.CameraFrameValue:
228
- if "data" not in value:
229
- raise ValueError("Key 'data' not found in value")
230
- data = value["data"]
231
- if not isinstance(data, str):
232
- raise ValueError("Key 'data' must be a string")
233
- return P.CameraFrameValue(data=base64.b64decode(data))
234
-
235
-
236
- class JsonAudioFrameSerializer(AudioFrameSerializer[JsonValue]):
237
- def serialize_audio_frame(
238
- self: "JsonAudioFrameSerializer",
239
- schema: P.AudioFrameSchema,
240
- value: P.AudioFrameValue,
241
- ) -> dict[str, str]:
242
- return {"data": base64.b64encode(value.data).decode("utf-8")}
243
-
244
- def deserialize_audio_frame(
245
- self: "JsonAudioFrameSerializer",
246
- schema: P.AudioFrameSchema,
247
- value: JsonValue,
248
- ) -> P.AudioFrameValue:
249
- if "data" not in value:
250
- raise ValueError("Key 'data' not found in value")
251
- data = value["data"]
252
- if not isinstance(data, str):
253
- raise ValueError("Key 'data' must be a string")
254
- return P.AudioFrameValue(data=base64.b64decode(data))
255
-
256
-
257
- class JsonImuSerializer(ImuSerializer[JsonValue]):
258
- def serialize_imu(
259
- self: "JsonImuSerializer",
260
- schema: P.ImuSchema,
261
- value: P.ImuValue,
262
- ) -> dict[str, list[float]]:
263
- data: dict[str, list[float]] = {}
264
- if schema.use_accelerometer:
265
- data["linear_acceleration"] = [
266
- value.linear_acceleration.x,
267
- value.linear_acceleration.y,
268
- value.linear_acceleration.z,
269
- ]
270
- if schema.use_gyroscope:
271
- data["angular_velocity"] = [
272
- value.angular_velocity.x,
273
- value.angular_velocity.y,
274
- value.angular_velocity.z,
275
- ]
276
- if schema.use_magnetometer:
277
- data["magnetic_field"] = [
278
- value.magnetic_field.x,
279
- value.magnetic_field.y,
280
- value.magnetic_field.z,
281
- ]
282
- return data
283
-
284
- def deserialize_imu(
285
- self: "JsonImuSerializer",
286
- schema: P.ImuSchema,
287
- value: JsonValue,
288
- ) -> P.ImuValue:
289
- imu_value = P.ImuValue()
290
- if schema.use_accelerometer:
291
- if not isinstance(linear_acceleration := value["linear_acceleration"], list):
292
- raise ValueError("Key 'linear_acceleration' must be a list")
293
- x, y, z = linear_acceleration
294
- imu_value.linear_acceleration.x = as_float(x)
295
- imu_value.linear_acceleration.y = as_float(y)
296
- imu_value.linear_acceleration.z = as_float(z)
297
- if schema.use_gyroscope:
298
- if not isinstance(angular_velocity := value["angular_velocity"], list):
299
- raise ValueError("Key 'angular_velocity' must be a list")
300
- x, y, z = angular_velocity
301
- imu_value.angular_velocity.x = as_float(x)
302
- imu_value.angular_velocity.y = as_float(y)
303
- imu_value.angular_velocity.z = as_float(z)
304
- if schema.use_magnetometer:
305
- if not isinstance(magnetic_field := value["magnetic_field"], list):
306
- raise ValueError("Key 'magnetic_field' must be a list")
307
- x, y, z = magnetic_field
308
- imu_value.magnetic_field.x = as_float(x)
309
- imu_value.magnetic_field.y = as_float(y)
310
- imu_value.magnetic_field.z = as_float(z)
311
- return imu_value
312
-
313
-
314
- class JsonTimestampSerializer(TimestampSerializer[JsonValue]):
315
- def serialize_timestamp(
316
- self: "JsonTimestampSerializer",
317
- schema: P.TimestampSchema,
318
- value: P.TimestampValue,
319
- ) -> dict[str, int]:
320
- return {"seconds": value.seconds, "nanos": value.nanos}
321
-
322
- def deserialize_timestamp(
323
- self: "JsonTimestampSerializer",
324
- schema: P.TimestampSchema,
325
- value: JsonValue,
326
- ) -> P.TimestampValue:
327
- if "seconds" not in value or "nanos" not in value:
328
- raise ValueError("Key 'seconds' or 'nanos' not found in value")
329
- seconds = value["seconds"]
330
- nanos = value["nanos"]
331
- if not isinstance(seconds, int) or not isinstance(nanos, int):
332
- raise ValueError("Key 'seconds' and 'nanos' must be integers")
333
- return P.TimestampValue(seconds=seconds, nanos=nanos)
334
-
335
-
336
- class JsonVectorCommandSerializer(VectorCommandSerializer[JsonValue]):
337
- def serialize_vector_command(
338
- self: "JsonVectorCommandSerializer",
339
- schema: P.VectorCommandSchema,
340
- value: P.VectorCommandValue,
341
- ) -> dict[str, list[float]]:
342
- return {"values": list(value.values)}
343
-
344
- def deserialize_vector_command(
345
- self: "JsonVectorCommandSerializer",
346
- schema: P.VectorCommandSchema,
347
- value: JsonValue,
348
- ) -> P.VectorCommandValue:
349
- if "values" not in value:
350
- raise ValueError("Key 'values' not found in value")
351
- values = value["values"]
352
- if not isinstance(values, list):
353
- raise ValueError("Key 'values' must be a list")
354
- if len(values) != schema.dimensions:
355
- raise ValueError(f"Length of list must match number of dimensions: {len(values)} != {schema.dimensions}")
356
- return P.VectorCommandValue(values=[as_float(v) for v in values])
357
-
358
-
359
- class JsonStateTensorSerializer(StateTensorSerializer[JsonValue]):
360
- def serialize_state_tensor(
361
- self: "JsonStateTensorSerializer",
362
- schema: P.StateTensorSchema,
363
- value: P.StateTensorValue,
364
- ) -> dict[str, str]:
365
- return {"data": base64.b64encode(value.data).decode("utf-8")}
366
-
367
- def deserialize_state_tensor(
368
- self: "JsonStateTensorSerializer",
369
- schema: P.StateTensorSchema,
370
- value: JsonValue,
371
- ) -> P.StateTensorValue:
372
- if "data" not in value:
373
- raise ValueError("Key 'data' not found in value")
374
- data = value["data"]
375
- if not isinstance(data, str):
376
- raise ValueError("Key 'data' must be a string")
377
- return P.StateTensorValue(data=base64.b64decode(data))
378
-
379
-
380
- class JsonSerializer(
381
- JsonJointPositionsSerializer,
382
- JsonJointVelocitiesSerializer,
383
- JsonJointTorquesSerializer,
384
- JsonJointCommandsSerializer,
385
- JsonCameraFrameSerializer,
386
- JsonAudioFrameSerializer,
387
- JsonImuSerializer,
388
- JsonTimestampSerializer,
389
- JsonVectorCommandSerializer,
390
- JsonStateTensorSerializer,
391
- Serializer[JsonValue],
392
- ):
393
- def __init__(self: "JsonSerializer", schema: P.ValueSchema) -> None:
394
- Serializer.__init__(self, schema=schema)
395
-
396
-
397
- class JsonMultiSerializer(MultiSerializer[JsonValue]):
398
- def __init__(self: "JsonMultiSerializer", schema: P.IOSchema) -> None:
399
- super().__init__([JsonSerializer(schema=s) for s in schema.values])