flwr-nightly 1.7.0.dev20240119__py3-none-any.whl → 1.7.0.dev20240123__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/common/configsrecord.py +11 -2
- flwr/common/metricsrecord.py +11 -2
- flwr/common/recordset_compat.py +401 -0
- flwr/common/serde.py +195 -17
- flwr/proto/recordset_pb2.py +68 -0
- flwr/proto/recordset_pb2.pyi +305 -0
- flwr/proto/recordset_pb2_grpc.py +4 -0
- flwr/proto/recordset_pb2_grpc.pyi +4 -0
- flwr/proto/task_pb2.py +16 -23
- flwr/proto/task_pb2.pyi +20 -70
- {flwr_nightly-1.7.0.dev20240119.dist-info → flwr_nightly-1.7.0.dev20240123.dist-info}/METADATA +1 -1
- {flwr_nightly-1.7.0.dev20240119.dist-info → flwr_nightly-1.7.0.dev20240123.dist-info}/RECORD +15 -11
- flwr/common/recordset_utils.py +0 -87
- {flwr_nightly-1.7.0.dev20240119.dist-info → flwr_nightly-1.7.0.dev20240123.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.7.0.dev20240119.dist-info → flwr_nightly-1.7.0.dev20240123.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.7.0.dev20240119.dist-info → flwr_nightly-1.7.0.dev20240123.dist-info}/entry_points.txt +0 -0
flwr/common/configsrecord.py
CHANGED
@@ -87,8 +87,17 @@ class ConfigsRecord:
|
|
87
87
|
# 1s to check 10M element list on a M2 Pro
|
88
88
|
# In such settings, you'd be better of treating such config as
|
89
89
|
# an array and pass it to a ParametersRecord.
|
90
|
-
|
91
|
-
|
90
|
+
# Empty lists are valid
|
91
|
+
if len(value) > 0:
|
92
|
+
is_valid(value[0])
|
93
|
+
# all elements in the list must be of the same valid type
|
94
|
+
# this is needed for protobuf
|
95
|
+
value_type = type(value[0])
|
96
|
+
if not all(isinstance(v, value_type) for v in value):
|
97
|
+
raise TypeError(
|
98
|
+
"All values in a list must be of the same valid type. "
|
99
|
+
f"One of {ConfigsScalar}."
|
100
|
+
)
|
92
101
|
else:
|
93
102
|
is_valid(value)
|
94
103
|
|
flwr/common/metricsrecord.py
CHANGED
@@ -87,8 +87,17 @@ class MetricsRecord:
|
|
87
87
|
# 1s to check 10M element list on a M2 Pro
|
88
88
|
# In such settings, you'd be better of treating such metric as
|
89
89
|
# an array and pass it to a ParametersRecord.
|
90
|
-
|
91
|
-
|
90
|
+
# Empty lists are valid
|
91
|
+
if len(value) > 0:
|
92
|
+
is_valid(value[0])
|
93
|
+
# all elements in the list must be of the same valid type
|
94
|
+
# this is needed for protobuf
|
95
|
+
value_type = type(value[0])
|
96
|
+
if not all(isinstance(v, value_type) for v in value):
|
97
|
+
raise TypeError(
|
98
|
+
"All values in a list must be of the same valid type. "
|
99
|
+
f"One of {MetricsScalar}."
|
100
|
+
)
|
92
101
|
else:
|
93
102
|
is_valid(value)
|
94
103
|
|
@@ -0,0 +1,401 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""RecordSet utilities."""
|
16
|
+
|
17
|
+
|
18
|
+
from typing import Dict, Mapping, OrderedDict, Tuple, Union, cast, get_args
|
19
|
+
|
20
|
+
from .configsrecord import ConfigsRecord
|
21
|
+
from .metricsrecord import MetricsRecord
|
22
|
+
from .parametersrecord import Array, ParametersRecord
|
23
|
+
from .recordset import RecordSet
|
24
|
+
from .typing import (
|
25
|
+
Code,
|
26
|
+
ConfigsRecordValues,
|
27
|
+
EvaluateIns,
|
28
|
+
EvaluateRes,
|
29
|
+
FitIns,
|
30
|
+
FitRes,
|
31
|
+
GetParametersIns,
|
32
|
+
GetParametersRes,
|
33
|
+
GetPropertiesIns,
|
34
|
+
GetPropertiesRes,
|
35
|
+
MetricsRecordValues,
|
36
|
+
Parameters,
|
37
|
+
Scalar,
|
38
|
+
Status,
|
39
|
+
)
|
40
|
+
|
41
|
+
|
42
|
+
def parametersrecord_to_parameters(
|
43
|
+
record: ParametersRecord, keep_input: bool = False
|
44
|
+
) -> Parameters:
|
45
|
+
"""Convert ParameterRecord to legacy Parameters.
|
46
|
+
|
47
|
+
Warning: Because `Arrays` in `ParametersRecord` encode more information of the
|
48
|
+
array-like or tensor-like data (e.g their datatype, shape) than `Parameters` it
|
49
|
+
might not be possible to reconstruct such data structures from `Parameters` objects
|
50
|
+
alone. Additional information or metadta must be provided from elsewhere.
|
51
|
+
|
52
|
+
Parameters
|
53
|
+
----------
|
54
|
+
record : ParametersRecord
|
55
|
+
The record to be conveted into Parameters.
|
56
|
+
keep_input : bool (default: False)
|
57
|
+
A boolean indicating whether entries in the record should be deleted from the
|
58
|
+
input dictionary immediately after adding them to the record.
|
59
|
+
"""
|
60
|
+
parameters = Parameters(tensors=[], tensor_type="")
|
61
|
+
|
62
|
+
for key in list(record.data.keys()):
|
63
|
+
parameters.tensors.append(record[key].data)
|
64
|
+
|
65
|
+
if not parameters.tensor_type:
|
66
|
+
# Setting from first array in record. Recall the warning in the docstrings
|
67
|
+
# of this function.
|
68
|
+
parameters.tensor_type = record[key].stype
|
69
|
+
|
70
|
+
if not keep_input:
|
71
|
+
del record.data[key]
|
72
|
+
|
73
|
+
return parameters
|
74
|
+
|
75
|
+
|
76
|
+
def parameters_to_parametersrecord(
|
77
|
+
parameters: Parameters, keep_input: bool = False
|
78
|
+
) -> ParametersRecord:
|
79
|
+
"""Convert legacy Parameters into a single ParametersRecord.
|
80
|
+
|
81
|
+
Because there is no concept of names in the legacy Parameters, arbitrary keys will
|
82
|
+
be used when constructing the ParametersRecord. Similarly, the shape and data type
|
83
|
+
won't be recorded in the Array objects.
|
84
|
+
|
85
|
+
Parameters
|
86
|
+
----------
|
87
|
+
parameters : Parameters
|
88
|
+
Parameters object to be represented as a ParametersRecord.
|
89
|
+
keep_input : bool (default: False)
|
90
|
+
A boolean indicating whether parameters should be deleted from the input
|
91
|
+
Parameters object (i.e. a list of serialized NumPy arrays) immediately after
|
92
|
+
adding them to the record.
|
93
|
+
"""
|
94
|
+
tensor_type = parameters.tensor_type
|
95
|
+
|
96
|
+
p_record = ParametersRecord()
|
97
|
+
|
98
|
+
num_arrays = len(parameters.tensors)
|
99
|
+
for idx in range(num_arrays):
|
100
|
+
if keep_input:
|
101
|
+
tensor = parameters.tensors[idx]
|
102
|
+
else:
|
103
|
+
tensor = parameters.tensors.pop(0)
|
104
|
+
p_record.set_parameters(
|
105
|
+
OrderedDict(
|
106
|
+
{str(idx): Array(data=tensor, dtype="", stype=tensor_type, shape=[])}
|
107
|
+
)
|
108
|
+
)
|
109
|
+
|
110
|
+
return p_record
|
111
|
+
|
112
|
+
|
113
|
+
def _check_mapping_from_recordscalartype_to_scalar(
|
114
|
+
record_data: Mapping[str, Union[ConfigsRecordValues, MetricsRecordValues]]
|
115
|
+
) -> Dict[str, Scalar]:
|
116
|
+
"""Check mapping `common.*RecordValues` into `common.Scalar` is possible."""
|
117
|
+
for value in record_data.values():
|
118
|
+
if not isinstance(value, get_args(Scalar)):
|
119
|
+
raise TypeError(
|
120
|
+
"There is not a 1:1 mapping between `common.Scalar` types and those "
|
121
|
+
"supported in `common.ConfigsRecordValues` or "
|
122
|
+
"`common.ConfigsRecordValues`. Consider casting your values to a type "
|
123
|
+
"supported by the `common.RecordSet` infrastructure. "
|
124
|
+
f"You used type: {type(value)}"
|
125
|
+
)
|
126
|
+
return cast(Dict[str, Scalar], record_data)
|
127
|
+
|
128
|
+
|
129
|
+
def _recordset_to_fit_or_evaluate_ins_components(
|
130
|
+
recordset: RecordSet,
|
131
|
+
ins_str: str,
|
132
|
+
keep_input: bool,
|
133
|
+
) -> Tuple[Parameters, Dict[str, Scalar]]:
|
134
|
+
"""Derive Fit/Evaluate Ins from a RecordSet."""
|
135
|
+
# get Array and construct Parameters
|
136
|
+
parameters_record = recordset.get_parameters(f"{ins_str}.parameters")
|
137
|
+
|
138
|
+
parameters = parametersrecord_to_parameters(
|
139
|
+
parameters_record, keep_input=keep_input
|
140
|
+
)
|
141
|
+
|
142
|
+
# get config dict
|
143
|
+
config_record = recordset.get_configs(f"{ins_str}.config")
|
144
|
+
|
145
|
+
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record.data)
|
146
|
+
|
147
|
+
return parameters, config_dict
|
148
|
+
|
149
|
+
|
150
|
+
def _fit_or_evaluate_ins_to_recordset(
|
151
|
+
ins: Union[FitIns, EvaluateIns], keep_input: bool
|
152
|
+
) -> RecordSet:
|
153
|
+
recordset = RecordSet()
|
154
|
+
|
155
|
+
ins_str = "fitins" if isinstance(ins, FitIns) else "evaluateins"
|
156
|
+
recordset.set_parameters(
|
157
|
+
name=f"{ins_str}.parameters",
|
158
|
+
record=parameters_to_parametersrecord(ins.parameters, keep_input=keep_input),
|
159
|
+
)
|
160
|
+
|
161
|
+
recordset.set_configs(
|
162
|
+
name=f"{ins_str}.config", record=ConfigsRecord(ins.config) # type: ignore
|
163
|
+
)
|
164
|
+
|
165
|
+
return recordset
|
166
|
+
|
167
|
+
|
168
|
+
def _embed_status_into_recordset(
|
169
|
+
res_str: str, status: Status, recordset: RecordSet
|
170
|
+
) -> RecordSet:
|
171
|
+
status_dict: Dict[str, ConfigsRecordValues] = {
|
172
|
+
"code": int(status.code.value),
|
173
|
+
"message": status.message,
|
174
|
+
}
|
175
|
+
# we add it to a `ConfigsRecord`` because the `status.message`` is a string
|
176
|
+
# and `str` values aren't supported in `MetricsRecords`
|
177
|
+
recordset.set_configs(f"{res_str}.status", record=ConfigsRecord(status_dict))
|
178
|
+
return recordset
|
179
|
+
|
180
|
+
|
181
|
+
def _extract_status_from_recordset(res_str: str, recordset: RecordSet) -> Status:
|
182
|
+
status = recordset.get_configs(f"{res_str}.status")
|
183
|
+
code = cast(int, status["code"])
|
184
|
+
return Status(code=Code(code), message=str(status["message"]))
|
185
|
+
|
186
|
+
|
187
|
+
def recordset_to_fitins(recordset: RecordSet, keep_input: bool) -> FitIns:
|
188
|
+
"""Derive FitIns from a RecordSet object."""
|
189
|
+
parameters, config = _recordset_to_fit_or_evaluate_ins_components(
|
190
|
+
recordset,
|
191
|
+
ins_str="fitins",
|
192
|
+
keep_input=keep_input,
|
193
|
+
)
|
194
|
+
|
195
|
+
return FitIns(parameters=parameters, config=config)
|
196
|
+
|
197
|
+
|
198
|
+
def fitins_to_recordset(fitins: FitIns, keep_input: bool) -> RecordSet:
|
199
|
+
"""Construct a RecordSet from a FitIns object."""
|
200
|
+
return _fit_or_evaluate_ins_to_recordset(fitins, keep_input)
|
201
|
+
|
202
|
+
|
203
|
+
def recordset_to_fitres(recordset: RecordSet, keep_input: bool) -> FitRes:
|
204
|
+
"""Derive FitRes from a RecordSet object."""
|
205
|
+
ins_str = "fitres"
|
206
|
+
parameters = parametersrecord_to_parameters(
|
207
|
+
recordset.get_parameters(f"{ins_str}.parameters"), keep_input=keep_input
|
208
|
+
)
|
209
|
+
|
210
|
+
num_examples = cast(
|
211
|
+
int, recordset.get_metrics(f"{ins_str}.num_examples")["num_examples"]
|
212
|
+
)
|
213
|
+
configs_record = recordset.get_configs(f"{ins_str}.metrics")
|
214
|
+
|
215
|
+
metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record.data)
|
216
|
+
status = _extract_status_from_recordset(ins_str, recordset)
|
217
|
+
|
218
|
+
return FitRes(
|
219
|
+
status=status, parameters=parameters, num_examples=num_examples, metrics=metrics
|
220
|
+
)
|
221
|
+
|
222
|
+
|
223
|
+
def fitres_to_recordset(fitres: FitRes, keep_input: bool) -> RecordSet:
|
224
|
+
"""Construct a RecordSet from a FitRes object."""
|
225
|
+
recordset = RecordSet()
|
226
|
+
|
227
|
+
res_str = "fitres"
|
228
|
+
|
229
|
+
recordset.set_configs(
|
230
|
+
name=f"{res_str}.metrics", record=ConfigsRecord(fitres.metrics) # type: ignore
|
231
|
+
)
|
232
|
+
recordset.set_metrics(
|
233
|
+
name=f"{res_str}.num_examples",
|
234
|
+
record=MetricsRecord({"num_examples": fitres.num_examples}),
|
235
|
+
)
|
236
|
+
recordset.set_parameters(
|
237
|
+
name=f"{res_str}.parameters",
|
238
|
+
record=parameters_to_parametersrecord(fitres.parameters, keep_input),
|
239
|
+
)
|
240
|
+
|
241
|
+
# status
|
242
|
+
recordset = _embed_status_into_recordset(res_str, fitres.status, recordset)
|
243
|
+
|
244
|
+
return recordset
|
245
|
+
|
246
|
+
|
247
|
+
def recordset_to_evaluateins(recordset: RecordSet, keep_input: bool) -> EvaluateIns:
|
248
|
+
"""Derive EvaluateIns from a RecordSet object."""
|
249
|
+
parameters, config = _recordset_to_fit_or_evaluate_ins_components(
|
250
|
+
recordset,
|
251
|
+
ins_str="evaluateins",
|
252
|
+
keep_input=keep_input,
|
253
|
+
)
|
254
|
+
|
255
|
+
return EvaluateIns(parameters=parameters, config=config)
|
256
|
+
|
257
|
+
|
258
|
+
def evaluateins_to_recordset(evaluateins: EvaluateIns, keep_input: bool) -> RecordSet:
|
259
|
+
"""Construct a RecordSet from a EvaluateIns object."""
|
260
|
+
return _fit_or_evaluate_ins_to_recordset(evaluateins, keep_input)
|
261
|
+
|
262
|
+
|
263
|
+
def recordset_to_evaluateres(recordset: RecordSet) -> EvaluateRes:
|
264
|
+
"""Derive EvaluateRes from a RecordSet object."""
|
265
|
+
ins_str = "evaluateres"
|
266
|
+
|
267
|
+
loss = cast(int, recordset.get_metrics(f"{ins_str}.loss")["loss"])
|
268
|
+
|
269
|
+
num_examples = cast(
|
270
|
+
int, recordset.get_metrics(f"{ins_str}.num_examples")["num_examples"]
|
271
|
+
)
|
272
|
+
configs_record = recordset.get_configs(f"{ins_str}.metrics")
|
273
|
+
|
274
|
+
metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record.data)
|
275
|
+
status = _extract_status_from_recordset(ins_str, recordset)
|
276
|
+
|
277
|
+
return EvaluateRes(
|
278
|
+
status=status, loss=loss, num_examples=num_examples, metrics=metrics
|
279
|
+
)
|
280
|
+
|
281
|
+
|
282
|
+
def evaluateres_to_recordset(evaluateres: EvaluateRes) -> RecordSet:
|
283
|
+
"""Construct a RecordSet from a EvaluateRes object."""
|
284
|
+
recordset = RecordSet()
|
285
|
+
|
286
|
+
res_str = "evaluateres"
|
287
|
+
# loss
|
288
|
+
recordset.set_metrics(
|
289
|
+
name=f"{res_str}.loss",
|
290
|
+
record=MetricsRecord({"loss": evaluateres.loss}),
|
291
|
+
)
|
292
|
+
|
293
|
+
# num_examples
|
294
|
+
recordset.set_metrics(
|
295
|
+
name=f"{res_str}.num_examples",
|
296
|
+
record=MetricsRecord({"num_examples": evaluateres.num_examples}),
|
297
|
+
)
|
298
|
+
|
299
|
+
# metrics
|
300
|
+
recordset.set_configs(
|
301
|
+
name=f"{res_str}.metrics",
|
302
|
+
record=ConfigsRecord(evaluateres.metrics), # type: ignore
|
303
|
+
)
|
304
|
+
|
305
|
+
# status
|
306
|
+
recordset = _embed_status_into_recordset(
|
307
|
+
f"{res_str}", evaluateres.status, recordset
|
308
|
+
)
|
309
|
+
|
310
|
+
return recordset
|
311
|
+
|
312
|
+
|
313
|
+
def recordset_to_getparametersins(recordset: RecordSet) -> GetParametersIns:
|
314
|
+
"""Derive GetParametersIns from a RecordSet object."""
|
315
|
+
config_record = recordset.get_configs("getparametersins.config")
|
316
|
+
|
317
|
+
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record.data)
|
318
|
+
|
319
|
+
return GetParametersIns(config=config_dict)
|
320
|
+
|
321
|
+
|
322
|
+
def getparametersins_to_recordset(getparameters_ins: GetParametersIns) -> RecordSet:
|
323
|
+
"""Construct a RecordSet from a GetParametersIns object."""
|
324
|
+
recordset = RecordSet()
|
325
|
+
|
326
|
+
recordset.set_configs(
|
327
|
+
name="getparametersins.config",
|
328
|
+
record=ConfigsRecord(getparameters_ins.config), # type: ignore
|
329
|
+
)
|
330
|
+
return recordset
|
331
|
+
|
332
|
+
|
333
|
+
def getparametersres_to_recordset(getparametersres: GetParametersRes) -> RecordSet:
|
334
|
+
"""Construct a RecordSet from a GetParametersRes object."""
|
335
|
+
recordset = RecordSet()
|
336
|
+
res_str = "getparametersres"
|
337
|
+
parameters_record = parameters_to_parametersrecord(getparametersres.parameters)
|
338
|
+
recordset.set_parameters(f"{res_str}.parameters", parameters_record)
|
339
|
+
|
340
|
+
# status
|
341
|
+
recordset = _embed_status_into_recordset(
|
342
|
+
res_str, getparametersres.status, recordset
|
343
|
+
)
|
344
|
+
|
345
|
+
return recordset
|
346
|
+
|
347
|
+
|
348
|
+
def recordset_to_getparametersres(recordset: RecordSet) -> GetParametersRes:
|
349
|
+
"""Derive GetParametersRes from a RecordSet object."""
|
350
|
+
res_str = "getparametersres"
|
351
|
+
parameters = parametersrecord_to_parameters(
|
352
|
+
recordset.get_parameters(f"{res_str}.parameters")
|
353
|
+
)
|
354
|
+
|
355
|
+
status = _extract_status_from_recordset(res_str, recordset)
|
356
|
+
return GetParametersRes(status=status, parameters=parameters)
|
357
|
+
|
358
|
+
|
359
|
+
def recordset_to_getpropertiesins(recordset: RecordSet) -> GetPropertiesIns:
|
360
|
+
"""Derive GetPropertiesIns from a RecordSet object."""
|
361
|
+
config_record = recordset.get_configs("getpropertiesins.config")
|
362
|
+
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record.data)
|
363
|
+
|
364
|
+
return GetPropertiesIns(config=config_dict)
|
365
|
+
|
366
|
+
|
367
|
+
def getpropertiesins_to_recordset(getpropertiesins: GetPropertiesIns) -> RecordSet:
|
368
|
+
"""Construct a RecordSet from a GetPropertiesRes object."""
|
369
|
+
recordset = RecordSet()
|
370
|
+
recordset.set_configs(
|
371
|
+
name="getpropertiesins.config",
|
372
|
+
record=ConfigsRecord(getpropertiesins.config), # type: ignore
|
373
|
+
)
|
374
|
+
return recordset
|
375
|
+
|
376
|
+
|
377
|
+
def recordset_to_getpropertiesres(recordset: RecordSet) -> GetPropertiesRes:
|
378
|
+
"""Derive GetPropertiesRes from a RecordSet object."""
|
379
|
+
res_str = "getpropertiesres"
|
380
|
+
config_record = recordset.get_configs(f"{res_str}.properties")
|
381
|
+
properties = _check_mapping_from_recordscalartype_to_scalar(config_record.data)
|
382
|
+
|
383
|
+
status = _extract_status_from_recordset(res_str, recordset=recordset)
|
384
|
+
|
385
|
+
return GetPropertiesRes(status=status, properties=properties)
|
386
|
+
|
387
|
+
|
388
|
+
def getpropertiesres_to_recordset(getpropertiesres: GetPropertiesRes) -> RecordSet:
|
389
|
+
"""Construct a RecordSet from a GetPropertiesRes object."""
|
390
|
+
recordset = RecordSet()
|
391
|
+
res_str = "getpropertiesres"
|
392
|
+
recordset.set_configs(
|
393
|
+
name=f"{res_str}.properties",
|
394
|
+
record=ConfigsRecord(getpropertiesres.properties), # type: ignore
|
395
|
+
)
|
396
|
+
# status
|
397
|
+
recordset = _embed_status_into_recordset(
|
398
|
+
res_str, getpropertiesres.status, recordset
|
399
|
+
)
|
400
|
+
|
401
|
+
return recordset
|
flwr/common/serde.py
CHANGED
@@ -15,10 +15,23 @@
|
|
15
15
|
"""ProtoBuf serialization and deserialization."""
|
16
16
|
|
17
17
|
|
18
|
-
from typing import Any, Dict, List, MutableMapping, cast
|
19
|
-
|
20
|
-
from
|
21
|
-
|
18
|
+
from typing import Any, Dict, List, MutableMapping, OrderedDict, Type, TypeVar, cast
|
19
|
+
|
20
|
+
from google.protobuf.message import Message
|
21
|
+
|
22
|
+
# pylint: disable=E0611
|
23
|
+
from flwr.proto.recordset_pb2 import Array as ProtoArray
|
24
|
+
from flwr.proto.recordset_pb2 import BoolList, BytesList
|
25
|
+
from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord
|
26
|
+
from flwr.proto.recordset_pb2 import ConfigsRecordValue as ProtoConfigsRecordValue
|
27
|
+
from flwr.proto.recordset_pb2 import DoubleList
|
28
|
+
from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord
|
29
|
+
from flwr.proto.recordset_pb2 import MetricsRecordValue as ProtoMetricsRecordValue
|
30
|
+
from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord
|
31
|
+
from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
|
32
|
+
from flwr.proto.recordset_pb2 import Sint64List, StringList
|
33
|
+
from flwr.proto.task_pb2 import Value
|
34
|
+
from flwr.proto.transport_pb2 import (
|
22
35
|
ClientMessage,
|
23
36
|
Code,
|
24
37
|
Parameters,
|
@@ -28,7 +41,12 @@ from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
|
28
41
|
Status,
|
29
42
|
)
|
30
43
|
|
44
|
+
# pylint: enable=E0611
|
31
45
|
from . import typing
|
46
|
+
from .configsrecord import ConfigsRecord
|
47
|
+
from .metricsrecord import MetricsRecord
|
48
|
+
from .parametersrecord import Array, ParametersRecord
|
49
|
+
from .recordset import RecordSet
|
32
50
|
|
33
51
|
# === ServerMessage message ===
|
34
52
|
|
@@ -493,7 +511,7 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar:
|
|
493
511
|
# === Value messages ===
|
494
512
|
|
495
513
|
|
496
|
-
|
514
|
+
_type_to_field = {
|
497
515
|
float: "double",
|
498
516
|
int: "sint64",
|
499
517
|
bool: "bool",
|
@@ -502,22 +520,20 @@ _python_type_to_field_name = {
|
|
502
520
|
}
|
503
521
|
|
504
522
|
|
505
|
-
|
506
|
-
float: (
|
507
|
-
int: (
|
508
|
-
bool: (
|
509
|
-
str: (
|
510
|
-
bytes: (
|
523
|
+
_list_type_to_class_and_field = {
|
524
|
+
float: (DoubleList, "double_list"),
|
525
|
+
int: (Sint64List, "sint64_list"),
|
526
|
+
bool: (BoolList, "bool_list"),
|
527
|
+
str: (StringList, "string_list"),
|
528
|
+
bytes: (BytesList, "bytes_list"),
|
511
529
|
}
|
512
530
|
|
513
531
|
|
514
532
|
def _check_value(value: typing.Value) -> None:
|
515
|
-
if isinstance(value, tuple(
|
533
|
+
if isinstance(value, tuple(_type_to_field.keys())):
|
516
534
|
return
|
517
535
|
if isinstance(value, list):
|
518
|
-
if len(value) > 0 and isinstance(
|
519
|
-
value[0], tuple(_python_type_to_field_name.keys())
|
520
|
-
):
|
536
|
+
if len(value) > 0 and isinstance(value[0], tuple(_type_to_field.keys())):
|
521
537
|
data_type = type(value[0])
|
522
538
|
for element in value:
|
523
539
|
if isinstance(element, data_type):
|
@@ -539,12 +555,12 @@ def value_to_proto(value: typing.Value) -> Value:
|
|
539
555
|
|
540
556
|
arg = {}
|
541
557
|
if isinstance(value, list):
|
542
|
-
msg_class, field_name =
|
558
|
+
msg_class, field_name = _list_type_to_class_and_field[
|
543
559
|
type(value[0]) if len(value) > 0 else int
|
544
560
|
]
|
545
561
|
arg[field_name] = msg_class(vals=value)
|
546
562
|
else:
|
547
|
-
arg[
|
563
|
+
arg[_type_to_field[type(value)]] = value
|
548
564
|
return Value(**arg)
|
549
565
|
|
550
566
|
|
@@ -573,3 +589,165 @@ def named_values_from_proto(
|
|
573
589
|
) -> Dict[str, typing.Value]:
|
574
590
|
"""Deserialize named values from ProtoBuf."""
|
575
591
|
return {name: value_from_proto(value) for name, value in named_values_proto.items()}
|
592
|
+
|
593
|
+
|
594
|
+
# === Record messages ===
|
595
|
+
|
596
|
+
|
597
|
+
T = TypeVar("T")
|
598
|
+
|
599
|
+
|
600
|
+
def _record_value_to_proto(
|
601
|
+
value: Any, allowed_types: List[type], proto_class: Type[T]
|
602
|
+
) -> T:
|
603
|
+
"""Serialize `*RecordValue` to ProtoBuf."""
|
604
|
+
arg = {}
|
605
|
+
for t in allowed_types:
|
606
|
+
# Single element
|
607
|
+
# Note: `isinstance(False, int) == True`.
|
608
|
+
if type(value) == t: # pylint: disable=C0123
|
609
|
+
arg[_type_to_field[t]] = value
|
610
|
+
return proto_class(**arg)
|
611
|
+
# List
|
612
|
+
if isinstance(value, list) and all(isinstance(item, t) for item in value):
|
613
|
+
list_class, field_name = _list_type_to_class_and_field[t]
|
614
|
+
arg[field_name] = list_class(vals=value)
|
615
|
+
return proto_class(**arg)
|
616
|
+
# Invalid types
|
617
|
+
raise TypeError(
|
618
|
+
f"The type of the following value is not allowed "
|
619
|
+
f"in '{proto_class.__name__}':\n{value}"
|
620
|
+
)
|
621
|
+
|
622
|
+
|
623
|
+
def _record_value_from_proto(value_proto: Message) -> Any:
|
624
|
+
"""Deserialize `*RecordValue` from ProtoBuf."""
|
625
|
+
value_field = cast(str, value_proto.WhichOneof("value"))
|
626
|
+
if value_field.endswith("list"):
|
627
|
+
value = list(getattr(value_proto, value_field).vals)
|
628
|
+
else:
|
629
|
+
value = getattr(value_proto, value_field)
|
630
|
+
return value
|
631
|
+
|
632
|
+
|
633
|
+
def _record_value_dict_to_proto(
|
634
|
+
value_dict: Dict[str, Any], allowed_types: List[type], value_proto_class: Type[T]
|
635
|
+
) -> Dict[str, T]:
|
636
|
+
"""Serialize the record value dict to ProtoBuf."""
|
637
|
+
|
638
|
+
def proto(_v: Any) -> T:
|
639
|
+
return _record_value_to_proto(_v, allowed_types, value_proto_class)
|
640
|
+
|
641
|
+
return {k: proto(v) for k, v in value_dict.items()}
|
642
|
+
|
643
|
+
|
644
|
+
def _record_value_dict_from_proto(
|
645
|
+
value_dict_proto: MutableMapping[str, Any]
|
646
|
+
) -> Dict[str, Any]:
|
647
|
+
"""Deserialize the record value dict from ProtoBuf."""
|
648
|
+
return {k: _record_value_from_proto(v) for k, v in value_dict_proto.items()}
|
649
|
+
|
650
|
+
|
651
|
+
def array_to_proto(array: Array) -> ProtoArray:
|
652
|
+
"""Serialize Array to ProtoBuf."""
|
653
|
+
return ProtoArray(**vars(array))
|
654
|
+
|
655
|
+
|
656
|
+
def array_from_proto(array_proto: ProtoArray) -> Array:
|
657
|
+
"""Deserialize Array from ProtoBuf."""
|
658
|
+
return Array(
|
659
|
+
dtype=array_proto.dtype,
|
660
|
+
shape=list(array_proto.shape),
|
661
|
+
stype=array_proto.stype,
|
662
|
+
data=array_proto.data,
|
663
|
+
)
|
664
|
+
|
665
|
+
|
666
|
+
def parameters_record_to_proto(record: ParametersRecord) -> ProtoParametersRecord:
|
667
|
+
"""Serialize ParametersRecord to ProtoBuf."""
|
668
|
+
return ProtoParametersRecord(
|
669
|
+
data_keys=record.data.keys(),
|
670
|
+
data_values=map(array_to_proto, record.data.values()),
|
671
|
+
)
|
672
|
+
|
673
|
+
|
674
|
+
def parameters_record_from_proto(
|
675
|
+
record_proto: ProtoParametersRecord,
|
676
|
+
) -> ParametersRecord:
|
677
|
+
"""Deserialize ParametersRecord from ProtoBuf."""
|
678
|
+
return ParametersRecord(
|
679
|
+
array_dict=OrderedDict(
|
680
|
+
zip(record_proto.data_keys, map(array_from_proto, record_proto.data_values))
|
681
|
+
),
|
682
|
+
keep_input=False,
|
683
|
+
)
|
684
|
+
|
685
|
+
|
686
|
+
def metrics_record_to_proto(record: MetricsRecord) -> ProtoMetricsRecord:
|
687
|
+
"""Serialize MetricsRecord to ProtoBuf."""
|
688
|
+
return ProtoMetricsRecord(
|
689
|
+
data=_record_value_dict_to_proto(
|
690
|
+
record.data, [float, int], ProtoMetricsRecordValue
|
691
|
+
)
|
692
|
+
)
|
693
|
+
|
694
|
+
|
695
|
+
def metrics_record_from_proto(record_proto: ProtoMetricsRecord) -> MetricsRecord:
|
696
|
+
"""Deserialize MetricsRecord from ProtoBuf."""
|
697
|
+
return MetricsRecord(
|
698
|
+
metrics_dict=cast(
|
699
|
+
Dict[str, typing.MetricsRecordValues],
|
700
|
+
_record_value_dict_from_proto(record_proto.data),
|
701
|
+
),
|
702
|
+
keep_input=False,
|
703
|
+
)
|
704
|
+
|
705
|
+
|
706
|
+
def configs_record_to_proto(record: ConfigsRecord) -> ProtoConfigsRecord:
|
707
|
+
"""Serialize ConfigsRecord to ProtoBuf."""
|
708
|
+
return ProtoConfigsRecord(
|
709
|
+
data=_record_value_dict_to_proto(
|
710
|
+
record.data, [int, float, bool, str, bytes], ProtoConfigsRecordValue
|
711
|
+
)
|
712
|
+
)
|
713
|
+
|
714
|
+
|
715
|
+
def configs_record_from_proto(record_proto: ProtoConfigsRecord) -> ConfigsRecord:
|
716
|
+
"""Deserialize ConfigsRecord from ProtoBuf."""
|
717
|
+
return ConfigsRecord(
|
718
|
+
configs_dict=cast(
|
719
|
+
Dict[str, typing.ConfigsRecordValues],
|
720
|
+
_record_value_dict_from_proto(record_proto.data),
|
721
|
+
),
|
722
|
+
keep_input=False,
|
723
|
+
)
|
724
|
+
|
725
|
+
|
726
|
+
# === RecordSet message ===
|
727
|
+
|
728
|
+
|
729
|
+
def recordset_to_proto(recordset: RecordSet) -> ProtoRecordSet:
|
730
|
+
"""Serialize RecordSet to ProtoBuf."""
|
731
|
+
return ProtoRecordSet(
|
732
|
+
parameters={
|
733
|
+
k: parameters_record_to_proto(v) for k, v in recordset.parameters.items()
|
734
|
+
},
|
735
|
+
metrics={k: metrics_record_to_proto(v) for k, v in recordset.metrics.items()},
|
736
|
+
configs={k: configs_record_to_proto(v) for k, v in recordset.configs.items()},
|
737
|
+
)
|
738
|
+
|
739
|
+
|
740
|
+
def recordset_from_proto(recordset_proto: ProtoRecordSet) -> RecordSet:
|
741
|
+
"""Deserialize RecordSet from ProtoBuf."""
|
742
|
+
return RecordSet(
|
743
|
+
parameters={
|
744
|
+
k: parameters_record_from_proto(v)
|
745
|
+
for k, v in recordset_proto.parameters.items()
|
746
|
+
},
|
747
|
+
metrics={
|
748
|
+
k: metrics_record_from_proto(v) for k, v in recordset_proto.metrics.items()
|
749
|
+
},
|
750
|
+
configs={
|
751
|
+
k: configs_record_from_proto(v) for k, v in recordset_proto.configs.items()
|
752
|
+
},
|
753
|
+
)
|