flwr-nightly 1.7.0.dev20240116__py3-none-any.whl → 1.7.0.dev20240118__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (55) hide show
  1. flwr/client/app.py +7 -4
  2. flwr/client/dpfedavg_numpy_client.py +4 -4
  3. flwr/client/grpc_client/connection.py +7 -4
  4. flwr/client/grpc_rere_client/connection.py +4 -4
  5. flwr/client/message_handler/message_handler.py +11 -2
  6. flwr/client/message_handler/task_handler.py +8 -6
  7. flwr/client/node_state_tests.py +1 -1
  8. flwr/client/numpy_client.py +2 -2
  9. flwr/client/rest_client/connection.py +7 -3
  10. flwr/client/secure_aggregation/secaggplus_handler.py +6 -6
  11. flwr/client/typing.py +1 -1
  12. flwr/common/configsrecord.py +98 -0
  13. flwr/common/logger.py +14 -0
  14. flwr/common/metricsrecord.py +96 -0
  15. flwr/common/parametersrecord.py +110 -0
  16. flwr/common/recordset.py +8 -18
  17. flwr/common/recordset_utils.py +87 -0
  18. flwr/common/retry_invoker.py +1 -0
  19. flwr/common/serde.py +12 -8
  20. flwr/common/typing.py +9 -0
  21. flwr/driver/app.py +5 -3
  22. flwr/driver/driver.py +3 -3
  23. flwr/driver/driver_client_proxy.py +24 -15
  24. flwr/driver/grpc_driver.py +6 -6
  25. flwr/proto/driver_pb2.py +23 -88
  26. flwr/proto/fleet_pb2.py +29 -111
  27. flwr/proto/node_pb2.py +7 -15
  28. flwr/proto/task_pb2.py +33 -127
  29. flwr/proto/transport_pb2.py +69 -278
  30. flwr/server/app.py +9 -3
  31. flwr/server/driver/driver_servicer.py +4 -4
  32. flwr/server/fleet/grpc_bidi/flower_service_servicer.py +5 -2
  33. flwr/server/fleet/grpc_bidi/grpc_bridge.py +9 -6
  34. flwr/server/fleet/grpc_bidi/grpc_client_proxy.py +4 -1
  35. flwr/server/fleet/grpc_bidi/grpc_server.py +3 -1
  36. flwr/server/fleet/grpc_bidi/ins_scheduler.py +7 -4
  37. flwr/server/fleet/grpc_rere/fleet_servicer.py +2 -2
  38. flwr/server/fleet/message_handler/message_handler.py +3 -3
  39. flwr/server/fleet/rest_rere/rest_api.py +1 -1
  40. flwr/server/state/in_memory_state.py +1 -1
  41. flwr/server/state/sqlite_state.py +8 -5
  42. flwr/server/state/state.py +1 -1
  43. flwr/server/strategy/aggregate.py +8 -8
  44. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  45. flwr/server/strategy/dpfedavg_fixed.py +2 -2
  46. flwr/server/strategy/fedavg_android.py +0 -2
  47. flwr/server/strategy/fedmedian.py +1 -1
  48. flwr/server/strategy/fedxgb_nn_avg.py +9 -2
  49. flwr/server/strategy/qfedavg.py +1 -1
  50. flwr/server/utils/validator.py +1 -1
  51. {flwr_nightly-1.7.0.dev20240116.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/METADATA +3 -3
  52. {flwr_nightly-1.7.0.dev20240116.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/RECORD +55 -51
  53. {flwr_nightly-1.7.0.dev20240116.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/LICENSE +0 -0
  54. {flwr_nightly-1.7.0.dev20240116.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/WHEEL +0 -0
  55. {flwr_nightly-1.7.0.dev20240116.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/entry_points.txt +0 -0
flwr/client/app.py CHANGED
@@ -35,7 +35,7 @@ from flwr.common.constant import (
35
35
  TRANSPORT_TYPES,
36
36
  )
37
37
  from flwr.common.logger import log, warn_experimental_feature
38
- from flwr.proto.task_pb2 import TaskIns, TaskRes
38
+ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
39
39
 
40
40
  from .flower import load_flower_callable
41
41
  from .grpc_client.connection import grpc_connection
@@ -138,10 +138,12 @@ def _check_actionable_client(
138
138
  client: Optional[Client], client_fn: Optional[ClientFn]
139
139
  ) -> None:
140
140
  if client_fn is None and client is None:
141
- raise Exception("Both `client_fn` and `client` are `None`, but one is required")
141
+ raise ValueError(
142
+ "Both `client_fn` and `client` are `None`, but one is required"
143
+ )
142
144
 
143
145
  if client_fn is not None and client is not None:
144
- raise Exception(
146
+ raise ValueError(
145
147
  "Both `client_fn` and `client` are provided, but only one is allowed"
146
148
  )
147
149
 
@@ -150,6 +152,7 @@ def _check_actionable_client(
150
152
  # pylint: disable=too-many-branches
151
153
  # pylint: disable=too-many-locals
152
154
  # pylint: disable=too-many-statements
155
+ # pylint: disable=too-many-arguments
153
156
  def start_client(
154
157
  *,
155
158
  server_address: str,
@@ -299,7 +302,7 @@ def _start_client_internal(
299
302
  cid: str, # pylint: disable=unused-argument
300
303
  ) -> Client:
301
304
  if client is None: # Added this to keep mypy happy
302
- raise Exception(
305
+ raise ValueError(
303
306
  "Both `client_fn` and `client` are `None`, but one is required"
304
307
  )
305
308
  return client # Always return the same instance
@@ -117,16 +117,16 @@ class DPFedAvgNumPyClient(NumPyClient):
117
117
  update = [np.subtract(x, y) for (x, y) in zip(updated_params, original_params)]
118
118
 
119
119
  if "dpfedavg_clip_norm" not in config:
120
- raise Exception("Clipping threshold not supplied by the server.")
120
+ raise KeyError("Clipping threshold not supplied by the server.")
121
121
  if not isinstance(config["dpfedavg_clip_norm"], float):
122
- raise Exception("Clipping threshold should be a floating point value.")
122
+ raise TypeError("Clipping threshold should be a floating point value.")
123
123
 
124
124
  # Clipping
125
125
  update, clipped = clip_by_l2(update, config["dpfedavg_clip_norm"])
126
126
 
127
127
  if "dpfedavg_noise_stddev" in config:
128
128
  if not isinstance(config["dpfedavg_noise_stddev"], float):
129
- raise Exception(
129
+ raise TypeError(
130
130
  "Scale of noise to be added should be a floating point value."
131
131
  )
132
132
  # Noising
@@ -138,7 +138,7 @@ class DPFedAvgNumPyClient(NumPyClient):
138
138
  # Calculating value of norm indicator bit, required for adaptive clipping
139
139
  if "dpfedavg_adaptive_clip_enabled" in config:
140
140
  if not isinstance(config["dpfedavg_adaptive_clip_enabled"], bool):
141
- raise Exception(
141
+ raise TypeError(
142
142
  "dpfedavg_adaptive_clip_enabled should be a boolean-valued flag."
143
143
  )
144
144
  metrics["dpfedavg_norm_bit"] = not clipped
@@ -25,10 +25,13 @@ from typing import Callable, Iterator, Optional, Tuple, Union
25
25
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
26
26
  from flwr.common.grpc import create_channel
27
27
  from flwr.common.logger import log
28
- from flwr.proto.node_pb2 import Node
29
- from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
30
- from flwr.proto.transport_pb2 import ClientMessage, ServerMessage
31
- from flwr.proto.transport_pb2_grpc import FlowerServiceStub
28
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
29
+ from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
30
+ from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
31
+ ClientMessage,
32
+ ServerMessage,
33
+ )
34
+ from flwr.proto.transport_pb2_grpc import FlowerServiceStub # pylint: disable=E0611
32
35
 
33
36
  # The following flags can be uncommented for debugging. Other possible values:
34
37
  # https://github.com/grpc/grpc/blob/master/doc/environment_variables.md
@@ -29,15 +29,15 @@ from flwr.client.message_handler.task_handler import (
29
29
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
30
30
  from flwr.common.grpc import create_channel
31
31
  from flwr.common.logger import log, warn_experimental_feature
32
- from flwr.proto.fleet_pb2 import (
32
+ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
33
33
  CreateNodeRequest,
34
34
  DeleteNodeRequest,
35
35
  PullTaskInsRequest,
36
36
  PushTaskResRequest,
37
37
  )
38
- from flwr.proto.fleet_pb2_grpc import FleetStub
39
- from flwr.proto.node_pb2 import Node
40
- from flwr.proto.task_pb2 import TaskIns, TaskRes
38
+ from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
39
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
40
+ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
41
41
 
42
42
  KEY_NODE = "node"
43
43
  KEY_TASK_INS = "current_task_ins"
@@ -32,8 +32,17 @@ from flwr.client.run_state import RunState
32
32
  from flwr.client.secure_aggregation import SecureAggregationHandler
33
33
  from flwr.client.typing import ClientFn
34
34
  from flwr.common import serde
35
- from flwr.proto.task_pb2 import SecureAggregation, Task, TaskIns, TaskRes
36
- from flwr.proto.transport_pb2 import ClientMessage, Reason, ServerMessage
35
+ from flwr.proto.task_pb2 import ( # pylint: disable=E0611
36
+ SecureAggregation,
37
+ Task,
38
+ TaskIns,
39
+ TaskRes,
40
+ )
41
+ from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
42
+ ClientMessage,
43
+ Reason,
44
+ ServerMessage,
45
+ )
37
46
 
38
47
 
39
48
  class UnexpectedServerMessage(Exception):
@@ -17,10 +17,13 @@
17
17
 
18
18
  from typing import Optional
19
19
 
20
- from flwr.proto.fleet_pb2 import PullTaskInsResponse
21
- from flwr.proto.node_pb2 import Node
22
- from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
23
- from flwr.proto.transport_pb2 import ClientMessage, ServerMessage
20
+ from flwr.proto.fleet_pb2 import PullTaskInsResponse # pylint: disable=E0611
21
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
22
+ from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
23
+ from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
24
+ ClientMessage,
25
+ ServerMessage,
26
+ )
24
27
 
25
28
 
26
29
  def validate_task_ins(task_ins: TaskIns, discard_reconnect_ins: bool) -> bool:
@@ -80,8 +83,7 @@ def validate_task_res(task_res: TaskRes) -> bool:
80
83
  initialized_fields_in_task = {field.name for field, _ in task_res.task.ListFields()}
81
84
 
82
85
  # Check if certain fields are already initialized
83
- # pylint: disable-next=too-many-boolean-expressions
84
- if (
86
+ if ( # pylint: disable-next=too-many-boolean-expressions
85
87
  "task_id" in initialized_fields_in_task_res
86
88
  or "group_id" in initialized_fields_in_task_res
87
89
  or "run_id" in initialized_fields_in_task_res
@@ -17,7 +17,7 @@
17
17
 
18
18
  from flwr.client.node_state import NodeState
19
19
  from flwr.client.run_state import RunState
20
- from flwr.proto.task_pb2 import TaskIns
20
+ from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
21
21
 
22
22
 
23
23
  def _run_dummy_task(state: RunState) -> RunState:
@@ -242,7 +242,7 @@ def _fit(self: Client, ins: FitIns) -> FitRes:
242
242
  and isinstance(results[1], int)
243
243
  and isinstance(results[2], dict)
244
244
  ):
245
- raise Exception(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT)
245
+ raise TypeError(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT)
246
246
 
247
247
  # Return FitRes
248
248
  parameters_prime, num_examples, metrics = results
@@ -266,7 +266,7 @@ def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes:
266
266
  and isinstance(results[1], int)
267
267
  and isinstance(results[2], dict)
268
268
  ):
269
- raise Exception(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_EVALUATE)
269
+ raise TypeError(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_EVALUATE)
270
270
 
271
271
  # Return EvaluateRes
272
272
  loss, num_examples, metrics = results
@@ -29,7 +29,7 @@ from flwr.client.message_handler.task_handler import (
29
29
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
30
30
  from flwr.common.constant import MISSING_EXTRA_REST
31
31
  from flwr.common.logger import log
32
- from flwr.proto.fleet_pb2 import (
32
+ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
33
33
  CreateNodeRequest,
34
34
  CreateNodeResponse,
35
35
  DeleteNodeRequest,
@@ -38,8 +38,8 @@ from flwr.proto.fleet_pb2 import (
38
38
  PushTaskResRequest,
39
39
  PushTaskResResponse,
40
40
  )
41
- from flwr.proto.node_pb2 import Node
42
- from flwr.proto.task_pb2 import TaskIns, TaskRes
41
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
42
+ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
43
43
 
44
44
  try:
45
45
  import requests
@@ -143,6 +143,7 @@ def http_request_response(
143
143
  },
144
144
  data=create_node_req_bytes,
145
145
  verify=verify,
146
+ timeout=None,
146
147
  )
147
148
 
148
149
  # Check status code and headers
@@ -185,6 +186,7 @@ def http_request_response(
185
186
  },
186
187
  data=delete_node_req_req_bytes,
187
188
  verify=verify,
189
+ timeout=None,
188
190
  )
189
191
 
190
192
  # Check status code and headers
@@ -225,6 +227,7 @@ def http_request_response(
225
227
  },
226
228
  data=pull_task_ins_req_bytes,
227
229
  verify=verify,
230
+ timeout=None,
228
231
  )
229
232
 
230
233
  # Check status code and headers
@@ -303,6 +306,7 @@ def http_request_response(
303
306
  },
304
307
  data=push_task_res_request_bytes,
305
308
  verify=verify,
309
+ timeout=None,
306
310
  )
307
311
 
308
312
  state[KEY_TASK_INS] = None
@@ -333,7 +333,7 @@ def _share_keys(
333
333
 
334
334
  # Check if the size is larger than threshold
335
335
  if len(state.public_keys_dict) < state.threshold:
336
- raise Exception("Available neighbours number smaller than threshold")
336
+ raise ValueError("Available neighbours number smaller than threshold")
337
337
 
338
338
  # Check if all public keys are unique
339
339
  pk_list: List[bytes] = []
@@ -341,14 +341,14 @@ def _share_keys(
341
341
  pk_list.append(pk1)
342
342
  pk_list.append(pk2)
343
343
  if len(set(pk_list)) != len(pk_list):
344
- raise Exception("Some public keys are identical")
344
+ raise ValueError("Some public keys are identical")
345
345
 
346
346
  # Check if public keys of this client are correct in the dictionary
347
347
  if (
348
348
  state.public_keys_dict[state.sid][0] != state.pk1
349
349
  or state.public_keys_dict[state.sid][1] != state.pk2
350
350
  ):
351
- raise Exception(
351
+ raise ValueError(
352
352
  "Own public keys are displayed in dict incorrectly, should not happen!"
353
353
  )
354
354
 
@@ -393,7 +393,7 @@ def _collect_masked_input(
393
393
  ciphertexts = cast(List[bytes], named_values[KEY_CIPHERTEXT_LIST])
394
394
  srcs = cast(List[int], named_values[KEY_SOURCE_LIST])
395
395
  if len(ciphertexts) + 1 < state.threshold:
396
- raise Exception("Not enough available neighbour clients.")
396
+ raise ValueError("Not enough available neighbour clients.")
397
397
 
398
398
  # Decrypt ciphertexts, verify their sources, and store shares.
399
399
  for src, ciphertext in zip(srcs, ciphertexts):
@@ -409,7 +409,7 @@ def _collect_masked_input(
409
409
  f"from {actual_src} instead of {src}."
410
410
  )
411
411
  if dst != state.sid:
412
- ValueError(
412
+ raise ValueError(
413
413
  f"Client {state.sid}: received an encrypted message"
414
414
  f"for Client {dst} from Client {src}."
415
415
  )
@@ -476,7 +476,7 @@ def _unmask(state: SecAggPlusState, named_values: Dict[str, Value]) -> Dict[str,
476
476
  # Send private mask seed share for every avaliable client (including itclient)
477
477
  # Send first private key share for building pairwise mask for every dropped client
478
478
  if len(active_sids) < state.threshold:
479
- raise Exception("Available neighbours number smaller than threshold")
479
+ raise ValueError("Available neighbours number smaller than threshold")
480
480
 
481
481
  sids, shares = [], []
482
482
  sids += active_sids
flwr/client/typing.py CHANGED
@@ -18,7 +18,7 @@ from dataclasses import dataclass
18
18
  from typing import Callable
19
19
 
20
20
  from flwr.client.run_state import RunState
21
- from flwr.proto.task_pb2 import TaskIns, TaskRes
21
+ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
22
22
 
23
23
  from .client import Client as Client
24
24
 
@@ -0,0 +1,98 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """ConfigsRecord."""
16
+
17
+
18
+ from dataclasses import dataclass, field
19
+ from typing import Dict, Optional, get_args
20
+
21
+ from .typing import ConfigsRecordValues, ConfigsScalar
22
+
23
+
24
+ @dataclass
25
+ class ConfigsRecord:
26
+ """Configs record."""
27
+
28
+ keep_input: bool
29
+ data: Dict[str, ConfigsRecordValues] = field(default_factory=dict)
30
+
31
+ def __init__(
32
+ self,
33
+ configs_dict: Optional[Dict[str, ConfigsRecordValues]] = None,
34
+ keep_input: bool = True,
35
+ ):
36
+ """Construct a ConfigsRecord object.
37
+
38
+ Parameters
39
+ ----------
40
+ configs_dict : Optional[Dict[str, ConfigsRecordValues]]
41
+ A dictionary that stores basic types (i.e. `str`, `int`, `float`, `bytes` as
42
+ defined in `ConfigsScalar`) and lists of such types (see
43
+ `ConfigsScalarList`).
44
+ keep_input : bool (default: True)
45
+ A boolean indicating whether config passed should be deleted from the input
46
+ dictionary immediately after adding them to the record. When set
47
+ to True, the data is duplicated in memory. If memory is a concern, set
48
+ it to False.
49
+ """
50
+ self.keep_input = keep_input
51
+ self.data = {}
52
+ if configs_dict:
53
+ self.set_configs(configs_dict)
54
+
55
+ def set_configs(self, configs_dict: Dict[str, ConfigsRecordValues]) -> None:
56
+ """Add configs to the record.
57
+
58
+ Parameters
59
+ ----------
60
+ configs_dict : Dict[str, ConfigsRecordValues]
61
+ A dictionary that stores basic types (i.e. `str`,`int`, `float`, `bytes` as
62
+ defined in `ConfigsRecordValues`) and list of such types (see
63
+ `ConfigsScalarList`).
64
+ """
65
+ if any(not isinstance(k, str) for k in configs_dict.keys()):
66
+ raise TypeError(f"Not all keys are of valid type. Expected {str}")
67
+
68
+ def is_valid(value: ConfigsScalar) -> None:
69
+ """Check if value is of expected type."""
70
+ if not isinstance(value, get_args(ConfigsScalar)):
71
+ raise TypeError(
72
+ "Not all values are of valid type."
73
+ f" Expected {ConfigsRecordValues} but you passed {type(value)}."
74
+ )
75
+
76
+ # Check types of values
77
+ # Split between those values that are list and those that aren't
78
+ # then process in the same way
79
+ for value in configs_dict.values():
80
+ if isinstance(value, list):
81
+ # If your lists are large (e.g. 1M+ elements) this will be slow
82
+ # 1s to check 10M element list on a M2 Pro
83
+ # In such settings, you'd be better of treating such config as
84
+ # an array and pass it to a ParametersRecord.
85
+ for list_value in value:
86
+ is_valid(list_value)
87
+ else:
88
+ is_valid(value)
89
+
90
+ # Add configs to record
91
+ if self.keep_input:
92
+ # Copy
93
+ self.data = configs_dict.copy()
94
+ else:
95
+ # Add entries to dataclass without duplicating memory
96
+ for key in list(configs_dict.keys()):
97
+ self.data[key] = configs_dict[key]
98
+ del configs_dict[key]
flwr/common/logger.py CHANGED
@@ -111,3 +111,17 @@ def warn_experimental_feature(name: str) -> None:
111
111
  """,
112
112
  name,
113
113
  )
114
+
115
+
116
+ def warn_deprecated_feature(name: str) -> None:
117
+ """Warn the user when they use a deprecated feature."""
118
+ log(
119
+ WARN,
120
+ """
121
+ DEPRECATED FEATURE: %s
122
+
123
+ This is a deprecated feature. It will be removed
124
+ entirely in future versions of Flower.
125
+ """,
126
+ name,
127
+ )
@@ -0,0 +1,96 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """MetricsRecord."""
16
+
17
+
18
+ from dataclasses import dataclass, field
19
+ from typing import Dict, Optional, get_args
20
+
21
+ from .typing import MetricsRecordValues, MetricsScalar
22
+
23
+
24
+ @dataclass
25
+ class MetricsRecord:
26
+ """Metrics record."""
27
+
28
+ keep_input: bool
29
+ data: Dict[str, MetricsRecordValues] = field(default_factory=dict)
30
+
31
+ def __init__(
32
+ self,
33
+ metrics_dict: Optional[Dict[str, MetricsRecordValues]] = None,
34
+ keep_input: bool = True,
35
+ ):
36
+ """Construct a MetricsRecord object.
37
+
38
+ Parameters
39
+ ----------
40
+ metrics_dict : Optional[Dict[str, MetricsRecordValues]]
41
+ A dictionary that stores basic types (i.e. `int`, `float` as defined
42
+ in `MetricsScalar`) and list of such types (see `MetricsScalarList`).
43
+ keep_input : bool (default: True)
44
+ A boolean indicating whether metrics should be deleted from the input
45
+ dictionary immediately after adding them to the record. When set
46
+ to True, the data is duplicated in memory. If memory is a concern, set
47
+ it to False.
48
+ """
49
+ self.keep_input = keep_input
50
+ self.data = {}
51
+ if metrics_dict:
52
+ self.set_metrics(metrics_dict)
53
+
54
+ def set_metrics(self, metrics_dict: Dict[str, MetricsRecordValues]) -> None:
55
+ """Add metrics to the record.
56
+
57
+ Parameters
58
+ ----------
59
+ metrics_dict : Dict[str, MetricsRecordValues]
60
+ A dictionary that stores basic types (i.e. `int`, `float` as defined
61
+ in `MetricsScalar`) and list of such types (see `MetricsScalarList`).
62
+ """
63
+ if any(not isinstance(k, str) for k in metrics_dict.keys()):
64
+ raise TypeError(f"Not all keys are of valid type. Expected {str}.")
65
+
66
+ def is_valid(value: MetricsScalar) -> None:
67
+ """Check if value is of expected type."""
68
+ if not isinstance(value, get_args(MetricsScalar)):
69
+ raise TypeError(
70
+ "Not all values are of valid type."
71
+ f" Expected {MetricsRecordValues} but you passed {type(value)}."
72
+ )
73
+
74
+ # Check types of values
75
+ # Split between those values that are list and those that aren't
76
+ # then process in the same way
77
+ for value in metrics_dict.values():
78
+ if isinstance(value, list):
79
+ # If your lists are large (e.g. 1M+ elements) this will be slow
80
+ # 1s to check 10M element list on a M2 Pro
81
+ # In such settings, you'd be better of treating such metric as
82
+ # an array and pass it to a ParametersRecord.
83
+ for list_value in value:
84
+ is_valid(list_value)
85
+ else:
86
+ is_valid(value)
87
+
88
+ # Add metrics to record
89
+ if self.keep_input:
90
+ # Copy
91
+ self.data = metrics_dict.copy()
92
+ else:
93
+ # Add entries to dataclass without duplicating memory
94
+ for key in list(metrics_dict.keys()):
95
+ self.data[key] = metrics_dict[key]
96
+ del metrics_dict[key]
@@ -0,0 +1,110 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """ParametersRecord and Array."""
16
+
17
+
18
+ from dataclasses import dataclass, field
19
+ from typing import List, Optional, OrderedDict
20
+
21
+
22
+ @dataclass
23
+ class Array:
24
+ """Array type.
25
+
26
+ A dataclass containing serialized data from an array-like or tensor-like object
27
+ along with some metadata about it.
28
+
29
+ Parameters
30
+ ----------
31
+ dtype : str
32
+ A string representing the data type of the serialised object (e.g. `np.float32`)
33
+
34
+ shape : List[int]
35
+ A list representing the shape of the unserialized array-like object. This is
36
+ used to deserialize the data (depending on the serialization method) or simply
37
+ as a metadata field.
38
+
39
+ stype : str
40
+ A string indicating the type of serialisation mechanism used to generate the
41
+ bytes in `data` from an array-like or tensor-like object.
42
+
43
+ data: bytes
44
+ A buffer of bytes containing the data.
45
+ """
46
+
47
+ dtype: str
48
+ shape: List[int]
49
+ stype: str
50
+ data: bytes
51
+
52
+
53
+ @dataclass
54
+ class ParametersRecord:
55
+ """Parameters record.
56
+
57
+ A dataclass storing named Arrays in order. This means that it holds entries as an
58
+ OrderedDict[str, Array]. ParametersRecord objects can be viewed as an equivalent to
59
+ PyTorch's state_dict, but holding serialised tensors instead.
60
+ """
61
+
62
+ keep_input: bool
63
+ data: OrderedDict[str, Array] = field(default_factory=OrderedDict[str, Array])
64
+
65
+ def __init__(
66
+ self,
67
+ array_dict: Optional[OrderedDict[str, Array]] = None,
68
+ keep_input: bool = False,
69
+ ) -> None:
70
+ """Construct a ParametersRecord object.
71
+
72
+ Parameters
73
+ ----------
74
+ array_dict : Optional[OrderedDict[str, Array]]
75
+ A dictionary that stores serialized array-like or tensor-like objects.
76
+ keep_input : bool (default: False)
77
+ A boolean indicating whether parameters should be deleted from the input
78
+ dictionary immediately after adding them to the record. If False, the
79
+ dictionary passed to `set_parameters()` will be empty once exiting from that
80
+ function. This is the desired behaviour when working with very large
81
+ models/tensors/arrays. However, if you plan to continue working with your
82
+ parameters after adding it to the record, set this flag to True. When set
83
+ to True, the data is duplicated in memory.
84
+ """
85
+ self.keep_input = keep_input
86
+ self.data = OrderedDict()
87
+ if array_dict:
88
+ self.set_parameters(array_dict)
89
+
90
+ def set_parameters(self, array_dict: OrderedDict[str, Array]) -> None:
91
+ """Add parameters to record.
92
+
93
+ Parameters
94
+ ----------
95
+ array_dict : OrderedDict[str, Array]
96
+ A dictionary that stores serialized array-like or tensor-like objects.
97
+ """
98
+ if any(not isinstance(k, str) for k in array_dict.keys()):
99
+ raise TypeError(f"Not all keys are of valid type. Expected {str}")
100
+ if any(not isinstance(v, Array) for v in array_dict.values()):
101
+ raise TypeError(f"Not all values are of valid type. Expected {Array}")
102
+
103
+ if self.keep_input:
104
+ # Copy
105
+ self.data = OrderedDict(array_dict)
106
+ else:
107
+ # Add entries to dataclass without duplicating memory
108
+ for key in list(array_dict.keys()):
109
+ self.data[key] = array_dict[key]
110
+ del array_dict[key]
flwr/common/recordset.py CHANGED
@@ -14,32 +14,22 @@
14
14
  # ==============================================================================
15
15
  """RecordSet."""
16
16
 
17
- from dataclasses import dataclass
18
- from typing import Dict
19
-
20
-
21
- @dataclass
22
- class ParametersRecord:
23
- """Parameters record."""
24
-
25
-
26
- @dataclass
27
- class MetricsRecord:
28
- """Metrics record."""
29
17
 
18
+ from dataclasses import dataclass, field
19
+ from typing import Dict
30
20
 
31
- @dataclass
32
- class ConfigsRecord:
33
- """Configs record."""
21
+ from .configsrecord import ConfigsRecord
22
+ from .metricsrecord import MetricsRecord
23
+ from .parametersrecord import ParametersRecord
34
24
 
35
25
 
36
26
  @dataclass
37
27
  class RecordSet:
38
28
  """Definition of RecordSet."""
39
29
 
40
- parameters: Dict[str, ParametersRecord] = {}
41
- metrics: Dict[str, MetricsRecord] = {}
42
- configs: Dict[str, ConfigsRecord] = {}
30
+ parameters: Dict[str, ParametersRecord] = field(default_factory=dict)
31
+ metrics: Dict[str, MetricsRecord] = field(default_factory=dict)
32
+ configs: Dict[str, ConfigsRecord] = field(default_factory=dict)
43
33
 
44
34
  def set_parameters(self, name: str, record: ParametersRecord) -> None:
45
35
  """Add a ParametersRecord."""