flwr-nightly 1.8.0.dev20240309__py3-none-any.whl → 1.8.0.dev20240311__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. flwr/cli/flower_toml.py +4 -48
  2. flwr/cli/new/new.py +6 -3
  3. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -3
  4. flwr/cli/new/templates/app/pyproject.toml.tpl +1 -1
  5. flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +2 -2
  6. flwr/cli/utils.py +14 -1
  7. flwr/client/app.py +39 -5
  8. flwr/client/client_app.py +1 -47
  9. flwr/client/mod/__init__.py +2 -1
  10. flwr/client/mod/secure_aggregation/__init__.py +2 -0
  11. flwr/client/mod/secure_aggregation/secagg_mod.py +30 -0
  12. flwr/client/mod/secure_aggregation/secaggplus_mod.py +73 -57
  13. flwr/common/grpc.py +3 -3
  14. flwr/common/logger.py +78 -15
  15. flwr/common/object_ref.py +140 -0
  16. flwr/common/secure_aggregation/ndarrays_arithmetic.py +5 -5
  17. flwr/common/secure_aggregation/secaggplus_constants.py +7 -6
  18. flwr/common/secure_aggregation/secaggplus_utils.py +15 -15
  19. flwr/server/compat/app.py +2 -1
  20. flwr/server/driver/grpc_driver.py +4 -4
  21. flwr/server/history.py +22 -15
  22. flwr/server/run_serverapp.py +22 -4
  23. flwr/server/server.py +27 -23
  24. flwr/server/server_app.py +1 -47
  25. flwr/server/server_config.py +9 -0
  26. flwr/server/strategy/fedavg.py +2 -0
  27. flwr/server/superlink/fleet/vce/vce_api.py +9 -2
  28. flwr/server/superlink/state/in_memory_state.py +34 -32
  29. flwr/server/workflow/__init__.py +3 -0
  30. flwr/server/workflow/constant.py +32 -0
  31. flwr/server/workflow/default_workflows.py +52 -57
  32. flwr/server/workflow/secure_aggregation/__init__.py +24 -0
  33. flwr/server/workflow/secure_aggregation/secagg_workflow.py +112 -0
  34. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +676 -0
  35. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/METADATA +1 -1
  36. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/RECORD +39 -33
  37. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/LICENSE +0 -0
  38. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/WHEEL +0 -0
  39. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/entry_points.txt +0 -0
@@ -72,13 +72,14 @@ class SecAggPlusState:
72
72
 
73
73
  current_stage: str = Stage.UNMASK
74
74
 
75
- sid: int = 0
75
+ nid: int = 0
76
76
  sample_num: int = 0
77
77
  share_num: int = 0
78
78
  threshold: int = 0
79
79
  clipping_range: float = 0.0
80
80
  target_range: int = 0
81
81
  mod_range: int = 0
82
+ max_weight: float = 0.0
82
83
 
83
84
  # Secret key (sk) and public key (pk)
84
85
  sk1: bytes = b""
@@ -173,12 +174,13 @@ def secaggplus_mod(
173
174
 
174
175
  # Execute
175
176
  if state.current_stage == Stage.SETUP:
177
+ state.nid = msg.metadata.dst_node_id
176
178
  res = _setup(state, configs)
177
179
  elif state.current_stage == Stage.SHARE_KEYS:
178
180
  res = _share_keys(state, configs)
179
- elif state.current_stage == Stage.COLLECT_MASKED_INPUT:
181
+ elif state.current_stage == Stage.COLLECT_MASKED_VECTORS:
180
182
  fit = _get_fit_fn(msg, ctxt, call_next)
181
- res = _collect_masked_input(state, configs, fit)
183
+ res = _collect_masked_vectors(state, configs, fit)
182
184
  elif state.current_stage == Stage.UNMASK:
183
185
  res = _unmask(state, configs)
184
186
  else:
@@ -197,7 +199,7 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
197
199
  # Check the existence of Config.STAGE
198
200
  if Key.STAGE not in configs:
199
201
  raise KeyError(
200
- f"The required key '{Key.STAGE}' is missing from the input `named_values`."
202
+ f"The required key '{Key.STAGE}' is missing from the ConfigsRecord."
201
203
  )
202
204
 
203
205
  # Check the value type of the Config.STAGE
@@ -213,7 +215,7 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
213
215
  if current_stage != Stage.UNMASK:
214
216
  log(WARNING, "Restart from the setup stage")
215
217
  # If stage is not "setup",
216
- # the stage from `named_values` should be the expected next stage
218
+ # the stage from configs should be the expected next stage
217
219
  else:
218
220
  stages = Stage.all()
219
221
  expected_next_stage = stages[(stages.index(current_stage) + 1) % len(stages)]
@@ -227,11 +229,10 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
227
229
  # pylint: disable-next=too-many-branches
228
230
  def check_configs(stage: str, configs: ConfigsRecord) -> None:
229
231
  """Check the validity of the configs."""
230
- # Check `named_values` for the setup stage
232
+ # Check configs for the setup stage
231
233
  if stage == Stage.SETUP:
232
234
  key_type_pairs = [
233
235
  (Key.SAMPLE_NUMBER, int),
234
- (Key.SECURE_ID, int),
235
236
  (Key.SHARE_NUMBER, int),
236
237
  (Key.THRESHOLD, int),
237
238
  (Key.CLIPPING_RANGE, float),
@@ -242,7 +243,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
242
243
  if key not in configs:
243
244
  raise KeyError(
244
245
  f"Stage {Stage.SETUP}: the required key '{key}' is "
245
- "missing from the input `named_values`."
246
+ "missing from the ConfigsRecord."
246
247
  )
247
248
  # Bool is a subclass of int in Python,
248
249
  # so `isinstance(v, int)` will return True even if v is a boolean.
@@ -265,7 +266,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
265
266
  f"Stage {Stage.SHARE_KEYS}: "
266
267
  f"the value for the key '{key}' must be a list of two bytes."
267
268
  )
268
- elif stage == Stage.COLLECT_MASKED_INPUT:
269
+ elif stage == Stage.COLLECT_MASKED_VECTORS:
269
270
  key_type_pairs = [
270
271
  (Key.CIPHERTEXT_LIST, bytes),
271
272
  (Key.SOURCE_LIST, int),
@@ -273,9 +274,9 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
273
274
  for key, expected_type in key_type_pairs:
274
275
  if key not in configs:
275
276
  raise KeyError(
276
- f"Stage {Stage.COLLECT_MASKED_INPUT}: "
277
+ f"Stage {Stage.COLLECT_MASKED_VECTORS}: "
277
278
  f"the required key '{key}' is "
278
- "missing from the input `named_values`."
279
+ "missing from the ConfigsRecord."
279
280
  )
280
281
  if not isinstance(configs[key], list) or any(
281
282
  elm
@@ -284,21 +285,21 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
284
285
  if type(elm) is not expected_type
285
286
  ):
286
287
  raise TypeError(
287
- f"Stage {Stage.COLLECT_MASKED_INPUT}: "
288
+ f"Stage {Stage.COLLECT_MASKED_VECTORS}: "
288
289
  f"the value for the key '{key}' "
289
290
  f"must be of type List[{expected_type.__name__}]"
290
291
  )
291
292
  elif stage == Stage.UNMASK:
292
293
  key_type_pairs = [
293
- (Key.ACTIVE_SECURE_ID_LIST, int),
294
- (Key.DEAD_SECURE_ID_LIST, int),
294
+ (Key.ACTIVE_NODE_ID_LIST, int),
295
+ (Key.DEAD_NODE_ID_LIST, int),
295
296
  ]
296
297
  for key, expected_type in key_type_pairs:
297
298
  if key not in configs:
298
299
  raise KeyError(
299
300
  f"Stage {Stage.UNMASK}: "
300
301
  f"the required key '{key}' is "
301
- "missing from the input `named_values`."
302
+ "missing from the ConfigsRecord."
302
303
  )
303
304
  if not isinstance(configs[key], list) or any(
304
305
  elm
@@ -321,20 +322,20 @@ def _setup(
321
322
  # Assigning parameter values to object fields
322
323
  sec_agg_param_dict = configs
323
324
  state.sample_num = cast(int, sec_agg_param_dict[Key.SAMPLE_NUMBER])
324
- state.sid = cast(int, sec_agg_param_dict[Key.SECURE_ID])
325
- log(INFO, "Client %d: starting stage 0...", state.sid)
325
+ log(INFO, "Node %d: starting stage 0...", state.nid)
326
326
 
327
327
  state.share_num = cast(int, sec_agg_param_dict[Key.SHARE_NUMBER])
328
328
  state.threshold = cast(int, sec_agg_param_dict[Key.THRESHOLD])
329
329
  state.clipping_range = cast(float, sec_agg_param_dict[Key.CLIPPING_RANGE])
330
330
  state.target_range = cast(int, sec_agg_param_dict[Key.TARGET_RANGE])
331
331
  state.mod_range = cast(int, sec_agg_param_dict[Key.MOD_RANGE])
332
+ state.max_weight = cast(float, sec_agg_param_dict[Key.MAX_WEIGHT])
332
333
 
333
- # Dictionaries containing client secure IDs as keys
334
+ # Dictionaries containing node IDs as keys
334
335
  # and their respective secret shares as values.
335
336
  state.rd_seed_share_dict = {}
336
337
  state.sk1_share_dict = {}
337
- # Dictionary containing client secure IDs as keys
338
+ # Dictionary containing node IDs as keys
338
339
  # and their respective shared secrets (with this client) as values.
339
340
  state.ss2_dict = {}
340
341
 
@@ -346,7 +347,7 @@ def _setup(
346
347
 
347
348
  state.sk1, state.pk1 = private_key_to_bytes(sk1), public_key_to_bytes(pk1)
348
349
  state.sk2, state.pk2 = private_key_to_bytes(sk2), public_key_to_bytes(pk2)
349
- log(INFO, "Client %d: stage 0 completes. uploading public keys...", state.sid)
350
+ log(INFO, "Node %d: stage 0 completes. uploading public keys...", state.nid)
350
351
  return {Key.PUBLIC_KEY_1: state.pk1, Key.PUBLIC_KEY_2: state.pk2}
351
352
 
352
353
 
@@ -356,7 +357,7 @@ def _share_keys(
356
357
  ) -> Dict[str, ConfigsRecordValues]:
357
358
  named_bytes_tuples = cast(Dict[str, Tuple[bytes, bytes]], configs)
358
359
  key_dict = {int(sid): (pk1, pk2) for sid, (pk1, pk2) in named_bytes_tuples.items()}
359
- log(INFO, "Client %d: starting stage 1...", state.sid)
360
+ log(INFO, "Node %d: starting stage 1...", state.nid)
360
361
  state.public_keys_dict = key_dict
361
362
 
362
363
  # Check if the size is larger than threshold
@@ -373,8 +374,8 @@ def _share_keys(
373
374
 
374
375
  # Check if public keys of this client are correct in the dictionary
375
376
  if (
376
- state.public_keys_dict[state.sid][0] != state.pk1
377
- or state.public_keys_dict[state.sid][1] != state.pk2
377
+ state.public_keys_dict[state.nid][0] != state.pk1
378
+ or state.public_keys_dict[state.nid][1] != state.pk2
378
379
  ):
379
380
  raise ValueError(
380
381
  "Own public keys are displayed in dict incorrectly, should not happen!"
@@ -390,35 +391,35 @@ def _share_keys(
390
391
  srcs, dsts, ciphertexts = [], [], []
391
392
 
392
393
  # Distribute shares
393
- for idx, (sid, (_, pk2)) in enumerate(state.public_keys_dict.items()):
394
- if sid == state.sid:
395
- state.rd_seed_share_dict[state.sid] = b_shares[idx]
396
- state.sk1_share_dict[state.sid] = sk1_shares[idx]
394
+ for idx, (nid, (_, pk2)) in enumerate(state.public_keys_dict.items()):
395
+ if nid == state.nid:
396
+ state.rd_seed_share_dict[state.nid] = b_shares[idx]
397
+ state.sk1_share_dict[state.nid] = sk1_shares[idx]
397
398
  else:
398
399
  shared_key = generate_shared_key(
399
400
  bytes_to_private_key(state.sk2),
400
401
  bytes_to_public_key(pk2),
401
402
  )
402
- state.ss2_dict[sid] = shared_key
403
+ state.ss2_dict[nid] = shared_key
403
404
  plaintext = share_keys_plaintext_concat(
404
- state.sid, sid, b_shares[idx], sk1_shares[idx]
405
+ state.nid, nid, b_shares[idx], sk1_shares[idx]
405
406
  )
406
407
  ciphertext = encrypt(shared_key, plaintext)
407
- srcs.append(state.sid)
408
- dsts.append(sid)
408
+ srcs.append(state.nid)
409
+ dsts.append(nid)
409
410
  ciphertexts.append(ciphertext)
410
411
 
411
- log(INFO, "Client %d: stage 1 completes. uploading key shares...", state.sid)
412
+ log(INFO, "Node %d: stage 1 completes. uploading key shares...", state.nid)
412
413
  return {Key.DESTINATION_LIST: dsts, Key.CIPHERTEXT_LIST: ciphertexts}
413
414
 
414
415
 
415
416
  # pylint: disable-next=too-many-locals
416
- def _collect_masked_input(
417
+ def _collect_masked_vectors(
417
418
  state: SecAggPlusState,
418
419
  configs: ConfigsRecord,
419
420
  fit: Callable[[], FitRes],
420
421
  ) -> Dict[str, ConfigsRecordValues]:
421
- log(INFO, "Client %d: starting stage 2...", state.sid)
422
+ log(INFO, "Node %d: starting stage 2...", state.nid)
422
423
  available_clients: List[int] = []
423
424
  ciphertexts = cast(List[bytes], configs[Key.CIPHERTEXT_LIST])
424
425
  srcs = cast(List[int], configs[Key.SOURCE_LIST])
@@ -435,29 +436,45 @@ def _collect_masked_input(
435
436
  available_clients.append(src)
436
437
  if src != actual_src:
437
438
  raise ValueError(
438
- f"Client {state.sid}: received ciphertext "
439
+ f"Node {state.nid}: received ciphertext "
439
440
  f"from {actual_src} instead of {src}."
440
441
  )
441
- if dst != state.sid:
442
+ if dst != state.nid:
442
443
  raise ValueError(
443
- f"Client {state.sid}: received an encrypted message"
444
- f"for Client {dst} from Client {src}."
444
+ f"Node {state.nid}: received an encrypted message"
445
+ f"for Node {dst} from Node {src}."
445
446
  )
446
447
  state.rd_seed_share_dict[src] = rd_seed_share
447
448
  state.sk1_share_dict[src] = sk1_share
448
449
 
449
450
  # Fit client
450
451
  fit_res = fit()
451
- parameters_factor = fit_res.num_examples
452
+ if len(fit_res.metrics) > 0:
453
+ log(
454
+ WARNING,
455
+ "The metrics in FitRes will not be preserved or sent to the server.",
456
+ )
457
+ ratio = fit_res.num_examples / state.max_weight
458
+ if ratio > 1:
459
+ log(
460
+ WARNING,
461
+ "Potential overflow warning: the provided weight (%s) exceeds the specified"
462
+ " max_weight (%s). This may lead to overflow issues.",
463
+ fit_res.num_examples,
464
+ state.max_weight,
465
+ )
466
+ q_ratio = round(ratio * state.target_range)
467
+ dq_ratio = q_ratio / state.target_range
468
+
452
469
  parameters = parameters_to_ndarrays(fit_res.parameters)
470
+ parameters = parameters_multiply(parameters, dq_ratio)
453
471
 
454
472
  # Quantize parameter update (vector)
455
473
  quantized_parameters = quantize(
456
474
  parameters, state.clipping_range, state.target_range
457
475
  )
458
476
 
459
- quantized_parameters = parameters_multiply(quantized_parameters, parameters_factor)
460
- quantized_parameters = factor_combine(parameters_factor, quantized_parameters)
477
+ quantized_parameters = factor_combine(q_ratio, quantized_parameters)
461
478
 
462
479
  dimensions_list: List[Tuple[int, ...]] = [a.shape for a in quantized_parameters]
463
480
 
@@ -465,14 +482,14 @@ def _collect_masked_input(
465
482
  private_mask = pseudo_rand_gen(state.rd_seed, state.mod_range, dimensions_list)
466
483
  quantized_parameters = parameters_addition(quantized_parameters, private_mask)
467
484
 
468
- for client_id in available_clients:
485
+ for node_id in available_clients:
469
486
  # Add pairwise masks
470
487
  shared_key = generate_shared_key(
471
488
  bytes_to_private_key(state.sk1),
472
- bytes_to_public_key(state.public_keys_dict[client_id][0]),
489
+ bytes_to_public_key(state.public_keys_dict[node_id][0]),
473
490
  )
474
491
  pairwise_mask = pseudo_rand_gen(shared_key, state.mod_range, dimensions_list)
475
- if state.sid > client_id:
492
+ if state.nid > node_id:
476
493
  quantized_parameters = parameters_addition(
477
494
  quantized_parameters, pairwise_mask
478
495
  )
@@ -483,7 +500,7 @@ def _collect_masked_input(
483
500
 
484
501
  # Take mod of final weight update vector and return to server
485
502
  quantized_parameters = parameters_mod(quantized_parameters, state.mod_range)
486
- log(INFO, "Client %d: stage 2 completes. uploading masked parameters...", state.sid)
503
+ log(INFO, "Node %d: stage 2 completed, uploading masked parameters...", state.nid)
487
504
  return {
488
505
  Key.MASKED_PARAMETERS: [ndarray_to_bytes(arr) for arr in quantized_parameters]
489
506
  }
@@ -492,20 +509,19 @@ def _collect_masked_input(
492
509
  def _unmask(
493
510
  state: SecAggPlusState, configs: ConfigsRecord
494
511
  ) -> Dict[str, ConfigsRecordValues]:
495
- log(INFO, "Client %d: starting stage 3...", state.sid)
512
+ log(INFO, "Node %d: starting stage 3...", state.nid)
496
513
 
497
- active_sids = cast(List[int], configs[Key.ACTIVE_SECURE_ID_LIST])
498
- dead_sids = cast(List[int], configs[Key.DEAD_SECURE_ID_LIST])
499
- # Send private mask seed share for every avaliable client (including itclient)
514
+ active_nids = cast(List[int], configs[Key.ACTIVE_NODE_ID_LIST])
515
+ dead_nids = cast(List[int], configs[Key.DEAD_NODE_ID_LIST])
516
+ # Send private mask seed share for every avaliable client (including itself)
500
517
  # Send first private key share for building pairwise mask for every dropped client
501
- if len(active_sids) < state.threshold:
518
+ if len(active_nids) < state.threshold:
502
519
  raise ValueError("Available neighbours number smaller than threshold")
503
520
 
504
- sids, shares = [], []
505
- sids += active_sids
506
- shares += [state.rd_seed_share_dict[sid] for sid in active_sids]
507
- sids += dead_sids
508
- shares += [state.sk1_share_dict[sid] for sid in dead_sids]
521
+ all_nids, shares = [], []
522
+ all_nids = active_nids + dead_nids
523
+ shares += [state.rd_seed_share_dict[nid] for nid in active_nids]
524
+ shares += [state.sk1_share_dict[nid] for nid in dead_nids]
509
525
 
510
- log(INFO, "Client %d: stage 3 completes. uploading key shares...", state.sid)
511
- return {Key.SECURE_ID_LIST: sids, Key.SHARE_LIST: shares}
526
+ log(INFO, "Node %d: stage 3 completes. uploading key shares...", state.nid)
527
+ return {Key.NODE_ID_LIST: all_nids, Key.SHARE_LIST: shares}
flwr/common/grpc.py CHANGED
@@ -15,7 +15,7 @@
15
15
  """Utility functions for gRPC."""
16
16
 
17
17
 
18
- from logging import INFO
18
+ from logging import DEBUG
19
19
  from typing import Optional
20
20
 
21
21
  import grpc
@@ -49,12 +49,12 @@ def create_channel(
49
49
 
50
50
  if insecure:
51
51
  channel = grpc.insecure_channel(server_address, options=channel_options)
52
- log(INFO, "Opened insecure gRPC connection (no certificates were passed)")
52
+ log(DEBUG, "Opened insecure gRPC connection (no certificates were passed)")
53
53
  else:
54
54
  ssl_channel_credentials = grpc.ssl_channel_credentials(root_certificates)
55
55
  channel = grpc.secure_channel(
56
56
  server_address, ssl_channel_credentials, options=channel_options
57
57
  )
58
- log(INFO, "Opened secure gRPC connection using certificates")
58
+ log(DEBUG, "Opened secure gRPC connection using certificates")
59
59
 
60
60
  return channel
flwr/common/logger.py CHANGED
@@ -18,21 +18,86 @@
18
18
  import logging
19
19
  from logging import WARN, LogRecord
20
20
  from logging.handlers import HTTPHandler
21
- from typing import Any, Dict, Optional, Tuple
21
+ from typing import TYPE_CHECKING, Any, Dict, Optional, TextIO, Tuple
22
22
 
23
23
  # Create logger
24
24
  LOGGER_NAME = "flwr"
25
25
  FLOWER_LOGGER = logging.getLogger(LOGGER_NAME)
26
26
  FLOWER_LOGGER.setLevel(logging.DEBUG)
27
27
 
28
- DEFAULT_FORMATTER = logging.Formatter(
29
- "%(levelname)s %(name)s %(asctime)s | %(filename)s:%(lineno)d | %(message)s"
30
- )
28
+ LOG_COLORS = {
29
+ "DEBUG": "\033[94m", # Blue
30
+ "INFO": "\033[92m", # Green
31
+ "WARNING": "\033[93m", # Yellow
32
+ "ERROR": "\033[91m", # Red
33
+ "CRITICAL": "\033[95m", # Magenta
34
+ "RESET": "\033[0m", # Reset to default
35
+ }
36
+
37
+ if TYPE_CHECKING:
38
+ StreamHandler = logging.StreamHandler[Any]
39
+ else:
40
+ StreamHandler = logging.StreamHandler
41
+
42
+
43
+ class ConsoleHandler(StreamHandler):
44
+ """Console handler that allows configurable formatting."""
45
+
46
+ def __init__(
47
+ self,
48
+ timestamps: bool = False,
49
+ json: bool = False,
50
+ colored: bool = True,
51
+ stream: Optional[TextIO] = None,
52
+ ) -> None:
53
+ super().__init__(stream)
54
+ self.timestamps = timestamps
55
+ self.json = json
56
+ self.colored = colored
57
+
58
+ def emit(self, record: LogRecord) -> None:
59
+ """Emit a record."""
60
+ if self.json:
61
+ record.message = record.getMessage().replace("\t", "").strip()
62
+
63
+ # Check if the message is empty
64
+ if not record.message:
65
+ return
66
+
67
+ super().emit(record)
68
+
69
+ def format(self, record: LogRecord) -> str:
70
+ """Format function that adds colors to log level."""
71
+ seperator = " " * (8 - len(record.levelname))
72
+ if self.json:
73
+ log_fmt = "{lvl='%(levelname)s', time='%(asctime)s', msg='%(message)s'}"
74
+ else:
75
+ log_fmt = (
76
+ f"{LOG_COLORS[record.levelname] if self.colored else ''}"
77
+ f"%(levelname)s {'%(asctime)s' if self.timestamps else ''}"
78
+ f"{LOG_COLORS['RESET'] if self.colored else ''}"
79
+ f": {seperator} %(message)s"
80
+ )
81
+ formatter = logging.Formatter(log_fmt)
82
+ return formatter.format(record)
83
+
84
+
85
+ def update_console_handler(level: int, timestamps: bool, colored: bool) -> None:
86
+ """Update the logging handler."""
87
+ for handler in logging.getLogger(LOGGER_NAME).handlers:
88
+ if isinstance(handler, ConsoleHandler):
89
+ handler.setLevel(level)
90
+ handler.timestamps = timestamps
91
+ handler.colored = colored
92
+
31
93
 
32
94
  # Configure console logger
33
- console_handler = logging.StreamHandler()
34
- console_handler.setLevel(logging.DEBUG)
35
- console_handler.setFormatter(DEFAULT_FORMATTER)
95
+ console_handler = ConsoleHandler(
96
+ timestamps=False,
97
+ json=False,
98
+ colored=True,
99
+ )
100
+ console_handler.setLevel(logging.INFO)
36
101
  FLOWER_LOGGER.addHandler(console_handler)
37
102
 
38
103
 
@@ -103,11 +168,10 @@ def warn_experimental_feature(name: str) -> None:
103
168
  """Warn the user when they use an experimental feature."""
104
169
  log(
105
170
  WARN,
106
- """
107
- EXPERIMENTAL FEATURE: %s
171
+ """EXPERIMENTAL FEATURE: %s
108
172
 
109
- This is an experimental feature. It could change significantly or be removed
110
- entirely in future versions of Flower.
173
+ This is an experimental feature. It could change significantly or be removed
174
+ entirely in future versions of Flower.
111
175
  """,
112
176
  name,
113
177
  )
@@ -117,11 +181,10 @@ def warn_deprecated_feature(name: str) -> None:
117
181
  """Warn the user when they use a deprecated feature."""
118
182
  log(
119
183
  WARN,
120
- """
121
- DEPRECATED FEATURE: %s
184
+ """DEPRECATED FEATURE: %s
122
185
 
123
- This is a deprecated feature. It will be removed
124
- entirely in future versions of Flower.
186
+ This is a deprecated feature. It will be removed
187
+ entirely in future versions of Flower.
125
188
  """,
126
189
  name,
127
190
  )
@@ -0,0 +1,140 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Helper functions to load objects from a reference."""
16
+
17
+
18
+ import ast
19
+ import importlib
20
+ from importlib.util import find_spec
21
+ from typing import Any, Optional, Tuple, Type
22
+
23
+ OBJECT_REF_HELP_STR = """
24
+ \n\nThe object reference string should have the form <module>:<attribute>. Valid
25
+ examples include `client:app` and `project.package.module:wrapper.app`. It must
26
+ refer to a module on the PYTHONPATH and the module needs to have the specified
27
+ attribute.
28
+ """
29
+
30
+
31
+ def validate(
32
+ module_attribute_str: str,
33
+ ) -> Tuple[bool, Optional[str]]:
34
+ """Validate object reference.
35
+
36
+ The object reference string should have the form <module>:<attribute>. Valid
37
+ examples include `client:app` and `project.package.module:wrapper.app`. It must
38
+ refer to a module on the PYTHONPATH and the module needs to have the specified
39
+ attribute.
40
+
41
+ Returns
42
+ -------
43
+ Tuple[bool, Optional[str]]
44
+ A boolean indicating whether an object reference is valid and
45
+ the reason why it might not be.
46
+ """
47
+ module_str, _, attributes_str = module_attribute_str.partition(":")
48
+ if not module_str:
49
+ return (
50
+ False,
51
+ f"Missing module in {module_attribute_str}{OBJECT_REF_HELP_STR}",
52
+ )
53
+ if not attributes_str:
54
+ return (
55
+ False,
56
+ f"Missing attribute in {module_attribute_str}{OBJECT_REF_HELP_STR}",
57
+ )
58
+
59
+ # Load module
60
+ module = find_spec(module_str)
61
+ if module and module.origin:
62
+ if not _find_attribute_in_module(module.origin, attributes_str):
63
+ return (
64
+ False,
65
+ f"Unable to find attribute {attributes_str} in module {module_str}"
66
+ f"{OBJECT_REF_HELP_STR}",
67
+ )
68
+ return (True, None)
69
+
70
+ return (
71
+ False,
72
+ f"Unable to load module {module_str}{OBJECT_REF_HELP_STR}",
73
+ )
74
+
75
+
76
+ def load_app(
77
+ module_attribute_str: str,
78
+ error_type: Type[Exception],
79
+ ) -> Any:
80
+ """Return the object specified in a module attribute string.
81
+
82
+ The module/attribute string should have the form <module>:<attribute>. Valid
83
+ examples include `client:app` and `project.package.module:wrapper.app`. It must
84
+ refer to a module on the PYTHONPATH, the module needs to have the specified
85
+ attribute.
86
+ """
87
+ valid, error_msg = validate(module_attribute_str)
88
+ if not valid and error_msg:
89
+ raise error_type(error_msg) from None
90
+
91
+ module_str, _, attributes_str = module_attribute_str.partition(":")
92
+
93
+ try:
94
+ module = importlib.import_module(module_str)
95
+ except ModuleNotFoundError:
96
+ raise error_type(
97
+ f"Unable to load module {module_str}{OBJECT_REF_HELP_STR}",
98
+ ) from None
99
+
100
+ # Recursively load attribute
101
+ attribute = module
102
+ try:
103
+ for attribute_str in attributes_str.split("."):
104
+ attribute = getattr(attribute, attribute_str)
105
+ except AttributeError:
106
+ raise error_type(
107
+ f"Unable to load attribute {attributes_str} from module {module_str}"
108
+ f"{OBJECT_REF_HELP_STR}",
109
+ ) from None
110
+
111
+ return attribute
112
+
113
+
114
+ def _find_attribute_in_module(file_path: str, attribute_name: str) -> bool:
115
+ """Check if attribute_name exists in module's abstract symbolic tree."""
116
+ with open(file_path, encoding="utf-8") as file:
117
+ node = ast.parse(file.read(), filename=file_path)
118
+
119
+ for n in ast.walk(node):
120
+ if isinstance(n, ast.Assign):
121
+ for target in n.targets:
122
+ if isinstance(target, ast.Name) and target.id == attribute_name:
123
+ return True
124
+ if _is_module_in_all(attribute_name, target, n):
125
+ return True
126
+ return False
127
+
128
+
129
+ def _is_module_in_all(attribute_name: str, target: ast.expr, n: ast.Assign) -> bool:
130
+ """Now check if attribute_name is in __all__."""
131
+ if isinstance(target, ast.Name) and target.id == "__all__":
132
+ if isinstance(n.value, ast.List):
133
+ for elt in n.value.elts:
134
+ if isinstance(elt, ast.Str) and elt.s == attribute_name:
135
+ return True
136
+ elif isinstance(n.value, ast.Tuple):
137
+ for elt in n.value.elts:
138
+ if isinstance(elt, ast.Str) and elt.s == attribute_name:
139
+ return True
140
+ return False
@@ -15,7 +15,7 @@
15
15
  """Utility functions for performing operations on Numpy NDArrays."""
16
16
 
17
17
 
18
- from typing import Any, List, Tuple
18
+ from typing import Any, List, Tuple, Union
19
19
 
20
20
  import numpy as np
21
21
  from numpy.typing import DTypeLike, NDArray
@@ -68,14 +68,14 @@ def parameters_mod(parameters: List[NDArray[Any]], divisor: int) -> List[NDArray
68
68
 
69
69
 
70
70
  def parameters_multiply(
71
- parameters: List[NDArray[Any]], multiplier: int
71
+ parameters: List[NDArray[Any]], multiplier: Union[int, float]
72
72
  ) -> List[NDArray[Any]]:
73
- """Multiply parameters by an integer multiplier."""
73
+ """Multiply parameters by an integer/float multiplier."""
74
74
  return [parameters[idx] * multiplier for idx in range(len(parameters))]
75
75
 
76
76
 
77
77
  def parameters_divide(
78
- parameters: List[NDArray[Any]], divisor: int
78
+ parameters: List[NDArray[Any]], divisor: Union[int, float]
79
79
  ) -> List[NDArray[Any]]:
80
- """Divide weight by an integer divisor."""
80
+ """Divide weight by an integer/float divisor."""
81
81
  return [parameters[idx] / divisor for idx in range(len(parameters))]