flwr 1.13.0__py3-none-any.whl → 1.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/app.py +5 -0
- flwr/cli/build.py +1 -37
- flwr/cli/cli_user_auth_interceptor.py +86 -0
- flwr/cli/config_utils.py +19 -2
- flwr/cli/example.py +1 -0
- flwr/cli/install.py +2 -19
- flwr/cli/log.py +18 -36
- flwr/cli/login/__init__.py +22 -0
- flwr/cli/login/login.py +81 -0
- flwr/cli/ls.py +205 -106
- flwr/cli/new/__init__.py +1 -0
- flwr/cli/new/new.py +25 -14
- flwr/cli/new/templates/app/.gitignore.tpl +3 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -3
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/__init__.py +1 -0
- flwr/cli/run/run.py +89 -39
- flwr/cli/stop.py +130 -0
- flwr/cli/utils.py +172 -8
- flwr/client/app.py +14 -3
- flwr/client/client.py +1 -32
- flwr/client/clientapp/app.py +4 -8
- flwr/client/clientapp/utils.py +1 -0
- flwr/client/grpc_adapter_client/connection.py +1 -1
- flwr/client/grpc_client/connection.py +1 -1
- flwr/client/grpc_rere_client/connection.py +13 -7
- flwr/client/message_handler/message_handler.py +1 -2
- flwr/client/mod/comms_mods.py +1 -0
- flwr/client/mod/localdp_mod.py +1 -1
- flwr/client/nodestate/__init__.py +1 -0
- flwr/client/nodestate/nodestate.py +1 -0
- flwr/client/nodestate/nodestate_factory.py +1 -0
- flwr/client/numpy_client.py +0 -44
- flwr/client/rest_client/connection.py +3 -3
- flwr/client/supernode/app.py +2 -2
- flwr/common/address.py +1 -0
- flwr/common/args.py +1 -0
- flwr/common/auth_plugin/__init__.py +24 -0
- flwr/common/auth_plugin/auth_plugin.py +111 -0
- flwr/common/config.py +3 -1
- flwr/common/constant.py +17 -1
- flwr/common/logger.py +40 -0
- flwr/common/message.py +1 -0
- flwr/common/object_ref.py +57 -54
- flwr/common/pyproject.py +1 -0
- flwr/common/record/__init__.py +1 -0
- flwr/common/record/parametersrecord.py +1 -0
- flwr/common/retry_invoker.py +77 -0
- flwr/common/secure_aggregation/secaggplus_utils.py +2 -2
- flwr/common/telemetry.py +15 -4
- flwr/common/typing.py +12 -0
- flwr/common/version.py +1 -0
- flwr/proto/exec_pb2.py +38 -14
- flwr/proto/exec_pb2.pyi +107 -2
- flwr/proto/exec_pb2_grpc.py +102 -0
- flwr/proto/exec_pb2_grpc.pyi +39 -0
- flwr/proto/fab_pb2.py +4 -4
- flwr/proto/fab_pb2.pyi +4 -1
- flwr/proto/serverappio_pb2.py +18 -18
- flwr/proto/serverappio_pb2.pyi +8 -2
- flwr/proto/serverappio_pb2_grpc.py +34 -0
- flwr/proto/serverappio_pb2_grpc.pyi +13 -0
- flwr/proto/simulationio_pb2.py +2 -2
- flwr/proto/simulationio_pb2_grpc.py +34 -0
- flwr/proto/simulationio_pb2_grpc.pyi +13 -0
- flwr/server/app.py +62 -7
- flwr/server/compat/app_utils.py +7 -1
- flwr/server/driver/grpc_driver.py +11 -63
- flwr/server/driver/inmemory_driver.py +5 -1
- flwr/server/run_serverapp.py +8 -9
- flwr/server/serverapp/app.py +25 -10
- flwr/server/strategy/dpfedavg_fixed.py +1 -0
- flwr/server/superlink/driver/serverappio_grpc.py +1 -0
- flwr/server/superlink/driver/serverappio_servicer.py +82 -23
- flwr/server/superlink/ffs/disk_ffs.py +1 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +1 -0
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +32 -12
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +12 -11
- flwr/server/superlink/fleet/message_handler/message_handler.py +32 -5
- flwr/server/superlink/fleet/rest_rere/rest_api.py +4 -1
- flwr/server/superlink/fleet/vce/__init__.py +1 -0
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -0
- flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -0
- flwr/server/superlink/linkstate/in_memory_linkstate.py +21 -30
- flwr/server/superlink/linkstate/linkstate.py +17 -2
- flwr/server/superlink/linkstate/sqlite_linkstate.py +30 -49
- flwr/server/superlink/simulation/simulationio_servicer.py +33 -0
- flwr/server/superlink/utils.py +65 -0
- flwr/simulation/app.py +59 -52
- flwr/simulation/ray_transport/ray_actor.py +1 -0
- flwr/simulation/ray_transport/utils.py +1 -0
- flwr/simulation/run_simulation.py +36 -22
- flwr/simulation/simulationio_connection.py +3 -0
- flwr/superexec/app.py +1 -0
- flwr/superexec/deployment.py +1 -0
- flwr/superexec/exec_grpc.py +19 -1
- flwr/superexec/exec_servicer.py +76 -2
- flwr/superexec/exec_user_auth_interceptor.py +101 -0
- flwr/superexec/executor.py +1 -0
- {flwr-1.13.0.dist-info → flwr-1.14.0.dist-info}/METADATA +8 -8
- {flwr-1.13.0.dist-info → flwr-1.14.0.dist-info}/RECORD +112 -112
- flwr/proto/common_pb2.py +0 -36
- flwr/proto/common_pb2.pyi +0 -121
- flwr/proto/common_pb2_grpc.py +0 -4
- flwr/proto/common_pb2_grpc.pyi +0 -4
- flwr/proto/control_pb2.py +0 -27
- flwr/proto/control_pb2.pyi +0 -7
- flwr/proto/control_pb2_grpc.py +0 -135
- flwr/proto/control_pb2_grpc.pyi +0 -53
- {flwr-1.13.0.dist-info → flwr-1.14.0.dist-info}/LICENSE +0 -0
- {flwr-1.13.0.dist-info → flwr-1.14.0.dist-info}/WHEEL +0 -0
- {flwr-1.13.0.dist-info → flwr-1.14.0.dist-info}/entry_points.txt +0 -0
|
@@ -20,6 +20,7 @@ from logging import DEBUG, INFO
|
|
|
20
20
|
import grpc
|
|
21
21
|
|
|
22
22
|
from flwr.common.logger import log
|
|
23
|
+
from flwr.common.typing import InvalidRunStatusException
|
|
23
24
|
from flwr.proto import fleet_pb2_grpc # pylint: disable=E0611
|
|
24
25
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
25
26
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
@@ -38,6 +39,7 @@ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=
|
|
|
38
39
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
39
40
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
40
41
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
42
|
+
from flwr.server.superlink.utils import abort_grpc_context
|
|
41
43
|
|
|
42
44
|
|
|
43
45
|
class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
@@ -105,27 +107,45 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
105
107
|
)
|
|
106
108
|
else:
|
|
107
109
|
log(INFO, "[Fleet.PushTaskRes] No task results to push")
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
110
|
+
|
|
111
|
+
try:
|
|
112
|
+
res = message_handler.push_task_res(
|
|
113
|
+
request=request,
|
|
114
|
+
state=self.state_factory.state(),
|
|
115
|
+
)
|
|
116
|
+
except InvalidRunStatusException as e:
|
|
117
|
+
abort_grpc_context(e.message, context)
|
|
118
|
+
|
|
119
|
+
return res
|
|
112
120
|
|
|
113
121
|
def GetRun(
|
|
114
122
|
self, request: GetRunRequest, context: grpc.ServicerContext
|
|
115
123
|
) -> GetRunResponse:
|
|
116
124
|
"""Get run information."""
|
|
117
125
|
log(INFO, "[Fleet.GetRun] Requesting `Run` for run_id=%s", request.run_id)
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
126
|
+
|
|
127
|
+
try:
|
|
128
|
+
res = message_handler.get_run(
|
|
129
|
+
request=request,
|
|
130
|
+
state=self.state_factory.state(),
|
|
131
|
+
)
|
|
132
|
+
except InvalidRunStatusException as e:
|
|
133
|
+
abort_grpc_context(e.message, context)
|
|
134
|
+
|
|
135
|
+
return res
|
|
122
136
|
|
|
123
137
|
def GetFab(
|
|
124
138
|
self, request: GetFabRequest, context: grpc.ServicerContext
|
|
125
139
|
) -> GetFabResponse:
|
|
126
140
|
"""Get FAB."""
|
|
127
141
|
log(INFO, "[Fleet.GetFab] Requesting FAB for fab_hash=%s", request.hash_str)
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
142
|
+
try:
|
|
143
|
+
res = message_handler.get_fab(
|
|
144
|
+
request=request,
|
|
145
|
+
ffs=self.ffs_factory.ffs(),
|
|
146
|
+
state=self.state_factory.state(),
|
|
147
|
+
)
|
|
148
|
+
except InvalidRunStatusException as e:
|
|
149
|
+
abort_grpc_context(e.message, context)
|
|
150
|
+
|
|
151
|
+
return res
|
|
@@ -45,7 +45,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
45
45
|
)
|
|
46
46
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
47
47
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
48
|
-
from flwr.server.superlink.linkstate import
|
|
48
|
+
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
49
49
|
|
|
50
50
|
_PUBLIC_KEY_HEADER = "public-key"
|
|
51
51
|
_AUTH_TOKEN_HEADER = "auth-token"
|
|
@@ -84,15 +84,16 @@ def _get_value_from_tuples(
|
|
|
84
84
|
class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
85
85
|
"""Server interceptor for node authentication."""
|
|
86
86
|
|
|
87
|
-
def __init__(self,
|
|
88
|
-
self.
|
|
87
|
+
def __init__(self, state_factory: LinkStateFactory):
|
|
88
|
+
self.state_factory = state_factory
|
|
89
|
+
state = self.state_factory.state()
|
|
89
90
|
|
|
90
91
|
self.node_public_keys = state.get_node_public_keys()
|
|
91
92
|
if len(self.node_public_keys) == 0:
|
|
92
93
|
log(WARNING, "Authentication enabled, but no known public keys configured")
|
|
93
94
|
|
|
94
|
-
private_key =
|
|
95
|
-
public_key =
|
|
95
|
+
private_key = state.get_server_private_key()
|
|
96
|
+
public_key = state.get_server_public_key()
|
|
96
97
|
|
|
97
98
|
if private_key is None or public_key is None:
|
|
98
99
|
raise ValueError("Error loading authentication keys")
|
|
@@ -154,7 +155,7 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
154
155
|
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|
|
155
156
|
|
|
156
157
|
# Verify node_id
|
|
157
|
-
node_id = self.state.get_node_id(node_public_key_bytes)
|
|
158
|
+
node_id = self.state_factory.state().get_node_id(node_public_key_bytes)
|
|
158
159
|
|
|
159
160
|
if not self._verify_node_id(node_id, request):
|
|
160
161
|
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|
|
@@ -186,7 +187,7 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
186
187
|
return False
|
|
187
188
|
return request.task_res_list[0].task.producer.node_id == node_id
|
|
188
189
|
if isinstance(request, GetRunRequest):
|
|
189
|
-
return node_id in self.state.get_nodes(request.run_id)
|
|
190
|
+
return node_id in self.state_factory.state().get_nodes(request.run_id)
|
|
190
191
|
return request.node.node_id == node_id
|
|
191
192
|
|
|
192
193
|
def _verify_hmac(
|
|
@@ -210,17 +211,17 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
210
211
|
),
|
|
211
212
|
)
|
|
212
213
|
)
|
|
213
|
-
|
|
214
|
-
node_id =
|
|
214
|
+
state = self.state_factory.state()
|
|
215
|
+
node_id = state.get_node_id(public_key_bytes)
|
|
215
216
|
|
|
216
217
|
# Handle `CreateNode` here instead of calling the default method handler
|
|
217
218
|
# Return previously assigned `node_id` for the provided `public_key`
|
|
218
219
|
if node_id is not None:
|
|
219
|
-
|
|
220
|
+
state.acknowledge_ping(node_id, request.ping_interval)
|
|
220
221
|
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
|
|
221
222
|
|
|
222
223
|
# No `node_id` exists for the provided `public_key`
|
|
223
224
|
# Handle `CreateNode` here instead of calling the default method handler
|
|
224
225
|
# Note: the innermost `CreateNode` method will never be called
|
|
225
|
-
node_id =
|
|
226
|
+
node_id = state.create_node(request.ping_interval, public_key_bytes)
|
|
226
227
|
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
|
|
@@ -19,8 +19,9 @@ import time
|
|
|
19
19
|
from typing import Optional
|
|
20
20
|
from uuid import UUID
|
|
21
21
|
|
|
22
|
+
from flwr.common.constant import Status
|
|
22
23
|
from flwr.common.serde import fab_to_proto, user_config_to_proto
|
|
23
|
-
from flwr.common.typing import Fab
|
|
24
|
+
from flwr.common.typing import Fab, InvalidRunStatusException
|
|
24
25
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
25
26
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
26
27
|
CreateNodeRequest,
|
|
@@ -44,6 +45,7 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
|
44
45
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
45
46
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
46
47
|
from flwr.server.superlink.linkstate import LinkState
|
|
48
|
+
from flwr.server.superlink.utils import check_abort
|
|
47
49
|
|
|
48
50
|
|
|
49
51
|
def create_node(
|
|
@@ -98,6 +100,15 @@ def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResR
|
|
|
98
100
|
task_res: TaskRes = request.task_res_list[0]
|
|
99
101
|
# pylint: enable=no-member
|
|
100
102
|
|
|
103
|
+
# Abort if the run is not running
|
|
104
|
+
abort_msg = check_abort(
|
|
105
|
+
task_res.run_id,
|
|
106
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
107
|
+
state,
|
|
108
|
+
)
|
|
109
|
+
if abort_msg:
|
|
110
|
+
raise InvalidRunStatusException(abort_msg)
|
|
111
|
+
|
|
101
112
|
# Set pushed_at (timestamp in seconds)
|
|
102
113
|
task_res.task.pushed_at = time.time()
|
|
103
114
|
|
|
@@ -112,15 +123,22 @@ def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResR
|
|
|
112
123
|
return response
|
|
113
124
|
|
|
114
125
|
|
|
115
|
-
def get_run(
|
|
116
|
-
request: GetRunRequest, state: LinkState # pylint: disable=W0613
|
|
117
|
-
) -> GetRunResponse:
|
|
126
|
+
def get_run(request: GetRunRequest, state: LinkState) -> GetRunResponse:
|
|
118
127
|
"""Get run information."""
|
|
119
128
|
run = state.get_run(request.run_id)
|
|
120
129
|
|
|
121
130
|
if run is None:
|
|
122
131
|
return GetRunResponse()
|
|
123
132
|
|
|
133
|
+
# Abort if the run is not running
|
|
134
|
+
abort_msg = check_abort(
|
|
135
|
+
request.run_id,
|
|
136
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
137
|
+
state,
|
|
138
|
+
)
|
|
139
|
+
if abort_msg:
|
|
140
|
+
raise InvalidRunStatusException(abort_msg)
|
|
141
|
+
|
|
124
142
|
return GetRunResponse(
|
|
125
143
|
run=Run(
|
|
126
144
|
run_id=run.run_id,
|
|
@@ -133,9 +151,18 @@ def get_run(
|
|
|
133
151
|
|
|
134
152
|
|
|
135
153
|
def get_fab(
|
|
136
|
-
request: GetFabRequest, ffs: Ffs # pylint: disable=W0613
|
|
154
|
+
request: GetFabRequest, ffs: Ffs, state: LinkState # pylint: disable=W0613
|
|
137
155
|
) -> GetFabResponse:
|
|
138
156
|
"""Get FAB."""
|
|
157
|
+
# Abort if the run is not running
|
|
158
|
+
abort_msg = check_abort(
|
|
159
|
+
request.run_id,
|
|
160
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
161
|
+
state,
|
|
162
|
+
)
|
|
163
|
+
if abort_msg:
|
|
164
|
+
raise InvalidRunStatusException(abort_msg)
|
|
165
|
+
|
|
139
166
|
if result := ffs.get(request.hash_str):
|
|
140
167
|
fab = Fab(request.hash_str, result[0])
|
|
141
168
|
return GetFabResponse(fab=fab_to_proto(fab))
|
|
@@ -154,8 +154,11 @@ async def get_fab(request: GetFabRequest) -> GetFabResponse:
|
|
|
154
154
|
# Get ffs from app
|
|
155
155
|
ffs: Ffs = cast(FfsFactory, app.state.FFS_FACTORY).ffs()
|
|
156
156
|
|
|
157
|
+
# Get state from app
|
|
158
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
|
159
|
+
|
|
157
160
|
# Handle message
|
|
158
|
-
return message_handler.get_fab(request=request, ffs=ffs)
|
|
161
|
+
return message_handler.get_fab(request=request, ffs=ffs, state=state)
|
|
159
162
|
|
|
160
163
|
|
|
161
164
|
routes = [
|
|
@@ -265,41 +265,15 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
265
265
|
for task_res in task_res_found:
|
|
266
266
|
task_res.task.delivered_at = delivered_at
|
|
267
267
|
|
|
268
|
-
# Cleanup
|
|
269
|
-
self._force_delete_tasks_by_ids(set(ret.keys()))
|
|
270
|
-
|
|
271
268
|
return list(ret.values())
|
|
272
269
|
|
|
273
|
-
def delete_tasks(self,
|
|
274
|
-
"""Delete
|
|
275
|
-
|
|
276
|
-
task_res_to_be_deleted: set[UUID] = set()
|
|
277
|
-
|
|
278
|
-
with self.lock:
|
|
279
|
-
for task_ins_id in task_ids:
|
|
280
|
-
# Find the task_id of the matching task_res
|
|
281
|
-
for task_res_id, task_res in self.task_res_store.items():
|
|
282
|
-
if UUID(task_res.task.ancestry[0]) != task_ins_id:
|
|
283
|
-
continue
|
|
284
|
-
if task_res.task.delivered_at == "":
|
|
285
|
-
continue
|
|
286
|
-
|
|
287
|
-
task_ins_to_be_deleted.add(task_ins_id)
|
|
288
|
-
task_res_to_be_deleted.add(task_res_id)
|
|
289
|
-
|
|
290
|
-
for task_id in task_ins_to_be_deleted:
|
|
291
|
-
del self.task_ins_store[task_id]
|
|
292
|
-
del self.task_ins_id_to_task_res_id[task_id]
|
|
293
|
-
for task_id in task_res_to_be_deleted:
|
|
294
|
-
del self.task_res_store[task_id]
|
|
295
|
-
|
|
296
|
-
def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
|
|
297
|
-
"""Delete tasks based on a set of TaskIns IDs."""
|
|
298
|
-
if not task_ids:
|
|
270
|
+
def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
|
|
271
|
+
"""Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
|
|
272
|
+
if not task_ins_ids:
|
|
299
273
|
return
|
|
300
274
|
|
|
301
275
|
with self.lock:
|
|
302
|
-
for task_id in
|
|
276
|
+
for task_id in task_ins_ids:
|
|
303
277
|
# Delete TaskIns
|
|
304
278
|
if task_id in self.task_ins_store:
|
|
305
279
|
del self.task_ins_store[task_id]
|
|
@@ -308,6 +282,16 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
308
282
|
task_res_id = self.task_ins_id_to_task_res_id.pop(task_id)
|
|
309
283
|
del self.task_res_store[task_res_id]
|
|
310
284
|
|
|
285
|
+
def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
|
286
|
+
"""Get all TaskIns IDs for the given run_id."""
|
|
287
|
+
task_id_list: set[UUID] = set()
|
|
288
|
+
with self.lock:
|
|
289
|
+
for task_id, task_ins in self.task_ins_store.items():
|
|
290
|
+
if task_ins.run_id == run_id:
|
|
291
|
+
task_id_list.add(task_id)
|
|
292
|
+
|
|
293
|
+
return task_id_list
|
|
294
|
+
|
|
311
295
|
def num_task_ins(self) -> int:
|
|
312
296
|
"""Calculate the number of task_ins in store.
|
|
313
297
|
|
|
@@ -446,6 +430,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
446
430
|
"""Retrieve `server_public_key` in urlsafe bytes."""
|
|
447
431
|
return self.server_public_key
|
|
448
432
|
|
|
433
|
+
def clear_supernode_auth_keys_and_credentials(self) -> None:
|
|
434
|
+
"""Clear stored `node_public_keys` and credentials in the link state if any."""
|
|
435
|
+
with self.lock:
|
|
436
|
+
self.server_private_key = None
|
|
437
|
+
self.server_public_key = None
|
|
438
|
+
self.node_public_keys.clear()
|
|
439
|
+
|
|
449
440
|
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
|
|
450
441
|
"""Store a set of `node_public_keys` in the link state."""
|
|
451
442
|
with self.lock:
|
|
@@ -139,8 +139,19 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
139
139
|
"""
|
|
140
140
|
|
|
141
141
|
@abc.abstractmethod
|
|
142
|
-
def delete_tasks(self,
|
|
143
|
-
"""Delete
|
|
142
|
+
def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
|
|
143
|
+
"""Delete TaskIns/TaskRes pairs based on provided TaskIns IDs.
|
|
144
|
+
|
|
145
|
+
Parameters
|
|
146
|
+
----------
|
|
147
|
+
task_ins_ids : set[UUID]
|
|
148
|
+
A set of TaskIns IDs. For each ID in the set, the corresponding
|
|
149
|
+
TaskIns and its associated TaskRes will be deleted.
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
@abc.abstractmethod
|
|
153
|
+
def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
|
154
|
+
"""Get all TaskIns IDs for the given run_id."""
|
|
144
155
|
|
|
145
156
|
@abc.abstractmethod
|
|
146
157
|
def create_node(
|
|
@@ -273,6 +284,10 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
273
284
|
def get_server_public_key(self) -> Optional[bytes]:
|
|
274
285
|
"""Retrieve `server_public_key` in urlsafe bytes."""
|
|
275
286
|
|
|
287
|
+
@abc.abstractmethod
|
|
288
|
+
def clear_supernode_auth_keys_and_credentials(self) -> None:
|
|
289
|
+
"""Clear stored `node_public_keys` and credentials in the link state if any."""
|
|
290
|
+
|
|
276
291
|
@abc.abstractmethod
|
|
277
292
|
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
|
|
278
293
|
"""Store a set of `node_public_keys` in the link state."""
|
|
@@ -14,12 +14,12 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""SQLite based implemenation of the link state."""
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
# pylint: disable=too-many-lines
|
|
18
19
|
|
|
19
20
|
import json
|
|
20
21
|
import re
|
|
21
22
|
import sqlite3
|
|
22
|
-
import threading
|
|
23
23
|
import time
|
|
24
24
|
from collections.abc import Sequence
|
|
25
25
|
from logging import DEBUG, ERROR, WARNING
|
|
@@ -183,7 +183,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
183
183
|
"""
|
|
184
184
|
self.database_path = database_path
|
|
185
185
|
self.conn: Optional[sqlite3.Connection] = None
|
|
186
|
-
self.lock = threading.RLock()
|
|
187
186
|
|
|
188
187
|
def initialize(self, log_queries: bool = False) -> list[tuple[str]]:
|
|
189
188
|
"""Create tables if they don't exist yet.
|
|
@@ -216,7 +215,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
216
215
|
cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
|
|
217
216
|
cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
|
|
218
217
|
res = cur.execute("SELECT name FROM sqlite_schema;")
|
|
219
|
-
|
|
220
218
|
return res.fetchall()
|
|
221
219
|
|
|
222
220
|
def query(
|
|
@@ -569,9 +567,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
569
567
|
data: list[Any] = [delivered_at] + task_res_ids
|
|
570
568
|
self.query(query, data)
|
|
571
569
|
|
|
572
|
-
# Cleanup
|
|
573
|
-
self._force_delete_tasks_by_ids(set(ret.keys()))
|
|
574
|
-
|
|
575
570
|
return list(ret.values())
|
|
576
571
|
|
|
577
572
|
def num_task_ins(self) -> int:
|
|
@@ -595,68 +590,50 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
595
590
|
result: dict[str, int] = rows[0]
|
|
596
591
|
return result["num"]
|
|
597
592
|
|
|
598
|
-
def delete_tasks(self,
|
|
599
|
-
"""Delete
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
593
|
+
def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
|
|
594
|
+
"""Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
|
|
595
|
+
if not task_ins_ids:
|
|
596
|
+
return
|
|
597
|
+
if self.conn is None:
|
|
598
|
+
raise AttributeError("LinkState not initialized")
|
|
603
599
|
|
|
604
|
-
placeholders = ",".join([
|
|
605
|
-
data =
|
|
600
|
+
placeholders = ",".join(["?"] * len(task_ins_ids))
|
|
601
|
+
data = tuple(str(task_id) for task_id in task_ins_ids)
|
|
606
602
|
|
|
607
|
-
#
|
|
603
|
+
# Delete task_ins
|
|
608
604
|
query_1 = f"""
|
|
609
605
|
DELETE FROM task_ins
|
|
610
|
-
WHERE
|
|
611
|
-
AND task_id IN (
|
|
612
|
-
SELECT ancestry
|
|
613
|
-
FROM task_res
|
|
614
|
-
WHERE ancestry IN ({placeholders})
|
|
615
|
-
AND delivered_at != ''
|
|
616
|
-
);
|
|
606
|
+
WHERE task_id IN ({placeholders});
|
|
617
607
|
"""
|
|
618
608
|
|
|
619
|
-
#
|
|
609
|
+
# Delete task_res
|
|
620
610
|
query_2 = f"""
|
|
621
611
|
DELETE FROM task_res
|
|
622
|
-
WHERE ancestry IN ({placeholders})
|
|
623
|
-
AND delivered_at != '';
|
|
612
|
+
WHERE ancestry IN ({placeholders});
|
|
624
613
|
"""
|
|
625
614
|
|
|
626
|
-
if self.conn is None:
|
|
627
|
-
raise AttributeError("LinkState not intitialized")
|
|
628
|
-
|
|
629
615
|
with self.conn:
|
|
630
616
|
self.conn.execute(query_1, data)
|
|
631
617
|
self.conn.execute(query_2, data)
|
|
632
618
|
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
|
|
636
|
-
"""Delete tasks based on a set of TaskIns IDs."""
|
|
637
|
-
if not task_ids:
|
|
638
|
-
return
|
|
619
|
+
def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
|
620
|
+
"""Get all TaskIns IDs for the given run_id."""
|
|
639
621
|
if self.conn is None:
|
|
640
622
|
raise AttributeError("LinkState not initialized")
|
|
641
623
|
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
query_1 = f"""
|
|
647
|
-
DELETE FROM task_ins
|
|
648
|
-
WHERE task_id IN ({placeholders});
|
|
624
|
+
query = """
|
|
625
|
+
SELECT task_id
|
|
626
|
+
FROM task_ins
|
|
627
|
+
WHERE run_id = :run_id;
|
|
649
628
|
"""
|
|
650
629
|
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
DELETE FROM task_res
|
|
654
|
-
WHERE ancestry IN ({placeholders});
|
|
655
|
-
"""
|
|
630
|
+
sint64_run_id = convert_uint64_to_sint64(run_id)
|
|
631
|
+
data = {"run_id": sint64_run_id}
|
|
656
632
|
|
|
657
633
|
with self.conn:
|
|
658
|
-
self.conn.execute(
|
|
659
|
-
|
|
634
|
+
rows = self.conn.execute(query, data).fetchall()
|
|
635
|
+
|
|
636
|
+
return {UUID(row["task_id"]) for row in rows}
|
|
660
637
|
|
|
661
638
|
def create_node(
|
|
662
639
|
self, ping_interval: float, public_key: Optional[bytes] = None
|
|
@@ -784,8 +761,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
784
761
|
"federation_options, pending_at, starting_at, running_at, finished_at, "
|
|
785
762
|
"sub_status, details) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
|
|
786
763
|
)
|
|
787
|
-
if fab_hash:
|
|
788
|
-
fab_id, fab_version = "", ""
|
|
789
764
|
override_config_json = json.dumps(override_config)
|
|
790
765
|
data = [
|
|
791
766
|
sint64_run_id,
|
|
@@ -843,6 +818,12 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
843
818
|
public_key = None
|
|
844
819
|
return public_key
|
|
845
820
|
|
|
821
|
+
def clear_supernode_auth_keys_and_credentials(self) -> None:
|
|
822
|
+
"""Clear stored `node_public_keys` and credentials in the link state if any."""
|
|
823
|
+
queries = ["DELETE FROM public_key;", "DELETE FROM credential;"]
|
|
824
|
+
for query in queries:
|
|
825
|
+
self.query(query)
|
|
826
|
+
|
|
846
827
|
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
|
|
847
828
|
"""Store a set of `node_public_keys` in the link state."""
|
|
848
829
|
query = "INSERT INTO public_key (public_key) VALUES (?)"
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""SimulationIo API servicer."""
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
import threading
|
|
18
19
|
from logging import DEBUG, INFO
|
|
19
20
|
|
|
@@ -28,6 +29,7 @@ from flwr.common.serde import (
|
|
|
28
29
|
context_to_proto,
|
|
29
30
|
fab_to_proto,
|
|
30
31
|
run_status_from_proto,
|
|
32
|
+
run_status_to_proto,
|
|
31
33
|
run_to_proto,
|
|
32
34
|
)
|
|
33
35
|
from flwr.common.typing import Fab, RunStatus
|
|
@@ -39,6 +41,8 @@ from flwr.proto.log_pb2 import ( # pylint: disable=E0611
|
|
|
39
41
|
from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
40
42
|
GetFederationOptionsRequest,
|
|
41
43
|
GetFederationOptionsResponse,
|
|
44
|
+
GetRunStatusRequest,
|
|
45
|
+
GetRunStatusResponse,
|
|
42
46
|
UpdateRunStatusRequest,
|
|
43
47
|
UpdateRunStatusResponse,
|
|
44
48
|
)
|
|
@@ -50,6 +54,7 @@ from flwr.proto.simulationio_pb2 import ( # pylint: disable=E0611
|
|
|
50
54
|
)
|
|
51
55
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
52
56
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
57
|
+
from flwr.server.superlink.utils import abort_if
|
|
53
58
|
|
|
54
59
|
|
|
55
60
|
class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
@@ -106,6 +111,15 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
|
106
111
|
"""Push Simulation process outputs."""
|
|
107
112
|
log(DEBUG, "SimultionIoServicer.PushSimulationOutputs")
|
|
108
113
|
state = self.state_factory.state()
|
|
114
|
+
|
|
115
|
+
# Abort if the run is not running
|
|
116
|
+
abort_if(
|
|
117
|
+
request.run_id,
|
|
118
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
119
|
+
state,
|
|
120
|
+
context,
|
|
121
|
+
)
|
|
122
|
+
|
|
109
123
|
state.set_serverapp_context(request.run_id, context_from_proto(request.context))
|
|
110
124
|
return PushSimulationOutputsResponse()
|
|
111
125
|
|
|
@@ -116,12 +130,31 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
|
116
130
|
log(DEBUG, "SimultionIoServicer.UpdateRunStatus")
|
|
117
131
|
state = self.state_factory.state()
|
|
118
132
|
|
|
133
|
+
# Abort if the run is finished
|
|
134
|
+
abort_if(request.run_id, [Status.FINISHED], state, context)
|
|
135
|
+
|
|
119
136
|
# Update the run status
|
|
120
137
|
state.update_run_status(
|
|
121
138
|
run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
|
|
122
139
|
)
|
|
123
140
|
return UpdateRunStatusResponse()
|
|
124
141
|
|
|
142
|
+
def GetRunStatus(
|
|
143
|
+
self, request: GetRunStatusRequest, context: ServicerContext
|
|
144
|
+
) -> GetRunStatusResponse:
|
|
145
|
+
"""Get status of requested runs."""
|
|
146
|
+
log(DEBUG, "SimultionIoServicer.GetRunStatus")
|
|
147
|
+
state = self.state_factory.state()
|
|
148
|
+
|
|
149
|
+
statuses = state.get_run_status(set(request.run_ids))
|
|
150
|
+
|
|
151
|
+
return GetRunStatusResponse(
|
|
152
|
+
run_status_dict={
|
|
153
|
+
run_id: run_status_to_proto(status)
|
|
154
|
+
for run_id, status in statuses.items()
|
|
155
|
+
}
|
|
156
|
+
)
|
|
157
|
+
|
|
125
158
|
def PushLogs(
|
|
126
159
|
self, request: PushLogsRequest, context: grpc.ServicerContext
|
|
127
160
|
) -> PushLogsResponse:
|