flwr-nightly 1.17.0.dev20250320__py3-none-any.whl → 1.17.0.dev20250322__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/run/run.py +5 -9
- flwr/client/client_app.py +10 -12
- flwr/client/grpc_client/connection.py +3 -3
- flwr/client/message_handler/message_handler.py +3 -3
- flwr/client/mod/__init__.py +2 -2
- flwr/client/mod/comms_mods.py +16 -22
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +26 -26
- flwr/common/__init__.py +10 -4
- flwr/common/config.py +4 -4
- flwr/common/constant.py +1 -1
- flwr/common/record/__init__.py +6 -3
- flwr/common/record/{parametersrecord.py → arrayrecord.py} +74 -31
- flwr/common/record/{configsrecord.py → configrecord.py} +73 -27
- flwr/common/record/conversion_utils.py +1 -1
- flwr/common/record/{metricsrecord.py → metricrecord.py} +77 -31
- flwr/common/record/recorddict.py +95 -56
- flwr/common/recorddict_compat.py +54 -62
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/serde.py +42 -43
- flwr/common/typing.py +8 -8
- flwr/proto/exec_pb2.py +30 -30
- flwr/proto/exec_pb2.pyi +2 -2
- flwr/proto/recorddict_pb2.py +29 -29
- flwr/proto/recorddict_pb2.pyi +33 -33
- flwr/proto/run_pb2.py +2 -2
- flwr/proto/run_pb2.pyi +2 -2
- flwr/server/compat/grid_client_proxy.py +1 -1
- flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
- flwr/server/superlink/linkstate/in_memory_linkstate.py +4 -4
- flwr/server/superlink/linkstate/linkstate.py +4 -4
- flwr/server/superlink/linkstate/sqlite_linkstate.py +7 -7
- flwr/server/superlink/linkstate/utils.py +9 -9
- flwr/server/superlink/serverappio/serverappio_servicer.py +2 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
- flwr/server/workflow/default_workflows.py +27 -34
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +32 -34
- flwr/simulation/app.py +2 -2
- flwr/simulation/ray_transport/ray_actor.py +4 -2
- flwr/simulation/run_simulation.py +2 -2
- flwr/superexec/deployment.py +3 -3
- flwr/superexec/exec_servicer.py +2 -2
- flwr/superexec/executor.py +3 -3
- flwr/superexec/simulation.py +2 -2
- {flwr_nightly-1.17.0.dev20250320.dist-info → flwr_nightly-1.17.0.dev20250322.dist-info}/METADATA +1 -1
- {flwr_nightly-1.17.0.dev20250320.dist-info → flwr_nightly-1.17.0.dev20250322.dist-info}/RECORD +49 -49
- {flwr_nightly-1.17.0.dev20250320.dist-info → flwr_nightly-1.17.0.dev20250322.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.17.0.dev20250320.dist-info → flwr_nightly-1.17.0.dev20250322.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.17.0.dev20250320.dist-info → flwr_nightly-1.17.0.dev20250322.dist-info}/entry_points.txt +0 -0
flwr/cli/run/run.py
CHANGED
@@ -35,15 +35,11 @@ from flwr.cli.constant import FEDERATION_CONFIG_HELP_MESSAGE
|
|
35
35
|
from flwr.common.config import (
|
36
36
|
flatten_dict,
|
37
37
|
parse_config_args,
|
38
|
-
|
38
|
+
user_config_to_configrecord,
|
39
39
|
)
|
40
40
|
from flwr.common.constant import CliOutputFormat
|
41
41
|
from flwr.common.logger import print_json_error, redirect_output, restore_output
|
42
|
-
from flwr.common.serde import
|
43
|
-
configs_record_to_proto,
|
44
|
-
fab_to_proto,
|
45
|
-
user_config_to_proto,
|
46
|
-
)
|
42
|
+
from flwr.common.serde import config_record_to_proto, fab_to_proto, user_config_to_proto
|
47
43
|
from flwr.common.typing import Fab
|
48
44
|
from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
|
49
45
|
from flwr.proto.exec_pb2_grpc import ExecStub
|
@@ -171,14 +167,14 @@ def _run_with_exec_api(
|
|
171
167
|
|
172
168
|
fab = Fab(fab_hash, content)
|
173
169
|
|
174
|
-
# Construct a `
|
170
|
+
# Construct a `ConfigRecord` out of a flattened `UserConfig`
|
175
171
|
fed_conf = flatten_dict(federation_config.get("options", {}))
|
176
|
-
c_record =
|
172
|
+
c_record = user_config_to_configrecord(fed_conf)
|
177
173
|
|
178
174
|
req = StartRunRequest(
|
179
175
|
fab=fab_to_proto(fab),
|
180
176
|
override_config=user_config_to_proto(parse_config_args(config_overrides)),
|
181
|
-
federation_options=
|
177
|
+
federation_options=config_record_to_proto(c_record),
|
182
178
|
)
|
183
179
|
with unauthenticated_exc_handler():
|
184
180
|
res = stub.StartRun(req)
|
flwr/client/client_app.py
CHANGED
@@ -189,7 +189,7 @@ class ClientApp:
|
|
189
189
|
>>> def train(message: Message, context: Context) -> Message:
|
190
190
|
>>> print("Executing default train function")
|
191
191
|
>>> # Create and return an echo reply message
|
192
|
-
>>> return message.
|
192
|
+
>>> return Message(message.content, reply_to=message)
|
193
193
|
|
194
194
|
Registering a train function with a custom action name:
|
195
195
|
|
@@ -200,7 +200,7 @@ class ClientApp:
|
|
200
200
|
>>> @app.train("custom_action")
|
201
201
|
>>> def custom_action(message: Message, context: Context) -> Message:
|
202
202
|
>>> print("Executing train function for custom action")
|
203
|
-
>>> return message.
|
203
|
+
>>> return Message(message.content, reply_to=message)
|
204
204
|
|
205
205
|
Registering a train function with a function-specific Flower Mod:
|
206
206
|
|
@@ -213,7 +213,7 @@ class ClientApp:
|
|
213
213
|
>>> def train(message: Message, context: Context) -> Message:
|
214
214
|
>>> print("Executing train function with message size mod")
|
215
215
|
>>> # Create and return an echo reply message
|
216
|
-
>>> return message.
|
216
|
+
>>> return Message(message.content, reply_to=message)
|
217
217
|
"""
|
218
218
|
return _get_decorator(self, MessageType.TRAIN, action, mods)
|
219
219
|
|
@@ -244,7 +244,7 @@ class ClientApp:
|
|
244
244
|
>>> def evaluate(message: Message, context: Context) -> Message:
|
245
245
|
>>> print("Executing default evaluate function")
|
246
246
|
>>> # Create and return an echo reply message
|
247
|
-
>>> return message.
|
247
|
+
>>> return Message(message.content, reply_to=message)
|
248
248
|
|
249
249
|
Registering an evaluate function with a custom action name:
|
250
250
|
|
@@ -255,7 +255,7 @@ class ClientApp:
|
|
255
255
|
>>> @app.evaluate("custom_action")
|
256
256
|
>>> def custom_action(message: Message, context: Context) -> Message:
|
257
257
|
>>> print("Executing evaluate function for custom action")
|
258
|
-
>>> return message.
|
258
|
+
>>> return Message(message.content, reply_to=message)
|
259
259
|
|
260
260
|
Registering an evaluate function with a function-specific Flower Mod:
|
261
261
|
|
@@ -268,7 +268,7 @@ class ClientApp:
|
|
268
268
|
>>> def evaluate(message: Message, context: Context) -> Message:
|
269
269
|
>>> print("Executing evaluate function with message size mod")
|
270
270
|
>>> # Create and return an echo reply message
|
271
|
-
>>> return message.
|
271
|
+
>>> return Message(message.content, reply_to=message)
|
272
272
|
"""
|
273
273
|
return _get_decorator(self, MessageType.EVALUATE, action, mods)
|
274
274
|
|
@@ -299,7 +299,7 @@ class ClientApp:
|
|
299
299
|
>>> def query(message: Message, context: Context) -> Message:
|
300
300
|
>>> print("Executing default query function")
|
301
301
|
>>> # Create and return an echo reply message
|
302
|
-
>>> return message.
|
302
|
+
>>> return Message(message.content, reply_to=message)
|
303
303
|
|
304
304
|
Registering a query function with a custom action name:
|
305
305
|
|
@@ -310,7 +310,7 @@ class ClientApp:
|
|
310
310
|
>>> @app.query("custom_action")
|
311
311
|
>>> def custom_action(message: Message, context: Context) -> Message:
|
312
312
|
>>> print("Executing query function for custom action")
|
313
|
-
>>> return message.
|
313
|
+
>>> return Message(message.content, reply_to=message)
|
314
314
|
|
315
315
|
Registering a query function with a function-specific Flower Mod:
|
316
316
|
|
@@ -323,7 +323,7 @@ class ClientApp:
|
|
323
323
|
>>> def query(message: Message, context: Context) -> Message:
|
324
324
|
>>> print("Executing query function with message size mod")
|
325
325
|
>>> # Create and return an echo reply message
|
326
|
-
>>> return message.
|
326
|
+
>>> return Message(message.content, reply_to=message)
|
327
327
|
"""
|
328
328
|
return _get_decorator(self, MessageType.QUERY, action, mods)
|
329
329
|
|
@@ -454,8 +454,6 @@ def _registration_error(fn_name: str) -> ValueError:
|
|
454
454
|
>>> def {fn_name}(message: Message, context: Context) -> Message:
|
455
455
|
>>> print("ClientApp {fn_name} running")
|
456
456
|
>>> # Create and return an echo reply message
|
457
|
-
>>> return message.
|
458
|
-
>>> content=message.content()
|
459
|
-
>>> )
|
457
|
+
>>> return Message(message.content, reply_to=message)
|
460
458
|
""",
|
461
459
|
)
|
@@ -28,7 +28,7 @@ from cryptography.hazmat.primitives.asymmetric import ec
|
|
28
28
|
from flwr.common import (
|
29
29
|
DEFAULT_TTL,
|
30
30
|
GRPC_MAX_MESSAGE_LENGTH,
|
31
|
-
|
31
|
+
ConfigRecord,
|
32
32
|
Message,
|
33
33
|
Metadata,
|
34
34
|
RecordDict,
|
@@ -166,7 +166,7 @@ def grpc_connection( # pylint: disable=R0913,R0915,too-many-positional-argument
|
|
166
166
|
message_type = MessageType.EVALUATE
|
167
167
|
elif field == "reconnect_ins":
|
168
168
|
recorddict = RecordDict()
|
169
|
-
recorddict.
|
169
|
+
recorddict.config_records["config"] = ConfigRecord(
|
170
170
|
{"seconds": proto.reconnect_ins.seconds}
|
171
171
|
)
|
172
172
|
message_type = "reconnect"
|
@@ -216,7 +216,7 @@ def grpc_connection( # pylint: disable=R0913,R0915,too-many-positional-argument
|
|
216
216
|
msg_proto = ClientMessage(evaluate_res=serde.evaluate_res_to_proto(evalres))
|
217
217
|
elif message_type == "reconnect":
|
218
218
|
reason = cast(
|
219
|
-
Reason.ValueType, recorddict.
|
219
|
+
Reason.ValueType, recorddict.config_records["config"]["reason"]
|
220
220
|
)
|
221
221
|
msg_proto = ClientMessage(
|
222
222
|
disconnect_res=ClientMessage.DisconnectRes(reason=reason)
|
@@ -26,7 +26,7 @@ from flwr.client.client import (
|
|
26
26
|
)
|
27
27
|
from flwr.client.numpy_client import NumPyClient
|
28
28
|
from flwr.client.typing import ClientFnExt
|
29
|
-
from flwr.common import
|
29
|
+
from flwr.common import ConfigRecord, Context, Message, Metadata, RecordDict, log
|
30
30
|
from flwr.common.constant import MessageType, MessageTypeLegacy
|
31
31
|
from flwr.common.recorddict_compat import (
|
32
32
|
evaluateres_to_recorddict,
|
@@ -72,7 +72,7 @@ def handle_control_message(message: Message) -> tuple[Optional[Message], int]:
|
|
72
72
|
if message.metadata.message_type == "reconnect":
|
73
73
|
# Retrieve ReconnectIns from RecordDict
|
74
74
|
recorddict = message.content
|
75
|
-
seconds = cast(int, recorddict.
|
75
|
+
seconds = cast(int, recorddict.config_records["config"]["seconds"])
|
76
76
|
# Construct ReconnectIns and call _reconnect
|
77
77
|
disconnect_msg, sleep_duration = _reconnect(
|
78
78
|
ServerMessage.ReconnectIns(seconds=seconds)
|
@@ -80,7 +80,7 @@ def handle_control_message(message: Message) -> tuple[Optional[Message], int]:
|
|
80
80
|
# Store DisconnectRes in RecordDict
|
81
81
|
reason = cast(int, disconnect_msg.disconnect_res.reason)
|
82
82
|
recorddict = RecordDict()
|
83
|
-
recorddict.
|
83
|
+
recorddict.config_records["config"] = ConfigRecord({"reason": reason})
|
84
84
|
out_message = Message(recorddict, reply_to=message)
|
85
85
|
# Return Message and sleep duration
|
86
86
|
return out_message, sleep_duration
|
flwr/client/mod/__init__.py
CHANGED
@@ -16,7 +16,7 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
from .centraldp_mods import adaptiveclipping_mod, fixedclipping_mod
|
19
|
-
from .comms_mods import
|
19
|
+
from .comms_mods import arrays_size_mod, message_size_mod
|
20
20
|
from .localdp_mod import LocalDpMod
|
21
21
|
from .secure_aggregation import secagg_mod, secaggplus_mod
|
22
22
|
from .utils import make_ffn
|
@@ -24,10 +24,10 @@ from .utils import make_ffn
|
|
24
24
|
__all__ = [
|
25
25
|
"LocalDpMod",
|
26
26
|
"adaptiveclipping_mod",
|
27
|
+
"arrays_size_mod",
|
27
28
|
"fixedclipping_mod",
|
28
29
|
"make_ffn",
|
29
30
|
"message_size_mod",
|
30
|
-
"parameters_size_mod",
|
31
31
|
"secagg_mod",
|
32
32
|
"secaggplus_mod",
|
33
33
|
]
|
flwr/client/mod/comms_mods.py
CHANGED
@@ -34,47 +34,41 @@ def message_size_mod(
|
|
34
34
|
"""
|
35
35
|
message_size_in_bytes = 0
|
36
36
|
|
37
|
-
for
|
38
|
-
message_size_in_bytes +=
|
39
|
-
|
40
|
-
for c_record in msg.content.configs_records.values():
|
41
|
-
message_size_in_bytes += c_record.count_bytes()
|
42
|
-
|
43
|
-
for m_record in msg.content.metrics_records.values():
|
44
|
-
message_size_in_bytes += m_record.count_bytes()
|
37
|
+
for record in msg.content.values():
|
38
|
+
message_size_in_bytes += record.count_bytes()
|
45
39
|
|
46
40
|
log(INFO, "Message size: %i bytes", message_size_in_bytes)
|
47
41
|
|
48
42
|
return call_next(msg, ctxt)
|
49
43
|
|
50
44
|
|
51
|
-
def
|
45
|
+
def arrays_size_mod(
|
52
46
|
msg: Message, ctxt: Context, call_next: ClientAppCallable
|
53
47
|
) -> Message:
|
54
|
-
"""
|
48
|
+
"""Arrays size mod.
|
55
49
|
|
56
|
-
This mod logs the number of
|
57
|
-
|
50
|
+
This mod logs the number of array elements transmitted in ``ArrayRecord``s of
|
51
|
+
the message as well as their sizes in bytes.
|
58
52
|
"""
|
59
53
|
model_size_stats = {}
|
60
|
-
|
61
|
-
for record_name,
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
for array in
|
66
|
-
|
54
|
+
arrays_size_in_bytes = 0
|
55
|
+
for record_name, arr_record in msg.content.array_records.items():
|
56
|
+
arr_record_bytes = arr_record.count_bytes()
|
57
|
+
arrays_size_in_bytes += arr_record_bytes
|
58
|
+
element_count = 0
|
59
|
+
for array in arr_record.values():
|
60
|
+
element_count += (
|
67
61
|
int(np.prod(array.shape)) if array.shape else array.numpy().size
|
68
62
|
)
|
69
63
|
|
70
64
|
model_size_stats[f"{record_name}"] = {
|
71
|
-
"
|
72
|
-
"bytes":
|
65
|
+
"elements": element_count,
|
66
|
+
"bytes": arr_record_bytes,
|
73
67
|
}
|
74
68
|
|
75
69
|
if model_size_stats:
|
76
70
|
log(INFO, model_size_stats)
|
77
71
|
|
78
|
-
log(INFO, "Total
|
72
|
+
log(INFO, "Total array elements transmitted: %i bytes", arrays_size_in_bytes)
|
79
73
|
|
80
74
|
return call_next(msg, ctxt)
|
@@ -22,7 +22,7 @@ from typing import Any, cast
|
|
22
22
|
|
23
23
|
from flwr.client.typing import ClientAppCallable
|
24
24
|
from flwr.common import (
|
25
|
-
|
25
|
+
ConfigRecord,
|
26
26
|
Context,
|
27
27
|
Message,
|
28
28
|
Parameters,
|
@@ -63,7 +63,7 @@ from flwr.common.secure_aggregation.secaggplus_utils import (
|
|
63
63
|
share_keys_plaintext_concat,
|
64
64
|
share_keys_plaintext_separate,
|
65
65
|
)
|
66
|
-
from flwr.common.typing import
|
66
|
+
from flwr.common.typing import ConfigRecordValues
|
67
67
|
|
68
68
|
|
69
69
|
@dataclass
|
@@ -97,7 +97,7 @@ class SecAggPlusState:
|
|
97
97
|
ss2_dict: dict[int, bytes] = field(default_factory=dict)
|
98
98
|
public_keys_dict: dict[int, tuple[bytes, bytes]] = field(default_factory=dict)
|
99
99
|
|
100
|
-
def __init__(self, **kwargs:
|
100
|
+
def __init__(self, **kwargs: ConfigRecordValues) -> None:
|
101
101
|
for k, v in kwargs.items():
|
102
102
|
if k.endswith(":V"):
|
103
103
|
continue
|
@@ -115,7 +115,7 @@ class SecAggPlusState:
|
|
115
115
|
new_v = dict(zip(keys, values))
|
116
116
|
self.__setattr__(k, new_v)
|
117
117
|
|
118
|
-
def to_dict(self) -> dict[str,
|
118
|
+
def to_dict(self) -> dict[str, ConfigRecordValues]:
|
119
119
|
"""Convert the state to a dictionary."""
|
120
120
|
ret = vars(self)
|
121
121
|
for k in list(ret.keys()):
|
@@ -144,13 +144,13 @@ def secaggplus_mod(
|
|
144
144
|
return call_next(msg, ctxt)
|
145
145
|
|
146
146
|
# Retrieve local state
|
147
|
-
if RECORD_KEY_STATE not in ctxt.state.
|
148
|
-
ctxt.state.
|
149
|
-
state_dict = ctxt.state.
|
147
|
+
if RECORD_KEY_STATE not in ctxt.state.config_records:
|
148
|
+
ctxt.state.config_records[RECORD_KEY_STATE] = ConfigRecord({})
|
149
|
+
state_dict = ctxt.state.config_records[RECORD_KEY_STATE]
|
150
150
|
state = SecAggPlusState(**state_dict)
|
151
151
|
|
152
152
|
# Retrieve incoming configs
|
153
|
-
configs = msg.content.
|
153
|
+
configs = msg.content.config_records[RECORD_KEY_CONFIGS]
|
154
154
|
|
155
155
|
# Check the validity of the next stage
|
156
156
|
check_stage(state.current_stage, configs)
|
@@ -175,27 +175,27 @@ def secaggplus_mod(
|
|
175
175
|
res = _collect_masked_vectors(
|
176
176
|
state, configs, fitres.num_examples, fitres.parameters
|
177
177
|
)
|
178
|
-
for
|
179
|
-
|
178
|
+
for arr_record in out_content.array_records.values():
|
179
|
+
arr_record.clear()
|
180
180
|
elif state.current_stage == Stage.UNMASK:
|
181
181
|
res = _unmask(state, configs)
|
182
182
|
else:
|
183
183
|
raise ValueError(f"Unknown SecAgg/SecAgg+ stage: {state.current_stage}")
|
184
184
|
|
185
185
|
# Save state
|
186
|
-
ctxt.state.
|
186
|
+
ctxt.state.config_records[RECORD_KEY_STATE] = ConfigRecord(state.to_dict())
|
187
187
|
|
188
188
|
# Return message
|
189
|
-
out_content.
|
189
|
+
out_content.config_records[RECORD_KEY_CONFIGS] = ConfigRecord(res, False)
|
190
190
|
return Message(out_content, reply_to=msg)
|
191
191
|
|
192
192
|
|
193
|
-
def check_stage(current_stage: str, configs:
|
193
|
+
def check_stage(current_stage: str, configs: ConfigRecord) -> None:
|
194
194
|
"""Check the validity of the next stage."""
|
195
195
|
# Check the existence of Config.STAGE
|
196
196
|
if Key.STAGE not in configs:
|
197
197
|
raise KeyError(
|
198
|
-
f"The required key '{Key.STAGE}' is missing from the
|
198
|
+
f"The required key '{Key.STAGE}' is missing from the ConfigRecord."
|
199
199
|
)
|
200
200
|
|
201
201
|
# Check the value type of the Config.STAGE
|
@@ -223,7 +223,7 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
|
|
223
223
|
|
224
224
|
|
225
225
|
# pylint: disable-next=too-many-branches
|
226
|
-
def check_configs(stage: str, configs:
|
226
|
+
def check_configs(stage: str, configs: ConfigRecord) -> None:
|
227
227
|
"""Check the validity of the configs."""
|
228
228
|
# Check configs for the setup stage
|
229
229
|
if stage == Stage.SETUP:
|
@@ -239,7 +239,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
239
239
|
if key not in configs:
|
240
240
|
raise KeyError(
|
241
241
|
f"Stage {Stage.SETUP}: the required key '{key}' is "
|
242
|
-
"missing from the
|
242
|
+
"missing from the ConfigRecord."
|
243
243
|
)
|
244
244
|
# Bool is a subclass of int in Python,
|
245
245
|
# so `isinstance(v, int)` will return True even if v is a boolean.
|
@@ -272,7 +272,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
272
272
|
raise KeyError(
|
273
273
|
f"Stage {Stage.COLLECT_MASKED_VECTORS}: "
|
274
274
|
f"the required key '{key}' is "
|
275
|
-
"missing from the
|
275
|
+
"missing from the ConfigRecord."
|
276
276
|
)
|
277
277
|
if not isinstance(configs[key], list) or any(
|
278
278
|
elm
|
@@ -295,7 +295,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
295
295
|
raise KeyError(
|
296
296
|
f"Stage {Stage.UNMASK}: "
|
297
297
|
f"the required key '{key}' is "
|
298
|
-
"missing from the
|
298
|
+
"missing from the ConfigRecord."
|
299
299
|
)
|
300
300
|
if not isinstance(configs[key], list) or any(
|
301
301
|
elm
|
@@ -313,8 +313,8 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
313
313
|
|
314
314
|
|
315
315
|
def _setup(
|
316
|
-
state: SecAggPlusState, configs:
|
317
|
-
) -> dict[str,
|
316
|
+
state: SecAggPlusState, configs: ConfigRecord
|
317
|
+
) -> dict[str, ConfigRecordValues]:
|
318
318
|
# Assigning parameter values to object fields
|
319
319
|
sec_agg_param_dict = configs
|
320
320
|
state.sample_num = cast(int, sec_agg_param_dict[Key.SAMPLE_NUMBER])
|
@@ -349,8 +349,8 @@ def _setup(
|
|
349
349
|
|
350
350
|
# pylint: disable-next=too-many-locals
|
351
351
|
def _share_keys(
|
352
|
-
state: SecAggPlusState, configs:
|
353
|
-
) -> dict[str,
|
352
|
+
state: SecAggPlusState, configs: ConfigRecord
|
353
|
+
) -> dict[str, ConfigRecordValues]:
|
354
354
|
named_bytes_tuples = cast(dict[str, tuple[bytes, bytes]], configs)
|
355
355
|
key_dict = {int(sid): (pk1, pk2) for sid, (pk1, pk2) in named_bytes_tuples.items()}
|
356
356
|
log(DEBUG, "Node %d: starting stage 1...", state.nid)
|
@@ -412,10 +412,10 @@ def _share_keys(
|
|
412
412
|
# pylint: disable-next=too-many-locals
|
413
413
|
def _collect_masked_vectors(
|
414
414
|
state: SecAggPlusState,
|
415
|
-
configs:
|
415
|
+
configs: ConfigRecord,
|
416
416
|
num_examples: int,
|
417
417
|
updated_parameters: Parameters,
|
418
|
-
) -> dict[str,
|
418
|
+
) -> dict[str, ConfigRecordValues]:
|
419
419
|
log(DEBUG, "Node %d: starting stage 2...", state.nid)
|
420
420
|
available_clients: list[int] = []
|
421
421
|
ciphertexts = cast(list[bytes], configs[Key.CIPHERTEXT_LIST])
|
@@ -498,8 +498,8 @@ def _collect_masked_vectors(
|
|
498
498
|
|
499
499
|
|
500
500
|
def _unmask(
|
501
|
-
state: SecAggPlusState, configs:
|
502
|
-
) -> dict[str,
|
501
|
+
state: SecAggPlusState, configs: ConfigRecord
|
502
|
+
) -> dict[str, ConfigRecordValues]:
|
503
503
|
log(DEBUG, "Node %d: starting stage 3...", state.nid)
|
504
504
|
|
505
505
|
active_nids = cast(list[int], configs[Key.ACTIVE_NODE_ID_LIST])
|
flwr/common/__init__.py
CHANGED
@@ -31,7 +31,10 @@ from .parameter import ndarray_to_bytes as ndarray_to_bytes
|
|
31
31
|
from .parameter import ndarrays_to_parameters as ndarrays_to_parameters
|
32
32
|
from .parameter import parameters_to_ndarrays as parameters_to_ndarrays
|
33
33
|
from .record import Array as Array
|
34
|
+
from .record import ArrayRecord as ArrayRecord
|
35
|
+
from .record import ConfigRecord as ConfigRecord
|
34
36
|
from .record import ConfigsRecord as ConfigsRecord
|
37
|
+
from .record import MetricRecord as MetricRecord
|
35
38
|
from .record import MetricsRecord as MetricsRecord
|
36
39
|
from .record import ParametersRecord as ParametersRecord
|
37
40
|
from .record import RecordDict as RecordDict
|
@@ -42,7 +45,7 @@ from .telemetry import event as event
|
|
42
45
|
from .typing import ClientMessage as ClientMessage
|
43
46
|
from .typing import Code as Code
|
44
47
|
from .typing import Config as Config
|
45
|
-
from .typing import
|
48
|
+
from .typing import ConfigRecordValues as ConfigRecordValues
|
46
49
|
from .typing import DisconnectRes as DisconnectRes
|
47
50
|
from .typing import EvaluateIns as EvaluateIns
|
48
51
|
from .typing import EvaluateRes as EvaluateRes
|
@@ -52,9 +55,9 @@ from .typing import GetParametersIns as GetParametersIns
|
|
52
55
|
from .typing import GetParametersRes as GetParametersRes
|
53
56
|
from .typing import GetPropertiesIns as GetPropertiesIns
|
54
57
|
from .typing import GetPropertiesRes as GetPropertiesRes
|
58
|
+
from .typing import MetricRecordValues as MetricRecordValues
|
55
59
|
from .typing import Metrics as Metrics
|
56
60
|
from .typing import MetricsAggregationFn as MetricsAggregationFn
|
57
|
-
from .typing import MetricsRecordValues as MetricsRecordValues
|
58
61
|
from .typing import NDArray as NDArray
|
59
62
|
from .typing import NDArrays as NDArrays
|
60
63
|
from .typing import Parameters as Parameters
|
@@ -66,11 +69,13 @@ from .typing import Status as Status
|
|
66
69
|
|
67
70
|
__all__ = [
|
68
71
|
"Array",
|
72
|
+
"ArrayRecord",
|
69
73
|
"ClientMessage",
|
70
74
|
"Code",
|
71
75
|
"Config",
|
76
|
+
"ConfigRecord",
|
77
|
+
"ConfigRecordValues",
|
72
78
|
"ConfigsRecord",
|
73
|
-
"ConfigsRecordValues",
|
74
79
|
"Context",
|
75
80
|
"DEFAULT_TTL",
|
76
81
|
"DisconnectRes",
|
@@ -89,10 +94,11 @@ __all__ = [
|
|
89
94
|
"MessageType",
|
90
95
|
"MessageTypeLegacy",
|
91
96
|
"Metadata",
|
97
|
+
"MetricRecord",
|
98
|
+
"MetricRecordValues",
|
92
99
|
"Metrics",
|
93
100
|
"MetricsAggregationFn",
|
94
101
|
"MetricsRecord",
|
95
|
-
"MetricsRecordValues",
|
96
102
|
"NDArray",
|
97
103
|
"NDArrays",
|
98
104
|
"Parameters",
|
flwr/common/config.py
CHANGED
@@ -34,7 +34,7 @@ from flwr.common.constant import (
|
|
34
34
|
)
|
35
35
|
from flwr.common.typing import Run, UserConfig, UserConfigValue
|
36
36
|
|
37
|
-
from . import
|
37
|
+
from . import ConfigRecord, object_ref
|
38
38
|
|
39
39
|
T_dict = TypeVar("T_dict", bound=dict[str, Any]) # pylint: disable=invalid-name
|
40
40
|
|
@@ -260,9 +260,9 @@ def get_metadata_from_config(config: dict[str, Any]) -> tuple[str, str]:
|
|
260
260
|
)
|
261
261
|
|
262
262
|
|
263
|
-
def
|
264
|
-
"""Construct a `
|
265
|
-
c_record =
|
263
|
+
def user_config_to_configrecord(config: UserConfig) -> ConfigRecord:
|
264
|
+
"""Construct a `ConfigRecord` out of a `UserConfig`."""
|
265
|
+
c_record = ConfigRecord()
|
266
266
|
for k, v in config.items():
|
267
267
|
c_record[k] = v
|
268
268
|
|
flwr/common/constant.py
CHANGED
@@ -121,7 +121,7 @@ TIMESTAMP_HEADER = "flwr-timestamp"
|
|
121
121
|
TIMESTAMP_TOLERANCE = 10 # General tolerance for timestamp verification
|
122
122
|
SYSTEM_TIME_TOLERANCE = 5 # Allowance for system time drift
|
123
123
|
|
124
|
-
# Constants for
|
124
|
+
# Constants for ArrayRecord
|
125
125
|
GC_THRESHOLD = 200_000_000 # 200 MB
|
126
126
|
|
127
127
|
|
flwr/common/record/__init__.py
CHANGED
@@ -15,15 +15,18 @@
|
|
15
15
|
"""Record APIs."""
|
16
16
|
|
17
17
|
|
18
|
-
from .
|
18
|
+
from .arrayrecord import Array, ArrayRecord, ParametersRecord
|
19
|
+
from .configrecord import ConfigRecord, ConfigsRecord
|
19
20
|
from .conversion_utils import array_from_numpy
|
20
|
-
from .
|
21
|
-
from .parametersrecord import Array, ParametersRecord
|
21
|
+
from .metricrecord import MetricRecord, MetricsRecord
|
22
22
|
from .recorddict import RecordDict, RecordSet
|
23
23
|
|
24
24
|
__all__ = [
|
25
25
|
"Array",
|
26
|
+
"ArrayRecord",
|
27
|
+
"ConfigRecord",
|
26
28
|
"ConfigsRecord",
|
29
|
+
"MetricRecord",
|
27
30
|
"MetricsRecord",
|
28
31
|
"ParametersRecord",
|
29
32
|
"RecordDict",
|