flwr-nightly 1.7.0.dev20240116__py3-none-any.whl → 1.7.0.dev20240118__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- flwr/client/app.py +7 -4
- flwr/client/dpfedavg_numpy_client.py +4 -4
- flwr/client/grpc_client/connection.py +7 -4
- flwr/client/grpc_rere_client/connection.py +4 -4
- flwr/client/message_handler/message_handler.py +11 -2
- flwr/client/message_handler/task_handler.py +8 -6
- flwr/client/node_state_tests.py +1 -1
- flwr/client/numpy_client.py +2 -2
- flwr/client/rest_client/connection.py +7 -3
- flwr/client/secure_aggregation/secaggplus_handler.py +6 -6
- flwr/client/typing.py +1 -1
- flwr/common/configsrecord.py +98 -0
- flwr/common/logger.py +14 -0
- flwr/common/metricsrecord.py +96 -0
- flwr/common/parametersrecord.py +110 -0
- flwr/common/recordset.py +8 -18
- flwr/common/recordset_utils.py +87 -0
- flwr/common/retry_invoker.py +1 -0
- flwr/common/serde.py +12 -8
- flwr/common/typing.py +9 -0
- flwr/driver/app.py +5 -3
- flwr/driver/driver.py +3 -3
- flwr/driver/driver_client_proxy.py +24 -15
- flwr/driver/grpc_driver.py +6 -6
- flwr/proto/driver_pb2.py +23 -88
- flwr/proto/fleet_pb2.py +29 -111
- flwr/proto/node_pb2.py +7 -15
- flwr/proto/task_pb2.py +33 -127
- flwr/proto/transport_pb2.py +69 -278
- flwr/server/app.py +9 -3
- flwr/server/driver/driver_servicer.py +4 -4
- flwr/server/fleet/grpc_bidi/flower_service_servicer.py +5 -2
- flwr/server/fleet/grpc_bidi/grpc_bridge.py +9 -6
- flwr/server/fleet/grpc_bidi/grpc_client_proxy.py +4 -1
- flwr/server/fleet/grpc_bidi/grpc_server.py +3 -1
- flwr/server/fleet/grpc_bidi/ins_scheduler.py +7 -4
- flwr/server/fleet/grpc_rere/fleet_servicer.py +2 -2
- flwr/server/fleet/message_handler/message_handler.py +3 -3
- flwr/server/fleet/rest_rere/rest_api.py +1 -1
- flwr/server/state/in_memory_state.py +1 -1
- flwr/server/state/sqlite_state.py +8 -5
- flwr/server/state/state.py +1 -1
- flwr/server/strategy/aggregate.py +8 -8
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +2 -2
- flwr/server/strategy/fedavg_android.py +0 -2
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +9 -2
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/utils/validator.py +1 -1
- {flwr_nightly-1.7.0.dev20240116.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/METADATA +3 -3
- {flwr_nightly-1.7.0.dev20240116.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/RECORD +55 -51
- {flwr_nightly-1.7.0.dev20240116.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.7.0.dev20240116.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.7.0.dev20240116.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/entry_points.txt +0 -0
flwr/client/app.py
CHANGED
@@ -35,7 +35,7 @@ from flwr.common.constant import (
|
|
35
35
|
TRANSPORT_TYPES,
|
36
36
|
)
|
37
37
|
from flwr.common.logger import log, warn_experimental_feature
|
38
|
-
from flwr.proto.task_pb2 import TaskIns, TaskRes
|
38
|
+
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
39
39
|
|
40
40
|
from .flower import load_flower_callable
|
41
41
|
from .grpc_client.connection import grpc_connection
|
@@ -138,10 +138,12 @@ def _check_actionable_client(
|
|
138
138
|
client: Optional[Client], client_fn: Optional[ClientFn]
|
139
139
|
) -> None:
|
140
140
|
if client_fn is None and client is None:
|
141
|
-
raise
|
141
|
+
raise ValueError(
|
142
|
+
"Both `client_fn` and `client` are `None`, but one is required"
|
143
|
+
)
|
142
144
|
|
143
145
|
if client_fn is not None and client is not None:
|
144
|
-
raise
|
146
|
+
raise ValueError(
|
145
147
|
"Both `client_fn` and `client` are provided, but only one is allowed"
|
146
148
|
)
|
147
149
|
|
@@ -150,6 +152,7 @@ def _check_actionable_client(
|
|
150
152
|
# pylint: disable=too-many-branches
|
151
153
|
# pylint: disable=too-many-locals
|
152
154
|
# pylint: disable=too-many-statements
|
155
|
+
# pylint: disable=too-many-arguments
|
153
156
|
def start_client(
|
154
157
|
*,
|
155
158
|
server_address: str,
|
@@ -299,7 +302,7 @@ def _start_client_internal(
|
|
299
302
|
cid: str, # pylint: disable=unused-argument
|
300
303
|
) -> Client:
|
301
304
|
if client is None: # Added this to keep mypy happy
|
302
|
-
raise
|
305
|
+
raise ValueError(
|
303
306
|
"Both `client_fn` and `client` are `None`, but one is required"
|
304
307
|
)
|
305
308
|
return client # Always return the same instance
|
@@ -117,16 +117,16 @@ class DPFedAvgNumPyClient(NumPyClient):
|
|
117
117
|
update = [np.subtract(x, y) for (x, y) in zip(updated_params, original_params)]
|
118
118
|
|
119
119
|
if "dpfedavg_clip_norm" not in config:
|
120
|
-
raise
|
120
|
+
raise KeyError("Clipping threshold not supplied by the server.")
|
121
121
|
if not isinstance(config["dpfedavg_clip_norm"], float):
|
122
|
-
raise
|
122
|
+
raise TypeError("Clipping threshold should be a floating point value.")
|
123
123
|
|
124
124
|
# Clipping
|
125
125
|
update, clipped = clip_by_l2(update, config["dpfedavg_clip_norm"])
|
126
126
|
|
127
127
|
if "dpfedavg_noise_stddev" in config:
|
128
128
|
if not isinstance(config["dpfedavg_noise_stddev"], float):
|
129
|
-
raise
|
129
|
+
raise TypeError(
|
130
130
|
"Scale of noise to be added should be a floating point value."
|
131
131
|
)
|
132
132
|
# Noising
|
@@ -138,7 +138,7 @@ class DPFedAvgNumPyClient(NumPyClient):
|
|
138
138
|
# Calculating value of norm indicator bit, required for adaptive clipping
|
139
139
|
if "dpfedavg_adaptive_clip_enabled" in config:
|
140
140
|
if not isinstance(config["dpfedavg_adaptive_clip_enabled"], bool):
|
141
|
-
raise
|
141
|
+
raise TypeError(
|
142
142
|
"dpfedavg_adaptive_clip_enabled should be a boolean-valued flag."
|
143
143
|
)
|
144
144
|
metrics["dpfedavg_norm_bit"] = not clipped
|
@@ -25,10 +25,13 @@ from typing import Callable, Iterator, Optional, Tuple, Union
|
|
25
25
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
26
26
|
from flwr.common.grpc import create_channel
|
27
27
|
from flwr.common.logger import log
|
28
|
-
from flwr.proto.node_pb2 import Node
|
29
|
-
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
|
30
|
-
from flwr.proto.transport_pb2 import
|
31
|
-
|
28
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
29
|
+
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
|
30
|
+
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
31
|
+
ClientMessage,
|
32
|
+
ServerMessage,
|
33
|
+
)
|
34
|
+
from flwr.proto.transport_pb2_grpc import FlowerServiceStub # pylint: disable=E0611
|
32
35
|
|
33
36
|
# The following flags can be uncommented for debugging. Other possible values:
|
34
37
|
# https://github.com/grpc/grpc/blob/master/doc/environment_variables.md
|
@@ -29,15 +29,15 @@ from flwr.client.message_handler.task_handler import (
|
|
29
29
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
30
30
|
from flwr.common.grpc import create_channel
|
31
31
|
from flwr.common.logger import log, warn_experimental_feature
|
32
|
-
from flwr.proto.fleet_pb2 import (
|
32
|
+
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
33
33
|
CreateNodeRequest,
|
34
34
|
DeleteNodeRequest,
|
35
35
|
PullTaskInsRequest,
|
36
36
|
PushTaskResRequest,
|
37
37
|
)
|
38
|
-
from flwr.proto.fleet_pb2_grpc import FleetStub
|
39
|
-
from flwr.proto.node_pb2 import Node
|
40
|
-
from flwr.proto.task_pb2 import TaskIns, TaskRes
|
38
|
+
from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
|
39
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
40
|
+
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
41
41
|
|
42
42
|
KEY_NODE = "node"
|
43
43
|
KEY_TASK_INS = "current_task_ins"
|
@@ -32,8 +32,17 @@ from flwr.client.run_state import RunState
|
|
32
32
|
from flwr.client.secure_aggregation import SecureAggregationHandler
|
33
33
|
from flwr.client.typing import ClientFn
|
34
34
|
from flwr.common import serde
|
35
|
-
from flwr.proto.task_pb2 import
|
36
|
-
|
35
|
+
from flwr.proto.task_pb2 import ( # pylint: disable=E0611
|
36
|
+
SecureAggregation,
|
37
|
+
Task,
|
38
|
+
TaskIns,
|
39
|
+
TaskRes,
|
40
|
+
)
|
41
|
+
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
42
|
+
ClientMessage,
|
43
|
+
Reason,
|
44
|
+
ServerMessage,
|
45
|
+
)
|
37
46
|
|
38
47
|
|
39
48
|
class UnexpectedServerMessage(Exception):
|
@@ -17,10 +17,13 @@
|
|
17
17
|
|
18
18
|
from typing import Optional
|
19
19
|
|
20
|
-
from flwr.proto.fleet_pb2 import PullTaskInsResponse
|
21
|
-
from flwr.proto.node_pb2 import Node
|
22
|
-
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
|
23
|
-
from flwr.proto.transport_pb2 import
|
20
|
+
from flwr.proto.fleet_pb2 import PullTaskInsResponse # pylint: disable=E0611
|
21
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
22
|
+
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
|
23
|
+
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
24
|
+
ClientMessage,
|
25
|
+
ServerMessage,
|
26
|
+
)
|
24
27
|
|
25
28
|
|
26
29
|
def validate_task_ins(task_ins: TaskIns, discard_reconnect_ins: bool) -> bool:
|
@@ -80,8 +83,7 @@ def validate_task_res(task_res: TaskRes) -> bool:
|
|
80
83
|
initialized_fields_in_task = {field.name for field, _ in task_res.task.ListFields()}
|
81
84
|
|
82
85
|
# Check if certain fields are already initialized
|
83
|
-
# pylint: disable-next=too-many-boolean-expressions
|
84
|
-
if (
|
86
|
+
if ( # pylint: disable-next=too-many-boolean-expressions
|
85
87
|
"task_id" in initialized_fields_in_task_res
|
86
88
|
or "group_id" in initialized_fields_in_task_res
|
87
89
|
or "run_id" in initialized_fields_in_task_res
|
flwr/client/node_state_tests.py
CHANGED
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
from flwr.client.node_state import NodeState
|
19
19
|
from flwr.client.run_state import RunState
|
20
|
-
from flwr.proto.task_pb2 import TaskIns
|
20
|
+
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
21
21
|
|
22
22
|
|
23
23
|
def _run_dummy_task(state: RunState) -> RunState:
|
flwr/client/numpy_client.py
CHANGED
@@ -242,7 +242,7 @@ def _fit(self: Client, ins: FitIns) -> FitRes:
|
|
242
242
|
and isinstance(results[1], int)
|
243
243
|
and isinstance(results[2], dict)
|
244
244
|
):
|
245
|
-
raise
|
245
|
+
raise TypeError(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT)
|
246
246
|
|
247
247
|
# Return FitRes
|
248
248
|
parameters_prime, num_examples, metrics = results
|
@@ -266,7 +266,7 @@ def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes:
|
|
266
266
|
and isinstance(results[1], int)
|
267
267
|
and isinstance(results[2], dict)
|
268
268
|
):
|
269
|
-
raise
|
269
|
+
raise TypeError(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_EVALUATE)
|
270
270
|
|
271
271
|
# Return EvaluateRes
|
272
272
|
loss, num_examples, metrics = results
|
@@ -29,7 +29,7 @@ from flwr.client.message_handler.task_handler import (
|
|
29
29
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
30
30
|
from flwr.common.constant import MISSING_EXTRA_REST
|
31
31
|
from flwr.common.logger import log
|
32
|
-
from flwr.proto.fleet_pb2 import (
|
32
|
+
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
33
33
|
CreateNodeRequest,
|
34
34
|
CreateNodeResponse,
|
35
35
|
DeleteNodeRequest,
|
@@ -38,8 +38,8 @@ from flwr.proto.fleet_pb2 import (
|
|
38
38
|
PushTaskResRequest,
|
39
39
|
PushTaskResResponse,
|
40
40
|
)
|
41
|
-
from flwr.proto.node_pb2 import Node
|
42
|
-
from flwr.proto.task_pb2 import TaskIns, TaskRes
|
41
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
42
|
+
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
43
43
|
|
44
44
|
try:
|
45
45
|
import requests
|
@@ -143,6 +143,7 @@ def http_request_response(
|
|
143
143
|
},
|
144
144
|
data=create_node_req_bytes,
|
145
145
|
verify=verify,
|
146
|
+
timeout=None,
|
146
147
|
)
|
147
148
|
|
148
149
|
# Check status code and headers
|
@@ -185,6 +186,7 @@ def http_request_response(
|
|
185
186
|
},
|
186
187
|
data=delete_node_req_req_bytes,
|
187
188
|
verify=verify,
|
189
|
+
timeout=None,
|
188
190
|
)
|
189
191
|
|
190
192
|
# Check status code and headers
|
@@ -225,6 +227,7 @@ def http_request_response(
|
|
225
227
|
},
|
226
228
|
data=pull_task_ins_req_bytes,
|
227
229
|
verify=verify,
|
230
|
+
timeout=None,
|
228
231
|
)
|
229
232
|
|
230
233
|
# Check status code and headers
|
@@ -303,6 +306,7 @@ def http_request_response(
|
|
303
306
|
},
|
304
307
|
data=push_task_res_request_bytes,
|
305
308
|
verify=verify,
|
309
|
+
timeout=None,
|
306
310
|
)
|
307
311
|
|
308
312
|
state[KEY_TASK_INS] = None
|
@@ -333,7 +333,7 @@ def _share_keys(
|
|
333
333
|
|
334
334
|
# Check if the size is larger than threshold
|
335
335
|
if len(state.public_keys_dict) < state.threshold:
|
336
|
-
raise
|
336
|
+
raise ValueError("Available neighbours number smaller than threshold")
|
337
337
|
|
338
338
|
# Check if all public keys are unique
|
339
339
|
pk_list: List[bytes] = []
|
@@ -341,14 +341,14 @@ def _share_keys(
|
|
341
341
|
pk_list.append(pk1)
|
342
342
|
pk_list.append(pk2)
|
343
343
|
if len(set(pk_list)) != len(pk_list):
|
344
|
-
raise
|
344
|
+
raise ValueError("Some public keys are identical")
|
345
345
|
|
346
346
|
# Check if public keys of this client are correct in the dictionary
|
347
347
|
if (
|
348
348
|
state.public_keys_dict[state.sid][0] != state.pk1
|
349
349
|
or state.public_keys_dict[state.sid][1] != state.pk2
|
350
350
|
):
|
351
|
-
raise
|
351
|
+
raise ValueError(
|
352
352
|
"Own public keys are displayed in dict incorrectly, should not happen!"
|
353
353
|
)
|
354
354
|
|
@@ -393,7 +393,7 @@ def _collect_masked_input(
|
|
393
393
|
ciphertexts = cast(List[bytes], named_values[KEY_CIPHERTEXT_LIST])
|
394
394
|
srcs = cast(List[int], named_values[KEY_SOURCE_LIST])
|
395
395
|
if len(ciphertexts) + 1 < state.threshold:
|
396
|
-
raise
|
396
|
+
raise ValueError("Not enough available neighbour clients.")
|
397
397
|
|
398
398
|
# Decrypt ciphertexts, verify their sources, and store shares.
|
399
399
|
for src, ciphertext in zip(srcs, ciphertexts):
|
@@ -409,7 +409,7 @@ def _collect_masked_input(
|
|
409
409
|
f"from {actual_src} instead of {src}."
|
410
410
|
)
|
411
411
|
if dst != state.sid:
|
412
|
-
ValueError(
|
412
|
+
raise ValueError(
|
413
413
|
f"Client {state.sid}: received an encrypted message"
|
414
414
|
f"for Client {dst} from Client {src}."
|
415
415
|
)
|
@@ -476,7 +476,7 @@ def _unmask(state: SecAggPlusState, named_values: Dict[str, Value]) -> Dict[str,
|
|
476
476
|
# Send private mask seed share for every avaliable client (including itclient)
|
477
477
|
# Send first private key share for building pairwise mask for every dropped client
|
478
478
|
if len(active_sids) < state.threshold:
|
479
|
-
raise
|
479
|
+
raise ValueError("Available neighbours number smaller than threshold")
|
480
480
|
|
481
481
|
sids, shares = [], []
|
482
482
|
sids += active_sids
|
flwr/client/typing.py
CHANGED
@@ -18,7 +18,7 @@ from dataclasses import dataclass
|
|
18
18
|
from typing import Callable
|
19
19
|
|
20
20
|
from flwr.client.run_state import RunState
|
21
|
-
from flwr.proto.task_pb2 import TaskIns, TaskRes
|
21
|
+
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
22
22
|
|
23
23
|
from .client import Client as Client
|
24
24
|
|
@@ -0,0 +1,98 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""ConfigsRecord."""
|
16
|
+
|
17
|
+
|
18
|
+
from dataclasses import dataclass, field
|
19
|
+
from typing import Dict, Optional, get_args
|
20
|
+
|
21
|
+
from .typing import ConfigsRecordValues, ConfigsScalar
|
22
|
+
|
23
|
+
|
24
|
+
@dataclass
|
25
|
+
class ConfigsRecord:
|
26
|
+
"""Configs record."""
|
27
|
+
|
28
|
+
keep_input: bool
|
29
|
+
data: Dict[str, ConfigsRecordValues] = field(default_factory=dict)
|
30
|
+
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
configs_dict: Optional[Dict[str, ConfigsRecordValues]] = None,
|
34
|
+
keep_input: bool = True,
|
35
|
+
):
|
36
|
+
"""Construct a ConfigsRecord object.
|
37
|
+
|
38
|
+
Parameters
|
39
|
+
----------
|
40
|
+
configs_dict : Optional[Dict[str, ConfigsRecordValues]]
|
41
|
+
A dictionary that stores basic types (i.e. `str`, `int`, `float`, `bytes` as
|
42
|
+
defined in `ConfigsScalar`) and lists of such types (see
|
43
|
+
`ConfigsScalarList`).
|
44
|
+
keep_input : bool (default: True)
|
45
|
+
A boolean indicating whether config passed should be deleted from the input
|
46
|
+
dictionary immediately after adding them to the record. When set
|
47
|
+
to True, the data is duplicated in memory. If memory is a concern, set
|
48
|
+
it to False.
|
49
|
+
"""
|
50
|
+
self.keep_input = keep_input
|
51
|
+
self.data = {}
|
52
|
+
if configs_dict:
|
53
|
+
self.set_configs(configs_dict)
|
54
|
+
|
55
|
+
def set_configs(self, configs_dict: Dict[str, ConfigsRecordValues]) -> None:
|
56
|
+
"""Add configs to the record.
|
57
|
+
|
58
|
+
Parameters
|
59
|
+
----------
|
60
|
+
configs_dict : Dict[str, ConfigsRecordValues]
|
61
|
+
A dictionary that stores basic types (i.e. `str`,`int`, `float`, `bytes` as
|
62
|
+
defined in `ConfigsRecordValues`) and list of such types (see
|
63
|
+
`ConfigsScalarList`).
|
64
|
+
"""
|
65
|
+
if any(not isinstance(k, str) for k in configs_dict.keys()):
|
66
|
+
raise TypeError(f"Not all keys are of valid type. Expected {str}")
|
67
|
+
|
68
|
+
def is_valid(value: ConfigsScalar) -> None:
|
69
|
+
"""Check if value is of expected type."""
|
70
|
+
if not isinstance(value, get_args(ConfigsScalar)):
|
71
|
+
raise TypeError(
|
72
|
+
"Not all values are of valid type."
|
73
|
+
f" Expected {ConfigsRecordValues} but you passed {type(value)}."
|
74
|
+
)
|
75
|
+
|
76
|
+
# Check types of values
|
77
|
+
# Split between those values that are list and those that aren't
|
78
|
+
# then process in the same way
|
79
|
+
for value in configs_dict.values():
|
80
|
+
if isinstance(value, list):
|
81
|
+
# If your lists are large (e.g. 1M+ elements) this will be slow
|
82
|
+
# 1s to check 10M element list on a M2 Pro
|
83
|
+
# In such settings, you'd be better of treating such config as
|
84
|
+
# an array and pass it to a ParametersRecord.
|
85
|
+
for list_value in value:
|
86
|
+
is_valid(list_value)
|
87
|
+
else:
|
88
|
+
is_valid(value)
|
89
|
+
|
90
|
+
# Add configs to record
|
91
|
+
if self.keep_input:
|
92
|
+
# Copy
|
93
|
+
self.data = configs_dict.copy()
|
94
|
+
else:
|
95
|
+
# Add entries to dataclass without duplicating memory
|
96
|
+
for key in list(configs_dict.keys()):
|
97
|
+
self.data[key] = configs_dict[key]
|
98
|
+
del configs_dict[key]
|
flwr/common/logger.py
CHANGED
@@ -111,3 +111,17 @@ def warn_experimental_feature(name: str) -> None:
|
|
111
111
|
""",
|
112
112
|
name,
|
113
113
|
)
|
114
|
+
|
115
|
+
|
116
|
+
def warn_deprecated_feature(name: str) -> None:
|
117
|
+
"""Warn the user when they use a deprecated feature."""
|
118
|
+
log(
|
119
|
+
WARN,
|
120
|
+
"""
|
121
|
+
DEPRECATED FEATURE: %s
|
122
|
+
|
123
|
+
This is a deprecated feature. It will be removed
|
124
|
+
entirely in future versions of Flower.
|
125
|
+
""",
|
126
|
+
name,
|
127
|
+
)
|
@@ -0,0 +1,96 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""MetricsRecord."""
|
16
|
+
|
17
|
+
|
18
|
+
from dataclasses import dataclass, field
|
19
|
+
from typing import Dict, Optional, get_args
|
20
|
+
|
21
|
+
from .typing import MetricsRecordValues, MetricsScalar
|
22
|
+
|
23
|
+
|
24
|
+
@dataclass
|
25
|
+
class MetricsRecord:
|
26
|
+
"""Metrics record."""
|
27
|
+
|
28
|
+
keep_input: bool
|
29
|
+
data: Dict[str, MetricsRecordValues] = field(default_factory=dict)
|
30
|
+
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
metrics_dict: Optional[Dict[str, MetricsRecordValues]] = None,
|
34
|
+
keep_input: bool = True,
|
35
|
+
):
|
36
|
+
"""Construct a MetricsRecord object.
|
37
|
+
|
38
|
+
Parameters
|
39
|
+
----------
|
40
|
+
metrics_dict : Optional[Dict[str, MetricsRecordValues]]
|
41
|
+
A dictionary that stores basic types (i.e. `int`, `float` as defined
|
42
|
+
in `MetricsScalar`) and list of such types (see `MetricsScalarList`).
|
43
|
+
keep_input : bool (default: True)
|
44
|
+
A boolean indicating whether metrics should be deleted from the input
|
45
|
+
dictionary immediately after adding them to the record. When set
|
46
|
+
to True, the data is duplicated in memory. If memory is a concern, set
|
47
|
+
it to False.
|
48
|
+
"""
|
49
|
+
self.keep_input = keep_input
|
50
|
+
self.data = {}
|
51
|
+
if metrics_dict:
|
52
|
+
self.set_metrics(metrics_dict)
|
53
|
+
|
54
|
+
def set_metrics(self, metrics_dict: Dict[str, MetricsRecordValues]) -> None:
|
55
|
+
"""Add metrics to the record.
|
56
|
+
|
57
|
+
Parameters
|
58
|
+
----------
|
59
|
+
metrics_dict : Dict[str, MetricsRecordValues]
|
60
|
+
A dictionary that stores basic types (i.e. `int`, `float` as defined
|
61
|
+
in `MetricsScalar`) and list of such types (see `MetricsScalarList`).
|
62
|
+
"""
|
63
|
+
if any(not isinstance(k, str) for k in metrics_dict.keys()):
|
64
|
+
raise TypeError(f"Not all keys are of valid type. Expected {str}.")
|
65
|
+
|
66
|
+
def is_valid(value: MetricsScalar) -> None:
|
67
|
+
"""Check if value is of expected type."""
|
68
|
+
if not isinstance(value, get_args(MetricsScalar)):
|
69
|
+
raise TypeError(
|
70
|
+
"Not all values are of valid type."
|
71
|
+
f" Expected {MetricsRecordValues} but you passed {type(value)}."
|
72
|
+
)
|
73
|
+
|
74
|
+
# Check types of values
|
75
|
+
# Split between those values that are list and those that aren't
|
76
|
+
# then process in the same way
|
77
|
+
for value in metrics_dict.values():
|
78
|
+
if isinstance(value, list):
|
79
|
+
# If your lists are large (e.g. 1M+ elements) this will be slow
|
80
|
+
# 1s to check 10M element list on a M2 Pro
|
81
|
+
# In such settings, you'd be better of treating such metric as
|
82
|
+
# an array and pass it to a ParametersRecord.
|
83
|
+
for list_value in value:
|
84
|
+
is_valid(list_value)
|
85
|
+
else:
|
86
|
+
is_valid(value)
|
87
|
+
|
88
|
+
# Add metrics to record
|
89
|
+
if self.keep_input:
|
90
|
+
# Copy
|
91
|
+
self.data = metrics_dict.copy()
|
92
|
+
else:
|
93
|
+
# Add entries to dataclass without duplicating memory
|
94
|
+
for key in list(metrics_dict.keys()):
|
95
|
+
self.data[key] = metrics_dict[key]
|
96
|
+
del metrics_dict[key]
|
@@ -0,0 +1,110 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""ParametersRecord and Array."""
|
16
|
+
|
17
|
+
|
18
|
+
from dataclasses import dataclass, field
|
19
|
+
from typing import List, Optional, OrderedDict
|
20
|
+
|
21
|
+
|
22
|
+
@dataclass
|
23
|
+
class Array:
|
24
|
+
"""Array type.
|
25
|
+
|
26
|
+
A dataclass containing serialized data from an array-like or tensor-like object
|
27
|
+
along with some metadata about it.
|
28
|
+
|
29
|
+
Parameters
|
30
|
+
----------
|
31
|
+
dtype : str
|
32
|
+
A string representing the data type of the serialised object (e.g. `np.float32`)
|
33
|
+
|
34
|
+
shape : List[int]
|
35
|
+
A list representing the shape of the unserialized array-like object. This is
|
36
|
+
used to deserialize the data (depending on the serialization method) or simply
|
37
|
+
as a metadata field.
|
38
|
+
|
39
|
+
stype : str
|
40
|
+
A string indicating the type of serialisation mechanism used to generate the
|
41
|
+
bytes in `data` from an array-like or tensor-like object.
|
42
|
+
|
43
|
+
data: bytes
|
44
|
+
A buffer of bytes containing the data.
|
45
|
+
"""
|
46
|
+
|
47
|
+
dtype: str
|
48
|
+
shape: List[int]
|
49
|
+
stype: str
|
50
|
+
data: bytes
|
51
|
+
|
52
|
+
|
53
|
+
@dataclass
|
54
|
+
class ParametersRecord:
|
55
|
+
"""Parameters record.
|
56
|
+
|
57
|
+
A dataclass storing named Arrays in order. This means that it holds entries as an
|
58
|
+
OrderedDict[str, Array]. ParametersRecord objects can be viewed as an equivalent to
|
59
|
+
PyTorch's state_dict, but holding serialised tensors instead.
|
60
|
+
"""
|
61
|
+
|
62
|
+
keep_input: bool
|
63
|
+
data: OrderedDict[str, Array] = field(default_factory=OrderedDict[str, Array])
|
64
|
+
|
65
|
+
def __init__(
|
66
|
+
self,
|
67
|
+
array_dict: Optional[OrderedDict[str, Array]] = None,
|
68
|
+
keep_input: bool = False,
|
69
|
+
) -> None:
|
70
|
+
"""Construct a ParametersRecord object.
|
71
|
+
|
72
|
+
Parameters
|
73
|
+
----------
|
74
|
+
array_dict : Optional[OrderedDict[str, Array]]
|
75
|
+
A dictionary that stores serialized array-like or tensor-like objects.
|
76
|
+
keep_input : bool (default: False)
|
77
|
+
A boolean indicating whether parameters should be deleted from the input
|
78
|
+
dictionary immediately after adding them to the record. If False, the
|
79
|
+
dictionary passed to `set_parameters()` will be empty once exiting from that
|
80
|
+
function. This is the desired behaviour when working with very large
|
81
|
+
models/tensors/arrays. However, if you plan to continue working with your
|
82
|
+
parameters after adding it to the record, set this flag to True. When set
|
83
|
+
to True, the data is duplicated in memory.
|
84
|
+
"""
|
85
|
+
self.keep_input = keep_input
|
86
|
+
self.data = OrderedDict()
|
87
|
+
if array_dict:
|
88
|
+
self.set_parameters(array_dict)
|
89
|
+
|
90
|
+
def set_parameters(self, array_dict: OrderedDict[str, Array]) -> None:
|
91
|
+
"""Add parameters to record.
|
92
|
+
|
93
|
+
Parameters
|
94
|
+
----------
|
95
|
+
array_dict : OrderedDict[str, Array]
|
96
|
+
A dictionary that stores serialized array-like or tensor-like objects.
|
97
|
+
"""
|
98
|
+
if any(not isinstance(k, str) for k in array_dict.keys()):
|
99
|
+
raise TypeError(f"Not all keys are of valid type. Expected {str}")
|
100
|
+
if any(not isinstance(v, Array) for v in array_dict.values()):
|
101
|
+
raise TypeError(f"Not all values are of valid type. Expected {Array}")
|
102
|
+
|
103
|
+
if self.keep_input:
|
104
|
+
# Copy
|
105
|
+
self.data = OrderedDict(array_dict)
|
106
|
+
else:
|
107
|
+
# Add entries to dataclass without duplicating memory
|
108
|
+
for key in list(array_dict.keys()):
|
109
|
+
self.data[key] = array_dict[key]
|
110
|
+
del array_dict[key]
|
flwr/common/recordset.py
CHANGED
@@ -14,32 +14,22 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""RecordSet."""
|
16
16
|
|
17
|
-
from dataclasses import dataclass
|
18
|
-
from typing import Dict
|
19
|
-
|
20
|
-
|
21
|
-
@dataclass
|
22
|
-
class ParametersRecord:
|
23
|
-
"""Parameters record."""
|
24
|
-
|
25
|
-
|
26
|
-
@dataclass
|
27
|
-
class MetricsRecord:
|
28
|
-
"""Metrics record."""
|
29
17
|
|
18
|
+
from dataclasses import dataclass, field
|
19
|
+
from typing import Dict
|
30
20
|
|
31
|
-
|
32
|
-
|
33
|
-
|
21
|
+
from .configsrecord import ConfigsRecord
|
22
|
+
from .metricsrecord import MetricsRecord
|
23
|
+
from .parametersrecord import ParametersRecord
|
34
24
|
|
35
25
|
|
36
26
|
@dataclass
|
37
27
|
class RecordSet:
|
38
28
|
"""Definition of RecordSet."""
|
39
29
|
|
40
|
-
parameters: Dict[str, ParametersRecord] =
|
41
|
-
metrics: Dict[str, MetricsRecord] =
|
42
|
-
configs: Dict[str, ConfigsRecord] =
|
30
|
+
parameters: Dict[str, ParametersRecord] = field(default_factory=dict)
|
31
|
+
metrics: Dict[str, MetricsRecord] = field(default_factory=dict)
|
32
|
+
configs: Dict[str, ConfigsRecord] = field(default_factory=dict)
|
43
33
|
|
44
34
|
def set_parameters(self, name: str, record: ParametersRecord) -> None:
|
45
35
|
"""Add a ParametersRecord."""
|