flwr-nightly 1.10.0.dev20240721__py3-none-any.whl → 1.10.0.dev20240723__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (74) hide show
  1. flwr/cli/config_utils.py +20 -18
  2. flwr/cli/new/new.py +1 -1
  3. flwr/cli/new/templates/app/code/{client.hf.py.tpl → client.huggingface.py.tpl} +7 -5
  4. flwr/cli/new/templates/app/code/client.mlx.py.tpl +28 -10
  5. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +7 -5
  6. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +2 -2
  7. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +17 -7
  8. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +20 -17
  9. flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +5 -3
  10. flwr/cli/new/templates/app/code/{server.hf.py.tpl → server.huggingface.py.tpl} +2 -1
  11. flwr/cli/new/templates/app/code/server.jax.py.tpl +2 -1
  12. flwr/cli/new/templates/app/code/server.mlx.py.tpl +2 -1
  13. flwr/cli/new/templates/app/code/server.numpy.py.tpl +2 -1
  14. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
  15. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +2 -1
  16. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +1 -1
  17. flwr/cli/new/templates/app/code/{task.hf.py.tpl → task.huggingface.py.tpl} +13 -1
  18. flwr/cli/new/templates/app/code/task.mlx.py.tpl +14 -1
  19. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +13 -2
  20. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +13 -1
  21. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  22. flwr/cli/new/templates/app/{pyproject.hf.toml.tpl → pyproject.huggingface.toml.tpl} +2 -2
  23. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  24. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +6 -6
  25. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  26. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +2 -2
  27. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  28. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +4 -4
  29. flwr/cli/run/run.py +35 -28
  30. flwr/client/app.py +3 -3
  31. flwr/client/grpc_rere_client/connection.py +6 -2
  32. flwr/client/node_state.py +3 -3
  33. flwr/client/rest_client/connection.py +6 -2
  34. flwr/client/supernode/app.py +12 -43
  35. flwr/common/config.py +23 -17
  36. flwr/common/context.py +7 -7
  37. flwr/common/object_ref.py +84 -21
  38. flwr/common/serde.py +45 -0
  39. flwr/common/telemetry.py +17 -0
  40. flwr/common/typing.py +5 -1
  41. flwr/proto/common_pb2.py +13 -1
  42. flwr/proto/common_pb2.pyi +114 -0
  43. flwr/proto/driver_pb2.py +22 -21
  44. flwr/proto/driver_pb2.pyi +7 -4
  45. flwr/proto/exec_pb2.py +18 -13
  46. flwr/proto/exec_pb2.pyi +27 -5
  47. flwr/proto/run_pb2.py +10 -9
  48. flwr/proto/run_pb2.pyi +7 -4
  49. flwr/proto/task_pb2.py +7 -8
  50. flwr/server/compat/legacy_context.py +5 -4
  51. flwr/server/driver/grpc_driver.py +6 -2
  52. flwr/server/run_serverapp.py +3 -5
  53. flwr/server/superlink/driver/driver_servicer.py +14 -3
  54. flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
  55. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
  56. flwr/server/superlink/fleet/vce/vce_api.py +4 -4
  57. flwr/server/superlink/state/in_memory_state.py +2 -2
  58. flwr/server/superlink/state/sqlite_state.py +2 -2
  59. flwr/server/superlink/state/state.py +3 -3
  60. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
  61. flwr/simulation/__init__.py +1 -1
  62. flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
  63. flwr/simulation/run_simulation.py +39 -11
  64. flwr/superexec/app.py +4 -5
  65. flwr/superexec/deployment.py +19 -8
  66. flwr/superexec/exec_grpc.py +3 -2
  67. flwr/superexec/exec_servicer.py +3 -1
  68. flwr/superexec/executor.py +10 -5
  69. flwr/superexec/simulation.py +41 -15
  70. {flwr_nightly-1.10.0.dev20240721.dist-info → flwr_nightly-1.10.0.dev20240723.dist-info}/METADATA +1 -1
  71. {flwr_nightly-1.10.0.dev20240721.dist-info → flwr_nightly-1.10.0.dev20240723.dist-info}/RECORD +74 -74
  72. {flwr_nightly-1.10.0.dev20240721.dist-info → flwr_nightly-1.10.0.dev20240723.dist-info}/LICENSE +0 -0
  73. {flwr_nightly-1.10.0.dev20240721.dist-info → flwr_nightly-1.10.0.dev20240723.dist-info}/WHEEL +0 -0
  74. {flwr_nightly-1.10.0.dev20240721.dist-info → flwr_nightly-1.10.0.dev20240723.dist-info}/entry_points.txt +0 -0
flwr/cli/config_utils.py CHANGED
@@ -17,11 +17,12 @@
17
17
  import zipfile
18
18
  from io import BytesIO
19
19
  from pathlib import Path
20
- from typing import IO, Any, Dict, List, Optional, Tuple, Union
20
+ from typing import IO, Any, Dict, List, Optional, Tuple, Union, get_args
21
21
 
22
22
  import tomli
23
23
 
24
24
  from flwr.common import object_ref
25
+ from flwr.common.typing import UserConfigValue
25
26
 
26
27
 
27
28
  def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]:
@@ -76,6 +77,9 @@ def load_and_validate(
76
77
  A tuple with the optional config in case it exists and is valid
77
78
  and associated errors and warnings.
78
79
  """
80
+ if path is None:
81
+ path = Path.cwd() / "pyproject.toml"
82
+
79
83
  config = load(path)
80
84
 
81
85
  if config is None:
@@ -85,7 +89,7 @@ def load_and_validate(
85
89
  ]
86
90
  return (None, errors, [])
87
91
 
88
- is_valid, errors, warnings = validate(config, check_module)
92
+ is_valid, errors, warnings = validate(config, check_module, path.parent)
89
93
 
90
94
  if not is_valid:
91
95
  return (None, errors, warnings)
@@ -93,14 +97,8 @@ def load_and_validate(
93
97
  return (config, errors, warnings)
94
98
 
95
99
 
96
- def load(path: Optional[Path] = None) -> Optional[Dict[str, Any]]:
100
+ def load(toml_path: Path) -> Optional[Dict[str, Any]]:
97
101
  """Load pyproject.toml and return as dict."""
98
- if path is None:
99
- cur_dir = Path.cwd()
100
- toml_path = cur_dir / "pyproject.toml"
101
- else:
102
- toml_path = path
103
-
104
102
  if not toml_path.is_file():
105
103
  return None
106
104
 
@@ -112,8 +110,11 @@ def _validate_run_config(config_dict: Dict[str, Any], errors: List[str]) -> None
112
110
  for key, value in config_dict.items():
113
111
  if isinstance(value, dict):
114
112
  _validate_run_config(config_dict[key], errors)
115
- elif not isinstance(value, str):
116
- errors.append(f"Config value of key {key} is not of type `str`.")
113
+ elif not isinstance(value, get_args(UserConfigValue)):
114
+ raise ValueError(
115
+ f"The value for key {key} needs to be of type `int`, `float`, "
116
+ "`bool, `str`, or a `dict` of those.",
117
+ )
117
118
 
118
119
 
119
120
  # pylint: disable=too-many-branches
@@ -163,7 +164,9 @@ def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]
163
164
 
164
165
 
165
166
  def validate(
166
- config: Dict[str, Any], check_module: bool = True
167
+ config: Dict[str, Any],
168
+ check_module: bool = True,
169
+ project_dir: Optional[Union[str, Path]] = None,
167
170
  ) -> Tuple[bool, List[str], List[str]]:
168
171
  """Validate pyproject.toml."""
169
172
  is_valid, errors, warnings = validate_fields(config)
@@ -172,16 +175,15 @@ def validate(
172
175
  return False, errors, warnings
173
176
 
174
177
  # Validate serverapp
175
- is_valid, reason = object_ref.validate(
176
- config["tool"]["flwr"]["app"]["components"]["serverapp"], check_module
177
- )
178
+ serverapp_ref = config["tool"]["flwr"]["app"]["components"]["serverapp"]
179
+ is_valid, reason = object_ref.validate(serverapp_ref, check_module, project_dir)
180
+
178
181
  if not is_valid and isinstance(reason, str):
179
182
  return False, [reason], []
180
183
 
181
184
  # Validate clientapp
182
- is_valid, reason = object_ref.validate(
183
- config["tool"]["flwr"]["app"]["components"]["clientapp"], check_module
184
- )
185
+ clientapp_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"]
186
+ is_valid, reason = object_ref.validate(clientapp_ref, check_module, project_dir)
185
187
 
186
188
  if not is_valid and isinstance(reason, str):
187
189
  return False, [reason], []
flwr/cli/new/new.py CHANGED
@@ -38,7 +38,7 @@ class MlFramework(str, Enum):
38
38
  PYTORCH = "PyTorch"
39
39
  TENSORFLOW = "TensorFlow"
40
40
  JAX = "JAX"
41
- HUGGINGFACE = "HF"
41
+ HUGGINGFACE = "HuggingFace"
42
42
  MLX = "MLX"
43
43
  SKLEARN = "sklearn"
44
44
  FLOWERTUNE = "FlowerTune"
@@ -17,10 +17,11 @@ from $import_name.task import (
17
17
 
18
18
  # Flower client
19
19
  class FlowerClient(NumPyClient):
20
- def __init__(self, net, trainloader, testloader):
20
+ def __init__(self, net, trainloader, testloader, local_epochs):
21
21
  self.net = net
22
22
  self.trainloader = trainloader
23
23
  self.testloader = testloader
24
+ self.local_epochs = local_epochs
24
25
 
25
26
  def get_parameters(self, config):
26
27
  return get_weights(self.net)
@@ -33,7 +34,7 @@ class FlowerClient(NumPyClient):
33
34
  train(
34
35
  self.net,
35
36
  self.trainloader,
36
- epochs=int(self.context.run_config["local-epochs"]),
37
+ epochs=self.local_epochs,
37
38
  )
38
39
  return self.get_parameters(config={}), len(self.trainloader), {}
39
40
 
@@ -49,12 +50,13 @@ def client_fn(context: Context):
49
50
  CHECKPOINT, num_labels=2
50
51
  ).to(DEVICE)
51
52
 
52
- partition_id = int(context.node_config["partition-id"])
53
- num_partitions = int(context.node_config["num-partitions"])
53
+ partition_id = context.node_config["partition-id"]
54
+ num_partitions = context.node_config["num-partitions"]
54
55
  trainloader, valloader = load_data(partition_id, num_partitions)
56
+ local_epochs = context.run_config["local-epochs"]
55
57
 
56
58
  # Return Client instance
57
- return FlowerClient(net, trainloader, valloader).to_client()
59
+ return FlowerClient(net, trainloader, valloader, local_epochs).to_client()
58
60
 
59
61
 
60
62
  # Flower ClientApp
@@ -19,13 +19,22 @@ from $import_name.task import (
19
19
 
20
20
  # Define Flower Client and client_fn
21
21
  class FlowerClient(NumPyClient):
22
- def __init__(self, data):
23
- num_layers = int(self.context.run_config["num-layers"])
24
- hidden_dim = int(self.context.run_config["hidden-dim"])
25
- num_classes = 10
26
- batch_size = int(self.context.run_config["batch-size"])
27
- learning_rate = float(self.context.run_config["lr"])
28
- num_epochs = int(self.context.run_config["local-epochs"])
22
+ def __init__(
23
+ self,
24
+ data,
25
+ num_layers,
26
+ hidden_dim,
27
+ num_classes,
28
+ batch_size,
29
+ learning_rate,
30
+ num_epochs,
31
+ ):
32
+ self.num_layers = num_layers
33
+ self.hidden_dim = hidden_dim
34
+ self.num_classes = num_classes
35
+ self.batch_size = batch_size
36
+ self.learning_rate = learning_rate
37
+ self.num_epochs = num_epochs
29
38
 
30
39
  self.train_images, self.train_labels, self.test_images, self.test_labels = data
31
40
  self.model = MLP(
@@ -61,12 +70,21 @@ class FlowerClient(NumPyClient):
61
70
 
62
71
 
63
72
  def client_fn(context: Context):
64
- partition_id = int(context.node_config["partition-id"])
65
- num_partitions = int(context.node_config["num-partitions"])
73
+ partition_id = context.node_config["partition-id"]
74
+ num_partitions = context.node_config["num-partitions"]
66
75
  data = load_data(partition_id, num_partitions)
67
76
 
77
+ num_layers = context.run_config["num-layers"]
78
+ hidden_dim = context.run_config["hidden-dim"]
79
+ num_classes = 10
80
+ batch_size = context.run_config["batch-size"]
81
+ learning_rate = context.run_config["lr"]
82
+ num_epochs = context.run_config["local-epochs"]
83
+
68
84
  # Return Client instance
69
- return FlowerClient(data).to_client()
85
+ return FlowerClient(
86
+ data, num_layers, hidden_dim, num_classes, batch_size, learning_rate, num_epochs
87
+ ).to_client()
70
88
 
71
89
 
72
90
  # Flower ClientApp
@@ -16,10 +16,11 @@ from $import_name.task import (
16
16
 
17
17
  # Define Flower Client and client_fn
18
18
  class FlowerClient(NumPyClient):
19
- def __init__(self, net, trainloader, valloader):
19
+ def __init__(self, net, trainloader, valloader, local_epochs):
20
20
  self.net = net
21
21
  self.trainloader = trainloader
22
22
  self.valloader = valloader
23
+ self.local_epochs = local_epochs
23
24
 
24
25
  def fit(self, parameters, config):
25
26
  set_weights(self.net, parameters)
@@ -27,7 +28,7 @@ class FlowerClient(NumPyClient):
27
28
  self.net,
28
29
  self.trainloader,
29
30
  self.valloader,
30
- int(self.context.run_config["local-epochs"]),
31
+ self.local_epochs,
31
32
  DEVICE,
32
33
  )
33
34
  return get_weights(self.net), len(self.trainloader.dataset), results
@@ -41,12 +42,13 @@ class FlowerClient(NumPyClient):
41
42
  def client_fn(context: Context):
42
43
  # Load model and data
43
44
  net = Net().to(DEVICE)
44
- partition_id = int(context.node_config["partition-id"])
45
- num_partitions = int(context.node_config["num-partitions"])
45
+ partition_id = context.node_config["partition-id"]
46
+ num_partitions = context.node_config["num-partitions"]
46
47
  trainloader, valloader = load_data(partition_id, num_partitions)
48
+ local_epochs = context.run_config["local-epochs"]
47
49
 
48
50
  # Return Client instance
49
- return FlowerClient(net, trainloader, valloader).to_client()
51
+ return FlowerClient(net, trainloader, valloader, local_epochs).to_client()
50
52
 
51
53
 
52
54
  # Flower ClientApp
@@ -69,8 +69,8 @@ class FlowerClient(NumPyClient):
69
69
 
70
70
 
71
71
  def client_fn(context: Context):
72
- partition_id = int(context.node_config["partition-id"])
73
- num_partitions = int(context.node_config["num-partitions"])
72
+ partition_id = context.node_config["partition-id"]
73
+ num_partitions = context.node_config["num-partitions"]
74
74
  fds = FederatedDataset(dataset="mnist", partitioners={"train": num_partitions})
75
75
  dataset = fds.load_partition(partition_id, "train").with_format("numpy")
76
76
 
@@ -8,12 +8,17 @@ from $import_name.task import load_data, load_model
8
8
 
9
9
  # Define Flower Client and client_fn
10
10
  class FlowerClient(NumPyClient):
11
- def __init__(self, model, x_train, y_train, x_test, y_test):
11
+ def __init__(
12
+ self, model, x_train, y_train, x_test, y_test, epochs, batch_size, verbose
13
+ ):
12
14
  self.model = model
13
15
  self.x_train = x_train
14
16
  self.y_train = y_train
15
17
  self.x_test = x_test
16
18
  self.y_test = y_test
19
+ self.epochs = epochs
20
+ self.batch_size = batch_size
21
+ self.verbose = verbose
17
22
 
18
23
  def get_parameters(self, config):
19
24
  return self.model.get_weights()
@@ -23,9 +28,9 @@ class FlowerClient(NumPyClient):
23
28
  self.model.fit(
24
29
  self.x_train,
25
30
  self.y_train,
26
- epochs=int(self.context.run_config["local-epochs"]),
27
- batch_size=int(self.context.run_config["batch-size"]),
28
- verbose=bool(self.context.run_config.get("verbose")),
31
+ epochs=self.epochs,
32
+ batch_size=self.batch_size,
33
+ verbose=self.verbose,
29
34
  )
30
35
  return self.model.get_weights(), len(self.x_train), {}
31
36
 
@@ -39,12 +44,17 @@ def client_fn(context: Context):
39
44
  # Load model and data
40
45
  net = load_model()
41
46
 
42
- partition_id = int(context.node_config["partition-id"])
43
- num_partitions = int(context.node_config["num-partitions"])
47
+ partition_id = context.node_config["partition-id"]
48
+ num_partitions = context.node_config["num-partitions"]
44
49
  x_train, y_train, x_test, y_test = load_data(partition_id, num_partitions)
50
+ epochs = context.run_config["local-epochs"]
51
+ batch_size = context.run_config["batch-size"]
52
+ verbose = context.run_config.get("verbose")
45
53
 
46
54
  # Return Client instance
47
- return FlowerClient(net, x_train, y_train, x_test, y_test).to_client()
55
+ return FlowerClient(
56
+ net, x_train, y_train, x_test, y_test, epochs, batch_size, verbose
57
+ ).to_client()
48
58
 
49
59
 
50
60
  # Flower ClientApp
@@ -9,8 +9,8 @@ from hydra import compose, initialize
9
9
  from hydra.utils import instantiate
10
10
 
11
11
  from flwr.client import ClientApp
12
- from flwr.common import ndarrays_to_parameters
13
- from flwr.server import ServerApp, ServerConfig
12
+ from flwr.common import Context, ndarrays_to_parameters
13
+ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
14
14
 
15
15
  from $import_name.client_app import gen_client_fn, get_parameters
16
16
  from $import_name.dataset import get_tokenizer_and_data_collator_and_propt_formatting
@@ -67,20 +67,23 @@ init_model = get_model(cfg.model)
67
67
  init_model_parameters = get_parameters(init_model)
68
68
  init_model_parameters = ndarrays_to_parameters(init_model_parameters)
69
69
 
70
- # Instantiate strategy according to config. Here we pass other arguments
71
- # that are only defined at runtime.
72
- strategy = instantiate(
73
- cfg.strategy,
74
- on_fit_config_fn=get_on_fit_config(),
75
- fit_metrics_aggregation_fn=fit_weighted_average,
76
- initial_parameters=init_model_parameters,
77
- evaluate_fn=get_evaluate_fn(
78
- cfg.model, cfg.train.save_every_round, cfg_static.num_rounds, save_path
79
- ),
80
- )
70
+ def server_fn(context: Context):
71
+ # Instantiate strategy according to config. Here we pass other arguments
72
+ # that are only defined at runtime.
73
+ strategy = instantiate(
74
+ cfg.strategy,
75
+ on_fit_config_fn=get_on_fit_config(),
76
+ fit_metrics_aggregation_fn=fit_weighted_average,
77
+ initial_parameters=init_model_parameters,
78
+ evaluate_fn=get_evaluate_fn(
79
+ cfg.model, cfg.train.save_every_round, cfg_static.num_rounds, save_path
80
+ ),
81
+ )
82
+
83
+ config = ServerConfig(num_rounds=cfg_static.num_rounds)
84
+
85
+ return ServerAppComponents(strategy=strategy, config=config)
86
+
81
87
 
82
88
  # ServerApp for Flower Next
83
- server = ServerApp(
84
- config=ServerConfig(num_rounds=cfg_static.num_rounds),
85
- strategy=strategy,
86
- )
89
+ server = ServerApp(server_fn=server_fn)
@@ -10,6 +10,7 @@ from transformers import TrainingArguments
10
10
  from trl import SFTTrainer
11
11
 
12
12
  from flwr.client import NumPyClient
13
+ from flwr.common import Context
13
14
  from flwr.common.typing import NDArrays, Scalar
14
15
  from $import_name.dataset import reformat
15
16
  from $import_name.models import cosine_annealing, get_model
@@ -102,13 +103,14 @@ def gen_client_fn(
102
103
  model_cfg: DictConfig,
103
104
  train_cfg: DictConfig,
104
105
  save_path: str,
105
- ) -> Callable[[str], FlowerClient]: # pylint: disable=too-many-arguments
106
+ ) -> Callable[[Context], FlowerClient]: # pylint: disable=too-many-arguments
106
107
  """Generate the client function that creates the Flower Clients."""
107
108
 
108
- def client_fn(cid: str) -> FlowerClient:
109
+ def client_fn(context: Context) -> FlowerClient:
109
110
  """Create a Flower client representing a single organization."""
110
111
  # Let's get the partition corresponding to the i-th client
111
- client_trainset = fds.load_partition(int(cid), "train")
112
+ partition_id = context.node_config["partition-id"]
113
+ client_trainset = fds.load_partition(partition_id, "train")
112
114
  client_trainset = reformat(client_trainset, llm_task="$llm_challenge_str")
113
115
 
114
116
  return FlowerClient(
@@ -7,7 +7,7 @@ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
7
7
 
8
8
  def server_fn(context: Context):
9
9
  # Read from config
10
- num_rounds = int(context.run_config["num-server-rounds"])
10
+ num_rounds = context.run_config["num-server-rounds"]
11
11
 
12
12
  # Define strategy
13
13
  strategy = FedAvg(
@@ -18,5 +18,6 @@ def server_fn(context: Context):
18
18
 
19
19
  return ServerAppComponents(strategy=strategy, config=config)
20
20
 
21
+
21
22
  # Create ServerApp
22
23
  app = ServerApp(server_fn=server_fn)
@@ -7,7 +7,7 @@ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
7
7
 
8
8
  def server_fn(context: Context):
9
9
  # Read from config
10
- num_rounds = int(context.run_config["num-server-rounds"])
10
+ num_rounds = context.run_config["num-server-rounds"]
11
11
 
12
12
  # Define strategy
13
13
  strategy = FedAvg()
@@ -15,5 +15,6 @@ def server_fn(context: Context):
15
15
 
16
16
  return ServerAppComponents(strategy=strategy, config=config)
17
17
 
18
+
18
19
  # Create ServerApp
19
20
  app = ServerApp(server_fn=server_fn)
@@ -7,7 +7,7 @@ from flwr.server.strategy import FedAvg
7
7
 
8
8
  def server_fn(context: Context):
9
9
  # Read from config
10
- num_rounds = int(context.run_config["num-server-rounds"])
10
+ num_rounds = context.run_config["num-server-rounds"]
11
11
 
12
12
  # Define strategy
13
13
  strategy = FedAvg()
@@ -15,5 +15,6 @@ def server_fn(context: Context):
15
15
 
16
16
  return ServerAppComponents(strategy=strategy, config=config)
17
17
 
18
+
18
19
  # Create ServerApp
19
20
  app = ServerApp(server_fn=server_fn)
@@ -7,7 +7,7 @@ from flwr.server.strategy import FedAvg
7
7
 
8
8
  def server_fn(context: Context):
9
9
  # Read from config
10
- num_rounds = int(context.run_config["num-server-rounds"])
10
+ num_rounds = context.run_config["num-server-rounds"]
11
11
 
12
12
  # Define strategy
13
13
  strategy = FedAvg()
@@ -15,5 +15,6 @@ def server_fn(context: Context):
15
15
 
16
16
  return ServerAppComponents(strategy=strategy, config=config)
17
17
 
18
+
18
19
  # Create ServerApp
19
20
  app = ServerApp(server_fn=server_fn)
@@ -13,7 +13,7 @@ parameters = ndarrays_to_parameters(ndarrays)
13
13
 
14
14
  def server_fn(context: Context):
15
15
  # Read from config
16
- num_rounds = int(context.run_config["num-server-rounds"])
16
+ num_rounds = context.run_config["num-server-rounds"]
17
17
 
18
18
  # Define strategy
19
19
  strategy = FedAvg(
@@ -7,7 +7,7 @@ from flwr.server.strategy import FedAvg
7
7
 
8
8
  def server_fn(context: Context):
9
9
  # Read from config
10
- num_rounds = int(context.run_config["num-server-rounds"])
10
+ num_rounds = context.run_config["num-server-rounds"]
11
11
 
12
12
  # Define strategy
13
13
  strategy = FedAvg(
@@ -19,5 +19,6 @@ def server_fn(context: Context):
19
19
 
20
20
  return ServerAppComponents(strategy=strategy, config=config)
21
21
 
22
+
22
23
  # Create ServerApp
23
24
  app = ServerApp(server_fn=server_fn)
@@ -13,7 +13,7 @@ parameters = ndarrays_to_parameters(load_model().get_weights())
13
13
 
14
14
  def server_fn(context: Context):
15
15
  # Read from config
16
- num_rounds = int(context.run_config["num-server-rounds"])
16
+ num_rounds = context.run_config["num-server-rounds"]
17
17
 
18
18
  # Define strategy
19
19
  strategy = strategy = FedAvg(
@@ -10,15 +10,27 @@ from torch.utils.data import DataLoader
10
10
  from transformers import AutoTokenizer, DataCollatorWithPadding
11
11
 
12
12
  from flwr_datasets import FederatedDataset
13
+ from flwr_datasets.partitioner import IidPartitioner
14
+
13
15
 
14
16
  warnings.filterwarnings("ignore", category=UserWarning)
15
17
  DEVICE = torch.device("cpu")
16
18
  CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint
17
19
 
18
20
 
21
+ fds = None # Cache FederatedDataset
22
+
23
+
19
24
  def load_data(partition_id: int, num_partitions: int):
20
25
  """Load IMDB data (training and eval)"""
21
- fds = FederatedDataset(dataset="imdb", partitioners={"train": num_partitions})
26
+ # Only initialize `FederatedDataset` once
27
+ global fds
28
+ if fds is None:
29
+ partitioner = IidPartitioner(num_partitions=num_partitions)
30
+ fds = FederatedDataset(
31
+ dataset="stanfordnlp/imdb",
32
+ partitioners={"train": partitioner},
33
+ )
22
34
  partition = fds.load_partition(partition_id)
23
35
  # Divide data: 80% train, 20% test
24
36
  partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
@@ -5,10 +5,12 @@ import mlx.nn as nn
5
5
  import numpy as np
6
6
  from datasets.utils.logging import disable_progress_bar
7
7
  from flwr_datasets import FederatedDataset
8
+ from flwr_datasets.partitioner import IidPartitioner
8
9
 
9
10
 
10
11
  disable_progress_bar()
11
12
 
13
+
12
14
  class MLP(nn.Module):
13
15
  """A simple MLP."""
14
16
 
@@ -43,8 +45,19 @@ def batch_iterate(batch_size, X, y):
43
45
  yield X[ids], y[ids]
44
46
 
45
47
 
48
+ fds = None # Cache FederatedDataset
49
+
50
+
46
51
  def load_data(partition_id: int, num_partitions: int):
47
- fds = FederatedDataset(dataset="mnist", partitioners={"train": num_partitions})
52
+ # Only initialize `FederatedDataset` once
53
+ global fds
54
+ if fds is None:
55
+ partitioner = IidPartitioner(num_partitions=num_partitions)
56
+ fds = FederatedDataset(
57
+ dataset="ylecun/mnist",
58
+ partitioners={"train": partitioner},
59
+ trust_remote_code=True,
60
+ )
48
61
  partition = fds.load_partition(partition_id)
49
62
  partition_splits = partition.train_test_split(test_size=0.2, seed=42)
50
63
 
@@ -6,9 +6,10 @@ import torch
6
6
  import torch.nn as nn
7
7
  import torch.nn.functional as F
8
8
  from torch.utils.data import DataLoader
9
- from torchvision.datasets import CIFAR10
10
9
  from torchvision.transforms import Compose, Normalize, ToTensor
11
10
  from flwr_datasets import FederatedDataset
11
+ from flwr_datasets.partitioner import IidPartitioner
12
+
12
13
 
13
14
  DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
15
 
@@ -34,9 +35,19 @@ class Net(nn.Module):
34
35
  return self.fc3(x)
35
36
 
36
37
 
38
+ fds = None # Cache FederatedDataset
39
+
40
+
37
41
  def load_data(partition_id: int, num_partitions: int):
38
42
  """Load partition CIFAR10 data."""
39
- fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions})
43
+ # Only initialize `FederatedDataset` once
44
+ global fds
45
+ if fds is None:
46
+ partitioner = IidPartitioner(num_partitions=num_partitions)
47
+ fds = FederatedDataset(
48
+ dataset="uoft-cs/cifar10",
49
+ partitioners={"train": partitioner},
50
+ )
40
51
  partition = fds.load_partition(partition_id)
41
52
  # Divide data on each node: 80% train, 20% test
42
53
  partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
@@ -4,11 +4,13 @@ import os
4
4
 
5
5
  import tensorflow as tf
6
6
  from flwr_datasets import FederatedDataset
7
+ from flwr_datasets.partitioner import IidPartitioner
7
8
 
8
9
 
9
10
  # Make TensorFlow log less verbose
10
11
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
11
12
 
13
+
12
14
  def load_model():
13
15
  # Load model and data (MobileNetV2, CIFAR-10)
14
16
  model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None)
@@ -16,9 +18,19 @@ def load_model():
16
18
  return model
17
19
 
18
20
 
21
+ fds = None # Cache FederatedDataset
22
+
23
+
19
24
  def load_data(partition_id, num_partitions):
20
25
  # Download and partition dataset
21
- fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions})
26
+ # Only initialize `FederatedDataset` once
27
+ global fds
28
+ if fds is None:
29
+ partitioner = IidPartitioner(num_partitions=num_partitions)
30
+ fds = FederatedDataset(
31
+ dataset="uoft-cs/cifar10",
32
+ partitioners={"train": partitioner},
33
+ )
22
34
  partition = fds.load_partition(partition_id, "train")
23
35
  partition.set_format("numpy")
24
36
 
@@ -30,7 +30,7 @@ serverapp = "$import_name.app:server"
30
30
  clientapp = "$import_name.app:client"
31
31
 
32
32
  [tool.flwr.app.config]
33
- num-server-rounds = "3"
33
+ num-server-rounds = 3
34
34
 
35
35
  [tool.flwr.federations]
36
36
  default = "localhost"
@@ -28,8 +28,8 @@ serverapp = "$import_name.server_app:app"
28
28
  clientapp = "$import_name.client_app:app"
29
29
 
30
30
  [tool.flwr.app.config]
31
- num-server-rounds = "3"
32
- local-epochs = "1"
31
+ num-server-rounds = 3
32
+ local-epochs = 1
33
33
 
34
34
  [tool.flwr.federations]
35
35
  default = "localhost"
@@ -25,7 +25,7 @@ serverapp = "$import_name.server_app:app"
25
25
  clientapp = "$import_name.client_app:app"
26
26
 
27
27
  [tool.flwr.app.config]
28
- num-server-rounds = "3"
28
+ num-server-rounds = 3
29
29
 
30
30
  [tool.flwr.federations]
31
31
  default = "localhost"