flwr-nightly 1.10.0.dev20240624__py3-none-any.whl → 1.10.0.dev20240722__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +18 -4
- flwr/cli/config_utils.py +36 -14
- flwr/cli/install.py +17 -1
- flwr/cli/new/new.py +31 -20
- flwr/cli/new/templates/app/code/client.hf.py.tpl +11 -3
- flwr/cli/new/templates/app/code/client.jax.py.tpl +2 -1
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +15 -10
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +2 -1
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +12 -3
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +6 -3
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +13 -3
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +2 -2
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.hf.py.tpl +16 -11
- flwr/cli/new/templates/app/code/server.jax.py.tpl +15 -8
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +11 -7
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +15 -8
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +15 -13
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +16 -10
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +16 -13
- flwr/cli/new/templates/app/code/task.hf.py.tpl +2 -2
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +2 -2
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +9 -12
- flwr/cli/new/templates/app/pyproject.hf.toml.tpl +17 -16
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +17 -11
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +17 -12
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +12 -12
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +13 -12
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +12 -12
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +15 -12
- flwr/cli/run/run.py +135 -51
- flwr/client/__init__.py +2 -0
- flwr/client/app.py +63 -26
- flwr/client/client_app.py +49 -4
- flwr/client/grpc_adapter_client/connection.py +3 -2
- flwr/client/grpc_client/connection.py +3 -2
- flwr/client/grpc_rere_client/connection.py +17 -6
- flwr/client/message_handler/message_handler.py +3 -4
- flwr/client/node_state.py +60 -10
- flwr/client/node_state_tests.py +4 -3
- flwr/client/rest_client/connection.py +19 -8
- flwr/client/supernode/app.py +60 -21
- flwr/client/typing.py +1 -0
- flwr/common/config.py +87 -2
- flwr/common/constant.py +6 -0
- flwr/common/context.py +26 -1
- flwr/common/logger.py +38 -0
- flwr/common/message.py +0 -17
- flwr/common/serde.py +45 -0
- flwr/common/telemetry.py +17 -0
- flwr/common/typing.py +5 -0
- flwr/proto/common_pb2.py +36 -0
- flwr/proto/common_pb2.pyi +121 -0
- flwr/proto/common_pb2_grpc.py +4 -0
- flwr/proto/common_pb2_grpc.pyi +4 -0
- flwr/proto/driver_pb2.py +24 -19
- flwr/proto/driver_pb2.pyi +21 -1
- flwr/proto/exec_pb2.py +16 -11
- flwr/proto/exec_pb2.pyi +22 -1
- flwr/proto/run_pb2.py +12 -7
- flwr/proto/run_pb2.pyi +22 -1
- flwr/proto/task_pb2.py +7 -8
- flwr/server/__init__.py +2 -0
- flwr/server/compat/legacy_context.py +5 -4
- flwr/server/driver/grpc_driver.py +82 -140
- flwr/server/run_serverapp.py +40 -15
- flwr/server/server_app.py +56 -10
- flwr/server/serverapp_components.py +52 -0
- flwr/server/superlink/driver/driver_servicer.py +18 -3
- flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
- flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
- flwr/server/superlink/fleet/vce/backend/raybackend.py +10 -10
- flwr/server/superlink/fleet/vce/vce_api.py +149 -122
- flwr/server/superlink/state/in_memory_state.py +15 -7
- flwr/server/superlink/state/sqlite_state.py +27 -12
- flwr/server/superlink/state/state.py +7 -2
- flwr/server/superlink/state/utils.py +6 -0
- flwr/server/typing.py +2 -0
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
- flwr/simulation/app.py +52 -36
- flwr/simulation/ray_transport/ray_actor.py +15 -19
- flwr/simulation/ray_transport/ray_client_proxy.py +33 -13
- flwr/simulation/run_simulation.py +237 -66
- flwr/superexec/app.py +14 -7
- flwr/superexec/deployment.py +186 -0
- flwr/superexec/exec_grpc.py +5 -1
- flwr/superexec/exec_servicer.py +4 -1
- flwr/superexec/executor.py +18 -0
- flwr/superexec/simulation.py +151 -0
- {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/METADATA +3 -2
- {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/RECORD +95 -88
- {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/entry_points.txt +0 -0
flwr/cli/run/run.py
CHANGED
|
@@ -14,53 +14,49 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower command line interface `run` command."""
|
|
16
16
|
|
|
17
|
+
import subprocess
|
|
17
18
|
import sys
|
|
18
|
-
from enum import Enum
|
|
19
19
|
from logging import DEBUG
|
|
20
|
-
from
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import Any, Dict, List, Optional
|
|
21
22
|
|
|
22
23
|
import typer
|
|
23
24
|
from typing_extensions import Annotated
|
|
24
25
|
|
|
25
|
-
from flwr.cli import
|
|
26
|
-
from flwr.
|
|
26
|
+
from flwr.cli.build import build
|
|
27
|
+
from flwr.cli.config_utils import load_and_validate
|
|
28
|
+
from flwr.common.config import parse_config_args
|
|
27
29
|
from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
|
|
28
30
|
from flwr.common.logger import log
|
|
31
|
+
from flwr.common.serde import user_config_to_proto
|
|
29
32
|
from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
|
|
30
33
|
from flwr.proto.exec_pb2_grpc import ExecStub
|
|
31
|
-
from flwr.simulation.run_simulation import _run_simulation
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
class Engine(str, Enum):
|
|
35
|
-
"""Enum defining the engine to run on."""
|
|
36
|
-
|
|
37
|
-
SIMULATION = "simulation"
|
|
38
34
|
|
|
39
35
|
|
|
40
36
|
# pylint: disable-next=too-many-locals
|
|
41
37
|
def run(
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
typer.
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
38
|
+
directory: Annotated[
|
|
39
|
+
Path,
|
|
40
|
+
typer.Argument(help="Path of the Flower project to run"),
|
|
41
|
+
] = Path("."),
|
|
42
|
+
federation_name: Annotated[
|
|
43
|
+
Optional[str],
|
|
44
|
+
typer.Argument(help="Name of the federation to run the app on"),
|
|
48
45
|
] = None,
|
|
49
|
-
|
|
50
|
-
|
|
46
|
+
config_overrides: Annotated[
|
|
47
|
+
Optional[List[str]],
|
|
51
48
|
typer.Option(
|
|
52
|
-
|
|
49
|
+
"--run-config",
|
|
50
|
+
"-c",
|
|
51
|
+
help="Override configuration key-value pairs",
|
|
53
52
|
),
|
|
54
|
-
] =
|
|
53
|
+
] = None,
|
|
55
54
|
) -> None:
|
|
56
55
|
"""Run Flower project."""
|
|
57
|
-
if use_superexec:
|
|
58
|
-
_start_superexec_run()
|
|
59
|
-
return
|
|
60
|
-
|
|
61
56
|
typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
|
|
62
57
|
|
|
63
|
-
|
|
58
|
+
pyproject_path = directory / "pyproject.toml" if directory else None
|
|
59
|
+
config, errors, warnings = load_and_validate(path=pyproject_path)
|
|
64
60
|
|
|
65
61
|
if config is None:
|
|
66
62
|
typer.secho(
|
|
@@ -82,47 +78,135 @@ def run(
|
|
|
82
78
|
|
|
83
79
|
typer.secho("Success", fg=typer.colors.GREEN)
|
|
84
80
|
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
if engine is None:
|
|
89
|
-
engine = config["flower"]["engine"]["name"]
|
|
90
|
-
|
|
91
|
-
if engine == Engine.SIMULATION:
|
|
92
|
-
num_supernodes = config["flower"]["engine"]["simulation"]["supernode"]["num"]
|
|
93
|
-
backend_config = config["flower"]["engine"]["simulation"].get(
|
|
94
|
-
"backend_config", None
|
|
95
|
-
)
|
|
81
|
+
federation_name = federation_name or config["tool"]["flwr"]["federations"].get(
|
|
82
|
+
"default"
|
|
83
|
+
)
|
|
96
84
|
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
85
|
+
if federation_name is None:
|
|
86
|
+
typer.secho(
|
|
87
|
+
"❌ No federation name was provided and the project's `pyproject.toml` "
|
|
88
|
+
"doesn't declare a default federation (with a SuperExec address or an "
|
|
89
|
+
"`options.num-supernodes` value).",
|
|
90
|
+
fg=typer.colors.RED,
|
|
91
|
+
bold=True,
|
|
103
92
|
)
|
|
104
|
-
|
|
93
|
+
raise typer.Exit(code=1)
|
|
94
|
+
|
|
95
|
+
# Validate the federation exists in the configuration
|
|
96
|
+
federation = config["tool"]["flwr"]["federations"].get(federation_name)
|
|
97
|
+
if federation is None:
|
|
98
|
+
available_feds = {
|
|
99
|
+
fed for fed in config["tool"]["flwr"]["federations"] if fed != "default"
|
|
100
|
+
}
|
|
105
101
|
typer.secho(
|
|
106
|
-
f"
|
|
102
|
+
f"❌ There is no `{federation_name}` federation declared in the "
|
|
103
|
+
"`pyproject.toml`.\n The following federations were found:\n\n"
|
|
104
|
+
+ "\n".join(available_feds),
|
|
107
105
|
fg=typer.colors.RED,
|
|
108
106
|
bold=True,
|
|
109
107
|
)
|
|
108
|
+
raise typer.Exit(code=1)
|
|
109
|
+
|
|
110
|
+
if "address" in federation:
|
|
111
|
+
_run_with_superexec(federation, directory, config_overrides)
|
|
112
|
+
else:
|
|
113
|
+
_run_without_superexec(directory, federation, federation_name, config_overrides)
|
|
114
|
+
|
|
110
115
|
|
|
116
|
+
def _run_with_superexec(
|
|
117
|
+
federation: Dict[str, str],
|
|
118
|
+
directory: Optional[Path],
|
|
119
|
+
config_overrides: Optional[List[str]],
|
|
120
|
+
) -> None:
|
|
111
121
|
|
|
112
|
-
def _start_superexec_run() -> None:
|
|
113
122
|
def on_channel_state_change(channel_connectivity: str) -> None:
|
|
114
123
|
"""Log channel connectivity."""
|
|
115
124
|
log(DEBUG, channel_connectivity)
|
|
116
125
|
|
|
126
|
+
insecure_str = federation.get("insecure")
|
|
127
|
+
if root_certificates := federation.get("root-certificates"):
|
|
128
|
+
root_certificates_bytes = Path(root_certificates).read_bytes()
|
|
129
|
+
if insecure := bool(insecure_str):
|
|
130
|
+
typer.secho(
|
|
131
|
+
"❌ `root_certificates` were provided but the `insecure` parameter"
|
|
132
|
+
"is set to `True`.",
|
|
133
|
+
fg=typer.colors.RED,
|
|
134
|
+
bold=True,
|
|
135
|
+
)
|
|
136
|
+
raise typer.Exit(code=1)
|
|
137
|
+
else:
|
|
138
|
+
root_certificates_bytes = None
|
|
139
|
+
if insecure_str is None:
|
|
140
|
+
typer.secho(
|
|
141
|
+
"❌ To disable TLS, set `insecure = true` in `pyproject.toml`.",
|
|
142
|
+
fg=typer.colors.RED,
|
|
143
|
+
bold=True,
|
|
144
|
+
)
|
|
145
|
+
raise typer.Exit(code=1)
|
|
146
|
+
if not (insecure := bool(insecure_str)):
|
|
147
|
+
typer.secho(
|
|
148
|
+
"❌ No certificate were given yet `insecure` is set to `False`.",
|
|
149
|
+
fg=typer.colors.RED,
|
|
150
|
+
bold=True,
|
|
151
|
+
)
|
|
152
|
+
raise typer.Exit(code=1)
|
|
153
|
+
|
|
117
154
|
channel = create_channel(
|
|
118
|
-
server_address=
|
|
119
|
-
insecure=
|
|
120
|
-
root_certificates=
|
|
155
|
+
server_address=federation["address"],
|
|
156
|
+
insecure=insecure,
|
|
157
|
+
root_certificates=root_certificates_bytes,
|
|
121
158
|
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
|
|
122
159
|
interceptors=None,
|
|
123
160
|
)
|
|
124
161
|
channel.subscribe(on_channel_state_change)
|
|
125
162
|
stub = ExecStub(channel)
|
|
126
163
|
|
|
127
|
-
|
|
128
|
-
|
|
164
|
+
fab_path = build(directory)
|
|
165
|
+
|
|
166
|
+
req = StartRunRequest(
|
|
167
|
+
fab_file=Path(fab_path).read_bytes(),
|
|
168
|
+
override_config=user_config_to_proto(
|
|
169
|
+
parse_config_args(config_overrides, separator=",")
|
|
170
|
+
),
|
|
171
|
+
)
|
|
172
|
+
res = stub.StartRun(req)
|
|
173
|
+
typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _run_without_superexec(
|
|
177
|
+
app_path: Optional[Path],
|
|
178
|
+
federation: Dict[str, Any],
|
|
179
|
+
federation_name: str,
|
|
180
|
+
config_overrides: Optional[List[str]],
|
|
181
|
+
) -> None:
|
|
182
|
+
try:
|
|
183
|
+
num_supernodes = federation["options"]["num-supernodes"]
|
|
184
|
+
except KeyError as err:
|
|
185
|
+
typer.secho(
|
|
186
|
+
"❌ The project's `pyproject.toml` needs to declare the number of"
|
|
187
|
+
" SuperNodes in the simulation. To simulate 10 SuperNodes,"
|
|
188
|
+
" use the following notation:\n\n"
|
|
189
|
+
f"[tool.flwr.federations.{federation_name}]\n"
|
|
190
|
+
"options.num-supernodes = 10\n",
|
|
191
|
+
fg=typer.colors.RED,
|
|
192
|
+
bold=True,
|
|
193
|
+
)
|
|
194
|
+
raise typer.Exit(code=1) from err
|
|
195
|
+
|
|
196
|
+
command = [
|
|
197
|
+
"flower-simulation",
|
|
198
|
+
"--app",
|
|
199
|
+
f"{app_path}",
|
|
200
|
+
"--num-supernodes",
|
|
201
|
+
f"{num_supernodes}",
|
|
202
|
+
]
|
|
203
|
+
|
|
204
|
+
if config_overrides:
|
|
205
|
+
command.extend(["--run-config", f"{','.join(config_overrides)}"])
|
|
206
|
+
|
|
207
|
+
# Run the simulation
|
|
208
|
+
subprocess.run(
|
|
209
|
+
command,
|
|
210
|
+
check=True,
|
|
211
|
+
text=True,
|
|
212
|
+
)
|
flwr/client/__init__.py
CHANGED
|
@@ -23,11 +23,13 @@ from .numpy_client import NumPyClient as NumPyClient
|
|
|
23
23
|
from .supernode import run_client_app as run_client_app
|
|
24
24
|
from .supernode import run_supernode as run_supernode
|
|
25
25
|
from .typing import ClientFn as ClientFn
|
|
26
|
+
from .typing import ClientFnExt as ClientFnExt
|
|
26
27
|
|
|
27
28
|
__all__ = [
|
|
28
29
|
"Client",
|
|
29
30
|
"ClientApp",
|
|
30
31
|
"ClientFn",
|
|
32
|
+
"ClientFnExt",
|
|
31
33
|
"NumPyClient",
|
|
32
34
|
"mod",
|
|
33
35
|
"run_client_app",
|
flwr/client/app.py
CHANGED
|
@@ -18,7 +18,8 @@ import signal
|
|
|
18
18
|
import sys
|
|
19
19
|
import time
|
|
20
20
|
from dataclasses import dataclass
|
|
21
|
-
from logging import
|
|
21
|
+
from logging import ERROR, INFO, WARN
|
|
22
|
+
from pathlib import Path
|
|
22
23
|
from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union
|
|
23
24
|
|
|
24
25
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
@@ -26,8 +27,8 @@ from grpc import RpcError
|
|
|
26
27
|
|
|
27
28
|
from flwr.client.client import Client
|
|
28
29
|
from flwr.client.client_app import ClientApp, LoadClientAppError
|
|
29
|
-
from flwr.client.typing import
|
|
30
|
-
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event
|
|
30
|
+
from flwr.client.typing import ClientFnExt
|
|
31
|
+
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, event
|
|
31
32
|
from flwr.common.address import parse_address
|
|
32
33
|
from flwr.common.constant import (
|
|
33
34
|
MISSING_EXTRA_REST,
|
|
@@ -41,6 +42,7 @@ from flwr.common.constant import (
|
|
|
41
42
|
from flwr.common.logger import log, warn_deprecated_feature
|
|
42
43
|
from flwr.common.message import Error
|
|
43
44
|
from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
|
|
45
|
+
from flwr.common.typing import Run, UserConfig
|
|
44
46
|
|
|
45
47
|
from .grpc_adapter_client.connection import grpc_adapter
|
|
46
48
|
from .grpc_client.connection import grpc_connection
|
|
@@ -51,7 +53,7 @@ from .numpy_client import NumPyClient
|
|
|
51
53
|
|
|
52
54
|
|
|
53
55
|
def _check_actionable_client(
|
|
54
|
-
client: Optional[Client], client_fn: Optional[
|
|
56
|
+
client: Optional[Client], client_fn: Optional[ClientFnExt]
|
|
55
57
|
) -> None:
|
|
56
58
|
if client_fn is None and client is None:
|
|
57
59
|
raise ValueError(
|
|
@@ -72,7 +74,7 @@ def _check_actionable_client(
|
|
|
72
74
|
def start_client(
|
|
73
75
|
*,
|
|
74
76
|
server_address: str,
|
|
75
|
-
client_fn: Optional[
|
|
77
|
+
client_fn: Optional[ClientFnExt] = None,
|
|
76
78
|
client: Optional[Client] = None,
|
|
77
79
|
grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
|
|
78
80
|
root_certificates: Optional[Union[bytes, str]] = None,
|
|
@@ -92,7 +94,7 @@ def start_client(
|
|
|
92
94
|
The IPv4 or IPv6 address of the server. If the Flower
|
|
93
95
|
server runs on the same machine on port 8080, then `server_address`
|
|
94
96
|
would be `"[::]:8080"`.
|
|
95
|
-
client_fn : Optional[
|
|
97
|
+
client_fn : Optional[ClientFnExt]
|
|
96
98
|
A callable that instantiates a Client. (default: None)
|
|
97
99
|
client : Optional[flwr.client.Client]
|
|
98
100
|
An implementation of the abstract base
|
|
@@ -136,8 +138,8 @@ def start_client(
|
|
|
136
138
|
|
|
137
139
|
Starting an SSL-enabled gRPC client using system certificates:
|
|
138
140
|
|
|
139
|
-
>>> def client_fn(
|
|
140
|
-
>>> return FlowerClient()
|
|
141
|
+
>>> def client_fn(context: Context):
|
|
142
|
+
>>> return FlowerClient().to_client()
|
|
141
143
|
>>>
|
|
142
144
|
>>> start_client(
|
|
143
145
|
>>> server_address=localhost:8080,
|
|
@@ -158,6 +160,7 @@ def start_client(
|
|
|
158
160
|
event(EventType.START_CLIENT_ENTER)
|
|
159
161
|
_start_client_internal(
|
|
160
162
|
server_address=server_address,
|
|
163
|
+
node_config={},
|
|
161
164
|
load_client_app_fn=None,
|
|
162
165
|
client_fn=client_fn,
|
|
163
166
|
client=client,
|
|
@@ -179,8 +182,9 @@ def start_client(
|
|
|
179
182
|
def _start_client_internal(
|
|
180
183
|
*,
|
|
181
184
|
server_address: str,
|
|
185
|
+
node_config: UserConfig,
|
|
182
186
|
load_client_app_fn: Optional[Callable[[str, str], ClientApp]] = None,
|
|
183
|
-
client_fn: Optional[
|
|
187
|
+
client_fn: Optional[ClientFnExt] = None,
|
|
184
188
|
client: Optional[Client] = None,
|
|
185
189
|
grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
|
|
186
190
|
root_certificates: Optional[Union[bytes, str]] = None,
|
|
@@ -191,6 +195,7 @@ def _start_client_internal(
|
|
|
191
195
|
] = None,
|
|
192
196
|
max_retries: Optional[int] = None,
|
|
193
197
|
max_wait_time: Optional[float] = None,
|
|
198
|
+
flwr_path: Optional[Path] = None,
|
|
194
199
|
) -> None:
|
|
195
200
|
"""Start a Flower client node which connects to a Flower server.
|
|
196
201
|
|
|
@@ -200,9 +205,11 @@ def _start_client_internal(
|
|
|
200
205
|
The IPv4 or IPv6 address of the server. If the Flower
|
|
201
206
|
server runs on the same machine on port 8080, then `server_address`
|
|
202
207
|
would be `"[::]:8080"`.
|
|
208
|
+
node_config: UserConfig
|
|
209
|
+
The configuration of the node.
|
|
203
210
|
load_client_app_fn : Optional[Callable[[], ClientApp]] (default: None)
|
|
204
211
|
A function that can be used to load a `ClientApp` instance.
|
|
205
|
-
client_fn : Optional[
|
|
212
|
+
client_fn : Optional[ClientFnExt]
|
|
206
213
|
A callable that instantiates a Client. (default: None)
|
|
207
214
|
client : Optional[flwr.client.Client]
|
|
208
215
|
An implementation of the abstract base
|
|
@@ -234,6 +241,8 @@ def _start_client_internal(
|
|
|
234
241
|
The maximum duration before the client stops trying to
|
|
235
242
|
connect to the server in case of connection error.
|
|
236
243
|
If set to None, there is no limit to the total time.
|
|
244
|
+
flwr_path: Optional[Path] (default: None)
|
|
245
|
+
The fully resolved path containing installed Flower Apps.
|
|
237
246
|
"""
|
|
238
247
|
if insecure is None:
|
|
239
248
|
insecure = root_certificates is None
|
|
@@ -244,7 +253,7 @@ def _start_client_internal(
|
|
|
244
253
|
if client_fn is None:
|
|
245
254
|
# Wrap `Client` instance in `client_fn`
|
|
246
255
|
def single_client_factory(
|
|
247
|
-
|
|
256
|
+
context: Context, # pylint: disable=unused-argument
|
|
248
257
|
) -> Client:
|
|
249
258
|
if client is None: # Added this to keep mypy happy
|
|
250
259
|
raise ValueError(
|
|
@@ -285,7 +294,7 @@ def _start_client_internal(
|
|
|
285
294
|
log(WARN, "Connection attempt failed, retrying...")
|
|
286
295
|
else:
|
|
287
296
|
log(
|
|
288
|
-
|
|
297
|
+
WARN,
|
|
289
298
|
"Connection attempt failed, retrying in %.2f seconds",
|
|
290
299
|
retry_state.actual_wait,
|
|
291
300
|
)
|
|
@@ -293,7 +302,7 @@ def _start_client_internal(
|
|
|
293
302
|
retry_invoker = RetryInvoker(
|
|
294
303
|
wait_gen_factory=exponential,
|
|
295
304
|
recoverable_exceptions=connection_error_type,
|
|
296
|
-
max_tries=max_retries,
|
|
305
|
+
max_tries=max_retries + 1 if max_retries is not None else None,
|
|
297
306
|
max_time=max_wait_time,
|
|
298
307
|
on_giveup=lambda retry_state: (
|
|
299
308
|
log(
|
|
@@ -309,9 +318,10 @@ def _start_client_internal(
|
|
|
309
318
|
on_backoff=_on_backoff,
|
|
310
319
|
)
|
|
311
320
|
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
321
|
+
# NodeState gets initialized when the first connection is established
|
|
322
|
+
node_state: Optional[NodeState] = None
|
|
323
|
+
|
|
324
|
+
runs: Dict[int, Run] = {}
|
|
315
325
|
|
|
316
326
|
while not app_state_tracker.interrupt:
|
|
317
327
|
sleep_duration: int = 0
|
|
@@ -325,9 +335,31 @@ def _start_client_internal(
|
|
|
325
335
|
) as conn:
|
|
326
336
|
receive, send, create_node, delete_node, get_run = conn
|
|
327
337
|
|
|
328
|
-
# Register node
|
|
329
|
-
if
|
|
330
|
-
create_node
|
|
338
|
+
# Register node when connecting the first time
|
|
339
|
+
if node_state is None:
|
|
340
|
+
if create_node is None:
|
|
341
|
+
if transport not in ["grpc-bidi", None]:
|
|
342
|
+
raise NotImplementedError(
|
|
343
|
+
"All transports except `grpc-bidi` require "
|
|
344
|
+
"an implementation for `create_node()`.'"
|
|
345
|
+
)
|
|
346
|
+
# gRPC-bidi doesn't have the concept of node_id,
|
|
347
|
+
# so we set it to -1
|
|
348
|
+
node_state = NodeState(
|
|
349
|
+
node_id=-1,
|
|
350
|
+
node_config={},
|
|
351
|
+
)
|
|
352
|
+
else:
|
|
353
|
+
# Call create_node fn to register node
|
|
354
|
+
node_id: Optional[int] = ( # pylint: disable=assignment-from-none
|
|
355
|
+
create_node()
|
|
356
|
+
) # pylint: disable=not-callable
|
|
357
|
+
if node_id is None:
|
|
358
|
+
raise ValueError("Node registration failed")
|
|
359
|
+
node_state = NodeState(
|
|
360
|
+
node_id=node_id,
|
|
361
|
+
node_config=node_config,
|
|
362
|
+
)
|
|
331
363
|
|
|
332
364
|
app_state_tracker.register_signal_handler()
|
|
333
365
|
while not app_state_tracker.interrupt:
|
|
@@ -361,15 +393,17 @@ def _start_client_internal(
|
|
|
361
393
|
|
|
362
394
|
# Get run info
|
|
363
395
|
run_id = message.metadata.run_id
|
|
364
|
-
if run_id not in
|
|
396
|
+
if run_id not in runs:
|
|
365
397
|
if get_run is not None:
|
|
366
|
-
|
|
398
|
+
runs[run_id] = get_run(run_id)
|
|
367
399
|
# If get_run is None, i.e., in grpc-bidi mode
|
|
368
400
|
else:
|
|
369
|
-
|
|
401
|
+
runs[run_id] = Run(run_id, "", "", {})
|
|
370
402
|
|
|
371
403
|
# Register context for this run
|
|
372
|
-
node_state.register_context(
|
|
404
|
+
node_state.register_context(
|
|
405
|
+
run_id=run_id, run=runs[run_id], flwr_path=flwr_path
|
|
406
|
+
)
|
|
373
407
|
|
|
374
408
|
# Retrieve context for this run
|
|
375
409
|
context = node_state.retrieve_context(run_id=run_id)
|
|
@@ -383,7 +417,10 @@ def _start_client_internal(
|
|
|
383
417
|
# Handle app loading and task message
|
|
384
418
|
try:
|
|
385
419
|
# Load ClientApp instance
|
|
386
|
-
|
|
420
|
+
run: Run = runs[run_id]
|
|
421
|
+
client_app: ClientApp = load_client_app_fn(
|
|
422
|
+
run.fab_id, run.fab_version
|
|
423
|
+
)
|
|
387
424
|
|
|
388
425
|
# Execute ClientApp
|
|
389
426
|
reply_message = client_app(message=message, context=context)
|
|
@@ -566,9 +603,9 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
|
|
|
566
603
|
Tuple[
|
|
567
604
|
Callable[[], Optional[Message]],
|
|
568
605
|
Callable[[Message], None],
|
|
606
|
+
Optional[Callable[[], Optional[int]]],
|
|
569
607
|
Optional[Callable[[], None]],
|
|
570
|
-
Optional[Callable[[],
|
|
571
|
-
Optional[Callable[[int], Tuple[str, str]]],
|
|
608
|
+
Optional[Callable[[int], Run]],
|
|
572
609
|
]
|
|
573
610
|
],
|
|
574
611
|
],
|
flwr/client/client_app.py
CHANGED
|
@@ -15,19 +15,62 @@
|
|
|
15
15
|
"""Flower ClientApp."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
import inspect
|
|
18
19
|
from typing import Callable, List, Optional
|
|
19
20
|
|
|
21
|
+
from flwr.client.client import Client
|
|
20
22
|
from flwr.client.message_handler.message_handler import (
|
|
21
23
|
handle_legacy_message_from_msgtype,
|
|
22
24
|
)
|
|
23
25
|
from flwr.client.mod.utils import make_ffn
|
|
24
|
-
from flwr.client.typing import
|
|
26
|
+
from flwr.client.typing import ClientFnExt, Mod
|
|
25
27
|
from flwr.common import Context, Message, MessageType
|
|
26
|
-
from flwr.common.logger import warn_preview_feature
|
|
28
|
+
from flwr.common.logger import warn_deprecated_feature, warn_preview_feature
|
|
27
29
|
|
|
28
30
|
from .typing import ClientAppCallable
|
|
29
31
|
|
|
30
32
|
|
|
33
|
+
def _alert_erroneous_client_fn() -> None:
|
|
34
|
+
raise ValueError(
|
|
35
|
+
"A `ClientApp` cannot make use of a `client_fn` that does "
|
|
36
|
+
"not have a signature in the form: `def client_fn(context: "
|
|
37
|
+
"Context)`. You can import the `Context` like this: "
|
|
38
|
+
"`from flwr.common import Context`"
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _inspect_maybe_adapt_client_fn_signature(client_fn: ClientFnExt) -> ClientFnExt:
|
|
43
|
+
client_fn_args = inspect.signature(client_fn).parameters
|
|
44
|
+
first_arg = list(client_fn_args.keys())[0]
|
|
45
|
+
|
|
46
|
+
if len(client_fn_args) != 1:
|
|
47
|
+
_alert_erroneous_client_fn()
|
|
48
|
+
|
|
49
|
+
first_arg_type = client_fn_args[first_arg].annotation
|
|
50
|
+
|
|
51
|
+
if first_arg_type is str or first_arg == "cid":
|
|
52
|
+
# Warn previous signature for `client_fn` seems to be used
|
|
53
|
+
warn_deprecated_feature(
|
|
54
|
+
"`client_fn` now expects a signature `def client_fn(context: Context)`."
|
|
55
|
+
"The provided `client_fn` has signature: "
|
|
56
|
+
f"{dict(client_fn_args.items())}. You can import the `Context` like this:"
|
|
57
|
+
" `from flwr.common import Context`"
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Wrap depcreated client_fn inside a function with the expected signature
|
|
61
|
+
def adaptor_fn(
|
|
62
|
+
context: Context,
|
|
63
|
+
) -> Client: # pylint: disable=unused-argument
|
|
64
|
+
# if patition-id is defined, pass it. Else pass node_id that should
|
|
65
|
+
# always be defined during Context init.
|
|
66
|
+
cid = context.node_config.get("partition-id", context.node_id)
|
|
67
|
+
return client_fn(str(cid)) # type: ignore
|
|
68
|
+
|
|
69
|
+
return adaptor_fn
|
|
70
|
+
|
|
71
|
+
return client_fn
|
|
72
|
+
|
|
73
|
+
|
|
31
74
|
class ClientAppException(Exception):
|
|
32
75
|
"""Exception raised when an exception is raised while executing a ClientApp."""
|
|
33
76
|
|
|
@@ -48,7 +91,7 @@ class ClientApp:
|
|
|
48
91
|
>>> class FlowerClient(NumPyClient):
|
|
49
92
|
>>> # ...
|
|
50
93
|
>>>
|
|
51
|
-
>>> def client_fn(
|
|
94
|
+
>>> def client_fn(context: Context):
|
|
52
95
|
>>> return FlowerClient().to_client()
|
|
53
96
|
>>>
|
|
54
97
|
>>> app = ClientApp(client_fn)
|
|
@@ -65,7 +108,7 @@ class ClientApp:
|
|
|
65
108
|
|
|
66
109
|
def __init__(
|
|
67
110
|
self,
|
|
68
|
-
client_fn: Optional[
|
|
111
|
+
client_fn: Optional[ClientFnExt] = None, # Only for backward compatibility
|
|
69
112
|
mods: Optional[List[Mod]] = None,
|
|
70
113
|
) -> None:
|
|
71
114
|
self._mods: List[Mod] = mods if mods is not None else []
|
|
@@ -74,6 +117,8 @@ class ClientApp:
|
|
|
74
117
|
self._call: Optional[ClientAppCallable] = None
|
|
75
118
|
if client_fn is not None:
|
|
76
119
|
|
|
120
|
+
client_fn = _inspect_maybe_adapt_client_fn_signature(client_fn)
|
|
121
|
+
|
|
77
122
|
def ffn(
|
|
78
123
|
message: Message,
|
|
79
124
|
context: Context,
|
|
@@ -27,6 +27,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
|
27
27
|
from flwr.common.logger import log
|
|
28
28
|
from flwr.common.message import Message
|
|
29
29
|
from flwr.common.retry_invoker import RetryInvoker
|
|
30
|
+
from flwr.common.typing import Run
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
@contextmanager
|
|
@@ -43,9 +44,9 @@ def grpc_adapter( # pylint: disable=R0913
|
|
|
43
44
|
Tuple[
|
|
44
45
|
Callable[[], Optional[Message]],
|
|
45
46
|
Callable[[Message], None],
|
|
47
|
+
Optional[Callable[[], Optional[int]]],
|
|
46
48
|
Optional[Callable[[], None]],
|
|
47
|
-
Optional[Callable[[],
|
|
48
|
-
Optional[Callable[[int], Tuple[str, str]]],
|
|
49
|
+
Optional[Callable[[int], Run]],
|
|
49
50
|
]
|
|
50
51
|
]:
|
|
51
52
|
"""Primitives for request/response-based interaction with a server via GrpcAdapter.
|
|
@@ -38,6 +38,7 @@ from flwr.common.constant import MessageType, MessageTypeLegacy
|
|
|
38
38
|
from flwr.common.grpc import create_channel
|
|
39
39
|
from flwr.common.logger import log
|
|
40
40
|
from flwr.common.retry_invoker import RetryInvoker
|
|
41
|
+
from flwr.common.typing import Run
|
|
41
42
|
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
|
42
43
|
ClientMessage,
|
|
43
44
|
Reason,
|
|
@@ -71,9 +72,9 @@ def grpc_connection( # pylint: disable=R0913, R0915
|
|
|
71
72
|
Tuple[
|
|
72
73
|
Callable[[], Optional[Message]],
|
|
73
74
|
Callable[[Message], None],
|
|
75
|
+
Optional[Callable[[], Optional[int]]],
|
|
74
76
|
Optional[Callable[[], None]],
|
|
75
|
-
Optional[Callable[[],
|
|
76
|
-
Optional[Callable[[int], Tuple[str, str]]],
|
|
77
|
+
Optional[Callable[[int], Run]],
|
|
77
78
|
]
|
|
78
79
|
]:
|
|
79
80
|
"""Establish a gRPC connection to a gRPC server.
|
|
@@ -40,7 +40,12 @@ from flwr.common.grpc import create_channel
|
|
|
40
40
|
from flwr.common.logger import log
|
|
41
41
|
from flwr.common.message import Message, Metadata
|
|
42
42
|
from flwr.common.retry_invoker import RetryInvoker
|
|
43
|
-
from flwr.common.serde import
|
|
43
|
+
from flwr.common.serde import (
|
|
44
|
+
message_from_taskins,
|
|
45
|
+
message_to_taskres,
|
|
46
|
+
user_config_from_proto,
|
|
47
|
+
)
|
|
48
|
+
from flwr.common.typing import Run
|
|
44
49
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
45
50
|
CreateNodeRequest,
|
|
46
51
|
DeleteNodeRequest,
|
|
@@ -78,9 +83,9 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
78
83
|
Tuple[
|
|
79
84
|
Callable[[], Optional[Message]],
|
|
80
85
|
Callable[[Message], None],
|
|
86
|
+
Optional[Callable[[], Optional[int]]],
|
|
81
87
|
Optional[Callable[[], None]],
|
|
82
|
-
Optional[Callable[[],
|
|
83
|
-
Optional[Callable[[int], Tuple[str, str]]],
|
|
88
|
+
Optional[Callable[[int], Run]],
|
|
84
89
|
]
|
|
85
90
|
]:
|
|
86
91
|
"""Primitives for request/response-based interaction with a server.
|
|
@@ -175,7 +180,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
175
180
|
if not ping_stop_event.is_set():
|
|
176
181
|
ping_stop_event.wait(next_interval)
|
|
177
182
|
|
|
178
|
-
def create_node() ->
|
|
183
|
+
def create_node() -> Optional[int]:
|
|
179
184
|
"""Set create_node."""
|
|
180
185
|
# Call FleetAPI
|
|
181
186
|
create_node_request = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL)
|
|
@@ -188,6 +193,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
188
193
|
nonlocal node, ping_thread
|
|
189
194
|
node = cast(Node, create_node_response.node)
|
|
190
195
|
ping_thread = start_ping_loop(ping, ping_stop_event)
|
|
196
|
+
return node.node_id
|
|
191
197
|
|
|
192
198
|
def delete_node() -> None:
|
|
193
199
|
"""Set delete_node."""
|
|
@@ -266,7 +272,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
266
272
|
# Cleanup
|
|
267
273
|
metadata = None
|
|
268
274
|
|
|
269
|
-
def get_run(run_id: int) ->
|
|
275
|
+
def get_run(run_id: int) -> Run:
|
|
270
276
|
# Call FleetAPI
|
|
271
277
|
get_run_request = GetRunRequest(run_id=run_id)
|
|
272
278
|
get_run_response: GetRunResponse = retry_invoker.invoke(
|
|
@@ -275,7 +281,12 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
275
281
|
)
|
|
276
282
|
|
|
277
283
|
# Return fab_id and fab_version
|
|
278
|
-
return
|
|
284
|
+
return Run(
|
|
285
|
+
run_id,
|
|
286
|
+
get_run_response.run.fab_id,
|
|
287
|
+
get_run_response.run.fab_version,
|
|
288
|
+
user_config_from_proto(get_run_response.run.override_config),
|
|
289
|
+
)
|
|
279
290
|
|
|
280
291
|
try:
|
|
281
292
|
# Yield methods
|