flwr-nightly 1.17.0.dev20250320__py3-none-any.whl → 1.17.0.dev20250322__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/run/run.py +5 -9
- flwr/client/client_app.py +10 -12
- flwr/client/grpc_client/connection.py +3 -3
- flwr/client/message_handler/message_handler.py +3 -3
- flwr/client/mod/__init__.py +2 -2
- flwr/client/mod/comms_mods.py +16 -22
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +26 -26
- flwr/common/__init__.py +10 -4
- flwr/common/config.py +4 -4
- flwr/common/constant.py +1 -1
- flwr/common/record/__init__.py +6 -3
- flwr/common/record/{parametersrecord.py → arrayrecord.py} +74 -31
- flwr/common/record/{configsrecord.py → configrecord.py} +73 -27
- flwr/common/record/conversion_utils.py +1 -1
- flwr/common/record/{metricsrecord.py → metricrecord.py} +77 -31
- flwr/common/record/recorddict.py +95 -56
- flwr/common/recorddict_compat.py +54 -62
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/serde.py +42 -43
- flwr/common/typing.py +8 -8
- flwr/proto/exec_pb2.py +30 -30
- flwr/proto/exec_pb2.pyi +2 -2
- flwr/proto/recorddict_pb2.py +29 -29
- flwr/proto/recorddict_pb2.pyi +33 -33
- flwr/proto/run_pb2.py +2 -2
- flwr/proto/run_pb2.pyi +2 -2
- flwr/server/compat/grid_client_proxy.py +1 -1
- flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
- flwr/server/superlink/linkstate/in_memory_linkstate.py +4 -4
- flwr/server/superlink/linkstate/linkstate.py +4 -4
- flwr/server/superlink/linkstate/sqlite_linkstate.py +7 -7
- flwr/server/superlink/linkstate/utils.py +9 -9
- flwr/server/superlink/serverappio/serverappio_servicer.py +2 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
- flwr/server/workflow/default_workflows.py +27 -34
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +32 -34
- flwr/simulation/app.py +2 -2
- flwr/simulation/ray_transport/ray_actor.py +4 -2
- flwr/simulation/run_simulation.py +2 -2
- flwr/superexec/deployment.py +3 -3
- flwr/superexec/exec_servicer.py +2 -2
- flwr/superexec/executor.py +3 -3
- flwr/superexec/simulation.py +2 -2
- {flwr_nightly-1.17.0.dev20250320.dist-info → flwr_nightly-1.17.0.dev20250322.dist-info}/METADATA +1 -1
- {flwr_nightly-1.17.0.dev20250320.dist-info → flwr_nightly-1.17.0.dev20250322.dist-info}/RECORD +49 -49
- {flwr_nightly-1.17.0.dev20250320.dist-info → flwr_nightly-1.17.0.dev20250322.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.17.0.dev20250320.dist-info → flwr_nightly-1.17.0.dev20250322.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.17.0.dev20250320.dist-info → flwr_nightly-1.17.0.dev20250322.dist-info}/entry_points.txt +0 -0
@@ -26,7 +26,7 @@ from flwr.common.constant import PARTITION_ID_KEY
|
|
26
26
|
from flwr.common.context import Context
|
27
27
|
from flwr.common.logger import log
|
28
28
|
from flwr.common.message import Message
|
29
|
-
from flwr.common.typing import
|
29
|
+
from flwr.common.typing import ConfigRecordValues
|
30
30
|
from flwr.simulation.ray_transport.ray_actor import BasicActorPool, ClientAppActor
|
31
31
|
from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth
|
32
32
|
|
@@ -104,7 +104,7 @@ class RayBackend(Backend):
|
|
104
104
|
if not ray.is_initialized():
|
105
105
|
ray_init_args: dict[
|
106
106
|
str,
|
107
|
-
|
107
|
+
ConfigRecordValues,
|
108
108
|
] = {}
|
109
109
|
|
110
110
|
if backend_config.get(self.init_args_key):
|
@@ -32,7 +32,7 @@ from flwr.common.constant import (
|
|
32
32
|
SUPERLINK_NODE_ID,
|
33
33
|
Status,
|
34
34
|
)
|
35
|
-
from flwr.common.record import
|
35
|
+
from flwr.common.record import ConfigRecord
|
36
36
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
37
37
|
from flwr.server.superlink.linkstate.linkstate import LinkState
|
38
38
|
from flwr.server.utils import validate_message
|
@@ -69,7 +69,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
69
69
|
# Map run_id to RunRecord
|
70
70
|
self.run_ids: dict[int, RunRecord] = {}
|
71
71
|
self.contexts: dict[int, Context] = {}
|
72
|
-
self.federation_options: dict[int,
|
72
|
+
self.federation_options: dict[int, ConfigRecord] = {}
|
73
73
|
self.message_ins_store: dict[UUID, Message] = {}
|
74
74
|
self.message_res_store: dict[UUID, Message] = {}
|
75
75
|
self.message_ins_id_to_message_res_id: dict[UUID, UUID] = {}
|
@@ -399,7 +399,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
399
399
|
fab_version: Optional[str],
|
400
400
|
fab_hash: Optional[str],
|
401
401
|
override_config: UserConfig,
|
402
|
-
federation_options:
|
402
|
+
federation_options: ConfigRecord,
|
403
403
|
) -> int:
|
404
404
|
"""Create a new run for the specified `fab_hash`."""
|
405
405
|
# Sample a random int64 as run_id
|
@@ -528,7 +528,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
528
528
|
|
529
529
|
return pending_run_id
|
530
530
|
|
531
|
-
def get_federation_options(self, run_id: int) -> Optional[
|
531
|
+
def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
|
532
532
|
"""Retrieve the federation options for the specified `run_id`."""
|
533
533
|
with self.lock:
|
534
534
|
if run_id not in self.run_ids:
|
@@ -20,7 +20,7 @@ from typing import Optional
|
|
20
20
|
from uuid import UUID
|
21
21
|
|
22
22
|
from flwr.common import Context, Message
|
23
|
-
from flwr.common.record import
|
23
|
+
from flwr.common.record import ConfigRecord
|
24
24
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
25
25
|
|
26
26
|
|
@@ -164,7 +164,7 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
164
164
|
fab_version: Optional[str],
|
165
165
|
fab_hash: Optional[str],
|
166
166
|
override_config: UserConfig,
|
167
|
-
federation_options:
|
167
|
+
federation_options: ConfigRecord,
|
168
168
|
) -> int:
|
169
169
|
"""Create a new run for the specified `fab_hash`."""
|
170
170
|
|
@@ -236,7 +236,7 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
236
236
|
"""
|
237
237
|
|
238
238
|
@abc.abstractmethod
|
239
|
-
def get_federation_options(self, run_id: int) -> Optional[
|
239
|
+
def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
|
240
240
|
"""Retrieve the federation options for the specified `run_id`.
|
241
241
|
|
242
242
|
Parameters
|
@@ -246,7 +246,7 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
246
246
|
|
247
247
|
Returns
|
248
248
|
-------
|
249
|
-
Optional[
|
249
|
+
Optional[ConfigRecord]
|
250
250
|
The federation options for the run if it exists; None otherwise.
|
251
251
|
"""
|
252
252
|
|
@@ -36,7 +36,7 @@ from flwr.common.constant import (
|
|
36
36
|
Status,
|
37
37
|
)
|
38
38
|
from flwr.common.message import make_message
|
39
|
-
from flwr.common.record import
|
39
|
+
from flwr.common.record import ConfigRecord
|
40
40
|
from flwr.common.serde import (
|
41
41
|
error_from_proto,
|
42
42
|
error_to_proto,
|
@@ -55,8 +55,8 @@ from flwr.server.utils.validator import validate_message
|
|
55
55
|
from .linkstate import LinkState
|
56
56
|
from .utils import (
|
57
57
|
check_node_availability_for_in_message,
|
58
|
-
|
59
|
-
|
58
|
+
configrecord_from_bytes,
|
59
|
+
configrecord_to_bytes,
|
60
60
|
context_from_bytes,
|
61
61
|
context_to_bytes,
|
62
62
|
convert_sint64_to_uint64,
|
@@ -727,7 +727,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
727
727
|
fab_version: Optional[str],
|
728
728
|
fab_hash: Optional[str],
|
729
729
|
override_config: UserConfig,
|
730
|
-
federation_options:
|
730
|
+
federation_options: ConfigRecord,
|
731
731
|
) -> int:
|
732
732
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
733
733
|
# Sample a random int64 as run_id
|
@@ -753,7 +753,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
753
753
|
fab_version,
|
754
754
|
fab_hash,
|
755
755
|
override_config_json,
|
756
|
-
|
756
|
+
configrecord_to_bytes(federation_options),
|
757
757
|
]
|
758
758
|
data += [
|
759
759
|
now().isoformat(),
|
@@ -911,7 +911,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
911
911
|
|
912
912
|
return pending_run_id
|
913
913
|
|
914
|
-
def get_federation_options(self, run_id: int) -> Optional[
|
914
|
+
def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
|
915
915
|
"""Retrieve the federation options for the specified `run_id`."""
|
916
916
|
# Convert the uint64 value to sint64 for SQLite
|
917
917
|
sint64_run_id = convert_uint64_to_sint64(run_id)
|
@@ -924,7 +924,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
924
924
|
return None
|
925
925
|
|
926
926
|
row = rows[0]
|
927
|
-
return
|
927
|
+
return configrecord_from_bytes(row["federation_options"])
|
928
928
|
|
929
929
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
930
930
|
"""Acknowledge a ping received from a node, serving as a heartbeat.
|
@@ -19,7 +19,7 @@ from os import urandom
|
|
19
19
|
from typing import Optional
|
20
20
|
from uuid import UUID, uuid4
|
21
21
|
|
22
|
-
from flwr.common import
|
22
|
+
from flwr.common import ConfigRecord, Context, Error, Message, Metadata, now, serde
|
23
23
|
from flwr.common.constant import (
|
24
24
|
SUPERLINK_NODE_ID,
|
25
25
|
ErrorCode,
|
@@ -32,7 +32,7 @@ from flwr.common.typing import RunStatus
|
|
32
32
|
|
33
33
|
# pylint: disable=E0611
|
34
34
|
from flwr.proto.message_pb2 import Context as ProtoContext
|
35
|
-
from flwr.proto.recorddict_pb2 import
|
35
|
+
from flwr.proto.recorddict_pb2 import ConfigRecord as ProtoConfigRecord
|
36
36
|
|
37
37
|
# pylint: enable=E0611
|
38
38
|
VALID_RUN_STATUS_TRANSITIONS = {
|
@@ -172,15 +172,15 @@ def context_from_bytes(context_bytes: bytes) -> Context:
|
|
172
172
|
return serde.context_from_proto(ProtoContext.FromString(context_bytes))
|
173
173
|
|
174
174
|
|
175
|
-
def
|
176
|
-
"""Serialize a `
|
177
|
-
return serde.
|
175
|
+
def configrecord_to_bytes(config_record: ConfigRecord) -> bytes:
|
176
|
+
"""Serialize a `ConfigRecord` to bytes."""
|
177
|
+
return serde.config_record_to_proto(config_record).SerializeToString()
|
178
178
|
|
179
179
|
|
180
|
-
def
|
181
|
-
"""Deserialize `
|
182
|
-
return serde.
|
183
|
-
|
180
|
+
def configrecord_from_bytes(configrecord_bytes: bytes) -> ConfigRecord:
|
181
|
+
"""Deserialize `ConfigRecord` from bytes."""
|
182
|
+
return serde.config_record_from_proto(
|
183
|
+
ProtoConfigRecord.FromString(configrecord_bytes)
|
184
184
|
)
|
185
185
|
|
186
186
|
|
@@ -22,7 +22,7 @@ from uuid import UUID
|
|
22
22
|
|
23
23
|
import grpc
|
24
24
|
|
25
|
-
from flwr.common import
|
25
|
+
from flwr.common import ConfigRecord, Message
|
26
26
|
from flwr.common.constant import SUPERLINK_NODE_ID, Status
|
27
27
|
from flwr.common.logger import log
|
28
28
|
from flwr.common.serde import (
|
@@ -127,7 +127,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
127
127
|
request.fab_version,
|
128
128
|
fab_hash,
|
129
129
|
user_config_from_proto(request.override_config),
|
130
|
-
|
130
|
+
ConfigRecord(),
|
131
131
|
)
|
132
132
|
return CreateRunResponse(run_id=run_id)
|
133
133
|
|
@@ -24,7 +24,7 @@ from grpc import ServicerContext
|
|
24
24
|
from flwr.common.constant import Status
|
25
25
|
from flwr.common.logger import log
|
26
26
|
from flwr.common.serde import (
|
27
|
-
|
27
|
+
config_record_to_proto,
|
28
28
|
context_from_proto,
|
29
29
|
context_to_proto,
|
30
30
|
fab_to_proto,
|
@@ -182,5 +182,5 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
182
182
|
)
|
183
183
|
return GetFederationOptionsResponse()
|
184
184
|
return GetFederationOptionsResponse(
|
185
|
-
federation_options=
|
185
|
+
federation_options=config_record_to_proto(federation_options)
|
186
186
|
)
|
@@ -22,13 +22,14 @@ from typing import Optional, Union, cast
|
|
22
22
|
|
23
23
|
import flwr.common.recorddict_compat as compat
|
24
24
|
from flwr.common import (
|
25
|
+
ArrayRecord,
|
25
26
|
Code,
|
26
|
-
|
27
|
+
ConfigRecord,
|
27
28
|
Context,
|
28
29
|
EvaluateRes,
|
29
30
|
FitRes,
|
30
31
|
GetParametersIns,
|
31
|
-
|
32
|
+
Message,
|
32
33
|
log,
|
33
34
|
)
|
34
35
|
from flwr.common.constant import MessageType, MessageTypeLegacy
|
@@ -77,9 +78,9 @@ class DefaultWorkflow:
|
|
77
78
|
|
78
79
|
# Run federated learning for num_rounds
|
79
80
|
start_time = timeit.default_timer()
|
80
|
-
cfg =
|
81
|
+
cfg = ConfigRecord()
|
81
82
|
cfg[Key.START_TIME] = start_time
|
82
|
-
context.state.
|
83
|
+
context.state.config_records[MAIN_CONFIGS_RECORD] = cfg
|
83
84
|
|
84
85
|
for current_round in range(1, context.config.num_rounds + 1):
|
85
86
|
log(INFO, "")
|
@@ -129,9 +130,7 @@ def default_init_params_workflow(grid: Grid, context: Context) -> None:
|
|
129
130
|
)
|
130
131
|
if parameters is not None:
|
131
132
|
log(INFO, "Using initial global parameters provided by strategy")
|
132
|
-
|
133
|
-
parameters, keep_input=True
|
134
|
-
)
|
133
|
+
arr_record = compat.parameters_to_arrayrecord(parameters, keep_input=True)
|
135
134
|
else:
|
136
135
|
# Get initial parameters from one of the clients
|
137
136
|
log(INFO, "Requesting initial parameters from one random client")
|
@@ -140,10 +139,10 @@ def default_init_params_workflow(grid: Grid, context: Context) -> None:
|
|
140
139
|
content = compat.getparametersins_to_recorddict(GetParametersIns({}))
|
141
140
|
messages = grid.send_and_receive(
|
142
141
|
[
|
143
|
-
|
142
|
+
Message(
|
144
143
|
content=content,
|
145
|
-
message_type=MessageTypeLegacy.GET_PARAMETERS,
|
146
144
|
dst_node_id=random_client.node_id,
|
145
|
+
message_type=MessageTypeLegacy.GET_PARAMETERS,
|
147
146
|
group_id="0",
|
148
147
|
)
|
149
148
|
]
|
@@ -158,20 +157,20 @@ def default_init_params_workflow(grid: Grid, context: Context) -> None:
|
|
158
157
|
== Code.OK
|
159
158
|
):
|
160
159
|
log(INFO, "Received initial parameters from one random client")
|
161
|
-
|
160
|
+
arr_record = next(iter(msg.content.array_records.values()))
|
162
161
|
else:
|
163
162
|
log(
|
164
163
|
WARN,
|
165
164
|
"Failed to receive initial parameters from the client."
|
166
165
|
" Empty initial parameters will be used.",
|
167
166
|
)
|
168
|
-
|
167
|
+
arr_record = ArrayRecord()
|
169
168
|
|
170
|
-
context.state.
|
169
|
+
context.state.array_records[MAIN_PARAMS_RECORD] = arr_record
|
171
170
|
|
172
171
|
# Evaluate initial parameters
|
173
172
|
log(INFO, "Starting evaluation of initial global parameters")
|
174
|
-
parameters = compat.
|
173
|
+
parameters = compat.arrayrecord_to_parameters(arr_record, keep_input=True)
|
175
174
|
res = context.strategy.evaluate(0, parameters=parameters)
|
176
175
|
if res is not None:
|
177
176
|
log(
|
@@ -192,13 +191,13 @@ def default_centralized_evaluation_workflow(_: Grid, context: Context) -> None:
|
|
192
191
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
193
192
|
|
194
193
|
# Retrieve current_round and start_time from the context
|
195
|
-
cfg = context.state.
|
194
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
196
195
|
current_round = cast(int, cfg[Key.CURRENT_ROUND])
|
197
196
|
start_time = cast(float, cfg[Key.START_TIME])
|
198
197
|
|
199
198
|
# Centralized evaluation
|
200
|
-
parameters = compat.
|
201
|
-
record=context.state.
|
199
|
+
parameters = compat.arrayrecord_to_parameters(
|
200
|
+
record=context.state.array_records[MAIN_PARAMS_RECORD],
|
202
201
|
keep_input=True,
|
203
202
|
)
|
204
203
|
res_cen = context.strategy.evaluate(current_round, parameters=parameters)
|
@@ -224,12 +223,10 @@ def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disab
|
|
224
223
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
225
224
|
|
226
225
|
# Get current_round and parameters
|
227
|
-
cfg = context.state.
|
226
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
228
227
|
current_round = cast(int, cfg[Key.CURRENT_ROUND])
|
229
|
-
|
230
|
-
parameters = compat.
|
231
|
-
parametersrecord, keep_input=True
|
232
|
-
)
|
228
|
+
arr_record = context.state.array_records[MAIN_PARAMS_RECORD]
|
229
|
+
parameters = compat.arrayrecord_to_parameters(arr_record, keep_input=True)
|
233
230
|
|
234
231
|
# Get clients and their respective instructions from strategy
|
235
232
|
client_instructions = context.strategy.configure_fit(
|
@@ -253,10 +250,10 @@ def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disab
|
|
253
250
|
|
254
251
|
# Build out messages
|
255
252
|
out_messages = [
|
256
|
-
|
253
|
+
Message(
|
257
254
|
content=compat.fitins_to_recorddict(fitins, True),
|
258
|
-
message_type=MessageType.TRAIN,
|
259
255
|
dst_node_id=proxy.node_id,
|
256
|
+
message_type=MessageType.TRAIN,
|
260
257
|
group_id=str(current_round),
|
261
258
|
)
|
262
259
|
for proxy, fitins in client_instructions
|
@@ -295,10 +292,8 @@ def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disab
|
|
295
292
|
|
296
293
|
# Update the parameters and write history
|
297
294
|
if parameters_aggregated:
|
298
|
-
|
299
|
-
|
300
|
-
)
|
301
|
-
context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
|
295
|
+
arr_record = compat.parameters_to_arrayrecord(parameters_aggregated, True)
|
296
|
+
context.state.array_records[MAIN_PARAMS_RECORD] = arr_record
|
302
297
|
context.history.add_metrics_distributed_fit(
|
303
298
|
server_round=current_round, metrics=metrics_aggregated
|
304
299
|
)
|
@@ -311,12 +306,10 @@ def default_evaluate_workflow(grid: Grid, context: Context) -> None:
|
|
311
306
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
312
307
|
|
313
308
|
# Get current_round and parameters
|
314
|
-
cfg = context.state.
|
309
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
315
310
|
current_round = cast(int, cfg[Key.CURRENT_ROUND])
|
316
|
-
|
317
|
-
parameters = compat.
|
318
|
-
parametersrecord, keep_input=True
|
319
|
-
)
|
311
|
+
arr_record = context.state.array_records[MAIN_PARAMS_RECORD]
|
312
|
+
parameters = compat.arrayrecord_to_parameters(arr_record, keep_input=True)
|
320
313
|
|
321
314
|
# Get clients and their respective instructions from strategy
|
322
315
|
client_instructions = context.strategy.configure_evaluate(
|
@@ -339,10 +332,10 @@ def default_evaluate_workflow(grid: Grid, context: Context) -> None:
|
|
339
332
|
|
340
333
|
# Build out messages
|
341
334
|
out_messages = [
|
342
|
-
|
335
|
+
Message(
|
343
336
|
content=compat.evaluateins_to_recorddict(evalins, True),
|
344
|
-
message_type=MessageType.EVALUATE,
|
345
337
|
dst_node_id=proxy.node_id,
|
338
|
+
message_type=MessageType.EVALUATE,
|
346
339
|
group_id=str(current_round),
|
347
340
|
)
|
348
341
|
for proxy, evalins in client_instructions
|
@@ -22,7 +22,7 @@ from typing import Optional, Union, cast
|
|
22
22
|
|
23
23
|
import flwr.common.recorddict_compat as compat
|
24
24
|
from flwr.common import (
|
25
|
-
|
25
|
+
ConfigRecord,
|
26
26
|
Context,
|
27
27
|
FitRes,
|
28
28
|
Message,
|
@@ -283,10 +283,10 @@ class SecAggPlusWorkflow:
|
|
283
283
|
) -> bool:
|
284
284
|
"""Execute the 'setup' stage."""
|
285
285
|
# Obtain fit instructions
|
286
|
-
cfg = context.state.
|
286
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
287
287
|
current_round = cast(int, cfg[WorkflowKey.CURRENT_ROUND])
|
288
|
-
parameters = compat.
|
289
|
-
context.state.
|
288
|
+
parameters = compat.arrayrecord_to_parameters(
|
289
|
+
context.state.array_records[MAIN_PARAMS_RECORD],
|
290
290
|
keep_input=True,
|
291
291
|
)
|
292
292
|
proxy_fitins_lst = context.strategy.configure_fit(
|
@@ -366,14 +366,14 @@ class SecAggPlusWorkflow:
|
|
366
366
|
state.sampled_node_ids = state.active_node_ids
|
367
367
|
|
368
368
|
# Send setup configuration to clients
|
369
|
-
|
370
|
-
content = RecordDict({RECORD_KEY_CONFIGS:
|
369
|
+
cfg_record = ConfigRecord(sa_params_dict) # type: ignore
|
370
|
+
content = RecordDict({RECORD_KEY_CONFIGS: cfg_record})
|
371
371
|
|
372
372
|
def make(nid: int) -> Message:
|
373
|
-
return
|
373
|
+
return Message(
|
374
374
|
content=content,
|
375
|
-
message_type=MessageType.TRAIN,
|
376
375
|
dst_node_id=nid,
|
376
|
+
message_type=MessageType.TRAIN,
|
377
377
|
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
|
378
378
|
)
|
379
379
|
|
@@ -398,7 +398,7 @@ class SecAggPlusWorkflow:
|
|
398
398
|
if msg.has_error():
|
399
399
|
state.failures.append(Exception(msg.error))
|
400
400
|
continue
|
401
|
-
key_dict = msg.content.
|
401
|
+
key_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
|
402
402
|
node_id = msg.metadata.src_node_id
|
403
403
|
pk1, pk2 = key_dict[Key.PUBLIC_KEY_1], key_dict[Key.PUBLIC_KEY_2]
|
404
404
|
state.nid_to_publickeys[node_id] = [cast(bytes, pk1), cast(bytes, pk2)]
|
@@ -409,19 +409,19 @@ class SecAggPlusWorkflow:
|
|
409
409
|
self, grid: Grid, context: LegacyContext, state: WorkflowState
|
410
410
|
) -> bool:
|
411
411
|
"""Execute the 'share keys' stage."""
|
412
|
-
cfg = context.state.
|
412
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
413
413
|
|
414
414
|
def make(nid: int) -> Message:
|
415
415
|
neighbours = state.nid_to_neighbours[nid] & state.active_node_ids
|
416
|
-
|
416
|
+
cfg_record = ConfigRecord(
|
417
417
|
{str(nid): state.nid_to_publickeys[nid] for nid in neighbours}
|
418
418
|
)
|
419
|
-
|
420
|
-
content = RecordDict({RECORD_KEY_CONFIGS:
|
421
|
-
return
|
419
|
+
cfg_record[Key.STAGE] = Stage.SHARE_KEYS
|
420
|
+
content = RecordDict({RECORD_KEY_CONFIGS: cfg_record})
|
421
|
+
return Message(
|
422
422
|
content=content,
|
423
|
-
message_type=MessageType.TRAIN,
|
424
423
|
dst_node_id=nid,
|
424
|
+
message_type=MessageType.TRAIN,
|
425
425
|
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
|
426
426
|
)
|
427
427
|
|
@@ -458,7 +458,7 @@ class SecAggPlusWorkflow:
|
|
458
458
|
state.failures.append(Exception(msg.error))
|
459
459
|
continue
|
460
460
|
node_id = msg.metadata.src_node_id
|
461
|
-
res_dict = msg.content.
|
461
|
+
res_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
|
462
462
|
dst_lst = cast(list[int], res_dict[Key.DESTINATION_LIST])
|
463
463
|
ctxt_lst = cast(list[bytes], res_dict[Key.CIPHERTEXT_LIST])
|
464
464
|
srcs += [node_id] * len(dst_lst)
|
@@ -479,22 +479,22 @@ class SecAggPlusWorkflow:
|
|
479
479
|
self, grid: Grid, context: LegacyContext, state: WorkflowState
|
480
480
|
) -> bool:
|
481
481
|
"""Execute the 'collect masked vectors' stage."""
|
482
|
-
cfg = context.state.
|
482
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
483
483
|
|
484
484
|
# Send secret key shares to clients (plus FitIns) and collect masked vectors
|
485
485
|
def make(nid: int) -> Message:
|
486
|
-
|
486
|
+
cfg_dict = {
|
487
487
|
Key.STAGE: Stage.COLLECT_MASKED_VECTORS,
|
488
488
|
Key.CIPHERTEXT_LIST: state.forward_ciphertexts[nid],
|
489
489
|
Key.SOURCE_LIST: state.forward_srcs[nid],
|
490
490
|
}
|
491
|
-
|
491
|
+
cfg_record = ConfigRecord(cfg_dict) # type: ignore
|
492
492
|
content = state.nid_to_fitins[nid]
|
493
|
-
content.
|
494
|
-
return
|
493
|
+
content.config_records[RECORD_KEY_CONFIGS] = cfg_record
|
494
|
+
return Message(
|
495
495
|
content=content,
|
496
|
-
message_type=MessageType.TRAIN,
|
497
496
|
dst_node_id=nid,
|
497
|
+
message_type=MessageType.TRAIN,
|
498
498
|
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
|
499
499
|
)
|
500
500
|
|
@@ -524,7 +524,7 @@ class SecAggPlusWorkflow:
|
|
524
524
|
if msg.has_error():
|
525
525
|
state.failures.append(Exception(msg.error))
|
526
526
|
continue
|
527
|
-
res_dict = msg.content.
|
527
|
+
res_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
|
528
528
|
bytes_list = cast(list[bytes], res_dict[Key.MASKED_PARAMETERS])
|
529
529
|
client_masked_vec = [bytes_to_ndarray(b) for b in bytes_list]
|
530
530
|
if masked_vector is None:
|
@@ -550,7 +550,7 @@ class SecAggPlusWorkflow:
|
|
550
550
|
self, grid: Grid, context: LegacyContext, state: WorkflowState
|
551
551
|
) -> bool:
|
552
552
|
"""Execute the 'unmask' stage."""
|
553
|
-
cfg = context.state.
|
553
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
554
554
|
current_round = cast(int, cfg[WorkflowKey.CURRENT_ROUND])
|
555
555
|
|
556
556
|
# Construct active node IDs and dead node IDs
|
@@ -560,17 +560,17 @@ class SecAggPlusWorkflow:
|
|
560
560
|
# Send secure IDs of active and dead clients and collect key shares from clients
|
561
561
|
def make(nid: int) -> Message:
|
562
562
|
neighbours = state.nid_to_neighbours[nid]
|
563
|
-
|
563
|
+
cfg_dict = {
|
564
564
|
Key.STAGE: Stage.UNMASK,
|
565
565
|
Key.ACTIVE_NODE_ID_LIST: list(neighbours & active_nids),
|
566
566
|
Key.DEAD_NODE_ID_LIST: list(neighbours & dead_nids),
|
567
567
|
}
|
568
|
-
|
569
|
-
content = RecordDict({RECORD_KEY_CONFIGS:
|
570
|
-
return
|
568
|
+
cfg_record = ConfigRecord(cfg_dict) # type: ignore
|
569
|
+
content = RecordDict({RECORD_KEY_CONFIGS: cfg_record})
|
570
|
+
return Message(
|
571
571
|
content=content,
|
572
|
-
message_type=MessageType.TRAIN,
|
573
572
|
dst_node_id=nid,
|
573
|
+
message_type=MessageType.TRAIN,
|
574
574
|
group_id=str(current_round),
|
575
575
|
)
|
576
576
|
|
@@ -599,7 +599,7 @@ class SecAggPlusWorkflow:
|
|
599
599
|
if msg.has_error():
|
600
600
|
state.failures.append(Exception(msg.error))
|
601
601
|
continue
|
602
|
-
res_dict = msg.content.
|
602
|
+
res_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
|
603
603
|
nids = cast(list[int], res_dict[Key.NODE_ID_LIST])
|
604
604
|
shares = cast(list[bytes], res_dict[Key.SHARE_LIST])
|
605
605
|
for owner_nid, share in zip(nids, shares):
|
@@ -676,10 +676,8 @@ class SecAggPlusWorkflow:
|
|
676
676
|
|
677
677
|
# Update the parameters and write history
|
678
678
|
if parameters_aggregated:
|
679
|
-
|
680
|
-
|
681
|
-
)
|
682
|
-
context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
|
679
|
+
arr_record = compat.parameters_to_arrayrecord(parameters_aggregated, True)
|
680
|
+
context.state.array_records[MAIN_PARAMS_RECORD] = arr_record
|
683
681
|
context.history.add_metrics_distributed_fit(
|
684
682
|
server_round=current_round, metrics=metrics_aggregated
|
685
683
|
)
|
flwr/simulation/app.py
CHANGED
@@ -47,7 +47,7 @@ from flwr.common.logger import (
|
|
47
47
|
stop_log_uploader,
|
48
48
|
)
|
49
49
|
from flwr.common.serde import (
|
50
|
-
|
50
|
+
config_record_from_proto,
|
51
51
|
context_from_proto,
|
52
52
|
context_to_proto,
|
53
53
|
fab_from_proto,
|
@@ -184,7 +184,7 @@ def run_simulation_process( # pylint: disable=R0914, disable=W0212, disable=R09
|
|
184
184
|
fed_opt_res: GetFederationOptionsResponse = conn._stub.GetFederationOptions(
|
185
185
|
GetFederationOptionsRequest(run_id=run.run_id)
|
186
186
|
)
|
187
|
-
federation_options =
|
187
|
+
federation_options = config_record_from_proto(
|
188
188
|
fed_opt_res.federation_options
|
189
189
|
)
|
190
190
|
|
@@ -105,8 +105,10 @@ def pool_size_from_resources(client_resources: dict[str, Union[int, float]]) ->
|
|
105
105
|
if not node_resources:
|
106
106
|
continue
|
107
107
|
|
108
|
-
|
109
|
-
|
108
|
+
# Fallback to zero when resource quantity is not configured on the ray node
|
109
|
+
# e.g.: node without GPU; head node set up not to run tasks (zero resources)
|
110
|
+
num_cpus = node_resources.get("CPU", 0)
|
111
|
+
num_gpus = node_resources.get("GPU", 0)
|
110
112
|
num_actors = int(num_cpus / client_resources["num_cpus"])
|
111
113
|
|
112
114
|
# If a GPU is present and client resources do require one
|
@@ -180,7 +180,7 @@ def run_simulation(
|
|
180
180
|
for values parsed to initialisation of backend, `client_resources`
|
181
181
|
to define the resources for clients, and `actor` to define the actor
|
182
182
|
parameters. Values supported in <value> are those included by
|
183
|
-
`flwr.common.typing.
|
183
|
+
`flwr.common.typing.ConfigRecordValues`.
|
184
184
|
|
185
185
|
enable_tf_gpu_growth : bool (default: False)
|
186
186
|
A boolean to indicate whether to enable GPU growth on the main thread. This is
|
@@ -546,7 +546,7 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser:
|
|
546
546
|
default="{}",
|
547
547
|
help='A JSON formatted stream, e.g \'{"<keyA>":<value>, "<keyB>":<value>}\' to '
|
548
548
|
"configure a backend. Values supported in <value> are those included by "
|
549
|
-
"`flwr.common.typing.
|
549
|
+
"`flwr.common.typing.ConfigRecordValues`. ",
|
550
550
|
)
|
551
551
|
parser.add_argument(
|
552
552
|
"--enable-tf-gpu-growth",
|
flwr/superexec/deployment.py
CHANGED
@@ -23,7 +23,7 @@ from typing import Optional
|
|
23
23
|
from typing_extensions import override
|
24
24
|
|
25
25
|
from flwr.cli.config_utils import get_fab_metadata
|
26
|
-
from flwr.common import
|
26
|
+
from flwr.common import ConfigRecord, Context, RecordDict
|
27
27
|
from flwr.common.constant import (
|
28
28
|
SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS,
|
29
29
|
Status,
|
@@ -141,7 +141,7 @@ class DeploymentEngine(Executor):
|
|
141
141
|
fab_id, fab_version = get_fab_metadata(fab.content)
|
142
142
|
|
143
143
|
run_id = self.linkstate.create_run(
|
144
|
-
fab_id, fab_version, fab_hash, override_config,
|
144
|
+
fab_id, fab_version, fab_hash, override_config, ConfigRecord()
|
145
145
|
)
|
146
146
|
return run_id
|
147
147
|
|
@@ -160,7 +160,7 @@ class DeploymentEngine(Executor):
|
|
160
160
|
self,
|
161
161
|
fab_file: bytes,
|
162
162
|
override_config: UserConfig,
|
163
|
-
federation_options:
|
163
|
+
federation_options: ConfigRecord,
|
164
164
|
) -> Optional[int]:
|
165
165
|
"""Start run using the Flower Deployment Engine."""
|
166
166
|
run_id = None
|
flwr/superexec/exec_servicer.py
CHANGED
@@ -28,7 +28,7 @@ from flwr.common.auth_plugin import ExecAuthPlugin
|
|
28
28
|
from flwr.common.constant import LOG_STREAM_INTERVAL, Status, SubStatus
|
29
29
|
from flwr.common.logger import log
|
30
30
|
from flwr.common.serde import (
|
31
|
-
|
31
|
+
config_record_from_proto,
|
32
32
|
run_to_proto,
|
33
33
|
user_config_from_proto,
|
34
34
|
)
|
@@ -79,7 +79,7 @@ class ExecServicer(exec_pb2_grpc.ExecServicer):
|
|
79
79
|
run_id = self.executor.start_run(
|
80
80
|
request.fab.content,
|
81
81
|
user_config_from_proto(request.override_config),
|
82
|
-
|
82
|
+
config_record_from_proto(request.federation_options),
|
83
83
|
)
|
84
84
|
|
85
85
|
if run_id is None:
|