flwr-nightly 1.19.0.dev20250507__py3-none-any.whl → 1.19.0.dev20250509__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/common/constant.py +7 -0
- flwr/common/inflatable.py +96 -0
- flwr/common/record/array.py +39 -1
- flwr/common/record/configrecord.py +49 -4
- flwr/common/record/metricrecord.py +45 -4
- flwr/common/serde.py +7 -104
- flwr/common/serde_utils.py +123 -0
- {flwr_nightly-1.19.0.dev20250507.dist-info → flwr_nightly-1.19.0.dev20250509.dist-info}/METADATA +1 -1
- {flwr_nightly-1.19.0.dev20250507.dist-info → flwr_nightly-1.19.0.dev20250509.dist-info}/RECORD +11 -9
- {flwr_nightly-1.19.0.dev20250507.dist-info → flwr_nightly-1.19.0.dev20250509.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.19.0.dev20250507.dist-info → flwr_nightly-1.19.0.dev20250509.dist-info}/entry_points.txt +0 -0
flwr/common/constant.py
CHANGED
@@ -127,6 +127,13 @@ GRPC_RETRY_MAX_DELAY = 20 # Maximum delay duration between two consecutive retr
|
|
127
127
|
# Constants for ArrayRecord
|
128
128
|
GC_THRESHOLD = 200_000_000 # 200 MB
|
129
129
|
|
130
|
+
# Constants for Inflatable
|
131
|
+
HEAD_BODY_DIVIDER = b"\x00"
|
132
|
+
TYPE_BODY_LEN_DIVIDER = " "
|
133
|
+
|
134
|
+
# Constants for serialization
|
135
|
+
INT64_MAX_VALUE = 9223372036854775807 # (1 << 63) - 1
|
136
|
+
|
130
137
|
|
131
138
|
class MessageType:
|
132
139
|
"""Message type."""
|
@@ -0,0 +1,96 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""InflatableObject base class."""
|
16
|
+
|
17
|
+
|
18
|
+
import hashlib
|
19
|
+
from typing import TypeVar
|
20
|
+
|
21
|
+
from .constant import HEAD_BODY_DIVIDER, TYPE_BODY_LEN_DIVIDER
|
22
|
+
|
23
|
+
T = TypeVar("T", bound="InflatableObject")
|
24
|
+
|
25
|
+
|
26
|
+
class InflatableObject:
|
27
|
+
"""Base class for inflatable objects."""
|
28
|
+
|
29
|
+
def deflate(self) -> bytes:
|
30
|
+
"""Deflate object."""
|
31
|
+
raise NotImplementedError()
|
32
|
+
|
33
|
+
@classmethod
|
34
|
+
def inflate(cls, object_content: bytes) -> "InflatableObject":
|
35
|
+
"""Inflate the object from bytes.
|
36
|
+
|
37
|
+
Parameters
|
38
|
+
----------
|
39
|
+
object_content : bytes
|
40
|
+
The deflated object content.
|
41
|
+
|
42
|
+
Returns
|
43
|
+
-------
|
44
|
+
InflatableObject
|
45
|
+
The inflated object.
|
46
|
+
"""
|
47
|
+
raise NotImplementedError()
|
48
|
+
|
49
|
+
@property
|
50
|
+
def object_id(self) -> str:
|
51
|
+
"""Get object_id."""
|
52
|
+
return get_object_id(self.deflate())
|
53
|
+
|
54
|
+
|
55
|
+
def get_object_id(object_content: bytes) -> str:
|
56
|
+
"""Return a SHA-256 hash of the (deflated) object content."""
|
57
|
+
return hashlib.sha256(object_content).hexdigest()
|
58
|
+
|
59
|
+
|
60
|
+
def get_object_body(object_content: bytes, cls: type[T]) -> bytes:
|
61
|
+
"""Return object body but raise an error if object type doesn't match class name."""
|
62
|
+
class_name = cls.__qualname__
|
63
|
+
object_type = get_object_type_from_object_content(object_content)
|
64
|
+
if not object_type == class_name:
|
65
|
+
raise ValueError(
|
66
|
+
f"Class name ({class_name}) and object type "
|
67
|
+
f"({object_type}) do not match."
|
68
|
+
)
|
69
|
+
|
70
|
+
# Return object body
|
71
|
+
return _get_object_body(object_content)
|
72
|
+
|
73
|
+
|
74
|
+
def add_header_to_object_body(object_body: bytes, cls: T) -> bytes:
|
75
|
+
"""Add header to object content."""
|
76
|
+
# Construct header
|
77
|
+
header = f"{cls.__class__.__qualname__}{TYPE_BODY_LEN_DIVIDER}{len(object_body)}"
|
78
|
+
enc_header = header.encode(encoding="utf-8")
|
79
|
+
# Concatenate header and object body
|
80
|
+
return enc_header + HEAD_BODY_DIVIDER + object_body
|
81
|
+
|
82
|
+
|
83
|
+
def _get_object_head(object_content: bytes) -> bytes:
|
84
|
+
"""Return object head from object content."""
|
85
|
+
return object_content.split(HEAD_BODY_DIVIDER, 1)[0]
|
86
|
+
|
87
|
+
|
88
|
+
def _get_object_body(object_content: bytes) -> bytes:
|
89
|
+
"""Return object body from object content."""
|
90
|
+
return object_content.split(HEAD_BODY_DIVIDER, 1)[1]
|
91
|
+
|
92
|
+
|
93
|
+
def get_object_type_from_object_content(object_content: bytes) -> str:
|
94
|
+
"""Return object type from bytes."""
|
95
|
+
obj_head: str = _get_object_head(object_content).decode(encoding="utf-8")
|
96
|
+
return obj_head.split(TYPE_BODY_LEN_DIVIDER, 1)[0]
|
flwr/common/record/array.py
CHANGED
@@ -24,7 +24,10 @@ from typing import TYPE_CHECKING, Any, cast, overload
|
|
24
24
|
|
25
25
|
import numpy as np
|
26
26
|
|
27
|
+
from flwr.proto.recorddict_pb2 import Array as ArrayProto # pylint: disable=E0611
|
28
|
+
|
27
29
|
from ..constant import SType
|
30
|
+
from ..inflatable import InflatableObject, add_header_to_object_body, get_object_body
|
28
31
|
from ..typing import NDArray
|
29
32
|
|
30
33
|
if TYPE_CHECKING:
|
@@ -40,7 +43,7 @@ def _raise_array_init_error() -> None:
|
|
40
43
|
|
41
44
|
|
42
45
|
@dataclass
|
43
|
-
class Array:
|
46
|
+
class Array(InflatableObject):
|
44
47
|
"""Array type.
|
45
48
|
|
46
49
|
A dataclass containing serialized data from an array-like or tensor-like object
|
@@ -248,3 +251,38 @@ class Array:
|
|
248
251
|
# Source: https://numpy.org/doc/stable/reference/generated/numpy.load.html
|
249
252
|
ndarray_deserialized = np.load(bytes_io, allow_pickle=False)
|
250
253
|
return cast(NDArray, ndarray_deserialized)
|
254
|
+
|
255
|
+
def deflate(self) -> bytes:
|
256
|
+
"""Deflate the Array."""
|
257
|
+
array_proto = ArrayProto(
|
258
|
+
dtype=self.dtype,
|
259
|
+
shape=self.shape,
|
260
|
+
stype=self.stype,
|
261
|
+
data=self.data,
|
262
|
+
)
|
263
|
+
|
264
|
+
obj_body = array_proto.SerializeToString(deterministic=True)
|
265
|
+
return add_header_to_object_body(object_body=obj_body, cls=self)
|
266
|
+
|
267
|
+
@classmethod
|
268
|
+
def inflate(cls, object_content: bytes) -> Array:
|
269
|
+
"""Inflate an Array from bytes.
|
270
|
+
|
271
|
+
Parameters
|
272
|
+
----------
|
273
|
+
object_content : bytes
|
274
|
+
The deflated object content of the Array.
|
275
|
+
|
276
|
+
Returns
|
277
|
+
-------
|
278
|
+
Array
|
279
|
+
The inflated Array.
|
280
|
+
"""
|
281
|
+
obj_body = get_object_body(object_content, cls)
|
282
|
+
proto_array = ArrayProto.FromString(obj_body)
|
283
|
+
return cls(
|
284
|
+
dtype=proto_array.dtype,
|
285
|
+
shape=list(proto_array.shape),
|
286
|
+
stype=proto_array.stype,
|
287
|
+
data=proto_array.data,
|
288
|
+
)
|
@@ -15,12 +15,21 @@
|
|
15
15
|
"""ConfigRecord."""
|
16
16
|
|
17
17
|
|
18
|
+
from __future__ import annotations
|
19
|
+
|
18
20
|
from logging import WARN
|
19
|
-
from typing import
|
21
|
+
from typing import cast, get_args
|
20
22
|
|
21
23
|
from flwr.common.typing import ConfigRecordValues, ConfigScalar
|
22
24
|
|
25
|
+
# pylint: disable=E0611
|
26
|
+
from flwr.proto.recorddict_pb2 import ConfigRecord as ProtoConfigRecord
|
27
|
+
from flwr.proto.recorddict_pb2 import ConfigRecordValue as ProtoConfigRecordValue
|
28
|
+
|
29
|
+
# pylint: enable=E0611
|
30
|
+
from ..inflatable import InflatableObject, add_header_to_object_body, get_object_body
|
23
31
|
from ..logger import log
|
32
|
+
from ..serde_utils import record_value_dict_from_proto, record_value_dict_to_proto
|
24
33
|
from .typeddict import TypedDict
|
25
34
|
|
26
35
|
|
@@ -59,7 +68,7 @@ def _check_value(value: ConfigRecordValues) -> None:
|
|
59
68
|
is_valid(value)
|
60
69
|
|
61
70
|
|
62
|
-
class ConfigRecord(TypedDict[str, ConfigRecordValues]):
|
71
|
+
class ConfigRecord(TypedDict[str, ConfigRecordValues], InflatableObject):
|
63
72
|
"""Config record.
|
64
73
|
|
65
74
|
A :code:`ConfigRecord` is a Python dictionary designed to ensure that
|
@@ -111,7 +120,7 @@ class ConfigRecord(TypedDict[str, ConfigRecordValues]):
|
|
111
120
|
|
112
121
|
def __init__(
|
113
122
|
self,
|
114
|
-
config_dict:
|
123
|
+
config_dict: dict[str, ConfigRecordValues] | None = None,
|
115
124
|
keep_input: bool = True,
|
116
125
|
) -> None:
|
117
126
|
|
@@ -164,6 +173,42 @@ class ConfigRecord(TypedDict[str, ConfigRecordValues]):
|
|
164
173
|
|
165
174
|
return num_bytes
|
166
175
|
|
176
|
+
def deflate(self) -> bytes:
|
177
|
+
"""Deflate object."""
|
178
|
+
obj_body = ProtoConfigRecord(
|
179
|
+
data=record_value_dict_to_proto(
|
180
|
+
self,
|
181
|
+
[bool, int, float, str, bytes],
|
182
|
+
ProtoConfigRecordValue,
|
183
|
+
)
|
184
|
+
).SerializeToString(deterministic=True)
|
185
|
+
return add_header_to_object_body(object_body=obj_body, cls=self)
|
186
|
+
|
187
|
+
@classmethod
|
188
|
+
def inflate(cls, object_content: bytes) -> ConfigRecord:
|
189
|
+
"""Inflate a ConfigRecord from bytes.
|
190
|
+
|
191
|
+
Parameters
|
192
|
+
----------
|
193
|
+
object_content : bytes
|
194
|
+
The deflated object content of the ConfigRecord.
|
195
|
+
|
196
|
+
Returns
|
197
|
+
-------
|
198
|
+
ConfigRecord
|
199
|
+
The inflated ConfigRecord.
|
200
|
+
"""
|
201
|
+
obj_body = get_object_body(object_content, cls)
|
202
|
+
config_record_proto = ProtoConfigRecord.FromString(obj_body)
|
203
|
+
|
204
|
+
return ConfigRecord(
|
205
|
+
config_dict=cast(
|
206
|
+
dict[str, ConfigRecordValues],
|
207
|
+
record_value_dict_from_proto(config_record_proto.data),
|
208
|
+
),
|
209
|
+
keep_input=False,
|
210
|
+
)
|
211
|
+
|
167
212
|
|
168
213
|
class ConfigsRecord(ConfigRecord):
|
169
214
|
"""Deprecated class ``ConfigsRecord``, use ``ConfigRecord`` instead.
|
@@ -195,7 +240,7 @@ class ConfigsRecord(ConfigRecord):
|
|
195
240
|
|
196
241
|
def __init__(
|
197
242
|
self,
|
198
|
-
config_dict:
|
243
|
+
config_dict: dict[str, ConfigRecordValues] | None = None,
|
199
244
|
keep_input: bool = True,
|
200
245
|
):
|
201
246
|
if not ConfigsRecord._warning_logged:
|
@@ -15,12 +15,21 @@
|
|
15
15
|
"""MetricRecord."""
|
16
16
|
|
17
17
|
|
18
|
+
from __future__ import annotations
|
19
|
+
|
18
20
|
from logging import WARN
|
19
|
-
from typing import
|
21
|
+
from typing import cast, get_args
|
20
22
|
|
21
23
|
from flwr.common.typing import MetricRecordValues, MetricScalar
|
22
24
|
|
25
|
+
# pylint: disable=E0611
|
26
|
+
from flwr.proto.recorddict_pb2 import MetricRecord as ProtoMetricRecord
|
27
|
+
from flwr.proto.recorddict_pb2 import MetricRecordValue as ProtoMetricRecordValue
|
28
|
+
|
29
|
+
# pylint: enable=E0611
|
30
|
+
from ..inflatable import InflatableObject, add_header_to_object_body, get_object_body
|
23
31
|
from ..logger import log
|
32
|
+
from ..serde_utils import record_value_dict_from_proto, record_value_dict_to_proto
|
24
33
|
from .typeddict import TypedDict
|
25
34
|
|
26
35
|
|
@@ -59,7 +68,7 @@ def _check_value(value: MetricRecordValues) -> None:
|
|
59
68
|
is_valid(value)
|
60
69
|
|
61
70
|
|
62
|
-
class MetricRecord(TypedDict[str, MetricRecordValues]):
|
71
|
+
class MetricRecord(TypedDict[str, MetricRecordValues], InflatableObject):
|
63
72
|
"""Metric record.
|
64
73
|
|
65
74
|
A :code:`MetricRecord` is a Python dictionary designed to ensure that
|
@@ -117,7 +126,7 @@ class MetricRecord(TypedDict[str, MetricRecordValues]):
|
|
117
126
|
|
118
127
|
def __init__(
|
119
128
|
self,
|
120
|
-
metric_dict:
|
129
|
+
metric_dict: dict[str, MetricRecordValues] | None = None,
|
121
130
|
keep_input: bool = True,
|
122
131
|
) -> None:
|
123
132
|
super().__init__(_check_key, _check_value)
|
@@ -143,6 +152,38 @@ class MetricRecord(TypedDict[str, MetricRecordValues]):
|
|
143
152
|
num_bytes += len(k)
|
144
153
|
return num_bytes
|
145
154
|
|
155
|
+
def deflate(self) -> bytes:
|
156
|
+
"""Deflate object."""
|
157
|
+
obj_body = ProtoMetricRecord(
|
158
|
+
data=record_value_dict_to_proto(self, [float, int], ProtoMetricRecordValue)
|
159
|
+
).SerializeToString(deterministic=True)
|
160
|
+
return add_header_to_object_body(object_body=obj_body, cls=self)
|
161
|
+
|
162
|
+
@classmethod
|
163
|
+
def inflate(cls, object_content: bytes) -> MetricRecord:
|
164
|
+
"""Inflate a MetricRecord from bytes.
|
165
|
+
|
166
|
+
Parameters
|
167
|
+
----------
|
168
|
+
object_content : bytes
|
169
|
+
The deflated object content of the MetricRecord.
|
170
|
+
|
171
|
+
Returns
|
172
|
+
-------
|
173
|
+
MetricRecord
|
174
|
+
The inflated MetricRecord.
|
175
|
+
"""
|
176
|
+
obj_body = get_object_body(object_content, cls)
|
177
|
+
metric_record_proto = ProtoMetricRecord.FromString(obj_body)
|
178
|
+
|
179
|
+
return cls(
|
180
|
+
metric_dict=cast(
|
181
|
+
dict[str, MetricRecordValues],
|
182
|
+
record_value_dict_from_proto(metric_record_proto.data),
|
183
|
+
),
|
184
|
+
keep_input=False,
|
185
|
+
)
|
186
|
+
|
146
187
|
|
147
188
|
class MetricsRecord(MetricRecord):
|
148
189
|
"""Deprecated class ``MetricsRecord``, use ``MetricRecord`` instead.
|
@@ -174,7 +215,7 @@ class MetricsRecord(MetricRecord):
|
|
174
215
|
|
175
216
|
def __init__(
|
176
217
|
self,
|
177
|
-
metric_dict:
|
218
|
+
metric_dict: dict[str, MetricRecordValues] | None = None,
|
178
219
|
keep_input: bool = True,
|
179
220
|
):
|
180
221
|
if not MetricsRecord._warning_logged:
|
flwr/common/serde.py
CHANGED
@@ -16,10 +16,7 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
from collections import OrderedDict
|
19
|
-
from
|
20
|
-
from typing import Any, TypeVar, cast
|
21
|
-
|
22
|
-
from google.protobuf.message import Message as GrpcMessage
|
19
|
+
from typing import Any, cast
|
23
20
|
|
24
21
|
# pylint: disable=E0611
|
25
22
|
from flwr.proto.clientappio_pb2 import ClientAppOutputCode, ClientAppOutputStatus
|
@@ -30,14 +27,11 @@ from flwr.proto.message_pb2 import Message as ProtoMessage
|
|
30
27
|
from flwr.proto.message_pb2 import Metadata as ProtoMetadata
|
31
28
|
from flwr.proto.recorddict_pb2 import Array as ProtoArray
|
32
29
|
from flwr.proto.recorddict_pb2 import ArrayRecord as ProtoArrayRecord
|
33
|
-
from flwr.proto.recorddict_pb2 import BoolList, BytesList
|
34
30
|
from flwr.proto.recorddict_pb2 import ConfigRecord as ProtoConfigRecord
|
35
31
|
from flwr.proto.recorddict_pb2 import ConfigRecordValue as ProtoConfigRecordValue
|
36
|
-
from flwr.proto.recorddict_pb2 import DoubleList
|
37
32
|
from flwr.proto.recorddict_pb2 import MetricRecord as ProtoMetricRecord
|
38
33
|
from flwr.proto.recorddict_pb2 import MetricRecordValue as ProtoMetricRecordValue
|
39
34
|
from flwr.proto.recorddict_pb2 import RecordDict as ProtoRecordDict
|
40
|
-
from flwr.proto.recorddict_pb2 import SintList, StringList, UintList
|
41
35
|
from flwr.proto.run_pb2 import Run as ProtoRun
|
42
36
|
from flwr.proto.run_pb2 import RunStatus as ProtoRunStatus
|
43
37
|
from flwr.proto.transport_pb2 import (
|
@@ -60,8 +54,9 @@ from . import (
|
|
60
54
|
RecordDict,
|
61
55
|
typing,
|
62
56
|
)
|
57
|
+
from .constant import INT64_MAX_VALUE
|
63
58
|
from .message import Error, Message, Metadata, make_message
|
64
|
-
from .
|
59
|
+
from .serde_utils import record_value_dict_from_proto, record_value_dict_to_proto
|
65
60
|
|
66
61
|
# === Parameters message ===
|
67
62
|
|
@@ -339,7 +334,6 @@ def metrics_from_proto(proto: Any) -> typing.Metrics:
|
|
339
334
|
|
340
335
|
|
341
336
|
# === Scalar messages ===
|
342
|
-
INT64_MAX_VALUE = 9223372036854775807 # (1 << 63) - 1
|
343
337
|
|
344
338
|
|
345
339
|
def scalar_to_proto(scalar: typing.Scalar) -> Scalar:
|
@@ -377,97 +371,6 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar:
|
|
377
371
|
# === Record messages ===
|
378
372
|
|
379
373
|
|
380
|
-
_type_to_field: dict[type, str] = {
|
381
|
-
float: "double",
|
382
|
-
int: "sint64",
|
383
|
-
bool: "bool",
|
384
|
-
str: "string",
|
385
|
-
bytes: "bytes",
|
386
|
-
}
|
387
|
-
_list_type_to_class_and_field: dict[type, tuple[type[GrpcMessage], str]] = {
|
388
|
-
float: (DoubleList, "double_list"),
|
389
|
-
int: (SintList, "sint_list"),
|
390
|
-
bool: (BoolList, "bool_list"),
|
391
|
-
str: (StringList, "string_list"),
|
392
|
-
bytes: (BytesList, "bytes_list"),
|
393
|
-
}
|
394
|
-
T = TypeVar("T")
|
395
|
-
|
396
|
-
|
397
|
-
def _is_uint64(value: Any) -> bool:
|
398
|
-
"""Check if a value is uint64."""
|
399
|
-
return isinstance(value, int) and value > INT64_MAX_VALUE
|
400
|
-
|
401
|
-
|
402
|
-
def _record_value_to_proto(
|
403
|
-
value: Any, allowed_types: list[type], proto_class: type[T]
|
404
|
-
) -> T:
|
405
|
-
"""Serialize `*RecordValue` to ProtoBuf.
|
406
|
-
|
407
|
-
Note: `bool` MUST be put in the front of allowd_types if it exists.
|
408
|
-
"""
|
409
|
-
arg = {}
|
410
|
-
for t in allowed_types:
|
411
|
-
# Single element
|
412
|
-
# Note: `isinstance(False, int) == True`.
|
413
|
-
if isinstance(value, t):
|
414
|
-
fld = _type_to_field[t]
|
415
|
-
if t is int and _is_uint64(value):
|
416
|
-
fld = "uint64"
|
417
|
-
arg[fld] = value
|
418
|
-
return proto_class(**arg)
|
419
|
-
# List
|
420
|
-
if isinstance(value, list) and all(isinstance(item, t) for item in value):
|
421
|
-
list_class, fld = _list_type_to_class_and_field[t]
|
422
|
-
# Use UintList if any element is of type `uint64`.
|
423
|
-
if t is int and any(_is_uint64(v) for v in value):
|
424
|
-
list_class, fld = UintList, "uint_list"
|
425
|
-
arg[fld] = list_class(vals=value)
|
426
|
-
return proto_class(**arg)
|
427
|
-
# Invalid types
|
428
|
-
raise TypeError(
|
429
|
-
f"The type of the following value is not allowed "
|
430
|
-
f"in '{proto_class.__name__}':\n{value}"
|
431
|
-
)
|
432
|
-
|
433
|
-
|
434
|
-
def _record_value_from_proto(value_proto: GrpcMessage) -> Any:
|
435
|
-
"""Deserialize `*RecordValue` from ProtoBuf."""
|
436
|
-
value_field = cast(str, value_proto.WhichOneof("value"))
|
437
|
-
if value_field.endswith("list"):
|
438
|
-
value = list(getattr(value_proto, value_field).vals)
|
439
|
-
else:
|
440
|
-
value = getattr(value_proto, value_field)
|
441
|
-
return value
|
442
|
-
|
443
|
-
|
444
|
-
def _record_value_dict_to_proto(
|
445
|
-
value_dict: TypedDict[str, Any],
|
446
|
-
allowed_types: list[type],
|
447
|
-
value_proto_class: type[T],
|
448
|
-
) -> dict[str, T]:
|
449
|
-
"""Serialize the record value dict to ProtoBuf.
|
450
|
-
|
451
|
-
Note: `bool` MUST be put in the front of allowd_types if it exists.
|
452
|
-
"""
|
453
|
-
# Move bool to the front
|
454
|
-
if bool in allowed_types and allowed_types[0] != bool:
|
455
|
-
allowed_types.remove(bool)
|
456
|
-
allowed_types.insert(0, bool)
|
457
|
-
|
458
|
-
def proto(_v: Any) -> T:
|
459
|
-
return _record_value_to_proto(_v, allowed_types, value_proto_class)
|
460
|
-
|
461
|
-
return {k: proto(v) for k, v in value_dict.items()}
|
462
|
-
|
463
|
-
|
464
|
-
def _record_value_dict_from_proto(
|
465
|
-
value_dict_proto: MutableMapping[str, Any]
|
466
|
-
) -> dict[str, Any]:
|
467
|
-
"""Deserialize the record value dict from ProtoBuf."""
|
468
|
-
return {k: _record_value_from_proto(v) for k, v in value_dict_proto.items()}
|
469
|
-
|
470
|
-
|
471
374
|
def array_to_proto(array: Array) -> ProtoArray:
|
472
375
|
"""Serialize Array to ProtoBuf."""
|
473
376
|
return ProtoArray(**vars(array))
|
@@ -506,7 +409,7 @@ def array_record_from_proto(
|
|
506
409
|
def metric_record_to_proto(record: MetricRecord) -> ProtoMetricRecord:
|
507
410
|
"""Serialize MetricRecord to ProtoBuf."""
|
508
411
|
return ProtoMetricRecord(
|
509
|
-
data=
|
412
|
+
data=record_value_dict_to_proto(record, [float, int], ProtoMetricRecordValue)
|
510
413
|
)
|
511
414
|
|
512
415
|
|
@@ -515,7 +418,7 @@ def metric_record_from_proto(record_proto: ProtoMetricRecord) -> MetricRecord:
|
|
515
418
|
return MetricRecord(
|
516
419
|
metric_dict=cast(
|
517
420
|
dict[str, typing.MetricRecordValues],
|
518
|
-
|
421
|
+
record_value_dict_from_proto(record_proto.data),
|
519
422
|
),
|
520
423
|
keep_input=False,
|
521
424
|
)
|
@@ -524,7 +427,7 @@ def metric_record_from_proto(record_proto: ProtoMetricRecord) -> MetricRecord:
|
|
524
427
|
def config_record_to_proto(record: ConfigRecord) -> ProtoConfigRecord:
|
525
428
|
"""Serialize ConfigRecord to ProtoBuf."""
|
526
429
|
return ProtoConfigRecord(
|
527
|
-
data=
|
430
|
+
data=record_value_dict_to_proto(
|
528
431
|
record,
|
529
432
|
[bool, int, float, str, bytes],
|
530
433
|
ProtoConfigRecordValue,
|
@@ -537,7 +440,7 @@ def config_record_from_proto(record_proto: ProtoConfigRecord) -> ConfigRecord:
|
|
537
440
|
return ConfigRecord(
|
538
441
|
config_dict=cast(
|
539
442
|
dict[str, typing.ConfigRecordValues],
|
540
|
-
|
443
|
+
record_value_dict_from_proto(record_proto.data),
|
541
444
|
),
|
542
445
|
keep_input=False,
|
543
446
|
)
|
@@ -0,0 +1,123 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Utils for serde."""
|
16
|
+
|
17
|
+
from collections.abc import MutableMapping
|
18
|
+
from typing import Any, TypeVar, cast
|
19
|
+
|
20
|
+
from google.protobuf.message import Message as GrpcMessage
|
21
|
+
|
22
|
+
# pylint: disable=E0611
|
23
|
+
from flwr.proto.recorddict_pb2 import (
|
24
|
+
BoolList,
|
25
|
+
BytesList,
|
26
|
+
DoubleList,
|
27
|
+
SintList,
|
28
|
+
StringList,
|
29
|
+
UintList,
|
30
|
+
)
|
31
|
+
|
32
|
+
from .constant import INT64_MAX_VALUE
|
33
|
+
from .record.typeddict import TypedDict
|
34
|
+
|
35
|
+
_type_to_field: dict[type, str] = {
|
36
|
+
float: "double",
|
37
|
+
int: "sint64",
|
38
|
+
bool: "bool",
|
39
|
+
str: "string",
|
40
|
+
bytes: "bytes",
|
41
|
+
}
|
42
|
+
_list_type_to_class_and_field: dict[type, tuple[type[GrpcMessage], str]] = {
|
43
|
+
float: (DoubleList, "double_list"),
|
44
|
+
int: (SintList, "sint_list"),
|
45
|
+
bool: (BoolList, "bool_list"),
|
46
|
+
str: (StringList, "string_list"),
|
47
|
+
bytes: (BytesList, "bytes_list"),
|
48
|
+
}
|
49
|
+
T = TypeVar("T")
|
50
|
+
|
51
|
+
|
52
|
+
def _is_uint64(value: Any) -> bool:
|
53
|
+
"""Check if a value is uint64."""
|
54
|
+
return isinstance(value, int) and value > INT64_MAX_VALUE
|
55
|
+
|
56
|
+
|
57
|
+
def _record_value_to_proto(
|
58
|
+
value: Any, allowed_types: list[type], proto_class: type[T]
|
59
|
+
) -> T:
|
60
|
+
"""Serialize `*RecordValue` to ProtoBuf.
|
61
|
+
|
62
|
+
Note: `bool` MUST be put in the front of allowd_types if it exists.
|
63
|
+
"""
|
64
|
+
arg = {}
|
65
|
+
for t in allowed_types:
|
66
|
+
# Single element
|
67
|
+
# Note: `isinstance(False, int) == True`.
|
68
|
+
if isinstance(value, t):
|
69
|
+
fld = _type_to_field[t]
|
70
|
+
if t is int and _is_uint64(value):
|
71
|
+
fld = "uint64"
|
72
|
+
arg[fld] = value
|
73
|
+
return proto_class(**arg)
|
74
|
+
# List
|
75
|
+
if isinstance(value, list) and all(isinstance(item, t) for item in value):
|
76
|
+
list_class, fld = _list_type_to_class_and_field[t]
|
77
|
+
# Use UintList if any element is of type `uint64`.
|
78
|
+
if t is int and any(_is_uint64(v) for v in value):
|
79
|
+
list_class, fld = UintList, "uint_list"
|
80
|
+
arg[fld] = list_class(vals=value)
|
81
|
+
return proto_class(**arg)
|
82
|
+
# Invalid types
|
83
|
+
raise TypeError(
|
84
|
+
f"The type of the following value is not allowed "
|
85
|
+
f"in '{proto_class.__name__}':\n{value}"
|
86
|
+
)
|
87
|
+
|
88
|
+
|
89
|
+
def _record_value_from_proto(value_proto: GrpcMessage) -> Any:
|
90
|
+
"""Deserialize `*RecordValue` from ProtoBuf."""
|
91
|
+
value_field = cast(str, value_proto.WhichOneof("value"))
|
92
|
+
if value_field.endswith("list"):
|
93
|
+
value = list(getattr(value_proto, value_field).vals)
|
94
|
+
else:
|
95
|
+
value = getattr(value_proto, value_field)
|
96
|
+
return value
|
97
|
+
|
98
|
+
|
99
|
+
def record_value_dict_to_proto(
|
100
|
+
value_dict: TypedDict[str, Any],
|
101
|
+
allowed_types: list[type],
|
102
|
+
value_proto_class: type[T],
|
103
|
+
) -> dict[str, T]:
|
104
|
+
"""Serialize the record value dict to ProtoBuf.
|
105
|
+
|
106
|
+
Note: `bool` MUST be put in the front of allowd_types if it exists.
|
107
|
+
"""
|
108
|
+
# Move bool to the front
|
109
|
+
if bool in allowed_types and allowed_types[0] != bool:
|
110
|
+
allowed_types.remove(bool)
|
111
|
+
allowed_types.insert(0, bool)
|
112
|
+
|
113
|
+
def proto(_v: Any) -> T:
|
114
|
+
return _record_value_to_proto(_v, allowed_types, value_proto_class)
|
115
|
+
|
116
|
+
return {k: proto(v) for k, v in value_dict.items()}
|
117
|
+
|
118
|
+
|
119
|
+
def record_value_dict_from_proto(
|
120
|
+
value_dict_proto: MutableMapping[str, Any]
|
121
|
+
) -> dict[str, Any]:
|
122
|
+
"""Deserialize the record value dict from ProtoBuf."""
|
123
|
+
return {k: _record_value_from_proto(v) for k, v in value_dict_proto.items()}
|
{flwr_nightly-1.19.0.dev20250507.dist-info → flwr_nightly-1.19.0.dev20250509.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: flwr-nightly
|
3
|
-
Version: 1.19.0.
|
3
|
+
Version: 1.19.0.dev20250509
|
4
4
|
Summary: Flower: A Friendly Federated AI Framework
|
5
5
|
License: Apache-2.0
|
6
6
|
Keywords: Artificial Intelligence,Federated AI,Federated Analytics,Federated Evaluation,Federated Learning,Flower,Machine Learning
|
{flwr_nightly-1.19.0.dev20250507.dist-info → flwr_nightly-1.19.0.dev20250509.dist-info}/RECORD
RENAMED
@@ -115,7 +115,7 @@ flwr/common/args.py,sha256=-aX_jVnSaDrJR2KZ8Wq0Y3dQHII4R4MJtJOIXzVUA0c,5417
|
|
115
115
|
flwr/common/auth_plugin/__init__.py,sha256=m271m9YjK2QfKDOuIIhcTvGmv1GWh1PL97QB05NTSHs,887
|
116
116
|
flwr/common/auth_plugin/auth_plugin.py,sha256=GaXw4IiU2DkVNkp5S9ue821sbkU9zWSu6HSVZetEdjs,3938
|
117
117
|
flwr/common/config.py,sha256=glcZDjco-amw1YfQcYTFJ4S1pt9APoexT-mf1QscuHs,13960
|
118
|
-
flwr/common/constant.py,sha256=
|
118
|
+
flwr/common/constant.py,sha256=RmVW2YLGosdBzyePgj2EMdZnHrT1PKEtNScaNK_FHZ0,7302
|
119
119
|
flwr/common/context.py,sha256=Be8obQR_OvEDy1OmshuUKxGRQ7Qx89mf5F4xlhkR10s,2407
|
120
120
|
flwr/common/date.py,sha256=1ZT2cRSpC2DJqprOVTLXYCR_O2_OZR0zXO_brJ3LqWc,1554
|
121
121
|
flwr/common/differential_privacy.py,sha256=FdlpdpPl_H_2HJa8CQM1iCUGBBQ5Dc8CzxmHERM-EoE,6148
|
@@ -129,17 +129,18 @@ flwr/common/exit/exit_code.py,sha256=PNEnCrZfOILjfDAFu5m-2YWEJBrk97xglq4zCUlqV7E
|
|
129
129
|
flwr/common/exit_handlers.py,sha256=MEk5_savTLphn-6lW57UQlos-XrFA39XEBn-OF1vXXg,3174
|
130
130
|
flwr/common/grpc.py,sha256=manTaHaPiyYngUq1ErZvvV2B2GxlXUUUGRy3jc3TBIQ,9798
|
131
131
|
flwr/common/heartbeat.py,sha256=yzi-gWH5wswdg0hfQwxwGkjI5twxIHBBVW45MD5QITI,3924
|
132
|
+
flwr/common/inflatable.py,sha256=yCfnRYj4xeUqV2m-K5hcQPeVhL7gdSGw7CewPYKnjnE,3156
|
132
133
|
flwr/common/logger.py,sha256=JbRf6E2vQxXzpDBq1T8IDUJo_usu3gjWEBPQ6uKcmdg,13049
|
133
134
|
flwr/common/message.py,sha256=znr205Erq2hkxwFbvNNCsQTRS2UKv_Qsyu0sFNEhEAw,23721
|
134
135
|
flwr/common/object_ref.py,sha256=p3SfTeqo3Aj16SkB-vsnNn01zswOPdGNBitcbRnqmUk,9134
|
135
136
|
flwr/common/parameter.py,sha256=UVw6sOgehEFhFs4uUCMl2kfVq1PD6ncmWgPLMsZPKPE,2095
|
136
137
|
flwr/common/pyproject.py,sha256=2SU6yJW7059SbMXgzjOdK1GZRWO6AixDH7BmdxbMvHI,1386
|
137
138
|
flwr/common/record/__init__.py,sha256=cNGccdDoxttqgnUgyKRIqLWULjW-NaSmOufVxtXq-sw,1197
|
138
|
-
flwr/common/record/array.py,sha256=
|
139
|
+
flwr/common/record/array.py,sha256=tPTT6cw7B1Fo626LOVaA_sfj2_EtkxdnvSkRTyPrVRY,10469
|
139
140
|
flwr/common/record/arrayrecord.py,sha256=KbehV2yXJ_6ZWcHPPrC-MNkE00DRCObxgyrVLwBQ5OY,14389
|
140
|
-
flwr/common/record/configrecord.py,sha256=
|
141
|
+
flwr/common/record/configrecord.py,sha256=lXVGjNfQD3lqvQTstGPFfQjeEHl29alfoL9trCKKlY4,9269
|
141
142
|
flwr/common/record/conversion_utils.py,sha256=wbNCzy7oAqaA3-arhls_EqRZYXRC4YrWIoE-Gy82fJ0,1191
|
142
|
-
flwr/common/record/metricrecord.py,sha256=
|
143
|
+
flwr/common/record/metricrecord.py,sha256=MRMv0fSmJvHlg0HtX_s4IBqxHAh8QHgFN75CmR6fBOU,8444
|
143
144
|
flwr/common/record/recorddict.py,sha256=zo7TiVZCH_LB9gwUP7-Jo-jLpFLrvxYSryovwZANQiw,12386
|
144
145
|
flwr/common/record/typeddict.py,sha256=dDKgUThs2BscYUNcgP82KP8-qfAYXYftDrf2LszAC_o,3599
|
145
146
|
flwr/common/recorddict_compat.py,sha256=Znn1xRGiqLpPPgviVqyb-GPTM-pCK6tpnEmhWSXafy8,14119
|
@@ -152,7 +153,8 @@ flwr/common/secure_aggregation/ndarrays_arithmetic.py,sha256=TrggOlizlny3V2KS7-3
|
|
152
153
|
flwr/common/secure_aggregation/quantization.py,sha256=ssFZpiRyj9ltIh0Ai3vGkDqWFO4SoqgoD1mDU9XqMEM,2400
|
153
154
|
flwr/common/secure_aggregation/secaggplus_constants.py,sha256=dGYhWOBMMDJcQH4_tQNC8-Efqm-ecEUNN9ANz59UnCk,2182
|
154
155
|
flwr/common/secure_aggregation/secaggplus_utils.py,sha256=E_xU-Zd45daO1em7M6C2wOjFXVtJf-6tl7fp-7xq1wo,3214
|
155
|
-
flwr/common/serde.py,sha256=
|
156
|
+
flwr/common/serde.py,sha256=B9PpDh_f3sO0okztebD_6bgX1dK-pJtDr6CNtZ-gJIQ,23910
|
157
|
+
flwr/common/serde_utils.py,sha256=ofmrgVHRBfrE1MtQwLQk0x12JS9vL-u8wHXrgZE2ueg,3985
|
156
158
|
flwr/common/telemetry.py,sha256=jF47v0SbnBd43XamHtl3wKxs3knFUY2p77cm_2lzZ8M,8762
|
157
159
|
flwr/common/typing.py,sha256=97QRfRRS7sQnjkAI5FDZ01-38oQUSz4i1qqewQmBWRg,6886
|
158
160
|
flwr/common/version.py,sha256=7GAGzPn73Mkh09qhrjbmjZQtcqVhBuzhFBaK4Mk4VRk,1325
|
@@ -331,7 +333,7 @@ flwr/superexec/exec_servicer.py,sha256=Z0YYfs6eNPhqn8rY0x_R04XgR2mKFpggt07IH0EhU
|
|
331
333
|
flwr/superexec/exec_user_auth_interceptor.py,sha256=iqygALkOMBUu_s_R9G0mFThZA7HTUzuXCLgxLCefiwI,4440
|
332
334
|
flwr/superexec/executor.py,sha256=M5ucqSE53jfRtuCNf59WFLqQvA1Mln4741TySeZE7qQ,3112
|
333
335
|
flwr/superexec/simulation.py,sha256=j6YwUvBN7EQ09ID7MYOCVZ70PGbuyBy8f9bXU0EszEM,4088
|
334
|
-
flwr_nightly-1.19.0.
|
335
|
-
flwr_nightly-1.19.0.
|
336
|
-
flwr_nightly-1.19.0.
|
337
|
-
flwr_nightly-1.19.0.
|
336
|
+
flwr_nightly-1.19.0.dev20250509.dist-info/METADATA,sha256=myNc-raWTpUXv9cnHt3ustV6ebzfD9VlMh_bpGCsJ9c,15880
|
337
|
+
flwr_nightly-1.19.0.dev20250509.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
338
|
+
flwr_nightly-1.19.0.dev20250509.dist-info/entry_points.txt,sha256=2-1L-GNKhwGw2_7_RoH55vHw2SIHjdAQy3HAVAWl9PY,374
|
339
|
+
flwr_nightly-1.19.0.dev20250509.dist-info/RECORD,,
|
{flwr_nightly-1.19.0.dev20250507.dist-info → flwr_nightly-1.19.0.dev20250509.dist-info}/WHEEL
RENAMED
File without changes
|
File without changes
|