flwr-nightly 1.11.0.dev20240823__py3-none-any.whl → 1.12.0.dev20240906__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (48) hide show
  1. flwr/cli/app.py +0 -2
  2. flwr/cli/new/new.py +24 -10
  3. flwr/cli/new/templates/app/LICENSE.tpl +202 -0
  4. flwr/cli/new/templates/app/README.baseline.md.tpl +127 -0
  5. flwr/cli/new/templates/app/README.flowertune.md.tpl +16 -6
  6. flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +1 -0
  7. flwr/cli/new/templates/app/code/client.baseline.py.tpl +58 -0
  8. flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +36 -0
  9. flwr/cli/new/templates/app/code/flwr_tune/{client.py.tpl → client_app.py.tpl} +50 -40
  10. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +32 -2
  11. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +0 -3
  12. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +95 -0
  13. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +83 -0
  14. flwr/cli/new/templates/app/code/model.baseline.py.tpl +80 -0
  15. flwr/cli/new/templates/app/code/server.baseline.py.tpl +46 -0
  16. flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +1 -0
  17. flwr/cli/new/templates/app/code/utils.baseline.py.tpl +1 -0
  18. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +138 -0
  19. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +34 -7
  20. flwr/cli/run/run.py +2 -2
  21. flwr/client/__init__.py +0 -4
  22. flwr/client/grpc_rere_client/client_interceptor.py +13 -4
  23. flwr/client/supernode/app.py +3 -1
  24. flwr/common/config.py +14 -11
  25. flwr/common/telemetry.py +36 -30
  26. flwr/server/__init__.py +0 -4
  27. flwr/server/app.py +13 -13
  28. flwr/server/compat/app.py +0 -5
  29. flwr/server/driver/grpc_driver.py +1 -3
  30. flwr/server/run_serverapp.py +15 -1
  31. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +19 -8
  32. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +11 -11
  33. flwr/server/superlink/state/in_memory_state.py +15 -15
  34. flwr/server/superlink/state/sqlite_state.py +10 -10
  35. flwr/server/superlink/state/state.py +8 -8
  36. flwr/simulation/run_simulation.py +23 -6
  37. flwr/superexec/__init__.py +0 -6
  38. flwr/superexec/app.py +3 -1
  39. {flwr_nightly-1.11.0.dev20240823.dist-info → flwr_nightly-1.12.0.dev20240906.dist-info}/METADATA +3 -3
  40. {flwr_nightly-1.11.0.dev20240823.dist-info → flwr_nightly-1.12.0.dev20240906.dist-info}/RECORD +43 -35
  41. flwr_nightly-1.12.0.dev20240906.dist-info/entry_points.txt +10 -0
  42. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +0 -89
  43. flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +0 -34
  44. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +0 -48
  45. flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +0 -11
  46. flwr_nightly-1.11.0.dev20240823.dist-info/entry_points.txt +0 -10
  47. {flwr_nightly-1.11.0.dev20240823.dist-info → flwr_nightly-1.12.0.dev20240906.dist-info}/LICENSE +0 -0
  48. {flwr_nightly-1.11.0.dev20240823.dist-info → flwr_nightly-1.12.0.dev20240906.dist-info}/WHEEL +0 -0
@@ -1,20 +1,32 @@
1
1
  """$project_name: A Flower / FlowerTune app."""
2
2
 
3
+ import os
4
+ import warnings
3
5
  from collections import OrderedDict
4
- from typing import Callable, Dict, Tuple
6
+ from typing import Dict, Tuple
5
7
 
6
8
  import torch
9
+ from flwr.client import ClientApp, NumPyClient
10
+ from flwr.common import Context
11
+ from flwr.common.config import unflatten_dict
12
+ from flwr.common.typing import NDArrays, Scalar
7
13
  from omegaconf import DictConfig
8
14
  from peft import get_peft_model_state_dict, set_peft_model_state_dict
9
15
  from transformers import TrainingArguments
10
16
  from trl import SFTTrainer
11
17
 
12
- from flwr.client import NumPyClient
13
- from flwr.common import Context
14
- from flwr.common.typing import NDArrays, Scalar
15
- from $import_name.dataset import reformat
18
+ from $import_name.dataset import (
19
+ get_tokenizer_and_data_collator_and_propt_formatting,
20
+ load_data,
21
+ replace_keys,
22
+ )
16
23
  from $import_name.models import cosine_annealing, get_model
17
24
 
25
+ # Avoid warnings
26
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
27
+ os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"
28
+ warnings.filterwarnings("ignore", category=UserWarning)
29
+
18
30
 
19
31
  # pylint: disable=too-many-arguments
20
32
  # pylint: disable=too-many-instance-attributes
@@ -29,7 +41,7 @@ class FlowerClient(NumPyClient):
29
41
  tokenizer,
30
42
  formatting_prompts_func,
31
43
  data_collator,
32
- save_path,
44
+ num_rounds,
33
45
  ): # pylint: disable=too-many-arguments
34
46
  self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
35
47
  self.train_cfg = train_cfg
@@ -37,13 +49,12 @@ class FlowerClient(NumPyClient):
37
49
  self.tokenizer = tokenizer
38
50
  self.formatting_prompts_func = formatting_prompts_func
39
51
  self.data_collator = data_collator
40
- self.save_path = save_path
52
+ self.num_rounds = num_rounds
53
+ self.trainset = trainset
41
54
 
42
55
  # instantiate model
43
56
  self.model = get_model(model_cfg)
44
57
 
45
- self.trainset = trainset
46
-
47
58
  def fit(
48
59
  self, parameters: NDArrays, config: Dict[str, Scalar]
49
60
  ) -> Tuple[NDArrays, int, Dict]:
@@ -52,13 +63,13 @@ class FlowerClient(NumPyClient):
52
63
 
53
64
  new_lr = cosine_annealing(
54
65
  int(config["current_round"]),
55
- self.train_cfg.num_rounds,
66
+ self.num_rounds,
56
67
  self.train_cfg.learning_rate_max,
57
68
  self.train_cfg.learning_rate_min,
58
69
  )
59
70
 
60
71
  self.training_argumnets.learning_rate = new_lr
61
- self.training_argumnets.output_dir = self.save_path
72
+ self.training_argumnets.output_dir = config["save_path"]
62
73
 
63
74
  # Construct trainer
64
75
  trainer = SFTTrainer(
@@ -95,32 +106,31 @@ def get_parameters(model) -> NDArrays:
95
106
  return [val.cpu().numpy() for _, val in state_dict.items()]
96
107
 
97
108
 
98
- def gen_client_fn(
99
- fds,
100
- tokenizer,
101
- formatting_prompts_func,
102
- data_collator,
103
- model_cfg: DictConfig,
104
- train_cfg: DictConfig,
105
- save_path: str,
106
- ) -> Callable[[Context], FlowerClient]: # pylint: disable=too-many-arguments
107
- """Generate the client function that creates the Flower Clients."""
108
-
109
- def client_fn(context: Context) -> FlowerClient:
110
- """Create a Flower client representing a single organization."""
111
- # Let's get the partition corresponding to the i-th client
112
- partition_id = context.node_config["partition-id"]
113
- client_trainset = fds.load_partition(partition_id, "train")
114
- client_trainset = reformat(client_trainset, llm_task="$llm_challenge_str")
115
-
116
- return FlowerClient(
117
- model_cfg,
118
- train_cfg,
119
- client_trainset,
120
- tokenizer,
121
- formatting_prompts_func,
122
- data_collator,
123
- save_path,
124
- ).to_client()
125
-
126
- return client_fn
109
+ def client_fn(context: Context) -> FlowerClient:
110
+ """Create a Flower client representing a single organization."""
111
+ partition_id = context.node_config["partition-id"]
112
+ num_partitions = context.node_config["num-partitions"]
113
+ num_rounds = context.run_config["num-server-rounds"]
114
+ cfg = DictConfig(replace_keys(unflatten_dict(context.run_config)))
115
+
116
+ # Let's get the client partition
117
+ client_trainset = load_data(partition_id, num_partitions, cfg.static.dataset.name)
118
+ (
119
+ tokenizer,
120
+ data_collator,
121
+ formatting_prompts_func,
122
+ ) = get_tokenizer_and_data_collator_and_propt_formatting(cfg.model.name)
123
+
124
+ return FlowerClient(
125
+ cfg.model,
126
+ cfg.train,
127
+ client_trainset,
128
+ tokenizer,
129
+ formatting_prompts_func,
130
+ data_collator,
131
+ num_rounds,
132
+ ).to_client()
133
+
134
+
135
+ # Flower ClientApp
136
+ app = ClientApp(client_fn)
@@ -1,8 +1,12 @@
1
1
  """$project_name: A Flower / FlowerTune app."""
2
2
 
3
+ from flwr_datasets import FederatedDataset
4
+ from flwr_datasets.partitioner import IidPartitioner
3
5
  from transformers import AutoTokenizer
4
6
  from trl import DataCollatorForCompletionOnlyLM
5
7
 
8
+ FDS = None # Cache FederatedDataset
9
+
6
10
 
7
11
  def formatting_prompts_func(example):
8
12
  """Construct prompts."""
@@ -24,7 +28,6 @@ def formatting_prompts_func(example):
24
28
 
25
29
  def get_tokenizer_and_data_collator_and_propt_formatting(model_name: str):
26
30
  """Get tokenizer, data_collator and prompt formatting."""
27
- # From: https://huggingface.co/docs/trl/en/sft_trainer
28
31
  tokenizer = AutoTokenizer.from_pretrained(
29
32
  model_name, use_fast=True, padding_side="right"
30
33
  )
@@ -49,9 +52,36 @@ def formatting(dataset):
49
52
  def reformat(dataset, llm_task):
50
53
  """Reformat datasets."""
51
54
  dataset = dataset.rename_column("output", "response")
52
- if llm_task == "finance" or llm_task == "code":
55
+ if llm_task in ["finance", "code"]:
53
56
  dataset = dataset.map(formatting, remove_columns=["input"])
54
57
  if llm_task == "medical":
55
58
  dataset = dataset.remove_columns(["instruction"])
56
59
  dataset = dataset.rename_column("input", "instruction")
57
60
  return dataset
61
+
62
+
63
+ def load_data(partition_id: int, num_partitions: int, dataset_name: str):
64
+ """Load partition data."""
65
+ # Only initialize `FederatedDataset` once
66
+ global FDS
67
+ if FDS is None:
68
+ partitioner = IidPartitioner(num_partitions=num_partitions)
69
+ FDS = FederatedDataset(
70
+ dataset=dataset_name,
71
+ partitioners={"train": partitioner},
72
+ )
73
+ client_trainset = FDS.load_partition(partition_id, "train")
74
+ client_trainset = reformat(client_trainset, llm_task="generalnlp")
75
+ return client_trainset
76
+
77
+
78
+ def replace_keys(input_dict, match="-", target="_"):
79
+ """Recursively replace match string with target string in dictionary keys."""
80
+ new_dict = {}
81
+ for key, value in input_dict.items():
82
+ new_key = key.replace(match, target)
83
+ if isinstance(value, dict):
84
+ new_dict[new_key] = replace_keys(value, match, target)
85
+ else:
86
+ new_dict[new_key] = value
87
+ return new_dict
@@ -22,9 +22,6 @@ def cosine_annealing(
22
22
 
23
23
  def get_model(model_cfg: DictConfig):
24
24
  """Load model with appropriate quantization config and other optimizations.
25
-
26
- Please refer to this example for `peft + BitsAndBytes`:
27
- https://github.com/huggingface/peft/blob/main/examples/fp4_finetuning/finetune_fp4_opt_bnb_peft.py
28
25
  """
29
26
  if model_cfg.quantization == 4:
30
27
  quantization_config = BitsAndBytesConfig(load_in_4bit=True)
@@ -0,0 +1,95 @@
1
+ """$project_name: A Flower / FlowerTune app."""
2
+
3
+ import os
4
+ from datetime import datetime
5
+
6
+ from flwr.common import Context, ndarrays_to_parameters
7
+ from flwr.common.config import unflatten_dict
8
+ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
9
+ from omegaconf import DictConfig
10
+
11
+ from $import_name.client_app import get_parameters, set_parameters
12
+ from $import_name.models import get_model
13
+ from $import_name.dataset import replace_keys
14
+ from $import_name.strategy import FlowerTuneLlm
15
+
16
+
17
+ # Get function that will be executed by the strategy's evaluate() method
18
+ # Here we use it to save global model checkpoints
19
+ def get_evaluate_fn(model_cfg, save_every_round, total_round, save_path):
20
+ """Return an evaluation function for saving global model."""
21
+
22
+ def evaluate(server_round: int, parameters, config):
23
+ # Save model
24
+ if server_round != 0 and (
25
+ server_round == total_round or server_round % save_every_round == 0
26
+ ):
27
+ # Init model
28
+ model = get_model(model_cfg)
29
+ set_parameters(model, parameters)
30
+
31
+ model.save_pretrained(f"{save_path}/peft_{server_round}")
32
+
33
+ return 0.0, {}
34
+
35
+ return evaluate
36
+
37
+
38
+ def get_on_fit_config(save_path):
39
+ """Return a function that will be used to construct the config that the
40
+ client's fit() method will receive."""
41
+
42
+ def fit_config_fn(server_round: int):
43
+ fit_config = {}
44
+ fit_config["current_round"] = server_round
45
+ fit_config["save_path"] = save_path
46
+ return fit_config
47
+
48
+ return fit_config_fn
49
+
50
+
51
+ def fit_weighted_average(metrics):
52
+ """Aggregate (federated) evaluation metrics."""
53
+ # Multiply accuracy of each client by number of examples used
54
+ losses = [num_examples * m["train_loss"] for num_examples, m in metrics]
55
+ examples = [num_examples for num_examples, _ in metrics]
56
+
57
+ # Aggregate and return custom metric (weighted average)
58
+ return {"train_loss": sum(losses) / sum(examples)}
59
+
60
+
61
+ def server_fn(context: Context):
62
+ """Construct components that set the ServerApp behaviour."""
63
+ # Create output directory given current timestamp
64
+ current_time = datetime.now()
65
+ folder_name = current_time.strftime("%Y-%m-%d_%H-%M-%S")
66
+ save_path = os.path.join(os.getcwd(), f"results/{folder_name}")
67
+ os.makedirs(save_path, exist_ok=True)
68
+
69
+ # Read from config
70
+ num_rounds = context.run_config["num-server-rounds"]
71
+ cfg = DictConfig(replace_keys(unflatten_dict(context.run_config)))
72
+
73
+ # Get initial model weights
74
+ init_model = get_model(cfg.model)
75
+ init_model_parameters = get_parameters(init_model)
76
+ init_model_parameters = ndarrays_to_parameters(init_model_parameters)
77
+
78
+ # Define strategy
79
+ strategy = FlowerTuneLlm(
80
+ fraction_fit=cfg.strategy.fraction_fit,
81
+ fraction_evaluate=cfg.strategy.fraction_evaluate,
82
+ on_fit_config_fn=get_on_fit_config(save_path),
83
+ fit_metrics_aggregation_fn=fit_weighted_average,
84
+ initial_parameters=init_model_parameters,
85
+ evaluate_fn=get_evaluate_fn(
86
+ cfg.model, cfg.train.save_every_round, num_rounds, save_path
87
+ ),
88
+ )
89
+ config = ServerConfig(num_rounds=num_rounds)
90
+
91
+ return ServerAppComponents(strategy=strategy, config=config)
92
+
93
+
94
+ # Flower ServerApp
95
+ app = ServerApp(server_fn=server_fn)
@@ -0,0 +1,83 @@
1
+ """$project_name: A Flower / FlowerTune app."""
2
+
3
+ from io import BytesIO
4
+ from logging import INFO, WARN
5
+ from typing import List, Tuple, Union
6
+
7
+ from flwr.common import FitIns, FitRes, Parameters, log, parameters_to_ndarrays
8
+ from flwr.server.client_manager import ClientManager
9
+ from flwr.server.client_proxy import ClientProxy
10
+ from flwr.server.strategy import FedAvg
11
+
12
+
13
+ class FlowerTuneLlm(FedAvg):
14
+ """Customised FedAvg strategy implementation.
15
+
16
+ This class behaves just like FedAvg but also tracks the communication
17
+ costs associated with `fit` over FL rounds.
18
+ """
19
+ def __init__(self, **kwargs):
20
+ super().__init__(**kwargs)
21
+ self.comm_tracker = CommunicationTracker()
22
+
23
+ def configure_fit(
24
+ self, server_round: int, parameters: Parameters, client_manager: ClientManager
25
+ ):
26
+ """Configure the next round of training."""
27
+ return_clients = super().configure_fit(server_round, parameters, client_manager)
28
+
29
+ # Test communication costs
30
+ fit_ins_list = [fit_ins for _, fit_ins in return_clients]
31
+ self.comm_tracker.track(fit_ins_list)
32
+
33
+ return return_clients
34
+
35
+ def aggregate_fit(
36
+ self,
37
+ server_round: int,
38
+ results: List[Tuple[ClientProxy, FitRes]],
39
+ failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
40
+ ):
41
+ """Aggregate fit results using weighted average."""
42
+ # Test communication costs
43
+ fit_res_list = [fit_res for _, fit_res in results]
44
+ self.comm_tracker.track(fit_res_list)
45
+
46
+ parameters_aggregated, metrics_aggregated = super().aggregate_fit(
47
+ server_round, results, failures
48
+ )
49
+
50
+ return parameters_aggregated, metrics_aggregated
51
+
52
+
53
+ class CommunicationTracker:
54
+ """Communication costs tracker over FL rounds."""
55
+ def __init__(self):
56
+ self.curr_comm_cost = 0.0
57
+
58
+ @staticmethod
59
+ def _compute_bytes(parameters):
60
+ return sum([BytesIO(t).getbuffer().nbytes for t in parameters.tensors])
61
+
62
+ def track(self, fit_list: List[Union[FitIns, FitRes]]):
63
+ size_bytes_list = [
64
+ self._compute_bytes(fit_ele.parameters)
65
+ for fit_ele in fit_list
66
+ ]
67
+ comm_cost = sum(size_bytes_list) / 1024**2
68
+
69
+ self.curr_comm_cost += comm_cost
70
+ log(
71
+ INFO,
72
+ "Communication budget: used %.2f MB (+%.2f MB this round) / 200,000 MB",
73
+ self.curr_comm_cost,
74
+ comm_cost,
75
+ )
76
+
77
+ if self.curr_comm_cost > 2e5:
78
+ log(
79
+ WARN,
80
+ "The accumulated communication cost has exceeded 200,000 MB. "
81
+ "Please consider reducing it if you plan to participate "
82
+ "FlowerTune LLM Leaderboard.",
83
+ )
@@ -0,0 +1,80 @@
1
+ """$project_name: A Flower Baseline."""
2
+
3
+ from collections import OrderedDict
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+
10
+ class Net(nn.Module):
11
+ """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')."""
12
+
13
+ def __init__(self):
14
+ super().__init__()
15
+ self.conv1 = nn.Conv2d(3, 6, 5)
16
+ self.pool = nn.MaxPool2d(2, 2)
17
+ self.conv2 = nn.Conv2d(6, 16, 5)
18
+ self.fc1 = nn.Linear(16 * 5 * 5, 120)
19
+ self.fc2 = nn.Linear(120, 84)
20
+ self.fc3 = nn.Linear(84, 10)
21
+
22
+ def forward(self, x):
23
+ """Do forward."""
24
+ x = self.pool(F.relu(self.conv1(x)))
25
+ x = self.pool(F.relu(self.conv2(x)))
26
+ x = x.view(-1, 16 * 5 * 5)
27
+ x = F.relu(self.fc1(x))
28
+ x = F.relu(self.fc2(x))
29
+ return self.fc3(x)
30
+
31
+
32
+ def train(net, trainloader, epochs, device):
33
+ """Train the model on the training set."""
34
+ net.to(device) # move model to GPU if available
35
+ criterion = torch.nn.CrossEntropyLoss()
36
+ criterion.to(device)
37
+ optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9)
38
+ net.train()
39
+ running_loss = 0.0
40
+ for _ in range(epochs):
41
+ for batch in trainloader:
42
+ images = batch["img"]
43
+ labels = batch["label"]
44
+ optimizer.zero_grad()
45
+ loss = criterion(net(images.to(device)), labels.to(device))
46
+ loss.backward()
47
+ optimizer.step()
48
+ running_loss += loss.item()
49
+
50
+ avg_trainloss = running_loss / len(trainloader)
51
+ return avg_trainloss
52
+
53
+
54
+ def test(net, testloader, device):
55
+ """Validate the model on the test set."""
56
+ net.to(device)
57
+ criterion = torch.nn.CrossEntropyLoss()
58
+ correct, loss = 0, 0.0
59
+ with torch.no_grad():
60
+ for batch in testloader:
61
+ images = batch["img"].to(device)
62
+ labels = batch["label"].to(device)
63
+ outputs = net(images)
64
+ loss += criterion(outputs, labels).item()
65
+ correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
66
+ accuracy = correct / len(testloader.dataset)
67
+ loss = loss / len(testloader)
68
+ return loss, accuracy
69
+
70
+
71
+ def get_weights(net):
72
+ """Extract model parameters as numpy arrays from state_dict."""
73
+ return [val.cpu().numpy() for _, val in net.state_dict().items()]
74
+
75
+
76
+ def set_weights(net, parameters):
77
+ """Apply parameters to an existing model."""
78
+ params_dict = zip(net.state_dict().keys(), parameters)
79
+ state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
80
+ net.load_state_dict(state_dict, strict=True)
@@ -0,0 +1,46 @@
1
+ """$project_name: A Flower Baseline."""
2
+
3
+ from typing import List, Tuple
4
+
5
+ from flwr.common import Context, Metrics, ndarrays_to_parameters
6
+ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
7
+ from flwr.server.strategy import FedAvg
8
+ from $import_name.model import Net, get_weights
9
+
10
+
11
+ # Define metric aggregation function
12
+ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
13
+ """Do weighted average of accuracy metric."""
14
+ # Multiply accuracy of each client by number of examples used
15
+ accuracies = [num_examples * float(m["accuracy"]) for num_examples, m in metrics]
16
+ examples = [num_examples for num_examples, _ in metrics]
17
+
18
+ # Aggregate and return custom metric (weighted average)
19
+ return {"accuracy": sum(accuracies) / sum(examples)}
20
+
21
+
22
+ def server_fn(context: Context):
23
+ """Construct components that set the ServerApp behaviour."""
24
+ # Read from config
25
+ num_rounds = context.run_config["num-server-rounds"]
26
+ fraction_fit = context.run_config["fraction-fit"]
27
+
28
+ # Initialize model parameters
29
+ ndarrays = get_weights(Net())
30
+ parameters = ndarrays_to_parameters(ndarrays)
31
+
32
+ # Define strategy
33
+ strategy = FedAvg(
34
+ fraction_fit=float(fraction_fit),
35
+ fraction_evaluate=1.0,
36
+ min_available_clients=2,
37
+ initial_parameters=parameters,
38
+ evaluate_metrics_aggregation_fn=weighted_average,
39
+ )
40
+ config = ServerConfig(num_rounds=int(num_rounds))
41
+
42
+ return ServerAppComponents(strategy=strategy, config=config)
43
+
44
+
45
+ # Create ServerApp
46
+ app = ServerApp(server_fn=server_fn)
@@ -0,0 +1 @@
1
+ """$project_name: A Flower Baseline."""
@@ -0,0 +1 @@
1
+ """$project_name: A Flower Baseline."""
@@ -0,0 +1,138 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "$package_name"
7
+ version = "1.0.0"
8
+ description = ""
9
+ license = "Apache-2.0"
10
+ dependencies = [
11
+ "flwr[simulation]>=1.11.0",
12
+ "flwr-datasets[vision]>=0.3.0",
13
+ "torch==2.2.1",
14
+ "torchvision==0.17.1",
15
+ ]
16
+
17
+ [tool.hatch.metadata]
18
+ allow-direct-references = true
19
+
20
+ [project.optional-dependencies]
21
+ dev = [
22
+ "isort==5.13.2",
23
+ "black==24.2.0",
24
+ "docformatter==1.7.5",
25
+ "mypy==1.8.0",
26
+ "pylint==3.2.6",
27
+ "flake8==5.0.4",
28
+ "pytest==6.2.4",
29
+ "pytest-watch==4.2.0",
30
+ "ruff==0.1.9",
31
+ "types-requests==2.31.0.20240125",
32
+ ]
33
+
34
+ [tool.isort]
35
+ profile = "black"
36
+ known_first_party = ["flwr"]
37
+
38
+ [tool.black]
39
+ line-length = 88
40
+ target-version = ["py38", "py39", "py310", "py311"]
41
+
42
+ [tool.pytest.ini_options]
43
+ minversion = "6.2"
44
+ addopts = "-qq"
45
+ testpaths = [
46
+ "flwr_baselines",
47
+ ]
48
+
49
+ [tool.mypy]
50
+ ignore_missing_imports = true
51
+ strict = false
52
+ plugins = "numpy.typing.mypy_plugin"
53
+
54
+ [tool.pylint."MESSAGES CONTROL"]
55
+ disable = "duplicate-code,too-few-public-methods,useless-import-alias"
56
+ good-names = "i,j,k,_,x,y,X,Y,K,N"
57
+ max-args = 10
58
+ max-attributes = 15
59
+ max-locals = 36
60
+ max-branches = 20
61
+ max-statements = 55
62
+
63
+ [tool.pylint.typecheck]
64
+ generated-members = "numpy.*, torch.*, tensorflow.*"
65
+
66
+ [[tool.mypy.overrides]]
67
+ module = [
68
+ "importlib.metadata.*",
69
+ "importlib_metadata.*",
70
+ ]
71
+ follow_imports = "skip"
72
+ follow_imports_for_stubs = true
73
+ disallow_untyped_calls = false
74
+
75
+ [[tool.mypy.overrides]]
76
+ module = "torch.*"
77
+ follow_imports = "skip"
78
+ follow_imports_for_stubs = true
79
+
80
+ [tool.docformatter]
81
+ wrap-summaries = 88
82
+ wrap-descriptions = 88
83
+
84
+ [tool.ruff]
85
+ target-version = "py38"
86
+ line-length = 88
87
+ select = ["D", "E", "F", "W", "B", "ISC", "C4"]
88
+ fixable = ["D", "E", "F", "W", "B", "ISC", "C4"]
89
+ ignore = ["B024", "B027"]
90
+ exclude = [
91
+ ".bzr",
92
+ ".direnv",
93
+ ".eggs",
94
+ ".git",
95
+ ".hg",
96
+ ".mypy_cache",
97
+ ".nox",
98
+ ".pants.d",
99
+ ".pytype",
100
+ ".ruff_cache",
101
+ ".svn",
102
+ ".tox",
103
+ ".venv",
104
+ "__pypackages__",
105
+ "_build",
106
+ "buck-out",
107
+ "build",
108
+ "dist",
109
+ "node_modules",
110
+ "venv",
111
+ "proto",
112
+ ]
113
+
114
+ [tool.ruff.pydocstyle]
115
+ convention = "numpy"
116
+
117
+ [tool.hatch.build.targets.wheel]
118
+ packages = ["."]
119
+
120
+ [tool.flwr.app]
121
+ publisher = "$username"
122
+
123
+ [tool.flwr.app.components]
124
+ serverapp = "$import_name.server_app:app"
125
+ clientapp = "$import_name.client_app:app"
126
+
127
+ [tool.flwr.app.config]
128
+ num-server-rounds = 3
129
+ fraction-fit = 0.5
130
+ local-epochs = 1
131
+
132
+ [tool.flwr.federations]
133
+ default = "local-simulation"
134
+
135
+ [tool.flwr.federations.local-simulation]
136
+ options.num-supernodes = 10
137
+ options.backend.client-resources.num-cpus = 2
138
+ options.backend.client-resources.num-gpus = 0.0