flwr-nightly 1.7.0.dev20240119__py3-none-any.whl → 1.7.0.dev20240123__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -87,8 +87,17 @@ class ConfigsRecord:
87
87
  # 1s to check 10M element list on a M2 Pro
88
88
  # In such settings, you'd be better of treating such config as
89
89
  # an array and pass it to a ParametersRecord.
90
- for list_value in value:
91
- is_valid(list_value)
90
+ # Empty lists are valid
91
+ if len(value) > 0:
92
+ is_valid(value[0])
93
+ # all elements in the list must be of the same valid type
94
+ # this is needed for protobuf
95
+ value_type = type(value[0])
96
+ if not all(isinstance(v, value_type) for v in value):
97
+ raise TypeError(
98
+ "All values in a list must be of the same valid type. "
99
+ f"One of {ConfigsScalar}."
100
+ )
92
101
  else:
93
102
  is_valid(value)
94
103
 
@@ -87,8 +87,17 @@ class MetricsRecord:
87
87
  # 1s to check 10M element list on a M2 Pro
88
88
  # In such settings, you'd be better of treating such metric as
89
89
  # an array and pass it to a ParametersRecord.
90
- for list_value in value:
91
- is_valid(list_value)
90
+ # Empty lists are valid
91
+ if len(value) > 0:
92
+ is_valid(value[0])
93
+ # all elements in the list must be of the same valid type
94
+ # this is needed for protobuf
95
+ value_type = type(value[0])
96
+ if not all(isinstance(v, value_type) for v in value):
97
+ raise TypeError(
98
+ "All values in a list must be of the same valid type. "
99
+ f"One of {MetricsScalar}."
100
+ )
92
101
  else:
93
102
  is_valid(value)
94
103
 
@@ -0,0 +1,401 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """RecordSet utilities."""
16
+
17
+
18
+ from typing import Dict, Mapping, OrderedDict, Tuple, Union, cast, get_args
19
+
20
+ from .configsrecord import ConfigsRecord
21
+ from .metricsrecord import MetricsRecord
22
+ from .parametersrecord import Array, ParametersRecord
23
+ from .recordset import RecordSet
24
+ from .typing import (
25
+ Code,
26
+ ConfigsRecordValues,
27
+ EvaluateIns,
28
+ EvaluateRes,
29
+ FitIns,
30
+ FitRes,
31
+ GetParametersIns,
32
+ GetParametersRes,
33
+ GetPropertiesIns,
34
+ GetPropertiesRes,
35
+ MetricsRecordValues,
36
+ Parameters,
37
+ Scalar,
38
+ Status,
39
+ )
40
+
41
+
42
+ def parametersrecord_to_parameters(
43
+ record: ParametersRecord, keep_input: bool = False
44
+ ) -> Parameters:
45
+ """Convert ParameterRecord to legacy Parameters.
46
+
47
+ Warning: Because `Arrays` in `ParametersRecord` encode more information of the
48
+ array-like or tensor-like data (e.g their datatype, shape) than `Parameters` it
49
+ might not be possible to reconstruct such data structures from `Parameters` objects
50
+ alone. Additional information or metadta must be provided from elsewhere.
51
+
52
+ Parameters
53
+ ----------
54
+ record : ParametersRecord
55
+ The record to be conveted into Parameters.
56
+ keep_input : bool (default: False)
57
+ A boolean indicating whether entries in the record should be deleted from the
58
+ input dictionary immediately after adding them to the record.
59
+ """
60
+ parameters = Parameters(tensors=[], tensor_type="")
61
+
62
+ for key in list(record.data.keys()):
63
+ parameters.tensors.append(record[key].data)
64
+
65
+ if not parameters.tensor_type:
66
+ # Setting from first array in record. Recall the warning in the docstrings
67
+ # of this function.
68
+ parameters.tensor_type = record[key].stype
69
+
70
+ if not keep_input:
71
+ del record.data[key]
72
+
73
+ return parameters
74
+
75
+
76
+ def parameters_to_parametersrecord(
77
+ parameters: Parameters, keep_input: bool = False
78
+ ) -> ParametersRecord:
79
+ """Convert legacy Parameters into a single ParametersRecord.
80
+
81
+ Because there is no concept of names in the legacy Parameters, arbitrary keys will
82
+ be used when constructing the ParametersRecord. Similarly, the shape and data type
83
+ won't be recorded in the Array objects.
84
+
85
+ Parameters
86
+ ----------
87
+ parameters : Parameters
88
+ Parameters object to be represented as a ParametersRecord.
89
+ keep_input : bool (default: False)
90
+ A boolean indicating whether parameters should be deleted from the input
91
+ Parameters object (i.e. a list of serialized NumPy arrays) immediately after
92
+ adding them to the record.
93
+ """
94
+ tensor_type = parameters.tensor_type
95
+
96
+ p_record = ParametersRecord()
97
+
98
+ num_arrays = len(parameters.tensors)
99
+ for idx in range(num_arrays):
100
+ if keep_input:
101
+ tensor = parameters.tensors[idx]
102
+ else:
103
+ tensor = parameters.tensors.pop(0)
104
+ p_record.set_parameters(
105
+ OrderedDict(
106
+ {str(idx): Array(data=tensor, dtype="", stype=tensor_type, shape=[])}
107
+ )
108
+ )
109
+
110
+ return p_record
111
+
112
+
113
+ def _check_mapping_from_recordscalartype_to_scalar(
114
+ record_data: Mapping[str, Union[ConfigsRecordValues, MetricsRecordValues]]
115
+ ) -> Dict[str, Scalar]:
116
+ """Check mapping `common.*RecordValues` into `common.Scalar` is possible."""
117
+ for value in record_data.values():
118
+ if not isinstance(value, get_args(Scalar)):
119
+ raise TypeError(
120
+ "There is not a 1:1 mapping between `common.Scalar` types and those "
121
+ "supported in `common.ConfigsRecordValues` or "
122
+ "`common.ConfigsRecordValues`. Consider casting your values to a type "
123
+ "supported by the `common.RecordSet` infrastructure. "
124
+ f"You used type: {type(value)}"
125
+ )
126
+ return cast(Dict[str, Scalar], record_data)
127
+
128
+
129
+ def _recordset_to_fit_or_evaluate_ins_components(
130
+ recordset: RecordSet,
131
+ ins_str: str,
132
+ keep_input: bool,
133
+ ) -> Tuple[Parameters, Dict[str, Scalar]]:
134
+ """Derive Fit/Evaluate Ins from a RecordSet."""
135
+ # get Array and construct Parameters
136
+ parameters_record = recordset.get_parameters(f"{ins_str}.parameters")
137
+
138
+ parameters = parametersrecord_to_parameters(
139
+ parameters_record, keep_input=keep_input
140
+ )
141
+
142
+ # get config dict
143
+ config_record = recordset.get_configs(f"{ins_str}.config")
144
+
145
+ config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record.data)
146
+
147
+ return parameters, config_dict
148
+
149
+
150
+ def _fit_or_evaluate_ins_to_recordset(
151
+ ins: Union[FitIns, EvaluateIns], keep_input: bool
152
+ ) -> RecordSet:
153
+ recordset = RecordSet()
154
+
155
+ ins_str = "fitins" if isinstance(ins, FitIns) else "evaluateins"
156
+ recordset.set_parameters(
157
+ name=f"{ins_str}.parameters",
158
+ record=parameters_to_parametersrecord(ins.parameters, keep_input=keep_input),
159
+ )
160
+
161
+ recordset.set_configs(
162
+ name=f"{ins_str}.config", record=ConfigsRecord(ins.config) # type: ignore
163
+ )
164
+
165
+ return recordset
166
+
167
+
168
+ def _embed_status_into_recordset(
169
+ res_str: str, status: Status, recordset: RecordSet
170
+ ) -> RecordSet:
171
+ status_dict: Dict[str, ConfigsRecordValues] = {
172
+ "code": int(status.code.value),
173
+ "message": status.message,
174
+ }
175
+ # we add it to a `ConfigsRecord`` because the `status.message`` is a string
176
+ # and `str` values aren't supported in `MetricsRecords`
177
+ recordset.set_configs(f"{res_str}.status", record=ConfigsRecord(status_dict))
178
+ return recordset
179
+
180
+
181
+ def _extract_status_from_recordset(res_str: str, recordset: RecordSet) -> Status:
182
+ status = recordset.get_configs(f"{res_str}.status")
183
+ code = cast(int, status["code"])
184
+ return Status(code=Code(code), message=str(status["message"]))
185
+
186
+
187
+ def recordset_to_fitins(recordset: RecordSet, keep_input: bool) -> FitIns:
188
+ """Derive FitIns from a RecordSet object."""
189
+ parameters, config = _recordset_to_fit_or_evaluate_ins_components(
190
+ recordset,
191
+ ins_str="fitins",
192
+ keep_input=keep_input,
193
+ )
194
+
195
+ return FitIns(parameters=parameters, config=config)
196
+
197
+
198
+ def fitins_to_recordset(fitins: FitIns, keep_input: bool) -> RecordSet:
199
+ """Construct a RecordSet from a FitIns object."""
200
+ return _fit_or_evaluate_ins_to_recordset(fitins, keep_input)
201
+
202
+
203
+ def recordset_to_fitres(recordset: RecordSet, keep_input: bool) -> FitRes:
204
+ """Derive FitRes from a RecordSet object."""
205
+ ins_str = "fitres"
206
+ parameters = parametersrecord_to_parameters(
207
+ recordset.get_parameters(f"{ins_str}.parameters"), keep_input=keep_input
208
+ )
209
+
210
+ num_examples = cast(
211
+ int, recordset.get_metrics(f"{ins_str}.num_examples")["num_examples"]
212
+ )
213
+ configs_record = recordset.get_configs(f"{ins_str}.metrics")
214
+
215
+ metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record.data)
216
+ status = _extract_status_from_recordset(ins_str, recordset)
217
+
218
+ return FitRes(
219
+ status=status, parameters=parameters, num_examples=num_examples, metrics=metrics
220
+ )
221
+
222
+
223
+ def fitres_to_recordset(fitres: FitRes, keep_input: bool) -> RecordSet:
224
+ """Construct a RecordSet from a FitRes object."""
225
+ recordset = RecordSet()
226
+
227
+ res_str = "fitres"
228
+
229
+ recordset.set_configs(
230
+ name=f"{res_str}.metrics", record=ConfigsRecord(fitres.metrics) # type: ignore
231
+ )
232
+ recordset.set_metrics(
233
+ name=f"{res_str}.num_examples",
234
+ record=MetricsRecord({"num_examples": fitres.num_examples}),
235
+ )
236
+ recordset.set_parameters(
237
+ name=f"{res_str}.parameters",
238
+ record=parameters_to_parametersrecord(fitres.parameters, keep_input),
239
+ )
240
+
241
+ # status
242
+ recordset = _embed_status_into_recordset(res_str, fitres.status, recordset)
243
+
244
+ return recordset
245
+
246
+
247
+ def recordset_to_evaluateins(recordset: RecordSet, keep_input: bool) -> EvaluateIns:
248
+ """Derive EvaluateIns from a RecordSet object."""
249
+ parameters, config = _recordset_to_fit_or_evaluate_ins_components(
250
+ recordset,
251
+ ins_str="evaluateins",
252
+ keep_input=keep_input,
253
+ )
254
+
255
+ return EvaluateIns(parameters=parameters, config=config)
256
+
257
+
258
+ def evaluateins_to_recordset(evaluateins: EvaluateIns, keep_input: bool) -> RecordSet:
259
+ """Construct a RecordSet from a EvaluateIns object."""
260
+ return _fit_or_evaluate_ins_to_recordset(evaluateins, keep_input)
261
+
262
+
263
+ def recordset_to_evaluateres(recordset: RecordSet) -> EvaluateRes:
264
+ """Derive EvaluateRes from a RecordSet object."""
265
+ ins_str = "evaluateres"
266
+
267
+ loss = cast(int, recordset.get_metrics(f"{ins_str}.loss")["loss"])
268
+
269
+ num_examples = cast(
270
+ int, recordset.get_metrics(f"{ins_str}.num_examples")["num_examples"]
271
+ )
272
+ configs_record = recordset.get_configs(f"{ins_str}.metrics")
273
+
274
+ metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record.data)
275
+ status = _extract_status_from_recordset(ins_str, recordset)
276
+
277
+ return EvaluateRes(
278
+ status=status, loss=loss, num_examples=num_examples, metrics=metrics
279
+ )
280
+
281
+
282
+ def evaluateres_to_recordset(evaluateres: EvaluateRes) -> RecordSet:
283
+ """Construct a RecordSet from a EvaluateRes object."""
284
+ recordset = RecordSet()
285
+
286
+ res_str = "evaluateres"
287
+ # loss
288
+ recordset.set_metrics(
289
+ name=f"{res_str}.loss",
290
+ record=MetricsRecord({"loss": evaluateres.loss}),
291
+ )
292
+
293
+ # num_examples
294
+ recordset.set_metrics(
295
+ name=f"{res_str}.num_examples",
296
+ record=MetricsRecord({"num_examples": evaluateres.num_examples}),
297
+ )
298
+
299
+ # metrics
300
+ recordset.set_configs(
301
+ name=f"{res_str}.metrics",
302
+ record=ConfigsRecord(evaluateres.metrics), # type: ignore
303
+ )
304
+
305
+ # status
306
+ recordset = _embed_status_into_recordset(
307
+ f"{res_str}", evaluateres.status, recordset
308
+ )
309
+
310
+ return recordset
311
+
312
+
313
+ def recordset_to_getparametersins(recordset: RecordSet) -> GetParametersIns:
314
+ """Derive GetParametersIns from a RecordSet object."""
315
+ config_record = recordset.get_configs("getparametersins.config")
316
+
317
+ config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record.data)
318
+
319
+ return GetParametersIns(config=config_dict)
320
+
321
+
322
+ def getparametersins_to_recordset(getparameters_ins: GetParametersIns) -> RecordSet:
323
+ """Construct a RecordSet from a GetParametersIns object."""
324
+ recordset = RecordSet()
325
+
326
+ recordset.set_configs(
327
+ name="getparametersins.config",
328
+ record=ConfigsRecord(getparameters_ins.config), # type: ignore
329
+ )
330
+ return recordset
331
+
332
+
333
+ def getparametersres_to_recordset(getparametersres: GetParametersRes) -> RecordSet:
334
+ """Construct a RecordSet from a GetParametersRes object."""
335
+ recordset = RecordSet()
336
+ res_str = "getparametersres"
337
+ parameters_record = parameters_to_parametersrecord(getparametersres.parameters)
338
+ recordset.set_parameters(f"{res_str}.parameters", parameters_record)
339
+
340
+ # status
341
+ recordset = _embed_status_into_recordset(
342
+ res_str, getparametersres.status, recordset
343
+ )
344
+
345
+ return recordset
346
+
347
+
348
+ def recordset_to_getparametersres(recordset: RecordSet) -> GetParametersRes:
349
+ """Derive GetParametersRes from a RecordSet object."""
350
+ res_str = "getparametersres"
351
+ parameters = parametersrecord_to_parameters(
352
+ recordset.get_parameters(f"{res_str}.parameters")
353
+ )
354
+
355
+ status = _extract_status_from_recordset(res_str, recordset)
356
+ return GetParametersRes(status=status, parameters=parameters)
357
+
358
+
359
+ def recordset_to_getpropertiesins(recordset: RecordSet) -> GetPropertiesIns:
360
+ """Derive GetPropertiesIns from a RecordSet object."""
361
+ config_record = recordset.get_configs("getpropertiesins.config")
362
+ config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record.data)
363
+
364
+ return GetPropertiesIns(config=config_dict)
365
+
366
+
367
+ def getpropertiesins_to_recordset(getpropertiesins: GetPropertiesIns) -> RecordSet:
368
+ """Construct a RecordSet from a GetPropertiesRes object."""
369
+ recordset = RecordSet()
370
+ recordset.set_configs(
371
+ name="getpropertiesins.config",
372
+ record=ConfigsRecord(getpropertiesins.config), # type: ignore
373
+ )
374
+ return recordset
375
+
376
+
377
+ def recordset_to_getpropertiesres(recordset: RecordSet) -> GetPropertiesRes:
378
+ """Derive GetPropertiesRes from a RecordSet object."""
379
+ res_str = "getpropertiesres"
380
+ config_record = recordset.get_configs(f"{res_str}.properties")
381
+ properties = _check_mapping_from_recordscalartype_to_scalar(config_record.data)
382
+
383
+ status = _extract_status_from_recordset(res_str, recordset=recordset)
384
+
385
+ return GetPropertiesRes(status=status, properties=properties)
386
+
387
+
388
+ def getpropertiesres_to_recordset(getpropertiesres: GetPropertiesRes) -> RecordSet:
389
+ """Construct a RecordSet from a GetPropertiesRes object."""
390
+ recordset = RecordSet()
391
+ res_str = "getpropertiesres"
392
+ recordset.set_configs(
393
+ name=f"{res_str}.properties",
394
+ record=ConfigsRecord(getpropertiesres.properties), # type: ignore
395
+ )
396
+ # status
397
+ recordset = _embed_status_into_recordset(
398
+ res_str, getpropertiesres.status, recordset
399
+ )
400
+
401
+ return recordset
flwr/common/serde.py CHANGED
@@ -15,10 +15,23 @@
15
15
  """ProtoBuf serialization and deserialization."""
16
16
 
17
17
 
18
- from typing import Any, Dict, List, MutableMapping, cast
19
-
20
- from flwr.proto.task_pb2 import Value # pylint: disable=E0611
21
- from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
18
+ from typing import Any, Dict, List, MutableMapping, OrderedDict, Type, TypeVar, cast
19
+
20
+ from google.protobuf.message import Message
21
+
22
+ # pylint: disable=E0611
23
+ from flwr.proto.recordset_pb2 import Array as ProtoArray
24
+ from flwr.proto.recordset_pb2 import BoolList, BytesList
25
+ from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord
26
+ from flwr.proto.recordset_pb2 import ConfigsRecordValue as ProtoConfigsRecordValue
27
+ from flwr.proto.recordset_pb2 import DoubleList
28
+ from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord
29
+ from flwr.proto.recordset_pb2 import MetricsRecordValue as ProtoMetricsRecordValue
30
+ from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord
31
+ from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
32
+ from flwr.proto.recordset_pb2 import Sint64List, StringList
33
+ from flwr.proto.task_pb2 import Value
34
+ from flwr.proto.transport_pb2 import (
22
35
  ClientMessage,
23
36
  Code,
24
37
  Parameters,
@@ -28,7 +41,12 @@ from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
28
41
  Status,
29
42
  )
30
43
 
44
+ # pylint: enable=E0611
31
45
  from . import typing
46
+ from .configsrecord import ConfigsRecord
47
+ from .metricsrecord import MetricsRecord
48
+ from .parametersrecord import Array, ParametersRecord
49
+ from .recordset import RecordSet
32
50
 
33
51
  # === ServerMessage message ===
34
52
 
@@ -493,7 +511,7 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar:
493
511
  # === Value messages ===
494
512
 
495
513
 
496
- _python_type_to_field_name = {
514
+ _type_to_field = {
497
515
  float: "double",
498
516
  int: "sint64",
499
517
  bool: "bool",
@@ -502,22 +520,20 @@ _python_type_to_field_name = {
502
520
  }
503
521
 
504
522
 
505
- _python_list_type_to_message_and_field_name = {
506
- float: (Value.DoubleList, "double_list"),
507
- int: (Value.Sint64List, "sint64_list"),
508
- bool: (Value.BoolList, "bool_list"),
509
- str: (Value.StringList, "string_list"),
510
- bytes: (Value.BytesList, "bytes_list"),
523
+ _list_type_to_class_and_field = {
524
+ float: (DoubleList, "double_list"),
525
+ int: (Sint64List, "sint64_list"),
526
+ bool: (BoolList, "bool_list"),
527
+ str: (StringList, "string_list"),
528
+ bytes: (BytesList, "bytes_list"),
511
529
  }
512
530
 
513
531
 
514
532
  def _check_value(value: typing.Value) -> None:
515
- if isinstance(value, tuple(_python_type_to_field_name.keys())):
533
+ if isinstance(value, tuple(_type_to_field.keys())):
516
534
  return
517
535
  if isinstance(value, list):
518
- if len(value) > 0 and isinstance(
519
- value[0], tuple(_python_type_to_field_name.keys())
520
- ):
536
+ if len(value) > 0 and isinstance(value[0], tuple(_type_to_field.keys())):
521
537
  data_type = type(value[0])
522
538
  for element in value:
523
539
  if isinstance(element, data_type):
@@ -539,12 +555,12 @@ def value_to_proto(value: typing.Value) -> Value:
539
555
 
540
556
  arg = {}
541
557
  if isinstance(value, list):
542
- msg_class, field_name = _python_list_type_to_message_and_field_name[
558
+ msg_class, field_name = _list_type_to_class_and_field[
543
559
  type(value[0]) if len(value) > 0 else int
544
560
  ]
545
561
  arg[field_name] = msg_class(vals=value)
546
562
  else:
547
- arg[_python_type_to_field_name[type(value)]] = value
563
+ arg[_type_to_field[type(value)]] = value
548
564
  return Value(**arg)
549
565
 
550
566
 
@@ -573,3 +589,165 @@ def named_values_from_proto(
573
589
  ) -> Dict[str, typing.Value]:
574
590
  """Deserialize named values from ProtoBuf."""
575
591
  return {name: value_from_proto(value) for name, value in named_values_proto.items()}
592
+
593
+
594
+ # === Record messages ===
595
+
596
+
597
+ T = TypeVar("T")
598
+
599
+
600
+ def _record_value_to_proto(
601
+ value: Any, allowed_types: List[type], proto_class: Type[T]
602
+ ) -> T:
603
+ """Serialize `*RecordValue` to ProtoBuf."""
604
+ arg = {}
605
+ for t in allowed_types:
606
+ # Single element
607
+ # Note: `isinstance(False, int) == True`.
608
+ if type(value) == t: # pylint: disable=C0123
609
+ arg[_type_to_field[t]] = value
610
+ return proto_class(**arg)
611
+ # List
612
+ if isinstance(value, list) and all(isinstance(item, t) for item in value):
613
+ list_class, field_name = _list_type_to_class_and_field[t]
614
+ arg[field_name] = list_class(vals=value)
615
+ return proto_class(**arg)
616
+ # Invalid types
617
+ raise TypeError(
618
+ f"The type of the following value is not allowed "
619
+ f"in '{proto_class.__name__}':\n{value}"
620
+ )
621
+
622
+
623
+ def _record_value_from_proto(value_proto: Message) -> Any:
624
+ """Deserialize `*RecordValue` from ProtoBuf."""
625
+ value_field = cast(str, value_proto.WhichOneof("value"))
626
+ if value_field.endswith("list"):
627
+ value = list(getattr(value_proto, value_field).vals)
628
+ else:
629
+ value = getattr(value_proto, value_field)
630
+ return value
631
+
632
+
633
+ def _record_value_dict_to_proto(
634
+ value_dict: Dict[str, Any], allowed_types: List[type], value_proto_class: Type[T]
635
+ ) -> Dict[str, T]:
636
+ """Serialize the record value dict to ProtoBuf."""
637
+
638
+ def proto(_v: Any) -> T:
639
+ return _record_value_to_proto(_v, allowed_types, value_proto_class)
640
+
641
+ return {k: proto(v) for k, v in value_dict.items()}
642
+
643
+
644
+ def _record_value_dict_from_proto(
645
+ value_dict_proto: MutableMapping[str, Any]
646
+ ) -> Dict[str, Any]:
647
+ """Deserialize the record value dict from ProtoBuf."""
648
+ return {k: _record_value_from_proto(v) for k, v in value_dict_proto.items()}
649
+
650
+
651
+ def array_to_proto(array: Array) -> ProtoArray:
652
+ """Serialize Array to ProtoBuf."""
653
+ return ProtoArray(**vars(array))
654
+
655
+
656
+ def array_from_proto(array_proto: ProtoArray) -> Array:
657
+ """Deserialize Array from ProtoBuf."""
658
+ return Array(
659
+ dtype=array_proto.dtype,
660
+ shape=list(array_proto.shape),
661
+ stype=array_proto.stype,
662
+ data=array_proto.data,
663
+ )
664
+
665
+
666
+ def parameters_record_to_proto(record: ParametersRecord) -> ProtoParametersRecord:
667
+ """Serialize ParametersRecord to ProtoBuf."""
668
+ return ProtoParametersRecord(
669
+ data_keys=record.data.keys(),
670
+ data_values=map(array_to_proto, record.data.values()),
671
+ )
672
+
673
+
674
+ def parameters_record_from_proto(
675
+ record_proto: ProtoParametersRecord,
676
+ ) -> ParametersRecord:
677
+ """Deserialize ParametersRecord from ProtoBuf."""
678
+ return ParametersRecord(
679
+ array_dict=OrderedDict(
680
+ zip(record_proto.data_keys, map(array_from_proto, record_proto.data_values))
681
+ ),
682
+ keep_input=False,
683
+ )
684
+
685
+
686
+ def metrics_record_to_proto(record: MetricsRecord) -> ProtoMetricsRecord:
687
+ """Serialize MetricsRecord to ProtoBuf."""
688
+ return ProtoMetricsRecord(
689
+ data=_record_value_dict_to_proto(
690
+ record.data, [float, int], ProtoMetricsRecordValue
691
+ )
692
+ )
693
+
694
+
695
+ def metrics_record_from_proto(record_proto: ProtoMetricsRecord) -> MetricsRecord:
696
+ """Deserialize MetricsRecord from ProtoBuf."""
697
+ return MetricsRecord(
698
+ metrics_dict=cast(
699
+ Dict[str, typing.MetricsRecordValues],
700
+ _record_value_dict_from_proto(record_proto.data),
701
+ ),
702
+ keep_input=False,
703
+ )
704
+
705
+
706
+ def configs_record_to_proto(record: ConfigsRecord) -> ProtoConfigsRecord:
707
+ """Serialize ConfigsRecord to ProtoBuf."""
708
+ return ProtoConfigsRecord(
709
+ data=_record_value_dict_to_proto(
710
+ record.data, [int, float, bool, str, bytes], ProtoConfigsRecordValue
711
+ )
712
+ )
713
+
714
+
715
+ def configs_record_from_proto(record_proto: ProtoConfigsRecord) -> ConfigsRecord:
716
+ """Deserialize ConfigsRecord from ProtoBuf."""
717
+ return ConfigsRecord(
718
+ configs_dict=cast(
719
+ Dict[str, typing.ConfigsRecordValues],
720
+ _record_value_dict_from_proto(record_proto.data),
721
+ ),
722
+ keep_input=False,
723
+ )
724
+
725
+
726
+ # === RecordSet message ===
727
+
728
+
729
+ def recordset_to_proto(recordset: RecordSet) -> ProtoRecordSet:
730
+ """Serialize RecordSet to ProtoBuf."""
731
+ return ProtoRecordSet(
732
+ parameters={
733
+ k: parameters_record_to_proto(v) for k, v in recordset.parameters.items()
734
+ },
735
+ metrics={k: metrics_record_to_proto(v) for k, v in recordset.metrics.items()},
736
+ configs={k: configs_record_to_proto(v) for k, v in recordset.configs.items()},
737
+ )
738
+
739
+
740
+ def recordset_from_proto(recordset_proto: ProtoRecordSet) -> RecordSet:
741
+ """Deserialize RecordSet from ProtoBuf."""
742
+ return RecordSet(
743
+ parameters={
744
+ k: parameters_record_from_proto(v)
745
+ for k, v in recordset_proto.parameters.items()
746
+ },
747
+ metrics={
748
+ k: metrics_record_from_proto(v) for k, v in recordset_proto.metrics.items()
749
+ },
750
+ configs={
751
+ k: configs_record_from_proto(v) for k, v in recordset_proto.configs.items()
752
+ },
753
+ )