flwr 1.16.0__py3-none-any.whl → 1.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +1 -1
- flwr/cli/__init__.py +1 -1
- flwr/cli/app.py +21 -2
- flwr/cli/build.py +1 -1
- flwr/cli/cli_user_auth_interceptor.py +1 -1
- flwr/cli/config_utils.py +53 -17
- flwr/cli/example.py +1 -1
- flwr/cli/install.py +1 -1
- flwr/cli/log.py +1 -1
- flwr/cli/login/__init__.py +1 -1
- flwr/cli/login/login.py +12 -1
- flwr/cli/ls.py +1 -1
- flwr/cli/new/__init__.py +1 -1
- flwr/cli/new/new.py +4 -4
- flwr/cli/new/templates/__init__.py +1 -1
- flwr/cli/new/templates/app/__init__.py +1 -1
- flwr/cli/new/templates/app/code/__init__.py +1 -1
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +1 -1
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +5 -5
- flwr/cli/new/templates/app/code/task.sklearn.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/__init__.py +1 -1
- flwr/cli/run/run.py +6 -10
- flwr/cli/stop.py +1 -1
- flwr/cli/utils.py +11 -12
- flwr/client/__init__.py +1 -1
- flwr/client/app.py +58 -56
- flwr/client/client.py +1 -1
- flwr/client/client_app.py +231 -166
- flwr/client/clientapp/__init__.py +1 -1
- flwr/client/clientapp/app.py +3 -3
- flwr/client/clientapp/clientappio_servicer.py +1 -1
- flwr/client/clientapp/utils.py +1 -1
- flwr/client/dpfedavg_numpy_client.py +1 -1
- flwr/client/grpc_adapter_client/__init__.py +1 -1
- flwr/client/grpc_adapter_client/connection.py +1 -1
- flwr/client/grpc_client/__init__.py +1 -1
- flwr/client/grpc_client/connection.py +37 -34
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/client_interceptor.py +1 -1
- flwr/client/grpc_rere_client/connection.py +1 -1
- flwr/client/grpc_rere_client/grpc_adapter.py +1 -1
- flwr/client/heartbeat.py +1 -1
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +28 -28
- flwr/client/mod/__init__.py +3 -3
- flwr/client/mod/centraldp_mods.py +8 -8
- flwr/client/mod/comms_mods.py +17 -23
- flwr/client/mod/localdp_mod.py +10 -10
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/secure_aggregation/secagg_mod.py +1 -1
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +32 -32
- flwr/client/mod/utils.py +1 -1
- flwr/client/nodestate/__init__.py +1 -1
- flwr/client/nodestate/in_memory_nodestate.py +1 -1
- flwr/client/nodestate/nodestate.py +1 -1
- flwr/client/nodestate/nodestate_factory.py +1 -1
- flwr/client/numpy_client.py +1 -1
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +1 -1
- flwr/client/run_info_store.py +3 -3
- flwr/client/supernode/__init__.py +1 -1
- flwr/client/supernode/app.py +1 -1
- flwr/client/typing.py +1 -1
- flwr/common/__init__.py +13 -5
- flwr/common/address.py +1 -1
- flwr/common/args.py +1 -1
- flwr/common/auth_plugin/__init__.py +1 -1
- flwr/common/auth_plugin/auth_plugin.py +1 -1
- flwr/common/config.py +5 -5
- flwr/common/constant.py +7 -7
- flwr/common/context.py +5 -5
- flwr/common/date.py +1 -1
- flwr/common/differential_privacy.py +1 -1
- flwr/common/differential_privacy_constants.py +1 -1
- flwr/common/dp.py +1 -1
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/exit/exit.py +6 -6
- flwr/common/exit_handlers.py +1 -1
- flwr/common/grpc.py +1 -1
- flwr/common/logger.py +3 -3
- flwr/common/message.py +344 -102
- flwr/common/object_ref.py +1 -1
- flwr/common/parameter.py +1 -1
- flwr/common/pyproject.py +1 -1
- flwr/common/record/__init__.py +9 -5
- flwr/common/record/arrayrecord.py +626 -0
- flwr/common/record/{configsrecord.py → configrecord.py} +83 -37
- flwr/common/record/conversion_utils.py +2 -2
- flwr/common/record/{metricsrecord.py → metricrecord.py} +90 -44
- flwr/common/record/recorddict.py +337 -0
- flwr/common/record/typeddict.py +1 -1
- flwr/common/recorddict_compat.py +410 -0
- flwr/common/retry_invoker.py +10 -10
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +52 -30
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
- flwr/common/secure_aggregation/quantization.py +1 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +2 -2
- flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
- flwr/common/serde.py +67 -72
- flwr/common/telemetry.py +2 -2
- flwr/common/typing.py +9 -9
- flwr/common/version.py +1 -1
- flwr/proto/__init__.py +1 -1
- flwr/proto/exec_pb2.py +3 -3
- flwr/proto/exec_pb2.pyi +3 -3
- flwr/proto/message_pb2.py +12 -12
- flwr/proto/message_pb2.pyi +9 -9
- flwr/proto/recorddict_pb2.py +70 -0
- flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
- flwr/proto/run_pb2.py +31 -31
- flwr/proto/run_pb2.pyi +3 -3
- flwr/server/__init__.py +4 -2
- flwr/server/app.py +67 -12
- flwr/server/client_manager.py +1 -1
- flwr/server/client_proxy.py +1 -1
- flwr/server/compat/__init__.py +3 -3
- flwr/server/compat/app.py +12 -12
- flwr/server/compat/app_utils.py +17 -17
- flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +39 -39
- flwr/server/compat/legacy_context.py +1 -1
- flwr/server/criterion.py +1 -1
- flwr/server/fleet_event_log_interceptor.py +94 -0
- flwr/server/{driver → grid}/__init__.py +8 -7
- flwr/server/{driver/driver.py → grid/grid.py} +48 -19
- flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +87 -64
- flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +24 -34
- flwr/server/history.py +1 -1
- flwr/server/run_serverapp.py +5 -5
- flwr/server/server.py +1 -1
- flwr/server/server_app.py +98 -71
- flwr/server/server_config.py +1 -1
- flwr/server/serverapp/__init__.py +1 -1
- flwr/server/serverapp/app.py +11 -11
- flwr/server/serverapp_components.py +1 -1
- flwr/server/strategy/__init__.py +1 -1
- flwr/server/strategy/aggregate.py +1 -1
- flwr/server/strategy/bulyan.py +2 -2
- flwr/server/strategy/dp_adaptive_clipping.py +17 -17
- flwr/server/strategy/dp_fixed_clipping.py +17 -17
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/strategy/fault_tolerant_fedavg.py +1 -1
- flwr/server/strategy/fedadagrad.py +1 -1
- flwr/server/strategy/fedadam.py +1 -1
- flwr/server/strategy/fedavg.py +1 -1
- flwr/server/strategy/fedavg_android.py +1 -1
- flwr/server/strategy/fedavgm.py +1 -1
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedopt.py +1 -1
- flwr/server/strategy/fedprox.py +1 -1
- flwr/server/strategy/fedtrimmedavg.py +1 -1
- flwr/server/strategy/fedxgb_bagging.py +1 -1
- flwr/server/strategy/fedxgb_cyclic.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +3 -2
- flwr/server/strategy/fedyogi.py +1 -1
- flwr/server/strategy/krum.py +1 -1
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/strategy/strategy.py +1 -1
- flwr/server/superlink/__init__.py +1 -1
- flwr/server/superlink/ffs/__init__.py +1 -1
- flwr/server/superlink/ffs/disk_ffs.py +1 -1
- flwr/server/superlink/ffs/ffs.py +1 -1
- flwr/server/superlink/ffs/ffs_factory.py +1 -1
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +13 -13
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +1 -1
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +1 -1
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
- flwr/server/superlink/fleet/vce/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
- flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -3
- flwr/server/superlink/fleet/vce/vce_api.py +2 -4
- flwr/server/superlink/linkstate/__init__.py +1 -1
- flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -9
- flwr/server/superlink/linkstate/linkstate.py +5 -5
- flwr/server/superlink/linkstate/linkstate_factory.py +1 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +62 -28
- flwr/server/superlink/linkstate/utils.py +94 -28
- flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +4 -4
- flwr/server/superlink/simulation/__init__.py +1 -1
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/superlink/simulation/simulationio_servicer.py +3 -3
- flwr/server/superlink/utils.py +1 -1
- flwr/server/typing.py +4 -4
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +1 -1
- flwr/server/utils/validator.py +5 -5
- flwr/server/workflow/__init__.py +1 -1
- flwr/server/workflow/constant.py +1 -1
- flwr/server/workflow/default_workflows.py +49 -58
- flwr/server/workflow/secure_aggregation/__init__.py +1 -1
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +1 -1
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +49 -51
- flwr/simulation/__init__.py +1 -1
- flwr/simulation/app.py +3 -3
- flwr/simulation/legacy_app.py +1 -1
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +5 -3
- flwr/simulation/ray_transport/ray_client_proxy.py +35 -33
- flwr/simulation/ray_transport/utils.py +1 -1
- flwr/simulation/run_simulation.py +17 -17
- flwr/simulation/simulationio_connection.py +1 -1
- flwr/superexec/__init__.py +1 -1
- flwr/superexec/app.py +1 -1
- flwr/superexec/deployment.py +5 -5
- flwr/superexec/exec_event_log_interceptor.py +135 -0
- flwr/superexec/exec_grpc.py +11 -5
- flwr/superexec/exec_servicer.py +3 -3
- flwr/superexec/exec_user_auth_interceptor.py +19 -3
- flwr/superexec/executor.py +4 -4
- flwr/superexec/simulation.py +4 -4
- {flwr-1.16.0.dist-info → flwr-1.18.0.dist-info}/METADATA +3 -3
- flwr-1.18.0.dist-info/RECORD +332 -0
- flwr/common/record/parametersrecord.py +0 -339
- flwr/common/record/recordset.py +0 -209
- flwr/common/recordset_compat.py +0 -418
- flwr/proto/recordset_pb2.py +0 -70
- flwr-1.16.0.dist-info/LICENSE +0 -202
- flwr-1.16.0.dist-info/RECORD +0 -331
- /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
- /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
- {flwr-1.16.0.dist-info → flwr-1.18.0.dist-info}/WHEEL +0 -0
- {flwr-1.16.0.dist-info → flwr-1.18.0.dist-info}/entry_points.txt +0 -0
flwr/client/clientapp/app.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -152,8 +152,8 @@ def run_clientapp( # pylint: disable=R0914
|
|
|
152
152
|
log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
|
|
153
153
|
|
|
154
154
|
# Create error message
|
|
155
|
-
reply_message =
|
|
156
|
-
|
|
155
|
+
reply_message = Message(
|
|
156
|
+
Error(code=e_code, reason=reason), reply_to=message
|
|
157
157
|
)
|
|
158
158
|
|
|
159
159
|
# Push Message and Context to SuperNode
|
flwr/client/clientapp/utils.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -28,16 +28,18 @@ from cryptography.hazmat.primitives.asymmetric import ec
|
|
|
28
28
|
from flwr.common import (
|
|
29
29
|
DEFAULT_TTL,
|
|
30
30
|
GRPC_MAX_MESSAGE_LENGTH,
|
|
31
|
-
|
|
31
|
+
ConfigRecord,
|
|
32
32
|
Message,
|
|
33
33
|
Metadata,
|
|
34
|
-
|
|
34
|
+
RecordDict,
|
|
35
|
+
now,
|
|
35
36
|
)
|
|
36
|
-
from flwr.common import
|
|
37
|
+
from flwr.common import recorddict_compat as compat
|
|
37
38
|
from flwr.common import serde
|
|
38
39
|
from flwr.common.constant import MessageType, MessageTypeLegacy
|
|
39
40
|
from flwr.common.grpc import create_channel, on_channel_state_change
|
|
40
41
|
from flwr.common.logger import log
|
|
42
|
+
from flwr.common.message import make_message
|
|
41
43
|
from flwr.common.retry_invoker import RetryInvoker
|
|
42
44
|
from flwr.common.typing import Fab, Run
|
|
43
45
|
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
|
@@ -102,18 +104,18 @@ def grpc_connection( # pylint: disable=R0913,R0915,too-many-positional-argument
|
|
|
102
104
|
|
|
103
105
|
Examples
|
|
104
106
|
--------
|
|
105
|
-
Establishing a
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
107
|
+
Establishing a TLS-enabled connection to the server::
|
|
108
|
+
|
|
109
|
+
from pathlib import Path
|
|
110
|
+
with grpc_connection(
|
|
111
|
+
server_address,
|
|
112
|
+
max_message_length=max_message_length,
|
|
113
|
+
root_certificates=Path("/crts/root.pem").read_bytes(),
|
|
114
|
+
) as conn:
|
|
115
|
+
receive, send = conn
|
|
116
|
+
server_message = receive()
|
|
117
|
+
# do something here
|
|
118
|
+
send(client_message)
|
|
117
119
|
"""
|
|
118
120
|
if isinstance(root_certificates, str):
|
|
119
121
|
root_certificates = Path(root_certificates).read_bytes()
|
|
@@ -139,32 +141,32 @@ def grpc_connection( # pylint: disable=R0913,R0915,too-many-positional-argument
|
|
|
139
141
|
# Receive ServerMessage proto
|
|
140
142
|
proto = next(server_message_iterator)
|
|
141
143
|
|
|
142
|
-
# ServerMessage proto --> *Ins -->
|
|
144
|
+
# ServerMessage proto --> *Ins --> RecordDict
|
|
143
145
|
field = proto.WhichOneof("msg")
|
|
144
146
|
message_type = ""
|
|
145
147
|
if field == "get_properties_ins":
|
|
146
|
-
|
|
148
|
+
recorddict = compat.getpropertiesins_to_recorddict(
|
|
147
149
|
serde.get_properties_ins_from_proto(proto.get_properties_ins)
|
|
148
150
|
)
|
|
149
151
|
message_type = MessageTypeLegacy.GET_PROPERTIES
|
|
150
152
|
elif field == "get_parameters_ins":
|
|
151
|
-
|
|
153
|
+
recorddict = compat.getparametersins_to_recorddict(
|
|
152
154
|
serde.get_parameters_ins_from_proto(proto.get_parameters_ins)
|
|
153
155
|
)
|
|
154
156
|
message_type = MessageTypeLegacy.GET_PARAMETERS
|
|
155
157
|
elif field == "fit_ins":
|
|
156
|
-
|
|
158
|
+
recorddict = compat.fitins_to_recorddict(
|
|
157
159
|
serde.fit_ins_from_proto(proto.fit_ins), False
|
|
158
160
|
)
|
|
159
161
|
message_type = MessageType.TRAIN
|
|
160
162
|
elif field == "evaluate_ins":
|
|
161
|
-
|
|
163
|
+
recorddict = compat.evaluateins_to_recorddict(
|
|
162
164
|
serde.evaluate_ins_from_proto(proto.evaluate_ins), False
|
|
163
165
|
)
|
|
164
166
|
message_type = MessageType.EVALUATE
|
|
165
167
|
elif field == "reconnect_ins":
|
|
166
|
-
|
|
167
|
-
|
|
168
|
+
recorddict = RecordDict()
|
|
169
|
+
recorddict.config_records["config"] = ConfigRecord(
|
|
168
170
|
{"seconds": proto.reconnect_ins.seconds}
|
|
169
171
|
)
|
|
170
172
|
message_type = "reconnect"
|
|
@@ -175,45 +177,46 @@ def grpc_connection( # pylint: disable=R0913,R0915,too-many-positional-argument
|
|
|
175
177
|
)
|
|
176
178
|
|
|
177
179
|
# Construct Message
|
|
178
|
-
return
|
|
180
|
+
return make_message(
|
|
179
181
|
metadata=Metadata(
|
|
180
182
|
run_id=0,
|
|
181
183
|
message_id=str(uuid.uuid4()),
|
|
182
184
|
src_node_id=0,
|
|
183
185
|
dst_node_id=0,
|
|
184
|
-
|
|
186
|
+
reply_to_message_id="",
|
|
185
187
|
group_id="",
|
|
188
|
+
created_at=now().timestamp(),
|
|
186
189
|
ttl=DEFAULT_TTL,
|
|
187
190
|
message_type=message_type,
|
|
188
191
|
),
|
|
189
|
-
content=
|
|
192
|
+
content=recorddict,
|
|
190
193
|
)
|
|
191
194
|
|
|
192
195
|
def send(message: Message) -> None:
|
|
193
|
-
# Retrieve
|
|
194
|
-
|
|
196
|
+
# Retrieve RecordDict and message_type
|
|
197
|
+
recorddict = message.content
|
|
195
198
|
message_type = message.metadata.message_type
|
|
196
199
|
|
|
197
|
-
#
|
|
200
|
+
# RecordDict --> *Res --> *Res proto -> ClientMessage proto
|
|
198
201
|
if message_type == MessageTypeLegacy.GET_PROPERTIES:
|
|
199
|
-
getpropres = compat.
|
|
202
|
+
getpropres = compat.recorddict_to_getpropertiesres(recorddict)
|
|
200
203
|
msg_proto = ClientMessage(
|
|
201
204
|
get_properties_res=serde.get_properties_res_to_proto(getpropres)
|
|
202
205
|
)
|
|
203
206
|
elif message_type == MessageTypeLegacy.GET_PARAMETERS:
|
|
204
|
-
getparamres = compat.
|
|
207
|
+
getparamres = compat.recorddict_to_getparametersres(recorddict, False)
|
|
205
208
|
msg_proto = ClientMessage(
|
|
206
209
|
get_parameters_res=serde.get_parameters_res_to_proto(getparamres)
|
|
207
210
|
)
|
|
208
211
|
elif message_type == MessageType.TRAIN:
|
|
209
|
-
fitres = compat.
|
|
212
|
+
fitres = compat.recorddict_to_fitres(recorddict, False)
|
|
210
213
|
msg_proto = ClientMessage(fit_res=serde.fit_res_to_proto(fitres))
|
|
211
214
|
elif message_type == MessageType.EVALUATE:
|
|
212
|
-
evalres = compat.
|
|
215
|
+
evalres = compat.recorddict_to_evaluateres(recorddict)
|
|
213
216
|
msg_proto = ClientMessage(evaluate_res=serde.evaluate_res_to_proto(evalres))
|
|
214
217
|
elif message_type == "reconnect":
|
|
215
218
|
reason = cast(
|
|
216
|
-
Reason.ValueType,
|
|
219
|
+
Reason.ValueType, recorddict.config_records["config"]["reason"]
|
|
217
220
|
)
|
|
218
221
|
msg_proto = ClientMessage(
|
|
219
222
|
disconnect_res=ClientMessage.DisconnectRes(reason=reason)
|
flwr/client/heartbeat.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -26,17 +26,17 @@ from flwr.client.client import (
|
|
|
26
26
|
)
|
|
27
27
|
from flwr.client.numpy_client import NumPyClient
|
|
28
28
|
from flwr.client.typing import ClientFnExt
|
|
29
|
-
from flwr.common import
|
|
29
|
+
from flwr.common import ConfigRecord, Context, Message, Metadata, RecordDict, log
|
|
30
30
|
from flwr.common.constant import MessageType, MessageTypeLegacy
|
|
31
|
-
from flwr.common.
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
31
|
+
from flwr.common.recorddict_compat import (
|
|
32
|
+
evaluateres_to_recorddict,
|
|
33
|
+
fitres_to_recorddict,
|
|
34
|
+
getparametersres_to_recorddict,
|
|
35
|
+
getpropertiesres_to_recorddict,
|
|
36
|
+
recorddict_to_evaluateins,
|
|
37
|
+
recorddict_to_fitins,
|
|
38
|
+
recorddict_to_getparametersins,
|
|
39
|
+
recorddict_to_getpropertiesins,
|
|
40
40
|
)
|
|
41
41
|
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
|
42
42
|
ClientMessage,
|
|
@@ -70,18 +70,18 @@ def handle_control_message(message: Message) -> tuple[Optional[Message], int]:
|
|
|
70
70
|
Number of seconds that the client should disconnect from the server.
|
|
71
71
|
"""
|
|
72
72
|
if message.metadata.message_type == "reconnect":
|
|
73
|
-
# Retrieve ReconnectIns from
|
|
74
|
-
|
|
75
|
-
seconds = cast(int,
|
|
73
|
+
# Retrieve ReconnectIns from RecordDict
|
|
74
|
+
recorddict = message.content
|
|
75
|
+
seconds = cast(int, recorddict.config_records["config"]["seconds"])
|
|
76
76
|
# Construct ReconnectIns and call _reconnect
|
|
77
77
|
disconnect_msg, sleep_duration = _reconnect(
|
|
78
78
|
ServerMessage.ReconnectIns(seconds=seconds)
|
|
79
79
|
)
|
|
80
|
-
# Store DisconnectRes in
|
|
80
|
+
# Store DisconnectRes in RecordDict
|
|
81
81
|
reason = cast(int, disconnect_msg.disconnect_res.reason)
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
out_message = message
|
|
82
|
+
recorddict = RecordDict()
|
|
83
|
+
recorddict.config_records["config"] = ConfigRecord({"reason": reason})
|
|
84
|
+
out_message = Message(recorddict, reply_to=message)
|
|
85
85
|
# Return Message and sleep duration
|
|
86
86
|
return out_message, sleep_duration
|
|
87
87
|
|
|
@@ -111,37 +111,37 @@ def handle_legacy_message_from_msgtype(
|
|
|
111
111
|
if message_type == MessageTypeLegacy.GET_PROPERTIES:
|
|
112
112
|
get_properties_res = maybe_call_get_properties(
|
|
113
113
|
client=client,
|
|
114
|
-
get_properties_ins=
|
|
114
|
+
get_properties_ins=recorddict_to_getpropertiesins(message.content),
|
|
115
115
|
)
|
|
116
|
-
|
|
116
|
+
out_recorddict = getpropertiesres_to_recorddict(get_properties_res)
|
|
117
117
|
# Handle GetParametersIns
|
|
118
118
|
elif message_type == MessageTypeLegacy.GET_PARAMETERS:
|
|
119
119
|
get_parameters_res = maybe_call_get_parameters(
|
|
120
120
|
client=client,
|
|
121
|
-
get_parameters_ins=
|
|
121
|
+
get_parameters_ins=recorddict_to_getparametersins(message.content),
|
|
122
122
|
)
|
|
123
|
-
|
|
123
|
+
out_recorddict = getparametersres_to_recorddict(
|
|
124
124
|
get_parameters_res, keep_input=False
|
|
125
125
|
)
|
|
126
126
|
# Handle FitIns
|
|
127
127
|
elif message_type == MessageType.TRAIN:
|
|
128
128
|
fit_res = maybe_call_fit(
|
|
129
129
|
client=client,
|
|
130
|
-
fit_ins=
|
|
130
|
+
fit_ins=recorddict_to_fitins(message.content, keep_input=True),
|
|
131
131
|
)
|
|
132
|
-
|
|
132
|
+
out_recorddict = fitres_to_recorddict(fit_res, keep_input=False)
|
|
133
133
|
# Handle EvaluateIns
|
|
134
134
|
elif message_type == MessageType.EVALUATE:
|
|
135
135
|
evaluate_res = maybe_call_evaluate(
|
|
136
136
|
client=client,
|
|
137
|
-
evaluate_ins=
|
|
137
|
+
evaluate_ins=recorddict_to_evaluateins(message.content, keep_input=True),
|
|
138
138
|
)
|
|
139
|
-
|
|
139
|
+
out_recorddict = evaluateres_to_recorddict(evaluate_res)
|
|
140
140
|
else:
|
|
141
141
|
raise ValueError(f"Invalid message type: {message_type}")
|
|
142
142
|
|
|
143
143
|
# Return Message
|
|
144
|
-
return message
|
|
144
|
+
return Message(out_recorddict, reply_to=message)
|
|
145
145
|
|
|
146
146
|
|
|
147
147
|
def _reconnect(
|
|
@@ -167,7 +167,7 @@ def validate_out_message(out_message: Message, in_message_metadata: Metadata) ->
|
|
|
167
167
|
and out_meta.message_id == "" # This will be generated by the server
|
|
168
168
|
and out_meta.src_node_id == in_meta.dst_node_id
|
|
169
169
|
and out_meta.dst_node_id == in_meta.src_node_id
|
|
170
|
-
and out_meta.
|
|
170
|
+
and out_meta.reply_to_message_id == in_meta.message_id
|
|
171
171
|
and out_meta.group_id == in_meta.group_id
|
|
172
172
|
and out_meta.message_type == in_meta.message_type
|
|
173
173
|
and out_meta.created_at > in_meta.created_at
|
flwr/client/mod/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from .centraldp_mods import adaptiveclipping_mod, fixedclipping_mod
|
|
19
|
-
from .comms_mods import
|
|
19
|
+
from .comms_mods import arrays_size_mod, message_size_mod
|
|
20
20
|
from .localdp_mod import LocalDpMod
|
|
21
21
|
from .secure_aggregation import secagg_mod, secaggplus_mod
|
|
22
22
|
from .utils import make_ffn
|
|
@@ -24,10 +24,10 @@ from .utils import make_ffn
|
|
|
24
24
|
__all__ = [
|
|
25
25
|
"LocalDpMod",
|
|
26
26
|
"adaptiveclipping_mod",
|
|
27
|
+
"arrays_size_mod",
|
|
27
28
|
"fixedclipping_mod",
|
|
28
29
|
"make_ffn",
|
|
29
30
|
"message_size_mod",
|
|
30
|
-
"parameters_size_mod",
|
|
31
31
|
"secagg_mod",
|
|
32
32
|
"secaggplus_mod",
|
|
33
33
|
]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -19,7 +19,7 @@ from logging import INFO
|
|
|
19
19
|
|
|
20
20
|
from flwr.client.typing import ClientAppCallable
|
|
21
21
|
from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays
|
|
22
|
-
from flwr.common import
|
|
22
|
+
from flwr.common import recorddict_compat as compat
|
|
23
23
|
from flwr.common.constant import MessageType
|
|
24
24
|
from flwr.common.context import Context
|
|
25
25
|
from flwr.common.differential_privacy import (
|
|
@@ -53,7 +53,7 @@ def fixedclipping_mod(
|
|
|
53
53
|
"""
|
|
54
54
|
if msg.metadata.message_type != MessageType.TRAIN:
|
|
55
55
|
return call_next(msg, ctxt)
|
|
56
|
-
fit_ins = compat.
|
|
56
|
+
fit_ins = compat.recorddict_to_fitins(msg.content, keep_input=True)
|
|
57
57
|
if KEY_CLIPPING_NORM not in fit_ins.config:
|
|
58
58
|
raise KeyError(
|
|
59
59
|
f"The {KEY_CLIPPING_NORM} value is not supplied by the "
|
|
@@ -71,7 +71,7 @@ def fixedclipping_mod(
|
|
|
71
71
|
if out_msg.has_error():
|
|
72
72
|
return out_msg
|
|
73
73
|
|
|
74
|
-
fit_res = compat.
|
|
74
|
+
fit_res = compat.recorddict_to_fitres(out_msg.content, keep_input=True)
|
|
75
75
|
|
|
76
76
|
client_to_server_params = parameters_to_ndarrays(fit_res.parameters)
|
|
77
77
|
|
|
@@ -87,7 +87,7 @@ def fixedclipping_mod(
|
|
|
87
87
|
)
|
|
88
88
|
|
|
89
89
|
fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
|
|
90
|
-
out_msg.content = compat.
|
|
90
|
+
out_msg.content = compat.fitres_to_recorddict(fit_res, keep_input=True)
|
|
91
91
|
return out_msg
|
|
92
92
|
|
|
93
93
|
|
|
@@ -116,7 +116,7 @@ def adaptiveclipping_mod(
|
|
|
116
116
|
if msg.metadata.message_type != MessageType.TRAIN:
|
|
117
117
|
return call_next(msg, ctxt)
|
|
118
118
|
|
|
119
|
-
fit_ins = compat.
|
|
119
|
+
fit_ins = compat.recorddict_to_fitins(msg.content, keep_input=True)
|
|
120
120
|
|
|
121
121
|
if KEY_CLIPPING_NORM not in fit_ins.config:
|
|
122
122
|
raise KeyError(
|
|
@@ -136,7 +136,7 @@ def adaptiveclipping_mod(
|
|
|
136
136
|
if out_msg.has_error():
|
|
137
137
|
return out_msg
|
|
138
138
|
|
|
139
|
-
fit_res = compat.
|
|
139
|
+
fit_res = compat.recorddict_to_fitres(out_msg.content, keep_input=True)
|
|
140
140
|
|
|
141
141
|
client_to_server_params = parameters_to_ndarrays(fit_res.parameters)
|
|
142
142
|
|
|
@@ -155,5 +155,5 @@ def adaptiveclipping_mod(
|
|
|
155
155
|
fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
|
|
156
156
|
|
|
157
157
|
fit_res.metrics[KEY_NORM_BIT] = norm_bit
|
|
158
|
-
out_msg.content = compat.
|
|
158
|
+
out_msg.content = compat.fitres_to_recorddict(fit_res, keep_input=True)
|
|
159
159
|
return out_msg
|
flwr/client/mod/comms_mods.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -34,47 +34,41 @@ def message_size_mod(
|
|
|
34
34
|
"""
|
|
35
35
|
message_size_in_bytes = 0
|
|
36
36
|
|
|
37
|
-
for
|
|
38
|
-
message_size_in_bytes +=
|
|
39
|
-
|
|
40
|
-
for c_record in msg.content.configs_records.values():
|
|
41
|
-
message_size_in_bytes += c_record.count_bytes()
|
|
42
|
-
|
|
43
|
-
for m_record in msg.content.metrics_records.values():
|
|
44
|
-
message_size_in_bytes += m_record.count_bytes()
|
|
37
|
+
for record in msg.content.values():
|
|
38
|
+
message_size_in_bytes += record.count_bytes()
|
|
45
39
|
|
|
46
40
|
log(INFO, "Message size: %i bytes", message_size_in_bytes)
|
|
47
41
|
|
|
48
42
|
return call_next(msg, ctxt)
|
|
49
43
|
|
|
50
44
|
|
|
51
|
-
def
|
|
45
|
+
def arrays_size_mod(
|
|
52
46
|
msg: Message, ctxt: Context, call_next: ClientAppCallable
|
|
53
47
|
) -> Message:
|
|
54
|
-
"""
|
|
48
|
+
"""Arrays size mod.
|
|
55
49
|
|
|
56
|
-
This mod logs the number of
|
|
57
|
-
|
|
50
|
+
This mod logs the number of array elements transmitted in ``ArrayRecord`` objects
|
|
51
|
+
of the message as well as their sizes in bytes.
|
|
58
52
|
"""
|
|
59
53
|
model_size_stats = {}
|
|
60
|
-
|
|
61
|
-
for record_name,
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
for array in
|
|
66
|
-
|
|
54
|
+
arrays_size_in_bytes = 0
|
|
55
|
+
for record_name, arr_record in msg.content.array_records.items():
|
|
56
|
+
arr_record_bytes = arr_record.count_bytes()
|
|
57
|
+
arrays_size_in_bytes += arr_record_bytes
|
|
58
|
+
element_count = 0
|
|
59
|
+
for array in arr_record.values():
|
|
60
|
+
element_count += (
|
|
67
61
|
int(np.prod(array.shape)) if array.shape else array.numpy().size
|
|
68
62
|
)
|
|
69
63
|
|
|
70
64
|
model_size_stats[f"{record_name}"] = {
|
|
71
|
-
"
|
|
72
|
-
"bytes":
|
|
65
|
+
"elements": element_count,
|
|
66
|
+
"bytes": arr_record_bytes,
|
|
73
67
|
}
|
|
74
68
|
|
|
75
69
|
if model_size_stats:
|
|
76
70
|
log(INFO, model_size_stats)
|
|
77
71
|
|
|
78
|
-
log(INFO, "Total
|
|
72
|
+
log(INFO, "Total array elements transmitted: %i bytes", arrays_size_in_bytes)
|
|
79
73
|
|
|
80
74
|
return call_next(msg, ctxt)
|
flwr/client/mod/localdp_mod.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -21,7 +21,7 @@ import numpy as np
|
|
|
21
21
|
|
|
22
22
|
from flwr.client.typing import ClientAppCallable
|
|
23
23
|
from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays
|
|
24
|
-
from flwr.common import
|
|
24
|
+
from flwr.common import recorddict_compat as compat
|
|
25
25
|
from flwr.common.constant import MessageType
|
|
26
26
|
from flwr.common.context import Context
|
|
27
27
|
from flwr.common.differential_privacy import (
|
|
@@ -57,12 +57,12 @@ class LocalDpMod:
|
|
|
57
57
|
|
|
58
58
|
Examples
|
|
59
59
|
--------
|
|
60
|
-
Create an instance of the local DP mod and add it to the client-side mods
|
|
60
|
+
Create an instance of the local DP mod and add it to the client-side mods::
|
|
61
61
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
62
|
+
local_dp_mod = LocalDpMod( ... )
|
|
63
|
+
app = fl.client.ClientApp(
|
|
64
|
+
client_fn=client_fn, mods=[local_dp_mod]
|
|
65
|
+
)
|
|
66
66
|
"""
|
|
67
67
|
|
|
68
68
|
def __init__(
|
|
@@ -107,7 +107,7 @@ class LocalDpMod:
|
|
|
107
107
|
if msg.metadata.message_type != MessageType.TRAIN:
|
|
108
108
|
return call_next(msg, ctxt)
|
|
109
109
|
|
|
110
|
-
fit_ins = compat.
|
|
110
|
+
fit_ins = compat.recorddict_to_fitins(msg.content, keep_input=True)
|
|
111
111
|
server_to_client_params = parameters_to_ndarrays(fit_ins.parameters)
|
|
112
112
|
|
|
113
113
|
# Call inner app
|
|
@@ -117,7 +117,7 @@ class LocalDpMod:
|
|
|
117
117
|
if out_msg.has_error():
|
|
118
118
|
return out_msg
|
|
119
119
|
|
|
120
|
-
fit_res = compat.
|
|
120
|
+
fit_res = compat.recorddict_to_fitres(out_msg.content, keep_input=True)
|
|
121
121
|
|
|
122
122
|
client_to_server_params = parameters_to_ndarrays(fit_res.parameters)
|
|
123
123
|
|
|
@@ -149,5 +149,5 @@ class LocalDpMod:
|
|
|
149
149
|
noise_value_sd,
|
|
150
150
|
)
|
|
151
151
|
|
|
152
|
-
out_msg.content = compat.
|
|
152
|
+
out_msg.content = compat.fitres_to_recorddict(fit_res, keep_input=True)
|
|
153
153
|
return out_msg
|