flwr-nightly 1.7.0.dev20240119__py3-none-any.whl → 1.7.0.dev20240123__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -87,8 +87,17 @@ class ConfigsRecord:
87
87
  # 1s to check 10M element list on a M2 Pro
88
88
  # In such settings, you'd be better of treating such config as
89
89
  # an array and pass it to a ParametersRecord.
90
- for list_value in value:
91
- is_valid(list_value)
90
+ # Empty lists are valid
91
+ if len(value) > 0:
92
+ is_valid(value[0])
93
+ # all elements in the list must be of the same valid type
94
+ # this is needed for protobuf
95
+ value_type = type(value[0])
96
+ if not all(isinstance(v, value_type) for v in value):
97
+ raise TypeError(
98
+ "All values in a list must be of the same valid type. "
99
+ f"One of {ConfigsScalar}."
100
+ )
92
101
  else:
93
102
  is_valid(value)
94
103
 
@@ -87,8 +87,17 @@ class MetricsRecord:
87
87
  # 1s to check 10M element list on a M2 Pro
88
88
  # In such settings, you'd be better of treating such metric as
89
89
  # an array and pass it to a ParametersRecord.
90
- for list_value in value:
91
- is_valid(list_value)
90
+ # Empty lists are valid
91
+ if len(value) > 0:
92
+ is_valid(value[0])
93
+ # all elements in the list must be of the same valid type
94
+ # this is needed for protobuf
95
+ value_type = type(value[0])
96
+ if not all(isinstance(v, value_type) for v in value):
97
+ raise TypeError(
98
+ "All values in a list must be of the same valid type. "
99
+ f"One of {MetricsScalar}."
100
+ )
92
101
  else:
93
102
  is_valid(value)
94
103
 
@@ -0,0 +1,401 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """RecordSet utilities."""
16
+
17
+
18
+ from typing import Dict, Mapping, OrderedDict, Tuple, Union, cast, get_args
19
+
20
+ from .configsrecord import ConfigsRecord
21
+ from .metricsrecord import MetricsRecord
22
+ from .parametersrecord import Array, ParametersRecord
23
+ from .recordset import RecordSet
24
+ from .typing import (
25
+ Code,
26
+ ConfigsRecordValues,
27
+ EvaluateIns,
28
+ EvaluateRes,
29
+ FitIns,
30
+ FitRes,
31
+ GetParametersIns,
32
+ GetParametersRes,
33
+ GetPropertiesIns,
34
+ GetPropertiesRes,
35
+ MetricsRecordValues,
36
+ Parameters,
37
+ Scalar,
38
+ Status,
39
+ )
40
+
41
+
42
+ def parametersrecord_to_parameters(
43
+ record: ParametersRecord, keep_input: bool = False
44
+ ) -> Parameters:
45
+ """Convert ParameterRecord to legacy Parameters.
46
+
47
+ Warning: Because `Arrays` in `ParametersRecord` encode more information of the
48
+ array-like or tensor-like data (e.g their datatype, shape) than `Parameters` it
49
+ might not be possible to reconstruct such data structures from `Parameters` objects
50
+ alone. Additional information or metadta must be provided from elsewhere.
51
+
52
+ Parameters
53
+ ----------
54
+ record : ParametersRecord
55
+ The record to be conveted into Parameters.
56
+ keep_input : bool (default: False)
57
+ A boolean indicating whether entries in the record should be deleted from the
58
+ input dictionary immediately after adding them to the record.
59
+ """
60
+ parameters = Parameters(tensors=[], tensor_type="")
61
+
62
+ for key in list(record.data.keys()):
63
+ parameters.tensors.append(record[key].data)
64
+
65
+ if not parameters.tensor_type:
66
+ # Setting from first array in record. Recall the warning in the docstrings
67
+ # of this function.
68
+ parameters.tensor_type = record[key].stype
69
+
70
+ if not keep_input:
71
+ del record.data[key]
72
+
73
+ return parameters
74
+
75
+
76
+ def parameters_to_parametersrecord(
77
+ parameters: Parameters, keep_input: bool = False
78
+ ) -> ParametersRecord:
79
+ """Convert legacy Parameters into a single ParametersRecord.
80
+
81
+ Because there is no concept of names in the legacy Parameters, arbitrary keys will
82
+ be used when constructing the ParametersRecord. Similarly, the shape and data type
83
+ won't be recorded in the Array objects.
84
+
85
+ Parameters
86
+ ----------
87
+ parameters : Parameters
88
+ Parameters object to be represented as a ParametersRecord.
89
+ keep_input : bool (default: False)
90
+ A boolean indicating whether parameters should be deleted from the input
91
+ Parameters object (i.e. a list of serialized NumPy arrays) immediately after
92
+ adding them to the record.
93
+ """
94
+ tensor_type = parameters.tensor_type
95
+
96
+ p_record = ParametersRecord()
97
+
98
+ num_arrays = len(parameters.tensors)
99
+ for idx in range(num_arrays):
100
+ if keep_input:
101
+ tensor = parameters.tensors[idx]
102
+ else:
103
+ tensor = parameters.tensors.pop(0)
104
+ p_record.set_parameters(
105
+ OrderedDict(
106
+ {str(idx): Array(data=tensor, dtype="", stype=tensor_type, shape=[])}
107
+ )
108
+ )
109
+
110
+ return p_record
111
+
112
+
113
+ def _check_mapping_from_recordscalartype_to_scalar(
114
+ record_data: Mapping[str, Union[ConfigsRecordValues, MetricsRecordValues]]
115
+ ) -> Dict[str, Scalar]:
116
+ """Check mapping `common.*RecordValues` into `common.Scalar` is possible."""
117
+ for value in record_data.values():
118
+ if not isinstance(value, get_args(Scalar)):
119
+ raise TypeError(
120
+ "There is not a 1:1 mapping between `common.Scalar` types and those "
121
+ "supported in `common.ConfigsRecordValues` or "
122
+ "`common.ConfigsRecordValues`. Consider casting your values to a type "
123
+ "supported by the `common.RecordSet` infrastructure. "
124
+ f"You used type: {type(value)}"
125
+ )
126
+ return cast(Dict[str, Scalar], record_data)
127
+
128
+
129
+ def _recordset_to_fit_or_evaluate_ins_components(
130
+ recordset: RecordSet,
131
+ ins_str: str,
132
+ keep_input: bool,
133
+ ) -> Tuple[Parameters, Dict[str, Scalar]]:
134
+ """Derive Fit/Evaluate Ins from a RecordSet."""
135
+ # get Array and construct Parameters
136
+ parameters_record = recordset.get_parameters(f"{ins_str}.parameters")
137
+
138
+ parameters = parametersrecord_to_parameters(
139
+ parameters_record, keep_input=keep_input
140
+ )
141
+
142
+ # get config dict
143
+ config_record = recordset.get_configs(f"{ins_str}.config")
144
+
145
+ config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record.data)
146
+
147
+ return parameters, config_dict
148
+
149
+
150
+ def _fit_or_evaluate_ins_to_recordset(
151
+ ins: Union[FitIns, EvaluateIns], keep_input: bool
152
+ ) -> RecordSet:
153
+ recordset = RecordSet()
154
+
155
+ ins_str = "fitins" if isinstance(ins, FitIns) else "evaluateins"
156
+ recordset.set_parameters(
157
+ name=f"{ins_str}.parameters",
158
+ record=parameters_to_parametersrecord(ins.parameters, keep_input=keep_input),
159
+ )
160
+
161
+ recordset.set_configs(
162
+ name=f"{ins_str}.config", record=ConfigsRecord(ins.config) # type: ignore
163
+ )
164
+
165
+ return recordset
166
+
167
+
168
+ def _embed_status_into_recordset(
169
+ res_str: str, status: Status, recordset: RecordSet
170
+ ) -> RecordSet:
171
+ status_dict: Dict[str, ConfigsRecordValues] = {
172
+ "code": int(status.code.value),
173
+ "message": status.message,
174
+ }
175
+ # we add it to a `ConfigsRecord`` because the `status.message`` is a string
176
+ # and `str` values aren't supported in `MetricsRecords`
177
+ recordset.set_configs(f"{res_str}.status", record=ConfigsRecord(status_dict))
178
+ return recordset
179
+
180
+
181
+ def _extract_status_from_recordset(res_str: str, recordset: RecordSet) -> Status:
182
+ status = recordset.get_configs(f"{res_str}.status")
183
+ code = cast(int, status["code"])
184
+ return Status(code=Code(code), message=str(status["message"]))
185
+
186
+
187
+ def recordset_to_fitins(recordset: RecordSet, keep_input: bool) -> FitIns:
188
+ """Derive FitIns from a RecordSet object."""
189
+ parameters, config = _recordset_to_fit_or_evaluate_ins_components(
190
+ recordset,
191
+ ins_str="fitins",
192
+ keep_input=keep_input,
193
+ )
194
+
195
+ return FitIns(parameters=parameters, config=config)
196
+
197
+
198
+ def fitins_to_recordset(fitins: FitIns, keep_input: bool) -> RecordSet:
199
+ """Construct a RecordSet from a FitIns object."""
200
+ return _fit_or_evaluate_ins_to_recordset(fitins, keep_input)
201
+
202
+
203
+ def recordset_to_fitres(recordset: RecordSet, keep_input: bool) -> FitRes:
204
+ """Derive FitRes from a RecordSet object."""
205
+ ins_str = "fitres"
206
+ parameters = parametersrecord_to_parameters(
207
+ recordset.get_parameters(f"{ins_str}.parameters"), keep_input=keep_input
208
+ )
209
+
210
+ num_examples = cast(
211
+ int, recordset.get_metrics(f"{ins_str}.num_examples")["num_examples"]
212
+ )
213
+ configs_record = recordset.get_configs(f"{ins_str}.metrics")
214
+
215
+ metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record.data)
216
+ status = _extract_status_from_recordset(ins_str, recordset)
217
+
218
+ return FitRes(
219
+ status=status, parameters=parameters, num_examples=num_examples, metrics=metrics
220
+ )
221
+
222
+
223
+ def fitres_to_recordset(fitres: FitRes, keep_input: bool) -> RecordSet:
224
+ """Construct a RecordSet from a FitRes object."""
225
+ recordset = RecordSet()
226
+
227
+ res_str = "fitres"
228
+
229
+ recordset.set_configs(
230
+ name=f"{res_str}.metrics", record=ConfigsRecord(fitres.metrics) # type: ignore
231
+ )
232
+ recordset.set_metrics(
233
+ name=f"{res_str}.num_examples",
234
+ record=MetricsRecord({"num_examples": fitres.num_examples}),
235
+ )
236
+ recordset.set_parameters(
237
+ name=f"{res_str}.parameters",
238
+ record=parameters_to_parametersrecord(fitres.parameters, keep_input),
239
+ )
240
+
241
+ # status
242
+ recordset = _embed_status_into_recordset(res_str, fitres.status, recordset)
243
+
244
+ return recordset
245
+
246
+
247
+ def recordset_to_evaluateins(recordset: RecordSet, keep_input: bool) -> EvaluateIns:
248
+ """Derive EvaluateIns from a RecordSet object."""
249
+ parameters, config = _recordset_to_fit_or_evaluate_ins_components(
250
+ recordset,
251
+ ins_str="evaluateins",
252
+ keep_input=keep_input,
253
+ )
254
+
255
+ return EvaluateIns(parameters=parameters, config=config)
256
+
257
+
258
+ def evaluateins_to_recordset(evaluateins: EvaluateIns, keep_input: bool) -> RecordSet:
259
+ """Construct a RecordSet from a EvaluateIns object."""
260
+ return _fit_or_evaluate_ins_to_recordset(evaluateins, keep_input)
261
+
262
+
263
+ def recordset_to_evaluateres(recordset: RecordSet) -> EvaluateRes:
264
+ """Derive EvaluateRes from a RecordSet object."""
265
+ ins_str = "evaluateres"
266
+
267
+ loss = cast(int, recordset.get_metrics(f"{ins_str}.loss")["loss"])
268
+
269
+ num_examples = cast(
270
+ int, recordset.get_metrics(f"{ins_str}.num_examples")["num_examples"]
271
+ )
272
+ configs_record = recordset.get_configs(f"{ins_str}.metrics")
273
+
274
+ metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record.data)
275
+ status = _extract_status_from_recordset(ins_str, recordset)
276
+
277
+ return EvaluateRes(
278
+ status=status, loss=loss, num_examples=num_examples, metrics=metrics
279
+ )
280
+
281
+
282
+ def evaluateres_to_recordset(evaluateres: EvaluateRes) -> RecordSet:
283
+ """Construct a RecordSet from a EvaluateRes object."""
284
+ recordset = RecordSet()
285
+
286
+ res_str = "evaluateres"
287
+ # loss
288
+ recordset.set_metrics(
289
+ name=f"{res_str}.loss",
290
+ record=MetricsRecord({"loss": evaluateres.loss}),
291
+ )
292
+
293
+ # num_examples
294
+ recordset.set_metrics(
295
+ name=f"{res_str}.num_examples",
296
+ record=MetricsRecord({"num_examples": evaluateres.num_examples}),
297
+ )
298
+
299
+ # metrics
300
+ recordset.set_configs(
301
+ name=f"{res_str}.metrics",
302
+ record=ConfigsRecord(evaluateres.metrics), # type: ignore
303
+ )
304
+
305
+ # status
306
+ recordset = _embed_status_into_recordset(
307
+ f"{res_str}", evaluateres.status, recordset
308
+ )
309
+
310
+ return recordset
311
+
312
+
313
+ def recordset_to_getparametersins(recordset: RecordSet) -> GetParametersIns:
314
+ """Derive GetParametersIns from a RecordSet object."""
315
+ config_record = recordset.get_configs("getparametersins.config")
316
+
317
+ config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record.data)
318
+
319
+ return GetParametersIns(config=config_dict)
320
+
321
+
322
+ def getparametersins_to_recordset(getparameters_ins: GetParametersIns) -> RecordSet:
323
+ """Construct a RecordSet from a GetParametersIns object."""
324
+ recordset = RecordSet()
325
+
326
+ recordset.set_configs(
327
+ name="getparametersins.config",
328
+ record=ConfigsRecord(getparameters_ins.config), # type: ignore
329
+ )
330
+ return recordset
331
+
332
+
333
+ def getparametersres_to_recordset(getparametersres: GetParametersRes) -> RecordSet:
334
+ """Construct a RecordSet from a GetParametersRes object."""
335
+ recordset = RecordSet()
336
+ res_str = "getparametersres"
337
+ parameters_record = parameters_to_parametersrecord(getparametersres.parameters)
338
+ recordset.set_parameters(f"{res_str}.parameters", parameters_record)
339
+
340
+ # status
341
+ recordset = _embed_status_into_recordset(
342
+ res_str, getparametersres.status, recordset
343
+ )
344
+
345
+ return recordset
346
+
347
+
348
+ def recordset_to_getparametersres(recordset: RecordSet) -> GetParametersRes:
349
+ """Derive GetParametersRes from a RecordSet object."""
350
+ res_str = "getparametersres"
351
+ parameters = parametersrecord_to_parameters(
352
+ recordset.get_parameters(f"{res_str}.parameters")
353
+ )
354
+
355
+ status = _extract_status_from_recordset(res_str, recordset)
356
+ return GetParametersRes(status=status, parameters=parameters)
357
+
358
+
359
+ def recordset_to_getpropertiesins(recordset: RecordSet) -> GetPropertiesIns:
360
+ """Derive GetPropertiesIns from a RecordSet object."""
361
+ config_record = recordset.get_configs("getpropertiesins.config")
362
+ config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record.data)
363
+
364
+ return GetPropertiesIns(config=config_dict)
365
+
366
+
367
+ def getpropertiesins_to_recordset(getpropertiesins: GetPropertiesIns) -> RecordSet:
368
+ """Construct a RecordSet from a GetPropertiesRes object."""
369
+ recordset = RecordSet()
370
+ recordset.set_configs(
371
+ name="getpropertiesins.config",
372
+ record=ConfigsRecord(getpropertiesins.config), # type: ignore
373
+ )
374
+ return recordset
375
+
376
+
377
+ def recordset_to_getpropertiesres(recordset: RecordSet) -> GetPropertiesRes:
378
+ """Derive GetPropertiesRes from a RecordSet object."""
379
+ res_str = "getpropertiesres"
380
+ config_record = recordset.get_configs(f"{res_str}.properties")
381
+ properties = _check_mapping_from_recordscalartype_to_scalar(config_record.data)
382
+
383
+ status = _extract_status_from_recordset(res_str, recordset=recordset)
384
+
385
+ return GetPropertiesRes(status=status, properties=properties)
386
+
387
+
388
+ def getpropertiesres_to_recordset(getpropertiesres: GetPropertiesRes) -> RecordSet:
389
+ """Construct a RecordSet from a GetPropertiesRes object."""
390
+ recordset = RecordSet()
391
+ res_str = "getpropertiesres"
392
+ recordset.set_configs(
393
+ name=f"{res_str}.properties",
394
+ record=ConfigsRecord(getpropertiesres.properties), # type: ignore
395
+ )
396
+ # status
397
+ recordset = _embed_status_into_recordset(
398
+ res_str, getpropertiesres.status, recordset
399
+ )
400
+
401
+ return recordset
flwr/common/serde.py CHANGED
@@ -15,10 +15,23 @@
15
15
  """ProtoBuf serialization and deserialization."""
16
16
 
17
17
 
18
- from typing import Any, Dict, List, MutableMapping, cast
19
-
20
- from flwr.proto.task_pb2 import Value # pylint: disable=E0611
21
- from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
18
+ from typing import Any, Dict, List, MutableMapping, OrderedDict, Type, TypeVar, cast
19
+
20
+ from google.protobuf.message import Message
21
+
22
+ # pylint: disable=E0611
23
+ from flwr.proto.recordset_pb2 import Array as ProtoArray
24
+ from flwr.proto.recordset_pb2 import BoolList, BytesList
25
+ from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord
26
+ from flwr.proto.recordset_pb2 import ConfigsRecordValue as ProtoConfigsRecordValue
27
+ from flwr.proto.recordset_pb2 import DoubleList
28
+ from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord
29
+ from flwr.proto.recordset_pb2 import MetricsRecordValue as ProtoMetricsRecordValue
30
+ from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord
31
+ from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
32
+ from flwr.proto.recordset_pb2 import Sint64List, StringList
33
+ from flwr.proto.task_pb2 import Value
34
+ from flwr.proto.transport_pb2 import (
22
35
  ClientMessage,
23
36
  Code,
24
37
  Parameters,
@@ -28,7 +41,12 @@ from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
28
41
  Status,
29
42
  )
30
43
 
44
+ # pylint: enable=E0611
31
45
  from . import typing
46
+ from .configsrecord import ConfigsRecord
47
+ from .metricsrecord import MetricsRecord
48
+ from .parametersrecord import Array, ParametersRecord
49
+ from .recordset import RecordSet
32
50
 
33
51
  # === ServerMessage message ===
34
52
 
@@ -493,7 +511,7 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar:
493
511
  # === Value messages ===
494
512
 
495
513
 
496
- _python_type_to_field_name = {
514
+ _type_to_field = {
497
515
  float: "double",
498
516
  int: "sint64",
499
517
  bool: "bool",
@@ -502,22 +520,20 @@ _python_type_to_field_name = {
502
520
  }
503
521
 
504
522
 
505
- _python_list_type_to_message_and_field_name = {
506
- float: (Value.DoubleList, "double_list"),
507
- int: (Value.Sint64List, "sint64_list"),
508
- bool: (Value.BoolList, "bool_list"),
509
- str: (Value.StringList, "string_list"),
510
- bytes: (Value.BytesList, "bytes_list"),
523
+ _list_type_to_class_and_field = {
524
+ float: (DoubleList, "double_list"),
525
+ int: (Sint64List, "sint64_list"),
526
+ bool: (BoolList, "bool_list"),
527
+ str: (StringList, "string_list"),
528
+ bytes: (BytesList, "bytes_list"),
511
529
  }
512
530
 
513
531
 
514
532
  def _check_value(value: typing.Value) -> None:
515
- if isinstance(value, tuple(_python_type_to_field_name.keys())):
533
+ if isinstance(value, tuple(_type_to_field.keys())):
516
534
  return
517
535
  if isinstance(value, list):
518
- if len(value) > 0 and isinstance(
519
- value[0], tuple(_python_type_to_field_name.keys())
520
- ):
536
+ if len(value) > 0 and isinstance(value[0], tuple(_type_to_field.keys())):
521
537
  data_type = type(value[0])
522
538
  for element in value:
523
539
  if isinstance(element, data_type):
@@ -539,12 +555,12 @@ def value_to_proto(value: typing.Value) -> Value:
539
555
 
540
556
  arg = {}
541
557
  if isinstance(value, list):
542
- msg_class, field_name = _python_list_type_to_message_and_field_name[
558
+ msg_class, field_name = _list_type_to_class_and_field[
543
559
  type(value[0]) if len(value) > 0 else int
544
560
  ]
545
561
  arg[field_name] = msg_class(vals=value)
546
562
  else:
547
- arg[_python_type_to_field_name[type(value)]] = value
563
+ arg[_type_to_field[type(value)]] = value
548
564
  return Value(**arg)
549
565
 
550
566
 
@@ -573,3 +589,165 @@ def named_values_from_proto(
573
589
  ) -> Dict[str, typing.Value]:
574
590
  """Deserialize named values from ProtoBuf."""
575
591
  return {name: value_from_proto(value) for name, value in named_values_proto.items()}
592
+
593
+
594
+ # === Record messages ===
595
+
596
+
597
+ T = TypeVar("T")
598
+
599
+
600
+ def _record_value_to_proto(
601
+ value: Any, allowed_types: List[type], proto_class: Type[T]
602
+ ) -> T:
603
+ """Serialize `*RecordValue` to ProtoBuf."""
604
+ arg = {}
605
+ for t in allowed_types:
606
+ # Single element
607
+ # Note: `isinstance(False, int) == True`.
608
+ if type(value) == t: # pylint: disable=C0123
609
+ arg[_type_to_field[t]] = value
610
+ return proto_class(**arg)
611
+ # List
612
+ if isinstance(value, list) and all(isinstance(item, t) for item in value):
613
+ list_class, field_name = _list_type_to_class_and_field[t]
614
+ arg[field_name] = list_class(vals=value)
615
+ return proto_class(**arg)
616
+ # Invalid types
617
+ raise TypeError(
618
+ f"The type of the following value is not allowed "
619
+ f"in '{proto_class.__name__}':\n{value}"
620
+ )
621
+
622
+
623
+ def _record_value_from_proto(value_proto: Message) -> Any:
624
+ """Deserialize `*RecordValue` from ProtoBuf."""
625
+ value_field = cast(str, value_proto.WhichOneof("value"))
626
+ if value_field.endswith("list"):
627
+ value = list(getattr(value_proto, value_field).vals)
628
+ else:
629
+ value = getattr(value_proto, value_field)
630
+ return value
631
+
632
+
633
+ def _record_value_dict_to_proto(
634
+ value_dict: Dict[str, Any], allowed_types: List[type], value_proto_class: Type[T]
635
+ ) -> Dict[str, T]:
636
+ """Serialize the record value dict to ProtoBuf."""
637
+
638
+ def proto(_v: Any) -> T:
639
+ return _record_value_to_proto(_v, allowed_types, value_proto_class)
640
+
641
+ return {k: proto(v) for k, v in value_dict.items()}
642
+
643
+
644
+ def _record_value_dict_from_proto(
645
+ value_dict_proto: MutableMapping[str, Any]
646
+ ) -> Dict[str, Any]:
647
+ """Deserialize the record value dict from ProtoBuf."""
648
+ return {k: _record_value_from_proto(v) for k, v in value_dict_proto.items()}
649
+
650
+
651
+ def array_to_proto(array: Array) -> ProtoArray:
652
+ """Serialize Array to ProtoBuf."""
653
+ return ProtoArray(**vars(array))
654
+
655
+
656
+ def array_from_proto(array_proto: ProtoArray) -> Array:
657
+ """Deserialize Array from ProtoBuf."""
658
+ return Array(
659
+ dtype=array_proto.dtype,
660
+ shape=list(array_proto.shape),
661
+ stype=array_proto.stype,
662
+ data=array_proto.data,
663
+ )
664
+
665
+
666
+ def parameters_record_to_proto(record: ParametersRecord) -> ProtoParametersRecord:
667
+ """Serialize ParametersRecord to ProtoBuf."""
668
+ return ProtoParametersRecord(
669
+ data_keys=record.data.keys(),
670
+ data_values=map(array_to_proto, record.data.values()),
671
+ )
672
+
673
+
674
+ def parameters_record_from_proto(
675
+ record_proto: ProtoParametersRecord,
676
+ ) -> ParametersRecord:
677
+ """Deserialize ParametersRecord from ProtoBuf."""
678
+ return ParametersRecord(
679
+ array_dict=OrderedDict(
680
+ zip(record_proto.data_keys, map(array_from_proto, record_proto.data_values))
681
+ ),
682
+ keep_input=False,
683
+ )
684
+
685
+
686
+ def metrics_record_to_proto(record: MetricsRecord) -> ProtoMetricsRecord:
687
+ """Serialize MetricsRecord to ProtoBuf."""
688
+ return ProtoMetricsRecord(
689
+ data=_record_value_dict_to_proto(
690
+ record.data, [float, int], ProtoMetricsRecordValue
691
+ )
692
+ )
693
+
694
+
695
+ def metrics_record_from_proto(record_proto: ProtoMetricsRecord) -> MetricsRecord:
696
+ """Deserialize MetricsRecord from ProtoBuf."""
697
+ return MetricsRecord(
698
+ metrics_dict=cast(
699
+ Dict[str, typing.MetricsRecordValues],
700
+ _record_value_dict_from_proto(record_proto.data),
701
+ ),
702
+ keep_input=False,
703
+ )
704
+
705
+
706
+ def configs_record_to_proto(record: ConfigsRecord) -> ProtoConfigsRecord:
707
+ """Serialize ConfigsRecord to ProtoBuf."""
708
+ return ProtoConfigsRecord(
709
+ data=_record_value_dict_to_proto(
710
+ record.data, [int, float, bool, str, bytes], ProtoConfigsRecordValue
711
+ )
712
+ )
713
+
714
+
715
+ def configs_record_from_proto(record_proto: ProtoConfigsRecord) -> ConfigsRecord:
716
+ """Deserialize ConfigsRecord from ProtoBuf."""
717
+ return ConfigsRecord(
718
+ configs_dict=cast(
719
+ Dict[str, typing.ConfigsRecordValues],
720
+ _record_value_dict_from_proto(record_proto.data),
721
+ ),
722
+ keep_input=False,
723
+ )
724
+
725
+
726
+ # === RecordSet message ===
727
+
728
+
729
+ def recordset_to_proto(recordset: RecordSet) -> ProtoRecordSet:
730
+ """Serialize RecordSet to ProtoBuf."""
731
+ return ProtoRecordSet(
732
+ parameters={
733
+ k: parameters_record_to_proto(v) for k, v in recordset.parameters.items()
734
+ },
735
+ metrics={k: metrics_record_to_proto(v) for k, v in recordset.metrics.items()},
736
+ configs={k: configs_record_to_proto(v) for k, v in recordset.configs.items()},
737
+ )
738
+
739
+
740
+ def recordset_from_proto(recordset_proto: ProtoRecordSet) -> RecordSet:
741
+ """Deserialize RecordSet from ProtoBuf."""
742
+ return RecordSet(
743
+ parameters={
744
+ k: parameters_record_from_proto(v)
745
+ for k, v in recordset_proto.parameters.items()
746
+ },
747
+ metrics={
748
+ k: metrics_record_from_proto(v) for k, v in recordset_proto.metrics.items()
749
+ },
750
+ configs={
751
+ k: configs_record_from_proto(v) for k, v in recordset_proto.configs.items()
752
+ },
753
+ )