flwr 1.13.1__py3-none-any.whl → 1.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/app.py +5 -0
- flwr/cli/build.py +1 -0
- flwr/cli/cli_user_auth_interceptor.py +86 -0
- flwr/cli/config_utils.py +19 -2
- flwr/cli/example.py +1 -0
- flwr/cli/install.py +1 -0
- flwr/cli/log.py +18 -36
- flwr/cli/login/__init__.py +22 -0
- flwr/cli/login/login.py +81 -0
- flwr/cli/ls.py +205 -106
- flwr/cli/new/__init__.py +1 -0
- flwr/cli/new/new.py +2 -1
- flwr/cli/new/templates/app/.gitignore.tpl +3 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -3
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/__init__.py +1 -0
- flwr/cli/run/run.py +89 -39
- flwr/cli/stop.py +130 -0
- flwr/cli/utils.py +172 -8
- flwr/client/app.py +14 -3
- flwr/client/client.py +1 -32
- flwr/client/clientapp/app.py +4 -1
- flwr/client/clientapp/utils.py +1 -0
- flwr/client/grpc_adapter_client/connection.py +1 -1
- flwr/client/grpc_client/connection.py +1 -1
- flwr/client/grpc_rere_client/connection.py +13 -7
- flwr/client/message_handler/message_handler.py +1 -2
- flwr/client/mod/comms_mods.py +1 -0
- flwr/client/mod/localdp_mod.py +1 -1
- flwr/client/nodestate/__init__.py +1 -0
- flwr/client/nodestate/nodestate.py +1 -0
- flwr/client/nodestate/nodestate_factory.py +1 -0
- flwr/client/numpy_client.py +0 -44
- flwr/client/rest_client/connection.py +3 -3
- flwr/client/supernode/app.py +2 -2
- flwr/common/address.py +1 -0
- flwr/common/args.py +1 -0
- flwr/common/auth_plugin/__init__.py +24 -0
- flwr/common/auth_plugin/auth_plugin.py +111 -0
- flwr/common/config.py +3 -1
- flwr/common/constant.py +17 -1
- flwr/common/logger.py +40 -0
- flwr/common/message.py +1 -0
- flwr/common/object_ref.py +57 -54
- flwr/common/pyproject.py +1 -0
- flwr/common/record/__init__.py +1 -0
- flwr/common/record/parametersrecord.py +1 -0
- flwr/common/retry_invoker.py +77 -0
- flwr/common/secure_aggregation/secaggplus_utils.py +2 -2
- flwr/common/telemetry.py +15 -4
- flwr/common/typing.py +12 -0
- flwr/common/version.py +1 -0
- flwr/proto/exec_pb2.py +38 -14
- flwr/proto/exec_pb2.pyi +107 -2
- flwr/proto/exec_pb2_grpc.py +102 -0
- flwr/proto/exec_pb2_grpc.pyi +39 -0
- flwr/proto/fab_pb2.py +4 -4
- flwr/proto/fab_pb2.pyi +4 -1
- flwr/proto/serverappio_pb2.py +18 -18
- flwr/proto/serverappio_pb2.pyi +8 -2
- flwr/proto/serverappio_pb2_grpc.py +34 -0
- flwr/proto/serverappio_pb2_grpc.pyi +13 -0
- flwr/proto/simulationio_pb2.py +2 -2
- flwr/proto/simulationio_pb2_grpc.py +34 -0
- flwr/proto/simulationio_pb2_grpc.pyi +13 -0
- flwr/server/app.py +54 -2
- flwr/server/compat/app_utils.py +7 -1
- flwr/server/driver/grpc_driver.py +11 -63
- flwr/server/driver/inmemory_driver.py +5 -1
- flwr/server/run_serverapp.py +8 -9
- flwr/server/serverapp/app.py +25 -3
- flwr/server/strategy/dpfedavg_fixed.py +1 -0
- flwr/server/superlink/driver/serverappio_grpc.py +1 -0
- flwr/server/superlink/driver/serverappio_servicer.py +82 -23
- flwr/server/superlink/ffs/disk_ffs.py +1 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +1 -0
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +32 -12
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +12 -11
- flwr/server/superlink/fleet/message_handler/message_handler.py +32 -5
- flwr/server/superlink/fleet/rest_rere/rest_api.py +4 -1
- flwr/server/superlink/fleet/vce/__init__.py +1 -0
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -0
- flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -0
- flwr/server/superlink/linkstate/in_memory_linkstate.py +21 -30
- flwr/server/superlink/linkstate/linkstate.py +17 -2
- flwr/server/superlink/linkstate/sqlite_linkstate.py +30 -49
- flwr/server/superlink/simulation/simulationio_servicer.py +33 -0
- flwr/server/superlink/utils.py +65 -0
- flwr/simulation/app.py +16 -4
- flwr/simulation/ray_transport/ray_actor.py +1 -0
- flwr/simulation/ray_transport/utils.py +1 -0
- flwr/simulation/run_simulation.py +36 -22
- flwr/simulation/simulationio_connection.py +3 -0
- flwr/superexec/app.py +1 -0
- flwr/superexec/deployment.py +1 -0
- flwr/superexec/exec_grpc.py +19 -1
- flwr/superexec/exec_servicer.py +76 -2
- flwr/superexec/exec_user_auth_interceptor.py +101 -0
- flwr/superexec/executor.py +1 -0
- {flwr-1.13.1.dist-info → flwr-1.14.0.dist-info}/METADATA +8 -7
- {flwr-1.13.1.dist-info → flwr-1.14.0.dist-info}/RECORD +112 -112
- flwr/proto/common_pb2.py +0 -36
- flwr/proto/common_pb2.pyi +0 -121
- flwr/proto/common_pb2_grpc.py +0 -4
- flwr/proto/common_pb2_grpc.pyi +0 -4
- flwr/proto/control_pb2.py +0 -27
- flwr/proto/control_pb2.pyi +0 -7
- flwr/proto/control_pb2_grpc.py +0 -135
- flwr/proto/control_pb2_grpc.pyi +0 -53
- {flwr-1.13.1.dist-info → flwr-1.14.0.dist-info}/LICENSE +0 -0
- {flwr-1.13.1.dist-info → flwr-1.14.0.dist-info}/WHEEL +0 -0
- {flwr-1.13.1.dist-info → flwr-1.14.0.dist-info}/entry_points.txt +0 -0
|
@@ -139,8 +139,19 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
139
139
|
"""
|
|
140
140
|
|
|
141
141
|
@abc.abstractmethod
|
|
142
|
-
def delete_tasks(self,
|
|
143
|
-
"""Delete
|
|
142
|
+
def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
|
|
143
|
+
"""Delete TaskIns/TaskRes pairs based on provided TaskIns IDs.
|
|
144
|
+
|
|
145
|
+
Parameters
|
|
146
|
+
----------
|
|
147
|
+
task_ins_ids : set[UUID]
|
|
148
|
+
A set of TaskIns IDs. For each ID in the set, the corresponding
|
|
149
|
+
TaskIns and its associated TaskRes will be deleted.
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
@abc.abstractmethod
|
|
153
|
+
def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
|
154
|
+
"""Get all TaskIns IDs for the given run_id."""
|
|
144
155
|
|
|
145
156
|
@abc.abstractmethod
|
|
146
157
|
def create_node(
|
|
@@ -273,6 +284,10 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
273
284
|
def get_server_public_key(self) -> Optional[bytes]:
|
|
274
285
|
"""Retrieve `server_public_key` in urlsafe bytes."""
|
|
275
286
|
|
|
287
|
+
@abc.abstractmethod
|
|
288
|
+
def clear_supernode_auth_keys_and_credentials(self) -> None:
|
|
289
|
+
"""Clear stored `node_public_keys` and credentials in the link state if any."""
|
|
290
|
+
|
|
276
291
|
@abc.abstractmethod
|
|
277
292
|
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
|
|
278
293
|
"""Store a set of `node_public_keys` in the link state."""
|
|
@@ -14,12 +14,12 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""SQLite based implemenation of the link state."""
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
# pylint: disable=too-many-lines
|
|
18
19
|
|
|
19
20
|
import json
|
|
20
21
|
import re
|
|
21
22
|
import sqlite3
|
|
22
|
-
import threading
|
|
23
23
|
import time
|
|
24
24
|
from collections.abc import Sequence
|
|
25
25
|
from logging import DEBUG, ERROR, WARNING
|
|
@@ -183,7 +183,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
183
183
|
"""
|
|
184
184
|
self.database_path = database_path
|
|
185
185
|
self.conn: Optional[sqlite3.Connection] = None
|
|
186
|
-
self.lock = threading.RLock()
|
|
187
186
|
|
|
188
187
|
def initialize(self, log_queries: bool = False) -> list[tuple[str]]:
|
|
189
188
|
"""Create tables if they don't exist yet.
|
|
@@ -216,7 +215,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
216
215
|
cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
|
|
217
216
|
cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
|
|
218
217
|
res = cur.execute("SELECT name FROM sqlite_schema;")
|
|
219
|
-
|
|
220
218
|
return res.fetchall()
|
|
221
219
|
|
|
222
220
|
def query(
|
|
@@ -569,9 +567,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
569
567
|
data: list[Any] = [delivered_at] + task_res_ids
|
|
570
568
|
self.query(query, data)
|
|
571
569
|
|
|
572
|
-
# Cleanup
|
|
573
|
-
self._force_delete_tasks_by_ids(set(ret.keys()))
|
|
574
|
-
|
|
575
570
|
return list(ret.values())
|
|
576
571
|
|
|
577
572
|
def num_task_ins(self) -> int:
|
|
@@ -595,68 +590,50 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
595
590
|
result: dict[str, int] = rows[0]
|
|
596
591
|
return result["num"]
|
|
597
592
|
|
|
598
|
-
def delete_tasks(self,
|
|
599
|
-
"""Delete
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
593
|
+
def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
|
|
594
|
+
"""Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
|
|
595
|
+
if not task_ins_ids:
|
|
596
|
+
return
|
|
597
|
+
if self.conn is None:
|
|
598
|
+
raise AttributeError("LinkState not initialized")
|
|
603
599
|
|
|
604
|
-
placeholders = ",".join([
|
|
605
|
-
data =
|
|
600
|
+
placeholders = ",".join(["?"] * len(task_ins_ids))
|
|
601
|
+
data = tuple(str(task_id) for task_id in task_ins_ids)
|
|
606
602
|
|
|
607
|
-
#
|
|
603
|
+
# Delete task_ins
|
|
608
604
|
query_1 = f"""
|
|
609
605
|
DELETE FROM task_ins
|
|
610
|
-
WHERE
|
|
611
|
-
AND task_id IN (
|
|
612
|
-
SELECT ancestry
|
|
613
|
-
FROM task_res
|
|
614
|
-
WHERE ancestry IN ({placeholders})
|
|
615
|
-
AND delivered_at != ''
|
|
616
|
-
);
|
|
606
|
+
WHERE task_id IN ({placeholders});
|
|
617
607
|
"""
|
|
618
608
|
|
|
619
|
-
#
|
|
609
|
+
# Delete task_res
|
|
620
610
|
query_2 = f"""
|
|
621
611
|
DELETE FROM task_res
|
|
622
|
-
WHERE ancestry IN ({placeholders})
|
|
623
|
-
AND delivered_at != '';
|
|
612
|
+
WHERE ancestry IN ({placeholders});
|
|
624
613
|
"""
|
|
625
614
|
|
|
626
|
-
if self.conn is None:
|
|
627
|
-
raise AttributeError("LinkState not intitialized")
|
|
628
|
-
|
|
629
615
|
with self.conn:
|
|
630
616
|
self.conn.execute(query_1, data)
|
|
631
617
|
self.conn.execute(query_2, data)
|
|
632
618
|
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
|
|
636
|
-
"""Delete tasks based on a set of TaskIns IDs."""
|
|
637
|
-
if not task_ids:
|
|
638
|
-
return
|
|
619
|
+
def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
|
620
|
+
"""Get all TaskIns IDs for the given run_id."""
|
|
639
621
|
if self.conn is None:
|
|
640
622
|
raise AttributeError("LinkState not initialized")
|
|
641
623
|
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
query_1 = f"""
|
|
647
|
-
DELETE FROM task_ins
|
|
648
|
-
WHERE task_id IN ({placeholders});
|
|
624
|
+
query = """
|
|
625
|
+
SELECT task_id
|
|
626
|
+
FROM task_ins
|
|
627
|
+
WHERE run_id = :run_id;
|
|
649
628
|
"""
|
|
650
629
|
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
DELETE FROM task_res
|
|
654
|
-
WHERE ancestry IN ({placeholders});
|
|
655
|
-
"""
|
|
630
|
+
sint64_run_id = convert_uint64_to_sint64(run_id)
|
|
631
|
+
data = {"run_id": sint64_run_id}
|
|
656
632
|
|
|
657
633
|
with self.conn:
|
|
658
|
-
self.conn.execute(
|
|
659
|
-
|
|
634
|
+
rows = self.conn.execute(query, data).fetchall()
|
|
635
|
+
|
|
636
|
+
return {UUID(row["task_id"]) for row in rows}
|
|
660
637
|
|
|
661
638
|
def create_node(
|
|
662
639
|
self, ping_interval: float, public_key: Optional[bytes] = None
|
|
@@ -784,8 +761,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
784
761
|
"federation_options, pending_at, starting_at, running_at, finished_at, "
|
|
785
762
|
"sub_status, details) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
|
|
786
763
|
)
|
|
787
|
-
if fab_hash:
|
|
788
|
-
fab_id, fab_version = "", ""
|
|
789
764
|
override_config_json = json.dumps(override_config)
|
|
790
765
|
data = [
|
|
791
766
|
sint64_run_id,
|
|
@@ -843,6 +818,12 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
843
818
|
public_key = None
|
|
844
819
|
return public_key
|
|
845
820
|
|
|
821
|
+
def clear_supernode_auth_keys_and_credentials(self) -> None:
|
|
822
|
+
"""Clear stored `node_public_keys` and credentials in the link state if any."""
|
|
823
|
+
queries = ["DELETE FROM public_key;", "DELETE FROM credential;"]
|
|
824
|
+
for query in queries:
|
|
825
|
+
self.query(query)
|
|
826
|
+
|
|
846
827
|
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
|
|
847
828
|
"""Store a set of `node_public_keys` in the link state."""
|
|
848
829
|
query = "INSERT INTO public_key (public_key) VALUES (?)"
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""SimulationIo API servicer."""
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
import threading
|
|
18
19
|
from logging import DEBUG, INFO
|
|
19
20
|
|
|
@@ -28,6 +29,7 @@ from flwr.common.serde import (
|
|
|
28
29
|
context_to_proto,
|
|
29
30
|
fab_to_proto,
|
|
30
31
|
run_status_from_proto,
|
|
32
|
+
run_status_to_proto,
|
|
31
33
|
run_to_proto,
|
|
32
34
|
)
|
|
33
35
|
from flwr.common.typing import Fab, RunStatus
|
|
@@ -39,6 +41,8 @@ from flwr.proto.log_pb2 import ( # pylint: disable=E0611
|
|
|
39
41
|
from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
40
42
|
GetFederationOptionsRequest,
|
|
41
43
|
GetFederationOptionsResponse,
|
|
44
|
+
GetRunStatusRequest,
|
|
45
|
+
GetRunStatusResponse,
|
|
42
46
|
UpdateRunStatusRequest,
|
|
43
47
|
UpdateRunStatusResponse,
|
|
44
48
|
)
|
|
@@ -50,6 +54,7 @@ from flwr.proto.simulationio_pb2 import ( # pylint: disable=E0611
|
|
|
50
54
|
)
|
|
51
55
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
52
56
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
57
|
+
from flwr.server.superlink.utils import abort_if
|
|
53
58
|
|
|
54
59
|
|
|
55
60
|
class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
@@ -106,6 +111,15 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
|
106
111
|
"""Push Simulation process outputs."""
|
|
107
112
|
log(DEBUG, "SimultionIoServicer.PushSimulationOutputs")
|
|
108
113
|
state = self.state_factory.state()
|
|
114
|
+
|
|
115
|
+
# Abort if the run is not running
|
|
116
|
+
abort_if(
|
|
117
|
+
request.run_id,
|
|
118
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
119
|
+
state,
|
|
120
|
+
context,
|
|
121
|
+
)
|
|
122
|
+
|
|
109
123
|
state.set_serverapp_context(request.run_id, context_from_proto(request.context))
|
|
110
124
|
return PushSimulationOutputsResponse()
|
|
111
125
|
|
|
@@ -116,12 +130,31 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
|
116
130
|
log(DEBUG, "SimultionIoServicer.UpdateRunStatus")
|
|
117
131
|
state = self.state_factory.state()
|
|
118
132
|
|
|
133
|
+
# Abort if the run is finished
|
|
134
|
+
abort_if(request.run_id, [Status.FINISHED], state, context)
|
|
135
|
+
|
|
119
136
|
# Update the run status
|
|
120
137
|
state.update_run_status(
|
|
121
138
|
run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
|
|
122
139
|
)
|
|
123
140
|
return UpdateRunStatusResponse()
|
|
124
141
|
|
|
142
|
+
def GetRunStatus(
|
|
143
|
+
self, request: GetRunStatusRequest, context: ServicerContext
|
|
144
|
+
) -> GetRunStatusResponse:
|
|
145
|
+
"""Get status of requested runs."""
|
|
146
|
+
log(DEBUG, "SimultionIoServicer.GetRunStatus")
|
|
147
|
+
state = self.state_factory.state()
|
|
148
|
+
|
|
149
|
+
statuses = state.get_run_status(set(request.run_ids))
|
|
150
|
+
|
|
151
|
+
return GetRunStatusResponse(
|
|
152
|
+
run_status_dict={
|
|
153
|
+
run_id: run_status_to_proto(status)
|
|
154
|
+
for run_id, status in statuses.items()
|
|
155
|
+
}
|
|
156
|
+
)
|
|
157
|
+
|
|
125
158
|
def PushLogs(
|
|
126
159
|
self, request: PushLogsRequest, context: grpc.ServicerContext
|
|
127
160
|
) -> PushLogsResponse:
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""SuperLink utilities."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from typing import Union
|
|
19
|
+
|
|
20
|
+
import grpc
|
|
21
|
+
|
|
22
|
+
from flwr.common.constant import Status, SubStatus
|
|
23
|
+
from flwr.common.typing import RunStatus
|
|
24
|
+
from flwr.server.superlink.linkstate import LinkState
|
|
25
|
+
|
|
26
|
+
_STATUS_TO_MSG = {
|
|
27
|
+
Status.PENDING: "Run is pending.",
|
|
28
|
+
Status.STARTING: "Run is starting.",
|
|
29
|
+
Status.RUNNING: "Run is running.",
|
|
30
|
+
Status.FINISHED: "Run is finished.",
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def check_abort(
|
|
35
|
+
run_id: int,
|
|
36
|
+
abort_status_list: list[str],
|
|
37
|
+
state: LinkState,
|
|
38
|
+
) -> Union[str, None]:
|
|
39
|
+
"""Check if the status of the provided `run_id` is in `abort_status_list`."""
|
|
40
|
+
run_status: RunStatus = state.get_run_status({run_id})[run_id]
|
|
41
|
+
|
|
42
|
+
if run_status.status in abort_status_list:
|
|
43
|
+
msg = _STATUS_TO_MSG[run_status.status]
|
|
44
|
+
if run_status.sub_status == SubStatus.STOPPED:
|
|
45
|
+
msg += " Stopped by user."
|
|
46
|
+
return msg
|
|
47
|
+
|
|
48
|
+
return None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def abort_grpc_context(msg: Union[str, None], context: grpc.ServicerContext) -> None:
|
|
52
|
+
"""Abort context with statuscode PERMISSION_DENIED if `msg` is not None."""
|
|
53
|
+
if msg is not None:
|
|
54
|
+
context.abort(grpc.StatusCode.PERMISSION_DENIED, msg)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def abort_if(
|
|
58
|
+
run_id: int,
|
|
59
|
+
abort_status_list: list[str],
|
|
60
|
+
state: LinkState,
|
|
61
|
+
context: grpc.ServicerContext,
|
|
62
|
+
) -> None:
|
|
63
|
+
"""Abort context if status of the provided `run_id` is in `abort_status_list`."""
|
|
64
|
+
msg = check_abort(run_id, abort_status_list, state)
|
|
65
|
+
abort_grpc_context(msg, context)
|
flwr/simulation/app.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower Simulation process."""
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
import argparse
|
|
18
19
|
import sys
|
|
19
20
|
from logging import DEBUG, ERROR, INFO
|
|
@@ -23,7 +24,8 @@ from typing import Optional
|
|
|
23
24
|
|
|
24
25
|
from flwr.cli.config_utils import get_fab_metadata
|
|
25
26
|
from flwr.cli.install import install_from_fab
|
|
26
|
-
from flwr.
|
|
27
|
+
from flwr.cli.utils import get_sha256_hash
|
|
28
|
+
from flwr.common import EventType, event
|
|
27
29
|
from flwr.common.args import add_args_flwr_app_common
|
|
28
30
|
from flwr.common.config import (
|
|
29
31
|
get_flwr_dir,
|
|
@@ -47,6 +49,7 @@ from flwr.common.logger import (
|
|
|
47
49
|
from flwr.common.serde import (
|
|
48
50
|
configs_record_from_proto,
|
|
49
51
|
context_from_proto,
|
|
52
|
+
context_to_proto,
|
|
50
53
|
fab_from_proto,
|
|
51
54
|
run_from_proto,
|
|
52
55
|
run_status_to_proto,
|
|
@@ -200,8 +203,17 @@ def run_simulation_process( # pylint: disable=R0914, disable=W0212, disable=R09
|
|
|
200
203
|
verbose: bool = fed_opt.get("verbose", False)
|
|
201
204
|
enable_tf_gpu_growth: bool = fed_opt.get("enable_tf_gpu_growth", False)
|
|
202
205
|
|
|
206
|
+
event(
|
|
207
|
+
EventType.FLWR_SIMULATION_RUN_ENTER,
|
|
208
|
+
event_details={
|
|
209
|
+
"backend": "ray",
|
|
210
|
+
"num-supernodes": num_supernodes,
|
|
211
|
+
"run-id-hash": get_sha256_hash(run.run_id),
|
|
212
|
+
},
|
|
213
|
+
)
|
|
214
|
+
|
|
203
215
|
# Launch the simulation
|
|
204
|
-
_run_simulation(
|
|
216
|
+
updated_context = _run_simulation(
|
|
205
217
|
server_app_attr=server_app_attr,
|
|
206
218
|
client_app_attr=client_app_attr,
|
|
207
219
|
num_supernodes=num_supernodes,
|
|
@@ -212,11 +224,11 @@ def run_simulation_process( # pylint: disable=R0914, disable=W0212, disable=R09
|
|
|
212
224
|
verbose_logging=verbose,
|
|
213
225
|
server_app_run_config=fused_config,
|
|
214
226
|
is_app=True,
|
|
215
|
-
exit_event=EventType.
|
|
227
|
+
exit_event=EventType.FLWR_SIMULATION_RUN_LEAVE,
|
|
216
228
|
)
|
|
217
229
|
|
|
218
230
|
# Send resulting context
|
|
219
|
-
context_proto =
|
|
231
|
+
context_proto = context_to_proto(updated_context)
|
|
220
232
|
out_req = PushSimulationOutputsRequest(
|
|
221
233
|
run_id=run.run_id, context=context_proto
|
|
222
234
|
)
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower Simulation."""
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
import argparse
|
|
18
19
|
import asyncio
|
|
19
20
|
import json
|
|
@@ -23,10 +24,11 @@ import threading
|
|
|
23
24
|
import traceback
|
|
24
25
|
from logging import DEBUG, ERROR, INFO, WARNING
|
|
25
26
|
from pathlib import Path
|
|
26
|
-
from
|
|
27
|
+
from queue import Empty, Queue
|
|
27
28
|
from typing import Any, Optional
|
|
28
29
|
|
|
29
30
|
from flwr.cli.config_utils import load_and_validate
|
|
31
|
+
from flwr.cli.utils import get_sha256_hash
|
|
30
32
|
from flwr.client import ClientApp
|
|
31
33
|
from flwr.common import Context, EventType, RecordSet, event, log, now
|
|
32
34
|
from flwr.common.config import get_fused_config_from_dir, parse_config_args
|
|
@@ -126,7 +128,7 @@ def run_simulation_from_cli() -> None:
|
|
|
126
128
|
run = Run.create_empty(run_id)
|
|
127
129
|
run.override_config = override_config
|
|
128
130
|
|
|
129
|
-
_run_simulation(
|
|
131
|
+
_ = _run_simulation(
|
|
130
132
|
server_app_attr=server_app_attr,
|
|
131
133
|
client_app_attr=client_app_attr,
|
|
132
134
|
num_supernodes=args.num_supernodes,
|
|
@@ -135,7 +137,6 @@ def run_simulation_from_cli() -> None:
|
|
|
135
137
|
app_dir=args.app,
|
|
136
138
|
run=run,
|
|
137
139
|
enable_tf_gpu_growth=args.enable_tf_gpu_growth,
|
|
138
|
-
delay_start=args.delay_start,
|
|
139
140
|
verbose_logging=args.verbose,
|
|
140
141
|
server_app_run_config=fused_config,
|
|
141
142
|
is_app=True,
|
|
@@ -207,7 +208,7 @@ def run_simulation(
|
|
|
207
208
|
"\n\tflwr.simulation.run_simulationt(...)",
|
|
208
209
|
)
|
|
209
210
|
|
|
210
|
-
_run_simulation(
|
|
211
|
+
_ = _run_simulation(
|
|
211
212
|
num_supernodes=num_supernodes,
|
|
212
213
|
client_app=client_app,
|
|
213
214
|
server_app=server_app,
|
|
@@ -230,6 +231,7 @@ def run_serverapp_th(
|
|
|
230
231
|
has_exception: threading.Event,
|
|
231
232
|
enable_tf_gpu_growth: bool,
|
|
232
233
|
run_id: int,
|
|
234
|
+
ctx_queue: "Queue[Context]",
|
|
233
235
|
) -> threading.Thread:
|
|
234
236
|
"""Run SeverApp in a thread."""
|
|
235
237
|
|
|
@@ -242,6 +244,7 @@ def run_serverapp_th(
|
|
|
242
244
|
_server_app_run_config: UserConfig,
|
|
243
245
|
_server_app_attr: Optional[str],
|
|
244
246
|
_server_app: Optional[ServerApp],
|
|
247
|
+
_ctx_queue: "Queue[Context]",
|
|
245
248
|
) -> None:
|
|
246
249
|
"""Run SeverApp, after check if GPU memory growth has to be set.
|
|
247
250
|
|
|
@@ -262,13 +265,14 @@ def run_serverapp_th(
|
|
|
262
265
|
)
|
|
263
266
|
|
|
264
267
|
# Run ServerApp
|
|
265
|
-
_run(
|
|
268
|
+
updated_context = _run(
|
|
266
269
|
driver=_driver,
|
|
267
270
|
context=context,
|
|
268
271
|
server_app_dir=_server_app_dir,
|
|
269
272
|
server_app_attr=_server_app_attr,
|
|
270
273
|
loaded_server_app=_server_app,
|
|
271
274
|
)
|
|
275
|
+
_ctx_queue.put(updated_context)
|
|
272
276
|
except Exception as ex: # pylint: disable=broad-exception-caught
|
|
273
277
|
log(ERROR, "ServerApp thread raised an exception: %s", ex)
|
|
274
278
|
log(ERROR, traceback.format_exc())
|
|
@@ -292,6 +296,7 @@ def run_serverapp_th(
|
|
|
292
296
|
server_app_run_config,
|
|
293
297
|
server_app_attr,
|
|
294
298
|
server_app,
|
|
299
|
+
ctx_queue,
|
|
295
300
|
),
|
|
296
301
|
)
|
|
297
302
|
serverapp_th.start()
|
|
@@ -308,14 +313,13 @@ def _main_loop(
|
|
|
308
313
|
enable_tf_gpu_growth: bool,
|
|
309
314
|
run: Run,
|
|
310
315
|
exit_event: EventType,
|
|
311
|
-
delay_start: int,
|
|
312
316
|
flwr_dir: Optional[str] = None,
|
|
313
317
|
client_app: Optional[ClientApp] = None,
|
|
314
318
|
client_app_attr: Optional[str] = None,
|
|
315
319
|
server_app: Optional[ServerApp] = None,
|
|
316
320
|
server_app_attr: Optional[str] = None,
|
|
317
321
|
server_app_run_config: Optional[UserConfig] = None,
|
|
318
|
-
) ->
|
|
322
|
+
) -> Context:
|
|
319
323
|
"""Start ServerApp on a separate thread, then launch Simulation Engine."""
|
|
320
324
|
# Initialize StateFactory
|
|
321
325
|
state_factory = LinkStateFactory(":flwr-in-memory-state:")
|
|
@@ -325,6 +329,13 @@ def _main_loop(
|
|
|
325
329
|
server_app_thread_has_exception = threading.Event()
|
|
326
330
|
serverapp_th = None
|
|
327
331
|
success = True
|
|
332
|
+
updated_context = Context(
|
|
333
|
+
run_id=run.run_id,
|
|
334
|
+
node_id=0,
|
|
335
|
+
node_config=UserConfig(),
|
|
336
|
+
state=RecordSet(),
|
|
337
|
+
run_config=UserConfig(),
|
|
338
|
+
)
|
|
328
339
|
try:
|
|
329
340
|
# Register run
|
|
330
341
|
log(DEBUG, "Pre-registering run with id %s", run.run_id)
|
|
@@ -339,6 +350,7 @@ def _main_loop(
|
|
|
339
350
|
# Initialize Driver
|
|
340
351
|
driver = InMemoryDriver(state_factory=state_factory)
|
|
341
352
|
driver.set_run(run_id=run.run_id)
|
|
353
|
+
output_context_queue: "Queue[Context]" = Queue()
|
|
342
354
|
|
|
343
355
|
# Get and run ServerApp thread
|
|
344
356
|
serverapp_th = run_serverapp_th(
|
|
@@ -351,11 +363,9 @@ def _main_loop(
|
|
|
351
363
|
has_exception=server_app_thread_has_exception,
|
|
352
364
|
enable_tf_gpu_growth=enable_tf_gpu_growth,
|
|
353
365
|
run_id=run.run_id,
|
|
366
|
+
ctx_queue=output_context_queue,
|
|
354
367
|
)
|
|
355
368
|
|
|
356
|
-
# Buffer time so the `ServerApp` in separate thread is ready
|
|
357
|
-
log(DEBUG, "Buffer time delay: %ds", delay_start)
|
|
358
|
-
sleep(delay_start)
|
|
359
369
|
# Start Simulation Engine
|
|
360
370
|
vce.start_vce(
|
|
361
371
|
num_supernodes=num_supernodes,
|
|
@@ -371,6 +381,11 @@ def _main_loop(
|
|
|
371
381
|
flwr_dir=flwr_dir,
|
|
372
382
|
)
|
|
373
383
|
|
|
384
|
+
updated_context = output_context_queue.get(timeout=3)
|
|
385
|
+
|
|
386
|
+
except Empty:
|
|
387
|
+
log(DEBUG, "Queue timeout. No context received.")
|
|
388
|
+
|
|
374
389
|
except Exception as ex:
|
|
375
390
|
log(ERROR, "An exception occurred !! %s", ex)
|
|
376
391
|
log(ERROR, traceback.format_exc())
|
|
@@ -380,13 +395,20 @@ def _main_loop(
|
|
|
380
395
|
finally:
|
|
381
396
|
# Trigger stop event
|
|
382
397
|
f_stop.set()
|
|
383
|
-
event(
|
|
398
|
+
event(
|
|
399
|
+
exit_event,
|
|
400
|
+
event_details={
|
|
401
|
+
"run-id-hash": get_sha256_hash(run.run_id),
|
|
402
|
+
"success": success,
|
|
403
|
+
},
|
|
404
|
+
)
|
|
384
405
|
if serverapp_th:
|
|
385
406
|
serverapp_th.join()
|
|
386
407
|
if server_app_thread_has_exception.is_set():
|
|
387
408
|
raise RuntimeError("Exception in ServerApp thread")
|
|
388
409
|
|
|
389
410
|
log(DEBUG, "Stopping Simulation Engine now.")
|
|
411
|
+
return updated_context
|
|
390
412
|
|
|
391
413
|
|
|
392
414
|
# pylint: disable=too-many-arguments,too-many-locals,too-many-positional-arguments
|
|
@@ -404,10 +426,9 @@ def _run_simulation(
|
|
|
404
426
|
flwr_dir: Optional[str] = None,
|
|
405
427
|
run: Optional[Run] = None,
|
|
406
428
|
enable_tf_gpu_growth: bool = False,
|
|
407
|
-
delay_start: int = 5,
|
|
408
429
|
verbose_logging: bool = False,
|
|
409
430
|
is_app: bool = False,
|
|
410
|
-
) ->
|
|
431
|
+
) -> Context:
|
|
411
432
|
"""Launch the Simulation Engine."""
|
|
412
433
|
if backend_config is None:
|
|
413
434
|
backend_config = {}
|
|
@@ -459,7 +480,6 @@ def _run_simulation(
|
|
|
459
480
|
enable_tf_gpu_growth,
|
|
460
481
|
run,
|
|
461
482
|
exit_event,
|
|
462
|
-
delay_start,
|
|
463
483
|
flwr_dir,
|
|
464
484
|
client_app,
|
|
465
485
|
client_app_attr,
|
|
@@ -487,7 +507,8 @@ def _run_simulation(
|
|
|
487
507
|
# Set logger propagation to False to prevent duplicated log output in Colab.
|
|
488
508
|
logger = set_logger_propagation(logger, False)
|
|
489
509
|
|
|
490
|
-
_main_loop(*args)
|
|
510
|
+
updated_context = _main_loop(*args)
|
|
511
|
+
return updated_context
|
|
491
512
|
|
|
492
513
|
|
|
493
514
|
def _parse_args_run_simulation() -> argparse.ArgumentParser:
|
|
@@ -537,13 +558,6 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser:
|
|
|
537
558
|
"Read more about how `tf.config.experimental.set_memory_growth()` works in "
|
|
538
559
|
"the TensorFlow documentation: https://www.tensorflow.org/api/stable.",
|
|
539
560
|
)
|
|
540
|
-
parser.add_argument(
|
|
541
|
-
"--delay-start",
|
|
542
|
-
type=int,
|
|
543
|
-
default=3,
|
|
544
|
-
help="Buffer time (in seconds) to delay the start the simulation engine after "
|
|
545
|
-
"the `ServerApp`, which runs in a separate thread, has been launched.",
|
|
546
|
-
)
|
|
547
561
|
parser.add_argument(
|
|
548
562
|
"--verbose",
|
|
549
563
|
action="store_true",
|
|
@@ -23,6 +23,7 @@ import grpc
|
|
|
23
23
|
from flwr.common.constant import SIMULATIONIO_API_DEFAULT_CLIENT_ADDRESS
|
|
24
24
|
from flwr.common.grpc import create_channel
|
|
25
25
|
from flwr.common.logger import log
|
|
26
|
+
from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
|
|
26
27
|
from flwr.proto.simulationio_pb2_grpc import SimulationIoStub # pylint: disable=E0611
|
|
27
28
|
|
|
28
29
|
|
|
@@ -48,6 +49,7 @@ class SimulationIoConnection:
|
|
|
48
49
|
self._cert = root_certificates
|
|
49
50
|
self._grpc_stub: Optional[SimulationIoStub] = None
|
|
50
51
|
self._channel: Optional[grpc.Channel] = None
|
|
52
|
+
self._retry_invoker = _make_simple_grpc_retry_invoker()
|
|
51
53
|
|
|
52
54
|
@property
|
|
53
55
|
def _is_connected(self) -> bool:
|
|
@@ -72,6 +74,7 @@ class SimulationIoConnection:
|
|
|
72
74
|
root_certificates=self._cert,
|
|
73
75
|
)
|
|
74
76
|
self._grpc_stub = SimulationIoStub(self._channel)
|
|
77
|
+
_wrap_stub(self._grpc_stub, self._retry_invoker)
|
|
75
78
|
log(DEBUG, "[SimulationIO] Connected to %s", self._addr)
|
|
76
79
|
|
|
77
80
|
def _disconnect(self) -> None:
|
flwr/superexec/app.py
CHANGED