flwr-nightly 1.10.0.dev20240619__py3-none-any.whl → 1.10.0.dev20240707__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +3 -0
- flwr/cli/build.py +5 -9
- flwr/cli/new/new.py +104 -28
- flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +86 -0
- flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +124 -0
- flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
- flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +42 -0
- flwr/cli/run/run.py +21 -5
- flwr/client/__init__.py +2 -0
- flwr/client/app.py +15 -10
- flwr/client/client_app.py +30 -5
- flwr/client/dpfedavg_numpy_client.py +1 -1
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/connection.py +1 -1
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +4 -5
- flwr/client/mod/__init__.py +1 -1
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/utils.py +1 -1
- flwr/client/node_state.py +6 -3
- flwr/client/node_state_tests.py +1 -1
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +1 -1
- flwr/client/supernode/app.py +12 -4
- flwr/client/typing.py +2 -1
- flwr/common/address.py +1 -1
- flwr/common/config.py +8 -6
- flwr/common/constant.py +4 -1
- flwr/common/context.py +11 -1
- flwr/common/date.py +1 -1
- flwr/common/dp.py +1 -1
- flwr/common/grpc.py +1 -1
- flwr/common/logger.py +13 -0
- flwr/common/message.py +0 -17
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
- flwr/common/secure_aggregation/quantization.py +1 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
- flwr/common/version.py +14 -0
- flwr/server/compat/app.py +1 -1
- flwr/server/compat/app_utils.py +1 -1
- flwr/server/compat/driver_client_proxy.py +1 -1
- flwr/server/driver/driver.py +6 -0
- flwr/server/driver/grpc_driver.py +85 -63
- flwr/server/driver/inmemory_driver.py +28 -26
- flwr/server/run_serverapp.py +61 -18
- flwr/server/strategy/bulyan.py +1 -1
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/strategy/fedadagrad.py +1 -1
- flwr/server/strategy/fedadam.py +1 -1
- flwr/server/strategy/fedavg_android.py +1 -1
- flwr/server/strategy/fedavgm.py +1 -1
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedopt.py +1 -1
- flwr/server/strategy/fedprox.py +1 -1
- flwr/server/strategy/fedxgb_bagging.py +1 -1
- flwr/server/strategy/fedxgb_cyclic.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +1 -1
- flwr/server/strategy/fedyogi.py +1 -1
- flwr/server/strategy/krum.py +1 -1
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/driver_grpc.py +1 -1
- flwr/server/superlink/driver/driver_servicer.py +15 -3
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +1 -1
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +45 -26
- flwr/server/superlink/fleet/vce/vce_api.py +3 -8
- flwr/server/superlink/state/__init__.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +5 -5
- flwr/server/superlink/state/sqlite_state.py +5 -5
- flwr/server/superlink/state/state.py +1 -1
- flwr/server/superlink/state/state_factory.py +11 -2
- flwr/server/superlink/state/utils.py +6 -0
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +1 -1
- flwr/simulation/__init__.py +1 -1
- flwr/simulation/app.py +52 -37
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +0 -6
- flwr/simulation/ray_transport/ray_client_proxy.py +17 -10
- flwr/simulation/run_simulation.py +47 -28
- flwr/superexec/deployment.py +109 -0
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240707.dist-info}/METADATA +2 -1
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240707.dist-info}/RECORD +109 -98
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240707.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240707.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240707.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""$project_name: A Flower / FlowerTune app."""
|
|
2
|
+
|
|
3
|
+
from $import_name.client import set_parameters
|
|
4
|
+
from $import_name.models import get_model
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
# Get function that will be executed by the strategy's evaluate() method
|
|
8
|
+
# Here we use it to save global model checkpoints
|
|
9
|
+
def get_evaluate_fn(model_cfg, save_every_round, total_round, save_path):
|
|
10
|
+
"""Return an evaluation function for saving global model."""
|
|
11
|
+
|
|
12
|
+
def evaluate(server_round: int, parameters, config):
|
|
13
|
+
# Save model
|
|
14
|
+
if server_round != 0 and (
|
|
15
|
+
server_round == total_round or server_round % save_every_round == 0
|
|
16
|
+
):
|
|
17
|
+
# Init model
|
|
18
|
+
model = get_model(model_cfg)
|
|
19
|
+
set_parameters(model, parameters)
|
|
20
|
+
|
|
21
|
+
model.save_pretrained(f"{save_path}/peft_{server_round}")
|
|
22
|
+
|
|
23
|
+
return 0.0, {}
|
|
24
|
+
|
|
25
|
+
return evaluate
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_on_fit_config():
|
|
29
|
+
"""
|
|
30
|
+
Return a function that will be used to construct the config
|
|
31
|
+
that the client's fit() method will receive.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def fit_config_fn(server_round: int):
|
|
35
|
+
fit_config = {"current_round": server_round}
|
|
36
|
+
return fit_config
|
|
37
|
+
|
|
38
|
+
return fit_config_fn
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def fit_weighted_average(metrics):
|
|
42
|
+
"""Aggregate (federated) evaluation metrics."""
|
|
43
|
+
# Multiply accuracy of each client by number of examples used
|
|
44
|
+
losses = [num_examples * m["train_loss"] for num_examples, m in metrics]
|
|
45
|
+
examples = [num_examples for num_examples, _ in metrics]
|
|
46
|
+
|
|
47
|
+
# Aggregate and return custom metric (weighted average)
|
|
48
|
+
return {"train_loss": sum(losses) / sum(examples)}
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
# Federated Instruction Tuning (static)
|
|
2
|
+
---
|
|
3
|
+
dataset:
|
|
4
|
+
name: $dataset_name
|
|
5
|
+
|
|
6
|
+
# FL experimental settings
|
|
7
|
+
num_clients: $num_clients # total number of clients
|
|
8
|
+
num_rounds: 200
|
|
9
|
+
partitioner:
|
|
10
|
+
_target_: flwr_datasets.partitioner.IidPartitioner
|
|
11
|
+
num_partitions: $num_clients
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "$package_name"
|
|
7
|
+
version = "1.0.0"
|
|
8
|
+
description = ""
|
|
9
|
+
authors = [
|
|
10
|
+
{ name = "The Flower Authors", email = "hello@flower.ai" },
|
|
11
|
+
]
|
|
12
|
+
license = { text = "Apache License (2.0)" }
|
|
13
|
+
dependencies = [
|
|
14
|
+
"flwr[simulation]>=1.9.0,<2.0",
|
|
15
|
+
"flwr-datasets>=0.1.0,<1.0.0",
|
|
16
|
+
"hydra-core==1.3.2",
|
|
17
|
+
"trl==0.8.1",
|
|
18
|
+
"bitsandbytes==0.43.0",
|
|
19
|
+
"scipy==1.13.0",
|
|
20
|
+
"peft==0.6.2",
|
|
21
|
+
"transformers==4.39.3",
|
|
22
|
+
"sentencepiece==0.2.0",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
[tool.hatch.build.targets.wheel]
|
|
26
|
+
packages = ["."]
|
|
27
|
+
|
|
28
|
+
[flower]
|
|
29
|
+
publisher = "$username"
|
|
30
|
+
|
|
31
|
+
[flower.components]
|
|
32
|
+
serverapp = "$import_name.app:server"
|
|
33
|
+
clientapp = "$import_name.app:client"
|
|
34
|
+
|
|
35
|
+
[flower.engine]
|
|
36
|
+
name = "simulation"
|
|
37
|
+
|
|
38
|
+
[flower.engine.simulation.supernode]
|
|
39
|
+
num = $num_clients
|
|
40
|
+
|
|
41
|
+
[flower.engine.simulation]
|
|
42
|
+
backend_config = { client_resources = { num_cpus = 8, num_gpus = 1.0 } }
|
flwr/cli/run/run.py
CHANGED
|
@@ -17,12 +17,14 @@
|
|
|
17
17
|
import sys
|
|
18
18
|
from enum import Enum
|
|
19
19
|
from logging import DEBUG
|
|
20
|
+
from pathlib import Path
|
|
20
21
|
from typing import Optional
|
|
21
22
|
|
|
22
23
|
import typer
|
|
23
24
|
from typing_extensions import Annotated
|
|
24
25
|
|
|
25
26
|
from flwr.cli import config_utils
|
|
27
|
+
from flwr.cli.build import build
|
|
26
28
|
from flwr.common.constant import SUPEREXEC_DEFAULT_ADDRESS
|
|
27
29
|
from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
|
|
28
30
|
from flwr.common.logger import log
|
|
@@ -41,7 +43,10 @@ class Engine(str, Enum):
|
|
|
41
43
|
def run(
|
|
42
44
|
engine: Annotated[
|
|
43
45
|
Optional[Engine],
|
|
44
|
-
typer.Option(
|
|
46
|
+
typer.Option(
|
|
47
|
+
case_sensitive=False,
|
|
48
|
+
help="The engine to run FL with (currently only simulation is supported).",
|
|
49
|
+
),
|
|
45
50
|
] = None,
|
|
46
51
|
use_superexec: Annotated[
|
|
47
52
|
bool,
|
|
@@ -49,10 +54,14 @@ def run(
|
|
|
49
54
|
case_sensitive=False, help="Use this flag to use the new SuperExec API"
|
|
50
55
|
),
|
|
51
56
|
] = False,
|
|
57
|
+
directory: Annotated[
|
|
58
|
+
Optional[Path],
|
|
59
|
+
typer.Option(help="Path of the Flower project to run"),
|
|
60
|
+
] = None,
|
|
52
61
|
) -> None:
|
|
53
62
|
"""Run Flower project."""
|
|
54
63
|
if use_superexec:
|
|
55
|
-
_start_superexec_run()
|
|
64
|
+
_start_superexec_run(directory)
|
|
56
65
|
return
|
|
57
66
|
|
|
58
67
|
typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
|
|
@@ -87,12 +96,16 @@ def run(
|
|
|
87
96
|
|
|
88
97
|
if engine == Engine.SIMULATION:
|
|
89
98
|
num_supernodes = config["flower"]["engine"]["simulation"]["supernode"]["num"]
|
|
99
|
+
backend_config = config["flower"]["engine"]["simulation"].get(
|
|
100
|
+
"backend_config", None
|
|
101
|
+
)
|
|
90
102
|
|
|
91
103
|
typer.secho("Starting run... ", fg=typer.colors.BLUE)
|
|
92
104
|
_run_simulation(
|
|
93
105
|
server_app_attr=server_app_ref,
|
|
94
106
|
client_app_attr=client_app_ref,
|
|
95
107
|
num_supernodes=num_supernodes,
|
|
108
|
+
backend_config=backend_config,
|
|
96
109
|
)
|
|
97
110
|
else:
|
|
98
111
|
typer.secho(
|
|
@@ -102,7 +115,7 @@ def run(
|
|
|
102
115
|
)
|
|
103
116
|
|
|
104
117
|
|
|
105
|
-
def _start_superexec_run() -> None:
|
|
118
|
+
def _start_superexec_run(directory: Optional[Path]) -> None:
|
|
106
119
|
def on_channel_state_change(channel_connectivity: str) -> None:
|
|
107
120
|
"""Log channel connectivity."""
|
|
108
121
|
log(DEBUG, channel_connectivity)
|
|
@@ -117,5 +130,8 @@ def _start_superexec_run() -> None:
|
|
|
117
130
|
channel.subscribe(on_channel_state_change)
|
|
118
131
|
stub = ExecStub(channel)
|
|
119
132
|
|
|
120
|
-
|
|
121
|
-
|
|
133
|
+
fab_path = build(directory)
|
|
134
|
+
|
|
135
|
+
req = StartRunRequest(fab_file=Path(fab_path).read_bytes())
|
|
136
|
+
res = stub.StartRun(req)
|
|
137
|
+
typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)
|
flwr/client/__init__.py
CHANGED
|
@@ -23,11 +23,13 @@ from .numpy_client import NumPyClient as NumPyClient
|
|
|
23
23
|
from .supernode import run_client_app as run_client_app
|
|
24
24
|
from .supernode import run_supernode as run_supernode
|
|
25
25
|
from .typing import ClientFn as ClientFn
|
|
26
|
+
from .typing import ClientFnExt as ClientFnExt
|
|
26
27
|
|
|
27
28
|
__all__ = [
|
|
28
29
|
"Client",
|
|
29
30
|
"ClientApp",
|
|
30
31
|
"ClientFn",
|
|
32
|
+
"ClientFnExt",
|
|
31
33
|
"NumPyClient",
|
|
32
34
|
"mod",
|
|
33
35
|
"run_client_app",
|
flwr/client/app.py
CHANGED
|
@@ -26,7 +26,7 @@ from grpc import RpcError
|
|
|
26
26
|
|
|
27
27
|
from flwr.client.client import Client
|
|
28
28
|
from flwr.client.client_app import ClientApp, LoadClientAppError
|
|
29
|
-
from flwr.client.typing import
|
|
29
|
+
from flwr.client.typing import ClientFnExt
|
|
30
30
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event
|
|
31
31
|
from flwr.common.address import parse_address
|
|
32
32
|
from flwr.common.constant import (
|
|
@@ -51,7 +51,7 @@ from .numpy_client import NumPyClient
|
|
|
51
51
|
|
|
52
52
|
|
|
53
53
|
def _check_actionable_client(
|
|
54
|
-
client: Optional[Client], client_fn: Optional[
|
|
54
|
+
client: Optional[Client], client_fn: Optional[ClientFnExt]
|
|
55
55
|
) -> None:
|
|
56
56
|
if client_fn is None and client is None:
|
|
57
57
|
raise ValueError(
|
|
@@ -72,7 +72,7 @@ def _check_actionable_client(
|
|
|
72
72
|
def start_client(
|
|
73
73
|
*,
|
|
74
74
|
server_address: str,
|
|
75
|
-
client_fn: Optional[
|
|
75
|
+
client_fn: Optional[ClientFnExt] = None,
|
|
76
76
|
client: Optional[Client] = None,
|
|
77
77
|
grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
|
|
78
78
|
root_certificates: Optional[Union[bytes, str]] = None,
|
|
@@ -92,7 +92,7 @@ def start_client(
|
|
|
92
92
|
The IPv4 or IPv6 address of the server. If the Flower
|
|
93
93
|
server runs on the same machine on port 8080, then `server_address`
|
|
94
94
|
would be `"[::]:8080"`.
|
|
95
|
-
client_fn : Optional[
|
|
95
|
+
client_fn : Optional[ClientFnExt]
|
|
96
96
|
A callable that instantiates a Client. (default: None)
|
|
97
97
|
client : Optional[flwr.client.Client]
|
|
98
98
|
An implementation of the abstract base
|
|
@@ -136,7 +136,7 @@ def start_client(
|
|
|
136
136
|
|
|
137
137
|
Starting an SSL-enabled gRPC client using system certificates:
|
|
138
138
|
|
|
139
|
-
>>> def client_fn(
|
|
139
|
+
>>> def client_fn(node_id: int, partition_id: Optional[int]):
|
|
140
140
|
>>> return FlowerClient()
|
|
141
141
|
>>>
|
|
142
142
|
>>> start_client(
|
|
@@ -180,7 +180,7 @@ def _start_client_internal(
|
|
|
180
180
|
*,
|
|
181
181
|
server_address: str,
|
|
182
182
|
load_client_app_fn: Optional[Callable[[str, str], ClientApp]] = None,
|
|
183
|
-
client_fn: Optional[
|
|
183
|
+
client_fn: Optional[ClientFnExt] = None,
|
|
184
184
|
client: Optional[Client] = None,
|
|
185
185
|
grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
|
|
186
186
|
root_certificates: Optional[Union[bytes, str]] = None,
|
|
@@ -191,6 +191,7 @@ def _start_client_internal(
|
|
|
191
191
|
] = None,
|
|
192
192
|
max_retries: Optional[int] = None,
|
|
193
193
|
max_wait_time: Optional[float] = None,
|
|
194
|
+
partition_id: Optional[int] = None,
|
|
194
195
|
) -> None:
|
|
195
196
|
"""Start a Flower client node which connects to a Flower server.
|
|
196
197
|
|
|
@@ -202,7 +203,7 @@ def _start_client_internal(
|
|
|
202
203
|
would be `"[::]:8080"`.
|
|
203
204
|
load_client_app_fn : Optional[Callable[[], ClientApp]] (default: None)
|
|
204
205
|
A function that can be used to load a `ClientApp` instance.
|
|
205
|
-
client_fn : Optional[
|
|
206
|
+
client_fn : Optional[ClientFnExt]
|
|
206
207
|
A callable that instantiates a Client. (default: None)
|
|
207
208
|
client : Optional[flwr.client.Client]
|
|
208
209
|
An implementation of the abstract base
|
|
@@ -234,6 +235,9 @@ def _start_client_internal(
|
|
|
234
235
|
The maximum duration before the client stops trying to
|
|
235
236
|
connect to the server in case of connection error.
|
|
236
237
|
If set to None, there is no limit to the total time.
|
|
238
|
+
partitioni_id: Optional[int] (default: None)
|
|
239
|
+
The data partition index associated with this node. Better suited for
|
|
240
|
+
prototyping purposes.
|
|
237
241
|
"""
|
|
238
242
|
if insecure is None:
|
|
239
243
|
insecure = root_certificates is None
|
|
@@ -244,7 +248,8 @@ def _start_client_internal(
|
|
|
244
248
|
if client_fn is None:
|
|
245
249
|
# Wrap `Client` instance in `client_fn`
|
|
246
250
|
def single_client_factory(
|
|
247
|
-
|
|
251
|
+
node_id: int, # pylint: disable=unused-argument
|
|
252
|
+
partition_id: Optional[int], # pylint: disable=unused-argument
|
|
248
253
|
) -> Client:
|
|
249
254
|
if client is None: # Added this to keep mypy happy
|
|
250
255
|
raise ValueError(
|
|
@@ -293,7 +298,7 @@ def _start_client_internal(
|
|
|
293
298
|
retry_invoker = RetryInvoker(
|
|
294
299
|
wait_gen_factory=exponential,
|
|
295
300
|
recoverable_exceptions=connection_error_type,
|
|
296
|
-
max_tries=max_retries,
|
|
301
|
+
max_tries=max_retries + 1 if max_retries is not None else None,
|
|
297
302
|
max_time=max_wait_time,
|
|
298
303
|
on_giveup=lambda retry_state: (
|
|
299
304
|
log(
|
|
@@ -309,7 +314,7 @@ def _start_client_internal(
|
|
|
309
314
|
on_backoff=_on_backoff,
|
|
310
315
|
)
|
|
311
316
|
|
|
312
|
-
node_state = NodeState()
|
|
317
|
+
node_state = NodeState(partition_id=partition_id)
|
|
313
318
|
# run_id -> (fab_id, fab_version)
|
|
314
319
|
run_info: Dict[int, Tuple[str, str]] = {}
|
|
315
320
|
|
flwr/client/client_app.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -15,19 +15,42 @@
|
|
|
15
15
|
"""Flower ClientApp."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
import inspect
|
|
18
19
|
from typing import Callable, List, Optional
|
|
19
20
|
|
|
21
|
+
from flwr.client.client import Client
|
|
20
22
|
from flwr.client.message_handler.message_handler import (
|
|
21
23
|
handle_legacy_message_from_msgtype,
|
|
22
24
|
)
|
|
23
25
|
from flwr.client.mod.utils import make_ffn
|
|
24
|
-
from flwr.client.typing import
|
|
26
|
+
from flwr.client.typing import ClientFnExt, Mod
|
|
25
27
|
from flwr.common import Context, Message, MessageType
|
|
26
|
-
from flwr.common.logger import warn_preview_feature
|
|
28
|
+
from flwr.common.logger import warn_deprecated_feature, warn_preview_feature
|
|
27
29
|
|
|
28
30
|
from .typing import ClientAppCallable
|
|
29
31
|
|
|
30
32
|
|
|
33
|
+
def _inspect_maybe_adapt_client_fn_signature(client_fn: ClientFnExt) -> ClientFnExt:
|
|
34
|
+
client_fn_args = inspect.signature(client_fn).parameters
|
|
35
|
+
|
|
36
|
+
if not all(key in client_fn_args for key in ["node_id", "partition_id"]):
|
|
37
|
+
warn_deprecated_feature(
|
|
38
|
+
"`client_fn` now expects a signature `def client_fn(node_id: int, "
|
|
39
|
+
"partition_id: Optional[int])`.\nYou provided `client_fn` with signature: "
|
|
40
|
+
f"{dict(client_fn_args.items())}"
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Wrap depcreated client_fn inside a function with the expected signature
|
|
44
|
+
def adaptor_fn(
|
|
45
|
+
node_id: int, partition_id: Optional[int] # pylint: disable=unused-argument
|
|
46
|
+
) -> Client:
|
|
47
|
+
return client_fn(str(partition_id)) # type: ignore
|
|
48
|
+
|
|
49
|
+
return adaptor_fn
|
|
50
|
+
|
|
51
|
+
return client_fn
|
|
52
|
+
|
|
53
|
+
|
|
31
54
|
class ClientAppException(Exception):
|
|
32
55
|
"""Exception raised when an exception is raised while executing a ClientApp."""
|
|
33
56
|
|
|
@@ -48,7 +71,7 @@ class ClientApp:
|
|
|
48
71
|
>>> class FlowerClient(NumPyClient):
|
|
49
72
|
>>> # ...
|
|
50
73
|
>>>
|
|
51
|
-
>>> def client_fn(
|
|
74
|
+
>>> def client_fn(node_id: int, partition_id: Optional[int]):
|
|
52
75
|
>>> return FlowerClient().to_client()
|
|
53
76
|
>>>
|
|
54
77
|
>>> app = ClientApp(client_fn)
|
|
@@ -65,7 +88,7 @@ class ClientApp:
|
|
|
65
88
|
|
|
66
89
|
def __init__(
|
|
67
90
|
self,
|
|
68
|
-
client_fn: Optional[
|
|
91
|
+
client_fn: Optional[ClientFnExt] = None, # Only for backward compatibility
|
|
69
92
|
mods: Optional[List[Mod]] = None,
|
|
70
93
|
) -> None:
|
|
71
94
|
self._mods: List[Mod] = mods if mods is not None else []
|
|
@@ -74,6 +97,8 @@ class ClientApp:
|
|
|
74
97
|
self._call: Optional[ClientAppCallable] = None
|
|
75
98
|
if client_fn is not None:
|
|
76
99
|
|
|
100
|
+
client_fn = _inspect_maybe_adapt_client_fn_signature(client_fn)
|
|
101
|
+
|
|
77
102
|
def ffn(
|
|
78
103
|
message: Message,
|
|
79
104
|
context: Context,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2022 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -14,7 +14,6 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Client-side message handler."""
|
|
16
16
|
|
|
17
|
-
|
|
18
17
|
from logging import WARN
|
|
19
18
|
from typing import Optional, Tuple, cast
|
|
20
19
|
|
|
@@ -25,7 +24,7 @@ from flwr.client.client import (
|
|
|
25
24
|
maybe_call_get_properties,
|
|
26
25
|
)
|
|
27
26
|
from flwr.client.numpy_client import NumPyClient
|
|
28
|
-
from flwr.client.typing import
|
|
27
|
+
from flwr.client.typing import ClientFnExt
|
|
29
28
|
from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordSet, log
|
|
30
29
|
from flwr.common.constant import MessageType, MessageTypeLegacy
|
|
31
30
|
from flwr.common.recordset_compat import (
|
|
@@ -90,10 +89,10 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]:
|
|
|
90
89
|
|
|
91
90
|
|
|
92
91
|
def handle_legacy_message_from_msgtype(
|
|
93
|
-
client_fn:
|
|
92
|
+
client_fn: ClientFnExt, message: Message, context: Context
|
|
94
93
|
) -> Message:
|
|
95
94
|
"""Handle legacy message in the inner most mod."""
|
|
96
|
-
client = client_fn(
|
|
95
|
+
client = client_fn(message.metadata.dst_node_id, context.partition_id)
|
|
97
96
|
|
|
98
97
|
# Check if NumPyClient is returend
|
|
99
98
|
if isinstance(client, NumPyClient):
|
flwr/client/mod/__init__.py
CHANGED
flwr/client/mod/utils.py
CHANGED
flwr/client/node_state.py
CHANGED
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
"""Node state."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from typing import Any, Dict
|
|
18
|
+
from typing import Any, Dict, Optional
|
|
19
19
|
|
|
20
20
|
from flwr.common import Context, RecordSet
|
|
21
21
|
|
|
@@ -23,14 +23,17 @@ from flwr.common import Context, RecordSet
|
|
|
23
23
|
class NodeState:
|
|
24
24
|
"""State of a node where client nodes execute runs."""
|
|
25
25
|
|
|
26
|
-
def __init__(self) -> None:
|
|
26
|
+
def __init__(self, partition_id: Optional[int]) -> None:
|
|
27
27
|
self._meta: Dict[str, Any] = {} # holds metadata about the node
|
|
28
28
|
self.run_contexts: Dict[int, Context] = {}
|
|
29
|
+
self._partition_id = partition_id
|
|
29
30
|
|
|
30
31
|
def register_context(self, run_id: int) -> None:
|
|
31
32
|
"""Register new run context for this node."""
|
|
32
33
|
if run_id not in self.run_contexts:
|
|
33
|
-
self.run_contexts[run_id] = Context(
|
|
34
|
+
self.run_contexts[run_id] = Context(
|
|
35
|
+
state=RecordSet(), partition_id=self._partition_id
|
|
36
|
+
)
|
|
34
37
|
|
|
35
38
|
def retrieve_context(self, run_id: int) -> Context:
|
|
36
39
|
"""Get run context given a run_id."""
|
flwr/client/node_state_tests.py
CHANGED
flwr/client/supernode/app.py
CHANGED
|
@@ -67,6 +67,7 @@ def run_supernode() -> None:
|
|
|
67
67
|
authentication_keys=authentication_keys,
|
|
68
68
|
max_retries=args.max_retries,
|
|
69
69
|
max_wait_time=args.max_wait_time,
|
|
70
|
+
partition_id=args.partition_id,
|
|
70
71
|
)
|
|
71
72
|
|
|
72
73
|
# Graceful shutdown
|
|
@@ -267,7 +268,7 @@ def _parse_args_run_supernode() -> argparse.ArgumentParser:
|
|
|
267
268
|
"--flwr-dir",
|
|
268
269
|
default=None,
|
|
269
270
|
help="""The path containing installed Flower Apps.
|
|
270
|
-
By default, this value
|
|
271
|
+
By default, this value is equal to:
|
|
271
272
|
|
|
272
273
|
- `$FLWR_HOME/` if `$FLWR_HOME` is defined
|
|
273
274
|
- `$XDG_DATA_HOME/.flwr/` if `$XDG_DATA_HOME` is defined
|
|
@@ -344,8 +345,8 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
344
345
|
"--max-retries",
|
|
345
346
|
type=int,
|
|
346
347
|
default=None,
|
|
347
|
-
help="The maximum number of times the client will try to
|
|
348
|
-
"
|
|
348
|
+
help="The maximum number of times the client will try to reconnect to the"
|
|
349
|
+
"SuperLink before giving up in case of a connection error. By default,"
|
|
349
350
|
"it is set to None, meaning there is no limit to the number of tries.",
|
|
350
351
|
)
|
|
351
352
|
parser.add_argument(
|
|
@@ -353,7 +354,7 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
353
354
|
type=float,
|
|
354
355
|
default=None,
|
|
355
356
|
help="The maximum duration before the client stops trying to"
|
|
356
|
-
"connect to the
|
|
357
|
+
"connect to the SuperLink in case of connection error. By default, it"
|
|
357
358
|
"is set to None, meaning there is no limit to the total time.",
|
|
358
359
|
)
|
|
359
360
|
parser.add_argument(
|
|
@@ -373,6 +374,13 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
373
374
|
type=str,
|
|
374
375
|
help="The SuperNode's public key (as a path str) to enable authentication.",
|
|
375
376
|
)
|
|
377
|
+
parser.add_argument(
|
|
378
|
+
"--partition-id",
|
|
379
|
+
type=int,
|
|
380
|
+
help="The data partition index associated with this SuperNode. Better suited "
|
|
381
|
+
"for prototyping purposes where a SuperNode might only load a fraction of an "
|
|
382
|
+
"artificially partitioned dataset (e.g. using `flwr-datasets`)",
|
|
383
|
+
)
|
|
376
384
|
|
|
377
385
|
|
|
378
386
|
def _try_setup_client_authentication(
|
flwr/client/typing.py
CHANGED
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
"""Custom types for Flower clients."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from typing import Callable
|
|
18
|
+
from typing import Callable, Optional
|
|
19
19
|
|
|
20
20
|
from flwr.common import Context, Message
|
|
21
21
|
|
|
@@ -23,6 +23,7 @@ from .client import Client as Client
|
|
|
23
23
|
|
|
24
24
|
# Compatibility
|
|
25
25
|
ClientFn = Callable[[str], Client]
|
|
26
|
+
ClientFnExt = Callable[[int, Optional[int]], Client]
|
|
26
27
|
|
|
27
28
|
ClientAppCallable = Callable[[Message, Context], Message]
|
|
28
29
|
Mod = Callable[[Message, Context, ClientAppCallable], Message]
|
flwr/common/address.py
CHANGED
flwr/common/config.py
CHANGED
|
@@ -24,14 +24,16 @@ from flwr.cli.config_utils import validate_fields
|
|
|
24
24
|
from flwr.common.constant import APP_DIR, FAB_CONFIG_FILE, FLWR_HOME
|
|
25
25
|
|
|
26
26
|
|
|
27
|
-
def get_flwr_dir() -> Path:
|
|
27
|
+
def get_flwr_dir(provided_path: Optional[str] = None) -> Path:
|
|
28
28
|
"""Return the Flower home directory based on env variables."""
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
29
|
+
if provided_path is None or not Path(provided_path).is_dir():
|
|
30
|
+
return Path(
|
|
31
|
+
os.getenv(
|
|
32
|
+
FLWR_HOME,
|
|
33
|
+
f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}/.flwr",
|
|
34
|
+
)
|
|
33
35
|
)
|
|
34
|
-
)
|
|
36
|
+
return Path(provided_path).absolute()
|
|
35
37
|
|
|
36
38
|
|
|
37
39
|
def get_project_dir(
|
flwr/common/constant.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -46,6 +46,9 @@ PING_BASE_MULTIPLIER = 0.8
|
|
|
46
46
|
PING_RANDOM_RANGE = (-0.1, 0.1)
|
|
47
47
|
PING_MAX_INTERVAL = 1e300
|
|
48
48
|
|
|
49
|
+
# IDs
|
|
50
|
+
RUN_ID_NUM_BYTES = 8
|
|
51
|
+
NODE_ID_NUM_BYTES = 8
|
|
49
52
|
GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version"
|
|
50
53
|
GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY = "should-exit"
|
|
51
54
|
|
flwr/common/context.py
CHANGED
|
@@ -16,13 +16,14 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from dataclasses import dataclass
|
|
19
|
+
from typing import Optional
|
|
19
20
|
|
|
20
21
|
from .record import RecordSet
|
|
21
22
|
|
|
22
23
|
|
|
23
24
|
@dataclass
|
|
24
25
|
class Context:
|
|
25
|
-
"""
|
|
26
|
+
"""Context of your run.
|
|
26
27
|
|
|
27
28
|
Parameters
|
|
28
29
|
----------
|
|
@@ -33,6 +34,15 @@ class Context:
|
|
|
33
34
|
executing mods. It can also be used as a memory to access
|
|
34
35
|
at different points during the lifecycle of this entity (e.g. across
|
|
35
36
|
multiple rounds)
|
|
37
|
+
partition_id : Optional[int] (default: None)
|
|
38
|
+
An index that specifies the data partition that the ClientApp using this Context
|
|
39
|
+
object should make use of. Setting this attribute is better suited for
|
|
40
|
+
simulation or proto typing setups.
|
|
36
41
|
"""
|
|
37
42
|
|
|
38
43
|
state: RecordSet
|
|
44
|
+
partition_id: Optional[int]
|
|
45
|
+
|
|
46
|
+
def __init__(self, state: RecordSet, partition_id: Optional[int] = None) -> None:
|
|
47
|
+
self.state = state
|
|
48
|
+
self.partition_id = partition_id
|