flwr 1.21.0__py3-none-any.whl → 1.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. flwr/cli/app.py +2 -0
  2. flwr/cli/new/new.py +9 -7
  3. flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
  4. flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
  5. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
  6. flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
  7. flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
  8. flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
  9. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
  10. flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
  11. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
  12. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
  13. flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
  14. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
  15. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
  16. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
  17. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
  18. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
  19. flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
  20. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
  21. flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
  22. flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
  23. flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
  24. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
  25. flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
  26. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
  27. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
  28. flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
  29. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
  30. flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
  31. flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
  32. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
  33. flwr/cli/new/templates/app/code/{task.pytorch_msg_api.py.tpl → task.pytorch_legacy_api.py.tpl} +27 -14
  34. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
  35. flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
  36. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
  37. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
  38. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
  39. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  40. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
  41. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  42. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
  43. flwr/cli/new/templates/app/{pyproject.pytorch_msg_api.toml.tpl → pyproject.pytorch_legacy_api.toml.tpl} +3 -3
  44. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  45. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  46. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
  47. flwr/cli/pull.py +100 -0
  48. flwr/cli/utils.py +17 -0
  49. flwr/clientapp/mod/__init__.py +4 -1
  50. flwr/clientapp/mod/centraldp_mods.py +156 -40
  51. flwr/clientapp/mod/localdp_mod.py +169 -0
  52. flwr/clientapp/typing.py +22 -0
  53. flwr/common/constant.py +3 -0
  54. flwr/common/exit/exit_code.py +4 -0
  55. flwr/common/record/typeddict.py +12 -0
  56. flwr/proto/control_pb2.py +7 -3
  57. flwr/proto/control_pb2.pyi +24 -0
  58. flwr/proto/control_pb2_grpc.py +34 -0
  59. flwr/proto/control_pb2_grpc.pyi +13 -0
  60. flwr/server/app.py +13 -0
  61. flwr/serverapp/strategy/__init__.py +26 -0
  62. flwr/serverapp/strategy/bulyan.py +238 -0
  63. flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
  64. flwr/serverapp/strategy/dp_fixed_clipping.py +71 -49
  65. flwr/serverapp/strategy/fedadagrad.py +0 -3
  66. flwr/serverapp/strategy/fedadam.py +0 -3
  67. flwr/serverapp/strategy/fedavg.py +89 -64
  68. flwr/serverapp/strategy/fedavgm.py +198 -0
  69. flwr/serverapp/strategy/fedmedian.py +105 -0
  70. flwr/serverapp/strategy/fedprox.py +174 -0
  71. flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
  72. flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
  73. flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
  74. flwr/serverapp/strategy/fedyogi.py +0 -3
  75. flwr/serverapp/strategy/krum.py +112 -0
  76. flwr/serverapp/strategy/multikrum.py +247 -0
  77. flwr/serverapp/strategy/qfedavg.py +252 -0
  78. flwr/serverapp/strategy/strategy_utils.py +48 -0
  79. flwr/simulation/app.py +1 -1
  80. flwr/simulation/run_simulation.py +25 -30
  81. flwr/supercore/cli/flower_superexec.py +26 -1
  82. flwr/supercore/constant.py +19 -0
  83. flwr/supercore/superexec/plugin/exec_plugin.py +11 -1
  84. flwr/supercore/superexec/run_superexec.py +16 -2
  85. flwr/superlink/artifact_provider/__init__.py +22 -0
  86. flwr/superlink/artifact_provider/artifact_provider.py +37 -0
  87. flwr/superlink/servicer/control/control_grpc.py +3 -0
  88. flwr/superlink/servicer/control/control_servicer.py +59 -2
  89. {flwr-1.21.0.dist-info → flwr-1.22.0.dist-info}/METADATA +6 -16
  90. {flwr-1.21.0.dist-info → flwr-1.22.0.dist-info}/RECORD +93 -74
  91. flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +0 -80
  92. flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +0 -41
  93. flwr/serverapp/dp_fixed_clipping.py +0 -352
  94. flwr/serverapp/strategy/strategy_utils_tests.py +0 -304
  95. /flwr/cli/new/templates/app/code/{__init__.pytorch_msg_api.py.tpl → __init__.pytorch_legacy_api.py.tpl} +0 -0
  96. {flwr-1.21.0.dist-info → flwr-1.22.0.dist-info}/WHEEL +0 -0
  97. {flwr-1.21.0.dist-info → flwr-1.22.0.dist-info}/entry_points.txt +0 -0
@@ -2,15 +2,12 @@
2
2
 
3
3
  import os
4
4
  import warnings
5
- from typing import Dict, Tuple
6
5
 
7
- import torch
8
- from flwr.client import ClientApp, NumPyClient
9
- from flwr.common import Context
6
+ from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
7
+ from flwr.clientapp import ClientApp
10
8
  from flwr.common.config import unflatten_dict
11
- from flwr.common.typing import NDArrays, Scalar
12
9
  from omegaconf import DictConfig
13
-
10
+ from peft import get_peft_model_state_dict, set_peft_model_state_dict
14
11
  from transformers import TrainingArguments
15
12
  from trl import SFTTrainer
16
13
 
@@ -19,12 +16,7 @@ from $import_name.dataset import (
19
16
  load_data,
20
17
  replace_keys,
21
18
  )
22
- from $import_name.models import (
23
- cosine_annealing,
24
- get_model,
25
- set_parameters,
26
- get_parameters,
27
- )
19
+ from $import_name.models import cosine_annealing, get_model
28
20
 
29
21
  # Avoid warnings
30
22
  os.environ["TOKENIZERS_PARALLELISM"] = "true"
@@ -32,95 +24,69 @@ os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"
32
24
  warnings.filterwarnings("ignore", category=UserWarning)
33
25
 
34
26
 
35
- # pylint: disable=too-many-arguments
36
- # pylint: disable=too-many-instance-attributes
37
- class FlowerClient(NumPyClient):
38
- """Flower client for LLM fine-tuning."""
27
+ # Avoid warnings
28
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
29
+ os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"
30
+ warnings.filterwarnings("ignore", category=UserWarning)
39
31
 
40
- def __init__(
41
- self,
42
- model_cfg: DictConfig,
43
- train_cfg: DictConfig,
44
- trainset,
45
- tokenizer,
46
- formatting_prompts_func,
47
- data_collator,
48
- num_rounds,
49
- ): # pylint: disable=too-many-arguments
50
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
51
- self.train_cfg = train_cfg
52
- self.training_arguments = TrainingArguments(**train_cfg.training_arguments)
53
- self.tokenizer = tokenizer
54
- self.formatting_prompts_func = formatting_prompts_func
55
- self.data_collator = data_collator
56
- self.num_rounds = num_rounds
57
- self.trainset = trainset
58
-
59
- # instantiate model
60
- self.model = get_model(model_cfg)
61
-
62
- def fit(
63
- self, parameters: NDArrays, config: Dict[str, Scalar]
64
- ) -> Tuple[NDArrays, int, Dict]:
65
- """Implement distributed fit function for a given client."""
66
- set_parameters(self.model, parameters)
67
-
68
- new_lr = cosine_annealing(
69
- int(config["current_round"]),
70
- self.num_rounds,
71
- self.train_cfg.learning_rate_max,
72
- self.train_cfg.learning_rate_min,
73
- )
74
-
75
- self.training_arguments.learning_rate = new_lr
76
- self.training_arguments.output_dir = config["save_path"]
77
-
78
- # Construct trainer
79
- trainer = SFTTrainer(
80
- model=self.model,
81
- tokenizer=self.tokenizer,
82
- args=self.training_arguments,
83
- max_seq_length=self.train_cfg.seq_length,
84
- train_dataset=self.trainset,
85
- formatting_func=self.formatting_prompts_func,
86
- data_collator=self.data_collator,
87
- )
88
-
89
- # Do local training
90
- results = trainer.train()
91
-
92
- return (
93
- get_parameters(self.model),
94
- len(self.trainset),
95
- {"train_loss": results.training_loss},
96
- )
97
-
98
-
99
- def client_fn(context: Context) -> FlowerClient:
100
- """Create a Flower client representing a single organization."""
32
+
33
+ # Flower ClientApp
34
+ app = ClientApp()
35
+
36
+
37
+ @app.train()
38
+ def train(msg: Message, context: Context):
39
+ """Train the model on local data."""
40
+ # Parse config
101
41
  partition_id = context.node_config["partition-id"]
102
42
  num_partitions = context.node_config["num-partitions"]
103
43
  num_rounds = context.run_config["num-server-rounds"]
104
44
  cfg = DictConfig(replace_keys(unflatten_dict(context.run_config)))
45
+ training_arguments = TrainingArguments(**cfg.train.training_arguments)
105
46
 
106
47
  # Let's get the client partition
107
- client_trainset = load_data(partition_id, num_partitions, cfg.static.dataset.name)
48
+ trainset = load_data(partition_id, num_partitions, cfg.static.dataset.name)
108
49
  (
109
50
  tokenizer,
110
51
  data_collator,
111
52
  formatting_prompts_func,
112
53
  ) = get_tokenizer_and_data_collator_and_propt_formatting(cfg.model.name)
113
54
 
114
- return FlowerClient(
115
- cfg.model,
116
- cfg.train,
117
- client_trainset,
118
- tokenizer,
119
- formatting_prompts_func,
120
- data_collator,
121
- num_rounds,
122
- ).to_client()
123
-
55
+ # Load the model and initialize it with the received weights
56
+ model = get_model(cfg.model)
57
+ set_peft_model_state_dict(model, msg.content["arrays"].to_torch_state_dict())
124
58
 
125
- # Flower ClientApp
126
- app = ClientApp(client_fn)
59
+ # Set learning rate for current round
60
+ new_lr = cosine_annealing(
61
+ msg.content["config"]["server-round"],
62
+ num_rounds,
63
+ cfg.train.learning_rate_max,
64
+ cfg.train.learning_rate_min,
65
+ )
66
+
67
+ training_arguments.learning_rate = new_lr
68
+ training_arguments.output_dir = msg.content["config"]["save_path"]
69
+
70
+ # Construct trainer
71
+ trainer = SFTTrainer(
72
+ model=model,
73
+ tokenizer=tokenizer,
74
+ args=training_arguments,
75
+ max_seq_length=cfg.train.seq_length,
76
+ train_dataset=trainset,
77
+ formatting_func=formatting_prompts_func,
78
+ data_collator=data_collator,
79
+ )
80
+
81
+ # Do local training
82
+ results = trainer.train()
83
+
84
+ # Construct and return reply Message
85
+ model_record = ArrayRecord(get_peft_model_state_dict(model))
86
+ metrics = {
87
+ "train_loss": results.training_loss,
88
+ "num-examples": len(trainset),
89
+ }
90
+ metric_record = MetricRecord(metrics)
91
+ content = RecordDict({"arrays": model_record, "metrics": metric_record})
92
+ return Message(content=content, reply_to=msg)
@@ -4,18 +4,10 @@ import math
4
4
 
5
5
  import torch
6
6
  from omegaconf import DictConfig
7
- from collections import OrderedDict
8
- from peft import (
9
- LoraConfig,
10
- get_peft_model,
11
- get_peft_model_state_dict,
12
- set_peft_model_state_dict,
13
- )
7
+ from peft import LoraConfig, get_peft_model
14
8
  from peft.utils import prepare_model_for_kbit_training
15
9
  from transformers import AutoModelForCausalLM, BitsAndBytesConfig
16
10
 
17
- from flwr.common.typing import NDArrays
18
-
19
11
 
20
12
  def cosine_annealing(
21
13
  current_round: int,
@@ -62,17 +54,3 @@ def get_model(model_cfg: DictConfig):
62
54
  model.config.use_cache = False
63
55
 
64
56
  return get_peft_model(model, peft_config)
65
-
66
-
67
- def set_parameters(model, parameters: NDArrays) -> None:
68
- """Change the parameters of the model using the given ones."""
69
- peft_state_dict_keys = get_peft_model_state_dict(model).keys()
70
- params_dict = zip(peft_state_dict_keys, parameters)
71
- state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
72
- set_peft_model_state_dict(model, state_dict)
73
-
74
-
75
- def get_parameters(model) -> NDArrays:
76
- """Return the parameters of the current net."""
77
- state_dict = get_peft_model_state_dict(model)
78
- return [val.cpu().numpy() for _, val in state_dict.items()]
@@ -3,62 +3,23 @@
3
3
  import os
4
4
  from datetime import datetime
5
5
 
6
- from flwr.common import Context, ndarrays_to_parameters
6
+ from flwr.app import ArrayRecord, ConfigRecord, Context, MetricRecord
7
7
  from flwr.common.config import unflatten_dict
8
- from flwr.server import ServerApp, ServerAppComponents, ServerConfig
8
+ from flwr.serverapp import Grid, ServerApp
9
9
  from omegaconf import DictConfig
10
+ from peft import get_peft_model_state_dict, set_peft_model_state_dict
10
11
 
11
- from $import_name.models import get_model, get_parameters, set_parameters
12
12
  from $import_name.dataset import replace_keys
13
+ from $import_name.models import get_model
13
14
  from $import_name.strategy import FlowerTuneLlm
14
15
 
15
-
16
- # Get function that will be executed by the strategy's evaluate() method
17
- # Here we use it to save global model checkpoints
18
- def get_evaluate_fn(model_cfg, save_every_round, total_round, save_path):
19
- """Return an evaluation function for saving global model."""
20
-
21
- def evaluate(server_round: int, parameters, config):
22
- # Save model
23
- if server_round != 0 and (
24
- server_round == total_round or server_round % save_every_round == 0
25
- ):
26
- # Init model
27
- model = get_model(model_cfg)
28
- set_parameters(model, parameters)
29
-
30
- model.save_pretrained(f"{save_path}/peft_{server_round}")
31
-
32
- return 0.0, {}
33
-
34
- return evaluate
35
-
36
-
37
- def get_on_fit_config(save_path):
38
- """Return a function that will be used to construct the config that the
39
- client's fit() method will receive."""
40
-
41
- def fit_config_fn(server_round: int):
42
- fit_config = {}
43
- fit_config["current_round"] = server_round
44
- fit_config["save_path"] = save_path
45
- return fit_config
46
-
47
- return fit_config_fn
16
+ # Create ServerApp
17
+ app = ServerApp()
48
18
 
49
19
 
50
- def fit_weighted_average(metrics):
51
- """Aggregate (federated) evaluation metrics."""
52
- # Multiply accuracy of each client by number of examples used
53
- losses = [num_examples * m["train_loss"] for num_examples, m in metrics]
54
- examples = [num_examples for num_examples, _ in metrics]
55
-
56
- # Aggregate and return custom metric (weighted average)
57
- return {"train_loss": sum(losses) / sum(examples)}
58
-
59
-
60
- def server_fn(context: Context):
61
- """Construct components that set the ServerApp behaviour."""
20
+ @app.main()
21
+ def main(grid: Grid, context: Context) -> None:
22
+ """Main entry point for the ServerApp."""
62
23
  # Create output directory given current timestamp
63
24
  current_time = datetime.now()
64
25
  folder_name = current_time.strftime("%Y-%m-%d_%H-%M-%S")
@@ -71,24 +32,42 @@ def server_fn(context: Context):
71
32
 
72
33
  # Get initial model weights
73
34
  init_model = get_model(cfg.model)
74
- init_model_parameters = get_parameters(init_model)
75
- init_model_parameters = ndarrays_to_parameters(init_model_parameters)
35
+ arrays = ArrayRecord(get_peft_model_state_dict(init_model))
76
36
 
77
37
  # Define strategy
78
38
  strategy = FlowerTuneLlm(
79
- fraction_fit=cfg.strategy.fraction_fit,
39
+ fraction_train=cfg.strategy.fraction_train,
80
40
  fraction_evaluate=cfg.strategy.fraction_evaluate,
81
- on_fit_config_fn=get_on_fit_config(save_path),
82
- fit_metrics_aggregation_fn=fit_weighted_average,
83
- initial_parameters=init_model_parameters,
41
+ )
42
+
43
+ # Start strategy, run FedAvg for `num_rounds`
44
+ strategy.start(
45
+ grid=grid,
46
+ initial_arrays=arrays,
47
+ train_config=ConfigRecord({"save_path": save_path}),
48
+ num_rounds=num_rounds,
84
49
  evaluate_fn=get_evaluate_fn(
85
50
  cfg.model, cfg.train.save_every_round, num_rounds, save_path
86
51
  ),
87
52
  )
88
- config = ServerConfig(num_rounds=num_rounds)
89
53
 
90
- return ServerAppComponents(strategy=strategy, config=config)
91
54
 
55
+ # Get function that will be executed by the strategy
56
+ # Here we use it to save global model checkpoints
57
+ def get_evaluate_fn(model_cfg, save_every_round, total_round, save_path):
58
+ """Return an evaluation function for saving global model."""
92
59
 
93
- # Flower ServerApp
94
- app = ServerApp(server_fn=server_fn)
60
+ def evaluate(server_round: int, arrays: ArrayRecord) -> MetricRecord:
61
+ # Save model
62
+ if server_round != 0 and (
63
+ server_round == total_round or server_round % save_every_round == 0
64
+ ):
65
+ # Init model
66
+ model = get_model(model_cfg)
67
+ set_peft_model_state_dict(model, arrays.to_torch_state_dict())
68
+
69
+ model.save_pretrained(f"{save_path}/peft_{server_round}")
70
+
71
+ return MetricRecord()
72
+
73
+ return evaluate
@@ -1,53 +1,48 @@
1
1
  """$project_name: A Flower / FlowerTune app."""
2
2
 
3
- from io import BytesIO
3
+ from collections.abc import Iterable
4
4
  from logging import INFO, WARN
5
- from typing import List, Tuple, Union
5
+ from typing import Optional
6
6
 
7
- from flwr.common import FitIns, FitRes, Parameters, log, parameters_to_ndarrays
8
- from flwr.server.client_manager import ClientManager
9
- from flwr.server.client_proxy import ClientProxy
10
- from flwr.server.strategy import FedAvg
7
+ from flwr.app import ArrayRecord, ConfigRecord, Message, MetricRecord
8
+ from flwr.common import log
9
+ from flwr.serverapp import Grid
10
+ from flwr.serverapp.strategy import FedAvg
11
11
 
12
12
 
13
13
  class FlowerTuneLlm(FedAvg):
14
14
  """Customised FedAvg strategy implementation.
15
-
15
+
16
16
  This class behaves just like FedAvg but also tracks the communication
17
- costs associated with `fit` over FL rounds.
17
+ costs associated with `train` over FL rounds.
18
18
  """
19
19
  def __init__(self, **kwargs):
20
20
  super().__init__(**kwargs)
21
21
  self.comm_tracker = CommunicationTracker()
22
22
 
23
- def configure_fit(
24
- self, server_round: int, parameters: Parameters, client_manager: ClientManager
25
- ):
23
+ def configure_train(
24
+ self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
25
+ ) -> Iterable[Message]:
26
26
  """Configure the next round of training."""
27
- return_clients = super().configure_fit(server_round, parameters, client_manager)
28
-
29
- # Test communication costs
30
- fit_ins_list = [fit_ins for _, fit_ins in return_clients]
31
- self.comm_tracker.track(fit_ins_list)
32
-
33
- return return_clients
34
-
35
- def aggregate_fit(
36
- self,
37
- server_round: int,
38
- results: List[Tuple[ClientProxy, FitRes]],
39
- failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
40
- ):
41
- """Aggregate fit results using weighted average."""
42
- # Test communication costs
43
- fit_res_list = [fit_res for _, fit_res in results]
44
- self.comm_tracker.track(fit_res_list)
45
-
46
- parameters_aggregated, metrics_aggregated = super().aggregate_fit(
47
- server_round, results, failures
48
- )
27
+ messages = super().configure_train(server_round, arrays, config, grid)
28
+
29
+ # Track communication costs
30
+ self.comm_tracker.track(messages)
31
+
32
+ return messages
33
+
34
+ def aggregate_train(
35
+ self,
36
+ server_round: int,
37
+ replies: Iterable[Message],
38
+ ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
39
+ """Aggregate ArrayRecords and MetricRecords in the received Messages."""
40
+ # Track communication costs
41
+ self.comm_tracker.track(replies)
49
42
 
50
- return parameters_aggregated, metrics_aggregated
43
+ arrays, metrics = super().aggregate_train(server_round, replies)
44
+
45
+ return arrays, metrics
51
46
 
52
47
 
53
48
  class CommunicationTracker:
@@ -55,16 +50,16 @@ class CommunicationTracker:
55
50
  def __init__(self):
56
51
  self.curr_comm_cost = 0.0
57
52
 
58
- @staticmethod
59
- def _compute_bytes(parameters):
60
- return sum([BytesIO(t).getbuffer().nbytes for t in parameters.tensors])
61
-
62
- def track(self, fit_list: List[Union[FitIns, FitRes]]):
63
- size_bytes_list = [
64
- self._compute_bytes(fit_ele.parameters)
65
- for fit_ele in fit_list
66
- ]
67
- comm_cost = sum(size_bytes_list) / 1024**2
53
+ def track(self, messages: Iterable[Message]):
54
+ comm_cost = (
55
+ sum(
56
+ record.count_bytes()
57
+ for msg in messages
58
+ if msg.has_content()
59
+ for record in msg.content.array_records.values()
60
+ )
61
+ / 1024**2
62
+ )
68
63
 
69
64
  self.curr_comm_cost += comm_cost
70
65
  log(
@@ -1,7 +1,5 @@
1
1
  """$project_name: A Flower Baseline."""
2
2
 
3
- from collections import OrderedDict
4
-
5
3
  import torch
6
4
  import torch.nn.functional as F
7
5
  from torch import nn
@@ -66,15 +64,3 @@ def test(net, testloader, device):
66
64
  accuracy = correct / len(testloader.dataset)
67
65
  loss = loss / len(testloader)
68
66
  return loss, accuracy
69
-
70
-
71
- def get_weights(net):
72
- """Extract model parameters as numpy arrays from state_dict."""
73
- return [val.cpu().numpy() for _, val in net.state_dict().items()]
74
-
75
-
76
- def set_weights(net, parameters):
77
- """Apply parameters to an existing model."""
78
- params_dict = zip(net.state_dict().keys(), parameters)
79
- state_dict = OrderedDict({k: torch.from_numpy(v) for k, v in params_dict})
80
- net.load_state_dict(state_dict, strict=True)
@@ -1,45 +1,43 @@
1
1
  """$project_name: A Flower Baseline."""
2
2
 
3
- from flwr.common import Context, Metrics, ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
- from flwr.server.strategy import FedAvg
3
+ import torch
4
+ from flwr.app import ArrayRecord, Context
5
+ from flwr.serverapp import Grid, ServerApp
6
+ from flwr.serverapp.strategy import FedAvg
6
7
 
7
- from $import_name.model import Net, get_weights
8
+ from $import_name.model import Net
8
9
 
10
+ # Create ServerApp
11
+ app = ServerApp()
9
12
 
10
- # Define metric aggregation function
11
- def weighted_average(metrics: list[tuple[int, Metrics]]) -> Metrics:
12
- """Do weighted average of accuracy metric."""
13
- # Multiply accuracy of each client by number of examples used
14
- accuracies = [num_examples * float(m["accuracy"]) for num_examples, m in metrics]
15
- examples = [num_examples for num_examples, _ in metrics]
16
-
17
- # Aggregate and return custom metric (weighted average)
18
- return {"accuracy": sum(accuracies) / sum(examples)}
19
13
 
14
+ @app.main()
15
+ def main(grid: Grid, context: Context) -> None:
16
+ """Main entry point for the ServerApp."""
20
17
 
21
- def server_fn(context: Context):
22
- """Construct components that set the ServerApp behaviour."""
23
18
  # Read from config
24
19
  num_rounds = context.run_config["num-server-rounds"]
25
- fraction_fit = context.run_config["fraction-fit"]
20
+ fraction_train = context.run_config["fraction-train"]
26
21
 
27
- # Initialize model parameters
28
- ndarrays = get_weights(Net())
29
- parameters = ndarrays_to_parameters(ndarrays)
22
+ # Load global model
23
+ global_model = Net()
24
+ arrays = ArrayRecord(global_model.state_dict())
30
25
 
31
- # Define strategy
26
+ # Initialize FedAvg strategy
32
27
  strategy = FedAvg(
33
- fraction_fit=float(fraction_fit),
28
+ fraction_train=fraction_train,
34
29
  fraction_evaluate=1.0,
35
- min_available_clients=2,
36
- initial_parameters=parameters,
37
- evaluate_metrics_aggregation_fn=weighted_average,
30
+ min_available_nodes=2,
38
31
  )
39
- config = ServerConfig(num_rounds=int(num_rounds))
40
-
41
- return ServerAppComponents(strategy=strategy, config=config)
42
32
 
33
+ # Start strategy, run FedAvg for `num_rounds`
34
+ result = strategy.start(
35
+ grid=grid,
36
+ initial_arrays=arrays,
37
+ num_rounds=num_rounds,
38
+ )
43
39
 
44
- # Create ServerApp
45
- app = ServerApp(server_fn=server_fn)
40
+ # Save final model to disk
41
+ print("\nSaving final model to disk...")
42
+ state_dict = result.arrays.to_torch_state_dict()
43
+ torch.save(state_dict, "final_model.pt")
@@ -1,17 +1,22 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from flwr.common import Context, ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
- from flwr.server.strategy import FedAvg
3
+ import torch
4
+ from flwr.app import ArrayRecord, Context
5
+ from flwr.serverapp import Grid, ServerApp
6
+ from flwr.serverapp.strategy import FedAvg
6
7
  from transformers import AutoModelForSequenceClassification
7
8
 
8
- from $import_name.task import get_weights
9
+ # Create ServerApp
10
+ app = ServerApp()
11
+
9
12
 
13
+ @app.main()
14
+ def main(grid: Grid, context: Context) -> None:
15
+ """Main entry point for the ServerApp."""
10
16
 
11
- def server_fn(context: Context):
12
17
  # Read from config
13
18
  num_rounds = context.run_config["num-server-rounds"]
14
- fraction_fit = context.run_config["fraction-fit"]
19
+ fraction_train = context.run_config["fraction-train"]
15
20
 
16
21
  # Initialize global model
17
22
  model_name = context.run_config["model-name"]
@@ -19,20 +24,19 @@ def server_fn(context: Context):
19
24
  net = AutoModelForSequenceClassification.from_pretrained(
20
25
  model_name, num_labels=num_labels
21
26
  )
27
+ arrays = ArrayRecord(net.state_dict())
22
28
 
23
- weights = get_weights(net)
24
- initial_parameters = ndarrays_to_parameters(weights)
29
+ # Initialize FedAvg strategy
30
+ strategy = FedAvg(fraction_train=fraction_train)
25
31
 
26
- # Define strategy
27
- strategy = FedAvg(
28
- fraction_fit=fraction_fit,
29
- fraction_evaluate=1.0,
30
- initial_parameters=initial_parameters,
32
+ # Start strategy, run FedAvg for `num_rounds`
33
+ result = strategy.start(
34
+ grid=grid,
35
+ initial_arrays=arrays,
36
+ num_rounds=num_rounds,
31
37
  )
32
- config = ServerConfig(num_rounds=num_rounds)
33
-
34
- return ServerAppComponents(strategy=strategy, config=config)
35
38
 
36
-
37
- # Create ServerApp
38
- app = ServerApp(server_fn=server_fn)
39
+ # Save final model to disk
40
+ print("\nSaving final model to disk...")
41
+ state_dict = result.arrays.to_torch_state_dict()
42
+ torch.save(state_dict, "final_model.pt")
@@ -1,26 +1,39 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from flwr.common import Context, ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
- from flwr.server.strategy import FedAvg
3
+ import numpy as np
4
+ from flwr.app import ArrayRecord, Context
5
+ from flwr.serverapp import Grid, ServerApp
6
+ from flwr.serverapp.strategy import FedAvg
7
+
6
8
  from $import_name.task import get_params, load_model
7
9
 
10
+ # Create ServerApp
11
+ app = ServerApp()
12
+
13
+
14
+ @app.main()
15
+ def main(grid: Grid, context: Context) -> None:
16
+ """Main entry point for the ServerApp."""
8
17
 
9
- def server_fn(context: Context):
10
18
  # Read from config
11
19
  num_rounds = context.run_config["num-server-rounds"]
12
20
  input_dim = context.run_config["input-dim"]
13
21
 
14
- # Initialize global model
15
- params = get_params(load_model((input_dim,)))
16
- initial_parameters = ndarrays_to_parameters(params)
17
-
18
- # Define strategy
19
- strategy = FedAvg(initial_parameters=initial_parameters)
20
- config = ServerConfig(num_rounds=num_rounds)
22
+ # Load global model
23
+ model = load_model((input_dim,))
24
+ arrays = ArrayRecord(get_params(model))
21
25
 
22
- return ServerAppComponents(strategy=strategy, config=config)
26
+ # Initialize FedAvg strategy
27
+ strategy = FedAvg()
23
28
 
29
+ # Start strategy, run FedAvg for `num_rounds`
30
+ result = strategy.start(
31
+ grid=grid,
32
+ initial_arrays=arrays,
33
+ num_rounds=num_rounds,
34
+ )
24
35
 
25
- # Create ServerApp
26
- app = ServerApp(server_fn=server_fn)
36
+ # Save final model to disk
37
+ print("\nSaving final model to disk...")
38
+ ndarrays = result.arrays.to_numpy_ndarrays()
39
+ np.savez("final_model.npz", *ndarrays)