flwr-nightly 1.7.0.dev20240119__py3-none-any.whl → 1.7.0.dev20240123__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- flwr/common/configsrecord.py +11 -2
- flwr/common/metricsrecord.py +11 -2
- flwr/common/recordset_compat.py +401 -0
- flwr/common/serde.py +195 -17
- flwr/proto/recordset_pb2.py +68 -0
- flwr/proto/recordset_pb2.pyi +305 -0
- flwr/proto/recordset_pb2_grpc.py +4 -0
- flwr/proto/recordset_pb2_grpc.pyi +4 -0
- flwr/proto/task_pb2.py +16 -23
- flwr/proto/task_pb2.pyi +20 -70
- {flwr_nightly-1.7.0.dev20240119.dist-info → flwr_nightly-1.7.0.dev20240123.dist-info}/METADATA +1 -1
- {flwr_nightly-1.7.0.dev20240119.dist-info → flwr_nightly-1.7.0.dev20240123.dist-info}/RECORD +15 -11
- flwr/common/recordset_utils.py +0 -87
- {flwr_nightly-1.7.0.dev20240119.dist-info → flwr_nightly-1.7.0.dev20240123.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.7.0.dev20240119.dist-info → flwr_nightly-1.7.0.dev20240123.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.7.0.dev20240119.dist-info → flwr_nightly-1.7.0.dev20240123.dist-info}/entry_points.txt +0 -0
flwr/common/configsrecord.py
CHANGED
@@ -87,8 +87,17 @@ class ConfigsRecord:
|
|
87
87
|
# 1s to check 10M element list on a M2 Pro
|
88
88
|
# In such settings, you'd be better of treating such config as
|
89
89
|
# an array and pass it to a ParametersRecord.
|
90
|
-
|
91
|
-
|
90
|
+
# Empty lists are valid
|
91
|
+
if len(value) > 0:
|
92
|
+
is_valid(value[0])
|
93
|
+
# all elements in the list must be of the same valid type
|
94
|
+
# this is needed for protobuf
|
95
|
+
value_type = type(value[0])
|
96
|
+
if not all(isinstance(v, value_type) for v in value):
|
97
|
+
raise TypeError(
|
98
|
+
"All values in a list must be of the same valid type. "
|
99
|
+
f"One of {ConfigsScalar}."
|
100
|
+
)
|
92
101
|
else:
|
93
102
|
is_valid(value)
|
94
103
|
|
flwr/common/metricsrecord.py
CHANGED
@@ -87,8 +87,17 @@ class MetricsRecord:
|
|
87
87
|
# 1s to check 10M element list on a M2 Pro
|
88
88
|
# In such settings, you'd be better of treating such metric as
|
89
89
|
# an array and pass it to a ParametersRecord.
|
90
|
-
|
91
|
-
|
90
|
+
# Empty lists are valid
|
91
|
+
if len(value) > 0:
|
92
|
+
is_valid(value[0])
|
93
|
+
# all elements in the list must be of the same valid type
|
94
|
+
# this is needed for protobuf
|
95
|
+
value_type = type(value[0])
|
96
|
+
if not all(isinstance(v, value_type) for v in value):
|
97
|
+
raise TypeError(
|
98
|
+
"All values in a list must be of the same valid type. "
|
99
|
+
f"One of {MetricsScalar}."
|
100
|
+
)
|
92
101
|
else:
|
93
102
|
is_valid(value)
|
94
103
|
|
@@ -0,0 +1,401 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""RecordSet utilities."""
|
16
|
+
|
17
|
+
|
18
|
+
from typing import Dict, Mapping, OrderedDict, Tuple, Union, cast, get_args
|
19
|
+
|
20
|
+
from .configsrecord import ConfigsRecord
|
21
|
+
from .metricsrecord import MetricsRecord
|
22
|
+
from .parametersrecord import Array, ParametersRecord
|
23
|
+
from .recordset import RecordSet
|
24
|
+
from .typing import (
|
25
|
+
Code,
|
26
|
+
ConfigsRecordValues,
|
27
|
+
EvaluateIns,
|
28
|
+
EvaluateRes,
|
29
|
+
FitIns,
|
30
|
+
FitRes,
|
31
|
+
GetParametersIns,
|
32
|
+
GetParametersRes,
|
33
|
+
GetPropertiesIns,
|
34
|
+
GetPropertiesRes,
|
35
|
+
MetricsRecordValues,
|
36
|
+
Parameters,
|
37
|
+
Scalar,
|
38
|
+
Status,
|
39
|
+
)
|
40
|
+
|
41
|
+
|
42
|
+
def parametersrecord_to_parameters(
|
43
|
+
record: ParametersRecord, keep_input: bool = False
|
44
|
+
) -> Parameters:
|
45
|
+
"""Convert ParameterRecord to legacy Parameters.
|
46
|
+
|
47
|
+
Warning: Because `Arrays` in `ParametersRecord` encode more information of the
|
48
|
+
array-like or tensor-like data (e.g their datatype, shape) than `Parameters` it
|
49
|
+
might not be possible to reconstruct such data structures from `Parameters` objects
|
50
|
+
alone. Additional information or metadta must be provided from elsewhere.
|
51
|
+
|
52
|
+
Parameters
|
53
|
+
----------
|
54
|
+
record : ParametersRecord
|
55
|
+
The record to be conveted into Parameters.
|
56
|
+
keep_input : bool (default: False)
|
57
|
+
A boolean indicating whether entries in the record should be deleted from the
|
58
|
+
input dictionary immediately after adding them to the record.
|
59
|
+
"""
|
60
|
+
parameters = Parameters(tensors=[], tensor_type="")
|
61
|
+
|
62
|
+
for key in list(record.data.keys()):
|
63
|
+
parameters.tensors.append(record[key].data)
|
64
|
+
|
65
|
+
if not parameters.tensor_type:
|
66
|
+
# Setting from first array in record. Recall the warning in the docstrings
|
67
|
+
# of this function.
|
68
|
+
parameters.tensor_type = record[key].stype
|
69
|
+
|
70
|
+
if not keep_input:
|
71
|
+
del record.data[key]
|
72
|
+
|
73
|
+
return parameters
|
74
|
+
|
75
|
+
|
76
|
+
def parameters_to_parametersrecord(
|
77
|
+
parameters: Parameters, keep_input: bool = False
|
78
|
+
) -> ParametersRecord:
|
79
|
+
"""Convert legacy Parameters into a single ParametersRecord.
|
80
|
+
|
81
|
+
Because there is no concept of names in the legacy Parameters, arbitrary keys will
|
82
|
+
be used when constructing the ParametersRecord. Similarly, the shape and data type
|
83
|
+
won't be recorded in the Array objects.
|
84
|
+
|
85
|
+
Parameters
|
86
|
+
----------
|
87
|
+
parameters : Parameters
|
88
|
+
Parameters object to be represented as a ParametersRecord.
|
89
|
+
keep_input : bool (default: False)
|
90
|
+
A boolean indicating whether parameters should be deleted from the input
|
91
|
+
Parameters object (i.e. a list of serialized NumPy arrays) immediately after
|
92
|
+
adding them to the record.
|
93
|
+
"""
|
94
|
+
tensor_type = parameters.tensor_type
|
95
|
+
|
96
|
+
p_record = ParametersRecord()
|
97
|
+
|
98
|
+
num_arrays = len(parameters.tensors)
|
99
|
+
for idx in range(num_arrays):
|
100
|
+
if keep_input:
|
101
|
+
tensor = parameters.tensors[idx]
|
102
|
+
else:
|
103
|
+
tensor = parameters.tensors.pop(0)
|
104
|
+
p_record.set_parameters(
|
105
|
+
OrderedDict(
|
106
|
+
{str(idx): Array(data=tensor, dtype="", stype=tensor_type, shape=[])}
|
107
|
+
)
|
108
|
+
)
|
109
|
+
|
110
|
+
return p_record
|
111
|
+
|
112
|
+
|
113
|
+
def _check_mapping_from_recordscalartype_to_scalar(
|
114
|
+
record_data: Mapping[str, Union[ConfigsRecordValues, MetricsRecordValues]]
|
115
|
+
) -> Dict[str, Scalar]:
|
116
|
+
"""Check mapping `common.*RecordValues` into `common.Scalar` is possible."""
|
117
|
+
for value in record_data.values():
|
118
|
+
if not isinstance(value, get_args(Scalar)):
|
119
|
+
raise TypeError(
|
120
|
+
"There is not a 1:1 mapping between `common.Scalar` types and those "
|
121
|
+
"supported in `common.ConfigsRecordValues` or "
|
122
|
+
"`common.ConfigsRecordValues`. Consider casting your values to a type "
|
123
|
+
"supported by the `common.RecordSet` infrastructure. "
|
124
|
+
f"You used type: {type(value)}"
|
125
|
+
)
|
126
|
+
return cast(Dict[str, Scalar], record_data)
|
127
|
+
|
128
|
+
|
129
|
+
def _recordset_to_fit_or_evaluate_ins_components(
|
130
|
+
recordset: RecordSet,
|
131
|
+
ins_str: str,
|
132
|
+
keep_input: bool,
|
133
|
+
) -> Tuple[Parameters, Dict[str, Scalar]]:
|
134
|
+
"""Derive Fit/Evaluate Ins from a RecordSet."""
|
135
|
+
# get Array and construct Parameters
|
136
|
+
parameters_record = recordset.get_parameters(f"{ins_str}.parameters")
|
137
|
+
|
138
|
+
parameters = parametersrecord_to_parameters(
|
139
|
+
parameters_record, keep_input=keep_input
|
140
|
+
)
|
141
|
+
|
142
|
+
# get config dict
|
143
|
+
config_record = recordset.get_configs(f"{ins_str}.config")
|
144
|
+
|
145
|
+
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record.data)
|
146
|
+
|
147
|
+
return parameters, config_dict
|
148
|
+
|
149
|
+
|
150
|
+
def _fit_or_evaluate_ins_to_recordset(
|
151
|
+
ins: Union[FitIns, EvaluateIns], keep_input: bool
|
152
|
+
) -> RecordSet:
|
153
|
+
recordset = RecordSet()
|
154
|
+
|
155
|
+
ins_str = "fitins" if isinstance(ins, FitIns) else "evaluateins"
|
156
|
+
recordset.set_parameters(
|
157
|
+
name=f"{ins_str}.parameters",
|
158
|
+
record=parameters_to_parametersrecord(ins.parameters, keep_input=keep_input),
|
159
|
+
)
|
160
|
+
|
161
|
+
recordset.set_configs(
|
162
|
+
name=f"{ins_str}.config", record=ConfigsRecord(ins.config) # type: ignore
|
163
|
+
)
|
164
|
+
|
165
|
+
return recordset
|
166
|
+
|
167
|
+
|
168
|
+
def _embed_status_into_recordset(
|
169
|
+
res_str: str, status: Status, recordset: RecordSet
|
170
|
+
) -> RecordSet:
|
171
|
+
status_dict: Dict[str, ConfigsRecordValues] = {
|
172
|
+
"code": int(status.code.value),
|
173
|
+
"message": status.message,
|
174
|
+
}
|
175
|
+
# we add it to a `ConfigsRecord`` because the `status.message`` is a string
|
176
|
+
# and `str` values aren't supported in `MetricsRecords`
|
177
|
+
recordset.set_configs(f"{res_str}.status", record=ConfigsRecord(status_dict))
|
178
|
+
return recordset
|
179
|
+
|
180
|
+
|
181
|
+
def _extract_status_from_recordset(res_str: str, recordset: RecordSet) -> Status:
|
182
|
+
status = recordset.get_configs(f"{res_str}.status")
|
183
|
+
code = cast(int, status["code"])
|
184
|
+
return Status(code=Code(code), message=str(status["message"]))
|
185
|
+
|
186
|
+
|
187
|
+
def recordset_to_fitins(recordset: RecordSet, keep_input: bool) -> FitIns:
|
188
|
+
"""Derive FitIns from a RecordSet object."""
|
189
|
+
parameters, config = _recordset_to_fit_or_evaluate_ins_components(
|
190
|
+
recordset,
|
191
|
+
ins_str="fitins",
|
192
|
+
keep_input=keep_input,
|
193
|
+
)
|
194
|
+
|
195
|
+
return FitIns(parameters=parameters, config=config)
|
196
|
+
|
197
|
+
|
198
|
+
def fitins_to_recordset(fitins: FitIns, keep_input: bool) -> RecordSet:
|
199
|
+
"""Construct a RecordSet from a FitIns object."""
|
200
|
+
return _fit_or_evaluate_ins_to_recordset(fitins, keep_input)
|
201
|
+
|
202
|
+
|
203
|
+
def recordset_to_fitres(recordset: RecordSet, keep_input: bool) -> FitRes:
|
204
|
+
"""Derive FitRes from a RecordSet object."""
|
205
|
+
ins_str = "fitres"
|
206
|
+
parameters = parametersrecord_to_parameters(
|
207
|
+
recordset.get_parameters(f"{ins_str}.parameters"), keep_input=keep_input
|
208
|
+
)
|
209
|
+
|
210
|
+
num_examples = cast(
|
211
|
+
int, recordset.get_metrics(f"{ins_str}.num_examples")["num_examples"]
|
212
|
+
)
|
213
|
+
configs_record = recordset.get_configs(f"{ins_str}.metrics")
|
214
|
+
|
215
|
+
metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record.data)
|
216
|
+
status = _extract_status_from_recordset(ins_str, recordset)
|
217
|
+
|
218
|
+
return FitRes(
|
219
|
+
status=status, parameters=parameters, num_examples=num_examples, metrics=metrics
|
220
|
+
)
|
221
|
+
|
222
|
+
|
223
|
+
def fitres_to_recordset(fitres: FitRes, keep_input: bool) -> RecordSet:
|
224
|
+
"""Construct a RecordSet from a FitRes object."""
|
225
|
+
recordset = RecordSet()
|
226
|
+
|
227
|
+
res_str = "fitres"
|
228
|
+
|
229
|
+
recordset.set_configs(
|
230
|
+
name=f"{res_str}.metrics", record=ConfigsRecord(fitres.metrics) # type: ignore
|
231
|
+
)
|
232
|
+
recordset.set_metrics(
|
233
|
+
name=f"{res_str}.num_examples",
|
234
|
+
record=MetricsRecord({"num_examples": fitres.num_examples}),
|
235
|
+
)
|
236
|
+
recordset.set_parameters(
|
237
|
+
name=f"{res_str}.parameters",
|
238
|
+
record=parameters_to_parametersrecord(fitres.parameters, keep_input),
|
239
|
+
)
|
240
|
+
|
241
|
+
# status
|
242
|
+
recordset = _embed_status_into_recordset(res_str, fitres.status, recordset)
|
243
|
+
|
244
|
+
return recordset
|
245
|
+
|
246
|
+
|
247
|
+
def recordset_to_evaluateins(recordset: RecordSet, keep_input: bool) -> EvaluateIns:
|
248
|
+
"""Derive EvaluateIns from a RecordSet object."""
|
249
|
+
parameters, config = _recordset_to_fit_or_evaluate_ins_components(
|
250
|
+
recordset,
|
251
|
+
ins_str="evaluateins",
|
252
|
+
keep_input=keep_input,
|
253
|
+
)
|
254
|
+
|
255
|
+
return EvaluateIns(parameters=parameters, config=config)
|
256
|
+
|
257
|
+
|
258
|
+
def evaluateins_to_recordset(evaluateins: EvaluateIns, keep_input: bool) -> RecordSet:
|
259
|
+
"""Construct a RecordSet from a EvaluateIns object."""
|
260
|
+
return _fit_or_evaluate_ins_to_recordset(evaluateins, keep_input)
|
261
|
+
|
262
|
+
|
263
|
+
def recordset_to_evaluateres(recordset: RecordSet) -> EvaluateRes:
|
264
|
+
"""Derive EvaluateRes from a RecordSet object."""
|
265
|
+
ins_str = "evaluateres"
|
266
|
+
|
267
|
+
loss = cast(int, recordset.get_metrics(f"{ins_str}.loss")["loss"])
|
268
|
+
|
269
|
+
num_examples = cast(
|
270
|
+
int, recordset.get_metrics(f"{ins_str}.num_examples")["num_examples"]
|
271
|
+
)
|
272
|
+
configs_record = recordset.get_configs(f"{ins_str}.metrics")
|
273
|
+
|
274
|
+
metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record.data)
|
275
|
+
status = _extract_status_from_recordset(ins_str, recordset)
|
276
|
+
|
277
|
+
return EvaluateRes(
|
278
|
+
status=status, loss=loss, num_examples=num_examples, metrics=metrics
|
279
|
+
)
|
280
|
+
|
281
|
+
|
282
|
+
def evaluateres_to_recordset(evaluateres: EvaluateRes) -> RecordSet:
|
283
|
+
"""Construct a RecordSet from a EvaluateRes object."""
|
284
|
+
recordset = RecordSet()
|
285
|
+
|
286
|
+
res_str = "evaluateres"
|
287
|
+
# loss
|
288
|
+
recordset.set_metrics(
|
289
|
+
name=f"{res_str}.loss",
|
290
|
+
record=MetricsRecord({"loss": evaluateres.loss}),
|
291
|
+
)
|
292
|
+
|
293
|
+
# num_examples
|
294
|
+
recordset.set_metrics(
|
295
|
+
name=f"{res_str}.num_examples",
|
296
|
+
record=MetricsRecord({"num_examples": evaluateres.num_examples}),
|
297
|
+
)
|
298
|
+
|
299
|
+
# metrics
|
300
|
+
recordset.set_configs(
|
301
|
+
name=f"{res_str}.metrics",
|
302
|
+
record=ConfigsRecord(evaluateres.metrics), # type: ignore
|
303
|
+
)
|
304
|
+
|
305
|
+
# status
|
306
|
+
recordset = _embed_status_into_recordset(
|
307
|
+
f"{res_str}", evaluateres.status, recordset
|
308
|
+
)
|
309
|
+
|
310
|
+
return recordset
|
311
|
+
|
312
|
+
|
313
|
+
def recordset_to_getparametersins(recordset: RecordSet) -> GetParametersIns:
|
314
|
+
"""Derive GetParametersIns from a RecordSet object."""
|
315
|
+
config_record = recordset.get_configs("getparametersins.config")
|
316
|
+
|
317
|
+
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record.data)
|
318
|
+
|
319
|
+
return GetParametersIns(config=config_dict)
|
320
|
+
|
321
|
+
|
322
|
+
def getparametersins_to_recordset(getparameters_ins: GetParametersIns) -> RecordSet:
|
323
|
+
"""Construct a RecordSet from a GetParametersIns object."""
|
324
|
+
recordset = RecordSet()
|
325
|
+
|
326
|
+
recordset.set_configs(
|
327
|
+
name="getparametersins.config",
|
328
|
+
record=ConfigsRecord(getparameters_ins.config), # type: ignore
|
329
|
+
)
|
330
|
+
return recordset
|
331
|
+
|
332
|
+
|
333
|
+
def getparametersres_to_recordset(getparametersres: GetParametersRes) -> RecordSet:
|
334
|
+
"""Construct a RecordSet from a GetParametersRes object."""
|
335
|
+
recordset = RecordSet()
|
336
|
+
res_str = "getparametersres"
|
337
|
+
parameters_record = parameters_to_parametersrecord(getparametersres.parameters)
|
338
|
+
recordset.set_parameters(f"{res_str}.parameters", parameters_record)
|
339
|
+
|
340
|
+
# status
|
341
|
+
recordset = _embed_status_into_recordset(
|
342
|
+
res_str, getparametersres.status, recordset
|
343
|
+
)
|
344
|
+
|
345
|
+
return recordset
|
346
|
+
|
347
|
+
|
348
|
+
def recordset_to_getparametersres(recordset: RecordSet) -> GetParametersRes:
|
349
|
+
"""Derive GetParametersRes from a RecordSet object."""
|
350
|
+
res_str = "getparametersres"
|
351
|
+
parameters = parametersrecord_to_parameters(
|
352
|
+
recordset.get_parameters(f"{res_str}.parameters")
|
353
|
+
)
|
354
|
+
|
355
|
+
status = _extract_status_from_recordset(res_str, recordset)
|
356
|
+
return GetParametersRes(status=status, parameters=parameters)
|
357
|
+
|
358
|
+
|
359
|
+
def recordset_to_getpropertiesins(recordset: RecordSet) -> GetPropertiesIns:
|
360
|
+
"""Derive GetPropertiesIns from a RecordSet object."""
|
361
|
+
config_record = recordset.get_configs("getpropertiesins.config")
|
362
|
+
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record.data)
|
363
|
+
|
364
|
+
return GetPropertiesIns(config=config_dict)
|
365
|
+
|
366
|
+
|
367
|
+
def getpropertiesins_to_recordset(getpropertiesins: GetPropertiesIns) -> RecordSet:
|
368
|
+
"""Construct a RecordSet from a GetPropertiesRes object."""
|
369
|
+
recordset = RecordSet()
|
370
|
+
recordset.set_configs(
|
371
|
+
name="getpropertiesins.config",
|
372
|
+
record=ConfigsRecord(getpropertiesins.config), # type: ignore
|
373
|
+
)
|
374
|
+
return recordset
|
375
|
+
|
376
|
+
|
377
|
+
def recordset_to_getpropertiesres(recordset: RecordSet) -> GetPropertiesRes:
|
378
|
+
"""Derive GetPropertiesRes from a RecordSet object."""
|
379
|
+
res_str = "getpropertiesres"
|
380
|
+
config_record = recordset.get_configs(f"{res_str}.properties")
|
381
|
+
properties = _check_mapping_from_recordscalartype_to_scalar(config_record.data)
|
382
|
+
|
383
|
+
status = _extract_status_from_recordset(res_str, recordset=recordset)
|
384
|
+
|
385
|
+
return GetPropertiesRes(status=status, properties=properties)
|
386
|
+
|
387
|
+
|
388
|
+
def getpropertiesres_to_recordset(getpropertiesres: GetPropertiesRes) -> RecordSet:
|
389
|
+
"""Construct a RecordSet from a GetPropertiesRes object."""
|
390
|
+
recordset = RecordSet()
|
391
|
+
res_str = "getpropertiesres"
|
392
|
+
recordset.set_configs(
|
393
|
+
name=f"{res_str}.properties",
|
394
|
+
record=ConfigsRecord(getpropertiesres.properties), # type: ignore
|
395
|
+
)
|
396
|
+
# status
|
397
|
+
recordset = _embed_status_into_recordset(
|
398
|
+
res_str, getpropertiesres.status, recordset
|
399
|
+
)
|
400
|
+
|
401
|
+
return recordset
|
flwr/common/serde.py
CHANGED
@@ -15,10 +15,23 @@
|
|
15
15
|
"""ProtoBuf serialization and deserialization."""
|
16
16
|
|
17
17
|
|
18
|
-
from typing import Any, Dict, List, MutableMapping, cast
|
19
|
-
|
20
|
-
from
|
21
|
-
|
18
|
+
from typing import Any, Dict, List, MutableMapping, OrderedDict, Type, TypeVar, cast
|
19
|
+
|
20
|
+
from google.protobuf.message import Message
|
21
|
+
|
22
|
+
# pylint: disable=E0611
|
23
|
+
from flwr.proto.recordset_pb2 import Array as ProtoArray
|
24
|
+
from flwr.proto.recordset_pb2 import BoolList, BytesList
|
25
|
+
from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord
|
26
|
+
from flwr.proto.recordset_pb2 import ConfigsRecordValue as ProtoConfigsRecordValue
|
27
|
+
from flwr.proto.recordset_pb2 import DoubleList
|
28
|
+
from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord
|
29
|
+
from flwr.proto.recordset_pb2 import MetricsRecordValue as ProtoMetricsRecordValue
|
30
|
+
from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord
|
31
|
+
from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
|
32
|
+
from flwr.proto.recordset_pb2 import Sint64List, StringList
|
33
|
+
from flwr.proto.task_pb2 import Value
|
34
|
+
from flwr.proto.transport_pb2 import (
|
22
35
|
ClientMessage,
|
23
36
|
Code,
|
24
37
|
Parameters,
|
@@ -28,7 +41,12 @@ from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
|
28
41
|
Status,
|
29
42
|
)
|
30
43
|
|
44
|
+
# pylint: enable=E0611
|
31
45
|
from . import typing
|
46
|
+
from .configsrecord import ConfigsRecord
|
47
|
+
from .metricsrecord import MetricsRecord
|
48
|
+
from .parametersrecord import Array, ParametersRecord
|
49
|
+
from .recordset import RecordSet
|
32
50
|
|
33
51
|
# === ServerMessage message ===
|
34
52
|
|
@@ -493,7 +511,7 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar:
|
|
493
511
|
# === Value messages ===
|
494
512
|
|
495
513
|
|
496
|
-
|
514
|
+
_type_to_field = {
|
497
515
|
float: "double",
|
498
516
|
int: "sint64",
|
499
517
|
bool: "bool",
|
@@ -502,22 +520,20 @@ _python_type_to_field_name = {
|
|
502
520
|
}
|
503
521
|
|
504
522
|
|
505
|
-
|
506
|
-
float: (
|
507
|
-
int: (
|
508
|
-
bool: (
|
509
|
-
str: (
|
510
|
-
bytes: (
|
523
|
+
_list_type_to_class_and_field = {
|
524
|
+
float: (DoubleList, "double_list"),
|
525
|
+
int: (Sint64List, "sint64_list"),
|
526
|
+
bool: (BoolList, "bool_list"),
|
527
|
+
str: (StringList, "string_list"),
|
528
|
+
bytes: (BytesList, "bytes_list"),
|
511
529
|
}
|
512
530
|
|
513
531
|
|
514
532
|
def _check_value(value: typing.Value) -> None:
|
515
|
-
if isinstance(value, tuple(
|
533
|
+
if isinstance(value, tuple(_type_to_field.keys())):
|
516
534
|
return
|
517
535
|
if isinstance(value, list):
|
518
|
-
if len(value) > 0 and isinstance(
|
519
|
-
value[0], tuple(_python_type_to_field_name.keys())
|
520
|
-
):
|
536
|
+
if len(value) > 0 and isinstance(value[0], tuple(_type_to_field.keys())):
|
521
537
|
data_type = type(value[0])
|
522
538
|
for element in value:
|
523
539
|
if isinstance(element, data_type):
|
@@ -539,12 +555,12 @@ def value_to_proto(value: typing.Value) -> Value:
|
|
539
555
|
|
540
556
|
arg = {}
|
541
557
|
if isinstance(value, list):
|
542
|
-
msg_class, field_name =
|
558
|
+
msg_class, field_name = _list_type_to_class_and_field[
|
543
559
|
type(value[0]) if len(value) > 0 else int
|
544
560
|
]
|
545
561
|
arg[field_name] = msg_class(vals=value)
|
546
562
|
else:
|
547
|
-
arg[
|
563
|
+
arg[_type_to_field[type(value)]] = value
|
548
564
|
return Value(**arg)
|
549
565
|
|
550
566
|
|
@@ -573,3 +589,165 @@ def named_values_from_proto(
|
|
573
589
|
) -> Dict[str, typing.Value]:
|
574
590
|
"""Deserialize named values from ProtoBuf."""
|
575
591
|
return {name: value_from_proto(value) for name, value in named_values_proto.items()}
|
592
|
+
|
593
|
+
|
594
|
+
# === Record messages ===
|
595
|
+
|
596
|
+
|
597
|
+
T = TypeVar("T")
|
598
|
+
|
599
|
+
|
600
|
+
def _record_value_to_proto(
|
601
|
+
value: Any, allowed_types: List[type], proto_class: Type[T]
|
602
|
+
) -> T:
|
603
|
+
"""Serialize `*RecordValue` to ProtoBuf."""
|
604
|
+
arg = {}
|
605
|
+
for t in allowed_types:
|
606
|
+
# Single element
|
607
|
+
# Note: `isinstance(False, int) == True`.
|
608
|
+
if type(value) == t: # pylint: disable=C0123
|
609
|
+
arg[_type_to_field[t]] = value
|
610
|
+
return proto_class(**arg)
|
611
|
+
# List
|
612
|
+
if isinstance(value, list) and all(isinstance(item, t) for item in value):
|
613
|
+
list_class, field_name = _list_type_to_class_and_field[t]
|
614
|
+
arg[field_name] = list_class(vals=value)
|
615
|
+
return proto_class(**arg)
|
616
|
+
# Invalid types
|
617
|
+
raise TypeError(
|
618
|
+
f"The type of the following value is not allowed "
|
619
|
+
f"in '{proto_class.__name__}':\n{value}"
|
620
|
+
)
|
621
|
+
|
622
|
+
|
623
|
+
def _record_value_from_proto(value_proto: Message) -> Any:
|
624
|
+
"""Deserialize `*RecordValue` from ProtoBuf."""
|
625
|
+
value_field = cast(str, value_proto.WhichOneof("value"))
|
626
|
+
if value_field.endswith("list"):
|
627
|
+
value = list(getattr(value_proto, value_field).vals)
|
628
|
+
else:
|
629
|
+
value = getattr(value_proto, value_field)
|
630
|
+
return value
|
631
|
+
|
632
|
+
|
633
|
+
def _record_value_dict_to_proto(
|
634
|
+
value_dict: Dict[str, Any], allowed_types: List[type], value_proto_class: Type[T]
|
635
|
+
) -> Dict[str, T]:
|
636
|
+
"""Serialize the record value dict to ProtoBuf."""
|
637
|
+
|
638
|
+
def proto(_v: Any) -> T:
|
639
|
+
return _record_value_to_proto(_v, allowed_types, value_proto_class)
|
640
|
+
|
641
|
+
return {k: proto(v) for k, v in value_dict.items()}
|
642
|
+
|
643
|
+
|
644
|
+
def _record_value_dict_from_proto(
|
645
|
+
value_dict_proto: MutableMapping[str, Any]
|
646
|
+
) -> Dict[str, Any]:
|
647
|
+
"""Deserialize the record value dict from ProtoBuf."""
|
648
|
+
return {k: _record_value_from_proto(v) for k, v in value_dict_proto.items()}
|
649
|
+
|
650
|
+
|
651
|
+
def array_to_proto(array: Array) -> ProtoArray:
|
652
|
+
"""Serialize Array to ProtoBuf."""
|
653
|
+
return ProtoArray(**vars(array))
|
654
|
+
|
655
|
+
|
656
|
+
def array_from_proto(array_proto: ProtoArray) -> Array:
|
657
|
+
"""Deserialize Array from ProtoBuf."""
|
658
|
+
return Array(
|
659
|
+
dtype=array_proto.dtype,
|
660
|
+
shape=list(array_proto.shape),
|
661
|
+
stype=array_proto.stype,
|
662
|
+
data=array_proto.data,
|
663
|
+
)
|
664
|
+
|
665
|
+
|
666
|
+
def parameters_record_to_proto(record: ParametersRecord) -> ProtoParametersRecord:
|
667
|
+
"""Serialize ParametersRecord to ProtoBuf."""
|
668
|
+
return ProtoParametersRecord(
|
669
|
+
data_keys=record.data.keys(),
|
670
|
+
data_values=map(array_to_proto, record.data.values()),
|
671
|
+
)
|
672
|
+
|
673
|
+
|
674
|
+
def parameters_record_from_proto(
|
675
|
+
record_proto: ProtoParametersRecord,
|
676
|
+
) -> ParametersRecord:
|
677
|
+
"""Deserialize ParametersRecord from ProtoBuf."""
|
678
|
+
return ParametersRecord(
|
679
|
+
array_dict=OrderedDict(
|
680
|
+
zip(record_proto.data_keys, map(array_from_proto, record_proto.data_values))
|
681
|
+
),
|
682
|
+
keep_input=False,
|
683
|
+
)
|
684
|
+
|
685
|
+
|
686
|
+
def metrics_record_to_proto(record: MetricsRecord) -> ProtoMetricsRecord:
|
687
|
+
"""Serialize MetricsRecord to ProtoBuf."""
|
688
|
+
return ProtoMetricsRecord(
|
689
|
+
data=_record_value_dict_to_proto(
|
690
|
+
record.data, [float, int], ProtoMetricsRecordValue
|
691
|
+
)
|
692
|
+
)
|
693
|
+
|
694
|
+
|
695
|
+
def metrics_record_from_proto(record_proto: ProtoMetricsRecord) -> MetricsRecord:
|
696
|
+
"""Deserialize MetricsRecord from ProtoBuf."""
|
697
|
+
return MetricsRecord(
|
698
|
+
metrics_dict=cast(
|
699
|
+
Dict[str, typing.MetricsRecordValues],
|
700
|
+
_record_value_dict_from_proto(record_proto.data),
|
701
|
+
),
|
702
|
+
keep_input=False,
|
703
|
+
)
|
704
|
+
|
705
|
+
|
706
|
+
def configs_record_to_proto(record: ConfigsRecord) -> ProtoConfigsRecord:
|
707
|
+
"""Serialize ConfigsRecord to ProtoBuf."""
|
708
|
+
return ProtoConfigsRecord(
|
709
|
+
data=_record_value_dict_to_proto(
|
710
|
+
record.data, [int, float, bool, str, bytes], ProtoConfigsRecordValue
|
711
|
+
)
|
712
|
+
)
|
713
|
+
|
714
|
+
|
715
|
+
def configs_record_from_proto(record_proto: ProtoConfigsRecord) -> ConfigsRecord:
|
716
|
+
"""Deserialize ConfigsRecord from ProtoBuf."""
|
717
|
+
return ConfigsRecord(
|
718
|
+
configs_dict=cast(
|
719
|
+
Dict[str, typing.ConfigsRecordValues],
|
720
|
+
_record_value_dict_from_proto(record_proto.data),
|
721
|
+
),
|
722
|
+
keep_input=False,
|
723
|
+
)
|
724
|
+
|
725
|
+
|
726
|
+
# === RecordSet message ===
|
727
|
+
|
728
|
+
|
729
|
+
def recordset_to_proto(recordset: RecordSet) -> ProtoRecordSet:
|
730
|
+
"""Serialize RecordSet to ProtoBuf."""
|
731
|
+
return ProtoRecordSet(
|
732
|
+
parameters={
|
733
|
+
k: parameters_record_to_proto(v) for k, v in recordset.parameters.items()
|
734
|
+
},
|
735
|
+
metrics={k: metrics_record_to_proto(v) for k, v in recordset.metrics.items()},
|
736
|
+
configs={k: configs_record_to_proto(v) for k, v in recordset.configs.items()},
|
737
|
+
)
|
738
|
+
|
739
|
+
|
740
|
+
def recordset_from_proto(recordset_proto: ProtoRecordSet) -> RecordSet:
|
741
|
+
"""Deserialize RecordSet from ProtoBuf."""
|
742
|
+
return RecordSet(
|
743
|
+
parameters={
|
744
|
+
k: parameters_record_from_proto(v)
|
745
|
+
for k, v in recordset_proto.parameters.items()
|
746
|
+
},
|
747
|
+
metrics={
|
748
|
+
k: metrics_record_from_proto(v) for k, v in recordset_proto.metrics.items()
|
749
|
+
},
|
750
|
+
configs={
|
751
|
+
k: configs_record_from_proto(v) for k, v in recordset_proto.configs.items()
|
752
|
+
},
|
753
|
+
)
|