flwr-nightly 1.10.0.dev20240709__py3-none-any.whl → 1.10.0.dev20240711__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/config_utils.py +10 -0
- flwr/cli/new/new.py +3 -1
- flwr/cli/run/run.py +25 -8
- flwr/client/app.py +12 -6
- flwr/client/node_state.py +36 -8
- flwr/client/node_state_tests.py +3 -2
- flwr/client/supernode/app.py +20 -6
- flwr/common/logger.py +25 -0
- flwr/server/__init__.py +2 -0
- flwr/server/server_app.py +56 -10
- flwr/server/serverapp_components.py +52 -0
- flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
- flwr/server/superlink/fleet/vce/backend/raybackend.py +8 -9
- flwr/server/superlink/fleet/vce/vce_api.py +93 -98
- flwr/server/typing.py +2 -0
- flwr/simulation/ray_transport/ray_actor.py +15 -19
- flwr/simulation/run_simulation.py +49 -33
- {flwr_nightly-1.10.0.dev20240709.dist-info → flwr_nightly-1.10.0.dev20240711.dist-info}/METADATA +2 -2
- {flwr_nightly-1.10.0.dev20240709.dist-info → flwr_nightly-1.10.0.dev20240711.dist-info}/RECORD +22 -21
- {flwr_nightly-1.10.0.dev20240709.dist-info → flwr_nightly-1.10.0.dev20240711.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240709.dist-info → flwr_nightly-1.10.0.dev20240711.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.10.0.dev20240709.dist-info → flwr_nightly-1.10.0.dev20240711.dist-info}/entry_points.txt +0 -0
flwr/cli/config_utils.py
CHANGED
|
@@ -108,6 +108,14 @@ def load(path: Optional[Path] = None) -> Optional[Dict[str, Any]]:
|
|
|
108
108
|
return load_from_string(toml_file.read())
|
|
109
109
|
|
|
110
110
|
|
|
111
|
+
def _validate_run_config(config_dict: Dict[str, Any], errors: List[str]) -> None:
|
|
112
|
+
for key, value in config_dict.items():
|
|
113
|
+
if isinstance(value, dict):
|
|
114
|
+
_validate_run_config(config_dict[key], errors)
|
|
115
|
+
elif not isinstance(value, str):
|
|
116
|
+
errors.append(f"Config value of key {key} is not of type `str`.")
|
|
117
|
+
|
|
118
|
+
|
|
111
119
|
# pylint: disable=too-many-branches
|
|
112
120
|
def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]:
|
|
113
121
|
"""Validate pyproject.toml fields."""
|
|
@@ -133,6 +141,8 @@ def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]
|
|
|
133
141
|
else:
|
|
134
142
|
if "publisher" not in config["flower"]:
|
|
135
143
|
errors.append('Property "publisher" missing in [flower]')
|
|
144
|
+
if "config" in config["flower"]:
|
|
145
|
+
_validate_run_config(config["flower"]["config"], errors)
|
|
136
146
|
if "components" not in config["flower"]:
|
|
137
147
|
errors.append("Missing [flower.components] section")
|
|
138
148
|
else:
|
flwr/cli/new/new.py
CHANGED
|
@@ -264,9 +264,11 @@ def new(
|
|
|
264
264
|
bold=True,
|
|
265
265
|
)
|
|
266
266
|
)
|
|
267
|
+
|
|
268
|
+
_add = " huggingface-cli login\n" if framework_str == "flowertune" else ""
|
|
267
269
|
print(
|
|
268
270
|
typer.style(
|
|
269
|
-
f" cd {package_name}\n" + " pip install -e .\n flwr run\n",
|
|
271
|
+
f" cd {package_name}\n" + " pip install -e .\n" + _add + " flwr run\n",
|
|
270
272
|
fg=typer.colors.BRIGHT_CYAN,
|
|
271
273
|
bold=True,
|
|
272
274
|
)
|
flwr/cli/run/run.py
CHANGED
|
@@ -18,13 +18,14 @@ import sys
|
|
|
18
18
|
from enum import Enum
|
|
19
19
|
from logging import DEBUG
|
|
20
20
|
from pathlib import Path
|
|
21
|
-
from typing import Optional
|
|
21
|
+
from typing import Dict, Optional
|
|
22
22
|
|
|
23
23
|
import typer
|
|
24
24
|
from typing_extensions import Annotated
|
|
25
25
|
|
|
26
26
|
from flwr.cli import config_utils
|
|
27
27
|
from flwr.cli.build import build
|
|
28
|
+
from flwr.common.config import parse_config_args
|
|
28
29
|
from flwr.common.constant import SUPEREXEC_DEFAULT_ADDRESS
|
|
29
30
|
from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
|
|
30
31
|
from flwr.common.logger import log
|
|
@@ -58,15 +59,20 @@ def run(
|
|
|
58
59
|
Optional[Path],
|
|
59
60
|
typer.Option(help="Path of the Flower project to run"),
|
|
60
61
|
] = None,
|
|
62
|
+
config_overrides: Annotated[
|
|
63
|
+
Optional[str],
|
|
64
|
+
typer.Option(
|
|
65
|
+
"--config",
|
|
66
|
+
"-c",
|
|
67
|
+
help="Override configuration key-value pairs",
|
|
68
|
+
),
|
|
69
|
+
] = None,
|
|
61
70
|
) -> None:
|
|
62
71
|
"""Run Flower project."""
|
|
63
|
-
if use_superexec:
|
|
64
|
-
_start_superexec_run(directory)
|
|
65
|
-
return
|
|
66
|
-
|
|
67
72
|
typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
|
|
68
73
|
|
|
69
|
-
|
|
74
|
+
pyproject_path = directory / "pyproject.toml" if directory else None
|
|
75
|
+
config, errors, warnings = config_utils.load_and_validate(path=pyproject_path)
|
|
70
76
|
|
|
71
77
|
if config is None:
|
|
72
78
|
typer.secho(
|
|
@@ -88,6 +94,12 @@ def run(
|
|
|
88
94
|
|
|
89
95
|
typer.secho("Success", fg=typer.colors.GREEN)
|
|
90
96
|
|
|
97
|
+
if use_superexec:
|
|
98
|
+
_start_superexec_run(
|
|
99
|
+
parse_config_args(config_overrides, separator=","), directory
|
|
100
|
+
)
|
|
101
|
+
return
|
|
102
|
+
|
|
91
103
|
server_app_ref = config["flower"]["components"]["serverapp"]
|
|
92
104
|
client_app_ref = config["flower"]["components"]["clientapp"]
|
|
93
105
|
|
|
@@ -115,7 +127,9 @@ def run(
|
|
|
115
127
|
)
|
|
116
128
|
|
|
117
129
|
|
|
118
|
-
def _start_superexec_run(
|
|
130
|
+
def _start_superexec_run(
|
|
131
|
+
override_config: Dict[str, str], directory: Optional[Path]
|
|
132
|
+
) -> None:
|
|
119
133
|
def on_channel_state_change(channel_connectivity: str) -> None:
|
|
120
134
|
"""Log channel connectivity."""
|
|
121
135
|
log(DEBUG, channel_connectivity)
|
|
@@ -132,6 +146,9 @@ def _start_superexec_run(directory: Optional[Path]) -> None:
|
|
|
132
146
|
|
|
133
147
|
fab_path = build(directory)
|
|
134
148
|
|
|
135
|
-
req = StartRunRequest(
|
|
149
|
+
req = StartRunRequest(
|
|
150
|
+
fab_file=Path(fab_path).read_bytes(),
|
|
151
|
+
override_config=override_config,
|
|
152
|
+
)
|
|
136
153
|
res = stub.StartRun(req)
|
|
137
154
|
typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)
|
flwr/client/app.py
CHANGED
|
@@ -19,6 +19,7 @@ import sys
|
|
|
19
19
|
import time
|
|
20
20
|
from dataclasses import dataclass
|
|
21
21
|
from logging import DEBUG, ERROR, INFO, WARN
|
|
22
|
+
from pathlib import Path
|
|
22
23
|
from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union
|
|
23
24
|
|
|
24
25
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
@@ -193,6 +194,7 @@ def _start_client_internal(
|
|
|
193
194
|
max_retries: Optional[int] = None,
|
|
194
195
|
max_wait_time: Optional[float] = None,
|
|
195
196
|
partition_id: Optional[int] = None,
|
|
197
|
+
flwr_dir: Optional[Path] = None,
|
|
196
198
|
) -> None:
|
|
197
199
|
"""Start a Flower client node which connects to a Flower server.
|
|
198
200
|
|
|
@@ -239,6 +241,8 @@ def _start_client_internal(
|
|
|
239
241
|
partition_id: Optional[int] (default: None)
|
|
240
242
|
The data partition index associated with this node. Better suited for
|
|
241
243
|
prototyping purposes.
|
|
244
|
+
flwr_dir: Optional[Path] (default: None)
|
|
245
|
+
The fully resolved path containing installed Flower Apps.
|
|
242
246
|
"""
|
|
243
247
|
if insecure is None:
|
|
244
248
|
insecure = root_certificates is None
|
|
@@ -316,7 +320,7 @@ def _start_client_internal(
|
|
|
316
320
|
)
|
|
317
321
|
|
|
318
322
|
node_state = NodeState(partition_id=partition_id)
|
|
319
|
-
|
|
323
|
+
runs: Dict[int, Run] = {}
|
|
320
324
|
|
|
321
325
|
while not app_state_tracker.interrupt:
|
|
322
326
|
sleep_duration: int = 0
|
|
@@ -366,15 +370,17 @@ def _start_client_internal(
|
|
|
366
370
|
|
|
367
371
|
# Get run info
|
|
368
372
|
run_id = message.metadata.run_id
|
|
369
|
-
if run_id not in
|
|
373
|
+
if run_id not in runs:
|
|
370
374
|
if get_run is not None:
|
|
371
|
-
|
|
375
|
+
runs[run_id] = get_run(run_id)
|
|
372
376
|
# If get_run is None, i.e., in grpc-bidi mode
|
|
373
377
|
else:
|
|
374
|
-
|
|
378
|
+
runs[run_id] = Run(run_id, "", "", {})
|
|
375
379
|
|
|
376
380
|
# Register context for this run
|
|
377
|
-
node_state.register_context(
|
|
381
|
+
node_state.register_context(
|
|
382
|
+
run_id=run_id, run=runs[run_id], flwr_dir=flwr_dir
|
|
383
|
+
)
|
|
378
384
|
|
|
379
385
|
# Retrieve context for this run
|
|
380
386
|
context = node_state.retrieve_context(run_id=run_id)
|
|
@@ -388,7 +394,7 @@ def _start_client_internal(
|
|
|
388
394
|
# Handle app loading and task message
|
|
389
395
|
try:
|
|
390
396
|
# Load ClientApp instance
|
|
391
|
-
run: Run =
|
|
397
|
+
run: Run = runs[run_id]
|
|
392
398
|
client_app: ClientApp = load_client_app_fn(
|
|
393
399
|
run.fab_id, run.fab_version
|
|
394
400
|
)
|
flwr/client/node_state.py
CHANGED
|
@@ -15,9 +15,21 @@
|
|
|
15
15
|
"""Node state."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
from pathlib import Path
|
|
18
20
|
from typing import Any, Dict, Optional
|
|
19
21
|
|
|
20
22
|
from flwr.common import Context, RecordSet
|
|
23
|
+
from flwr.common.config import get_fused_config
|
|
24
|
+
from flwr.common.typing import Run
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass()
|
|
28
|
+
class RunInfo:
|
|
29
|
+
"""Contains the Context and initial run_config of a Run."""
|
|
30
|
+
|
|
31
|
+
context: Context
|
|
32
|
+
initial_run_config: Dict[str, str]
|
|
21
33
|
|
|
22
34
|
|
|
23
35
|
class NodeState:
|
|
@@ -25,20 +37,31 @@ class NodeState:
|
|
|
25
37
|
|
|
26
38
|
def __init__(self, partition_id: Optional[int]) -> None:
|
|
27
39
|
self._meta: Dict[str, Any] = {} # holds metadata about the node
|
|
28
|
-
self.
|
|
40
|
+
self.run_infos: Dict[int, RunInfo] = {}
|
|
29
41
|
self._partition_id = partition_id
|
|
30
42
|
|
|
31
|
-
def register_context(
|
|
43
|
+
def register_context(
|
|
44
|
+
self,
|
|
45
|
+
run_id: int,
|
|
46
|
+
run: Optional[Run] = None,
|
|
47
|
+
flwr_dir: Optional[Path] = None,
|
|
48
|
+
) -> None:
|
|
32
49
|
"""Register new run context for this node."""
|
|
33
|
-
if run_id not in self.
|
|
34
|
-
|
|
35
|
-
|
|
50
|
+
if run_id not in self.run_infos:
|
|
51
|
+
initial_run_config = get_fused_config(run, flwr_dir) if run else {}
|
|
52
|
+
self.run_infos[run_id] = RunInfo(
|
|
53
|
+
initial_run_config=initial_run_config,
|
|
54
|
+
context=Context(
|
|
55
|
+
state=RecordSet(),
|
|
56
|
+
run_config=initial_run_config.copy(),
|
|
57
|
+
partition_id=self._partition_id,
|
|
58
|
+
),
|
|
36
59
|
)
|
|
37
60
|
|
|
38
61
|
def retrieve_context(self, run_id: int) -> Context:
|
|
39
62
|
"""Get run context given a run_id."""
|
|
40
|
-
if run_id in self.
|
|
41
|
-
return self.
|
|
63
|
+
if run_id in self.run_infos:
|
|
64
|
+
return self.run_infos[run_id].context
|
|
42
65
|
|
|
43
66
|
raise RuntimeError(
|
|
44
67
|
f"Context for run_id={run_id} doesn't exist."
|
|
@@ -48,4 +71,9 @@ class NodeState:
|
|
|
48
71
|
|
|
49
72
|
def update_context(self, run_id: int, context: Context) -> None:
|
|
50
73
|
"""Update run context."""
|
|
51
|
-
self.
|
|
74
|
+
if context.run_config != self.run_infos[run_id].initial_run_config:
|
|
75
|
+
raise ValueError(
|
|
76
|
+
"The `run_config` field of the `Context` object cannot be "
|
|
77
|
+
f"modified (run_id: {run_id})."
|
|
78
|
+
)
|
|
79
|
+
self.run_infos[run_id].context = context
|
flwr/client/node_state_tests.py
CHANGED
|
@@ -59,7 +59,8 @@ def test_multirun_in_node_state() -> None:
|
|
|
59
59
|
node_state.update_context(run_id=run_id, context=updated_state)
|
|
60
60
|
|
|
61
61
|
# Verify values
|
|
62
|
-
for run_id,
|
|
62
|
+
for run_id, run_info in node_state.run_infos.items():
|
|
63
63
|
assert (
|
|
64
|
-
context.state.configs_records["counter"]["count"]
|
|
64
|
+
run_info.context.state.configs_records["counter"]["count"]
|
|
65
|
+
== expected_values[run_id]
|
|
65
66
|
)
|
flwr/client/supernode/app.py
CHANGED
|
@@ -68,6 +68,7 @@ def run_supernode() -> None:
|
|
|
68
68
|
max_retries=args.max_retries,
|
|
69
69
|
max_wait_time=args.max_wait_time,
|
|
70
70
|
partition_id=args.partition_id,
|
|
71
|
+
flwr_dir=get_flwr_dir(args.flwr_dir),
|
|
71
72
|
)
|
|
72
73
|
|
|
73
74
|
# Graceful shutdown
|
|
@@ -178,7 +179,7 @@ def _get_load_client_app_fn(
|
|
|
178
179
|
else:
|
|
179
180
|
flwr_dir = Path(args.flwr_dir).absolute()
|
|
180
181
|
|
|
181
|
-
|
|
182
|
+
inserted_path = None
|
|
182
183
|
|
|
183
184
|
default_app_ref: str = getattr(args, "client-app")
|
|
184
185
|
|
|
@@ -188,6 +189,11 @@ def _get_load_client_app_fn(
|
|
|
188
189
|
"Flower SuperNode will load and validate ClientApp `%s`",
|
|
189
190
|
getattr(args, "client-app"),
|
|
190
191
|
)
|
|
192
|
+
# Insert sys.path
|
|
193
|
+
dir_path = Path(args.dir).absolute()
|
|
194
|
+
sys.path.insert(0, str(dir_path))
|
|
195
|
+
inserted_path = str(dir_path)
|
|
196
|
+
|
|
191
197
|
valid, error_msg = validate(default_app_ref)
|
|
192
198
|
if not valid and error_msg:
|
|
193
199
|
raise LoadClientAppError(error_msg) from None
|
|
@@ -196,7 +202,7 @@ def _get_load_client_app_fn(
|
|
|
196
202
|
# If multi-app feature is disabled
|
|
197
203
|
if not multi_app:
|
|
198
204
|
# Get sys path to be inserted
|
|
199
|
-
|
|
205
|
+
dir_path = Path(args.dir).absolute()
|
|
200
206
|
|
|
201
207
|
# Set app reference
|
|
202
208
|
client_app_ref = default_app_ref
|
|
@@ -209,7 +215,7 @@ def _get_load_client_app_fn(
|
|
|
209
215
|
|
|
210
216
|
log(WARN, "FAB ID is not provided; the default ClientApp will be loaded.")
|
|
211
217
|
# Get sys path to be inserted
|
|
212
|
-
|
|
218
|
+
dir_path = Path(args.dir).absolute()
|
|
213
219
|
|
|
214
220
|
# Set app reference
|
|
215
221
|
client_app_ref = default_app_ref
|
|
@@ -222,13 +228,21 @@ def _get_load_client_app_fn(
|
|
|
222
228
|
raise LoadClientAppError("Failed to load ClientApp") from e
|
|
223
229
|
|
|
224
230
|
# Get sys path to be inserted
|
|
225
|
-
|
|
231
|
+
dir_path = Path(project_dir).absolute()
|
|
226
232
|
|
|
227
233
|
# Set app reference
|
|
228
234
|
client_app_ref = config["flower"]["components"]["clientapp"]
|
|
229
235
|
|
|
230
236
|
# Set sys.path
|
|
231
|
-
|
|
237
|
+
nonlocal inserted_path
|
|
238
|
+
if inserted_path != str(dir_path):
|
|
239
|
+
# Remove the previously inserted path
|
|
240
|
+
if inserted_path is not None:
|
|
241
|
+
sys.path.remove(inserted_path)
|
|
242
|
+
# Insert the new path
|
|
243
|
+
sys.path.insert(0, str(dir_path))
|
|
244
|
+
|
|
245
|
+
inserted_path = str(dir_path)
|
|
232
246
|
|
|
233
247
|
# Load ClientApp
|
|
234
248
|
log(
|
|
@@ -236,7 +250,7 @@ def _get_load_client_app_fn(
|
|
|
236
250
|
"Loading ClientApp `%s`",
|
|
237
251
|
client_app_ref,
|
|
238
252
|
)
|
|
239
|
-
client_app = load_app(client_app_ref, LoadClientAppError,
|
|
253
|
+
client_app = load_app(client_app_ref, LoadClientAppError, dir_path)
|
|
240
254
|
|
|
241
255
|
if not isinstance(client_app, ClientApp):
|
|
242
256
|
raise LoadClientAppError(
|
flwr/common/logger.py
CHANGED
|
@@ -197,6 +197,31 @@ def warn_deprecated_feature(name: str) -> None:
|
|
|
197
197
|
)
|
|
198
198
|
|
|
199
199
|
|
|
200
|
+
def warn_deprecated_feature_with_example(
|
|
201
|
+
deprecation_message: str, example_message: str, code_example: str
|
|
202
|
+
) -> None:
|
|
203
|
+
"""Warn if a feature is deprecated and show code example."""
|
|
204
|
+
log(
|
|
205
|
+
WARN,
|
|
206
|
+
"""DEPRECATED FEATURE: %s
|
|
207
|
+
|
|
208
|
+
Check the following `FEATURE UPDATE` warning message for the preferred
|
|
209
|
+
new mechanism to use this feature in Flower.
|
|
210
|
+
""",
|
|
211
|
+
deprecation_message,
|
|
212
|
+
)
|
|
213
|
+
log(
|
|
214
|
+
WARN,
|
|
215
|
+
"""FEATURE UPDATE: %s
|
|
216
|
+
------------------------------------------------------------
|
|
217
|
+
%s
|
|
218
|
+
------------------------------------------------------------
|
|
219
|
+
""",
|
|
220
|
+
example_message,
|
|
221
|
+
code_example,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
|
|
200
225
|
def warn_unsupported_feature(name: str) -> None:
|
|
201
226
|
"""Warn the user when they use an unsupported feature."""
|
|
202
227
|
log(
|
flwr/server/__init__.py
CHANGED
|
@@ -28,6 +28,7 @@ from .run_serverapp import run_server_app as run_server_app
|
|
|
28
28
|
from .server import Server as Server
|
|
29
29
|
from .server_app import ServerApp as ServerApp
|
|
30
30
|
from .server_config import ServerConfig as ServerConfig
|
|
31
|
+
from .serverapp_components import ServerAppComponents as ServerAppComponents
|
|
31
32
|
|
|
32
33
|
__all__ = [
|
|
33
34
|
"ClientManager",
|
|
@@ -36,6 +37,7 @@ __all__ = [
|
|
|
36
37
|
"LegacyContext",
|
|
37
38
|
"Server",
|
|
38
39
|
"ServerApp",
|
|
40
|
+
"ServerAppComponents",
|
|
39
41
|
"ServerConfig",
|
|
40
42
|
"SimpleClientManager",
|
|
41
43
|
"run_server_app",
|
flwr/server/server_app.py
CHANGED
|
@@ -17,8 +17,11 @@
|
|
|
17
17
|
|
|
18
18
|
from typing import Callable, Optional
|
|
19
19
|
|
|
20
|
-
from flwr.common import Context
|
|
21
|
-
from flwr.common.logger import
|
|
20
|
+
from flwr.common import Context
|
|
21
|
+
from flwr.common.logger import (
|
|
22
|
+
warn_deprecated_feature_with_example,
|
|
23
|
+
warn_preview_feature,
|
|
24
|
+
)
|
|
22
25
|
from flwr.server.strategy import Strategy
|
|
23
26
|
|
|
24
27
|
from .client_manager import ClientManager
|
|
@@ -26,7 +29,20 @@ from .compat import start_driver
|
|
|
26
29
|
from .driver import Driver
|
|
27
30
|
from .server import Server
|
|
28
31
|
from .server_config import ServerConfig
|
|
29
|
-
from .typing import ServerAppCallable
|
|
32
|
+
from .typing import ServerAppCallable, ServerFn
|
|
33
|
+
|
|
34
|
+
SERVER_FN_USAGE_EXAMPLE = """
|
|
35
|
+
|
|
36
|
+
def server_fn(context: Context):
|
|
37
|
+
server_config = ServerConfig(num_rounds=3)
|
|
38
|
+
strategy = FedAvg()
|
|
39
|
+
return ServerAppComponents(
|
|
40
|
+
strategy=strategy,
|
|
41
|
+
server_config=server_config,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
app = ServerApp(server_fn=server_fn)
|
|
45
|
+
"""
|
|
30
46
|
|
|
31
47
|
|
|
32
48
|
class ServerApp:
|
|
@@ -36,13 +52,15 @@ class ServerApp:
|
|
|
36
52
|
--------
|
|
37
53
|
Use the `ServerApp` with an existing `Strategy`:
|
|
38
54
|
|
|
39
|
-
>>>
|
|
40
|
-
>>>
|
|
55
|
+
>>> def server_fn(context: Context):
|
|
56
|
+
>>> server_config = ServerConfig(num_rounds=3)
|
|
57
|
+
>>> strategy = FedAvg()
|
|
58
|
+
>>> return ServerAppComponents(
|
|
59
|
+
>>> strategy=strategy,
|
|
60
|
+
>>> server_config=server_config,
|
|
61
|
+
>>> )
|
|
41
62
|
>>>
|
|
42
|
-
>>> app = ServerApp(
|
|
43
|
-
>>> server_config=server_config,
|
|
44
|
-
>>> strategy=strategy,
|
|
45
|
-
>>> )
|
|
63
|
+
>>> app = ServerApp(server_fn=server_fn)
|
|
46
64
|
|
|
47
65
|
Use the `ServerApp` with a custom main function:
|
|
48
66
|
|
|
@@ -53,23 +71,52 @@ class ServerApp:
|
|
|
53
71
|
>>> print("ServerApp running")
|
|
54
72
|
"""
|
|
55
73
|
|
|
74
|
+
# pylint: disable=too-many-arguments
|
|
56
75
|
def __init__(
|
|
57
76
|
self,
|
|
58
77
|
server: Optional[Server] = None,
|
|
59
78
|
config: Optional[ServerConfig] = None,
|
|
60
79
|
strategy: Optional[Strategy] = None,
|
|
61
80
|
client_manager: Optional[ClientManager] = None,
|
|
81
|
+
server_fn: Optional[ServerFn] = None,
|
|
62
82
|
) -> None:
|
|
83
|
+
if any([server, config, strategy, client_manager]):
|
|
84
|
+
warn_deprecated_feature_with_example(
|
|
85
|
+
deprecation_message="Passing either `server`, `config`, `strategy` or "
|
|
86
|
+
"`client_manager` directly to the ServerApp "
|
|
87
|
+
"constructor is deprecated.",
|
|
88
|
+
example_message="Pass `ServerApp` arguments wrapped "
|
|
89
|
+
"in a `flwr.server.ServerAppComponents` object that gets "
|
|
90
|
+
"returned by a function passed as the `server_fn` argument "
|
|
91
|
+
"to the `ServerApp` constructor. For example: ",
|
|
92
|
+
code_example=SERVER_FN_USAGE_EXAMPLE,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
if server_fn:
|
|
96
|
+
raise ValueError(
|
|
97
|
+
"Passing `server_fn` is incompatible with passing the "
|
|
98
|
+
"other arguments (now deprecated) to ServerApp. "
|
|
99
|
+
"Use `server_fn` exclusively."
|
|
100
|
+
)
|
|
101
|
+
|
|
63
102
|
self._server = server
|
|
64
103
|
self._config = config
|
|
65
104
|
self._strategy = strategy
|
|
66
105
|
self._client_manager = client_manager
|
|
106
|
+
self._server_fn = server_fn
|
|
67
107
|
self._main: Optional[ServerAppCallable] = None
|
|
68
108
|
|
|
69
109
|
def __call__(self, driver: Driver, context: Context) -> None:
|
|
70
110
|
"""Execute `ServerApp`."""
|
|
71
111
|
# Compatibility mode
|
|
72
112
|
if not self._main:
|
|
113
|
+
if self._server_fn:
|
|
114
|
+
# Execute server_fn()
|
|
115
|
+
components = self._server_fn(context)
|
|
116
|
+
self._server = components.server
|
|
117
|
+
self._config = components.config
|
|
118
|
+
self._strategy = components.strategy
|
|
119
|
+
self._client_manager = components.client_manager
|
|
73
120
|
start_driver(
|
|
74
121
|
server=self._server,
|
|
75
122
|
config=self._config,
|
|
@@ -80,7 +127,6 @@ class ServerApp:
|
|
|
80
127
|
return
|
|
81
128
|
|
|
82
129
|
# New execution mode
|
|
83
|
-
context = Context(state=RecordSet(), run_config={})
|
|
84
130
|
self._main(driver, context)
|
|
85
131
|
|
|
86
132
|
def main(self) -> Callable[[ServerAppCallable], ServerAppCallable]:
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""ServerAppComponents for the ServerApp."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
from typing import Optional
|
|
20
|
+
|
|
21
|
+
from .client_manager import ClientManager
|
|
22
|
+
from .server import Server
|
|
23
|
+
from .server_config import ServerConfig
|
|
24
|
+
from .strategy import Strategy
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class ServerAppComponents: # pylint: disable=too-many-instance-attributes
|
|
29
|
+
"""Components to construct a ServerApp.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
server : Optional[Server] (default: None)
|
|
34
|
+
A server implementation, either `flwr.server.Server` or a subclass
|
|
35
|
+
thereof. If no instance is provided, one will be created internally.
|
|
36
|
+
config : Optional[ServerConfig] (default: None)
|
|
37
|
+
Currently supported values are `num_rounds` (int, default: 1) and
|
|
38
|
+
`round_timeout` in seconds (float, default: None).
|
|
39
|
+
strategy : Optional[Strategy] (default: None)
|
|
40
|
+
An implementation of the abstract base class
|
|
41
|
+
`flwr.server.strategy.Strategy`. If no strategy is provided, then
|
|
42
|
+
`flwr.server.strategy.FedAvg` will be used.
|
|
43
|
+
client_manager : Optional[ClientManager] (default: None)
|
|
44
|
+
An implementation of the class `flwr.server.ClientManager`. If no
|
|
45
|
+
implementation is provided, then `flwr.server.SimpleClientManager`
|
|
46
|
+
will be used.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
server: Optional[Server] = None
|
|
50
|
+
config: Optional[ServerConfig] = None
|
|
51
|
+
strategy: Optional[Strategy] = None
|
|
52
|
+
client_manager: Optional[ClientManager] = None
|
|
@@ -33,8 +33,8 @@ class Backend(ABC):
|
|
|
33
33
|
"""Construct a backend."""
|
|
34
34
|
|
|
35
35
|
@abstractmethod
|
|
36
|
-
|
|
37
|
-
"""Build backend
|
|
36
|
+
def build(self) -> None:
|
|
37
|
+
"""Build backend.
|
|
38
38
|
|
|
39
39
|
Different components need to be in place before workers in a backend are ready
|
|
40
40
|
to accept jobs. When this method finishes executing, the backend should be fully
|
|
@@ -54,11 +54,11 @@ class Backend(ABC):
|
|
|
54
54
|
"""Report whether a backend worker is idle and can therefore run a ClientApp."""
|
|
55
55
|
|
|
56
56
|
@abstractmethod
|
|
57
|
-
|
|
57
|
+
def terminate(self) -> None:
|
|
58
58
|
"""Terminate backend."""
|
|
59
59
|
|
|
60
60
|
@abstractmethod
|
|
61
|
-
|
|
61
|
+
def process_message(
|
|
62
62
|
self,
|
|
63
63
|
app: Callable[[], ClientApp],
|
|
64
64
|
message: Message,
|
|
@@ -153,12 +153,12 @@ class RayBackend(Backend):
|
|
|
153
153
|
"""Report whether the pool has idle actors."""
|
|
154
154
|
return self.pool.is_actor_available()
|
|
155
155
|
|
|
156
|
-
|
|
156
|
+
def build(self) -> None:
|
|
157
157
|
"""Build pool of Ray actors that this backend will submit jobs to."""
|
|
158
|
-
|
|
158
|
+
self.pool.add_actors_to_pool(self.pool.actors_capacity)
|
|
159
159
|
log(DEBUG, "Constructed ActorPool with: %i actors", self.pool.num_actors)
|
|
160
160
|
|
|
161
|
-
|
|
161
|
+
def process_message(
|
|
162
162
|
self,
|
|
163
163
|
app: Callable[[], ClientApp],
|
|
164
164
|
message: Message,
|
|
@@ -172,17 +172,16 @@ class RayBackend(Backend):
|
|
|
172
172
|
|
|
173
173
|
try:
|
|
174
174
|
# Submit a task to the pool
|
|
175
|
-
future =
|
|
175
|
+
future = self.pool.submit(
|
|
176
176
|
lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state),
|
|
177
177
|
(app, message, str(partition_id), context),
|
|
178
178
|
)
|
|
179
179
|
|
|
180
|
-
await future
|
|
181
180
|
# Fetch result
|
|
182
181
|
(
|
|
183
182
|
out_mssg,
|
|
184
183
|
updated_context,
|
|
185
|
-
) =
|
|
184
|
+
) = self.pool.fetch_result_and_return_actor_to_pool(future)
|
|
186
185
|
|
|
187
186
|
return out_mssg, updated_context
|
|
188
187
|
|
|
@@ -193,11 +192,11 @@ class RayBackend(Backend):
|
|
|
193
192
|
self.__class__.__name__,
|
|
194
193
|
)
|
|
195
194
|
# add actor back into pool
|
|
196
|
-
|
|
195
|
+
self.pool.add_actor_back_to_pool(future)
|
|
197
196
|
raise ex
|
|
198
197
|
|
|
199
|
-
|
|
198
|
+
def terminate(self) -> None:
|
|
200
199
|
"""Terminate all actors in actor pool."""
|
|
201
|
-
|
|
200
|
+
self.pool.terminate_all_actors()
|
|
202
201
|
ray.shutdown()
|
|
203
202
|
log(DEBUG, "Terminated %s", self.__class__.__name__)
|