flwr-nightly 1.8.0.dev20240309__py3-none-any.whl → 1.8.0.dev20240311__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. flwr/cli/flower_toml.py +4 -48
  2. flwr/cli/new/new.py +6 -3
  3. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -3
  4. flwr/cli/new/templates/app/pyproject.toml.tpl +1 -1
  5. flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +2 -2
  6. flwr/cli/utils.py +14 -1
  7. flwr/client/app.py +39 -5
  8. flwr/client/client_app.py +1 -47
  9. flwr/client/mod/__init__.py +2 -1
  10. flwr/client/mod/secure_aggregation/__init__.py +2 -0
  11. flwr/client/mod/secure_aggregation/secagg_mod.py +30 -0
  12. flwr/client/mod/secure_aggregation/secaggplus_mod.py +73 -57
  13. flwr/common/grpc.py +3 -3
  14. flwr/common/logger.py +78 -15
  15. flwr/common/object_ref.py +140 -0
  16. flwr/common/secure_aggregation/ndarrays_arithmetic.py +5 -5
  17. flwr/common/secure_aggregation/secaggplus_constants.py +7 -6
  18. flwr/common/secure_aggregation/secaggplus_utils.py +15 -15
  19. flwr/server/compat/app.py +2 -1
  20. flwr/server/driver/grpc_driver.py +4 -4
  21. flwr/server/history.py +22 -15
  22. flwr/server/run_serverapp.py +22 -4
  23. flwr/server/server.py +27 -23
  24. flwr/server/server_app.py +1 -47
  25. flwr/server/server_config.py +9 -0
  26. flwr/server/strategy/fedavg.py +2 -0
  27. flwr/server/superlink/fleet/vce/vce_api.py +9 -2
  28. flwr/server/superlink/state/in_memory_state.py +34 -32
  29. flwr/server/workflow/__init__.py +3 -0
  30. flwr/server/workflow/constant.py +32 -0
  31. flwr/server/workflow/default_workflows.py +52 -57
  32. flwr/server/workflow/secure_aggregation/__init__.py +24 -0
  33. flwr/server/workflow/secure_aggregation/secagg_workflow.py +112 -0
  34. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +676 -0
  35. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/METADATA +1 -1
  36. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/RECORD +39 -33
  37. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/LICENSE +0 -0
  38. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/WHEEL +0 -0
  39. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/entry_points.txt +0 -0
@@ -15,8 +15,9 @@
15
15
  """Legacy default workflows."""
16
16
 
17
17
 
18
+ import io
18
19
  import timeit
19
- from logging import DEBUG, INFO
20
+ from logging import INFO
20
21
  from typing import Optional, cast
21
22
 
22
23
  import flwr.common.recordset_compat as compat
@@ -27,11 +28,7 @@ from ..compat.app_utils import start_update_client_manager_thread
27
28
  from ..compat.legacy_context import LegacyContext
28
29
  from ..driver import Driver
29
30
  from ..typing import Workflow
30
-
31
- KEY_CURRENT_ROUND = "current_round"
32
- KEY_START_TIME = "start_time"
33
- CONFIGS_RECORD_KEY = "config"
34
- PARAMS_RECORD_KEY = "parameters"
31
+ from .constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD, Key
35
32
 
36
33
 
37
34
  class DefaultWorkflow:
@@ -62,17 +59,19 @@ class DefaultWorkflow:
62
59
  )
63
60
 
64
61
  # Initialize parameters
62
+ log(INFO, "[INIT]")
65
63
  default_init_params_workflow(driver, context)
66
64
 
67
65
  # Run federated learning for num_rounds
68
- log(INFO, "FL starting")
69
66
  start_time = timeit.default_timer()
70
67
  cfg = ConfigsRecord()
71
- cfg[KEY_START_TIME] = start_time
72
- context.state.configs_records[CONFIGS_RECORD_KEY] = cfg
68
+ cfg[Key.START_TIME] = start_time
69
+ context.state.configs_records[MAIN_CONFIGS_RECORD] = cfg
73
70
 
74
71
  for current_round in range(1, context.config.num_rounds + 1):
75
- cfg[KEY_CURRENT_ROUND] = current_round
72
+ log(INFO, "")
73
+ log(INFO, "[ROUND %s]", current_round)
74
+ cfg[Key.CURRENT_ROUND] = current_round
76
75
 
77
76
  # Fit round
78
77
  self.fit_workflow(driver, context)
@@ -83,22 +82,19 @@ class DefaultWorkflow:
83
82
  # Evaluate round
84
83
  self.evaluate_workflow(driver, context)
85
84
 
86
- # Bookkeeping
85
+ # Bookkeeping and log results
87
86
  end_time = timeit.default_timer()
88
87
  elapsed = end_time - start_time
89
- log(INFO, "FL finished in %s", elapsed)
90
-
91
- # Log results
92
88
  hist = context.history
93
- log(INFO, "app_fit: losses_distributed %s", str(hist.losses_distributed))
94
- log(
95
- INFO,
96
- "app_fit: metrics_distributed_fit %s",
97
- str(hist.metrics_distributed_fit),
98
- )
99
- log(INFO, "app_fit: metrics_distributed %s", str(hist.metrics_distributed))
100
- log(INFO, "app_fit: losses_centralized %s", str(hist.losses_centralized))
101
- log(INFO, "app_fit: metrics_centralized %s", str(hist.metrics_centralized))
89
+ log(INFO, "")
90
+ log(INFO, "[SUMMARY]")
91
+ log(INFO, "Run finished %s rounds in %.2fs", context.config.num_rounds, elapsed)
92
+ for idx, line in enumerate(io.StringIO(str(hist))):
93
+ if idx == 0:
94
+ log(INFO, "%s", line.strip("\n"))
95
+ else:
96
+ log(INFO, "\t%s", line.strip("\n"))
97
+ log(INFO, "")
102
98
 
103
99
  # Terminate the thread
104
100
  f_stop.set()
@@ -111,12 +107,11 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
111
107
  if not isinstance(context, LegacyContext):
112
108
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
113
109
 
114
- log(INFO, "Initializing global parameters")
115
110
  parameters = context.strategy.initialize_parameters(
116
111
  client_manager=context.client_manager
117
112
  )
118
113
  if parameters is not None:
119
- log(INFO, "Using initial parameters provided by strategy")
114
+ log(INFO, "Using initial global parameters provided by strategy")
120
115
  paramsrecord = compat.parameters_to_parametersrecord(
121
116
  parameters, keep_input=True
122
117
  )
@@ -141,10 +136,10 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
141
136
  msg = list(messages)[0]
142
137
  paramsrecord = next(iter(msg.content.parameters_records.values()))
143
138
 
144
- context.state.parameters_records[PARAMS_RECORD_KEY] = paramsrecord
139
+ context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
145
140
 
146
141
  # Evaluate initial parameters
147
- log(INFO, "Evaluating initial parameters")
142
+ log(INFO, "Evaluating initial global parameters")
148
143
  parameters = compat.parametersrecord_to_parameters(paramsrecord, keep_input=True)
149
144
  res = context.strategy.evaluate(0, parameters=parameters)
150
145
  if res is not None:
@@ -164,13 +159,13 @@ def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None
164
159
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
165
160
 
166
161
  # Retrieve current_round and start_time from the context
167
- cfg = context.state.configs_records[CONFIGS_RECORD_KEY]
168
- current_round = cast(int, cfg[KEY_CURRENT_ROUND])
169
- start_time = cast(float, cfg[KEY_START_TIME])
162
+ cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
163
+ current_round = cast(int, cfg[Key.CURRENT_ROUND])
164
+ start_time = cast(float, cfg[Key.START_TIME])
170
165
 
171
166
  # Centralized evaluation
172
167
  parameters = compat.parametersrecord_to_parameters(
173
- record=context.state.parameters_records[PARAMS_RECORD_KEY],
168
+ record=context.state.parameters_records[MAIN_PARAMS_RECORD],
174
169
  keep_input=True,
175
170
  )
176
171
  res_cen = context.strategy.evaluate(current_round, parameters=parameters)
@@ -190,15 +185,17 @@ def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None
190
185
  )
191
186
 
192
187
 
193
- def default_fit_workflow(driver: Driver, context: Context) -> None:
188
+ def default_fit_workflow( # pylint: disable=R0914
189
+ driver: Driver, context: Context
190
+ ) -> None:
194
191
  """Execute the default workflow for a single fit round."""
195
192
  if not isinstance(context, LegacyContext):
196
193
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
197
194
 
198
195
  # Get current_round and parameters
199
- cfg = context.state.configs_records[CONFIGS_RECORD_KEY]
200
- current_round = cast(int, cfg[KEY_CURRENT_ROUND])
201
- parametersrecord = context.state.parameters_records[PARAMS_RECORD_KEY]
196
+ cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
197
+ current_round = cast(int, cfg[Key.CURRENT_ROUND])
198
+ parametersrecord = context.state.parameters_records[MAIN_PARAMS_RECORD]
202
199
  parameters = compat.parametersrecord_to_parameters(
203
200
  parametersrecord, keep_input=True
204
201
  )
@@ -211,12 +208,11 @@ def default_fit_workflow(driver: Driver, context: Context) -> None:
211
208
  )
212
209
 
213
210
  if not client_instructions:
214
- log(INFO, "fit_round %s: no clients selected, cancel", current_round)
211
+ log(INFO, "configure_fit: no clients selected, cancel")
215
212
  return
216
213
  log(
217
- DEBUG,
218
- "fit_round %s: strategy sampled %s clients (out of %s)",
219
- current_round,
214
+ INFO,
215
+ "configure_fit: strategy sampled %s clients (out of %s)",
220
216
  len(client_instructions),
221
217
  context.client_manager.num_available(),
222
218
  )
@@ -240,14 +236,14 @@ def default_fit_workflow(driver: Driver, context: Context) -> None:
240
236
  # collect `fit` results from all clients participating in this round
241
237
  messages = list(driver.send_and_receive(out_messages))
242
238
  del out_messages
239
+ num_failures = len([msg for msg in messages if msg.has_error()])
243
240
 
244
241
  # No exception/failure handling currently
245
242
  log(
246
- DEBUG,
247
- "fit_round %s received %s results and %s failures",
248
- current_round,
249
- len(messages),
250
- 0,
243
+ INFO,
244
+ "aggregate_fit: received %s results and %s failures",
245
+ len(messages) - num_failures,
246
+ num_failures,
251
247
  )
252
248
 
253
249
  # Aggregate training results
@@ -266,7 +262,7 @@ def default_fit_workflow(driver: Driver, context: Context) -> None:
266
262
  paramsrecord = compat.parameters_to_parametersrecord(
267
263
  parameters_aggregated, True
268
264
  )
269
- context.state.parameters_records[PARAMS_RECORD_KEY] = paramsrecord
265
+ context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
270
266
  context.history.add_metrics_distributed_fit(
271
267
  server_round=current_round, metrics=metrics_aggregated
272
268
  )
@@ -278,9 +274,9 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
278
274
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
279
275
 
280
276
  # Get current_round and parameters
281
- cfg = context.state.configs_records[CONFIGS_RECORD_KEY]
282
- current_round = cast(int, cfg[KEY_CURRENT_ROUND])
283
- parametersrecord = context.state.parameters_records[PARAMS_RECORD_KEY]
277
+ cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
278
+ current_round = cast(int, cfg[Key.CURRENT_ROUND])
279
+ parametersrecord = context.state.parameters_records[MAIN_PARAMS_RECORD]
284
280
  parameters = compat.parametersrecord_to_parameters(
285
281
  parametersrecord, keep_input=True
286
282
  )
@@ -292,12 +288,11 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
292
288
  client_manager=context.client_manager,
293
289
  )
294
290
  if not client_instructions:
295
- log(INFO, "evaluate_round %s: no clients selected, cancel", current_round)
291
+ log(INFO, "configure_evaluate: no clients selected, skipping evaluation")
296
292
  return
297
293
  log(
298
- DEBUG,
299
- "evaluate_round %s: strategy sampled %s clients (out of %s)",
300
- current_round,
294
+ INFO,
295
+ "configure_evaluate: strategy sampled %s clients (out of %s)",
301
296
  len(client_instructions),
302
297
  context.client_manager.num_available(),
303
298
  )
@@ -321,14 +316,14 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
321
316
  # collect `evaluate` results from all clients participating in this round
322
317
  messages = list(driver.send_and_receive(out_messages))
323
318
  del out_messages
319
+ num_failures = len([msg for msg in messages if msg.has_error()])
324
320
 
325
321
  # No exception/failure handling currently
326
322
  log(
327
- DEBUG,
328
- "evaluate_round %s received %s results and %s failures",
329
- current_round,
330
- len(messages),
331
- 0,
323
+ INFO,
324
+ "aggregate_evaluate: received %s results and %s failures",
325
+ len(messages) - num_failures,
326
+ num_failures,
332
327
  )
333
328
 
334
329
  # Aggregate the evaluation results
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Secure Aggregation workflows."""
16
+
17
+
18
+ from .secagg_workflow import SecAggWorkflow
19
+ from .secaggplus_workflow import SecAggPlusWorkflow
20
+
21
+ __all__ = [
22
+ "SecAggPlusWorkflow",
23
+ "SecAggWorkflow",
24
+ ]
@@ -0,0 +1,112 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Workflow for the SecAgg protocol."""
16
+
17
+
18
+ from typing import Optional, Union
19
+
20
+ from .secaggplus_workflow import SecAggPlusWorkflow
21
+
22
+
23
+ class SecAggWorkflow(SecAggPlusWorkflow):
24
+ """The workflow for the SecAgg protocol.
25
+
26
+ The SecAgg protocol ensures the secure summation of integer vectors owned by
27
+ multiple parties, without accessing any individual integer vector. This workflow
28
+ allows the server to compute the weighted average of model parameters across all
29
+ clients, ensuring individual contributions remain private. This is achieved by
30
+ clients sending both, a weighting factor and a weighted version of the locally
31
+ updated parameters, both of which are masked for privacy. Specifically, each
32
+ client uploads "[w, w * params]" with masks, where weighting factor 'w' is the
33
+ number of examples ('num_examples') and 'params' represents the model parameters
34
+ ('parameters') from the client's `FitRes`. The server then aggregates these
35
+ contributions to compute the weighted average of model parameters.
36
+
37
+ The protocol involves four main stages:
38
+ - 'setup': Send SecAgg configuration to clients and collect their public keys.
39
+ - 'share keys': Broadcast public keys among clients and collect encrypted secret
40
+ key shares.
41
+ - 'collect masked vectors': Forward encrypted secret key shares to target clients
42
+ and collect masked model parameters.
43
+ - 'unmask': Collect secret key shares to decrypt and aggregate the model parameters.
44
+
45
+ Only the aggregated model parameters are exposed and passed to
46
+ `Strategy.aggregate_fit`, ensuring individual data privacy.
47
+
48
+ Parameters
49
+ ----------
50
+ reconstruction_threshold : Union[int, float]
51
+ The minimum number of shares required to reconstruct a client's private key,
52
+ or, if specified as a float, it represents the proportion of the total number
53
+ of shares needed for reconstruction. This threshold ensures privacy by allowing
54
+ for the recovery of contributions from dropped clients during aggregation,
55
+ without compromising individual client data.
56
+ max_weight : Optional[float] (default: 1000.0)
57
+ The maximum value of the weight that can be assigned to any single client's
58
+ update during the weighted average calculation on the server side, e.g., in the
59
+ FedAvg algorithm.
60
+ clipping_range : float, optional (default: 8.0)
61
+ The range within which model parameters are clipped before quantization.
62
+ This parameter ensures each model parameter is bounded within
63
+ [-clipping_range, clipping_range], facilitating quantization.
64
+ quantization_range : int, optional (default: 4194304, this equals 2**22)
65
+ The size of the range into which floating-point model parameters are quantized,
66
+ mapping each parameter to an integer in [0, quantization_range-1]. This
67
+ facilitates cryptographic operations on the model updates.
68
+ modulus_range : int, optional (default: 4294967296, this equals 2**32)
69
+ The range of values from which random mask entries are uniformly sampled
70
+ ([0, modulus_range-1]). `modulus_range` must be less than 4294967296.
71
+ Please use 2**n values for `modulus_range` to prevent overflow issues.
72
+ timeout : Optional[float] (default: None)
73
+ The timeout duration in seconds. If specified, the workflow will wait for
74
+ replies for this duration each time. If `None`, there is no time limit and
75
+ the workflow will wait until replies for all messages are received.
76
+
77
+ Notes
78
+ -----
79
+ - Each client's private key is split into N shares under the SecAgg protocol, where
80
+ N is the number of selected clients.
81
+ - Generally, higher `reconstruction_threshold` means better privacy guarantees but
82
+ less tolerance to dropouts.
83
+ - Too large `max_weight` may compromise the precision of the quantization.
84
+ - `modulus_range` must be 2**n and larger than `quantization_range`.
85
+ - When `reconstruction_threshold` is a float, it is interpreted as the proportion of
86
+ the number of all selected clients needed for the reconstruction of a private key.
87
+ This feature enables flexibility in setting the security threshold relative to the
88
+ number of selected clients.
89
+ - `reconstruction_threshold`, and the quantization parameters
90
+ (`clipping_range`, `quantization_range`, `modulus_range`) play critical roles in
91
+ balancing privacy, robustness, and efficiency within the SecAgg protocol.
92
+ """
93
+
94
+ def __init__( # pylint: disable=R0913
95
+ self,
96
+ reconstruction_threshold: Union[int, float],
97
+ *,
98
+ max_weight: float = 1000.0,
99
+ clipping_range: float = 8.0,
100
+ quantization_range: int = 4194304,
101
+ modulus_range: int = 4294967296,
102
+ timeout: Optional[float] = None,
103
+ ) -> None:
104
+ super().__init__(
105
+ num_shares=1.0,
106
+ reconstruction_threshold=reconstruction_threshold,
107
+ max_weight=max_weight,
108
+ clipping_range=clipping_range,
109
+ quantization_range=quantization_range,
110
+ modulus_range=modulus_range,
111
+ timeout=timeout,
112
+ )