flwr-nightly 1.7.0.dev20240116__py3-none-any.whl → 1.7.0.dev20240118__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/client/app.py +7 -4
- flwr/client/dpfedavg_numpy_client.py +4 -4
- flwr/client/grpc_client/connection.py +7 -4
- flwr/client/grpc_rere_client/connection.py +4 -4
- flwr/client/message_handler/message_handler.py +11 -2
- flwr/client/message_handler/task_handler.py +8 -6
- flwr/client/node_state_tests.py +1 -1
- flwr/client/numpy_client.py +2 -2
- flwr/client/rest_client/connection.py +7 -3
- flwr/client/secure_aggregation/secaggplus_handler.py +6 -6
- flwr/client/typing.py +1 -1
- flwr/common/configsrecord.py +98 -0
- flwr/common/logger.py +14 -0
- flwr/common/metricsrecord.py +96 -0
- flwr/common/parametersrecord.py +110 -0
- flwr/common/recordset.py +8 -18
- flwr/common/recordset_utils.py +87 -0
- flwr/common/retry_invoker.py +1 -0
- flwr/common/serde.py +12 -8
- flwr/common/typing.py +9 -0
- flwr/driver/app.py +5 -3
- flwr/driver/driver.py +3 -3
- flwr/driver/driver_client_proxy.py +24 -15
- flwr/driver/grpc_driver.py +6 -6
- flwr/proto/driver_pb2.py +23 -88
- flwr/proto/fleet_pb2.py +29 -111
- flwr/proto/node_pb2.py +7 -15
- flwr/proto/task_pb2.py +33 -127
- flwr/proto/transport_pb2.py +69 -278
- flwr/server/app.py +9 -3
- flwr/server/driver/driver_servicer.py +4 -4
- flwr/server/fleet/grpc_bidi/flower_service_servicer.py +5 -2
- flwr/server/fleet/grpc_bidi/grpc_bridge.py +9 -6
- flwr/server/fleet/grpc_bidi/grpc_client_proxy.py +4 -1
- flwr/server/fleet/grpc_bidi/grpc_server.py +3 -1
- flwr/server/fleet/grpc_bidi/ins_scheduler.py +7 -4
- flwr/server/fleet/grpc_rere/fleet_servicer.py +2 -2
- flwr/server/fleet/message_handler/message_handler.py +3 -3
- flwr/server/fleet/rest_rere/rest_api.py +1 -1
- flwr/server/state/in_memory_state.py +1 -1
- flwr/server/state/sqlite_state.py +8 -5
- flwr/server/state/state.py +1 -1
- flwr/server/strategy/aggregate.py +8 -8
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +2 -2
- flwr/server/strategy/fedavg_android.py +0 -2
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +9 -2
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/utils/validator.py +1 -1
- {flwr_nightly-1.7.0.dev20240116.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/METADATA +3 -3
- {flwr_nightly-1.7.0.dev20240116.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/RECORD +55 -51
- {flwr_nightly-1.7.0.dev20240116.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.7.0.dev20240116.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.7.0.dev20240116.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/entry_points.txt +0 -0
flwr/client/app.py
CHANGED
@@ -35,7 +35,7 @@ from flwr.common.constant import (
|
|
35
35
|
TRANSPORT_TYPES,
|
36
36
|
)
|
37
37
|
from flwr.common.logger import log, warn_experimental_feature
|
38
|
-
from flwr.proto.task_pb2 import TaskIns, TaskRes
|
38
|
+
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
39
39
|
|
40
40
|
from .flower import load_flower_callable
|
41
41
|
from .grpc_client.connection import grpc_connection
|
@@ -138,10 +138,12 @@ def _check_actionable_client(
|
|
138
138
|
client: Optional[Client], client_fn: Optional[ClientFn]
|
139
139
|
) -> None:
|
140
140
|
if client_fn is None and client is None:
|
141
|
-
raise
|
141
|
+
raise ValueError(
|
142
|
+
"Both `client_fn` and `client` are `None`, but one is required"
|
143
|
+
)
|
142
144
|
|
143
145
|
if client_fn is not None and client is not None:
|
144
|
-
raise
|
146
|
+
raise ValueError(
|
145
147
|
"Both `client_fn` and `client` are provided, but only one is allowed"
|
146
148
|
)
|
147
149
|
|
@@ -150,6 +152,7 @@ def _check_actionable_client(
|
|
150
152
|
# pylint: disable=too-many-branches
|
151
153
|
# pylint: disable=too-many-locals
|
152
154
|
# pylint: disable=too-many-statements
|
155
|
+
# pylint: disable=too-many-arguments
|
153
156
|
def start_client(
|
154
157
|
*,
|
155
158
|
server_address: str,
|
@@ -299,7 +302,7 @@ def _start_client_internal(
|
|
299
302
|
cid: str, # pylint: disable=unused-argument
|
300
303
|
) -> Client:
|
301
304
|
if client is None: # Added this to keep mypy happy
|
302
|
-
raise
|
305
|
+
raise ValueError(
|
303
306
|
"Both `client_fn` and `client` are `None`, but one is required"
|
304
307
|
)
|
305
308
|
return client # Always return the same instance
|
@@ -117,16 +117,16 @@ class DPFedAvgNumPyClient(NumPyClient):
|
|
117
117
|
update = [np.subtract(x, y) for (x, y) in zip(updated_params, original_params)]
|
118
118
|
|
119
119
|
if "dpfedavg_clip_norm" not in config:
|
120
|
-
raise
|
120
|
+
raise KeyError("Clipping threshold not supplied by the server.")
|
121
121
|
if not isinstance(config["dpfedavg_clip_norm"], float):
|
122
|
-
raise
|
122
|
+
raise TypeError("Clipping threshold should be a floating point value.")
|
123
123
|
|
124
124
|
# Clipping
|
125
125
|
update, clipped = clip_by_l2(update, config["dpfedavg_clip_norm"])
|
126
126
|
|
127
127
|
if "dpfedavg_noise_stddev" in config:
|
128
128
|
if not isinstance(config["dpfedavg_noise_stddev"], float):
|
129
|
-
raise
|
129
|
+
raise TypeError(
|
130
130
|
"Scale of noise to be added should be a floating point value."
|
131
131
|
)
|
132
132
|
# Noising
|
@@ -138,7 +138,7 @@ class DPFedAvgNumPyClient(NumPyClient):
|
|
138
138
|
# Calculating value of norm indicator bit, required for adaptive clipping
|
139
139
|
if "dpfedavg_adaptive_clip_enabled" in config:
|
140
140
|
if not isinstance(config["dpfedavg_adaptive_clip_enabled"], bool):
|
141
|
-
raise
|
141
|
+
raise TypeError(
|
142
142
|
"dpfedavg_adaptive_clip_enabled should be a boolean-valued flag."
|
143
143
|
)
|
144
144
|
metrics["dpfedavg_norm_bit"] = not clipped
|
@@ -25,10 +25,13 @@ from typing import Callable, Iterator, Optional, Tuple, Union
|
|
25
25
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
26
26
|
from flwr.common.grpc import create_channel
|
27
27
|
from flwr.common.logger import log
|
28
|
-
from flwr.proto.node_pb2 import Node
|
29
|
-
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
|
30
|
-
from flwr.proto.transport_pb2 import
|
31
|
-
|
28
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
29
|
+
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
|
30
|
+
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
31
|
+
ClientMessage,
|
32
|
+
ServerMessage,
|
33
|
+
)
|
34
|
+
from flwr.proto.transport_pb2_grpc import FlowerServiceStub # pylint: disable=E0611
|
32
35
|
|
33
36
|
# The following flags can be uncommented for debugging. Other possible values:
|
34
37
|
# https://github.com/grpc/grpc/blob/master/doc/environment_variables.md
|
@@ -29,15 +29,15 @@ from flwr.client.message_handler.task_handler import (
|
|
29
29
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
30
30
|
from flwr.common.grpc import create_channel
|
31
31
|
from flwr.common.logger import log, warn_experimental_feature
|
32
|
-
from flwr.proto.fleet_pb2 import (
|
32
|
+
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
33
33
|
CreateNodeRequest,
|
34
34
|
DeleteNodeRequest,
|
35
35
|
PullTaskInsRequest,
|
36
36
|
PushTaskResRequest,
|
37
37
|
)
|
38
|
-
from flwr.proto.fleet_pb2_grpc import FleetStub
|
39
|
-
from flwr.proto.node_pb2 import Node
|
40
|
-
from flwr.proto.task_pb2 import TaskIns, TaskRes
|
38
|
+
from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
|
39
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
40
|
+
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
41
41
|
|
42
42
|
KEY_NODE = "node"
|
43
43
|
KEY_TASK_INS = "current_task_ins"
|
@@ -32,8 +32,17 @@ from flwr.client.run_state import RunState
|
|
32
32
|
from flwr.client.secure_aggregation import SecureAggregationHandler
|
33
33
|
from flwr.client.typing import ClientFn
|
34
34
|
from flwr.common import serde
|
35
|
-
from flwr.proto.task_pb2 import
|
36
|
-
|
35
|
+
from flwr.proto.task_pb2 import ( # pylint: disable=E0611
|
36
|
+
SecureAggregation,
|
37
|
+
Task,
|
38
|
+
TaskIns,
|
39
|
+
TaskRes,
|
40
|
+
)
|
41
|
+
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
42
|
+
ClientMessage,
|
43
|
+
Reason,
|
44
|
+
ServerMessage,
|
45
|
+
)
|
37
46
|
|
38
47
|
|
39
48
|
class UnexpectedServerMessage(Exception):
|
@@ -17,10 +17,13 @@
|
|
17
17
|
|
18
18
|
from typing import Optional
|
19
19
|
|
20
|
-
from flwr.proto.fleet_pb2 import PullTaskInsResponse
|
21
|
-
from flwr.proto.node_pb2 import Node
|
22
|
-
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
|
23
|
-
from flwr.proto.transport_pb2 import
|
20
|
+
from flwr.proto.fleet_pb2 import PullTaskInsResponse # pylint: disable=E0611
|
21
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
22
|
+
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
|
23
|
+
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
24
|
+
ClientMessage,
|
25
|
+
ServerMessage,
|
26
|
+
)
|
24
27
|
|
25
28
|
|
26
29
|
def validate_task_ins(task_ins: TaskIns, discard_reconnect_ins: bool) -> bool:
|
@@ -80,8 +83,7 @@ def validate_task_res(task_res: TaskRes) -> bool:
|
|
80
83
|
initialized_fields_in_task = {field.name for field, _ in task_res.task.ListFields()}
|
81
84
|
|
82
85
|
# Check if certain fields are already initialized
|
83
|
-
# pylint: disable-next=too-many-boolean-expressions
|
84
|
-
if (
|
86
|
+
if ( # pylint: disable-next=too-many-boolean-expressions
|
85
87
|
"task_id" in initialized_fields_in_task_res
|
86
88
|
or "group_id" in initialized_fields_in_task_res
|
87
89
|
or "run_id" in initialized_fields_in_task_res
|
flwr/client/node_state_tests.py
CHANGED
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
from flwr.client.node_state import NodeState
|
19
19
|
from flwr.client.run_state import RunState
|
20
|
-
from flwr.proto.task_pb2 import TaskIns
|
20
|
+
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
21
21
|
|
22
22
|
|
23
23
|
def _run_dummy_task(state: RunState) -> RunState:
|
flwr/client/numpy_client.py
CHANGED
@@ -242,7 +242,7 @@ def _fit(self: Client, ins: FitIns) -> FitRes:
|
|
242
242
|
and isinstance(results[1], int)
|
243
243
|
and isinstance(results[2], dict)
|
244
244
|
):
|
245
|
-
raise
|
245
|
+
raise TypeError(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT)
|
246
246
|
|
247
247
|
# Return FitRes
|
248
248
|
parameters_prime, num_examples, metrics = results
|
@@ -266,7 +266,7 @@ def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes:
|
|
266
266
|
and isinstance(results[1], int)
|
267
267
|
and isinstance(results[2], dict)
|
268
268
|
):
|
269
|
-
raise
|
269
|
+
raise TypeError(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_EVALUATE)
|
270
270
|
|
271
271
|
# Return EvaluateRes
|
272
272
|
loss, num_examples, metrics = results
|
@@ -29,7 +29,7 @@ from flwr.client.message_handler.task_handler import (
|
|
29
29
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
30
30
|
from flwr.common.constant import MISSING_EXTRA_REST
|
31
31
|
from flwr.common.logger import log
|
32
|
-
from flwr.proto.fleet_pb2 import (
|
32
|
+
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
33
33
|
CreateNodeRequest,
|
34
34
|
CreateNodeResponse,
|
35
35
|
DeleteNodeRequest,
|
@@ -38,8 +38,8 @@ from flwr.proto.fleet_pb2 import (
|
|
38
38
|
PushTaskResRequest,
|
39
39
|
PushTaskResResponse,
|
40
40
|
)
|
41
|
-
from flwr.proto.node_pb2 import Node
|
42
|
-
from flwr.proto.task_pb2 import TaskIns, TaskRes
|
41
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
42
|
+
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
43
43
|
|
44
44
|
try:
|
45
45
|
import requests
|
@@ -143,6 +143,7 @@ def http_request_response(
|
|
143
143
|
},
|
144
144
|
data=create_node_req_bytes,
|
145
145
|
verify=verify,
|
146
|
+
timeout=None,
|
146
147
|
)
|
147
148
|
|
148
149
|
# Check status code and headers
|
@@ -185,6 +186,7 @@ def http_request_response(
|
|
185
186
|
},
|
186
187
|
data=delete_node_req_req_bytes,
|
187
188
|
verify=verify,
|
189
|
+
timeout=None,
|
188
190
|
)
|
189
191
|
|
190
192
|
# Check status code and headers
|
@@ -225,6 +227,7 @@ def http_request_response(
|
|
225
227
|
},
|
226
228
|
data=pull_task_ins_req_bytes,
|
227
229
|
verify=verify,
|
230
|
+
timeout=None,
|
228
231
|
)
|
229
232
|
|
230
233
|
# Check status code and headers
|
@@ -303,6 +306,7 @@ def http_request_response(
|
|
303
306
|
},
|
304
307
|
data=push_task_res_request_bytes,
|
305
308
|
verify=verify,
|
309
|
+
timeout=None,
|
306
310
|
)
|
307
311
|
|
308
312
|
state[KEY_TASK_INS] = None
|
@@ -333,7 +333,7 @@ def _share_keys(
|
|
333
333
|
|
334
334
|
# Check if the size is larger than threshold
|
335
335
|
if len(state.public_keys_dict) < state.threshold:
|
336
|
-
raise
|
336
|
+
raise ValueError("Available neighbours number smaller than threshold")
|
337
337
|
|
338
338
|
# Check if all public keys are unique
|
339
339
|
pk_list: List[bytes] = []
|
@@ -341,14 +341,14 @@ def _share_keys(
|
|
341
341
|
pk_list.append(pk1)
|
342
342
|
pk_list.append(pk2)
|
343
343
|
if len(set(pk_list)) != len(pk_list):
|
344
|
-
raise
|
344
|
+
raise ValueError("Some public keys are identical")
|
345
345
|
|
346
346
|
# Check if public keys of this client are correct in the dictionary
|
347
347
|
if (
|
348
348
|
state.public_keys_dict[state.sid][0] != state.pk1
|
349
349
|
or state.public_keys_dict[state.sid][1] != state.pk2
|
350
350
|
):
|
351
|
-
raise
|
351
|
+
raise ValueError(
|
352
352
|
"Own public keys are displayed in dict incorrectly, should not happen!"
|
353
353
|
)
|
354
354
|
|
@@ -393,7 +393,7 @@ def _collect_masked_input(
|
|
393
393
|
ciphertexts = cast(List[bytes], named_values[KEY_CIPHERTEXT_LIST])
|
394
394
|
srcs = cast(List[int], named_values[KEY_SOURCE_LIST])
|
395
395
|
if len(ciphertexts) + 1 < state.threshold:
|
396
|
-
raise
|
396
|
+
raise ValueError("Not enough available neighbour clients.")
|
397
397
|
|
398
398
|
# Decrypt ciphertexts, verify their sources, and store shares.
|
399
399
|
for src, ciphertext in zip(srcs, ciphertexts):
|
@@ -409,7 +409,7 @@ def _collect_masked_input(
|
|
409
409
|
f"from {actual_src} instead of {src}."
|
410
410
|
)
|
411
411
|
if dst != state.sid:
|
412
|
-
ValueError(
|
412
|
+
raise ValueError(
|
413
413
|
f"Client {state.sid}: received an encrypted message"
|
414
414
|
f"for Client {dst} from Client {src}."
|
415
415
|
)
|
@@ -476,7 +476,7 @@ def _unmask(state: SecAggPlusState, named_values: Dict[str, Value]) -> Dict[str,
|
|
476
476
|
# Send private mask seed share for every avaliable client (including itclient)
|
477
477
|
# Send first private key share for building pairwise mask for every dropped client
|
478
478
|
if len(active_sids) < state.threshold:
|
479
|
-
raise
|
479
|
+
raise ValueError("Available neighbours number smaller than threshold")
|
480
480
|
|
481
481
|
sids, shares = [], []
|
482
482
|
sids += active_sids
|
flwr/client/typing.py
CHANGED
@@ -18,7 +18,7 @@ from dataclasses import dataclass
|
|
18
18
|
from typing import Callable
|
19
19
|
|
20
20
|
from flwr.client.run_state import RunState
|
21
|
-
from flwr.proto.task_pb2 import TaskIns, TaskRes
|
21
|
+
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
22
22
|
|
23
23
|
from .client import Client as Client
|
24
24
|
|
@@ -0,0 +1,98 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""ConfigsRecord."""
|
16
|
+
|
17
|
+
|
18
|
+
from dataclasses import dataclass, field
|
19
|
+
from typing import Dict, Optional, get_args
|
20
|
+
|
21
|
+
from .typing import ConfigsRecordValues, ConfigsScalar
|
22
|
+
|
23
|
+
|
24
|
+
@dataclass
|
25
|
+
class ConfigsRecord:
|
26
|
+
"""Configs record."""
|
27
|
+
|
28
|
+
keep_input: bool
|
29
|
+
data: Dict[str, ConfigsRecordValues] = field(default_factory=dict)
|
30
|
+
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
configs_dict: Optional[Dict[str, ConfigsRecordValues]] = None,
|
34
|
+
keep_input: bool = True,
|
35
|
+
):
|
36
|
+
"""Construct a ConfigsRecord object.
|
37
|
+
|
38
|
+
Parameters
|
39
|
+
----------
|
40
|
+
configs_dict : Optional[Dict[str, ConfigsRecordValues]]
|
41
|
+
A dictionary that stores basic types (i.e. `str`, `int`, `float`, `bytes` as
|
42
|
+
defined in `ConfigsScalar`) and lists of such types (see
|
43
|
+
`ConfigsScalarList`).
|
44
|
+
keep_input : bool (default: True)
|
45
|
+
A boolean indicating whether config passed should be deleted from the input
|
46
|
+
dictionary immediately after adding them to the record. When set
|
47
|
+
to True, the data is duplicated in memory. If memory is a concern, set
|
48
|
+
it to False.
|
49
|
+
"""
|
50
|
+
self.keep_input = keep_input
|
51
|
+
self.data = {}
|
52
|
+
if configs_dict:
|
53
|
+
self.set_configs(configs_dict)
|
54
|
+
|
55
|
+
def set_configs(self, configs_dict: Dict[str, ConfigsRecordValues]) -> None:
|
56
|
+
"""Add configs to the record.
|
57
|
+
|
58
|
+
Parameters
|
59
|
+
----------
|
60
|
+
configs_dict : Dict[str, ConfigsRecordValues]
|
61
|
+
A dictionary that stores basic types (i.e. `str`,`int`, `float`, `bytes` as
|
62
|
+
defined in `ConfigsRecordValues`) and list of such types (see
|
63
|
+
`ConfigsScalarList`).
|
64
|
+
"""
|
65
|
+
if any(not isinstance(k, str) for k in configs_dict.keys()):
|
66
|
+
raise TypeError(f"Not all keys are of valid type. Expected {str}")
|
67
|
+
|
68
|
+
def is_valid(value: ConfigsScalar) -> None:
|
69
|
+
"""Check if value is of expected type."""
|
70
|
+
if not isinstance(value, get_args(ConfigsScalar)):
|
71
|
+
raise TypeError(
|
72
|
+
"Not all values are of valid type."
|
73
|
+
f" Expected {ConfigsRecordValues} but you passed {type(value)}."
|
74
|
+
)
|
75
|
+
|
76
|
+
# Check types of values
|
77
|
+
# Split between those values that are list and those that aren't
|
78
|
+
# then process in the same way
|
79
|
+
for value in configs_dict.values():
|
80
|
+
if isinstance(value, list):
|
81
|
+
# If your lists are large (e.g. 1M+ elements) this will be slow
|
82
|
+
# 1s to check 10M element list on a M2 Pro
|
83
|
+
# In such settings, you'd be better of treating such config as
|
84
|
+
# an array and pass it to a ParametersRecord.
|
85
|
+
for list_value in value:
|
86
|
+
is_valid(list_value)
|
87
|
+
else:
|
88
|
+
is_valid(value)
|
89
|
+
|
90
|
+
# Add configs to record
|
91
|
+
if self.keep_input:
|
92
|
+
# Copy
|
93
|
+
self.data = configs_dict.copy()
|
94
|
+
else:
|
95
|
+
# Add entries to dataclass without duplicating memory
|
96
|
+
for key in list(configs_dict.keys()):
|
97
|
+
self.data[key] = configs_dict[key]
|
98
|
+
del configs_dict[key]
|
flwr/common/logger.py
CHANGED
@@ -111,3 +111,17 @@ def warn_experimental_feature(name: str) -> None:
|
|
111
111
|
""",
|
112
112
|
name,
|
113
113
|
)
|
114
|
+
|
115
|
+
|
116
|
+
def warn_deprecated_feature(name: str) -> None:
|
117
|
+
"""Warn the user when they use a deprecated feature."""
|
118
|
+
log(
|
119
|
+
WARN,
|
120
|
+
"""
|
121
|
+
DEPRECATED FEATURE: %s
|
122
|
+
|
123
|
+
This is a deprecated feature. It will be removed
|
124
|
+
entirely in future versions of Flower.
|
125
|
+
""",
|
126
|
+
name,
|
127
|
+
)
|
@@ -0,0 +1,96 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""MetricsRecord."""
|
16
|
+
|
17
|
+
|
18
|
+
from dataclasses import dataclass, field
|
19
|
+
from typing import Dict, Optional, get_args
|
20
|
+
|
21
|
+
from .typing import MetricsRecordValues, MetricsScalar
|
22
|
+
|
23
|
+
|
24
|
+
@dataclass
|
25
|
+
class MetricsRecord:
|
26
|
+
"""Metrics record."""
|
27
|
+
|
28
|
+
keep_input: bool
|
29
|
+
data: Dict[str, MetricsRecordValues] = field(default_factory=dict)
|
30
|
+
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
metrics_dict: Optional[Dict[str, MetricsRecordValues]] = None,
|
34
|
+
keep_input: bool = True,
|
35
|
+
):
|
36
|
+
"""Construct a MetricsRecord object.
|
37
|
+
|
38
|
+
Parameters
|
39
|
+
----------
|
40
|
+
metrics_dict : Optional[Dict[str, MetricsRecordValues]]
|
41
|
+
A dictionary that stores basic types (i.e. `int`, `float` as defined
|
42
|
+
in `MetricsScalar`) and list of such types (see `MetricsScalarList`).
|
43
|
+
keep_input : bool (default: True)
|
44
|
+
A boolean indicating whether metrics should be deleted from the input
|
45
|
+
dictionary immediately after adding them to the record. When set
|
46
|
+
to True, the data is duplicated in memory. If memory is a concern, set
|
47
|
+
it to False.
|
48
|
+
"""
|
49
|
+
self.keep_input = keep_input
|
50
|
+
self.data = {}
|
51
|
+
if metrics_dict:
|
52
|
+
self.set_metrics(metrics_dict)
|
53
|
+
|
54
|
+
def set_metrics(self, metrics_dict: Dict[str, MetricsRecordValues]) -> None:
|
55
|
+
"""Add metrics to the record.
|
56
|
+
|
57
|
+
Parameters
|
58
|
+
----------
|
59
|
+
metrics_dict : Dict[str, MetricsRecordValues]
|
60
|
+
A dictionary that stores basic types (i.e. `int`, `float` as defined
|
61
|
+
in `MetricsScalar`) and list of such types (see `MetricsScalarList`).
|
62
|
+
"""
|
63
|
+
if any(not isinstance(k, str) for k in metrics_dict.keys()):
|
64
|
+
raise TypeError(f"Not all keys are of valid type. Expected {str}.")
|
65
|
+
|
66
|
+
def is_valid(value: MetricsScalar) -> None:
|
67
|
+
"""Check if value is of expected type."""
|
68
|
+
if not isinstance(value, get_args(MetricsScalar)):
|
69
|
+
raise TypeError(
|
70
|
+
"Not all values are of valid type."
|
71
|
+
f" Expected {MetricsRecordValues} but you passed {type(value)}."
|
72
|
+
)
|
73
|
+
|
74
|
+
# Check types of values
|
75
|
+
# Split between those values that are list and those that aren't
|
76
|
+
# then process in the same way
|
77
|
+
for value in metrics_dict.values():
|
78
|
+
if isinstance(value, list):
|
79
|
+
# If your lists are large (e.g. 1M+ elements) this will be slow
|
80
|
+
# 1s to check 10M element list on a M2 Pro
|
81
|
+
# In such settings, you'd be better of treating such metric as
|
82
|
+
# an array and pass it to a ParametersRecord.
|
83
|
+
for list_value in value:
|
84
|
+
is_valid(list_value)
|
85
|
+
else:
|
86
|
+
is_valid(value)
|
87
|
+
|
88
|
+
# Add metrics to record
|
89
|
+
if self.keep_input:
|
90
|
+
# Copy
|
91
|
+
self.data = metrics_dict.copy()
|
92
|
+
else:
|
93
|
+
# Add entries to dataclass without duplicating memory
|
94
|
+
for key in list(metrics_dict.keys()):
|
95
|
+
self.data[key] = metrics_dict[key]
|
96
|
+
del metrics_dict[key]
|
@@ -0,0 +1,110 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""ParametersRecord and Array."""
|
16
|
+
|
17
|
+
|
18
|
+
from dataclasses import dataclass, field
|
19
|
+
from typing import List, Optional, OrderedDict
|
20
|
+
|
21
|
+
|
22
|
+
@dataclass
|
23
|
+
class Array:
|
24
|
+
"""Array type.
|
25
|
+
|
26
|
+
A dataclass containing serialized data from an array-like or tensor-like object
|
27
|
+
along with some metadata about it.
|
28
|
+
|
29
|
+
Parameters
|
30
|
+
----------
|
31
|
+
dtype : str
|
32
|
+
A string representing the data type of the serialised object (e.g. `np.float32`)
|
33
|
+
|
34
|
+
shape : List[int]
|
35
|
+
A list representing the shape of the unserialized array-like object. This is
|
36
|
+
used to deserialize the data (depending on the serialization method) or simply
|
37
|
+
as a metadata field.
|
38
|
+
|
39
|
+
stype : str
|
40
|
+
A string indicating the type of serialisation mechanism used to generate the
|
41
|
+
bytes in `data` from an array-like or tensor-like object.
|
42
|
+
|
43
|
+
data: bytes
|
44
|
+
A buffer of bytes containing the data.
|
45
|
+
"""
|
46
|
+
|
47
|
+
dtype: str
|
48
|
+
shape: List[int]
|
49
|
+
stype: str
|
50
|
+
data: bytes
|
51
|
+
|
52
|
+
|
53
|
+
@dataclass
|
54
|
+
class ParametersRecord:
|
55
|
+
"""Parameters record.
|
56
|
+
|
57
|
+
A dataclass storing named Arrays in order. This means that it holds entries as an
|
58
|
+
OrderedDict[str, Array]. ParametersRecord objects can be viewed as an equivalent to
|
59
|
+
PyTorch's state_dict, but holding serialised tensors instead.
|
60
|
+
"""
|
61
|
+
|
62
|
+
keep_input: bool
|
63
|
+
data: OrderedDict[str, Array] = field(default_factory=OrderedDict[str, Array])
|
64
|
+
|
65
|
+
def __init__(
|
66
|
+
self,
|
67
|
+
array_dict: Optional[OrderedDict[str, Array]] = None,
|
68
|
+
keep_input: bool = False,
|
69
|
+
) -> None:
|
70
|
+
"""Construct a ParametersRecord object.
|
71
|
+
|
72
|
+
Parameters
|
73
|
+
----------
|
74
|
+
array_dict : Optional[OrderedDict[str, Array]]
|
75
|
+
A dictionary that stores serialized array-like or tensor-like objects.
|
76
|
+
keep_input : bool (default: False)
|
77
|
+
A boolean indicating whether parameters should be deleted from the input
|
78
|
+
dictionary immediately after adding them to the record. If False, the
|
79
|
+
dictionary passed to `set_parameters()` will be empty once exiting from that
|
80
|
+
function. This is the desired behaviour when working with very large
|
81
|
+
models/tensors/arrays. However, if you plan to continue working with your
|
82
|
+
parameters after adding it to the record, set this flag to True. When set
|
83
|
+
to True, the data is duplicated in memory.
|
84
|
+
"""
|
85
|
+
self.keep_input = keep_input
|
86
|
+
self.data = OrderedDict()
|
87
|
+
if array_dict:
|
88
|
+
self.set_parameters(array_dict)
|
89
|
+
|
90
|
+
def set_parameters(self, array_dict: OrderedDict[str, Array]) -> None:
|
91
|
+
"""Add parameters to record.
|
92
|
+
|
93
|
+
Parameters
|
94
|
+
----------
|
95
|
+
array_dict : OrderedDict[str, Array]
|
96
|
+
A dictionary that stores serialized array-like or tensor-like objects.
|
97
|
+
"""
|
98
|
+
if any(not isinstance(k, str) for k in array_dict.keys()):
|
99
|
+
raise TypeError(f"Not all keys are of valid type. Expected {str}")
|
100
|
+
if any(not isinstance(v, Array) for v in array_dict.values()):
|
101
|
+
raise TypeError(f"Not all values are of valid type. Expected {Array}")
|
102
|
+
|
103
|
+
if self.keep_input:
|
104
|
+
# Copy
|
105
|
+
self.data = OrderedDict(array_dict)
|
106
|
+
else:
|
107
|
+
# Add entries to dataclass without duplicating memory
|
108
|
+
for key in list(array_dict.keys()):
|
109
|
+
self.data[key] = array_dict[key]
|
110
|
+
del array_dict[key]
|
flwr/common/recordset.py
CHANGED
@@ -14,32 +14,22 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""RecordSet."""
|
16
16
|
|
17
|
-
from dataclasses import dataclass
|
18
|
-
from typing import Dict
|
19
|
-
|
20
|
-
|
21
|
-
@dataclass
|
22
|
-
class ParametersRecord:
|
23
|
-
"""Parameters record."""
|
24
|
-
|
25
|
-
|
26
|
-
@dataclass
|
27
|
-
class MetricsRecord:
|
28
|
-
"""Metrics record."""
|
29
17
|
|
18
|
+
from dataclasses import dataclass, field
|
19
|
+
from typing import Dict
|
30
20
|
|
31
|
-
|
32
|
-
|
33
|
-
|
21
|
+
from .configsrecord import ConfigsRecord
|
22
|
+
from .metricsrecord import MetricsRecord
|
23
|
+
from .parametersrecord import ParametersRecord
|
34
24
|
|
35
25
|
|
36
26
|
@dataclass
|
37
27
|
class RecordSet:
|
38
28
|
"""Definition of RecordSet."""
|
39
29
|
|
40
|
-
parameters: Dict[str, ParametersRecord] =
|
41
|
-
metrics: Dict[str, MetricsRecord] =
|
42
|
-
configs: Dict[str, ConfigsRecord] =
|
30
|
+
parameters: Dict[str, ParametersRecord] = field(default_factory=dict)
|
31
|
+
metrics: Dict[str, MetricsRecord] = field(default_factory=dict)
|
32
|
+
configs: Dict[str, ConfigsRecord] = field(default_factory=dict)
|
43
33
|
|
44
34
|
def set_parameters(self, name: str, record: ParametersRecord) -> None:
|
45
35
|
"""Add a ParametersRecord."""
|