flwr-nightly 1.10.0.dev20240707__py3-none-any.whl → 1.11.0.dev20240724__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (99) hide show
  1. flwr/cli/build.py +16 -2
  2. flwr/cli/config_utils.py +47 -27
  3. flwr/cli/install.py +17 -1
  4. flwr/cli/new/new.py +32 -21
  5. flwr/cli/new/templates/app/code/{client.hf.py.tpl → client.huggingface.py.tpl} +15 -5
  6. flwr/cli/new/templates/app/code/client.jax.py.tpl +2 -1
  7. flwr/cli/new/templates/app/code/client.mlx.py.tpl +36 -13
  8. flwr/cli/new/templates/app/code/client.numpy.py.tpl +2 -1
  9. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +16 -5
  10. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +6 -3
  11. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +25 -5
  12. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +22 -19
  13. flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +5 -3
  14. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +1 -1
  15. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
  16. flwr/cli/new/templates/app/code/server.jax.py.tpl +16 -8
  17. flwr/cli/new/templates/app/code/server.mlx.py.tpl +12 -7
  18. flwr/cli/new/templates/app/code/server.numpy.py.tpl +16 -8
  19. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +15 -13
  20. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -10
  21. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +16 -13
  22. flwr/cli/new/templates/app/code/{task.hf.py.tpl → task.huggingface.py.tpl} +14 -2
  23. flwr/cli/new/templates/app/code/task.mlx.py.tpl +14 -2
  24. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -3
  25. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +13 -1
  26. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +9 -12
  27. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
  28. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +17 -11
  29. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +17 -12
  30. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +12 -12
  31. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +13 -12
  32. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +12 -12
  33. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +15 -12
  34. flwr/cli/run/run.py +133 -54
  35. flwr/client/app.py +56 -24
  36. flwr/client/client_app.py +28 -8
  37. flwr/client/grpc_adapter_client/connection.py +3 -2
  38. flwr/client/grpc_client/connection.py +3 -2
  39. flwr/client/grpc_rere_client/connection.py +17 -6
  40. flwr/client/message_handler/message_handler.py +1 -1
  41. flwr/client/node_state.py +59 -12
  42. flwr/client/node_state_tests.py +4 -3
  43. flwr/client/rest_client/connection.py +19 -8
  44. flwr/client/supernode/app.py +39 -39
  45. flwr/client/typing.py +2 -2
  46. flwr/common/config.py +92 -2
  47. flwr/common/constant.py +3 -0
  48. flwr/common/context.py +24 -9
  49. flwr/common/logger.py +25 -0
  50. flwr/common/object_ref.py +84 -21
  51. flwr/common/serde.py +45 -0
  52. flwr/common/telemetry.py +17 -0
  53. flwr/common/typing.py +5 -0
  54. flwr/proto/common_pb2.py +36 -0
  55. flwr/proto/common_pb2.pyi +121 -0
  56. flwr/proto/common_pb2_grpc.py +4 -0
  57. flwr/proto/common_pb2_grpc.pyi +4 -0
  58. flwr/proto/driver_pb2.py +24 -19
  59. flwr/proto/driver_pb2.pyi +21 -1
  60. flwr/proto/exec_pb2.py +20 -11
  61. flwr/proto/exec_pb2.pyi +41 -1
  62. flwr/proto/run_pb2.py +12 -7
  63. flwr/proto/run_pb2.pyi +22 -1
  64. flwr/proto/task_pb2.py +7 -8
  65. flwr/server/__init__.py +2 -0
  66. flwr/server/compat/legacy_context.py +5 -4
  67. flwr/server/driver/grpc_driver.py +82 -140
  68. flwr/server/run_serverapp.py +40 -18
  69. flwr/server/server_app.py +56 -10
  70. flwr/server/serverapp_components.py +52 -0
  71. flwr/server/superlink/driver/driver_servicer.py +18 -3
  72. flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
  73. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
  74. flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
  75. flwr/server/superlink/fleet/vce/backend/raybackend.py +10 -10
  76. flwr/server/superlink/fleet/vce/vce_api.py +149 -117
  77. flwr/server/superlink/state/in_memory_state.py +11 -3
  78. flwr/server/superlink/state/sqlite_state.py +23 -8
  79. flwr/server/superlink/state/state.py +7 -2
  80. flwr/server/typing.py +2 -0
  81. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
  82. flwr/simulation/__init__.py +1 -1
  83. flwr/simulation/app.py +4 -3
  84. flwr/simulation/ray_transport/ray_actor.py +15 -19
  85. flwr/simulation/ray_transport/ray_client_proxy.py +22 -9
  86. flwr/simulation/run_simulation.py +269 -70
  87. flwr/superexec/app.py +17 -11
  88. flwr/superexec/deployment.py +111 -35
  89. flwr/superexec/exec_grpc.py +5 -1
  90. flwr/superexec/exec_servicer.py +6 -1
  91. flwr/superexec/executor.py +21 -0
  92. flwr/superexec/simulation.py +181 -0
  93. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/METADATA +3 -2
  94. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/RECORD +97 -91
  95. flwr/cli/new/templates/app/code/server.hf.py.tpl +0 -17
  96. flwr/cli/new/templates/app/pyproject.hf.toml.tpl +0 -37
  97. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/LICENSE +0 -0
  98. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/WHEEL +0 -0
  99. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/entry_points.txt +0 -0
@@ -9,13 +9,13 @@ from hydra import compose, initialize
9
9
  from hydra.utils import instantiate
10
10
 
11
11
  from flwr.client import ClientApp
12
- from flwr.common import ndarrays_to_parameters
13
- from flwr.server import ServerApp, ServerConfig
12
+ from flwr.common import Context, ndarrays_to_parameters
13
+ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
14
14
 
15
- from $import_name.client import gen_client_fn, get_parameters
15
+ from $import_name.client_app import gen_client_fn, get_parameters
16
16
  from $import_name.dataset import get_tokenizer_and_data_collator_and_propt_formatting
17
17
  from $import_name.models import get_model
18
- from $import_name.server import fit_weighted_average, get_evaluate_fn, get_on_fit_config
18
+ from $import_name.server_app import fit_weighted_average, get_evaluate_fn, get_on_fit_config
19
19
 
20
20
  # Avoid warnings
21
21
  warnings.filterwarnings("ignore", category=UserWarning)
@@ -67,20 +67,23 @@ init_model = get_model(cfg.model)
67
67
  init_model_parameters = get_parameters(init_model)
68
68
  init_model_parameters = ndarrays_to_parameters(init_model_parameters)
69
69
 
70
- # Instantiate strategy according to config. Here we pass other arguments
71
- # that are only defined at runtime.
72
- strategy = instantiate(
73
- cfg.strategy,
74
- on_fit_config_fn=get_on_fit_config(),
75
- fit_metrics_aggregation_fn=fit_weighted_average,
76
- initial_parameters=init_model_parameters,
77
- evaluate_fn=get_evaluate_fn(
78
- cfg.model, cfg.train.save_every_round, cfg_static.num_rounds, save_path
79
- ),
80
- )
70
+ def server_fn(context: Context):
71
+ # Instantiate strategy according to config. Here we pass other arguments
72
+ # that are only defined at runtime.
73
+ strategy = instantiate(
74
+ cfg.strategy,
75
+ on_fit_config_fn=get_on_fit_config(),
76
+ fit_metrics_aggregation_fn=fit_weighted_average,
77
+ initial_parameters=init_model_parameters,
78
+ evaluate_fn=get_evaluate_fn(
79
+ cfg.model, cfg.train.save_every_round, cfg_static.num_rounds, save_path
80
+ ),
81
+ )
82
+
83
+ config = ServerConfig(num_rounds=cfg_static.num_rounds)
84
+
85
+ return ServerAppComponents(strategy=strategy, config=config)
86
+
81
87
 
82
88
  # ServerApp for Flower Next
83
- server = ServerApp(
84
- config=ServerConfig(num_rounds=cfg_static.num_rounds),
85
- strategy=strategy,
86
- )
89
+ server = ServerApp(server_fn=server_fn)
@@ -10,6 +10,7 @@ from transformers import TrainingArguments
10
10
  from trl import SFTTrainer
11
11
 
12
12
  from flwr.client import NumPyClient
13
+ from flwr.common import Context
13
14
  from flwr.common.typing import NDArrays, Scalar
14
15
  from $import_name.dataset import reformat
15
16
  from $import_name.models import cosine_annealing, get_model
@@ -102,13 +103,14 @@ def gen_client_fn(
102
103
  model_cfg: DictConfig,
103
104
  train_cfg: DictConfig,
104
105
  save_path: str,
105
- ) -> Callable[[str], FlowerClient]: # pylint: disable=too-many-arguments
106
+ ) -> Callable[[Context], FlowerClient]: # pylint: disable=too-many-arguments
106
107
  """Generate the client function that creates the Flower Clients."""
107
108
 
108
- def client_fn(cid: str) -> FlowerClient:
109
+ def client_fn(context: Context) -> FlowerClient:
109
110
  """Create a Flower client representing a single organization."""
110
111
  # Let's get the partition corresponding to the i-th client
111
- client_trainset = fds.load_partition(int(cid), "train")
112
+ partition_id = context.node_config["partition-id"]
113
+ client_trainset = fds.load_partition(partition_id, "train")
112
114
  client_trainset = reformat(client_trainset, llm_task="$llm_challenge_str")
113
115
 
114
116
  return FlowerClient(
@@ -1,6 +1,6 @@
1
1
  """$project_name: A Flower / FlowerTune app."""
2
2
 
3
- from $import_name.client import set_parameters
3
+ from $import_name.client_app import set_parameters
4
4
  from $import_name.models import get_model
5
5
 
6
6
 
@@ -0,0 +1,23 @@
1
+ """$project_name: A Flower / HuggingFace Transformers app."""
2
+
3
+ from flwr.common import Context
4
+ from flwr.server.strategy import FedAvg
5
+ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
6
+
7
+
8
+ def server_fn(context: Context):
9
+ # Read from config
10
+ num_rounds = context.run_config["num-server-rounds"]
11
+
12
+ # Define strategy
13
+ strategy = FedAvg(
14
+ fraction_fit=1.0,
15
+ fraction_evaluate=1.0,
16
+ )
17
+ config = ServerConfig(num_rounds=num_rounds)
18
+
19
+ return ServerAppComponents(strategy=strategy, config=config)
20
+
21
+
22
+ # Create ServerApp
23
+ app = ServerApp(server_fn=server_fn)
@@ -1,12 +1,20 @@
1
1
  """$project_name: A Flower / JAX app."""
2
2
 
3
- import flwr as fl
3
+ from flwr.common import Context
4
+ from flwr.server.strategy import FedAvg
5
+ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
4
6
 
5
- # Configure the strategy
6
- strategy = fl.server.strategy.FedAvg()
7
7
 
8
- # Flower ServerApp
9
- app = fl.server.ServerApp(
10
- config=fl.server.ServerConfig(num_rounds=3),
11
- strategy=strategy,
12
- )
8
+ def server_fn(context: Context):
9
+ # Read from config
10
+ num_rounds = context.run_config["num-server-rounds"]
11
+
12
+ # Define strategy
13
+ strategy = FedAvg()
14
+ config = ServerConfig(num_rounds=num_rounds)
15
+
16
+ return ServerAppComponents(strategy=strategy, config=config)
17
+
18
+
19
+ # Create ServerApp
20
+ app = ServerApp(server_fn=server_fn)
@@ -1,15 +1,20 @@
1
1
  """$project_name: A Flower / MLX app."""
2
2
 
3
- from flwr.server import ServerApp, ServerConfig
3
+ from flwr.common import Context
4
+ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
4
5
  from flwr.server.strategy import FedAvg
5
6
 
6
7
 
7
- # Define strategy
8
- strategy = FedAvg()
8
+ def server_fn(context: Context):
9
+ # Read from config
10
+ num_rounds = context.run_config["num-server-rounds"]
11
+
12
+ # Define strategy
13
+ strategy = FedAvg()
14
+ config = ServerConfig(num_rounds=num_rounds)
15
+
16
+ return ServerAppComponents(strategy=strategy, config=config)
9
17
 
10
18
 
11
19
  # Create ServerApp
12
- app = ServerApp(
13
- config=ServerConfig(num_rounds=3),
14
- strategy=strategy,
15
- )
20
+ app = ServerApp(server_fn=server_fn)
@@ -1,12 +1,20 @@
1
1
  """$project_name: A Flower / NumPy app."""
2
2
 
3
- import flwr as fl
3
+ from flwr.common import Context
4
+ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
+ from flwr.server.strategy import FedAvg
4
6
 
5
- # Configure the strategy
6
- strategy = fl.server.strategy.FedAvg()
7
7
 
8
- # Flower ServerApp
9
- app = fl.server.ServerApp(
10
- config=fl.server.ServerConfig(num_rounds=1),
11
- strategy=strategy,
12
- )
8
+ def server_fn(context: Context):
9
+ # Read from config
10
+ num_rounds = context.run_config["num-server-rounds"]
11
+
12
+ # Define strategy
13
+ strategy = FedAvg()
14
+ config = ServerConfig(num_rounds=num_rounds)
15
+
16
+ return ServerAppComponents(strategy=strategy, config=config)
17
+
18
+
19
+ # Create ServerApp
20
+ app = ServerApp(server_fn=server_fn)
@@ -1,7 +1,7 @@
1
1
  """$project_name: A Flower / PyTorch app."""
2
2
 
3
- from flwr.common import ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerConfig
3
+ from flwr.common import Context, ndarrays_to_parameters
4
+ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
5
  from flwr.server.strategy import FedAvg
6
6
 
7
7
  from $import_name.task import Net, get_weights
@@ -11,18 +11,20 @@ from $import_name.task import Net, get_weights
11
11
  ndarrays = get_weights(Net())
12
12
  parameters = ndarrays_to_parameters(ndarrays)
13
13
 
14
+ def server_fn(context: Context):
15
+ # Read from config
16
+ num_rounds = context.run_config["num-server-rounds"]
14
17
 
15
- # Define strategy
16
- strategy = FedAvg(
17
- fraction_fit=1.0,
18
- fraction_evaluate=1.0,
19
- min_available_clients=2,
20
- initial_parameters=parameters,
21
- )
18
+ # Define strategy
19
+ strategy = FedAvg(
20
+ fraction_fit=1.0,
21
+ fraction_evaluate=1.0,
22
+ min_available_clients=2,
23
+ initial_parameters=parameters,
24
+ )
25
+ config = ServerConfig(num_rounds=num_rounds)
22
26
 
27
+ return ServerAppComponents(strategy=strategy, config=config)
23
28
 
24
29
  # Create ServerApp
25
- app = ServerApp(
26
- config=ServerConfig(num_rounds=3),
27
- strategy=strategy,
28
- )
30
+ app = ServerApp(server_fn=server_fn)
@@ -1,17 +1,24 @@
1
1
  """$project_name: A Flower / Scikit-Learn app."""
2
2
 
3
- from flwr.server import ServerApp, ServerConfig
3
+ from flwr.common import Context
4
+ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
4
5
  from flwr.server.strategy import FedAvg
5
6
 
6
7
 
7
- strategy = FedAvg(
8
- fraction_fit=1.0,
9
- fraction_evaluate=1.0,
10
- min_available_clients=2,
11
- )
8
+ def server_fn(context: Context):
9
+ # Read from config
10
+ num_rounds = context.run_config["num-server-rounds"]
11
+
12
+ # Define strategy
13
+ strategy = FedAvg(
14
+ fraction_fit=1.0,
15
+ fraction_evaluate=1.0,
16
+ min_available_clients=2,
17
+ )
18
+ config = ServerConfig(num_rounds=num_rounds)
19
+
20
+ return ServerAppComponents(strategy=strategy, config=config)
21
+
12
22
 
13
23
  # Create ServerApp
14
- app = ServerApp(
15
- config=ServerConfig(num_rounds=3),
16
- strategy=strategy,
17
- )
24
+ app = ServerApp(server_fn=server_fn)
@@ -1,7 +1,7 @@
1
1
  """$project_name: A Flower / TensorFlow app."""
2
2
 
3
- from flwr.common import ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerConfig
3
+ from flwr.common import Context, ndarrays_to_parameters
4
+ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
5
  from flwr.server.strategy import FedAvg
6
6
 
7
7
  from $import_name.task import load_model
@@ -11,17 +11,20 @@ config = ServerConfig(num_rounds=3)
11
11
 
12
12
  parameters = ndarrays_to_parameters(load_model().get_weights())
13
13
 
14
- # Define strategy
15
- strategy = FedAvg(
16
- fraction_fit=1.0,
17
- fraction_evaluate=1.0,
18
- min_available_clients=2,
19
- initial_parameters=parameters,
20
- )
14
+ def server_fn(context: Context):
15
+ # Read from config
16
+ num_rounds = context.run_config["num-server-rounds"]
21
17
 
18
+ # Define strategy
19
+ strategy = strategy = FedAvg(
20
+ fraction_fit=1.0,
21
+ fraction_evaluate=1.0,
22
+ min_available_clients=2,
23
+ initial_parameters=parameters,
24
+ )
25
+ config = ServerConfig(num_rounds=num_rounds)
26
+
27
+ return ServerAppComponents(strategy=strategy, config=config)
22
28
 
23
29
  # Create ServerApp
24
- app = ServerApp(
25
- config=config,
26
- strategy=strategy,
27
- )
30
+ app = ServerApp(server_fn=server_fn)
@@ -10,15 +10,27 @@ from torch.utils.data import DataLoader
10
10
  from transformers import AutoTokenizer, DataCollatorWithPadding
11
11
 
12
12
  from flwr_datasets import FederatedDataset
13
+ from flwr_datasets.partitioner import IidPartitioner
14
+
13
15
 
14
16
  warnings.filterwarnings("ignore", category=UserWarning)
15
17
  DEVICE = torch.device("cpu")
16
18
  CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint
17
19
 
18
20
 
19
- def load_data(partition_id, num_clients):
21
+ fds = None # Cache FederatedDataset
22
+
23
+
24
+ def load_data(partition_id: int, num_partitions: int):
20
25
  """Load IMDB data (training and eval)"""
21
- fds = FederatedDataset(dataset="imdb", partitioners={"train": num_clients})
26
+ # Only initialize `FederatedDataset` once
27
+ global fds
28
+ if fds is None:
29
+ partitioner = IidPartitioner(num_partitions=num_partitions)
30
+ fds = FederatedDataset(
31
+ dataset="stanfordnlp/imdb",
32
+ partitioners={"train": partitioner},
33
+ )
22
34
  partition = fds.load_partition(partition_id)
23
35
  # Divide data: 80% train, 20% test
24
36
  partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
@@ -5,10 +5,12 @@ import mlx.nn as nn
5
5
  import numpy as np
6
6
  from datasets.utils.logging import disable_progress_bar
7
7
  from flwr_datasets import FederatedDataset
8
+ from flwr_datasets.partitioner import IidPartitioner
8
9
 
9
10
 
10
11
  disable_progress_bar()
11
12
 
13
+
12
14
  class MLP(nn.Module):
13
15
  """A simple MLP."""
14
16
 
@@ -43,8 +45,18 @@ def batch_iterate(batch_size, X, y):
43
45
  yield X[ids], y[ids]
44
46
 
45
47
 
46
- def load_data(partition_id, num_clients):
47
- fds = FederatedDataset(dataset="mnist", partitioners={"train": num_clients})
48
+ fds = None # Cache FederatedDataset
49
+
50
+
51
+ def load_data(partition_id: int, num_partitions: int):
52
+ # Only initialize `FederatedDataset` once
53
+ global fds
54
+ if fds is None:
55
+ partitioner = IidPartitioner(num_partitions=num_partitions)
56
+ fds = FederatedDataset(
57
+ dataset="ylecun/mnist",
58
+ partitioners={"train": partitioner},
59
+ )
48
60
  partition = fds.load_partition(partition_id)
49
61
  partition_splits = partition.train_test_split(test_size=0.2, seed=42)
50
62
 
@@ -6,9 +6,10 @@ import torch
6
6
  import torch.nn as nn
7
7
  import torch.nn.functional as F
8
8
  from torch.utils.data import DataLoader
9
- from torchvision.datasets import CIFAR10
10
9
  from torchvision.transforms import Compose, Normalize, ToTensor
11
10
  from flwr_datasets import FederatedDataset
11
+ from flwr_datasets.partitioner import IidPartitioner
12
+
12
13
 
13
14
  DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
15
 
@@ -34,9 +35,19 @@ class Net(nn.Module):
34
35
  return self.fc3(x)
35
36
 
36
37
 
37
- def load_data(partition_id, num_partitions):
38
+ fds = None # Cache FederatedDataset
39
+
40
+
41
+ def load_data(partition_id: int, num_partitions: int):
38
42
  """Load partition CIFAR10 data."""
39
- fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions})
43
+ # Only initialize `FederatedDataset` once
44
+ global fds
45
+ if fds is None:
46
+ partitioner = IidPartitioner(num_partitions=num_partitions)
47
+ fds = FederatedDataset(
48
+ dataset="uoft-cs/cifar10",
49
+ partitioners={"train": partitioner},
50
+ )
40
51
  partition = fds.load_partition(partition_id)
41
52
  # Divide data on each node: 80% train, 20% test
42
53
  partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
@@ -4,11 +4,13 @@ import os
4
4
 
5
5
  import tensorflow as tf
6
6
  from flwr_datasets import FederatedDataset
7
+ from flwr_datasets.partitioner import IidPartitioner
7
8
 
8
9
 
9
10
  # Make TensorFlow log less verbose
10
11
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
11
12
 
13
+
12
14
  def load_model():
13
15
  # Load model and data (MobileNetV2, CIFAR-10)
14
16
  model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None)
@@ -16,9 +18,19 @@ def load_model():
16
18
  return model
17
19
 
18
20
 
21
+ fds = None # Cache FederatedDataset
22
+
23
+
19
24
  def load_data(partition_id, num_partitions):
20
25
  # Download and partition dataset
21
- fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions})
26
+ # Only initialize `FederatedDataset` once
27
+ global fds
28
+ if fds is None:
29
+ partitioner = IidPartitioner(num_partitions=num_partitions)
30
+ fds = FederatedDataset(
31
+ dataset="uoft-cs/cifar10",
32
+ partitioners={"train": partitioner},
33
+ )
22
34
  partition = fds.load_partition(partition_id, "train")
23
35
  partition.set_format("numpy")
24
36
 
@@ -6,10 +6,7 @@ build-backend = "hatchling.build"
6
6
  name = "$package_name"
7
7
  version = "1.0.0"
8
8
  description = ""
9
- authors = [
10
- { name = "The Flower Authors", email = "hello@flower.ai" },
11
- ]
12
- license = { text = "Apache License (2.0)" }
9
+ license = "Apache-2.0"
13
10
  dependencies = [
14
11
  "flwr[simulation]>=1.9.0,<2.0",
15
12
  "flwr-datasets>=0.1.0,<1.0.0",
@@ -25,18 +22,18 @@ dependencies = [
25
22
  [tool.hatch.build.targets.wheel]
26
23
  packages = ["."]
27
24
 
28
- [flower]
25
+ [tool.flwr.app]
29
26
  publisher = "$username"
30
27
 
31
- [flower.components]
28
+ [tool.flwr.app.components]
32
29
  serverapp = "$import_name.app:server"
33
30
  clientapp = "$import_name.app:client"
34
31
 
35
- [flower.engine]
36
- name = "simulation"
32
+ [tool.flwr.app.config]
33
+ num-server-rounds = 3
37
34
 
38
- [flower.engine.simulation.supernode]
39
- num = $num_clients
35
+ [tool.flwr.federations]
36
+ default = "local-simulation"
40
37
 
41
- [flower.engine.simulation]
42
- backend_config = { client_resources = { num_cpus = 8, num_gpus = 1.0 } }
38
+ [tool.flwr.federations.local-simulation]
39
+ options.num-supernodes = 10
@@ -0,0 +1,38 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "$package_name"
7
+ version = "1.0.0"
8
+ description = ""
9
+ license = "Apache-2.0"
10
+ dependencies = [
11
+ "flwr[simulation]>=1.9.0,<2.0",
12
+ "flwr-datasets>=0.0.2,<1.0.0",
13
+ "torch==2.2.1",
14
+ "transformers>=4.30.0,<5.0",
15
+ "evaluate>=0.4.0,<1.0",
16
+ "datasets>=2.0.0, <3.0",
17
+ "scikit-learn>=1.3.1, <2.0",
18
+ ]
19
+
20
+ [tool.hatch.build.targets.wheel]
21
+ packages = ["."]
22
+
23
+ [tool.flwr.app]
24
+ publisher = "$username"
25
+
26
+ [tool.flwr.app.components]
27
+ serverapp = "$import_name.server_app:app"
28
+ clientapp = "$import_name.client_app:app"
29
+
30
+ [tool.flwr.app.config]
31
+ num-server-rounds = 3
32
+ local-epochs = 1
33
+
34
+ [tool.flwr.federations]
35
+ default = "localhost"
36
+
37
+ [tool.flwr.federations.localhost]
38
+ options.num-supernodes = 10
@@ -6,23 +6,29 @@ build-backend = "hatchling.build"
6
6
  name = "$package_name"
7
7
  version = "1.0.0"
8
8
  description = ""
9
- authors = [
10
- { name = "The Flower Authors", email = "hello@flower.ai" },
11
- ]
12
- license = {text = "Apache License (2.0)"}
9
+ license = "Apache-2.0"
13
10
  dependencies = [
14
11
  "flwr[simulation]>=1.9.0,<2.0",
15
- "jax==0.4.26",
16
- "jaxlib==0.4.26",
17
- "scikit-learn==1.4.2",
12
+ "jax==0.4.13",
13
+ "jaxlib==0.4.13",
14
+ "scikit-learn==1.3.2",
18
15
  ]
19
16
 
20
17
  [tool.hatch.build.targets.wheel]
21
18
  packages = ["."]
22
19
 
23
- [flower]
20
+ [tool.flwr.app]
24
21
  publisher = "$username"
25
22
 
26
- [flower.components]
27
- serverapp = "$import_name.server:app"
28
- clientapp = "$import_name.client:app"
23
+ [tool.flwr.app.components]
24
+ serverapp = "$import_name.server_app:app"
25
+ clientapp = "$import_name.client_app:app"
26
+
27
+ [tool.flwr.app.config]
28
+ num-server-rounds = 3
29
+
30
+ [tool.flwr.federations]
31
+ default = "local-simulation"
32
+
33
+ [tool.flwr.federations.local-simulation]
34
+ options.num-supernodes = 10
@@ -6,10 +6,7 @@ build-backend = "hatchling.build"
6
6
  name = "$package_name"
7
7
  version = "1.0.0"
8
8
  description = ""
9
- authors = [
10
- { name = "The Flower Authors", email = "hello@flower.ai" },
11
- ]
12
- license = { text = "Apache License (2.0)" }
9
+ license = "Apache-2.0"
13
10
  dependencies = [
14
11
  "flwr[simulation]>=1.9.0,<2.0",
15
12
  "flwr-datasets[vision]>=0.0.2,<1.0.0",
@@ -20,15 +17,23 @@ dependencies = [
20
17
  [tool.hatch.build.targets.wheel]
21
18
  packages = ["."]
22
19
 
23
- [flower]
20
+ [tool.flwr.app]
24
21
  publisher = "$username"
25
22
 
26
- [flower.components]
27
- serverapp = "$import_name.server:app"
28
- clientapp = "$import_name.client:app"
23
+ [tool.flwr.app.components]
24
+ serverapp = "$import_name.server_app:app"
25
+ clientapp = "$import_name.client_app:app"
26
+
27
+ [tool.flwr.app.config]
28
+ num-server-rounds = 3
29
+ local-epochs = 1
30
+ num-layers = 2
31
+ hidden-dim = 32
32
+ batch-size = 256
33
+ lr = 0.1
29
34
 
30
- [flower.engine]
31
- name = "simulation"
35
+ [tool.flwr.federations]
36
+ default = "local-simulation"
32
37
 
33
- [flower.engine.simulation.supernode]
34
- num = 2
38
+ [tool.flwr.federations.local-simulation]
39
+ options.num-supernodes = 10
@@ -6,10 +6,7 @@ build-backend = "hatchling.build"
6
6
  name = "$package_name"
7
7
  version = "1.0.0"
8
8
  description = ""
9
- authors = [
10
- { name = "The Flower Authors", email = "hello@flower.ai" },
11
- ]
12
- license = { text = "Apache License (2.0)" }
9
+ license = "Apache-2.0"
13
10
  dependencies = [
14
11
  "flwr[simulation]>=1.9.0,<2.0",
15
12
  "numpy>=1.21.0",
@@ -18,15 +15,18 @@ dependencies = [
18
15
  [tool.hatch.build.targets.wheel]
19
16
  packages = ["."]
20
17
 
21
- [flower]
18
+ [tool.flwr.app]
22
19
  publisher = "$username"
23
20
 
24
- [flower.components]
25
- serverapp = "$import_name.server:app"
26
- clientapp = "$import_name.client:app"
21
+ [tool.flwr.app.components]
22
+ serverapp = "$import_name.server_app:app"
23
+ clientapp = "$import_name.client_app:app"
24
+
25
+ [tool.flwr.app.config]
26
+ num-server-rounds = 3
27
27
 
28
- [flower.engine]
29
- name = "simulation"
28
+ [tool.flwr.federations]
29
+ default = "local-simulation"
30
30
 
31
- [flower.engine.simulation.supernode]
32
- num = 2
31
+ [tool.flwr.federations.local-simulation]
32
+ options.num-supernodes = 10