flwr-nightly 1.8.0.dev20240309__py3-none-any.whl → 1.8.0.dev20240311__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- flwr/cli/flower_toml.py +4 -48
- flwr/cli/new/new.py +6 -3
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -3
- flwr/cli/new/templates/app/pyproject.toml.tpl +1 -1
- flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +2 -2
- flwr/cli/utils.py +14 -1
- flwr/client/app.py +39 -5
- flwr/client/client_app.py +1 -47
- flwr/client/mod/__init__.py +2 -1
- flwr/client/mod/secure_aggregation/__init__.py +2 -0
- flwr/client/mod/secure_aggregation/secagg_mod.py +30 -0
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +73 -57
- flwr/common/grpc.py +3 -3
- flwr/common/logger.py +78 -15
- flwr/common/object_ref.py +140 -0
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +5 -5
- flwr/common/secure_aggregation/secaggplus_constants.py +7 -6
- flwr/common/secure_aggregation/secaggplus_utils.py +15 -15
- flwr/server/compat/app.py +2 -1
- flwr/server/driver/grpc_driver.py +4 -4
- flwr/server/history.py +22 -15
- flwr/server/run_serverapp.py +22 -4
- flwr/server/server.py +27 -23
- flwr/server/server_app.py +1 -47
- flwr/server/server_config.py +9 -0
- flwr/server/strategy/fedavg.py +2 -0
- flwr/server/superlink/fleet/vce/vce_api.py +9 -2
- flwr/server/superlink/state/in_memory_state.py +34 -32
- flwr/server/workflow/__init__.py +3 -0
- flwr/server/workflow/constant.py +32 -0
- flwr/server/workflow/default_workflows.py +52 -57
- flwr/server/workflow/secure_aggregation/__init__.py +24 -0
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +112 -0
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +676 -0
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/METADATA +1 -1
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/RECORD +39 -33
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/entry_points.txt +0 -0
@@ -72,13 +72,14 @@ class SecAggPlusState:
|
|
72
72
|
|
73
73
|
current_stage: str = Stage.UNMASK
|
74
74
|
|
75
|
-
|
75
|
+
nid: int = 0
|
76
76
|
sample_num: int = 0
|
77
77
|
share_num: int = 0
|
78
78
|
threshold: int = 0
|
79
79
|
clipping_range: float = 0.0
|
80
80
|
target_range: int = 0
|
81
81
|
mod_range: int = 0
|
82
|
+
max_weight: float = 0.0
|
82
83
|
|
83
84
|
# Secret key (sk) and public key (pk)
|
84
85
|
sk1: bytes = b""
|
@@ -173,12 +174,13 @@ def secaggplus_mod(
|
|
173
174
|
|
174
175
|
# Execute
|
175
176
|
if state.current_stage == Stage.SETUP:
|
177
|
+
state.nid = msg.metadata.dst_node_id
|
176
178
|
res = _setup(state, configs)
|
177
179
|
elif state.current_stage == Stage.SHARE_KEYS:
|
178
180
|
res = _share_keys(state, configs)
|
179
|
-
elif state.current_stage == Stage.
|
181
|
+
elif state.current_stage == Stage.COLLECT_MASKED_VECTORS:
|
180
182
|
fit = _get_fit_fn(msg, ctxt, call_next)
|
181
|
-
res =
|
183
|
+
res = _collect_masked_vectors(state, configs, fit)
|
182
184
|
elif state.current_stage == Stage.UNMASK:
|
183
185
|
res = _unmask(state, configs)
|
184
186
|
else:
|
@@ -197,7 +199,7 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
|
|
197
199
|
# Check the existence of Config.STAGE
|
198
200
|
if Key.STAGE not in configs:
|
199
201
|
raise KeyError(
|
200
|
-
f"The required key '{Key.STAGE}' is missing from the
|
202
|
+
f"The required key '{Key.STAGE}' is missing from the ConfigsRecord."
|
201
203
|
)
|
202
204
|
|
203
205
|
# Check the value type of the Config.STAGE
|
@@ -213,7 +215,7 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
|
|
213
215
|
if current_stage != Stage.UNMASK:
|
214
216
|
log(WARNING, "Restart from the setup stage")
|
215
217
|
# If stage is not "setup",
|
216
|
-
# the stage from
|
218
|
+
# the stage from configs should be the expected next stage
|
217
219
|
else:
|
218
220
|
stages = Stage.all()
|
219
221
|
expected_next_stage = stages[(stages.index(current_stage) + 1) % len(stages)]
|
@@ -227,11 +229,10 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
|
|
227
229
|
# pylint: disable-next=too-many-branches
|
228
230
|
def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
229
231
|
"""Check the validity of the configs."""
|
230
|
-
# Check
|
232
|
+
# Check configs for the setup stage
|
231
233
|
if stage == Stage.SETUP:
|
232
234
|
key_type_pairs = [
|
233
235
|
(Key.SAMPLE_NUMBER, int),
|
234
|
-
(Key.SECURE_ID, int),
|
235
236
|
(Key.SHARE_NUMBER, int),
|
236
237
|
(Key.THRESHOLD, int),
|
237
238
|
(Key.CLIPPING_RANGE, float),
|
@@ -242,7 +243,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
242
243
|
if key not in configs:
|
243
244
|
raise KeyError(
|
244
245
|
f"Stage {Stage.SETUP}: the required key '{key}' is "
|
245
|
-
"missing from the
|
246
|
+
"missing from the ConfigsRecord."
|
246
247
|
)
|
247
248
|
# Bool is a subclass of int in Python,
|
248
249
|
# so `isinstance(v, int)` will return True even if v is a boolean.
|
@@ -265,7 +266,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
265
266
|
f"Stage {Stage.SHARE_KEYS}: "
|
266
267
|
f"the value for the key '{key}' must be a list of two bytes."
|
267
268
|
)
|
268
|
-
elif stage == Stage.
|
269
|
+
elif stage == Stage.COLLECT_MASKED_VECTORS:
|
269
270
|
key_type_pairs = [
|
270
271
|
(Key.CIPHERTEXT_LIST, bytes),
|
271
272
|
(Key.SOURCE_LIST, int),
|
@@ -273,9 +274,9 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
273
274
|
for key, expected_type in key_type_pairs:
|
274
275
|
if key not in configs:
|
275
276
|
raise KeyError(
|
276
|
-
f"Stage {Stage.
|
277
|
+
f"Stage {Stage.COLLECT_MASKED_VECTORS}: "
|
277
278
|
f"the required key '{key}' is "
|
278
|
-
"missing from the
|
279
|
+
"missing from the ConfigsRecord."
|
279
280
|
)
|
280
281
|
if not isinstance(configs[key], list) or any(
|
281
282
|
elm
|
@@ -284,21 +285,21 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
284
285
|
if type(elm) is not expected_type
|
285
286
|
):
|
286
287
|
raise TypeError(
|
287
|
-
f"Stage {Stage.
|
288
|
+
f"Stage {Stage.COLLECT_MASKED_VECTORS}: "
|
288
289
|
f"the value for the key '{key}' "
|
289
290
|
f"must be of type List[{expected_type.__name__}]"
|
290
291
|
)
|
291
292
|
elif stage == Stage.UNMASK:
|
292
293
|
key_type_pairs = [
|
293
|
-
(Key.
|
294
|
-
(Key.
|
294
|
+
(Key.ACTIVE_NODE_ID_LIST, int),
|
295
|
+
(Key.DEAD_NODE_ID_LIST, int),
|
295
296
|
]
|
296
297
|
for key, expected_type in key_type_pairs:
|
297
298
|
if key not in configs:
|
298
299
|
raise KeyError(
|
299
300
|
f"Stage {Stage.UNMASK}: "
|
300
301
|
f"the required key '{key}' is "
|
301
|
-
"missing from the
|
302
|
+
"missing from the ConfigsRecord."
|
302
303
|
)
|
303
304
|
if not isinstance(configs[key], list) or any(
|
304
305
|
elm
|
@@ -321,20 +322,20 @@ def _setup(
|
|
321
322
|
# Assigning parameter values to object fields
|
322
323
|
sec_agg_param_dict = configs
|
323
324
|
state.sample_num = cast(int, sec_agg_param_dict[Key.SAMPLE_NUMBER])
|
324
|
-
|
325
|
-
log(INFO, "Client %d: starting stage 0...", state.sid)
|
325
|
+
log(INFO, "Node %d: starting stage 0...", state.nid)
|
326
326
|
|
327
327
|
state.share_num = cast(int, sec_agg_param_dict[Key.SHARE_NUMBER])
|
328
328
|
state.threshold = cast(int, sec_agg_param_dict[Key.THRESHOLD])
|
329
329
|
state.clipping_range = cast(float, sec_agg_param_dict[Key.CLIPPING_RANGE])
|
330
330
|
state.target_range = cast(int, sec_agg_param_dict[Key.TARGET_RANGE])
|
331
331
|
state.mod_range = cast(int, sec_agg_param_dict[Key.MOD_RANGE])
|
332
|
+
state.max_weight = cast(float, sec_agg_param_dict[Key.MAX_WEIGHT])
|
332
333
|
|
333
|
-
# Dictionaries containing
|
334
|
+
# Dictionaries containing node IDs as keys
|
334
335
|
# and their respective secret shares as values.
|
335
336
|
state.rd_seed_share_dict = {}
|
336
337
|
state.sk1_share_dict = {}
|
337
|
-
# Dictionary containing
|
338
|
+
# Dictionary containing node IDs as keys
|
338
339
|
# and their respective shared secrets (with this client) as values.
|
339
340
|
state.ss2_dict = {}
|
340
341
|
|
@@ -346,7 +347,7 @@ def _setup(
|
|
346
347
|
|
347
348
|
state.sk1, state.pk1 = private_key_to_bytes(sk1), public_key_to_bytes(pk1)
|
348
349
|
state.sk2, state.pk2 = private_key_to_bytes(sk2), public_key_to_bytes(pk2)
|
349
|
-
log(INFO, "
|
350
|
+
log(INFO, "Node %d: stage 0 completes. uploading public keys...", state.nid)
|
350
351
|
return {Key.PUBLIC_KEY_1: state.pk1, Key.PUBLIC_KEY_2: state.pk2}
|
351
352
|
|
352
353
|
|
@@ -356,7 +357,7 @@ def _share_keys(
|
|
356
357
|
) -> Dict[str, ConfigsRecordValues]:
|
357
358
|
named_bytes_tuples = cast(Dict[str, Tuple[bytes, bytes]], configs)
|
358
359
|
key_dict = {int(sid): (pk1, pk2) for sid, (pk1, pk2) in named_bytes_tuples.items()}
|
359
|
-
log(INFO, "
|
360
|
+
log(INFO, "Node %d: starting stage 1...", state.nid)
|
360
361
|
state.public_keys_dict = key_dict
|
361
362
|
|
362
363
|
# Check if the size is larger than threshold
|
@@ -373,8 +374,8 @@ def _share_keys(
|
|
373
374
|
|
374
375
|
# Check if public keys of this client are correct in the dictionary
|
375
376
|
if (
|
376
|
-
state.public_keys_dict[state.
|
377
|
-
or state.public_keys_dict[state.
|
377
|
+
state.public_keys_dict[state.nid][0] != state.pk1
|
378
|
+
or state.public_keys_dict[state.nid][1] != state.pk2
|
378
379
|
):
|
379
380
|
raise ValueError(
|
380
381
|
"Own public keys are displayed in dict incorrectly, should not happen!"
|
@@ -390,35 +391,35 @@ def _share_keys(
|
|
390
391
|
srcs, dsts, ciphertexts = [], [], []
|
391
392
|
|
392
393
|
# Distribute shares
|
393
|
-
for idx, (
|
394
|
-
if
|
395
|
-
state.rd_seed_share_dict[state.
|
396
|
-
state.sk1_share_dict[state.
|
394
|
+
for idx, (nid, (_, pk2)) in enumerate(state.public_keys_dict.items()):
|
395
|
+
if nid == state.nid:
|
396
|
+
state.rd_seed_share_dict[state.nid] = b_shares[idx]
|
397
|
+
state.sk1_share_dict[state.nid] = sk1_shares[idx]
|
397
398
|
else:
|
398
399
|
shared_key = generate_shared_key(
|
399
400
|
bytes_to_private_key(state.sk2),
|
400
401
|
bytes_to_public_key(pk2),
|
401
402
|
)
|
402
|
-
state.ss2_dict[
|
403
|
+
state.ss2_dict[nid] = shared_key
|
403
404
|
plaintext = share_keys_plaintext_concat(
|
404
|
-
state.
|
405
|
+
state.nid, nid, b_shares[idx], sk1_shares[idx]
|
405
406
|
)
|
406
407
|
ciphertext = encrypt(shared_key, plaintext)
|
407
|
-
srcs.append(state.
|
408
|
-
dsts.append(
|
408
|
+
srcs.append(state.nid)
|
409
|
+
dsts.append(nid)
|
409
410
|
ciphertexts.append(ciphertext)
|
410
411
|
|
411
|
-
log(INFO, "
|
412
|
+
log(INFO, "Node %d: stage 1 completes. uploading key shares...", state.nid)
|
412
413
|
return {Key.DESTINATION_LIST: dsts, Key.CIPHERTEXT_LIST: ciphertexts}
|
413
414
|
|
414
415
|
|
415
416
|
# pylint: disable-next=too-many-locals
|
416
|
-
def
|
417
|
+
def _collect_masked_vectors(
|
417
418
|
state: SecAggPlusState,
|
418
419
|
configs: ConfigsRecord,
|
419
420
|
fit: Callable[[], FitRes],
|
420
421
|
) -> Dict[str, ConfigsRecordValues]:
|
421
|
-
log(INFO, "
|
422
|
+
log(INFO, "Node %d: starting stage 2...", state.nid)
|
422
423
|
available_clients: List[int] = []
|
423
424
|
ciphertexts = cast(List[bytes], configs[Key.CIPHERTEXT_LIST])
|
424
425
|
srcs = cast(List[int], configs[Key.SOURCE_LIST])
|
@@ -435,29 +436,45 @@ def _collect_masked_input(
|
|
435
436
|
available_clients.append(src)
|
436
437
|
if src != actual_src:
|
437
438
|
raise ValueError(
|
438
|
-
f"
|
439
|
+
f"Node {state.nid}: received ciphertext "
|
439
440
|
f"from {actual_src} instead of {src}."
|
440
441
|
)
|
441
|
-
if dst != state.
|
442
|
+
if dst != state.nid:
|
442
443
|
raise ValueError(
|
443
|
-
f"
|
444
|
-
f"for
|
444
|
+
f"Node {state.nid}: received an encrypted message"
|
445
|
+
f"for Node {dst} from Node {src}."
|
445
446
|
)
|
446
447
|
state.rd_seed_share_dict[src] = rd_seed_share
|
447
448
|
state.sk1_share_dict[src] = sk1_share
|
448
449
|
|
449
450
|
# Fit client
|
450
451
|
fit_res = fit()
|
451
|
-
|
452
|
+
if len(fit_res.metrics) > 0:
|
453
|
+
log(
|
454
|
+
WARNING,
|
455
|
+
"The metrics in FitRes will not be preserved or sent to the server.",
|
456
|
+
)
|
457
|
+
ratio = fit_res.num_examples / state.max_weight
|
458
|
+
if ratio > 1:
|
459
|
+
log(
|
460
|
+
WARNING,
|
461
|
+
"Potential overflow warning: the provided weight (%s) exceeds the specified"
|
462
|
+
" max_weight (%s). This may lead to overflow issues.",
|
463
|
+
fit_res.num_examples,
|
464
|
+
state.max_weight,
|
465
|
+
)
|
466
|
+
q_ratio = round(ratio * state.target_range)
|
467
|
+
dq_ratio = q_ratio / state.target_range
|
468
|
+
|
452
469
|
parameters = parameters_to_ndarrays(fit_res.parameters)
|
470
|
+
parameters = parameters_multiply(parameters, dq_ratio)
|
453
471
|
|
454
472
|
# Quantize parameter update (vector)
|
455
473
|
quantized_parameters = quantize(
|
456
474
|
parameters, state.clipping_range, state.target_range
|
457
475
|
)
|
458
476
|
|
459
|
-
quantized_parameters =
|
460
|
-
quantized_parameters = factor_combine(parameters_factor, quantized_parameters)
|
477
|
+
quantized_parameters = factor_combine(q_ratio, quantized_parameters)
|
461
478
|
|
462
479
|
dimensions_list: List[Tuple[int, ...]] = [a.shape for a in quantized_parameters]
|
463
480
|
|
@@ -465,14 +482,14 @@ def _collect_masked_input(
|
|
465
482
|
private_mask = pseudo_rand_gen(state.rd_seed, state.mod_range, dimensions_list)
|
466
483
|
quantized_parameters = parameters_addition(quantized_parameters, private_mask)
|
467
484
|
|
468
|
-
for
|
485
|
+
for node_id in available_clients:
|
469
486
|
# Add pairwise masks
|
470
487
|
shared_key = generate_shared_key(
|
471
488
|
bytes_to_private_key(state.sk1),
|
472
|
-
bytes_to_public_key(state.public_keys_dict[
|
489
|
+
bytes_to_public_key(state.public_keys_dict[node_id][0]),
|
473
490
|
)
|
474
491
|
pairwise_mask = pseudo_rand_gen(shared_key, state.mod_range, dimensions_list)
|
475
|
-
if state.
|
492
|
+
if state.nid > node_id:
|
476
493
|
quantized_parameters = parameters_addition(
|
477
494
|
quantized_parameters, pairwise_mask
|
478
495
|
)
|
@@ -483,7 +500,7 @@ def _collect_masked_input(
|
|
483
500
|
|
484
501
|
# Take mod of final weight update vector and return to server
|
485
502
|
quantized_parameters = parameters_mod(quantized_parameters, state.mod_range)
|
486
|
-
log(INFO, "
|
503
|
+
log(INFO, "Node %d: stage 2 completed, uploading masked parameters...", state.nid)
|
487
504
|
return {
|
488
505
|
Key.MASKED_PARAMETERS: [ndarray_to_bytes(arr) for arr in quantized_parameters]
|
489
506
|
}
|
@@ -492,20 +509,19 @@ def _collect_masked_input(
|
|
492
509
|
def _unmask(
|
493
510
|
state: SecAggPlusState, configs: ConfigsRecord
|
494
511
|
) -> Dict[str, ConfigsRecordValues]:
|
495
|
-
log(INFO, "
|
512
|
+
log(INFO, "Node %d: starting stage 3...", state.nid)
|
496
513
|
|
497
|
-
|
498
|
-
|
499
|
-
# Send private mask seed share for every avaliable client (including
|
514
|
+
active_nids = cast(List[int], configs[Key.ACTIVE_NODE_ID_LIST])
|
515
|
+
dead_nids = cast(List[int], configs[Key.DEAD_NODE_ID_LIST])
|
516
|
+
# Send private mask seed share for every avaliable client (including itself)
|
500
517
|
# Send first private key share for building pairwise mask for every dropped client
|
501
|
-
if len(
|
518
|
+
if len(active_nids) < state.threshold:
|
502
519
|
raise ValueError("Available neighbours number smaller than threshold")
|
503
520
|
|
504
|
-
|
505
|
-
|
506
|
-
shares += [state.rd_seed_share_dict[
|
507
|
-
|
508
|
-
shares += [state.sk1_share_dict[sid] for sid in dead_sids]
|
521
|
+
all_nids, shares = [], []
|
522
|
+
all_nids = active_nids + dead_nids
|
523
|
+
shares += [state.rd_seed_share_dict[nid] for nid in active_nids]
|
524
|
+
shares += [state.sk1_share_dict[nid] for nid in dead_nids]
|
509
525
|
|
510
|
-
log(INFO, "
|
511
|
-
return {Key.
|
526
|
+
log(INFO, "Node %d: stage 3 completes. uploading key shares...", state.nid)
|
527
|
+
return {Key.NODE_ID_LIST: all_nids, Key.SHARE_LIST: shares}
|
flwr/common/grpc.py
CHANGED
@@ -15,7 +15,7 @@
|
|
15
15
|
"""Utility functions for gRPC."""
|
16
16
|
|
17
17
|
|
18
|
-
from logging import
|
18
|
+
from logging import DEBUG
|
19
19
|
from typing import Optional
|
20
20
|
|
21
21
|
import grpc
|
@@ -49,12 +49,12 @@ def create_channel(
|
|
49
49
|
|
50
50
|
if insecure:
|
51
51
|
channel = grpc.insecure_channel(server_address, options=channel_options)
|
52
|
-
log(
|
52
|
+
log(DEBUG, "Opened insecure gRPC connection (no certificates were passed)")
|
53
53
|
else:
|
54
54
|
ssl_channel_credentials = grpc.ssl_channel_credentials(root_certificates)
|
55
55
|
channel = grpc.secure_channel(
|
56
56
|
server_address, ssl_channel_credentials, options=channel_options
|
57
57
|
)
|
58
|
-
log(
|
58
|
+
log(DEBUG, "Opened secure gRPC connection using certificates")
|
59
59
|
|
60
60
|
return channel
|
flwr/common/logger.py
CHANGED
@@ -18,21 +18,86 @@
|
|
18
18
|
import logging
|
19
19
|
from logging import WARN, LogRecord
|
20
20
|
from logging.handlers import HTTPHandler
|
21
|
-
from typing import Any, Dict, Optional, Tuple
|
21
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, TextIO, Tuple
|
22
22
|
|
23
23
|
# Create logger
|
24
24
|
LOGGER_NAME = "flwr"
|
25
25
|
FLOWER_LOGGER = logging.getLogger(LOGGER_NAME)
|
26
26
|
FLOWER_LOGGER.setLevel(logging.DEBUG)
|
27
27
|
|
28
|
-
|
29
|
-
"
|
30
|
-
|
28
|
+
LOG_COLORS = {
|
29
|
+
"DEBUG": "\033[94m", # Blue
|
30
|
+
"INFO": "\033[92m", # Green
|
31
|
+
"WARNING": "\033[93m", # Yellow
|
32
|
+
"ERROR": "\033[91m", # Red
|
33
|
+
"CRITICAL": "\033[95m", # Magenta
|
34
|
+
"RESET": "\033[0m", # Reset to default
|
35
|
+
}
|
36
|
+
|
37
|
+
if TYPE_CHECKING:
|
38
|
+
StreamHandler = logging.StreamHandler[Any]
|
39
|
+
else:
|
40
|
+
StreamHandler = logging.StreamHandler
|
41
|
+
|
42
|
+
|
43
|
+
class ConsoleHandler(StreamHandler):
|
44
|
+
"""Console handler that allows configurable formatting."""
|
45
|
+
|
46
|
+
def __init__(
|
47
|
+
self,
|
48
|
+
timestamps: bool = False,
|
49
|
+
json: bool = False,
|
50
|
+
colored: bool = True,
|
51
|
+
stream: Optional[TextIO] = None,
|
52
|
+
) -> None:
|
53
|
+
super().__init__(stream)
|
54
|
+
self.timestamps = timestamps
|
55
|
+
self.json = json
|
56
|
+
self.colored = colored
|
57
|
+
|
58
|
+
def emit(self, record: LogRecord) -> None:
|
59
|
+
"""Emit a record."""
|
60
|
+
if self.json:
|
61
|
+
record.message = record.getMessage().replace("\t", "").strip()
|
62
|
+
|
63
|
+
# Check if the message is empty
|
64
|
+
if not record.message:
|
65
|
+
return
|
66
|
+
|
67
|
+
super().emit(record)
|
68
|
+
|
69
|
+
def format(self, record: LogRecord) -> str:
|
70
|
+
"""Format function that adds colors to log level."""
|
71
|
+
seperator = " " * (8 - len(record.levelname))
|
72
|
+
if self.json:
|
73
|
+
log_fmt = "{lvl='%(levelname)s', time='%(asctime)s', msg='%(message)s'}"
|
74
|
+
else:
|
75
|
+
log_fmt = (
|
76
|
+
f"{LOG_COLORS[record.levelname] if self.colored else ''}"
|
77
|
+
f"%(levelname)s {'%(asctime)s' if self.timestamps else ''}"
|
78
|
+
f"{LOG_COLORS['RESET'] if self.colored else ''}"
|
79
|
+
f": {seperator} %(message)s"
|
80
|
+
)
|
81
|
+
formatter = logging.Formatter(log_fmt)
|
82
|
+
return formatter.format(record)
|
83
|
+
|
84
|
+
|
85
|
+
def update_console_handler(level: int, timestamps: bool, colored: bool) -> None:
|
86
|
+
"""Update the logging handler."""
|
87
|
+
for handler in logging.getLogger(LOGGER_NAME).handlers:
|
88
|
+
if isinstance(handler, ConsoleHandler):
|
89
|
+
handler.setLevel(level)
|
90
|
+
handler.timestamps = timestamps
|
91
|
+
handler.colored = colored
|
92
|
+
|
31
93
|
|
32
94
|
# Configure console logger
|
33
|
-
console_handler =
|
34
|
-
|
35
|
-
|
95
|
+
console_handler = ConsoleHandler(
|
96
|
+
timestamps=False,
|
97
|
+
json=False,
|
98
|
+
colored=True,
|
99
|
+
)
|
100
|
+
console_handler.setLevel(logging.INFO)
|
36
101
|
FLOWER_LOGGER.addHandler(console_handler)
|
37
102
|
|
38
103
|
|
@@ -103,11 +168,10 @@ def warn_experimental_feature(name: str) -> None:
|
|
103
168
|
"""Warn the user when they use an experimental feature."""
|
104
169
|
log(
|
105
170
|
WARN,
|
106
|
-
"""
|
107
|
-
EXPERIMENTAL FEATURE: %s
|
171
|
+
"""EXPERIMENTAL FEATURE: %s
|
108
172
|
|
109
|
-
|
110
|
-
|
173
|
+
This is an experimental feature. It could change significantly or be removed
|
174
|
+
entirely in future versions of Flower.
|
111
175
|
""",
|
112
176
|
name,
|
113
177
|
)
|
@@ -117,11 +181,10 @@ def warn_deprecated_feature(name: str) -> None:
|
|
117
181
|
"""Warn the user when they use a deprecated feature."""
|
118
182
|
log(
|
119
183
|
WARN,
|
120
|
-
"""
|
121
|
-
DEPRECATED FEATURE: %s
|
184
|
+
"""DEPRECATED FEATURE: %s
|
122
185
|
|
123
|
-
|
124
|
-
|
186
|
+
This is a deprecated feature. It will be removed
|
187
|
+
entirely in future versions of Flower.
|
125
188
|
""",
|
126
189
|
name,
|
127
190
|
)
|
@@ -0,0 +1,140 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Helper functions to load objects from a reference."""
|
16
|
+
|
17
|
+
|
18
|
+
import ast
|
19
|
+
import importlib
|
20
|
+
from importlib.util import find_spec
|
21
|
+
from typing import Any, Optional, Tuple, Type
|
22
|
+
|
23
|
+
OBJECT_REF_HELP_STR = """
|
24
|
+
\n\nThe object reference string should have the form <module>:<attribute>. Valid
|
25
|
+
examples include `client:app` and `project.package.module:wrapper.app`. It must
|
26
|
+
refer to a module on the PYTHONPATH and the module needs to have the specified
|
27
|
+
attribute.
|
28
|
+
"""
|
29
|
+
|
30
|
+
|
31
|
+
def validate(
|
32
|
+
module_attribute_str: str,
|
33
|
+
) -> Tuple[bool, Optional[str]]:
|
34
|
+
"""Validate object reference.
|
35
|
+
|
36
|
+
The object reference string should have the form <module>:<attribute>. Valid
|
37
|
+
examples include `client:app` and `project.package.module:wrapper.app`. It must
|
38
|
+
refer to a module on the PYTHONPATH and the module needs to have the specified
|
39
|
+
attribute.
|
40
|
+
|
41
|
+
Returns
|
42
|
+
-------
|
43
|
+
Tuple[bool, Optional[str]]
|
44
|
+
A boolean indicating whether an object reference is valid and
|
45
|
+
the reason why it might not be.
|
46
|
+
"""
|
47
|
+
module_str, _, attributes_str = module_attribute_str.partition(":")
|
48
|
+
if not module_str:
|
49
|
+
return (
|
50
|
+
False,
|
51
|
+
f"Missing module in {module_attribute_str}{OBJECT_REF_HELP_STR}",
|
52
|
+
)
|
53
|
+
if not attributes_str:
|
54
|
+
return (
|
55
|
+
False,
|
56
|
+
f"Missing attribute in {module_attribute_str}{OBJECT_REF_HELP_STR}",
|
57
|
+
)
|
58
|
+
|
59
|
+
# Load module
|
60
|
+
module = find_spec(module_str)
|
61
|
+
if module and module.origin:
|
62
|
+
if not _find_attribute_in_module(module.origin, attributes_str):
|
63
|
+
return (
|
64
|
+
False,
|
65
|
+
f"Unable to find attribute {attributes_str} in module {module_str}"
|
66
|
+
f"{OBJECT_REF_HELP_STR}",
|
67
|
+
)
|
68
|
+
return (True, None)
|
69
|
+
|
70
|
+
return (
|
71
|
+
False,
|
72
|
+
f"Unable to load module {module_str}{OBJECT_REF_HELP_STR}",
|
73
|
+
)
|
74
|
+
|
75
|
+
|
76
|
+
def load_app(
|
77
|
+
module_attribute_str: str,
|
78
|
+
error_type: Type[Exception],
|
79
|
+
) -> Any:
|
80
|
+
"""Return the object specified in a module attribute string.
|
81
|
+
|
82
|
+
The module/attribute string should have the form <module>:<attribute>. Valid
|
83
|
+
examples include `client:app` and `project.package.module:wrapper.app`. It must
|
84
|
+
refer to a module on the PYTHONPATH, the module needs to have the specified
|
85
|
+
attribute.
|
86
|
+
"""
|
87
|
+
valid, error_msg = validate(module_attribute_str)
|
88
|
+
if not valid and error_msg:
|
89
|
+
raise error_type(error_msg) from None
|
90
|
+
|
91
|
+
module_str, _, attributes_str = module_attribute_str.partition(":")
|
92
|
+
|
93
|
+
try:
|
94
|
+
module = importlib.import_module(module_str)
|
95
|
+
except ModuleNotFoundError:
|
96
|
+
raise error_type(
|
97
|
+
f"Unable to load module {module_str}{OBJECT_REF_HELP_STR}",
|
98
|
+
) from None
|
99
|
+
|
100
|
+
# Recursively load attribute
|
101
|
+
attribute = module
|
102
|
+
try:
|
103
|
+
for attribute_str in attributes_str.split("."):
|
104
|
+
attribute = getattr(attribute, attribute_str)
|
105
|
+
except AttributeError:
|
106
|
+
raise error_type(
|
107
|
+
f"Unable to load attribute {attributes_str} from module {module_str}"
|
108
|
+
f"{OBJECT_REF_HELP_STR}",
|
109
|
+
) from None
|
110
|
+
|
111
|
+
return attribute
|
112
|
+
|
113
|
+
|
114
|
+
def _find_attribute_in_module(file_path: str, attribute_name: str) -> bool:
|
115
|
+
"""Check if attribute_name exists in module's abstract symbolic tree."""
|
116
|
+
with open(file_path, encoding="utf-8") as file:
|
117
|
+
node = ast.parse(file.read(), filename=file_path)
|
118
|
+
|
119
|
+
for n in ast.walk(node):
|
120
|
+
if isinstance(n, ast.Assign):
|
121
|
+
for target in n.targets:
|
122
|
+
if isinstance(target, ast.Name) and target.id == attribute_name:
|
123
|
+
return True
|
124
|
+
if _is_module_in_all(attribute_name, target, n):
|
125
|
+
return True
|
126
|
+
return False
|
127
|
+
|
128
|
+
|
129
|
+
def _is_module_in_all(attribute_name: str, target: ast.expr, n: ast.Assign) -> bool:
|
130
|
+
"""Now check if attribute_name is in __all__."""
|
131
|
+
if isinstance(target, ast.Name) and target.id == "__all__":
|
132
|
+
if isinstance(n.value, ast.List):
|
133
|
+
for elt in n.value.elts:
|
134
|
+
if isinstance(elt, ast.Str) and elt.s == attribute_name:
|
135
|
+
return True
|
136
|
+
elif isinstance(n.value, ast.Tuple):
|
137
|
+
for elt in n.value.elts:
|
138
|
+
if isinstance(elt, ast.Str) and elt.s == attribute_name:
|
139
|
+
return True
|
140
|
+
return False
|
@@ -15,7 +15,7 @@
|
|
15
15
|
"""Utility functions for performing operations on Numpy NDArrays."""
|
16
16
|
|
17
17
|
|
18
|
-
from typing import Any, List, Tuple
|
18
|
+
from typing import Any, List, Tuple, Union
|
19
19
|
|
20
20
|
import numpy as np
|
21
21
|
from numpy.typing import DTypeLike, NDArray
|
@@ -68,14 +68,14 @@ def parameters_mod(parameters: List[NDArray[Any]], divisor: int) -> List[NDArray
|
|
68
68
|
|
69
69
|
|
70
70
|
def parameters_multiply(
|
71
|
-
parameters: List[NDArray[Any]], multiplier: int
|
71
|
+
parameters: List[NDArray[Any]], multiplier: Union[int, float]
|
72
72
|
) -> List[NDArray[Any]]:
|
73
|
-
"""Multiply parameters by an integer multiplier."""
|
73
|
+
"""Multiply parameters by an integer/float multiplier."""
|
74
74
|
return [parameters[idx] * multiplier for idx in range(len(parameters))]
|
75
75
|
|
76
76
|
|
77
77
|
def parameters_divide(
|
78
|
-
parameters: List[NDArray[Any]], divisor: int
|
78
|
+
parameters: List[NDArray[Any]], divisor: Union[int, float]
|
79
79
|
) -> List[NDArray[Any]]:
|
80
|
-
"""Divide weight by an integer divisor."""
|
80
|
+
"""Divide weight by an integer/float divisor."""
|
81
81
|
return [parameters[idx] / divisor for idx in range(len(parameters))]
|