flwr-nightly 1.8.0.dev20240228__py3-none-any.whl → 1.8.0.dev20240229__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -52,31 +52,10 @@ from flwr.common.secure_aggregation.ndarrays_arithmetic import (
52
52
  )
53
53
  from flwr.common.secure_aggregation.quantization import quantize
54
54
  from flwr.common.secure_aggregation.secaggplus_constants import (
55
- KEY_ACTIVE_SECURE_ID_LIST,
56
- KEY_CIPHERTEXT_LIST,
57
- KEY_CLIPPING_RANGE,
58
- KEY_DEAD_SECURE_ID_LIST,
59
- KEY_DESTINATION_LIST,
60
- KEY_MASKED_PARAMETERS,
61
- KEY_MOD_RANGE,
62
- KEY_PUBLIC_KEY_1,
63
- KEY_PUBLIC_KEY_2,
64
- KEY_SAMPLE_NUMBER,
65
- KEY_SECURE_ID,
66
- KEY_SECURE_ID_LIST,
67
- KEY_SHARE_LIST,
68
- KEY_SHARE_NUMBER,
69
- KEY_SOURCE_LIST,
70
- KEY_STAGE,
71
- KEY_TARGET_RANGE,
72
- KEY_THRESHOLD,
73
55
  RECORD_KEY_CONFIGS,
74
56
  RECORD_KEY_STATE,
75
- STAGE_COLLECT_MASKED_INPUT,
76
- STAGE_SETUP,
77
- STAGE_SHARE_KEYS,
78
- STAGE_UNMASK,
79
- STAGES,
57
+ Key,
58
+ Stage,
80
59
  )
81
60
  from flwr.common.secure_aggregation.secaggplus_utils import (
82
61
  pseudo_rand_gen,
@@ -91,7 +70,7 @@ from flwr.common.typing import ConfigsRecordValues, FitRes
91
70
  class SecAggPlusState:
92
71
  """State of the SecAgg+ protocol."""
93
72
 
94
- current_stage: str = STAGE_UNMASK
73
+ current_stage: str = Stage.UNMASK
95
74
 
96
75
  sid: int = 0
97
76
  sample_num: int = 0
@@ -187,20 +166,20 @@ def secaggplus_mod(
187
166
  check_stage(state.current_stage, configs)
188
167
 
189
168
  # Update the current stage
190
- state.current_stage = cast(str, configs.pop(KEY_STAGE))
169
+ state.current_stage = cast(str, configs.pop(Key.STAGE))
191
170
 
192
171
  # Check the validity of the configs based on the current stage
193
172
  check_configs(state.current_stage, configs)
194
173
 
195
174
  # Execute
196
- if state.current_stage == STAGE_SETUP:
175
+ if state.current_stage == Stage.SETUP:
197
176
  res = _setup(state, configs)
198
- elif state.current_stage == STAGE_SHARE_KEYS:
177
+ elif state.current_stage == Stage.SHARE_KEYS:
199
178
  res = _share_keys(state, configs)
200
- elif state.current_stage == STAGE_COLLECT_MASKED_INPUT:
179
+ elif state.current_stage == Stage.COLLECT_MASKED_INPUT:
201
180
  fit = _get_fit_fn(msg, ctxt, call_next)
202
181
  res = _collect_masked_input(state, configs, fit)
203
- elif state.current_stage == STAGE_UNMASK:
182
+ elif state.current_stage == Stage.UNMASK:
204
183
  res = _unmask(state, configs)
205
184
  else:
206
185
  raise ValueError(f"Unknown secagg stage: {state.current_stage}")
@@ -215,28 +194,29 @@ def secaggplus_mod(
215
194
 
216
195
  def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
217
196
  """Check the validity of the next stage."""
218
- # Check the existence of KEY_STAGE
219
- if KEY_STAGE not in configs:
197
+ # Check the existence of Config.STAGE
198
+ if Key.STAGE not in configs:
220
199
  raise KeyError(
221
- f"The required key '{KEY_STAGE}' is missing from the input `named_values`."
200
+ f"The required key '{Key.STAGE}' is missing from the input `named_values`."
222
201
  )
223
202
 
224
- # Check the value type of the KEY_STAGE
225
- next_stage = configs[KEY_STAGE]
203
+ # Check the value type of the Config.STAGE
204
+ next_stage = configs[Key.STAGE]
226
205
  if not isinstance(next_stage, str):
227
206
  raise TypeError(
228
- f"The value for the key '{KEY_STAGE}' must be of type {str}, "
207
+ f"The value for the key '{Key.STAGE}' must be of type {str}, "
229
208
  f"but got {type(next_stage)} instead."
230
209
  )
231
210
 
232
211
  # Check the validity of the next stage
233
- if next_stage == STAGE_SETUP:
234
- if current_stage != STAGE_UNMASK:
212
+ if next_stage == Stage.SETUP:
213
+ if current_stage != Stage.UNMASK:
235
214
  log(WARNING, "Restart from the setup stage")
236
215
  # If stage is not "setup",
237
216
  # the stage from `named_values` should be the expected next stage
238
217
  else:
239
- expected_next_stage = STAGES[(STAGES.index(current_stage) + 1) % len(STAGES)]
218
+ stages = Stage.all()
219
+ expected_next_stage = stages[(stages.index(current_stage) + 1) % len(stages)]
240
220
  if next_stage != expected_next_stage:
241
221
  raise ValueError(
242
222
  "Abort secure aggregation: "
@@ -248,20 +228,20 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
248
228
  def check_configs(stage: str, configs: ConfigsRecord) -> None:
249
229
  """Check the validity of the configs."""
250
230
  # Check `named_values` for the setup stage
251
- if stage == STAGE_SETUP:
231
+ if stage == Stage.SETUP:
252
232
  key_type_pairs = [
253
- (KEY_SAMPLE_NUMBER, int),
254
- (KEY_SECURE_ID, int),
255
- (KEY_SHARE_NUMBER, int),
256
- (KEY_THRESHOLD, int),
257
- (KEY_CLIPPING_RANGE, float),
258
- (KEY_TARGET_RANGE, int),
259
- (KEY_MOD_RANGE, int),
233
+ (Key.SAMPLE_NUMBER, int),
234
+ (Key.SECURE_ID, int),
235
+ (Key.SHARE_NUMBER, int),
236
+ (Key.THRESHOLD, int),
237
+ (Key.CLIPPING_RANGE, float),
238
+ (Key.TARGET_RANGE, int),
239
+ (Key.MOD_RANGE, int),
260
240
  ]
261
241
  for key, expected_type in key_type_pairs:
262
242
  if key not in configs:
263
243
  raise KeyError(
264
- f"Stage {STAGE_SETUP}: the required key '{key}' is "
244
+ f"Stage {Stage.SETUP}: the required key '{key}' is "
265
245
  "missing from the input `named_values`."
266
246
  )
267
247
  # Bool is a subclass of int in Python,
@@ -269,11 +249,11 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
269
249
  # pylint: disable-next=unidiomatic-typecheck
270
250
  if type(configs[key]) is not expected_type:
271
251
  raise TypeError(
272
- f"Stage {STAGE_SETUP}: The value for the key '{key}' "
252
+ f"Stage {Stage.SETUP}: The value for the key '{key}' "
273
253
  f"must be of type {expected_type}, "
274
254
  f"but got {type(configs[key])} instead."
275
255
  )
276
- elif stage == STAGE_SHARE_KEYS:
256
+ elif stage == Stage.SHARE_KEYS:
277
257
  for key, value in configs.items():
278
258
  if (
279
259
  not isinstance(value, list)
@@ -282,18 +262,18 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
282
262
  or not isinstance(value[1], bytes)
283
263
  ):
284
264
  raise TypeError(
285
- f"Stage {STAGE_SHARE_KEYS}: "
265
+ f"Stage {Stage.SHARE_KEYS}: "
286
266
  f"the value for the key '{key}' must be a list of two bytes."
287
267
  )
288
- elif stage == STAGE_COLLECT_MASKED_INPUT:
268
+ elif stage == Stage.COLLECT_MASKED_INPUT:
289
269
  key_type_pairs = [
290
- (KEY_CIPHERTEXT_LIST, bytes),
291
- (KEY_SOURCE_LIST, int),
270
+ (Key.CIPHERTEXT_LIST, bytes),
271
+ (Key.SOURCE_LIST, int),
292
272
  ]
293
273
  for key, expected_type in key_type_pairs:
294
274
  if key not in configs:
295
275
  raise KeyError(
296
- f"Stage {STAGE_COLLECT_MASKED_INPUT}: "
276
+ f"Stage {Stage.COLLECT_MASKED_INPUT}: "
297
277
  f"the required key '{key}' is "
298
278
  "missing from the input `named_values`."
299
279
  )
@@ -304,19 +284,19 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
304
284
  if type(elm) is not expected_type
305
285
  ):
306
286
  raise TypeError(
307
- f"Stage {STAGE_COLLECT_MASKED_INPUT}: "
287
+ f"Stage {Stage.COLLECT_MASKED_INPUT}: "
308
288
  f"the value for the key '{key}' "
309
289
  f"must be of type List[{expected_type.__name__}]"
310
290
  )
311
- elif stage == STAGE_UNMASK:
291
+ elif stage == Stage.UNMASK:
312
292
  key_type_pairs = [
313
- (KEY_ACTIVE_SECURE_ID_LIST, int),
314
- (KEY_DEAD_SECURE_ID_LIST, int),
293
+ (Key.ACTIVE_SECURE_ID_LIST, int),
294
+ (Key.DEAD_SECURE_ID_LIST, int),
315
295
  ]
316
296
  for key, expected_type in key_type_pairs:
317
297
  if key not in configs:
318
298
  raise KeyError(
319
- f"Stage {STAGE_UNMASK}: "
299
+ f"Stage {Stage.UNMASK}: "
320
300
  f"the required key '{key}' is "
321
301
  "missing from the input `named_values`."
322
302
  )
@@ -327,7 +307,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
327
307
  if type(elm) is not expected_type
328
308
  ):
329
309
  raise TypeError(
330
- f"Stage {STAGE_UNMASK}: "
310
+ f"Stage {Stage.UNMASK}: "
331
311
  f"the value for the key '{key}' "
332
312
  f"must be of type List[{expected_type.__name__}]"
333
313
  )
@@ -340,15 +320,15 @@ def _setup(
340
320
  ) -> Dict[str, ConfigsRecordValues]:
341
321
  # Assigning parameter values to object fields
342
322
  sec_agg_param_dict = configs
343
- state.sample_num = cast(int, sec_agg_param_dict[KEY_SAMPLE_NUMBER])
344
- state.sid = cast(int, sec_agg_param_dict[KEY_SECURE_ID])
323
+ state.sample_num = cast(int, sec_agg_param_dict[Key.SAMPLE_NUMBER])
324
+ state.sid = cast(int, sec_agg_param_dict[Key.SECURE_ID])
345
325
  log(INFO, "Client %d: starting stage 0...", state.sid)
346
326
 
347
- state.share_num = cast(int, sec_agg_param_dict[KEY_SHARE_NUMBER])
348
- state.threshold = cast(int, sec_agg_param_dict[KEY_THRESHOLD])
349
- state.clipping_range = cast(float, sec_agg_param_dict[KEY_CLIPPING_RANGE])
350
- state.target_range = cast(int, sec_agg_param_dict[KEY_TARGET_RANGE])
351
- state.mod_range = cast(int, sec_agg_param_dict[KEY_MOD_RANGE])
327
+ state.share_num = cast(int, sec_agg_param_dict[Key.SHARE_NUMBER])
328
+ state.threshold = cast(int, sec_agg_param_dict[Key.THRESHOLD])
329
+ state.clipping_range = cast(float, sec_agg_param_dict[Key.CLIPPING_RANGE])
330
+ state.target_range = cast(int, sec_agg_param_dict[Key.TARGET_RANGE])
331
+ state.mod_range = cast(int, sec_agg_param_dict[Key.MOD_RANGE])
352
332
 
353
333
  # Dictionaries containing client secure IDs as keys
354
334
  # and their respective secret shares as values.
@@ -367,7 +347,7 @@ def _setup(
367
347
  state.sk1, state.pk1 = private_key_to_bytes(sk1), public_key_to_bytes(pk1)
368
348
  state.sk2, state.pk2 = private_key_to_bytes(sk2), public_key_to_bytes(pk2)
369
349
  log(INFO, "Client %d: stage 0 completes. uploading public keys...", state.sid)
370
- return {KEY_PUBLIC_KEY_1: state.pk1, KEY_PUBLIC_KEY_2: state.pk2}
350
+ return {Key.PUBLIC_KEY_1: state.pk1, Key.PUBLIC_KEY_2: state.pk2}
371
351
 
372
352
 
373
353
  # pylint: disable-next=too-many-locals
@@ -429,7 +409,7 @@ def _share_keys(
429
409
  ciphertexts.append(ciphertext)
430
410
 
431
411
  log(INFO, "Client %d: stage 1 completes. uploading key shares...", state.sid)
432
- return {KEY_DESTINATION_LIST: dsts, KEY_CIPHERTEXT_LIST: ciphertexts}
412
+ return {Key.DESTINATION_LIST: dsts, Key.CIPHERTEXT_LIST: ciphertexts}
433
413
 
434
414
 
435
415
  # pylint: disable-next=too-many-locals
@@ -440,8 +420,8 @@ def _collect_masked_input(
440
420
  ) -> Dict[str, ConfigsRecordValues]:
441
421
  log(INFO, "Client %d: starting stage 2...", state.sid)
442
422
  available_clients: List[int] = []
443
- ciphertexts = cast(List[bytes], configs[KEY_CIPHERTEXT_LIST])
444
- srcs = cast(List[int], configs[KEY_SOURCE_LIST])
423
+ ciphertexts = cast(List[bytes], configs[Key.CIPHERTEXT_LIST])
424
+ srcs = cast(List[int], configs[Key.SOURCE_LIST])
445
425
  if len(ciphertexts) + 1 < state.threshold:
446
426
  raise ValueError("Not enough available neighbour clients.")
447
427
 
@@ -505,7 +485,7 @@ def _collect_masked_input(
505
485
  quantized_parameters = parameters_mod(quantized_parameters, state.mod_range)
506
486
  log(INFO, "Client %d: stage 2 completes. uploading masked parameters...", state.sid)
507
487
  return {
508
- KEY_MASKED_PARAMETERS: [ndarray_to_bytes(arr) for arr in quantized_parameters]
488
+ Key.MASKED_PARAMETERS: [ndarray_to_bytes(arr) for arr in quantized_parameters]
509
489
  }
510
490
 
511
491
 
@@ -514,8 +494,8 @@ def _unmask(
514
494
  ) -> Dict[str, ConfigsRecordValues]:
515
495
  log(INFO, "Client %d: starting stage 3...", state.sid)
516
496
 
517
- active_sids = cast(List[int], configs[KEY_ACTIVE_SECURE_ID_LIST])
518
- dead_sids = cast(List[int], configs[KEY_DEAD_SECURE_ID_LIST])
497
+ active_sids = cast(List[int], configs[Key.ACTIVE_SECURE_ID_LIST])
498
+ dead_sids = cast(List[int], configs[Key.DEAD_SECURE_ID_LIST])
519
499
  # Send private mask seed share for every avaliable client (including itclient)
520
500
  # Send first private key share for building pairwise mask for every dropped client
521
501
  if len(active_sids) < state.threshold:
@@ -528,4 +508,4 @@ def _unmask(
528
508
  shares += [state.sk1_share_dict[sid] for sid in dead_sids]
529
509
 
530
510
  log(INFO, "Client %d: stage 3 completes. uploading key shares...", state.sid)
531
- return {KEY_SECURE_ID_LIST: sids, KEY_SHARE_LIST: shares}
511
+ return {Key.SECURE_ID_LIST: sids, Key.SHARE_LIST: shares}
@@ -14,33 +14,55 @@
14
14
  # ==============================================================================
15
15
  """Constants for the SecAgg/SecAgg+ protocol."""
16
16
 
17
+
18
+ from __future__ import annotations
19
+
17
20
  RECORD_KEY_STATE = "secaggplus_state"
18
21
  RECORD_KEY_CONFIGS = "secaggplus_configs"
19
22
 
20
- # Names of stages
21
- STAGE_SETUP = "setup"
22
- STAGE_SHARE_KEYS = "share_keys"
23
- STAGE_COLLECT_MASKED_INPUT = "collect_masked_input"
24
- STAGE_UNMASK = "unmask"
25
- STAGES = (STAGE_SETUP, STAGE_SHARE_KEYS, STAGE_COLLECT_MASKED_INPUT, STAGE_UNMASK)
26
-
27
- # All valid keys in received/replied `named_values` dictionaries
28
- KEY_STAGE = "stage"
29
- KEY_SAMPLE_NUMBER = "sample_num"
30
- KEY_SECURE_ID = "secure_id"
31
- KEY_SHARE_NUMBER = "share_num"
32
- KEY_THRESHOLD = "threshold"
33
- KEY_CLIPPING_RANGE = "clipping_range"
34
- KEY_TARGET_RANGE = "target_range"
35
- KEY_MOD_RANGE = "mod_range"
36
- KEY_PUBLIC_KEY_1 = "pk1"
37
- KEY_PUBLIC_KEY_2 = "pk2"
38
- KEY_DESTINATION_LIST = "dsts"
39
- KEY_CIPHERTEXT_LIST = "ctxts"
40
- KEY_SOURCE_LIST = "srcs"
41
- KEY_PARAMETERS = "params"
42
- KEY_MASKED_PARAMETERS = "masked_params"
43
- KEY_ACTIVE_SECURE_ID_LIST = "active_sids"
44
- KEY_DEAD_SECURE_ID_LIST = "dead_sids"
45
- KEY_SECURE_ID_LIST = "sids"
46
- KEY_SHARE_LIST = "shares"
23
+
24
+ class Stage:
25
+ """Stages for the SecAgg+ protocol."""
26
+
27
+ SETUP = "setup"
28
+ SHARE_KEYS = "share_keys"
29
+ COLLECT_MASKED_INPUT = "collect_masked_input"
30
+ UNMASK = "unmask"
31
+ _stages = (SETUP, SHARE_KEYS, COLLECT_MASKED_INPUT, UNMASK)
32
+
33
+ @classmethod
34
+ def all(cls) -> tuple[str, str, str, str]:
35
+ """Return all stages."""
36
+ return cls._stages
37
+
38
+ def __new__(cls) -> Stage:
39
+ """Prevent instantiation."""
40
+ raise TypeError(f"{cls.__name__} cannot be instantiated.")
41
+
42
+
43
+ class Key:
44
+ """Keys for the configs in the ConfigsRecord."""
45
+
46
+ STAGE = "stage"
47
+ SAMPLE_NUMBER = "sample_num"
48
+ SECURE_ID = "secure_id"
49
+ SHARE_NUMBER = "share_num"
50
+ THRESHOLD = "threshold"
51
+ CLIPPING_RANGE = "clipping_range"
52
+ TARGET_RANGE = "target_range"
53
+ MOD_RANGE = "mod_range"
54
+ PUBLIC_KEY_1 = "pk1"
55
+ PUBLIC_KEY_2 = "pk2"
56
+ DESTINATION_LIST = "dsts"
57
+ CIPHERTEXT_LIST = "ctxts"
58
+ SOURCE_LIST = "srcs"
59
+ PARAMETERS = "params"
60
+ MASKED_PARAMETERS = "masked_params"
61
+ ACTIVE_SECURE_ID_LIST = "active_sids"
62
+ DEAD_SECURE_ID_LIST = "dead_sids"
63
+ SECURE_ID_LIST = "sids"
64
+ SHARE_LIST = "shares"
65
+
66
+ def __new__(cls) -> Key:
67
+ """Prevent instantiation."""
68
+ raise TypeError(f"{cls.__name__} cannot be instantiated.")
@@ -0,0 +1,26 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: flwr/proto/error.proto
4
+ # Protobuf Python Version: 4.25.0
5
+ """Generated protocol buffer code."""
6
+ from google.protobuf import descriptor as _descriptor
7
+ from google.protobuf import descriptor_pool as _descriptor_pool
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ from google.protobuf.internal import builder as _builder
10
+ # @@protoc_insertion_point(imports)
11
+
12
+ _sym_db = _symbol_database.Default()
13
+
14
+
15
+
16
+
17
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/error.proto\x12\nflwr.proto\"%\n\x05\x45rror\x12\x0c\n\x04\x63ode\x18\x01 \x01(\x12\x12\x0e\n\x06reason\x18\x02 \x01(\tb\x06proto3')
18
+
19
+ _globals = globals()
20
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
21
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.error_pb2', _globals)
22
+ if _descriptor._USE_C_DESCRIPTORS == False:
23
+ DESCRIPTOR._options = None
24
+ _globals['_ERROR']._serialized_start=38
25
+ _globals['_ERROR']._serialized_end=75
26
+ # @@protoc_insertion_point(module_scope)
@@ -0,0 +1,25 @@
1
+ """
2
+ @generated by mypy-protobuf. Do not edit manually!
3
+ isort:skip_file
4
+ """
5
+ import builtins
6
+ import google.protobuf.descriptor
7
+ import google.protobuf.message
8
+ import typing
9
+ import typing_extensions
10
+
11
+ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
12
+
13
+ class Error(google.protobuf.message.Message):
14
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
15
+ CODE_FIELD_NUMBER: builtins.int
16
+ REASON_FIELD_NUMBER: builtins.int
17
+ code: builtins.int
18
+ reason: typing.Text
19
+ def __init__(self,
20
+ *,
21
+ code: builtins.int = ...,
22
+ reason: typing.Text = ...,
23
+ ) -> None: ...
24
+ def ClearField(self, field_name: typing_extensions.Literal["code",b"code","reason",b"reason"]) -> None: ...
25
+ global___Error = Error
@@ -0,0 +1,4 @@
1
+ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
+ """Client and server classes corresponding to protobuf-defined services."""
3
+ import grpc
4
+
@@ -0,0 +1,4 @@
1
+ """
2
+ @generated by mypy-protobuf. Do not edit manually!
3
+ isort:skip_file
4
+ """
flwr/proto/task_pb2.py CHANGED
@@ -15,19 +15,20 @@ _sym_db = _symbol_database.Default()
15
15
  from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
16
16
  from flwr.proto import recordset_pb2 as flwr_dot_proto_dot_recordset__pb2
17
17
  from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2
18
+ from flwr.proto import error_pb2 as flwr_dot_proto_dot_error__pb2
18
19
 
19
20
 
20
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xd4\x01\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12(\n\trecordset\x18\x08 \x01(\x0b\x32\x15.flwr.proto.RecordSet\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3')
21
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x16\x66lwr/proto/error.proto\"\xf6\x01\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12(\n\trecordset\x18\x08 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\t \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3')
21
22
 
22
23
  _globals = globals()
23
24
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
24
25
  _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.task_pb2', _globals)
25
26
  if _descriptor._USE_C_DESCRIPTORS == False:
26
27
  DESCRIPTOR._options = None
27
- _globals['_TASK']._serialized_start=117
28
- _globals['_TASK']._serialized_end=329
29
- _globals['_TASKINS']._serialized_start=331
30
- _globals['_TASKINS']._serialized_end=423
31
- _globals['_TASKRES']._serialized_start=425
32
- _globals['_TASKRES']._serialized_end=517
28
+ _globals['_TASK']._serialized_start=141
29
+ _globals['_TASK']._serialized_end=387
30
+ _globals['_TASKINS']._serialized_start=389
31
+ _globals['_TASKINS']._serialized_end=481
32
+ _globals['_TASKRES']._serialized_start=483
33
+ _globals['_TASKRES']._serialized_end=575
33
34
  # @@protoc_insertion_point(module_scope)
flwr/proto/task_pb2.pyi CHANGED
@@ -3,6 +3,7 @@
3
3
  isort:skip_file
4
4
  """
5
5
  import builtins
6
+ import flwr.proto.error_pb2
6
7
  import flwr.proto.node_pb2
7
8
  import flwr.proto.recordset_pb2
8
9
  import google.protobuf.descriptor
@@ -23,6 +24,7 @@ class Task(google.protobuf.message.Message):
23
24
  ANCESTRY_FIELD_NUMBER: builtins.int
24
25
  TASK_TYPE_FIELD_NUMBER: builtins.int
25
26
  RECORDSET_FIELD_NUMBER: builtins.int
27
+ ERROR_FIELD_NUMBER: builtins.int
26
28
  @property
27
29
  def producer(self) -> flwr.proto.node_pb2.Node: ...
28
30
  @property
@@ -35,6 +37,8 @@ class Task(google.protobuf.message.Message):
35
37
  task_type: typing.Text
36
38
  @property
37
39
  def recordset(self) -> flwr.proto.recordset_pb2.RecordSet: ...
40
+ @property
41
+ def error(self) -> flwr.proto.error_pb2.Error: ...
38
42
  def __init__(self,
39
43
  *,
40
44
  producer: typing.Optional[flwr.proto.node_pb2.Node] = ...,
@@ -45,9 +49,10 @@ class Task(google.protobuf.message.Message):
45
49
  ancestry: typing.Optional[typing.Iterable[typing.Text]] = ...,
46
50
  task_type: typing.Text = ...,
47
51
  recordset: typing.Optional[flwr.proto.recordset_pb2.RecordSet] = ...,
52
+ error: typing.Optional[flwr.proto.error_pb2.Error] = ...,
48
53
  ) -> None: ...
49
- def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","producer",b"producer","recordset",b"recordset"]) -> builtins.bool: ...
50
- def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","producer",b"producer","recordset",b"recordset","task_type",b"task_type","ttl",b"ttl"]) -> None: ...
54
+ def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","error",b"error","producer",b"producer","recordset",b"recordset"]) -> builtins.bool: ...
55
+ def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","error",b"error","producer",b"producer","recordset",b"recordset","task_type",b"task_type","ttl",b"ttl"]) -> None: ...
51
56
  global___Task = Task
52
57
 
53
58
  class TaskIns(google.protobuf.message.Message):
@@ -16,10 +16,17 @@
16
16
 
17
17
 
18
18
  from .bulyan import Bulyan as Bulyan
19
- from .dp_adaptive_clipping import DifferentialPrivacyClientSideAdaptiveClipping
19
+ from .dp_adaptive_clipping import (
20
+ DifferentialPrivacyClientSideAdaptiveClipping as DifferentialPrivacyClientSideAdaptiveClipping,
21
+ )
22
+ from .dp_adaptive_clipping import (
23
+ DifferentialPrivacyServerSideAdaptiveClipping as DifferentialPrivacyServerSideAdaptiveClipping,
24
+ )
25
+ from .dp_fixed_clipping import (
26
+ DifferentialPrivacyClientSideFixedClipping as DifferentialPrivacyClientSideFixedClipping,
27
+ )
20
28
  from .dp_fixed_clipping import (
21
- DifferentialPrivacyClientSideFixedClipping,
22
- DifferentialPrivacyServerSideFixedClipping,
29
+ DifferentialPrivacyServerSideFixedClipping as DifferentialPrivacyServerSideFixedClipping,
23
30
  )
24
31
  from .dpfedavg_adaptive import DPFedAvgAdaptive as DPFedAvgAdaptive
25
32
  from .dpfedavg_fixed import DPFedAvgFixed as DPFedAvgFixed
@@ -46,6 +53,7 @@ __all__ = [
46
53
  "DPFedAvgAdaptive",
47
54
  "DPFedAvgFixed",
48
55
  "DifferentialPrivacyClientSideAdaptiveClipping",
56
+ "DifferentialPrivacyServerSideAdaptiveClipping",
49
57
  "DifferentialPrivacyClientSideFixedClipping",
50
58
  "DifferentialPrivacyServerSideFixedClipping",
51
59
  "FedAdagrad",
@@ -24,8 +24,19 @@ from typing import Dict, List, Optional, Tuple, Union
24
24
 
25
25
  import numpy as np
26
26
 
27
- from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar
27
+ from flwr.common import (
28
+ EvaluateIns,
29
+ EvaluateRes,
30
+ FitIns,
31
+ FitRes,
32
+ NDArrays,
33
+ Parameters,
34
+ Scalar,
35
+ ndarrays_to_parameters,
36
+ parameters_to_ndarrays,
37
+ )
28
38
  from flwr.common.differential_privacy import (
39
+ adaptive_clip_inputs_inplace,
29
40
  add_gaussian_noise_to_params,
30
41
  compute_adaptive_noise_params,
31
42
  )
@@ -40,6 +51,199 @@ from flwr.server.client_proxy import ClientProxy
40
51
  from flwr.server.strategy.strategy import Strategy
41
52
 
42
53
 
54
+ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
55
+ """Strategy wrapper for central DP with server-side adaptive clipping.
56
+
57
+ Parameters
58
+ ----------
59
+ strategy: Strategy
60
+ The strategy to which DP functionalities will be added by this wrapper.
61
+ noise_multiplier : float
62
+ The noise multiplier for the Gaussian mechanism for model updates.
63
+ num_sampled_clients : int
64
+ The number of clients that are sampled on each round.
65
+ initial_clipping_norm : float
66
+ The initial value of clipping norm. Deafults to 0.1.
67
+ Andrew et al. recommends to set to 0.1.
68
+ target_clipped_quantile : float
69
+ The desired quantile of updates which should be clipped. Defaults to 0.5.
70
+ clip_norm_lr : float
71
+ The learning rate for the clipping norm adaptation. Defaults to 0.2.
72
+ Andrew et al. recommends to set to 0.2.
73
+ clipped_count_stddev : float
74
+ The standard deviation of the noise added to the count of updates below the estimate.
75
+ Andrew et al. recommends to set to `expected_num_records/20`
76
+
77
+ Examples
78
+ --------
79
+ Create a strategy:
80
+
81
+ >>> strategy = fl.server.strategy.FedAvg( ... )
82
+
83
+ Wrap the strategy with the DifferentialPrivacyServerSideAdaptiveClipping wrapper
84
+
85
+ >>> dp_strategy = DifferentialPrivacyServerSideAdaptiveClipping(
86
+ >>> strategy, cfg.noise_multiplier, cfg.num_sampled_clients, ...
87
+ >>> )
88
+ """
89
+
90
+ # pylint: disable=too-many-arguments,too-many-instance-attributes
91
+ def __init__(
92
+ self,
93
+ strategy: Strategy,
94
+ noise_multiplier: float,
95
+ num_sampled_clients: int,
96
+ initial_clipping_norm: float = 0.1,
97
+ target_clipped_quantile: float = 0.5,
98
+ clip_norm_lr: float = 0.2,
99
+ clipped_count_stddev: Optional[float] = None,
100
+ ) -> None:
101
+ super().__init__()
102
+
103
+ if strategy is None:
104
+ raise ValueError("The passed strategy is None.")
105
+
106
+ if noise_multiplier < 0:
107
+ raise ValueError("The noise multiplier should be a non-negative value.")
108
+
109
+ if num_sampled_clients <= 0:
110
+ raise ValueError(
111
+ "The number of sampled clients should be a positive value."
112
+ )
113
+
114
+ if initial_clipping_norm <= 0:
115
+ raise ValueError("The initial clipping norm should be a positive value.")
116
+
117
+ if not 0 <= target_clipped_quantile <= 1:
118
+ raise ValueError(
119
+ "The target clipped quantile must be between 0 and 1 (inclusive)."
120
+ )
121
+
122
+ if clip_norm_lr <= 0:
123
+ raise ValueError("The learning rate must be positive.")
124
+
125
+ if clipped_count_stddev is not None:
126
+ if clipped_count_stddev < 0:
127
+ raise ValueError("The `clipped_count_stddev` must be non-negative.")
128
+
129
+ self.strategy = strategy
130
+ self.num_sampled_clients = num_sampled_clients
131
+ self.clipping_norm = initial_clipping_norm
132
+ self.target_clipped_quantile = target_clipped_quantile
133
+ self.clip_norm_lr = clip_norm_lr
134
+ (
135
+ self.clipped_count_stddev,
136
+ self.noise_multiplier,
137
+ ) = compute_adaptive_noise_params(
138
+ noise_multiplier,
139
+ num_sampled_clients,
140
+ clipped_count_stddev,
141
+ )
142
+
143
+ self.current_round_params: NDArrays = []
144
+
145
+ def __repr__(self) -> str:
146
+ """Compute a string representation of the strategy."""
147
+ rep = "Differential Privacy Strategy Wrapper (Server-Side Adaptive Clipping)"
148
+ return rep
149
+
150
+ def initialize_parameters(
151
+ self, client_manager: ClientManager
152
+ ) -> Optional[Parameters]:
153
+ """Initialize global model parameters using given strategy."""
154
+ return self.strategy.initialize_parameters(client_manager)
155
+
156
+ def configure_fit(
157
+ self, server_round: int, parameters: Parameters, client_manager: ClientManager
158
+ ) -> List[Tuple[ClientProxy, FitIns]]:
159
+ """Configure the next round of training."""
160
+ self.current_round_params = parameters_to_ndarrays(parameters)
161
+ return self.strategy.configure_fit(server_round, parameters, client_manager)
162
+
163
+ def configure_evaluate(
164
+ self, server_round: int, parameters: Parameters, client_manager: ClientManager
165
+ ) -> List[Tuple[ClientProxy, EvaluateIns]]:
166
+ """Configure the next round of evaluation."""
167
+ return self.strategy.configure_evaluate(
168
+ server_round, parameters, client_manager
169
+ )
170
+
171
+ def aggregate_fit(
172
+ self,
173
+ server_round: int,
174
+ results: List[Tuple[ClientProxy, FitRes]],
175
+ failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
176
+ ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
177
+ """Aggregate training results and update clip norms."""
178
+ if failures:
179
+ return None, {}
180
+
181
+ if len(results) != self.num_sampled_clients:
182
+ log(
183
+ WARNING,
184
+ CLIENTS_DISCREPANCY_WARNING,
185
+ len(results),
186
+ self.num_sampled_clients,
187
+ )
188
+
189
+ norm_bit_set_count = 0
190
+ for _, res in results:
191
+ param = parameters_to_ndarrays(res.parameters)
192
+ # Compute and clip update
193
+ model_update = [
194
+ np.subtract(x, y) for (x, y) in zip(param, self.current_round_params)
195
+ ]
196
+
197
+ norm_bit = adaptive_clip_inputs_inplace(model_update, self.clipping_norm)
198
+ norm_bit_set_count += norm_bit
199
+
200
+ for i, _ in enumerate(self.current_round_params):
201
+ param[i] = self.current_round_params[i] + model_update[i]
202
+ # Convert back to parameters
203
+ res.parameters = ndarrays_to_parameters(param)
204
+
205
+ # Noising the count
206
+ noised_norm_bit_set_count = float(
207
+ np.random.normal(norm_bit_set_count, self.clipped_count_stddev)
208
+ )
209
+ noised_norm_bit_set_fraction = noised_norm_bit_set_count / len(results)
210
+ # Geometric update
211
+ self.clipping_norm *= math.exp(
212
+ -self.clip_norm_lr
213
+ * (noised_norm_bit_set_fraction - self.target_clipped_quantile)
214
+ )
215
+
216
+ aggregated_params, metrics = self.strategy.aggregate_fit(
217
+ server_round, results, failures
218
+ )
219
+
220
+ # Add Gaussian noise to the aggregated parameters
221
+ if aggregated_params:
222
+ aggregated_params = add_gaussian_noise_to_params(
223
+ aggregated_params,
224
+ self.noise_multiplier,
225
+ self.clipping_norm,
226
+ self.num_sampled_clients,
227
+ )
228
+
229
+ return aggregated_params, metrics
230
+
231
+ def aggregate_evaluate(
232
+ self,
233
+ server_round: int,
234
+ results: List[Tuple[ClientProxy, EvaluateRes]],
235
+ failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
236
+ ) -> Tuple[Optional[float], Dict[str, Scalar]]:
237
+ """Aggregate evaluation losses using the given strategy."""
238
+ return self.strategy.aggregate_evaluate(server_round, results, failures)
239
+
240
+ def evaluate(
241
+ self, server_round: int, parameters: Parameters
242
+ ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
243
+ """Evaluate model parameters using an evaluation function from the strategy."""
244
+ return self.strategy.evaluate(server_round, parameters)
245
+
246
+
43
247
  class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
44
248
  """Strategy wrapper for central DP with client-side adaptive clipping.
45
249
 
@@ -15,7 +15,7 @@
15
15
  """Driver API servicer."""
16
16
 
17
17
 
18
- from logging import INFO
18
+ from logging import DEBUG, INFO
19
19
  from typing import List, Optional, Set
20
20
  from uuid import UUID
21
21
 
@@ -70,7 +70,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
70
70
  self, request: PushTaskInsRequest, context: grpc.ServicerContext
71
71
  ) -> PushTaskInsResponse:
72
72
  """Push a set of TaskIns."""
73
- log(INFO, "DriverServicer.PushTaskIns")
73
+ log(DEBUG, "DriverServicer.PushTaskIns")
74
74
 
75
75
  # Validate request
76
76
  _raise_if(len(request.task_ins_list) == 0, "`task_ins_list` must not be empty")
@@ -95,7 +95,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
95
95
  self, request: PullTaskResRequest, context: grpc.ServicerContext
96
96
  ) -> PullTaskResResponse:
97
97
  """Pull a set of TaskRes."""
98
- log(INFO, "DriverServicer.PullTaskRes")
98
+ log(DEBUG, "DriverServicer.PullTaskRes")
99
99
 
100
100
  # Convert each task_id str to UUID
101
101
  task_ids: Set[UUID] = {UUID(task_id) for task_id in request.task_ids}
@@ -105,7 +105,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
105
105
 
106
106
  # Register callback
107
107
  def on_rpc_done() -> None:
108
- log(INFO, "DriverServicer.PullTaskRes callback: delete TaskIns/TaskRes")
108
+ log(DEBUG, "DriverServicer.PullTaskRes callback: delete TaskIns/TaskRes")
109
109
 
110
110
  if context.is_active():
111
111
  return
@@ -17,6 +17,8 @@
17
17
 
18
18
  import importlib
19
19
 
20
+ from flwr.simulation.run_simulation import run_simulation
21
+
20
22
  is_ray_installed = importlib.util.find_spec("ray") is not None
21
23
 
22
24
  if is_ray_installed:
@@ -36,4 +38,5 @@ To install the necessary dependencies, install `flwr` with the `simulation` extr
36
38
 
37
39
  __all__ = [
38
40
  "start_simulation",
41
+ "run_simulation",
39
42
  ]
@@ -0,0 +1,177 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower Simulation."""
16
+
17
+ import argparse
18
+ import asyncio
19
+ import json
20
+ import threading
21
+ import traceback
22
+ from logging import ERROR, INFO, WARNING
23
+
24
+ import grpc
25
+
26
+ from flwr.common import EventType, event, log
27
+ from flwr.common.exit_handlers import register_exit_handlers
28
+ from flwr.server.driver.driver import Driver
29
+ from flwr.server.run_serverapp import run
30
+ from flwr.server.superlink.driver.driver_grpc import run_driver_api_grpc
31
+ from flwr.server.superlink.fleet import vce
32
+ from flwr.server.superlink.state import StateFactory
33
+ from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth
34
+
35
+
36
+ def run_simulation() -> None:
37
+ """Run Simulation Engine."""
38
+ args = _parse_args_run_simulation().parse_args()
39
+
40
+ # Load JSON config
41
+ backend_config_dict = json.loads(args.backend_config)
42
+
43
+ # Enable GPU memory growth (relevant only for TF)
44
+ if args.enable_tf_gpu_growth:
45
+ log(INFO, "Enabling GPU growth for Tensorflow on the main thread.")
46
+ enable_tf_gpu_growth()
47
+ # Check that Backend config has also enabled using GPU growth
48
+ use_tf = backend_config_dict.get("tensorflow", False)
49
+ if not use_tf:
50
+ log(WARNING, "Enabling GPU growth for your backend.")
51
+ backend_config_dict["tensorflow"] = True
52
+
53
+ # Convert back to JSON stream
54
+ backend_config = json.dumps(backend_config_dict)
55
+
56
+ # Initialize StateFactory
57
+ state_factory = StateFactory(":flwr-in-memory-state:")
58
+
59
+ # Start Driver API
60
+ driver_server: grpc.Server = run_driver_api_grpc(
61
+ address=args.driver_api_address,
62
+ state_factory=state_factory,
63
+ certificates=None,
64
+ )
65
+
66
+ # SuperLink with Simulation Engine
67
+ f_stop = asyncio.Event()
68
+ superlink_th = threading.Thread(
69
+ target=vce.start_vce,
70
+ kwargs={
71
+ "num_supernodes": args.num_supernodes,
72
+ "client_app_module_name": args.client_app,
73
+ "backend_name": args.backend,
74
+ "backend_config_json_stream": backend_config,
75
+ "working_dir": args.dir,
76
+ "state_factory": state_factory,
77
+ "f_stop": f_stop,
78
+ },
79
+ daemon=False,
80
+ )
81
+
82
+ superlink_th.start()
83
+ event(EventType.RUN_SUPERLINK_ENTER)
84
+
85
+ try:
86
+ # Initialize Driver
87
+ driver = Driver(
88
+ driver_service_address=args.driver_api_address,
89
+ root_certificates=None,
90
+ )
91
+
92
+ # Launch server app
93
+ run(args.server_app, driver, args.dir)
94
+
95
+ except Exception as ex:
96
+
97
+ log(ERROR, "An exception occurred: %s", ex)
98
+ log(ERROR, traceback.format_exc())
99
+ raise RuntimeError(
100
+ "An error was encountered by the Simulation Engine. Ending simulation."
101
+ ) from ex
102
+
103
+ finally:
104
+
105
+ del driver
106
+
107
+ # Trigger stop event
108
+ f_stop.set()
109
+
110
+ register_exit_handlers(
111
+ grpc_servers=[driver_server],
112
+ bckg_threads=[superlink_th],
113
+ event_type=EventType.RUN_SUPERLINK_LEAVE,
114
+ )
115
+ superlink_th.join()
116
+
117
+
118
+ def _parse_args_run_simulation() -> argparse.ArgumentParser:
119
+ """Parse flower-simulation command line arguments."""
120
+ parser = argparse.ArgumentParser(
121
+ description="Start a Flower simulation",
122
+ )
123
+ parser.add_argument(
124
+ "--client-app",
125
+ required=True,
126
+ help="For example: `client:app` or `project.package.module:wrapper.app`",
127
+ )
128
+ parser.add_argument(
129
+ "--server-app",
130
+ required=True,
131
+ help="For example: `server:app` or `project.package.module:wrapper.app`",
132
+ )
133
+ parser.add_argument(
134
+ "--driver-api-address",
135
+ default="0.0.0.0:9091",
136
+ type=str,
137
+ help="For example: `server:app` or `project.package.module:wrapper.app`",
138
+ )
139
+ parser.add_argument(
140
+ "--num-supernodes",
141
+ type=int,
142
+ required=True,
143
+ help="Number of simulated SuperNodes.",
144
+ )
145
+ parser.add_argument(
146
+ "--backend",
147
+ default="ray",
148
+ type=str,
149
+ help="Simulation backend that executes the ClientApp.",
150
+ )
151
+ parser.add_argument(
152
+ "--enable-tf-gpu-growth",
153
+ action="store_true",
154
+ help="Enables GPU growth on the main thread. This is desirable if you make "
155
+ "use of a TensorFlow model on your `ServerApp` while having your `ClientApp` "
156
+ "running on the same GPU. Without enabling this, you might encounter an "
157
+ "out-of-memory error becasue TensorFlow by default allocates all GPU memory."
158
+ "Read mor about how `tf.config.experimental.set_memory_growth()` works in "
159
+ "the TensorFlow documentation: https://www.tensorflow.org/api/stable.",
160
+ )
161
+ parser.add_argument(
162
+ "--backend-config",
163
+ type=str,
164
+ default='{"client_resources": {"num_cpus":2, "num_gpus":0.0}, "tensorflow": 0}',
165
+ help='A JSON formatted stream, e.g \'{"<keyA>":<value>, "<keyB>":<value>}\' to '
166
+ "configure a backend. Values supported in <value> are those included by "
167
+ "`flwr.common.typing.ConfigsRecordValues`. ",
168
+ )
169
+ parser.add_argument(
170
+ "--dir",
171
+ default="",
172
+ help="Add specified directory to the PYTHONPATH and load"
173
+ "ClientApp and ServerApp from there."
174
+ " Default: current working directory.",
175
+ )
176
+
177
+ return parser
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flwr-nightly
3
- Version: 1.8.0.dev20240228
3
+ Version: 1.8.0.dev20240229
4
4
  Summary: Flower: A Friendly Federated Learning Framework
5
5
  Home-page: https://flower.ai
6
6
  License: Apache-2.0
@@ -32,7 +32,7 @@ flwr/client/message_handler/task_handler.py,sha256=ZDJBKmrn2grRMNl1rU1iGs7FiMHL5
32
32
  flwr/client/mod/__init__.py,sha256=w6r7n6fWIrrm4lEk36lh9f1Ix6LXgAzQUrgjmMspY98,961
33
33
  flwr/client/mod/centraldp_mods.py,sha256=aHbzGjSbyRENuU5vzad_tkJ9UDb48uHEvUq-zgydBwo,4954
34
34
  flwr/client/mod/secure_aggregation/__init__.py,sha256=AzCdezuzX2BfXUuxVRwXdv8-zUIXoU-Bf6u4LRhzvg8,796
35
- flwr/client/mod/secure_aggregation/secaggplus_mod.py,sha256=z_5t1YzqLs91ZLW5Yoo7Ozqw9_nyVuEpJ7Noa2a34bs,19890
35
+ flwr/client/mod/secure_aggregation/secaggplus_mod.py,sha256=Zc5b2C58SYeyQVXIbLRESOIq4rMUOuzMHNAjdSAPt6I,19434
36
36
  flwr/client/mod/utils.py,sha256=lvETHcCYsSWz7h8I772hCV_kZspxqlMqzriMZ-SxmKc,1226
37
37
  flwr/client/node_state.py,sha256=KTTs_l4I0jBM7IsSsbAGjhfL_yZC3QANbzyvyfZBRDM,1778
38
38
  flwr/client/node_state_tests.py,sha256=gPwz0zf2iuDSa11jedkur_u3Xm7lokIDG5ALD2MCvSw,2195
@@ -68,7 +68,7 @@ flwr/common/secure_aggregation/crypto/shamir.py,sha256=yY35ZgHlB4YyGW_buG-1X-0M-
68
68
  flwr/common/secure_aggregation/crypto/symmetric_encryption.py,sha256=-zDyQoTsHHQjR7o-92FNIikg1zM_Ke9yynaD5u2BXbQ,3546
69
69
  flwr/common/secure_aggregation/ndarrays_arithmetic.py,sha256=KAHCEHGSTJ6mCgnC8dTpIx6URk11s5XxWTHI_7ToGIg,2979
70
70
  flwr/common/secure_aggregation/quantization.py,sha256=appui7GGrkRPsupF59TkapeV4Na_CyPi73JtJ1pimdI,2310
71
- flwr/common/secure_aggregation/secaggplus_constants.py,sha256=pITi-nuzrNvKWR42XwVFBuejv1RdGLwmuErLp0X0t_Y,1686
71
+ flwr/common/secure_aggregation/secaggplus_constants.py,sha256=2dYMKqiO2ja8DewFeZUAGqM0xujNfyYHVwShRxeRaIM,2132
72
72
  flwr/common/secure_aggregation/secaggplus_utils.py,sha256=PleDyDu7jHNAfbRoEaoQiOjxG6iMl9yA8rNKYTfnyFw,3155
73
73
  flwr/common/serde.py,sha256=0tmfTcywJVLA7Hsu4nAjMn2dVMVbjYZqePJbPcDt01Y,20726
74
74
  flwr/common/telemetry.py,sha256=JkFB6WBOskqAJfzSM-l6tQfRiSi2oiysClfg0-5T7NY,7782
@@ -79,6 +79,10 @@ flwr/proto/driver_pb2.py,sha256=JHIdjNPTgp6YHD-_lz5ZZFB0VIOR3_GmcaOTN4jndc4,3115
79
79
  flwr/proto/driver_pb2.pyi,sha256=xwl2AqIWn0SwAlg-x5RUQeqr6DC48eywnqmD7gbaaFs,4670
80
80
  flwr/proto/driver_pb2_grpc.py,sha256=qQBRdQUz4k2K4DVO7kSfWHx-62UJ85HaYKnKCr6JcU8,7304
81
81
  flwr/proto/driver_pb2_grpc.pyi,sha256=NpOM5eCrIPcuWdYrZAayQSDvvFp6cDCVflabhmuvMfo,2022
82
+ flwr/proto/error_pb2.py,sha256=LarjKL90LbwkXKlhzNrDssgl4DXcvIPve8NVCXHpsKA,1084
83
+ flwr/proto/error_pb2.pyi,sha256=ZNH4HhJTU_KfMXlyCeg8FwU-fcUYxTqEmoJPtWtHikc,734
84
+ flwr/proto/error_pb2_grpc.py,sha256=1oboBPFxaTEXt9Aw7EAj8gXHDCNMhZD2VXqocC9l_gk,159
85
+ flwr/proto/error_pb2_grpc.pyi,sha256=ff2TSiLVnG6IVQcTGzb2DIH3XRSoAvAo_RMcvbMFyc0,76
82
86
  flwr/proto/fleet_pb2.py,sha256=8rKQHu6Oa9ki_NG6kRNGtfPPYZp5kKBZhPW696_kn84,3852
83
87
  flwr/proto/fleet_pb2.pyi,sha256=QXYs9M7_dABghdCMfk5Rjf4w0LsZGDeQ1ojH00XaQME,6182
84
88
  flwr/proto/fleet_pb2_grpc.py,sha256=hF1uPaioZzQMRCP9yPlv9LC0mi_DTuhn-IkQJzWIPCs,7505
@@ -91,8 +95,8 @@ flwr/proto/recordset_pb2.py,sha256=un8L0kvBcgFXQIiQweOseeIJBjlOozUvQY9uTQ42Dqo,6
91
95
  flwr/proto/recordset_pb2.pyi,sha256=NPzCJWAj1xLWzeZ_xZ6uaObQjQfWGnnqlLtn4J-SoFY,14161
92
96
  flwr/proto/recordset_pb2_grpc.py,sha256=1oboBPFxaTEXt9Aw7EAj8gXHDCNMhZD2VXqocC9l_gk,159
93
97
  flwr/proto/recordset_pb2_grpc.pyi,sha256=ff2TSiLVnG6IVQcTGzb2DIH3XRSoAvAo_RMcvbMFyc0,76
94
- flwr/proto/task_pb2.py,sha256=0iZTgNnrypfuO9AqjXqvV79VGaS9lkpi6oohMFeoTso,2272
95
- flwr/proto/task_pb2.pyi,sha256=k3Bulbd7pbYszemI5ntWJtuaYlSz_wnhhBZSyGSwjF8,3937
98
+ flwr/proto/task_pb2.py,sha256=-UX3TqskOIRbPu8U3YwgW9ul2k9ZN6MJGgbIOX3pTqg,2431
99
+ flwr/proto/task_pb2.pyi,sha256=IgXggFya0RpL64z2o2K_qLnZHyZ1mg_WzLxFwEKrpL0,4171
96
100
  flwr/proto/task_pb2_grpc.py,sha256=1oboBPFxaTEXt9Aw7EAj8gXHDCNMhZD2VXqocC9l_gk,159
97
101
  flwr/proto/task_pb2_grpc.pyi,sha256=ff2TSiLVnG6IVQcTGzb2DIH3XRSoAvAo_RMcvbMFyc0,76
98
102
  flwr/proto/transport_pb2.py,sha256=cURzfpCgZvH7GEvBPLvTYijE3HvhK1MePjINk4xYArk,9781
@@ -118,10 +122,10 @@ flwr/server/run_serverapp.py,sha256=7LLE1cVQz0Rl-hZnY7DLXvFxWCdep8xgLgEVC-yffi0,
118
122
  flwr/server/server.py,sha256=JLc2lg-qCchD-Jyg_hTBZQN3rXsnfAGx6qJAo0vqH2Y,17812
119
123
  flwr/server/server_app.py,sha256=avNQ7AMMKsn09ly81C3UBgOfHhM_R29l4MrzlalGoj8,5892
120
124
  flwr/server/server_config.py,sha256=yOHpkdyuhOm--Gy_4Vofvu6jCDxhyECEDpIy02beuCg,1018
121
- flwr/server/strategy/__init__.py,sha256=xKnjkTQQT4PjoCXrUXW4LJ85H_iXNszKZ4u6W__VWp8,2460
125
+ flwr/server/strategy/__init__.py,sha256=7eVZ3hQEg2BgA_usAeL6tsLp9T6XI1VYYoFy08Xn-ew,2836
122
126
  flwr/server/strategy/aggregate.py,sha256=QyRIJtI5gnuY1NbgrcrOvkHxGIxBvApq7d9Y4xl-6W4,13468
123
127
  flwr/server/strategy/bulyan.py,sha256=8GsSVJzRSoSWE2zQUKqC3Z795grdN9xpmc3MSGGXnzM,6532
124
- flwr/server/strategy/dp_adaptive_clipping.py,sha256=DIfQA4PDXNFpEYKTf5aHFzvNn7t-_6Bv1MZdKGrJsc8,9261
128
+ flwr/server/strategy/dp_adaptive_clipping.py,sha256=BVvX1LivyukvEtOZVZnpgVpzH8BBjvA3OmdGwFxgRuQ,16679
125
129
  flwr/server/strategy/dp_fixed_clipping.py,sha256=v9YyX53jt2RatGnFxTK4ZMO_3SN7EdL9YCaaJtn9Fcc,12125
126
130
  flwr/server/strategy/dpfedavg_adaptive.py,sha256=hLJkPQJl1bHjwrBNg3PSRFKf3no0hg5EHiFaWhHlWqw,4877
127
131
  flwr/server/strategy/dpfedavg_fixed.py,sha256=G0yYxrPoM-MHQ889DYN3OeNiEeU0yQrjgAzcq0G653w,7219
@@ -145,7 +149,7 @@ flwr/server/strategy/strategy.py,sha256=g6VoIFogEviRub6G4QsKdIp6M_Ek6GhBhqcdNx5u
145
149
  flwr/server/superlink/__init__.py,sha256=8tHYCfodUlRD8PCP9fHgvu8cz5N31A2QoRVL0jDJ15E,707
146
150
  flwr/server/superlink/driver/__init__.py,sha256=STB1_DASVEg7Cu6L7VYxTzV7UMkgtBkFim09Z82Dh8I,712
147
151
  flwr/server/superlink/driver/driver_grpc.py,sha256=1qSGDs1k_OVPWxp2ofxvQgtYXExrMeC3N_rNPVWH65M,1932
148
- flwr/server/superlink/driver/driver_servicer.py,sha256=69-QMkdefgd8BDDAuy4TSJHVcmZAGqq8Yy9ZPDq2kks,4563
152
+ flwr/server/superlink/driver/driver_servicer.py,sha256=p1rlocgqU2I4s5IwdU8rTZBkQ73yPmuWtK_aUlB7V84,4573
149
153
  flwr/server/superlink/fleet/__init__.py,sha256=C6GCSD5eP5Of6_dIeSe1jx9HnV0icsvWyQ5EKAUHJRU,711
150
154
  flwr/server/superlink/fleet/grpc_bidi/__init__.py,sha256=mgGJGjwT6VU7ovC1gdnnqttjyBPlNIcZnYRqx4K3IBQ,735
151
155
  flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py,sha256=57b3UL5-baGdLwgCtB0dCUTTSbmmfMAXcXV5bjPZNWQ,5993
@@ -174,14 +178,15 @@ flwr/server/utils/tensorboard.py,sha256=k0G6bqsLx7wfYbH2KtXsDYcOCfyIeE12-hefXA7l
174
178
  flwr/server/utils/validator.py,sha256=IJN2475yyD_i_9kg_SJ_JodIuZh58ufpWGUDQRAqu2s,4740
175
179
  flwr/server/workflow/__init__.py,sha256=2YrKq5wUwge8Tm1xaAdf2P3l4LbM4olka6tO_0_Mu9A,787
176
180
  flwr/server/workflow/default_workflows.py,sha256=DKSt14WY5m19ujwh6UDP4a31kRs-j6V_NZm5cXn01ZY,12705
177
- flwr/simulation/__init__.py,sha256=E2eD5FlTmZZ80u21FmWCkacrM7O4mrEHD8iXqeCaBUQ,1278
181
+ flwr/simulation/__init__.py,sha256=jdrJeTnLLj9Eyl8BRPXMewqkhTnxD7fvXDgyjfspy0Q,1359
178
182
  flwr/simulation/app.py,sha256=WqJxdXTEuehwMW605p5NMmvBbKYx5tuqnV3Mp7jSWXM,13904
179
183
  flwr/simulation/ray_transport/__init__.py,sha256=FsaAnzC4cw4DqoouBCix6496k29jACkfeIam55BvW9g,734
180
184
  flwr/simulation/ray_transport/ray_actor.py,sha256=zRETW_xuCAOLRFaYnQ-q3IBSz0LIv_0RifGuhgWaYOg,19872
181
185
  flwr/simulation/ray_transport/ray_client_proxy.py,sha256=DpmrBC87_sX3J4WrrwzyEDIjONUeliBZx9T-gZGuPmQ,6799
182
186
  flwr/simulation/ray_transport/utils.py,sha256=TYdtfg1P9VfTdLMOJlifInGpxWHYs9UfUqIv2wfkRLA,2392
183
- flwr_nightly-1.8.0.dev20240228.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
184
- flwr_nightly-1.8.0.dev20240228.dist-info/METADATA,sha256=ftdjOdLrEKow1vhHw6KFjaP741hpdqxPz2TN75s_ykc,15184
185
- flwr_nightly-1.8.0.dev20240228.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
186
- flwr_nightly-1.8.0.dev20240228.dist-info/entry_points.txt,sha256=S1zLNFLrz0uPWs4Zrgo2EPY0iQiIcCJHrIAlnQkkOBI,262
187
- flwr_nightly-1.8.0.dev20240228.dist-info/RECORD,,
187
+ flwr/simulation/run_simulation.py,sha256=NYUFJ6cG5QtuwEl6f2IWJNMRo_jDmmA41DABqISpq-Q,5950
188
+ flwr_nightly-1.8.0.dev20240229.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
189
+ flwr_nightly-1.8.0.dev20240229.dist-info/METADATA,sha256=bnkE4rD-5EY00NUOdT-ZzoeuextoRXX3zZL0MLR6PbU,15184
190
+ flwr_nightly-1.8.0.dev20240229.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
191
+ flwr_nightly-1.8.0.dev20240229.dist-info/entry_points.txt,sha256=qz6t0YdMrV_PLbEarJ6ITIWSIRrTbG_jossZAYfXBZQ,311
192
+ flwr_nightly-1.8.0.dev20240229.dist-info/RECORD,,
@@ -3,6 +3,7 @@ flower-client-app=flwr.client:run_client_app
3
3
  flower-driver-api=flwr.server:run_driver_api
4
4
  flower-fleet-api=flwr.server:run_fleet_api
5
5
  flower-server-app=flwr.server:run_server_app
6
+ flower-simulation=flwr.simulation:run_simulation
6
7
  flower-superlink=flwr.server:run_superlink
7
8
  flwr=flwr.cli.app:app
8
9