flwr-nightly 1.8.0.dev20240228__py3-none-any.whl → 1.8.0.dev20240229__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -52,31 +52,10 @@ from flwr.common.secure_aggregation.ndarrays_arithmetic import (
52
52
  )
53
53
  from flwr.common.secure_aggregation.quantization import quantize
54
54
  from flwr.common.secure_aggregation.secaggplus_constants import (
55
- KEY_ACTIVE_SECURE_ID_LIST,
56
- KEY_CIPHERTEXT_LIST,
57
- KEY_CLIPPING_RANGE,
58
- KEY_DEAD_SECURE_ID_LIST,
59
- KEY_DESTINATION_LIST,
60
- KEY_MASKED_PARAMETERS,
61
- KEY_MOD_RANGE,
62
- KEY_PUBLIC_KEY_1,
63
- KEY_PUBLIC_KEY_2,
64
- KEY_SAMPLE_NUMBER,
65
- KEY_SECURE_ID,
66
- KEY_SECURE_ID_LIST,
67
- KEY_SHARE_LIST,
68
- KEY_SHARE_NUMBER,
69
- KEY_SOURCE_LIST,
70
- KEY_STAGE,
71
- KEY_TARGET_RANGE,
72
- KEY_THRESHOLD,
73
55
  RECORD_KEY_CONFIGS,
74
56
  RECORD_KEY_STATE,
75
- STAGE_COLLECT_MASKED_INPUT,
76
- STAGE_SETUP,
77
- STAGE_SHARE_KEYS,
78
- STAGE_UNMASK,
79
- STAGES,
57
+ Key,
58
+ Stage,
80
59
  )
81
60
  from flwr.common.secure_aggregation.secaggplus_utils import (
82
61
  pseudo_rand_gen,
@@ -91,7 +70,7 @@ from flwr.common.typing import ConfigsRecordValues, FitRes
91
70
  class SecAggPlusState:
92
71
  """State of the SecAgg+ protocol."""
93
72
 
94
- current_stage: str = STAGE_UNMASK
73
+ current_stage: str = Stage.UNMASK
95
74
 
96
75
  sid: int = 0
97
76
  sample_num: int = 0
@@ -187,20 +166,20 @@ def secaggplus_mod(
187
166
  check_stage(state.current_stage, configs)
188
167
 
189
168
  # Update the current stage
190
- state.current_stage = cast(str, configs.pop(KEY_STAGE))
169
+ state.current_stage = cast(str, configs.pop(Key.STAGE))
191
170
 
192
171
  # Check the validity of the configs based on the current stage
193
172
  check_configs(state.current_stage, configs)
194
173
 
195
174
  # Execute
196
- if state.current_stage == STAGE_SETUP:
175
+ if state.current_stage == Stage.SETUP:
197
176
  res = _setup(state, configs)
198
- elif state.current_stage == STAGE_SHARE_KEYS:
177
+ elif state.current_stage == Stage.SHARE_KEYS:
199
178
  res = _share_keys(state, configs)
200
- elif state.current_stage == STAGE_COLLECT_MASKED_INPUT:
179
+ elif state.current_stage == Stage.COLLECT_MASKED_INPUT:
201
180
  fit = _get_fit_fn(msg, ctxt, call_next)
202
181
  res = _collect_masked_input(state, configs, fit)
203
- elif state.current_stage == STAGE_UNMASK:
182
+ elif state.current_stage == Stage.UNMASK:
204
183
  res = _unmask(state, configs)
205
184
  else:
206
185
  raise ValueError(f"Unknown secagg stage: {state.current_stage}")
@@ -215,28 +194,29 @@ def secaggplus_mod(
215
194
 
216
195
  def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
217
196
  """Check the validity of the next stage."""
218
- # Check the existence of KEY_STAGE
219
- if KEY_STAGE not in configs:
197
+ # Check the existence of Config.STAGE
198
+ if Key.STAGE not in configs:
220
199
  raise KeyError(
221
- f"The required key '{KEY_STAGE}' is missing from the input `named_values`."
200
+ f"The required key '{Key.STAGE}' is missing from the input `named_values`."
222
201
  )
223
202
 
224
- # Check the value type of the KEY_STAGE
225
- next_stage = configs[KEY_STAGE]
203
+ # Check the value type of the Config.STAGE
204
+ next_stage = configs[Key.STAGE]
226
205
  if not isinstance(next_stage, str):
227
206
  raise TypeError(
228
- f"The value for the key '{KEY_STAGE}' must be of type {str}, "
207
+ f"The value for the key '{Key.STAGE}' must be of type {str}, "
229
208
  f"but got {type(next_stage)} instead."
230
209
  )
231
210
 
232
211
  # Check the validity of the next stage
233
- if next_stage == STAGE_SETUP:
234
- if current_stage != STAGE_UNMASK:
212
+ if next_stage == Stage.SETUP:
213
+ if current_stage != Stage.UNMASK:
235
214
  log(WARNING, "Restart from the setup stage")
236
215
  # If stage is not "setup",
237
216
  # the stage from `named_values` should be the expected next stage
238
217
  else:
239
- expected_next_stage = STAGES[(STAGES.index(current_stage) + 1) % len(STAGES)]
218
+ stages = Stage.all()
219
+ expected_next_stage = stages[(stages.index(current_stage) + 1) % len(stages)]
240
220
  if next_stage != expected_next_stage:
241
221
  raise ValueError(
242
222
  "Abort secure aggregation: "
@@ -248,20 +228,20 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
248
228
  def check_configs(stage: str, configs: ConfigsRecord) -> None:
249
229
  """Check the validity of the configs."""
250
230
  # Check `named_values` for the setup stage
251
- if stage == STAGE_SETUP:
231
+ if stage == Stage.SETUP:
252
232
  key_type_pairs = [
253
- (KEY_SAMPLE_NUMBER, int),
254
- (KEY_SECURE_ID, int),
255
- (KEY_SHARE_NUMBER, int),
256
- (KEY_THRESHOLD, int),
257
- (KEY_CLIPPING_RANGE, float),
258
- (KEY_TARGET_RANGE, int),
259
- (KEY_MOD_RANGE, int),
233
+ (Key.SAMPLE_NUMBER, int),
234
+ (Key.SECURE_ID, int),
235
+ (Key.SHARE_NUMBER, int),
236
+ (Key.THRESHOLD, int),
237
+ (Key.CLIPPING_RANGE, float),
238
+ (Key.TARGET_RANGE, int),
239
+ (Key.MOD_RANGE, int),
260
240
  ]
261
241
  for key, expected_type in key_type_pairs:
262
242
  if key not in configs:
263
243
  raise KeyError(
264
- f"Stage {STAGE_SETUP}: the required key '{key}' is "
244
+ f"Stage {Stage.SETUP}: the required key '{key}' is "
265
245
  "missing from the input `named_values`."
266
246
  )
267
247
  # Bool is a subclass of int in Python,
@@ -269,11 +249,11 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
269
249
  # pylint: disable-next=unidiomatic-typecheck
270
250
  if type(configs[key]) is not expected_type:
271
251
  raise TypeError(
272
- f"Stage {STAGE_SETUP}: The value for the key '{key}' "
252
+ f"Stage {Stage.SETUP}: The value for the key '{key}' "
273
253
  f"must be of type {expected_type}, "
274
254
  f"but got {type(configs[key])} instead."
275
255
  )
276
- elif stage == STAGE_SHARE_KEYS:
256
+ elif stage == Stage.SHARE_KEYS:
277
257
  for key, value in configs.items():
278
258
  if (
279
259
  not isinstance(value, list)
@@ -282,18 +262,18 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
282
262
  or not isinstance(value[1], bytes)
283
263
  ):
284
264
  raise TypeError(
285
- f"Stage {STAGE_SHARE_KEYS}: "
265
+ f"Stage {Stage.SHARE_KEYS}: "
286
266
  f"the value for the key '{key}' must be a list of two bytes."
287
267
  )
288
- elif stage == STAGE_COLLECT_MASKED_INPUT:
268
+ elif stage == Stage.COLLECT_MASKED_INPUT:
289
269
  key_type_pairs = [
290
- (KEY_CIPHERTEXT_LIST, bytes),
291
- (KEY_SOURCE_LIST, int),
270
+ (Key.CIPHERTEXT_LIST, bytes),
271
+ (Key.SOURCE_LIST, int),
292
272
  ]
293
273
  for key, expected_type in key_type_pairs:
294
274
  if key not in configs:
295
275
  raise KeyError(
296
- f"Stage {STAGE_COLLECT_MASKED_INPUT}: "
276
+ f"Stage {Stage.COLLECT_MASKED_INPUT}: "
297
277
  f"the required key '{key}' is "
298
278
  "missing from the input `named_values`."
299
279
  )
@@ -304,19 +284,19 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
304
284
  if type(elm) is not expected_type
305
285
  ):
306
286
  raise TypeError(
307
- f"Stage {STAGE_COLLECT_MASKED_INPUT}: "
287
+ f"Stage {Stage.COLLECT_MASKED_INPUT}: "
308
288
  f"the value for the key '{key}' "
309
289
  f"must be of type List[{expected_type.__name__}]"
310
290
  )
311
- elif stage == STAGE_UNMASK:
291
+ elif stage == Stage.UNMASK:
312
292
  key_type_pairs = [
313
- (KEY_ACTIVE_SECURE_ID_LIST, int),
314
- (KEY_DEAD_SECURE_ID_LIST, int),
293
+ (Key.ACTIVE_SECURE_ID_LIST, int),
294
+ (Key.DEAD_SECURE_ID_LIST, int),
315
295
  ]
316
296
  for key, expected_type in key_type_pairs:
317
297
  if key not in configs:
318
298
  raise KeyError(
319
- f"Stage {STAGE_UNMASK}: "
299
+ f"Stage {Stage.UNMASK}: "
320
300
  f"the required key '{key}' is "
321
301
  "missing from the input `named_values`."
322
302
  )
@@ -327,7 +307,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
327
307
  if type(elm) is not expected_type
328
308
  ):
329
309
  raise TypeError(
330
- f"Stage {STAGE_UNMASK}: "
310
+ f"Stage {Stage.UNMASK}: "
331
311
  f"the value for the key '{key}' "
332
312
  f"must be of type List[{expected_type.__name__}]"
333
313
  )
@@ -340,15 +320,15 @@ def _setup(
340
320
  ) -> Dict[str, ConfigsRecordValues]:
341
321
  # Assigning parameter values to object fields
342
322
  sec_agg_param_dict = configs
343
- state.sample_num = cast(int, sec_agg_param_dict[KEY_SAMPLE_NUMBER])
344
- state.sid = cast(int, sec_agg_param_dict[KEY_SECURE_ID])
323
+ state.sample_num = cast(int, sec_agg_param_dict[Key.SAMPLE_NUMBER])
324
+ state.sid = cast(int, sec_agg_param_dict[Key.SECURE_ID])
345
325
  log(INFO, "Client %d: starting stage 0...", state.sid)
346
326
 
347
- state.share_num = cast(int, sec_agg_param_dict[KEY_SHARE_NUMBER])
348
- state.threshold = cast(int, sec_agg_param_dict[KEY_THRESHOLD])
349
- state.clipping_range = cast(float, sec_agg_param_dict[KEY_CLIPPING_RANGE])
350
- state.target_range = cast(int, sec_agg_param_dict[KEY_TARGET_RANGE])
351
- state.mod_range = cast(int, sec_agg_param_dict[KEY_MOD_RANGE])
327
+ state.share_num = cast(int, sec_agg_param_dict[Key.SHARE_NUMBER])
328
+ state.threshold = cast(int, sec_agg_param_dict[Key.THRESHOLD])
329
+ state.clipping_range = cast(float, sec_agg_param_dict[Key.CLIPPING_RANGE])
330
+ state.target_range = cast(int, sec_agg_param_dict[Key.TARGET_RANGE])
331
+ state.mod_range = cast(int, sec_agg_param_dict[Key.MOD_RANGE])
352
332
 
353
333
  # Dictionaries containing client secure IDs as keys
354
334
  # and their respective secret shares as values.
@@ -367,7 +347,7 @@ def _setup(
367
347
  state.sk1, state.pk1 = private_key_to_bytes(sk1), public_key_to_bytes(pk1)
368
348
  state.sk2, state.pk2 = private_key_to_bytes(sk2), public_key_to_bytes(pk2)
369
349
  log(INFO, "Client %d: stage 0 completes. uploading public keys...", state.sid)
370
- return {KEY_PUBLIC_KEY_1: state.pk1, KEY_PUBLIC_KEY_2: state.pk2}
350
+ return {Key.PUBLIC_KEY_1: state.pk1, Key.PUBLIC_KEY_2: state.pk2}
371
351
 
372
352
 
373
353
  # pylint: disable-next=too-many-locals
@@ -429,7 +409,7 @@ def _share_keys(
429
409
  ciphertexts.append(ciphertext)
430
410
 
431
411
  log(INFO, "Client %d: stage 1 completes. uploading key shares...", state.sid)
432
- return {KEY_DESTINATION_LIST: dsts, KEY_CIPHERTEXT_LIST: ciphertexts}
412
+ return {Key.DESTINATION_LIST: dsts, Key.CIPHERTEXT_LIST: ciphertexts}
433
413
 
434
414
 
435
415
  # pylint: disable-next=too-many-locals
@@ -440,8 +420,8 @@ def _collect_masked_input(
440
420
  ) -> Dict[str, ConfigsRecordValues]:
441
421
  log(INFO, "Client %d: starting stage 2...", state.sid)
442
422
  available_clients: List[int] = []
443
- ciphertexts = cast(List[bytes], configs[KEY_CIPHERTEXT_LIST])
444
- srcs = cast(List[int], configs[KEY_SOURCE_LIST])
423
+ ciphertexts = cast(List[bytes], configs[Key.CIPHERTEXT_LIST])
424
+ srcs = cast(List[int], configs[Key.SOURCE_LIST])
445
425
  if len(ciphertexts) + 1 < state.threshold:
446
426
  raise ValueError("Not enough available neighbour clients.")
447
427
 
@@ -505,7 +485,7 @@ def _collect_masked_input(
505
485
  quantized_parameters = parameters_mod(quantized_parameters, state.mod_range)
506
486
  log(INFO, "Client %d: stage 2 completes. uploading masked parameters...", state.sid)
507
487
  return {
508
- KEY_MASKED_PARAMETERS: [ndarray_to_bytes(arr) for arr in quantized_parameters]
488
+ Key.MASKED_PARAMETERS: [ndarray_to_bytes(arr) for arr in quantized_parameters]
509
489
  }
510
490
 
511
491
 
@@ -514,8 +494,8 @@ def _unmask(
514
494
  ) -> Dict[str, ConfigsRecordValues]:
515
495
  log(INFO, "Client %d: starting stage 3...", state.sid)
516
496
 
517
- active_sids = cast(List[int], configs[KEY_ACTIVE_SECURE_ID_LIST])
518
- dead_sids = cast(List[int], configs[KEY_DEAD_SECURE_ID_LIST])
497
+ active_sids = cast(List[int], configs[Key.ACTIVE_SECURE_ID_LIST])
498
+ dead_sids = cast(List[int], configs[Key.DEAD_SECURE_ID_LIST])
519
499
  # Send private mask seed share for every avaliable client (including itclient)
520
500
  # Send first private key share for building pairwise mask for every dropped client
521
501
  if len(active_sids) < state.threshold:
@@ -528,4 +508,4 @@ def _unmask(
528
508
  shares += [state.sk1_share_dict[sid] for sid in dead_sids]
529
509
 
530
510
  log(INFO, "Client %d: stage 3 completes. uploading key shares...", state.sid)
531
- return {KEY_SECURE_ID_LIST: sids, KEY_SHARE_LIST: shares}
511
+ return {Key.SECURE_ID_LIST: sids, Key.SHARE_LIST: shares}
@@ -14,33 +14,55 @@
14
14
  # ==============================================================================
15
15
  """Constants for the SecAgg/SecAgg+ protocol."""
16
16
 
17
+
18
+ from __future__ import annotations
19
+
17
20
  RECORD_KEY_STATE = "secaggplus_state"
18
21
  RECORD_KEY_CONFIGS = "secaggplus_configs"
19
22
 
20
- # Names of stages
21
- STAGE_SETUP = "setup"
22
- STAGE_SHARE_KEYS = "share_keys"
23
- STAGE_COLLECT_MASKED_INPUT = "collect_masked_input"
24
- STAGE_UNMASK = "unmask"
25
- STAGES = (STAGE_SETUP, STAGE_SHARE_KEYS, STAGE_COLLECT_MASKED_INPUT, STAGE_UNMASK)
26
-
27
- # All valid keys in received/replied `named_values` dictionaries
28
- KEY_STAGE = "stage"
29
- KEY_SAMPLE_NUMBER = "sample_num"
30
- KEY_SECURE_ID = "secure_id"
31
- KEY_SHARE_NUMBER = "share_num"
32
- KEY_THRESHOLD = "threshold"
33
- KEY_CLIPPING_RANGE = "clipping_range"
34
- KEY_TARGET_RANGE = "target_range"
35
- KEY_MOD_RANGE = "mod_range"
36
- KEY_PUBLIC_KEY_1 = "pk1"
37
- KEY_PUBLIC_KEY_2 = "pk2"
38
- KEY_DESTINATION_LIST = "dsts"
39
- KEY_CIPHERTEXT_LIST = "ctxts"
40
- KEY_SOURCE_LIST = "srcs"
41
- KEY_PARAMETERS = "params"
42
- KEY_MASKED_PARAMETERS = "masked_params"
43
- KEY_ACTIVE_SECURE_ID_LIST = "active_sids"
44
- KEY_DEAD_SECURE_ID_LIST = "dead_sids"
45
- KEY_SECURE_ID_LIST = "sids"
46
- KEY_SHARE_LIST = "shares"
23
+
24
+ class Stage:
25
+ """Stages for the SecAgg+ protocol."""
26
+
27
+ SETUP = "setup"
28
+ SHARE_KEYS = "share_keys"
29
+ COLLECT_MASKED_INPUT = "collect_masked_input"
30
+ UNMASK = "unmask"
31
+ _stages = (SETUP, SHARE_KEYS, COLLECT_MASKED_INPUT, UNMASK)
32
+
33
+ @classmethod
34
+ def all(cls) -> tuple[str, str, str, str]:
35
+ """Return all stages."""
36
+ return cls._stages
37
+
38
+ def __new__(cls) -> Stage:
39
+ """Prevent instantiation."""
40
+ raise TypeError(f"{cls.__name__} cannot be instantiated.")
41
+
42
+
43
+ class Key:
44
+ """Keys for the configs in the ConfigsRecord."""
45
+
46
+ STAGE = "stage"
47
+ SAMPLE_NUMBER = "sample_num"
48
+ SECURE_ID = "secure_id"
49
+ SHARE_NUMBER = "share_num"
50
+ THRESHOLD = "threshold"
51
+ CLIPPING_RANGE = "clipping_range"
52
+ TARGET_RANGE = "target_range"
53
+ MOD_RANGE = "mod_range"
54
+ PUBLIC_KEY_1 = "pk1"
55
+ PUBLIC_KEY_2 = "pk2"
56
+ DESTINATION_LIST = "dsts"
57
+ CIPHERTEXT_LIST = "ctxts"
58
+ SOURCE_LIST = "srcs"
59
+ PARAMETERS = "params"
60
+ MASKED_PARAMETERS = "masked_params"
61
+ ACTIVE_SECURE_ID_LIST = "active_sids"
62
+ DEAD_SECURE_ID_LIST = "dead_sids"
63
+ SECURE_ID_LIST = "sids"
64
+ SHARE_LIST = "shares"
65
+
66
+ def __new__(cls) -> Key:
67
+ """Prevent instantiation."""
68
+ raise TypeError(f"{cls.__name__} cannot be instantiated.")
@@ -0,0 +1,26 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: flwr/proto/error.proto
4
+ # Protobuf Python Version: 4.25.0
5
+ """Generated protocol buffer code."""
6
+ from google.protobuf import descriptor as _descriptor
7
+ from google.protobuf import descriptor_pool as _descriptor_pool
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ from google.protobuf.internal import builder as _builder
10
+ # @@protoc_insertion_point(imports)
11
+
12
+ _sym_db = _symbol_database.Default()
13
+
14
+
15
+
16
+
17
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/error.proto\x12\nflwr.proto\"%\n\x05\x45rror\x12\x0c\n\x04\x63ode\x18\x01 \x01(\x12\x12\x0e\n\x06reason\x18\x02 \x01(\tb\x06proto3')
18
+
19
+ _globals = globals()
20
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
21
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.error_pb2', _globals)
22
+ if _descriptor._USE_C_DESCRIPTORS == False:
23
+ DESCRIPTOR._options = None
24
+ _globals['_ERROR']._serialized_start=38
25
+ _globals['_ERROR']._serialized_end=75
26
+ # @@protoc_insertion_point(module_scope)
@@ -0,0 +1,25 @@
1
+ """
2
+ @generated by mypy-protobuf. Do not edit manually!
3
+ isort:skip_file
4
+ """
5
+ import builtins
6
+ import google.protobuf.descriptor
7
+ import google.protobuf.message
8
+ import typing
9
+ import typing_extensions
10
+
11
+ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
12
+
13
+ class Error(google.protobuf.message.Message):
14
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
15
+ CODE_FIELD_NUMBER: builtins.int
16
+ REASON_FIELD_NUMBER: builtins.int
17
+ code: builtins.int
18
+ reason: typing.Text
19
+ def __init__(self,
20
+ *,
21
+ code: builtins.int = ...,
22
+ reason: typing.Text = ...,
23
+ ) -> None: ...
24
+ def ClearField(self, field_name: typing_extensions.Literal["code",b"code","reason",b"reason"]) -> None: ...
25
+ global___Error = Error
@@ -0,0 +1,4 @@
1
+ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
+ """Client and server classes corresponding to protobuf-defined services."""
3
+ import grpc
4
+
@@ -0,0 +1,4 @@
1
+ """
2
+ @generated by mypy-protobuf. Do not edit manually!
3
+ isort:skip_file
4
+ """
flwr/proto/task_pb2.py CHANGED
@@ -15,19 +15,20 @@ _sym_db = _symbol_database.Default()
15
15
  from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
16
16
  from flwr.proto import recordset_pb2 as flwr_dot_proto_dot_recordset__pb2
17
17
  from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2
18
+ from flwr.proto import error_pb2 as flwr_dot_proto_dot_error__pb2
18
19
 
19
20
 
20
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xd4\x01\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12(\n\trecordset\x18\x08 \x01(\x0b\x32\x15.flwr.proto.RecordSet\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3')
21
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x16\x66lwr/proto/error.proto\"\xf6\x01\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12(\n\trecordset\x18\x08 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\t \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3')
21
22
 
22
23
  _globals = globals()
23
24
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
24
25
  _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.task_pb2', _globals)
25
26
  if _descriptor._USE_C_DESCRIPTORS == False:
26
27
  DESCRIPTOR._options = None
27
- _globals['_TASK']._serialized_start=117
28
- _globals['_TASK']._serialized_end=329
29
- _globals['_TASKINS']._serialized_start=331
30
- _globals['_TASKINS']._serialized_end=423
31
- _globals['_TASKRES']._serialized_start=425
32
- _globals['_TASKRES']._serialized_end=517
28
+ _globals['_TASK']._serialized_start=141
29
+ _globals['_TASK']._serialized_end=387
30
+ _globals['_TASKINS']._serialized_start=389
31
+ _globals['_TASKINS']._serialized_end=481
32
+ _globals['_TASKRES']._serialized_start=483
33
+ _globals['_TASKRES']._serialized_end=575
33
34
  # @@protoc_insertion_point(module_scope)
flwr/proto/task_pb2.pyi CHANGED
@@ -3,6 +3,7 @@
3
3
  isort:skip_file
4
4
  """
5
5
  import builtins
6
+ import flwr.proto.error_pb2
6
7
  import flwr.proto.node_pb2
7
8
  import flwr.proto.recordset_pb2
8
9
  import google.protobuf.descriptor
@@ -23,6 +24,7 @@ class Task(google.protobuf.message.Message):
23
24
  ANCESTRY_FIELD_NUMBER: builtins.int
24
25
  TASK_TYPE_FIELD_NUMBER: builtins.int
25
26
  RECORDSET_FIELD_NUMBER: builtins.int
27
+ ERROR_FIELD_NUMBER: builtins.int
26
28
  @property
27
29
  def producer(self) -> flwr.proto.node_pb2.Node: ...
28
30
  @property
@@ -35,6 +37,8 @@ class Task(google.protobuf.message.Message):
35
37
  task_type: typing.Text
36
38
  @property
37
39
  def recordset(self) -> flwr.proto.recordset_pb2.RecordSet: ...
40
+ @property
41
+ def error(self) -> flwr.proto.error_pb2.Error: ...
38
42
  def __init__(self,
39
43
  *,
40
44
  producer: typing.Optional[flwr.proto.node_pb2.Node] = ...,
@@ -45,9 +49,10 @@ class Task(google.protobuf.message.Message):
45
49
  ancestry: typing.Optional[typing.Iterable[typing.Text]] = ...,
46
50
  task_type: typing.Text = ...,
47
51
  recordset: typing.Optional[flwr.proto.recordset_pb2.RecordSet] = ...,
52
+ error: typing.Optional[flwr.proto.error_pb2.Error] = ...,
48
53
  ) -> None: ...
49
- def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","producer",b"producer","recordset",b"recordset"]) -> builtins.bool: ...
50
- def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","producer",b"producer","recordset",b"recordset","task_type",b"task_type","ttl",b"ttl"]) -> None: ...
54
+ def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","error",b"error","producer",b"producer","recordset",b"recordset"]) -> builtins.bool: ...
55
+ def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","error",b"error","producer",b"producer","recordset",b"recordset","task_type",b"task_type","ttl",b"ttl"]) -> None: ...
51
56
  global___Task = Task
52
57
 
53
58
  class TaskIns(google.protobuf.message.Message):
@@ -16,10 +16,17 @@
16
16
 
17
17
 
18
18
  from .bulyan import Bulyan as Bulyan
19
- from .dp_adaptive_clipping import DifferentialPrivacyClientSideAdaptiveClipping
19
+ from .dp_adaptive_clipping import (
20
+ DifferentialPrivacyClientSideAdaptiveClipping as DifferentialPrivacyClientSideAdaptiveClipping,
21
+ )
22
+ from .dp_adaptive_clipping import (
23
+ DifferentialPrivacyServerSideAdaptiveClipping as DifferentialPrivacyServerSideAdaptiveClipping,
24
+ )
25
+ from .dp_fixed_clipping import (
26
+ DifferentialPrivacyClientSideFixedClipping as DifferentialPrivacyClientSideFixedClipping,
27
+ )
20
28
  from .dp_fixed_clipping import (
21
- DifferentialPrivacyClientSideFixedClipping,
22
- DifferentialPrivacyServerSideFixedClipping,
29
+ DifferentialPrivacyServerSideFixedClipping as DifferentialPrivacyServerSideFixedClipping,
23
30
  )
24
31
  from .dpfedavg_adaptive import DPFedAvgAdaptive as DPFedAvgAdaptive
25
32
  from .dpfedavg_fixed import DPFedAvgFixed as DPFedAvgFixed
@@ -46,6 +53,7 @@ __all__ = [
46
53
  "DPFedAvgAdaptive",
47
54
  "DPFedAvgFixed",
48
55
  "DifferentialPrivacyClientSideAdaptiveClipping",
56
+ "DifferentialPrivacyServerSideAdaptiveClipping",
49
57
  "DifferentialPrivacyClientSideFixedClipping",
50
58
  "DifferentialPrivacyServerSideFixedClipping",
51
59
  "FedAdagrad",
@@ -24,8 +24,19 @@ from typing import Dict, List, Optional, Tuple, Union
24
24
 
25
25
  import numpy as np
26
26
 
27
- from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar
27
+ from flwr.common import (
28
+ EvaluateIns,
29
+ EvaluateRes,
30
+ FitIns,
31
+ FitRes,
32
+ NDArrays,
33
+ Parameters,
34
+ Scalar,
35
+ ndarrays_to_parameters,
36
+ parameters_to_ndarrays,
37
+ )
28
38
  from flwr.common.differential_privacy import (
39
+ adaptive_clip_inputs_inplace,
29
40
  add_gaussian_noise_to_params,
30
41
  compute_adaptive_noise_params,
31
42
  )
@@ -40,6 +51,199 @@ from flwr.server.client_proxy import ClientProxy
40
51
  from flwr.server.strategy.strategy import Strategy
41
52
 
42
53
 
54
+ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
55
+ """Strategy wrapper for central DP with server-side adaptive clipping.
56
+
57
+ Parameters
58
+ ----------
59
+ strategy: Strategy
60
+ The strategy to which DP functionalities will be added by this wrapper.
61
+ noise_multiplier : float
62
+ The noise multiplier for the Gaussian mechanism for model updates.
63
+ num_sampled_clients : int
64
+ The number of clients that are sampled on each round.
65
+ initial_clipping_norm : float
66
+ The initial value of clipping norm. Deafults to 0.1.
67
+ Andrew et al. recommends to set to 0.1.
68
+ target_clipped_quantile : float
69
+ The desired quantile of updates which should be clipped. Defaults to 0.5.
70
+ clip_norm_lr : float
71
+ The learning rate for the clipping norm adaptation. Defaults to 0.2.
72
+ Andrew et al. recommends to set to 0.2.
73
+ clipped_count_stddev : float
74
+ The standard deviation of the noise added to the count of updates below the estimate.
75
+ Andrew et al. recommends to set to `expected_num_records/20`
76
+
77
+ Examples
78
+ --------
79
+ Create a strategy:
80
+
81
+ >>> strategy = fl.server.strategy.FedAvg( ... )
82
+
83
+ Wrap the strategy with the DifferentialPrivacyServerSideAdaptiveClipping wrapper
84
+
85
+ >>> dp_strategy = DifferentialPrivacyServerSideAdaptiveClipping(
86
+ >>> strategy, cfg.noise_multiplier, cfg.num_sampled_clients, ...
87
+ >>> )
88
+ """
89
+
90
+ # pylint: disable=too-many-arguments,too-many-instance-attributes
91
+ def __init__(
92
+ self,
93
+ strategy: Strategy,
94
+ noise_multiplier: float,
95
+ num_sampled_clients: int,
96
+ initial_clipping_norm: float = 0.1,
97
+ target_clipped_quantile: float = 0.5,
98
+ clip_norm_lr: float = 0.2,
99
+ clipped_count_stddev: Optional[float] = None,
100
+ ) -> None:
101
+ super().__init__()
102
+
103
+ if strategy is None:
104
+ raise ValueError("The passed strategy is None.")
105
+
106
+ if noise_multiplier < 0:
107
+ raise ValueError("The noise multiplier should be a non-negative value.")
108
+
109
+ if num_sampled_clients <= 0:
110
+ raise ValueError(
111
+ "The number of sampled clients should be a positive value."
112
+ )
113
+
114
+ if initial_clipping_norm <= 0:
115
+ raise ValueError("The initial clipping norm should be a positive value.")
116
+
117
+ if not 0 <= target_clipped_quantile <= 1:
118
+ raise ValueError(
119
+ "The target clipped quantile must be between 0 and 1 (inclusive)."
120
+ )
121
+
122
+ if clip_norm_lr <= 0:
123
+ raise ValueError("The learning rate must be positive.")
124
+
125
+ if clipped_count_stddev is not None:
126
+ if clipped_count_stddev < 0:
127
+ raise ValueError("The `clipped_count_stddev` must be non-negative.")
128
+
129
+ self.strategy = strategy
130
+ self.num_sampled_clients = num_sampled_clients
131
+ self.clipping_norm = initial_clipping_norm
132
+ self.target_clipped_quantile = target_clipped_quantile
133
+ self.clip_norm_lr = clip_norm_lr
134
+ (
135
+ self.clipped_count_stddev,
136
+ self.noise_multiplier,
137
+ ) = compute_adaptive_noise_params(
138
+ noise_multiplier,
139
+ num_sampled_clients,
140
+ clipped_count_stddev,
141
+ )
142
+
143
+ self.current_round_params: NDArrays = []
144
+
145
+ def __repr__(self) -> str:
146
+ """Compute a string representation of the strategy."""
147
+ rep = "Differential Privacy Strategy Wrapper (Server-Side Adaptive Clipping)"
148
+ return rep
149
+
150
+ def initialize_parameters(
151
+ self, client_manager: ClientManager
152
+ ) -> Optional[Parameters]:
153
+ """Initialize global model parameters using given strategy."""
154
+ return self.strategy.initialize_parameters(client_manager)
155
+
156
+ def configure_fit(
157
+ self, server_round: int, parameters: Parameters, client_manager: ClientManager
158
+ ) -> List[Tuple[ClientProxy, FitIns]]:
159
+ """Configure the next round of training."""
160
+ self.current_round_params = parameters_to_ndarrays(parameters)
161
+ return self.strategy.configure_fit(server_round, parameters, client_manager)
162
+
163
+ def configure_evaluate(
164
+ self, server_round: int, parameters: Parameters, client_manager: ClientManager
165
+ ) -> List[Tuple[ClientProxy, EvaluateIns]]:
166
+ """Configure the next round of evaluation."""
167
+ return self.strategy.configure_evaluate(
168
+ server_round, parameters, client_manager
169
+ )
170
+
171
+ def aggregate_fit(
172
+ self,
173
+ server_round: int,
174
+ results: List[Tuple[ClientProxy, FitRes]],
175
+ failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
176
+ ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
177
+ """Aggregate training results and update clip norms."""
178
+ if failures:
179
+ return None, {}
180
+
181
+ if len(results) != self.num_sampled_clients:
182
+ log(
183
+ WARNING,
184
+ CLIENTS_DISCREPANCY_WARNING,
185
+ len(results),
186
+ self.num_sampled_clients,
187
+ )
188
+
189
+ norm_bit_set_count = 0
190
+ for _, res in results:
191
+ param = parameters_to_ndarrays(res.parameters)
192
+ # Compute and clip update
193
+ model_update = [
194
+ np.subtract(x, y) for (x, y) in zip(param, self.current_round_params)
195
+ ]
196
+
197
+ norm_bit = adaptive_clip_inputs_inplace(model_update, self.clipping_norm)
198
+ norm_bit_set_count += norm_bit
199
+
200
+ for i, _ in enumerate(self.current_round_params):
201
+ param[i] = self.current_round_params[i] + model_update[i]
202
+ # Convert back to parameters
203
+ res.parameters = ndarrays_to_parameters(param)
204
+
205
+ # Noising the count
206
+ noised_norm_bit_set_count = float(
207
+ np.random.normal(norm_bit_set_count, self.clipped_count_stddev)
208
+ )
209
+ noised_norm_bit_set_fraction = noised_norm_bit_set_count / len(results)
210
+ # Geometric update
211
+ self.clipping_norm *= math.exp(
212
+ -self.clip_norm_lr
213
+ * (noised_norm_bit_set_fraction - self.target_clipped_quantile)
214
+ )
215
+
216
+ aggregated_params, metrics = self.strategy.aggregate_fit(
217
+ server_round, results, failures
218
+ )
219
+
220
+ # Add Gaussian noise to the aggregated parameters
221
+ if aggregated_params:
222
+ aggregated_params = add_gaussian_noise_to_params(
223
+ aggregated_params,
224
+ self.noise_multiplier,
225
+ self.clipping_norm,
226
+ self.num_sampled_clients,
227
+ )
228
+
229
+ return aggregated_params, metrics
230
+
231
+ def aggregate_evaluate(
232
+ self,
233
+ server_round: int,
234
+ results: List[Tuple[ClientProxy, EvaluateRes]],
235
+ failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
236
+ ) -> Tuple[Optional[float], Dict[str, Scalar]]:
237
+ """Aggregate evaluation losses using the given strategy."""
238
+ return self.strategy.aggregate_evaluate(server_round, results, failures)
239
+
240
+ def evaluate(
241
+ self, server_round: int, parameters: Parameters
242
+ ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
243
+ """Evaluate model parameters using an evaluation function from the strategy."""
244
+ return self.strategy.evaluate(server_round, parameters)
245
+
246
+
43
247
  class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
44
248
  """Strategy wrapper for central DP with client-side adaptive clipping.
45
249
 
@@ -15,7 +15,7 @@
15
15
  """Driver API servicer."""
16
16
 
17
17
 
18
- from logging import INFO
18
+ from logging import DEBUG, INFO
19
19
  from typing import List, Optional, Set
20
20
  from uuid import UUID
21
21
 
@@ -70,7 +70,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
70
70
  self, request: PushTaskInsRequest, context: grpc.ServicerContext
71
71
  ) -> PushTaskInsResponse:
72
72
  """Push a set of TaskIns."""
73
- log(INFO, "DriverServicer.PushTaskIns")
73
+ log(DEBUG, "DriverServicer.PushTaskIns")
74
74
 
75
75
  # Validate request
76
76
  _raise_if(len(request.task_ins_list) == 0, "`task_ins_list` must not be empty")
@@ -95,7 +95,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
95
95
  self, request: PullTaskResRequest, context: grpc.ServicerContext
96
96
  ) -> PullTaskResResponse:
97
97
  """Pull a set of TaskRes."""
98
- log(INFO, "DriverServicer.PullTaskRes")
98
+ log(DEBUG, "DriverServicer.PullTaskRes")
99
99
 
100
100
  # Convert each task_id str to UUID
101
101
  task_ids: Set[UUID] = {UUID(task_id) for task_id in request.task_ids}
@@ -105,7 +105,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
105
105
 
106
106
  # Register callback
107
107
  def on_rpc_done() -> None:
108
- log(INFO, "DriverServicer.PullTaskRes callback: delete TaskIns/TaskRes")
108
+ log(DEBUG, "DriverServicer.PullTaskRes callback: delete TaskIns/TaskRes")
109
109
 
110
110
  if context.is_active():
111
111
  return
@@ -17,6 +17,8 @@
17
17
 
18
18
  import importlib
19
19
 
20
+ from flwr.simulation.run_simulation import run_simulation
21
+
20
22
  is_ray_installed = importlib.util.find_spec("ray") is not None
21
23
 
22
24
  if is_ray_installed:
@@ -36,4 +38,5 @@ To install the necessary dependencies, install `flwr` with the `simulation` extr
36
38
 
37
39
  __all__ = [
38
40
  "start_simulation",
41
+ "run_simulation",
39
42
  ]
@@ -0,0 +1,177 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower Simulation."""
16
+
17
+ import argparse
18
+ import asyncio
19
+ import json
20
+ import threading
21
+ import traceback
22
+ from logging import ERROR, INFO, WARNING
23
+
24
+ import grpc
25
+
26
+ from flwr.common import EventType, event, log
27
+ from flwr.common.exit_handlers import register_exit_handlers
28
+ from flwr.server.driver.driver import Driver
29
+ from flwr.server.run_serverapp import run
30
+ from flwr.server.superlink.driver.driver_grpc import run_driver_api_grpc
31
+ from flwr.server.superlink.fleet import vce
32
+ from flwr.server.superlink.state import StateFactory
33
+ from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth
34
+
35
+
36
+ def run_simulation() -> None:
37
+ """Run Simulation Engine."""
38
+ args = _parse_args_run_simulation().parse_args()
39
+
40
+ # Load JSON config
41
+ backend_config_dict = json.loads(args.backend_config)
42
+
43
+ # Enable GPU memory growth (relevant only for TF)
44
+ if args.enable_tf_gpu_growth:
45
+ log(INFO, "Enabling GPU growth for Tensorflow on the main thread.")
46
+ enable_tf_gpu_growth()
47
+ # Check that Backend config has also enabled using GPU growth
48
+ use_tf = backend_config_dict.get("tensorflow", False)
49
+ if not use_tf:
50
+ log(WARNING, "Enabling GPU growth for your backend.")
51
+ backend_config_dict["tensorflow"] = True
52
+
53
+ # Convert back to JSON stream
54
+ backend_config = json.dumps(backend_config_dict)
55
+
56
+ # Initialize StateFactory
57
+ state_factory = StateFactory(":flwr-in-memory-state:")
58
+
59
+ # Start Driver API
60
+ driver_server: grpc.Server = run_driver_api_grpc(
61
+ address=args.driver_api_address,
62
+ state_factory=state_factory,
63
+ certificates=None,
64
+ )
65
+
66
+ # SuperLink with Simulation Engine
67
+ f_stop = asyncio.Event()
68
+ superlink_th = threading.Thread(
69
+ target=vce.start_vce,
70
+ kwargs={
71
+ "num_supernodes": args.num_supernodes,
72
+ "client_app_module_name": args.client_app,
73
+ "backend_name": args.backend,
74
+ "backend_config_json_stream": backend_config,
75
+ "working_dir": args.dir,
76
+ "state_factory": state_factory,
77
+ "f_stop": f_stop,
78
+ },
79
+ daemon=False,
80
+ )
81
+
82
+ superlink_th.start()
83
+ event(EventType.RUN_SUPERLINK_ENTER)
84
+
85
+ try:
86
+ # Initialize Driver
87
+ driver = Driver(
88
+ driver_service_address=args.driver_api_address,
89
+ root_certificates=None,
90
+ )
91
+
92
+ # Launch server app
93
+ run(args.server_app, driver, args.dir)
94
+
95
+ except Exception as ex:
96
+
97
+ log(ERROR, "An exception occurred: %s", ex)
98
+ log(ERROR, traceback.format_exc())
99
+ raise RuntimeError(
100
+ "An error was encountered by the Simulation Engine. Ending simulation."
101
+ ) from ex
102
+
103
+ finally:
104
+
105
+ del driver
106
+
107
+ # Trigger stop event
108
+ f_stop.set()
109
+
110
+ register_exit_handlers(
111
+ grpc_servers=[driver_server],
112
+ bckg_threads=[superlink_th],
113
+ event_type=EventType.RUN_SUPERLINK_LEAVE,
114
+ )
115
+ superlink_th.join()
116
+
117
+
118
+ def _parse_args_run_simulation() -> argparse.ArgumentParser:
119
+ """Parse flower-simulation command line arguments."""
120
+ parser = argparse.ArgumentParser(
121
+ description="Start a Flower simulation",
122
+ )
123
+ parser.add_argument(
124
+ "--client-app",
125
+ required=True,
126
+ help="For example: `client:app` or `project.package.module:wrapper.app`",
127
+ )
128
+ parser.add_argument(
129
+ "--server-app",
130
+ required=True,
131
+ help="For example: `server:app` or `project.package.module:wrapper.app`",
132
+ )
133
+ parser.add_argument(
134
+ "--driver-api-address",
135
+ default="0.0.0.0:9091",
136
+ type=str,
137
+ help="For example: `server:app` or `project.package.module:wrapper.app`",
138
+ )
139
+ parser.add_argument(
140
+ "--num-supernodes",
141
+ type=int,
142
+ required=True,
143
+ help="Number of simulated SuperNodes.",
144
+ )
145
+ parser.add_argument(
146
+ "--backend",
147
+ default="ray",
148
+ type=str,
149
+ help="Simulation backend that executes the ClientApp.",
150
+ )
151
+ parser.add_argument(
152
+ "--enable-tf-gpu-growth",
153
+ action="store_true",
154
+ help="Enables GPU growth on the main thread. This is desirable if you make "
155
+ "use of a TensorFlow model on your `ServerApp` while having your `ClientApp` "
156
+ "running on the same GPU. Without enabling this, you might encounter an "
157
+ "out-of-memory error becasue TensorFlow by default allocates all GPU memory."
158
+ "Read mor about how `tf.config.experimental.set_memory_growth()` works in "
159
+ "the TensorFlow documentation: https://www.tensorflow.org/api/stable.",
160
+ )
161
+ parser.add_argument(
162
+ "--backend-config",
163
+ type=str,
164
+ default='{"client_resources": {"num_cpus":2, "num_gpus":0.0}, "tensorflow": 0}',
165
+ help='A JSON formatted stream, e.g \'{"<keyA>":<value>, "<keyB>":<value>}\' to '
166
+ "configure a backend. Values supported in <value> are those included by "
167
+ "`flwr.common.typing.ConfigsRecordValues`. ",
168
+ )
169
+ parser.add_argument(
170
+ "--dir",
171
+ default="",
172
+ help="Add specified directory to the PYTHONPATH and load"
173
+ "ClientApp and ServerApp from there."
174
+ " Default: current working directory.",
175
+ )
176
+
177
+ return parser
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flwr-nightly
3
- Version: 1.8.0.dev20240228
3
+ Version: 1.8.0.dev20240229
4
4
  Summary: Flower: A Friendly Federated Learning Framework
5
5
  Home-page: https://flower.ai
6
6
  License: Apache-2.0
@@ -32,7 +32,7 @@ flwr/client/message_handler/task_handler.py,sha256=ZDJBKmrn2grRMNl1rU1iGs7FiMHL5
32
32
  flwr/client/mod/__init__.py,sha256=w6r7n6fWIrrm4lEk36lh9f1Ix6LXgAzQUrgjmMspY98,961
33
33
  flwr/client/mod/centraldp_mods.py,sha256=aHbzGjSbyRENuU5vzad_tkJ9UDb48uHEvUq-zgydBwo,4954
34
34
  flwr/client/mod/secure_aggregation/__init__.py,sha256=AzCdezuzX2BfXUuxVRwXdv8-zUIXoU-Bf6u4LRhzvg8,796
35
- flwr/client/mod/secure_aggregation/secaggplus_mod.py,sha256=z_5t1YzqLs91ZLW5Yoo7Ozqw9_nyVuEpJ7Noa2a34bs,19890
35
+ flwr/client/mod/secure_aggregation/secaggplus_mod.py,sha256=Zc5b2C58SYeyQVXIbLRESOIq4rMUOuzMHNAjdSAPt6I,19434
36
36
  flwr/client/mod/utils.py,sha256=lvETHcCYsSWz7h8I772hCV_kZspxqlMqzriMZ-SxmKc,1226
37
37
  flwr/client/node_state.py,sha256=KTTs_l4I0jBM7IsSsbAGjhfL_yZC3QANbzyvyfZBRDM,1778
38
38
  flwr/client/node_state_tests.py,sha256=gPwz0zf2iuDSa11jedkur_u3Xm7lokIDG5ALD2MCvSw,2195
@@ -68,7 +68,7 @@ flwr/common/secure_aggregation/crypto/shamir.py,sha256=yY35ZgHlB4YyGW_buG-1X-0M-
68
68
  flwr/common/secure_aggregation/crypto/symmetric_encryption.py,sha256=-zDyQoTsHHQjR7o-92FNIikg1zM_Ke9yynaD5u2BXbQ,3546
69
69
  flwr/common/secure_aggregation/ndarrays_arithmetic.py,sha256=KAHCEHGSTJ6mCgnC8dTpIx6URk11s5XxWTHI_7ToGIg,2979
70
70
  flwr/common/secure_aggregation/quantization.py,sha256=appui7GGrkRPsupF59TkapeV4Na_CyPi73JtJ1pimdI,2310
71
- flwr/common/secure_aggregation/secaggplus_constants.py,sha256=pITi-nuzrNvKWR42XwVFBuejv1RdGLwmuErLp0X0t_Y,1686
71
+ flwr/common/secure_aggregation/secaggplus_constants.py,sha256=2dYMKqiO2ja8DewFeZUAGqM0xujNfyYHVwShRxeRaIM,2132
72
72
  flwr/common/secure_aggregation/secaggplus_utils.py,sha256=PleDyDu7jHNAfbRoEaoQiOjxG6iMl9yA8rNKYTfnyFw,3155
73
73
  flwr/common/serde.py,sha256=0tmfTcywJVLA7Hsu4nAjMn2dVMVbjYZqePJbPcDt01Y,20726
74
74
  flwr/common/telemetry.py,sha256=JkFB6WBOskqAJfzSM-l6tQfRiSi2oiysClfg0-5T7NY,7782
@@ -79,6 +79,10 @@ flwr/proto/driver_pb2.py,sha256=JHIdjNPTgp6YHD-_lz5ZZFB0VIOR3_GmcaOTN4jndc4,3115
79
79
  flwr/proto/driver_pb2.pyi,sha256=xwl2AqIWn0SwAlg-x5RUQeqr6DC48eywnqmD7gbaaFs,4670
80
80
  flwr/proto/driver_pb2_grpc.py,sha256=qQBRdQUz4k2K4DVO7kSfWHx-62UJ85HaYKnKCr6JcU8,7304
81
81
  flwr/proto/driver_pb2_grpc.pyi,sha256=NpOM5eCrIPcuWdYrZAayQSDvvFp6cDCVflabhmuvMfo,2022
82
+ flwr/proto/error_pb2.py,sha256=LarjKL90LbwkXKlhzNrDssgl4DXcvIPve8NVCXHpsKA,1084
83
+ flwr/proto/error_pb2.pyi,sha256=ZNH4HhJTU_KfMXlyCeg8FwU-fcUYxTqEmoJPtWtHikc,734
84
+ flwr/proto/error_pb2_grpc.py,sha256=1oboBPFxaTEXt9Aw7EAj8gXHDCNMhZD2VXqocC9l_gk,159
85
+ flwr/proto/error_pb2_grpc.pyi,sha256=ff2TSiLVnG6IVQcTGzb2DIH3XRSoAvAo_RMcvbMFyc0,76
82
86
  flwr/proto/fleet_pb2.py,sha256=8rKQHu6Oa9ki_NG6kRNGtfPPYZp5kKBZhPW696_kn84,3852
83
87
  flwr/proto/fleet_pb2.pyi,sha256=QXYs9M7_dABghdCMfk5Rjf4w0LsZGDeQ1ojH00XaQME,6182
84
88
  flwr/proto/fleet_pb2_grpc.py,sha256=hF1uPaioZzQMRCP9yPlv9LC0mi_DTuhn-IkQJzWIPCs,7505
@@ -91,8 +95,8 @@ flwr/proto/recordset_pb2.py,sha256=un8L0kvBcgFXQIiQweOseeIJBjlOozUvQY9uTQ42Dqo,6
91
95
  flwr/proto/recordset_pb2.pyi,sha256=NPzCJWAj1xLWzeZ_xZ6uaObQjQfWGnnqlLtn4J-SoFY,14161
92
96
  flwr/proto/recordset_pb2_grpc.py,sha256=1oboBPFxaTEXt9Aw7EAj8gXHDCNMhZD2VXqocC9l_gk,159
93
97
  flwr/proto/recordset_pb2_grpc.pyi,sha256=ff2TSiLVnG6IVQcTGzb2DIH3XRSoAvAo_RMcvbMFyc0,76
94
- flwr/proto/task_pb2.py,sha256=0iZTgNnrypfuO9AqjXqvV79VGaS9lkpi6oohMFeoTso,2272
95
- flwr/proto/task_pb2.pyi,sha256=k3Bulbd7pbYszemI5ntWJtuaYlSz_wnhhBZSyGSwjF8,3937
98
+ flwr/proto/task_pb2.py,sha256=-UX3TqskOIRbPu8U3YwgW9ul2k9ZN6MJGgbIOX3pTqg,2431
99
+ flwr/proto/task_pb2.pyi,sha256=IgXggFya0RpL64z2o2K_qLnZHyZ1mg_WzLxFwEKrpL0,4171
96
100
  flwr/proto/task_pb2_grpc.py,sha256=1oboBPFxaTEXt9Aw7EAj8gXHDCNMhZD2VXqocC9l_gk,159
97
101
  flwr/proto/task_pb2_grpc.pyi,sha256=ff2TSiLVnG6IVQcTGzb2DIH3XRSoAvAo_RMcvbMFyc0,76
98
102
  flwr/proto/transport_pb2.py,sha256=cURzfpCgZvH7GEvBPLvTYijE3HvhK1MePjINk4xYArk,9781
@@ -118,10 +122,10 @@ flwr/server/run_serverapp.py,sha256=7LLE1cVQz0Rl-hZnY7DLXvFxWCdep8xgLgEVC-yffi0,
118
122
  flwr/server/server.py,sha256=JLc2lg-qCchD-Jyg_hTBZQN3rXsnfAGx6qJAo0vqH2Y,17812
119
123
  flwr/server/server_app.py,sha256=avNQ7AMMKsn09ly81C3UBgOfHhM_R29l4MrzlalGoj8,5892
120
124
  flwr/server/server_config.py,sha256=yOHpkdyuhOm--Gy_4Vofvu6jCDxhyECEDpIy02beuCg,1018
121
- flwr/server/strategy/__init__.py,sha256=xKnjkTQQT4PjoCXrUXW4LJ85H_iXNszKZ4u6W__VWp8,2460
125
+ flwr/server/strategy/__init__.py,sha256=7eVZ3hQEg2BgA_usAeL6tsLp9T6XI1VYYoFy08Xn-ew,2836
122
126
  flwr/server/strategy/aggregate.py,sha256=QyRIJtI5gnuY1NbgrcrOvkHxGIxBvApq7d9Y4xl-6W4,13468
123
127
  flwr/server/strategy/bulyan.py,sha256=8GsSVJzRSoSWE2zQUKqC3Z795grdN9xpmc3MSGGXnzM,6532
124
- flwr/server/strategy/dp_adaptive_clipping.py,sha256=DIfQA4PDXNFpEYKTf5aHFzvNn7t-_6Bv1MZdKGrJsc8,9261
128
+ flwr/server/strategy/dp_adaptive_clipping.py,sha256=BVvX1LivyukvEtOZVZnpgVpzH8BBjvA3OmdGwFxgRuQ,16679
125
129
  flwr/server/strategy/dp_fixed_clipping.py,sha256=v9YyX53jt2RatGnFxTK4ZMO_3SN7EdL9YCaaJtn9Fcc,12125
126
130
  flwr/server/strategy/dpfedavg_adaptive.py,sha256=hLJkPQJl1bHjwrBNg3PSRFKf3no0hg5EHiFaWhHlWqw,4877
127
131
  flwr/server/strategy/dpfedavg_fixed.py,sha256=G0yYxrPoM-MHQ889DYN3OeNiEeU0yQrjgAzcq0G653w,7219
@@ -145,7 +149,7 @@ flwr/server/strategy/strategy.py,sha256=g6VoIFogEviRub6G4QsKdIp6M_Ek6GhBhqcdNx5u
145
149
  flwr/server/superlink/__init__.py,sha256=8tHYCfodUlRD8PCP9fHgvu8cz5N31A2QoRVL0jDJ15E,707
146
150
  flwr/server/superlink/driver/__init__.py,sha256=STB1_DASVEg7Cu6L7VYxTzV7UMkgtBkFim09Z82Dh8I,712
147
151
  flwr/server/superlink/driver/driver_grpc.py,sha256=1qSGDs1k_OVPWxp2ofxvQgtYXExrMeC3N_rNPVWH65M,1932
148
- flwr/server/superlink/driver/driver_servicer.py,sha256=69-QMkdefgd8BDDAuy4TSJHVcmZAGqq8Yy9ZPDq2kks,4563
152
+ flwr/server/superlink/driver/driver_servicer.py,sha256=p1rlocgqU2I4s5IwdU8rTZBkQ73yPmuWtK_aUlB7V84,4573
149
153
  flwr/server/superlink/fleet/__init__.py,sha256=C6GCSD5eP5Of6_dIeSe1jx9HnV0icsvWyQ5EKAUHJRU,711
150
154
  flwr/server/superlink/fleet/grpc_bidi/__init__.py,sha256=mgGJGjwT6VU7ovC1gdnnqttjyBPlNIcZnYRqx4K3IBQ,735
151
155
  flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py,sha256=57b3UL5-baGdLwgCtB0dCUTTSbmmfMAXcXV5bjPZNWQ,5993
@@ -174,14 +178,15 @@ flwr/server/utils/tensorboard.py,sha256=k0G6bqsLx7wfYbH2KtXsDYcOCfyIeE12-hefXA7l
174
178
  flwr/server/utils/validator.py,sha256=IJN2475yyD_i_9kg_SJ_JodIuZh58ufpWGUDQRAqu2s,4740
175
179
  flwr/server/workflow/__init__.py,sha256=2YrKq5wUwge8Tm1xaAdf2P3l4LbM4olka6tO_0_Mu9A,787
176
180
  flwr/server/workflow/default_workflows.py,sha256=DKSt14WY5m19ujwh6UDP4a31kRs-j6V_NZm5cXn01ZY,12705
177
- flwr/simulation/__init__.py,sha256=E2eD5FlTmZZ80u21FmWCkacrM7O4mrEHD8iXqeCaBUQ,1278
181
+ flwr/simulation/__init__.py,sha256=jdrJeTnLLj9Eyl8BRPXMewqkhTnxD7fvXDgyjfspy0Q,1359
178
182
  flwr/simulation/app.py,sha256=WqJxdXTEuehwMW605p5NMmvBbKYx5tuqnV3Mp7jSWXM,13904
179
183
  flwr/simulation/ray_transport/__init__.py,sha256=FsaAnzC4cw4DqoouBCix6496k29jACkfeIam55BvW9g,734
180
184
  flwr/simulation/ray_transport/ray_actor.py,sha256=zRETW_xuCAOLRFaYnQ-q3IBSz0LIv_0RifGuhgWaYOg,19872
181
185
  flwr/simulation/ray_transport/ray_client_proxy.py,sha256=DpmrBC87_sX3J4WrrwzyEDIjONUeliBZx9T-gZGuPmQ,6799
182
186
  flwr/simulation/ray_transport/utils.py,sha256=TYdtfg1P9VfTdLMOJlifInGpxWHYs9UfUqIv2wfkRLA,2392
183
- flwr_nightly-1.8.0.dev20240228.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
184
- flwr_nightly-1.8.0.dev20240228.dist-info/METADATA,sha256=ftdjOdLrEKow1vhHw6KFjaP741hpdqxPz2TN75s_ykc,15184
185
- flwr_nightly-1.8.0.dev20240228.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
186
- flwr_nightly-1.8.0.dev20240228.dist-info/entry_points.txt,sha256=S1zLNFLrz0uPWs4Zrgo2EPY0iQiIcCJHrIAlnQkkOBI,262
187
- flwr_nightly-1.8.0.dev20240228.dist-info/RECORD,,
187
+ flwr/simulation/run_simulation.py,sha256=NYUFJ6cG5QtuwEl6f2IWJNMRo_jDmmA41DABqISpq-Q,5950
188
+ flwr_nightly-1.8.0.dev20240229.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
189
+ flwr_nightly-1.8.0.dev20240229.dist-info/METADATA,sha256=bnkE4rD-5EY00NUOdT-ZzoeuextoRXX3zZL0MLR6PbU,15184
190
+ flwr_nightly-1.8.0.dev20240229.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
191
+ flwr_nightly-1.8.0.dev20240229.dist-info/entry_points.txt,sha256=qz6t0YdMrV_PLbEarJ6ITIWSIRrTbG_jossZAYfXBZQ,311
192
+ flwr_nightly-1.8.0.dev20240229.dist-info/RECORD,,
@@ -3,6 +3,7 @@ flower-client-app=flwr.client:run_client_app
3
3
  flower-driver-api=flwr.server:run_driver_api
4
4
  flower-fleet-api=flwr.server:run_fleet_api
5
5
  flower-server-app=flwr.server:run_server_app
6
+ flower-simulation=flwr.simulation:run_simulation
6
7
  flower-superlink=flwr.server:run_superlink
7
8
  flwr=flwr.cli.app:app
8
9