flwr 1.22.0__py3-none-any.whl → 1.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +16 -5
- flwr/app/error.py +2 -2
- flwr/app/exception.py +3 -3
- flwr/cli/app.py +34 -1
- flwr/cli/app_cmd/__init__.py +23 -0
- flwr/cli/app_cmd/publish.py +285 -0
- flwr/cli/app_cmd/review.py +252 -0
- flwr/cli/auth_plugin/__init__.py +15 -6
- flwr/cli/auth_plugin/auth_plugin.py +94 -0
- flwr/cli/auth_plugin/noop_auth_plugin.py +101 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +46 -32
- flwr/cli/build.py +166 -53
- flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +29 -11
- flwr/cli/config_utils.py +101 -13
- flwr/cli/federation/__init__.py +24 -0
- flwr/cli/federation/ls.py +140 -0
- flwr/cli/federation/show.py +317 -0
- flwr/cli/install.py +91 -13
- flwr/cli/log.py +54 -11
- flwr/cli/login/login.py +41 -27
- flwr/cli/ls.py +177 -133
- flwr/cli/new/new.py +175 -40
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
- flwr/cli/pull.py +12 -7
- flwr/cli/run/run.py +82 -31
- flwr/cli/run_utils.py +130 -0
- flwr/cli/stop.py +27 -9
- flwr/cli/supernode/__init__.py +25 -0
- flwr/cli/supernode/ls.py +268 -0
- flwr/cli/supernode/register.py +190 -0
- flwr/cli/supernode/unregister.py +140 -0
- flwr/cli/utils.py +464 -81
- flwr/client/__init__.py +2 -1
- flwr/client/dpfedavg_numpy_client.py +4 -1
- flwr/client/grpc_adapter_client/connection.py +12 -15
- flwr/client/grpc_rere_client/connection.py +68 -41
- flwr/client/grpc_rere_client/grpc_adapter.py +34 -14
- flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +5 -7
- flwr/client/message_handler/message_handler.py +2 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +10 -8
- flwr/client/numpy_client.py +1 -1
- flwr/client/rest_client/connection.py +94 -51
- flwr/client/run_info_store.py +4 -5
- flwr/client/typing.py +1 -1
- flwr/clientapp/__init__.py +1 -2
- flwr/{client → clientapp}/client_app.py +9 -10
- flwr/clientapp/mod/centraldp_mods.py +16 -17
- flwr/clientapp/mod/localdp_mod.py +8 -9
- flwr/clientapp/typing.py +1 -1
- flwr/{client/clientapp → clientapp}/utils.py +4 -4
- flwr/common/address.py +1 -2
- flwr/common/args.py +3 -4
- flwr/common/config.py +13 -16
- flwr/common/constant.py +56 -13
- flwr/common/differential_privacy.py +3 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -4
- flwr/common/exit/exit.py +15 -2
- flwr/common/exit/exit_code.py +39 -10
- flwr/common/exit/exit_handler.py +6 -2
- flwr/common/exit/signal_handler.py +5 -5
- flwr/common/grpc.py +6 -6
- flwr/common/inflatable_protobuf_utils.py +1 -1
- flwr/common/inflatable_utils.py +48 -31
- flwr/common/logger.py +19 -19
- flwr/common/message.py +4 -4
- flwr/common/object_ref.py +7 -7
- flwr/common/record/array.py +6 -6
- flwr/common/record/arrayrecord.py +18 -21
- flwr/common/record/configrecord.py +3 -3
- flwr/common/record/recorddict.py +5 -5
- flwr/common/record/typeddict.py +9 -2
- flwr/common/recorddict_compat.py +7 -10
- flwr/common/retry_invoker.py +20 -20
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
- flwr/common/serde.py +9 -6
- flwr/common/serde_utils.py +2 -2
- flwr/common/telemetry.py +9 -5
- flwr/common/typing.py +59 -43
- flwr/compat/client/app.py +39 -38
- flwr/compat/client/grpc_client/connection.py +13 -13
- flwr/compat/server/app.py +5 -6
- flwr/proto/appio_pb2.py +13 -3
- flwr/proto/appio_pb2.pyi +134 -65
- flwr/proto/appio_pb2_grpc.py +20 -0
- flwr/proto/appio_pb2_grpc.pyi +27 -0
- flwr/proto/clientappio_pb2.py +17 -7
- flwr/proto/clientappio_pb2.pyi +15 -0
- flwr/proto/clientappio_pb2_grpc.py +206 -40
- flwr/proto/clientappio_pb2_grpc.pyi +168 -53
- flwr/proto/control_pb2.py +72 -40
- flwr/proto/control_pb2.pyi +319 -87
- flwr/proto/control_pb2_grpc.py +339 -28
- flwr/proto/control_pb2_grpc.pyi +209 -37
- flwr/proto/error_pb2.py +13 -3
- flwr/proto/error_pb2.pyi +24 -6
- flwr/proto/error_pb2_grpc.py +20 -0
- flwr/proto/error_pb2_grpc.pyi +27 -0
- flwr/proto/fab_pb2.py +24 -10
- flwr/proto/fab_pb2.pyi +68 -20
- flwr/proto/fab_pb2_grpc.py +20 -0
- flwr/proto/fab_pb2_grpc.pyi +27 -0
- flwr/proto/federation_pb2.py +38 -0
- flwr/proto/federation_pb2.pyi +56 -0
- flwr/proto/federation_pb2_grpc.py +24 -0
- flwr/proto/federation_pb2_grpc.pyi +31 -0
- flwr/proto/fleet_pb2.py +45 -27
- flwr/proto/fleet_pb2.pyi +186 -70
- flwr/proto/fleet_pb2_grpc.py +277 -66
- flwr/proto/fleet_pb2_grpc.pyi +201 -55
- flwr/proto/grpcadapter_pb2.py +14 -4
- flwr/proto/grpcadapter_pb2.pyi +38 -16
- flwr/proto/grpcadapter_pb2_grpc.py +35 -4
- flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
- flwr/proto/heartbeat_pb2.py +17 -7
- flwr/proto/heartbeat_pb2.pyi +51 -22
- flwr/proto/heartbeat_pb2_grpc.py +20 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
- flwr/proto/log_pb2.py +13 -3
- flwr/proto/log_pb2.pyi +34 -11
- flwr/proto/log_pb2_grpc.py +20 -0
- flwr/proto/log_pb2_grpc.pyi +27 -0
- flwr/proto/message_pb2.py +15 -5
- flwr/proto/message_pb2.pyi +154 -86
- flwr/proto/message_pb2_grpc.py +20 -0
- flwr/proto/message_pb2_grpc.pyi +27 -0
- flwr/proto/node_pb2.py +16 -4
- flwr/proto/node_pb2.pyi +77 -4
- flwr/proto/node_pb2_grpc.py +20 -0
- flwr/proto/node_pb2_grpc.pyi +27 -0
- flwr/proto/recorddict_pb2.py +13 -3
- flwr/proto/recorddict_pb2.pyi +184 -107
- flwr/proto/recorddict_pb2_grpc.py +20 -0
- flwr/proto/recorddict_pb2_grpc.pyi +27 -0
- flwr/proto/run_pb2.py +40 -31
- flwr/proto/run_pb2.pyi +149 -84
- flwr/proto/run_pb2_grpc.py +20 -0
- flwr/proto/run_pb2_grpc.pyi +27 -0
- flwr/proto/serverappio_pb2.py +13 -3
- flwr/proto/serverappio_pb2.pyi +32 -8
- flwr/proto/serverappio_pb2_grpc.py +246 -65
- flwr/proto/serverappio_pb2_grpc.pyi +221 -85
- flwr/proto/simulationio_pb2.py +16 -8
- flwr/proto/simulationio_pb2.pyi +15 -0
- flwr/proto/simulationio_pb2_grpc.py +162 -41
- flwr/proto/simulationio_pb2_grpc.pyi +149 -55
- flwr/proto/transport_pb2.py +20 -10
- flwr/proto/transport_pb2.pyi +249 -160
- flwr/proto/transport_pb2_grpc.py +35 -4
- flwr/proto/transport_pb2_grpc.pyi +38 -8
- flwr/server/app.py +173 -127
- flwr/server/client_manager.py +4 -5
- flwr/server/client_proxy.py +10 -11
- flwr/server/compat/app.py +4 -5
- flwr/server/compat/app_utils.py +2 -1
- flwr/server/compat/grid_client_proxy.py +10 -12
- flwr/server/compat/legacy_context.py +3 -4
- flwr/server/fleet_event_log_interceptor.py +2 -1
- flwr/server/grid/grid.py +2 -3
- flwr/server/grid/grpc_grid.py +10 -8
- flwr/server/grid/inmemory_grid.py +4 -4
- flwr/server/run_serverapp.py +2 -3
- flwr/server/server.py +34 -39
- flwr/server/server_app.py +7 -8
- flwr/server/server_config.py +1 -2
- flwr/server/serverapp/app.py +34 -28
- flwr/server/serverapp_components.py +4 -5
- flwr/server/strategy/aggregate.py +9 -8
- flwr/server/strategy/bulyan.py +13 -11
- flwr/server/strategy/dp_adaptive_clipping.py +16 -20
- flwr/server/strategy/dp_fixed_clipping.py +12 -17
- flwr/server/strategy/dpfedavg_adaptive.py +3 -4
- flwr/server/strategy/dpfedavg_fixed.py +6 -10
- flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
- flwr/server/strategy/fedadagrad.py +18 -14
- flwr/server/strategy/fedadam.py +16 -14
- flwr/server/strategy/fedavg.py +16 -17
- flwr/server/strategy/fedavg_android.py +15 -15
- flwr/server/strategy/fedavgm.py +21 -18
- flwr/server/strategy/fedmedian.py +2 -3
- flwr/server/strategy/fedopt.py +11 -10
- flwr/server/strategy/fedprox.py +10 -9
- flwr/server/strategy/fedtrimmedavg.py +12 -11
- flwr/server/strategy/fedxgb_bagging.py +13 -11
- flwr/server/strategy/fedxgb_cyclic.py +6 -6
- flwr/server/strategy/fedxgb_nn_avg.py +4 -4
- flwr/server/strategy/fedyogi.py +16 -14
- flwr/server/strategy/krum.py +12 -11
- flwr/server/strategy/qfedavg.py +16 -15
- flwr/server/strategy/strategy.py +6 -9
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +19 -8
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +136 -42
- flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +28 -51
- flwr/server/superlink/fleet/message_handler/message_handler.py +100 -49
- flwr/server/superlink/fleet/rest_rere/rest_api.py +54 -33
- flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +6 -6
- flwr/server/superlink/fleet/vce/vce_api.py +32 -13
- flwr/server/superlink/linkstate/in_memory_linkstate.py +266 -207
- flwr/server/superlink/linkstate/linkstate.py +161 -62
- flwr/server/superlink/linkstate/linkstate_factory.py +24 -6
- flwr/server/superlink/linkstate/sqlite_linkstate.py +698 -638
- flwr/server/superlink/linkstate/utils.py +9 -60
- flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
- flwr/server/superlink/serverappio/serverappio_servicer.py +28 -23
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +19 -14
- flwr/server/superlink/utils.py +4 -6
- flwr/server/typing.py +1 -1
- flwr/server/utils/tensorboard.py +15 -8
- flwr/server/utils/validator.py +2 -3
- flwr/server/workflow/default_workflows.py +5 -5
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +12 -10
- flwr/serverapp/strategy/bulyan.py +16 -15
- flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
- flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
- flwr/serverapp/strategy/fedadagrad.py +10 -11
- flwr/serverapp/strategy/fedadam.py +10 -11
- flwr/serverapp/strategy/fedavg.py +9 -10
- flwr/serverapp/strategy/fedavgm.py +17 -16
- flwr/serverapp/strategy/fedmedian.py +2 -2
- flwr/serverapp/strategy/fedopt.py +10 -11
- flwr/serverapp/strategy/fedprox.py +7 -8
- flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
- flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
- flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
- flwr/serverapp/strategy/fedyogi.py +9 -11
- flwr/serverapp/strategy/krum.py +7 -7
- flwr/serverapp/strategy/multikrum.py +9 -9
- flwr/serverapp/strategy/qfedavg.py +17 -16
- flwr/serverapp/strategy/strategy.py +6 -9
- flwr/serverapp/strategy/strategy_utils.py +7 -8
- flwr/simulation/app.py +46 -42
- flwr/simulation/legacy_app.py +12 -12
- flwr/simulation/ray_transport/ray_actor.py +11 -12
- flwr/simulation/ray_transport/ray_client_proxy.py +12 -13
- flwr/simulation/run_simulation.py +44 -43
- flwr/simulation/simulationio_connection.py +4 -4
- flwr/supercore/cli/flower_superexec.py +3 -4
- flwr/supercore/constant.py +52 -0
- flwr/supercore/corestate/corestate.py +24 -3
- flwr/supercore/corestate/in_memory_corestate.py +138 -0
- flwr/supercore/corestate/sqlite_corestate.py +157 -0
- flwr/supercore/ffs/disk_ffs.py +1 -2
- flwr/supercore/ffs/ffs.py +1 -2
- flwr/supercore/ffs/ffs_factory.py +1 -2
- flwr/{common → supercore}/heartbeat.py +20 -25
- flwr/supercore/object_store/in_memory_object_store.py +1 -6
- flwr/supercore/object_store/object_store.py +1 -2
- flwr/supercore/object_store/object_store_factory.py +27 -8
- flwr/supercore/object_store/sqlite_object_store.py +253 -0
- flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
- flwr/supercore/primitives/asymmetric.py +117 -0
- flwr/supercore/primitives/asymmetric_ed25519.py +175 -0
- flwr/supercore/sqlite_mixin.py +159 -0
- flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
- flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
- flwr/supercore/superexec/run_superexec.py +9 -13
- flwr/supercore/utils.py +20 -0
- flwr/superlink/artifact_provider/artifact_provider.py +1 -2
- flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
- flwr/superlink/auth_plugin/auth_plugin.py +88 -0
- flwr/superlink/auth_plugin/noop_auth_plugin.py +84 -0
- flwr/superlink/federation/__init__.py +24 -0
- flwr/superlink/federation/federation_manager.py +64 -0
- flwr/superlink/federation/noop_federation_manager.py +71 -0
- flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +41 -32
- flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
- flwr/superlink/servicer/control/control_grpc.py +18 -17
- flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
- flwr/superlink/servicer/control/control_servicer.py +239 -63
- flwr/supernode/cli/flower_supernode.py +74 -26
- flwr/supernode/nodestate/in_memory_nodestate.py +60 -49
- flwr/supernode/nodestate/nodestate.py +7 -8
- flwr/supernode/nodestate/nodestate_factory.py +7 -4
- flwr/supernode/runtime/run_clientapp.py +43 -24
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +40 -10
- flwr/supernode/start_client_internal.py +175 -51
- {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/METADATA +8 -8
- flwr-1.24.0.dist-info/RECORD +454 -0
- flwr/common/auth_plugin/auth_plugin.py +0 -149
- flwr/supercore/object_store/utils.py +0 -43
- flwr-1.22.0.dist-info/RECORD +0 -428
- {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/WHEEL +0 -0
- {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/entry_points.txt +0 -0
|
@@ -15,25 +15,31 @@
|
|
|
15
15
|
"""Fleet API gRPC request-response servicer."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
|
|
18
|
+
import threading
|
|
19
|
+
from logging import DEBUG, ERROR, INFO
|
|
19
20
|
|
|
20
21
|
import grpc
|
|
21
22
|
from google.protobuf.json_format import MessageToDict
|
|
22
23
|
|
|
24
|
+
from flwr.common.constant import PUBLIC_KEY_ALREADY_IN_USE_MESSAGE
|
|
23
25
|
from flwr.common.inflatable import UnexpectedObjectContentError
|
|
24
26
|
from flwr.common.logger import log
|
|
25
27
|
from flwr.common.typing import InvalidRunStatusException
|
|
26
28
|
from flwr.proto import fleet_pb2_grpc # pylint: disable=E0611
|
|
27
29
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
28
30
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
31
|
+
ActivateNodeRequest,
|
|
32
|
+
ActivateNodeResponse,
|
|
33
|
+
DeactivateNodeRequest,
|
|
34
|
+
DeactivateNodeResponse,
|
|
33
35
|
PullMessagesRequest,
|
|
34
36
|
PullMessagesResponse,
|
|
35
37
|
PushMessagesRequest,
|
|
36
38
|
PushMessagesResponse,
|
|
39
|
+
RegisterNodeFleetRequest,
|
|
40
|
+
RegisterNodeFleetResponse,
|
|
41
|
+
UnregisterNodeFleetRequest,
|
|
42
|
+
UnregisterNodeFleetResponse,
|
|
37
43
|
)
|
|
38
44
|
from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
39
45
|
SendNodeHeartbeatRequest,
|
|
@@ -63,49 +69,137 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
63
69
|
state_factory: LinkStateFactory,
|
|
64
70
|
ffs_factory: FfsFactory,
|
|
65
71
|
objectstore_factory: ObjectStoreFactory,
|
|
72
|
+
enable_supernode_auth: bool,
|
|
66
73
|
) -> None:
|
|
67
74
|
self.state_factory = state_factory
|
|
68
75
|
self.ffs_factory = ffs_factory
|
|
69
76
|
self.objectstore_factory = objectstore_factory
|
|
77
|
+
self.enable_supernode_auth = enable_supernode_auth
|
|
78
|
+
self.lock = threading.Lock()
|
|
70
79
|
|
|
71
|
-
def
|
|
72
|
-
self, request:
|
|
73
|
-
) ->
|
|
74
|
-
"""."""
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
"
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
state=self.state_factory.state(),
|
|
84
|
-
)
|
|
85
|
-
log(INFO, "[Fleet.CreateNode] Created node_id=%s", response.node.node_id)
|
|
86
|
-
log(DEBUG, "[Fleet.CreateNode] Response: %s", MessageToDict(response))
|
|
87
|
-
return response
|
|
80
|
+
def RegisterNode(
|
|
81
|
+
self, request: RegisterNodeFleetRequest, context: grpc.ServicerContext
|
|
82
|
+
) -> RegisterNodeFleetResponse:
|
|
83
|
+
"""Register a node."""
|
|
84
|
+
# Prevent registration when SuperNode authentication is enabled
|
|
85
|
+
if self.enable_supernode_auth:
|
|
86
|
+
log(ERROR, "SuperNode registration is disabled through Fleet API.")
|
|
87
|
+
context.abort(
|
|
88
|
+
grpc.StatusCode.FAILED_PRECONDITION,
|
|
89
|
+
"SuperNode authentication is enabled. "
|
|
90
|
+
"All SuperNodes must be registered via the CLI.",
|
|
91
|
+
)
|
|
88
92
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
93
|
+
try:
|
|
94
|
+
response = message_handler.register_node(
|
|
95
|
+
request=request,
|
|
96
|
+
state=self.state_factory.state(),
|
|
97
|
+
)
|
|
98
|
+
log(DEBUG, "[Fleet.RegisterNode] Registered node_id=%s", response.node_id)
|
|
99
|
+
return response
|
|
100
|
+
except ValueError:
|
|
101
|
+
# Public key already in use
|
|
102
|
+
# This should NEVER happen due to the public keys should be automatically
|
|
103
|
+
# generated and unique for each SuperNode instance.
|
|
104
|
+
log(
|
|
105
|
+
ERROR,
|
|
106
|
+
"[Fleet.RegisterNode] Registration failed: %s",
|
|
107
|
+
PUBLIC_KEY_ALREADY_IN_USE_MESSAGE,
|
|
108
|
+
)
|
|
109
|
+
context.abort(
|
|
110
|
+
grpc.StatusCode.FAILED_PRECONDITION, PUBLIC_KEY_ALREADY_IN_USE_MESSAGE
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
raise RuntimeError # Make mypy happy
|
|
114
|
+
|
|
115
|
+
def ActivateNode(
|
|
116
|
+
self, request: ActivateNodeRequest, context: grpc.ServicerContext
|
|
117
|
+
) -> ActivateNodeResponse:
|
|
118
|
+
"""Activate a node."""
|
|
119
|
+
try:
|
|
120
|
+
response = message_handler.activate_node(
|
|
121
|
+
request=request,
|
|
122
|
+
state=self.state_factory.state(),
|
|
123
|
+
)
|
|
124
|
+
log(INFO, "[Fleet.ActivateNode] Activated node_id=%s", response.node_id)
|
|
125
|
+
return response
|
|
126
|
+
except message_handler.InvalidHeartbeatIntervalError:
|
|
127
|
+
# Heartbeat interval is invalid
|
|
128
|
+
log(ERROR, "[Fleet.ActivateNode] Invalid heartbeat interval")
|
|
129
|
+
context.abort(
|
|
130
|
+
grpc.StatusCode.INVALID_ARGUMENT, "Invalid heartbeat interval"
|
|
131
|
+
)
|
|
132
|
+
except ValueError as e:
|
|
133
|
+
log(ERROR, "[Fleet.ActivateNode] Activation failed: %s", str(e))
|
|
134
|
+
context.abort(grpc.StatusCode.FAILED_PRECONDITION, str(e))
|
|
135
|
+
|
|
136
|
+
raise RuntimeError # Make mypy happy
|
|
137
|
+
|
|
138
|
+
def DeactivateNode(
|
|
139
|
+
self, request: DeactivateNodeRequest, context: grpc.ServicerContext
|
|
140
|
+
) -> DeactivateNodeResponse:
|
|
141
|
+
"""Deactivate a node."""
|
|
142
|
+
try:
|
|
143
|
+
response = message_handler.deactivate_node(
|
|
144
|
+
request=request,
|
|
145
|
+
state=self.state_factory.state(),
|
|
146
|
+
)
|
|
147
|
+
log(INFO, "[Fleet.DeactivateNode] Deactivated node_id=%s", request.node_id)
|
|
148
|
+
return response
|
|
149
|
+
except ValueError as e:
|
|
150
|
+
log(ERROR, "[Fleet.DeactivateNode] Deactivation failed: %s", str(e))
|
|
151
|
+
context.abort(grpc.StatusCode.FAILED_PRECONDITION, str(e))
|
|
152
|
+
|
|
153
|
+
raise RuntimeError # Make mypy happy
|
|
154
|
+
|
|
155
|
+
def UnregisterNode(
|
|
156
|
+
self, request: UnregisterNodeFleetRequest, context: grpc.ServicerContext
|
|
157
|
+
) -> UnregisterNodeFleetResponse:
|
|
158
|
+
"""Unregister a node."""
|
|
159
|
+
# Prevent unregistration when SuperNode authentication is enabled
|
|
160
|
+
if self.enable_supernode_auth:
|
|
161
|
+
log(ERROR, "SuperNode unregistration is disabled through Fleet API.")
|
|
162
|
+
context.abort(
|
|
163
|
+
grpc.StatusCode.FAILED_PRECONDITION,
|
|
164
|
+
"SuperNode authentication is enabled. "
|
|
165
|
+
"All SuperNodes must be unregistered via the CLI.",
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
try:
|
|
169
|
+
response = message_handler.unregister_node(
|
|
170
|
+
request=request,
|
|
171
|
+
state=self.state_factory.state(),
|
|
172
|
+
)
|
|
173
|
+
log(
|
|
174
|
+
DEBUG, "[Fleet.UnregisterNode] Unregistered node_id=%s", request.node_id
|
|
175
|
+
)
|
|
176
|
+
return response
|
|
177
|
+
except ValueError as e:
|
|
178
|
+
log(
|
|
179
|
+
ERROR,
|
|
180
|
+
"[Fleet.UnregisterNode] Unregistration failed: %s",
|
|
181
|
+
str(e),
|
|
182
|
+
)
|
|
183
|
+
context.abort(grpc.StatusCode.FAILED_PRECONDITION, str(e))
|
|
184
|
+
raise RuntimeError from None # Make mypy happy
|
|
99
185
|
|
|
100
186
|
def SendNodeHeartbeat(
|
|
101
187
|
self, request: SendNodeHeartbeatRequest, context: grpc.ServicerContext
|
|
102
188
|
) -> SendNodeHeartbeatResponse:
|
|
103
189
|
"""."""
|
|
104
190
|
log(DEBUG, "[Fleet.SendNodeHeartbeat] Request: %s", MessageToDict(request))
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
191
|
+
try:
|
|
192
|
+
return message_handler.send_node_heartbeat(
|
|
193
|
+
request=request,
|
|
194
|
+
state=self.state_factory.state(),
|
|
195
|
+
)
|
|
196
|
+
except message_handler.InvalidHeartbeatIntervalError:
|
|
197
|
+
# Heartbeat interval is invalid
|
|
198
|
+
log(ERROR, "[Fleet.SendNodeHeartbeat] Invalid heartbeat interval")
|
|
199
|
+
context.abort(
|
|
200
|
+
grpc.StatusCode.INVALID_ARGUMENT, "Invalid heartbeat interval"
|
|
201
|
+
)
|
|
202
|
+
raise RuntimeError # Make mypy happy
|
|
109
203
|
|
|
110
204
|
def PullMessages(
|
|
111
205
|
self, request: PullMessagesRequest, context: grpc.ServicerContext
|
|
@@ -155,8 +249,8 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
155
249
|
state=self.state_factory.state(),
|
|
156
250
|
store=self.objectstore_factory.store(),
|
|
157
251
|
)
|
|
158
|
-
except InvalidRunStatusException as e:
|
|
159
|
-
abort_grpc_context(e
|
|
252
|
+
except (InvalidRunStatusException, ValueError) as e:
|
|
253
|
+
abort_grpc_context(str(e), context)
|
|
160
254
|
|
|
161
255
|
return res
|
|
162
256
|
|
|
@@ -172,8 +266,8 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
172
266
|
state=self.state_factory.state(),
|
|
173
267
|
store=self.objectstore_factory.store(),
|
|
174
268
|
)
|
|
175
|
-
except InvalidRunStatusException as e:
|
|
176
|
-
abort_grpc_context(e
|
|
269
|
+
except (InvalidRunStatusException, ValueError) as e:
|
|
270
|
+
abort_grpc_context(str(e), context)
|
|
177
271
|
|
|
178
272
|
return res
|
|
179
273
|
|
|
@@ -183,7 +277,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
183
277
|
"""Push an object to the ObjectStore."""
|
|
184
278
|
log(
|
|
185
279
|
DEBUG,
|
|
186
|
-
"[
|
|
280
|
+
"[Fleet.PushObject] Push Object with object_id=%s",
|
|
187
281
|
request.object_id,
|
|
188
282
|
)
|
|
189
283
|
|
|
@@ -208,7 +302,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
208
302
|
"""Pull an object from the ObjectStore."""
|
|
209
303
|
log(
|
|
210
304
|
DEBUG,
|
|
211
|
-
"[
|
|
305
|
+
"[Fleet.PullObject] Pull Object with object_id=%s",
|
|
212
306
|
request.object_id,
|
|
213
307
|
)
|
|
214
308
|
|
flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py}
RENAMED
|
@@ -16,7 +16,8 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import datetime
|
|
19
|
-
from
|
|
19
|
+
from collections.abc import Callable
|
|
20
|
+
from typing import Any, cast
|
|
20
21
|
|
|
21
22
|
import grpc
|
|
22
23
|
from google.protobuf.message import Message as GrpcMessage
|
|
@@ -29,15 +30,12 @@ from flwr.common.constant import (
|
|
|
29
30
|
TIMESTAMP_HEADER,
|
|
30
31
|
TIMESTAMP_TOLERANCE,
|
|
31
32
|
)
|
|
32
|
-
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
33
|
-
bytes_to_public_key,
|
|
34
|
-
verify_signature,
|
|
35
|
-
)
|
|
36
33
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
37
|
-
|
|
38
|
-
|
|
34
|
+
ActivateNodeRequest,
|
|
35
|
+
RegisterNodeFleetRequest,
|
|
39
36
|
)
|
|
40
37
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
38
|
+
from flwr.supercore.primitives.asymmetric import bytes_to_public_key, verify_signature
|
|
41
39
|
|
|
42
40
|
MIN_TIMESTAMP_DIFF = -SYSTEM_TIME_TOLERANCE
|
|
43
41
|
MAX_TIMESTAMP_DIFF = TIMESTAMP_TOLERANCE + SYSTEM_TIME_TOLERANCE
|
|
@@ -53,22 +51,17 @@ def _unary_unary_rpc_terminator(
|
|
|
53
51
|
return grpc.unary_unary_rpc_method_handler(terminate)
|
|
54
52
|
|
|
55
53
|
|
|
56
|
-
class
|
|
54
|
+
class NodeAuthServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
57
55
|
"""Server interceptor for node authentication.
|
|
58
56
|
|
|
59
57
|
Parameters
|
|
60
58
|
----------
|
|
61
59
|
state_factory : LinkStateFactory
|
|
62
60
|
A factory for creating new instances of LinkState.
|
|
63
|
-
auto_auth : bool (default: False)
|
|
64
|
-
If True, nodes are authenticated without requiring their public keys to be
|
|
65
|
-
pre-stored in the LinkState. If False, only nodes with pre-stored public keys
|
|
66
|
-
can be authenticated.
|
|
67
61
|
"""
|
|
68
62
|
|
|
69
|
-
def __init__(self, state_factory: LinkStateFactory
|
|
63
|
+
def __init__(self, state_factory: LinkStateFactory):
|
|
70
64
|
self.state_factory = state_factory
|
|
71
|
-
self.auto_auth = auto_auth
|
|
72
65
|
|
|
73
66
|
def intercept_service( # pylint: disable=too-many-return-statements
|
|
74
67
|
self,
|
|
@@ -85,7 +78,6 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
85
78
|
if not handler_call_details.method.startswith("/flwr.proto.Fleet/"):
|
|
86
79
|
return continuation(handler_call_details)
|
|
87
80
|
|
|
88
|
-
state = self.state_factory.state()
|
|
89
81
|
metadata_dict = dict(handler_call_details.invocation_metadata)
|
|
90
82
|
|
|
91
83
|
# Retrieve info from the metadata
|
|
@@ -96,11 +88,6 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
96
88
|
except KeyError:
|
|
97
89
|
return _unary_unary_rpc_terminator("Missing authentication metadata")
|
|
98
90
|
|
|
99
|
-
if not self.auto_auth:
|
|
100
|
-
# Abort the RPC call if the node public key is not found
|
|
101
|
-
if node_pk_bytes not in state.get_node_public_keys():
|
|
102
|
-
return _unary_unary_rpc_terminator("Public key not recognized")
|
|
103
|
-
|
|
104
91
|
# Verify the signature
|
|
105
92
|
node_pk = bytes_to_public_key(node_pk_bytes)
|
|
106
93
|
if not verify_signature(node_pk, timestamp_iso.encode("ascii"), signature):
|
|
@@ -113,50 +100,40 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
113
100
|
if not MIN_TIMESTAMP_DIFF < time_diff.total_seconds() < MAX_TIMESTAMP_DIFF:
|
|
114
101
|
return _unary_unary_rpc_terminator("Invalid timestamp")
|
|
115
102
|
|
|
116
|
-
# Continue the RPC call
|
|
117
|
-
expected_node_id = state.get_node_id(node_pk_bytes)
|
|
118
|
-
if not handler_call_details.method.endswith("CreateNode"):
|
|
119
|
-
# All calls, except for `CreateNode`, must provide a public key that is
|
|
120
|
-
# already mapped to a `node_id` (in `LinkState`)
|
|
121
|
-
if expected_node_id is None:
|
|
122
|
-
return _unary_unary_rpc_terminator("Invalid node ID")
|
|
123
|
-
# One of the method handlers in
|
|
103
|
+
# Continue the RPC call: One of the method handlers in
|
|
124
104
|
# `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer`
|
|
125
105
|
method_handler: grpc.RpcMethodHandler = continuation(handler_call_details)
|
|
126
|
-
return self._wrap_method_handler(
|
|
127
|
-
method_handler, expected_node_id, node_pk_bytes
|
|
128
|
-
)
|
|
106
|
+
return self._wrap_method_handler(method_handler, node_pk_bytes)
|
|
129
107
|
|
|
130
108
|
def _wrap_method_handler(
|
|
131
109
|
self,
|
|
132
110
|
method_handler: grpc.RpcMethodHandler,
|
|
133
|
-
|
|
134
|
-
node_public_key: bytes,
|
|
111
|
+
expected_public_key: bytes,
|
|
135
112
|
) -> grpc.RpcMethodHandler:
|
|
136
113
|
def _generic_method_handler(
|
|
137
114
|
request: GrpcMessage,
|
|
138
115
|
context: grpc.ServicerContext,
|
|
139
116
|
) -> GrpcMessage:
|
|
140
|
-
#
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
117
|
+
# Note: This function runs in a different thread
|
|
118
|
+
# than the `intercept_service` function.
|
|
119
|
+
|
|
120
|
+
# Retrieve the public key
|
|
121
|
+
if isinstance(request, (RegisterNodeFleetRequest | ActivateNodeRequest)):
|
|
122
|
+
actual_public_key = request.public_key
|
|
123
|
+
else:
|
|
124
|
+
if hasattr(request, "node"):
|
|
125
|
+
node_id = request.node.node_id
|
|
126
|
+
else:
|
|
127
|
+
node_id = request.node_id # type: ignore[attr-defined]
|
|
128
|
+
actual_public_key = self.state_factory.state().get_node_public_key(
|
|
129
|
+
node_id
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Verify the public key
|
|
133
|
+
if actual_public_key != expected_public_key:
|
|
134
|
+
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Invalid node ID")
|
|
147
135
|
|
|
148
136
|
response: GrpcMessage = method_handler.unary_unary(request, context)
|
|
149
|
-
|
|
150
|
-
# Set the public key after a successful CreateNode request
|
|
151
|
-
if isinstance(response, CreateNodeResponse):
|
|
152
|
-
state = self.state_factory.state()
|
|
153
|
-
try:
|
|
154
|
-
state.set_node_public_key(response.node.node_id, node_public_key)
|
|
155
|
-
except ValueError as e:
|
|
156
|
-
# Remove newly created node if setting the public key fails
|
|
157
|
-
state.delete_node(response.node.node_id)
|
|
158
|
-
context.abort(grpc.StatusCode.UNAUTHENTICATED, str(e))
|
|
159
|
-
|
|
160
137
|
return response
|
|
161
138
|
|
|
162
139
|
return grpc.unary_unary_rpc_method_handler(
|
|
@@ -15,29 +15,38 @@
|
|
|
15
15
|
"""Fleet API message handlers."""
|
|
16
16
|
|
|
17
17
|
from logging import ERROR
|
|
18
|
-
from typing import Optional
|
|
19
18
|
|
|
20
19
|
from flwr.common import Message, log
|
|
21
|
-
from flwr.common.constant import
|
|
20
|
+
from flwr.common.constant import (
|
|
21
|
+
HEARTBEAT_MAX_INTERVAL,
|
|
22
|
+
HEARTBEAT_MIN_INTERVAL,
|
|
23
|
+
NOOP_ACCOUNT_NAME,
|
|
24
|
+
NOOP_FLWR_AID,
|
|
25
|
+
Status,
|
|
26
|
+
)
|
|
22
27
|
from flwr.common.inflatable import UnexpectedObjectContentError
|
|
23
28
|
from flwr.common.serde import (
|
|
24
29
|
fab_to_proto,
|
|
25
30
|
message_from_proto,
|
|
26
31
|
message_to_proto,
|
|
27
|
-
|
|
32
|
+
run_to_proto,
|
|
28
33
|
)
|
|
29
|
-
from flwr.common.typing import Fab, InvalidRunStatusException
|
|
34
|
+
from flwr.common.typing import Fab, InvalidRunStatusException, Run
|
|
30
35
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
31
36
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
37
|
+
ActivateNodeRequest,
|
|
38
|
+
ActivateNodeResponse,
|
|
39
|
+
DeactivateNodeRequest,
|
|
40
|
+
DeactivateNodeResponse,
|
|
36
41
|
PullMessagesRequest,
|
|
37
42
|
PullMessagesResponse,
|
|
38
43
|
PushMessagesRequest,
|
|
39
44
|
PushMessagesResponse,
|
|
40
45
|
Reconnect,
|
|
46
|
+
RegisterNodeFleetRequest,
|
|
47
|
+
RegisterNodeFleetResponse,
|
|
48
|
+
UnregisterNodeFleetRequest,
|
|
49
|
+
UnregisterNodeFleetResponse,
|
|
41
50
|
)
|
|
42
51
|
from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
43
52
|
SendNodeHeartbeatRequest,
|
|
@@ -51,38 +60,59 @@ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
|
|
51
60
|
PushObjectRequest,
|
|
52
61
|
PushObjectResponse,
|
|
53
62
|
)
|
|
54
|
-
from flwr.proto.
|
|
55
|
-
from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
56
|
-
GetRunRequest,
|
|
57
|
-
GetRunResponse,
|
|
58
|
-
Run,
|
|
59
|
-
)
|
|
63
|
+
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
60
64
|
from flwr.server.superlink.linkstate import LinkState
|
|
61
65
|
from flwr.server.superlink.utils import check_abort
|
|
62
66
|
from flwr.supercore.ffs import Ffs
|
|
63
67
|
from flwr.supercore.object_store import NoObjectInStoreError, ObjectStore
|
|
64
|
-
from flwr.supercore.object_store.utils import store_mapping_and_register_objects
|
|
65
68
|
|
|
66
69
|
|
|
67
|
-
|
|
68
|
-
|
|
70
|
+
class InvalidHeartbeatIntervalError(Exception):
|
|
71
|
+
"""Invalid heartbeat interval exception."""
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def register_node(
|
|
75
|
+
request: RegisterNodeFleetRequest,
|
|
69
76
|
state: LinkState,
|
|
70
|
-
) ->
|
|
71
|
-
"""."""
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
return CreateNodeResponse(node=Node(node_id=node_id))
|
|
77
|
+
) -> RegisterNodeFleetResponse:
|
|
78
|
+
"""Register a node (Fleet API only)."""
|
|
79
|
+
node_id = state.create_node(NOOP_FLWR_AID, NOOP_ACCOUNT_NAME, request.public_key, 0)
|
|
80
|
+
return RegisterNodeFleetResponse(node_id=node_id)
|
|
75
81
|
|
|
76
82
|
|
|
77
|
-
def
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
83
|
+
def activate_node(
|
|
84
|
+
request: ActivateNodeRequest,
|
|
85
|
+
state: LinkState,
|
|
86
|
+
) -> ActivateNodeResponse:
|
|
87
|
+
"""Activate a node."""
|
|
88
|
+
node_id = state.get_node_id_by_public_key(request.public_key)
|
|
89
|
+
if node_id is None:
|
|
90
|
+
raise ValueError("No SuperNode found with the given public key.")
|
|
91
|
+
_validate_heartbeat_interval(request.heartbeat_interval)
|
|
92
|
+
if not state.activate_node(node_id, request.heartbeat_interval):
|
|
93
|
+
raise ValueError(f"SuperNode with node ID {node_id} could not be activated.")
|
|
94
|
+
return ActivateNodeResponse(node_id=node_id)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def deactivate_node(
|
|
98
|
+
request: DeactivateNodeRequest,
|
|
99
|
+
state: LinkState,
|
|
100
|
+
) -> DeactivateNodeResponse:
|
|
101
|
+
"""Deactivate a node."""
|
|
102
|
+
if not state.deactivate_node(request.node_id):
|
|
103
|
+
raise ValueError(
|
|
104
|
+
f"SuperNode with node ID {request.node_id} could not be deactivated."
|
|
105
|
+
)
|
|
106
|
+
return DeactivateNodeResponse()
|
|
82
107
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
108
|
+
|
|
109
|
+
def unregister_node(
|
|
110
|
+
request: UnregisterNodeFleetRequest,
|
|
111
|
+
state: LinkState,
|
|
112
|
+
) -> UnregisterNodeFleetResponse:
|
|
113
|
+
"""Unregister a node (Fleet API only)."""
|
|
114
|
+
state.delete_node(NOOP_FLWR_AID, request.node_id)
|
|
115
|
+
return UnregisterNodeFleetResponse()
|
|
86
116
|
|
|
87
117
|
|
|
88
118
|
def send_node_heartbeat(
|
|
@@ -90,6 +120,7 @@ def send_node_heartbeat(
|
|
|
90
120
|
state: LinkState, # pylint: disable=unused-argument
|
|
91
121
|
) -> SendNodeHeartbeatResponse:
|
|
92
122
|
"""."""
|
|
123
|
+
_validate_heartbeat_interval(request.heartbeat_interval)
|
|
93
124
|
res = state.acknowledge_node_heartbeat(
|
|
94
125
|
request.node.node_id, request.heartbeat_interval
|
|
95
126
|
)
|
|
@@ -137,10 +168,11 @@ def push_messages(
|
|
|
137
168
|
"""Push Messages handler."""
|
|
138
169
|
# Convert Message from proto
|
|
139
170
|
msg = message_from_proto(message_proto=request.messages_list[0])
|
|
171
|
+
run_id = msg.metadata.run_id
|
|
140
172
|
|
|
141
173
|
# Abort if the run is not running
|
|
142
174
|
abort_msg = check_abort(
|
|
143
|
-
|
|
175
|
+
run_id,
|
|
144
176
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
145
177
|
state,
|
|
146
178
|
store,
|
|
@@ -148,11 +180,12 @@ def push_messages(
|
|
|
148
180
|
if abort_msg:
|
|
149
181
|
raise InvalidRunStatusException(abort_msg)
|
|
150
182
|
|
|
151
|
-
# Store Message in State
|
|
152
|
-
message_id: Optional[str] = state.store_message_res(message=msg)
|
|
153
|
-
|
|
154
183
|
# Store Message object to descendants mapping and preregister objects
|
|
155
|
-
objects_to_push =
|
|
184
|
+
objects_to_push: set[str] = set()
|
|
185
|
+
for object_tree in request.message_object_trees:
|
|
186
|
+
objects_to_push |= set(store.preregister(run_id, object_tree))
|
|
187
|
+
# Store Message in State
|
|
188
|
+
message_id: str | None = state.store_message_res(message=msg)
|
|
156
189
|
|
|
157
190
|
# Build response
|
|
158
191
|
response = PushMessagesResponse(
|
|
@@ -167,10 +200,8 @@ def get_run(
|
|
|
167
200
|
request: GetRunRequest, state: LinkState, store: ObjectStore
|
|
168
201
|
) -> GetRunResponse:
|
|
169
202
|
"""Get run information."""
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
if run is None:
|
|
173
|
-
return GetRunResponse()
|
|
203
|
+
# Validate that the requesting SuperNode is part of the federation
|
|
204
|
+
run = _validate_node_in_federation(state, request.node.node_id, request.run_id)
|
|
174
205
|
|
|
175
206
|
# Abort if the run is not running
|
|
176
207
|
abort_msg = check_abort(
|
|
@@ -182,21 +213,16 @@ def get_run(
|
|
|
182
213
|
if abort_msg:
|
|
183
214
|
raise InvalidRunStatusException(abort_msg)
|
|
184
215
|
|
|
185
|
-
return GetRunResponse(
|
|
186
|
-
run=Run(
|
|
187
|
-
run_id=run.run_id,
|
|
188
|
-
fab_id=run.fab_id,
|
|
189
|
-
fab_version=run.fab_version,
|
|
190
|
-
override_config=user_config_to_proto(run.override_config),
|
|
191
|
-
fab_hash=run.fab_hash,
|
|
192
|
-
)
|
|
193
|
-
)
|
|
216
|
+
return GetRunResponse(run=run_to_proto(run))
|
|
194
217
|
|
|
195
218
|
|
|
196
219
|
def get_fab(
|
|
197
220
|
request: GetFabRequest, ffs: Ffs, state: LinkState, store: ObjectStore
|
|
198
221
|
) -> GetFabResponse:
|
|
199
222
|
"""Get FAB."""
|
|
223
|
+
# Validate that the requesting SuperNode is part of the federation
|
|
224
|
+
_validate_node_in_federation(state, request.node.node_id, request.run_id)
|
|
225
|
+
|
|
200
226
|
# Abort if the run is not running
|
|
201
227
|
abort_msg = check_abort(
|
|
202
228
|
request.run_id,
|
|
@@ -208,7 +234,7 @@ def get_fab(
|
|
|
208
234
|
raise InvalidRunStatusException(abort_msg)
|
|
209
235
|
|
|
210
236
|
if result := ffs.get(request.hash_str):
|
|
211
|
-
fab = Fab(request.hash_str, result[0])
|
|
237
|
+
fab = Fab(request.hash_str, result[0], result[1])
|
|
212
238
|
return GetFabResponse(fab=fab_to_proto(fab))
|
|
213
239
|
|
|
214
240
|
raise ValueError(f"Found no FAB with hash: {request.hash_str}")
|
|
@@ -284,3 +310,28 @@ def confirm_message_received(
|
|
|
284
310
|
store.delete(request.message_object_id)
|
|
285
311
|
|
|
286
312
|
return ConfirmMessageReceivedResponse()
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def _validate_heartbeat_interval(interval: float) -> None:
|
|
316
|
+
"""Raise if heartbeat interval is out of bounds."""
|
|
317
|
+
if not HEARTBEAT_MIN_INTERVAL <= interval <= HEARTBEAT_MAX_INTERVAL:
|
|
318
|
+
raise InvalidHeartbeatIntervalError(
|
|
319
|
+
f"Heartbeat interval {interval} is out of bounds "
|
|
320
|
+
f"[{HEARTBEAT_MIN_INTERVAL}, {HEARTBEAT_MAX_INTERVAL}]."
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def _validate_node_in_federation(
|
|
325
|
+
state: LinkState,
|
|
326
|
+
node_id: int,
|
|
327
|
+
run_id: int,
|
|
328
|
+
) -> Run:
|
|
329
|
+
"""Raise if the requesting SuperNode is not part of the federation the run belongs
|
|
330
|
+
to."""
|
|
331
|
+
run = state.get_run(run_id)
|
|
332
|
+
if not run:
|
|
333
|
+
raise ValueError(f"Run ID not found: {run_id}")
|
|
334
|
+
|
|
335
|
+
if not state.federation_manager.has_node(node_id, run.federation):
|
|
336
|
+
raise ValueError(f"SuperNode is not part of the federation '{run.federation}'.")
|
|
337
|
+
return run
|