flwr 1.15.2__py3-none-any.whl → 1.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/build.py +2 -0
- flwr/cli/log.py +20 -21
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +5 -9
- flwr/client/app.py +6 -4
- flwr/client/client_app.py +260 -86
- flwr/client/clientapp/app.py +6 -2
- flwr/client/grpc_client/connection.py +24 -21
- flwr/client/message_handler/message_handler.py +28 -28
- flwr/client/mod/__init__.py +2 -2
- flwr/client/mod/centraldp_mods.py +7 -7
- flwr/client/mod/comms_mods.py +16 -22
- flwr/client/mod/localdp_mod.py +4 -4
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
- flwr/client/rest_client/connection.py +4 -6
- flwr/client/run_info_store.py +2 -2
- flwr/client/supernode/__init__.py +0 -2
- flwr/client/supernode/app.py +1 -11
- flwr/common/__init__.py +12 -4
- flwr/common/address.py +35 -0
- flwr/common/args.py +8 -2
- flwr/common/auth_plugin/auth_plugin.py +2 -1
- flwr/common/config.py +4 -4
- flwr/common/constant.py +16 -0
- flwr/common/context.py +4 -4
- flwr/common/event_log_plugin/__init__.py +22 -0
- flwr/common/event_log_plugin/event_log_plugin.py +60 -0
- flwr/common/grpc.py +1 -1
- flwr/common/logger.py +2 -2
- flwr/common/message.py +338 -102
- flwr/common/object_ref.py +0 -10
- flwr/common/record/__init__.py +8 -4
- flwr/common/record/arrayrecord.py +626 -0
- flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
- flwr/common/record/conversion_utils.py +9 -18
- flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
- flwr/common/record/recorddict.py +288 -0
- flwr/common/recorddict_compat.py +410 -0
- flwr/common/secure_aggregation/quantization.py +5 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/serde.py +67 -190
- flwr/common/telemetry.py +0 -10
- flwr/common/typing.py +44 -8
- flwr/proto/exec_pb2.py +3 -3
- flwr/proto/exec_pb2.pyi +3 -3
- flwr/proto/message_pb2.py +12 -12
- flwr/proto/message_pb2.pyi +9 -9
- flwr/proto/recorddict_pb2.py +70 -0
- flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
- flwr/proto/run_pb2.py +31 -31
- flwr/proto/run_pb2.pyi +3 -3
- flwr/server/__init__.py +3 -1
- flwr/server/app.py +74 -3
- flwr/server/compat/__init__.py +2 -2
- flwr/server/compat/app.py +15 -12
- flwr/server/compat/app_utils.py +26 -18
- flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +41 -41
- flwr/server/fleet_event_log_interceptor.py +94 -0
- flwr/server/{driver → grid}/__init__.py +8 -7
- flwr/server/{driver/driver.py → grid/grid.py} +48 -19
- flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +88 -56
- flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +41 -54
- flwr/server/run_serverapp.py +6 -17
- flwr/server/server_app.py +126 -33
- flwr/server/serverapp/app.py +10 -10
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +8 -12
- flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
- flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
- flwr/server/superlink/fleet/vce/vce_api.py +33 -38
- flwr/server/superlink/linkstate/in_memory_linkstate.py +171 -132
- flwr/server/superlink/linkstate/linkstate.py +51 -64
- flwr/server/superlink/linkstate/sqlite_linkstate.py +253 -285
- flwr/server/superlink/linkstate/utils.py +171 -133
- flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +27 -29
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
- flwr/server/typing.py +3 -3
- flwr/server/utils/__init__.py +2 -2
- flwr/server/utils/validator.py +53 -68
- flwr/server/workflow/default_workflows.py +52 -58
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
- flwr/simulation/app.py +2 -2
- flwr/simulation/ray_transport/ray_actor.py +4 -2
- flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
- flwr/simulation/run_simulation.py +15 -15
- flwr/superexec/app.py +0 -14
- flwr/superexec/deployment.py +4 -4
- flwr/superexec/exec_event_log_interceptor.py +135 -0
- flwr/superexec/exec_grpc.py +10 -4
- flwr/superexec/exec_servicer.py +6 -6
- flwr/superexec/exec_user_auth_interceptor.py +22 -4
- flwr/superexec/executor.py +3 -3
- flwr/superexec/simulation.py +3 -3
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/METADATA +5 -5
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/RECORD +111 -112
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -3
- flwr/client/message_handler/task_handler.py +0 -37
- flwr/common/record/parametersrecord.py +0 -204
- flwr/common/record/recordset.py +0 -202
- flwr/common/recordset_compat.py +0 -418
- flwr/proto/recordset_pb2.py +0 -70
- flwr/proto/task_pb2.py +0 -33
- flwr/proto/task_pb2.pyi +0 -100
- flwr/proto/task_pb2_grpc.py +0 -4
- flwr/proto/task_pb2_grpc.pyi +0 -4
- /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
- /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
|
@@ -20,15 +20,15 @@ from dataclasses import dataclass, field
|
|
|
20
20
|
from logging import DEBUG, ERROR, INFO, WARN
|
|
21
21
|
from typing import Optional, Union, cast
|
|
22
22
|
|
|
23
|
-
import flwr.common.
|
|
23
|
+
import flwr.common.recorddict_compat as compat
|
|
24
24
|
from flwr.common import (
|
|
25
|
-
|
|
25
|
+
ConfigRecord,
|
|
26
26
|
Context,
|
|
27
27
|
FitRes,
|
|
28
28
|
Message,
|
|
29
29
|
MessageType,
|
|
30
30
|
NDArrays,
|
|
31
|
-
|
|
31
|
+
RecordDict,
|
|
32
32
|
bytes_to_ndarray,
|
|
33
33
|
log,
|
|
34
34
|
ndarrays_to_parameters,
|
|
@@ -55,7 +55,7 @@ from flwr.common.secure_aggregation.secaggplus_constants import (
|
|
|
55
55
|
from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen
|
|
56
56
|
from flwr.server.client_proxy import ClientProxy
|
|
57
57
|
from flwr.server.compat.legacy_context import LegacyContext
|
|
58
|
-
from flwr.server.
|
|
58
|
+
from flwr.server.grid import Grid
|
|
59
59
|
|
|
60
60
|
from ..constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD
|
|
61
61
|
from ..constant import Key as WorkflowKey
|
|
@@ -66,7 +66,7 @@ class WorkflowState: # pylint: disable=R0902
|
|
|
66
66
|
"""The state of the SecAgg+ protocol."""
|
|
67
67
|
|
|
68
68
|
nid_to_proxies: dict[int, ClientProxy] = field(default_factory=dict)
|
|
69
|
-
nid_to_fitins: dict[int,
|
|
69
|
+
nid_to_fitins: dict[int, RecordDict] = field(default_factory=dict)
|
|
70
70
|
sampled_node_ids: set[int] = field(default_factory=set)
|
|
71
71
|
active_node_ids: set[int] = field(default_factory=set)
|
|
72
72
|
num_shares: int = 0
|
|
@@ -186,7 +186,7 @@ class SecAggPlusWorkflow:
|
|
|
186
186
|
|
|
187
187
|
self._check_init_params()
|
|
188
188
|
|
|
189
|
-
def __call__(self,
|
|
189
|
+
def __call__(self, grid: Grid, context: Context) -> None:
|
|
190
190
|
"""Run the SecAgg+ protocol."""
|
|
191
191
|
if not isinstance(context, LegacyContext):
|
|
192
192
|
raise TypeError(
|
|
@@ -202,7 +202,7 @@ class SecAggPlusWorkflow:
|
|
|
202
202
|
)
|
|
203
203
|
log(INFO, "Secure aggregation commencing.")
|
|
204
204
|
for step in steps:
|
|
205
|
-
if not step(
|
|
205
|
+
if not step(grid, context, state):
|
|
206
206
|
log(INFO, "Secure aggregation halted.")
|
|
207
207
|
return
|
|
208
208
|
log(INFO, "Secure aggregation completed.")
|
|
@@ -279,14 +279,14 @@ class SecAggPlusWorkflow:
|
|
|
279
279
|
return True
|
|
280
280
|
|
|
281
281
|
def setup_stage( # pylint: disable=R0912, R0914, R0915
|
|
282
|
-
self,
|
|
282
|
+
self, grid: Grid, context: LegacyContext, state: WorkflowState
|
|
283
283
|
) -> bool:
|
|
284
284
|
"""Execute the 'setup' stage."""
|
|
285
285
|
# Obtain fit instructions
|
|
286
|
-
cfg = context.state.
|
|
286
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
|
287
287
|
current_round = cast(int, cfg[WorkflowKey.CURRENT_ROUND])
|
|
288
|
-
parameters = compat.
|
|
289
|
-
context.state.
|
|
288
|
+
parameters = compat.arrayrecord_to_parameters(
|
|
289
|
+
context.state.array_records[MAIN_PARAMS_RECORD],
|
|
290
290
|
keep_input=True,
|
|
291
291
|
)
|
|
292
292
|
proxy_fitins_lst = context.strategy.configure_fit(
|
|
@@ -303,7 +303,7 @@ class SecAggPlusWorkflow:
|
|
|
303
303
|
)
|
|
304
304
|
|
|
305
305
|
state.nid_to_fitins = {
|
|
306
|
-
proxy.node_id: compat.
|
|
306
|
+
proxy.node_id: compat.fitins_to_recorddict(fitins, True)
|
|
307
307
|
for proxy, fitins in proxy_fitins_lst
|
|
308
308
|
}
|
|
309
309
|
state.nid_to_proxies = {proxy.node_id: proxy for proxy, _ in proxy_fitins_lst}
|
|
@@ -366,14 +366,14 @@ class SecAggPlusWorkflow:
|
|
|
366
366
|
state.sampled_node_ids = state.active_node_ids
|
|
367
367
|
|
|
368
368
|
# Send setup configuration to clients
|
|
369
|
-
|
|
370
|
-
content =
|
|
369
|
+
cfg_record = ConfigRecord(sa_params_dict) # type: ignore
|
|
370
|
+
content = RecordDict({RECORD_KEY_CONFIGS: cfg_record})
|
|
371
371
|
|
|
372
372
|
def make(nid: int) -> Message:
|
|
373
|
-
return
|
|
373
|
+
return Message(
|
|
374
374
|
content=content,
|
|
375
|
-
message_type=MessageType.TRAIN,
|
|
376
375
|
dst_node_id=nid,
|
|
376
|
+
message_type=MessageType.TRAIN,
|
|
377
377
|
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
|
|
378
378
|
)
|
|
379
379
|
|
|
@@ -382,7 +382,7 @@ class SecAggPlusWorkflow:
|
|
|
382
382
|
"[Stage 0] Sending configurations to %s clients.",
|
|
383
383
|
len(state.active_node_ids),
|
|
384
384
|
)
|
|
385
|
-
msgs =
|
|
385
|
+
msgs = grid.send_and_receive(
|
|
386
386
|
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
|
387
387
|
)
|
|
388
388
|
state.active_node_ids = {
|
|
@@ -398,7 +398,7 @@ class SecAggPlusWorkflow:
|
|
|
398
398
|
if msg.has_error():
|
|
399
399
|
state.failures.append(Exception(msg.error))
|
|
400
400
|
continue
|
|
401
|
-
key_dict = msg.content.
|
|
401
|
+
key_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
|
|
402
402
|
node_id = msg.metadata.src_node_id
|
|
403
403
|
pk1, pk2 = key_dict[Key.PUBLIC_KEY_1], key_dict[Key.PUBLIC_KEY_2]
|
|
404
404
|
state.nid_to_publickeys[node_id] = [cast(bytes, pk1), cast(bytes, pk2)]
|
|
@@ -406,22 +406,22 @@ class SecAggPlusWorkflow:
|
|
|
406
406
|
return self._check_threshold(state)
|
|
407
407
|
|
|
408
408
|
def share_keys_stage( # pylint: disable=R0914
|
|
409
|
-
self,
|
|
409
|
+
self, grid: Grid, context: LegacyContext, state: WorkflowState
|
|
410
410
|
) -> bool:
|
|
411
411
|
"""Execute the 'share keys' stage."""
|
|
412
|
-
cfg = context.state.
|
|
412
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
|
413
413
|
|
|
414
414
|
def make(nid: int) -> Message:
|
|
415
415
|
neighbours = state.nid_to_neighbours[nid] & state.active_node_ids
|
|
416
|
-
|
|
416
|
+
cfg_record = ConfigRecord(
|
|
417
417
|
{str(nid): state.nid_to_publickeys[nid] for nid in neighbours}
|
|
418
418
|
)
|
|
419
|
-
|
|
420
|
-
content =
|
|
421
|
-
return
|
|
419
|
+
cfg_record[Key.STAGE] = Stage.SHARE_KEYS
|
|
420
|
+
content = RecordDict({RECORD_KEY_CONFIGS: cfg_record})
|
|
421
|
+
return Message(
|
|
422
422
|
content=content,
|
|
423
|
-
message_type=MessageType.TRAIN,
|
|
424
423
|
dst_node_id=nid,
|
|
424
|
+
message_type=MessageType.TRAIN,
|
|
425
425
|
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
|
|
426
426
|
)
|
|
427
427
|
|
|
@@ -431,7 +431,7 @@ class SecAggPlusWorkflow:
|
|
|
431
431
|
"[Stage 1] Forwarding public keys to %s clients.",
|
|
432
432
|
len(state.active_node_ids),
|
|
433
433
|
)
|
|
434
|
-
msgs =
|
|
434
|
+
msgs = grid.send_and_receive(
|
|
435
435
|
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
|
436
436
|
)
|
|
437
437
|
state.active_node_ids = {
|
|
@@ -458,7 +458,7 @@ class SecAggPlusWorkflow:
|
|
|
458
458
|
state.failures.append(Exception(msg.error))
|
|
459
459
|
continue
|
|
460
460
|
node_id = msg.metadata.src_node_id
|
|
461
|
-
res_dict = msg.content.
|
|
461
|
+
res_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
|
|
462
462
|
dst_lst = cast(list[int], res_dict[Key.DESTINATION_LIST])
|
|
463
463
|
ctxt_lst = cast(list[bytes], res_dict[Key.CIPHERTEXT_LIST])
|
|
464
464
|
srcs += [node_id] * len(dst_lst)
|
|
@@ -476,25 +476,25 @@ class SecAggPlusWorkflow:
|
|
|
476
476
|
return self._check_threshold(state)
|
|
477
477
|
|
|
478
478
|
def collect_masked_vectors_stage(
|
|
479
|
-
self,
|
|
479
|
+
self, grid: Grid, context: LegacyContext, state: WorkflowState
|
|
480
480
|
) -> bool:
|
|
481
481
|
"""Execute the 'collect masked vectors' stage."""
|
|
482
|
-
cfg = context.state.
|
|
482
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
|
483
483
|
|
|
484
484
|
# Send secret key shares to clients (plus FitIns) and collect masked vectors
|
|
485
485
|
def make(nid: int) -> Message:
|
|
486
|
-
|
|
486
|
+
cfg_dict = {
|
|
487
487
|
Key.STAGE: Stage.COLLECT_MASKED_VECTORS,
|
|
488
488
|
Key.CIPHERTEXT_LIST: state.forward_ciphertexts[nid],
|
|
489
489
|
Key.SOURCE_LIST: state.forward_srcs[nid],
|
|
490
490
|
}
|
|
491
|
-
|
|
491
|
+
cfg_record = ConfigRecord(cfg_dict) # type: ignore
|
|
492
492
|
content = state.nid_to_fitins[nid]
|
|
493
|
-
content.
|
|
494
|
-
return
|
|
493
|
+
content.config_records[RECORD_KEY_CONFIGS] = cfg_record
|
|
494
|
+
return Message(
|
|
495
495
|
content=content,
|
|
496
|
-
message_type=MessageType.TRAIN,
|
|
497
496
|
dst_node_id=nid,
|
|
497
|
+
message_type=MessageType.TRAIN,
|
|
498
498
|
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
|
|
499
499
|
)
|
|
500
500
|
|
|
@@ -503,7 +503,7 @@ class SecAggPlusWorkflow:
|
|
|
503
503
|
"[Stage 2] Forwarding encrypted key shares to %s clients.",
|
|
504
504
|
len(state.active_node_ids),
|
|
505
505
|
)
|
|
506
|
-
msgs =
|
|
506
|
+
msgs = grid.send_and_receive(
|
|
507
507
|
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
|
508
508
|
)
|
|
509
509
|
state.active_node_ids = {
|
|
@@ -524,7 +524,7 @@ class SecAggPlusWorkflow:
|
|
|
524
524
|
if msg.has_error():
|
|
525
525
|
state.failures.append(Exception(msg.error))
|
|
526
526
|
continue
|
|
527
|
-
res_dict = msg.content.
|
|
527
|
+
res_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
|
|
528
528
|
bytes_list = cast(list[bytes], res_dict[Key.MASKED_PARAMETERS])
|
|
529
529
|
client_masked_vec = [bytes_to_ndarray(b) for b in bytes_list]
|
|
530
530
|
if masked_vector is None:
|
|
@@ -540,17 +540,17 @@ class SecAggPlusWorkflow:
|
|
|
540
540
|
if msg.has_error():
|
|
541
541
|
state.failures.append(Exception(msg.error))
|
|
542
542
|
continue
|
|
543
|
-
fitres = compat.
|
|
543
|
+
fitres = compat.recorddict_to_fitres(msg.content, True)
|
|
544
544
|
proxy = state.nid_to_proxies[msg.metadata.src_node_id]
|
|
545
545
|
state.legacy_results.append((proxy, fitres))
|
|
546
546
|
|
|
547
547
|
return self._check_threshold(state)
|
|
548
548
|
|
|
549
549
|
def unmask_stage( # pylint: disable=R0912, R0914, R0915
|
|
550
|
-
self,
|
|
550
|
+
self, grid: Grid, context: LegacyContext, state: WorkflowState
|
|
551
551
|
) -> bool:
|
|
552
552
|
"""Execute the 'unmask' stage."""
|
|
553
|
-
cfg = context.state.
|
|
553
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
|
554
554
|
current_round = cast(int, cfg[WorkflowKey.CURRENT_ROUND])
|
|
555
555
|
|
|
556
556
|
# Construct active node IDs and dead node IDs
|
|
@@ -560,17 +560,17 @@ class SecAggPlusWorkflow:
|
|
|
560
560
|
# Send secure IDs of active and dead clients and collect key shares from clients
|
|
561
561
|
def make(nid: int) -> Message:
|
|
562
562
|
neighbours = state.nid_to_neighbours[nid]
|
|
563
|
-
|
|
563
|
+
cfg_dict = {
|
|
564
564
|
Key.STAGE: Stage.UNMASK,
|
|
565
565
|
Key.ACTIVE_NODE_ID_LIST: list(neighbours & active_nids),
|
|
566
566
|
Key.DEAD_NODE_ID_LIST: list(neighbours & dead_nids),
|
|
567
567
|
}
|
|
568
|
-
|
|
569
|
-
content =
|
|
570
|
-
return
|
|
568
|
+
cfg_record = ConfigRecord(cfg_dict) # type: ignore
|
|
569
|
+
content = RecordDict({RECORD_KEY_CONFIGS: cfg_record})
|
|
570
|
+
return Message(
|
|
571
571
|
content=content,
|
|
572
|
-
message_type=MessageType.TRAIN,
|
|
573
572
|
dst_node_id=nid,
|
|
573
|
+
message_type=MessageType.TRAIN,
|
|
574
574
|
group_id=str(current_round),
|
|
575
575
|
)
|
|
576
576
|
|
|
@@ -579,7 +579,7 @@ class SecAggPlusWorkflow:
|
|
|
579
579
|
"[Stage 3] Requesting key shares from %s clients to remove masks.",
|
|
580
580
|
len(state.active_node_ids),
|
|
581
581
|
)
|
|
582
|
-
msgs =
|
|
582
|
+
msgs = grid.send_and_receive(
|
|
583
583
|
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
|
584
584
|
)
|
|
585
585
|
state.active_node_ids = {
|
|
@@ -599,7 +599,7 @@ class SecAggPlusWorkflow:
|
|
|
599
599
|
if msg.has_error():
|
|
600
600
|
state.failures.append(Exception(msg.error))
|
|
601
601
|
continue
|
|
602
|
-
res_dict = msg.content.
|
|
602
|
+
res_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
|
|
603
603
|
nids = cast(list[int], res_dict[Key.NODE_ID_LIST])
|
|
604
604
|
shares = cast(list[bytes], res_dict[Key.SHARE_LIST])
|
|
605
605
|
for owner_nid, share in zip(nids, shares):
|
|
@@ -676,10 +676,8 @@ class SecAggPlusWorkflow:
|
|
|
676
676
|
|
|
677
677
|
# Update the parameters and write history
|
|
678
678
|
if parameters_aggregated:
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
)
|
|
682
|
-
context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
|
|
679
|
+
arr_record = compat.parameters_to_arrayrecord(parameters_aggregated, True)
|
|
680
|
+
context.state.array_records[MAIN_PARAMS_RECORD] = arr_record
|
|
683
681
|
context.history.add_metrics_distributed_fit(
|
|
684
682
|
server_round=current_round, metrics=metrics_aggregated
|
|
685
683
|
)
|
flwr/simulation/app.py
CHANGED
|
@@ -47,7 +47,7 @@ from flwr.common.logger import (
|
|
|
47
47
|
stop_log_uploader,
|
|
48
48
|
)
|
|
49
49
|
from flwr.common.serde import (
|
|
50
|
-
|
|
50
|
+
config_record_from_proto,
|
|
51
51
|
context_from_proto,
|
|
52
52
|
context_to_proto,
|
|
53
53
|
fab_from_proto,
|
|
@@ -184,7 +184,7 @@ def run_simulation_process( # pylint: disable=R0914, disable=W0212, disable=R09
|
|
|
184
184
|
fed_opt_res: GetFederationOptionsResponse = conn._stub.GetFederationOptions(
|
|
185
185
|
GetFederationOptionsRequest(run_id=run.run_id)
|
|
186
186
|
)
|
|
187
|
-
federation_options =
|
|
187
|
+
federation_options = config_record_from_proto(
|
|
188
188
|
fed_opt_res.federation_options
|
|
189
189
|
)
|
|
190
190
|
|
|
@@ -105,8 +105,10 @@ def pool_size_from_resources(client_resources: dict[str, Union[int, float]]) ->
|
|
|
105
105
|
if not node_resources:
|
|
106
106
|
continue
|
|
107
107
|
|
|
108
|
-
|
|
109
|
-
|
|
108
|
+
# Fallback to zero when resource quantity is not configured on the ray node
|
|
109
|
+
# e.g.: node without GPU; head node set up not to run tasks (zero resources)
|
|
110
|
+
num_cpus = node_resources.get("CPU", 0)
|
|
111
|
+
num_gpus = node_resources.get("GPU", 0)
|
|
110
112
|
num_actors = int(num_cpus / client_resources["num_cpus"])
|
|
111
113
|
|
|
112
114
|
# If a GPU is present and client resources do require one
|
|
@@ -23,7 +23,7 @@ from flwr import common
|
|
|
23
23
|
from flwr.client import ClientFnExt
|
|
24
24
|
from flwr.client.client_app import ClientApp
|
|
25
25
|
from flwr.client.run_info_store import DeprecatedRunInfoStore
|
|
26
|
-
from flwr.common import DEFAULT_TTL, Message, Metadata,
|
|
26
|
+
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordDict, now
|
|
27
27
|
from flwr.common.constant import (
|
|
28
28
|
NUM_PARTITIONS_KEY,
|
|
29
29
|
PARTITION_ID_KEY,
|
|
@@ -31,15 +31,16 @@ from flwr.common.constant import (
|
|
|
31
31
|
MessageTypeLegacy,
|
|
32
32
|
)
|
|
33
33
|
from flwr.common.logger import log
|
|
34
|
-
from flwr.common.
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
34
|
+
from flwr.common.message import make_message
|
|
35
|
+
from flwr.common.recorddict_compat import (
|
|
36
|
+
evaluateins_to_recorddict,
|
|
37
|
+
fitins_to_recorddict,
|
|
38
|
+
getparametersins_to_recorddict,
|
|
39
|
+
getpropertiesins_to_recorddict,
|
|
40
|
+
recorddict_to_evaluateres,
|
|
41
|
+
recorddict_to_fitres,
|
|
42
|
+
recorddict_to_getparametersres,
|
|
43
|
+
recorddict_to_getpropertiesres,
|
|
43
44
|
)
|
|
44
45
|
from flwr.server.client_proxy import ClientProxy
|
|
45
46
|
from flwr.simulation.ray_transport.ray_actor import VirtualClientEngineActorPool
|
|
@@ -109,23 +110,24 @@ class RayActorClientProxy(ClientProxy):
|
|
|
109
110
|
|
|
110
111
|
return out_mssg
|
|
111
112
|
|
|
112
|
-
def
|
|
113
|
+
def _wrap_recorddict_in_message(
|
|
113
114
|
self,
|
|
114
|
-
|
|
115
|
+
recorddict: RecordDict,
|
|
115
116
|
message_type: str,
|
|
116
117
|
timeout: Optional[float],
|
|
117
118
|
group_id: Optional[int],
|
|
118
119
|
) -> Message:
|
|
119
|
-
"""Wrap a
|
|
120
|
-
return
|
|
121
|
-
content=
|
|
120
|
+
"""Wrap a RecordDict inside a Message."""
|
|
121
|
+
return make_message(
|
|
122
|
+
content=recorddict,
|
|
122
123
|
metadata=Metadata(
|
|
123
124
|
run_id=0,
|
|
124
125
|
message_id="",
|
|
125
126
|
group_id=str(group_id) if group_id is not None else "",
|
|
126
127
|
src_node_id=0,
|
|
127
128
|
dst_node_id=self.node_id,
|
|
128
|
-
|
|
129
|
+
reply_to_message_id="",
|
|
130
|
+
created_at=now().timestamp(),
|
|
129
131
|
ttl=timeout if timeout else DEFAULT_TTL,
|
|
130
132
|
message_type=message_type,
|
|
131
133
|
),
|
|
@@ -138,9 +140,9 @@ class RayActorClientProxy(ClientProxy):
|
|
|
138
140
|
group_id: Optional[int],
|
|
139
141
|
) -> common.GetPropertiesRes:
|
|
140
142
|
"""Return client's properties."""
|
|
141
|
-
|
|
142
|
-
message = self.
|
|
143
|
-
|
|
143
|
+
recorddict = getpropertiesins_to_recorddict(ins)
|
|
144
|
+
message = self._wrap_recorddict_in_message(
|
|
145
|
+
recorddict,
|
|
144
146
|
message_type=MessageTypeLegacy.GET_PROPERTIES,
|
|
145
147
|
timeout=timeout,
|
|
146
148
|
group_id=group_id,
|
|
@@ -148,7 +150,7 @@ class RayActorClientProxy(ClientProxy):
|
|
|
148
150
|
|
|
149
151
|
message_out = self._submit_job(message, timeout)
|
|
150
152
|
|
|
151
|
-
return
|
|
153
|
+
return recorddict_to_getpropertiesres(message_out.content)
|
|
152
154
|
|
|
153
155
|
def get_parameters(
|
|
154
156
|
self,
|
|
@@ -157,9 +159,9 @@ class RayActorClientProxy(ClientProxy):
|
|
|
157
159
|
group_id: Optional[int],
|
|
158
160
|
) -> common.GetParametersRes:
|
|
159
161
|
"""Return the current local model parameters."""
|
|
160
|
-
|
|
161
|
-
message = self.
|
|
162
|
-
|
|
162
|
+
recorddict = getparametersins_to_recorddict(ins)
|
|
163
|
+
message = self._wrap_recorddict_in_message(
|
|
164
|
+
recorddict,
|
|
163
165
|
message_type=MessageTypeLegacy.GET_PARAMETERS,
|
|
164
166
|
timeout=timeout,
|
|
165
167
|
group_id=group_id,
|
|
@@ -167,17 +169,17 @@ class RayActorClientProxy(ClientProxy):
|
|
|
167
169
|
|
|
168
170
|
message_out = self._submit_job(message, timeout)
|
|
169
171
|
|
|
170
|
-
return
|
|
172
|
+
return recorddict_to_getparametersres(message_out.content, keep_input=False)
|
|
171
173
|
|
|
172
174
|
def fit(
|
|
173
175
|
self, ins: common.FitIns, timeout: Optional[float], group_id: Optional[int]
|
|
174
176
|
) -> common.FitRes:
|
|
175
177
|
"""Train model parameters on the locally held dataset."""
|
|
176
|
-
|
|
178
|
+
recorddict = fitins_to_recorddict(
|
|
177
179
|
ins, keep_input=True
|
|
178
180
|
) # This must stay TRUE since ins are in-memory
|
|
179
|
-
message = self.
|
|
180
|
-
|
|
181
|
+
message = self._wrap_recorddict_in_message(
|
|
182
|
+
recorddict,
|
|
181
183
|
message_type=MessageType.TRAIN,
|
|
182
184
|
timeout=timeout,
|
|
183
185
|
group_id=group_id,
|
|
@@ -185,17 +187,17 @@ class RayActorClientProxy(ClientProxy):
|
|
|
185
187
|
|
|
186
188
|
message_out = self._submit_job(message, timeout)
|
|
187
189
|
|
|
188
|
-
return
|
|
190
|
+
return recorddict_to_fitres(message_out.content, keep_input=False)
|
|
189
191
|
|
|
190
192
|
def evaluate(
|
|
191
193
|
self, ins: common.EvaluateIns, timeout: Optional[float], group_id: Optional[int]
|
|
192
194
|
) -> common.EvaluateRes:
|
|
193
195
|
"""Evaluate model parameters on the locally held dataset."""
|
|
194
|
-
|
|
196
|
+
recorddict = evaluateins_to_recorddict(
|
|
195
197
|
ins, keep_input=True
|
|
196
198
|
) # This must stay TRUE since ins are in-memory
|
|
197
|
-
message = self.
|
|
198
|
-
|
|
199
|
+
message = self._wrap_recorddict_in_message(
|
|
200
|
+
recorddict,
|
|
199
201
|
message_type=MessageType.EVALUATE,
|
|
200
202
|
timeout=timeout,
|
|
201
203
|
group_id=group_id,
|
|
@@ -203,7 +205,7 @@ class RayActorClientProxy(ClientProxy):
|
|
|
203
205
|
|
|
204
206
|
message_out = self._submit_job(message, timeout)
|
|
205
207
|
|
|
206
|
-
return
|
|
208
|
+
return recorddict_to_evaluateres(message_out.content)
|
|
207
209
|
|
|
208
210
|
def reconnect(
|
|
209
211
|
self,
|
|
@@ -30,7 +30,7 @@ from typing import Any, Optional
|
|
|
30
30
|
from flwr.cli.config_utils import load_and_validate
|
|
31
31
|
from flwr.cli.utils import get_sha256_hash
|
|
32
32
|
from flwr.client import ClientApp
|
|
33
|
-
from flwr.common import Context, EventType,
|
|
33
|
+
from flwr.common import Context, EventType, RecordDict, event, log, now
|
|
34
34
|
from flwr.common.config import get_fused_config_from_dir, parse_config_args
|
|
35
35
|
from flwr.common.constant import RUN_ID_NUM_BYTES, Status
|
|
36
36
|
from flwr.common.logger import (
|
|
@@ -39,7 +39,7 @@ from flwr.common.logger import (
|
|
|
39
39
|
warn_deprecated_feature_with_example,
|
|
40
40
|
)
|
|
41
41
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
42
|
-
from flwr.server.
|
|
42
|
+
from flwr.server.grid import Grid, InMemoryGrid
|
|
43
43
|
from flwr.server.run_serverapp import run as _run
|
|
44
44
|
from flwr.server.server_app import ServerApp
|
|
45
45
|
from flwr.server.superlink.fleet import vce
|
|
@@ -168,7 +168,7 @@ def run_simulation(
|
|
|
168
168
|
messages sent by the `ServerApp`.
|
|
169
169
|
|
|
170
170
|
num_supernodes : int
|
|
171
|
-
Number of nodes that run a ClientApp. They can be sampled by a
|
|
171
|
+
Number of nodes that run a ClientApp. They can be sampled by a Grid in the
|
|
172
172
|
ServerApp and receive a Message describing what the ClientApp should perform.
|
|
173
173
|
|
|
174
174
|
backend_name : str (default: ray)
|
|
@@ -180,7 +180,7 @@ def run_simulation(
|
|
|
180
180
|
for values parsed to initialisation of backend, `client_resources`
|
|
181
181
|
to define the resources for clients, and `actor` to define the actor
|
|
182
182
|
parameters. Values supported in <value> are those included by
|
|
183
|
-
`flwr.common.typing.
|
|
183
|
+
`flwr.common.typing.ConfigRecordValues`.
|
|
184
184
|
|
|
185
185
|
enable_tf_gpu_growth : bool (default: False)
|
|
186
186
|
A boolean to indicate whether to enable GPU growth on the main thread. This is
|
|
@@ -225,7 +225,7 @@ def run_serverapp_th(
|
|
|
225
225
|
server_app_attr: Optional[str],
|
|
226
226
|
server_app: Optional[ServerApp],
|
|
227
227
|
server_app_run_config: UserConfig,
|
|
228
|
-
|
|
228
|
+
grid: Grid,
|
|
229
229
|
app_dir: str,
|
|
230
230
|
f_stop: threading.Event,
|
|
231
231
|
has_exception: threading.Event,
|
|
@@ -239,7 +239,7 @@ def run_serverapp_th(
|
|
|
239
239
|
tf_gpu_growth: bool,
|
|
240
240
|
stop_event: threading.Event,
|
|
241
241
|
exception_event: threading.Event,
|
|
242
|
-
|
|
242
|
+
_grid: Grid,
|
|
243
243
|
_server_app_dir: str,
|
|
244
244
|
_server_app_run_config: UserConfig,
|
|
245
245
|
_server_app_attr: Optional[str],
|
|
@@ -260,13 +260,13 @@ def run_serverapp_th(
|
|
|
260
260
|
run_id=run_id,
|
|
261
261
|
node_id=0,
|
|
262
262
|
node_config={},
|
|
263
|
-
state=
|
|
263
|
+
state=RecordDict(),
|
|
264
264
|
run_config=_server_app_run_config,
|
|
265
265
|
)
|
|
266
266
|
|
|
267
267
|
# Run ServerApp
|
|
268
268
|
updated_context = _run(
|
|
269
|
-
|
|
269
|
+
grid=_grid,
|
|
270
270
|
context=context,
|
|
271
271
|
server_app_dir=_server_app_dir,
|
|
272
272
|
server_app_attr=_server_app_attr,
|
|
@@ -291,7 +291,7 @@ def run_serverapp_th(
|
|
|
291
291
|
enable_tf_gpu_growth,
|
|
292
292
|
f_stop,
|
|
293
293
|
has_exception,
|
|
294
|
-
|
|
294
|
+
grid,
|
|
295
295
|
app_dir,
|
|
296
296
|
server_app_run_config,
|
|
297
297
|
server_app_attr,
|
|
@@ -333,7 +333,7 @@ def _main_loop(
|
|
|
333
333
|
run_id=run.run_id,
|
|
334
334
|
node_id=0,
|
|
335
335
|
node_config=UserConfig(),
|
|
336
|
-
state=
|
|
336
|
+
state=RecordDict(),
|
|
337
337
|
run_config=UserConfig(),
|
|
338
338
|
)
|
|
339
339
|
try:
|
|
@@ -347,9 +347,9 @@ def _main_loop(
|
|
|
347
347
|
if server_app_run_config is None:
|
|
348
348
|
server_app_run_config = {}
|
|
349
349
|
|
|
350
|
-
# Initialize
|
|
351
|
-
|
|
352
|
-
|
|
350
|
+
# Initialize Grid
|
|
351
|
+
grid = InMemoryGrid(state_factory=state_factory)
|
|
352
|
+
grid.set_run(run_id=run.run_id)
|
|
353
353
|
output_context_queue: Queue[Context] = Queue()
|
|
354
354
|
|
|
355
355
|
# Get and run ServerApp thread
|
|
@@ -357,7 +357,7 @@ def _main_loop(
|
|
|
357
357
|
server_app_attr=server_app_attr,
|
|
358
358
|
server_app=server_app,
|
|
359
359
|
server_app_run_config=server_app_run_config,
|
|
360
|
-
|
|
360
|
+
grid=grid,
|
|
361
361
|
app_dir=app_dir,
|
|
362
362
|
f_stop=f_stop,
|
|
363
363
|
has_exception=server_app_thread_has_exception,
|
|
@@ -546,7 +546,7 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser:
|
|
|
546
546
|
default="{}",
|
|
547
547
|
help='A JSON formatted stream, e.g \'{"<keyA>":<value>, "<keyB>":<value>}\' to '
|
|
548
548
|
"configure a backend. Values supported in <value> are those included by "
|
|
549
|
-
"`flwr.common.typing.
|
|
549
|
+
"`flwr.common.typing.ConfigRecordValues`. ",
|
|
550
550
|
)
|
|
551
551
|
parser.add_argument(
|
|
552
552
|
"--enable-tf-gpu-growth",
|
flwr/superexec/app.py
CHANGED
|
@@ -16,26 +16,12 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import argparse
|
|
19
|
-
import sys
|
|
20
|
-
from logging import INFO
|
|
21
19
|
|
|
22
|
-
from flwr.common import log
|
|
23
20
|
from flwr.common.object_ref import load_app, validate
|
|
24
21
|
|
|
25
22
|
from .executor import Executor
|
|
26
23
|
|
|
27
24
|
|
|
28
|
-
def run_superexec() -> None:
|
|
29
|
-
"""Run Flower SuperExec."""
|
|
30
|
-
log(INFO, "Starting Flower SuperExec")
|
|
31
|
-
|
|
32
|
-
sys.exit(
|
|
33
|
-
"Manually launching the SuperExec is deprecated. Since `flwr 1.13.0` "
|
|
34
|
-
"the executor service runs in the SuperLink. Launching it manually is not "
|
|
35
|
-
"recommended."
|
|
36
|
-
)
|
|
37
|
-
|
|
38
|
-
|
|
39
25
|
def load_executor(
|
|
40
26
|
args: argparse.Namespace,
|
|
41
27
|
) -> Executor:
|
flwr/superexec/deployment.py
CHANGED
|
@@ -23,7 +23,7 @@ from typing import Optional
|
|
|
23
23
|
from typing_extensions import override
|
|
24
24
|
|
|
25
25
|
from flwr.cli.config_utils import get_fab_metadata
|
|
26
|
-
from flwr.common import
|
|
26
|
+
from flwr.common import ConfigRecord, Context, RecordDict
|
|
27
27
|
from flwr.common.constant import (
|
|
28
28
|
SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS,
|
|
29
29
|
Status,
|
|
@@ -141,7 +141,7 @@ class DeploymentEngine(Executor):
|
|
|
141
141
|
fab_id, fab_version = get_fab_metadata(fab.content)
|
|
142
142
|
|
|
143
143
|
run_id = self.linkstate.create_run(
|
|
144
|
-
fab_id, fab_version, fab_hash, override_config,
|
|
144
|
+
fab_id, fab_version, fab_hash, override_config, ConfigRecord()
|
|
145
145
|
)
|
|
146
146
|
return run_id
|
|
147
147
|
|
|
@@ -149,7 +149,7 @@ class DeploymentEngine(Executor):
|
|
|
149
149
|
"""Register a Context for a Run."""
|
|
150
150
|
# Create an empty context for the Run
|
|
151
151
|
context = Context(
|
|
152
|
-
run_id=run_id, node_id=0, node_config={}, state=
|
|
152
|
+
run_id=run_id, node_id=0, node_config={}, state=RecordDict(), run_config={}
|
|
153
153
|
)
|
|
154
154
|
|
|
155
155
|
# Register the context at the LinkState
|
|
@@ -160,7 +160,7 @@ class DeploymentEngine(Executor):
|
|
|
160
160
|
self,
|
|
161
161
|
fab_file: bytes,
|
|
162
162
|
override_config: UserConfig,
|
|
163
|
-
federation_options:
|
|
163
|
+
federation_options: ConfigRecord,
|
|
164
164
|
) -> Optional[int]:
|
|
165
165
|
"""Start run using the Flower Deployment Engine."""
|
|
166
166
|
run_id = None
|