flwr-nightly 1.8.0.dev20240309__py3-none-any.whl → 1.8.0.dev20240311__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/flower_toml.py +4 -48
- flwr/cli/new/new.py +6 -3
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -3
- flwr/cli/new/templates/app/pyproject.toml.tpl +1 -1
- flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +2 -2
- flwr/cli/utils.py +14 -1
- flwr/client/app.py +39 -5
- flwr/client/client_app.py +1 -47
- flwr/client/mod/__init__.py +2 -1
- flwr/client/mod/secure_aggregation/__init__.py +2 -0
- flwr/client/mod/secure_aggregation/secagg_mod.py +30 -0
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +73 -57
- flwr/common/grpc.py +3 -3
- flwr/common/logger.py +78 -15
- flwr/common/object_ref.py +140 -0
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +5 -5
- flwr/common/secure_aggregation/secaggplus_constants.py +7 -6
- flwr/common/secure_aggregation/secaggplus_utils.py +15 -15
- flwr/server/compat/app.py +2 -1
- flwr/server/driver/grpc_driver.py +4 -4
- flwr/server/history.py +22 -15
- flwr/server/run_serverapp.py +22 -4
- flwr/server/server.py +27 -23
- flwr/server/server_app.py +1 -47
- flwr/server/server_config.py +9 -0
- flwr/server/strategy/fedavg.py +2 -0
- flwr/server/superlink/fleet/vce/vce_api.py +9 -2
- flwr/server/superlink/state/in_memory_state.py +34 -32
- flwr/server/workflow/__init__.py +3 -0
- flwr/server/workflow/constant.py +32 -0
- flwr/server/workflow/default_workflows.py +52 -57
- flwr/server/workflow/secure_aggregation/__init__.py +24 -0
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +112 -0
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +676 -0
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/METADATA +1 -1
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/RECORD +39 -33
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/entry_points.txt +0 -0
@@ -72,13 +72,14 @@ class SecAggPlusState:
|
|
72
72
|
|
73
73
|
current_stage: str = Stage.UNMASK
|
74
74
|
|
75
|
-
|
75
|
+
nid: int = 0
|
76
76
|
sample_num: int = 0
|
77
77
|
share_num: int = 0
|
78
78
|
threshold: int = 0
|
79
79
|
clipping_range: float = 0.0
|
80
80
|
target_range: int = 0
|
81
81
|
mod_range: int = 0
|
82
|
+
max_weight: float = 0.0
|
82
83
|
|
83
84
|
# Secret key (sk) and public key (pk)
|
84
85
|
sk1: bytes = b""
|
@@ -173,12 +174,13 @@ def secaggplus_mod(
|
|
173
174
|
|
174
175
|
# Execute
|
175
176
|
if state.current_stage == Stage.SETUP:
|
177
|
+
state.nid = msg.metadata.dst_node_id
|
176
178
|
res = _setup(state, configs)
|
177
179
|
elif state.current_stage == Stage.SHARE_KEYS:
|
178
180
|
res = _share_keys(state, configs)
|
179
|
-
elif state.current_stage == Stage.
|
181
|
+
elif state.current_stage == Stage.COLLECT_MASKED_VECTORS:
|
180
182
|
fit = _get_fit_fn(msg, ctxt, call_next)
|
181
|
-
res =
|
183
|
+
res = _collect_masked_vectors(state, configs, fit)
|
182
184
|
elif state.current_stage == Stage.UNMASK:
|
183
185
|
res = _unmask(state, configs)
|
184
186
|
else:
|
@@ -197,7 +199,7 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
|
|
197
199
|
# Check the existence of Config.STAGE
|
198
200
|
if Key.STAGE not in configs:
|
199
201
|
raise KeyError(
|
200
|
-
f"The required key '{Key.STAGE}' is missing from the
|
202
|
+
f"The required key '{Key.STAGE}' is missing from the ConfigsRecord."
|
201
203
|
)
|
202
204
|
|
203
205
|
# Check the value type of the Config.STAGE
|
@@ -213,7 +215,7 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
|
|
213
215
|
if current_stage != Stage.UNMASK:
|
214
216
|
log(WARNING, "Restart from the setup stage")
|
215
217
|
# If stage is not "setup",
|
216
|
-
# the stage from
|
218
|
+
# the stage from configs should be the expected next stage
|
217
219
|
else:
|
218
220
|
stages = Stage.all()
|
219
221
|
expected_next_stage = stages[(stages.index(current_stage) + 1) % len(stages)]
|
@@ -227,11 +229,10 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
|
|
227
229
|
# pylint: disable-next=too-many-branches
|
228
230
|
def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
229
231
|
"""Check the validity of the configs."""
|
230
|
-
# Check
|
232
|
+
# Check configs for the setup stage
|
231
233
|
if stage == Stage.SETUP:
|
232
234
|
key_type_pairs = [
|
233
235
|
(Key.SAMPLE_NUMBER, int),
|
234
|
-
(Key.SECURE_ID, int),
|
235
236
|
(Key.SHARE_NUMBER, int),
|
236
237
|
(Key.THRESHOLD, int),
|
237
238
|
(Key.CLIPPING_RANGE, float),
|
@@ -242,7 +243,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
242
243
|
if key not in configs:
|
243
244
|
raise KeyError(
|
244
245
|
f"Stage {Stage.SETUP}: the required key '{key}' is "
|
245
|
-
"missing from the
|
246
|
+
"missing from the ConfigsRecord."
|
246
247
|
)
|
247
248
|
# Bool is a subclass of int in Python,
|
248
249
|
# so `isinstance(v, int)` will return True even if v is a boolean.
|
@@ -265,7 +266,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
265
266
|
f"Stage {Stage.SHARE_KEYS}: "
|
266
267
|
f"the value for the key '{key}' must be a list of two bytes."
|
267
268
|
)
|
268
|
-
elif stage == Stage.
|
269
|
+
elif stage == Stage.COLLECT_MASKED_VECTORS:
|
269
270
|
key_type_pairs = [
|
270
271
|
(Key.CIPHERTEXT_LIST, bytes),
|
271
272
|
(Key.SOURCE_LIST, int),
|
@@ -273,9 +274,9 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
273
274
|
for key, expected_type in key_type_pairs:
|
274
275
|
if key not in configs:
|
275
276
|
raise KeyError(
|
276
|
-
f"Stage {Stage.
|
277
|
+
f"Stage {Stage.COLLECT_MASKED_VECTORS}: "
|
277
278
|
f"the required key '{key}' is "
|
278
|
-
"missing from the
|
279
|
+
"missing from the ConfigsRecord."
|
279
280
|
)
|
280
281
|
if not isinstance(configs[key], list) or any(
|
281
282
|
elm
|
@@ -284,21 +285,21 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
284
285
|
if type(elm) is not expected_type
|
285
286
|
):
|
286
287
|
raise TypeError(
|
287
|
-
f"Stage {Stage.
|
288
|
+
f"Stage {Stage.COLLECT_MASKED_VECTORS}: "
|
288
289
|
f"the value for the key '{key}' "
|
289
290
|
f"must be of type List[{expected_type.__name__}]"
|
290
291
|
)
|
291
292
|
elif stage == Stage.UNMASK:
|
292
293
|
key_type_pairs = [
|
293
|
-
(Key.
|
294
|
-
(Key.
|
294
|
+
(Key.ACTIVE_NODE_ID_LIST, int),
|
295
|
+
(Key.DEAD_NODE_ID_LIST, int),
|
295
296
|
]
|
296
297
|
for key, expected_type in key_type_pairs:
|
297
298
|
if key not in configs:
|
298
299
|
raise KeyError(
|
299
300
|
f"Stage {Stage.UNMASK}: "
|
300
301
|
f"the required key '{key}' is "
|
301
|
-
"missing from the
|
302
|
+
"missing from the ConfigsRecord."
|
302
303
|
)
|
303
304
|
if not isinstance(configs[key], list) or any(
|
304
305
|
elm
|
@@ -321,20 +322,20 @@ def _setup(
|
|
321
322
|
# Assigning parameter values to object fields
|
322
323
|
sec_agg_param_dict = configs
|
323
324
|
state.sample_num = cast(int, sec_agg_param_dict[Key.SAMPLE_NUMBER])
|
324
|
-
|
325
|
-
log(INFO, "Client %d: starting stage 0...", state.sid)
|
325
|
+
log(INFO, "Node %d: starting stage 0...", state.nid)
|
326
326
|
|
327
327
|
state.share_num = cast(int, sec_agg_param_dict[Key.SHARE_NUMBER])
|
328
328
|
state.threshold = cast(int, sec_agg_param_dict[Key.THRESHOLD])
|
329
329
|
state.clipping_range = cast(float, sec_agg_param_dict[Key.CLIPPING_RANGE])
|
330
330
|
state.target_range = cast(int, sec_agg_param_dict[Key.TARGET_RANGE])
|
331
331
|
state.mod_range = cast(int, sec_agg_param_dict[Key.MOD_RANGE])
|
332
|
+
state.max_weight = cast(float, sec_agg_param_dict[Key.MAX_WEIGHT])
|
332
333
|
|
333
|
-
# Dictionaries containing
|
334
|
+
# Dictionaries containing node IDs as keys
|
334
335
|
# and their respective secret shares as values.
|
335
336
|
state.rd_seed_share_dict = {}
|
336
337
|
state.sk1_share_dict = {}
|
337
|
-
# Dictionary containing
|
338
|
+
# Dictionary containing node IDs as keys
|
338
339
|
# and their respective shared secrets (with this client) as values.
|
339
340
|
state.ss2_dict = {}
|
340
341
|
|
@@ -346,7 +347,7 @@ def _setup(
|
|
346
347
|
|
347
348
|
state.sk1, state.pk1 = private_key_to_bytes(sk1), public_key_to_bytes(pk1)
|
348
349
|
state.sk2, state.pk2 = private_key_to_bytes(sk2), public_key_to_bytes(pk2)
|
349
|
-
log(INFO, "
|
350
|
+
log(INFO, "Node %d: stage 0 completes. uploading public keys...", state.nid)
|
350
351
|
return {Key.PUBLIC_KEY_1: state.pk1, Key.PUBLIC_KEY_2: state.pk2}
|
351
352
|
|
352
353
|
|
@@ -356,7 +357,7 @@ def _share_keys(
|
|
356
357
|
) -> Dict[str, ConfigsRecordValues]:
|
357
358
|
named_bytes_tuples = cast(Dict[str, Tuple[bytes, bytes]], configs)
|
358
359
|
key_dict = {int(sid): (pk1, pk2) for sid, (pk1, pk2) in named_bytes_tuples.items()}
|
359
|
-
log(INFO, "
|
360
|
+
log(INFO, "Node %d: starting stage 1...", state.nid)
|
360
361
|
state.public_keys_dict = key_dict
|
361
362
|
|
362
363
|
# Check if the size is larger than threshold
|
@@ -373,8 +374,8 @@ def _share_keys(
|
|
373
374
|
|
374
375
|
# Check if public keys of this client are correct in the dictionary
|
375
376
|
if (
|
376
|
-
state.public_keys_dict[state.
|
377
|
-
or state.public_keys_dict[state.
|
377
|
+
state.public_keys_dict[state.nid][0] != state.pk1
|
378
|
+
or state.public_keys_dict[state.nid][1] != state.pk2
|
378
379
|
):
|
379
380
|
raise ValueError(
|
380
381
|
"Own public keys are displayed in dict incorrectly, should not happen!"
|
@@ -390,35 +391,35 @@ def _share_keys(
|
|
390
391
|
srcs, dsts, ciphertexts = [], [], []
|
391
392
|
|
392
393
|
# Distribute shares
|
393
|
-
for idx, (
|
394
|
-
if
|
395
|
-
state.rd_seed_share_dict[state.
|
396
|
-
state.sk1_share_dict[state.
|
394
|
+
for idx, (nid, (_, pk2)) in enumerate(state.public_keys_dict.items()):
|
395
|
+
if nid == state.nid:
|
396
|
+
state.rd_seed_share_dict[state.nid] = b_shares[idx]
|
397
|
+
state.sk1_share_dict[state.nid] = sk1_shares[idx]
|
397
398
|
else:
|
398
399
|
shared_key = generate_shared_key(
|
399
400
|
bytes_to_private_key(state.sk2),
|
400
401
|
bytes_to_public_key(pk2),
|
401
402
|
)
|
402
|
-
state.ss2_dict[
|
403
|
+
state.ss2_dict[nid] = shared_key
|
403
404
|
plaintext = share_keys_plaintext_concat(
|
404
|
-
state.
|
405
|
+
state.nid, nid, b_shares[idx], sk1_shares[idx]
|
405
406
|
)
|
406
407
|
ciphertext = encrypt(shared_key, plaintext)
|
407
|
-
srcs.append(state.
|
408
|
-
dsts.append(
|
408
|
+
srcs.append(state.nid)
|
409
|
+
dsts.append(nid)
|
409
410
|
ciphertexts.append(ciphertext)
|
410
411
|
|
411
|
-
log(INFO, "
|
412
|
+
log(INFO, "Node %d: stage 1 completes. uploading key shares...", state.nid)
|
412
413
|
return {Key.DESTINATION_LIST: dsts, Key.CIPHERTEXT_LIST: ciphertexts}
|
413
414
|
|
414
415
|
|
415
416
|
# pylint: disable-next=too-many-locals
|
416
|
-
def
|
417
|
+
def _collect_masked_vectors(
|
417
418
|
state: SecAggPlusState,
|
418
419
|
configs: ConfigsRecord,
|
419
420
|
fit: Callable[[], FitRes],
|
420
421
|
) -> Dict[str, ConfigsRecordValues]:
|
421
|
-
log(INFO, "
|
422
|
+
log(INFO, "Node %d: starting stage 2...", state.nid)
|
422
423
|
available_clients: List[int] = []
|
423
424
|
ciphertexts = cast(List[bytes], configs[Key.CIPHERTEXT_LIST])
|
424
425
|
srcs = cast(List[int], configs[Key.SOURCE_LIST])
|
@@ -435,29 +436,45 @@ def _collect_masked_input(
|
|
435
436
|
available_clients.append(src)
|
436
437
|
if src != actual_src:
|
437
438
|
raise ValueError(
|
438
|
-
f"
|
439
|
+
f"Node {state.nid}: received ciphertext "
|
439
440
|
f"from {actual_src} instead of {src}."
|
440
441
|
)
|
441
|
-
if dst != state.
|
442
|
+
if dst != state.nid:
|
442
443
|
raise ValueError(
|
443
|
-
f"
|
444
|
-
f"for
|
444
|
+
f"Node {state.nid}: received an encrypted message"
|
445
|
+
f"for Node {dst} from Node {src}."
|
445
446
|
)
|
446
447
|
state.rd_seed_share_dict[src] = rd_seed_share
|
447
448
|
state.sk1_share_dict[src] = sk1_share
|
448
449
|
|
449
450
|
# Fit client
|
450
451
|
fit_res = fit()
|
451
|
-
|
452
|
+
if len(fit_res.metrics) > 0:
|
453
|
+
log(
|
454
|
+
WARNING,
|
455
|
+
"The metrics in FitRes will not be preserved or sent to the server.",
|
456
|
+
)
|
457
|
+
ratio = fit_res.num_examples / state.max_weight
|
458
|
+
if ratio > 1:
|
459
|
+
log(
|
460
|
+
WARNING,
|
461
|
+
"Potential overflow warning: the provided weight (%s) exceeds the specified"
|
462
|
+
" max_weight (%s). This may lead to overflow issues.",
|
463
|
+
fit_res.num_examples,
|
464
|
+
state.max_weight,
|
465
|
+
)
|
466
|
+
q_ratio = round(ratio * state.target_range)
|
467
|
+
dq_ratio = q_ratio / state.target_range
|
468
|
+
|
452
469
|
parameters = parameters_to_ndarrays(fit_res.parameters)
|
470
|
+
parameters = parameters_multiply(parameters, dq_ratio)
|
453
471
|
|
454
472
|
# Quantize parameter update (vector)
|
455
473
|
quantized_parameters = quantize(
|
456
474
|
parameters, state.clipping_range, state.target_range
|
457
475
|
)
|
458
476
|
|
459
|
-
quantized_parameters =
|
460
|
-
quantized_parameters = factor_combine(parameters_factor, quantized_parameters)
|
477
|
+
quantized_parameters = factor_combine(q_ratio, quantized_parameters)
|
461
478
|
|
462
479
|
dimensions_list: List[Tuple[int, ...]] = [a.shape for a in quantized_parameters]
|
463
480
|
|
@@ -465,14 +482,14 @@ def _collect_masked_input(
|
|
465
482
|
private_mask = pseudo_rand_gen(state.rd_seed, state.mod_range, dimensions_list)
|
466
483
|
quantized_parameters = parameters_addition(quantized_parameters, private_mask)
|
467
484
|
|
468
|
-
for
|
485
|
+
for node_id in available_clients:
|
469
486
|
# Add pairwise masks
|
470
487
|
shared_key = generate_shared_key(
|
471
488
|
bytes_to_private_key(state.sk1),
|
472
|
-
bytes_to_public_key(state.public_keys_dict[
|
489
|
+
bytes_to_public_key(state.public_keys_dict[node_id][0]),
|
473
490
|
)
|
474
491
|
pairwise_mask = pseudo_rand_gen(shared_key, state.mod_range, dimensions_list)
|
475
|
-
if state.
|
492
|
+
if state.nid > node_id:
|
476
493
|
quantized_parameters = parameters_addition(
|
477
494
|
quantized_parameters, pairwise_mask
|
478
495
|
)
|
@@ -483,7 +500,7 @@ def _collect_masked_input(
|
|
483
500
|
|
484
501
|
# Take mod of final weight update vector and return to server
|
485
502
|
quantized_parameters = parameters_mod(quantized_parameters, state.mod_range)
|
486
|
-
log(INFO, "
|
503
|
+
log(INFO, "Node %d: stage 2 completed, uploading masked parameters...", state.nid)
|
487
504
|
return {
|
488
505
|
Key.MASKED_PARAMETERS: [ndarray_to_bytes(arr) for arr in quantized_parameters]
|
489
506
|
}
|
@@ -492,20 +509,19 @@ def _collect_masked_input(
|
|
492
509
|
def _unmask(
|
493
510
|
state: SecAggPlusState, configs: ConfigsRecord
|
494
511
|
) -> Dict[str, ConfigsRecordValues]:
|
495
|
-
log(INFO, "
|
512
|
+
log(INFO, "Node %d: starting stage 3...", state.nid)
|
496
513
|
|
497
|
-
|
498
|
-
|
499
|
-
# Send private mask seed share for every avaliable client (including
|
514
|
+
active_nids = cast(List[int], configs[Key.ACTIVE_NODE_ID_LIST])
|
515
|
+
dead_nids = cast(List[int], configs[Key.DEAD_NODE_ID_LIST])
|
516
|
+
# Send private mask seed share for every avaliable client (including itself)
|
500
517
|
# Send first private key share for building pairwise mask for every dropped client
|
501
|
-
if len(
|
518
|
+
if len(active_nids) < state.threshold:
|
502
519
|
raise ValueError("Available neighbours number smaller than threshold")
|
503
520
|
|
504
|
-
|
505
|
-
|
506
|
-
shares += [state.rd_seed_share_dict[
|
507
|
-
|
508
|
-
shares += [state.sk1_share_dict[sid] for sid in dead_sids]
|
521
|
+
all_nids, shares = [], []
|
522
|
+
all_nids = active_nids + dead_nids
|
523
|
+
shares += [state.rd_seed_share_dict[nid] for nid in active_nids]
|
524
|
+
shares += [state.sk1_share_dict[nid] for nid in dead_nids]
|
509
525
|
|
510
|
-
log(INFO, "
|
511
|
-
return {Key.
|
526
|
+
log(INFO, "Node %d: stage 3 completes. uploading key shares...", state.nid)
|
527
|
+
return {Key.NODE_ID_LIST: all_nids, Key.SHARE_LIST: shares}
|
flwr/common/grpc.py
CHANGED
@@ -15,7 +15,7 @@
|
|
15
15
|
"""Utility functions for gRPC."""
|
16
16
|
|
17
17
|
|
18
|
-
from logging import
|
18
|
+
from logging import DEBUG
|
19
19
|
from typing import Optional
|
20
20
|
|
21
21
|
import grpc
|
@@ -49,12 +49,12 @@ def create_channel(
|
|
49
49
|
|
50
50
|
if insecure:
|
51
51
|
channel = grpc.insecure_channel(server_address, options=channel_options)
|
52
|
-
log(
|
52
|
+
log(DEBUG, "Opened insecure gRPC connection (no certificates were passed)")
|
53
53
|
else:
|
54
54
|
ssl_channel_credentials = grpc.ssl_channel_credentials(root_certificates)
|
55
55
|
channel = grpc.secure_channel(
|
56
56
|
server_address, ssl_channel_credentials, options=channel_options
|
57
57
|
)
|
58
|
-
log(
|
58
|
+
log(DEBUG, "Opened secure gRPC connection using certificates")
|
59
59
|
|
60
60
|
return channel
|
flwr/common/logger.py
CHANGED
@@ -18,21 +18,86 @@
|
|
18
18
|
import logging
|
19
19
|
from logging import WARN, LogRecord
|
20
20
|
from logging.handlers import HTTPHandler
|
21
|
-
from typing import Any, Dict, Optional, Tuple
|
21
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, TextIO, Tuple
|
22
22
|
|
23
23
|
# Create logger
|
24
24
|
LOGGER_NAME = "flwr"
|
25
25
|
FLOWER_LOGGER = logging.getLogger(LOGGER_NAME)
|
26
26
|
FLOWER_LOGGER.setLevel(logging.DEBUG)
|
27
27
|
|
28
|
-
|
29
|
-
"
|
30
|
-
|
28
|
+
LOG_COLORS = {
|
29
|
+
"DEBUG": "\033[94m", # Blue
|
30
|
+
"INFO": "\033[92m", # Green
|
31
|
+
"WARNING": "\033[93m", # Yellow
|
32
|
+
"ERROR": "\033[91m", # Red
|
33
|
+
"CRITICAL": "\033[95m", # Magenta
|
34
|
+
"RESET": "\033[0m", # Reset to default
|
35
|
+
}
|
36
|
+
|
37
|
+
if TYPE_CHECKING:
|
38
|
+
StreamHandler = logging.StreamHandler[Any]
|
39
|
+
else:
|
40
|
+
StreamHandler = logging.StreamHandler
|
41
|
+
|
42
|
+
|
43
|
+
class ConsoleHandler(StreamHandler):
|
44
|
+
"""Console handler that allows configurable formatting."""
|
45
|
+
|
46
|
+
def __init__(
|
47
|
+
self,
|
48
|
+
timestamps: bool = False,
|
49
|
+
json: bool = False,
|
50
|
+
colored: bool = True,
|
51
|
+
stream: Optional[TextIO] = None,
|
52
|
+
) -> None:
|
53
|
+
super().__init__(stream)
|
54
|
+
self.timestamps = timestamps
|
55
|
+
self.json = json
|
56
|
+
self.colored = colored
|
57
|
+
|
58
|
+
def emit(self, record: LogRecord) -> None:
|
59
|
+
"""Emit a record."""
|
60
|
+
if self.json:
|
61
|
+
record.message = record.getMessage().replace("\t", "").strip()
|
62
|
+
|
63
|
+
# Check if the message is empty
|
64
|
+
if not record.message:
|
65
|
+
return
|
66
|
+
|
67
|
+
super().emit(record)
|
68
|
+
|
69
|
+
def format(self, record: LogRecord) -> str:
|
70
|
+
"""Format function that adds colors to log level."""
|
71
|
+
seperator = " " * (8 - len(record.levelname))
|
72
|
+
if self.json:
|
73
|
+
log_fmt = "{lvl='%(levelname)s', time='%(asctime)s', msg='%(message)s'}"
|
74
|
+
else:
|
75
|
+
log_fmt = (
|
76
|
+
f"{LOG_COLORS[record.levelname] if self.colored else ''}"
|
77
|
+
f"%(levelname)s {'%(asctime)s' if self.timestamps else ''}"
|
78
|
+
f"{LOG_COLORS['RESET'] if self.colored else ''}"
|
79
|
+
f": {seperator} %(message)s"
|
80
|
+
)
|
81
|
+
formatter = logging.Formatter(log_fmt)
|
82
|
+
return formatter.format(record)
|
83
|
+
|
84
|
+
|
85
|
+
def update_console_handler(level: int, timestamps: bool, colored: bool) -> None:
|
86
|
+
"""Update the logging handler."""
|
87
|
+
for handler in logging.getLogger(LOGGER_NAME).handlers:
|
88
|
+
if isinstance(handler, ConsoleHandler):
|
89
|
+
handler.setLevel(level)
|
90
|
+
handler.timestamps = timestamps
|
91
|
+
handler.colored = colored
|
92
|
+
|
31
93
|
|
32
94
|
# Configure console logger
|
33
|
-
console_handler =
|
34
|
-
|
35
|
-
|
95
|
+
console_handler = ConsoleHandler(
|
96
|
+
timestamps=False,
|
97
|
+
json=False,
|
98
|
+
colored=True,
|
99
|
+
)
|
100
|
+
console_handler.setLevel(logging.INFO)
|
36
101
|
FLOWER_LOGGER.addHandler(console_handler)
|
37
102
|
|
38
103
|
|
@@ -103,11 +168,10 @@ def warn_experimental_feature(name: str) -> None:
|
|
103
168
|
"""Warn the user when they use an experimental feature."""
|
104
169
|
log(
|
105
170
|
WARN,
|
106
|
-
"""
|
107
|
-
EXPERIMENTAL FEATURE: %s
|
171
|
+
"""EXPERIMENTAL FEATURE: %s
|
108
172
|
|
109
|
-
|
110
|
-
|
173
|
+
This is an experimental feature. It could change significantly or be removed
|
174
|
+
entirely in future versions of Flower.
|
111
175
|
""",
|
112
176
|
name,
|
113
177
|
)
|
@@ -117,11 +181,10 @@ def warn_deprecated_feature(name: str) -> None:
|
|
117
181
|
"""Warn the user when they use a deprecated feature."""
|
118
182
|
log(
|
119
183
|
WARN,
|
120
|
-
"""
|
121
|
-
DEPRECATED FEATURE: %s
|
184
|
+
"""DEPRECATED FEATURE: %s
|
122
185
|
|
123
|
-
|
124
|
-
|
186
|
+
This is a deprecated feature. It will be removed
|
187
|
+
entirely in future versions of Flower.
|
125
188
|
""",
|
126
189
|
name,
|
127
190
|
)
|
@@ -0,0 +1,140 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Helper functions to load objects from a reference."""
|
16
|
+
|
17
|
+
|
18
|
+
import ast
|
19
|
+
import importlib
|
20
|
+
from importlib.util import find_spec
|
21
|
+
from typing import Any, Optional, Tuple, Type
|
22
|
+
|
23
|
+
OBJECT_REF_HELP_STR = """
|
24
|
+
\n\nThe object reference string should have the form <module>:<attribute>. Valid
|
25
|
+
examples include `client:app` and `project.package.module:wrapper.app`. It must
|
26
|
+
refer to a module on the PYTHONPATH and the module needs to have the specified
|
27
|
+
attribute.
|
28
|
+
"""
|
29
|
+
|
30
|
+
|
31
|
+
def validate(
|
32
|
+
module_attribute_str: str,
|
33
|
+
) -> Tuple[bool, Optional[str]]:
|
34
|
+
"""Validate object reference.
|
35
|
+
|
36
|
+
The object reference string should have the form <module>:<attribute>. Valid
|
37
|
+
examples include `client:app` and `project.package.module:wrapper.app`. It must
|
38
|
+
refer to a module on the PYTHONPATH and the module needs to have the specified
|
39
|
+
attribute.
|
40
|
+
|
41
|
+
Returns
|
42
|
+
-------
|
43
|
+
Tuple[bool, Optional[str]]
|
44
|
+
A boolean indicating whether an object reference is valid and
|
45
|
+
the reason why it might not be.
|
46
|
+
"""
|
47
|
+
module_str, _, attributes_str = module_attribute_str.partition(":")
|
48
|
+
if not module_str:
|
49
|
+
return (
|
50
|
+
False,
|
51
|
+
f"Missing module in {module_attribute_str}{OBJECT_REF_HELP_STR}",
|
52
|
+
)
|
53
|
+
if not attributes_str:
|
54
|
+
return (
|
55
|
+
False,
|
56
|
+
f"Missing attribute in {module_attribute_str}{OBJECT_REF_HELP_STR}",
|
57
|
+
)
|
58
|
+
|
59
|
+
# Load module
|
60
|
+
module = find_spec(module_str)
|
61
|
+
if module and module.origin:
|
62
|
+
if not _find_attribute_in_module(module.origin, attributes_str):
|
63
|
+
return (
|
64
|
+
False,
|
65
|
+
f"Unable to find attribute {attributes_str} in module {module_str}"
|
66
|
+
f"{OBJECT_REF_HELP_STR}",
|
67
|
+
)
|
68
|
+
return (True, None)
|
69
|
+
|
70
|
+
return (
|
71
|
+
False,
|
72
|
+
f"Unable to load module {module_str}{OBJECT_REF_HELP_STR}",
|
73
|
+
)
|
74
|
+
|
75
|
+
|
76
|
+
def load_app(
|
77
|
+
module_attribute_str: str,
|
78
|
+
error_type: Type[Exception],
|
79
|
+
) -> Any:
|
80
|
+
"""Return the object specified in a module attribute string.
|
81
|
+
|
82
|
+
The module/attribute string should have the form <module>:<attribute>. Valid
|
83
|
+
examples include `client:app` and `project.package.module:wrapper.app`. It must
|
84
|
+
refer to a module on the PYTHONPATH, the module needs to have the specified
|
85
|
+
attribute.
|
86
|
+
"""
|
87
|
+
valid, error_msg = validate(module_attribute_str)
|
88
|
+
if not valid and error_msg:
|
89
|
+
raise error_type(error_msg) from None
|
90
|
+
|
91
|
+
module_str, _, attributes_str = module_attribute_str.partition(":")
|
92
|
+
|
93
|
+
try:
|
94
|
+
module = importlib.import_module(module_str)
|
95
|
+
except ModuleNotFoundError:
|
96
|
+
raise error_type(
|
97
|
+
f"Unable to load module {module_str}{OBJECT_REF_HELP_STR}",
|
98
|
+
) from None
|
99
|
+
|
100
|
+
# Recursively load attribute
|
101
|
+
attribute = module
|
102
|
+
try:
|
103
|
+
for attribute_str in attributes_str.split("."):
|
104
|
+
attribute = getattr(attribute, attribute_str)
|
105
|
+
except AttributeError:
|
106
|
+
raise error_type(
|
107
|
+
f"Unable to load attribute {attributes_str} from module {module_str}"
|
108
|
+
f"{OBJECT_REF_HELP_STR}",
|
109
|
+
) from None
|
110
|
+
|
111
|
+
return attribute
|
112
|
+
|
113
|
+
|
114
|
+
def _find_attribute_in_module(file_path: str, attribute_name: str) -> bool:
|
115
|
+
"""Check if attribute_name exists in module's abstract symbolic tree."""
|
116
|
+
with open(file_path, encoding="utf-8") as file:
|
117
|
+
node = ast.parse(file.read(), filename=file_path)
|
118
|
+
|
119
|
+
for n in ast.walk(node):
|
120
|
+
if isinstance(n, ast.Assign):
|
121
|
+
for target in n.targets:
|
122
|
+
if isinstance(target, ast.Name) and target.id == attribute_name:
|
123
|
+
return True
|
124
|
+
if _is_module_in_all(attribute_name, target, n):
|
125
|
+
return True
|
126
|
+
return False
|
127
|
+
|
128
|
+
|
129
|
+
def _is_module_in_all(attribute_name: str, target: ast.expr, n: ast.Assign) -> bool:
|
130
|
+
"""Now check if attribute_name is in __all__."""
|
131
|
+
if isinstance(target, ast.Name) and target.id == "__all__":
|
132
|
+
if isinstance(n.value, ast.List):
|
133
|
+
for elt in n.value.elts:
|
134
|
+
if isinstance(elt, ast.Str) and elt.s == attribute_name:
|
135
|
+
return True
|
136
|
+
elif isinstance(n.value, ast.Tuple):
|
137
|
+
for elt in n.value.elts:
|
138
|
+
if isinstance(elt, ast.Str) and elt.s == attribute_name:
|
139
|
+
return True
|
140
|
+
return False
|
@@ -15,7 +15,7 @@
|
|
15
15
|
"""Utility functions for performing operations on Numpy NDArrays."""
|
16
16
|
|
17
17
|
|
18
|
-
from typing import Any, List, Tuple
|
18
|
+
from typing import Any, List, Tuple, Union
|
19
19
|
|
20
20
|
import numpy as np
|
21
21
|
from numpy.typing import DTypeLike, NDArray
|
@@ -68,14 +68,14 @@ def parameters_mod(parameters: List[NDArray[Any]], divisor: int) -> List[NDArray
|
|
68
68
|
|
69
69
|
|
70
70
|
def parameters_multiply(
|
71
|
-
parameters: List[NDArray[Any]], multiplier: int
|
71
|
+
parameters: List[NDArray[Any]], multiplier: Union[int, float]
|
72
72
|
) -> List[NDArray[Any]]:
|
73
|
-
"""Multiply parameters by an integer multiplier."""
|
73
|
+
"""Multiply parameters by an integer/float multiplier."""
|
74
74
|
return [parameters[idx] * multiplier for idx in range(len(parameters))]
|
75
75
|
|
76
76
|
|
77
77
|
def parameters_divide(
|
78
|
-
parameters: List[NDArray[Any]], divisor: int
|
78
|
+
parameters: List[NDArray[Any]], divisor: Union[int, float]
|
79
79
|
) -> List[NDArray[Any]]:
|
80
|
-
"""Divide weight by an integer divisor."""
|
80
|
+
"""Divide weight by an integer/float divisor."""
|
81
81
|
return [parameters[idx] / divisor for idx in range(len(parameters))]
|