flwr-nightly 1.8.0.dev20240309__py3-none-any.whl → 1.8.0.dev20240311__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (39) hide show
  1. flwr/cli/flower_toml.py +4 -48
  2. flwr/cli/new/new.py +6 -3
  3. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -3
  4. flwr/cli/new/templates/app/pyproject.toml.tpl +1 -1
  5. flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +2 -2
  6. flwr/cli/utils.py +14 -1
  7. flwr/client/app.py +39 -5
  8. flwr/client/client_app.py +1 -47
  9. flwr/client/mod/__init__.py +2 -1
  10. flwr/client/mod/secure_aggregation/__init__.py +2 -0
  11. flwr/client/mod/secure_aggregation/secagg_mod.py +30 -0
  12. flwr/client/mod/secure_aggregation/secaggplus_mod.py +73 -57
  13. flwr/common/grpc.py +3 -3
  14. flwr/common/logger.py +78 -15
  15. flwr/common/object_ref.py +140 -0
  16. flwr/common/secure_aggregation/ndarrays_arithmetic.py +5 -5
  17. flwr/common/secure_aggregation/secaggplus_constants.py +7 -6
  18. flwr/common/secure_aggregation/secaggplus_utils.py +15 -15
  19. flwr/server/compat/app.py +2 -1
  20. flwr/server/driver/grpc_driver.py +4 -4
  21. flwr/server/history.py +22 -15
  22. flwr/server/run_serverapp.py +22 -4
  23. flwr/server/server.py +27 -23
  24. flwr/server/server_app.py +1 -47
  25. flwr/server/server_config.py +9 -0
  26. flwr/server/strategy/fedavg.py +2 -0
  27. flwr/server/superlink/fleet/vce/vce_api.py +9 -2
  28. flwr/server/superlink/state/in_memory_state.py +34 -32
  29. flwr/server/workflow/__init__.py +3 -0
  30. flwr/server/workflow/constant.py +32 -0
  31. flwr/server/workflow/default_workflows.py +52 -57
  32. flwr/server/workflow/secure_aggregation/__init__.py +24 -0
  33. flwr/server/workflow/secure_aggregation/secagg_workflow.py +112 -0
  34. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +676 -0
  35. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/METADATA +1 -1
  36. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/RECORD +39 -33
  37. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/LICENSE +0 -0
  38. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/WHEEL +0 -0
  39. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/entry_points.txt +0 -0
@@ -72,13 +72,14 @@ class SecAggPlusState:
72
72
 
73
73
  current_stage: str = Stage.UNMASK
74
74
 
75
- sid: int = 0
75
+ nid: int = 0
76
76
  sample_num: int = 0
77
77
  share_num: int = 0
78
78
  threshold: int = 0
79
79
  clipping_range: float = 0.0
80
80
  target_range: int = 0
81
81
  mod_range: int = 0
82
+ max_weight: float = 0.0
82
83
 
83
84
  # Secret key (sk) and public key (pk)
84
85
  sk1: bytes = b""
@@ -173,12 +174,13 @@ def secaggplus_mod(
173
174
 
174
175
  # Execute
175
176
  if state.current_stage == Stage.SETUP:
177
+ state.nid = msg.metadata.dst_node_id
176
178
  res = _setup(state, configs)
177
179
  elif state.current_stage == Stage.SHARE_KEYS:
178
180
  res = _share_keys(state, configs)
179
- elif state.current_stage == Stage.COLLECT_MASKED_INPUT:
181
+ elif state.current_stage == Stage.COLLECT_MASKED_VECTORS:
180
182
  fit = _get_fit_fn(msg, ctxt, call_next)
181
- res = _collect_masked_input(state, configs, fit)
183
+ res = _collect_masked_vectors(state, configs, fit)
182
184
  elif state.current_stage == Stage.UNMASK:
183
185
  res = _unmask(state, configs)
184
186
  else:
@@ -197,7 +199,7 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
197
199
  # Check the existence of Config.STAGE
198
200
  if Key.STAGE not in configs:
199
201
  raise KeyError(
200
- f"The required key '{Key.STAGE}' is missing from the input `named_values`."
202
+ f"The required key '{Key.STAGE}' is missing from the ConfigsRecord."
201
203
  )
202
204
 
203
205
  # Check the value type of the Config.STAGE
@@ -213,7 +215,7 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
213
215
  if current_stage != Stage.UNMASK:
214
216
  log(WARNING, "Restart from the setup stage")
215
217
  # If stage is not "setup",
216
- # the stage from `named_values` should be the expected next stage
218
+ # the stage from configs should be the expected next stage
217
219
  else:
218
220
  stages = Stage.all()
219
221
  expected_next_stage = stages[(stages.index(current_stage) + 1) % len(stages)]
@@ -227,11 +229,10 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
227
229
  # pylint: disable-next=too-many-branches
228
230
  def check_configs(stage: str, configs: ConfigsRecord) -> None:
229
231
  """Check the validity of the configs."""
230
- # Check `named_values` for the setup stage
232
+ # Check configs for the setup stage
231
233
  if stage == Stage.SETUP:
232
234
  key_type_pairs = [
233
235
  (Key.SAMPLE_NUMBER, int),
234
- (Key.SECURE_ID, int),
235
236
  (Key.SHARE_NUMBER, int),
236
237
  (Key.THRESHOLD, int),
237
238
  (Key.CLIPPING_RANGE, float),
@@ -242,7 +243,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
242
243
  if key not in configs:
243
244
  raise KeyError(
244
245
  f"Stage {Stage.SETUP}: the required key '{key}' is "
245
- "missing from the input `named_values`."
246
+ "missing from the ConfigsRecord."
246
247
  )
247
248
  # Bool is a subclass of int in Python,
248
249
  # so `isinstance(v, int)` will return True even if v is a boolean.
@@ -265,7 +266,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
265
266
  f"Stage {Stage.SHARE_KEYS}: "
266
267
  f"the value for the key '{key}' must be a list of two bytes."
267
268
  )
268
- elif stage == Stage.COLLECT_MASKED_INPUT:
269
+ elif stage == Stage.COLLECT_MASKED_VECTORS:
269
270
  key_type_pairs = [
270
271
  (Key.CIPHERTEXT_LIST, bytes),
271
272
  (Key.SOURCE_LIST, int),
@@ -273,9 +274,9 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
273
274
  for key, expected_type in key_type_pairs:
274
275
  if key not in configs:
275
276
  raise KeyError(
276
- f"Stage {Stage.COLLECT_MASKED_INPUT}: "
277
+ f"Stage {Stage.COLLECT_MASKED_VECTORS}: "
277
278
  f"the required key '{key}' is "
278
- "missing from the input `named_values`."
279
+ "missing from the ConfigsRecord."
279
280
  )
280
281
  if not isinstance(configs[key], list) or any(
281
282
  elm
@@ -284,21 +285,21 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
284
285
  if type(elm) is not expected_type
285
286
  ):
286
287
  raise TypeError(
287
- f"Stage {Stage.COLLECT_MASKED_INPUT}: "
288
+ f"Stage {Stage.COLLECT_MASKED_VECTORS}: "
288
289
  f"the value for the key '{key}' "
289
290
  f"must be of type List[{expected_type.__name__}]"
290
291
  )
291
292
  elif stage == Stage.UNMASK:
292
293
  key_type_pairs = [
293
- (Key.ACTIVE_SECURE_ID_LIST, int),
294
- (Key.DEAD_SECURE_ID_LIST, int),
294
+ (Key.ACTIVE_NODE_ID_LIST, int),
295
+ (Key.DEAD_NODE_ID_LIST, int),
295
296
  ]
296
297
  for key, expected_type in key_type_pairs:
297
298
  if key not in configs:
298
299
  raise KeyError(
299
300
  f"Stage {Stage.UNMASK}: "
300
301
  f"the required key '{key}' is "
301
- "missing from the input `named_values`."
302
+ "missing from the ConfigsRecord."
302
303
  )
303
304
  if not isinstance(configs[key], list) or any(
304
305
  elm
@@ -321,20 +322,20 @@ def _setup(
321
322
  # Assigning parameter values to object fields
322
323
  sec_agg_param_dict = configs
323
324
  state.sample_num = cast(int, sec_agg_param_dict[Key.SAMPLE_NUMBER])
324
- state.sid = cast(int, sec_agg_param_dict[Key.SECURE_ID])
325
- log(INFO, "Client %d: starting stage 0...", state.sid)
325
+ log(INFO, "Node %d: starting stage 0...", state.nid)
326
326
 
327
327
  state.share_num = cast(int, sec_agg_param_dict[Key.SHARE_NUMBER])
328
328
  state.threshold = cast(int, sec_agg_param_dict[Key.THRESHOLD])
329
329
  state.clipping_range = cast(float, sec_agg_param_dict[Key.CLIPPING_RANGE])
330
330
  state.target_range = cast(int, sec_agg_param_dict[Key.TARGET_RANGE])
331
331
  state.mod_range = cast(int, sec_agg_param_dict[Key.MOD_RANGE])
332
+ state.max_weight = cast(float, sec_agg_param_dict[Key.MAX_WEIGHT])
332
333
 
333
- # Dictionaries containing client secure IDs as keys
334
+ # Dictionaries containing node IDs as keys
334
335
  # and their respective secret shares as values.
335
336
  state.rd_seed_share_dict = {}
336
337
  state.sk1_share_dict = {}
337
- # Dictionary containing client secure IDs as keys
338
+ # Dictionary containing node IDs as keys
338
339
  # and their respective shared secrets (with this client) as values.
339
340
  state.ss2_dict = {}
340
341
 
@@ -346,7 +347,7 @@ def _setup(
346
347
 
347
348
  state.sk1, state.pk1 = private_key_to_bytes(sk1), public_key_to_bytes(pk1)
348
349
  state.sk2, state.pk2 = private_key_to_bytes(sk2), public_key_to_bytes(pk2)
349
- log(INFO, "Client %d: stage 0 completes. uploading public keys...", state.sid)
350
+ log(INFO, "Node %d: stage 0 completes. uploading public keys...", state.nid)
350
351
  return {Key.PUBLIC_KEY_1: state.pk1, Key.PUBLIC_KEY_2: state.pk2}
351
352
 
352
353
 
@@ -356,7 +357,7 @@ def _share_keys(
356
357
  ) -> Dict[str, ConfigsRecordValues]:
357
358
  named_bytes_tuples = cast(Dict[str, Tuple[bytes, bytes]], configs)
358
359
  key_dict = {int(sid): (pk1, pk2) for sid, (pk1, pk2) in named_bytes_tuples.items()}
359
- log(INFO, "Client %d: starting stage 1...", state.sid)
360
+ log(INFO, "Node %d: starting stage 1...", state.nid)
360
361
  state.public_keys_dict = key_dict
361
362
 
362
363
  # Check if the size is larger than threshold
@@ -373,8 +374,8 @@ def _share_keys(
373
374
 
374
375
  # Check if public keys of this client are correct in the dictionary
375
376
  if (
376
- state.public_keys_dict[state.sid][0] != state.pk1
377
- or state.public_keys_dict[state.sid][1] != state.pk2
377
+ state.public_keys_dict[state.nid][0] != state.pk1
378
+ or state.public_keys_dict[state.nid][1] != state.pk2
378
379
  ):
379
380
  raise ValueError(
380
381
  "Own public keys are displayed in dict incorrectly, should not happen!"
@@ -390,35 +391,35 @@ def _share_keys(
390
391
  srcs, dsts, ciphertexts = [], [], []
391
392
 
392
393
  # Distribute shares
393
- for idx, (sid, (_, pk2)) in enumerate(state.public_keys_dict.items()):
394
- if sid == state.sid:
395
- state.rd_seed_share_dict[state.sid] = b_shares[idx]
396
- state.sk1_share_dict[state.sid] = sk1_shares[idx]
394
+ for idx, (nid, (_, pk2)) in enumerate(state.public_keys_dict.items()):
395
+ if nid == state.nid:
396
+ state.rd_seed_share_dict[state.nid] = b_shares[idx]
397
+ state.sk1_share_dict[state.nid] = sk1_shares[idx]
397
398
  else:
398
399
  shared_key = generate_shared_key(
399
400
  bytes_to_private_key(state.sk2),
400
401
  bytes_to_public_key(pk2),
401
402
  )
402
- state.ss2_dict[sid] = shared_key
403
+ state.ss2_dict[nid] = shared_key
403
404
  plaintext = share_keys_plaintext_concat(
404
- state.sid, sid, b_shares[idx], sk1_shares[idx]
405
+ state.nid, nid, b_shares[idx], sk1_shares[idx]
405
406
  )
406
407
  ciphertext = encrypt(shared_key, plaintext)
407
- srcs.append(state.sid)
408
- dsts.append(sid)
408
+ srcs.append(state.nid)
409
+ dsts.append(nid)
409
410
  ciphertexts.append(ciphertext)
410
411
 
411
- log(INFO, "Client %d: stage 1 completes. uploading key shares...", state.sid)
412
+ log(INFO, "Node %d: stage 1 completes. uploading key shares...", state.nid)
412
413
  return {Key.DESTINATION_LIST: dsts, Key.CIPHERTEXT_LIST: ciphertexts}
413
414
 
414
415
 
415
416
  # pylint: disable-next=too-many-locals
416
- def _collect_masked_input(
417
+ def _collect_masked_vectors(
417
418
  state: SecAggPlusState,
418
419
  configs: ConfigsRecord,
419
420
  fit: Callable[[], FitRes],
420
421
  ) -> Dict[str, ConfigsRecordValues]:
421
- log(INFO, "Client %d: starting stage 2...", state.sid)
422
+ log(INFO, "Node %d: starting stage 2...", state.nid)
422
423
  available_clients: List[int] = []
423
424
  ciphertexts = cast(List[bytes], configs[Key.CIPHERTEXT_LIST])
424
425
  srcs = cast(List[int], configs[Key.SOURCE_LIST])
@@ -435,29 +436,45 @@ def _collect_masked_input(
435
436
  available_clients.append(src)
436
437
  if src != actual_src:
437
438
  raise ValueError(
438
- f"Client {state.sid}: received ciphertext "
439
+ f"Node {state.nid}: received ciphertext "
439
440
  f"from {actual_src} instead of {src}."
440
441
  )
441
- if dst != state.sid:
442
+ if dst != state.nid:
442
443
  raise ValueError(
443
- f"Client {state.sid}: received an encrypted message"
444
- f"for Client {dst} from Client {src}."
444
+ f"Node {state.nid}: received an encrypted message"
445
+ f"for Node {dst} from Node {src}."
445
446
  )
446
447
  state.rd_seed_share_dict[src] = rd_seed_share
447
448
  state.sk1_share_dict[src] = sk1_share
448
449
 
449
450
  # Fit client
450
451
  fit_res = fit()
451
- parameters_factor = fit_res.num_examples
452
+ if len(fit_res.metrics) > 0:
453
+ log(
454
+ WARNING,
455
+ "The metrics in FitRes will not be preserved or sent to the server.",
456
+ )
457
+ ratio = fit_res.num_examples / state.max_weight
458
+ if ratio > 1:
459
+ log(
460
+ WARNING,
461
+ "Potential overflow warning: the provided weight (%s) exceeds the specified"
462
+ " max_weight (%s). This may lead to overflow issues.",
463
+ fit_res.num_examples,
464
+ state.max_weight,
465
+ )
466
+ q_ratio = round(ratio * state.target_range)
467
+ dq_ratio = q_ratio / state.target_range
468
+
452
469
  parameters = parameters_to_ndarrays(fit_res.parameters)
470
+ parameters = parameters_multiply(parameters, dq_ratio)
453
471
 
454
472
  # Quantize parameter update (vector)
455
473
  quantized_parameters = quantize(
456
474
  parameters, state.clipping_range, state.target_range
457
475
  )
458
476
 
459
- quantized_parameters = parameters_multiply(quantized_parameters, parameters_factor)
460
- quantized_parameters = factor_combine(parameters_factor, quantized_parameters)
477
+ quantized_parameters = factor_combine(q_ratio, quantized_parameters)
461
478
 
462
479
  dimensions_list: List[Tuple[int, ...]] = [a.shape for a in quantized_parameters]
463
480
 
@@ -465,14 +482,14 @@ def _collect_masked_input(
465
482
  private_mask = pseudo_rand_gen(state.rd_seed, state.mod_range, dimensions_list)
466
483
  quantized_parameters = parameters_addition(quantized_parameters, private_mask)
467
484
 
468
- for client_id in available_clients:
485
+ for node_id in available_clients:
469
486
  # Add pairwise masks
470
487
  shared_key = generate_shared_key(
471
488
  bytes_to_private_key(state.sk1),
472
- bytes_to_public_key(state.public_keys_dict[client_id][0]),
489
+ bytes_to_public_key(state.public_keys_dict[node_id][0]),
473
490
  )
474
491
  pairwise_mask = pseudo_rand_gen(shared_key, state.mod_range, dimensions_list)
475
- if state.sid > client_id:
492
+ if state.nid > node_id:
476
493
  quantized_parameters = parameters_addition(
477
494
  quantized_parameters, pairwise_mask
478
495
  )
@@ -483,7 +500,7 @@ def _collect_masked_input(
483
500
 
484
501
  # Take mod of final weight update vector and return to server
485
502
  quantized_parameters = parameters_mod(quantized_parameters, state.mod_range)
486
- log(INFO, "Client %d: stage 2 completes. uploading masked parameters...", state.sid)
503
+ log(INFO, "Node %d: stage 2 completed, uploading masked parameters...", state.nid)
487
504
  return {
488
505
  Key.MASKED_PARAMETERS: [ndarray_to_bytes(arr) for arr in quantized_parameters]
489
506
  }
@@ -492,20 +509,19 @@ def _collect_masked_input(
492
509
  def _unmask(
493
510
  state: SecAggPlusState, configs: ConfigsRecord
494
511
  ) -> Dict[str, ConfigsRecordValues]:
495
- log(INFO, "Client %d: starting stage 3...", state.sid)
512
+ log(INFO, "Node %d: starting stage 3...", state.nid)
496
513
 
497
- active_sids = cast(List[int], configs[Key.ACTIVE_SECURE_ID_LIST])
498
- dead_sids = cast(List[int], configs[Key.DEAD_SECURE_ID_LIST])
499
- # Send private mask seed share for every avaliable client (including itclient)
514
+ active_nids = cast(List[int], configs[Key.ACTIVE_NODE_ID_LIST])
515
+ dead_nids = cast(List[int], configs[Key.DEAD_NODE_ID_LIST])
516
+ # Send private mask seed share for every avaliable client (including itself)
500
517
  # Send first private key share for building pairwise mask for every dropped client
501
- if len(active_sids) < state.threshold:
518
+ if len(active_nids) < state.threshold:
502
519
  raise ValueError("Available neighbours number smaller than threshold")
503
520
 
504
- sids, shares = [], []
505
- sids += active_sids
506
- shares += [state.rd_seed_share_dict[sid] for sid in active_sids]
507
- sids += dead_sids
508
- shares += [state.sk1_share_dict[sid] for sid in dead_sids]
521
+ all_nids, shares = [], []
522
+ all_nids = active_nids + dead_nids
523
+ shares += [state.rd_seed_share_dict[nid] for nid in active_nids]
524
+ shares += [state.sk1_share_dict[nid] for nid in dead_nids]
509
525
 
510
- log(INFO, "Client %d: stage 3 completes. uploading key shares...", state.sid)
511
- return {Key.SECURE_ID_LIST: sids, Key.SHARE_LIST: shares}
526
+ log(INFO, "Node %d: stage 3 completes. uploading key shares...", state.nid)
527
+ return {Key.NODE_ID_LIST: all_nids, Key.SHARE_LIST: shares}
flwr/common/grpc.py CHANGED
@@ -15,7 +15,7 @@
15
15
  """Utility functions for gRPC."""
16
16
 
17
17
 
18
- from logging import INFO
18
+ from logging import DEBUG
19
19
  from typing import Optional
20
20
 
21
21
  import grpc
@@ -49,12 +49,12 @@ def create_channel(
49
49
 
50
50
  if insecure:
51
51
  channel = grpc.insecure_channel(server_address, options=channel_options)
52
- log(INFO, "Opened insecure gRPC connection (no certificates were passed)")
52
+ log(DEBUG, "Opened insecure gRPC connection (no certificates were passed)")
53
53
  else:
54
54
  ssl_channel_credentials = grpc.ssl_channel_credentials(root_certificates)
55
55
  channel = grpc.secure_channel(
56
56
  server_address, ssl_channel_credentials, options=channel_options
57
57
  )
58
- log(INFO, "Opened secure gRPC connection using certificates")
58
+ log(DEBUG, "Opened secure gRPC connection using certificates")
59
59
 
60
60
  return channel
flwr/common/logger.py CHANGED
@@ -18,21 +18,86 @@
18
18
  import logging
19
19
  from logging import WARN, LogRecord
20
20
  from logging.handlers import HTTPHandler
21
- from typing import Any, Dict, Optional, Tuple
21
+ from typing import TYPE_CHECKING, Any, Dict, Optional, TextIO, Tuple
22
22
 
23
23
  # Create logger
24
24
  LOGGER_NAME = "flwr"
25
25
  FLOWER_LOGGER = logging.getLogger(LOGGER_NAME)
26
26
  FLOWER_LOGGER.setLevel(logging.DEBUG)
27
27
 
28
- DEFAULT_FORMATTER = logging.Formatter(
29
- "%(levelname)s %(name)s %(asctime)s | %(filename)s:%(lineno)d | %(message)s"
30
- )
28
+ LOG_COLORS = {
29
+ "DEBUG": "\033[94m", # Blue
30
+ "INFO": "\033[92m", # Green
31
+ "WARNING": "\033[93m", # Yellow
32
+ "ERROR": "\033[91m", # Red
33
+ "CRITICAL": "\033[95m", # Magenta
34
+ "RESET": "\033[0m", # Reset to default
35
+ }
36
+
37
+ if TYPE_CHECKING:
38
+ StreamHandler = logging.StreamHandler[Any]
39
+ else:
40
+ StreamHandler = logging.StreamHandler
41
+
42
+
43
+ class ConsoleHandler(StreamHandler):
44
+ """Console handler that allows configurable formatting."""
45
+
46
+ def __init__(
47
+ self,
48
+ timestamps: bool = False,
49
+ json: bool = False,
50
+ colored: bool = True,
51
+ stream: Optional[TextIO] = None,
52
+ ) -> None:
53
+ super().__init__(stream)
54
+ self.timestamps = timestamps
55
+ self.json = json
56
+ self.colored = colored
57
+
58
+ def emit(self, record: LogRecord) -> None:
59
+ """Emit a record."""
60
+ if self.json:
61
+ record.message = record.getMessage().replace("\t", "").strip()
62
+
63
+ # Check if the message is empty
64
+ if not record.message:
65
+ return
66
+
67
+ super().emit(record)
68
+
69
+ def format(self, record: LogRecord) -> str:
70
+ """Format function that adds colors to log level."""
71
+ seperator = " " * (8 - len(record.levelname))
72
+ if self.json:
73
+ log_fmt = "{lvl='%(levelname)s', time='%(asctime)s', msg='%(message)s'}"
74
+ else:
75
+ log_fmt = (
76
+ f"{LOG_COLORS[record.levelname] if self.colored else ''}"
77
+ f"%(levelname)s {'%(asctime)s' if self.timestamps else ''}"
78
+ f"{LOG_COLORS['RESET'] if self.colored else ''}"
79
+ f": {seperator} %(message)s"
80
+ )
81
+ formatter = logging.Formatter(log_fmt)
82
+ return formatter.format(record)
83
+
84
+
85
+ def update_console_handler(level: int, timestamps: bool, colored: bool) -> None:
86
+ """Update the logging handler."""
87
+ for handler in logging.getLogger(LOGGER_NAME).handlers:
88
+ if isinstance(handler, ConsoleHandler):
89
+ handler.setLevel(level)
90
+ handler.timestamps = timestamps
91
+ handler.colored = colored
92
+
31
93
 
32
94
  # Configure console logger
33
- console_handler = logging.StreamHandler()
34
- console_handler.setLevel(logging.DEBUG)
35
- console_handler.setFormatter(DEFAULT_FORMATTER)
95
+ console_handler = ConsoleHandler(
96
+ timestamps=False,
97
+ json=False,
98
+ colored=True,
99
+ )
100
+ console_handler.setLevel(logging.INFO)
36
101
  FLOWER_LOGGER.addHandler(console_handler)
37
102
 
38
103
 
@@ -103,11 +168,10 @@ def warn_experimental_feature(name: str) -> None:
103
168
  """Warn the user when they use an experimental feature."""
104
169
  log(
105
170
  WARN,
106
- """
107
- EXPERIMENTAL FEATURE: %s
171
+ """EXPERIMENTAL FEATURE: %s
108
172
 
109
- This is an experimental feature. It could change significantly or be removed
110
- entirely in future versions of Flower.
173
+ This is an experimental feature. It could change significantly or be removed
174
+ entirely in future versions of Flower.
111
175
  """,
112
176
  name,
113
177
  )
@@ -117,11 +181,10 @@ def warn_deprecated_feature(name: str) -> None:
117
181
  """Warn the user when they use a deprecated feature."""
118
182
  log(
119
183
  WARN,
120
- """
121
- DEPRECATED FEATURE: %s
184
+ """DEPRECATED FEATURE: %s
122
185
 
123
- This is a deprecated feature. It will be removed
124
- entirely in future versions of Flower.
186
+ This is a deprecated feature. It will be removed
187
+ entirely in future versions of Flower.
125
188
  """,
126
189
  name,
127
190
  )
@@ -0,0 +1,140 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Helper functions to load objects from a reference."""
16
+
17
+
18
+ import ast
19
+ import importlib
20
+ from importlib.util import find_spec
21
+ from typing import Any, Optional, Tuple, Type
22
+
23
+ OBJECT_REF_HELP_STR = """
24
+ \n\nThe object reference string should have the form <module>:<attribute>. Valid
25
+ examples include `client:app` and `project.package.module:wrapper.app`. It must
26
+ refer to a module on the PYTHONPATH and the module needs to have the specified
27
+ attribute.
28
+ """
29
+
30
+
31
+ def validate(
32
+ module_attribute_str: str,
33
+ ) -> Tuple[bool, Optional[str]]:
34
+ """Validate object reference.
35
+
36
+ The object reference string should have the form <module>:<attribute>. Valid
37
+ examples include `client:app` and `project.package.module:wrapper.app`. It must
38
+ refer to a module on the PYTHONPATH and the module needs to have the specified
39
+ attribute.
40
+
41
+ Returns
42
+ -------
43
+ Tuple[bool, Optional[str]]
44
+ A boolean indicating whether an object reference is valid and
45
+ the reason why it might not be.
46
+ """
47
+ module_str, _, attributes_str = module_attribute_str.partition(":")
48
+ if not module_str:
49
+ return (
50
+ False,
51
+ f"Missing module in {module_attribute_str}{OBJECT_REF_HELP_STR}",
52
+ )
53
+ if not attributes_str:
54
+ return (
55
+ False,
56
+ f"Missing attribute in {module_attribute_str}{OBJECT_REF_HELP_STR}",
57
+ )
58
+
59
+ # Load module
60
+ module = find_spec(module_str)
61
+ if module and module.origin:
62
+ if not _find_attribute_in_module(module.origin, attributes_str):
63
+ return (
64
+ False,
65
+ f"Unable to find attribute {attributes_str} in module {module_str}"
66
+ f"{OBJECT_REF_HELP_STR}",
67
+ )
68
+ return (True, None)
69
+
70
+ return (
71
+ False,
72
+ f"Unable to load module {module_str}{OBJECT_REF_HELP_STR}",
73
+ )
74
+
75
+
76
+ def load_app(
77
+ module_attribute_str: str,
78
+ error_type: Type[Exception],
79
+ ) -> Any:
80
+ """Return the object specified in a module attribute string.
81
+
82
+ The module/attribute string should have the form <module>:<attribute>. Valid
83
+ examples include `client:app` and `project.package.module:wrapper.app`. It must
84
+ refer to a module on the PYTHONPATH, the module needs to have the specified
85
+ attribute.
86
+ """
87
+ valid, error_msg = validate(module_attribute_str)
88
+ if not valid and error_msg:
89
+ raise error_type(error_msg) from None
90
+
91
+ module_str, _, attributes_str = module_attribute_str.partition(":")
92
+
93
+ try:
94
+ module = importlib.import_module(module_str)
95
+ except ModuleNotFoundError:
96
+ raise error_type(
97
+ f"Unable to load module {module_str}{OBJECT_REF_HELP_STR}",
98
+ ) from None
99
+
100
+ # Recursively load attribute
101
+ attribute = module
102
+ try:
103
+ for attribute_str in attributes_str.split("."):
104
+ attribute = getattr(attribute, attribute_str)
105
+ except AttributeError:
106
+ raise error_type(
107
+ f"Unable to load attribute {attributes_str} from module {module_str}"
108
+ f"{OBJECT_REF_HELP_STR}",
109
+ ) from None
110
+
111
+ return attribute
112
+
113
+
114
+ def _find_attribute_in_module(file_path: str, attribute_name: str) -> bool:
115
+ """Check if attribute_name exists in module's abstract symbolic tree."""
116
+ with open(file_path, encoding="utf-8") as file:
117
+ node = ast.parse(file.read(), filename=file_path)
118
+
119
+ for n in ast.walk(node):
120
+ if isinstance(n, ast.Assign):
121
+ for target in n.targets:
122
+ if isinstance(target, ast.Name) and target.id == attribute_name:
123
+ return True
124
+ if _is_module_in_all(attribute_name, target, n):
125
+ return True
126
+ return False
127
+
128
+
129
+ def _is_module_in_all(attribute_name: str, target: ast.expr, n: ast.Assign) -> bool:
130
+ """Now check if attribute_name is in __all__."""
131
+ if isinstance(target, ast.Name) and target.id == "__all__":
132
+ if isinstance(n.value, ast.List):
133
+ for elt in n.value.elts:
134
+ if isinstance(elt, ast.Str) and elt.s == attribute_name:
135
+ return True
136
+ elif isinstance(n.value, ast.Tuple):
137
+ for elt in n.value.elts:
138
+ if isinstance(elt, ast.Str) and elt.s == attribute_name:
139
+ return True
140
+ return False
@@ -15,7 +15,7 @@
15
15
  """Utility functions for performing operations on Numpy NDArrays."""
16
16
 
17
17
 
18
- from typing import Any, List, Tuple
18
+ from typing import Any, List, Tuple, Union
19
19
 
20
20
  import numpy as np
21
21
  from numpy.typing import DTypeLike, NDArray
@@ -68,14 +68,14 @@ def parameters_mod(parameters: List[NDArray[Any]], divisor: int) -> List[NDArray
68
68
 
69
69
 
70
70
  def parameters_multiply(
71
- parameters: List[NDArray[Any]], multiplier: int
71
+ parameters: List[NDArray[Any]], multiplier: Union[int, float]
72
72
  ) -> List[NDArray[Any]]:
73
- """Multiply parameters by an integer multiplier."""
73
+ """Multiply parameters by an integer/float multiplier."""
74
74
  return [parameters[idx] * multiplier for idx in range(len(parameters))]
75
75
 
76
76
 
77
77
  def parameters_divide(
78
- parameters: List[NDArray[Any]], divisor: int
78
+ parameters: List[NDArray[Any]], divisor: Union[int, float]
79
79
  ) -> List[NDArray[Any]]:
80
- """Divide weight by an integer divisor."""
80
+ """Divide weight by an integer/float divisor."""
81
81
  return [parameters[idx] / divisor for idx in range(len(parameters))]