flwr 1.21.0__py3-none-any.whl → 1.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/app.py +17 -1
- flwr/cli/auth_plugin/__init__.py +15 -6
- flwr/cli/auth_plugin/auth_plugin.py +95 -0
- flwr/cli/auth_plugin/noop_auth_plugin.py +58 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +16 -25
- flwr/cli/build.py +118 -47
- flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +6 -5
- flwr/cli/log.py +2 -2
- flwr/cli/login/login.py +34 -23
- flwr/cli/ls.py +13 -9
- flwr/cli/new/new.py +196 -42
- flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
- flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
- flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
- flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
- flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
- flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
- flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
- flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
- flwr/cli/new/templates/app/code/{task.pytorch_msg_api.py.tpl → task.pytorch_legacy_api.py.tpl} +27 -14
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
- flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/{pyproject.pytorch_msg_api.toml.tpl → pyproject.pytorch_legacy_api.toml.tpl} +3 -3
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
- flwr/cli/pull.py +100 -0
- flwr/cli/run/run.py +11 -7
- flwr/cli/stop.py +2 -2
- flwr/cli/supernode/__init__.py +25 -0
- flwr/cli/supernode/ls.py +260 -0
- flwr/cli/supernode/register.py +185 -0
- flwr/cli/supernode/unregister.py +138 -0
- flwr/cli/utils.py +109 -69
- flwr/client/__init__.py +2 -1
- flwr/client/grpc_adapter_client/connection.py +6 -8
- flwr/client/grpc_rere_client/connection.py +59 -31
- flwr/client/grpc_rere_client/grpc_adapter.py +28 -12
- flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +3 -6
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +7 -5
- flwr/client/rest_client/connection.py +82 -37
- flwr/clientapp/__init__.py +1 -2
- flwr/clientapp/mod/__init__.py +4 -1
- flwr/clientapp/mod/centraldp_mods.py +156 -40
- flwr/clientapp/mod/localdp_mod.py +169 -0
- flwr/clientapp/typing.py +22 -0
- flwr/{client/clientapp → clientapp}/utils.py +1 -1
- flwr/common/constant.py +56 -13
- flwr/common/exit/exit_code.py +24 -10
- flwr/common/inflatable_utils.py +10 -10
- flwr/common/record/array.py +3 -3
- flwr/common/record/arrayrecord.py +10 -1
- flwr/common/record/typeddict.py +12 -0
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
- flwr/common/serde.py +4 -2
- flwr/common/typing.py +7 -6
- flwr/compat/client/app.py +1 -1
- flwr/compat/client/grpc_client/connection.py +2 -2
- flwr/proto/control_pb2.py +48 -31
- flwr/proto/control_pb2.pyi +95 -5
- flwr/proto/control_pb2_grpc.py +136 -0
- flwr/proto/control_pb2_grpc.pyi +52 -0
- flwr/proto/fab_pb2.py +11 -7
- flwr/proto/fab_pb2.pyi +21 -1
- flwr/proto/fleet_pb2.py +31 -23
- flwr/proto/fleet_pb2.pyi +63 -23
- flwr/proto/fleet_pb2_grpc.py +98 -28
- flwr/proto/fleet_pb2_grpc.pyi +45 -13
- flwr/proto/node_pb2.py +3 -1
- flwr/proto/node_pb2.pyi +48 -0
- flwr/server/app.py +152 -114
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +17 -7
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +132 -38
- flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +27 -51
- flwr/server/superlink/fleet/message_handler/message_handler.py +67 -22
- flwr/server/superlink/fleet/rest_rere/rest_api.py +52 -31
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +18 -5
- flwr/server/superlink/linkstate/in_memory_linkstate.py +167 -73
- flwr/server/superlink/linkstate/linkstate.py +107 -24
- flwr/server/superlink/linkstate/linkstate_factory.py +2 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +306 -255
- flwr/server/superlink/linkstate/utils.py +3 -54
- flwr/server/superlink/serverappio/serverappio_servicer.py +2 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +1 -1
- flwr/server/utils/validator.py +2 -3
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +4 -2
- flwr/serverapp/strategy/__init__.py +26 -0
- flwr/serverapp/strategy/bulyan.py +238 -0
- flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
- flwr/serverapp/strategy/dp_fixed_clipping.py +71 -49
- flwr/serverapp/strategy/fedadagrad.py +0 -3
- flwr/serverapp/strategy/fedadam.py +0 -3
- flwr/serverapp/strategy/fedavg.py +89 -64
- flwr/serverapp/strategy/fedavgm.py +198 -0
- flwr/serverapp/strategy/fedmedian.py +105 -0
- flwr/serverapp/strategy/fedprox.py +174 -0
- flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
- flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
- flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
- flwr/serverapp/strategy/fedyogi.py +0 -3
- flwr/serverapp/strategy/krum.py +112 -0
- flwr/serverapp/strategy/multikrum.py +247 -0
- flwr/serverapp/strategy/qfedavg.py +252 -0
- flwr/serverapp/strategy/strategy_utils.py +48 -0
- flwr/simulation/app.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +1 -1
- flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
- flwr/simulation/run_simulation.py +28 -32
- flwr/supercore/cli/flower_superexec.py +26 -1
- flwr/supercore/constant.py +41 -0
- flwr/supercore/object_store/in_memory_object_store.py +0 -4
- flwr/supercore/object_store/object_store_factory.py +26 -6
- flwr/supercore/object_store/sqlite_object_store.py +252 -0
- flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
- flwr/supercore/primitives/asymmetric.py +117 -0
- flwr/supercore/primitives/asymmetric_ed25519.py +165 -0
- flwr/supercore/sqlite_mixin.py +156 -0
- flwr/supercore/superexec/plugin/exec_plugin.py +11 -1
- flwr/supercore/superexec/run_superexec.py +16 -2
- flwr/supercore/utils.py +20 -0
- flwr/superlink/artifact_provider/__init__.py +22 -0
- flwr/superlink/artifact_provider/artifact_provider.py +37 -0
- flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
- flwr/superlink/auth_plugin/auth_plugin.py +91 -0
- flwr/superlink/auth_plugin/noop_auth_plugin.py +87 -0
- flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +19 -19
- flwr/superlink/servicer/control/control_event_log_interceptor.py +1 -1
- flwr/superlink/servicer/control/control_grpc.py +16 -11
- flwr/superlink/servicer/control/control_servicer.py +207 -58
- flwr/supernode/cli/flower_supernode.py +19 -26
- flwr/supernode/runtime/run_clientapp.py +2 -2
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +1 -1
- flwr/supernode/start_client_internal.py +17 -9
- {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/METADATA +6 -16
- {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/RECORD +170 -140
- flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +0 -80
- flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +0 -41
- flwr/common/auth_plugin/auth_plugin.py +0 -149
- flwr/serverapp/dp_fixed_clipping.py +0 -352
- flwr/serverapp/strategy/strategy_utils_tests.py +0 -304
- /flwr/cli/new/templates/app/code/{__init__.pytorch_msg_api.py.tpl → __init__.pytorch_legacy_api.py.tpl} +0 -0
- /flwr/{client → clientapp}/client_app.py +0 -0
- {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/WHEEL +0 -0
- {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/entry_points.txt +0 -0
|
@@ -19,7 +19,7 @@ from collections.abc import Iterator, Sequence
|
|
|
19
19
|
from contextlib import contextmanager
|
|
20
20
|
from logging import ERROR
|
|
21
21
|
from pathlib import Path
|
|
22
|
-
from typing import Callable, Optional, Union
|
|
22
|
+
from typing import Callable, Optional, Union
|
|
23
23
|
|
|
24
24
|
import grpc
|
|
25
25
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
@@ -36,19 +36,24 @@ from flwr.common.inflatable_protobuf_utils import (
|
|
|
36
36
|
from flwr.common.logger import log
|
|
37
37
|
from flwr.common.message import Message, remove_content_from_message
|
|
38
38
|
from flwr.common.retry_invoker import RetryInvoker, _wrap_stub
|
|
39
|
-
from flwr.common.
|
|
40
|
-
|
|
39
|
+
from flwr.common.serde import (
|
|
40
|
+
fab_from_proto,
|
|
41
|
+
message_from_proto,
|
|
42
|
+
message_to_proto,
|
|
43
|
+
run_from_proto,
|
|
41
44
|
)
|
|
42
|
-
from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
|
|
43
45
|
from flwr.common.typing import Fab, Run
|
|
44
46
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
45
47
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
46
|
-
|
|
47
|
-
|
|
48
|
+
ActivateNodeRequest,
|
|
49
|
+
ActivateNodeResponse,
|
|
50
|
+
DeactivateNodeRequest,
|
|
48
51
|
PullMessagesRequest,
|
|
49
52
|
PullMessagesResponse,
|
|
50
53
|
PushMessagesRequest,
|
|
51
54
|
PushMessagesResponse,
|
|
55
|
+
RegisterNodeFleetRequest,
|
|
56
|
+
UnregisterNodeFleetRequest,
|
|
52
57
|
)
|
|
53
58
|
from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
|
|
54
59
|
from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
@@ -58,9 +63,10 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
|
58
63
|
from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
|
|
59
64
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
60
65
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
66
|
+
from flwr.supercore.primitives.asymmetric import generate_key_pairs, public_key_to_bytes
|
|
61
67
|
|
|
62
|
-
from .client_interceptor import AuthenticateClientInterceptor
|
|
63
68
|
from .grpc_adapter import GrpcAdapter
|
|
69
|
+
from .node_auth_client_interceptor import NodeAuthClientInterceptor
|
|
64
70
|
|
|
65
71
|
|
|
66
72
|
@contextmanager
|
|
@@ -76,10 +82,9 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
76
82
|
adapter_cls: Optional[Union[type[FleetStub], type[GrpcAdapter]]] = None,
|
|
77
83
|
) -> Iterator[
|
|
78
84
|
tuple[
|
|
85
|
+
int,
|
|
79
86
|
Callable[[], Optional[tuple[Message, ObjectTree]]],
|
|
80
87
|
Callable[[Message, ObjectTree], set[str]],
|
|
81
|
-
Callable[[], Optional[int]],
|
|
82
|
-
Callable[[], None],
|
|
83
88
|
Callable[[int], Run],
|
|
84
89
|
Callable[[str, int], Fab],
|
|
85
90
|
Callable[[int, str], bytes],
|
|
@@ -122,11 +127,11 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
122
127
|
|
|
123
128
|
Returns
|
|
124
129
|
-------
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
+
node_id : int
|
|
131
|
+
receive : Callable[[], Optional[tuple[Message, ObjectTree]]]
|
|
132
|
+
send : Callable[[Message, ObjectTree], set[str]]
|
|
133
|
+
get_run : Callable[[int], Run]
|
|
134
|
+
get_fab : Callable[[str, int], Fab]
|
|
130
135
|
pull_object : Callable[[str], bytes]
|
|
131
136
|
push_object : Callable[[str, bytes], None]
|
|
132
137
|
confirm_message_received : Callable[[str], None]
|
|
@@ -135,13 +140,16 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
135
140
|
root_certificates = Path(root_certificates).read_bytes()
|
|
136
141
|
|
|
137
142
|
# Automatic node auth: generate keys if user didn't provide any
|
|
143
|
+
self_registered = False
|
|
138
144
|
if authentication_keys is None:
|
|
145
|
+
self_registered = True
|
|
139
146
|
authentication_keys = generate_key_pairs()
|
|
140
147
|
|
|
141
148
|
# Always configure auth interceptor, with either user-provided or generated keys
|
|
142
149
|
interceptors: Sequence[grpc.UnaryUnaryClientInterceptor] = [
|
|
143
|
-
|
|
150
|
+
NodeAuthClientInterceptor(*authentication_keys),
|
|
144
151
|
]
|
|
152
|
+
node_pk = public_key_to_bytes(authentication_keys[1])
|
|
145
153
|
channel = create_channel(
|
|
146
154
|
server_address=server_address,
|
|
147
155
|
insecure=insecure,
|
|
@@ -160,7 +168,7 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
160
168
|
# Wrap stub
|
|
161
169
|
_wrap_stub(stub, retry_invoker)
|
|
162
170
|
###########################################################################
|
|
163
|
-
#
|
|
171
|
+
# SuperNode functions
|
|
164
172
|
###########################################################################
|
|
165
173
|
|
|
166
174
|
def send_node_heartbeat() -> bool:
|
|
@@ -197,22 +205,26 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
197
205
|
|
|
198
206
|
heartbeat_sender = HeartbeatSender(send_node_heartbeat)
|
|
199
207
|
|
|
200
|
-
def
|
|
201
|
-
"""
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
208
|
+
def register_node() -> None:
|
|
209
|
+
"""Register node with SuperLink."""
|
|
210
|
+
stub.RegisterNode(RegisterNodeFleetRequest(public_key=node_pk))
|
|
211
|
+
|
|
212
|
+
def activate_node() -> int:
|
|
213
|
+
"""Activate node and start heartbeat."""
|
|
214
|
+
req = ActivateNodeRequest(
|
|
215
|
+
public_key=node_pk,
|
|
216
|
+
heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL,
|
|
205
217
|
)
|
|
206
|
-
|
|
218
|
+
res: ActivateNodeResponse = stub.ActivateNode(req)
|
|
207
219
|
|
|
208
220
|
# Remember the node and start the heartbeat sender
|
|
209
221
|
nonlocal node
|
|
210
|
-
node =
|
|
222
|
+
node = Node(node_id=res.node_id)
|
|
211
223
|
heartbeat_sender.start()
|
|
212
224
|
return node.node_id
|
|
213
225
|
|
|
214
|
-
def
|
|
215
|
-
"""
|
|
226
|
+
def deactivate_node() -> None:
|
|
227
|
+
"""Deactivate node and stop heartbeat."""
|
|
216
228
|
# Get Node
|
|
217
229
|
nonlocal node
|
|
218
230
|
if node is None:
|
|
@@ -223,8 +235,20 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
223
235
|
heartbeat_sender.stop()
|
|
224
236
|
|
|
225
237
|
# Call FleetAPI
|
|
226
|
-
|
|
227
|
-
stub.
|
|
238
|
+
req = DeactivateNodeRequest(node_id=node.node_id)
|
|
239
|
+
stub.DeactivateNode(req)
|
|
240
|
+
|
|
241
|
+
def unregister_node() -> None:
|
|
242
|
+
"""Unregister node from SuperLink."""
|
|
243
|
+
# Get Node
|
|
244
|
+
nonlocal node
|
|
245
|
+
if node is None:
|
|
246
|
+
log(ERROR, "Node instance missing")
|
|
247
|
+
return
|
|
248
|
+
|
|
249
|
+
# Call FleetAPI
|
|
250
|
+
req = UnregisterNodeFleetRequest(node_id=node.node_id)
|
|
251
|
+
stub.UnregisterNode(req)
|
|
228
252
|
|
|
229
253
|
# Cleanup
|
|
230
254
|
node = None
|
|
@@ -289,7 +313,7 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
289
313
|
get_fab_request = GetFabRequest(node=node, hash_str=fab_hash, run_id=run_id)
|
|
290
314
|
get_fab_response: GetFabResponse = stub.GetFab(request=get_fab_request)
|
|
291
315
|
|
|
292
|
-
return
|
|
316
|
+
return fab_from_proto(get_fab_response.fab)
|
|
293
317
|
|
|
294
318
|
def pull_object(run_id: int, object_id: str) -> bytes:
|
|
295
319
|
"""Pull the object from the SuperLink."""
|
|
@@ -331,12 +355,14 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
331
355
|
fn(object_id)
|
|
332
356
|
|
|
333
357
|
try:
|
|
358
|
+
if self_registered:
|
|
359
|
+
register_node()
|
|
360
|
+
node_id = activate_node()
|
|
334
361
|
# Yield methods
|
|
335
362
|
yield (
|
|
363
|
+
node_id,
|
|
336
364
|
receive,
|
|
337
365
|
send,
|
|
338
|
-
create_node,
|
|
339
|
-
delete_node,
|
|
340
366
|
get_run,
|
|
341
367
|
get_fab,
|
|
342
368
|
pull_object,
|
|
@@ -351,7 +377,9 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
351
377
|
if node is not None:
|
|
352
378
|
# Disable retrying
|
|
353
379
|
retry_invoker.max_tries = 1
|
|
354
|
-
|
|
380
|
+
deactivate_node()
|
|
381
|
+
if self_registered:
|
|
382
|
+
unregister_node()
|
|
355
383
|
except grpc.RpcError:
|
|
356
384
|
pass
|
|
357
385
|
channel.close()
|
|
@@ -34,14 +34,18 @@ from flwr.common.constant import (
|
|
|
34
34
|
from flwr.common.version import package_name, package_version
|
|
35
35
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
36
36
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
37
|
+
ActivateNodeRequest,
|
|
38
|
+
ActivateNodeResponse,
|
|
39
|
+
DeactivateNodeRequest,
|
|
40
|
+
DeactivateNodeResponse,
|
|
41
41
|
PullMessagesRequest,
|
|
42
42
|
PullMessagesResponse,
|
|
43
43
|
PushMessagesRequest,
|
|
44
44
|
PushMessagesResponse,
|
|
45
|
+
RegisterNodeFleetRequest,
|
|
46
|
+
RegisterNodeFleetResponse,
|
|
47
|
+
UnregisterNodeFleetRequest,
|
|
48
|
+
UnregisterNodeFleetResponse,
|
|
45
49
|
)
|
|
46
50
|
from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
|
|
47
51
|
from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterStub
|
|
@@ -118,17 +122,29 @@ class GrpcAdapter:
|
|
|
118
122
|
response.ParseFromString(container_res.grpc_message_content)
|
|
119
123
|
return response
|
|
120
124
|
|
|
121
|
-
def
|
|
122
|
-
self, request:
|
|
123
|
-
) ->
|
|
125
|
+
def RegisterNode( # pylint: disable=C0103
|
|
126
|
+
self, request: RegisterNodeFleetRequest, **kwargs: Any
|
|
127
|
+
) -> RegisterNodeFleetResponse:
|
|
124
128
|
"""."""
|
|
125
|
-
return self._send_and_receive(request,
|
|
129
|
+
return self._send_and_receive(request, RegisterNodeFleetResponse, **kwargs)
|
|
126
130
|
|
|
127
|
-
def
|
|
128
|
-
self, request:
|
|
129
|
-
) ->
|
|
131
|
+
def ActivateNode( # pylint: disable=C0103
|
|
132
|
+
self, request: ActivateNodeRequest, **kwargs: Any
|
|
133
|
+
) -> ActivateNodeResponse:
|
|
130
134
|
"""."""
|
|
131
|
-
return self._send_and_receive(request,
|
|
135
|
+
return self._send_and_receive(request, ActivateNodeResponse, **kwargs)
|
|
136
|
+
|
|
137
|
+
def DeactivateNode( # pylint: disable=C0103
|
|
138
|
+
self, request: DeactivateNodeRequest, **kwargs: Any
|
|
139
|
+
) -> DeactivateNodeResponse:
|
|
140
|
+
"""."""
|
|
141
|
+
return self._send_and_receive(request, DeactivateNodeResponse, **kwargs)
|
|
142
|
+
|
|
143
|
+
def UnregisterNode( # pylint: disable=C0103
|
|
144
|
+
self, request: UnregisterNodeFleetRequest, **kwargs: Any
|
|
145
|
+
) -> UnregisterNodeFleetResponse:
|
|
146
|
+
"""."""
|
|
147
|
+
return self._send_and_receive(request, UnregisterNodeFleetResponse, **kwargs)
|
|
132
148
|
|
|
133
149
|
def SendNodeHeartbeat( # pylint: disable=C0103
|
|
134
150
|
self, request: SendNodeHeartbeatRequest, **kwargs: Any
|
|
@@ -23,14 +23,11 @@ from google.protobuf.message import Message as GrpcMessage
|
|
|
23
23
|
|
|
24
24
|
from flwr.common import now
|
|
25
25
|
from flwr.common.constant import PUBLIC_KEY_HEADER, SIGNATURE_HEADER, TIMESTAMP_HEADER
|
|
26
|
-
from flwr.
|
|
27
|
-
public_key_to_bytes,
|
|
28
|
-
sign_message,
|
|
29
|
-
)
|
|
26
|
+
from flwr.supercore.primitives.asymmetric import public_key_to_bytes, sign_message
|
|
30
27
|
|
|
31
28
|
|
|
32
|
-
class
|
|
33
|
-
"""Client interceptor for
|
|
29
|
+
class NodeAuthClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
|
|
30
|
+
"""Client interceptor for node authentication."""
|
|
34
31
|
|
|
35
32
|
def __init__(
|
|
36
33
|
self,
|
|
@@ -35,14 +35,9 @@ from flwr.common.constant import MessageType
|
|
|
35
35
|
from flwr.common.logger import log
|
|
36
36
|
from flwr.common.secure_aggregation.crypto.shamir import create_shares
|
|
37
37
|
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
38
|
-
bytes_to_private_key,
|
|
39
|
-
bytes_to_public_key,
|
|
40
38
|
decrypt,
|
|
41
39
|
encrypt,
|
|
42
|
-
generate_key_pairs,
|
|
43
40
|
generate_shared_key,
|
|
44
|
-
private_key_to_bytes,
|
|
45
|
-
public_key_to_bytes,
|
|
46
41
|
)
|
|
47
42
|
from flwr.common.secure_aggregation.ndarrays_arithmetic import (
|
|
48
43
|
factor_combine,
|
|
@@ -64,6 +59,13 @@ from flwr.common.secure_aggregation.secaggplus_utils import (
|
|
|
64
59
|
share_keys_plaintext_separate,
|
|
65
60
|
)
|
|
66
61
|
from flwr.common.typing import ConfigRecordValues
|
|
62
|
+
from flwr.supercore.primitives.asymmetric import (
|
|
63
|
+
bytes_to_private_key,
|
|
64
|
+
bytes_to_public_key,
|
|
65
|
+
generate_key_pairs,
|
|
66
|
+
private_key_to_bytes,
|
|
67
|
+
public_key_to_bytes,
|
|
68
|
+
)
|
|
67
69
|
|
|
68
70
|
|
|
69
71
|
@dataclass
|
|
@@ -36,18 +36,27 @@ from flwr.common.inflatable_protobuf_utils import (
|
|
|
36
36
|
from flwr.common.logger import log
|
|
37
37
|
from flwr.common.message import Message, remove_content_from_message
|
|
38
38
|
from flwr.common.retry_invoker import RetryInvoker
|
|
39
|
-
from flwr.common.serde import
|
|
39
|
+
from flwr.common.serde import (
|
|
40
|
+
fab_from_proto,
|
|
41
|
+
message_from_proto,
|
|
42
|
+
message_to_proto,
|
|
43
|
+
run_from_proto,
|
|
44
|
+
)
|
|
40
45
|
from flwr.common.typing import Fab, Run
|
|
41
46
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
42
47
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
48
|
+
ActivateNodeRequest,
|
|
49
|
+
ActivateNodeResponse,
|
|
50
|
+
DeactivateNodeRequest,
|
|
51
|
+
DeactivateNodeResponse,
|
|
47
52
|
PullMessagesRequest,
|
|
48
53
|
PullMessagesResponse,
|
|
49
54
|
PushMessagesRequest,
|
|
50
55
|
PushMessagesResponse,
|
|
56
|
+
RegisterNodeFleetRequest,
|
|
57
|
+
RegisterNodeFleetResponse,
|
|
58
|
+
UnregisterNodeFleetRequest,
|
|
59
|
+
UnregisterNodeFleetResponse,
|
|
51
60
|
)
|
|
52
61
|
from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
53
62
|
SendNodeHeartbeatRequest,
|
|
@@ -64,6 +73,7 @@ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
|
|
64
73
|
)
|
|
65
74
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
66
75
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
76
|
+
from flwr.supercore.primitives.asymmetric import generate_key_pairs, public_key_to_bytes
|
|
67
77
|
|
|
68
78
|
try:
|
|
69
79
|
import requests
|
|
@@ -71,8 +81,10 @@ except ModuleNotFoundError:
|
|
|
71
81
|
flwr_exit(ExitCode.COMMON_MISSING_EXTRA_REST)
|
|
72
82
|
|
|
73
83
|
|
|
74
|
-
|
|
75
|
-
|
|
84
|
+
PATH_REGISTER_NODE: str = "/api/v0/fleet/register-node"
|
|
85
|
+
PATH_ACTIVATE_NODE: str = "/api/v0/fleet/activate-node"
|
|
86
|
+
PATH_DEACTIVATE_NODE: str = "/api/v0/fleet/deactivate-node"
|
|
87
|
+
PATH_UNREGISTER_NODE: str = "/api/v0/fleet/unregister-node"
|
|
76
88
|
PATH_PULL_MESSAGES: str = "/api/v0/fleet/pull-messages"
|
|
77
89
|
PATH_PUSH_MESSAGES: str = "/api/v0/fleet/push-messages"
|
|
78
90
|
PATH_PULL_OBJECT: str = "/api/v0/fleet/pull-object"
|
|
@@ -99,10 +111,9 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
99
111
|
] = None,
|
|
100
112
|
) -> Iterator[
|
|
101
113
|
tuple[
|
|
114
|
+
int,
|
|
102
115
|
Callable[[], Optional[tuple[Message, ObjectTree]]],
|
|
103
116
|
Callable[[Message, ObjectTree], set[str]],
|
|
104
|
-
Callable[[], Optional[int]],
|
|
105
|
-
Callable[[], None],
|
|
106
117
|
Callable[[int], Run],
|
|
107
118
|
Callable[[str, int], Fab],
|
|
108
119
|
Callable[[int, str], bytes],
|
|
@@ -134,15 +145,15 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
134
145
|
connection using the certificates will be established to an SSL-enabled
|
|
135
146
|
Flower server. Bytes won't work for the REST API.
|
|
136
147
|
authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None)
|
|
137
|
-
|
|
148
|
+
SuperNode authentication is not supported for this transport type.
|
|
138
149
|
|
|
139
150
|
Returns
|
|
140
151
|
-------
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
152
|
+
node_id : int
|
|
153
|
+
receive : Callable[[], Optional[tuple[Message, ObjectTree]]]
|
|
154
|
+
send : Callable[[Message, ObjectTree], set[str]]
|
|
155
|
+
get_run : Callable[[int], Run]
|
|
156
|
+
get_fab : Callable[[str, int], Fab]
|
|
146
157
|
pull_object : Callable[[str], bytes]
|
|
147
158
|
push_object : Callable[[str, bytes], None]
|
|
148
159
|
confirm_message_received : Callable[[str], None]
|
|
@@ -171,7 +182,14 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
171
182
|
"must be provided as a string path to the client.",
|
|
172
183
|
)
|
|
173
184
|
if authentication_keys is not None:
|
|
174
|
-
log(ERROR, "
|
|
185
|
+
log(ERROR, "SuperNode authentication is not supported for this transport type.")
|
|
186
|
+
|
|
187
|
+
# REST does NOT support node authentication
|
|
188
|
+
self_registered = False
|
|
189
|
+
if authentication_keys is None:
|
|
190
|
+
self_registered = True
|
|
191
|
+
authentication_keys = generate_key_pairs()
|
|
192
|
+
node_pk = public_key_to_bytes(authentication_keys[1])
|
|
175
193
|
|
|
176
194
|
# Shared variables for inner functions
|
|
177
195
|
node: Optional[Node] = None
|
|
@@ -180,7 +198,7 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
180
198
|
retry_invoker.should_giveup = None
|
|
181
199
|
|
|
182
200
|
###########################################################################
|
|
183
|
-
#
|
|
201
|
+
# SuperNode functions
|
|
184
202
|
###########################################################################
|
|
185
203
|
|
|
186
204
|
def _request(
|
|
@@ -290,23 +308,35 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
290
308
|
|
|
291
309
|
heartbeat_sender = HeartbeatSender(send_node_heartbeat)
|
|
292
310
|
|
|
293
|
-
def
|
|
294
|
-
"""
|
|
295
|
-
req =
|
|
311
|
+
def register_node() -> None:
|
|
312
|
+
"""Register node with SuperLink."""
|
|
313
|
+
req = RegisterNodeFleetRequest(public_key=node_pk)
|
|
296
314
|
|
|
297
315
|
# Send the request
|
|
298
|
-
res = _request(req,
|
|
316
|
+
res = _request(req, RegisterNodeFleetResponse, PATH_REGISTER_NODE)
|
|
299
317
|
if res is None:
|
|
300
|
-
|
|
318
|
+
raise RuntimeError("Failed to register node")
|
|
319
|
+
|
|
320
|
+
def activate_node() -> int:
|
|
321
|
+
"""Activate node and start heartbeat."""
|
|
322
|
+
req = ActivateNodeRequest(
|
|
323
|
+
public_key=node_pk,
|
|
324
|
+
heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL,
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
# Send the request
|
|
328
|
+
res = _request(req, ActivateNodeResponse, PATH_ACTIVATE_NODE)
|
|
329
|
+
if res is None:
|
|
330
|
+
raise RuntimeError("Failed to activate node")
|
|
301
331
|
|
|
302
332
|
# Remember the node and start the heartbeat sender
|
|
303
333
|
nonlocal node
|
|
304
|
-
node = res.
|
|
334
|
+
node = Node(node_id=res.node_id)
|
|
305
335
|
heartbeat_sender.start()
|
|
306
336
|
return node.node_id
|
|
307
337
|
|
|
308
|
-
def
|
|
309
|
-
"""
|
|
338
|
+
def deactivate_node() -> None:
|
|
339
|
+
"""Deactivate node and stop heartbeat."""
|
|
310
340
|
nonlocal node
|
|
311
341
|
if node is None:
|
|
312
342
|
raise RuntimeError("Node instance missing")
|
|
@@ -314,13 +344,27 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
314
344
|
# Stop the heartbeat sender
|
|
315
345
|
heartbeat_sender.stop()
|
|
316
346
|
|
|
317
|
-
# Send
|
|
318
|
-
req =
|
|
347
|
+
# Send DeactivateNode request
|
|
348
|
+
req = DeactivateNodeRequest(node_id=node.node_id)
|
|
319
349
|
|
|
320
350
|
# Send the request
|
|
321
|
-
res = _request(req,
|
|
351
|
+
res = _request(req, DeactivateNodeResponse, PATH_DEACTIVATE_NODE)
|
|
322
352
|
if res is None:
|
|
323
|
-
|
|
353
|
+
raise RuntimeError("Failed to deactivate node")
|
|
354
|
+
|
|
355
|
+
def unregister_node() -> None:
|
|
356
|
+
"""Unregister node from SuperLink."""
|
|
357
|
+
nonlocal node
|
|
358
|
+
if node is None:
|
|
359
|
+
raise RuntimeError("Node instance missing")
|
|
360
|
+
|
|
361
|
+
# Send UnregisterNode request
|
|
362
|
+
req = UnregisterNodeFleetRequest(node_id=node.node_id)
|
|
363
|
+
|
|
364
|
+
# Send the request
|
|
365
|
+
res = _request(req, UnregisterNodeFleetResponse, PATH_UNREGISTER_NODE)
|
|
366
|
+
if res is None:
|
|
367
|
+
raise RuntimeError("Failed to unregister node")
|
|
324
368
|
|
|
325
369
|
# Cleanup
|
|
326
370
|
node = None
|
|
@@ -392,12 +436,9 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
392
436
|
# Send the request
|
|
393
437
|
res = _request(req, GetFabResponse, PATH_GET_FAB)
|
|
394
438
|
if res is None:
|
|
395
|
-
return Fab("", b"")
|
|
439
|
+
return Fab("", b"", {})
|
|
396
440
|
|
|
397
|
-
return
|
|
398
|
-
res.fab.hash_str,
|
|
399
|
-
res.fab.content,
|
|
400
|
-
)
|
|
441
|
+
return fab_from_proto(res.fab)
|
|
401
442
|
|
|
402
443
|
def pull_object(run_id: int, object_id: str) -> bytes:
|
|
403
444
|
"""Pull the object from the SuperLink."""
|
|
@@ -439,12 +480,14 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
439
480
|
fn(object_id)
|
|
440
481
|
|
|
441
482
|
try:
|
|
483
|
+
if self_registered:
|
|
484
|
+
register_node()
|
|
485
|
+
node_id = activate_node()
|
|
442
486
|
# Yield methods
|
|
443
487
|
yield (
|
|
488
|
+
node_id,
|
|
444
489
|
receive,
|
|
445
490
|
send,
|
|
446
|
-
create_node,
|
|
447
|
-
delete_node,
|
|
448
491
|
get_run,
|
|
449
492
|
get_fab,
|
|
450
493
|
pull_object,
|
|
@@ -459,6 +502,8 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
459
502
|
if node is not None:
|
|
460
503
|
# Disable retrying
|
|
461
504
|
retry_invoker.max_tries = 1
|
|
462
|
-
|
|
505
|
+
deactivate_node()
|
|
506
|
+
if self_registered:
|
|
507
|
+
unregister_node()
|
|
463
508
|
except RequestsConnectionError:
|
|
464
509
|
pass
|
flwr/clientapp/__init__.py
CHANGED
flwr/clientapp/mod/__init__.py
CHANGED
|
@@ -17,9 +17,12 @@
|
|
|
17
17
|
|
|
18
18
|
from flwr.client.mod.comms_mods import arrays_size_mod, message_size_mod
|
|
19
19
|
|
|
20
|
-
from .centraldp_mods import fixedclipping_mod
|
|
20
|
+
from .centraldp_mods import adaptiveclipping_mod, fixedclipping_mod
|
|
21
|
+
from .localdp_mod import LocalDpMod
|
|
21
22
|
|
|
22
23
|
__all__ = [
|
|
24
|
+
"LocalDpMod",
|
|
25
|
+
"adaptiveclipping_mod",
|
|
23
26
|
"arrays_size_mod",
|
|
24
27
|
"fixedclipping_mod",
|
|
25
28
|
"message_size_mod",
|