flwr 1.19.0__py3-none-any.whl → 1.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +4 -1
- flwr/app/__init__.py +28 -0
- flwr/app/exception.py +31 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +4 -4
- flwr/cli/build.py +15 -5
- flwr/cli/cli_user_auth_interceptor.py +1 -1
- flwr/cli/config_utils.py +3 -3
- flwr/cli/constant.py +25 -8
- flwr/cli/log.py +9 -9
- flwr/cli/login/login.py +3 -3
- flwr/cli/ls.py +5 -5
- flwr/cli/new/new.py +23 -4
- flwr/cli/new/templates/app/README.flowertune.md.tpl +2 -0
- flwr/cli/new/templates/app/README.md.tpl +5 -0
- flwr/cli/new/templates/app/code/__init__.pytorch_msg_api.py.tpl +1 -0
- flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +80 -0
- flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +41 -0
- flwr/cli/new/templates/app/code/task.pytorch_msg_api.py.tpl +98 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -3
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +13 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +21 -2
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +19 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +20 -3
- flwr/cli/new/templates/app/pyproject.pytorch_msg_api.toml.tpl +53 -0
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +18 -1
- flwr/cli/run/run.py +53 -50
- flwr/cli/stop.py +7 -4
- flwr/cli/utils.py +29 -11
- flwr/client/grpc_adapter_client/connection.py +11 -4
- flwr/client/grpc_rere_client/connection.py +93 -129
- flwr/client/rest_client/connection.py +134 -164
- flwr/clientapp/__init__.py +10 -0
- flwr/clientapp/mod/__init__.py +26 -0
- flwr/clientapp/mod/centraldp_mods.py +132 -0
- flwr/common/args.py +20 -6
- flwr/common/auth_plugin/__init__.py +4 -4
- flwr/common/auth_plugin/auth_plugin.py +7 -7
- flwr/common/constant.py +26 -5
- flwr/common/event_log_plugin/event_log_plugin.py +1 -1
- flwr/common/exit/__init__.py +4 -0
- flwr/common/exit/exit.py +8 -1
- flwr/common/exit/exit_code.py +42 -8
- flwr/common/exit/exit_handler.py +62 -0
- flwr/common/{exit_handlers.py → exit/signal_handler.py} +20 -37
- flwr/common/grpc.py +1 -1
- flwr/common/{inflatable_grpc_utils.py → inflatable_protobuf_utils.py} +52 -10
- flwr/common/inflatable_utils.py +191 -24
- flwr/common/logger.py +1 -1
- flwr/common/record/array.py +101 -22
- flwr/common/record/arraychunk.py +59 -0
- flwr/common/retry_invoker.py +30 -11
- flwr/common/serde.py +0 -28
- flwr/common/telemetry.py +4 -0
- flwr/compat/client/app.py +14 -31
- flwr/compat/server/app.py +2 -2
- flwr/proto/appio_pb2.py +51 -0
- flwr/proto/appio_pb2.pyi +195 -0
- flwr/proto/appio_pb2_grpc.py +4 -0
- flwr/proto/appio_pb2_grpc.pyi +4 -0
- flwr/proto/clientappio_pb2.py +4 -19
- flwr/proto/clientappio_pb2.pyi +0 -125
- flwr/proto/clientappio_pb2_grpc.py +269 -29
- flwr/proto/clientappio_pb2_grpc.pyi +114 -21
- flwr/proto/control_pb2.py +62 -0
- flwr/proto/{exec_pb2_grpc.py → control_pb2_grpc.py} +54 -54
- flwr/proto/{exec_pb2_grpc.pyi → control_pb2_grpc.pyi} +28 -28
- flwr/proto/fleet_pb2.py +12 -20
- flwr/proto/fleet_pb2.pyi +6 -36
- flwr/proto/serverappio_pb2.py +8 -31
- flwr/proto/serverappio_pb2.pyi +0 -152
- flwr/proto/serverappio_pb2_grpc.py +107 -38
- flwr/proto/serverappio_pb2_grpc.pyi +47 -20
- flwr/proto/simulationio_pb2.py +4 -11
- flwr/proto/simulationio_pb2.pyi +0 -58
- flwr/proto/simulationio_pb2_grpc.py +129 -27
- flwr/proto/simulationio_pb2_grpc.pyi +52 -13
- flwr/server/app.py +130 -153
- flwr/server/fleet_event_log_interceptor.py +4 -0
- flwr/server/grid/grpc_grid.py +94 -54
- flwr/server/grid/inmemory_grid.py +1 -0
- flwr/server/serverapp/app.py +165 -144
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +8 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -5
- flwr/server/superlink/fleet/message_handler/message_handler.py +10 -16
- flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -1
- flwr/server/superlink/fleet/vce/vce_api.py +6 -6
- flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -0
- flwr/server/superlink/linkstate/linkstate.py +2 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +45 -0
- flwr/server/superlink/serverappio/serverappio_grpc.py +2 -2
- flwr/server/superlink/serverappio/serverappio_servicer.py +95 -48
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/superlink/simulation/simulationio_servicer.py +98 -22
- flwr/server/superlink/utils.py +0 -35
- flwr/serverapp/__init__.py +12 -0
- flwr/serverapp/dp_fixed_clipping.py +352 -0
- flwr/serverapp/exception.py +38 -0
- flwr/serverapp/strategy/__init__.py +38 -0
- flwr/serverapp/strategy/dp_fixed_clipping.py +352 -0
- flwr/serverapp/strategy/fedadagrad.py +162 -0
- flwr/serverapp/strategy/fedadam.py +181 -0
- flwr/serverapp/strategy/fedavg.py +295 -0
- flwr/serverapp/strategy/fedopt.py +218 -0
- flwr/serverapp/strategy/fedyogi.py +173 -0
- flwr/serverapp/strategy/result.py +105 -0
- flwr/serverapp/strategy/strategy.py +285 -0
- flwr/serverapp/strategy/strategy_utils.py +251 -0
- flwr/serverapp/strategy/strategy_utils_tests.py +304 -0
- flwr/simulation/app.py +159 -154
- flwr/simulation/run_simulation.py +17 -0
- flwr/supercore/app_utils.py +58 -0
- flwr/supercore/cli/__init__.py +22 -0
- flwr/supercore/cli/flower_superexec.py +141 -0
- flwr/supercore/corestate/__init__.py +22 -0
- flwr/supercore/corestate/corestate.py +81 -0
- flwr/{server/superlink → supercore}/ffs/disk_ffs.py +1 -1
- flwr/supercore/grpc_health/__init__.py +25 -0
- flwr/supercore/grpc_health/health_server.py +53 -0
- flwr/supercore/grpc_health/simple_health_servicer.py +38 -0
- flwr/supercore/license_plugin/__init__.py +22 -0
- flwr/supercore/license_plugin/license_plugin.py +26 -0
- flwr/supercore/object_store/in_memory_object_store.py +31 -31
- flwr/supercore/object_store/object_store.py +20 -42
- flwr/supercore/object_store/utils.py +43 -0
- flwr/{superexec → supercore/superexec}/__init__.py +1 -1
- flwr/supercore/superexec/plugin/__init__.py +28 -0
- flwr/supercore/superexec/plugin/base_exec_plugin.py +53 -0
- flwr/supercore/superexec/plugin/clientapp_exec_plugin.py +28 -0
- flwr/supercore/superexec/plugin/exec_plugin.py +71 -0
- flwr/supercore/superexec/plugin/serverapp_exec_plugin.py +28 -0
- flwr/supercore/superexec/plugin/simulation_exec_plugin.py +28 -0
- flwr/supercore/superexec/run_superexec.py +185 -0
- flwr/supercore/utils.py +32 -0
- flwr/superlink/servicer/__init__.py +15 -0
- flwr/superlink/servicer/control/__init__.py +22 -0
- flwr/{superexec/exec_event_log_interceptor.py → superlink/servicer/control/control_event_log_interceptor.py} +9 -5
- flwr/{superexec/exec_grpc.py → superlink/servicer/control/control_grpc.py} +39 -28
- flwr/superlink/servicer/control/control_license_interceptor.py +82 -0
- flwr/{superexec/exec_servicer.py → superlink/servicer/control/control_servicer.py} +79 -31
- flwr/{superexec/exec_user_auth_interceptor.py → superlink/servicer/control/control_user_auth_interceptor.py} +18 -10
- flwr/supernode/cli/flower_supernode.py +3 -7
- flwr/supernode/cli/flwr_clientapp.py +20 -16
- flwr/supernode/nodestate/in_memory_nodestate.py +13 -4
- flwr/supernode/nodestate/nodestate.py +3 -44
- flwr/supernode/runtime/run_clientapp.py +129 -115
- flwr/supernode/servicer/clientappio/__init__.py +1 -3
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +217 -165
- flwr/supernode/start_client_internal.py +205 -148
- {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/METADATA +5 -3
- {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/RECORD +161 -117
- {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/entry_points.txt +1 -0
- flwr/common/inflatable_rest_utils.py +0 -99
- flwr/proto/exec_pb2.py +0 -62
- flwr/superexec/app.py +0 -45
- flwr/superexec/deployment.py +0 -192
- flwr/superexec/executor.py +0 -100
- flwr/superexec/simulation.py +0 -130
- /flwr/proto/{exec_pb2.pyi → control_pb2.pyi} +0 -0
- /flwr/{server/superlink → supercore}/ffs/__init__.py +0 -0
- /flwr/{server/superlink → supercore}/ffs/ffs.py +0 -0
- /flwr/{server/superlink → supercore}/ffs/ffs_factory.py +0 -0
- {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/WHEEL +0 -0
|
@@ -81,12 +81,9 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
81
81
|
metadata sent by the node. Continue RPC call if node is authenticated, else,
|
|
82
82
|
terminate RPC call by setting context to abort.
|
|
83
83
|
"""
|
|
84
|
-
#
|
|
84
|
+
# Only apply to Fleet service
|
|
85
85
|
if not handler_call_details.method.startswith("/flwr.proto.Fleet/"):
|
|
86
|
-
return
|
|
87
|
-
"This request should be sent to a different service.",
|
|
88
|
-
grpc.StatusCode.FAILED_PRECONDITION,
|
|
89
|
-
)
|
|
86
|
+
return continuation(handler_call_details)
|
|
90
87
|
|
|
91
88
|
state = self.state_factory.state()
|
|
92
89
|
metadata_dict = dict(handler_call_details.invocation_metadata)
|
|
@@ -46,7 +46,6 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
|
46
46
|
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
|
47
47
|
ConfirmMessageReceivedRequest,
|
|
48
48
|
ConfirmMessageReceivedResponse,
|
|
49
|
-
ObjectIDs,
|
|
50
49
|
PullObjectRequest,
|
|
51
50
|
PullObjectResponse,
|
|
52
51
|
PushObjectRequest,
|
|
@@ -58,12 +57,11 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
|
58
57
|
GetRunResponse,
|
|
59
58
|
Run,
|
|
60
59
|
)
|
|
61
|
-
from flwr.server.superlink.ffs.ffs import Ffs
|
|
62
60
|
from flwr.server.superlink.linkstate import LinkState
|
|
63
61
|
from flwr.server.superlink.utils import check_abort
|
|
62
|
+
from flwr.supercore.ffs import Ffs
|
|
64
63
|
from flwr.supercore.object_store import NoObjectInStoreError, ObjectStore
|
|
65
|
-
|
|
66
|
-
from ...utils import store_mapping_and_register_objects
|
|
64
|
+
from flwr.supercore.object_store.utils import store_mapping_and_register_objects
|
|
67
65
|
|
|
68
66
|
|
|
69
67
|
def create_node(
|
|
@@ -113,25 +111,22 @@ def pull_messages(
|
|
|
113
111
|
|
|
114
112
|
# Convert to Messages
|
|
115
113
|
msg_proto = []
|
|
116
|
-
|
|
114
|
+
trees = []
|
|
117
115
|
for msg in message_list:
|
|
118
116
|
try:
|
|
119
|
-
|
|
120
|
-
|
|
117
|
+
# Retrieve Message object tree from ObjectStore
|
|
121
118
|
msg_object_id = msg.metadata.message_id
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
)
|
|
119
|
+
obj_tree = store.get_object_tree(msg_object_id)
|
|
120
|
+
|
|
121
|
+
# Add Message and its object tree to the response
|
|
122
|
+
msg_proto.append(message_to_proto(msg))
|
|
123
|
+
trees.append(obj_tree)
|
|
127
124
|
except NoObjectInStoreError as e:
|
|
128
125
|
log(ERROR, e.message)
|
|
129
126
|
# Delete message ins from state
|
|
130
127
|
state.delete_messages(message_ins_ids={msg_object_id})
|
|
131
128
|
|
|
132
|
-
return PullMessagesResponse(
|
|
133
|
-
messages_list=msg_proto, objects_to_pull=objects_to_pull
|
|
134
|
-
)
|
|
129
|
+
return PullMessagesResponse(messages_list=msg_proto, message_object_trees=trees)
|
|
135
130
|
|
|
136
131
|
|
|
137
132
|
def push_messages(
|
|
@@ -287,6 +282,5 @@ def confirm_message_received(
|
|
|
287
282
|
|
|
288
283
|
# Delete the message object
|
|
289
284
|
store.delete(request.message_object_id)
|
|
290
|
-
store.delete_message_descendant_ids(request.message_object_id)
|
|
291
285
|
|
|
292
286
|
return ConfirmMessageReceivedResponse()
|
|
@@ -47,10 +47,9 @@ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
|
|
47
47
|
PushObjectResponse,
|
|
48
48
|
)
|
|
49
49
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
50
|
-
from flwr.server.superlink.ffs.ffs import Ffs
|
|
51
|
-
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
52
50
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
53
51
|
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
|
52
|
+
from flwr.supercore.ffs import Ffs, FfsFactory
|
|
54
53
|
from flwr.supercore.object_store import ObjectStore, ObjectStoreFactory
|
|
55
54
|
|
|
56
55
|
try:
|
|
@@ -161,6 +161,7 @@ class RayBackend(Backend):
|
|
|
161
161
|
"Call the backend's `build()` method before processing messages."
|
|
162
162
|
)
|
|
163
163
|
|
|
164
|
+
future = None
|
|
164
165
|
try:
|
|
165
166
|
# Submit a task to the pool
|
|
166
167
|
future = self.pool.submit(
|
|
@@ -183,7 +184,8 @@ class RayBackend(Backend):
|
|
|
183
184
|
self.__class__.__name__,
|
|
184
185
|
)
|
|
185
186
|
# add actor back into pool
|
|
186
|
-
|
|
187
|
+
if future is not None:
|
|
188
|
+
self.pool.add_actor_back_to_pool(future)
|
|
187
189
|
raise ex
|
|
188
190
|
|
|
189
191
|
def terminate(self) -> None:
|
|
@@ -23,7 +23,6 @@ from concurrent.futures import ThreadPoolExecutor
|
|
|
23
23
|
from logging import DEBUG, ERROR, INFO, WARN
|
|
24
24
|
from pathlib import Path
|
|
25
25
|
from queue import Empty, Queue
|
|
26
|
-
from time import sleep
|
|
27
26
|
from typing import Callable, Optional
|
|
28
27
|
from uuid import uuid4
|
|
29
28
|
|
|
@@ -153,7 +152,7 @@ def add_messages_to_queue(
|
|
|
153
152
|
message_ins_list = state.get_message_ins(node_id=node_id, limit=1)
|
|
154
153
|
for msg in message_ins_list:
|
|
155
154
|
queue.put(msg)
|
|
156
|
-
|
|
155
|
+
f_stop.wait(0.1)
|
|
157
156
|
|
|
158
157
|
|
|
159
158
|
def put_message_into_state(
|
|
@@ -182,6 +181,7 @@ def run_api(
|
|
|
182
181
|
messageins_queue: Queue[Message] = Queue()
|
|
183
182
|
messageres_queue: Queue[Message] = Queue()
|
|
184
183
|
|
|
184
|
+
backend = None
|
|
185
185
|
try:
|
|
186
186
|
|
|
187
187
|
# Instantiate backend
|
|
@@ -236,16 +236,16 @@ def run_api(
|
|
|
236
236
|
log(ERROR, traceback.format_exc())
|
|
237
237
|
log(WARN, "Stopping Simulation Engine.")
|
|
238
238
|
|
|
239
|
-
# Manually trigger stopping event
|
|
240
|
-
f_stop.set()
|
|
241
|
-
|
|
242
239
|
# Raise exception
|
|
243
240
|
raise RuntimeError("Simulation Engine crashed.") from ex
|
|
244
241
|
|
|
245
242
|
finally:
|
|
243
|
+
# Manually trigger stopping event
|
|
244
|
+
f_stop.set()
|
|
246
245
|
|
|
247
246
|
# Terminate backend
|
|
248
|
-
backend
|
|
247
|
+
if backend is not None:
|
|
248
|
+
backend.terminate()
|
|
249
249
|
|
|
250
250
|
|
|
251
251
|
# pylint: disable=too-many-arguments,unused-argument,too-many-locals,too-many-branches
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""In-memory LinkState implementation."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
import secrets
|
|
18
19
|
import threading
|
|
19
20
|
import time
|
|
20
21
|
from bisect import bisect_right
|
|
@@ -25,6 +26,7 @@ from typing import Optional
|
|
|
25
26
|
|
|
26
27
|
from flwr.common import Context, Message, log, now
|
|
27
28
|
from flwr.common.constant import (
|
|
29
|
+
FLWR_APP_TOKEN_LENGTH,
|
|
28
30
|
HEARTBEAT_MAX_INTERVAL,
|
|
29
31
|
HEARTBEAT_PATIENCE,
|
|
30
32
|
MESSAGE_TTL_TOLERANCE,
|
|
@@ -80,6 +82,11 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
80
82
|
self.message_res_store: dict[str, Message] = {}
|
|
81
83
|
self.message_ins_id_to_message_res_id: dict[str, str] = {}
|
|
82
84
|
|
|
85
|
+
# Store run ID to token mapping and token to run ID mapping
|
|
86
|
+
self.token_store: dict[int, str] = {}
|
|
87
|
+
self.token_to_run_id: dict[str, int] = {}
|
|
88
|
+
self.lock_token_store = threading.Lock()
|
|
89
|
+
|
|
83
90
|
# Map flwr_aid to run_ids for O(1) reverse index lookup
|
|
84
91
|
self.flwr_aid_to_run_ids: dict[str, set[int]] = defaultdict(set)
|
|
85
92
|
|
|
@@ -678,3 +685,30 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
678
685
|
index = bisect_right(run.logs, (after_timestamp, ""))
|
|
679
686
|
latest_timestamp = run.logs[-1][0] if index < len(run.logs) else 0.0
|
|
680
687
|
return "".join(log for _, log in run.logs[index:]), latest_timestamp
|
|
688
|
+
|
|
689
|
+
def create_token(self, run_id: int) -> Optional[str]:
|
|
690
|
+
"""Create a token for the given run ID."""
|
|
691
|
+
token = secrets.token_hex(FLWR_APP_TOKEN_LENGTH) # Generate a random token
|
|
692
|
+
with self.lock_token_store:
|
|
693
|
+
if run_id in self.token_store:
|
|
694
|
+
return None # Token already created for this run ID
|
|
695
|
+
self.token_store[run_id] = token
|
|
696
|
+
self.token_to_run_id[token] = run_id
|
|
697
|
+
return token
|
|
698
|
+
|
|
699
|
+
def verify_token(self, run_id: int, token: str) -> bool:
|
|
700
|
+
"""Verify a token for the given run ID."""
|
|
701
|
+
with self.lock_token_store:
|
|
702
|
+
return self.token_store.get(run_id) == token
|
|
703
|
+
|
|
704
|
+
def delete_token(self, run_id: int) -> None:
|
|
705
|
+
"""Delete the token for the given run ID."""
|
|
706
|
+
with self.lock_token_store:
|
|
707
|
+
token = self.token_store.pop(run_id, None)
|
|
708
|
+
if token is not None:
|
|
709
|
+
self.token_to_run_id.pop(token, None)
|
|
710
|
+
|
|
711
|
+
def get_run_id_by_token(self, token: str) -> Optional[int]:
|
|
712
|
+
"""Get the run ID associated with a given token."""
|
|
713
|
+
with self.lock_token_store:
|
|
714
|
+
return self.token_to_run_id.get(token)
|
|
@@ -21,9 +21,10 @@ from typing import Optional
|
|
|
21
21
|
from flwr.common import Context, Message
|
|
22
22
|
from flwr.common.record import ConfigRecord
|
|
23
23
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
24
|
+
from flwr.supercore.corestate import CoreState
|
|
24
25
|
|
|
25
26
|
|
|
26
|
-
class LinkState(
|
|
27
|
+
class LinkState(CoreState): # pylint: disable=R0904
|
|
27
28
|
"""Abstract LinkState."""
|
|
28
29
|
|
|
29
30
|
@abc.abstractmethod
|
|
@@ -19,6 +19,7 @@
|
|
|
19
19
|
|
|
20
20
|
import json
|
|
21
21
|
import re
|
|
22
|
+
import secrets
|
|
22
23
|
import sqlite3
|
|
23
24
|
import time
|
|
24
25
|
from collections.abc import Sequence
|
|
@@ -27,6 +28,7 @@ from typing import Any, Optional, Union, cast
|
|
|
27
28
|
|
|
28
29
|
from flwr.common import Context, Message, Metadata, log, now
|
|
29
30
|
from flwr.common.constant import (
|
|
31
|
+
FLWR_APP_TOKEN_LENGTH,
|
|
30
32
|
HEARTBEAT_MAX_INTERVAL,
|
|
31
33
|
HEARTBEAT_PATIENCE,
|
|
32
34
|
MESSAGE_TTL_TOLERANCE,
|
|
@@ -163,6 +165,13 @@ CREATE TABLE IF NOT EXISTS message_res(
|
|
|
163
165
|
);
|
|
164
166
|
"""
|
|
165
167
|
|
|
168
|
+
SQL_CREATE_TABLE_TOKEN_STORE = """
|
|
169
|
+
CREATE TABLE IF NOT EXISTS token_store (
|
|
170
|
+
run_id INTEGER PRIMARY KEY,
|
|
171
|
+
token TEXT UNIQUE NOT NULL
|
|
172
|
+
);
|
|
173
|
+
"""
|
|
174
|
+
|
|
166
175
|
DictOrTuple = Union[tuple[Any, ...], dict[str, Any]]
|
|
167
176
|
|
|
168
177
|
|
|
@@ -212,6 +221,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
212
221
|
cur.execute(SQL_CREATE_TABLE_MESSAGE_RES)
|
|
213
222
|
cur.execute(SQL_CREATE_TABLE_NODE)
|
|
214
223
|
cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
|
|
224
|
+
cur.execute(SQL_CREATE_TABLE_TOKEN_STORE)
|
|
215
225
|
cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
|
|
216
226
|
res = cur.execute("SELECT name FROM sqlite_schema;")
|
|
217
227
|
return res.fetchall()
|
|
@@ -1138,6 +1148,41 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
1138
1148
|
|
|
1139
1149
|
return message_ins
|
|
1140
1150
|
|
|
1151
|
+
def create_token(self, run_id: int) -> Optional[str]:
|
|
1152
|
+
"""Create a token for the given run ID."""
|
|
1153
|
+
token = secrets.token_hex(FLWR_APP_TOKEN_LENGTH) # Generate a random token
|
|
1154
|
+
query = "INSERT INTO token_store (run_id, token) VALUES (:run_id, :token);"
|
|
1155
|
+
data = {"run_id": convert_uint64_to_sint64(run_id), "token": token}
|
|
1156
|
+
try:
|
|
1157
|
+
self.query(query, data)
|
|
1158
|
+
except sqlite3.IntegrityError:
|
|
1159
|
+
return None # Token already created for this run ID
|
|
1160
|
+
return token
|
|
1161
|
+
|
|
1162
|
+
def verify_token(self, run_id: int, token: str) -> bool:
|
|
1163
|
+
"""Verify a token for the given run ID."""
|
|
1164
|
+
query = "SELECT token FROM token_store WHERE run_id = :run_id;"
|
|
1165
|
+
data = {"run_id": convert_uint64_to_sint64(run_id)}
|
|
1166
|
+
rows = self.query(query, data)
|
|
1167
|
+
if not rows:
|
|
1168
|
+
return False
|
|
1169
|
+
return cast(str, rows[0]["token"]) == token
|
|
1170
|
+
|
|
1171
|
+
def delete_token(self, run_id: int) -> None:
|
|
1172
|
+
"""Delete the token for the given run ID."""
|
|
1173
|
+
query = "DELETE FROM token_store WHERE run_id = :run_id;"
|
|
1174
|
+
data = {"run_id": convert_uint64_to_sint64(run_id)}
|
|
1175
|
+
self.query(query, data)
|
|
1176
|
+
|
|
1177
|
+
def get_run_id_by_token(self, token: str) -> Optional[int]:
|
|
1178
|
+
"""Get the run ID associated with a given token."""
|
|
1179
|
+
query = "SELECT run_id FROM token_store WHERE token = :token;"
|
|
1180
|
+
data = {"token": token}
|
|
1181
|
+
rows = self.query(query, data)
|
|
1182
|
+
if not rows:
|
|
1183
|
+
return None
|
|
1184
|
+
return convert_sint64_to_uint64(rows[0]["run_id"])
|
|
1185
|
+
|
|
1141
1186
|
|
|
1142
1187
|
def dict_factory(
|
|
1143
1188
|
cursor: sqlite3.Cursor,
|
|
@@ -26,8 +26,8 @@ from flwr.common.logger import log
|
|
|
26
26
|
from flwr.proto.serverappio_pb2_grpc import ( # pylint: disable=E0611
|
|
27
27
|
add_ServerAppIoServicer_to_server,
|
|
28
28
|
)
|
|
29
|
-
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
30
29
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
30
|
+
from flwr.supercore.ffs import FfsFactory
|
|
31
31
|
from flwr.supercore.object_store import ObjectStoreFactory
|
|
32
32
|
|
|
33
33
|
from .serverappio_servicer import ServerAppIoServicer
|
|
@@ -58,7 +58,7 @@ def run_serverappio_api_grpc(
|
|
|
58
58
|
certificates=certificates,
|
|
59
59
|
)
|
|
60
60
|
|
|
61
|
-
log(INFO, "Flower
|
|
61
|
+
log(INFO, "Flower Deployment Runtime: Starting ServerAppIo API on %s", address)
|
|
62
62
|
serverappio_grpc_server.start()
|
|
63
63
|
|
|
64
64
|
return serverappio_grpc_server
|
|
@@ -42,6 +42,20 @@ from flwr.common.serde import (
|
|
|
42
42
|
)
|
|
43
43
|
from flwr.common.typing import Fab, RunStatus
|
|
44
44
|
from flwr.proto import serverappio_pb2_grpc # pylint: disable=E0611
|
|
45
|
+
from flwr.proto.appio_pb2 import ( # pylint: disable=E0611
|
|
46
|
+
ListAppsToLaunchRequest,
|
|
47
|
+
ListAppsToLaunchResponse,
|
|
48
|
+
PullAppInputsRequest,
|
|
49
|
+
PullAppInputsResponse,
|
|
50
|
+
PullAppMessagesRequest,
|
|
51
|
+
PullAppMessagesResponse,
|
|
52
|
+
PushAppMessagesRequest,
|
|
53
|
+
PushAppMessagesResponse,
|
|
54
|
+
PushAppOutputsRequest,
|
|
55
|
+
PushAppOutputsResponse,
|
|
56
|
+
RequestTokenRequest,
|
|
57
|
+
RequestTokenResponse,
|
|
58
|
+
)
|
|
45
59
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
46
60
|
from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
47
61
|
SendAppHeartbeatRequest,
|
|
@@ -54,7 +68,6 @@ from flwr.proto.log_pb2 import ( # pylint: disable=E0611
|
|
|
54
68
|
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
|
55
69
|
ConfirmMessageReceivedRequest,
|
|
56
70
|
ConfirmMessageReceivedResponse,
|
|
57
|
-
ObjectIDs,
|
|
58
71
|
PullObjectRequest,
|
|
59
72
|
PullObjectResponse,
|
|
60
73
|
PushObjectRequest,
|
|
@@ -72,23 +85,13 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
|
72
85
|
from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
|
73
86
|
GetNodesRequest,
|
|
74
87
|
GetNodesResponse,
|
|
75
|
-
PullResMessagesRequest,
|
|
76
|
-
PullResMessagesResponse,
|
|
77
|
-
PullServerAppInputsRequest,
|
|
78
|
-
PullServerAppInputsResponse,
|
|
79
|
-
PushInsMessagesRequest,
|
|
80
|
-
PushInsMessagesResponse,
|
|
81
|
-
PushServerAppOutputsRequest,
|
|
82
|
-
PushServerAppOutputsResponse,
|
|
83
88
|
)
|
|
84
|
-
from flwr.server.superlink.ffs.ffs import Ffs
|
|
85
|
-
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
86
89
|
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
|
87
90
|
from flwr.server.superlink.utils import abort_if
|
|
88
91
|
from flwr.server.utils.validator import validate_message
|
|
92
|
+
from flwr.supercore.ffs import Ffs, FfsFactory
|
|
89
93
|
from flwr.supercore.object_store import NoObjectInStoreError, ObjectStoreFactory
|
|
90
|
-
|
|
91
|
-
from ..utils import store_mapping_and_register_objects
|
|
94
|
+
from flwr.supercore.object_store.utils import store_mapping_and_register_objects
|
|
92
95
|
|
|
93
96
|
|
|
94
97
|
class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
@@ -105,6 +108,42 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
105
108
|
self.objectstore_factory = objectstore_factory
|
|
106
109
|
self.lock = threading.RLock()
|
|
107
110
|
|
|
111
|
+
def ListAppsToLaunch(
|
|
112
|
+
self,
|
|
113
|
+
request: ListAppsToLaunchRequest,
|
|
114
|
+
context: grpc.ServicerContext,
|
|
115
|
+
) -> ListAppsToLaunchResponse:
|
|
116
|
+
"""Get run IDs with pending messages."""
|
|
117
|
+
log(DEBUG, "ServerAppIoServicer.ListAppsToLaunch")
|
|
118
|
+
|
|
119
|
+
# Initialize state connection
|
|
120
|
+
state = self.state_factory.state()
|
|
121
|
+
|
|
122
|
+
# Get IDs of runs in pending status
|
|
123
|
+
run_ids = state.get_run_ids(flwr_aid=None)
|
|
124
|
+
pending_run_ids = []
|
|
125
|
+
for run_id, status in state.get_run_status(run_ids).items():
|
|
126
|
+
if status.status == Status.PENDING:
|
|
127
|
+
pending_run_ids.append(run_id)
|
|
128
|
+
|
|
129
|
+
# Return run IDs
|
|
130
|
+
return ListAppsToLaunchResponse(run_ids=pending_run_ids)
|
|
131
|
+
|
|
132
|
+
def RequestToken(
|
|
133
|
+
self, request: RequestTokenRequest, context: grpc.ServicerContext
|
|
134
|
+
) -> RequestTokenResponse:
|
|
135
|
+
"""Request token."""
|
|
136
|
+
log(DEBUG, "ServerAppIoServicer.RequestToken")
|
|
137
|
+
|
|
138
|
+
# Initialize state connection
|
|
139
|
+
state = self.state_factory.state()
|
|
140
|
+
|
|
141
|
+
# Attempt to create a token for the provided run ID
|
|
142
|
+
token = state.create_token(request.run_id)
|
|
143
|
+
|
|
144
|
+
# Return the token
|
|
145
|
+
return RequestTokenResponse(token=token or "")
|
|
146
|
+
|
|
108
147
|
def GetNodes(
|
|
109
148
|
self, request: GetNodesRequest, context: grpc.ServicerContext
|
|
110
149
|
) -> GetNodesResponse:
|
|
@@ -129,8 +168,8 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
129
168
|
return GetNodesResponse(nodes=nodes)
|
|
130
169
|
|
|
131
170
|
def PushMessages(
|
|
132
|
-
self, request:
|
|
133
|
-
) ->
|
|
171
|
+
self, request: PushAppMessagesRequest, context: grpc.ServicerContext
|
|
172
|
+
) -> PushAppMessagesResponse:
|
|
134
173
|
"""Push a set of Messages."""
|
|
135
174
|
log(DEBUG, "ServerAppIoServicer.PushMessages")
|
|
136
175
|
|
|
@@ -174,7 +213,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
174
213
|
# Store Message object to descendants mapping and preregister objects
|
|
175
214
|
objects_to_push = store_mapping_and_register_objects(store, request=request)
|
|
176
215
|
|
|
177
|
-
return
|
|
216
|
+
return PushAppMessagesResponse(
|
|
178
217
|
message_ids=[
|
|
179
218
|
str(message_id) if message_id else "" for message_id in message_ids
|
|
180
219
|
],
|
|
@@ -182,8 +221,8 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
182
221
|
)
|
|
183
222
|
|
|
184
223
|
def PullMessages( # pylint: disable=R0914
|
|
185
|
-
self, request:
|
|
186
|
-
) ->
|
|
224
|
+
self, request: PullAppMessagesRequest, context: grpc.ServicerContext
|
|
225
|
+
) -> PullAppMessagesResponse:
|
|
187
226
|
"""Pull a set of Messages."""
|
|
188
227
|
log(DEBUG, "ServerAppIoServicer.PullMessages")
|
|
189
228
|
|
|
@@ -210,12 +249,6 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
210
249
|
if msg_res.metadata.src_node_id == SUPERLINK_NODE_ID:
|
|
211
250
|
with no_object_id_recompute():
|
|
212
251
|
all_objects = get_all_nested_objects(msg_res)
|
|
213
|
-
descendants = list(all_objects.keys())[:-1]
|
|
214
|
-
message_obj_id = msg_res.metadata.message_id
|
|
215
|
-
# Store mapping
|
|
216
|
-
store.set_message_descendant_ids(
|
|
217
|
-
msg_object_id=message_obj_id, descendant_ids=descendants
|
|
218
|
-
)
|
|
219
252
|
# Preregister
|
|
220
253
|
store.preregister(request.run_id, get_object_tree(msg_res))
|
|
221
254
|
# Store objects
|
|
@@ -231,7 +264,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
231
264
|
|
|
232
265
|
# Convert Messages to proto
|
|
233
266
|
messages_list = []
|
|
234
|
-
|
|
267
|
+
trees = []
|
|
235
268
|
while messages_res:
|
|
236
269
|
msg = messages_res.pop(0)
|
|
237
270
|
|
|
@@ -242,20 +275,20 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
242
275
|
request_name="PullMessages",
|
|
243
276
|
detail="`message.metadata` has mismatched `run_id`",
|
|
244
277
|
)
|
|
245
|
-
messages_list.append(message_to_proto(msg))
|
|
246
278
|
|
|
247
279
|
try:
|
|
248
280
|
msg_object_id = msg.metadata.message_id
|
|
249
|
-
|
|
250
|
-
# Add
|
|
251
|
-
|
|
281
|
+
obj_tree = store.get_object_tree(msg_object_id)
|
|
282
|
+
# Add message and object tree to the response
|
|
283
|
+
messages_list.append(message_to_proto(msg))
|
|
284
|
+
trees.append(obj_tree)
|
|
252
285
|
except NoObjectInStoreError as e:
|
|
253
286
|
log(ERROR, e.message)
|
|
254
287
|
# Delete message ins from state
|
|
255
288
|
state.delete_messages(message_ins_ids={msg_object_id})
|
|
256
289
|
|
|
257
|
-
return
|
|
258
|
-
messages_list=messages_list,
|
|
290
|
+
return PullAppMessagesResponse(
|
|
291
|
+
messages_list=messages_list, message_object_trees=trees
|
|
259
292
|
)
|
|
260
293
|
|
|
261
294
|
def GetRun(
|
|
@@ -288,22 +321,19 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
288
321
|
|
|
289
322
|
raise ValueError(f"Found no FAB with hash: {request.hash_str}")
|
|
290
323
|
|
|
291
|
-
def
|
|
292
|
-
self, request:
|
|
293
|
-
) ->
|
|
324
|
+
def PullAppInputs(
|
|
325
|
+
self, request: PullAppInputsRequest, context: grpc.ServicerContext
|
|
326
|
+
) -> PullAppInputsResponse:
|
|
294
327
|
"""Pull ServerApp process inputs."""
|
|
295
|
-
log(DEBUG, "ServerAppIoServicer.
|
|
328
|
+
log(DEBUG, "ServerAppIoServicer.PullAppInputs")
|
|
296
329
|
# Init access to LinkState
|
|
297
330
|
state = self.state_factory.state()
|
|
298
331
|
|
|
332
|
+
# Validate the token
|
|
333
|
+
run_id = self._verify_token(request.token, context)
|
|
334
|
+
|
|
299
335
|
# Lock access to LinkState, preventing obtaining the same pending run_id
|
|
300
336
|
with self.lock:
|
|
301
|
-
# Attempt getting the run_id of a pending run
|
|
302
|
-
run_id = state.get_pending_run_id()
|
|
303
|
-
# If there's no pending run, return an empty response
|
|
304
|
-
if run_id is None:
|
|
305
|
-
return PullServerAppInputsResponse()
|
|
306
|
-
|
|
307
337
|
# Init access to Ffs
|
|
308
338
|
ffs = self.ffs_factory.ffs()
|
|
309
339
|
|
|
@@ -318,7 +348,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
318
348
|
# Update run status to STARTING
|
|
319
349
|
if state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")):
|
|
320
350
|
log(INFO, "Starting run %d", run_id)
|
|
321
|
-
return
|
|
351
|
+
return PullAppInputsResponse(
|
|
322
352
|
context=context_to_proto(serverapp_ctxt),
|
|
323
353
|
run=run_to_proto(run),
|
|
324
354
|
fab=fab_to_proto(fab),
|
|
@@ -328,11 +358,14 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
328
358
|
# or if the status cannot be updated to STARTING
|
|
329
359
|
raise RuntimeError(f"Failed to start run {run_id}")
|
|
330
360
|
|
|
331
|
-
def
|
|
332
|
-
self, request:
|
|
333
|
-
) ->
|
|
361
|
+
def PushAppOutputs(
|
|
362
|
+
self, request: PushAppOutputsRequest, context: grpc.ServicerContext
|
|
363
|
+
) -> PushAppOutputsResponse:
|
|
334
364
|
"""Push ServerApp process outputs."""
|
|
335
|
-
log(DEBUG, "ServerAppIoServicer.
|
|
365
|
+
log(DEBUG, "ServerAppIoServicer.PushAppOutputs")
|
|
366
|
+
|
|
367
|
+
# Validate the token
|
|
368
|
+
run_id = self._verify_token(request.token, context)
|
|
336
369
|
|
|
337
370
|
# Init state and store
|
|
338
371
|
state = self.state_factory.state()
|
|
@@ -348,7 +381,10 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
348
381
|
)
|
|
349
382
|
|
|
350
383
|
state.set_serverapp_context(request.run_id, context_from_proto(request.context))
|
|
351
|
-
|
|
384
|
+
|
|
385
|
+
# Remove the token
|
|
386
|
+
state.delete_token(run_id)
|
|
387
|
+
return PushAppOutputsResponse()
|
|
352
388
|
|
|
353
389
|
def UpdateRunStatus(
|
|
354
390
|
self, request: UpdateRunStatusRequest, context: grpc.ServicerContext
|
|
@@ -512,10 +548,21 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
512
548
|
|
|
513
549
|
# Delete the message object
|
|
514
550
|
store.delete(request.message_object_id)
|
|
515
|
-
store.delete_message_descendant_ids(request.message_object_id)
|
|
516
551
|
|
|
517
552
|
return ConfirmMessageReceivedResponse()
|
|
518
553
|
|
|
554
|
+
def _verify_token(self, token: str, context: grpc.ServicerContext) -> int:
|
|
555
|
+
"""Verify the token and return the associated run ID."""
|
|
556
|
+
state = self.state_factory.state()
|
|
557
|
+
run_id = state.get_run_id_by_token(token)
|
|
558
|
+
if run_id is None or not state.verify_token(run_id, token):
|
|
559
|
+
context.abort(
|
|
560
|
+
grpc.StatusCode.PERMISSION_DENIED,
|
|
561
|
+
"Invalid token.",
|
|
562
|
+
)
|
|
563
|
+
raise RuntimeError("This line should never be reached.")
|
|
564
|
+
return run_id
|
|
565
|
+
|
|
519
566
|
|
|
520
567
|
def _raise_if(validation_error: bool, request_name: str, detail: str) -> None:
|
|
521
568
|
"""Raise a `ValueError` with a detailed message if a validation error occurs."""
|
|
@@ -26,8 +26,8 @@ from flwr.common.logger import log
|
|
|
26
26
|
from flwr.proto.simulationio_pb2_grpc import ( # pylint: disable=E0611
|
|
27
27
|
add_SimulationIoServicer_to_server,
|
|
28
28
|
)
|
|
29
|
-
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
30
29
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
30
|
+
from flwr.supercore.ffs import FfsFactory
|
|
31
31
|
|
|
32
32
|
from .simulationio_servicer import SimulationIoServicer
|
|
33
33
|
|