flwr-nightly 1.10.0.dev20240707__py3-none-any.whl → 1.11.0.dev20240724__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +16 -2
- flwr/cli/config_utils.py +47 -27
- flwr/cli/install.py +17 -1
- flwr/cli/new/new.py +32 -21
- flwr/cli/new/templates/app/code/{client.hf.py.tpl → client.huggingface.py.tpl} +15 -5
- flwr/cli/new/templates/app/code/client.jax.py.tpl +2 -1
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +36 -13
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +2 -1
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +16 -5
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +6 -3
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +25 -5
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +22 -19
- flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +5 -3
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
- flwr/cli/new/templates/app/code/server.jax.py.tpl +16 -8
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +12 -7
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +16 -8
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +15 -13
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -10
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +16 -13
- flwr/cli/new/templates/app/code/{task.hf.py.tpl → task.huggingface.py.tpl} +14 -2
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +14 -2
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -3
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +13 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +9 -12
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +17 -11
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +17 -12
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +12 -12
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +13 -12
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +12 -12
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +15 -12
- flwr/cli/run/run.py +133 -54
- flwr/client/app.py +56 -24
- flwr/client/client_app.py +28 -8
- flwr/client/grpc_adapter_client/connection.py +3 -2
- flwr/client/grpc_client/connection.py +3 -2
- flwr/client/grpc_rere_client/connection.py +17 -6
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/node_state.py +59 -12
- flwr/client/node_state_tests.py +4 -3
- flwr/client/rest_client/connection.py +19 -8
- flwr/client/supernode/app.py +39 -39
- flwr/client/typing.py +2 -2
- flwr/common/config.py +92 -2
- flwr/common/constant.py +3 -0
- flwr/common/context.py +24 -9
- flwr/common/logger.py +25 -0
- flwr/common/object_ref.py +84 -21
- flwr/common/serde.py +45 -0
- flwr/common/telemetry.py +17 -0
- flwr/common/typing.py +5 -0
- flwr/proto/common_pb2.py +36 -0
- flwr/proto/common_pb2.pyi +121 -0
- flwr/proto/common_pb2_grpc.py +4 -0
- flwr/proto/common_pb2_grpc.pyi +4 -0
- flwr/proto/driver_pb2.py +24 -19
- flwr/proto/driver_pb2.pyi +21 -1
- flwr/proto/exec_pb2.py +20 -11
- flwr/proto/exec_pb2.pyi +41 -1
- flwr/proto/run_pb2.py +12 -7
- flwr/proto/run_pb2.pyi +22 -1
- flwr/proto/task_pb2.py +7 -8
- flwr/server/__init__.py +2 -0
- flwr/server/compat/legacy_context.py +5 -4
- flwr/server/driver/grpc_driver.py +82 -140
- flwr/server/run_serverapp.py +40 -18
- flwr/server/server_app.py +56 -10
- flwr/server/serverapp_components.py +52 -0
- flwr/server/superlink/driver/driver_servicer.py +18 -3
- flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
- flwr/server/superlink/fleet/vce/backend/raybackend.py +10 -10
- flwr/server/superlink/fleet/vce/vce_api.py +149 -117
- flwr/server/superlink/state/in_memory_state.py +11 -3
- flwr/server/superlink/state/sqlite_state.py +23 -8
- flwr/server/superlink/state/state.py +7 -2
- flwr/server/typing.py +2 -0
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
- flwr/simulation/__init__.py +1 -1
- flwr/simulation/app.py +4 -3
- flwr/simulation/ray_transport/ray_actor.py +15 -19
- flwr/simulation/ray_transport/ray_client_proxy.py +22 -9
- flwr/simulation/run_simulation.py +269 -70
- flwr/superexec/app.py +17 -11
- flwr/superexec/deployment.py +111 -35
- flwr/superexec/exec_grpc.py +5 -1
- flwr/superexec/exec_servicer.py +6 -1
- flwr/superexec/executor.py +21 -0
- flwr/superexec/simulation.py +181 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/METADATA +3 -2
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/RECORD +97 -91
- flwr/cli/new/templates/app/code/server.hf.py.tpl +0 -17
- flwr/cli/new/templates/app/pyproject.hf.toml.tpl +0 -37
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/entry_points.txt +0 -0
|
@@ -9,13 +9,13 @@ from hydra import compose, initialize
|
|
|
9
9
|
from hydra.utils import instantiate
|
|
10
10
|
|
|
11
11
|
from flwr.client import ClientApp
|
|
12
|
-
from flwr.common import ndarrays_to_parameters
|
|
13
|
-
from flwr.server import ServerApp, ServerConfig
|
|
12
|
+
from flwr.common import Context, ndarrays_to_parameters
|
|
13
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
14
14
|
|
|
15
|
-
from $import_name.
|
|
15
|
+
from $import_name.client_app import gen_client_fn, get_parameters
|
|
16
16
|
from $import_name.dataset import get_tokenizer_and_data_collator_and_propt_formatting
|
|
17
17
|
from $import_name.models import get_model
|
|
18
|
-
from $import_name.
|
|
18
|
+
from $import_name.server_app import fit_weighted_average, get_evaluate_fn, get_on_fit_config
|
|
19
19
|
|
|
20
20
|
# Avoid warnings
|
|
21
21
|
warnings.filterwarnings("ignore", category=UserWarning)
|
|
@@ -67,20 +67,23 @@ init_model = get_model(cfg.model)
|
|
|
67
67
|
init_model_parameters = get_parameters(init_model)
|
|
68
68
|
init_model_parameters = ndarrays_to_parameters(init_model_parameters)
|
|
69
69
|
|
|
70
|
-
|
|
71
|
-
#
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
)
|
|
70
|
+
def server_fn(context: Context):
|
|
71
|
+
# Instantiate strategy according to config. Here we pass other arguments
|
|
72
|
+
# that are only defined at runtime.
|
|
73
|
+
strategy = instantiate(
|
|
74
|
+
cfg.strategy,
|
|
75
|
+
on_fit_config_fn=get_on_fit_config(),
|
|
76
|
+
fit_metrics_aggregation_fn=fit_weighted_average,
|
|
77
|
+
initial_parameters=init_model_parameters,
|
|
78
|
+
evaluate_fn=get_evaluate_fn(
|
|
79
|
+
cfg.model, cfg.train.save_every_round, cfg_static.num_rounds, save_path
|
|
80
|
+
),
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
config = ServerConfig(num_rounds=cfg_static.num_rounds)
|
|
84
|
+
|
|
85
|
+
return ServerAppComponents(strategy=strategy, config=config)
|
|
86
|
+
|
|
81
87
|
|
|
82
88
|
# ServerApp for Flower Next
|
|
83
|
-
server = ServerApp(
|
|
84
|
-
config=ServerConfig(num_rounds=cfg_static.num_rounds),
|
|
85
|
-
strategy=strategy,
|
|
86
|
-
)
|
|
89
|
+
server = ServerApp(server_fn=server_fn)
|
|
@@ -10,6 +10,7 @@ from transformers import TrainingArguments
|
|
|
10
10
|
from trl import SFTTrainer
|
|
11
11
|
|
|
12
12
|
from flwr.client import NumPyClient
|
|
13
|
+
from flwr.common import Context
|
|
13
14
|
from flwr.common.typing import NDArrays, Scalar
|
|
14
15
|
from $import_name.dataset import reformat
|
|
15
16
|
from $import_name.models import cosine_annealing, get_model
|
|
@@ -102,13 +103,14 @@ def gen_client_fn(
|
|
|
102
103
|
model_cfg: DictConfig,
|
|
103
104
|
train_cfg: DictConfig,
|
|
104
105
|
save_path: str,
|
|
105
|
-
) -> Callable[[
|
|
106
|
+
) -> Callable[[Context], FlowerClient]: # pylint: disable=too-many-arguments
|
|
106
107
|
"""Generate the client function that creates the Flower Clients."""
|
|
107
108
|
|
|
108
|
-
def client_fn(
|
|
109
|
+
def client_fn(context: Context) -> FlowerClient:
|
|
109
110
|
"""Create a Flower client representing a single organization."""
|
|
110
111
|
# Let's get the partition corresponding to the i-th client
|
|
111
|
-
|
|
112
|
+
partition_id = context.node_config["partition-id"]
|
|
113
|
+
client_trainset = fds.load_partition(partition_id, "train")
|
|
112
114
|
client_trainset = reformat(client_trainset, llm_task="$llm_challenge_str")
|
|
113
115
|
|
|
114
116
|
return FlowerClient(
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""$project_name: A Flower / HuggingFace Transformers app."""
|
|
2
|
+
|
|
3
|
+
from flwr.common import Context
|
|
4
|
+
from flwr.server.strategy import FedAvg
|
|
5
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def server_fn(context: Context):
|
|
9
|
+
# Read from config
|
|
10
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
11
|
+
|
|
12
|
+
# Define strategy
|
|
13
|
+
strategy = FedAvg(
|
|
14
|
+
fraction_fit=1.0,
|
|
15
|
+
fraction_evaluate=1.0,
|
|
16
|
+
)
|
|
17
|
+
config = ServerConfig(num_rounds=num_rounds)
|
|
18
|
+
|
|
19
|
+
return ServerAppComponents(strategy=strategy, config=config)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# Create ServerApp
|
|
23
|
+
app = ServerApp(server_fn=server_fn)
|
|
@@ -1,12 +1,20 @@
|
|
|
1
1
|
"""$project_name: A Flower / JAX app."""
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
from flwr.common import Context
|
|
4
|
+
from flwr.server.strategy import FedAvg
|
|
5
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
4
6
|
|
|
5
|
-
# Configure the strategy
|
|
6
|
-
strategy = fl.server.strategy.FedAvg()
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
8
|
+
def server_fn(context: Context):
|
|
9
|
+
# Read from config
|
|
10
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
11
|
+
|
|
12
|
+
# Define strategy
|
|
13
|
+
strategy = FedAvg()
|
|
14
|
+
config = ServerConfig(num_rounds=num_rounds)
|
|
15
|
+
|
|
16
|
+
return ServerAppComponents(strategy=strategy, config=config)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# Create ServerApp
|
|
20
|
+
app = ServerApp(server_fn=server_fn)
|
|
@@ -1,15 +1,20 @@
|
|
|
1
1
|
"""$project_name: A Flower / MLX app."""
|
|
2
2
|
|
|
3
|
-
from flwr.
|
|
3
|
+
from flwr.common import Context
|
|
4
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
4
5
|
from flwr.server.strategy import FedAvg
|
|
5
6
|
|
|
6
7
|
|
|
7
|
-
|
|
8
|
-
|
|
8
|
+
def server_fn(context: Context):
|
|
9
|
+
# Read from config
|
|
10
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
11
|
+
|
|
12
|
+
# Define strategy
|
|
13
|
+
strategy = FedAvg()
|
|
14
|
+
config = ServerConfig(num_rounds=num_rounds)
|
|
15
|
+
|
|
16
|
+
return ServerAppComponents(strategy=strategy, config=config)
|
|
9
17
|
|
|
10
18
|
|
|
11
19
|
# Create ServerApp
|
|
12
|
-
app = ServerApp(
|
|
13
|
-
config=ServerConfig(num_rounds=3),
|
|
14
|
-
strategy=strategy,
|
|
15
|
-
)
|
|
20
|
+
app = ServerApp(server_fn=server_fn)
|
|
@@ -1,12 +1,20 @@
|
|
|
1
1
|
"""$project_name: A Flower / NumPy app."""
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
from flwr.common import Context
|
|
4
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
5
|
+
from flwr.server.strategy import FedAvg
|
|
4
6
|
|
|
5
|
-
# Configure the strategy
|
|
6
|
-
strategy = fl.server.strategy.FedAvg()
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
8
|
+
def server_fn(context: Context):
|
|
9
|
+
# Read from config
|
|
10
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
11
|
+
|
|
12
|
+
# Define strategy
|
|
13
|
+
strategy = FedAvg()
|
|
14
|
+
config = ServerConfig(num_rounds=num_rounds)
|
|
15
|
+
|
|
16
|
+
return ServerAppComponents(strategy=strategy, config=config)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# Create ServerApp
|
|
20
|
+
app = ServerApp(server_fn=server_fn)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""$project_name: A Flower / PyTorch app."""
|
|
2
2
|
|
|
3
|
-
from flwr.common import ndarrays_to_parameters
|
|
4
|
-
from flwr.server import ServerApp, ServerConfig
|
|
3
|
+
from flwr.common import Context, ndarrays_to_parameters
|
|
4
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
5
5
|
from flwr.server.strategy import FedAvg
|
|
6
6
|
|
|
7
7
|
from $import_name.task import Net, get_weights
|
|
@@ -11,18 +11,20 @@ from $import_name.task import Net, get_weights
|
|
|
11
11
|
ndarrays = get_weights(Net())
|
|
12
12
|
parameters = ndarrays_to_parameters(ndarrays)
|
|
13
13
|
|
|
14
|
+
def server_fn(context: Context):
|
|
15
|
+
# Read from config
|
|
16
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
14
17
|
|
|
15
|
-
# Define strategy
|
|
16
|
-
strategy = FedAvg(
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
)
|
|
18
|
+
# Define strategy
|
|
19
|
+
strategy = FedAvg(
|
|
20
|
+
fraction_fit=1.0,
|
|
21
|
+
fraction_evaluate=1.0,
|
|
22
|
+
min_available_clients=2,
|
|
23
|
+
initial_parameters=parameters,
|
|
24
|
+
)
|
|
25
|
+
config = ServerConfig(num_rounds=num_rounds)
|
|
22
26
|
|
|
27
|
+
return ServerAppComponents(strategy=strategy, config=config)
|
|
23
28
|
|
|
24
29
|
# Create ServerApp
|
|
25
|
-
app = ServerApp(
|
|
26
|
-
config=ServerConfig(num_rounds=3),
|
|
27
|
-
strategy=strategy,
|
|
28
|
-
)
|
|
30
|
+
app = ServerApp(server_fn=server_fn)
|
|
@@ -1,17 +1,24 @@
|
|
|
1
1
|
"""$project_name: A Flower / Scikit-Learn app."""
|
|
2
2
|
|
|
3
|
-
from flwr.
|
|
3
|
+
from flwr.common import Context
|
|
4
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
4
5
|
from flwr.server.strategy import FedAvg
|
|
5
6
|
|
|
6
7
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
8
|
+
def server_fn(context: Context):
|
|
9
|
+
# Read from config
|
|
10
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
11
|
+
|
|
12
|
+
# Define strategy
|
|
13
|
+
strategy = FedAvg(
|
|
14
|
+
fraction_fit=1.0,
|
|
15
|
+
fraction_evaluate=1.0,
|
|
16
|
+
min_available_clients=2,
|
|
17
|
+
)
|
|
18
|
+
config = ServerConfig(num_rounds=num_rounds)
|
|
19
|
+
|
|
20
|
+
return ServerAppComponents(strategy=strategy, config=config)
|
|
21
|
+
|
|
12
22
|
|
|
13
23
|
# Create ServerApp
|
|
14
|
-
app = ServerApp(
|
|
15
|
-
config=ServerConfig(num_rounds=3),
|
|
16
|
-
strategy=strategy,
|
|
17
|
-
)
|
|
24
|
+
app = ServerApp(server_fn=server_fn)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""$project_name: A Flower / TensorFlow app."""
|
|
2
2
|
|
|
3
|
-
from flwr.common import ndarrays_to_parameters
|
|
4
|
-
from flwr.server import ServerApp, ServerConfig
|
|
3
|
+
from flwr.common import Context, ndarrays_to_parameters
|
|
4
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
5
5
|
from flwr.server.strategy import FedAvg
|
|
6
6
|
|
|
7
7
|
from $import_name.task import load_model
|
|
@@ -11,17 +11,20 @@ config = ServerConfig(num_rounds=3)
|
|
|
11
11
|
|
|
12
12
|
parameters = ndarrays_to_parameters(load_model().get_weights())
|
|
13
13
|
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
fraction_evaluate=1.0,
|
|
18
|
-
min_available_clients=2,
|
|
19
|
-
initial_parameters=parameters,
|
|
20
|
-
)
|
|
14
|
+
def server_fn(context: Context):
|
|
15
|
+
# Read from config
|
|
16
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
21
17
|
|
|
18
|
+
# Define strategy
|
|
19
|
+
strategy = strategy = FedAvg(
|
|
20
|
+
fraction_fit=1.0,
|
|
21
|
+
fraction_evaluate=1.0,
|
|
22
|
+
min_available_clients=2,
|
|
23
|
+
initial_parameters=parameters,
|
|
24
|
+
)
|
|
25
|
+
config = ServerConfig(num_rounds=num_rounds)
|
|
26
|
+
|
|
27
|
+
return ServerAppComponents(strategy=strategy, config=config)
|
|
22
28
|
|
|
23
29
|
# Create ServerApp
|
|
24
|
-
app = ServerApp(
|
|
25
|
-
config=config,
|
|
26
|
-
strategy=strategy,
|
|
27
|
-
)
|
|
30
|
+
app = ServerApp(server_fn=server_fn)
|
|
@@ -10,15 +10,27 @@ from torch.utils.data import DataLoader
|
|
|
10
10
|
from transformers import AutoTokenizer, DataCollatorWithPadding
|
|
11
11
|
|
|
12
12
|
from flwr_datasets import FederatedDataset
|
|
13
|
+
from flwr_datasets.partitioner import IidPartitioner
|
|
14
|
+
|
|
13
15
|
|
|
14
16
|
warnings.filterwarnings("ignore", category=UserWarning)
|
|
15
17
|
DEVICE = torch.device("cpu")
|
|
16
18
|
CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint
|
|
17
19
|
|
|
18
20
|
|
|
19
|
-
|
|
21
|
+
fds = None # Cache FederatedDataset
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def load_data(partition_id: int, num_partitions: int):
|
|
20
25
|
"""Load IMDB data (training and eval)"""
|
|
21
|
-
|
|
26
|
+
# Only initialize `FederatedDataset` once
|
|
27
|
+
global fds
|
|
28
|
+
if fds is None:
|
|
29
|
+
partitioner = IidPartitioner(num_partitions=num_partitions)
|
|
30
|
+
fds = FederatedDataset(
|
|
31
|
+
dataset="stanfordnlp/imdb",
|
|
32
|
+
partitioners={"train": partitioner},
|
|
33
|
+
)
|
|
22
34
|
partition = fds.load_partition(partition_id)
|
|
23
35
|
# Divide data: 80% train, 20% test
|
|
24
36
|
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
|
|
@@ -5,10 +5,12 @@ import mlx.nn as nn
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
from datasets.utils.logging import disable_progress_bar
|
|
7
7
|
from flwr_datasets import FederatedDataset
|
|
8
|
+
from flwr_datasets.partitioner import IidPartitioner
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
disable_progress_bar()
|
|
11
12
|
|
|
13
|
+
|
|
12
14
|
class MLP(nn.Module):
|
|
13
15
|
"""A simple MLP."""
|
|
14
16
|
|
|
@@ -43,8 +45,18 @@ def batch_iterate(batch_size, X, y):
|
|
|
43
45
|
yield X[ids], y[ids]
|
|
44
46
|
|
|
45
47
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
+
fds = None # Cache FederatedDataset
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def load_data(partition_id: int, num_partitions: int):
|
|
52
|
+
# Only initialize `FederatedDataset` once
|
|
53
|
+
global fds
|
|
54
|
+
if fds is None:
|
|
55
|
+
partitioner = IidPartitioner(num_partitions=num_partitions)
|
|
56
|
+
fds = FederatedDataset(
|
|
57
|
+
dataset="ylecun/mnist",
|
|
58
|
+
partitioners={"train": partitioner},
|
|
59
|
+
)
|
|
48
60
|
partition = fds.load_partition(partition_id)
|
|
49
61
|
partition_splits = partition.train_test_split(test_size=0.2, seed=42)
|
|
50
62
|
|
|
@@ -6,9 +6,10 @@ import torch
|
|
|
6
6
|
import torch.nn as nn
|
|
7
7
|
import torch.nn.functional as F
|
|
8
8
|
from torch.utils.data import DataLoader
|
|
9
|
-
from torchvision.datasets import CIFAR10
|
|
10
9
|
from torchvision.transforms import Compose, Normalize, ToTensor
|
|
11
10
|
from flwr_datasets import FederatedDataset
|
|
11
|
+
from flwr_datasets.partitioner import IidPartitioner
|
|
12
|
+
|
|
12
13
|
|
|
13
14
|
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
14
15
|
|
|
@@ -34,9 +35,19 @@ class Net(nn.Module):
|
|
|
34
35
|
return self.fc3(x)
|
|
35
36
|
|
|
36
37
|
|
|
37
|
-
|
|
38
|
+
fds = None # Cache FederatedDataset
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def load_data(partition_id: int, num_partitions: int):
|
|
38
42
|
"""Load partition CIFAR10 data."""
|
|
39
|
-
|
|
43
|
+
# Only initialize `FederatedDataset` once
|
|
44
|
+
global fds
|
|
45
|
+
if fds is None:
|
|
46
|
+
partitioner = IidPartitioner(num_partitions=num_partitions)
|
|
47
|
+
fds = FederatedDataset(
|
|
48
|
+
dataset="uoft-cs/cifar10",
|
|
49
|
+
partitioners={"train": partitioner},
|
|
50
|
+
)
|
|
40
51
|
partition = fds.load_partition(partition_id)
|
|
41
52
|
# Divide data on each node: 80% train, 20% test
|
|
42
53
|
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
|
|
@@ -4,11 +4,13 @@ import os
|
|
|
4
4
|
|
|
5
5
|
import tensorflow as tf
|
|
6
6
|
from flwr_datasets import FederatedDataset
|
|
7
|
+
from flwr_datasets.partitioner import IidPartitioner
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
# Make TensorFlow log less verbose
|
|
10
11
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
11
12
|
|
|
13
|
+
|
|
12
14
|
def load_model():
|
|
13
15
|
# Load model and data (MobileNetV2, CIFAR-10)
|
|
14
16
|
model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None)
|
|
@@ -16,9 +18,19 @@ def load_model():
|
|
|
16
18
|
return model
|
|
17
19
|
|
|
18
20
|
|
|
21
|
+
fds = None # Cache FederatedDataset
|
|
22
|
+
|
|
23
|
+
|
|
19
24
|
def load_data(partition_id, num_partitions):
|
|
20
25
|
# Download and partition dataset
|
|
21
|
-
|
|
26
|
+
# Only initialize `FederatedDataset` once
|
|
27
|
+
global fds
|
|
28
|
+
if fds is None:
|
|
29
|
+
partitioner = IidPartitioner(num_partitions=num_partitions)
|
|
30
|
+
fds = FederatedDataset(
|
|
31
|
+
dataset="uoft-cs/cifar10",
|
|
32
|
+
partitioners={"train": partitioner},
|
|
33
|
+
)
|
|
22
34
|
partition = fds.load_partition(partition_id, "train")
|
|
23
35
|
partition.set_format("numpy")
|
|
24
36
|
|
|
@@ -6,10 +6,7 @@ build-backend = "hatchling.build"
|
|
|
6
6
|
name = "$package_name"
|
|
7
7
|
version = "1.0.0"
|
|
8
8
|
description = ""
|
|
9
|
-
|
|
10
|
-
{ name = "The Flower Authors", email = "hello@flower.ai" },
|
|
11
|
-
]
|
|
12
|
-
license = { text = "Apache License (2.0)" }
|
|
9
|
+
license = "Apache-2.0"
|
|
13
10
|
dependencies = [
|
|
14
11
|
"flwr[simulation]>=1.9.0,<2.0",
|
|
15
12
|
"flwr-datasets>=0.1.0,<1.0.0",
|
|
@@ -25,18 +22,18 @@ dependencies = [
|
|
|
25
22
|
[tool.hatch.build.targets.wheel]
|
|
26
23
|
packages = ["."]
|
|
27
24
|
|
|
28
|
-
[
|
|
25
|
+
[tool.flwr.app]
|
|
29
26
|
publisher = "$username"
|
|
30
27
|
|
|
31
|
-
[
|
|
28
|
+
[tool.flwr.app.components]
|
|
32
29
|
serverapp = "$import_name.app:server"
|
|
33
30
|
clientapp = "$import_name.app:client"
|
|
34
31
|
|
|
35
|
-
[
|
|
36
|
-
|
|
32
|
+
[tool.flwr.app.config]
|
|
33
|
+
num-server-rounds = 3
|
|
37
34
|
|
|
38
|
-
[
|
|
39
|
-
|
|
35
|
+
[tool.flwr.federations]
|
|
36
|
+
default = "local-simulation"
|
|
40
37
|
|
|
41
|
-
[
|
|
42
|
-
|
|
38
|
+
[tool.flwr.federations.local-simulation]
|
|
39
|
+
options.num-supernodes = 10
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "$package_name"
|
|
7
|
+
version = "1.0.0"
|
|
8
|
+
description = ""
|
|
9
|
+
license = "Apache-2.0"
|
|
10
|
+
dependencies = [
|
|
11
|
+
"flwr[simulation]>=1.9.0,<2.0",
|
|
12
|
+
"flwr-datasets>=0.0.2,<1.0.0",
|
|
13
|
+
"torch==2.2.1",
|
|
14
|
+
"transformers>=4.30.0,<5.0",
|
|
15
|
+
"evaluate>=0.4.0,<1.0",
|
|
16
|
+
"datasets>=2.0.0, <3.0",
|
|
17
|
+
"scikit-learn>=1.3.1, <2.0",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
[tool.hatch.build.targets.wheel]
|
|
21
|
+
packages = ["."]
|
|
22
|
+
|
|
23
|
+
[tool.flwr.app]
|
|
24
|
+
publisher = "$username"
|
|
25
|
+
|
|
26
|
+
[tool.flwr.app.components]
|
|
27
|
+
serverapp = "$import_name.server_app:app"
|
|
28
|
+
clientapp = "$import_name.client_app:app"
|
|
29
|
+
|
|
30
|
+
[tool.flwr.app.config]
|
|
31
|
+
num-server-rounds = 3
|
|
32
|
+
local-epochs = 1
|
|
33
|
+
|
|
34
|
+
[tool.flwr.federations]
|
|
35
|
+
default = "localhost"
|
|
36
|
+
|
|
37
|
+
[tool.flwr.federations.localhost]
|
|
38
|
+
options.num-supernodes = 10
|
|
@@ -6,23 +6,29 @@ build-backend = "hatchling.build"
|
|
|
6
6
|
name = "$package_name"
|
|
7
7
|
version = "1.0.0"
|
|
8
8
|
description = ""
|
|
9
|
-
|
|
10
|
-
{ name = "The Flower Authors", email = "hello@flower.ai" },
|
|
11
|
-
]
|
|
12
|
-
license = {text = "Apache License (2.0)"}
|
|
9
|
+
license = "Apache-2.0"
|
|
13
10
|
dependencies = [
|
|
14
11
|
"flwr[simulation]>=1.9.0,<2.0",
|
|
15
|
-
"jax==0.4.
|
|
16
|
-
"jaxlib==0.4.
|
|
17
|
-
"scikit-learn==1.
|
|
12
|
+
"jax==0.4.13",
|
|
13
|
+
"jaxlib==0.4.13",
|
|
14
|
+
"scikit-learn==1.3.2",
|
|
18
15
|
]
|
|
19
16
|
|
|
20
17
|
[tool.hatch.build.targets.wheel]
|
|
21
18
|
packages = ["."]
|
|
22
19
|
|
|
23
|
-
[
|
|
20
|
+
[tool.flwr.app]
|
|
24
21
|
publisher = "$username"
|
|
25
22
|
|
|
26
|
-
[
|
|
27
|
-
serverapp = "$import_name.
|
|
28
|
-
clientapp = "$import_name.
|
|
23
|
+
[tool.flwr.app.components]
|
|
24
|
+
serverapp = "$import_name.server_app:app"
|
|
25
|
+
clientapp = "$import_name.client_app:app"
|
|
26
|
+
|
|
27
|
+
[tool.flwr.app.config]
|
|
28
|
+
num-server-rounds = 3
|
|
29
|
+
|
|
30
|
+
[tool.flwr.federations]
|
|
31
|
+
default = "local-simulation"
|
|
32
|
+
|
|
33
|
+
[tool.flwr.federations.local-simulation]
|
|
34
|
+
options.num-supernodes = 10
|
|
@@ -6,10 +6,7 @@ build-backend = "hatchling.build"
|
|
|
6
6
|
name = "$package_name"
|
|
7
7
|
version = "1.0.0"
|
|
8
8
|
description = ""
|
|
9
|
-
|
|
10
|
-
{ name = "The Flower Authors", email = "hello@flower.ai" },
|
|
11
|
-
]
|
|
12
|
-
license = { text = "Apache License (2.0)" }
|
|
9
|
+
license = "Apache-2.0"
|
|
13
10
|
dependencies = [
|
|
14
11
|
"flwr[simulation]>=1.9.0,<2.0",
|
|
15
12
|
"flwr-datasets[vision]>=0.0.2,<1.0.0",
|
|
@@ -20,15 +17,23 @@ dependencies = [
|
|
|
20
17
|
[tool.hatch.build.targets.wheel]
|
|
21
18
|
packages = ["."]
|
|
22
19
|
|
|
23
|
-
[
|
|
20
|
+
[tool.flwr.app]
|
|
24
21
|
publisher = "$username"
|
|
25
22
|
|
|
26
|
-
[
|
|
27
|
-
serverapp = "$import_name.
|
|
28
|
-
clientapp = "$import_name.
|
|
23
|
+
[tool.flwr.app.components]
|
|
24
|
+
serverapp = "$import_name.server_app:app"
|
|
25
|
+
clientapp = "$import_name.client_app:app"
|
|
26
|
+
|
|
27
|
+
[tool.flwr.app.config]
|
|
28
|
+
num-server-rounds = 3
|
|
29
|
+
local-epochs = 1
|
|
30
|
+
num-layers = 2
|
|
31
|
+
hidden-dim = 32
|
|
32
|
+
batch-size = 256
|
|
33
|
+
lr = 0.1
|
|
29
34
|
|
|
30
|
-
[
|
|
31
|
-
|
|
35
|
+
[tool.flwr.federations]
|
|
36
|
+
default = "local-simulation"
|
|
32
37
|
|
|
33
|
-
[
|
|
34
|
-
num =
|
|
38
|
+
[tool.flwr.federations.local-simulation]
|
|
39
|
+
options.num-supernodes = 10
|
|
@@ -6,10 +6,7 @@ build-backend = "hatchling.build"
|
|
|
6
6
|
name = "$package_name"
|
|
7
7
|
version = "1.0.0"
|
|
8
8
|
description = ""
|
|
9
|
-
|
|
10
|
-
{ name = "The Flower Authors", email = "hello@flower.ai" },
|
|
11
|
-
]
|
|
12
|
-
license = { text = "Apache License (2.0)" }
|
|
9
|
+
license = "Apache-2.0"
|
|
13
10
|
dependencies = [
|
|
14
11
|
"flwr[simulation]>=1.9.0,<2.0",
|
|
15
12
|
"numpy>=1.21.0",
|
|
@@ -18,15 +15,18 @@ dependencies = [
|
|
|
18
15
|
[tool.hatch.build.targets.wheel]
|
|
19
16
|
packages = ["."]
|
|
20
17
|
|
|
21
|
-
[
|
|
18
|
+
[tool.flwr.app]
|
|
22
19
|
publisher = "$username"
|
|
23
20
|
|
|
24
|
-
[
|
|
25
|
-
serverapp = "$import_name.
|
|
26
|
-
clientapp = "$import_name.
|
|
21
|
+
[tool.flwr.app.components]
|
|
22
|
+
serverapp = "$import_name.server_app:app"
|
|
23
|
+
clientapp = "$import_name.client_app:app"
|
|
24
|
+
|
|
25
|
+
[tool.flwr.app.config]
|
|
26
|
+
num-server-rounds = 3
|
|
27
27
|
|
|
28
|
-
[
|
|
29
|
-
|
|
28
|
+
[tool.flwr.federations]
|
|
29
|
+
default = "local-simulation"
|
|
30
30
|
|
|
31
|
-
[
|
|
32
|
-
num =
|
|
31
|
+
[tool.flwr.federations.local-simulation]
|
|
32
|
+
options.num-supernodes = 10
|