flwr-nightly 1.8.0.dev20240309__py3-none-any.whl → 1.8.0.dev20240311__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (39) hide show
  1. flwr/cli/flower_toml.py +4 -48
  2. flwr/cli/new/new.py +6 -3
  3. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -3
  4. flwr/cli/new/templates/app/pyproject.toml.tpl +1 -1
  5. flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +2 -2
  6. flwr/cli/utils.py +14 -1
  7. flwr/client/app.py +39 -5
  8. flwr/client/client_app.py +1 -47
  9. flwr/client/mod/__init__.py +2 -1
  10. flwr/client/mod/secure_aggregation/__init__.py +2 -0
  11. flwr/client/mod/secure_aggregation/secagg_mod.py +30 -0
  12. flwr/client/mod/secure_aggregation/secaggplus_mod.py +73 -57
  13. flwr/common/grpc.py +3 -3
  14. flwr/common/logger.py +78 -15
  15. flwr/common/object_ref.py +140 -0
  16. flwr/common/secure_aggregation/ndarrays_arithmetic.py +5 -5
  17. flwr/common/secure_aggregation/secaggplus_constants.py +7 -6
  18. flwr/common/secure_aggregation/secaggplus_utils.py +15 -15
  19. flwr/server/compat/app.py +2 -1
  20. flwr/server/driver/grpc_driver.py +4 -4
  21. flwr/server/history.py +22 -15
  22. flwr/server/run_serverapp.py +22 -4
  23. flwr/server/server.py +27 -23
  24. flwr/server/server_app.py +1 -47
  25. flwr/server/server_config.py +9 -0
  26. flwr/server/strategy/fedavg.py +2 -0
  27. flwr/server/superlink/fleet/vce/vce_api.py +9 -2
  28. flwr/server/superlink/state/in_memory_state.py +34 -32
  29. flwr/server/workflow/__init__.py +3 -0
  30. flwr/server/workflow/constant.py +32 -0
  31. flwr/server/workflow/default_workflows.py +52 -57
  32. flwr/server/workflow/secure_aggregation/__init__.py +24 -0
  33. flwr/server/workflow/secure_aggregation/secagg_workflow.py +112 -0
  34. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +676 -0
  35. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/METADATA +1 -1
  36. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/RECORD +39 -33
  37. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/LICENSE +0 -0
  38. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/WHEEL +0 -0
  39. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/entry_points.txt +0 -0
@@ -15,8 +15,9 @@
15
15
  """Legacy default workflows."""
16
16
 
17
17
 
18
+ import io
18
19
  import timeit
19
- from logging import DEBUG, INFO
20
+ from logging import INFO
20
21
  from typing import Optional, cast
21
22
 
22
23
  import flwr.common.recordset_compat as compat
@@ -27,11 +28,7 @@ from ..compat.app_utils import start_update_client_manager_thread
27
28
  from ..compat.legacy_context import LegacyContext
28
29
  from ..driver import Driver
29
30
  from ..typing import Workflow
30
-
31
- KEY_CURRENT_ROUND = "current_round"
32
- KEY_START_TIME = "start_time"
33
- CONFIGS_RECORD_KEY = "config"
34
- PARAMS_RECORD_KEY = "parameters"
31
+ from .constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD, Key
35
32
 
36
33
 
37
34
  class DefaultWorkflow:
@@ -62,17 +59,19 @@ class DefaultWorkflow:
62
59
  )
63
60
 
64
61
  # Initialize parameters
62
+ log(INFO, "[INIT]")
65
63
  default_init_params_workflow(driver, context)
66
64
 
67
65
  # Run federated learning for num_rounds
68
- log(INFO, "FL starting")
69
66
  start_time = timeit.default_timer()
70
67
  cfg = ConfigsRecord()
71
- cfg[KEY_START_TIME] = start_time
72
- context.state.configs_records[CONFIGS_RECORD_KEY] = cfg
68
+ cfg[Key.START_TIME] = start_time
69
+ context.state.configs_records[MAIN_CONFIGS_RECORD] = cfg
73
70
 
74
71
  for current_round in range(1, context.config.num_rounds + 1):
75
- cfg[KEY_CURRENT_ROUND] = current_round
72
+ log(INFO, "")
73
+ log(INFO, "[ROUND %s]", current_round)
74
+ cfg[Key.CURRENT_ROUND] = current_round
76
75
 
77
76
  # Fit round
78
77
  self.fit_workflow(driver, context)
@@ -83,22 +82,19 @@ class DefaultWorkflow:
83
82
  # Evaluate round
84
83
  self.evaluate_workflow(driver, context)
85
84
 
86
- # Bookkeeping
85
+ # Bookkeeping and log results
87
86
  end_time = timeit.default_timer()
88
87
  elapsed = end_time - start_time
89
- log(INFO, "FL finished in %s", elapsed)
90
-
91
- # Log results
92
88
  hist = context.history
93
- log(INFO, "app_fit: losses_distributed %s", str(hist.losses_distributed))
94
- log(
95
- INFO,
96
- "app_fit: metrics_distributed_fit %s",
97
- str(hist.metrics_distributed_fit),
98
- )
99
- log(INFO, "app_fit: metrics_distributed %s", str(hist.metrics_distributed))
100
- log(INFO, "app_fit: losses_centralized %s", str(hist.losses_centralized))
101
- log(INFO, "app_fit: metrics_centralized %s", str(hist.metrics_centralized))
89
+ log(INFO, "")
90
+ log(INFO, "[SUMMARY]")
91
+ log(INFO, "Run finished %s rounds in %.2fs", context.config.num_rounds, elapsed)
92
+ for idx, line in enumerate(io.StringIO(str(hist))):
93
+ if idx == 0:
94
+ log(INFO, "%s", line.strip("\n"))
95
+ else:
96
+ log(INFO, "\t%s", line.strip("\n"))
97
+ log(INFO, "")
102
98
 
103
99
  # Terminate the thread
104
100
  f_stop.set()
@@ -111,12 +107,11 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
111
107
  if not isinstance(context, LegacyContext):
112
108
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
113
109
 
114
- log(INFO, "Initializing global parameters")
115
110
  parameters = context.strategy.initialize_parameters(
116
111
  client_manager=context.client_manager
117
112
  )
118
113
  if parameters is not None:
119
- log(INFO, "Using initial parameters provided by strategy")
114
+ log(INFO, "Using initial global parameters provided by strategy")
120
115
  paramsrecord = compat.parameters_to_parametersrecord(
121
116
  parameters, keep_input=True
122
117
  )
@@ -141,10 +136,10 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
141
136
  msg = list(messages)[0]
142
137
  paramsrecord = next(iter(msg.content.parameters_records.values()))
143
138
 
144
- context.state.parameters_records[PARAMS_RECORD_KEY] = paramsrecord
139
+ context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
145
140
 
146
141
  # Evaluate initial parameters
147
- log(INFO, "Evaluating initial parameters")
142
+ log(INFO, "Evaluating initial global parameters")
148
143
  parameters = compat.parametersrecord_to_parameters(paramsrecord, keep_input=True)
149
144
  res = context.strategy.evaluate(0, parameters=parameters)
150
145
  if res is not None:
@@ -164,13 +159,13 @@ def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None
164
159
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
165
160
 
166
161
  # Retrieve current_round and start_time from the context
167
- cfg = context.state.configs_records[CONFIGS_RECORD_KEY]
168
- current_round = cast(int, cfg[KEY_CURRENT_ROUND])
169
- start_time = cast(float, cfg[KEY_START_TIME])
162
+ cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
163
+ current_round = cast(int, cfg[Key.CURRENT_ROUND])
164
+ start_time = cast(float, cfg[Key.START_TIME])
170
165
 
171
166
  # Centralized evaluation
172
167
  parameters = compat.parametersrecord_to_parameters(
173
- record=context.state.parameters_records[PARAMS_RECORD_KEY],
168
+ record=context.state.parameters_records[MAIN_PARAMS_RECORD],
174
169
  keep_input=True,
175
170
  )
176
171
  res_cen = context.strategy.evaluate(current_round, parameters=parameters)
@@ -190,15 +185,17 @@ def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None
190
185
  )
191
186
 
192
187
 
193
- def default_fit_workflow(driver: Driver, context: Context) -> None:
188
+ def default_fit_workflow( # pylint: disable=R0914
189
+ driver: Driver, context: Context
190
+ ) -> None:
194
191
  """Execute the default workflow for a single fit round."""
195
192
  if not isinstance(context, LegacyContext):
196
193
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
197
194
 
198
195
  # Get current_round and parameters
199
- cfg = context.state.configs_records[CONFIGS_RECORD_KEY]
200
- current_round = cast(int, cfg[KEY_CURRENT_ROUND])
201
- parametersrecord = context.state.parameters_records[PARAMS_RECORD_KEY]
196
+ cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
197
+ current_round = cast(int, cfg[Key.CURRENT_ROUND])
198
+ parametersrecord = context.state.parameters_records[MAIN_PARAMS_RECORD]
202
199
  parameters = compat.parametersrecord_to_parameters(
203
200
  parametersrecord, keep_input=True
204
201
  )
@@ -211,12 +208,11 @@ def default_fit_workflow(driver: Driver, context: Context) -> None:
211
208
  )
212
209
 
213
210
  if not client_instructions:
214
- log(INFO, "fit_round %s: no clients selected, cancel", current_round)
211
+ log(INFO, "configure_fit: no clients selected, cancel")
215
212
  return
216
213
  log(
217
- DEBUG,
218
- "fit_round %s: strategy sampled %s clients (out of %s)",
219
- current_round,
214
+ INFO,
215
+ "configure_fit: strategy sampled %s clients (out of %s)",
220
216
  len(client_instructions),
221
217
  context.client_manager.num_available(),
222
218
  )
@@ -240,14 +236,14 @@ def default_fit_workflow(driver: Driver, context: Context) -> None:
240
236
  # collect `fit` results from all clients participating in this round
241
237
  messages = list(driver.send_and_receive(out_messages))
242
238
  del out_messages
239
+ num_failures = len([msg for msg in messages if msg.has_error()])
243
240
 
244
241
  # No exception/failure handling currently
245
242
  log(
246
- DEBUG,
247
- "fit_round %s received %s results and %s failures",
248
- current_round,
249
- len(messages),
250
- 0,
243
+ INFO,
244
+ "aggregate_fit: received %s results and %s failures",
245
+ len(messages) - num_failures,
246
+ num_failures,
251
247
  )
252
248
 
253
249
  # Aggregate training results
@@ -266,7 +262,7 @@ def default_fit_workflow(driver: Driver, context: Context) -> None:
266
262
  paramsrecord = compat.parameters_to_parametersrecord(
267
263
  parameters_aggregated, True
268
264
  )
269
- context.state.parameters_records[PARAMS_RECORD_KEY] = paramsrecord
265
+ context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
270
266
  context.history.add_metrics_distributed_fit(
271
267
  server_round=current_round, metrics=metrics_aggregated
272
268
  )
@@ -278,9 +274,9 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
278
274
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
279
275
 
280
276
  # Get current_round and parameters
281
- cfg = context.state.configs_records[CONFIGS_RECORD_KEY]
282
- current_round = cast(int, cfg[KEY_CURRENT_ROUND])
283
- parametersrecord = context.state.parameters_records[PARAMS_RECORD_KEY]
277
+ cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
278
+ current_round = cast(int, cfg[Key.CURRENT_ROUND])
279
+ parametersrecord = context.state.parameters_records[MAIN_PARAMS_RECORD]
284
280
  parameters = compat.parametersrecord_to_parameters(
285
281
  parametersrecord, keep_input=True
286
282
  )
@@ -292,12 +288,11 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
292
288
  client_manager=context.client_manager,
293
289
  )
294
290
  if not client_instructions:
295
- log(INFO, "evaluate_round %s: no clients selected, cancel", current_round)
291
+ log(INFO, "configure_evaluate: no clients selected, skipping evaluation")
296
292
  return
297
293
  log(
298
- DEBUG,
299
- "evaluate_round %s: strategy sampled %s clients (out of %s)",
300
- current_round,
294
+ INFO,
295
+ "configure_evaluate: strategy sampled %s clients (out of %s)",
301
296
  len(client_instructions),
302
297
  context.client_manager.num_available(),
303
298
  )
@@ -321,14 +316,14 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
321
316
  # collect `evaluate` results from all clients participating in this round
322
317
  messages = list(driver.send_and_receive(out_messages))
323
318
  del out_messages
319
+ num_failures = len([msg for msg in messages if msg.has_error()])
324
320
 
325
321
  # No exception/failure handling currently
326
322
  log(
327
- DEBUG,
328
- "evaluate_round %s received %s results and %s failures",
329
- current_round,
330
- len(messages),
331
- 0,
323
+ INFO,
324
+ "aggregate_evaluate: received %s results and %s failures",
325
+ len(messages) - num_failures,
326
+ num_failures,
332
327
  )
333
328
 
334
329
  # Aggregate the evaluation results
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Secure Aggregation workflows."""
16
+
17
+
18
+ from .secagg_workflow import SecAggWorkflow
19
+ from .secaggplus_workflow import SecAggPlusWorkflow
20
+
21
+ __all__ = [
22
+ "SecAggPlusWorkflow",
23
+ "SecAggWorkflow",
24
+ ]
@@ -0,0 +1,112 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Workflow for the SecAgg protocol."""
16
+
17
+
18
+ from typing import Optional, Union
19
+
20
+ from .secaggplus_workflow import SecAggPlusWorkflow
21
+
22
+
23
+ class SecAggWorkflow(SecAggPlusWorkflow):
24
+ """The workflow for the SecAgg protocol.
25
+
26
+ The SecAgg protocol ensures the secure summation of integer vectors owned by
27
+ multiple parties, without accessing any individual integer vector. This workflow
28
+ allows the server to compute the weighted average of model parameters across all
29
+ clients, ensuring individual contributions remain private. This is achieved by
30
+ clients sending both, a weighting factor and a weighted version of the locally
31
+ updated parameters, both of which are masked for privacy. Specifically, each
32
+ client uploads "[w, w * params]" with masks, where weighting factor 'w' is the
33
+ number of examples ('num_examples') and 'params' represents the model parameters
34
+ ('parameters') from the client's `FitRes`. The server then aggregates these
35
+ contributions to compute the weighted average of model parameters.
36
+
37
+ The protocol involves four main stages:
38
+ - 'setup': Send SecAgg configuration to clients and collect their public keys.
39
+ - 'share keys': Broadcast public keys among clients and collect encrypted secret
40
+ key shares.
41
+ - 'collect masked vectors': Forward encrypted secret key shares to target clients
42
+ and collect masked model parameters.
43
+ - 'unmask': Collect secret key shares to decrypt and aggregate the model parameters.
44
+
45
+ Only the aggregated model parameters are exposed and passed to
46
+ `Strategy.aggregate_fit`, ensuring individual data privacy.
47
+
48
+ Parameters
49
+ ----------
50
+ reconstruction_threshold : Union[int, float]
51
+ The minimum number of shares required to reconstruct a client's private key,
52
+ or, if specified as a float, it represents the proportion of the total number
53
+ of shares needed for reconstruction. This threshold ensures privacy by allowing
54
+ for the recovery of contributions from dropped clients during aggregation,
55
+ without compromising individual client data.
56
+ max_weight : Optional[float] (default: 1000.0)
57
+ The maximum value of the weight that can be assigned to any single client's
58
+ update during the weighted average calculation on the server side, e.g., in the
59
+ FedAvg algorithm.
60
+ clipping_range : float, optional (default: 8.0)
61
+ The range within which model parameters are clipped before quantization.
62
+ This parameter ensures each model parameter is bounded within
63
+ [-clipping_range, clipping_range], facilitating quantization.
64
+ quantization_range : int, optional (default: 4194304, this equals 2**22)
65
+ The size of the range into which floating-point model parameters are quantized,
66
+ mapping each parameter to an integer in [0, quantization_range-1]. This
67
+ facilitates cryptographic operations on the model updates.
68
+ modulus_range : int, optional (default: 4294967296, this equals 2**32)
69
+ The range of values from which random mask entries are uniformly sampled
70
+ ([0, modulus_range-1]). `modulus_range` must be less than 4294967296.
71
+ Please use 2**n values for `modulus_range` to prevent overflow issues.
72
+ timeout : Optional[float] (default: None)
73
+ The timeout duration in seconds. If specified, the workflow will wait for
74
+ replies for this duration each time. If `None`, there is no time limit and
75
+ the workflow will wait until replies for all messages are received.
76
+
77
+ Notes
78
+ -----
79
+ - Each client's private key is split into N shares under the SecAgg protocol, where
80
+ N is the number of selected clients.
81
+ - Generally, higher `reconstruction_threshold` means better privacy guarantees but
82
+ less tolerance to dropouts.
83
+ - Too large `max_weight` may compromise the precision of the quantization.
84
+ - `modulus_range` must be 2**n and larger than `quantization_range`.
85
+ - When `reconstruction_threshold` is a float, it is interpreted as the proportion of
86
+ the number of all selected clients needed for the reconstruction of a private key.
87
+ This feature enables flexibility in setting the security threshold relative to the
88
+ number of selected clients.
89
+ - `reconstruction_threshold`, and the quantization parameters
90
+ (`clipping_range`, `quantization_range`, `modulus_range`) play critical roles in
91
+ balancing privacy, robustness, and efficiency within the SecAgg protocol.
92
+ """
93
+
94
+ def __init__( # pylint: disable=R0913
95
+ self,
96
+ reconstruction_threshold: Union[int, float],
97
+ *,
98
+ max_weight: float = 1000.0,
99
+ clipping_range: float = 8.0,
100
+ quantization_range: int = 4194304,
101
+ modulus_range: int = 4294967296,
102
+ timeout: Optional[float] = None,
103
+ ) -> None:
104
+ super().__init__(
105
+ num_shares=1.0,
106
+ reconstruction_threshold=reconstruction_threshold,
107
+ max_weight=max_weight,
108
+ clipping_range=clipping_range,
109
+ quantization_range=quantization_range,
110
+ modulus_range=modulus_range,
111
+ timeout=timeout,
112
+ )