flwr-nightly 1.7.0.dev20240116__py3-none-any.whl → 1.7.0.dev20240118__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. flwr/client/app.py +7 -4
  2. flwr/client/dpfedavg_numpy_client.py +4 -4
  3. flwr/client/grpc_client/connection.py +7 -4
  4. flwr/client/grpc_rere_client/connection.py +4 -4
  5. flwr/client/message_handler/message_handler.py +11 -2
  6. flwr/client/message_handler/task_handler.py +8 -6
  7. flwr/client/node_state_tests.py +1 -1
  8. flwr/client/numpy_client.py +2 -2
  9. flwr/client/rest_client/connection.py +7 -3
  10. flwr/client/secure_aggregation/secaggplus_handler.py +6 -6
  11. flwr/client/typing.py +1 -1
  12. flwr/common/configsrecord.py +98 -0
  13. flwr/common/logger.py +14 -0
  14. flwr/common/metricsrecord.py +96 -0
  15. flwr/common/parametersrecord.py +110 -0
  16. flwr/common/recordset.py +8 -18
  17. flwr/common/recordset_utils.py +87 -0
  18. flwr/common/retry_invoker.py +1 -0
  19. flwr/common/serde.py +12 -8
  20. flwr/common/typing.py +9 -0
  21. flwr/driver/app.py +5 -3
  22. flwr/driver/driver.py +3 -3
  23. flwr/driver/driver_client_proxy.py +24 -15
  24. flwr/driver/grpc_driver.py +6 -6
  25. flwr/proto/driver_pb2.py +23 -88
  26. flwr/proto/fleet_pb2.py +29 -111
  27. flwr/proto/node_pb2.py +7 -15
  28. flwr/proto/task_pb2.py +33 -127
  29. flwr/proto/transport_pb2.py +69 -278
  30. flwr/server/app.py +9 -3
  31. flwr/server/driver/driver_servicer.py +4 -4
  32. flwr/server/fleet/grpc_bidi/flower_service_servicer.py +5 -2
  33. flwr/server/fleet/grpc_bidi/grpc_bridge.py +9 -6
  34. flwr/server/fleet/grpc_bidi/grpc_client_proxy.py +4 -1
  35. flwr/server/fleet/grpc_bidi/grpc_server.py +3 -1
  36. flwr/server/fleet/grpc_bidi/ins_scheduler.py +7 -4
  37. flwr/server/fleet/grpc_rere/fleet_servicer.py +2 -2
  38. flwr/server/fleet/message_handler/message_handler.py +3 -3
  39. flwr/server/fleet/rest_rere/rest_api.py +1 -1
  40. flwr/server/state/in_memory_state.py +1 -1
  41. flwr/server/state/sqlite_state.py +8 -5
  42. flwr/server/state/state.py +1 -1
  43. flwr/server/strategy/aggregate.py +8 -8
  44. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  45. flwr/server/strategy/dpfedavg_fixed.py +2 -2
  46. flwr/server/strategy/fedavg_android.py +0 -2
  47. flwr/server/strategy/fedmedian.py +1 -1
  48. flwr/server/strategy/fedxgb_nn_avg.py +9 -2
  49. flwr/server/strategy/qfedavg.py +1 -1
  50. flwr/server/utils/validator.py +1 -1
  51. {flwr_nightly-1.7.0.dev20240116.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/METADATA +3 -3
  52. {flwr_nightly-1.7.0.dev20240116.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/RECORD +55 -51
  53. {flwr_nightly-1.7.0.dev20240116.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/LICENSE +0 -0
  54. {flwr_nightly-1.7.0.dev20240116.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/WHEEL +0 -0
  55. {flwr_nightly-1.7.0.dev20240116.dist-info → flwr_nightly-1.7.0.dev20240118.dist-info}/entry_points.txt +0 -0
flwr/client/app.py CHANGED
@@ -35,7 +35,7 @@ from flwr.common.constant import (
35
35
  TRANSPORT_TYPES,
36
36
  )
37
37
  from flwr.common.logger import log, warn_experimental_feature
38
- from flwr.proto.task_pb2 import TaskIns, TaskRes
38
+ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
39
39
 
40
40
  from .flower import load_flower_callable
41
41
  from .grpc_client.connection import grpc_connection
@@ -138,10 +138,12 @@ def _check_actionable_client(
138
138
  client: Optional[Client], client_fn: Optional[ClientFn]
139
139
  ) -> None:
140
140
  if client_fn is None and client is None:
141
- raise Exception("Both `client_fn` and `client` are `None`, but one is required")
141
+ raise ValueError(
142
+ "Both `client_fn` and `client` are `None`, but one is required"
143
+ )
142
144
 
143
145
  if client_fn is not None and client is not None:
144
- raise Exception(
146
+ raise ValueError(
145
147
  "Both `client_fn` and `client` are provided, but only one is allowed"
146
148
  )
147
149
 
@@ -150,6 +152,7 @@ def _check_actionable_client(
150
152
  # pylint: disable=too-many-branches
151
153
  # pylint: disable=too-many-locals
152
154
  # pylint: disable=too-many-statements
155
+ # pylint: disable=too-many-arguments
153
156
  def start_client(
154
157
  *,
155
158
  server_address: str,
@@ -299,7 +302,7 @@ def _start_client_internal(
299
302
  cid: str, # pylint: disable=unused-argument
300
303
  ) -> Client:
301
304
  if client is None: # Added this to keep mypy happy
302
- raise Exception(
305
+ raise ValueError(
303
306
  "Both `client_fn` and `client` are `None`, but one is required"
304
307
  )
305
308
  return client # Always return the same instance
@@ -117,16 +117,16 @@ class DPFedAvgNumPyClient(NumPyClient):
117
117
  update = [np.subtract(x, y) for (x, y) in zip(updated_params, original_params)]
118
118
 
119
119
  if "dpfedavg_clip_norm" not in config:
120
- raise Exception("Clipping threshold not supplied by the server.")
120
+ raise KeyError("Clipping threshold not supplied by the server.")
121
121
  if not isinstance(config["dpfedavg_clip_norm"], float):
122
- raise Exception("Clipping threshold should be a floating point value.")
122
+ raise TypeError("Clipping threshold should be a floating point value.")
123
123
 
124
124
  # Clipping
125
125
  update, clipped = clip_by_l2(update, config["dpfedavg_clip_norm"])
126
126
 
127
127
  if "dpfedavg_noise_stddev" in config:
128
128
  if not isinstance(config["dpfedavg_noise_stddev"], float):
129
- raise Exception(
129
+ raise TypeError(
130
130
  "Scale of noise to be added should be a floating point value."
131
131
  )
132
132
  # Noising
@@ -138,7 +138,7 @@ class DPFedAvgNumPyClient(NumPyClient):
138
138
  # Calculating value of norm indicator bit, required for adaptive clipping
139
139
  if "dpfedavg_adaptive_clip_enabled" in config:
140
140
  if not isinstance(config["dpfedavg_adaptive_clip_enabled"], bool):
141
- raise Exception(
141
+ raise TypeError(
142
142
  "dpfedavg_adaptive_clip_enabled should be a boolean-valued flag."
143
143
  )
144
144
  metrics["dpfedavg_norm_bit"] = not clipped
@@ -25,10 +25,13 @@ from typing import Callable, Iterator, Optional, Tuple, Union
25
25
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
26
26
  from flwr.common.grpc import create_channel
27
27
  from flwr.common.logger import log
28
- from flwr.proto.node_pb2 import Node
29
- from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
30
- from flwr.proto.transport_pb2 import ClientMessage, ServerMessage
31
- from flwr.proto.transport_pb2_grpc import FlowerServiceStub
28
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
29
+ from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
30
+ from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
31
+ ClientMessage,
32
+ ServerMessage,
33
+ )
34
+ from flwr.proto.transport_pb2_grpc import FlowerServiceStub # pylint: disable=E0611
32
35
 
33
36
  # The following flags can be uncommented for debugging. Other possible values:
34
37
  # https://github.com/grpc/grpc/blob/master/doc/environment_variables.md
@@ -29,15 +29,15 @@ from flwr.client.message_handler.task_handler import (
29
29
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
30
30
  from flwr.common.grpc import create_channel
31
31
  from flwr.common.logger import log, warn_experimental_feature
32
- from flwr.proto.fleet_pb2 import (
32
+ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
33
33
  CreateNodeRequest,
34
34
  DeleteNodeRequest,
35
35
  PullTaskInsRequest,
36
36
  PushTaskResRequest,
37
37
  )
38
- from flwr.proto.fleet_pb2_grpc import FleetStub
39
- from flwr.proto.node_pb2 import Node
40
- from flwr.proto.task_pb2 import TaskIns, TaskRes
38
+ from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
39
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
40
+ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
41
41
 
42
42
  KEY_NODE = "node"
43
43
  KEY_TASK_INS = "current_task_ins"
@@ -32,8 +32,17 @@ from flwr.client.run_state import RunState
32
32
  from flwr.client.secure_aggregation import SecureAggregationHandler
33
33
  from flwr.client.typing import ClientFn
34
34
  from flwr.common import serde
35
- from flwr.proto.task_pb2 import SecureAggregation, Task, TaskIns, TaskRes
36
- from flwr.proto.transport_pb2 import ClientMessage, Reason, ServerMessage
35
+ from flwr.proto.task_pb2 import ( # pylint: disable=E0611
36
+ SecureAggregation,
37
+ Task,
38
+ TaskIns,
39
+ TaskRes,
40
+ )
41
+ from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
42
+ ClientMessage,
43
+ Reason,
44
+ ServerMessage,
45
+ )
37
46
 
38
47
 
39
48
  class UnexpectedServerMessage(Exception):
@@ -17,10 +17,13 @@
17
17
 
18
18
  from typing import Optional
19
19
 
20
- from flwr.proto.fleet_pb2 import PullTaskInsResponse
21
- from flwr.proto.node_pb2 import Node
22
- from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
23
- from flwr.proto.transport_pb2 import ClientMessage, ServerMessage
20
+ from flwr.proto.fleet_pb2 import PullTaskInsResponse # pylint: disable=E0611
21
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
22
+ from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
23
+ from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
24
+ ClientMessage,
25
+ ServerMessage,
26
+ )
24
27
 
25
28
 
26
29
  def validate_task_ins(task_ins: TaskIns, discard_reconnect_ins: bool) -> bool:
@@ -80,8 +83,7 @@ def validate_task_res(task_res: TaskRes) -> bool:
80
83
  initialized_fields_in_task = {field.name for field, _ in task_res.task.ListFields()}
81
84
 
82
85
  # Check if certain fields are already initialized
83
- # pylint: disable-next=too-many-boolean-expressions
84
- if (
86
+ if ( # pylint: disable-next=too-many-boolean-expressions
85
87
  "task_id" in initialized_fields_in_task_res
86
88
  or "group_id" in initialized_fields_in_task_res
87
89
  or "run_id" in initialized_fields_in_task_res
@@ -17,7 +17,7 @@
17
17
 
18
18
  from flwr.client.node_state import NodeState
19
19
  from flwr.client.run_state import RunState
20
- from flwr.proto.task_pb2 import TaskIns
20
+ from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
21
21
 
22
22
 
23
23
  def _run_dummy_task(state: RunState) -> RunState:
@@ -242,7 +242,7 @@ def _fit(self: Client, ins: FitIns) -> FitRes:
242
242
  and isinstance(results[1], int)
243
243
  and isinstance(results[2], dict)
244
244
  ):
245
- raise Exception(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT)
245
+ raise TypeError(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT)
246
246
 
247
247
  # Return FitRes
248
248
  parameters_prime, num_examples, metrics = results
@@ -266,7 +266,7 @@ def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes:
266
266
  and isinstance(results[1], int)
267
267
  and isinstance(results[2], dict)
268
268
  ):
269
- raise Exception(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_EVALUATE)
269
+ raise TypeError(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_EVALUATE)
270
270
 
271
271
  # Return EvaluateRes
272
272
  loss, num_examples, metrics = results
@@ -29,7 +29,7 @@ from flwr.client.message_handler.task_handler import (
29
29
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
30
30
  from flwr.common.constant import MISSING_EXTRA_REST
31
31
  from flwr.common.logger import log
32
- from flwr.proto.fleet_pb2 import (
32
+ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
33
33
  CreateNodeRequest,
34
34
  CreateNodeResponse,
35
35
  DeleteNodeRequest,
@@ -38,8 +38,8 @@ from flwr.proto.fleet_pb2 import (
38
38
  PushTaskResRequest,
39
39
  PushTaskResResponse,
40
40
  )
41
- from flwr.proto.node_pb2 import Node
42
- from flwr.proto.task_pb2 import TaskIns, TaskRes
41
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
42
+ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
43
43
 
44
44
  try:
45
45
  import requests
@@ -143,6 +143,7 @@ def http_request_response(
143
143
  },
144
144
  data=create_node_req_bytes,
145
145
  verify=verify,
146
+ timeout=None,
146
147
  )
147
148
 
148
149
  # Check status code and headers
@@ -185,6 +186,7 @@ def http_request_response(
185
186
  },
186
187
  data=delete_node_req_req_bytes,
187
188
  verify=verify,
189
+ timeout=None,
188
190
  )
189
191
 
190
192
  # Check status code and headers
@@ -225,6 +227,7 @@ def http_request_response(
225
227
  },
226
228
  data=pull_task_ins_req_bytes,
227
229
  verify=verify,
230
+ timeout=None,
228
231
  )
229
232
 
230
233
  # Check status code and headers
@@ -303,6 +306,7 @@ def http_request_response(
303
306
  },
304
307
  data=push_task_res_request_bytes,
305
308
  verify=verify,
309
+ timeout=None,
306
310
  )
307
311
 
308
312
  state[KEY_TASK_INS] = None
@@ -333,7 +333,7 @@ def _share_keys(
333
333
 
334
334
  # Check if the size is larger than threshold
335
335
  if len(state.public_keys_dict) < state.threshold:
336
- raise Exception("Available neighbours number smaller than threshold")
336
+ raise ValueError("Available neighbours number smaller than threshold")
337
337
 
338
338
  # Check if all public keys are unique
339
339
  pk_list: List[bytes] = []
@@ -341,14 +341,14 @@ def _share_keys(
341
341
  pk_list.append(pk1)
342
342
  pk_list.append(pk2)
343
343
  if len(set(pk_list)) != len(pk_list):
344
- raise Exception("Some public keys are identical")
344
+ raise ValueError("Some public keys are identical")
345
345
 
346
346
  # Check if public keys of this client are correct in the dictionary
347
347
  if (
348
348
  state.public_keys_dict[state.sid][0] != state.pk1
349
349
  or state.public_keys_dict[state.sid][1] != state.pk2
350
350
  ):
351
- raise Exception(
351
+ raise ValueError(
352
352
  "Own public keys are displayed in dict incorrectly, should not happen!"
353
353
  )
354
354
 
@@ -393,7 +393,7 @@ def _collect_masked_input(
393
393
  ciphertexts = cast(List[bytes], named_values[KEY_CIPHERTEXT_LIST])
394
394
  srcs = cast(List[int], named_values[KEY_SOURCE_LIST])
395
395
  if len(ciphertexts) + 1 < state.threshold:
396
- raise Exception("Not enough available neighbour clients.")
396
+ raise ValueError("Not enough available neighbour clients.")
397
397
 
398
398
  # Decrypt ciphertexts, verify their sources, and store shares.
399
399
  for src, ciphertext in zip(srcs, ciphertexts):
@@ -409,7 +409,7 @@ def _collect_masked_input(
409
409
  f"from {actual_src} instead of {src}."
410
410
  )
411
411
  if dst != state.sid:
412
- ValueError(
412
+ raise ValueError(
413
413
  f"Client {state.sid}: received an encrypted message"
414
414
  f"for Client {dst} from Client {src}."
415
415
  )
@@ -476,7 +476,7 @@ def _unmask(state: SecAggPlusState, named_values: Dict[str, Value]) -> Dict[str,
476
476
  # Send private mask seed share for every avaliable client (including itclient)
477
477
  # Send first private key share for building pairwise mask for every dropped client
478
478
  if len(active_sids) < state.threshold:
479
- raise Exception("Available neighbours number smaller than threshold")
479
+ raise ValueError("Available neighbours number smaller than threshold")
480
480
 
481
481
  sids, shares = [], []
482
482
  sids += active_sids
flwr/client/typing.py CHANGED
@@ -18,7 +18,7 @@ from dataclasses import dataclass
18
18
  from typing import Callable
19
19
 
20
20
  from flwr.client.run_state import RunState
21
- from flwr.proto.task_pb2 import TaskIns, TaskRes
21
+ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
22
22
 
23
23
  from .client import Client as Client
24
24
 
@@ -0,0 +1,98 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """ConfigsRecord."""
16
+
17
+
18
+ from dataclasses import dataclass, field
19
+ from typing import Dict, Optional, get_args
20
+
21
+ from .typing import ConfigsRecordValues, ConfigsScalar
22
+
23
+
24
+ @dataclass
25
+ class ConfigsRecord:
26
+ """Configs record."""
27
+
28
+ keep_input: bool
29
+ data: Dict[str, ConfigsRecordValues] = field(default_factory=dict)
30
+
31
+ def __init__(
32
+ self,
33
+ configs_dict: Optional[Dict[str, ConfigsRecordValues]] = None,
34
+ keep_input: bool = True,
35
+ ):
36
+ """Construct a ConfigsRecord object.
37
+
38
+ Parameters
39
+ ----------
40
+ configs_dict : Optional[Dict[str, ConfigsRecordValues]]
41
+ A dictionary that stores basic types (i.e. `str`, `int`, `float`, `bytes` as
42
+ defined in `ConfigsScalar`) and lists of such types (see
43
+ `ConfigsScalarList`).
44
+ keep_input : bool (default: True)
45
+ A boolean indicating whether config passed should be deleted from the input
46
+ dictionary immediately after adding them to the record. When set
47
+ to True, the data is duplicated in memory. If memory is a concern, set
48
+ it to False.
49
+ """
50
+ self.keep_input = keep_input
51
+ self.data = {}
52
+ if configs_dict:
53
+ self.set_configs(configs_dict)
54
+
55
+ def set_configs(self, configs_dict: Dict[str, ConfigsRecordValues]) -> None:
56
+ """Add configs to the record.
57
+
58
+ Parameters
59
+ ----------
60
+ configs_dict : Dict[str, ConfigsRecordValues]
61
+ A dictionary that stores basic types (i.e. `str`,`int`, `float`, `bytes` as
62
+ defined in `ConfigsRecordValues`) and list of such types (see
63
+ `ConfigsScalarList`).
64
+ """
65
+ if any(not isinstance(k, str) for k in configs_dict.keys()):
66
+ raise TypeError(f"Not all keys are of valid type. Expected {str}")
67
+
68
+ def is_valid(value: ConfigsScalar) -> None:
69
+ """Check if value is of expected type."""
70
+ if not isinstance(value, get_args(ConfigsScalar)):
71
+ raise TypeError(
72
+ "Not all values are of valid type."
73
+ f" Expected {ConfigsRecordValues} but you passed {type(value)}."
74
+ )
75
+
76
+ # Check types of values
77
+ # Split between those values that are list and those that aren't
78
+ # then process in the same way
79
+ for value in configs_dict.values():
80
+ if isinstance(value, list):
81
+ # If your lists are large (e.g. 1M+ elements) this will be slow
82
+ # 1s to check 10M element list on a M2 Pro
83
+ # In such settings, you'd be better of treating such config as
84
+ # an array and pass it to a ParametersRecord.
85
+ for list_value in value:
86
+ is_valid(list_value)
87
+ else:
88
+ is_valid(value)
89
+
90
+ # Add configs to record
91
+ if self.keep_input:
92
+ # Copy
93
+ self.data = configs_dict.copy()
94
+ else:
95
+ # Add entries to dataclass without duplicating memory
96
+ for key in list(configs_dict.keys()):
97
+ self.data[key] = configs_dict[key]
98
+ del configs_dict[key]
flwr/common/logger.py CHANGED
@@ -111,3 +111,17 @@ def warn_experimental_feature(name: str) -> None:
111
111
  """,
112
112
  name,
113
113
  )
114
+
115
+
116
+ def warn_deprecated_feature(name: str) -> None:
117
+ """Warn the user when they use a deprecated feature."""
118
+ log(
119
+ WARN,
120
+ """
121
+ DEPRECATED FEATURE: %s
122
+
123
+ This is a deprecated feature. It will be removed
124
+ entirely in future versions of Flower.
125
+ """,
126
+ name,
127
+ )
@@ -0,0 +1,96 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """MetricsRecord."""
16
+
17
+
18
+ from dataclasses import dataclass, field
19
+ from typing import Dict, Optional, get_args
20
+
21
+ from .typing import MetricsRecordValues, MetricsScalar
22
+
23
+
24
+ @dataclass
25
+ class MetricsRecord:
26
+ """Metrics record."""
27
+
28
+ keep_input: bool
29
+ data: Dict[str, MetricsRecordValues] = field(default_factory=dict)
30
+
31
+ def __init__(
32
+ self,
33
+ metrics_dict: Optional[Dict[str, MetricsRecordValues]] = None,
34
+ keep_input: bool = True,
35
+ ):
36
+ """Construct a MetricsRecord object.
37
+
38
+ Parameters
39
+ ----------
40
+ metrics_dict : Optional[Dict[str, MetricsRecordValues]]
41
+ A dictionary that stores basic types (i.e. `int`, `float` as defined
42
+ in `MetricsScalar`) and list of such types (see `MetricsScalarList`).
43
+ keep_input : bool (default: True)
44
+ A boolean indicating whether metrics should be deleted from the input
45
+ dictionary immediately after adding them to the record. When set
46
+ to True, the data is duplicated in memory. If memory is a concern, set
47
+ it to False.
48
+ """
49
+ self.keep_input = keep_input
50
+ self.data = {}
51
+ if metrics_dict:
52
+ self.set_metrics(metrics_dict)
53
+
54
+ def set_metrics(self, metrics_dict: Dict[str, MetricsRecordValues]) -> None:
55
+ """Add metrics to the record.
56
+
57
+ Parameters
58
+ ----------
59
+ metrics_dict : Dict[str, MetricsRecordValues]
60
+ A dictionary that stores basic types (i.e. `int`, `float` as defined
61
+ in `MetricsScalar`) and list of such types (see `MetricsScalarList`).
62
+ """
63
+ if any(not isinstance(k, str) for k in metrics_dict.keys()):
64
+ raise TypeError(f"Not all keys are of valid type. Expected {str}.")
65
+
66
+ def is_valid(value: MetricsScalar) -> None:
67
+ """Check if value is of expected type."""
68
+ if not isinstance(value, get_args(MetricsScalar)):
69
+ raise TypeError(
70
+ "Not all values are of valid type."
71
+ f" Expected {MetricsRecordValues} but you passed {type(value)}."
72
+ )
73
+
74
+ # Check types of values
75
+ # Split between those values that are list and those that aren't
76
+ # then process in the same way
77
+ for value in metrics_dict.values():
78
+ if isinstance(value, list):
79
+ # If your lists are large (e.g. 1M+ elements) this will be slow
80
+ # 1s to check 10M element list on a M2 Pro
81
+ # In such settings, you'd be better of treating such metric as
82
+ # an array and pass it to a ParametersRecord.
83
+ for list_value in value:
84
+ is_valid(list_value)
85
+ else:
86
+ is_valid(value)
87
+
88
+ # Add metrics to record
89
+ if self.keep_input:
90
+ # Copy
91
+ self.data = metrics_dict.copy()
92
+ else:
93
+ # Add entries to dataclass without duplicating memory
94
+ for key in list(metrics_dict.keys()):
95
+ self.data[key] = metrics_dict[key]
96
+ del metrics_dict[key]
@@ -0,0 +1,110 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """ParametersRecord and Array."""
16
+
17
+
18
+ from dataclasses import dataclass, field
19
+ from typing import List, Optional, OrderedDict
20
+
21
+
22
+ @dataclass
23
+ class Array:
24
+ """Array type.
25
+
26
+ A dataclass containing serialized data from an array-like or tensor-like object
27
+ along with some metadata about it.
28
+
29
+ Parameters
30
+ ----------
31
+ dtype : str
32
+ A string representing the data type of the serialised object (e.g. `np.float32`)
33
+
34
+ shape : List[int]
35
+ A list representing the shape of the unserialized array-like object. This is
36
+ used to deserialize the data (depending on the serialization method) or simply
37
+ as a metadata field.
38
+
39
+ stype : str
40
+ A string indicating the type of serialisation mechanism used to generate the
41
+ bytes in `data` from an array-like or tensor-like object.
42
+
43
+ data: bytes
44
+ A buffer of bytes containing the data.
45
+ """
46
+
47
+ dtype: str
48
+ shape: List[int]
49
+ stype: str
50
+ data: bytes
51
+
52
+
53
+ @dataclass
54
+ class ParametersRecord:
55
+ """Parameters record.
56
+
57
+ A dataclass storing named Arrays in order. This means that it holds entries as an
58
+ OrderedDict[str, Array]. ParametersRecord objects can be viewed as an equivalent to
59
+ PyTorch's state_dict, but holding serialised tensors instead.
60
+ """
61
+
62
+ keep_input: bool
63
+ data: OrderedDict[str, Array] = field(default_factory=OrderedDict[str, Array])
64
+
65
+ def __init__(
66
+ self,
67
+ array_dict: Optional[OrderedDict[str, Array]] = None,
68
+ keep_input: bool = False,
69
+ ) -> None:
70
+ """Construct a ParametersRecord object.
71
+
72
+ Parameters
73
+ ----------
74
+ array_dict : Optional[OrderedDict[str, Array]]
75
+ A dictionary that stores serialized array-like or tensor-like objects.
76
+ keep_input : bool (default: False)
77
+ A boolean indicating whether parameters should be deleted from the input
78
+ dictionary immediately after adding them to the record. If False, the
79
+ dictionary passed to `set_parameters()` will be empty once exiting from that
80
+ function. This is the desired behaviour when working with very large
81
+ models/tensors/arrays. However, if you plan to continue working with your
82
+ parameters after adding it to the record, set this flag to True. When set
83
+ to True, the data is duplicated in memory.
84
+ """
85
+ self.keep_input = keep_input
86
+ self.data = OrderedDict()
87
+ if array_dict:
88
+ self.set_parameters(array_dict)
89
+
90
+ def set_parameters(self, array_dict: OrderedDict[str, Array]) -> None:
91
+ """Add parameters to record.
92
+
93
+ Parameters
94
+ ----------
95
+ array_dict : OrderedDict[str, Array]
96
+ A dictionary that stores serialized array-like or tensor-like objects.
97
+ """
98
+ if any(not isinstance(k, str) for k in array_dict.keys()):
99
+ raise TypeError(f"Not all keys are of valid type. Expected {str}")
100
+ if any(not isinstance(v, Array) for v in array_dict.values()):
101
+ raise TypeError(f"Not all values are of valid type. Expected {Array}")
102
+
103
+ if self.keep_input:
104
+ # Copy
105
+ self.data = OrderedDict(array_dict)
106
+ else:
107
+ # Add entries to dataclass without duplicating memory
108
+ for key in list(array_dict.keys()):
109
+ self.data[key] = array_dict[key]
110
+ del array_dict[key]
flwr/common/recordset.py CHANGED
@@ -14,32 +14,22 @@
14
14
  # ==============================================================================
15
15
  """RecordSet."""
16
16
 
17
- from dataclasses import dataclass
18
- from typing import Dict
19
-
20
-
21
- @dataclass
22
- class ParametersRecord:
23
- """Parameters record."""
24
-
25
-
26
- @dataclass
27
- class MetricsRecord:
28
- """Metrics record."""
29
17
 
18
+ from dataclasses import dataclass, field
19
+ from typing import Dict
30
20
 
31
- @dataclass
32
- class ConfigsRecord:
33
- """Configs record."""
21
+ from .configsrecord import ConfigsRecord
22
+ from .metricsrecord import MetricsRecord
23
+ from .parametersrecord import ParametersRecord
34
24
 
35
25
 
36
26
  @dataclass
37
27
  class RecordSet:
38
28
  """Definition of RecordSet."""
39
29
 
40
- parameters: Dict[str, ParametersRecord] = {}
41
- metrics: Dict[str, MetricsRecord] = {}
42
- configs: Dict[str, ConfigsRecord] = {}
30
+ parameters: Dict[str, ParametersRecord] = field(default_factory=dict)
31
+ metrics: Dict[str, MetricsRecord] = field(default_factory=dict)
32
+ configs: Dict[str, ConfigsRecord] = field(default_factory=dict)
43
33
 
44
34
  def set_parameters(self, name: str, record: ParametersRecord) -> None:
45
35
  """Add a ParametersRecord."""