flwr 1.15.2__py3-none-any.whl → 1.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/build.py +2 -0
- flwr/cli/log.py +20 -21
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +5 -9
- flwr/client/app.py +6 -4
- flwr/client/client_app.py +260 -86
- flwr/client/clientapp/app.py +6 -2
- flwr/client/grpc_client/connection.py +24 -21
- flwr/client/message_handler/message_handler.py +28 -28
- flwr/client/mod/__init__.py +2 -2
- flwr/client/mod/centraldp_mods.py +7 -7
- flwr/client/mod/comms_mods.py +16 -22
- flwr/client/mod/localdp_mod.py +4 -4
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
- flwr/client/rest_client/connection.py +4 -6
- flwr/client/run_info_store.py +2 -2
- flwr/client/supernode/__init__.py +0 -2
- flwr/client/supernode/app.py +1 -11
- flwr/common/__init__.py +12 -4
- flwr/common/address.py +35 -0
- flwr/common/args.py +8 -2
- flwr/common/auth_plugin/auth_plugin.py +2 -1
- flwr/common/config.py +4 -4
- flwr/common/constant.py +16 -0
- flwr/common/context.py +4 -4
- flwr/common/event_log_plugin/__init__.py +22 -0
- flwr/common/event_log_plugin/event_log_plugin.py +60 -0
- flwr/common/grpc.py +1 -1
- flwr/common/logger.py +2 -2
- flwr/common/message.py +338 -102
- flwr/common/object_ref.py +0 -10
- flwr/common/record/__init__.py +8 -4
- flwr/common/record/arrayrecord.py +626 -0
- flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
- flwr/common/record/conversion_utils.py +9 -18
- flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
- flwr/common/record/recorddict.py +288 -0
- flwr/common/recorddict_compat.py +410 -0
- flwr/common/secure_aggregation/quantization.py +5 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/serde.py +67 -190
- flwr/common/telemetry.py +0 -10
- flwr/common/typing.py +44 -8
- flwr/proto/exec_pb2.py +3 -3
- flwr/proto/exec_pb2.pyi +3 -3
- flwr/proto/message_pb2.py +12 -12
- flwr/proto/message_pb2.pyi +9 -9
- flwr/proto/recorddict_pb2.py +70 -0
- flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
- flwr/proto/run_pb2.py +31 -31
- flwr/proto/run_pb2.pyi +3 -3
- flwr/server/__init__.py +3 -1
- flwr/server/app.py +74 -3
- flwr/server/compat/__init__.py +2 -2
- flwr/server/compat/app.py +15 -12
- flwr/server/compat/app_utils.py +26 -18
- flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +41 -41
- flwr/server/fleet_event_log_interceptor.py +94 -0
- flwr/server/{driver → grid}/__init__.py +8 -7
- flwr/server/{driver/driver.py → grid/grid.py} +48 -19
- flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +88 -56
- flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +41 -54
- flwr/server/run_serverapp.py +6 -17
- flwr/server/server_app.py +126 -33
- flwr/server/serverapp/app.py +10 -10
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +8 -12
- flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
- flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
- flwr/server/superlink/fleet/vce/vce_api.py +33 -38
- flwr/server/superlink/linkstate/in_memory_linkstate.py +171 -132
- flwr/server/superlink/linkstate/linkstate.py +51 -64
- flwr/server/superlink/linkstate/sqlite_linkstate.py +253 -285
- flwr/server/superlink/linkstate/utils.py +171 -133
- flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +27 -29
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
- flwr/server/typing.py +3 -3
- flwr/server/utils/__init__.py +2 -2
- flwr/server/utils/validator.py +53 -68
- flwr/server/workflow/default_workflows.py +52 -58
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
- flwr/simulation/app.py +2 -2
- flwr/simulation/ray_transport/ray_actor.py +4 -2
- flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
- flwr/simulation/run_simulation.py +15 -15
- flwr/superexec/app.py +0 -14
- flwr/superexec/deployment.py +4 -4
- flwr/superexec/exec_event_log_interceptor.py +135 -0
- flwr/superexec/exec_grpc.py +10 -4
- flwr/superexec/exec_servicer.py +6 -6
- flwr/superexec/exec_user_auth_interceptor.py +22 -4
- flwr/superexec/executor.py +3 -3
- flwr/superexec/simulation.py +3 -3
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/METADATA +5 -5
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/RECORD +111 -112
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -3
- flwr/client/message_handler/task_handler.py +0 -37
- flwr/common/record/parametersrecord.py +0 -204
- flwr/common/record/recordset.py +0 -202
- flwr/common/recordset_compat.py +0 -418
- flwr/proto/recordset_pb2.py +0 -70
- flwr/proto/task_pb2.py +0 -33
- flwr/proto/task_pb2.pyi +0 -100
- flwr/proto/task_pb2_grpc.py +0 -4
- flwr/proto/task_pb2_grpc.pyi +0 -4
- /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
- /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
|
@@ -26,17 +26,17 @@ from flwr.client.client import (
|
|
|
26
26
|
)
|
|
27
27
|
from flwr.client.numpy_client import NumPyClient
|
|
28
28
|
from flwr.client.typing import ClientFnExt
|
|
29
|
-
from flwr.common import
|
|
29
|
+
from flwr.common import ConfigRecord, Context, Message, Metadata, RecordDict, log
|
|
30
30
|
from flwr.common.constant import MessageType, MessageTypeLegacy
|
|
31
|
-
from flwr.common.
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
31
|
+
from flwr.common.recorddict_compat import (
|
|
32
|
+
evaluateres_to_recorddict,
|
|
33
|
+
fitres_to_recorddict,
|
|
34
|
+
getparametersres_to_recorddict,
|
|
35
|
+
getpropertiesres_to_recorddict,
|
|
36
|
+
recorddict_to_evaluateins,
|
|
37
|
+
recorddict_to_fitins,
|
|
38
|
+
recorddict_to_getparametersins,
|
|
39
|
+
recorddict_to_getpropertiesins,
|
|
40
40
|
)
|
|
41
41
|
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
|
42
42
|
ClientMessage,
|
|
@@ -70,19 +70,19 @@ def handle_control_message(message: Message) -> tuple[Optional[Message], int]:
|
|
|
70
70
|
Number of seconds that the client should disconnect from the server.
|
|
71
71
|
"""
|
|
72
72
|
if message.metadata.message_type == "reconnect":
|
|
73
|
-
# Retrieve ReconnectIns from
|
|
74
|
-
|
|
75
|
-
seconds = cast(int,
|
|
73
|
+
# Retrieve ReconnectIns from RecordDict
|
|
74
|
+
recorddict = message.content
|
|
75
|
+
seconds = cast(int, recorddict.config_records["config"]["seconds"])
|
|
76
76
|
# Construct ReconnectIns and call _reconnect
|
|
77
77
|
disconnect_msg, sleep_duration = _reconnect(
|
|
78
78
|
ServerMessage.ReconnectIns(seconds=seconds)
|
|
79
79
|
)
|
|
80
|
-
# Store DisconnectRes in
|
|
80
|
+
# Store DisconnectRes in RecordDict
|
|
81
81
|
reason = cast(int, disconnect_msg.disconnect_res.reason)
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
out_message = message
|
|
85
|
-
# Return
|
|
82
|
+
recorddict = RecordDict()
|
|
83
|
+
recorddict.config_records["config"] = ConfigRecord({"reason": reason})
|
|
84
|
+
out_message = Message(recorddict, reply_to=message)
|
|
85
|
+
# Return Message and sleep duration
|
|
86
86
|
return out_message, sleep_duration
|
|
87
87
|
|
|
88
88
|
# Any other message
|
|
@@ -111,37 +111,37 @@ def handle_legacy_message_from_msgtype(
|
|
|
111
111
|
if message_type == MessageTypeLegacy.GET_PROPERTIES:
|
|
112
112
|
get_properties_res = maybe_call_get_properties(
|
|
113
113
|
client=client,
|
|
114
|
-
get_properties_ins=
|
|
114
|
+
get_properties_ins=recorddict_to_getpropertiesins(message.content),
|
|
115
115
|
)
|
|
116
|
-
|
|
116
|
+
out_recorddict = getpropertiesres_to_recorddict(get_properties_res)
|
|
117
117
|
# Handle GetParametersIns
|
|
118
118
|
elif message_type == MessageTypeLegacy.GET_PARAMETERS:
|
|
119
119
|
get_parameters_res = maybe_call_get_parameters(
|
|
120
120
|
client=client,
|
|
121
|
-
get_parameters_ins=
|
|
121
|
+
get_parameters_ins=recorddict_to_getparametersins(message.content),
|
|
122
122
|
)
|
|
123
|
-
|
|
123
|
+
out_recorddict = getparametersres_to_recorddict(
|
|
124
124
|
get_parameters_res, keep_input=False
|
|
125
125
|
)
|
|
126
126
|
# Handle FitIns
|
|
127
127
|
elif message_type == MessageType.TRAIN:
|
|
128
128
|
fit_res = maybe_call_fit(
|
|
129
129
|
client=client,
|
|
130
|
-
fit_ins=
|
|
130
|
+
fit_ins=recorddict_to_fitins(message.content, keep_input=True),
|
|
131
131
|
)
|
|
132
|
-
|
|
132
|
+
out_recorddict = fitres_to_recorddict(fit_res, keep_input=False)
|
|
133
133
|
# Handle EvaluateIns
|
|
134
134
|
elif message_type == MessageType.EVALUATE:
|
|
135
135
|
evaluate_res = maybe_call_evaluate(
|
|
136
136
|
client=client,
|
|
137
|
-
evaluate_ins=
|
|
137
|
+
evaluate_ins=recorddict_to_evaluateins(message.content, keep_input=True),
|
|
138
138
|
)
|
|
139
|
-
|
|
139
|
+
out_recorddict = evaluateres_to_recorddict(evaluate_res)
|
|
140
140
|
else:
|
|
141
141
|
raise ValueError(f"Invalid message type: {message_type}")
|
|
142
142
|
|
|
143
143
|
# Return Message
|
|
144
|
-
return message
|
|
144
|
+
return Message(out_recorddict, reply_to=message)
|
|
145
145
|
|
|
146
146
|
|
|
147
147
|
def _reconnect(
|
|
@@ -167,7 +167,7 @@ def validate_out_message(out_message: Message, in_message_metadata: Metadata) ->
|
|
|
167
167
|
and out_meta.message_id == "" # This will be generated by the server
|
|
168
168
|
and out_meta.src_node_id == in_meta.dst_node_id
|
|
169
169
|
and out_meta.dst_node_id == in_meta.src_node_id
|
|
170
|
-
and out_meta.
|
|
170
|
+
and out_meta.reply_to_message_id == in_meta.message_id
|
|
171
171
|
and out_meta.group_id == in_meta.group_id
|
|
172
172
|
and out_meta.message_type == in_meta.message_type
|
|
173
173
|
and out_meta.created_at > in_meta.created_at
|
flwr/client/mod/__init__.py
CHANGED
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from .centraldp_mods import adaptiveclipping_mod, fixedclipping_mod
|
|
19
|
-
from .comms_mods import
|
|
19
|
+
from .comms_mods import arrays_size_mod, message_size_mod
|
|
20
20
|
from .localdp_mod import LocalDpMod
|
|
21
21
|
from .secure_aggregation import secagg_mod, secaggplus_mod
|
|
22
22
|
from .utils import make_ffn
|
|
@@ -24,10 +24,10 @@ from .utils import make_ffn
|
|
|
24
24
|
__all__ = [
|
|
25
25
|
"LocalDpMod",
|
|
26
26
|
"adaptiveclipping_mod",
|
|
27
|
+
"arrays_size_mod",
|
|
27
28
|
"fixedclipping_mod",
|
|
28
29
|
"make_ffn",
|
|
29
30
|
"message_size_mod",
|
|
30
|
-
"parameters_size_mod",
|
|
31
31
|
"secagg_mod",
|
|
32
32
|
"secaggplus_mod",
|
|
33
33
|
]
|
|
@@ -19,7 +19,7 @@ from logging import INFO
|
|
|
19
19
|
|
|
20
20
|
from flwr.client.typing import ClientAppCallable
|
|
21
21
|
from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays
|
|
22
|
-
from flwr.common import
|
|
22
|
+
from flwr.common import recorddict_compat as compat
|
|
23
23
|
from flwr.common.constant import MessageType
|
|
24
24
|
from flwr.common.context import Context
|
|
25
25
|
from flwr.common.differential_privacy import (
|
|
@@ -53,7 +53,7 @@ def fixedclipping_mod(
|
|
|
53
53
|
"""
|
|
54
54
|
if msg.metadata.message_type != MessageType.TRAIN:
|
|
55
55
|
return call_next(msg, ctxt)
|
|
56
|
-
fit_ins = compat.
|
|
56
|
+
fit_ins = compat.recorddict_to_fitins(msg.content, keep_input=True)
|
|
57
57
|
if KEY_CLIPPING_NORM not in fit_ins.config:
|
|
58
58
|
raise KeyError(
|
|
59
59
|
f"The {KEY_CLIPPING_NORM} value is not supplied by the "
|
|
@@ -71,7 +71,7 @@ def fixedclipping_mod(
|
|
|
71
71
|
if out_msg.has_error():
|
|
72
72
|
return out_msg
|
|
73
73
|
|
|
74
|
-
fit_res = compat.
|
|
74
|
+
fit_res = compat.recorddict_to_fitres(out_msg.content, keep_input=True)
|
|
75
75
|
|
|
76
76
|
client_to_server_params = parameters_to_ndarrays(fit_res.parameters)
|
|
77
77
|
|
|
@@ -87,7 +87,7 @@ def fixedclipping_mod(
|
|
|
87
87
|
)
|
|
88
88
|
|
|
89
89
|
fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
|
|
90
|
-
out_msg.content = compat.
|
|
90
|
+
out_msg.content = compat.fitres_to_recorddict(fit_res, keep_input=True)
|
|
91
91
|
return out_msg
|
|
92
92
|
|
|
93
93
|
|
|
@@ -116,7 +116,7 @@ def adaptiveclipping_mod(
|
|
|
116
116
|
if msg.metadata.message_type != MessageType.TRAIN:
|
|
117
117
|
return call_next(msg, ctxt)
|
|
118
118
|
|
|
119
|
-
fit_ins = compat.
|
|
119
|
+
fit_ins = compat.recorddict_to_fitins(msg.content, keep_input=True)
|
|
120
120
|
|
|
121
121
|
if KEY_CLIPPING_NORM not in fit_ins.config:
|
|
122
122
|
raise KeyError(
|
|
@@ -136,7 +136,7 @@ def adaptiveclipping_mod(
|
|
|
136
136
|
if out_msg.has_error():
|
|
137
137
|
return out_msg
|
|
138
138
|
|
|
139
|
-
fit_res = compat.
|
|
139
|
+
fit_res = compat.recorddict_to_fitres(out_msg.content, keep_input=True)
|
|
140
140
|
|
|
141
141
|
client_to_server_params = parameters_to_ndarrays(fit_res.parameters)
|
|
142
142
|
|
|
@@ -155,5 +155,5 @@ def adaptiveclipping_mod(
|
|
|
155
155
|
fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
|
|
156
156
|
|
|
157
157
|
fit_res.metrics[KEY_NORM_BIT] = norm_bit
|
|
158
|
-
out_msg.content = compat.
|
|
158
|
+
out_msg.content = compat.fitres_to_recorddict(fit_res, keep_input=True)
|
|
159
159
|
return out_msg
|
flwr/client/mod/comms_mods.py
CHANGED
|
@@ -34,47 +34,41 @@ def message_size_mod(
|
|
|
34
34
|
"""
|
|
35
35
|
message_size_in_bytes = 0
|
|
36
36
|
|
|
37
|
-
for
|
|
38
|
-
message_size_in_bytes +=
|
|
39
|
-
|
|
40
|
-
for c_record in msg.content.configs_records.values():
|
|
41
|
-
message_size_in_bytes += c_record.count_bytes()
|
|
42
|
-
|
|
43
|
-
for m_record in msg.content.metrics_records.values():
|
|
44
|
-
message_size_in_bytes += m_record.count_bytes()
|
|
37
|
+
for record in msg.content.values():
|
|
38
|
+
message_size_in_bytes += record.count_bytes()
|
|
45
39
|
|
|
46
40
|
log(INFO, "Message size: %i bytes", message_size_in_bytes)
|
|
47
41
|
|
|
48
42
|
return call_next(msg, ctxt)
|
|
49
43
|
|
|
50
44
|
|
|
51
|
-
def
|
|
45
|
+
def arrays_size_mod(
|
|
52
46
|
msg: Message, ctxt: Context, call_next: ClientAppCallable
|
|
53
47
|
) -> Message:
|
|
54
|
-
"""
|
|
48
|
+
"""Arrays size mod.
|
|
55
49
|
|
|
56
|
-
This mod logs the number of
|
|
57
|
-
|
|
50
|
+
This mod logs the number of array elements transmitted in ``ArrayRecord``s of
|
|
51
|
+
the message as well as their sizes in bytes.
|
|
58
52
|
"""
|
|
59
53
|
model_size_stats = {}
|
|
60
|
-
|
|
61
|
-
for record_name,
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
for array in
|
|
66
|
-
|
|
54
|
+
arrays_size_in_bytes = 0
|
|
55
|
+
for record_name, arr_record in msg.content.array_records.items():
|
|
56
|
+
arr_record_bytes = arr_record.count_bytes()
|
|
57
|
+
arrays_size_in_bytes += arr_record_bytes
|
|
58
|
+
element_count = 0
|
|
59
|
+
for array in arr_record.values():
|
|
60
|
+
element_count += (
|
|
67
61
|
int(np.prod(array.shape)) if array.shape else array.numpy().size
|
|
68
62
|
)
|
|
69
63
|
|
|
70
64
|
model_size_stats[f"{record_name}"] = {
|
|
71
|
-
"
|
|
72
|
-
"bytes":
|
|
65
|
+
"elements": element_count,
|
|
66
|
+
"bytes": arr_record_bytes,
|
|
73
67
|
}
|
|
74
68
|
|
|
75
69
|
if model_size_stats:
|
|
76
70
|
log(INFO, model_size_stats)
|
|
77
71
|
|
|
78
|
-
log(INFO, "Total
|
|
72
|
+
log(INFO, "Total array elements transmitted: %i bytes", arrays_size_in_bytes)
|
|
79
73
|
|
|
80
74
|
return call_next(msg, ctxt)
|
flwr/client/mod/localdp_mod.py
CHANGED
|
@@ -21,7 +21,7 @@ import numpy as np
|
|
|
21
21
|
|
|
22
22
|
from flwr.client.typing import ClientAppCallable
|
|
23
23
|
from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays
|
|
24
|
-
from flwr.common import
|
|
24
|
+
from flwr.common import recorddict_compat as compat
|
|
25
25
|
from flwr.common.constant import MessageType
|
|
26
26
|
from flwr.common.context import Context
|
|
27
27
|
from flwr.common.differential_privacy import (
|
|
@@ -107,7 +107,7 @@ class LocalDpMod:
|
|
|
107
107
|
if msg.metadata.message_type != MessageType.TRAIN:
|
|
108
108
|
return call_next(msg, ctxt)
|
|
109
109
|
|
|
110
|
-
fit_ins = compat.
|
|
110
|
+
fit_ins = compat.recorddict_to_fitins(msg.content, keep_input=True)
|
|
111
111
|
server_to_client_params = parameters_to_ndarrays(fit_ins.parameters)
|
|
112
112
|
|
|
113
113
|
# Call inner app
|
|
@@ -117,7 +117,7 @@ class LocalDpMod:
|
|
|
117
117
|
if out_msg.has_error():
|
|
118
118
|
return out_msg
|
|
119
119
|
|
|
120
|
-
fit_res = compat.
|
|
120
|
+
fit_res = compat.recorddict_to_fitres(out_msg.content, keep_input=True)
|
|
121
121
|
|
|
122
122
|
client_to_server_params = parameters_to_ndarrays(fit_res.parameters)
|
|
123
123
|
|
|
@@ -149,5 +149,5 @@ class LocalDpMod:
|
|
|
149
149
|
noise_value_sd,
|
|
150
150
|
)
|
|
151
151
|
|
|
152
|
-
out_msg.content = compat.
|
|
152
|
+
out_msg.content = compat.fitres_to_recorddict(fit_res, keep_input=True)
|
|
153
153
|
return out_msg
|
|
@@ -22,15 +22,15 @@ from typing import Any, cast
|
|
|
22
22
|
|
|
23
23
|
from flwr.client.typing import ClientAppCallable
|
|
24
24
|
from flwr.common import (
|
|
25
|
-
|
|
25
|
+
ConfigRecord,
|
|
26
26
|
Context,
|
|
27
27
|
Message,
|
|
28
28
|
Parameters,
|
|
29
|
-
|
|
29
|
+
RecordDict,
|
|
30
30
|
ndarray_to_bytes,
|
|
31
31
|
parameters_to_ndarrays,
|
|
32
32
|
)
|
|
33
|
-
from flwr.common import
|
|
33
|
+
from flwr.common import recorddict_compat as compat
|
|
34
34
|
from flwr.common.constant import MessageType
|
|
35
35
|
from flwr.common.logger import log
|
|
36
36
|
from flwr.common.secure_aggregation.crypto.shamir import create_shares
|
|
@@ -63,7 +63,7 @@ from flwr.common.secure_aggregation.secaggplus_utils import (
|
|
|
63
63
|
share_keys_plaintext_concat,
|
|
64
64
|
share_keys_plaintext_separate,
|
|
65
65
|
)
|
|
66
|
-
from flwr.common.typing import
|
|
66
|
+
from flwr.common.typing import ConfigRecordValues
|
|
67
67
|
|
|
68
68
|
|
|
69
69
|
@dataclass
|
|
@@ -97,7 +97,7 @@ class SecAggPlusState:
|
|
|
97
97
|
ss2_dict: dict[int, bytes] = field(default_factory=dict)
|
|
98
98
|
public_keys_dict: dict[int, tuple[bytes, bytes]] = field(default_factory=dict)
|
|
99
99
|
|
|
100
|
-
def __init__(self, **kwargs:
|
|
100
|
+
def __init__(self, **kwargs: ConfigRecordValues) -> None:
|
|
101
101
|
for k, v in kwargs.items():
|
|
102
102
|
if k.endswith(":V"):
|
|
103
103
|
continue
|
|
@@ -115,7 +115,7 @@ class SecAggPlusState:
|
|
|
115
115
|
new_v = dict(zip(keys, values))
|
|
116
116
|
self.__setattr__(k, new_v)
|
|
117
117
|
|
|
118
|
-
def to_dict(self) -> dict[str,
|
|
118
|
+
def to_dict(self) -> dict[str, ConfigRecordValues]:
|
|
119
119
|
"""Convert the state to a dictionary."""
|
|
120
120
|
ret = vars(self)
|
|
121
121
|
for k in list(ret.keys()):
|
|
@@ -144,13 +144,13 @@ def secaggplus_mod(
|
|
|
144
144
|
return call_next(msg, ctxt)
|
|
145
145
|
|
|
146
146
|
# Retrieve local state
|
|
147
|
-
if RECORD_KEY_STATE not in ctxt.state.
|
|
148
|
-
ctxt.state.
|
|
149
|
-
state_dict = ctxt.state.
|
|
147
|
+
if RECORD_KEY_STATE not in ctxt.state.config_records:
|
|
148
|
+
ctxt.state.config_records[RECORD_KEY_STATE] = ConfigRecord({})
|
|
149
|
+
state_dict = ctxt.state.config_records[RECORD_KEY_STATE]
|
|
150
150
|
state = SecAggPlusState(**state_dict)
|
|
151
151
|
|
|
152
152
|
# Retrieve incoming configs
|
|
153
|
-
configs = msg.content.
|
|
153
|
+
configs = msg.content.config_records[RECORD_KEY_CONFIGS]
|
|
154
154
|
|
|
155
155
|
# Check the validity of the next stage
|
|
156
156
|
check_stage(state.current_stage, configs)
|
|
@@ -162,7 +162,7 @@ def secaggplus_mod(
|
|
|
162
162
|
check_configs(state.current_stage, configs)
|
|
163
163
|
|
|
164
164
|
# Execute
|
|
165
|
-
out_content =
|
|
165
|
+
out_content = RecordDict()
|
|
166
166
|
if state.current_stage == Stage.SETUP:
|
|
167
167
|
state.nid = msg.metadata.dst_node_id
|
|
168
168
|
res = _setup(state, configs)
|
|
@@ -171,31 +171,31 @@ def secaggplus_mod(
|
|
|
171
171
|
elif state.current_stage == Stage.COLLECT_MASKED_VECTORS:
|
|
172
172
|
out_msg = call_next(msg, ctxt)
|
|
173
173
|
out_content = out_msg.content
|
|
174
|
-
fitres = compat.
|
|
174
|
+
fitres = compat.recorddict_to_fitres(out_content, keep_input=True)
|
|
175
175
|
res = _collect_masked_vectors(
|
|
176
176
|
state, configs, fitres.num_examples, fitres.parameters
|
|
177
177
|
)
|
|
178
|
-
for
|
|
179
|
-
|
|
178
|
+
for arr_record in out_content.array_records.values():
|
|
179
|
+
arr_record.clear()
|
|
180
180
|
elif state.current_stage == Stage.UNMASK:
|
|
181
181
|
res = _unmask(state, configs)
|
|
182
182
|
else:
|
|
183
183
|
raise ValueError(f"Unknown SecAgg/SecAgg+ stage: {state.current_stage}")
|
|
184
184
|
|
|
185
185
|
# Save state
|
|
186
|
-
ctxt.state.
|
|
186
|
+
ctxt.state.config_records[RECORD_KEY_STATE] = ConfigRecord(state.to_dict())
|
|
187
187
|
|
|
188
188
|
# Return message
|
|
189
|
-
out_content.
|
|
190
|
-
return
|
|
189
|
+
out_content.config_records[RECORD_KEY_CONFIGS] = ConfigRecord(res, False)
|
|
190
|
+
return Message(out_content, reply_to=msg)
|
|
191
191
|
|
|
192
192
|
|
|
193
|
-
def check_stage(current_stage: str, configs:
|
|
193
|
+
def check_stage(current_stage: str, configs: ConfigRecord) -> None:
|
|
194
194
|
"""Check the validity of the next stage."""
|
|
195
195
|
# Check the existence of Config.STAGE
|
|
196
196
|
if Key.STAGE not in configs:
|
|
197
197
|
raise KeyError(
|
|
198
|
-
f"The required key '{Key.STAGE}' is missing from the
|
|
198
|
+
f"The required key '{Key.STAGE}' is missing from the ConfigRecord."
|
|
199
199
|
)
|
|
200
200
|
|
|
201
201
|
# Check the value type of the Config.STAGE
|
|
@@ -223,7 +223,7 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
|
|
|
223
223
|
|
|
224
224
|
|
|
225
225
|
# pylint: disable-next=too-many-branches
|
|
226
|
-
def check_configs(stage: str, configs:
|
|
226
|
+
def check_configs(stage: str, configs: ConfigRecord) -> None:
|
|
227
227
|
"""Check the validity of the configs."""
|
|
228
228
|
# Check configs for the setup stage
|
|
229
229
|
if stage == Stage.SETUP:
|
|
@@ -239,7 +239,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
|
239
239
|
if key not in configs:
|
|
240
240
|
raise KeyError(
|
|
241
241
|
f"Stage {Stage.SETUP}: the required key '{key}' is "
|
|
242
|
-
"missing from the
|
|
242
|
+
"missing from the ConfigRecord."
|
|
243
243
|
)
|
|
244
244
|
# Bool is a subclass of int in Python,
|
|
245
245
|
# so `isinstance(v, int)` will return True even if v is a boolean.
|
|
@@ -272,7 +272,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
|
272
272
|
raise KeyError(
|
|
273
273
|
f"Stage {Stage.COLLECT_MASKED_VECTORS}: "
|
|
274
274
|
f"the required key '{key}' is "
|
|
275
|
-
"missing from the
|
|
275
|
+
"missing from the ConfigRecord."
|
|
276
276
|
)
|
|
277
277
|
if not isinstance(configs[key], list) or any(
|
|
278
278
|
elm
|
|
@@ -295,7 +295,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
|
295
295
|
raise KeyError(
|
|
296
296
|
f"Stage {Stage.UNMASK}: "
|
|
297
297
|
f"the required key '{key}' is "
|
|
298
|
-
"missing from the
|
|
298
|
+
"missing from the ConfigRecord."
|
|
299
299
|
)
|
|
300
300
|
if not isinstance(configs[key], list) or any(
|
|
301
301
|
elm
|
|
@@ -313,8 +313,8 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
|
313
313
|
|
|
314
314
|
|
|
315
315
|
def _setup(
|
|
316
|
-
state: SecAggPlusState, configs:
|
|
317
|
-
) -> dict[str,
|
|
316
|
+
state: SecAggPlusState, configs: ConfigRecord
|
|
317
|
+
) -> dict[str, ConfigRecordValues]:
|
|
318
318
|
# Assigning parameter values to object fields
|
|
319
319
|
sec_agg_param_dict = configs
|
|
320
320
|
state.sample_num = cast(int, sec_agg_param_dict[Key.SAMPLE_NUMBER])
|
|
@@ -349,8 +349,8 @@ def _setup(
|
|
|
349
349
|
|
|
350
350
|
# pylint: disable-next=too-many-locals
|
|
351
351
|
def _share_keys(
|
|
352
|
-
state: SecAggPlusState, configs:
|
|
353
|
-
) -> dict[str,
|
|
352
|
+
state: SecAggPlusState, configs: ConfigRecord
|
|
353
|
+
) -> dict[str, ConfigRecordValues]:
|
|
354
354
|
named_bytes_tuples = cast(dict[str, tuple[bytes, bytes]], configs)
|
|
355
355
|
key_dict = {int(sid): (pk1, pk2) for sid, (pk1, pk2) in named_bytes_tuples.items()}
|
|
356
356
|
log(DEBUG, "Node %d: starting stage 1...", state.nid)
|
|
@@ -412,10 +412,10 @@ def _share_keys(
|
|
|
412
412
|
# pylint: disable-next=too-many-locals
|
|
413
413
|
def _collect_masked_vectors(
|
|
414
414
|
state: SecAggPlusState,
|
|
415
|
-
configs:
|
|
415
|
+
configs: ConfigRecord,
|
|
416
416
|
num_examples: int,
|
|
417
417
|
updated_parameters: Parameters,
|
|
418
|
-
) -> dict[str,
|
|
418
|
+
) -> dict[str, ConfigRecordValues]:
|
|
419
419
|
log(DEBUG, "Node %d: starting stage 2...", state.nid)
|
|
420
420
|
available_clients: list[int] = []
|
|
421
421
|
ciphertexts = cast(list[bytes], configs[Key.CIPHERTEXT_LIST])
|
|
@@ -498,8 +498,8 @@ def _collect_masked_vectors(
|
|
|
498
498
|
|
|
499
499
|
|
|
500
500
|
def _unmask(
|
|
501
|
-
state: SecAggPlusState, configs:
|
|
502
|
-
) -> dict[str,
|
|
501
|
+
state: SecAggPlusState, configs: ConfigRecord
|
|
502
|
+
) -> dict[str, ConfigRecordValues]:
|
|
503
503
|
log(DEBUG, "Node %d: starting stage 3...", state.nid)
|
|
504
504
|
|
|
505
505
|
active_nids = cast(list[int], configs[Key.ACTIVE_NODE_ID_LIST])
|
|
@@ -66,9 +66,7 @@ except ModuleNotFoundError:
|
|
|
66
66
|
|
|
67
67
|
PATH_CREATE_NODE: str = "api/v0/fleet/create-node"
|
|
68
68
|
PATH_DELETE_NODE: str = "api/v0/fleet/delete-node"
|
|
69
|
-
PATH_PULL_TASK_INS: str = "api/v0/fleet/pull-task-ins"
|
|
70
69
|
PATH_PULL_MESSAGES: str = "/api/v0/fleet/pull-messages"
|
|
71
|
-
PATH_PUSH_TASK_RES: str = "api/v0/fleet/push-task-res"
|
|
72
70
|
PATH_PUSH_MESSAGES: str = "/api/v0/fleet/push-messages"
|
|
73
71
|
PATH_PING: str = "api/v0/fleet/ping"
|
|
74
72
|
PATH_GET_RUN: str = "/api/v0/fleet/get-run"
|
|
@@ -280,7 +278,7 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
280
278
|
node = None
|
|
281
279
|
|
|
282
280
|
def receive() -> Optional[Message]:
|
|
283
|
-
"""Receive next
|
|
281
|
+
"""Receive next Message from server."""
|
|
284
282
|
# Get Node
|
|
285
283
|
if node is None:
|
|
286
284
|
log(ERROR, "Node instance missing")
|
|
@@ -309,11 +307,11 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
309
307
|
if message_proto is not None:
|
|
310
308
|
message = message_from_proto(message_proto)
|
|
311
309
|
metadata = copy(message.metadata)
|
|
312
|
-
log(INFO, "[Node] POST /%s: success",
|
|
310
|
+
log(INFO, "[Node] POST /%s: success", PATH_PULL_MESSAGES)
|
|
313
311
|
return message
|
|
314
312
|
|
|
315
313
|
def send(message: Message) -> None:
|
|
316
|
-
"""Send
|
|
314
|
+
"""Send Message result back to server."""
|
|
317
315
|
# Get Node
|
|
318
316
|
if node is None:
|
|
319
317
|
log(ERROR, "Node instance missing")
|
|
@@ -345,7 +343,7 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
345
343
|
log(
|
|
346
344
|
INFO,
|
|
347
345
|
"[Node] POST /%s: success, created result %s",
|
|
348
|
-
|
|
346
|
+
PATH_PUSH_MESSAGES,
|
|
349
347
|
res.results, # pylint: disable=no-member
|
|
350
348
|
)
|
|
351
349
|
|
flwr/client/run_info_store.py
CHANGED
|
@@ -19,7 +19,7 @@ from dataclasses import dataclass
|
|
|
19
19
|
from pathlib import Path
|
|
20
20
|
from typing import Optional
|
|
21
21
|
|
|
22
|
-
from flwr.common import Context,
|
|
22
|
+
from flwr.common import Context, RecordDict
|
|
23
23
|
from flwr.common.config import (
|
|
24
24
|
get_fused_config,
|
|
25
25
|
get_fused_config_from_dir,
|
|
@@ -86,7 +86,7 @@ class DeprecatedRunInfoStore:
|
|
|
86
86
|
run_id=run_id,
|
|
87
87
|
node_id=self.node_id,
|
|
88
88
|
node_config=self.node_config,
|
|
89
|
-
state=
|
|
89
|
+
state=RecordDict(),
|
|
90
90
|
run_config=initial_run_config.copy(),
|
|
91
91
|
),
|
|
92
92
|
)
|
flwr/client/supernode/app.py
CHANGED
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import argparse
|
|
19
|
-
from logging import DEBUG,
|
|
19
|
+
from logging import DEBUG, INFO, WARN
|
|
20
20
|
from pathlib import Path
|
|
21
21
|
from typing import Optional
|
|
22
22
|
|
|
@@ -98,16 +98,6 @@ def run_supernode() -> None:
|
|
|
98
98
|
)
|
|
99
99
|
|
|
100
100
|
|
|
101
|
-
def run_client_app() -> None:
|
|
102
|
-
"""Run Flower client app."""
|
|
103
|
-
event(EventType.RUN_CLIENT_APP_ENTER)
|
|
104
|
-
log(
|
|
105
|
-
ERROR,
|
|
106
|
-
"The command `flower-client-app` has been replaced by `flwr run`.",
|
|
107
|
-
)
|
|
108
|
-
register_exit_handlers(event_type=EventType.RUN_CLIENT_APP_LEAVE)
|
|
109
|
-
|
|
110
|
-
|
|
111
101
|
def _parse_args_run_supernode() -> argparse.ArgumentParser:
|
|
112
102
|
"""Parse flower-supernode command line arguments."""
|
|
113
103
|
parser = argparse.ArgumentParser(
|
flwr/common/__init__.py
CHANGED
|
@@ -31,9 +31,13 @@ from .parameter import ndarray_to_bytes as ndarray_to_bytes
|
|
|
31
31
|
from .parameter import ndarrays_to_parameters as ndarrays_to_parameters
|
|
32
32
|
from .parameter import parameters_to_ndarrays as parameters_to_ndarrays
|
|
33
33
|
from .record import Array as Array
|
|
34
|
+
from .record import ArrayRecord as ArrayRecord
|
|
35
|
+
from .record import ConfigRecord as ConfigRecord
|
|
34
36
|
from .record import ConfigsRecord as ConfigsRecord
|
|
37
|
+
from .record import MetricRecord as MetricRecord
|
|
35
38
|
from .record import MetricsRecord as MetricsRecord
|
|
36
39
|
from .record import ParametersRecord as ParametersRecord
|
|
40
|
+
from .record import RecordDict as RecordDict
|
|
37
41
|
from .record import RecordSet as RecordSet
|
|
38
42
|
from .record import array_from_numpy as array_from_numpy
|
|
39
43
|
from .telemetry import EventType as EventType
|
|
@@ -41,7 +45,7 @@ from .telemetry import event as event
|
|
|
41
45
|
from .typing import ClientMessage as ClientMessage
|
|
42
46
|
from .typing import Code as Code
|
|
43
47
|
from .typing import Config as Config
|
|
44
|
-
from .typing import
|
|
48
|
+
from .typing import ConfigRecordValues as ConfigRecordValues
|
|
45
49
|
from .typing import DisconnectRes as DisconnectRes
|
|
46
50
|
from .typing import EvaluateIns as EvaluateIns
|
|
47
51
|
from .typing import EvaluateRes as EvaluateRes
|
|
@@ -51,9 +55,9 @@ from .typing import GetParametersIns as GetParametersIns
|
|
|
51
55
|
from .typing import GetParametersRes as GetParametersRes
|
|
52
56
|
from .typing import GetPropertiesIns as GetPropertiesIns
|
|
53
57
|
from .typing import GetPropertiesRes as GetPropertiesRes
|
|
58
|
+
from .typing import MetricRecordValues as MetricRecordValues
|
|
54
59
|
from .typing import Metrics as Metrics
|
|
55
60
|
from .typing import MetricsAggregationFn as MetricsAggregationFn
|
|
56
|
-
from .typing import MetricsRecordValues as MetricsRecordValues
|
|
57
61
|
from .typing import NDArray as NDArray
|
|
58
62
|
from .typing import NDArrays as NDArrays
|
|
59
63
|
from .typing import Parameters as Parameters
|
|
@@ -65,11 +69,13 @@ from .typing import Status as Status
|
|
|
65
69
|
|
|
66
70
|
__all__ = [
|
|
67
71
|
"Array",
|
|
72
|
+
"ArrayRecord",
|
|
68
73
|
"ClientMessage",
|
|
69
74
|
"Code",
|
|
70
75
|
"Config",
|
|
76
|
+
"ConfigRecord",
|
|
77
|
+
"ConfigRecordValues",
|
|
71
78
|
"ConfigsRecord",
|
|
72
|
-
"ConfigsRecordValues",
|
|
73
79
|
"Context",
|
|
74
80
|
"DEFAULT_TTL",
|
|
75
81
|
"DisconnectRes",
|
|
@@ -88,16 +94,18 @@ __all__ = [
|
|
|
88
94
|
"MessageType",
|
|
89
95
|
"MessageTypeLegacy",
|
|
90
96
|
"Metadata",
|
|
97
|
+
"MetricRecord",
|
|
98
|
+
"MetricRecordValues",
|
|
91
99
|
"Metrics",
|
|
92
100
|
"MetricsAggregationFn",
|
|
93
101
|
"MetricsRecord",
|
|
94
|
-
"MetricsRecordValues",
|
|
95
102
|
"NDArray",
|
|
96
103
|
"NDArrays",
|
|
97
104
|
"Parameters",
|
|
98
105
|
"ParametersRecord",
|
|
99
106
|
"Properties",
|
|
100
107
|
"ReconnectIns",
|
|
108
|
+
"RecordDict",
|
|
101
109
|
"RecordSet",
|
|
102
110
|
"Scalar",
|
|
103
111
|
"ServerMessage",
|