flwr 1.13.1__py3-none-any.whl → 1.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/app.py +5 -0
- flwr/cli/build.py +1 -0
- flwr/cli/cli_user_auth_interceptor.py +86 -0
- flwr/cli/config_utils.py +19 -2
- flwr/cli/example.py +1 -0
- flwr/cli/install.py +1 -0
- flwr/cli/log.py +18 -36
- flwr/cli/login/__init__.py +22 -0
- flwr/cli/login/login.py +81 -0
- flwr/cli/ls.py +205 -106
- flwr/cli/new/__init__.py +1 -0
- flwr/cli/new/new.py +2 -1
- flwr/cli/new/templates/app/.gitignore.tpl +3 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -3
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/__init__.py +1 -0
- flwr/cli/run/run.py +89 -39
- flwr/cli/stop.py +130 -0
- flwr/cli/utils.py +172 -8
- flwr/client/app.py +14 -3
- flwr/client/client.py +1 -32
- flwr/client/clientapp/app.py +4 -1
- flwr/client/clientapp/utils.py +1 -0
- flwr/client/grpc_adapter_client/connection.py +1 -1
- flwr/client/grpc_client/connection.py +1 -1
- flwr/client/grpc_rere_client/connection.py +13 -7
- flwr/client/message_handler/message_handler.py +1 -2
- flwr/client/mod/comms_mods.py +1 -0
- flwr/client/mod/localdp_mod.py +1 -1
- flwr/client/nodestate/__init__.py +1 -0
- flwr/client/nodestate/nodestate.py +1 -0
- flwr/client/nodestate/nodestate_factory.py +1 -0
- flwr/client/numpy_client.py +0 -44
- flwr/client/rest_client/connection.py +3 -3
- flwr/client/supernode/app.py +2 -2
- flwr/common/address.py +1 -0
- flwr/common/args.py +1 -0
- flwr/common/auth_plugin/__init__.py +24 -0
- flwr/common/auth_plugin/auth_plugin.py +111 -0
- flwr/common/config.py +3 -1
- flwr/common/constant.py +17 -1
- flwr/common/logger.py +40 -0
- flwr/common/message.py +1 -0
- flwr/common/object_ref.py +57 -54
- flwr/common/pyproject.py +1 -0
- flwr/common/record/__init__.py +1 -0
- flwr/common/record/parametersrecord.py +1 -0
- flwr/common/retry_invoker.py +77 -0
- flwr/common/secure_aggregation/secaggplus_utils.py +2 -2
- flwr/common/telemetry.py +15 -4
- flwr/common/typing.py +12 -0
- flwr/common/version.py +1 -0
- flwr/proto/exec_pb2.py +38 -14
- flwr/proto/exec_pb2.pyi +107 -2
- flwr/proto/exec_pb2_grpc.py +102 -0
- flwr/proto/exec_pb2_grpc.pyi +39 -0
- flwr/proto/fab_pb2.py +4 -4
- flwr/proto/fab_pb2.pyi +4 -1
- flwr/proto/serverappio_pb2.py +18 -18
- flwr/proto/serverappio_pb2.pyi +8 -2
- flwr/proto/serverappio_pb2_grpc.py +34 -0
- flwr/proto/serverappio_pb2_grpc.pyi +13 -0
- flwr/proto/simulationio_pb2.py +2 -2
- flwr/proto/simulationio_pb2_grpc.py +34 -0
- flwr/proto/simulationio_pb2_grpc.pyi +13 -0
- flwr/server/app.py +54 -2
- flwr/server/compat/app_utils.py +7 -1
- flwr/server/driver/grpc_driver.py +11 -63
- flwr/server/driver/inmemory_driver.py +5 -1
- flwr/server/run_serverapp.py +8 -9
- flwr/server/serverapp/app.py +25 -3
- flwr/server/strategy/dpfedavg_fixed.py +1 -0
- flwr/server/superlink/driver/serverappio_grpc.py +1 -0
- flwr/server/superlink/driver/serverappio_servicer.py +82 -23
- flwr/server/superlink/ffs/disk_ffs.py +1 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +1 -0
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +32 -12
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +12 -11
- flwr/server/superlink/fleet/message_handler/message_handler.py +32 -5
- flwr/server/superlink/fleet/rest_rere/rest_api.py +4 -1
- flwr/server/superlink/fleet/vce/__init__.py +1 -0
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -0
- flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -0
- flwr/server/superlink/linkstate/in_memory_linkstate.py +21 -30
- flwr/server/superlink/linkstate/linkstate.py +17 -2
- flwr/server/superlink/linkstate/sqlite_linkstate.py +30 -49
- flwr/server/superlink/simulation/simulationio_servicer.py +33 -0
- flwr/server/superlink/utils.py +65 -0
- flwr/simulation/app.py +16 -4
- flwr/simulation/ray_transport/ray_actor.py +1 -0
- flwr/simulation/ray_transport/utils.py +1 -0
- flwr/simulation/run_simulation.py +36 -22
- flwr/simulation/simulationio_connection.py +3 -0
- flwr/superexec/app.py +1 -0
- flwr/superexec/deployment.py +1 -0
- flwr/superexec/exec_grpc.py +19 -1
- flwr/superexec/exec_servicer.py +76 -2
- flwr/superexec/exec_user_auth_interceptor.py +101 -0
- flwr/superexec/executor.py +1 -0
- {flwr-1.13.1.dist-info → flwr-1.14.0.dist-info}/METADATA +8 -7
- {flwr-1.13.1.dist-info → flwr-1.14.0.dist-info}/RECORD +112 -112
- flwr/proto/common_pb2.py +0 -36
- flwr/proto/common_pb2.pyi +0 -121
- flwr/proto/common_pb2_grpc.py +0 -4
- flwr/proto/common_pb2_grpc.pyi +0 -4
- flwr/proto/control_pb2.py +0 -27
- flwr/proto/control_pb2.pyi +0 -7
- flwr/proto/control_pb2_grpc.py +0 -135
- flwr/proto/control_pb2_grpc.pyi +0 -53
- {flwr-1.13.1.dist-info → flwr-1.14.0.dist-info}/LICENSE +0 -0
- {flwr-1.13.1.dist-info → flwr-1.14.0.dist-info}/WHEEL +0 -0
- {flwr-1.13.1.dist-info → flwr-1.14.0.dist-info}/entry_points.txt +0 -0
flwr/server/serverapp/app.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower ServerApp process."""
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
import argparse
|
|
18
19
|
import sys
|
|
19
20
|
from logging import DEBUG, ERROR, INFO
|
|
@@ -24,6 +25,7 @@ from typing import Optional
|
|
|
24
25
|
|
|
25
26
|
from flwr.cli.config_utils import get_fab_metadata
|
|
26
27
|
from flwr.cli.install import install_from_fab
|
|
28
|
+
from flwr.cli.utils import get_sha256_hash
|
|
27
29
|
from flwr.common.args import add_args_flwr_app_common
|
|
28
30
|
from flwr.common.config import (
|
|
29
31
|
get_flwr_dir,
|
|
@@ -50,7 +52,8 @@ from flwr.common.serde import (
|
|
|
50
52
|
run_from_proto,
|
|
51
53
|
run_status_to_proto,
|
|
52
54
|
)
|
|
53
|
-
from flwr.common.
|
|
55
|
+
from flwr.common.telemetry import EventType, event
|
|
56
|
+
from flwr.common.typing import RunNotRunningException, RunStatus
|
|
54
57
|
from flwr.proto.run_pb2 import UpdateRunStatusRequest # pylint: disable=E0611
|
|
55
58
|
from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
|
56
59
|
PullServerAppInputsRequest,
|
|
@@ -96,7 +99,7 @@ def flwr_serverapp() -> None:
|
|
|
96
99
|
restore_output()
|
|
97
100
|
|
|
98
101
|
|
|
99
|
-
def run_serverapp( # pylint: disable=R0914, disable=W0212
|
|
102
|
+
def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
100
103
|
serverappio_api_address: str,
|
|
101
104
|
log_queue: Queue[Optional[str]],
|
|
102
105
|
run_once: bool,
|
|
@@ -112,7 +115,8 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
|
|
|
112
115
|
# Resolve directory where FABs are installed
|
|
113
116
|
flwr_dir_ = get_flwr_dir(flwr_dir)
|
|
114
117
|
log_uploader = None
|
|
115
|
-
|
|
118
|
+
success = True
|
|
119
|
+
hash_run_id = None
|
|
116
120
|
while True:
|
|
117
121
|
|
|
118
122
|
try:
|
|
@@ -128,6 +132,8 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
|
|
|
128
132
|
run = run_from_proto(res.run)
|
|
129
133
|
fab = fab_from_proto(res.fab)
|
|
130
134
|
|
|
135
|
+
hash_run_id = get_sha256_hash(run.run_id)
|
|
136
|
+
|
|
131
137
|
driver.set_run(run.run_id)
|
|
132
138
|
|
|
133
139
|
# Start log uploader for this run
|
|
@@ -170,6 +176,11 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
|
|
|
170
176
|
UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
|
|
171
177
|
)
|
|
172
178
|
|
|
179
|
+
event(
|
|
180
|
+
EventType.FLWR_SERVERAPP_RUN_ENTER,
|
|
181
|
+
event_details={"run-id-hash": hash_run_id},
|
|
182
|
+
)
|
|
183
|
+
|
|
173
184
|
# Load and run the ServerApp with the Driver
|
|
174
185
|
updated_context = run_(
|
|
175
186
|
driver=driver,
|
|
@@ -186,11 +197,18 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
|
|
|
186
197
|
_ = driver._stub.PushServerAppOutputs(out_req)
|
|
187
198
|
|
|
188
199
|
run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
|
|
200
|
+
except RunNotRunningException:
|
|
201
|
+
log(INFO, "")
|
|
202
|
+
log(INFO, "Run ID %s stopped.", run.run_id)
|
|
203
|
+
log(INFO, "")
|
|
204
|
+
run_status = None
|
|
205
|
+
success = False
|
|
189
206
|
|
|
190
207
|
except Exception as ex: # pylint: disable=broad-exception-caught
|
|
191
208
|
exc_entity = "ServerApp"
|
|
192
209
|
log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
|
|
193
210
|
run_status = RunStatus(Status.FINISHED, SubStatus.FAILED, str(ex))
|
|
211
|
+
success = False
|
|
194
212
|
|
|
195
213
|
finally:
|
|
196
214
|
# Stop log uploader for this run and upload final logs
|
|
@@ -206,6 +224,10 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
|
|
|
206
224
|
run_id=run.run_id, run_status=run_status_proto
|
|
207
225
|
)
|
|
208
226
|
)
|
|
227
|
+
event(
|
|
228
|
+
EventType.FLWR_SERVERAPP_RUN_LEAVE,
|
|
229
|
+
event_details={"run-id-hash": hash_run_id, "success": success},
|
|
230
|
+
)
|
|
209
231
|
|
|
210
232
|
# Stop the loop if `flwr-serverapp` is expected to process a single run
|
|
211
233
|
if run_once:
|
|
@@ -32,6 +32,7 @@ from flwr.common.serde import (
|
|
|
32
32
|
fab_from_proto,
|
|
33
33
|
fab_to_proto,
|
|
34
34
|
run_status_from_proto,
|
|
35
|
+
run_status_to_proto,
|
|
35
36
|
run_to_proto,
|
|
36
37
|
user_config_from_proto,
|
|
37
38
|
)
|
|
@@ -48,6 +49,8 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
|
48
49
|
CreateRunResponse,
|
|
49
50
|
GetRunRequest,
|
|
50
51
|
GetRunResponse,
|
|
52
|
+
GetRunStatusRequest,
|
|
53
|
+
GetRunStatusResponse,
|
|
51
54
|
UpdateRunStatusRequest,
|
|
52
55
|
UpdateRunStatusResponse,
|
|
53
56
|
)
|
|
@@ -67,6 +70,7 @@ from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
|
|
|
67
70
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
68
71
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
69
72
|
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
|
73
|
+
from flwr.server.superlink.utils import abort_if
|
|
70
74
|
from flwr.server.utils.validator import validate_task_ins_or_res
|
|
71
75
|
|
|
72
76
|
|
|
@@ -85,7 +89,18 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
85
89
|
) -> GetNodesResponse:
|
|
86
90
|
"""Get available nodes."""
|
|
87
91
|
log(DEBUG, "ServerAppIoServicer.GetNodes")
|
|
92
|
+
|
|
93
|
+
# Init state
|
|
88
94
|
state: LinkState = self.state_factory.state()
|
|
95
|
+
|
|
96
|
+
# Abort if the run is not running
|
|
97
|
+
abort_if(
|
|
98
|
+
request.run_id,
|
|
99
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
100
|
+
state,
|
|
101
|
+
context,
|
|
102
|
+
)
|
|
103
|
+
|
|
89
104
|
all_ids: set[int] = state.get_nodes(request.run_id)
|
|
90
105
|
nodes: list[Node] = [
|
|
91
106
|
Node(node_id=node_id, anonymous=False) for node_id in all_ids
|
|
@@ -123,6 +138,17 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
123
138
|
"""Push a set of TaskIns."""
|
|
124
139
|
log(DEBUG, "ServerAppIoServicer.PushTaskIns")
|
|
125
140
|
|
|
141
|
+
# Init state
|
|
142
|
+
state: LinkState = self.state_factory.state()
|
|
143
|
+
|
|
144
|
+
# Abort if the run is not running
|
|
145
|
+
abort_if(
|
|
146
|
+
request.run_id,
|
|
147
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
148
|
+
state,
|
|
149
|
+
context,
|
|
150
|
+
)
|
|
151
|
+
|
|
126
152
|
# Set pushed_at (timestamp in seconds)
|
|
127
153
|
pushed_at = time.time()
|
|
128
154
|
for task_ins in request.task_ins_list:
|
|
@@ -133,9 +159,9 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
133
159
|
for task_ins in request.task_ins_list:
|
|
134
160
|
validation_errors = validate_task_ins_or_res(task_ins)
|
|
135
161
|
_raise_if(bool(validation_errors), ", ".join(validation_errors))
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
162
|
+
_raise_if(
|
|
163
|
+
request.run_id != task_ins.run_id, "`task_ins` has mismatched `run_id`"
|
|
164
|
+
)
|
|
139
165
|
|
|
140
166
|
# Store each TaskIns
|
|
141
167
|
task_ids: list[Optional[UUID]] = []
|
|
@@ -153,33 +179,35 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
153
179
|
"""Pull a set of TaskRes."""
|
|
154
180
|
log(DEBUG, "ServerAppIoServicer.PullTaskRes")
|
|
155
181
|
|
|
156
|
-
# Convert each task_id str to UUID
|
|
157
|
-
task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
|
|
158
|
-
|
|
159
182
|
# Init state
|
|
160
183
|
state: LinkState = self.state_factory.state()
|
|
161
184
|
|
|
162
|
-
#
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
if context.is_active():
|
|
170
|
-
return
|
|
171
|
-
if context.code() != grpc.StatusCode.OK:
|
|
172
|
-
return
|
|
173
|
-
|
|
174
|
-
# Delete delivered TaskIns and TaskRes
|
|
175
|
-
state.delete_tasks(task_ids=task_ids)
|
|
185
|
+
# Abort if the run is not running
|
|
186
|
+
abort_if(
|
|
187
|
+
request.run_id,
|
|
188
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
189
|
+
state,
|
|
190
|
+
context,
|
|
191
|
+
)
|
|
176
192
|
|
|
177
|
-
|
|
193
|
+
# Convert each task_id str to UUID
|
|
194
|
+
task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
|
|
178
195
|
|
|
179
196
|
# Read from state
|
|
180
197
|
task_res_list: list[TaskRes] = state.get_task_res(task_ids=task_ids)
|
|
181
198
|
|
|
182
|
-
|
|
199
|
+
# Validate request
|
|
200
|
+
for task_res in task_res_list:
|
|
201
|
+
_raise_if(
|
|
202
|
+
request.run_id != task_res.run_id, "`task_res` has mismatched `run_id`"
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
# Delete the TaskIns/TaskRes pairs if TaskRes is found
|
|
206
|
+
task_ins_ids_to_delete = {
|
|
207
|
+
UUID(task_res.task.ancestry[0]) for task_res in task_res_list
|
|
208
|
+
}
|
|
209
|
+
state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
|
|
210
|
+
|
|
183
211
|
return PullTaskResResponse(task_res_list=task_res_list)
|
|
184
212
|
|
|
185
213
|
def GetRun(
|
|
@@ -255,7 +283,18 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
255
283
|
) -> PushServerAppOutputsResponse:
|
|
256
284
|
"""Push ServerApp process outputs."""
|
|
257
285
|
log(DEBUG, "ServerAppIoServicer.PushServerAppOutputs")
|
|
286
|
+
|
|
287
|
+
# Init state
|
|
258
288
|
state = self.state_factory.state()
|
|
289
|
+
|
|
290
|
+
# Abort if the run is not running
|
|
291
|
+
abort_if(
|
|
292
|
+
request.run_id,
|
|
293
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
294
|
+
state,
|
|
295
|
+
context,
|
|
296
|
+
)
|
|
297
|
+
|
|
259
298
|
state.set_serverapp_context(request.run_id, context_from_proto(request.context))
|
|
260
299
|
return PushServerAppOutputsResponse()
|
|
261
300
|
|
|
@@ -263,9 +302,14 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
263
302
|
self, request: UpdateRunStatusRequest, context: grpc.ServicerContext
|
|
264
303
|
) -> UpdateRunStatusResponse:
|
|
265
304
|
"""Update the status of a run."""
|
|
266
|
-
log(DEBUG, "
|
|
305
|
+
log(DEBUG, "ServerAppIoServicer.UpdateRunStatus")
|
|
306
|
+
|
|
307
|
+
# Init state
|
|
267
308
|
state = self.state_factory.state()
|
|
268
309
|
|
|
310
|
+
# Abort if the run is finished
|
|
311
|
+
abort_if(request.run_id, [Status.FINISHED], state, context)
|
|
312
|
+
|
|
269
313
|
# Update the run status
|
|
270
314
|
state.update_run_status(
|
|
271
315
|
run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
|
|
@@ -284,6 +328,21 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
284
328
|
state.add_serverapp_log(request.run_id, merged_logs)
|
|
285
329
|
return PushLogsResponse()
|
|
286
330
|
|
|
331
|
+
def GetRunStatus(
|
|
332
|
+
self, request: GetRunStatusRequest, context: grpc.ServicerContext
|
|
333
|
+
) -> GetRunStatusResponse:
|
|
334
|
+
"""Get the status of a run."""
|
|
335
|
+
log(DEBUG, "ServerAppIoServicer.GetRunStatus")
|
|
336
|
+
state = self.state_factory.state()
|
|
337
|
+
|
|
338
|
+
# Get run status from LinkState
|
|
339
|
+
run_statuses = state.get_run_status(set(request.run_ids))
|
|
340
|
+
run_status_dict = {
|
|
341
|
+
run_id: run_status_to_proto(run_status)
|
|
342
|
+
for run_id, run_status in run_statuses.items()
|
|
343
|
+
}
|
|
344
|
+
return GetRunStatusResponse(run_status_dict=run_status_dict)
|
|
345
|
+
|
|
287
346
|
|
|
288
347
|
def _raise_if(validation_error: bool, detail: str) -> None:
|
|
289
348
|
if validation_error:
|
|
@@ -20,6 +20,7 @@ from logging import DEBUG, INFO
|
|
|
20
20
|
import grpc
|
|
21
21
|
|
|
22
22
|
from flwr.common.logger import log
|
|
23
|
+
from flwr.common.typing import InvalidRunStatusException
|
|
23
24
|
from flwr.proto import fleet_pb2_grpc # pylint: disable=E0611
|
|
24
25
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
25
26
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
@@ -38,6 +39,7 @@ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=
|
|
|
38
39
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
39
40
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
40
41
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
42
|
+
from flwr.server.superlink.utils import abort_grpc_context
|
|
41
43
|
|
|
42
44
|
|
|
43
45
|
class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
@@ -105,27 +107,45 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
105
107
|
)
|
|
106
108
|
else:
|
|
107
109
|
log(INFO, "[Fleet.PushTaskRes] No task results to push")
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
110
|
+
|
|
111
|
+
try:
|
|
112
|
+
res = message_handler.push_task_res(
|
|
113
|
+
request=request,
|
|
114
|
+
state=self.state_factory.state(),
|
|
115
|
+
)
|
|
116
|
+
except InvalidRunStatusException as e:
|
|
117
|
+
abort_grpc_context(e.message, context)
|
|
118
|
+
|
|
119
|
+
return res
|
|
112
120
|
|
|
113
121
|
def GetRun(
|
|
114
122
|
self, request: GetRunRequest, context: grpc.ServicerContext
|
|
115
123
|
) -> GetRunResponse:
|
|
116
124
|
"""Get run information."""
|
|
117
125
|
log(INFO, "[Fleet.GetRun] Requesting `Run` for run_id=%s", request.run_id)
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
126
|
+
|
|
127
|
+
try:
|
|
128
|
+
res = message_handler.get_run(
|
|
129
|
+
request=request,
|
|
130
|
+
state=self.state_factory.state(),
|
|
131
|
+
)
|
|
132
|
+
except InvalidRunStatusException as e:
|
|
133
|
+
abort_grpc_context(e.message, context)
|
|
134
|
+
|
|
135
|
+
return res
|
|
122
136
|
|
|
123
137
|
def GetFab(
|
|
124
138
|
self, request: GetFabRequest, context: grpc.ServicerContext
|
|
125
139
|
) -> GetFabResponse:
|
|
126
140
|
"""Get FAB."""
|
|
127
141
|
log(INFO, "[Fleet.GetFab] Requesting FAB for fab_hash=%s", request.hash_str)
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
142
|
+
try:
|
|
143
|
+
res = message_handler.get_fab(
|
|
144
|
+
request=request,
|
|
145
|
+
ffs=self.ffs_factory.ffs(),
|
|
146
|
+
state=self.state_factory.state(),
|
|
147
|
+
)
|
|
148
|
+
except InvalidRunStatusException as e:
|
|
149
|
+
abort_grpc_context(e.message, context)
|
|
150
|
+
|
|
151
|
+
return res
|
|
@@ -45,7 +45,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
45
45
|
)
|
|
46
46
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
47
47
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
48
|
-
from flwr.server.superlink.linkstate import
|
|
48
|
+
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
49
49
|
|
|
50
50
|
_PUBLIC_KEY_HEADER = "public-key"
|
|
51
51
|
_AUTH_TOKEN_HEADER = "auth-token"
|
|
@@ -84,15 +84,16 @@ def _get_value_from_tuples(
|
|
|
84
84
|
class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
85
85
|
"""Server interceptor for node authentication."""
|
|
86
86
|
|
|
87
|
-
def __init__(self,
|
|
88
|
-
self.
|
|
87
|
+
def __init__(self, state_factory: LinkStateFactory):
|
|
88
|
+
self.state_factory = state_factory
|
|
89
|
+
state = self.state_factory.state()
|
|
89
90
|
|
|
90
91
|
self.node_public_keys = state.get_node_public_keys()
|
|
91
92
|
if len(self.node_public_keys) == 0:
|
|
92
93
|
log(WARNING, "Authentication enabled, but no known public keys configured")
|
|
93
94
|
|
|
94
|
-
private_key =
|
|
95
|
-
public_key =
|
|
95
|
+
private_key = state.get_server_private_key()
|
|
96
|
+
public_key = state.get_server_public_key()
|
|
96
97
|
|
|
97
98
|
if private_key is None or public_key is None:
|
|
98
99
|
raise ValueError("Error loading authentication keys")
|
|
@@ -154,7 +155,7 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
154
155
|
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|
|
155
156
|
|
|
156
157
|
# Verify node_id
|
|
157
|
-
node_id = self.state.get_node_id(node_public_key_bytes)
|
|
158
|
+
node_id = self.state_factory.state().get_node_id(node_public_key_bytes)
|
|
158
159
|
|
|
159
160
|
if not self._verify_node_id(node_id, request):
|
|
160
161
|
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|
|
@@ -186,7 +187,7 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
186
187
|
return False
|
|
187
188
|
return request.task_res_list[0].task.producer.node_id == node_id
|
|
188
189
|
if isinstance(request, GetRunRequest):
|
|
189
|
-
return node_id in self.state.get_nodes(request.run_id)
|
|
190
|
+
return node_id in self.state_factory.state().get_nodes(request.run_id)
|
|
190
191
|
return request.node.node_id == node_id
|
|
191
192
|
|
|
192
193
|
def _verify_hmac(
|
|
@@ -210,17 +211,17 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
210
211
|
),
|
|
211
212
|
)
|
|
212
213
|
)
|
|
213
|
-
|
|
214
|
-
node_id =
|
|
214
|
+
state = self.state_factory.state()
|
|
215
|
+
node_id = state.get_node_id(public_key_bytes)
|
|
215
216
|
|
|
216
217
|
# Handle `CreateNode` here instead of calling the default method handler
|
|
217
218
|
# Return previously assigned `node_id` for the provided `public_key`
|
|
218
219
|
if node_id is not None:
|
|
219
|
-
|
|
220
|
+
state.acknowledge_ping(node_id, request.ping_interval)
|
|
220
221
|
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
|
|
221
222
|
|
|
222
223
|
# No `node_id` exists for the provided `public_key`
|
|
223
224
|
# Handle `CreateNode` here instead of calling the default method handler
|
|
224
225
|
# Note: the innermost `CreateNode` method will never be called
|
|
225
|
-
node_id =
|
|
226
|
+
node_id = state.create_node(request.ping_interval, public_key_bytes)
|
|
226
227
|
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
|
|
@@ -19,8 +19,9 @@ import time
|
|
|
19
19
|
from typing import Optional
|
|
20
20
|
from uuid import UUID
|
|
21
21
|
|
|
22
|
+
from flwr.common.constant import Status
|
|
22
23
|
from flwr.common.serde import fab_to_proto, user_config_to_proto
|
|
23
|
-
from flwr.common.typing import Fab
|
|
24
|
+
from flwr.common.typing import Fab, InvalidRunStatusException
|
|
24
25
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
25
26
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
26
27
|
CreateNodeRequest,
|
|
@@ -44,6 +45,7 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
|
44
45
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
45
46
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
46
47
|
from flwr.server.superlink.linkstate import LinkState
|
|
48
|
+
from flwr.server.superlink.utils import check_abort
|
|
47
49
|
|
|
48
50
|
|
|
49
51
|
def create_node(
|
|
@@ -98,6 +100,15 @@ def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResR
|
|
|
98
100
|
task_res: TaskRes = request.task_res_list[0]
|
|
99
101
|
# pylint: enable=no-member
|
|
100
102
|
|
|
103
|
+
# Abort if the run is not running
|
|
104
|
+
abort_msg = check_abort(
|
|
105
|
+
task_res.run_id,
|
|
106
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
107
|
+
state,
|
|
108
|
+
)
|
|
109
|
+
if abort_msg:
|
|
110
|
+
raise InvalidRunStatusException(abort_msg)
|
|
111
|
+
|
|
101
112
|
# Set pushed_at (timestamp in seconds)
|
|
102
113
|
task_res.task.pushed_at = time.time()
|
|
103
114
|
|
|
@@ -112,15 +123,22 @@ def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResR
|
|
|
112
123
|
return response
|
|
113
124
|
|
|
114
125
|
|
|
115
|
-
def get_run(
|
|
116
|
-
request: GetRunRequest, state: LinkState # pylint: disable=W0613
|
|
117
|
-
) -> GetRunResponse:
|
|
126
|
+
def get_run(request: GetRunRequest, state: LinkState) -> GetRunResponse:
|
|
118
127
|
"""Get run information."""
|
|
119
128
|
run = state.get_run(request.run_id)
|
|
120
129
|
|
|
121
130
|
if run is None:
|
|
122
131
|
return GetRunResponse()
|
|
123
132
|
|
|
133
|
+
# Abort if the run is not running
|
|
134
|
+
abort_msg = check_abort(
|
|
135
|
+
request.run_id,
|
|
136
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
137
|
+
state,
|
|
138
|
+
)
|
|
139
|
+
if abort_msg:
|
|
140
|
+
raise InvalidRunStatusException(abort_msg)
|
|
141
|
+
|
|
124
142
|
return GetRunResponse(
|
|
125
143
|
run=Run(
|
|
126
144
|
run_id=run.run_id,
|
|
@@ -133,9 +151,18 @@ def get_run(
|
|
|
133
151
|
|
|
134
152
|
|
|
135
153
|
def get_fab(
|
|
136
|
-
request: GetFabRequest, ffs: Ffs # pylint: disable=W0613
|
|
154
|
+
request: GetFabRequest, ffs: Ffs, state: LinkState # pylint: disable=W0613
|
|
137
155
|
) -> GetFabResponse:
|
|
138
156
|
"""Get FAB."""
|
|
157
|
+
# Abort if the run is not running
|
|
158
|
+
abort_msg = check_abort(
|
|
159
|
+
request.run_id,
|
|
160
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
161
|
+
state,
|
|
162
|
+
)
|
|
163
|
+
if abort_msg:
|
|
164
|
+
raise InvalidRunStatusException(abort_msg)
|
|
165
|
+
|
|
139
166
|
if result := ffs.get(request.hash_str):
|
|
140
167
|
fab = Fab(request.hash_str, result[0])
|
|
141
168
|
return GetFabResponse(fab=fab_to_proto(fab))
|
|
@@ -154,8 +154,11 @@ async def get_fab(request: GetFabRequest) -> GetFabResponse:
|
|
|
154
154
|
# Get ffs from app
|
|
155
155
|
ffs: Ffs = cast(FfsFactory, app.state.FFS_FACTORY).ffs()
|
|
156
156
|
|
|
157
|
+
# Get state from app
|
|
158
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
|
159
|
+
|
|
157
160
|
# Handle message
|
|
158
|
-
return message_handler.get_fab(request=request, ffs=ffs)
|
|
161
|
+
return message_handler.get_fab(request=request, ffs=ffs, state=state)
|
|
159
162
|
|
|
160
163
|
|
|
161
164
|
routes = [
|
|
@@ -265,41 +265,15 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
265
265
|
for task_res in task_res_found:
|
|
266
266
|
task_res.task.delivered_at = delivered_at
|
|
267
267
|
|
|
268
|
-
# Cleanup
|
|
269
|
-
self._force_delete_tasks_by_ids(set(ret.keys()))
|
|
270
|
-
|
|
271
268
|
return list(ret.values())
|
|
272
269
|
|
|
273
|
-
def delete_tasks(self,
|
|
274
|
-
"""Delete
|
|
275
|
-
|
|
276
|
-
task_res_to_be_deleted: set[UUID] = set()
|
|
277
|
-
|
|
278
|
-
with self.lock:
|
|
279
|
-
for task_ins_id in task_ids:
|
|
280
|
-
# Find the task_id of the matching task_res
|
|
281
|
-
for task_res_id, task_res in self.task_res_store.items():
|
|
282
|
-
if UUID(task_res.task.ancestry[0]) != task_ins_id:
|
|
283
|
-
continue
|
|
284
|
-
if task_res.task.delivered_at == "":
|
|
285
|
-
continue
|
|
286
|
-
|
|
287
|
-
task_ins_to_be_deleted.add(task_ins_id)
|
|
288
|
-
task_res_to_be_deleted.add(task_res_id)
|
|
289
|
-
|
|
290
|
-
for task_id in task_ins_to_be_deleted:
|
|
291
|
-
del self.task_ins_store[task_id]
|
|
292
|
-
del self.task_ins_id_to_task_res_id[task_id]
|
|
293
|
-
for task_id in task_res_to_be_deleted:
|
|
294
|
-
del self.task_res_store[task_id]
|
|
295
|
-
|
|
296
|
-
def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
|
|
297
|
-
"""Delete tasks based on a set of TaskIns IDs."""
|
|
298
|
-
if not task_ids:
|
|
270
|
+
def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
|
|
271
|
+
"""Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
|
|
272
|
+
if not task_ins_ids:
|
|
299
273
|
return
|
|
300
274
|
|
|
301
275
|
with self.lock:
|
|
302
|
-
for task_id in
|
|
276
|
+
for task_id in task_ins_ids:
|
|
303
277
|
# Delete TaskIns
|
|
304
278
|
if task_id in self.task_ins_store:
|
|
305
279
|
del self.task_ins_store[task_id]
|
|
@@ -308,6 +282,16 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
308
282
|
task_res_id = self.task_ins_id_to_task_res_id.pop(task_id)
|
|
309
283
|
del self.task_res_store[task_res_id]
|
|
310
284
|
|
|
285
|
+
def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
|
286
|
+
"""Get all TaskIns IDs for the given run_id."""
|
|
287
|
+
task_id_list: set[UUID] = set()
|
|
288
|
+
with self.lock:
|
|
289
|
+
for task_id, task_ins in self.task_ins_store.items():
|
|
290
|
+
if task_ins.run_id == run_id:
|
|
291
|
+
task_id_list.add(task_id)
|
|
292
|
+
|
|
293
|
+
return task_id_list
|
|
294
|
+
|
|
311
295
|
def num_task_ins(self) -> int:
|
|
312
296
|
"""Calculate the number of task_ins in store.
|
|
313
297
|
|
|
@@ -446,6 +430,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
446
430
|
"""Retrieve `server_public_key` in urlsafe bytes."""
|
|
447
431
|
return self.server_public_key
|
|
448
432
|
|
|
433
|
+
def clear_supernode_auth_keys_and_credentials(self) -> None:
|
|
434
|
+
"""Clear stored `node_public_keys` and credentials in the link state if any."""
|
|
435
|
+
with self.lock:
|
|
436
|
+
self.server_private_key = None
|
|
437
|
+
self.server_public_key = None
|
|
438
|
+
self.node_public_keys.clear()
|
|
439
|
+
|
|
449
440
|
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
|
|
450
441
|
"""Store a set of `node_public_keys` in the link state."""
|
|
451
442
|
with self.lock:
|