flwr 1.23.0__py3-none-any.whl → 1.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +16 -5
- flwr/app/error.py +2 -2
- flwr/app/exception.py +3 -3
- flwr/cli/app.py +19 -0
- flwr/cli/app_cmd/__init__.py +23 -0
- flwr/cli/app_cmd/publish.py +285 -0
- flwr/cli/app_cmd/review.py +252 -0
- flwr/cli/auth_plugin/auth_plugin.py +4 -5
- flwr/cli/auth_plugin/noop_auth_plugin.py +54 -11
- flwr/cli/auth_plugin/oidc_cli_plugin.py +32 -9
- flwr/cli/build.py +60 -18
- flwr/cli/cli_account_auth_interceptor.py +24 -7
- flwr/cli/config_utils.py +101 -13
- flwr/cli/federation/__init__.py +24 -0
- flwr/cli/federation/ls.py +140 -0
- flwr/cli/federation/show.py +317 -0
- flwr/cli/install.py +91 -13
- flwr/cli/log.py +52 -9
- flwr/cli/login/login.py +7 -4
- flwr/cli/ls.py +170 -130
- flwr/cli/new/new.py +33 -50
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
- flwr/cli/pull.py +10 -5
- flwr/cli/run/run.py +77 -30
- flwr/cli/run_utils.py +130 -0
- flwr/cli/stop.py +25 -7
- flwr/cli/supernode/ls.py +16 -8
- flwr/cli/supernode/register.py +9 -4
- flwr/cli/supernode/unregister.py +5 -3
- flwr/cli/utils.py +376 -16
- flwr/client/__init__.py +1 -1
- flwr/client/dpfedavg_numpy_client.py +4 -1
- flwr/client/grpc_adapter_client/connection.py +6 -7
- flwr/client/grpc_rere_client/connection.py +10 -11
- flwr/client/grpc_rere_client/grpc_adapter.py +6 -2
- flwr/client/grpc_rere_client/node_auth_client_interceptor.py +2 -1
- flwr/client/message_handler/message_handler.py +2 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +3 -3
- flwr/client/numpy_client.py +1 -1
- flwr/client/rest_client/connection.py +12 -14
- flwr/client/run_info_store.py +4 -5
- flwr/client/typing.py +1 -1
- flwr/clientapp/client_app.py +9 -10
- flwr/clientapp/mod/centraldp_mods.py +16 -17
- flwr/clientapp/mod/localdp_mod.py +8 -9
- flwr/clientapp/typing.py +1 -1
- flwr/clientapp/utils.py +3 -3
- flwr/common/address.py +1 -2
- flwr/common/args.py +3 -4
- flwr/common/config.py +13 -16
- flwr/common/constant.py +5 -2
- flwr/common/differential_privacy.py +3 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -4
- flwr/common/exit/exit.py +15 -2
- flwr/common/exit/exit_code.py +19 -0
- flwr/common/exit/exit_handler.py +6 -2
- flwr/common/exit/signal_handler.py +5 -5
- flwr/common/grpc.py +6 -6
- flwr/common/inflatable_protobuf_utils.py +1 -1
- flwr/common/inflatable_utils.py +38 -21
- flwr/common/logger.py +19 -19
- flwr/common/message.py +4 -4
- flwr/common/object_ref.py +7 -7
- flwr/common/record/array.py +3 -3
- flwr/common/record/arrayrecord.py +18 -30
- flwr/common/record/configrecord.py +3 -3
- flwr/common/record/recorddict.py +5 -5
- flwr/common/record/typeddict.py +9 -2
- flwr/common/recorddict_compat.py +7 -10
- flwr/common/retry_invoker.py +20 -20
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
- flwr/common/serde.py +5 -4
- flwr/common/serde_utils.py +2 -2
- flwr/common/telemetry.py +9 -5
- flwr/common/typing.py +52 -37
- flwr/compat/client/app.py +38 -37
- flwr/compat/client/grpc_client/connection.py +11 -11
- flwr/compat/server/app.py +5 -6
- flwr/proto/appio_pb2.py +13 -3
- flwr/proto/appio_pb2.pyi +134 -65
- flwr/proto/appio_pb2_grpc.py +20 -0
- flwr/proto/appio_pb2_grpc.pyi +27 -0
- flwr/proto/clientappio_pb2.py +17 -7
- flwr/proto/clientappio_pb2.pyi +15 -0
- flwr/proto/clientappio_pb2_grpc.py +206 -40
- flwr/proto/clientappio_pb2_grpc.pyi +168 -53
- flwr/proto/control_pb2.py +71 -52
- flwr/proto/control_pb2.pyi +277 -111
- flwr/proto/control_pb2_grpc.py +249 -40
- flwr/proto/control_pb2_grpc.pyi +185 -52
- flwr/proto/error_pb2.py +13 -3
- flwr/proto/error_pb2.pyi +24 -6
- flwr/proto/error_pb2_grpc.py +20 -0
- flwr/proto/error_pb2_grpc.pyi +27 -0
- flwr/proto/fab_pb2.py +14 -4
- flwr/proto/fab_pb2.pyi +59 -31
- flwr/proto/fab_pb2_grpc.py +20 -0
- flwr/proto/fab_pb2_grpc.pyi +27 -0
- flwr/proto/federation_pb2.py +38 -0
- flwr/proto/federation_pb2.pyi +56 -0
- flwr/proto/federation_pb2_grpc.py +24 -0
- flwr/proto/federation_pb2_grpc.pyi +31 -0
- flwr/proto/fleet_pb2.py +14 -4
- flwr/proto/fleet_pb2.pyi +137 -61
- flwr/proto/fleet_pb2_grpc.py +189 -48
- flwr/proto/fleet_pb2_grpc.pyi +175 -61
- flwr/proto/grpcadapter_pb2.py +14 -4
- flwr/proto/grpcadapter_pb2.pyi +38 -16
- flwr/proto/grpcadapter_pb2_grpc.py +35 -4
- flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
- flwr/proto/heartbeat_pb2.py +17 -7
- flwr/proto/heartbeat_pb2.pyi +51 -22
- flwr/proto/heartbeat_pb2_grpc.py +20 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
- flwr/proto/log_pb2.py +13 -3
- flwr/proto/log_pb2.pyi +34 -11
- flwr/proto/log_pb2_grpc.py +20 -0
- flwr/proto/log_pb2_grpc.pyi +27 -0
- flwr/proto/message_pb2.py +15 -5
- flwr/proto/message_pb2.pyi +154 -86
- flwr/proto/message_pb2_grpc.py +20 -0
- flwr/proto/message_pb2_grpc.pyi +27 -0
- flwr/proto/node_pb2.py +15 -5
- flwr/proto/node_pb2.pyi +50 -25
- flwr/proto/node_pb2_grpc.py +20 -0
- flwr/proto/node_pb2_grpc.pyi +27 -0
- flwr/proto/recorddict_pb2.py +13 -3
- flwr/proto/recorddict_pb2.pyi +184 -107
- flwr/proto/recorddict_pb2_grpc.py +20 -0
- flwr/proto/recorddict_pb2_grpc.pyi +27 -0
- flwr/proto/run_pb2.py +40 -31
- flwr/proto/run_pb2.pyi +149 -84
- flwr/proto/run_pb2_grpc.py +20 -0
- flwr/proto/run_pb2_grpc.pyi +27 -0
- flwr/proto/serverappio_pb2.py +13 -3
- flwr/proto/serverappio_pb2.pyi +32 -8
- flwr/proto/serverappio_pb2_grpc.py +246 -65
- flwr/proto/serverappio_pb2_grpc.pyi +221 -85
- flwr/proto/simulationio_pb2.py +16 -8
- flwr/proto/simulationio_pb2.pyi +15 -0
- flwr/proto/simulationio_pb2_grpc.py +162 -41
- flwr/proto/simulationio_pb2_grpc.pyi +149 -55
- flwr/proto/transport_pb2.py +20 -10
- flwr/proto/transport_pb2.pyi +249 -160
- flwr/proto/transport_pb2_grpc.py +35 -4
- flwr/proto/transport_pb2_grpc.pyi +38 -8
- flwr/server/app.py +38 -17
- flwr/server/client_manager.py +4 -5
- flwr/server/client_proxy.py +10 -11
- flwr/server/compat/app.py +4 -5
- flwr/server/compat/app_utils.py +2 -1
- flwr/server/compat/grid_client_proxy.py +10 -12
- flwr/server/compat/legacy_context.py +3 -4
- flwr/server/fleet_event_log_interceptor.py +2 -1
- flwr/server/grid/grid.py +2 -3
- flwr/server/grid/grpc_grid.py +10 -8
- flwr/server/grid/inmemory_grid.py +4 -4
- flwr/server/run_serverapp.py +2 -3
- flwr/server/server.py +34 -39
- flwr/server/server_app.py +7 -8
- flwr/server/server_config.py +1 -2
- flwr/server/serverapp/app.py +34 -28
- flwr/server/serverapp_components.py +4 -5
- flwr/server/strategy/aggregate.py +9 -8
- flwr/server/strategy/bulyan.py +13 -11
- flwr/server/strategy/dp_adaptive_clipping.py +16 -20
- flwr/server/strategy/dp_fixed_clipping.py +12 -17
- flwr/server/strategy/dpfedavg_adaptive.py +3 -4
- flwr/server/strategy/dpfedavg_fixed.py +6 -10
- flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
- flwr/server/strategy/fedadagrad.py +18 -14
- flwr/server/strategy/fedadam.py +16 -14
- flwr/server/strategy/fedavg.py +16 -17
- flwr/server/strategy/fedavg_android.py +15 -15
- flwr/server/strategy/fedavgm.py +21 -18
- flwr/server/strategy/fedmedian.py +2 -3
- flwr/server/strategy/fedopt.py +11 -10
- flwr/server/strategy/fedprox.py +10 -9
- flwr/server/strategy/fedtrimmedavg.py +12 -11
- flwr/server/strategy/fedxgb_bagging.py +13 -11
- flwr/server/strategy/fedxgb_cyclic.py +6 -6
- flwr/server/strategy/fedxgb_nn_avg.py +4 -4
- flwr/server/strategy/fedyogi.py +16 -14
- flwr/server/strategy/krum.py +12 -11
- flwr/server/strategy/qfedavg.py +16 -15
- flwr/server/strategy/strategy.py +6 -9
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -4
- flwr/server/superlink/fleet/grpc_rere/node_auth_server_interceptor.py +3 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +34 -28
- flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -2
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
- flwr/server/superlink/fleet/vce/vce_api.py +15 -9
- flwr/server/superlink/linkstate/in_memory_linkstate.py +115 -150
- flwr/server/superlink/linkstate/linkstate.py +59 -43
- flwr/server/superlink/linkstate/linkstate_factory.py +22 -5
- flwr/server/superlink/linkstate/sqlite_linkstate.py +447 -438
- flwr/server/superlink/linkstate/utils.py +6 -6
- flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
- flwr/server/superlink/serverappio/serverappio_servicer.py +26 -21
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +18 -13
- flwr/server/superlink/utils.py +4 -6
- flwr/server/typing.py +1 -1
- flwr/server/utils/tensorboard.py +15 -8
- flwr/server/workflow/default_workflows.py +5 -5
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +8 -8
- flwr/serverapp/strategy/bulyan.py +16 -15
- flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
- flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
- flwr/serverapp/strategy/fedadagrad.py +10 -11
- flwr/serverapp/strategy/fedadam.py +10 -11
- flwr/serverapp/strategy/fedavg.py +9 -10
- flwr/serverapp/strategy/fedavgm.py +17 -16
- flwr/serverapp/strategy/fedmedian.py +2 -2
- flwr/serverapp/strategy/fedopt.py +10 -11
- flwr/serverapp/strategy/fedprox.py +7 -8
- flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
- flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
- flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
- flwr/serverapp/strategy/fedyogi.py +9 -11
- flwr/serverapp/strategy/krum.py +7 -7
- flwr/serverapp/strategy/multikrum.py +9 -9
- flwr/serverapp/strategy/qfedavg.py +17 -16
- flwr/serverapp/strategy/strategy.py +6 -9
- flwr/serverapp/strategy/strategy_utils.py +7 -8
- flwr/simulation/app.py +46 -42
- flwr/simulation/legacy_app.py +12 -12
- flwr/simulation/ray_transport/ray_actor.py +10 -11
- flwr/simulation/ray_transport/ray_client_proxy.py +11 -12
- flwr/simulation/run_simulation.py +43 -43
- flwr/simulation/simulationio_connection.py +4 -4
- flwr/supercore/cli/flower_superexec.py +3 -4
- flwr/supercore/constant.py +31 -1
- flwr/supercore/corestate/corestate.py +24 -3
- flwr/supercore/corestate/in_memory_corestate.py +138 -0
- flwr/supercore/corestate/sqlite_corestate.py +157 -0
- flwr/supercore/ffs/disk_ffs.py +1 -2
- flwr/supercore/ffs/ffs.py +1 -2
- flwr/supercore/ffs/ffs_factory.py +1 -2
- flwr/{common → supercore}/heartbeat.py +20 -25
- flwr/supercore/object_store/in_memory_object_store.py +1 -2
- flwr/supercore/object_store/object_store.py +1 -2
- flwr/supercore/object_store/object_store_factory.py +1 -2
- flwr/supercore/object_store/sqlite_object_store.py +8 -7
- flwr/supercore/primitives/asymmetric.py +1 -1
- flwr/supercore/primitives/asymmetric_ed25519.py +11 -1
- flwr/supercore/sqlite_mixin.py +37 -34
- flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
- flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
- flwr/supercore/superexec/run_superexec.py +9 -13
- flwr/superlink/artifact_provider/artifact_provider.py +1 -2
- flwr/superlink/auth_plugin/auth_plugin.py +6 -9
- flwr/superlink/auth_plugin/noop_auth_plugin.py +6 -9
- flwr/superlink/federation/__init__.py +24 -0
- flwr/superlink/federation/federation_manager.py +64 -0
- flwr/superlink/federation/noop_federation_manager.py +71 -0
- flwr/superlink/servicer/control/control_account_auth_interceptor.py +22 -13
- flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
- flwr/superlink/servicer/control/control_grpc.py +5 -6
- flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
- flwr/superlink/servicer/control/control_servicer.py +102 -18
- flwr/supernode/cli/flower_supernode.py +58 -3
- flwr/supernode/nodestate/in_memory_nodestate.py +60 -49
- flwr/supernode/nodestate/nodestate.py +7 -8
- flwr/supernode/nodestate/nodestate_factory.py +7 -4
- flwr/supernode/runtime/run_clientapp.py +41 -22
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +40 -10
- flwr/supernode/start_client_internal.py +158 -42
- {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/METADATA +8 -8
- flwr-1.24.0.dist-info/RECORD +454 -0
- flwr/supercore/object_store/utils.py +0 -43
- flwr-1.23.0.dist-info/RECORD +0 -439
- {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/WHEEL +0 -0
- {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""NoOp implementation of FederationManager."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from flwr.common.constant import NOOP_FLWR_AID
|
|
19
|
+
from flwr.common.typing import Federation
|
|
20
|
+
from flwr.supercore.constant import NOOP_FEDERATION
|
|
21
|
+
|
|
22
|
+
from .federation_manager import FederationManager
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class NoOpFederationManager(FederationManager):
|
|
26
|
+
"""No-Op FederationManager implementation."""
|
|
27
|
+
|
|
28
|
+
def exists(self, federation: str) -> bool:
|
|
29
|
+
"""Check if a federation exists."""
|
|
30
|
+
return federation == NOOP_FEDERATION
|
|
31
|
+
|
|
32
|
+
def has_member(self, flwr_aid: str, federation: str) -> bool:
|
|
33
|
+
"""Check if the given account is a member of the federation."""
|
|
34
|
+
if not self.exists(federation):
|
|
35
|
+
raise ValueError(f"Federation '{federation}' does not exist.")
|
|
36
|
+
return flwr_aid == NOOP_FLWR_AID
|
|
37
|
+
|
|
38
|
+
def filter_nodes(self, node_ids: set[int], federation: str) -> set[int]:
|
|
39
|
+
"""Given a list of node IDs, return sublist with nodes in federation."""
|
|
40
|
+
if not self.exists(federation):
|
|
41
|
+
raise ValueError(f"Federation '{federation}' does not exist.")
|
|
42
|
+
return node_ids
|
|
43
|
+
|
|
44
|
+
def has_node(self, node_id: int, federation: str) -> bool:
|
|
45
|
+
"""Given a node ID, check if it is in the federation."""
|
|
46
|
+
if not self.exists(federation):
|
|
47
|
+
raise ValueError(f"Federation '{federation}' does not exist.")
|
|
48
|
+
return True
|
|
49
|
+
|
|
50
|
+
def get_federations(self, flwr_aid: str) -> list[str]:
|
|
51
|
+
"""Get federations of which the account is a member."""
|
|
52
|
+
if flwr_aid != NOOP_FLWR_AID:
|
|
53
|
+
return []
|
|
54
|
+
return [NOOP_FEDERATION]
|
|
55
|
+
|
|
56
|
+
def get_details(self, federation: str) -> Federation:
|
|
57
|
+
"""Get details of the federation."""
|
|
58
|
+
if federation != NOOP_FEDERATION:
|
|
59
|
+
raise ValueError(f"Federation '{federation}' does not exist.")
|
|
60
|
+
|
|
61
|
+
run_ids = self.linkstate.get_run_ids(flwr_aid=NOOP_FLWR_AID)
|
|
62
|
+
nodes = list(self.linkstate.get_node_info(owner_aids=[NOOP_FLWR_AID]))
|
|
63
|
+
runs = [
|
|
64
|
+
run for run_id in run_ids if (run := self.linkstate.get_run(run_id=run_id))
|
|
65
|
+
]
|
|
66
|
+
return Federation(
|
|
67
|
+
name=NOOP_FEDERATION,
|
|
68
|
+
member_aids=[NOOP_FLWR_AID],
|
|
69
|
+
nodes=nodes,
|
|
70
|
+
runs=runs,
|
|
71
|
+
)
|
|
@@ -16,7 +16,8 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import contextvars
|
|
19
|
-
from
|
|
19
|
+
from collections.abc import Callable
|
|
20
|
+
from typing import Any
|
|
20
21
|
|
|
21
22
|
import grpc
|
|
22
23
|
|
|
@@ -33,23 +34,31 @@ from flwr.proto.control_pb2 import ( # pylint: disable=E0611
|
|
|
33
34
|
)
|
|
34
35
|
from flwr.superlink.auth_plugin import ControlAuthnPlugin, ControlAuthzPlugin
|
|
35
36
|
|
|
36
|
-
Request =
|
|
37
|
-
StartRunRequest
|
|
38
|
-
|
|
39
|
-
GetLoginDetailsRequest,
|
|
40
|
-
GetAuthTokensRequest,
|
|
41
|
-
]
|
|
37
|
+
Request = (
|
|
38
|
+
StartRunRequest | StreamLogsRequest | GetLoginDetailsRequest | GetAuthTokensRequest
|
|
39
|
+
)
|
|
42
40
|
|
|
43
|
-
Response =
|
|
44
|
-
StartRunResponse
|
|
45
|
-
|
|
41
|
+
Response = (
|
|
42
|
+
StartRunResponse
|
|
43
|
+
| StreamLogsResponse
|
|
44
|
+
| GetLoginDetailsResponse
|
|
45
|
+
| GetAuthTokensResponse
|
|
46
|
+
)
|
|
46
47
|
|
|
47
48
|
|
|
48
|
-
shared_account_info: contextvars.ContextVar[AccountInfo] =
|
|
49
|
-
"account_info", default=
|
|
49
|
+
shared_account_info: contextvars.ContextVar[AccountInfo | None] = (
|
|
50
|
+
contextvars.ContextVar("account_info", default=None)
|
|
50
51
|
)
|
|
51
52
|
|
|
52
53
|
|
|
54
|
+
def get_current_account_info() -> AccountInfo:
|
|
55
|
+
"""Get the current account info from context, or return a default if not set."""
|
|
56
|
+
account_info = shared_account_info.get()
|
|
57
|
+
if account_info is None:
|
|
58
|
+
return AccountInfo(flwr_aid=None, account_name=None)
|
|
59
|
+
return account_info
|
|
60
|
+
|
|
61
|
+
|
|
53
62
|
class ControlAccountAuthInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
54
63
|
"""Control API interceptor for account authentication."""
|
|
55
64
|
|
|
@@ -93,7 +102,7 @@ class ControlAccountAuthInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
93
102
|
|
|
94
103
|
# Intercept GetLoginDetails and GetAuthTokens requests, and return
|
|
95
104
|
# the response without authentication
|
|
96
|
-
if isinstance(request, (GetLoginDetailsRequest
|
|
105
|
+
if isinstance(request, (GetLoginDetailsRequest | GetAuthTokensRequest)):
|
|
97
106
|
return call(request, context) # type: ignore
|
|
98
107
|
|
|
99
108
|
# For other requests, check if the account is authenticated
|
|
@@ -15,8 +15,8 @@
|
|
|
15
15
|
"""Flower Control API event log interceptor."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from collections.abc import Iterator
|
|
19
|
-
from typing import Any,
|
|
18
|
+
from collections.abc import Callable, Iterator
|
|
19
|
+
from typing import Any, cast
|
|
20
20
|
|
|
21
21
|
import grpc
|
|
22
22
|
from google.protobuf.message import Message as GrpcMessage
|
|
@@ -24,7 +24,7 @@ from google.protobuf.message import Message as GrpcMessage
|
|
|
24
24
|
from flwr.common.event_log_plugin.event_log_plugin import EventLogWriterPlugin
|
|
25
25
|
from flwr.common.typing import LogEntry
|
|
26
26
|
|
|
27
|
-
from .control_account_auth_interceptor import
|
|
27
|
+
from .control_account_auth_interceptor import get_current_account_info
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
class ControlEventLogInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
@@ -60,13 +60,13 @@ class ControlEventLogInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
60
60
|
def _generic_method_handler(
|
|
61
61
|
request: GrpcMessage,
|
|
62
62
|
context: grpc.ServicerContext,
|
|
63
|
-
) ->
|
|
63
|
+
) -> GrpcMessage | Iterator[GrpcMessage] | BaseException:
|
|
64
64
|
log_entry: LogEntry
|
|
65
65
|
# Log before call
|
|
66
66
|
log_entry = self.log_plugin.compose_log_before_event(
|
|
67
67
|
request=request,
|
|
68
68
|
context=context,
|
|
69
|
-
account_info=
|
|
69
|
+
account_info=get_current_account_info(),
|
|
70
70
|
method_name=method_name,
|
|
71
71
|
)
|
|
72
72
|
self.log_plugin.write_log(log_entry)
|
|
@@ -85,7 +85,7 @@ class ControlEventLogInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
85
85
|
log_entry = self.log_plugin.compose_log_after_event(
|
|
86
86
|
request=request,
|
|
87
87
|
context=context,
|
|
88
|
-
account_info=
|
|
88
|
+
account_info=get_current_account_info(),
|
|
89
89
|
method_name=method_name,
|
|
90
90
|
response=unary_response or error,
|
|
91
91
|
)
|
|
@@ -115,7 +115,7 @@ class ControlEventLogInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
115
115
|
log_entry = self.log_plugin.compose_log_after_event(
|
|
116
116
|
request=request,
|
|
117
117
|
context=context,
|
|
118
|
-
account_info=
|
|
118
|
+
account_info=get_current_account_info(),
|
|
119
119
|
method_name=method_name,
|
|
120
120
|
response=stream_response or error,
|
|
121
121
|
)
|
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from logging import INFO
|
|
19
|
-
from typing import Optional
|
|
20
19
|
|
|
21
20
|
import grpc
|
|
22
21
|
|
|
@@ -46,7 +45,7 @@ try:
|
|
|
46
45
|
from flwr.ee import get_license_plugin
|
|
47
46
|
except ImportError:
|
|
48
47
|
|
|
49
|
-
def get_license_plugin() ->
|
|
48
|
+
def get_license_plugin() -> LicensePlugin | None:
|
|
50
49
|
"""Return the license plugin."""
|
|
51
50
|
|
|
52
51
|
|
|
@@ -56,15 +55,15 @@ def run_control_api_grpc(
|
|
|
56
55
|
state_factory: LinkStateFactory,
|
|
57
56
|
ffs_factory: FfsFactory,
|
|
58
57
|
objectstore_factory: ObjectStoreFactory,
|
|
59
|
-
certificates:
|
|
58
|
+
certificates: tuple[bytes, bytes, bytes] | None,
|
|
60
59
|
is_simulation: bool,
|
|
61
60
|
authn_plugin: ControlAuthnPlugin,
|
|
62
61
|
authz_plugin: ControlAuthzPlugin,
|
|
63
|
-
event_log_plugin:
|
|
64
|
-
artifact_provider:
|
|
62
|
+
event_log_plugin: EventLogWriterPlugin | None = None,
|
|
63
|
+
artifact_provider: ArtifactProvider | None = None,
|
|
65
64
|
) -> grpc.Server:
|
|
66
65
|
"""Run Control API (gRPC, request-response)."""
|
|
67
|
-
license_plugin:
|
|
66
|
+
license_plugin: LicensePlugin | None = get_license_plugin()
|
|
68
67
|
if license_plugin and not license_plugin.check_license():
|
|
69
68
|
flwr_exit(ExitCode.SUPERLINK_LICENSE_INVALID)
|
|
70
69
|
|
|
@@ -15,8 +15,8 @@
|
|
|
15
15
|
"""Flower Control API license interceptor."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from collections.abc import Iterator
|
|
19
|
-
from typing import Any
|
|
18
|
+
from collections.abc import Callable, Iterator
|
|
19
|
+
from typing import Any
|
|
20
20
|
|
|
21
21
|
import grpc
|
|
22
22
|
from google.protobuf.message import Message as GrpcMessage
|
|
@@ -57,7 +57,7 @@ class ControlLicenseInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
57
57
|
def _generic_method_handler(
|
|
58
58
|
request: GrpcMessage,
|
|
59
59
|
context: grpc.ServicerContext,
|
|
60
|
-
) ->
|
|
60
|
+
) -> GrpcMessage | Iterator[GrpcMessage]:
|
|
61
61
|
"""Handle the method call with license checking."""
|
|
62
62
|
call = method_handler.unary_unary or method_handler.unary_stream
|
|
63
63
|
|
|
@@ -19,7 +19,7 @@ import hashlib
|
|
|
19
19
|
import time
|
|
20
20
|
from collections.abc import Generator, Sequence
|
|
21
21
|
from logging import ERROR, INFO
|
|
22
|
-
from typing import Any,
|
|
22
|
+
from typing import Any, cast
|
|
23
23
|
|
|
24
24
|
import grpc
|
|
25
25
|
|
|
@@ -52,6 +52,8 @@ from flwr.proto.control_pb2 import ( # pylint: disable=E0611
|
|
|
52
52
|
GetAuthTokensResponse,
|
|
53
53
|
GetLoginDetailsRequest,
|
|
54
54
|
GetLoginDetailsResponse,
|
|
55
|
+
ListFederationsRequest,
|
|
56
|
+
ListFederationsResponse,
|
|
55
57
|
ListNodesRequest,
|
|
56
58
|
ListNodesResponse,
|
|
57
59
|
ListRunsRequest,
|
|
@@ -60,6 +62,8 @@ from flwr.proto.control_pb2 import ( # pylint: disable=E0611
|
|
|
60
62
|
PullArtifactsResponse,
|
|
61
63
|
RegisterNodeRequest,
|
|
62
64
|
RegisterNodeResponse,
|
|
65
|
+
ShowFederationRequest,
|
|
66
|
+
ShowFederationResponse,
|
|
63
67
|
StartRunRequest,
|
|
64
68
|
StartRunResponse,
|
|
65
69
|
StopRunRequest,
|
|
@@ -69,6 +73,7 @@ from flwr.proto.control_pb2 import ( # pylint: disable=E0611
|
|
|
69
73
|
UnregisterNodeRequest,
|
|
70
74
|
UnregisterNodeResponse,
|
|
71
75
|
)
|
|
76
|
+
from flwr.proto.federation_pb2 import Federation # pylint: disable=E0611
|
|
72
77
|
from flwr.proto.node_pb2 import NodeInfo # pylint: disable=E0611
|
|
73
78
|
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
|
74
79
|
from flwr.supercore.ffs import FfsFactory
|
|
@@ -77,7 +82,7 @@ from flwr.supercore.primitives.asymmetric import bytes_to_public_key, uses_nist_
|
|
|
77
82
|
from flwr.superlink.artifact_provider import ArtifactProvider
|
|
78
83
|
from flwr.superlink.auth_plugin import ControlAuthnPlugin
|
|
79
84
|
|
|
80
|
-
from .control_account_auth_interceptor import
|
|
85
|
+
from .control_account_auth_interceptor import get_current_account_info
|
|
81
86
|
|
|
82
87
|
|
|
83
88
|
class ControlServicer(control_pb2_grpc.ControlServicer):
|
|
@@ -90,7 +95,7 @@ class ControlServicer(control_pb2_grpc.ControlServicer):
|
|
|
90
95
|
objectstore_factory: ObjectStoreFactory,
|
|
91
96
|
is_simulation: bool,
|
|
92
97
|
authn_plugin: ControlAuthnPlugin,
|
|
93
|
-
artifact_provider:
|
|
98
|
+
artifact_provider: ArtifactProvider | None = None,
|
|
94
99
|
) -> None:
|
|
95
100
|
self.linkstate_factory = linkstate_factory
|
|
96
101
|
self.ffs_factory = ffs_factory
|
|
@@ -115,8 +120,8 @@ class ControlServicer(control_pb2_grpc.ControlServicer):
|
|
|
115
120
|
)
|
|
116
121
|
return StartRunResponse()
|
|
117
122
|
|
|
118
|
-
flwr_aid =
|
|
119
|
-
_check_flwr_aid_exists(flwr_aid, context)
|
|
123
|
+
flwr_aid = get_current_account_info().flwr_aid
|
|
124
|
+
flwr_aid = _check_flwr_aid_exists(flwr_aid, context)
|
|
120
125
|
override_config = user_config_from_proto(request.override_config)
|
|
121
126
|
federation_options = config_record_from_proto(request.federation_options)
|
|
122
127
|
fab_file = request.fab.content
|
|
@@ -128,6 +133,19 @@ class ControlServicer(control_pb2_grpc.ControlServicer):
|
|
|
128
133
|
"Federation options doesn't contain key `num-supernodes`."
|
|
129
134
|
)
|
|
130
135
|
|
|
136
|
+
# Check (1) federation exists and (2) the flwr_aid is a member
|
|
137
|
+
federation = request.federation
|
|
138
|
+
|
|
139
|
+
if not state.federation_manager.exists(federation):
|
|
140
|
+
raise ValueError(f"Federation '{federation}' does not exist.")
|
|
141
|
+
|
|
142
|
+
if not state.federation_manager.has_member(flwr_aid, federation):
|
|
143
|
+
raise ValueError(
|
|
144
|
+
f"Account with ID '{flwr_aid}' is not a member of the "
|
|
145
|
+
f"federation '{federation}'. Please log in with another account "
|
|
146
|
+
"or request access to this federation."
|
|
147
|
+
)
|
|
148
|
+
|
|
131
149
|
# Create run
|
|
132
150
|
fab = Fab(
|
|
133
151
|
hashlib.sha256(fab_file).hexdigest(),
|
|
@@ -146,6 +164,7 @@ class ControlServicer(control_pb2_grpc.ControlServicer):
|
|
|
146
164
|
fab_version,
|
|
147
165
|
fab_hash,
|
|
148
166
|
override_config,
|
|
167
|
+
request.federation,
|
|
149
168
|
federation_options,
|
|
150
169
|
flwr_aid,
|
|
151
170
|
)
|
|
@@ -174,7 +193,10 @@ class ControlServicer(control_pb2_grpc.ControlServicer):
|
|
|
174
193
|
# pylint: disable-next=broad-except
|
|
175
194
|
except Exception as e:
|
|
176
195
|
log(ERROR, "Could not start run: %s", str(e))
|
|
177
|
-
|
|
196
|
+
context.abort(
|
|
197
|
+
grpc.StatusCode.FAILED_PRECONDITION,
|
|
198
|
+
str(e),
|
|
199
|
+
)
|
|
178
200
|
|
|
179
201
|
log(INFO, "Created run %s", str(run_id))
|
|
180
202
|
return StartRunResponse(run_id=run_id)
|
|
@@ -195,7 +217,7 @@ class ControlServicer(control_pb2_grpc.ControlServicer):
|
|
|
195
217
|
context.abort(grpc.StatusCode.NOT_FOUND, RUN_ID_NOT_FOUND_MESSAGE)
|
|
196
218
|
|
|
197
219
|
# Check if `flwr_aid` matches the run's `flwr_aid`
|
|
198
|
-
flwr_aid =
|
|
220
|
+
flwr_aid = get_current_account_info().flwr_aid
|
|
199
221
|
_check_flwr_aid_in_run(flwr_aid=flwr_aid, run=cast(Run, run), context=context)
|
|
200
222
|
|
|
201
223
|
after_timestamp = request.after_timestamp + 1e-6
|
|
@@ -234,7 +256,7 @@ class ControlServicer(control_pb2_grpc.ControlServicer):
|
|
|
234
256
|
if not request.HasField("run_id"):
|
|
235
257
|
# If no `run_id` is specified and account auth is enabled,
|
|
236
258
|
# return run IDs for the authenticated account
|
|
237
|
-
flwr_aid =
|
|
259
|
+
flwr_aid = get_current_account_info().flwr_aid
|
|
238
260
|
_check_flwr_aid_exists(flwr_aid, context)
|
|
239
261
|
run_ids = state.get_run_ids(flwr_aid=flwr_aid)
|
|
240
262
|
# Build a set of run IDs for `flwr ls --run-id <run_id>`
|
|
@@ -249,7 +271,7 @@ class ControlServicer(control_pb2_grpc.ControlServicer):
|
|
|
249
271
|
raise grpc.RpcError() # This line is unreachable
|
|
250
272
|
|
|
251
273
|
# Check if `flwr_aid` matches the run's `flwr_aid`
|
|
252
|
-
flwr_aid =
|
|
274
|
+
flwr_aid = get_current_account_info().flwr_aid
|
|
253
275
|
_check_flwr_aid_in_run(flwr_aid=flwr_aid, run=run, context=context)
|
|
254
276
|
|
|
255
277
|
run_ids = {run_id}
|
|
@@ -275,7 +297,7 @@ class ControlServicer(control_pb2_grpc.ControlServicer):
|
|
|
275
297
|
raise grpc.RpcError() # This line is unreachable
|
|
276
298
|
|
|
277
299
|
# Check if `flwr_aid` matches the run's `flwr_aid`
|
|
278
|
-
flwr_aid =
|
|
300
|
+
flwr_aid = get_current_account_info().flwr_aid
|
|
279
301
|
_check_flwr_aid_in_run(flwr_aid=flwr_aid, run=run, context=context)
|
|
280
302
|
|
|
281
303
|
run_status = state.get_run_status({run_id})[run_id]
|
|
@@ -285,11 +307,15 @@ class ControlServicer(control_pb2_grpc.ControlServicer):
|
|
|
285
307
|
f"Run ID {run_id} is already finished",
|
|
286
308
|
)
|
|
287
309
|
|
|
310
|
+
# Update run status to finished:stopped
|
|
288
311
|
update_success = state.update_run_status(
|
|
289
312
|
run_id=run_id,
|
|
290
313
|
new_status=RunStatus(Status.FINISHED, SubStatus.STOPPED, ""),
|
|
291
314
|
)
|
|
292
315
|
|
|
316
|
+
# Delete the token associated with the run to stop further operations
|
|
317
|
+
state.delete_token(run_id)
|
|
318
|
+
|
|
293
319
|
if update_success:
|
|
294
320
|
message_ids: set[str] = state.get_message_ids_from_run_id(run_id)
|
|
295
321
|
|
|
@@ -385,7 +411,7 @@ class ControlServicer(control_pb2_grpc.ControlServicer):
|
|
|
385
411
|
)
|
|
386
412
|
|
|
387
413
|
# Check if `flwr_aid` matches the run's `flwr_aid`
|
|
388
|
-
flwr_aid =
|
|
414
|
+
flwr_aid = get_current_account_info().flwr_aid
|
|
389
415
|
_check_flwr_aid_in_run(flwr_aid=flwr_aid, run=run, context=context)
|
|
390
416
|
|
|
391
417
|
# Call artifact provider
|
|
@@ -415,11 +441,14 @@ class ControlServicer(control_pb2_grpc.ControlServicer):
|
|
|
415
441
|
state = self.linkstate_factory.state()
|
|
416
442
|
node_id = 0
|
|
417
443
|
|
|
418
|
-
flwr_aid =
|
|
444
|
+
flwr_aid = get_current_account_info().flwr_aid
|
|
419
445
|
flwr_aid = _check_flwr_aid_exists(flwr_aid, context)
|
|
446
|
+
# Account name exists if `flwr_aid` exists
|
|
447
|
+
account_name = cast(str, get_current_account_info().account_name)
|
|
420
448
|
try:
|
|
421
449
|
node_id = state.create_node(
|
|
422
450
|
owner_aid=flwr_aid,
|
|
451
|
+
owner_name=account_name,
|
|
423
452
|
public_key=request.public_key,
|
|
424
453
|
heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL,
|
|
425
454
|
)
|
|
@@ -443,7 +472,7 @@ class ControlServicer(control_pb2_grpc.ControlServicer):
|
|
|
443
472
|
# Init link state
|
|
444
473
|
state = self.linkstate_factory.state()
|
|
445
474
|
|
|
446
|
-
flwr_aid =
|
|
475
|
+
flwr_aid = get_current_account_info().flwr_aid
|
|
447
476
|
flwr_aid = _check_flwr_aid_exists(flwr_aid, context)
|
|
448
477
|
try:
|
|
449
478
|
state.delete_node(owner_aid=flwr_aid, node_id=request.node_id)
|
|
@@ -471,13 +500,70 @@ class ControlServicer(control_pb2_grpc.ControlServicer):
|
|
|
471
500
|
# Init link state
|
|
472
501
|
state = self.linkstate_factory.state()
|
|
473
502
|
|
|
474
|
-
flwr_aid =
|
|
503
|
+
flwr_aid = get_current_account_info().flwr_aid
|
|
475
504
|
flwr_aid = _check_flwr_aid_exists(flwr_aid, context)
|
|
476
505
|
# Retrieve all nodes for the account
|
|
477
506
|
nodes_info = state.get_node_info(owner_aids=[flwr_aid])
|
|
478
507
|
|
|
479
508
|
return ListNodesResponse(nodes_info=nodes_info, now=now().isoformat())
|
|
480
509
|
|
|
510
|
+
def ListFederations(
|
|
511
|
+
self, request: ListFederationsRequest, context: grpc.ServicerContext
|
|
512
|
+
) -> ListFederationsResponse:
|
|
513
|
+
"""List all SuperNodes."""
|
|
514
|
+
log(INFO, "ControlServicer.ListFederations")
|
|
515
|
+
|
|
516
|
+
# Init link state
|
|
517
|
+
state = self.linkstate_factory.state()
|
|
518
|
+
|
|
519
|
+
flwr_aid = get_current_account_info().flwr_aid
|
|
520
|
+
flwr_aid = _check_flwr_aid_exists(flwr_aid, context)
|
|
521
|
+
|
|
522
|
+
# Get federations the account is a member of
|
|
523
|
+
federations = state.federation_manager.get_federations(flwr_aid=flwr_aid)
|
|
524
|
+
|
|
525
|
+
return ListFederationsResponse(
|
|
526
|
+
federations=[Federation(name=fed) for fed in federations]
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
def ShowFederation(
|
|
530
|
+
self, request: ShowFederationRequest, context: grpc.ServicerContext
|
|
531
|
+
) -> ShowFederationResponse:
|
|
532
|
+
"""Show details of a specific Federation."""
|
|
533
|
+
log(INFO, "ControlServicer.ShowFederation")
|
|
534
|
+
|
|
535
|
+
# Init link state
|
|
536
|
+
state = self.linkstate_factory.state()
|
|
537
|
+
|
|
538
|
+
flwr_aid = get_current_account_info().flwr_aid
|
|
539
|
+
flwr_aid = _check_flwr_aid_exists(flwr_aid, context)
|
|
540
|
+
|
|
541
|
+
# Get federations the account is a member of
|
|
542
|
+
federations = state.federation_manager.get_federations(flwr_aid=flwr_aid)
|
|
543
|
+
|
|
544
|
+
# Ensure flwr_aid is a member of the requested federation
|
|
545
|
+
federation = request.federation_name
|
|
546
|
+
if federation not in federations:
|
|
547
|
+
context.abort(
|
|
548
|
+
grpc.StatusCode.FAILED_PRECONDITION,
|
|
549
|
+
f"Federation '{federation}' does not exist or you are "
|
|
550
|
+
"not a member of it.",
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
# Fetch federation details
|
|
554
|
+
details = state.federation_manager.get_details(federation)
|
|
555
|
+
|
|
556
|
+
# Build Federation proto object
|
|
557
|
+
federation_proto = Federation(
|
|
558
|
+
name=federation,
|
|
559
|
+
member_aids=details.member_aids,
|
|
560
|
+
nodes=details.nodes,
|
|
561
|
+
runs=[run_to_proto(run) for run in details.runs],
|
|
562
|
+
)
|
|
563
|
+
return ShowFederationResponse(
|
|
564
|
+
federation=federation_proto, now=now().isoformat()
|
|
565
|
+
)
|
|
566
|
+
|
|
481
567
|
|
|
482
568
|
def _create_list_runs_response(
|
|
483
569
|
run_ids: set[int], state: LinkState, store: ObjectStore
|
|
@@ -496,9 +582,7 @@ def _create_list_runs_response(
|
|
|
496
582
|
)
|
|
497
583
|
|
|
498
584
|
|
|
499
|
-
def _check_flwr_aid_exists(
|
|
500
|
-
flwr_aid: Optional[str], context: grpc.ServicerContext
|
|
501
|
-
) -> str:
|
|
585
|
+
def _check_flwr_aid_exists(flwr_aid: str | None, context: grpc.ServicerContext) -> str:
|
|
502
586
|
"""Guard clause to check if `flwr_aid` exists."""
|
|
503
587
|
if flwr_aid is None:
|
|
504
588
|
context.abort(
|
|
@@ -510,7 +594,7 @@ def _check_flwr_aid_exists(
|
|
|
510
594
|
|
|
511
595
|
|
|
512
596
|
def _check_flwr_aid_in_run(
|
|
513
|
-
flwr_aid:
|
|
597
|
+
flwr_aid: str | None, run: Run, context: grpc.ServicerContext
|
|
514
598
|
) -> None:
|
|
515
599
|
"""Guard clause to check if `flwr_aid` matches the run's `flwr_aid`."""
|
|
516
600
|
_check_flwr_aid_exists(flwr_aid, context)
|
|
@@ -18,11 +18,12 @@
|
|
|
18
18
|
import argparse
|
|
19
19
|
from logging import DEBUG, INFO, WARN
|
|
20
20
|
from pathlib import Path
|
|
21
|
-
from typing import Optional
|
|
22
21
|
|
|
22
|
+
import yaml
|
|
23
23
|
from cryptography.exceptions import UnsupportedAlgorithm
|
|
24
|
-
from cryptography.hazmat.primitives.asymmetric import ec
|
|
24
|
+
from cryptography.hazmat.primitives.asymmetric import ec, ed25519
|
|
25
25
|
from cryptography.hazmat.primitives.serialization import load_ssh_private_key
|
|
26
|
+
from cryptography.hazmat.primitives.serialization.ssh import load_ssh_public_key
|
|
26
27
|
|
|
27
28
|
from flwr.common import EventType, event
|
|
28
29
|
from flwr.common.args import try_obtain_root_certificates
|
|
@@ -58,6 +59,9 @@ def flower_supernode() -> None:
|
|
|
58
59
|
"Ignoring `--flwr-dir`.",
|
|
59
60
|
)
|
|
60
61
|
|
|
62
|
+
trusted_entities = _try_obtain_trusted_entities(args.trusted_entities)
|
|
63
|
+
if trusted_entities:
|
|
64
|
+
_validate_public_keys_ed25519(trusted_entities)
|
|
61
65
|
root_certificates = try_obtain_root_certificates(args, args.superlink)
|
|
62
66
|
authentication_keys = _try_setup_client_authentication(args)
|
|
63
67
|
|
|
@@ -85,6 +89,7 @@ def flower_supernode() -> None:
|
|
|
85
89
|
isolation=args.isolation,
|
|
86
90
|
clientappio_api_address=args.clientappio_api_address,
|
|
87
91
|
health_server_address=args.health_server_address,
|
|
92
|
+
trusted_entities=trusted_entities,
|
|
88
93
|
)
|
|
89
94
|
|
|
90
95
|
|
|
@@ -124,6 +129,18 @@ def _parse_args_run_supernode() -> argparse.ArgumentParser:
|
|
|
124
129
|
help="ClientAppIo API (gRPC) server address (IPv4, IPv6, or a domain name). "
|
|
125
130
|
f"By default, it is set to {CLIENTAPPIO_API_DEFAULT_SERVER_ADDRESS}.",
|
|
126
131
|
)
|
|
132
|
+
parser.add_argument(
|
|
133
|
+
"--trusted-entities",
|
|
134
|
+
type=Path,
|
|
135
|
+
default=None,
|
|
136
|
+
metavar="YAML_FILE",
|
|
137
|
+
help=(
|
|
138
|
+
"Path to a YAML file defining trusted entities. "
|
|
139
|
+
"The file must map public key IDs to public keys. "
|
|
140
|
+
"Example: { fpk_UUID1: 'ssh-ed25519 <key1> [comment1]', "
|
|
141
|
+
"fpk_UUID2: 'ssh-ed25519 <key2> [comment2]' }"
|
|
142
|
+
),
|
|
143
|
+
)
|
|
127
144
|
add_args_health(parser)
|
|
128
145
|
|
|
129
146
|
return parser
|
|
@@ -210,7 +227,7 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
210
227
|
|
|
211
228
|
def _try_setup_client_authentication(
|
|
212
229
|
args: argparse.Namespace,
|
|
213
|
-
) ->
|
|
230
|
+
) -> tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] | None:
|
|
214
231
|
if not args.auth_supernode_private_key:
|
|
215
232
|
return None
|
|
216
233
|
|
|
@@ -235,3 +252,41 @@ def _try_setup_client_authentication(
|
|
|
235
252
|
"private key provided by `--auth-supernode-private-key`.",
|
|
236
253
|
)
|
|
237
254
|
return ssh_private_key, ssh_private_key.public_key()
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def _try_obtain_trusted_entities(
|
|
258
|
+
trusted_entities_path: Path | None,
|
|
259
|
+
) -> dict[str, str] | None:
|
|
260
|
+
"""Validate and return the trust entities."""
|
|
261
|
+
if not trusted_entities_path:
|
|
262
|
+
return None
|
|
263
|
+
if not trusted_entities_path.is_file():
|
|
264
|
+
flwr_exit(
|
|
265
|
+
ExitCode.SUPERNODE_INVALID_TRUSTED_ENTITIES,
|
|
266
|
+
"Path argument `--trusted-entities` does not point to a file.",
|
|
267
|
+
)
|
|
268
|
+
try:
|
|
269
|
+
with trusted_entities_path.open("r", encoding="utf-8") as f:
|
|
270
|
+
trusted_entities = yaml.safe_load(f)
|
|
271
|
+
if not isinstance(trusted_entities, dict):
|
|
272
|
+
raise ValueError("Invalid trusted entities format.")
|
|
273
|
+
except (yaml.YAMLError, ValueError) as e:
|
|
274
|
+
flwr_exit(
|
|
275
|
+
ExitCode.SUPERNODE_INVALID_TRUSTED_ENTITIES,
|
|
276
|
+
f"Failed to read YAML file '{trusted_entities_path}': {e}",
|
|
277
|
+
)
|
|
278
|
+
return trusted_entities
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def _validate_public_keys_ed25519(trusted_entities: dict[str, str]) -> None:
|
|
282
|
+
"""Validate public keys for the trust entities are Ed25519."""
|
|
283
|
+
for public_key_id in trusted_entities.keys():
|
|
284
|
+
verifier_public_key = load_ssh_public_key(
|
|
285
|
+
trusted_entities[public_key_id].encode("utf-8")
|
|
286
|
+
)
|
|
287
|
+
if not isinstance(verifier_public_key, ed25519.Ed25519PublicKey):
|
|
288
|
+
flwr_exit(
|
|
289
|
+
ExitCode.SUPERNODE_INVALID_TRUSTED_ENTITIES,
|
|
290
|
+
"The provided public key associated with "
|
|
291
|
+
f"trusted entity {public_key_id} is not Ed25519.",
|
|
292
|
+
)
|