flwr-nightly 1.10.0.dev20240619__py3-none-any.whl → 1.10.0.dev20240707__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +3 -0
- flwr/cli/build.py +5 -9
- flwr/cli/new/new.py +104 -28
- flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +86 -0
- flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +124 -0
- flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
- flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +42 -0
- flwr/cli/run/run.py +21 -5
- flwr/client/__init__.py +2 -0
- flwr/client/app.py +15 -10
- flwr/client/client_app.py +30 -5
- flwr/client/dpfedavg_numpy_client.py +1 -1
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/connection.py +1 -1
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +4 -5
- flwr/client/mod/__init__.py +1 -1
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/utils.py +1 -1
- flwr/client/node_state.py +6 -3
- flwr/client/node_state_tests.py +1 -1
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +1 -1
- flwr/client/supernode/app.py +12 -4
- flwr/client/typing.py +2 -1
- flwr/common/address.py +1 -1
- flwr/common/config.py +8 -6
- flwr/common/constant.py +4 -1
- flwr/common/context.py +11 -1
- flwr/common/date.py +1 -1
- flwr/common/dp.py +1 -1
- flwr/common/grpc.py +1 -1
- flwr/common/logger.py +13 -0
- flwr/common/message.py +0 -17
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
- flwr/common/secure_aggregation/quantization.py +1 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
- flwr/common/version.py +14 -0
- flwr/server/compat/app.py +1 -1
- flwr/server/compat/app_utils.py +1 -1
- flwr/server/compat/driver_client_proxy.py +1 -1
- flwr/server/driver/driver.py +6 -0
- flwr/server/driver/grpc_driver.py +85 -63
- flwr/server/driver/inmemory_driver.py +28 -26
- flwr/server/run_serverapp.py +61 -18
- flwr/server/strategy/bulyan.py +1 -1
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/strategy/fedadagrad.py +1 -1
- flwr/server/strategy/fedadam.py +1 -1
- flwr/server/strategy/fedavg_android.py +1 -1
- flwr/server/strategy/fedavgm.py +1 -1
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedopt.py +1 -1
- flwr/server/strategy/fedprox.py +1 -1
- flwr/server/strategy/fedxgb_bagging.py +1 -1
- flwr/server/strategy/fedxgb_cyclic.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +1 -1
- flwr/server/strategy/fedyogi.py +1 -1
- flwr/server/strategy/krum.py +1 -1
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/driver_grpc.py +1 -1
- flwr/server/superlink/driver/driver_servicer.py +15 -3
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +1 -1
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +45 -26
- flwr/server/superlink/fleet/vce/vce_api.py +3 -8
- flwr/server/superlink/state/__init__.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +5 -5
- flwr/server/superlink/state/sqlite_state.py +5 -5
- flwr/server/superlink/state/state.py +1 -1
- flwr/server/superlink/state/state_factory.py +11 -2
- flwr/server/superlink/state/utils.py +6 -0
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +1 -1
- flwr/simulation/__init__.py +1 -1
- flwr/simulation/app.py +52 -37
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +0 -6
- flwr/simulation/ray_transport/ray_client_proxy.py +17 -10
- flwr/simulation/run_simulation.py +47 -28
- flwr/superexec/deployment.py +109 -0
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240707.dist-info}/METADATA +2 -1
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240707.dist-info}/RECORD +109 -98
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240707.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240707.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240707.dist-info}/entry_points.txt +0 -0
flwr/server/run_serverapp.py
CHANGED
|
@@ -22,10 +22,13 @@ from pathlib import Path
|
|
|
22
22
|
from typing import Optional
|
|
23
23
|
|
|
24
24
|
from flwr.common import Context, EventType, RecordSet, event
|
|
25
|
+
from flwr.common.config import get_flwr_dir, get_project_config, get_project_dir
|
|
25
26
|
from flwr.common.logger import log, update_console_handler, warn_deprecated_feature
|
|
26
27
|
from flwr.common.object_ref import load_app
|
|
28
|
+
from flwr.proto.driver_pb2 import CreateRunRequest # pylint: disable=E0611
|
|
27
29
|
|
|
28
|
-
from .driver import Driver
|
|
30
|
+
from .driver import Driver
|
|
31
|
+
from .driver.grpc_driver import GrpcDriver, GrpcDriverStub
|
|
29
32
|
from .server_app import LoadServerAppError, ServerApp
|
|
30
33
|
|
|
31
34
|
ADDRESS_DRIVER_API = "0.0.0.0:9091"
|
|
@@ -41,7 +44,7 @@ def run(
|
|
|
41
44
|
if not (server_app_attr is None) ^ (loaded_server_app is None):
|
|
42
45
|
raise ValueError(
|
|
43
46
|
"Either `server_app_attr` or `loaded_server_app` should be set "
|
|
44
|
-
"but not both.
|
|
47
|
+
"but not both."
|
|
45
48
|
)
|
|
46
49
|
|
|
47
50
|
if server_app_dir is not None:
|
|
@@ -74,7 +77,7 @@ def run(
|
|
|
74
77
|
log(DEBUG, "ServerApp finished running.")
|
|
75
78
|
|
|
76
79
|
|
|
77
|
-
def run_server_app() -> None:
|
|
80
|
+
def run_server_app() -> None: # pylint: disable=too-many-branches
|
|
78
81
|
"""Run Flower server app."""
|
|
79
82
|
event(EventType.RUN_SERVER_APP_ENTER)
|
|
80
83
|
|
|
@@ -134,11 +137,43 @@ def run_server_app() -> None:
|
|
|
134
137
|
cert_path,
|
|
135
138
|
)
|
|
136
139
|
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
140
|
+
server_app_attr: Optional[str] = getattr(args, "server-app")
|
|
141
|
+
if not (server_app_attr is None) ^ (args.run_id is None):
|
|
142
|
+
raise sys.exit(
|
|
143
|
+
"Please provide either a ServerApp reference or a Run ID, but not both. "
|
|
144
|
+
"For more details, use: ``flower-server-app -h``"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
stub = GrpcDriverStub(
|
|
148
|
+
driver_service_address=args.superlink, root_certificates=root_certificates
|
|
141
149
|
)
|
|
150
|
+
if args.run_id is not None:
|
|
151
|
+
# User provided `--run-id`, but not `server-app`
|
|
152
|
+
run_id = args.run_id
|
|
153
|
+
else:
|
|
154
|
+
# User provided `server-app`, but not `--run-id`
|
|
155
|
+
# Create run if run_id is not provided
|
|
156
|
+
stub.connect()
|
|
157
|
+
req = CreateRunRequest(fab_id=args.fab_id, fab_version=args.fab_version)
|
|
158
|
+
res = stub.create_run(req)
|
|
159
|
+
run_id = res.run_id
|
|
160
|
+
|
|
161
|
+
# Initialize GrpcDriver
|
|
162
|
+
driver = GrpcDriver(run_id=run_id, stub=stub)
|
|
163
|
+
|
|
164
|
+
# Dynamically obtain ServerApp path based on run_id
|
|
165
|
+
if args.run_id is not None:
|
|
166
|
+
# User provided `--run-id`, but not `server-app`
|
|
167
|
+
flwr_dir = get_flwr_dir(args.flwr_dir)
|
|
168
|
+
run_ = driver.run
|
|
169
|
+
server_app_dir = str(get_project_dir(run_.fab_id, run_.fab_version, flwr_dir))
|
|
170
|
+
config = get_project_config(server_app_dir)
|
|
171
|
+
server_app_attr = config["flower"]["components"]["serverapp"]
|
|
172
|
+
else:
|
|
173
|
+
# User provided `server-app`, but not `--run-id`
|
|
174
|
+
server_app_dir = str(Path(args.dir).absolute())
|
|
175
|
+
|
|
176
|
+
log(DEBUG, "Flower will load ServerApp `%s` in %s", server_app_attr, server_app_dir)
|
|
142
177
|
|
|
143
178
|
log(
|
|
144
179
|
DEBUG,
|
|
@@ -146,17 +181,6 @@ def run_server_app() -> None:
|
|
|
146
181
|
root_certificates,
|
|
147
182
|
)
|
|
148
183
|
|
|
149
|
-
server_app_dir = args.dir
|
|
150
|
-
server_app_attr = getattr(args, "server-app")
|
|
151
|
-
|
|
152
|
-
# Initialize GrpcDriver
|
|
153
|
-
driver = GrpcDriver(
|
|
154
|
-
driver_service_address=args.superlink,
|
|
155
|
-
root_certificates=root_certificates,
|
|
156
|
-
fab_id=args.fab_id,
|
|
157
|
-
fab_version=args.fab_version,
|
|
158
|
-
)
|
|
159
|
-
|
|
160
184
|
# Run the ServerApp with the Driver
|
|
161
185
|
run(driver=driver, server_app_dir=server_app_dir, server_app_attr=server_app_attr)
|
|
162
186
|
|
|
@@ -174,6 +198,8 @@ def _parse_args_run_server_app() -> argparse.ArgumentParser:
|
|
|
174
198
|
|
|
175
199
|
parser.add_argument(
|
|
176
200
|
"server-app",
|
|
201
|
+
nargs="?",
|
|
202
|
+
default=None,
|
|
177
203
|
help="For example: `server:app` or `project.package.module:wrapper.app`",
|
|
178
204
|
)
|
|
179
205
|
parser.add_argument(
|
|
@@ -223,5 +249,22 @@ def _parse_args_run_server_app() -> argparse.ArgumentParser:
|
|
|
223
249
|
type=str,
|
|
224
250
|
help="The version of the FAB used in the run.",
|
|
225
251
|
)
|
|
252
|
+
parser.add_argument(
|
|
253
|
+
"--run-id",
|
|
254
|
+
default=None,
|
|
255
|
+
type=int,
|
|
256
|
+
help="The identifier of the run.",
|
|
257
|
+
)
|
|
258
|
+
parser.add_argument(
|
|
259
|
+
"--flwr-dir",
|
|
260
|
+
default=None,
|
|
261
|
+
help="""The path containing installed Flower Apps.
|
|
262
|
+
By default, this value is equal to:
|
|
263
|
+
|
|
264
|
+
- `$FLWR_HOME/` if `$FLWR_HOME` is defined
|
|
265
|
+
- `$XDG_DATA_HOME/.flwr/` if `$XDG_DATA_HOME` is defined
|
|
266
|
+
- `$HOME/.flwr/` in all other cases
|
|
267
|
+
""",
|
|
268
|
+
)
|
|
226
269
|
|
|
227
270
|
return parser
|
flwr/server/strategy/bulyan.py
CHANGED
flwr/server/strategy/fedadam.py
CHANGED
flwr/server/strategy/fedavgm.py
CHANGED
flwr/server/strategy/fedopt.py
CHANGED
flwr/server/strategy/fedprox.py
CHANGED
flwr/server/strategy/fedyogi.py
CHANGED
flwr/server/strategy/krum.py
CHANGED
flwr/server/strategy/qfedavg.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -35,7 +35,11 @@ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
|
35
35
|
PushTaskInsResponse,
|
|
36
36
|
)
|
|
37
37
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
38
|
-
from flwr.proto.run_pb2 import
|
|
38
|
+
from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
39
|
+
GetRunRequest,
|
|
40
|
+
GetRunResponse,
|
|
41
|
+
Run,
|
|
42
|
+
)
|
|
39
43
|
from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
|
|
40
44
|
from flwr.server.superlink.state import State, StateFactory
|
|
41
45
|
from flwr.server.utils.validator import validate_task_ins_or_res
|
|
@@ -134,7 +138,15 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
134
138
|
self, request: GetRunRequest, context: grpc.ServicerContext
|
|
135
139
|
) -> GetRunResponse:
|
|
136
140
|
"""Get run information."""
|
|
137
|
-
|
|
141
|
+
log(DEBUG, "DriverServicer.GetRun")
|
|
142
|
+
|
|
143
|
+
# Init state
|
|
144
|
+
state: State = self.state_factory.state()
|
|
145
|
+
|
|
146
|
+
# Retrieve run information
|
|
147
|
+
run = state.get_run(request.run_id)
|
|
148
|
+
run_proto = None if run is None else Run(**vars(run))
|
|
149
|
+
return GetRunResponse(run=run_proto)
|
|
138
150
|
|
|
139
151
|
|
|
140
152
|
def _raise_if(validation_error: bool, detail: str) -> None:
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
"""Ray backend for the Fleet API using the Simulation Engine."""
|
|
16
16
|
|
|
17
17
|
import pathlib
|
|
18
|
-
from logging import DEBUG, ERROR
|
|
18
|
+
from logging import DEBUG, ERROR
|
|
19
19
|
from typing import Callable, Dict, List, Tuple, Union
|
|
20
20
|
|
|
21
21
|
import ray
|
|
@@ -24,16 +24,15 @@ from flwr.client.client_app import ClientApp
|
|
|
24
24
|
from flwr.common.context import Context
|
|
25
25
|
from flwr.common.logger import log
|
|
26
26
|
from flwr.common.message import Message
|
|
27
|
-
from flwr.
|
|
28
|
-
|
|
29
|
-
ClientAppActor,
|
|
30
|
-
init_ray,
|
|
31
|
-
)
|
|
27
|
+
from flwr.common.typing import ConfigsRecordValues
|
|
28
|
+
from flwr.simulation.ray_transport.ray_actor import BasicActorPool, ClientAppActor
|
|
32
29
|
from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth
|
|
33
30
|
|
|
34
31
|
from .backend import Backend, BackendConfig
|
|
35
32
|
|
|
36
33
|
ClientResourcesDict = Dict[str, Union[int, float]]
|
|
34
|
+
ActorArgsDict = Dict[str, Union[int, float, Callable[[], None]]]
|
|
35
|
+
RunTimeEnvDict = Dict[str, Union[str, List[str]]]
|
|
37
36
|
|
|
38
37
|
|
|
39
38
|
class RayBackend(Backend):
|
|
@@ -51,40 +50,29 @@ class RayBackend(Backend):
|
|
|
51
50
|
if not pathlib.Path(work_dir).exists():
|
|
52
51
|
raise ValueError(f"Specified work_dir {work_dir} does not exist.")
|
|
53
52
|
|
|
54
|
-
#
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
)
|
|
58
|
-
|
|
59
|
-
if backend_config.get("mute_logging", False):
|
|
60
|
-
init_ray(
|
|
61
|
-
logging_level=WARNING, log_to_driver=False, runtime_env=runtime_env
|
|
62
|
-
)
|
|
63
|
-
elif backend_config.get("silent", False):
|
|
64
|
-
init_ray(logging_level=WARNING, log_to_driver=True, runtime_env=runtime_env)
|
|
65
|
-
else:
|
|
66
|
-
init_ray(runtime_env=runtime_env)
|
|
53
|
+
# Initialise ray
|
|
54
|
+
self.init_args_key = "init_args"
|
|
55
|
+
self.init_ray(backend_config, work_dir)
|
|
67
56
|
|
|
68
57
|
# Validate client resources
|
|
69
58
|
self.client_resources_key = "client_resources"
|
|
59
|
+
client_resources = self._validate_client_resources(config=backend_config)
|
|
70
60
|
|
|
71
61
|
# Create actor pool
|
|
72
|
-
|
|
73
|
-
actor_kwargs = {"on_actor_init_fn": enable_tf_gpu_growth} if use_tf else {}
|
|
62
|
+
actor_kwargs = self._validate_actor_arguments(config=backend_config)
|
|
74
63
|
|
|
75
|
-
client_resources = self._validate_client_resources(config=backend_config)
|
|
76
64
|
self.pool = BasicActorPool(
|
|
77
65
|
actor_type=ClientAppActor,
|
|
78
66
|
client_resources=client_resources,
|
|
79
67
|
actor_kwargs=actor_kwargs,
|
|
80
68
|
)
|
|
81
69
|
|
|
82
|
-
def _configure_runtime_env(self, work_dir: str) ->
|
|
70
|
+
def _configure_runtime_env(self, work_dir: str) -> RunTimeEnvDict:
|
|
83
71
|
"""Return list of files/subdirectories to exclude relative to work_dir.
|
|
84
72
|
|
|
85
73
|
Without this, Ray will push everything to the Ray Cluster.
|
|
86
74
|
"""
|
|
87
|
-
runtime_env:
|
|
75
|
+
runtime_env: RunTimeEnvDict = {"working_dir": work_dir}
|
|
88
76
|
|
|
89
77
|
excludes = []
|
|
90
78
|
path = pathlib.Path(work_dir)
|
|
@@ -125,6 +113,37 @@ class RayBackend(Backend):
|
|
|
125
113
|
|
|
126
114
|
return client_resources
|
|
127
115
|
|
|
116
|
+
def _validate_actor_arguments(self, config: BackendConfig) -> ActorArgsDict:
|
|
117
|
+
actor_args_config = config.get("actor", False)
|
|
118
|
+
actor_args: ActorArgsDict = {}
|
|
119
|
+
if actor_args_config:
|
|
120
|
+
use_tf = actor_args.get("tensorflow", False)
|
|
121
|
+
if use_tf:
|
|
122
|
+
actor_args["on_actor_init_fn"] = enable_tf_gpu_growth
|
|
123
|
+
return actor_args
|
|
124
|
+
|
|
125
|
+
def init_ray(self, backend_config: BackendConfig, work_dir: str) -> None:
|
|
126
|
+
"""Intialises Ray if not already initialised."""
|
|
127
|
+
if not ray.is_initialized():
|
|
128
|
+
# Init ray and append working dir if needed
|
|
129
|
+
runtime_env = (
|
|
130
|
+
self._configure_runtime_env(work_dir=work_dir) if work_dir else None
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
ray_init_args: Dict[
|
|
134
|
+
str,
|
|
135
|
+
Union[ConfigsRecordValues, RunTimeEnvDict],
|
|
136
|
+
] = {}
|
|
137
|
+
|
|
138
|
+
if backend_config.get(self.init_args_key):
|
|
139
|
+
for k, v in backend_config[self.init_args_key].items():
|
|
140
|
+
ray_init_args[k] = v
|
|
141
|
+
|
|
142
|
+
if runtime_env is not None:
|
|
143
|
+
ray_init_args["runtime_env"] = runtime_env
|
|
144
|
+
|
|
145
|
+
ray.init(**ray_init_args)
|
|
146
|
+
|
|
128
147
|
@property
|
|
129
148
|
def num_workers(self) -> int:
|
|
130
149
|
"""Return number of actors in pool."""
|
|
@@ -149,10 +168,10 @@ class RayBackend(Backend):
|
|
|
149
168
|
|
|
150
169
|
Return output message and updated context.
|
|
151
170
|
"""
|
|
152
|
-
partition_id =
|
|
171
|
+
partition_id = context.partition_id
|
|
153
172
|
|
|
154
173
|
try:
|
|
155
|
-
#
|
|
174
|
+
# Submit a task to the pool
|
|
156
175
|
future = await self.pool.submit(
|
|
157
176
|
lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state),
|
|
158
177
|
(app, message, str(partition_id), context),
|
|
@@ -57,7 +57,6 @@ async def worker(
|
|
|
57
57
|
queue: "asyncio.Queue[TaskIns]",
|
|
58
58
|
node_states: Dict[int, NodeState],
|
|
59
59
|
state_factory: StateFactory,
|
|
60
|
-
nodes_mapping: NodeToPartitionMapping,
|
|
61
60
|
backend: Backend,
|
|
62
61
|
) -> None:
|
|
63
62
|
"""Get TaskIns from queue and pass it to an actor in the pool to execute it."""
|
|
@@ -74,8 +73,6 @@ async def worker(
|
|
|
74
73
|
|
|
75
74
|
# Convert TaskIns to Message
|
|
76
75
|
message = message_from_taskins(task_ins)
|
|
77
|
-
# Set partition_id
|
|
78
|
-
message.metadata.partition_id = nodes_mapping[node_id]
|
|
79
76
|
|
|
80
77
|
# Let backend process message
|
|
81
78
|
out_mssg, updated_context = await backend.process_message(
|
|
@@ -187,9 +184,7 @@ async def run(
|
|
|
187
184
|
# Add workers (they submit Messages to Backend)
|
|
188
185
|
worker_tasks = [
|
|
189
186
|
asyncio.create_task(
|
|
190
|
-
worker(
|
|
191
|
-
app_fn, queue, node_states, state_factory, nodes_mapping, backend
|
|
192
|
-
)
|
|
187
|
+
worker(app_fn, queue, node_states, state_factory, backend)
|
|
193
188
|
)
|
|
194
189
|
for _ in range(backend.num_workers)
|
|
195
190
|
]
|
|
@@ -291,8 +286,8 @@ def start_vce(
|
|
|
291
286
|
|
|
292
287
|
# Construct mapping of NodeStates
|
|
293
288
|
node_states: Dict[int, NodeState] = {}
|
|
294
|
-
for node_id in nodes_mapping:
|
|
295
|
-
node_states[node_id] = NodeState()
|
|
289
|
+
for node_id, partition_id in nodes_mapping.items():
|
|
290
|
+
node_states[node_id] = NodeState(partition_id=partition_id)
|
|
296
291
|
|
|
297
292
|
# Load backend config
|
|
298
293
|
log(DEBUG, "Supported backends: %s", list(supported_backends.keys()))
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -15,7 +15,6 @@
|
|
|
15
15
|
"""In-memory State implementation."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
import os
|
|
19
18
|
import threading
|
|
20
19
|
import time
|
|
21
20
|
from logging import ERROR
|
|
@@ -23,12 +22,13 @@ from typing import Dict, List, Optional, Set, Tuple
|
|
|
23
22
|
from uuid import UUID, uuid4
|
|
24
23
|
|
|
25
24
|
from flwr.common import log, now
|
|
25
|
+
from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES
|
|
26
26
|
from flwr.common.typing import Run
|
|
27
27
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
28
28
|
from flwr.server.superlink.state.state import State
|
|
29
29
|
from flwr.server.utils import validate_task_ins_or_res
|
|
30
30
|
|
|
31
|
-
from .utils import make_node_unavailable_taskres
|
|
31
|
+
from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres
|
|
32
32
|
|
|
33
33
|
|
|
34
34
|
class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
@@ -216,7 +216,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
216
216
|
) -> int:
|
|
217
217
|
"""Create, store in state, and return `node_id`."""
|
|
218
218
|
# Sample a random int64 as node_id
|
|
219
|
-
node_id
|
|
219
|
+
node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
|
|
220
220
|
|
|
221
221
|
with self.lock:
|
|
222
222
|
if node_id in self.node_ids:
|
|
@@ -279,7 +279,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
279
279
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
280
280
|
# Sample a random int64 as run_id
|
|
281
281
|
with self.lock:
|
|
282
|
-
run_id
|
|
282
|
+
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
|
|
283
283
|
|
|
284
284
|
if run_id not in self.run_ids:
|
|
285
285
|
self.run_ids[run_id] = Run(
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -15,7 +15,6 @@
|
|
|
15
15
|
"""SQLite based implemenation of server state."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
import os
|
|
19
18
|
import re
|
|
20
19
|
import sqlite3
|
|
21
20
|
import time
|
|
@@ -24,6 +23,7 @@ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast
|
|
|
24
23
|
from uuid import UUID, uuid4
|
|
25
24
|
|
|
26
25
|
from flwr.common import log, now
|
|
26
|
+
from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES
|
|
27
27
|
from flwr.common.typing import Run
|
|
28
28
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
29
29
|
from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611
|
|
@@ -31,7 +31,7 @@ from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
|
|
|
31
31
|
from flwr.server.utils.validator import validate_task_ins_or_res
|
|
32
32
|
|
|
33
33
|
from .state import State
|
|
34
|
-
from .utils import make_node_unavailable_taskres
|
|
34
|
+
from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres
|
|
35
35
|
|
|
36
36
|
SQL_CREATE_TABLE_NODE = """
|
|
37
37
|
CREATE TABLE IF NOT EXISTS node(
|
|
@@ -541,7 +541,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
541
541
|
) -> int:
|
|
542
542
|
"""Create, store in state, and return `node_id`."""
|
|
543
543
|
# Sample a random int64 as node_id
|
|
544
|
-
node_id
|
|
544
|
+
node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
|
|
545
545
|
|
|
546
546
|
query = "SELECT node_id FROM node WHERE public_key = :public_key;"
|
|
547
547
|
row = self.query(query, {"public_key": public_key})
|
|
@@ -616,7 +616,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
616
616
|
def create_run(self, fab_id: str, fab_version: str) -> int:
|
|
617
617
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
618
618
|
# Sample a random int64 as run_id
|
|
619
|
-
run_id
|
|
619
|
+
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
|
|
620
620
|
|
|
621
621
|
# Check conflicts
|
|
622
622
|
query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
|