kinfer 0.3.2__cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl → 0.3.3__cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kinfer/__init__.py CHANGED
@@ -1,6 +1,10 @@
1
1
  """Defines the kinfer API."""
2
2
 
3
- from . import export, inference
3
+ from . import proto as K
4
+ from .export.pytorch import export_model, get_model
5
+ from .inference.base import KModel
6
+ from .inference.python import ONNXModel
4
7
  from .rust_bindings import get_version
8
+ from .serialize import get_multi_serializer, get_serializer
5
9
 
6
10
  __version__ = get_version()
kinfer/export/__init__.py CHANGED
@@ -1 +0,0 @@
1
- from .pytorch import *
kinfer/export/pytorch.py CHANGED
@@ -10,7 +10,7 @@ import onnxruntime as ort
10
10
  import torch
11
11
  from torch import Tensor
12
12
 
13
- from kinfer import proto as P
13
+ from kinfer import proto as K
14
14
  from kinfer.serialize.pytorch import PyTorchMultiSerializer
15
15
  from kinfer.serialize.schema import get_dummy_io
16
16
  from kinfer.serialize.utils import check_names_match
@@ -18,7 +18,7 @@ from kinfer.serialize.utils import check_names_match
18
18
  KINFER_METADATA_KEY = "kinfer_metadata"
19
19
 
20
20
 
21
- def _add_metadata_to_onnx(model_proto: onnx.ModelProto, schema: P.ModelSchema) -> onnx.ModelProto:
21
+ def _add_metadata_to_onnx(model_proto: onnx.ModelProto, schema: K.ModelSchema) -> onnx.ModelProto:
22
22
  """Add metadata to ONNX model.
23
23
 
24
24
  Args:
@@ -37,7 +37,7 @@ def _add_metadata_to_onnx(model_proto: onnx.ModelProto, schema: P.ModelSchema) -
37
37
  return model_proto
38
38
 
39
39
 
40
- def export_model(model: torch.jit.ScriptModule, schema: P.ModelSchema) -> onnx.ModelProto:
40
+ def export_model(model: torch.jit.ScriptModule, schema: K.ModelSchema) -> onnx.ModelProto:
41
41
  """Export PyTorch model to ONNX format with metadata.
42
42
 
43
43
  Args:
@@ -1 +1,2 @@
1
- from .python import *
1
+ from .base import KModel
2
+ from .python import ONNXModel
@@ -0,0 +1,64 @@
1
+ """Defines the base interface for running model inference.
2
+
3
+ All kinfer models must implement this interface - the model inputs and outputs
4
+ should match the provided schema, and the `__call__` method should take the
5
+ inputs and return the outputs according to this schema.
6
+ """
7
+
8
+ import functools
9
+ from abc import ABC, abstractmethod
10
+
11
+ from kinfer import proto as K
12
+
13
+
14
+ class KModel(ABC):
15
+ """Base interface for running model inference."""
16
+
17
+ @abstractmethod
18
+ def get_schema(self) -> K.ModelSchema:
19
+ """Get the model schema."""
20
+
21
+ @abstractmethod
22
+ def __call__(self, inputs: K.IO) -> K.IO:
23
+ """Run inference on input data.
24
+
25
+ Args:
26
+ inputs: Input data, matching the input schema.
27
+
28
+ Returns:
29
+ Model outputs, matching the output schema.
30
+ """
31
+
32
+ @functools.cached_property
33
+ def schema(self) -> K.ModelSchema:
34
+ return self.get_schema()
35
+
36
+ @property
37
+ def input_schema(self) -> K.IOSchema:
38
+ """Get the input schema."""
39
+ return self.schema.input_schema
40
+
41
+ @property
42
+ def output_schema(self) -> K.IOSchema:
43
+ """Get the output schema."""
44
+ return self.schema.output_schema
45
+
46
+ @property
47
+ def schema_input_keys(self) -> list[str]:
48
+ """Get all value names from input schemas.
49
+
50
+ Returns:
51
+ List of value names from input schema.
52
+ """
53
+ input_names = [value.value_name for value in self.input_schema.values]
54
+ return input_names
55
+
56
+ @property
57
+ def schema_output_keys(self) -> list[str]:
58
+ """Get all value names from output schemas.
59
+
60
+ Returns:
61
+ List of value names from output schema.
62
+ """
63
+ output_names = [value.value_name for value in self.output_schema.values]
64
+ return output_names
@@ -6,17 +6,18 @@ from pathlib import Path
6
6
  import onnx
7
7
  import onnxruntime as ort
8
8
 
9
- from kinfer import proto as P
9
+ from kinfer import proto as K
10
10
  from kinfer.export.pytorch import KINFER_METADATA_KEY
11
+ from kinfer.inference.base import KModel
11
12
  from kinfer.serialize.numpy import NumpyMultiSerializer
12
13
 
13
14
 
14
- def _read_schema(model: onnx.ModelProto) -> P.ModelSchema:
15
+ def _read_schema(model: onnx.ModelProto) -> K.ModelSchema:
15
16
  for prop in model.metadata_props:
16
17
  if prop.key == KINFER_METADATA_KEY:
17
18
  try:
18
19
  schema_bytes = base64.b64decode(prop.value)
19
- schema = P.ModelSchema()
20
+ schema = K.ModelSchema()
20
21
  schema.ParseFromString(schema_bytes)
21
22
  return schema
22
23
  except Exception as e:
@@ -27,7 +28,7 @@ def _read_schema(model: onnx.ModelProto) -> P.ModelSchema:
27
28
  raise ValueError(f"{KINFER_METADATA_KEY} not found in model metadata")
28
29
 
29
30
 
30
- class ONNXModel:
31
+ class ONNXModel(KModel):
31
32
  """Wrapper for ONNX model inference."""
32
33
 
33
34
  def __init__(self: "ONNXModel", model_path: str | Path) -> None:
@@ -47,7 +48,10 @@ class ONNXModel:
47
48
  self._input_serializer = NumpyMultiSerializer(self._schema.input_schema)
48
49
  self._output_serializer = NumpyMultiSerializer(self._schema.output_schema)
49
50
 
50
- def __call__(self: "ONNXModel", inputs: P.IO) -> P.IO:
51
+ def get_schema(self) -> K.ModelSchema:
52
+ return self._schema
53
+
54
+ def __call__(self, inputs: K.IO) -> K.IO:
51
55
  """Run inference on input data.
52
56
 
53
57
  Args:
@@ -60,33 +64,3 @@ class ONNXModel:
60
64
  outputs_np = self.session.run(None, inputs_np)
61
65
  outputs = self._output_serializer.deserialize_io(outputs_np)
62
66
  return outputs
63
-
64
- @property
65
- def input_schema(self: "ONNXModel") -> P.IOSchema:
66
- """Get the input schema."""
67
- return self._schema.input_schema
68
-
69
- @property
70
- def output_schema(self: "ONNXModel") -> P.IOSchema:
71
- """Get the output schema."""
72
- return self._schema.output_schema
73
-
74
- @property
75
- def schema_input_keys(self: "ONNXModel") -> list[str]:
76
- """Get all value names from input schemas.
77
-
78
- Returns:
79
- List of value names from input schema.
80
- """
81
- input_names = [value.value_name for value in self._schema.input_schema.values]
82
- return input_names
83
-
84
- @property
85
- def schema_output_keys(self: "ONNXModel") -> list[str]:
86
- """Get all value names from output schemas.
87
-
88
- Returns:
89
- List of value names from output schema.
90
- """
91
- output_names = [value.value_name for value in self._schema.output_schema.values]
92
- return output_names
@@ -1,8 +1,8 @@
1
1
  """Defines an interface for instantiating serializers."""
2
2
 
3
- from typing import Literal
3
+ from typing import Literal, overload
4
4
 
5
- from kinfer import proto as P
5
+ from kinfer import proto as K
6
6
 
7
7
  from .base import MultiSerializer, Serializer
8
8
  from .json import JsonMultiSerializer, JsonSerializer
@@ -12,7 +12,19 @@ from .pytorch import PyTorchMultiSerializer, PyTorchSerializer
12
12
  SerializerType = Literal["json", "numpy", "pytorch"]
13
13
 
14
14
 
15
- def get_serializer(schema: P.ValueSchema, serializer_type: SerializerType) -> Serializer:
15
+ @overload
16
+ def get_serializer(schema: K.ValueSchema, serializer_type: Literal["json"]) -> JsonSerializer: ...
17
+
18
+
19
+ @overload
20
+ def get_serializer(schema: K.ValueSchema, serializer_type: Literal["numpy"]) -> NumpySerializer: ...
21
+
22
+
23
+ @overload
24
+ def get_serializer(schema: K.ValueSchema, serializer_type: Literal["pytorch"]) -> PyTorchSerializer: ...
25
+
26
+
27
+ def get_serializer(schema: K.ValueSchema, serializer_type: SerializerType) -> Serializer:
16
28
  match serializer_type:
17
29
  case "json":
18
30
  return JsonSerializer(schema=schema)
@@ -24,7 +36,19 @@ def get_serializer(schema: P.ValueSchema, serializer_type: SerializerType) -> Se
24
36
  raise ValueError(f"Unsupported serializer type: {serializer_type}")
25
37
 
26
38
 
27
- def get_multi_serializer(schema: P.IOSchema, serializer_type: SerializerType) -> MultiSerializer:
39
+ @overload
40
+ def get_multi_serializer(schema: K.IOSchema, serializer_type: Literal["json"]) -> JsonMultiSerializer: ...
41
+
42
+
43
+ @overload
44
+ def get_multi_serializer(schema: K.IOSchema, serializer_type: Literal["numpy"]) -> NumpyMultiSerializer: ...
45
+
46
+
47
+ @overload
48
+ def get_multi_serializer(schema: K.IOSchema, serializer_type: Literal["pytorch"]) -> PyTorchMultiSerializer: ...
49
+
50
+
51
+ def get_multi_serializer(schema: K.IOSchema, serializer_type: SerializerType) -> MultiSerializer:
28
52
  match serializer_type:
29
53
  case "json":
30
54
  return JsonMultiSerializer(schema=schema)
kinfer/serialize/base.py CHANGED
@@ -3,7 +3,7 @@
3
3
  from abc import ABC, abstractmethod
4
4
  from typing import Generic, Literal, Sequence, TypeVar, overload
5
5
 
6
- from kinfer import proto as P
6
+ from kinfer import proto as K
7
7
 
8
8
  T = TypeVar("T")
9
9
 
@@ -12,8 +12,8 @@ class JointPositionsSerializer(ABC, Generic[T]):
12
12
  @abstractmethod
13
13
  def serialize_joint_positions(
14
14
  self: "JointPositionsSerializer[T]",
15
- schema: P.JointPositionsSchema,
16
- value: P.JointPositionsValue,
15
+ schema: K.JointPositionsSchema,
16
+ value: K.JointPositionsValue,
17
17
  ) -> T:
18
18
  """Serialize a joint positions value.
19
19
 
@@ -28,9 +28,9 @@ class JointPositionsSerializer(ABC, Generic[T]):
28
28
  @abstractmethod
29
29
  def deserialize_joint_positions(
30
30
  self: "JointPositionsSerializer[T]",
31
- schema: P.JointPositionsSchema,
31
+ schema: K.JointPositionsSchema,
32
32
  value: T,
33
- ) -> P.JointPositionsValue:
33
+ ) -> K.JointPositionsValue:
34
34
  """Deserialize a joint positions value.
35
35
 
36
36
  Args:
@@ -47,8 +47,8 @@ class JointVelocitiesSerializer(ABC, Generic[T]):
47
47
  @abstractmethod
48
48
  def serialize_joint_velocities(
49
49
  self: "JointVelocitiesSerializer[T]",
50
- schema: P.JointVelocitiesSchema,
51
- value: P.JointVelocitiesValue,
50
+ schema: K.JointVelocitiesSchema,
51
+ value: K.JointVelocitiesValue,
52
52
  ) -> T:
53
53
  """Serialize a joint velocities value.
54
54
 
@@ -63,9 +63,9 @@ class JointVelocitiesSerializer(ABC, Generic[T]):
63
63
  @abstractmethod
64
64
  def deserialize_joint_velocities(
65
65
  self: "JointVelocitiesSerializer[T]",
66
- schema: P.JointVelocitiesSchema,
66
+ schema: K.JointVelocitiesSchema,
67
67
  value: T,
68
- ) -> P.JointVelocitiesValue:
68
+ ) -> K.JointVelocitiesValue:
69
69
  """Deserialize a joint velocities value.
70
70
 
71
71
  Args:
@@ -81,8 +81,8 @@ class JointTorquesSerializer(ABC, Generic[T]):
81
81
  @abstractmethod
82
82
  def serialize_joint_torques(
83
83
  self: "JointTorquesSerializer[T]",
84
- schema: P.JointTorquesSchema,
85
- value: P.JointTorquesValue,
84
+ schema: K.JointTorquesSchema,
85
+ value: K.JointTorquesValue,
86
86
  ) -> T:
87
87
  """Serialize a joint torques value.
88
88
 
@@ -97,9 +97,9 @@ class JointTorquesSerializer(ABC, Generic[T]):
97
97
  @abstractmethod
98
98
  def deserialize_joint_torques(
99
99
  self: "JointTorquesSerializer[T]",
100
- schema: P.JointTorquesSchema,
100
+ schema: K.JointTorquesSchema,
101
101
  value: T,
102
- ) -> P.JointTorquesValue:
102
+ ) -> K.JointTorquesValue:
103
103
  """Deserialize a joint torques value.
104
104
 
105
105
  Args:
@@ -115,8 +115,8 @@ class JointCommandsSerializer(ABC, Generic[T]):
115
115
  @abstractmethod
116
116
  def serialize_joint_commands(
117
117
  self: "JointCommandsSerializer[T]",
118
- schema: P.JointCommandsSchema,
119
- value: P.JointCommandsValue,
118
+ schema: K.JointCommandsSchema,
119
+ value: K.JointCommandsValue,
120
120
  ) -> T:
121
121
  """Serialize a joint commands value.
122
122
 
@@ -131,9 +131,9 @@ class JointCommandsSerializer(ABC, Generic[T]):
131
131
  @abstractmethod
132
132
  def deserialize_joint_commands(
133
133
  self: "JointCommandsSerializer[T]",
134
- schema: P.JointCommandsSchema,
134
+ schema: K.JointCommandsSchema,
135
135
  value: T,
136
- ) -> P.JointCommandsValue:
136
+ ) -> K.JointCommandsValue:
137
137
  """Deserialize a joint commands value.
138
138
 
139
139
  Args:
@@ -149,8 +149,8 @@ class CameraFrameSerializer(ABC, Generic[T]):
149
149
  @abstractmethod
150
150
  def serialize_camera_frame(
151
151
  self: "CameraFrameSerializer[T]",
152
- schema: P.CameraFrameSchema,
153
- value: P.CameraFrameValue,
152
+ schema: K.CameraFrameSchema,
153
+ value: K.CameraFrameValue,
154
154
  ) -> T:
155
155
  """Serialize a camera frame value.
156
156
 
@@ -165,9 +165,9 @@ class CameraFrameSerializer(ABC, Generic[T]):
165
165
  @abstractmethod
166
166
  def deserialize_camera_frame(
167
167
  self: "CameraFrameSerializer[T]",
168
- schema: P.CameraFrameSchema,
168
+ schema: K.CameraFrameSchema,
169
169
  value: T,
170
- ) -> P.CameraFrameValue:
170
+ ) -> K.CameraFrameValue:
171
171
  """Deserialize a camera frame value.
172
172
 
173
173
  Args:
@@ -183,8 +183,8 @@ class AudioFrameSerializer(ABC, Generic[T]):
183
183
  @abstractmethod
184
184
  def serialize_audio_frame(
185
185
  self: "AudioFrameSerializer[T]",
186
- schema: P.AudioFrameSchema,
187
- value: P.AudioFrameValue,
186
+ schema: K.AudioFrameSchema,
187
+ value: K.AudioFrameValue,
188
188
  ) -> T:
189
189
  """Serialize an audio frame value.
190
190
 
@@ -199,9 +199,9 @@ class AudioFrameSerializer(ABC, Generic[T]):
199
199
  @abstractmethod
200
200
  def deserialize_audio_frame(
201
201
  self: "AudioFrameSerializer[T]",
202
- schema: P.AudioFrameSchema,
202
+ schema: K.AudioFrameSchema,
203
203
  value: T,
204
- ) -> P.AudioFrameValue:
204
+ ) -> K.AudioFrameValue:
205
205
  """Deserialize an audio frame value.
206
206
 
207
207
  Args:
@@ -217,8 +217,8 @@ class ImuSerializer(ABC, Generic[T]):
217
217
  @abstractmethod
218
218
  def serialize_imu(
219
219
  self: "ImuSerializer[T]",
220
- schema: P.ImuSchema,
221
- value: P.ImuValue,
220
+ schema: K.ImuSchema,
221
+ value: K.ImuValue,
222
222
  ) -> T:
223
223
  """Serialize an IMU value.
224
224
 
@@ -233,9 +233,9 @@ class ImuSerializer(ABC, Generic[T]):
233
233
  @abstractmethod
234
234
  def deserialize_imu(
235
235
  self: "ImuSerializer[T]",
236
- schema: P.ImuSchema,
236
+ schema: K.ImuSchema,
237
237
  value: T,
238
- ) -> P.ImuValue:
238
+ ) -> K.ImuValue:
239
239
  """Deserialize an IMU value.
240
240
 
241
241
  Args:
@@ -251,8 +251,8 @@ class TimestampSerializer(ABC, Generic[T]):
251
251
  @abstractmethod
252
252
  def serialize_timestamp(
253
253
  self: "TimestampSerializer[T]",
254
- schema: P.TimestampSchema,
255
- value: P.TimestampValue,
254
+ schema: K.TimestampSchema,
255
+ value: K.TimestampValue,
256
256
  ) -> T:
257
257
  """Serialize a timestamp value.
258
258
 
@@ -267,9 +267,9 @@ class TimestampSerializer(ABC, Generic[T]):
267
267
  @abstractmethod
268
268
  def deserialize_timestamp(
269
269
  self: "TimestampSerializer[T]",
270
- schema: P.TimestampSchema,
270
+ schema: K.TimestampSchema,
271
271
  value: T,
272
- ) -> P.TimestampValue:
272
+ ) -> K.TimestampValue:
273
273
  """Deserialize a timestamp value.
274
274
 
275
275
  Args:
@@ -285,8 +285,8 @@ class VectorCommandSerializer(ABC, Generic[T]):
285
285
  @abstractmethod
286
286
  def serialize_vector_command(
287
287
  self: "VectorCommandSerializer[T]",
288
- schema: P.VectorCommandSchema,
289
- value: P.VectorCommandValue,
288
+ schema: K.VectorCommandSchema,
289
+ value: K.VectorCommandValue,
290
290
  ) -> T:
291
291
  """Serialize an XY command value.
292
292
 
@@ -301,9 +301,9 @@ class VectorCommandSerializer(ABC, Generic[T]):
301
301
  @abstractmethod
302
302
  def deserialize_vector_command(
303
303
  self: "VectorCommandSerializer[T]",
304
- schema: P.VectorCommandSchema,
304
+ schema: K.VectorCommandSchema,
305
305
  value: T,
306
- ) -> P.VectorCommandValue:
306
+ ) -> K.VectorCommandValue:
307
307
  """Deserialize a vector command value.
308
308
 
309
309
  Args:
@@ -319,8 +319,8 @@ class StateTensorSerializer(ABC, Generic[T]):
319
319
  @abstractmethod
320
320
  def serialize_state_tensor(
321
321
  self: "StateTensorSerializer[T]",
322
- schema: P.StateTensorSchema,
323
- value: P.StateTensorValue,
322
+ schema: K.StateTensorSchema,
323
+ value: K.StateTensorValue,
324
324
  ) -> T:
325
325
  """Serialize a state tensor value.
326
326
 
@@ -335,9 +335,9 @@ class StateTensorSerializer(ABC, Generic[T]):
335
335
  @abstractmethod
336
336
  def deserialize_state_tensor(
337
337
  self: "StateTensorSerializer[T]",
338
- schema: P.StateTensorSchema,
338
+ schema: K.StateTensorSchema,
339
339
  value: T,
340
- ) -> P.StateTensorValue:
340
+ ) -> K.StateTensorValue:
341
341
  """Deserialize a state tensor value.
342
342
 
343
343
  Args:
@@ -362,10 +362,10 @@ class Serializer(
362
362
  StateTensorSerializer[T],
363
363
  Generic[T],
364
364
  ):
365
- def __init__(self: "Serializer[T]", schema: P.ValueSchema) -> None:
365
+ def __init__(self: "Serializer[T]", schema: K.ValueSchema) -> None:
366
366
  self.schema = schema
367
367
 
368
- def serialize(self: "Serializer[T]", value: P.Value) -> T:
368
+ def serialize(self: "Serializer[T]", value: K.Value) -> T:
369
369
  value_type = value.WhichOneof("value")
370
370
 
371
371
  match value_type:
@@ -422,75 +422,75 @@ class Serializer(
422
422
  case _:
423
423
  raise ValueError(f"Unsupported value type: {value_type}")
424
424
 
425
- def deserialize(self: "Serializer[T]", value: T) -> P.Value:
425
+ def deserialize(self: "Serializer[T]", value: T) -> K.Value:
426
426
  value_type = self.schema.WhichOneof("value_type")
427
427
 
428
428
  match value_type:
429
429
  case "joint_positions":
430
- return P.Value(
430
+ return K.Value(
431
431
  joint_positions=self.deserialize_joint_positions(
432
432
  schema=self.schema.joint_positions,
433
433
  value=value,
434
434
  ),
435
435
  )
436
436
  case "joint_velocities":
437
- return P.Value(
437
+ return K.Value(
438
438
  joint_velocities=self.deserialize_joint_velocities(
439
439
  schema=self.schema.joint_velocities,
440
440
  value=value,
441
441
  ),
442
442
  )
443
443
  case "joint_torques":
444
- return P.Value(
444
+ return K.Value(
445
445
  joint_torques=self.deserialize_joint_torques(
446
446
  schema=self.schema.joint_torques,
447
447
  value=value,
448
448
  ),
449
449
  )
450
450
  case "joint_commands":
451
- return P.Value(
451
+ return K.Value(
452
452
  joint_commands=self.deserialize_joint_commands(
453
453
  schema=self.schema.joint_commands,
454
454
  value=value,
455
455
  ),
456
456
  )
457
457
  case "camera_frame":
458
- return P.Value(
458
+ return K.Value(
459
459
  camera_frame=self.deserialize_camera_frame(
460
460
  schema=self.schema.camera_frame,
461
461
  value=value,
462
462
  ),
463
463
  )
464
464
  case "audio_frame":
465
- return P.Value(
465
+ return K.Value(
466
466
  audio_frame=self.deserialize_audio_frame(
467
467
  schema=self.schema.audio_frame,
468
468
  value=value,
469
469
  ),
470
470
  )
471
471
  case "imu":
472
- return P.Value(
472
+ return K.Value(
473
473
  imu=self.deserialize_imu(
474
474
  schema=self.schema.imu,
475
475
  value=value,
476
476
  ),
477
477
  )
478
478
  case "timestamp":
479
- return P.Value(
479
+ return K.Value(
480
480
  timestamp=self.deserialize_timestamp(
481
481
  schema=self.schema.timestamp,
482
482
  value=value,
483
483
  ),
484
484
  )
485
485
  case "vector_command":
486
- return P.Value(
486
+ return K.Value(
487
487
  vector_command=self.deserialize_vector_command(
488
488
  schema=self.schema.vector_command,
489
489
  value=value,
490
490
  ),
491
491
  )
492
492
  case "state_tensor":
493
- return P.Value(
493
+ return K.Value(
494
494
  state_tensor=self.deserialize_state_tensor(
495
495
  schema=self.schema.state_tensor,
496
496
  value=value,
@@ -505,24 +505,24 @@ class MultiSerializer(Generic[T]):
505
505
  self.serializers = list(serializers)
506
506
 
507
507
  @overload
508
- def serialize_io(self: "MultiSerializer[T]", io: P.IO, *, as_dict: Literal[True]) -> dict[str, T]: ...
508
+ def serialize_io(self: "MultiSerializer[T]", io: K.IO, *, as_dict: Literal[True]) -> dict[str, T]: ...
509
509
 
510
510
  @overload
511
- def serialize_io(self: "MultiSerializer[T]", io: P.IO, *, as_dict: Literal[False] = False) -> list[T]: ...
511
+ def serialize_io(self: "MultiSerializer[T]", io: K.IO, *, as_dict: Literal[False] = False) -> list[T]: ...
512
512
 
513
- def serialize_io(self: "MultiSerializer[T]", io: P.IO, *, as_dict: bool = False) -> dict[str, T] | list[T]:
514
- if not isinstance(io, P.IO):
513
+ def serialize_io(self: "MultiSerializer[T]", io: K.IO, *, as_dict: bool = False) -> dict[str, T] | list[T]:
514
+ if not isinstance(io, K.IO):
515
515
  raise ValueError(f"Inputs must be an IO protobuf, not {type(io)}")
516
516
  if as_dict:
517
517
  return {s.schema.value_name: s.serialize(i) for s, i in zip(self.serializers, io.values)}
518
518
  return [s.serialize(i) for s, i in zip(self.serializers, io.values)]
519
519
 
520
- def deserialize_io(self: "MultiSerializer[T]", io: dict[str, T] | list[T]) -> P.IO:
520
+ def deserialize_io(self: "MultiSerializer[T]", io: dict[str, T] | list[T]) -> K.IO:
521
521
  if not isinstance(io, (dict, list)):
522
522
  raise ValueError(f"Inputs must be a dictionary or list, not {type(io)}")
523
523
  if isinstance(io, dict):
524
- return P.IO(values=[s.deserialize(i) for s, i in zip(self.serializers, io.values())])
525
- return P.IO(values=[s.deserialize(i) for s, i in zip(self.serializers, io)])
524
+ return K.IO(values=[s.deserialize(i) for s, i in zip(self.serializers, io.values())])
525
+ return K.IO(values=[s.deserialize(i) for s, i in zip(self.serializers, io)])
526
526
 
527
527
  def assign_names(self: "MultiSerializer[T]", values: Sequence[T]) -> dict[str, T]:
528
528
  if not isinstance(values, Sequence):