flwr-nightly 1.17.0.dev20250319__py3-none-any.whl → 1.17.0.dev20250321__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/run/run.py +5 -9
- flwr/client/app.py +6 -4
- flwr/client/client_app.py +10 -12
- flwr/client/clientapp/app.py +2 -2
- flwr/client/grpc_client/connection.py +24 -21
- flwr/client/message_handler/message_handler.py +27 -27
- flwr/client/mod/__init__.py +2 -2
- flwr/client/mod/centraldp_mods.py +7 -7
- flwr/client/mod/comms_mods.py +16 -22
- flwr/client/mod/localdp_mod.py +4 -4
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
- flwr/client/run_info_store.py +2 -2
- flwr/common/__init__.py +12 -4
- flwr/common/config.py +4 -4
- flwr/common/constant.py +1 -1
- flwr/common/context.py +4 -4
- flwr/common/message.py +269 -101
- flwr/common/record/__init__.py +8 -4
- flwr/common/record/{parametersrecord.py → arrayrecord.py} +75 -32
- flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
- flwr/common/record/conversion_utils.py +1 -1
- flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
- flwr/common/record/recorddict.py +288 -0
- flwr/common/recorddict_compat.py +410 -0
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/serde.py +66 -71
- flwr/common/typing.py +8 -8
- flwr/proto/exec_pb2.py +3 -3
- flwr/proto/exec_pb2.pyi +3 -3
- flwr/proto/message_pb2.py +12 -12
- flwr/proto/message_pb2.pyi +9 -9
- flwr/proto/recorddict_pb2.py +70 -0
- flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
- flwr/proto/run_pb2.py +31 -31
- flwr/proto/run_pb2.pyi +3 -3
- flwr/server/compat/grid_client_proxy.py +31 -31
- flwr/server/grid/grid.py +3 -3
- flwr/server/grid/grpc_grid.py +15 -23
- flwr/server/grid/inmemory_grid.py +14 -20
- flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
- flwr/server/superlink/fleet/vce/vce_api.py +1 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +5 -5
- flwr/server/superlink/linkstate/linkstate.py +4 -4
- flwr/server/superlink/linkstate/sqlite_linkstate.py +21 -25
- flwr/server/superlink/linkstate/utils.py +18 -15
- flwr/server/superlink/serverappio/serverappio_servicer.py +3 -3
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
- flwr/server/utils/validator.py +4 -4
- flwr/server/workflow/default_workflows.py +34 -41
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +37 -39
- flwr/simulation/app.py +2 -2
- flwr/simulation/ray_transport/ray_actor.py +4 -2
- flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
- flwr/simulation/run_simulation.py +5 -5
- flwr/superexec/deployment.py +4 -4
- flwr/superexec/exec_servicer.py +2 -2
- flwr/superexec/executor.py +3 -3
- flwr/superexec/simulation.py +3 -3
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/METADATA +1 -1
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/RECORD +66 -66
- flwr/common/record/recordset.py +0 -209
- flwr/common/recordset_compat.py +0 -418
- flwr/proto/recordset_pb2.py +0 -70
- /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
- /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/entry_points.txt +0 -0
@@ -22,15 +22,15 @@ from typing import Any, cast
|
|
22
22
|
|
23
23
|
from flwr.client.typing import ClientAppCallable
|
24
24
|
from flwr.common import (
|
25
|
-
|
25
|
+
ConfigRecord,
|
26
26
|
Context,
|
27
27
|
Message,
|
28
28
|
Parameters,
|
29
|
-
|
29
|
+
RecordDict,
|
30
30
|
ndarray_to_bytes,
|
31
31
|
parameters_to_ndarrays,
|
32
32
|
)
|
33
|
-
from flwr.common import
|
33
|
+
from flwr.common import recorddict_compat as compat
|
34
34
|
from flwr.common.constant import MessageType
|
35
35
|
from flwr.common.logger import log
|
36
36
|
from flwr.common.secure_aggregation.crypto.shamir import create_shares
|
@@ -63,7 +63,7 @@ from flwr.common.secure_aggregation.secaggplus_utils import (
|
|
63
63
|
share_keys_plaintext_concat,
|
64
64
|
share_keys_plaintext_separate,
|
65
65
|
)
|
66
|
-
from flwr.common.typing import
|
66
|
+
from flwr.common.typing import ConfigRecordValues
|
67
67
|
|
68
68
|
|
69
69
|
@dataclass
|
@@ -97,7 +97,7 @@ class SecAggPlusState:
|
|
97
97
|
ss2_dict: dict[int, bytes] = field(default_factory=dict)
|
98
98
|
public_keys_dict: dict[int, tuple[bytes, bytes]] = field(default_factory=dict)
|
99
99
|
|
100
|
-
def __init__(self, **kwargs:
|
100
|
+
def __init__(self, **kwargs: ConfigRecordValues) -> None:
|
101
101
|
for k, v in kwargs.items():
|
102
102
|
if k.endswith(":V"):
|
103
103
|
continue
|
@@ -115,7 +115,7 @@ class SecAggPlusState:
|
|
115
115
|
new_v = dict(zip(keys, values))
|
116
116
|
self.__setattr__(k, new_v)
|
117
117
|
|
118
|
-
def to_dict(self) -> dict[str,
|
118
|
+
def to_dict(self) -> dict[str, ConfigRecordValues]:
|
119
119
|
"""Convert the state to a dictionary."""
|
120
120
|
ret = vars(self)
|
121
121
|
for k in list(ret.keys()):
|
@@ -144,13 +144,13 @@ def secaggplus_mod(
|
|
144
144
|
return call_next(msg, ctxt)
|
145
145
|
|
146
146
|
# Retrieve local state
|
147
|
-
if RECORD_KEY_STATE not in ctxt.state.
|
148
|
-
ctxt.state.
|
149
|
-
state_dict = ctxt.state.
|
147
|
+
if RECORD_KEY_STATE not in ctxt.state.config_records:
|
148
|
+
ctxt.state.config_records[RECORD_KEY_STATE] = ConfigRecord({})
|
149
|
+
state_dict = ctxt.state.config_records[RECORD_KEY_STATE]
|
150
150
|
state = SecAggPlusState(**state_dict)
|
151
151
|
|
152
152
|
# Retrieve incoming configs
|
153
|
-
configs = msg.content.
|
153
|
+
configs = msg.content.config_records[RECORD_KEY_CONFIGS]
|
154
154
|
|
155
155
|
# Check the validity of the next stage
|
156
156
|
check_stage(state.current_stage, configs)
|
@@ -162,7 +162,7 @@ def secaggplus_mod(
|
|
162
162
|
check_configs(state.current_stage, configs)
|
163
163
|
|
164
164
|
# Execute
|
165
|
-
out_content =
|
165
|
+
out_content = RecordDict()
|
166
166
|
if state.current_stage == Stage.SETUP:
|
167
167
|
state.nid = msg.metadata.dst_node_id
|
168
168
|
res = _setup(state, configs)
|
@@ -171,31 +171,31 @@ def secaggplus_mod(
|
|
171
171
|
elif state.current_stage == Stage.COLLECT_MASKED_VECTORS:
|
172
172
|
out_msg = call_next(msg, ctxt)
|
173
173
|
out_content = out_msg.content
|
174
|
-
fitres = compat.
|
174
|
+
fitres = compat.recorddict_to_fitres(out_content, keep_input=True)
|
175
175
|
res = _collect_masked_vectors(
|
176
176
|
state, configs, fitres.num_examples, fitres.parameters
|
177
177
|
)
|
178
|
-
for
|
179
|
-
|
178
|
+
for arr_record in out_content.array_records.values():
|
179
|
+
arr_record.clear()
|
180
180
|
elif state.current_stage == Stage.UNMASK:
|
181
181
|
res = _unmask(state, configs)
|
182
182
|
else:
|
183
183
|
raise ValueError(f"Unknown SecAgg/SecAgg+ stage: {state.current_stage}")
|
184
184
|
|
185
185
|
# Save state
|
186
|
-
ctxt.state.
|
186
|
+
ctxt.state.config_records[RECORD_KEY_STATE] = ConfigRecord(state.to_dict())
|
187
187
|
|
188
188
|
# Return message
|
189
|
-
out_content.
|
190
|
-
return
|
189
|
+
out_content.config_records[RECORD_KEY_CONFIGS] = ConfigRecord(res, False)
|
190
|
+
return Message(out_content, reply_to=msg)
|
191
191
|
|
192
192
|
|
193
|
-
def check_stage(current_stage: str, configs:
|
193
|
+
def check_stage(current_stage: str, configs: ConfigRecord) -> None:
|
194
194
|
"""Check the validity of the next stage."""
|
195
195
|
# Check the existence of Config.STAGE
|
196
196
|
if Key.STAGE not in configs:
|
197
197
|
raise KeyError(
|
198
|
-
f"The required key '{Key.STAGE}' is missing from the
|
198
|
+
f"The required key '{Key.STAGE}' is missing from the ConfigRecord."
|
199
199
|
)
|
200
200
|
|
201
201
|
# Check the value type of the Config.STAGE
|
@@ -223,7 +223,7 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
|
|
223
223
|
|
224
224
|
|
225
225
|
# pylint: disable-next=too-many-branches
|
226
|
-
def check_configs(stage: str, configs:
|
226
|
+
def check_configs(stage: str, configs: ConfigRecord) -> None:
|
227
227
|
"""Check the validity of the configs."""
|
228
228
|
# Check configs for the setup stage
|
229
229
|
if stage == Stage.SETUP:
|
@@ -239,7 +239,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
239
239
|
if key not in configs:
|
240
240
|
raise KeyError(
|
241
241
|
f"Stage {Stage.SETUP}: the required key '{key}' is "
|
242
|
-
"missing from the
|
242
|
+
"missing from the ConfigRecord."
|
243
243
|
)
|
244
244
|
# Bool is a subclass of int in Python,
|
245
245
|
# so `isinstance(v, int)` will return True even if v is a boolean.
|
@@ -272,7 +272,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
272
272
|
raise KeyError(
|
273
273
|
f"Stage {Stage.COLLECT_MASKED_VECTORS}: "
|
274
274
|
f"the required key '{key}' is "
|
275
|
-
"missing from the
|
275
|
+
"missing from the ConfigRecord."
|
276
276
|
)
|
277
277
|
if not isinstance(configs[key], list) or any(
|
278
278
|
elm
|
@@ -295,7 +295,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
295
295
|
raise KeyError(
|
296
296
|
f"Stage {Stage.UNMASK}: "
|
297
297
|
f"the required key '{key}' is "
|
298
|
-
"missing from the
|
298
|
+
"missing from the ConfigRecord."
|
299
299
|
)
|
300
300
|
if not isinstance(configs[key], list) or any(
|
301
301
|
elm
|
@@ -313,8 +313,8 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
313
313
|
|
314
314
|
|
315
315
|
def _setup(
|
316
|
-
state: SecAggPlusState, configs:
|
317
|
-
) -> dict[str,
|
316
|
+
state: SecAggPlusState, configs: ConfigRecord
|
317
|
+
) -> dict[str, ConfigRecordValues]:
|
318
318
|
# Assigning parameter values to object fields
|
319
319
|
sec_agg_param_dict = configs
|
320
320
|
state.sample_num = cast(int, sec_agg_param_dict[Key.SAMPLE_NUMBER])
|
@@ -349,8 +349,8 @@ def _setup(
|
|
349
349
|
|
350
350
|
# pylint: disable-next=too-many-locals
|
351
351
|
def _share_keys(
|
352
|
-
state: SecAggPlusState, configs:
|
353
|
-
) -> dict[str,
|
352
|
+
state: SecAggPlusState, configs: ConfigRecord
|
353
|
+
) -> dict[str, ConfigRecordValues]:
|
354
354
|
named_bytes_tuples = cast(dict[str, tuple[bytes, bytes]], configs)
|
355
355
|
key_dict = {int(sid): (pk1, pk2) for sid, (pk1, pk2) in named_bytes_tuples.items()}
|
356
356
|
log(DEBUG, "Node %d: starting stage 1...", state.nid)
|
@@ -412,10 +412,10 @@ def _share_keys(
|
|
412
412
|
# pylint: disable-next=too-many-locals
|
413
413
|
def _collect_masked_vectors(
|
414
414
|
state: SecAggPlusState,
|
415
|
-
configs:
|
415
|
+
configs: ConfigRecord,
|
416
416
|
num_examples: int,
|
417
417
|
updated_parameters: Parameters,
|
418
|
-
) -> dict[str,
|
418
|
+
) -> dict[str, ConfigRecordValues]:
|
419
419
|
log(DEBUG, "Node %d: starting stage 2...", state.nid)
|
420
420
|
available_clients: list[int] = []
|
421
421
|
ciphertexts = cast(list[bytes], configs[Key.CIPHERTEXT_LIST])
|
@@ -498,8 +498,8 @@ def _collect_masked_vectors(
|
|
498
498
|
|
499
499
|
|
500
500
|
def _unmask(
|
501
|
-
state: SecAggPlusState, configs:
|
502
|
-
) -> dict[str,
|
501
|
+
state: SecAggPlusState, configs: ConfigRecord
|
502
|
+
) -> dict[str, ConfigRecordValues]:
|
503
503
|
log(DEBUG, "Node %d: starting stage 3...", state.nid)
|
504
504
|
|
505
505
|
active_nids = cast(list[int], configs[Key.ACTIVE_NODE_ID_LIST])
|
flwr/client/run_info_store.py
CHANGED
@@ -19,7 +19,7 @@ from dataclasses import dataclass
|
|
19
19
|
from pathlib import Path
|
20
20
|
from typing import Optional
|
21
21
|
|
22
|
-
from flwr.common import Context,
|
22
|
+
from flwr.common import Context, RecordDict
|
23
23
|
from flwr.common.config import (
|
24
24
|
get_fused_config,
|
25
25
|
get_fused_config_from_dir,
|
@@ -86,7 +86,7 @@ class DeprecatedRunInfoStore:
|
|
86
86
|
run_id=run_id,
|
87
87
|
node_id=self.node_id,
|
88
88
|
node_config=self.node_config,
|
89
|
-
state=
|
89
|
+
state=RecordDict(),
|
90
90
|
run_config=initial_run_config.copy(),
|
91
91
|
),
|
92
92
|
)
|
flwr/common/__init__.py
CHANGED
@@ -31,9 +31,13 @@ from .parameter import ndarray_to_bytes as ndarray_to_bytes
|
|
31
31
|
from .parameter import ndarrays_to_parameters as ndarrays_to_parameters
|
32
32
|
from .parameter import parameters_to_ndarrays as parameters_to_ndarrays
|
33
33
|
from .record import Array as Array
|
34
|
+
from .record import ArrayRecord as ArrayRecord
|
35
|
+
from .record import ConfigRecord as ConfigRecord
|
34
36
|
from .record import ConfigsRecord as ConfigsRecord
|
37
|
+
from .record import MetricRecord as MetricRecord
|
35
38
|
from .record import MetricsRecord as MetricsRecord
|
36
39
|
from .record import ParametersRecord as ParametersRecord
|
40
|
+
from .record import RecordDict as RecordDict
|
37
41
|
from .record import RecordSet as RecordSet
|
38
42
|
from .record import array_from_numpy as array_from_numpy
|
39
43
|
from .telemetry import EventType as EventType
|
@@ -41,7 +45,7 @@ from .telemetry import event as event
|
|
41
45
|
from .typing import ClientMessage as ClientMessage
|
42
46
|
from .typing import Code as Code
|
43
47
|
from .typing import Config as Config
|
44
|
-
from .typing import
|
48
|
+
from .typing import ConfigRecordValues as ConfigRecordValues
|
45
49
|
from .typing import DisconnectRes as DisconnectRes
|
46
50
|
from .typing import EvaluateIns as EvaluateIns
|
47
51
|
from .typing import EvaluateRes as EvaluateRes
|
@@ -51,9 +55,9 @@ from .typing import GetParametersIns as GetParametersIns
|
|
51
55
|
from .typing import GetParametersRes as GetParametersRes
|
52
56
|
from .typing import GetPropertiesIns as GetPropertiesIns
|
53
57
|
from .typing import GetPropertiesRes as GetPropertiesRes
|
58
|
+
from .typing import MetricRecordValues as MetricRecordValues
|
54
59
|
from .typing import Metrics as Metrics
|
55
60
|
from .typing import MetricsAggregationFn as MetricsAggregationFn
|
56
|
-
from .typing import MetricsRecordValues as MetricsRecordValues
|
57
61
|
from .typing import NDArray as NDArray
|
58
62
|
from .typing import NDArrays as NDArrays
|
59
63
|
from .typing import Parameters as Parameters
|
@@ -65,11 +69,13 @@ from .typing import Status as Status
|
|
65
69
|
|
66
70
|
__all__ = [
|
67
71
|
"Array",
|
72
|
+
"ArrayRecord",
|
68
73
|
"ClientMessage",
|
69
74
|
"Code",
|
70
75
|
"Config",
|
76
|
+
"ConfigRecord",
|
77
|
+
"ConfigRecordValues",
|
71
78
|
"ConfigsRecord",
|
72
|
-
"ConfigsRecordValues",
|
73
79
|
"Context",
|
74
80
|
"DEFAULT_TTL",
|
75
81
|
"DisconnectRes",
|
@@ -88,16 +94,18 @@ __all__ = [
|
|
88
94
|
"MessageType",
|
89
95
|
"MessageTypeLegacy",
|
90
96
|
"Metadata",
|
97
|
+
"MetricRecord",
|
98
|
+
"MetricRecordValues",
|
91
99
|
"Metrics",
|
92
100
|
"MetricsAggregationFn",
|
93
101
|
"MetricsRecord",
|
94
|
-
"MetricsRecordValues",
|
95
102
|
"NDArray",
|
96
103
|
"NDArrays",
|
97
104
|
"Parameters",
|
98
105
|
"ParametersRecord",
|
99
106
|
"Properties",
|
100
107
|
"ReconnectIns",
|
108
|
+
"RecordDict",
|
101
109
|
"RecordSet",
|
102
110
|
"Scalar",
|
103
111
|
"ServerMessage",
|
flwr/common/config.py
CHANGED
@@ -34,7 +34,7 @@ from flwr.common.constant import (
|
|
34
34
|
)
|
35
35
|
from flwr.common.typing import Run, UserConfig, UserConfigValue
|
36
36
|
|
37
|
-
from . import
|
37
|
+
from . import ConfigRecord, object_ref
|
38
38
|
|
39
39
|
T_dict = TypeVar("T_dict", bound=dict[str, Any]) # pylint: disable=invalid-name
|
40
40
|
|
@@ -260,9 +260,9 @@ def get_metadata_from_config(config: dict[str, Any]) -> tuple[str, str]:
|
|
260
260
|
)
|
261
261
|
|
262
262
|
|
263
|
-
def
|
264
|
-
"""Construct a `
|
265
|
-
c_record =
|
263
|
+
def user_config_to_configrecord(config: UserConfig) -> ConfigRecord:
|
264
|
+
"""Construct a `ConfigRecord` out of a `UserConfig`."""
|
265
|
+
c_record = ConfigRecord()
|
266
266
|
for k, v in config.items():
|
267
267
|
c_record[k] = v
|
268
268
|
|
flwr/common/constant.py
CHANGED
@@ -121,7 +121,7 @@ TIMESTAMP_HEADER = "flwr-timestamp"
|
|
121
121
|
TIMESTAMP_TOLERANCE = 10 # General tolerance for timestamp verification
|
122
122
|
SYSTEM_TIME_TOLERANCE = 5 # Allowance for system time drift
|
123
123
|
|
124
|
-
# Constants for
|
124
|
+
# Constants for ArrayRecord
|
125
125
|
GC_THRESHOLD = 200_000_000 # 200 MB
|
126
126
|
|
127
127
|
|
flwr/common/context.py
CHANGED
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
from dataclasses import dataclass
|
19
19
|
|
20
|
-
from .record import
|
20
|
+
from .record import RecordDict
|
21
21
|
from .typing import UserConfig
|
22
22
|
|
23
23
|
|
@@ -34,7 +34,7 @@ class Context:
|
|
34
34
|
node_config : UserConfig
|
35
35
|
A config (key/value mapping) unique to the node and independent of the
|
36
36
|
`run_config`. This config persists across all runs this node participates in.
|
37
|
-
state :
|
37
|
+
state : RecordDict
|
38
38
|
Holds records added by the entity in a given `run_id` and that will stay local.
|
39
39
|
This means that the data it holds will never leave the system it's running from.
|
40
40
|
This can be used as an intermediate storage or scratchpad when
|
@@ -50,7 +50,7 @@ class Context:
|
|
50
50
|
run_id: int
|
51
51
|
node_id: int
|
52
52
|
node_config: UserConfig
|
53
|
-
state:
|
53
|
+
state: RecordDict
|
54
54
|
run_config: UserConfig
|
55
55
|
|
56
56
|
def __init__( # pylint: disable=too-many-arguments, too-many-positional-arguments
|
@@ -58,7 +58,7 @@ class Context:
|
|
58
58
|
run_id: int,
|
59
59
|
node_id: int,
|
60
60
|
node_config: UserConfig,
|
61
|
-
state:
|
61
|
+
state: RecordDict,
|
62
62
|
run_config: UserConfig,
|
63
63
|
) -> None:
|
64
64
|
self.run_id = run_id
|