flwr-nightly 1.14.0.dev20241204__py3-none-any.whl → 1.14.0.dev20241216__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +5 -0
- flwr/cli/build.py +1 -0
- flwr/cli/cli_user_auth_interceptor.py +86 -0
- flwr/cli/config_utils.py +19 -2
- flwr/cli/example.py +1 -0
- flwr/cli/install.py +1 -0
- flwr/cli/log.py +11 -31
- flwr/cli/login/__init__.py +22 -0
- flwr/cli/login/login.py +81 -0
- flwr/cli/ls.py +25 -55
- flwr/cli/new/__init__.py +1 -0
- flwr/cli/new/new.py +2 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -2
- flwr/cli/run/__init__.py +1 -0
- flwr/cli/run/run.py +17 -39
- flwr/cli/stop.py +129 -0
- flwr/cli/utils.py +96 -1
- flwr/client/app.py +14 -3
- flwr/client/client.py +1 -0
- flwr/client/clientapp/app.py +4 -1
- flwr/client/clientapp/utils.py +1 -0
- flwr/client/grpc_adapter_client/connection.py +1 -1
- flwr/client/grpc_client/connection.py +1 -1
- flwr/client/grpc_rere_client/connection.py +13 -7
- flwr/client/message_handler/message_handler.py +1 -0
- flwr/client/mod/comms_mods.py +1 -0
- flwr/client/mod/localdp_mod.py +1 -1
- flwr/client/nodestate/__init__.py +1 -0
- flwr/client/nodestate/nodestate.py +1 -0
- flwr/client/nodestate/nodestate_factory.py +1 -0
- flwr/client/rest_client/connection.py +3 -3
- flwr/client/supernode/app.py +1 -0
- flwr/common/address.py +1 -0
- flwr/common/args.py +1 -0
- flwr/common/auth_plugin/__init__.py +24 -0
- flwr/common/auth_plugin/auth_plugin.py +111 -0
- flwr/common/config.py +3 -1
- flwr/common/constant.py +6 -1
- flwr/common/logger.py +17 -1
- flwr/common/message.py +1 -0
- flwr/common/object_ref.py +57 -54
- flwr/common/pyproject.py +1 -0
- flwr/common/record/__init__.py +1 -0
- flwr/common/record/parametersrecord.py +1 -0
- flwr/common/retry_invoker.py +77 -0
- flwr/common/secure_aggregation/secaggplus_utils.py +2 -2
- flwr/common/telemetry.py +2 -1
- flwr/common/typing.py +12 -0
- flwr/common/version.py +1 -0
- flwr/proto/exec_pb2.py +27 -3
- flwr/proto/exec_pb2.pyi +103 -0
- flwr/proto/exec_pb2_grpc.py +102 -0
- flwr/proto/exec_pb2_grpc.pyi +39 -0
- flwr/proto/fab_pb2.py +4 -4
- flwr/proto/fab_pb2.pyi +4 -1
- flwr/proto/serverappio_pb2.py +18 -18
- flwr/proto/serverappio_pb2.pyi +8 -2
- flwr/proto/serverappio_pb2_grpc.py +34 -0
- flwr/proto/serverappio_pb2_grpc.pyi +13 -0
- flwr/proto/simulationio_pb2.py +2 -2
- flwr/proto/simulationio_pb2_grpc.py +34 -0
- flwr/proto/simulationio_pb2_grpc.pyi +13 -0
- flwr/server/app.py +52 -1
- flwr/server/compat/app_utils.py +7 -1
- flwr/server/driver/grpc_driver.py +11 -63
- flwr/server/driver/inmemory_driver.py +5 -1
- flwr/server/serverapp/app.py +9 -2
- flwr/server/strategy/dpfedavg_fixed.py +1 -0
- flwr/server/superlink/driver/serverappio_grpc.py +1 -0
- flwr/server/superlink/driver/serverappio_servicer.py +72 -22
- flwr/server/superlink/ffs/disk_ffs.py +1 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +1 -0
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +32 -12
- flwr/server/superlink/fleet/message_handler/message_handler.py +32 -5
- flwr/server/superlink/fleet/rest_rere/rest_api.py +4 -1
- flwr/server/superlink/fleet/vce/__init__.py +1 -0
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -0
- flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -0
- flwr/server/superlink/linkstate/in_memory_linkstate.py +14 -30
- flwr/server/superlink/linkstate/linkstate.py +13 -2
- flwr/server/superlink/linkstate/sqlite_linkstate.py +24 -44
- flwr/server/superlink/simulation/simulationio_servicer.py +20 -0
- flwr/server/superlink/utils.py +65 -0
- flwr/simulation/app.py +1 -0
- flwr/simulation/ray_transport/ray_actor.py +1 -0
- flwr/simulation/ray_transport/utils.py +1 -0
- flwr/simulation/run_simulation.py +1 -15
- flwr/simulation/simulationio_connection.py +3 -0
- flwr/superexec/app.py +1 -0
- flwr/superexec/deployment.py +1 -0
- flwr/superexec/exec_grpc.py +19 -1
- flwr/superexec/exec_servicer.py +76 -2
- flwr/superexec/exec_user_auth_interceptor.py +101 -0
- flwr/superexec/executor.py +1 -0
- {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241216.dist-info}/METADATA +8 -7
- {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241216.dist-info}/RECORD +101 -93
- {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241216.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241216.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241216.dist-info}/entry_points.txt +0 -0
|
@@ -265,41 +265,15 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
265
265
|
for task_res in task_res_found:
|
|
266
266
|
task_res.task.delivered_at = delivered_at
|
|
267
267
|
|
|
268
|
-
# Cleanup
|
|
269
|
-
self._force_delete_tasks_by_ids(set(ret.keys()))
|
|
270
|
-
|
|
271
268
|
return list(ret.values())
|
|
272
269
|
|
|
273
|
-
def delete_tasks(self,
|
|
274
|
-
"""Delete
|
|
275
|
-
|
|
276
|
-
task_res_to_be_deleted: set[UUID] = set()
|
|
277
|
-
|
|
278
|
-
with self.lock:
|
|
279
|
-
for task_ins_id in task_ids:
|
|
280
|
-
# Find the task_id of the matching task_res
|
|
281
|
-
for task_res_id, task_res in self.task_res_store.items():
|
|
282
|
-
if UUID(task_res.task.ancestry[0]) != task_ins_id:
|
|
283
|
-
continue
|
|
284
|
-
if task_res.task.delivered_at == "":
|
|
285
|
-
continue
|
|
286
|
-
|
|
287
|
-
task_ins_to_be_deleted.add(task_ins_id)
|
|
288
|
-
task_res_to_be_deleted.add(task_res_id)
|
|
289
|
-
|
|
290
|
-
for task_id in task_ins_to_be_deleted:
|
|
291
|
-
del self.task_ins_store[task_id]
|
|
292
|
-
del self.task_ins_id_to_task_res_id[task_id]
|
|
293
|
-
for task_id in task_res_to_be_deleted:
|
|
294
|
-
del self.task_res_store[task_id]
|
|
295
|
-
|
|
296
|
-
def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
|
|
297
|
-
"""Delete tasks based on a set of TaskIns IDs."""
|
|
298
|
-
if not task_ids:
|
|
270
|
+
def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
|
|
271
|
+
"""Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
|
|
272
|
+
if not task_ins_ids:
|
|
299
273
|
return
|
|
300
274
|
|
|
301
275
|
with self.lock:
|
|
302
|
-
for task_id in
|
|
276
|
+
for task_id in task_ins_ids:
|
|
303
277
|
# Delete TaskIns
|
|
304
278
|
if task_id in self.task_ins_store:
|
|
305
279
|
del self.task_ins_store[task_id]
|
|
@@ -308,6 +282,16 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
308
282
|
task_res_id = self.task_ins_id_to_task_res_id.pop(task_id)
|
|
309
283
|
del self.task_res_store[task_res_id]
|
|
310
284
|
|
|
285
|
+
def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
|
286
|
+
"""Get all TaskIns IDs for the given run_id."""
|
|
287
|
+
task_id_list: set[UUID] = set()
|
|
288
|
+
with self.lock:
|
|
289
|
+
for task_id, task_ins in self.task_ins_store.items():
|
|
290
|
+
if task_ins.run_id == run_id:
|
|
291
|
+
task_id_list.add(task_id)
|
|
292
|
+
|
|
293
|
+
return task_id_list
|
|
294
|
+
|
|
311
295
|
def num_task_ins(self) -> int:
|
|
312
296
|
"""Calculate the number of task_ins in store.
|
|
313
297
|
|
|
@@ -139,8 +139,19 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
139
139
|
"""
|
|
140
140
|
|
|
141
141
|
@abc.abstractmethod
|
|
142
|
-
def delete_tasks(self,
|
|
143
|
-
"""Delete
|
|
142
|
+
def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
|
|
143
|
+
"""Delete TaskIns/TaskRes pairs based on provided TaskIns IDs.
|
|
144
|
+
|
|
145
|
+
Parameters
|
|
146
|
+
----------
|
|
147
|
+
task_ins_ids : set[UUID]
|
|
148
|
+
A set of TaskIns IDs. For each ID in the set, the corresponding
|
|
149
|
+
TaskIns and its associated TaskRes will be deleted.
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
@abc.abstractmethod
|
|
153
|
+
def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
|
154
|
+
"""Get all TaskIns IDs for the given run_id."""
|
|
144
155
|
|
|
145
156
|
@abc.abstractmethod
|
|
146
157
|
def create_node(
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""SQLite based implemenation of the link state."""
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
# pylint: disable=too-many-lines
|
|
18
19
|
|
|
19
20
|
import json
|
|
@@ -566,9 +567,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
566
567
|
data: list[Any] = [delivered_at] + task_res_ids
|
|
567
568
|
self.query(query, data)
|
|
568
569
|
|
|
569
|
-
# Cleanup
|
|
570
|
-
self._force_delete_tasks_by_ids(set(ret.keys()))
|
|
571
|
-
|
|
572
570
|
return list(ret.values())
|
|
573
571
|
|
|
574
572
|
def num_task_ins(self) -> int:
|
|
@@ -592,68 +590,50 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
592
590
|
result: dict[str, int] = rows[0]
|
|
593
591
|
return result["num"]
|
|
594
592
|
|
|
595
|
-
def delete_tasks(self,
|
|
596
|
-
"""Delete
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
593
|
+
def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
|
|
594
|
+
"""Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
|
|
595
|
+
if not task_ins_ids:
|
|
596
|
+
return
|
|
597
|
+
if self.conn is None:
|
|
598
|
+
raise AttributeError("LinkState not initialized")
|
|
600
599
|
|
|
601
|
-
placeholders = ",".join([
|
|
602
|
-
data =
|
|
600
|
+
placeholders = ",".join(["?"] * len(task_ins_ids))
|
|
601
|
+
data = tuple(str(task_id) for task_id in task_ins_ids)
|
|
603
602
|
|
|
604
|
-
#
|
|
603
|
+
# Delete task_ins
|
|
605
604
|
query_1 = f"""
|
|
606
605
|
DELETE FROM task_ins
|
|
607
|
-
WHERE
|
|
608
|
-
AND task_id IN (
|
|
609
|
-
SELECT ancestry
|
|
610
|
-
FROM task_res
|
|
611
|
-
WHERE ancestry IN ({placeholders})
|
|
612
|
-
AND delivered_at != ''
|
|
613
|
-
);
|
|
606
|
+
WHERE task_id IN ({placeholders});
|
|
614
607
|
"""
|
|
615
608
|
|
|
616
|
-
#
|
|
609
|
+
# Delete task_res
|
|
617
610
|
query_2 = f"""
|
|
618
611
|
DELETE FROM task_res
|
|
619
|
-
WHERE ancestry IN ({placeholders})
|
|
620
|
-
AND delivered_at != '';
|
|
612
|
+
WHERE ancestry IN ({placeholders});
|
|
621
613
|
"""
|
|
622
614
|
|
|
623
|
-
if self.conn is None:
|
|
624
|
-
raise AttributeError("LinkState not intitialized")
|
|
625
|
-
|
|
626
615
|
with self.conn:
|
|
627
616
|
self.conn.execute(query_1, data)
|
|
628
617
|
self.conn.execute(query_2, data)
|
|
629
618
|
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
|
|
633
|
-
"""Delete tasks based on a set of TaskIns IDs."""
|
|
634
|
-
if not task_ids:
|
|
635
|
-
return
|
|
619
|
+
def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
|
620
|
+
"""Get all TaskIns IDs for the given run_id."""
|
|
636
621
|
if self.conn is None:
|
|
637
622
|
raise AttributeError("LinkState not initialized")
|
|
638
623
|
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
query_1 = f"""
|
|
644
|
-
DELETE FROM task_ins
|
|
645
|
-
WHERE task_id IN ({placeholders});
|
|
624
|
+
query = """
|
|
625
|
+
SELECT task_id
|
|
626
|
+
FROM task_ins
|
|
627
|
+
WHERE run_id = :run_id;
|
|
646
628
|
"""
|
|
647
629
|
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
DELETE FROM task_res
|
|
651
|
-
WHERE ancestry IN ({placeholders});
|
|
652
|
-
"""
|
|
630
|
+
sint64_run_id = convert_uint64_to_sint64(run_id)
|
|
631
|
+
data = {"run_id": sint64_run_id}
|
|
653
632
|
|
|
654
633
|
with self.conn:
|
|
655
|
-
self.conn.execute(
|
|
656
|
-
|
|
634
|
+
rows = self.conn.execute(query, data).fetchall()
|
|
635
|
+
|
|
636
|
+
return {UUID(row["task_id"]) for row in rows}
|
|
657
637
|
|
|
658
638
|
def create_node(
|
|
659
639
|
self, ping_interval: float, public_key: Optional[bytes] = None
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""SimulationIo API servicer."""
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
import threading
|
|
18
19
|
from logging import DEBUG, INFO
|
|
19
20
|
|
|
@@ -28,6 +29,7 @@ from flwr.common.serde import (
|
|
|
28
29
|
context_to_proto,
|
|
29
30
|
fab_to_proto,
|
|
30
31
|
run_status_from_proto,
|
|
32
|
+
run_status_to_proto,
|
|
31
33
|
run_to_proto,
|
|
32
34
|
)
|
|
33
35
|
from flwr.common.typing import Fab, RunStatus
|
|
@@ -39,6 +41,8 @@ from flwr.proto.log_pb2 import ( # pylint: disable=E0611
|
|
|
39
41
|
from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
40
42
|
GetFederationOptionsRequest,
|
|
41
43
|
GetFederationOptionsResponse,
|
|
44
|
+
GetRunStatusRequest,
|
|
45
|
+
GetRunStatusResponse,
|
|
42
46
|
UpdateRunStatusRequest,
|
|
43
47
|
UpdateRunStatusResponse,
|
|
44
48
|
)
|
|
@@ -122,6 +126,22 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
|
122
126
|
)
|
|
123
127
|
return UpdateRunStatusResponse()
|
|
124
128
|
|
|
129
|
+
def GetRunStatus(
|
|
130
|
+
self, request: GetRunStatusRequest, context: ServicerContext
|
|
131
|
+
) -> GetRunStatusResponse:
|
|
132
|
+
"""Get status of requested runs."""
|
|
133
|
+
log(DEBUG, "SimultionIoServicer.GetRunStatus")
|
|
134
|
+
state = self.state_factory.state()
|
|
135
|
+
|
|
136
|
+
statuses = state.get_run_status(set(request.run_ids))
|
|
137
|
+
|
|
138
|
+
return GetRunStatusResponse(
|
|
139
|
+
run_status_dict={
|
|
140
|
+
run_id: run_status_to_proto(status)
|
|
141
|
+
for run_id, status in statuses.items()
|
|
142
|
+
}
|
|
143
|
+
)
|
|
144
|
+
|
|
125
145
|
def PushLogs(
|
|
126
146
|
self, request: PushLogsRequest, context: grpc.ServicerContext
|
|
127
147
|
) -> PushLogsResponse:
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""SuperLink utilities."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from typing import Union
|
|
19
|
+
|
|
20
|
+
import grpc
|
|
21
|
+
|
|
22
|
+
from flwr.common.constant import Status, SubStatus
|
|
23
|
+
from flwr.common.typing import RunStatus
|
|
24
|
+
from flwr.server.superlink.linkstate import LinkState
|
|
25
|
+
|
|
26
|
+
_STATUS_TO_MSG = {
|
|
27
|
+
Status.PENDING: "Run is pending.",
|
|
28
|
+
Status.STARTING: "Run is starting.",
|
|
29
|
+
Status.RUNNING: "Run is running.",
|
|
30
|
+
Status.FINISHED: "Run is finished.",
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def check_abort(
|
|
35
|
+
run_id: int,
|
|
36
|
+
abort_status_list: list[str],
|
|
37
|
+
state: LinkState,
|
|
38
|
+
) -> Union[str, None]:
|
|
39
|
+
"""Check if the status of the provided `run_id` is in `abort_status_list`."""
|
|
40
|
+
run_status: RunStatus = state.get_run_status({run_id})[run_id]
|
|
41
|
+
|
|
42
|
+
if run_status.status in abort_status_list:
|
|
43
|
+
msg = _STATUS_TO_MSG[run_status.status]
|
|
44
|
+
if run_status.sub_status == SubStatus.STOPPED:
|
|
45
|
+
msg += " Stopped by user."
|
|
46
|
+
return msg
|
|
47
|
+
|
|
48
|
+
return None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def abort_grpc_context(msg: Union[str, None], context: grpc.ServicerContext) -> None:
|
|
52
|
+
"""Abort context with statuscode PERMISSION_DENIED if `msg` is not None."""
|
|
53
|
+
if msg is not None:
|
|
54
|
+
context.abort(grpc.StatusCode.PERMISSION_DENIED, msg)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def abort_if(
|
|
58
|
+
run_id: int,
|
|
59
|
+
abort_status_list: list[str],
|
|
60
|
+
state: LinkState,
|
|
61
|
+
context: grpc.ServicerContext,
|
|
62
|
+
) -> None:
|
|
63
|
+
"""Abort context if status of the provided `run_id` is in `abort_status_list`."""
|
|
64
|
+
msg = check_abort(run_id, abort_status_list, state)
|
|
65
|
+
abort_grpc_context(msg, context)
|
flwr/simulation/app.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower Simulation."""
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
import argparse
|
|
18
19
|
import asyncio
|
|
19
20
|
import json
|
|
@@ -23,7 +24,6 @@ import threading
|
|
|
23
24
|
import traceback
|
|
24
25
|
from logging import DEBUG, ERROR, INFO, WARNING
|
|
25
26
|
from pathlib import Path
|
|
26
|
-
from time import sleep
|
|
27
27
|
from typing import Any, Optional
|
|
28
28
|
|
|
29
29
|
from flwr.cli.config_utils import load_and_validate
|
|
@@ -135,7 +135,6 @@ def run_simulation_from_cli() -> None:
|
|
|
135
135
|
app_dir=args.app,
|
|
136
136
|
run=run,
|
|
137
137
|
enable_tf_gpu_growth=args.enable_tf_gpu_growth,
|
|
138
|
-
delay_start=args.delay_start,
|
|
139
138
|
verbose_logging=args.verbose,
|
|
140
139
|
server_app_run_config=fused_config,
|
|
141
140
|
is_app=True,
|
|
@@ -308,7 +307,6 @@ def _main_loop(
|
|
|
308
307
|
enable_tf_gpu_growth: bool,
|
|
309
308
|
run: Run,
|
|
310
309
|
exit_event: EventType,
|
|
311
|
-
delay_start: int,
|
|
312
310
|
flwr_dir: Optional[str] = None,
|
|
313
311
|
client_app: Optional[ClientApp] = None,
|
|
314
312
|
client_app_attr: Optional[str] = None,
|
|
@@ -353,9 +351,6 @@ def _main_loop(
|
|
|
353
351
|
run_id=run.run_id,
|
|
354
352
|
)
|
|
355
353
|
|
|
356
|
-
# Buffer time so the `ServerApp` in separate thread is ready
|
|
357
|
-
log(DEBUG, "Buffer time delay: %ds", delay_start)
|
|
358
|
-
sleep(delay_start)
|
|
359
354
|
# Start Simulation Engine
|
|
360
355
|
vce.start_vce(
|
|
361
356
|
num_supernodes=num_supernodes,
|
|
@@ -404,7 +399,6 @@ def _run_simulation(
|
|
|
404
399
|
flwr_dir: Optional[str] = None,
|
|
405
400
|
run: Optional[Run] = None,
|
|
406
401
|
enable_tf_gpu_growth: bool = False,
|
|
407
|
-
delay_start: int = 5,
|
|
408
402
|
verbose_logging: bool = False,
|
|
409
403
|
is_app: bool = False,
|
|
410
404
|
) -> None:
|
|
@@ -459,7 +453,6 @@ def _run_simulation(
|
|
|
459
453
|
enable_tf_gpu_growth,
|
|
460
454
|
run,
|
|
461
455
|
exit_event,
|
|
462
|
-
delay_start,
|
|
463
456
|
flwr_dir,
|
|
464
457
|
client_app,
|
|
465
458
|
client_app_attr,
|
|
@@ -537,13 +530,6 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser:
|
|
|
537
530
|
"Read more about how `tf.config.experimental.set_memory_growth()` works in "
|
|
538
531
|
"the TensorFlow documentation: https://www.tensorflow.org/api/stable.",
|
|
539
532
|
)
|
|
540
|
-
parser.add_argument(
|
|
541
|
-
"--delay-start",
|
|
542
|
-
type=int,
|
|
543
|
-
default=3,
|
|
544
|
-
help="Buffer time (in seconds) to delay the start the simulation engine after "
|
|
545
|
-
"the `ServerApp`, which runs in a separate thread, has been launched.",
|
|
546
|
-
)
|
|
547
533
|
parser.add_argument(
|
|
548
534
|
"--verbose",
|
|
549
535
|
action="store_true",
|
|
@@ -23,6 +23,7 @@ import grpc
|
|
|
23
23
|
from flwr.common.constant import SIMULATIONIO_API_DEFAULT_CLIENT_ADDRESS
|
|
24
24
|
from flwr.common.grpc import create_channel
|
|
25
25
|
from flwr.common.logger import log
|
|
26
|
+
from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
|
|
26
27
|
from flwr.proto.simulationio_pb2_grpc import SimulationIoStub # pylint: disable=E0611
|
|
27
28
|
|
|
28
29
|
|
|
@@ -48,6 +49,7 @@ class SimulationIoConnection:
|
|
|
48
49
|
self._cert = root_certificates
|
|
49
50
|
self._grpc_stub: Optional[SimulationIoStub] = None
|
|
50
51
|
self._channel: Optional[grpc.Channel] = None
|
|
52
|
+
self._retry_invoker = _make_simple_grpc_retry_invoker()
|
|
51
53
|
|
|
52
54
|
@property
|
|
53
55
|
def _is_connected(self) -> bool:
|
|
@@ -72,6 +74,7 @@ class SimulationIoConnection:
|
|
|
72
74
|
root_certificates=self._cert,
|
|
73
75
|
)
|
|
74
76
|
self._grpc_stub = SimulationIoStub(self._channel)
|
|
77
|
+
_wrap_stub(self._grpc_stub, self._retry_invoker)
|
|
75
78
|
log(DEBUG, "[SimulationIO] Connected to %s", self._addr)
|
|
76
79
|
|
|
77
80
|
def _disconnect(self) -> None:
|
flwr/superexec/app.py
CHANGED
flwr/superexec/deployment.py
CHANGED
flwr/superexec/exec_grpc.py
CHANGED
|
@@ -14,18 +14,22 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""SuperExec gRPC API."""
|
|
16
16
|
|
|
17
|
+
|
|
18
|
+
from collections.abc import Sequence
|
|
17
19
|
from logging import INFO
|
|
18
20
|
from typing import Optional
|
|
19
21
|
|
|
20
22
|
import grpc
|
|
21
23
|
|
|
22
24
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
25
|
+
from flwr.common.auth_plugin import ExecAuthPlugin
|
|
23
26
|
from flwr.common.logger import log
|
|
24
27
|
from flwr.common.typing import UserConfig
|
|
25
28
|
from flwr.proto.exec_pb2_grpc import add_ExecServicer_to_server
|
|
26
29
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
27
30
|
from flwr.server.superlink.fleet.grpc_bidi.grpc_server import generic_create_grpc_server
|
|
28
31
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
32
|
+
from flwr.superexec.exec_user_auth_interceptor import ExecUserAuthInterceptor
|
|
29
33
|
|
|
30
34
|
from .exec_servicer import ExecServicer
|
|
31
35
|
from .executor import Executor
|
|
@@ -39,6 +43,7 @@ def run_exec_api_grpc(
|
|
|
39
43
|
ffs_factory: FfsFactory,
|
|
40
44
|
certificates: Optional[tuple[bytes, bytes, bytes]],
|
|
41
45
|
config: UserConfig,
|
|
46
|
+
auth_plugin: Optional[ExecAuthPlugin] = None,
|
|
42
47
|
) -> grpc.Server:
|
|
43
48
|
"""Run Exec API (gRPC, request-response)."""
|
|
44
49
|
executor.set_config(config)
|
|
@@ -47,16 +52,29 @@ def run_exec_api_grpc(
|
|
|
47
52
|
linkstate_factory=state_factory,
|
|
48
53
|
ffs_factory=ffs_factory,
|
|
49
54
|
executor=executor,
|
|
55
|
+
auth_plugin=auth_plugin,
|
|
50
56
|
)
|
|
57
|
+
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
|
|
58
|
+
if auth_plugin is not None:
|
|
59
|
+
interceptors = [ExecUserAuthInterceptor(auth_plugin)]
|
|
51
60
|
exec_add_servicer_to_server_fn = add_ExecServicer_to_server
|
|
52
61
|
exec_grpc_server = generic_create_grpc_server(
|
|
53
62
|
servicer_and_add_fn=(exec_servicer, exec_add_servicer_to_server_fn),
|
|
54
63
|
server_address=address,
|
|
55
64
|
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
|
|
56
65
|
certificates=certificates,
|
|
66
|
+
interceptors=interceptors,
|
|
57
67
|
)
|
|
58
68
|
|
|
59
|
-
|
|
69
|
+
if auth_plugin is None:
|
|
70
|
+
log(INFO, "Flower Deployment Engine: Starting Exec API on %s", address)
|
|
71
|
+
else:
|
|
72
|
+
log(
|
|
73
|
+
INFO,
|
|
74
|
+
"Flower Deployment Engine: Starting Exec API with user "
|
|
75
|
+
"authentication on %s",
|
|
76
|
+
address,
|
|
77
|
+
)
|
|
60
78
|
exec_grpc_server.start()
|
|
61
79
|
|
|
62
80
|
return exec_grpc_server
|
flwr/superexec/exec_servicer.py
CHANGED
|
@@ -18,24 +18,33 @@
|
|
|
18
18
|
import time
|
|
19
19
|
from collections.abc import Generator
|
|
20
20
|
from logging import ERROR, INFO
|
|
21
|
-
from typing import Any
|
|
21
|
+
from typing import Any, Optional
|
|
22
|
+
from uuid import UUID
|
|
22
23
|
|
|
23
24
|
import grpc
|
|
24
25
|
|
|
25
26
|
from flwr.common import now
|
|
26
|
-
from flwr.common.
|
|
27
|
+
from flwr.common.auth_plugin import ExecAuthPlugin
|
|
28
|
+
from flwr.common.constant import LOG_STREAM_INTERVAL, Status, SubStatus
|
|
27
29
|
from flwr.common.logger import log
|
|
28
30
|
from flwr.common.serde import (
|
|
29
31
|
configs_record_from_proto,
|
|
30
32
|
run_to_proto,
|
|
31
33
|
user_config_from_proto,
|
|
32
34
|
)
|
|
35
|
+
from flwr.common.typing import RunStatus
|
|
33
36
|
from flwr.proto import exec_pb2_grpc # pylint: disable=E0611
|
|
34
37
|
from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
|
|
38
|
+
GetAuthTokensRequest,
|
|
39
|
+
GetAuthTokensResponse,
|
|
40
|
+
GetLoginDetailsRequest,
|
|
41
|
+
GetLoginDetailsResponse,
|
|
35
42
|
ListRunsRequest,
|
|
36
43
|
ListRunsResponse,
|
|
37
44
|
StartRunRequest,
|
|
38
45
|
StartRunResponse,
|
|
46
|
+
StopRunRequest,
|
|
47
|
+
StopRunResponse,
|
|
39
48
|
StreamLogsRequest,
|
|
40
49
|
StreamLogsResponse,
|
|
41
50
|
)
|
|
@@ -53,11 +62,13 @@ class ExecServicer(exec_pb2_grpc.ExecServicer):
|
|
|
53
62
|
linkstate_factory: LinkStateFactory,
|
|
54
63
|
ffs_factory: FfsFactory,
|
|
55
64
|
executor: Executor,
|
|
65
|
+
auth_plugin: Optional[ExecAuthPlugin] = None,
|
|
56
66
|
) -> None:
|
|
57
67
|
self.linkstate_factory = linkstate_factory
|
|
58
68
|
self.ffs_factory = ffs_factory
|
|
59
69
|
self.executor = executor
|
|
60
70
|
self.executor.initialize(linkstate_factory, ffs_factory)
|
|
71
|
+
self.auth_plugin = auth_plugin
|
|
61
72
|
|
|
62
73
|
def StartRun(
|
|
63
74
|
self, request: StartRunRequest, context: grpc.ServicerContext
|
|
@@ -126,6 +137,69 @@ class ExecServicer(exec_pb2_grpc.ExecServicer):
|
|
|
126
137
|
# Handle `flwr ls --run-id <run_id>`
|
|
127
138
|
return _create_list_runs_response({request.run_id}, state)
|
|
128
139
|
|
|
140
|
+
def StopRun(
|
|
141
|
+
self, request: StopRunRequest, context: grpc.ServicerContext
|
|
142
|
+
) -> StopRunResponse:
|
|
143
|
+
"""Stop a given run ID."""
|
|
144
|
+
log(INFO, "ExecServicer.StopRun")
|
|
145
|
+
state = self.linkstate_factory.state()
|
|
146
|
+
|
|
147
|
+
# Exit if `run_id` not found
|
|
148
|
+
if not state.get_run(request.run_id):
|
|
149
|
+
context.abort(
|
|
150
|
+
grpc.StatusCode.NOT_FOUND, f"Run ID {request.run_id} not found"
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
run_status = state.get_run_status({request.run_id})[request.run_id]
|
|
154
|
+
if run_status.status == Status.FINISHED:
|
|
155
|
+
context.abort(
|
|
156
|
+
grpc.StatusCode.FAILED_PRECONDITION,
|
|
157
|
+
f"Run ID {request.run_id} is already finished",
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
update_success = state.update_run_status(
|
|
161
|
+
run_id=request.run_id,
|
|
162
|
+
new_status=RunStatus(Status.FINISHED, SubStatus.STOPPED, ""),
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
if update_success:
|
|
166
|
+
task_ids: set[UUID] = state.get_task_ids_from_run_id(request.run_id)
|
|
167
|
+
|
|
168
|
+
# Delete TaskIns and TaskRes for the `run_id`
|
|
169
|
+
state.delete_tasks(task_ids)
|
|
170
|
+
|
|
171
|
+
return StopRunResponse(success=update_success)
|
|
172
|
+
|
|
173
|
+
def GetLoginDetails(
|
|
174
|
+
self, request: GetLoginDetailsRequest, context: grpc.ServicerContext
|
|
175
|
+
) -> GetLoginDetailsResponse:
|
|
176
|
+
"""Start login."""
|
|
177
|
+
log(INFO, "ExecServicer.GetLoginDetails")
|
|
178
|
+
if self.auth_plugin is None:
|
|
179
|
+
context.abort(
|
|
180
|
+
grpc.StatusCode.UNIMPLEMENTED,
|
|
181
|
+
"ExecServicer initialized without user authentication",
|
|
182
|
+
)
|
|
183
|
+
raise grpc.RpcError() # This line is unreachable
|
|
184
|
+
return GetLoginDetailsResponse(
|
|
185
|
+
login_details=self.auth_plugin.get_login_details()
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
def GetAuthTokens(
|
|
189
|
+
self, request: GetAuthTokensRequest, context: grpc.ServicerContext
|
|
190
|
+
) -> GetAuthTokensResponse:
|
|
191
|
+
"""Get auth token."""
|
|
192
|
+
log(INFO, "ExecServicer.GetAuthTokens")
|
|
193
|
+
if self.auth_plugin is None:
|
|
194
|
+
context.abort(
|
|
195
|
+
grpc.StatusCode.UNIMPLEMENTED,
|
|
196
|
+
"ExecServicer initialized without user authentication",
|
|
197
|
+
)
|
|
198
|
+
raise grpc.RpcError() # This line is unreachable
|
|
199
|
+
return GetAuthTokensResponse(
|
|
200
|
+
auth_tokens=self.auth_plugin.get_auth_tokens(dict(request.auth_details))
|
|
201
|
+
)
|
|
202
|
+
|
|
129
203
|
|
|
130
204
|
def _create_list_runs_response(run_ids: set[int], state: LinkState) -> ListRunsResponse:
|
|
131
205
|
"""Create response for `flwr ls --runs` and `flwr ls --run-id <run_id>`."""
|