flwr-nightly 1.10.0.dev20240707__py3-none-any.whl → 1.10.0.dev20240722__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +16 -2
- flwr/cli/config_utils.py +36 -14
- flwr/cli/install.py +17 -1
- flwr/cli/new/new.py +31 -20
- flwr/cli/new/templates/app/code/client.hf.py.tpl +11 -3
- flwr/cli/new/templates/app/code/client.jax.py.tpl +2 -1
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +15 -10
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +2 -1
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +12 -3
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +6 -3
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +13 -3
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +2 -2
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.hf.py.tpl +16 -11
- flwr/cli/new/templates/app/code/server.jax.py.tpl +15 -8
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +11 -7
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +15 -8
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +15 -13
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +16 -10
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +16 -13
- flwr/cli/new/templates/app/code/task.hf.py.tpl +2 -2
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +2 -2
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +9 -12
- flwr/cli/new/templates/app/pyproject.hf.toml.tpl +17 -16
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +17 -11
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +17 -12
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +12 -12
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +13 -12
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +12 -12
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +15 -12
- flwr/cli/run/run.py +128 -53
- flwr/client/app.py +56 -24
- flwr/client/client_app.py +28 -8
- flwr/client/grpc_adapter_client/connection.py +3 -2
- flwr/client/grpc_client/connection.py +3 -2
- flwr/client/grpc_rere_client/connection.py +17 -6
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/node_state.py +59 -12
- flwr/client/node_state_tests.py +4 -3
- flwr/client/rest_client/connection.py +19 -8
- flwr/client/supernode/app.py +55 -24
- flwr/client/typing.py +2 -2
- flwr/common/config.py +87 -2
- flwr/common/constant.py +3 -0
- flwr/common/context.py +24 -9
- flwr/common/logger.py +25 -0
- flwr/common/serde.py +45 -0
- flwr/common/telemetry.py +17 -0
- flwr/common/typing.py +5 -0
- flwr/proto/common_pb2.py +36 -0
- flwr/proto/common_pb2.pyi +121 -0
- flwr/proto/common_pb2_grpc.py +4 -0
- flwr/proto/common_pb2_grpc.pyi +4 -0
- flwr/proto/driver_pb2.py +24 -19
- flwr/proto/driver_pb2.pyi +21 -1
- flwr/proto/exec_pb2.py +16 -11
- flwr/proto/exec_pb2.pyi +22 -1
- flwr/proto/run_pb2.py +12 -7
- flwr/proto/run_pb2.pyi +22 -1
- flwr/proto/task_pb2.py +7 -8
- flwr/server/__init__.py +2 -0
- flwr/server/compat/legacy_context.py +5 -4
- flwr/server/driver/grpc_driver.py +82 -140
- flwr/server/run_serverapp.py +40 -15
- flwr/server/server_app.py +56 -10
- flwr/server/serverapp_components.py +52 -0
- flwr/server/superlink/driver/driver_servicer.py +18 -3
- flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
- flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
- flwr/server/superlink/fleet/vce/backend/raybackend.py +10 -10
- flwr/server/superlink/fleet/vce/vce_api.py +149 -117
- flwr/server/superlink/state/in_memory_state.py +11 -3
- flwr/server/superlink/state/sqlite_state.py +23 -8
- flwr/server/superlink/state/state.py +7 -2
- flwr/server/typing.py +2 -0
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
- flwr/simulation/app.py +4 -3
- flwr/simulation/ray_transport/ray_actor.py +15 -19
- flwr/simulation/ray_transport/ray_client_proxy.py +22 -9
- flwr/simulation/run_simulation.py +237 -66
- flwr/superexec/app.py +14 -7
- flwr/superexec/deployment.py +110 -33
- flwr/superexec/exec_grpc.py +5 -1
- flwr/superexec/exec_servicer.py +4 -1
- flwr/superexec/executor.py +18 -0
- flwr/superexec/simulation.py +151 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/METADATA +3 -2
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/RECORD +92 -86
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/entry_points.txt +0 -0
flwr/cli/run/run.py
CHANGED
|
@@ -14,59 +14,49 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower command line interface `run` command."""
|
|
16
16
|
|
|
17
|
+
import subprocess
|
|
17
18
|
import sys
|
|
18
|
-
from enum import Enum
|
|
19
19
|
from logging import DEBUG
|
|
20
20
|
from pathlib import Path
|
|
21
|
-
from typing import Optional
|
|
21
|
+
from typing import Any, Dict, List, Optional
|
|
22
22
|
|
|
23
23
|
import typer
|
|
24
24
|
from typing_extensions import Annotated
|
|
25
25
|
|
|
26
|
-
from flwr.cli import config_utils
|
|
27
26
|
from flwr.cli.build import build
|
|
28
|
-
from flwr.
|
|
27
|
+
from flwr.cli.config_utils import load_and_validate
|
|
28
|
+
from flwr.common.config import parse_config_args
|
|
29
29
|
from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
|
|
30
30
|
from flwr.common.logger import log
|
|
31
|
+
from flwr.common.serde import user_config_to_proto
|
|
31
32
|
from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
|
|
32
33
|
from flwr.proto.exec_pb2_grpc import ExecStub
|
|
33
|
-
from flwr.simulation.run_simulation import _run_simulation
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
class Engine(str, Enum):
|
|
37
|
-
"""Enum defining the engine to run on."""
|
|
38
|
-
|
|
39
|
-
SIMULATION = "simulation"
|
|
40
34
|
|
|
41
35
|
|
|
42
36
|
# pylint: disable-next=too-many-locals
|
|
43
37
|
def run(
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
typer.
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
38
|
+
directory: Annotated[
|
|
39
|
+
Path,
|
|
40
|
+
typer.Argument(help="Path of the Flower project to run"),
|
|
41
|
+
] = Path("."),
|
|
42
|
+
federation_name: Annotated[
|
|
43
|
+
Optional[str],
|
|
44
|
+
typer.Argument(help="Name of the federation to run the app on"),
|
|
50
45
|
] = None,
|
|
51
|
-
|
|
52
|
-
|
|
46
|
+
config_overrides: Annotated[
|
|
47
|
+
Optional[List[str]],
|
|
53
48
|
typer.Option(
|
|
54
|
-
|
|
49
|
+
"--run-config",
|
|
50
|
+
"-c",
|
|
51
|
+
help="Override configuration key-value pairs",
|
|
55
52
|
),
|
|
56
|
-
] = False,
|
|
57
|
-
directory: Annotated[
|
|
58
|
-
Optional[Path],
|
|
59
|
-
typer.Option(help="Path of the Flower project to run"),
|
|
60
53
|
] = None,
|
|
61
54
|
) -> None:
|
|
62
55
|
"""Run Flower project."""
|
|
63
|
-
if use_superexec:
|
|
64
|
-
_start_superexec_run(directory)
|
|
65
|
-
return
|
|
66
|
-
|
|
67
56
|
typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
|
|
68
57
|
|
|
69
|
-
|
|
58
|
+
pyproject_path = directory / "pyproject.toml" if directory else None
|
|
59
|
+
config, errors, warnings = load_and_validate(path=pyproject_path)
|
|
70
60
|
|
|
71
61
|
if config is None:
|
|
72
62
|
typer.secho(
|
|
@@ -88,42 +78,83 @@ def run(
|
|
|
88
78
|
|
|
89
79
|
typer.secho("Success", fg=typer.colors.GREEN)
|
|
90
80
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
if engine is None:
|
|
95
|
-
engine = config["flower"]["engine"]["name"]
|
|
96
|
-
|
|
97
|
-
if engine == Engine.SIMULATION:
|
|
98
|
-
num_supernodes = config["flower"]["engine"]["simulation"]["supernode"]["num"]
|
|
99
|
-
backend_config = config["flower"]["engine"]["simulation"].get(
|
|
100
|
-
"backend_config", None
|
|
101
|
-
)
|
|
81
|
+
federation_name = federation_name or config["tool"]["flwr"]["federations"].get(
|
|
82
|
+
"default"
|
|
83
|
+
)
|
|
102
84
|
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
85
|
+
if federation_name is None:
|
|
86
|
+
typer.secho(
|
|
87
|
+
"❌ No federation name was provided and the project's `pyproject.toml` "
|
|
88
|
+
"doesn't declare a default federation (with a SuperExec address or an "
|
|
89
|
+
"`options.num-supernodes` value).",
|
|
90
|
+
fg=typer.colors.RED,
|
|
91
|
+
bold=True,
|
|
109
92
|
)
|
|
110
|
-
|
|
93
|
+
raise typer.Exit(code=1)
|
|
94
|
+
|
|
95
|
+
# Validate the federation exists in the configuration
|
|
96
|
+
federation = config["tool"]["flwr"]["federations"].get(federation_name)
|
|
97
|
+
if federation is None:
|
|
98
|
+
available_feds = {
|
|
99
|
+
fed for fed in config["tool"]["flwr"]["federations"] if fed != "default"
|
|
100
|
+
}
|
|
111
101
|
typer.secho(
|
|
112
|
-
f"
|
|
102
|
+
f"❌ There is no `{federation_name}` federation declared in the "
|
|
103
|
+
"`pyproject.toml`.\n The following federations were found:\n\n"
|
|
104
|
+
+ "\n".join(available_feds),
|
|
113
105
|
fg=typer.colors.RED,
|
|
114
106
|
bold=True,
|
|
115
107
|
)
|
|
108
|
+
raise typer.Exit(code=1)
|
|
116
109
|
|
|
110
|
+
if "address" in federation:
|
|
111
|
+
_run_with_superexec(federation, directory, config_overrides)
|
|
112
|
+
else:
|
|
113
|
+
_run_without_superexec(directory, federation, federation_name, config_overrides)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _run_with_superexec(
|
|
117
|
+
federation: Dict[str, str],
|
|
118
|
+
directory: Optional[Path],
|
|
119
|
+
config_overrides: Optional[List[str]],
|
|
120
|
+
) -> None:
|
|
117
121
|
|
|
118
|
-
def _start_superexec_run(directory: Optional[Path]) -> None:
|
|
119
122
|
def on_channel_state_change(channel_connectivity: str) -> None:
|
|
120
123
|
"""Log channel connectivity."""
|
|
121
124
|
log(DEBUG, channel_connectivity)
|
|
122
125
|
|
|
126
|
+
insecure_str = federation.get("insecure")
|
|
127
|
+
if root_certificates := federation.get("root-certificates"):
|
|
128
|
+
root_certificates_bytes = Path(root_certificates).read_bytes()
|
|
129
|
+
if insecure := bool(insecure_str):
|
|
130
|
+
typer.secho(
|
|
131
|
+
"❌ `root_certificates` were provided but the `insecure` parameter"
|
|
132
|
+
"is set to `True`.",
|
|
133
|
+
fg=typer.colors.RED,
|
|
134
|
+
bold=True,
|
|
135
|
+
)
|
|
136
|
+
raise typer.Exit(code=1)
|
|
137
|
+
else:
|
|
138
|
+
root_certificates_bytes = None
|
|
139
|
+
if insecure_str is None:
|
|
140
|
+
typer.secho(
|
|
141
|
+
"❌ To disable TLS, set `insecure = true` in `pyproject.toml`.",
|
|
142
|
+
fg=typer.colors.RED,
|
|
143
|
+
bold=True,
|
|
144
|
+
)
|
|
145
|
+
raise typer.Exit(code=1)
|
|
146
|
+
if not (insecure := bool(insecure_str)):
|
|
147
|
+
typer.secho(
|
|
148
|
+
"❌ No certificate were given yet `insecure` is set to `False`.",
|
|
149
|
+
fg=typer.colors.RED,
|
|
150
|
+
bold=True,
|
|
151
|
+
)
|
|
152
|
+
raise typer.Exit(code=1)
|
|
153
|
+
|
|
123
154
|
channel = create_channel(
|
|
124
|
-
server_address=
|
|
125
|
-
insecure=
|
|
126
|
-
root_certificates=
|
|
155
|
+
server_address=federation["address"],
|
|
156
|
+
insecure=insecure,
|
|
157
|
+
root_certificates=root_certificates_bytes,
|
|
127
158
|
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
|
|
128
159
|
interceptors=None,
|
|
129
160
|
)
|
|
@@ -132,6 +163,50 @@ def _start_superexec_run(directory: Optional[Path]) -> None:
|
|
|
132
163
|
|
|
133
164
|
fab_path = build(directory)
|
|
134
165
|
|
|
135
|
-
req = StartRunRequest(
|
|
166
|
+
req = StartRunRequest(
|
|
167
|
+
fab_file=Path(fab_path).read_bytes(),
|
|
168
|
+
override_config=user_config_to_proto(
|
|
169
|
+
parse_config_args(config_overrides, separator=",")
|
|
170
|
+
),
|
|
171
|
+
)
|
|
136
172
|
res = stub.StartRun(req)
|
|
137
173
|
typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _run_without_superexec(
|
|
177
|
+
app_path: Optional[Path],
|
|
178
|
+
federation: Dict[str, Any],
|
|
179
|
+
federation_name: str,
|
|
180
|
+
config_overrides: Optional[List[str]],
|
|
181
|
+
) -> None:
|
|
182
|
+
try:
|
|
183
|
+
num_supernodes = federation["options"]["num-supernodes"]
|
|
184
|
+
except KeyError as err:
|
|
185
|
+
typer.secho(
|
|
186
|
+
"❌ The project's `pyproject.toml` needs to declare the number of"
|
|
187
|
+
" SuperNodes in the simulation. To simulate 10 SuperNodes,"
|
|
188
|
+
" use the following notation:\n\n"
|
|
189
|
+
f"[tool.flwr.federations.{federation_name}]\n"
|
|
190
|
+
"options.num-supernodes = 10\n",
|
|
191
|
+
fg=typer.colors.RED,
|
|
192
|
+
bold=True,
|
|
193
|
+
)
|
|
194
|
+
raise typer.Exit(code=1) from err
|
|
195
|
+
|
|
196
|
+
command = [
|
|
197
|
+
"flower-simulation",
|
|
198
|
+
"--app",
|
|
199
|
+
f"{app_path}",
|
|
200
|
+
"--num-supernodes",
|
|
201
|
+
f"{num_supernodes}",
|
|
202
|
+
]
|
|
203
|
+
|
|
204
|
+
if config_overrides:
|
|
205
|
+
command.extend(["--run-config", f"{','.join(config_overrides)}"])
|
|
206
|
+
|
|
207
|
+
# Run the simulation
|
|
208
|
+
subprocess.run(
|
|
209
|
+
command,
|
|
210
|
+
check=True,
|
|
211
|
+
text=True,
|
|
212
|
+
)
|
flwr/client/app.py
CHANGED
|
@@ -18,7 +18,8 @@ import signal
|
|
|
18
18
|
import sys
|
|
19
19
|
import time
|
|
20
20
|
from dataclasses import dataclass
|
|
21
|
-
from logging import
|
|
21
|
+
from logging import ERROR, INFO, WARN
|
|
22
|
+
from pathlib import Path
|
|
22
23
|
from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union
|
|
23
24
|
|
|
24
25
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
@@ -27,7 +28,7 @@ from grpc import RpcError
|
|
|
27
28
|
from flwr.client.client import Client
|
|
28
29
|
from flwr.client.client_app import ClientApp, LoadClientAppError
|
|
29
30
|
from flwr.client.typing import ClientFnExt
|
|
30
|
-
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event
|
|
31
|
+
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, event
|
|
31
32
|
from flwr.common.address import parse_address
|
|
32
33
|
from flwr.common.constant import (
|
|
33
34
|
MISSING_EXTRA_REST,
|
|
@@ -41,6 +42,7 @@ from flwr.common.constant import (
|
|
|
41
42
|
from flwr.common.logger import log, warn_deprecated_feature
|
|
42
43
|
from flwr.common.message import Error
|
|
43
44
|
from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
|
|
45
|
+
from flwr.common.typing import Run, UserConfig
|
|
44
46
|
|
|
45
47
|
from .grpc_adapter_client.connection import grpc_adapter
|
|
46
48
|
from .grpc_client.connection import grpc_connection
|
|
@@ -136,8 +138,8 @@ def start_client(
|
|
|
136
138
|
|
|
137
139
|
Starting an SSL-enabled gRPC client using system certificates:
|
|
138
140
|
|
|
139
|
-
>>> def client_fn(
|
|
140
|
-
>>> return FlowerClient()
|
|
141
|
+
>>> def client_fn(context: Context):
|
|
142
|
+
>>> return FlowerClient().to_client()
|
|
141
143
|
>>>
|
|
142
144
|
>>> start_client(
|
|
143
145
|
>>> server_address=localhost:8080,
|
|
@@ -158,6 +160,7 @@ def start_client(
|
|
|
158
160
|
event(EventType.START_CLIENT_ENTER)
|
|
159
161
|
_start_client_internal(
|
|
160
162
|
server_address=server_address,
|
|
163
|
+
node_config={},
|
|
161
164
|
load_client_app_fn=None,
|
|
162
165
|
client_fn=client_fn,
|
|
163
166
|
client=client,
|
|
@@ -179,6 +182,7 @@ def start_client(
|
|
|
179
182
|
def _start_client_internal(
|
|
180
183
|
*,
|
|
181
184
|
server_address: str,
|
|
185
|
+
node_config: UserConfig,
|
|
182
186
|
load_client_app_fn: Optional[Callable[[str, str], ClientApp]] = None,
|
|
183
187
|
client_fn: Optional[ClientFnExt] = None,
|
|
184
188
|
client: Optional[Client] = None,
|
|
@@ -191,7 +195,7 @@ def _start_client_internal(
|
|
|
191
195
|
] = None,
|
|
192
196
|
max_retries: Optional[int] = None,
|
|
193
197
|
max_wait_time: Optional[float] = None,
|
|
194
|
-
|
|
198
|
+
flwr_path: Optional[Path] = None,
|
|
195
199
|
) -> None:
|
|
196
200
|
"""Start a Flower client node which connects to a Flower server.
|
|
197
201
|
|
|
@@ -201,6 +205,8 @@ def _start_client_internal(
|
|
|
201
205
|
The IPv4 or IPv6 address of the server. If the Flower
|
|
202
206
|
server runs on the same machine on port 8080, then `server_address`
|
|
203
207
|
would be `"[::]:8080"`.
|
|
208
|
+
node_config: UserConfig
|
|
209
|
+
The configuration of the node.
|
|
204
210
|
load_client_app_fn : Optional[Callable[[], ClientApp]] (default: None)
|
|
205
211
|
A function that can be used to load a `ClientApp` instance.
|
|
206
212
|
client_fn : Optional[ClientFnExt]
|
|
@@ -235,9 +241,8 @@ def _start_client_internal(
|
|
|
235
241
|
The maximum duration before the client stops trying to
|
|
236
242
|
connect to the server in case of connection error.
|
|
237
243
|
If set to None, there is no limit to the total time.
|
|
238
|
-
|
|
239
|
-
The
|
|
240
|
-
prototyping purposes.
|
|
244
|
+
flwr_path: Optional[Path] (default: None)
|
|
245
|
+
The fully resolved path containing installed Flower Apps.
|
|
241
246
|
"""
|
|
242
247
|
if insecure is None:
|
|
243
248
|
insecure = root_certificates is None
|
|
@@ -248,8 +253,7 @@ def _start_client_internal(
|
|
|
248
253
|
if client_fn is None:
|
|
249
254
|
# Wrap `Client` instance in `client_fn`
|
|
250
255
|
def single_client_factory(
|
|
251
|
-
|
|
252
|
-
partition_id: Optional[int], # pylint: disable=unused-argument
|
|
256
|
+
context: Context, # pylint: disable=unused-argument
|
|
253
257
|
) -> Client:
|
|
254
258
|
if client is None: # Added this to keep mypy happy
|
|
255
259
|
raise ValueError(
|
|
@@ -290,7 +294,7 @@ def _start_client_internal(
|
|
|
290
294
|
log(WARN, "Connection attempt failed, retrying...")
|
|
291
295
|
else:
|
|
292
296
|
log(
|
|
293
|
-
|
|
297
|
+
WARN,
|
|
294
298
|
"Connection attempt failed, retrying in %.2f seconds",
|
|
295
299
|
retry_state.actual_wait,
|
|
296
300
|
)
|
|
@@ -314,9 +318,10 @@ def _start_client_internal(
|
|
|
314
318
|
on_backoff=_on_backoff,
|
|
315
319
|
)
|
|
316
320
|
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
321
|
+
# NodeState gets initialized when the first connection is established
|
|
322
|
+
node_state: Optional[NodeState] = None
|
|
323
|
+
|
|
324
|
+
runs: Dict[int, Run] = {}
|
|
320
325
|
|
|
321
326
|
while not app_state_tracker.interrupt:
|
|
322
327
|
sleep_duration: int = 0
|
|
@@ -330,9 +335,31 @@ def _start_client_internal(
|
|
|
330
335
|
) as conn:
|
|
331
336
|
receive, send, create_node, delete_node, get_run = conn
|
|
332
337
|
|
|
333
|
-
# Register node
|
|
334
|
-
if
|
|
335
|
-
create_node
|
|
338
|
+
# Register node when connecting the first time
|
|
339
|
+
if node_state is None:
|
|
340
|
+
if create_node is None:
|
|
341
|
+
if transport not in ["grpc-bidi", None]:
|
|
342
|
+
raise NotImplementedError(
|
|
343
|
+
"All transports except `grpc-bidi` require "
|
|
344
|
+
"an implementation for `create_node()`.'"
|
|
345
|
+
)
|
|
346
|
+
# gRPC-bidi doesn't have the concept of node_id,
|
|
347
|
+
# so we set it to -1
|
|
348
|
+
node_state = NodeState(
|
|
349
|
+
node_id=-1,
|
|
350
|
+
node_config={},
|
|
351
|
+
)
|
|
352
|
+
else:
|
|
353
|
+
# Call create_node fn to register node
|
|
354
|
+
node_id: Optional[int] = ( # pylint: disable=assignment-from-none
|
|
355
|
+
create_node()
|
|
356
|
+
) # pylint: disable=not-callable
|
|
357
|
+
if node_id is None:
|
|
358
|
+
raise ValueError("Node registration failed")
|
|
359
|
+
node_state = NodeState(
|
|
360
|
+
node_id=node_id,
|
|
361
|
+
node_config=node_config,
|
|
362
|
+
)
|
|
336
363
|
|
|
337
364
|
app_state_tracker.register_signal_handler()
|
|
338
365
|
while not app_state_tracker.interrupt:
|
|
@@ -366,15 +393,17 @@ def _start_client_internal(
|
|
|
366
393
|
|
|
367
394
|
# Get run info
|
|
368
395
|
run_id = message.metadata.run_id
|
|
369
|
-
if run_id not in
|
|
396
|
+
if run_id not in runs:
|
|
370
397
|
if get_run is not None:
|
|
371
|
-
|
|
398
|
+
runs[run_id] = get_run(run_id)
|
|
372
399
|
# If get_run is None, i.e., in grpc-bidi mode
|
|
373
400
|
else:
|
|
374
|
-
|
|
401
|
+
runs[run_id] = Run(run_id, "", "", {})
|
|
375
402
|
|
|
376
403
|
# Register context for this run
|
|
377
|
-
node_state.register_context(
|
|
404
|
+
node_state.register_context(
|
|
405
|
+
run_id=run_id, run=runs[run_id], flwr_path=flwr_path
|
|
406
|
+
)
|
|
378
407
|
|
|
379
408
|
# Retrieve context for this run
|
|
380
409
|
context = node_state.retrieve_context(run_id=run_id)
|
|
@@ -388,7 +417,10 @@ def _start_client_internal(
|
|
|
388
417
|
# Handle app loading and task message
|
|
389
418
|
try:
|
|
390
419
|
# Load ClientApp instance
|
|
391
|
-
|
|
420
|
+
run: Run = runs[run_id]
|
|
421
|
+
client_app: ClientApp = load_client_app_fn(
|
|
422
|
+
run.fab_id, run.fab_version
|
|
423
|
+
)
|
|
392
424
|
|
|
393
425
|
# Execute ClientApp
|
|
394
426
|
reply_message = client_app(message=message, context=context)
|
|
@@ -571,9 +603,9 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
|
|
|
571
603
|
Tuple[
|
|
572
604
|
Callable[[], Optional[Message]],
|
|
573
605
|
Callable[[Message], None],
|
|
606
|
+
Optional[Callable[[], Optional[int]]],
|
|
574
607
|
Optional[Callable[[], None]],
|
|
575
|
-
Optional[Callable[[],
|
|
576
|
-
Optional[Callable[[int], Tuple[str, str]]],
|
|
608
|
+
Optional[Callable[[int], Run]],
|
|
577
609
|
]
|
|
578
610
|
],
|
|
579
611
|
],
|
flwr/client/client_app.py
CHANGED
|
@@ -30,21 +30,41 @@ from flwr.common.logger import warn_deprecated_feature, warn_preview_feature
|
|
|
30
30
|
from .typing import ClientAppCallable
|
|
31
31
|
|
|
32
32
|
|
|
33
|
+
def _alert_erroneous_client_fn() -> None:
|
|
34
|
+
raise ValueError(
|
|
35
|
+
"A `ClientApp` cannot make use of a `client_fn` that does "
|
|
36
|
+
"not have a signature in the form: `def client_fn(context: "
|
|
37
|
+
"Context)`. You can import the `Context` like this: "
|
|
38
|
+
"`from flwr.common import Context`"
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
33
42
|
def _inspect_maybe_adapt_client_fn_signature(client_fn: ClientFnExt) -> ClientFnExt:
|
|
34
43
|
client_fn_args = inspect.signature(client_fn).parameters
|
|
44
|
+
first_arg = list(client_fn_args.keys())[0]
|
|
45
|
+
|
|
46
|
+
if len(client_fn_args) != 1:
|
|
47
|
+
_alert_erroneous_client_fn()
|
|
48
|
+
|
|
49
|
+
first_arg_type = client_fn_args[first_arg].annotation
|
|
35
50
|
|
|
36
|
-
if
|
|
51
|
+
if first_arg_type is str or first_arg == "cid":
|
|
52
|
+
# Warn previous signature for `client_fn` seems to be used
|
|
37
53
|
warn_deprecated_feature(
|
|
38
|
-
"`client_fn` now expects a signature `def client_fn(
|
|
39
|
-
"
|
|
40
|
-
f"{dict(client_fn_args.items())}"
|
|
54
|
+
"`client_fn` now expects a signature `def client_fn(context: Context)`."
|
|
55
|
+
"The provided `client_fn` has signature: "
|
|
56
|
+
f"{dict(client_fn_args.items())}. You can import the `Context` like this:"
|
|
57
|
+
" `from flwr.common import Context`"
|
|
41
58
|
)
|
|
42
59
|
|
|
43
60
|
# Wrap depcreated client_fn inside a function with the expected signature
|
|
44
61
|
def adaptor_fn(
|
|
45
|
-
|
|
46
|
-
) -> Client:
|
|
47
|
-
|
|
62
|
+
context: Context,
|
|
63
|
+
) -> Client: # pylint: disable=unused-argument
|
|
64
|
+
# if patition-id is defined, pass it. Else pass node_id that should
|
|
65
|
+
# always be defined during Context init.
|
|
66
|
+
cid = context.node_config.get("partition-id", context.node_id)
|
|
67
|
+
return client_fn(str(cid)) # type: ignore
|
|
48
68
|
|
|
49
69
|
return adaptor_fn
|
|
50
70
|
|
|
@@ -71,7 +91,7 @@ class ClientApp:
|
|
|
71
91
|
>>> class FlowerClient(NumPyClient):
|
|
72
92
|
>>> # ...
|
|
73
93
|
>>>
|
|
74
|
-
>>> def client_fn(
|
|
94
|
+
>>> def client_fn(context: Context):
|
|
75
95
|
>>> return FlowerClient().to_client()
|
|
76
96
|
>>>
|
|
77
97
|
>>> app = ClientApp(client_fn)
|
|
@@ -27,6 +27,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
|
27
27
|
from flwr.common.logger import log
|
|
28
28
|
from flwr.common.message import Message
|
|
29
29
|
from flwr.common.retry_invoker import RetryInvoker
|
|
30
|
+
from flwr.common.typing import Run
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
@contextmanager
|
|
@@ -43,9 +44,9 @@ def grpc_adapter( # pylint: disable=R0913
|
|
|
43
44
|
Tuple[
|
|
44
45
|
Callable[[], Optional[Message]],
|
|
45
46
|
Callable[[Message], None],
|
|
47
|
+
Optional[Callable[[], Optional[int]]],
|
|
46
48
|
Optional[Callable[[], None]],
|
|
47
|
-
Optional[Callable[[],
|
|
48
|
-
Optional[Callable[[int], Tuple[str, str]]],
|
|
49
|
+
Optional[Callable[[int], Run]],
|
|
49
50
|
]
|
|
50
51
|
]:
|
|
51
52
|
"""Primitives for request/response-based interaction with a server via GrpcAdapter.
|
|
@@ -38,6 +38,7 @@ from flwr.common.constant import MessageType, MessageTypeLegacy
|
|
|
38
38
|
from flwr.common.grpc import create_channel
|
|
39
39
|
from flwr.common.logger import log
|
|
40
40
|
from flwr.common.retry_invoker import RetryInvoker
|
|
41
|
+
from flwr.common.typing import Run
|
|
41
42
|
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
|
42
43
|
ClientMessage,
|
|
43
44
|
Reason,
|
|
@@ -71,9 +72,9 @@ def grpc_connection( # pylint: disable=R0913, R0915
|
|
|
71
72
|
Tuple[
|
|
72
73
|
Callable[[], Optional[Message]],
|
|
73
74
|
Callable[[Message], None],
|
|
75
|
+
Optional[Callable[[], Optional[int]]],
|
|
74
76
|
Optional[Callable[[], None]],
|
|
75
|
-
Optional[Callable[[],
|
|
76
|
-
Optional[Callable[[int], Tuple[str, str]]],
|
|
77
|
+
Optional[Callable[[int], Run]],
|
|
77
78
|
]
|
|
78
79
|
]:
|
|
79
80
|
"""Establish a gRPC connection to a gRPC server.
|
|
@@ -40,7 +40,12 @@ from flwr.common.grpc import create_channel
|
|
|
40
40
|
from flwr.common.logger import log
|
|
41
41
|
from flwr.common.message import Message, Metadata
|
|
42
42
|
from flwr.common.retry_invoker import RetryInvoker
|
|
43
|
-
from flwr.common.serde import
|
|
43
|
+
from flwr.common.serde import (
|
|
44
|
+
message_from_taskins,
|
|
45
|
+
message_to_taskres,
|
|
46
|
+
user_config_from_proto,
|
|
47
|
+
)
|
|
48
|
+
from flwr.common.typing import Run
|
|
44
49
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
45
50
|
CreateNodeRequest,
|
|
46
51
|
DeleteNodeRequest,
|
|
@@ -78,9 +83,9 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
78
83
|
Tuple[
|
|
79
84
|
Callable[[], Optional[Message]],
|
|
80
85
|
Callable[[Message], None],
|
|
86
|
+
Optional[Callable[[], Optional[int]]],
|
|
81
87
|
Optional[Callable[[], None]],
|
|
82
|
-
Optional[Callable[[],
|
|
83
|
-
Optional[Callable[[int], Tuple[str, str]]],
|
|
88
|
+
Optional[Callable[[int], Run]],
|
|
84
89
|
]
|
|
85
90
|
]:
|
|
86
91
|
"""Primitives for request/response-based interaction with a server.
|
|
@@ -175,7 +180,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
175
180
|
if not ping_stop_event.is_set():
|
|
176
181
|
ping_stop_event.wait(next_interval)
|
|
177
182
|
|
|
178
|
-
def create_node() ->
|
|
183
|
+
def create_node() -> Optional[int]:
|
|
179
184
|
"""Set create_node."""
|
|
180
185
|
# Call FleetAPI
|
|
181
186
|
create_node_request = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL)
|
|
@@ -188,6 +193,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
188
193
|
nonlocal node, ping_thread
|
|
189
194
|
node = cast(Node, create_node_response.node)
|
|
190
195
|
ping_thread = start_ping_loop(ping, ping_stop_event)
|
|
196
|
+
return node.node_id
|
|
191
197
|
|
|
192
198
|
def delete_node() -> None:
|
|
193
199
|
"""Set delete_node."""
|
|
@@ -266,7 +272,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
266
272
|
# Cleanup
|
|
267
273
|
metadata = None
|
|
268
274
|
|
|
269
|
-
def get_run(run_id: int) ->
|
|
275
|
+
def get_run(run_id: int) -> Run:
|
|
270
276
|
# Call FleetAPI
|
|
271
277
|
get_run_request = GetRunRequest(run_id=run_id)
|
|
272
278
|
get_run_response: GetRunResponse = retry_invoker.invoke(
|
|
@@ -275,7 +281,12 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
275
281
|
)
|
|
276
282
|
|
|
277
283
|
# Return fab_id and fab_version
|
|
278
|
-
return
|
|
284
|
+
return Run(
|
|
285
|
+
run_id,
|
|
286
|
+
get_run_response.run.fab_id,
|
|
287
|
+
get_run_response.run.fab_version,
|
|
288
|
+
user_config_from_proto(get_run_response.run.override_config),
|
|
289
|
+
)
|
|
279
290
|
|
|
280
291
|
try:
|
|
281
292
|
# Yield methods
|
|
@@ -92,7 +92,7 @@ def handle_legacy_message_from_msgtype(
|
|
|
92
92
|
client_fn: ClientFnExt, message: Message, context: Context
|
|
93
93
|
) -> Message:
|
|
94
94
|
"""Handle legacy message in the inner most mod."""
|
|
95
|
-
client = client_fn(
|
|
95
|
+
client = client_fn(context)
|
|
96
96
|
|
|
97
97
|
# Check if NumPyClient is returend
|
|
98
98
|
if isinstance(client, NumPyClient):
|