flwr 1.19.0__py3-none-any.whl → 1.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +4 -1
- flwr/app/__init__.py +28 -0
- flwr/app/exception.py +31 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +4 -4
- flwr/cli/build.py +15 -5
- flwr/cli/cli_user_auth_interceptor.py +1 -1
- flwr/cli/config_utils.py +3 -3
- flwr/cli/constant.py +25 -8
- flwr/cli/log.py +9 -9
- flwr/cli/login/login.py +3 -3
- flwr/cli/ls.py +5 -5
- flwr/cli/new/new.py +23 -4
- flwr/cli/new/templates/app/README.flowertune.md.tpl +2 -0
- flwr/cli/new/templates/app/README.md.tpl +5 -0
- flwr/cli/new/templates/app/code/__init__.pytorch_msg_api.py.tpl +1 -0
- flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +80 -0
- flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +41 -0
- flwr/cli/new/templates/app/code/task.pytorch_msg_api.py.tpl +98 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -3
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +13 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +21 -2
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +19 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +20 -3
- flwr/cli/new/templates/app/pyproject.pytorch_msg_api.toml.tpl +53 -0
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +18 -1
- flwr/cli/run/run.py +53 -50
- flwr/cli/stop.py +7 -4
- flwr/cli/utils.py +29 -11
- flwr/client/grpc_adapter_client/connection.py +11 -4
- flwr/client/grpc_rere_client/connection.py +93 -129
- flwr/client/rest_client/connection.py +134 -164
- flwr/clientapp/__init__.py +10 -0
- flwr/clientapp/mod/__init__.py +26 -0
- flwr/clientapp/mod/centraldp_mods.py +132 -0
- flwr/common/args.py +20 -6
- flwr/common/auth_plugin/__init__.py +4 -4
- flwr/common/auth_plugin/auth_plugin.py +7 -7
- flwr/common/constant.py +26 -5
- flwr/common/event_log_plugin/event_log_plugin.py +1 -1
- flwr/common/exit/__init__.py +4 -0
- flwr/common/exit/exit.py +8 -1
- flwr/common/exit/exit_code.py +42 -8
- flwr/common/exit/exit_handler.py +62 -0
- flwr/common/{exit_handlers.py → exit/signal_handler.py} +20 -37
- flwr/common/grpc.py +1 -1
- flwr/common/{inflatable_grpc_utils.py → inflatable_protobuf_utils.py} +52 -10
- flwr/common/inflatable_utils.py +191 -24
- flwr/common/logger.py +1 -1
- flwr/common/record/array.py +101 -22
- flwr/common/record/arraychunk.py +59 -0
- flwr/common/retry_invoker.py +30 -11
- flwr/common/serde.py +0 -28
- flwr/common/telemetry.py +4 -0
- flwr/compat/client/app.py +14 -31
- flwr/compat/server/app.py +2 -2
- flwr/proto/appio_pb2.py +51 -0
- flwr/proto/appio_pb2.pyi +195 -0
- flwr/proto/appio_pb2_grpc.py +4 -0
- flwr/proto/appio_pb2_grpc.pyi +4 -0
- flwr/proto/clientappio_pb2.py +4 -19
- flwr/proto/clientappio_pb2.pyi +0 -125
- flwr/proto/clientappio_pb2_grpc.py +269 -29
- flwr/proto/clientappio_pb2_grpc.pyi +114 -21
- flwr/proto/control_pb2.py +62 -0
- flwr/proto/{exec_pb2_grpc.py → control_pb2_grpc.py} +54 -54
- flwr/proto/{exec_pb2_grpc.pyi → control_pb2_grpc.pyi} +28 -28
- flwr/proto/fleet_pb2.py +12 -20
- flwr/proto/fleet_pb2.pyi +6 -36
- flwr/proto/serverappio_pb2.py +8 -31
- flwr/proto/serverappio_pb2.pyi +0 -152
- flwr/proto/serverappio_pb2_grpc.py +107 -38
- flwr/proto/serverappio_pb2_grpc.pyi +47 -20
- flwr/proto/simulationio_pb2.py +4 -11
- flwr/proto/simulationio_pb2.pyi +0 -58
- flwr/proto/simulationio_pb2_grpc.py +129 -27
- flwr/proto/simulationio_pb2_grpc.pyi +52 -13
- flwr/server/app.py +130 -153
- flwr/server/fleet_event_log_interceptor.py +4 -0
- flwr/server/grid/grpc_grid.py +94 -54
- flwr/server/grid/inmemory_grid.py +1 -0
- flwr/server/serverapp/app.py +165 -144
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +8 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -5
- flwr/server/superlink/fleet/message_handler/message_handler.py +10 -16
- flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -1
- flwr/server/superlink/fleet/vce/vce_api.py +6 -6
- flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -0
- flwr/server/superlink/linkstate/linkstate.py +2 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +45 -0
- flwr/server/superlink/serverappio/serverappio_grpc.py +2 -2
- flwr/server/superlink/serverappio/serverappio_servicer.py +95 -48
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/superlink/simulation/simulationio_servicer.py +98 -22
- flwr/server/superlink/utils.py +0 -35
- flwr/serverapp/__init__.py +12 -0
- flwr/serverapp/dp_fixed_clipping.py +352 -0
- flwr/serverapp/exception.py +38 -0
- flwr/serverapp/strategy/__init__.py +38 -0
- flwr/serverapp/strategy/dp_fixed_clipping.py +352 -0
- flwr/serverapp/strategy/fedadagrad.py +162 -0
- flwr/serverapp/strategy/fedadam.py +181 -0
- flwr/serverapp/strategy/fedavg.py +295 -0
- flwr/serverapp/strategy/fedopt.py +218 -0
- flwr/serverapp/strategy/fedyogi.py +173 -0
- flwr/serverapp/strategy/result.py +105 -0
- flwr/serverapp/strategy/strategy.py +285 -0
- flwr/serverapp/strategy/strategy_utils.py +251 -0
- flwr/serverapp/strategy/strategy_utils_tests.py +304 -0
- flwr/simulation/app.py +159 -154
- flwr/simulation/run_simulation.py +17 -0
- flwr/supercore/app_utils.py +58 -0
- flwr/supercore/cli/__init__.py +22 -0
- flwr/supercore/cli/flower_superexec.py +141 -0
- flwr/supercore/corestate/__init__.py +22 -0
- flwr/supercore/corestate/corestate.py +81 -0
- flwr/{server/superlink → supercore}/ffs/disk_ffs.py +1 -1
- flwr/supercore/grpc_health/__init__.py +25 -0
- flwr/supercore/grpc_health/health_server.py +53 -0
- flwr/supercore/grpc_health/simple_health_servicer.py +38 -0
- flwr/supercore/license_plugin/__init__.py +22 -0
- flwr/supercore/license_plugin/license_plugin.py +26 -0
- flwr/supercore/object_store/in_memory_object_store.py +31 -31
- flwr/supercore/object_store/object_store.py +20 -42
- flwr/supercore/object_store/utils.py +43 -0
- flwr/{superexec → supercore/superexec}/__init__.py +1 -1
- flwr/supercore/superexec/plugin/__init__.py +28 -0
- flwr/supercore/superexec/plugin/base_exec_plugin.py +53 -0
- flwr/supercore/superexec/plugin/clientapp_exec_plugin.py +28 -0
- flwr/supercore/superexec/plugin/exec_plugin.py +71 -0
- flwr/supercore/superexec/plugin/serverapp_exec_plugin.py +28 -0
- flwr/supercore/superexec/plugin/simulation_exec_plugin.py +28 -0
- flwr/supercore/superexec/run_superexec.py +185 -0
- flwr/supercore/utils.py +32 -0
- flwr/superlink/servicer/__init__.py +15 -0
- flwr/superlink/servicer/control/__init__.py +22 -0
- flwr/{superexec/exec_event_log_interceptor.py → superlink/servicer/control/control_event_log_interceptor.py} +9 -5
- flwr/{superexec/exec_grpc.py → superlink/servicer/control/control_grpc.py} +39 -28
- flwr/superlink/servicer/control/control_license_interceptor.py +82 -0
- flwr/{superexec/exec_servicer.py → superlink/servicer/control/control_servicer.py} +79 -31
- flwr/{superexec/exec_user_auth_interceptor.py → superlink/servicer/control/control_user_auth_interceptor.py} +18 -10
- flwr/supernode/cli/flower_supernode.py +3 -7
- flwr/supernode/cli/flwr_clientapp.py +20 -16
- flwr/supernode/nodestate/in_memory_nodestate.py +13 -4
- flwr/supernode/nodestate/nodestate.py +3 -44
- flwr/supernode/runtime/run_clientapp.py +129 -115
- flwr/supernode/servicer/clientappio/__init__.py +1 -3
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +217 -165
- flwr/supernode/start_client_internal.py +205 -148
- {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/METADATA +5 -3
- {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/RECORD +161 -117
- {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/entry_points.txt +1 -0
- flwr/common/inflatable_rest_utils.py +0 -99
- flwr/proto/exec_pb2.py +0 -62
- flwr/superexec/app.py +0 -45
- flwr/superexec/deployment.py +0 -192
- flwr/superexec/executor.py +0 -100
- flwr/superexec/simulation.py +0 -130
- /flwr/proto/{exec_pb2.pyi → control_pb2.pyi} +0 -0
- /flwr/{server/superlink → supercore}/ffs/__init__.py +0 -0
- /flwr/{server/superlink → supercore}/ffs/ffs.py +0 -0
- /flwr/{server/superlink → supercore}/ffs/ffs_factory.py +0 -0
- {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/WHEEL +0 -0
|
@@ -14,40 +14,29 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Contextmanager for a REST request-response channel to the Flower server."""
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
from collections.abc import Iterator
|
|
18
19
|
from contextlib import contextmanager
|
|
19
|
-
from
|
|
20
|
-
from
|
|
21
|
-
from typing import Callable, Optional, TypeVar, Union, cast
|
|
20
|
+
from logging import ERROR, WARN
|
|
21
|
+
from typing import Callable, Optional, TypeVar, Union
|
|
22
22
|
|
|
23
23
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
24
24
|
from google.protobuf.message import Message as GrpcMessage
|
|
25
25
|
from requests.exceptions import ConnectionError as RequestsConnectionError
|
|
26
26
|
|
|
27
|
-
from flwr.app.metadata import Metadata
|
|
28
|
-
from flwr.client.message_handler.message_handler import validate_out_message
|
|
29
27
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
30
28
|
from flwr.common.constant import HEARTBEAT_DEFAULT_INTERVAL
|
|
31
29
|
from flwr.common.exit import ExitCode, flwr_exit
|
|
32
30
|
from flwr.common.heartbeat import HeartbeatSender
|
|
33
|
-
from flwr.common.
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
)
|
|
38
|
-
from flwr.common.inflatable_rest_utils import (
|
|
39
|
-
make_pull_object_fn_rest,
|
|
40
|
-
make_push_object_fn_rest,
|
|
41
|
-
)
|
|
42
|
-
from flwr.common.inflatable_utils import (
|
|
43
|
-
inflate_object_from_contents,
|
|
44
|
-
pull_objects,
|
|
45
|
-
push_objects,
|
|
31
|
+
from flwr.common.inflatable_protobuf_utils import (
|
|
32
|
+
make_confirm_message_received_fn_protobuf,
|
|
33
|
+
make_pull_object_fn_protobuf,
|
|
34
|
+
make_push_object_fn_protobuf,
|
|
46
35
|
)
|
|
47
36
|
from flwr.common.logger import log
|
|
48
37
|
from flwr.common.message import Message, remove_content_from_message
|
|
49
38
|
from flwr.common.retry_invoker import RetryInvoker
|
|
50
|
-
from flwr.common.serde import message_to_proto, run_from_proto
|
|
39
|
+
from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
|
|
51
40
|
from flwr.common.typing import Fab, Run
|
|
52
41
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
53
42
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
@@ -67,6 +56,7 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
|
67
56
|
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
|
68
57
|
ConfirmMessageReceivedRequest,
|
|
69
58
|
ConfirmMessageReceivedResponse,
|
|
59
|
+
ObjectTree,
|
|
70
60
|
PullObjectRequest,
|
|
71
61
|
PullObjectResponse,
|
|
72
62
|
PushObjectRequest,
|
|
@@ -109,12 +99,15 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
109
99
|
] = None,
|
|
110
100
|
) -> Iterator[
|
|
111
101
|
tuple[
|
|
112
|
-
Callable[[], Optional[Message]],
|
|
113
|
-
Callable[[Message],
|
|
102
|
+
Callable[[], Optional[tuple[Message, ObjectTree]]],
|
|
103
|
+
Callable[[Message, ObjectTree], set[str]],
|
|
114
104
|
Callable[[], Optional[int]],
|
|
115
105
|
Callable[[], None],
|
|
116
106
|
Callable[[int], Run],
|
|
117
107
|
Callable[[str, int], Fab],
|
|
108
|
+
Callable[[int, str], bytes],
|
|
109
|
+
Callable[[int, str, bytes], None],
|
|
110
|
+
Callable[[int, str], None],
|
|
118
111
|
]
|
|
119
112
|
]:
|
|
120
113
|
"""Primitives for request/response-based interaction with a server.
|
|
@@ -150,6 +143,9 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
150
143
|
create_node : Optional[Callable]
|
|
151
144
|
delete_node : Optional[Callable]
|
|
152
145
|
get_run : Optional[Callable]
|
|
146
|
+
pull_object : Callable[[str], bytes]
|
|
147
|
+
push_object : Callable[[str, bytes], None]
|
|
148
|
+
confirm_message_received : Callable[[str], None]
|
|
153
149
|
"""
|
|
154
150
|
log(
|
|
155
151
|
WARN,
|
|
@@ -178,9 +174,11 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
178
174
|
log(ERROR, "Client authentication is not supported for this transport type.")
|
|
179
175
|
|
|
180
176
|
# Shared variables for inner functions
|
|
181
|
-
metadata: Optional[Metadata] = None
|
|
182
177
|
node: Optional[Node] = None
|
|
183
178
|
|
|
179
|
+
# Remove should_giveup from RetryInvoker as REST does not support gRPC status codes
|
|
180
|
+
retry_invoker.should_giveup = None
|
|
181
|
+
|
|
184
182
|
###########################################################################
|
|
185
183
|
# heartbeat/create_node/delete_node/receive/send/get_run functions
|
|
186
184
|
###########################################################################
|
|
@@ -232,6 +230,38 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
232
230
|
grpc_res.ParseFromString(res.content)
|
|
233
231
|
return grpc_res
|
|
234
232
|
|
|
233
|
+
def _pull_object_protobuf(request: PullObjectRequest) -> PullObjectResponse:
|
|
234
|
+
res = _request(
|
|
235
|
+
req=request,
|
|
236
|
+
res_type=PullObjectResponse,
|
|
237
|
+
api_path=PATH_PULL_OBJECT,
|
|
238
|
+
)
|
|
239
|
+
if res is None:
|
|
240
|
+
raise ValueError(f"{PullObjectResponse.__name__} is None.")
|
|
241
|
+
return res
|
|
242
|
+
|
|
243
|
+
def _push_object_protobuf(request: PushObjectRequest) -> PushObjectResponse:
|
|
244
|
+
res = _request(
|
|
245
|
+
req=request,
|
|
246
|
+
res_type=PushObjectResponse,
|
|
247
|
+
api_path=PATH_PUSH_OBJECT,
|
|
248
|
+
)
|
|
249
|
+
if res is None:
|
|
250
|
+
raise ValueError(f"{PushObjectResponse.__name__} is None.")
|
|
251
|
+
return res
|
|
252
|
+
|
|
253
|
+
def _confirm_message_received_protobuf(
|
|
254
|
+
request: ConfirmMessageReceivedRequest,
|
|
255
|
+
) -> ConfirmMessageReceivedResponse:
|
|
256
|
+
res = _request(
|
|
257
|
+
req=request,
|
|
258
|
+
res_type=ConfirmMessageReceivedResponse,
|
|
259
|
+
api_path=PATH_CONFIRM_MESSAGE_RECEIVED,
|
|
260
|
+
)
|
|
261
|
+
if res is None:
|
|
262
|
+
raise ValueError(f"{ConfirmMessageReceivedResponse.__name__} is None.")
|
|
263
|
+
return res
|
|
264
|
+
|
|
235
265
|
def send_node_heartbeat() -> bool:
|
|
236
266
|
# Get Node
|
|
237
267
|
if node is None:
|
|
@@ -279,8 +309,7 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
279
309
|
"""Set delete_node."""
|
|
280
310
|
nonlocal node
|
|
281
311
|
if node is None:
|
|
282
|
-
|
|
283
|
-
return
|
|
312
|
+
raise RuntimeError("Node instance missing")
|
|
284
313
|
|
|
285
314
|
# Stop the heartbeat sender
|
|
286
315
|
heartbeat_sender.stop()
|
|
@@ -296,162 +325,54 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
296
325
|
# Cleanup
|
|
297
326
|
node = None
|
|
298
327
|
|
|
299
|
-
def receive() -> Optional[Message]:
|
|
300
|
-
"""
|
|
328
|
+
def receive() -> Optional[tuple[Message, ObjectTree]]:
|
|
329
|
+
"""Pull a message with its ObjectTree from SuperLink."""
|
|
301
330
|
# Get Node
|
|
302
331
|
if node is None:
|
|
303
|
-
|
|
304
|
-
return None
|
|
332
|
+
raise RuntimeError("Node instance missing")
|
|
305
333
|
|
|
306
|
-
#
|
|
334
|
+
# Try to pull a message with its object tree from SuperLink
|
|
307
335
|
req = PullMessagesRequest(node=node)
|
|
308
|
-
|
|
309
|
-
# Send the request
|
|
310
336
|
res = _request(req, PullMessagesResponse, PATH_PULL_MESSAGES)
|
|
311
337
|
if res is None:
|
|
312
|
-
|
|
338
|
+
raise ValueError("PushMessagesResponse is None.")
|
|
313
339
|
|
|
314
|
-
#
|
|
315
|
-
|
|
340
|
+
# If no messages are available, return None
|
|
341
|
+
if len(res.messages_list) == 0:
|
|
342
|
+
return None
|
|
316
343
|
|
|
317
|
-
#
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
):
|
|
321
|
-
message_proto = None
|
|
344
|
+
# Get the current Message and its object tree
|
|
345
|
+
message_proto = res.messages_list[0]
|
|
346
|
+
object_tree = res.message_object_trees[0]
|
|
322
347
|
|
|
323
348
|
# Construct the Message
|
|
324
|
-
in_message
|
|
325
|
-
|
|
326
|
-
if message_proto:
|
|
327
|
-
log(INFO, "[Node] POST /%s: success", PATH_PULL_MESSAGES)
|
|
328
|
-
msg_id = message_proto.metadata.message_id
|
|
329
|
-
run_id = message_proto.metadata.run_id
|
|
330
|
-
|
|
331
|
-
def fn(request: PullObjectRequest) -> PullObjectResponse:
|
|
332
|
-
res = _request(
|
|
333
|
-
req=request, res_type=PullObjectResponse, api_path=PATH_PULL_OBJECT
|
|
334
|
-
)
|
|
335
|
-
if res is None:
|
|
336
|
-
raise ValueError("PushObjectResponse is None.")
|
|
337
|
-
return res
|
|
338
|
-
|
|
339
|
-
try:
|
|
340
|
-
all_object_contents = pull_objects(
|
|
341
|
-
list(res.objects_to_pull[msg_id].object_ids) + [msg_id],
|
|
342
|
-
pull_object_fn=make_pull_object_fn_rest(
|
|
343
|
-
pull_object_rest=fn,
|
|
344
|
-
node=node,
|
|
345
|
-
run_id=run_id,
|
|
346
|
-
),
|
|
347
|
-
)
|
|
348
|
-
|
|
349
|
-
# Confirm that the message has been received
|
|
350
|
-
_request(
|
|
351
|
-
req=ConfirmMessageReceivedRequest(
|
|
352
|
-
node=node, run_id=run_id, message_object_id=msg_id
|
|
353
|
-
),
|
|
354
|
-
res_type=ConfirmMessageReceivedResponse,
|
|
355
|
-
api_path=PATH_CONFIRM_MESSAGE_RECEIVED,
|
|
356
|
-
)
|
|
357
|
-
except ValueError as e:
|
|
358
|
-
log(
|
|
359
|
-
ERROR,
|
|
360
|
-
"Pulling objects failed. Potential irrecoverable error: %s",
|
|
361
|
-
str(e),
|
|
362
|
-
)
|
|
363
|
-
in_message = cast(
|
|
364
|
-
Message, inflate_object_from_contents(msg_id, all_object_contents)
|
|
365
|
-
)
|
|
366
|
-
# The deflated message doesn't contain the message_id (its own object_id)
|
|
367
|
-
# Inject
|
|
368
|
-
in_message.metadata.__dict__["_message_id"] = msg_id
|
|
369
|
-
|
|
370
|
-
# Remember `metadata` of the in message
|
|
371
|
-
nonlocal metadata
|
|
372
|
-
metadata = copy(in_message.metadata) if in_message else None
|
|
349
|
+
in_message = message_from_proto(message_proto)
|
|
373
350
|
|
|
374
|
-
|
|
351
|
+
# Return the Message and its object tree
|
|
352
|
+
return in_message, object_tree
|
|
375
353
|
|
|
376
|
-
def send(message: Message) ->
|
|
377
|
-
"""Send
|
|
354
|
+
def send(message: Message, object_tree: ObjectTree) -> set[str]:
|
|
355
|
+
"""Send the message with its ObjectTree to SuperLink."""
|
|
378
356
|
# Get Node
|
|
379
357
|
if node is None:
|
|
380
|
-
|
|
381
|
-
return
|
|
358
|
+
raise RuntimeError("Node instance missing")
|
|
382
359
|
|
|
383
|
-
#
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
log(ERROR, "No current message")
|
|
387
|
-
return
|
|
360
|
+
# Remove the content from the message if it has
|
|
361
|
+
if message.has_content():
|
|
362
|
+
message = remove_content_from_message(message)
|
|
388
363
|
|
|
389
|
-
#
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
all_objects = get_all_nested_objects(message)
|
|
399
|
-
object_tree = get_object_tree(message)
|
|
400
|
-
|
|
401
|
-
# Serialize Message
|
|
402
|
-
message_proto = message_to_proto(
|
|
403
|
-
message=remove_content_from_message(message)
|
|
404
|
-
)
|
|
405
|
-
req = PushMessagesRequest(
|
|
406
|
-
node=node,
|
|
407
|
-
messages_list=[message_proto],
|
|
408
|
-
message_object_trees=[object_tree],
|
|
409
|
-
)
|
|
410
|
-
|
|
411
|
-
# Send the request
|
|
412
|
-
res = _request(req, PushMessagesResponse, PATH_PUSH_MESSAGES)
|
|
413
|
-
if res:
|
|
414
|
-
log(
|
|
415
|
-
INFO,
|
|
416
|
-
"[Node] POST /%s: success, created result %s",
|
|
417
|
-
PATH_PUSH_MESSAGES,
|
|
418
|
-
res.results, # pylint: disable=no-member
|
|
419
|
-
)
|
|
420
|
-
|
|
421
|
-
if res and res.objects_to_push:
|
|
422
|
-
objs_to_push = set(res.objects_to_push[message.object_id].object_ids)
|
|
423
|
-
|
|
424
|
-
def fn(request: PushObjectRequest) -> PushObjectResponse:
|
|
425
|
-
res = _request(
|
|
426
|
-
req=request,
|
|
427
|
-
res_type=PushObjectResponse,
|
|
428
|
-
api_path=PATH_PUSH_OBJECT,
|
|
429
|
-
)
|
|
430
|
-
if res is None:
|
|
431
|
-
raise ValueError("PushObjectResponse is None.")
|
|
432
|
-
return res
|
|
433
|
-
|
|
434
|
-
try:
|
|
435
|
-
push_objects(
|
|
436
|
-
all_objects,
|
|
437
|
-
push_object_fn=make_push_object_fn_rest(
|
|
438
|
-
push_object_rest=fn,
|
|
439
|
-
node=node,
|
|
440
|
-
run_id=message_proto.metadata.run_id,
|
|
441
|
-
),
|
|
442
|
-
object_ids_to_push=objs_to_push,
|
|
443
|
-
)
|
|
444
|
-
log(DEBUG, "Pushed %s objects to servicer.", len(objs_to_push))
|
|
445
|
-
except ValueError as e:
|
|
446
|
-
log(
|
|
447
|
-
ERROR,
|
|
448
|
-
"Pushing objects failed. Potential irrecoverable error: %s",
|
|
449
|
-
str(e),
|
|
450
|
-
)
|
|
451
|
-
log(ERROR, str(e))
|
|
364
|
+
# Send the message with its ObjectTree to SuperLink
|
|
365
|
+
req = PushMessagesRequest(
|
|
366
|
+
node=node,
|
|
367
|
+
messages_list=[message_to_proto(message)],
|
|
368
|
+
message_object_trees=[object_tree],
|
|
369
|
+
)
|
|
370
|
+
res = _request(req, PushMessagesResponse, PATH_PUSH_MESSAGES)
|
|
371
|
+
if res is None:
|
|
372
|
+
raise ValueError("PushMessagesResponse is None.")
|
|
452
373
|
|
|
453
|
-
#
|
|
454
|
-
|
|
374
|
+
# Get and return the object IDs to push
|
|
375
|
+
return set(res.objects_to_push)
|
|
455
376
|
|
|
456
377
|
def get_run(run_id: int) -> Run:
|
|
457
378
|
# Construct the request
|
|
@@ -478,9 +399,58 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
478
399
|
res.fab.content,
|
|
479
400
|
)
|
|
480
401
|
|
|
402
|
+
def pull_object(run_id: int, object_id: str) -> bytes:
|
|
403
|
+
"""Pull the object from the SuperLink."""
|
|
404
|
+
# Check Node
|
|
405
|
+
if node is None:
|
|
406
|
+
raise RuntimeError("Node instance missing")
|
|
407
|
+
|
|
408
|
+
fn = make_pull_object_fn_protobuf(
|
|
409
|
+
pull_object_protobuf=_pull_object_protobuf,
|
|
410
|
+
node=node,
|
|
411
|
+
run_id=run_id,
|
|
412
|
+
)
|
|
413
|
+
return fn(object_id)
|
|
414
|
+
|
|
415
|
+
def push_object(run_id: int, object_id: str, contents: bytes) -> None:
|
|
416
|
+
"""Push the object to the SuperLink."""
|
|
417
|
+
# Check Node
|
|
418
|
+
if node is None:
|
|
419
|
+
raise RuntimeError("Node instance missing")
|
|
420
|
+
|
|
421
|
+
fn = make_push_object_fn_protobuf(
|
|
422
|
+
push_object_protobuf=_push_object_protobuf,
|
|
423
|
+
node=node,
|
|
424
|
+
run_id=run_id,
|
|
425
|
+
)
|
|
426
|
+
fn(object_id, contents)
|
|
427
|
+
|
|
428
|
+
def confirm_message_received(run_id: int, object_id: str) -> None:
|
|
429
|
+
"""Confirm that the message has been received."""
|
|
430
|
+
# Check Node
|
|
431
|
+
if node is None:
|
|
432
|
+
raise RuntimeError("Node instance missing")
|
|
433
|
+
|
|
434
|
+
fn = make_confirm_message_received_fn_protobuf(
|
|
435
|
+
confirm_message_received_protobuf=_confirm_message_received_protobuf,
|
|
436
|
+
node=node,
|
|
437
|
+
run_id=run_id,
|
|
438
|
+
)
|
|
439
|
+
fn(object_id)
|
|
440
|
+
|
|
481
441
|
try:
|
|
482
442
|
# Yield methods
|
|
483
|
-
yield (
|
|
443
|
+
yield (
|
|
444
|
+
receive,
|
|
445
|
+
send,
|
|
446
|
+
create_node,
|
|
447
|
+
delete_node,
|
|
448
|
+
get_run,
|
|
449
|
+
get_fab,
|
|
450
|
+
pull_object,
|
|
451
|
+
push_object,
|
|
452
|
+
confirm_message_received,
|
|
453
|
+
)
|
|
484
454
|
except Exception as exc: # pylint: disable=broad-except
|
|
485
455
|
log(ERROR, exc)
|
|
486
456
|
# Cleanup
|
flwr/clientapp/__init__.py
CHANGED
|
@@ -13,3 +13,13 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Public Flower ClientApp APIs."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from flwr.client.client_app import ClientApp
|
|
19
|
+
|
|
20
|
+
from . import mod
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"ClientApp",
|
|
24
|
+
"mod",
|
|
25
|
+
]
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower Built-in Mods."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from flwr.client.mod.comms_mods import arrays_size_mod, message_size_mod
|
|
19
|
+
|
|
20
|
+
from .centraldp_mods import fixedclipping_mod
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"arrays_size_mod",
|
|
24
|
+
"fixedclipping_mod",
|
|
25
|
+
"message_size_mod",
|
|
26
|
+
]
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Clipping modifiers for central DP with client-side clipping."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from collections import OrderedDict
|
|
19
|
+
from logging import INFO, WARN
|
|
20
|
+
from typing import cast
|
|
21
|
+
|
|
22
|
+
from flwr.client.typing import ClientAppCallable
|
|
23
|
+
from flwr.common import Array, ArrayRecord, Context, Message, MessageType, log
|
|
24
|
+
from flwr.common.differential_privacy import compute_clip_model_update
|
|
25
|
+
from flwr.common.differential_privacy_constants import KEY_CLIPPING_NORM
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# pylint: disable=too-many-return-statements
|
|
29
|
+
def fixedclipping_mod(
|
|
30
|
+
msg: Message, ctxt: Context, call_next: ClientAppCallable
|
|
31
|
+
) -> Message:
|
|
32
|
+
"""Client-side fixed clipping modifier.
|
|
33
|
+
|
|
34
|
+
This mod needs to be used with the `DifferentialPrivacyClientSideFixedClipping`
|
|
35
|
+
server-side strategy wrapper.
|
|
36
|
+
|
|
37
|
+
The wrapper sends the clipping_norm value to the client.
|
|
38
|
+
|
|
39
|
+
This mod clips the client model updates before sending them to the server.
|
|
40
|
+
|
|
41
|
+
It operates on messages of type `MessageType.TRAIN`.
|
|
42
|
+
|
|
43
|
+
Notes
|
|
44
|
+
-----
|
|
45
|
+
Consider the order of mods when using multiple.
|
|
46
|
+
|
|
47
|
+
Typically, fixedclipping_mod should be the last to operate on params.
|
|
48
|
+
"""
|
|
49
|
+
if msg.metadata.message_type != MessageType.TRAIN:
|
|
50
|
+
return call_next(msg, ctxt)
|
|
51
|
+
|
|
52
|
+
if len(msg.content.array_records) != 1:
|
|
53
|
+
log(
|
|
54
|
+
WARN,
|
|
55
|
+
"fixedclipping_mod is designed to work with a single ArrayRecord. "
|
|
56
|
+
"Skipping.",
|
|
57
|
+
)
|
|
58
|
+
return call_next(msg, ctxt)
|
|
59
|
+
|
|
60
|
+
if len(msg.content.config_records) != 1:
|
|
61
|
+
log(
|
|
62
|
+
WARN,
|
|
63
|
+
"fixedclipping_mod is designed to work with a single ConfigRecord. "
|
|
64
|
+
"Skipping.",
|
|
65
|
+
)
|
|
66
|
+
return call_next(msg, ctxt)
|
|
67
|
+
|
|
68
|
+
# Get keys in the single ConfigRecord
|
|
69
|
+
keys_in_config = set(next(iter(msg.content.config_records.values())).keys())
|
|
70
|
+
if KEY_CLIPPING_NORM not in keys_in_config:
|
|
71
|
+
raise KeyError(
|
|
72
|
+
f"The {KEY_CLIPPING_NORM} value is not supplied by the "
|
|
73
|
+
f"`DifferentialPrivacyClientSideFixedClipping` wrapper at"
|
|
74
|
+
f" the server side."
|
|
75
|
+
)
|
|
76
|
+
# Record array record communicated to client and clipping norm
|
|
77
|
+
original_array_record = next(iter(msg.content.array_records.values()))
|
|
78
|
+
clipping_norm = cast(
|
|
79
|
+
float, next(iter(msg.content.config_records.values()))[KEY_CLIPPING_NORM]
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# Call inner app
|
|
83
|
+
out_msg = call_next(msg, ctxt)
|
|
84
|
+
|
|
85
|
+
# Check if the msg has error
|
|
86
|
+
if out_msg.has_error():
|
|
87
|
+
return out_msg
|
|
88
|
+
|
|
89
|
+
# Ensure there is a single ArrayRecord
|
|
90
|
+
if len(out_msg.content.array_records) != 1:
|
|
91
|
+
log(
|
|
92
|
+
WARN,
|
|
93
|
+
"fixedclipping_mod is designed to work with a single ArrayRecord. "
|
|
94
|
+
"Skipping.",
|
|
95
|
+
)
|
|
96
|
+
return out_msg
|
|
97
|
+
|
|
98
|
+
new_array_record_key, client_to_server_arrecord = next(
|
|
99
|
+
iter(out_msg.content.array_records.items())
|
|
100
|
+
)
|
|
101
|
+
# Ensure keys in returned ArrayRecord match those in the one sent from server
|
|
102
|
+
if set(original_array_record.keys()) != set(client_to_server_arrecord.keys()):
|
|
103
|
+
log(
|
|
104
|
+
WARN,
|
|
105
|
+
"fixedclipping_mod: Keys in ArrayRecord must match those from the model "
|
|
106
|
+
"that the ClientApp received. Skipping.",
|
|
107
|
+
)
|
|
108
|
+
return out_msg
|
|
109
|
+
|
|
110
|
+
client_to_server_ndarrays = client_to_server_arrecord.to_numpy_ndarrays()
|
|
111
|
+
# Clip the client update
|
|
112
|
+
compute_clip_model_update(
|
|
113
|
+
param1=client_to_server_ndarrays,
|
|
114
|
+
param2=original_array_record.to_numpy_ndarrays(),
|
|
115
|
+
clipping_norm=clipping_norm,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
log(
|
|
119
|
+
INFO, "fixedclipping_mod: parameters are clipped by value: %.4f.", clipping_norm
|
|
120
|
+
)
|
|
121
|
+
# Replace outgoing ArrayRecord's Array while preserving their keys
|
|
122
|
+
out_msg.content.array_records[new_array_record_key] = ArrayRecord(
|
|
123
|
+
OrderedDict(
|
|
124
|
+
{
|
|
125
|
+
k: Array(v)
|
|
126
|
+
for k, v in zip(
|
|
127
|
+
client_to_server_arrecord.keys(), client_to_server_ndarrays
|
|
128
|
+
)
|
|
129
|
+
}
|
|
130
|
+
)
|
|
131
|
+
)
|
|
132
|
+
return out_msg
|
flwr/common/args.py
CHANGED
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
|
|
18
18
|
import argparse
|
|
19
19
|
import sys
|
|
20
|
-
from logging import DEBUG, ERROR, WARN
|
|
20
|
+
from logging import DEBUG, ERROR, INFO, WARN
|
|
21
21
|
from os.path import isfile
|
|
22
22
|
from pathlib import Path
|
|
23
23
|
from typing import Optional, Union
|
|
@@ -28,6 +28,12 @@ from flwr.common.logger import log
|
|
|
28
28
|
|
|
29
29
|
def add_args_flwr_app_common(parser: argparse.ArgumentParser) -> None:
|
|
30
30
|
"""Add common Flower arguments for flwr-*app to the provided parser."""
|
|
31
|
+
parser.add_argument(
|
|
32
|
+
"--token",
|
|
33
|
+
type=str,
|
|
34
|
+
required=False,
|
|
35
|
+
help="Unique token generated by AppIo API for each app execution",
|
|
36
|
+
)
|
|
31
37
|
parser.add_argument(
|
|
32
38
|
"--flwr-dir",
|
|
33
39
|
default=None,
|
|
@@ -47,6 +53,18 @@ def add_args_flwr_app_common(parser: argparse.ArgumentParser) -> None:
|
|
|
47
53
|
"is not encrypted. By default, the server runs with HTTPS enabled. "
|
|
48
54
|
"Use this flag only if you understand the risks.",
|
|
49
55
|
)
|
|
56
|
+
parser.add_argument(
|
|
57
|
+
"--parent-pid",
|
|
58
|
+
type=int,
|
|
59
|
+
default=None,
|
|
60
|
+
help="The PID of the parent process. When set, the process will terminate "
|
|
61
|
+
"when the parent process exits.",
|
|
62
|
+
)
|
|
63
|
+
parser.add_argument(
|
|
64
|
+
"--run-once",
|
|
65
|
+
action="store_true",
|
|
66
|
+
help="This flag is deprecated and will be removed in a future release.",
|
|
67
|
+
)
|
|
50
68
|
|
|
51
69
|
|
|
52
70
|
def try_obtain_root_certificates(
|
|
@@ -72,11 +90,7 @@ def try_obtain_root_certificates(
|
|
|
72
90
|
else:
|
|
73
91
|
# Load the certificates if provided, or load the system certificates
|
|
74
92
|
if root_cert_path is None:
|
|
75
|
-
log(
|
|
76
|
-
WARN,
|
|
77
|
-
"Both `--insecure` and `--root-certificates` were not set. "
|
|
78
|
-
"Using system certificates.",
|
|
79
|
-
)
|
|
93
|
+
log(INFO, "Using system certificates")
|
|
80
94
|
root_certificates = None
|
|
81
95
|
elif not isfile(root_cert_path):
|
|
82
96
|
log(ERROR, "Path argument `--root-certificates` does not point to a file.")
|
|
@@ -16,11 +16,11 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from .auth_plugin import CliAuthPlugin as CliAuthPlugin
|
|
19
|
-
from .auth_plugin import
|
|
20
|
-
from .auth_plugin import
|
|
19
|
+
from .auth_plugin import ControlAuthPlugin as ControlAuthPlugin
|
|
20
|
+
from .auth_plugin import ControlAuthzPlugin as ControlAuthzPlugin
|
|
21
21
|
|
|
22
22
|
__all__ = [
|
|
23
23
|
"CliAuthPlugin",
|
|
24
|
-
"
|
|
25
|
-
"
|
|
24
|
+
"ControlAuthPlugin",
|
|
25
|
+
"ControlAuthzPlugin",
|
|
26
26
|
]
|
|
@@ -21,13 +21,13 @@ from pathlib import Path
|
|
|
21
21
|
from typing import Optional, Union
|
|
22
22
|
|
|
23
23
|
from flwr.common.typing import AccountInfo
|
|
24
|
-
from flwr.proto.
|
|
24
|
+
from flwr.proto.control_pb2_grpc import ControlStub
|
|
25
25
|
|
|
26
26
|
from ..typing import UserAuthCredentials, UserAuthLoginDetails
|
|
27
27
|
|
|
28
28
|
|
|
29
|
-
class
|
|
30
|
-
"""Abstract Flower Auth Plugin class for
|
|
29
|
+
class ControlAuthPlugin(ABC):
|
|
30
|
+
"""Abstract Flower Auth Plugin class for ControlServicer.
|
|
31
31
|
|
|
32
32
|
Parameters
|
|
33
33
|
----------
|
|
@@ -69,8 +69,8 @@ class ExecAuthPlugin(ABC):
|
|
|
69
69
|
"""Refresh authentication tokens in the provided metadata."""
|
|
70
70
|
|
|
71
71
|
|
|
72
|
-
class
|
|
73
|
-
"""Abstract Flower Authorization Plugin class for
|
|
72
|
+
class ControlAuthzPlugin(ABC): # pylint: disable=too-few-public-methods
|
|
73
|
+
"""Abstract Flower Authorization Plugin class for ControlServicer.
|
|
74
74
|
|
|
75
75
|
Parameters
|
|
76
76
|
----------
|
|
@@ -103,7 +103,7 @@ class CliAuthPlugin(ABC):
|
|
|
103
103
|
@abstractmethod
|
|
104
104
|
def login(
|
|
105
105
|
login_details: UserAuthLoginDetails,
|
|
106
|
-
|
|
106
|
+
control_stub: ControlStub,
|
|
107
107
|
) -> UserAuthCredentials:
|
|
108
108
|
"""Authenticate the user and retrieve authentication credentials.
|
|
109
109
|
|
|
@@ -111,7 +111,7 @@ class CliAuthPlugin(ABC):
|
|
|
111
111
|
----------
|
|
112
112
|
login_details : UserAuthLoginDetails
|
|
113
113
|
An object containing the user's login details.
|
|
114
|
-
|
|
114
|
+
control_stub : ControlStub
|
|
115
115
|
A stub for executing RPC calls to the server.
|
|
116
116
|
|
|
117
117
|
Returns
|