flwr-nightly 1.19.0.dev20250508__py3-none-any.whl → 1.19.0.dev20250510__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/common/constant.py +3 -0
- flwr/common/inflatable.py +16 -0
- flwr/common/record/array.py +39 -1
- flwr/common/record/configrecord.py +49 -4
- flwr/common/record/metricrecord.py +45 -4
- flwr/common/serde.py +7 -104
- flwr/common/serde_utils.py +123 -0
- {flwr_nightly-1.19.0.dev20250508.dist-info → flwr_nightly-1.19.0.dev20250510.dist-info}/METADATA +1 -1
- {flwr_nightly-1.19.0.dev20250508.dist-info → flwr_nightly-1.19.0.dev20250510.dist-info}/RECORD +11 -10
- {flwr_nightly-1.19.0.dev20250508.dist-info → flwr_nightly-1.19.0.dev20250510.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.19.0.dev20250508.dist-info → flwr_nightly-1.19.0.dev20250510.dist-info}/entry_points.txt +0 -0
flwr/common/constant.py
CHANGED
@@ -131,6 +131,9 @@ GC_THRESHOLD = 200_000_000 # 200 MB
|
|
131
131
|
HEAD_BODY_DIVIDER = b"\x00"
|
132
132
|
TYPE_BODY_LEN_DIVIDER = " "
|
133
133
|
|
134
|
+
# Constants for serialization
|
135
|
+
INT64_MAX_VALUE = 9223372036854775807 # (1 << 63) - 1
|
136
|
+
|
134
137
|
|
135
138
|
class MessageType:
|
136
139
|
"""Message type."""
|
flwr/common/inflatable.py
CHANGED
@@ -30,6 +30,22 @@ class InflatableObject:
|
|
30
30
|
"""Deflate object."""
|
31
31
|
raise NotImplementedError()
|
32
32
|
|
33
|
+
@classmethod
|
34
|
+
def inflate(cls, object_content: bytes) -> "InflatableObject":
|
35
|
+
"""Inflate the object from bytes.
|
36
|
+
|
37
|
+
Parameters
|
38
|
+
----------
|
39
|
+
object_content : bytes
|
40
|
+
The deflated object content.
|
41
|
+
|
42
|
+
Returns
|
43
|
+
-------
|
44
|
+
InflatableObject
|
45
|
+
The inflated object.
|
46
|
+
"""
|
47
|
+
raise NotImplementedError()
|
48
|
+
|
33
49
|
@property
|
34
50
|
def object_id(self) -> str:
|
35
51
|
"""Get object_id."""
|
flwr/common/record/array.py
CHANGED
@@ -24,7 +24,10 @@ from typing import TYPE_CHECKING, Any, cast, overload
|
|
24
24
|
|
25
25
|
import numpy as np
|
26
26
|
|
27
|
+
from flwr.proto.recorddict_pb2 import Array as ArrayProto # pylint: disable=E0611
|
28
|
+
|
27
29
|
from ..constant import SType
|
30
|
+
from ..inflatable import InflatableObject, add_header_to_object_body, get_object_body
|
28
31
|
from ..typing import NDArray
|
29
32
|
|
30
33
|
if TYPE_CHECKING:
|
@@ -40,7 +43,7 @@ def _raise_array_init_error() -> None:
|
|
40
43
|
|
41
44
|
|
42
45
|
@dataclass
|
43
|
-
class Array:
|
46
|
+
class Array(InflatableObject):
|
44
47
|
"""Array type.
|
45
48
|
|
46
49
|
A dataclass containing serialized data from an array-like or tensor-like object
|
@@ -248,3 +251,38 @@ class Array:
|
|
248
251
|
# Source: https://numpy.org/doc/stable/reference/generated/numpy.load.html
|
249
252
|
ndarray_deserialized = np.load(bytes_io, allow_pickle=False)
|
250
253
|
return cast(NDArray, ndarray_deserialized)
|
254
|
+
|
255
|
+
def deflate(self) -> bytes:
|
256
|
+
"""Deflate the Array."""
|
257
|
+
array_proto = ArrayProto(
|
258
|
+
dtype=self.dtype,
|
259
|
+
shape=self.shape,
|
260
|
+
stype=self.stype,
|
261
|
+
data=self.data,
|
262
|
+
)
|
263
|
+
|
264
|
+
obj_body = array_proto.SerializeToString(deterministic=True)
|
265
|
+
return add_header_to_object_body(object_body=obj_body, cls=self)
|
266
|
+
|
267
|
+
@classmethod
|
268
|
+
def inflate(cls, object_content: bytes) -> Array:
|
269
|
+
"""Inflate an Array from bytes.
|
270
|
+
|
271
|
+
Parameters
|
272
|
+
----------
|
273
|
+
object_content : bytes
|
274
|
+
The deflated object content of the Array.
|
275
|
+
|
276
|
+
Returns
|
277
|
+
-------
|
278
|
+
Array
|
279
|
+
The inflated Array.
|
280
|
+
"""
|
281
|
+
obj_body = get_object_body(object_content, cls)
|
282
|
+
proto_array = ArrayProto.FromString(obj_body)
|
283
|
+
return cls(
|
284
|
+
dtype=proto_array.dtype,
|
285
|
+
shape=list(proto_array.shape),
|
286
|
+
stype=proto_array.stype,
|
287
|
+
data=proto_array.data,
|
288
|
+
)
|
@@ -15,12 +15,21 @@
|
|
15
15
|
"""ConfigRecord."""
|
16
16
|
|
17
17
|
|
18
|
+
from __future__ import annotations
|
19
|
+
|
18
20
|
from logging import WARN
|
19
|
-
from typing import
|
21
|
+
from typing import cast, get_args
|
20
22
|
|
21
23
|
from flwr.common.typing import ConfigRecordValues, ConfigScalar
|
22
24
|
|
25
|
+
# pylint: disable=E0611
|
26
|
+
from flwr.proto.recorddict_pb2 import ConfigRecord as ProtoConfigRecord
|
27
|
+
from flwr.proto.recorddict_pb2 import ConfigRecordValue as ProtoConfigRecordValue
|
28
|
+
|
29
|
+
# pylint: enable=E0611
|
30
|
+
from ..inflatable import InflatableObject, add_header_to_object_body, get_object_body
|
23
31
|
from ..logger import log
|
32
|
+
from ..serde_utils import record_value_dict_from_proto, record_value_dict_to_proto
|
24
33
|
from .typeddict import TypedDict
|
25
34
|
|
26
35
|
|
@@ -59,7 +68,7 @@ def _check_value(value: ConfigRecordValues) -> None:
|
|
59
68
|
is_valid(value)
|
60
69
|
|
61
70
|
|
62
|
-
class ConfigRecord(TypedDict[str, ConfigRecordValues]):
|
71
|
+
class ConfigRecord(TypedDict[str, ConfigRecordValues], InflatableObject):
|
63
72
|
"""Config record.
|
64
73
|
|
65
74
|
A :code:`ConfigRecord` is a Python dictionary designed to ensure that
|
@@ -111,7 +120,7 @@ class ConfigRecord(TypedDict[str, ConfigRecordValues]):
|
|
111
120
|
|
112
121
|
def __init__(
|
113
122
|
self,
|
114
|
-
config_dict:
|
123
|
+
config_dict: dict[str, ConfigRecordValues] | None = None,
|
115
124
|
keep_input: bool = True,
|
116
125
|
) -> None:
|
117
126
|
|
@@ -164,6 +173,42 @@ class ConfigRecord(TypedDict[str, ConfigRecordValues]):
|
|
164
173
|
|
165
174
|
return num_bytes
|
166
175
|
|
176
|
+
def deflate(self) -> bytes:
|
177
|
+
"""Deflate object."""
|
178
|
+
obj_body = ProtoConfigRecord(
|
179
|
+
data=record_value_dict_to_proto(
|
180
|
+
self,
|
181
|
+
[bool, int, float, str, bytes],
|
182
|
+
ProtoConfigRecordValue,
|
183
|
+
)
|
184
|
+
).SerializeToString(deterministic=True)
|
185
|
+
return add_header_to_object_body(object_body=obj_body, cls=self)
|
186
|
+
|
187
|
+
@classmethod
|
188
|
+
def inflate(cls, object_content: bytes) -> ConfigRecord:
|
189
|
+
"""Inflate a ConfigRecord from bytes.
|
190
|
+
|
191
|
+
Parameters
|
192
|
+
----------
|
193
|
+
object_content : bytes
|
194
|
+
The deflated object content of the ConfigRecord.
|
195
|
+
|
196
|
+
Returns
|
197
|
+
-------
|
198
|
+
ConfigRecord
|
199
|
+
The inflated ConfigRecord.
|
200
|
+
"""
|
201
|
+
obj_body = get_object_body(object_content, cls)
|
202
|
+
config_record_proto = ProtoConfigRecord.FromString(obj_body)
|
203
|
+
|
204
|
+
return ConfigRecord(
|
205
|
+
config_dict=cast(
|
206
|
+
dict[str, ConfigRecordValues],
|
207
|
+
record_value_dict_from_proto(config_record_proto.data),
|
208
|
+
),
|
209
|
+
keep_input=False,
|
210
|
+
)
|
211
|
+
|
167
212
|
|
168
213
|
class ConfigsRecord(ConfigRecord):
|
169
214
|
"""Deprecated class ``ConfigsRecord``, use ``ConfigRecord`` instead.
|
@@ -195,7 +240,7 @@ class ConfigsRecord(ConfigRecord):
|
|
195
240
|
|
196
241
|
def __init__(
|
197
242
|
self,
|
198
|
-
config_dict:
|
243
|
+
config_dict: dict[str, ConfigRecordValues] | None = None,
|
199
244
|
keep_input: bool = True,
|
200
245
|
):
|
201
246
|
if not ConfigsRecord._warning_logged:
|
@@ -15,12 +15,21 @@
|
|
15
15
|
"""MetricRecord."""
|
16
16
|
|
17
17
|
|
18
|
+
from __future__ import annotations
|
19
|
+
|
18
20
|
from logging import WARN
|
19
|
-
from typing import
|
21
|
+
from typing import cast, get_args
|
20
22
|
|
21
23
|
from flwr.common.typing import MetricRecordValues, MetricScalar
|
22
24
|
|
25
|
+
# pylint: disable=E0611
|
26
|
+
from flwr.proto.recorddict_pb2 import MetricRecord as ProtoMetricRecord
|
27
|
+
from flwr.proto.recorddict_pb2 import MetricRecordValue as ProtoMetricRecordValue
|
28
|
+
|
29
|
+
# pylint: enable=E0611
|
30
|
+
from ..inflatable import InflatableObject, add_header_to_object_body, get_object_body
|
23
31
|
from ..logger import log
|
32
|
+
from ..serde_utils import record_value_dict_from_proto, record_value_dict_to_proto
|
24
33
|
from .typeddict import TypedDict
|
25
34
|
|
26
35
|
|
@@ -59,7 +68,7 @@ def _check_value(value: MetricRecordValues) -> None:
|
|
59
68
|
is_valid(value)
|
60
69
|
|
61
70
|
|
62
|
-
class MetricRecord(TypedDict[str, MetricRecordValues]):
|
71
|
+
class MetricRecord(TypedDict[str, MetricRecordValues], InflatableObject):
|
63
72
|
"""Metric record.
|
64
73
|
|
65
74
|
A :code:`MetricRecord` is a Python dictionary designed to ensure that
|
@@ -117,7 +126,7 @@ class MetricRecord(TypedDict[str, MetricRecordValues]):
|
|
117
126
|
|
118
127
|
def __init__(
|
119
128
|
self,
|
120
|
-
metric_dict:
|
129
|
+
metric_dict: dict[str, MetricRecordValues] | None = None,
|
121
130
|
keep_input: bool = True,
|
122
131
|
) -> None:
|
123
132
|
super().__init__(_check_key, _check_value)
|
@@ -143,6 +152,38 @@ class MetricRecord(TypedDict[str, MetricRecordValues]):
|
|
143
152
|
num_bytes += len(k)
|
144
153
|
return num_bytes
|
145
154
|
|
155
|
+
def deflate(self) -> bytes:
|
156
|
+
"""Deflate object."""
|
157
|
+
obj_body = ProtoMetricRecord(
|
158
|
+
data=record_value_dict_to_proto(self, [float, int], ProtoMetricRecordValue)
|
159
|
+
).SerializeToString(deterministic=True)
|
160
|
+
return add_header_to_object_body(object_body=obj_body, cls=self)
|
161
|
+
|
162
|
+
@classmethod
|
163
|
+
def inflate(cls, object_content: bytes) -> MetricRecord:
|
164
|
+
"""Inflate a MetricRecord from bytes.
|
165
|
+
|
166
|
+
Parameters
|
167
|
+
----------
|
168
|
+
object_content : bytes
|
169
|
+
The deflated object content of the MetricRecord.
|
170
|
+
|
171
|
+
Returns
|
172
|
+
-------
|
173
|
+
MetricRecord
|
174
|
+
The inflated MetricRecord.
|
175
|
+
"""
|
176
|
+
obj_body = get_object_body(object_content, cls)
|
177
|
+
metric_record_proto = ProtoMetricRecord.FromString(obj_body)
|
178
|
+
|
179
|
+
return cls(
|
180
|
+
metric_dict=cast(
|
181
|
+
dict[str, MetricRecordValues],
|
182
|
+
record_value_dict_from_proto(metric_record_proto.data),
|
183
|
+
),
|
184
|
+
keep_input=False,
|
185
|
+
)
|
186
|
+
|
146
187
|
|
147
188
|
class MetricsRecord(MetricRecord):
|
148
189
|
"""Deprecated class ``MetricsRecord``, use ``MetricRecord`` instead.
|
@@ -174,7 +215,7 @@ class MetricsRecord(MetricRecord):
|
|
174
215
|
|
175
216
|
def __init__(
|
176
217
|
self,
|
177
|
-
metric_dict:
|
218
|
+
metric_dict: dict[str, MetricRecordValues] | None = None,
|
178
219
|
keep_input: bool = True,
|
179
220
|
):
|
180
221
|
if not MetricsRecord._warning_logged:
|
flwr/common/serde.py
CHANGED
@@ -16,10 +16,7 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
from collections import OrderedDict
|
19
|
-
from
|
20
|
-
from typing import Any, TypeVar, cast
|
21
|
-
|
22
|
-
from google.protobuf.message import Message as GrpcMessage
|
19
|
+
from typing import Any, cast
|
23
20
|
|
24
21
|
# pylint: disable=E0611
|
25
22
|
from flwr.proto.clientappio_pb2 import ClientAppOutputCode, ClientAppOutputStatus
|
@@ -30,14 +27,11 @@ from flwr.proto.message_pb2 import Message as ProtoMessage
|
|
30
27
|
from flwr.proto.message_pb2 import Metadata as ProtoMetadata
|
31
28
|
from flwr.proto.recorddict_pb2 import Array as ProtoArray
|
32
29
|
from flwr.proto.recorddict_pb2 import ArrayRecord as ProtoArrayRecord
|
33
|
-
from flwr.proto.recorddict_pb2 import BoolList, BytesList
|
34
30
|
from flwr.proto.recorddict_pb2 import ConfigRecord as ProtoConfigRecord
|
35
31
|
from flwr.proto.recorddict_pb2 import ConfigRecordValue as ProtoConfigRecordValue
|
36
|
-
from flwr.proto.recorddict_pb2 import DoubleList
|
37
32
|
from flwr.proto.recorddict_pb2 import MetricRecord as ProtoMetricRecord
|
38
33
|
from flwr.proto.recorddict_pb2 import MetricRecordValue as ProtoMetricRecordValue
|
39
34
|
from flwr.proto.recorddict_pb2 import RecordDict as ProtoRecordDict
|
40
|
-
from flwr.proto.recorddict_pb2 import SintList, StringList, UintList
|
41
35
|
from flwr.proto.run_pb2 import Run as ProtoRun
|
42
36
|
from flwr.proto.run_pb2 import RunStatus as ProtoRunStatus
|
43
37
|
from flwr.proto.transport_pb2 import (
|
@@ -60,8 +54,9 @@ from . import (
|
|
60
54
|
RecordDict,
|
61
55
|
typing,
|
62
56
|
)
|
57
|
+
from .constant import INT64_MAX_VALUE
|
63
58
|
from .message import Error, Message, Metadata, make_message
|
64
|
-
from .
|
59
|
+
from .serde_utils import record_value_dict_from_proto, record_value_dict_to_proto
|
65
60
|
|
66
61
|
# === Parameters message ===
|
67
62
|
|
@@ -339,7 +334,6 @@ def metrics_from_proto(proto: Any) -> typing.Metrics:
|
|
339
334
|
|
340
335
|
|
341
336
|
# === Scalar messages ===
|
342
|
-
INT64_MAX_VALUE = 9223372036854775807 # (1 << 63) - 1
|
343
337
|
|
344
338
|
|
345
339
|
def scalar_to_proto(scalar: typing.Scalar) -> Scalar:
|
@@ -377,97 +371,6 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar:
|
|
377
371
|
# === Record messages ===
|
378
372
|
|
379
373
|
|
380
|
-
_type_to_field: dict[type, str] = {
|
381
|
-
float: "double",
|
382
|
-
int: "sint64",
|
383
|
-
bool: "bool",
|
384
|
-
str: "string",
|
385
|
-
bytes: "bytes",
|
386
|
-
}
|
387
|
-
_list_type_to_class_and_field: dict[type, tuple[type[GrpcMessage], str]] = {
|
388
|
-
float: (DoubleList, "double_list"),
|
389
|
-
int: (SintList, "sint_list"),
|
390
|
-
bool: (BoolList, "bool_list"),
|
391
|
-
str: (StringList, "string_list"),
|
392
|
-
bytes: (BytesList, "bytes_list"),
|
393
|
-
}
|
394
|
-
T = TypeVar("T")
|
395
|
-
|
396
|
-
|
397
|
-
def _is_uint64(value: Any) -> bool:
|
398
|
-
"""Check if a value is uint64."""
|
399
|
-
return isinstance(value, int) and value > INT64_MAX_VALUE
|
400
|
-
|
401
|
-
|
402
|
-
def _record_value_to_proto(
|
403
|
-
value: Any, allowed_types: list[type], proto_class: type[T]
|
404
|
-
) -> T:
|
405
|
-
"""Serialize `*RecordValue` to ProtoBuf.
|
406
|
-
|
407
|
-
Note: `bool` MUST be put in the front of allowd_types if it exists.
|
408
|
-
"""
|
409
|
-
arg = {}
|
410
|
-
for t in allowed_types:
|
411
|
-
# Single element
|
412
|
-
# Note: `isinstance(False, int) == True`.
|
413
|
-
if isinstance(value, t):
|
414
|
-
fld = _type_to_field[t]
|
415
|
-
if t is int and _is_uint64(value):
|
416
|
-
fld = "uint64"
|
417
|
-
arg[fld] = value
|
418
|
-
return proto_class(**arg)
|
419
|
-
# List
|
420
|
-
if isinstance(value, list) and all(isinstance(item, t) for item in value):
|
421
|
-
list_class, fld = _list_type_to_class_and_field[t]
|
422
|
-
# Use UintList if any element is of type `uint64`.
|
423
|
-
if t is int and any(_is_uint64(v) for v in value):
|
424
|
-
list_class, fld = UintList, "uint_list"
|
425
|
-
arg[fld] = list_class(vals=value)
|
426
|
-
return proto_class(**arg)
|
427
|
-
# Invalid types
|
428
|
-
raise TypeError(
|
429
|
-
f"The type of the following value is not allowed "
|
430
|
-
f"in '{proto_class.__name__}':\n{value}"
|
431
|
-
)
|
432
|
-
|
433
|
-
|
434
|
-
def _record_value_from_proto(value_proto: GrpcMessage) -> Any:
|
435
|
-
"""Deserialize `*RecordValue` from ProtoBuf."""
|
436
|
-
value_field = cast(str, value_proto.WhichOneof("value"))
|
437
|
-
if value_field.endswith("list"):
|
438
|
-
value = list(getattr(value_proto, value_field).vals)
|
439
|
-
else:
|
440
|
-
value = getattr(value_proto, value_field)
|
441
|
-
return value
|
442
|
-
|
443
|
-
|
444
|
-
def _record_value_dict_to_proto(
|
445
|
-
value_dict: TypedDict[str, Any],
|
446
|
-
allowed_types: list[type],
|
447
|
-
value_proto_class: type[T],
|
448
|
-
) -> dict[str, T]:
|
449
|
-
"""Serialize the record value dict to ProtoBuf.
|
450
|
-
|
451
|
-
Note: `bool` MUST be put in the front of allowd_types if it exists.
|
452
|
-
"""
|
453
|
-
# Move bool to the front
|
454
|
-
if bool in allowed_types and allowed_types[0] != bool:
|
455
|
-
allowed_types.remove(bool)
|
456
|
-
allowed_types.insert(0, bool)
|
457
|
-
|
458
|
-
def proto(_v: Any) -> T:
|
459
|
-
return _record_value_to_proto(_v, allowed_types, value_proto_class)
|
460
|
-
|
461
|
-
return {k: proto(v) for k, v in value_dict.items()}
|
462
|
-
|
463
|
-
|
464
|
-
def _record_value_dict_from_proto(
|
465
|
-
value_dict_proto: MutableMapping[str, Any]
|
466
|
-
) -> dict[str, Any]:
|
467
|
-
"""Deserialize the record value dict from ProtoBuf."""
|
468
|
-
return {k: _record_value_from_proto(v) for k, v in value_dict_proto.items()}
|
469
|
-
|
470
|
-
|
471
374
|
def array_to_proto(array: Array) -> ProtoArray:
|
472
375
|
"""Serialize Array to ProtoBuf."""
|
473
376
|
return ProtoArray(**vars(array))
|
@@ -506,7 +409,7 @@ def array_record_from_proto(
|
|
506
409
|
def metric_record_to_proto(record: MetricRecord) -> ProtoMetricRecord:
|
507
410
|
"""Serialize MetricRecord to ProtoBuf."""
|
508
411
|
return ProtoMetricRecord(
|
509
|
-
data=
|
412
|
+
data=record_value_dict_to_proto(record, [float, int], ProtoMetricRecordValue)
|
510
413
|
)
|
511
414
|
|
512
415
|
|
@@ -515,7 +418,7 @@ def metric_record_from_proto(record_proto: ProtoMetricRecord) -> MetricRecord:
|
|
515
418
|
return MetricRecord(
|
516
419
|
metric_dict=cast(
|
517
420
|
dict[str, typing.MetricRecordValues],
|
518
|
-
|
421
|
+
record_value_dict_from_proto(record_proto.data),
|
519
422
|
),
|
520
423
|
keep_input=False,
|
521
424
|
)
|
@@ -524,7 +427,7 @@ def metric_record_from_proto(record_proto: ProtoMetricRecord) -> MetricRecord:
|
|
524
427
|
def config_record_to_proto(record: ConfigRecord) -> ProtoConfigRecord:
|
525
428
|
"""Serialize ConfigRecord to ProtoBuf."""
|
526
429
|
return ProtoConfigRecord(
|
527
|
-
data=
|
430
|
+
data=record_value_dict_to_proto(
|
528
431
|
record,
|
529
432
|
[bool, int, float, str, bytes],
|
530
433
|
ProtoConfigRecordValue,
|
@@ -537,7 +440,7 @@ def config_record_from_proto(record_proto: ProtoConfigRecord) -> ConfigRecord:
|
|
537
440
|
return ConfigRecord(
|
538
441
|
config_dict=cast(
|
539
442
|
dict[str, typing.ConfigRecordValues],
|
540
|
-
|
443
|
+
record_value_dict_from_proto(record_proto.data),
|
541
444
|
),
|
542
445
|
keep_input=False,
|
543
446
|
)
|
@@ -0,0 +1,123 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Utils for serde."""
|
16
|
+
|
17
|
+
from collections.abc import MutableMapping
|
18
|
+
from typing import Any, TypeVar, cast
|
19
|
+
|
20
|
+
from google.protobuf.message import Message as GrpcMessage
|
21
|
+
|
22
|
+
# pylint: disable=E0611
|
23
|
+
from flwr.proto.recorddict_pb2 import (
|
24
|
+
BoolList,
|
25
|
+
BytesList,
|
26
|
+
DoubleList,
|
27
|
+
SintList,
|
28
|
+
StringList,
|
29
|
+
UintList,
|
30
|
+
)
|
31
|
+
|
32
|
+
from .constant import INT64_MAX_VALUE
|
33
|
+
from .record.typeddict import TypedDict
|
34
|
+
|
35
|
+
_type_to_field: dict[type, str] = {
|
36
|
+
float: "double",
|
37
|
+
int: "sint64",
|
38
|
+
bool: "bool",
|
39
|
+
str: "string",
|
40
|
+
bytes: "bytes",
|
41
|
+
}
|
42
|
+
_list_type_to_class_and_field: dict[type, tuple[type[GrpcMessage], str]] = {
|
43
|
+
float: (DoubleList, "double_list"),
|
44
|
+
int: (SintList, "sint_list"),
|
45
|
+
bool: (BoolList, "bool_list"),
|
46
|
+
str: (StringList, "string_list"),
|
47
|
+
bytes: (BytesList, "bytes_list"),
|
48
|
+
}
|
49
|
+
T = TypeVar("T")
|
50
|
+
|
51
|
+
|
52
|
+
def _is_uint64(value: Any) -> bool:
|
53
|
+
"""Check if a value is uint64."""
|
54
|
+
return isinstance(value, int) and value > INT64_MAX_VALUE
|
55
|
+
|
56
|
+
|
57
|
+
def _record_value_to_proto(
|
58
|
+
value: Any, allowed_types: list[type], proto_class: type[T]
|
59
|
+
) -> T:
|
60
|
+
"""Serialize `*RecordValue` to ProtoBuf.
|
61
|
+
|
62
|
+
Note: `bool` MUST be put in the front of allowd_types if it exists.
|
63
|
+
"""
|
64
|
+
arg = {}
|
65
|
+
for t in allowed_types:
|
66
|
+
# Single element
|
67
|
+
# Note: `isinstance(False, int) == True`.
|
68
|
+
if isinstance(value, t):
|
69
|
+
fld = _type_to_field[t]
|
70
|
+
if t is int and _is_uint64(value):
|
71
|
+
fld = "uint64"
|
72
|
+
arg[fld] = value
|
73
|
+
return proto_class(**arg)
|
74
|
+
# List
|
75
|
+
if isinstance(value, list) and all(isinstance(item, t) for item in value):
|
76
|
+
list_class, fld = _list_type_to_class_and_field[t]
|
77
|
+
# Use UintList if any element is of type `uint64`.
|
78
|
+
if t is int and any(_is_uint64(v) for v in value):
|
79
|
+
list_class, fld = UintList, "uint_list"
|
80
|
+
arg[fld] = list_class(vals=value)
|
81
|
+
return proto_class(**arg)
|
82
|
+
# Invalid types
|
83
|
+
raise TypeError(
|
84
|
+
f"The type of the following value is not allowed "
|
85
|
+
f"in '{proto_class.__name__}':\n{value}"
|
86
|
+
)
|
87
|
+
|
88
|
+
|
89
|
+
def _record_value_from_proto(value_proto: GrpcMessage) -> Any:
|
90
|
+
"""Deserialize `*RecordValue` from ProtoBuf."""
|
91
|
+
value_field = cast(str, value_proto.WhichOneof("value"))
|
92
|
+
if value_field.endswith("list"):
|
93
|
+
value = list(getattr(value_proto, value_field).vals)
|
94
|
+
else:
|
95
|
+
value = getattr(value_proto, value_field)
|
96
|
+
return value
|
97
|
+
|
98
|
+
|
99
|
+
def record_value_dict_to_proto(
|
100
|
+
value_dict: TypedDict[str, Any],
|
101
|
+
allowed_types: list[type],
|
102
|
+
value_proto_class: type[T],
|
103
|
+
) -> dict[str, T]:
|
104
|
+
"""Serialize the record value dict to ProtoBuf.
|
105
|
+
|
106
|
+
Note: `bool` MUST be put in the front of allowd_types if it exists.
|
107
|
+
"""
|
108
|
+
# Move bool to the front
|
109
|
+
if bool in allowed_types and allowed_types[0] != bool:
|
110
|
+
allowed_types.remove(bool)
|
111
|
+
allowed_types.insert(0, bool)
|
112
|
+
|
113
|
+
def proto(_v: Any) -> T:
|
114
|
+
return _record_value_to_proto(_v, allowed_types, value_proto_class)
|
115
|
+
|
116
|
+
return {k: proto(v) for k, v in value_dict.items()}
|
117
|
+
|
118
|
+
|
119
|
+
def record_value_dict_from_proto(
|
120
|
+
value_dict_proto: MutableMapping[str, Any]
|
121
|
+
) -> dict[str, Any]:
|
122
|
+
"""Deserialize the record value dict from ProtoBuf."""
|
123
|
+
return {k: _record_value_from_proto(v) for k, v in value_dict_proto.items()}
|
{flwr_nightly-1.19.0.dev20250508.dist-info → flwr_nightly-1.19.0.dev20250510.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: flwr-nightly
|
3
|
-
Version: 1.19.0.
|
3
|
+
Version: 1.19.0.dev20250510
|
4
4
|
Summary: Flower: A Friendly Federated AI Framework
|
5
5
|
License: Apache-2.0
|
6
6
|
Keywords: Artificial Intelligence,Federated AI,Federated Analytics,Federated Evaluation,Federated Learning,Flower,Machine Learning
|
{flwr_nightly-1.19.0.dev20250508.dist-info → flwr_nightly-1.19.0.dev20250510.dist-info}/RECORD
RENAMED
@@ -115,7 +115,7 @@ flwr/common/args.py,sha256=-aX_jVnSaDrJR2KZ8Wq0Y3dQHII4R4MJtJOIXzVUA0c,5417
|
|
115
115
|
flwr/common/auth_plugin/__init__.py,sha256=m271m9YjK2QfKDOuIIhcTvGmv1GWh1PL97QB05NTSHs,887
|
116
116
|
flwr/common/auth_plugin/auth_plugin.py,sha256=GaXw4IiU2DkVNkp5S9ue821sbkU9zWSu6HSVZetEdjs,3938
|
117
117
|
flwr/common/config.py,sha256=glcZDjco-amw1YfQcYTFJ4S1pt9APoexT-mf1QscuHs,13960
|
118
|
-
flwr/common/constant.py,sha256=
|
118
|
+
flwr/common/constant.py,sha256=RmVW2YLGosdBzyePgj2EMdZnHrT1PKEtNScaNK_FHZ0,7302
|
119
119
|
flwr/common/context.py,sha256=Be8obQR_OvEDy1OmshuUKxGRQ7Qx89mf5F4xlhkR10s,2407
|
120
120
|
flwr/common/date.py,sha256=1ZT2cRSpC2DJqprOVTLXYCR_O2_OZR0zXO_brJ3LqWc,1554
|
121
121
|
flwr/common/differential_privacy.py,sha256=FdlpdpPl_H_2HJa8CQM1iCUGBBQ5Dc8CzxmHERM-EoE,6148
|
@@ -129,18 +129,18 @@ flwr/common/exit/exit_code.py,sha256=PNEnCrZfOILjfDAFu5m-2YWEJBrk97xglq4zCUlqV7E
|
|
129
129
|
flwr/common/exit_handlers.py,sha256=MEk5_savTLphn-6lW57UQlos-XrFA39XEBn-OF1vXXg,3174
|
130
130
|
flwr/common/grpc.py,sha256=manTaHaPiyYngUq1ErZvvV2B2GxlXUUUGRy3jc3TBIQ,9798
|
131
131
|
flwr/common/heartbeat.py,sha256=yzi-gWH5wswdg0hfQwxwGkjI5twxIHBBVW45MD5QITI,3924
|
132
|
-
flwr/common/inflatable.py,sha256=
|
132
|
+
flwr/common/inflatable.py,sha256=yCfnRYj4xeUqV2m-K5hcQPeVhL7gdSGw7CewPYKnjnE,3156
|
133
133
|
flwr/common/logger.py,sha256=JbRf6E2vQxXzpDBq1T8IDUJo_usu3gjWEBPQ6uKcmdg,13049
|
134
134
|
flwr/common/message.py,sha256=znr205Erq2hkxwFbvNNCsQTRS2UKv_Qsyu0sFNEhEAw,23721
|
135
135
|
flwr/common/object_ref.py,sha256=p3SfTeqo3Aj16SkB-vsnNn01zswOPdGNBitcbRnqmUk,9134
|
136
136
|
flwr/common/parameter.py,sha256=UVw6sOgehEFhFs4uUCMl2kfVq1PD6ncmWgPLMsZPKPE,2095
|
137
137
|
flwr/common/pyproject.py,sha256=2SU6yJW7059SbMXgzjOdK1GZRWO6AixDH7BmdxbMvHI,1386
|
138
138
|
flwr/common/record/__init__.py,sha256=cNGccdDoxttqgnUgyKRIqLWULjW-NaSmOufVxtXq-sw,1197
|
139
|
-
flwr/common/record/array.py,sha256=
|
139
|
+
flwr/common/record/array.py,sha256=tPTT6cw7B1Fo626LOVaA_sfj2_EtkxdnvSkRTyPrVRY,10469
|
140
140
|
flwr/common/record/arrayrecord.py,sha256=KbehV2yXJ_6ZWcHPPrC-MNkE00DRCObxgyrVLwBQ5OY,14389
|
141
|
-
flwr/common/record/configrecord.py,sha256=
|
141
|
+
flwr/common/record/configrecord.py,sha256=lXVGjNfQD3lqvQTstGPFfQjeEHl29alfoL9trCKKlY4,9269
|
142
142
|
flwr/common/record/conversion_utils.py,sha256=wbNCzy7oAqaA3-arhls_EqRZYXRC4YrWIoE-Gy82fJ0,1191
|
143
|
-
flwr/common/record/metricrecord.py,sha256=
|
143
|
+
flwr/common/record/metricrecord.py,sha256=MRMv0fSmJvHlg0HtX_s4IBqxHAh8QHgFN75CmR6fBOU,8444
|
144
144
|
flwr/common/record/recorddict.py,sha256=zo7TiVZCH_LB9gwUP7-Jo-jLpFLrvxYSryovwZANQiw,12386
|
145
145
|
flwr/common/record/typeddict.py,sha256=dDKgUThs2BscYUNcgP82KP8-qfAYXYftDrf2LszAC_o,3599
|
146
146
|
flwr/common/recorddict_compat.py,sha256=Znn1xRGiqLpPPgviVqyb-GPTM-pCK6tpnEmhWSXafy8,14119
|
@@ -153,7 +153,8 @@ flwr/common/secure_aggregation/ndarrays_arithmetic.py,sha256=TrggOlizlny3V2KS7-3
|
|
153
153
|
flwr/common/secure_aggregation/quantization.py,sha256=ssFZpiRyj9ltIh0Ai3vGkDqWFO4SoqgoD1mDU9XqMEM,2400
|
154
154
|
flwr/common/secure_aggregation/secaggplus_constants.py,sha256=dGYhWOBMMDJcQH4_tQNC8-Efqm-ecEUNN9ANz59UnCk,2182
|
155
155
|
flwr/common/secure_aggregation/secaggplus_utils.py,sha256=E_xU-Zd45daO1em7M6C2wOjFXVtJf-6tl7fp-7xq1wo,3214
|
156
|
-
flwr/common/serde.py,sha256=
|
156
|
+
flwr/common/serde.py,sha256=B9PpDh_f3sO0okztebD_6bgX1dK-pJtDr6CNtZ-gJIQ,23910
|
157
|
+
flwr/common/serde_utils.py,sha256=ofmrgVHRBfrE1MtQwLQk0x12JS9vL-u8wHXrgZE2ueg,3985
|
157
158
|
flwr/common/telemetry.py,sha256=jF47v0SbnBd43XamHtl3wKxs3knFUY2p77cm_2lzZ8M,8762
|
158
159
|
flwr/common/typing.py,sha256=97QRfRRS7sQnjkAI5FDZ01-38oQUSz4i1qqewQmBWRg,6886
|
159
160
|
flwr/common/version.py,sha256=7GAGzPn73Mkh09qhrjbmjZQtcqVhBuzhFBaK4Mk4VRk,1325
|
@@ -332,7 +333,7 @@ flwr/superexec/exec_servicer.py,sha256=Z0YYfs6eNPhqn8rY0x_R04XgR2mKFpggt07IH0EhU
|
|
332
333
|
flwr/superexec/exec_user_auth_interceptor.py,sha256=iqygALkOMBUu_s_R9G0mFThZA7HTUzuXCLgxLCefiwI,4440
|
333
334
|
flwr/superexec/executor.py,sha256=M5ucqSE53jfRtuCNf59WFLqQvA1Mln4741TySeZE7qQ,3112
|
334
335
|
flwr/superexec/simulation.py,sha256=j6YwUvBN7EQ09ID7MYOCVZ70PGbuyBy8f9bXU0EszEM,4088
|
335
|
-
flwr_nightly-1.19.0.
|
336
|
-
flwr_nightly-1.19.0.
|
337
|
-
flwr_nightly-1.19.0.
|
338
|
-
flwr_nightly-1.19.0.
|
336
|
+
flwr_nightly-1.19.0.dev20250510.dist-info/METADATA,sha256=xpLHVo5qUcXHQZ8I7pPAqABpFf4KBBoRQxccdZ8GmzY,15880
|
337
|
+
flwr_nightly-1.19.0.dev20250510.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
338
|
+
flwr_nightly-1.19.0.dev20250510.dist-info/entry_points.txt,sha256=2-1L-GNKhwGw2_7_RoH55vHw2SIHjdAQy3HAVAWl9PY,374
|
339
|
+
flwr_nightly-1.19.0.dev20250510.dist-info/RECORD,,
|
{flwr_nightly-1.19.0.dev20250508.dist-info → flwr_nightly-1.19.0.dev20250510.dist-info}/WHEEL
RENAMED
File without changes
|
File without changes
|