flwr-nightly 1.10.0.dev20240714__py3-none-any.whl → 1.10.0.dev20240716__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +16 -2
- flwr/cli/config_utils.py +23 -15
- flwr/cli/install.py +1 -1
- flwr/cli/new/templates/app/code/server.hf.py.tpl +4 -1
- flwr/cli/new/templates/app/code/server.jax.py.tpl +4 -1
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +4 -1
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +4 -1
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +4 -1
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +4 -1
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +4 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +9 -12
- flwr/cli/new/templates/app/pyproject.hf.toml.tpl +10 -10
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +12 -6
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +10 -10
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +10 -10
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +10 -10
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +10 -10
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +10 -10
- flwr/cli/run/run.py +110 -57
- flwr/client/app.py +3 -3
- flwr/client/node_state.py +17 -3
- flwr/client/supernode/app.py +26 -15
- flwr/common/config.py +13 -4
- flwr/server/run_serverapp.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +52 -28
- flwr/simulation/run_simulation.py +184 -33
- flwr/superexec/simulation.py +157 -0
- {flwr_nightly-1.10.0.dev20240714.dist-info → flwr_nightly-1.10.0.dev20240716.dist-info}/METADATA +2 -1
- {flwr_nightly-1.10.0.dev20240714.dist-info → flwr_nightly-1.10.0.dev20240716.dist-info}/RECORD +32 -31
- {flwr_nightly-1.10.0.dev20240714.dist-info → flwr_nightly-1.10.0.dev20240716.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240714.dist-info → flwr_nightly-1.10.0.dev20240716.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.10.0.dev20240714.dist-info → flwr_nightly-1.10.0.dev20240716.dist-info}/entry_points.txt +0 -0
flwr/cli/run/run.py
CHANGED
|
@@ -14,50 +14,33 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower command line interface `run` command."""
|
|
16
16
|
|
|
17
|
+
import subprocess
|
|
17
18
|
import sys
|
|
18
|
-
from enum import Enum
|
|
19
19
|
from logging import DEBUG
|
|
20
20
|
from pathlib import Path
|
|
21
|
-
from typing import Dict, Optional
|
|
21
|
+
from typing import Any, Dict, Optional
|
|
22
22
|
|
|
23
23
|
import typer
|
|
24
24
|
from typing_extensions import Annotated
|
|
25
25
|
|
|
26
|
-
from flwr.cli import config_utils
|
|
27
26
|
from flwr.cli.build import build
|
|
27
|
+
from flwr.cli.config_utils import load_and_validate
|
|
28
28
|
from flwr.common.config import parse_config_args
|
|
29
|
-
from flwr.common.constant import SUPEREXEC_DEFAULT_ADDRESS
|
|
30
29
|
from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
|
|
31
30
|
from flwr.common.logger import log
|
|
32
31
|
from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
|
|
33
32
|
from flwr.proto.exec_pb2_grpc import ExecStub
|
|
34
|
-
from flwr.simulation.run_simulation import _run_simulation
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
class Engine(str, Enum):
|
|
38
|
-
"""Enum defining the engine to run on."""
|
|
39
|
-
|
|
40
|
-
SIMULATION = "simulation"
|
|
41
33
|
|
|
42
34
|
|
|
43
35
|
# pylint: disable-next=too-many-locals
|
|
44
36
|
def run(
|
|
45
|
-
engine: Annotated[
|
|
46
|
-
Optional[Engine],
|
|
47
|
-
typer.Option(
|
|
48
|
-
case_sensitive=False,
|
|
49
|
-
help="The engine to run FL with (currently only simulation is supported).",
|
|
50
|
-
),
|
|
51
|
-
] = None,
|
|
52
|
-
use_superexec: Annotated[
|
|
53
|
-
bool,
|
|
54
|
-
typer.Option(
|
|
55
|
-
case_sensitive=False, help="Use this flag to use the new SuperExec API"
|
|
56
|
-
),
|
|
57
|
-
] = False,
|
|
58
37
|
directory: Annotated[
|
|
59
|
-
|
|
60
|
-
typer.
|
|
38
|
+
Path,
|
|
39
|
+
typer.Argument(help="Path of the Flower project to run"),
|
|
40
|
+
] = Path("."),
|
|
41
|
+
federation_name: Annotated[
|
|
42
|
+
Optional[str],
|
|
43
|
+
typer.Argument(help="Name of the federation to run the app on"),
|
|
61
44
|
] = None,
|
|
62
45
|
config_overrides: Annotated[
|
|
63
46
|
Optional[str],
|
|
@@ -72,7 +55,7 @@ def run(
|
|
|
72
55
|
typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
|
|
73
56
|
|
|
74
57
|
pyproject_path = directory / "pyproject.toml" if directory else None
|
|
75
|
-
config, errors, warnings =
|
|
58
|
+
config, errors, warnings = load_and_validate(path=pyproject_path)
|
|
76
59
|
|
|
77
60
|
if config is None:
|
|
78
61
|
typer.secho(
|
|
@@ -94,50 +77,81 @@ def run(
|
|
|
94
77
|
|
|
95
78
|
typer.secho("Success", fg=typer.colors.GREEN)
|
|
96
79
|
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
)
|
|
101
|
-
return
|
|
102
|
-
|
|
103
|
-
server_app_ref = config["flower"]["components"]["serverapp"]
|
|
104
|
-
client_app_ref = config["flower"]["components"]["clientapp"]
|
|
105
|
-
|
|
106
|
-
if engine is None:
|
|
107
|
-
engine = config["flower"]["engine"]["name"]
|
|
80
|
+
federation_name = federation_name or config["tool"]["flwr"]["federations"].get(
|
|
81
|
+
"default"
|
|
82
|
+
)
|
|
108
83
|
|
|
109
|
-
if
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
"
|
|
84
|
+
if federation_name is None:
|
|
85
|
+
typer.secho(
|
|
86
|
+
"❌ No federation name was provided and the project's `pyproject.toml` "
|
|
87
|
+
"doesn't declare a default federation (with a SuperExec address or an "
|
|
88
|
+
"`options.num-supernodes` value).",
|
|
89
|
+
fg=typer.colors.RED,
|
|
90
|
+
bold=True,
|
|
113
91
|
)
|
|
92
|
+
raise typer.Exit(code=1)
|
|
114
93
|
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
num_supernodes=num_supernodes,
|
|
120
|
-
backend_config=backend_config,
|
|
121
|
-
)
|
|
122
|
-
else:
|
|
94
|
+
# Validate the federation exists in the configuration
|
|
95
|
+
federation = config["tool"]["flwr"]["federations"].get(federation_name)
|
|
96
|
+
if federation is None:
|
|
97
|
+
available_feds = list(config["tool"]["flwr"]["federations"])
|
|
123
98
|
typer.secho(
|
|
124
|
-
f"
|
|
99
|
+
f"❌ There is no `{federation_name}` federation declared in the "
|
|
100
|
+
"`pyproject.toml`.\n The following federations were found:\n\n"
|
|
101
|
+
"\n".join(available_feds) + "\n\n",
|
|
125
102
|
fg=typer.colors.RED,
|
|
126
103
|
bold=True,
|
|
127
104
|
)
|
|
105
|
+
raise typer.Exit(code=1)
|
|
106
|
+
|
|
107
|
+
if "address" in federation:
|
|
108
|
+
_run_with_superexec(federation, directory, config_overrides)
|
|
109
|
+
else:
|
|
110
|
+
_run_without_superexec(directory, federation, federation_name, config_overrides)
|
|
128
111
|
|
|
129
112
|
|
|
130
|
-
def
|
|
131
|
-
|
|
113
|
+
def _run_with_superexec(
|
|
114
|
+
federation: Dict[str, str],
|
|
115
|
+
directory: Optional[Path],
|
|
116
|
+
config_overrides: Optional[str],
|
|
132
117
|
) -> None:
|
|
118
|
+
|
|
133
119
|
def on_channel_state_change(channel_connectivity: str) -> None:
|
|
134
120
|
"""Log channel connectivity."""
|
|
135
121
|
log(DEBUG, channel_connectivity)
|
|
136
122
|
|
|
123
|
+
insecure_str = federation.get("insecure")
|
|
124
|
+
if root_certificates := federation.get("root-certificates"):
|
|
125
|
+
root_certificates_bytes = Path(root_certificates).read_bytes()
|
|
126
|
+
if insecure := bool(insecure_str):
|
|
127
|
+
typer.secho(
|
|
128
|
+
"❌ `root_certificates` were provided but the `insecure` parameter"
|
|
129
|
+
"is set to `True`.",
|
|
130
|
+
fg=typer.colors.RED,
|
|
131
|
+
bold=True,
|
|
132
|
+
)
|
|
133
|
+
raise typer.Exit(code=1)
|
|
134
|
+
else:
|
|
135
|
+
root_certificates_bytes = None
|
|
136
|
+
if insecure_str is None:
|
|
137
|
+
typer.secho(
|
|
138
|
+
"❌ To disable TLS, set `insecure = true` in `pyproject.toml`.",
|
|
139
|
+
fg=typer.colors.RED,
|
|
140
|
+
bold=True,
|
|
141
|
+
)
|
|
142
|
+
raise typer.Exit(code=1)
|
|
143
|
+
if not (insecure := bool(insecure_str)):
|
|
144
|
+
typer.secho(
|
|
145
|
+
"❌ No certificate were given yet `insecure` is set to `False`.",
|
|
146
|
+
fg=typer.colors.RED,
|
|
147
|
+
bold=True,
|
|
148
|
+
)
|
|
149
|
+
raise typer.Exit(code=1)
|
|
150
|
+
|
|
137
151
|
channel = create_channel(
|
|
138
|
-
server_address=
|
|
139
|
-
insecure=
|
|
140
|
-
root_certificates=
|
|
152
|
+
server_address=federation["address"],
|
|
153
|
+
insecure=insecure,
|
|
154
|
+
root_certificates=root_certificates_bytes,
|
|
141
155
|
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
|
|
142
156
|
interceptors=None,
|
|
143
157
|
)
|
|
@@ -148,7 +162,46 @@ def _start_superexec_run(
|
|
|
148
162
|
|
|
149
163
|
req = StartRunRequest(
|
|
150
164
|
fab_file=Path(fab_path).read_bytes(),
|
|
151
|
-
override_config=
|
|
165
|
+
override_config=parse_config_args(config_overrides, separator=","),
|
|
152
166
|
)
|
|
153
167
|
res = stub.StartRun(req)
|
|
154
168
|
typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _run_without_superexec(
|
|
172
|
+
app_path: Optional[Path],
|
|
173
|
+
federation: Dict[str, Any],
|
|
174
|
+
federation_name: str,
|
|
175
|
+
config_overrides: Optional[str],
|
|
176
|
+
) -> None:
|
|
177
|
+
try:
|
|
178
|
+
num_supernodes = federation["options"]["num-supernodes"]
|
|
179
|
+
except KeyError as err:
|
|
180
|
+
typer.secho(
|
|
181
|
+
"❌ The project's `pyproject.toml` needs to declare the number of"
|
|
182
|
+
" SuperNodes in the simulation. To simulate 10 SuperNodes,"
|
|
183
|
+
" use the following notation:\n\n"
|
|
184
|
+
f"[tool.flwr.federations.{federation_name}]\n"
|
|
185
|
+
"options.num-supernodes = 10\n",
|
|
186
|
+
fg=typer.colors.RED,
|
|
187
|
+
bold=True,
|
|
188
|
+
)
|
|
189
|
+
raise typer.Exit(code=1) from err
|
|
190
|
+
|
|
191
|
+
command = [
|
|
192
|
+
"flower-simulation",
|
|
193
|
+
"--app",
|
|
194
|
+
f"{app_path}",
|
|
195
|
+
"--num-supernodes",
|
|
196
|
+
f"{num_supernodes}",
|
|
197
|
+
]
|
|
198
|
+
|
|
199
|
+
if config_overrides:
|
|
200
|
+
command.extend(["--run-config", f"{config_overrides}"])
|
|
201
|
+
|
|
202
|
+
# Run the simulation
|
|
203
|
+
subprocess.run(
|
|
204
|
+
command,
|
|
205
|
+
check=True,
|
|
206
|
+
text=True,
|
|
207
|
+
)
|
flwr/client/app.py
CHANGED
|
@@ -195,7 +195,7 @@ def _start_client_internal(
|
|
|
195
195
|
] = None,
|
|
196
196
|
max_retries: Optional[int] = None,
|
|
197
197
|
max_wait_time: Optional[float] = None,
|
|
198
|
-
|
|
198
|
+
flwr_path: Optional[Path] = None,
|
|
199
199
|
) -> None:
|
|
200
200
|
"""Start a Flower client node which connects to a Flower server.
|
|
201
201
|
|
|
@@ -241,7 +241,7 @@ def _start_client_internal(
|
|
|
241
241
|
The maximum duration before the client stops trying to
|
|
242
242
|
connect to the server in case of connection error.
|
|
243
243
|
If set to None, there is no limit to the total time.
|
|
244
|
-
|
|
244
|
+
flwr_path: Optional[Path] (default: None)
|
|
245
245
|
The fully resolved path containing installed Flower Apps.
|
|
246
246
|
"""
|
|
247
247
|
if insecure is None:
|
|
@@ -402,7 +402,7 @@ def _start_client_internal(
|
|
|
402
402
|
|
|
403
403
|
# Register context for this run
|
|
404
404
|
node_state.register_context(
|
|
405
|
-
run_id=run_id, run=runs[run_id],
|
|
405
|
+
run_id=run_id, run=runs[run_id], flwr_path=flwr_path
|
|
406
406
|
)
|
|
407
407
|
|
|
408
408
|
# Retrieve context for this run
|
flwr/client/node_state.py
CHANGED
|
@@ -20,7 +20,7 @@ from pathlib import Path
|
|
|
20
20
|
from typing import Dict, Optional
|
|
21
21
|
|
|
22
22
|
from flwr.common import Context, RecordSet
|
|
23
|
-
from flwr.common.config import get_fused_config
|
|
23
|
+
from flwr.common.config import get_fused_config, get_fused_config_from_dir
|
|
24
24
|
from flwr.common.typing import Run
|
|
25
25
|
|
|
26
26
|
|
|
@@ -48,11 +48,25 @@ class NodeState:
|
|
|
48
48
|
self,
|
|
49
49
|
run_id: int,
|
|
50
50
|
run: Optional[Run] = None,
|
|
51
|
-
|
|
51
|
+
flwr_path: Optional[Path] = None,
|
|
52
|
+
app_dir: Optional[str] = None,
|
|
52
53
|
) -> None:
|
|
53
54
|
"""Register new run context for this node."""
|
|
54
55
|
if run_id not in self.run_infos:
|
|
55
|
-
initial_run_config =
|
|
56
|
+
initial_run_config = {}
|
|
57
|
+
if app_dir:
|
|
58
|
+
# Load from app directory
|
|
59
|
+
app_path = Path(app_dir)
|
|
60
|
+
if app_path.is_dir():
|
|
61
|
+
override_config = run.override_config if run else {}
|
|
62
|
+
initial_run_config = get_fused_config_from_dir(
|
|
63
|
+
app_path, override_config
|
|
64
|
+
)
|
|
65
|
+
else:
|
|
66
|
+
raise ValueError("The specified `app_dir` must be a directory.")
|
|
67
|
+
else:
|
|
68
|
+
# Load from .fab
|
|
69
|
+
initial_run_config = get_fused_config(run, flwr_path) if run else {}
|
|
56
70
|
self.run_infos[run_id] = RunInfo(
|
|
57
71
|
initial_run_config=initial_run_config,
|
|
58
72
|
context=Context(
|
flwr/client/supernode/app.py
CHANGED
|
@@ -60,7 +60,12 @@ def run_supernode() -> None:
|
|
|
60
60
|
_warn_deprecated_server_arg(args)
|
|
61
61
|
|
|
62
62
|
root_certificates = _get_certificates(args)
|
|
63
|
-
load_fn = _get_load_client_app_fn(
|
|
63
|
+
load_fn = _get_load_client_app_fn(
|
|
64
|
+
default_app_ref=getattr(args, "client-app"),
|
|
65
|
+
dir_arg=args.dir,
|
|
66
|
+
flwr_dir_arg=args.flwr_dir,
|
|
67
|
+
multi_app=True,
|
|
68
|
+
)
|
|
64
69
|
authentication_keys = _try_setup_client_authentication(args)
|
|
65
70
|
|
|
66
71
|
_start_client_internal(
|
|
@@ -73,7 +78,7 @@ def run_supernode() -> None:
|
|
|
73
78
|
max_retries=args.max_retries,
|
|
74
79
|
max_wait_time=args.max_wait_time,
|
|
75
80
|
node_config=parse_config_args(args.node_config),
|
|
76
|
-
|
|
81
|
+
flwr_path=get_flwr_dir(args.flwr_dir),
|
|
77
82
|
)
|
|
78
83
|
|
|
79
84
|
# Graceful shutdown
|
|
@@ -93,7 +98,11 @@ def run_client_app() -> None:
|
|
|
93
98
|
_warn_deprecated_server_arg(args)
|
|
94
99
|
|
|
95
100
|
root_certificates = _get_certificates(args)
|
|
96
|
-
load_fn = _get_load_client_app_fn(
|
|
101
|
+
load_fn = _get_load_client_app_fn(
|
|
102
|
+
default_app_ref=getattr(args, "client-app"),
|
|
103
|
+
dir_arg=args.dir,
|
|
104
|
+
multi_app=False,
|
|
105
|
+
)
|
|
97
106
|
authentication_keys = _try_setup_client_authentication(args)
|
|
98
107
|
|
|
99
108
|
_start_client_internal(
|
|
@@ -166,7 +175,10 @@ def _get_certificates(args: argparse.Namespace) -> Optional[bytes]:
|
|
|
166
175
|
|
|
167
176
|
|
|
168
177
|
def _get_load_client_app_fn(
|
|
169
|
-
|
|
178
|
+
default_app_ref: str,
|
|
179
|
+
dir_arg: str,
|
|
180
|
+
multi_app: bool,
|
|
181
|
+
flwr_dir_arg: Optional[str] = None,
|
|
170
182
|
) -> Callable[[str, str], ClientApp]:
|
|
171
183
|
"""Get the load_client_app_fn function.
|
|
172
184
|
|
|
@@ -178,25 +190,24 @@ def _get_load_client_app_fn(
|
|
|
178
190
|
loads a default ClientApp.
|
|
179
191
|
"""
|
|
180
192
|
# Find the Flower directory containing Flower Apps (only for multi-app)
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
193
|
+
if not multi_app:
|
|
194
|
+
flwr_dir = Path("")
|
|
195
|
+
else:
|
|
196
|
+
if flwr_dir_arg is None:
|
|
184
197
|
flwr_dir = get_flwr_dir()
|
|
185
198
|
else:
|
|
186
|
-
flwr_dir = Path(
|
|
199
|
+
flwr_dir = Path(flwr_dir_arg).absolute()
|
|
187
200
|
|
|
188
201
|
inserted_path = None
|
|
189
202
|
|
|
190
|
-
default_app_ref: str = getattr(args, "client-app")
|
|
191
|
-
|
|
192
203
|
if not multi_app:
|
|
193
204
|
log(
|
|
194
205
|
DEBUG,
|
|
195
206
|
"Flower SuperNode will load and validate ClientApp `%s`",
|
|
196
|
-
|
|
207
|
+
default_app_ref,
|
|
197
208
|
)
|
|
198
209
|
# Insert sys.path
|
|
199
|
-
dir_path = Path(
|
|
210
|
+
dir_path = Path(dir_arg).absolute()
|
|
200
211
|
sys.path.insert(0, str(dir_path))
|
|
201
212
|
inserted_path = str(dir_path)
|
|
202
213
|
|
|
@@ -208,7 +219,7 @@ def _get_load_client_app_fn(
|
|
|
208
219
|
# If multi-app feature is disabled
|
|
209
220
|
if not multi_app:
|
|
210
221
|
# Get sys path to be inserted
|
|
211
|
-
dir_path = Path(
|
|
222
|
+
dir_path = Path(dir_arg).absolute()
|
|
212
223
|
|
|
213
224
|
# Set app reference
|
|
214
225
|
client_app_ref = default_app_ref
|
|
@@ -221,7 +232,7 @@ def _get_load_client_app_fn(
|
|
|
221
232
|
|
|
222
233
|
log(WARN, "FAB ID is not provided; the default ClientApp will be loaded.")
|
|
223
234
|
# Get sys path to be inserted
|
|
224
|
-
dir_path = Path(
|
|
235
|
+
dir_path = Path(dir_arg).absolute()
|
|
225
236
|
|
|
226
237
|
# Set app reference
|
|
227
238
|
client_app_ref = default_app_ref
|
|
@@ -237,7 +248,7 @@ def _get_load_client_app_fn(
|
|
|
237
248
|
dir_path = Path(project_dir).absolute()
|
|
238
249
|
|
|
239
250
|
# Set app reference
|
|
240
|
-
client_app_ref = config["
|
|
251
|
+
client_app_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"]
|
|
241
252
|
|
|
242
253
|
# Set sys.path
|
|
243
254
|
nonlocal inserted_path
|
flwr/common/config.py
CHANGED
|
@@ -86,6 +86,18 @@ def _fuse_dicts(
|
|
|
86
86
|
return fused_dict
|
|
87
87
|
|
|
88
88
|
|
|
89
|
+
def get_fused_config_from_dir(
|
|
90
|
+
project_dir: Path, override_config: Dict[str, str]
|
|
91
|
+
) -> Dict[str, str]:
|
|
92
|
+
"""Merge the overrides from a given dict with the config from a Flower App."""
|
|
93
|
+
default_config = get_project_config(project_dir)["tool"]["flwr"]["app"].get(
|
|
94
|
+
"config", {}
|
|
95
|
+
)
|
|
96
|
+
flat_default_config = flatten_dict(default_config)
|
|
97
|
+
|
|
98
|
+
return _fuse_dicts(flat_default_config, override_config)
|
|
99
|
+
|
|
100
|
+
|
|
89
101
|
def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> Dict[str, str]:
|
|
90
102
|
"""Merge the overrides from a `Run` with the config from a FAB.
|
|
91
103
|
|
|
@@ -97,10 +109,7 @@ def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> Dict[str, str]:
|
|
|
97
109
|
|
|
98
110
|
project_dir = get_project_dir(run.fab_id, run.fab_version, flwr_dir)
|
|
99
111
|
|
|
100
|
-
|
|
101
|
-
flat_default_config = flatten_dict(default_config)
|
|
102
|
-
|
|
103
|
-
return _fuse_dicts(flat_default_config, run.override_config)
|
|
112
|
+
return get_fused_config_from_dir(project_dir, run.override_config)
|
|
104
113
|
|
|
105
114
|
|
|
106
115
|
def flatten_dict(raw_dict: Dict[str, Any], parent_key: str = "") -> Dict[str, str]:
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -186,7 +186,7 @@ def run_server_app() -> None: # pylint: disable=too-many-branches
|
|
|
186
186
|
run_ = driver.run
|
|
187
187
|
server_app_dir = str(get_project_dir(run_.fab_id, run_.fab_version, flwr_dir))
|
|
188
188
|
config = get_project_config(server_app_dir)
|
|
189
|
-
server_app_attr = config["
|
|
189
|
+
server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"]
|
|
190
190
|
server_app_run_config = get_fused_config(run_, flwr_dir)
|
|
191
191
|
else:
|
|
192
192
|
# User provided `server-app`, but not `--run-id`
|
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import json
|
|
19
|
-
import sys
|
|
20
19
|
import threading
|
|
21
20
|
import time
|
|
22
21
|
import traceback
|
|
@@ -29,6 +28,7 @@ from typing import Callable, Dict, Optional
|
|
|
29
28
|
|
|
30
29
|
from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
|
|
31
30
|
from flwr.client.node_state import NodeState
|
|
31
|
+
from flwr.client.supernode.app import _get_load_client_app_fn
|
|
32
32
|
from flwr.common.constant import (
|
|
33
33
|
NUM_PARTITIONS_KEY,
|
|
34
34
|
PARTITION_ID_KEY,
|
|
@@ -37,8 +37,8 @@ from flwr.common.constant import (
|
|
|
37
37
|
)
|
|
38
38
|
from flwr.common.logger import log
|
|
39
39
|
from flwr.common.message import Error
|
|
40
|
-
from flwr.common.object_ref import load_app
|
|
41
40
|
from flwr.common.serde import message_from_taskins, message_to_taskres
|
|
41
|
+
from flwr.common.typing import Run
|
|
42
42
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
43
43
|
from flwr.server.superlink.state import State, StateFactory
|
|
44
44
|
|
|
@@ -60,6 +60,31 @@ def _register_nodes(
|
|
|
60
60
|
return nodes_mapping
|
|
61
61
|
|
|
62
62
|
|
|
63
|
+
def _register_node_states(
|
|
64
|
+
nodes_mapping: NodeToPartitionMapping,
|
|
65
|
+
run: Run,
|
|
66
|
+
app_dir: Optional[str] = None,
|
|
67
|
+
) -> Dict[int, NodeState]:
|
|
68
|
+
"""Create NodeState objects and pre-register the context for the run."""
|
|
69
|
+
node_states: Dict[int, NodeState] = {}
|
|
70
|
+
num_partitions = len(set(nodes_mapping.values()))
|
|
71
|
+
for node_id, partition_id in nodes_mapping.items():
|
|
72
|
+
node_states[node_id] = NodeState(
|
|
73
|
+
node_id=node_id,
|
|
74
|
+
node_config={
|
|
75
|
+
PARTITION_ID_KEY: str(partition_id),
|
|
76
|
+
NUM_PARTITIONS_KEY: str(num_partitions),
|
|
77
|
+
},
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Pre-register Context objects
|
|
81
|
+
node_states[node_id].register_context(
|
|
82
|
+
run_id=run.run_id, run=run, app_dir=app_dir
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
return node_states
|
|
86
|
+
|
|
87
|
+
|
|
63
88
|
# pylint: disable=too-many-arguments,too-many-locals
|
|
64
89
|
def worker(
|
|
65
90
|
app_fn: Callable[[], ClientApp],
|
|
@@ -78,8 +103,7 @@ def worker(
|
|
|
78
103
|
task_ins: TaskIns = taskins_queue.get(timeout=1.0)
|
|
79
104
|
node_id = task_ins.task.consumer.node_id
|
|
80
105
|
|
|
81
|
-
#
|
|
82
|
-
node_states[node_id].register_context(run_id=task_ins.run_id)
|
|
106
|
+
# Retrieve context
|
|
83
107
|
context = node_states[node_id].retrieve_context(run_id=task_ins.run_id)
|
|
84
108
|
|
|
85
109
|
# Convert TaskIns to Message
|
|
@@ -151,7 +175,7 @@ def put_taskres_into_state(
|
|
|
151
175
|
pass
|
|
152
176
|
|
|
153
177
|
|
|
154
|
-
def
|
|
178
|
+
def run_api(
|
|
155
179
|
app_fn: Callable[[], ClientApp],
|
|
156
180
|
backend_fn: Callable[[], Backend],
|
|
157
181
|
nodes_mapping: NodeToPartitionMapping,
|
|
@@ -236,7 +260,10 @@ def start_vce(
|
|
|
236
260
|
backend_name: str,
|
|
237
261
|
backend_config_json_stream: str,
|
|
238
262
|
app_dir: str,
|
|
263
|
+
is_app: bool,
|
|
239
264
|
f_stop: threading.Event,
|
|
265
|
+
run: Run,
|
|
266
|
+
flwr_dir: Optional[str] = None,
|
|
240
267
|
client_app: Optional[ClientApp] = None,
|
|
241
268
|
client_app_attr: Optional[str] = None,
|
|
242
269
|
num_supernodes: Optional[int] = None,
|
|
@@ -287,17 +314,9 @@ def start_vce(
|
|
|
287
314
|
)
|
|
288
315
|
|
|
289
316
|
# Construct mapping of NodeStates
|
|
290
|
-
node_states
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
for node_id, partition_id in nodes_mapping.items():
|
|
294
|
-
node_states[node_id] = NodeState(
|
|
295
|
-
node_id=node_id,
|
|
296
|
-
node_config={
|
|
297
|
-
PARTITION_ID_KEY: str(partition_id),
|
|
298
|
-
NUM_PARTITIONS_KEY: str(num_partitions),
|
|
299
|
-
},
|
|
300
|
-
)
|
|
317
|
+
node_states = _register_node_states(
|
|
318
|
+
nodes_mapping=nodes_mapping, run=run, app_dir=app_dir if is_app else None
|
|
319
|
+
)
|
|
301
320
|
|
|
302
321
|
# Load backend config
|
|
303
322
|
log(DEBUG, "Supported backends: %s", list(supported_backends.keys()))
|
|
@@ -326,16 +345,12 @@ def start_vce(
|
|
|
326
345
|
def _load() -> ClientApp:
|
|
327
346
|
|
|
328
347
|
if client_app_attr:
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
if not isinstance(app, ClientApp):
|
|
336
|
-
raise LoadClientAppError(
|
|
337
|
-
f"Attribute {client_app_attr} is not of type {ClientApp}",
|
|
338
|
-
) from None
|
|
348
|
+
app = _get_load_client_app_fn(
|
|
349
|
+
default_app_ref=client_app_attr,
|
|
350
|
+
dir_arg=app_dir,
|
|
351
|
+
flwr_dir_arg=flwr_dir,
|
|
352
|
+
multi_app=True,
|
|
353
|
+
)(run.fab_id, run.fab_version)
|
|
339
354
|
|
|
340
355
|
if client_app:
|
|
341
356
|
app = client_app
|
|
@@ -345,10 +360,19 @@ def start_vce(
|
|
|
345
360
|
|
|
346
361
|
try:
|
|
347
362
|
# Test if ClientApp can be loaded
|
|
348
|
-
|
|
363
|
+
client_app = app_fn()
|
|
364
|
+
|
|
365
|
+
# Cache `ClientApp`
|
|
366
|
+
if client_app_attr:
|
|
367
|
+
# Now wrap the loaded ClientApp in a dummy function
|
|
368
|
+
# this prevent unnecesary low-level loading of ClientApp
|
|
369
|
+
def _load_client_app() -> ClientApp:
|
|
370
|
+
return client_app
|
|
371
|
+
|
|
372
|
+
app_fn = _load_client_app
|
|
349
373
|
|
|
350
374
|
# Run main simulation loop
|
|
351
|
-
|
|
375
|
+
run_api(
|
|
352
376
|
app_fn,
|
|
353
377
|
backend_fn,
|
|
354
378
|
nodes_mapping,
|