flwr-nightly 1.17.0.dev20250319__py3-none-any.whl → 1.17.0.dev20250321__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/run/run.py +5 -9
- flwr/client/app.py +6 -4
- flwr/client/client_app.py +10 -12
- flwr/client/clientapp/app.py +2 -2
- flwr/client/grpc_client/connection.py +24 -21
- flwr/client/message_handler/message_handler.py +27 -27
- flwr/client/mod/__init__.py +2 -2
- flwr/client/mod/centraldp_mods.py +7 -7
- flwr/client/mod/comms_mods.py +16 -22
- flwr/client/mod/localdp_mod.py +4 -4
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
- flwr/client/run_info_store.py +2 -2
- flwr/common/__init__.py +12 -4
- flwr/common/config.py +4 -4
- flwr/common/constant.py +1 -1
- flwr/common/context.py +4 -4
- flwr/common/message.py +269 -101
- flwr/common/record/__init__.py +8 -4
- flwr/common/record/{parametersrecord.py → arrayrecord.py} +75 -32
- flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
- flwr/common/record/conversion_utils.py +1 -1
- flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
- flwr/common/record/recorddict.py +288 -0
- flwr/common/recorddict_compat.py +410 -0
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/serde.py +66 -71
- flwr/common/typing.py +8 -8
- flwr/proto/exec_pb2.py +3 -3
- flwr/proto/exec_pb2.pyi +3 -3
- flwr/proto/message_pb2.py +12 -12
- flwr/proto/message_pb2.pyi +9 -9
- flwr/proto/recorddict_pb2.py +70 -0
- flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
- flwr/proto/run_pb2.py +31 -31
- flwr/proto/run_pb2.pyi +3 -3
- flwr/server/compat/grid_client_proxy.py +31 -31
- flwr/server/grid/grid.py +3 -3
- flwr/server/grid/grpc_grid.py +15 -23
- flwr/server/grid/inmemory_grid.py +14 -20
- flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
- flwr/server/superlink/fleet/vce/vce_api.py +1 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +5 -5
- flwr/server/superlink/linkstate/linkstate.py +4 -4
- flwr/server/superlink/linkstate/sqlite_linkstate.py +21 -25
- flwr/server/superlink/linkstate/utils.py +18 -15
- flwr/server/superlink/serverappio/serverappio_servicer.py +3 -3
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
- flwr/server/utils/validator.py +4 -4
- flwr/server/workflow/default_workflows.py +34 -41
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +37 -39
- flwr/simulation/app.py +2 -2
- flwr/simulation/ray_transport/ray_actor.py +4 -2
- flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
- flwr/simulation/run_simulation.py +5 -5
- flwr/superexec/deployment.py +4 -4
- flwr/superexec/exec_servicer.py +2 -2
- flwr/superexec/executor.py +3 -3
- flwr/superexec/simulation.py +3 -3
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/METADATA +1 -1
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/RECORD +66 -66
- flwr/common/record/recordset.py +0 -209
- flwr/common/recordset_compat.py +0 -418
- flwr/proto/recordset_pb2.py +0 -70
- /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
- /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/entry_points.txt +0 -0
@@ -20,15 +20,15 @@ from dataclasses import dataclass, field
|
|
20
20
|
from logging import DEBUG, ERROR, INFO, WARN
|
21
21
|
from typing import Optional, Union, cast
|
22
22
|
|
23
|
-
import flwr.common.
|
23
|
+
import flwr.common.recorddict_compat as compat
|
24
24
|
from flwr.common import (
|
25
|
-
|
25
|
+
ConfigRecord,
|
26
26
|
Context,
|
27
27
|
FitRes,
|
28
28
|
Message,
|
29
29
|
MessageType,
|
30
30
|
NDArrays,
|
31
|
-
|
31
|
+
RecordDict,
|
32
32
|
bytes_to_ndarray,
|
33
33
|
log,
|
34
34
|
ndarrays_to_parameters,
|
@@ -66,7 +66,7 @@ class WorkflowState: # pylint: disable=R0902
|
|
66
66
|
"""The state of the SecAgg+ protocol."""
|
67
67
|
|
68
68
|
nid_to_proxies: dict[int, ClientProxy] = field(default_factory=dict)
|
69
|
-
nid_to_fitins: dict[int,
|
69
|
+
nid_to_fitins: dict[int, RecordDict] = field(default_factory=dict)
|
70
70
|
sampled_node_ids: set[int] = field(default_factory=set)
|
71
71
|
active_node_ids: set[int] = field(default_factory=set)
|
72
72
|
num_shares: int = 0
|
@@ -283,10 +283,10 @@ class SecAggPlusWorkflow:
|
|
283
283
|
) -> bool:
|
284
284
|
"""Execute the 'setup' stage."""
|
285
285
|
# Obtain fit instructions
|
286
|
-
cfg = context.state.
|
286
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
287
287
|
current_round = cast(int, cfg[WorkflowKey.CURRENT_ROUND])
|
288
|
-
parameters = compat.
|
289
|
-
context.state.
|
288
|
+
parameters = compat.arrayrecord_to_parameters(
|
289
|
+
context.state.array_records[MAIN_PARAMS_RECORD],
|
290
290
|
keep_input=True,
|
291
291
|
)
|
292
292
|
proxy_fitins_lst = context.strategy.configure_fit(
|
@@ -303,7 +303,7 @@ class SecAggPlusWorkflow:
|
|
303
303
|
)
|
304
304
|
|
305
305
|
state.nid_to_fitins = {
|
306
|
-
proxy.node_id: compat.
|
306
|
+
proxy.node_id: compat.fitins_to_recorddict(fitins, True)
|
307
307
|
for proxy, fitins in proxy_fitins_lst
|
308
308
|
}
|
309
309
|
state.nid_to_proxies = {proxy.node_id: proxy for proxy, _ in proxy_fitins_lst}
|
@@ -366,14 +366,14 @@ class SecAggPlusWorkflow:
|
|
366
366
|
state.sampled_node_ids = state.active_node_ids
|
367
367
|
|
368
368
|
# Send setup configuration to clients
|
369
|
-
|
370
|
-
content =
|
369
|
+
cfg_record = ConfigRecord(sa_params_dict) # type: ignore
|
370
|
+
content = RecordDict({RECORD_KEY_CONFIGS: cfg_record})
|
371
371
|
|
372
372
|
def make(nid: int) -> Message:
|
373
|
-
return
|
373
|
+
return Message(
|
374
374
|
content=content,
|
375
|
-
message_type=MessageType.TRAIN,
|
376
375
|
dst_node_id=nid,
|
376
|
+
message_type=MessageType.TRAIN,
|
377
377
|
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
|
378
378
|
)
|
379
379
|
|
@@ -398,7 +398,7 @@ class SecAggPlusWorkflow:
|
|
398
398
|
if msg.has_error():
|
399
399
|
state.failures.append(Exception(msg.error))
|
400
400
|
continue
|
401
|
-
key_dict = msg.content.
|
401
|
+
key_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
|
402
402
|
node_id = msg.metadata.src_node_id
|
403
403
|
pk1, pk2 = key_dict[Key.PUBLIC_KEY_1], key_dict[Key.PUBLIC_KEY_2]
|
404
404
|
state.nid_to_publickeys[node_id] = [cast(bytes, pk1), cast(bytes, pk2)]
|
@@ -409,19 +409,19 @@ class SecAggPlusWorkflow:
|
|
409
409
|
self, grid: Grid, context: LegacyContext, state: WorkflowState
|
410
410
|
) -> bool:
|
411
411
|
"""Execute the 'share keys' stage."""
|
412
|
-
cfg = context.state.
|
412
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
413
413
|
|
414
414
|
def make(nid: int) -> Message:
|
415
415
|
neighbours = state.nid_to_neighbours[nid] & state.active_node_ids
|
416
|
-
|
416
|
+
cfg_record = ConfigRecord(
|
417
417
|
{str(nid): state.nid_to_publickeys[nid] for nid in neighbours}
|
418
418
|
)
|
419
|
-
|
420
|
-
content =
|
421
|
-
return
|
419
|
+
cfg_record[Key.STAGE] = Stage.SHARE_KEYS
|
420
|
+
content = RecordDict({RECORD_KEY_CONFIGS: cfg_record})
|
421
|
+
return Message(
|
422
422
|
content=content,
|
423
|
-
message_type=MessageType.TRAIN,
|
424
423
|
dst_node_id=nid,
|
424
|
+
message_type=MessageType.TRAIN,
|
425
425
|
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
|
426
426
|
)
|
427
427
|
|
@@ -458,7 +458,7 @@ class SecAggPlusWorkflow:
|
|
458
458
|
state.failures.append(Exception(msg.error))
|
459
459
|
continue
|
460
460
|
node_id = msg.metadata.src_node_id
|
461
|
-
res_dict = msg.content.
|
461
|
+
res_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
|
462
462
|
dst_lst = cast(list[int], res_dict[Key.DESTINATION_LIST])
|
463
463
|
ctxt_lst = cast(list[bytes], res_dict[Key.CIPHERTEXT_LIST])
|
464
464
|
srcs += [node_id] * len(dst_lst)
|
@@ -479,22 +479,22 @@ class SecAggPlusWorkflow:
|
|
479
479
|
self, grid: Grid, context: LegacyContext, state: WorkflowState
|
480
480
|
) -> bool:
|
481
481
|
"""Execute the 'collect masked vectors' stage."""
|
482
|
-
cfg = context.state.
|
482
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
483
483
|
|
484
484
|
# Send secret key shares to clients (plus FitIns) and collect masked vectors
|
485
485
|
def make(nid: int) -> Message:
|
486
|
-
|
486
|
+
cfg_dict = {
|
487
487
|
Key.STAGE: Stage.COLLECT_MASKED_VECTORS,
|
488
488
|
Key.CIPHERTEXT_LIST: state.forward_ciphertexts[nid],
|
489
489
|
Key.SOURCE_LIST: state.forward_srcs[nid],
|
490
490
|
}
|
491
|
-
|
491
|
+
cfg_record = ConfigRecord(cfg_dict) # type: ignore
|
492
492
|
content = state.nid_to_fitins[nid]
|
493
|
-
content.
|
494
|
-
return
|
493
|
+
content.config_records[RECORD_KEY_CONFIGS] = cfg_record
|
494
|
+
return Message(
|
495
495
|
content=content,
|
496
|
-
message_type=MessageType.TRAIN,
|
497
496
|
dst_node_id=nid,
|
497
|
+
message_type=MessageType.TRAIN,
|
498
498
|
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
|
499
499
|
)
|
500
500
|
|
@@ -524,7 +524,7 @@ class SecAggPlusWorkflow:
|
|
524
524
|
if msg.has_error():
|
525
525
|
state.failures.append(Exception(msg.error))
|
526
526
|
continue
|
527
|
-
res_dict = msg.content.
|
527
|
+
res_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
|
528
528
|
bytes_list = cast(list[bytes], res_dict[Key.MASKED_PARAMETERS])
|
529
529
|
client_masked_vec = [bytes_to_ndarray(b) for b in bytes_list]
|
530
530
|
if masked_vector is None:
|
@@ -540,7 +540,7 @@ class SecAggPlusWorkflow:
|
|
540
540
|
if msg.has_error():
|
541
541
|
state.failures.append(Exception(msg.error))
|
542
542
|
continue
|
543
|
-
fitres = compat.
|
543
|
+
fitres = compat.recorddict_to_fitres(msg.content, True)
|
544
544
|
proxy = state.nid_to_proxies[msg.metadata.src_node_id]
|
545
545
|
state.legacy_results.append((proxy, fitres))
|
546
546
|
|
@@ -550,7 +550,7 @@ class SecAggPlusWorkflow:
|
|
550
550
|
self, grid: Grid, context: LegacyContext, state: WorkflowState
|
551
551
|
) -> bool:
|
552
552
|
"""Execute the 'unmask' stage."""
|
553
|
-
cfg = context.state.
|
553
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
554
554
|
current_round = cast(int, cfg[WorkflowKey.CURRENT_ROUND])
|
555
555
|
|
556
556
|
# Construct active node IDs and dead node IDs
|
@@ -560,17 +560,17 @@ class SecAggPlusWorkflow:
|
|
560
560
|
# Send secure IDs of active and dead clients and collect key shares from clients
|
561
561
|
def make(nid: int) -> Message:
|
562
562
|
neighbours = state.nid_to_neighbours[nid]
|
563
|
-
|
563
|
+
cfg_dict = {
|
564
564
|
Key.STAGE: Stage.UNMASK,
|
565
565
|
Key.ACTIVE_NODE_ID_LIST: list(neighbours & active_nids),
|
566
566
|
Key.DEAD_NODE_ID_LIST: list(neighbours & dead_nids),
|
567
567
|
}
|
568
|
-
|
569
|
-
content =
|
570
|
-
return
|
568
|
+
cfg_record = ConfigRecord(cfg_dict) # type: ignore
|
569
|
+
content = RecordDict({RECORD_KEY_CONFIGS: cfg_record})
|
570
|
+
return Message(
|
571
571
|
content=content,
|
572
|
-
message_type=MessageType.TRAIN,
|
573
572
|
dst_node_id=nid,
|
573
|
+
message_type=MessageType.TRAIN,
|
574
574
|
group_id=str(current_round),
|
575
575
|
)
|
576
576
|
|
@@ -599,7 +599,7 @@ class SecAggPlusWorkflow:
|
|
599
599
|
if msg.has_error():
|
600
600
|
state.failures.append(Exception(msg.error))
|
601
601
|
continue
|
602
|
-
res_dict = msg.content.
|
602
|
+
res_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
|
603
603
|
nids = cast(list[int], res_dict[Key.NODE_ID_LIST])
|
604
604
|
shares = cast(list[bytes], res_dict[Key.SHARE_LIST])
|
605
605
|
for owner_nid, share in zip(nids, shares):
|
@@ -676,10 +676,8 @@ class SecAggPlusWorkflow:
|
|
676
676
|
|
677
677
|
# Update the parameters and write history
|
678
678
|
if parameters_aggregated:
|
679
|
-
|
680
|
-
|
681
|
-
)
|
682
|
-
context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
|
679
|
+
arr_record = compat.parameters_to_arrayrecord(parameters_aggregated, True)
|
680
|
+
context.state.array_records[MAIN_PARAMS_RECORD] = arr_record
|
683
681
|
context.history.add_metrics_distributed_fit(
|
684
682
|
server_round=current_round, metrics=metrics_aggregated
|
685
683
|
)
|
flwr/simulation/app.py
CHANGED
@@ -47,7 +47,7 @@ from flwr.common.logger import (
|
|
47
47
|
stop_log_uploader,
|
48
48
|
)
|
49
49
|
from flwr.common.serde import (
|
50
|
-
|
50
|
+
config_record_from_proto,
|
51
51
|
context_from_proto,
|
52
52
|
context_to_proto,
|
53
53
|
fab_from_proto,
|
@@ -184,7 +184,7 @@ def run_simulation_process( # pylint: disable=R0914, disable=W0212, disable=R09
|
|
184
184
|
fed_opt_res: GetFederationOptionsResponse = conn._stub.GetFederationOptions(
|
185
185
|
GetFederationOptionsRequest(run_id=run.run_id)
|
186
186
|
)
|
187
|
-
federation_options =
|
187
|
+
federation_options = config_record_from_proto(
|
188
188
|
fed_opt_res.federation_options
|
189
189
|
)
|
190
190
|
|
@@ -105,8 +105,10 @@ def pool_size_from_resources(client_resources: dict[str, Union[int, float]]) ->
|
|
105
105
|
if not node_resources:
|
106
106
|
continue
|
107
107
|
|
108
|
-
|
109
|
-
|
108
|
+
# Fallback to zero when resource quantity is not configured on the ray node
|
109
|
+
# e.g.: node without GPU; head node set up not to run tasks (zero resources)
|
110
|
+
num_cpus = node_resources.get("CPU", 0)
|
111
|
+
num_gpus = node_resources.get("GPU", 0)
|
110
112
|
num_actors = int(num_cpus / client_resources["num_cpus"])
|
111
113
|
|
112
114
|
# If a GPU is present and client resources do require one
|
@@ -23,7 +23,7 @@ from flwr import common
|
|
23
23
|
from flwr.client import ClientFnExt
|
24
24
|
from flwr.client.client_app import ClientApp
|
25
25
|
from flwr.client.run_info_store import DeprecatedRunInfoStore
|
26
|
-
from flwr.common import DEFAULT_TTL, Message, Metadata,
|
26
|
+
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordDict, now
|
27
27
|
from flwr.common.constant import (
|
28
28
|
NUM_PARTITIONS_KEY,
|
29
29
|
PARTITION_ID_KEY,
|
@@ -31,15 +31,16 @@ from flwr.common.constant import (
|
|
31
31
|
MessageTypeLegacy,
|
32
32
|
)
|
33
33
|
from flwr.common.logger import log
|
34
|
-
from flwr.common.
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
34
|
+
from flwr.common.message import make_message
|
35
|
+
from flwr.common.recorddict_compat import (
|
36
|
+
evaluateins_to_recorddict,
|
37
|
+
fitins_to_recorddict,
|
38
|
+
getparametersins_to_recorddict,
|
39
|
+
getpropertiesins_to_recorddict,
|
40
|
+
recorddict_to_evaluateres,
|
41
|
+
recorddict_to_fitres,
|
42
|
+
recorddict_to_getparametersres,
|
43
|
+
recorddict_to_getpropertiesres,
|
43
44
|
)
|
44
45
|
from flwr.server.client_proxy import ClientProxy
|
45
46
|
from flwr.simulation.ray_transport.ray_actor import VirtualClientEngineActorPool
|
@@ -109,23 +110,24 @@ class RayActorClientProxy(ClientProxy):
|
|
109
110
|
|
110
111
|
return out_mssg
|
111
112
|
|
112
|
-
def
|
113
|
+
def _wrap_recorddict_in_message(
|
113
114
|
self,
|
114
|
-
|
115
|
+
recorddict: RecordDict,
|
115
116
|
message_type: str,
|
116
117
|
timeout: Optional[float],
|
117
118
|
group_id: Optional[int],
|
118
119
|
) -> Message:
|
119
|
-
"""Wrap a
|
120
|
-
return
|
121
|
-
content=
|
120
|
+
"""Wrap a RecordDict inside a Message."""
|
121
|
+
return make_message(
|
122
|
+
content=recorddict,
|
122
123
|
metadata=Metadata(
|
123
124
|
run_id=0,
|
124
125
|
message_id="",
|
125
126
|
group_id=str(group_id) if group_id is not None else "",
|
126
127
|
src_node_id=0,
|
127
128
|
dst_node_id=self.node_id,
|
128
|
-
|
129
|
+
reply_to_message_id="",
|
130
|
+
created_at=now().timestamp(),
|
129
131
|
ttl=timeout if timeout else DEFAULT_TTL,
|
130
132
|
message_type=message_type,
|
131
133
|
),
|
@@ -138,9 +140,9 @@ class RayActorClientProxy(ClientProxy):
|
|
138
140
|
group_id: Optional[int],
|
139
141
|
) -> common.GetPropertiesRes:
|
140
142
|
"""Return client's properties."""
|
141
|
-
|
142
|
-
message = self.
|
143
|
-
|
143
|
+
recorddict = getpropertiesins_to_recorddict(ins)
|
144
|
+
message = self._wrap_recorddict_in_message(
|
145
|
+
recorddict,
|
144
146
|
message_type=MessageTypeLegacy.GET_PROPERTIES,
|
145
147
|
timeout=timeout,
|
146
148
|
group_id=group_id,
|
@@ -148,7 +150,7 @@ class RayActorClientProxy(ClientProxy):
|
|
148
150
|
|
149
151
|
message_out = self._submit_job(message, timeout)
|
150
152
|
|
151
|
-
return
|
153
|
+
return recorddict_to_getpropertiesres(message_out.content)
|
152
154
|
|
153
155
|
def get_parameters(
|
154
156
|
self,
|
@@ -157,9 +159,9 @@ class RayActorClientProxy(ClientProxy):
|
|
157
159
|
group_id: Optional[int],
|
158
160
|
) -> common.GetParametersRes:
|
159
161
|
"""Return the current local model parameters."""
|
160
|
-
|
161
|
-
message = self.
|
162
|
-
|
162
|
+
recorddict = getparametersins_to_recorddict(ins)
|
163
|
+
message = self._wrap_recorddict_in_message(
|
164
|
+
recorddict,
|
163
165
|
message_type=MessageTypeLegacy.GET_PARAMETERS,
|
164
166
|
timeout=timeout,
|
165
167
|
group_id=group_id,
|
@@ -167,17 +169,17 @@ class RayActorClientProxy(ClientProxy):
|
|
167
169
|
|
168
170
|
message_out = self._submit_job(message, timeout)
|
169
171
|
|
170
|
-
return
|
172
|
+
return recorddict_to_getparametersres(message_out.content, keep_input=False)
|
171
173
|
|
172
174
|
def fit(
|
173
175
|
self, ins: common.FitIns, timeout: Optional[float], group_id: Optional[int]
|
174
176
|
) -> common.FitRes:
|
175
177
|
"""Train model parameters on the locally held dataset."""
|
176
|
-
|
178
|
+
recorddict = fitins_to_recorddict(
|
177
179
|
ins, keep_input=True
|
178
180
|
) # This must stay TRUE since ins are in-memory
|
179
|
-
message = self.
|
180
|
-
|
181
|
+
message = self._wrap_recorddict_in_message(
|
182
|
+
recorddict,
|
181
183
|
message_type=MessageType.TRAIN,
|
182
184
|
timeout=timeout,
|
183
185
|
group_id=group_id,
|
@@ -185,17 +187,17 @@ class RayActorClientProxy(ClientProxy):
|
|
185
187
|
|
186
188
|
message_out = self._submit_job(message, timeout)
|
187
189
|
|
188
|
-
return
|
190
|
+
return recorddict_to_fitres(message_out.content, keep_input=False)
|
189
191
|
|
190
192
|
def evaluate(
|
191
193
|
self, ins: common.EvaluateIns, timeout: Optional[float], group_id: Optional[int]
|
192
194
|
) -> common.EvaluateRes:
|
193
195
|
"""Evaluate model parameters on the locally held dataset."""
|
194
|
-
|
196
|
+
recorddict = evaluateins_to_recorddict(
|
195
197
|
ins, keep_input=True
|
196
198
|
) # This must stay TRUE since ins are in-memory
|
197
|
-
message = self.
|
198
|
-
|
199
|
+
message = self._wrap_recorddict_in_message(
|
200
|
+
recorddict,
|
199
201
|
message_type=MessageType.EVALUATE,
|
200
202
|
timeout=timeout,
|
201
203
|
group_id=group_id,
|
@@ -203,7 +205,7 @@ class RayActorClientProxy(ClientProxy):
|
|
203
205
|
|
204
206
|
message_out = self._submit_job(message, timeout)
|
205
207
|
|
206
|
-
return
|
208
|
+
return recorddict_to_evaluateres(message_out.content)
|
207
209
|
|
208
210
|
def reconnect(
|
209
211
|
self,
|
@@ -30,7 +30,7 @@ from typing import Any, Optional
|
|
30
30
|
from flwr.cli.config_utils import load_and_validate
|
31
31
|
from flwr.cli.utils import get_sha256_hash
|
32
32
|
from flwr.client import ClientApp
|
33
|
-
from flwr.common import Context, EventType,
|
33
|
+
from flwr.common import Context, EventType, RecordDict, event, log, now
|
34
34
|
from flwr.common.config import get_fused_config_from_dir, parse_config_args
|
35
35
|
from flwr.common.constant import RUN_ID_NUM_BYTES, Status
|
36
36
|
from flwr.common.logger import (
|
@@ -180,7 +180,7 @@ def run_simulation(
|
|
180
180
|
for values parsed to initialisation of backend, `client_resources`
|
181
181
|
to define the resources for clients, and `actor` to define the actor
|
182
182
|
parameters. Values supported in <value> are those included by
|
183
|
-
`flwr.common.typing.
|
183
|
+
`flwr.common.typing.ConfigRecordValues`.
|
184
184
|
|
185
185
|
enable_tf_gpu_growth : bool (default: False)
|
186
186
|
A boolean to indicate whether to enable GPU growth on the main thread. This is
|
@@ -260,7 +260,7 @@ def run_serverapp_th(
|
|
260
260
|
run_id=run_id,
|
261
261
|
node_id=0,
|
262
262
|
node_config={},
|
263
|
-
state=
|
263
|
+
state=RecordDict(),
|
264
264
|
run_config=_server_app_run_config,
|
265
265
|
)
|
266
266
|
|
@@ -333,7 +333,7 @@ def _main_loop(
|
|
333
333
|
run_id=run.run_id,
|
334
334
|
node_id=0,
|
335
335
|
node_config=UserConfig(),
|
336
|
-
state=
|
336
|
+
state=RecordDict(),
|
337
337
|
run_config=UserConfig(),
|
338
338
|
)
|
339
339
|
try:
|
@@ -546,7 +546,7 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser:
|
|
546
546
|
default="{}",
|
547
547
|
help='A JSON formatted stream, e.g \'{"<keyA>":<value>, "<keyB>":<value>}\' to '
|
548
548
|
"configure a backend. Values supported in <value> are those included by "
|
549
|
-
"`flwr.common.typing.
|
549
|
+
"`flwr.common.typing.ConfigRecordValues`. ",
|
550
550
|
)
|
551
551
|
parser.add_argument(
|
552
552
|
"--enable-tf-gpu-growth",
|
flwr/superexec/deployment.py
CHANGED
@@ -23,7 +23,7 @@ from typing import Optional
|
|
23
23
|
from typing_extensions import override
|
24
24
|
|
25
25
|
from flwr.cli.config_utils import get_fab_metadata
|
26
|
-
from flwr.common import
|
26
|
+
from flwr.common import ConfigRecord, Context, RecordDict
|
27
27
|
from flwr.common.constant import (
|
28
28
|
SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS,
|
29
29
|
Status,
|
@@ -141,7 +141,7 @@ class DeploymentEngine(Executor):
|
|
141
141
|
fab_id, fab_version = get_fab_metadata(fab.content)
|
142
142
|
|
143
143
|
run_id = self.linkstate.create_run(
|
144
|
-
fab_id, fab_version, fab_hash, override_config,
|
144
|
+
fab_id, fab_version, fab_hash, override_config, ConfigRecord()
|
145
145
|
)
|
146
146
|
return run_id
|
147
147
|
|
@@ -149,7 +149,7 @@ class DeploymentEngine(Executor):
|
|
149
149
|
"""Register a Context for a Run."""
|
150
150
|
# Create an empty context for the Run
|
151
151
|
context = Context(
|
152
|
-
run_id=run_id, node_id=0, node_config={}, state=
|
152
|
+
run_id=run_id, node_id=0, node_config={}, state=RecordDict(), run_config={}
|
153
153
|
)
|
154
154
|
|
155
155
|
# Register the context at the LinkState
|
@@ -160,7 +160,7 @@ class DeploymentEngine(Executor):
|
|
160
160
|
self,
|
161
161
|
fab_file: bytes,
|
162
162
|
override_config: UserConfig,
|
163
|
-
federation_options:
|
163
|
+
federation_options: ConfigRecord,
|
164
164
|
) -> Optional[int]:
|
165
165
|
"""Start run using the Flower Deployment Engine."""
|
166
166
|
run_id = None
|
flwr/superexec/exec_servicer.py
CHANGED
@@ -28,7 +28,7 @@ from flwr.common.auth_plugin import ExecAuthPlugin
|
|
28
28
|
from flwr.common.constant import LOG_STREAM_INTERVAL, Status, SubStatus
|
29
29
|
from flwr.common.logger import log
|
30
30
|
from flwr.common.serde import (
|
31
|
-
|
31
|
+
config_record_from_proto,
|
32
32
|
run_to_proto,
|
33
33
|
user_config_from_proto,
|
34
34
|
)
|
@@ -79,7 +79,7 @@ class ExecServicer(exec_pb2_grpc.ExecServicer):
|
|
79
79
|
run_id = self.executor.start_run(
|
80
80
|
request.fab.content,
|
81
81
|
user_config_from_proto(request.override_config),
|
82
|
-
|
82
|
+
config_record_from_proto(request.federation_options),
|
83
83
|
)
|
84
84
|
|
85
85
|
if run_id is None:
|
flwr/superexec/executor.py
CHANGED
@@ -20,7 +20,7 @@ from dataclasses import dataclass, field
|
|
20
20
|
from subprocess import Popen
|
21
21
|
from typing import Optional
|
22
22
|
|
23
|
-
from flwr.common import
|
23
|
+
from flwr.common import ConfigRecord
|
24
24
|
from flwr.common.typing import UserConfig
|
25
25
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
26
26
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
@@ -73,7 +73,7 @@ class Executor(ABC):
|
|
73
73
|
self,
|
74
74
|
fab_file: bytes,
|
75
75
|
override_config: UserConfig,
|
76
|
-
federation_options:
|
76
|
+
federation_options: ConfigRecord,
|
77
77
|
) -> Optional[int]:
|
78
78
|
"""Start a run using the given Flower FAB ID and version.
|
79
79
|
|
@@ -86,7 +86,7 @@ class Executor(ABC):
|
|
86
86
|
The Flower App Bundle file bytes.
|
87
87
|
override_config: UserConfig
|
88
88
|
The config overrides dict sent by the user (using `flwr run`).
|
89
|
-
federation_options:
|
89
|
+
federation_options: ConfigRecord
|
90
90
|
The federation options sent by the user (using `flwr run`).
|
91
91
|
|
92
92
|
Returns
|
flwr/superexec/simulation.py
CHANGED
@@ -22,7 +22,7 @@ from typing import Optional
|
|
22
22
|
from typing_extensions import override
|
23
23
|
|
24
24
|
from flwr.cli.config_utils import get_fab_metadata
|
25
|
-
from flwr.common import
|
25
|
+
from flwr.common import ConfigRecord, Context, RecordDict
|
26
26
|
from flwr.common.logger import log
|
27
27
|
from flwr.common.typing import Fab, UserConfig
|
28
28
|
from flwr.server.superlink.ffs import Ffs
|
@@ -76,7 +76,7 @@ class SimulationEngine(Executor):
|
|
76
76
|
self,
|
77
77
|
fab_file: bytes,
|
78
78
|
override_config: UserConfig,
|
79
|
-
federation_options:
|
79
|
+
federation_options: ConfigRecord,
|
80
80
|
) -> Optional[int]:
|
81
81
|
"""Start run using the Flower Simulation Engine."""
|
82
82
|
try:
|
@@ -104,7 +104,7 @@ class SimulationEngine(Executor):
|
|
104
104
|
run_id=run_id,
|
105
105
|
node_id=0,
|
106
106
|
node_config={},
|
107
|
-
state=
|
107
|
+
state=RecordDict(),
|
108
108
|
run_config={},
|
109
109
|
)
|
110
110
|
|