flwr-nightly 1.9.0.dev20240506__py3-none-any.whl → 1.9.0.dev20240508__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

flwr/cli/new/new.py CHANGED
@@ -37,6 +37,8 @@ class MlFramework(str, Enum):
37
37
  NUMPY = "NumPy"
38
38
  PYTORCH = "PyTorch"
39
39
  TENSORFLOW = "TensorFlow"
40
+ HUGGINGFACE = "HF"
41
+ MLX = "MLX"
40
42
  SKLEARN = "sklearn"
41
43
 
42
44
 
@@ -111,7 +113,7 @@ def new(
111
113
  else:
112
114
  framework_value = prompt_options(
113
115
  "Please select ML framework by typing in the number",
114
- [mlf.value for mlf in MlFramework],
116
+ sorted([mlf.value for mlf in MlFramework]),
115
117
  )
116
118
  selected_value = [
117
119
  name
@@ -153,6 +155,8 @@ def new(
153
155
  # Depending on the framework, generate task.py file
154
156
  frameworks_with_tasks = [
155
157
  MlFramework.PYTORCH.value.lower(),
158
+ MlFramework.HUGGINGFACE.value.lower(),
159
+ MlFramework.MLX.value.lower(),
156
160
  MlFramework.TENSORFLOW.value.lower(),
157
161
  ]
158
162
  if framework_str in frameworks_with_tasks:
@@ -0,0 +1,55 @@
1
+ """$project_name: A Flower / HuggingFace Transformers app."""
2
+
3
+ from flwr.client import ClientApp, NumPyClient
4
+ from transformers import AutoModelForSequenceClassification
5
+
6
+ from $import_name.task import (
7
+ get_weights,
8
+ load_data,
9
+ set_weights,
10
+ train,
11
+ test,
12
+ CHECKPOINT,
13
+ DEVICE,
14
+ )
15
+
16
+
17
+ # Flower client
18
+ class FlowerClient(NumPyClient):
19
+ def __init__(self, net, trainloader, testloader):
20
+ self.net = net
21
+ self.trainloader = trainloader
22
+ self.testloader = testloader
23
+
24
+ def get_parameters(self, config):
25
+ return get_weights(self.net)
26
+
27
+ def set_parameters(self, parameters):
28
+ set_weights(self.net, parameters)
29
+
30
+ def fit(self, parameters, config):
31
+ self.set_parameters(parameters)
32
+ train(self.net, self.trainloader, epochs=1)
33
+ return self.get_parameters(config={}), len(self.trainloader), {}
34
+
35
+ def evaluate(self, parameters, config):
36
+ self.set_parameters(parameters)
37
+ loss, accuracy = test(self.net, self.testloader)
38
+ return float(loss), len(self.testloader), {"accuracy": accuracy}
39
+
40
+
41
+ def client_fn(cid):
42
+ # Load model and data
43
+ net = AutoModelForSequenceClassification.from_pretrained(
44
+ CHECKPOINT, num_labels=2
45
+ ).to(DEVICE)
46
+ trainloader, valloader = load_data(int(cid), 2)
47
+
48
+ # Return Client instance
49
+ return FlowerClient(net, trainloader, valloader).to_client()
50
+
51
+
52
+ # Flower ClientApp
53
+ app = ClientApp(
54
+ client_fn,
55
+ )
@@ -0,0 +1,70 @@
1
+ """$project_name: A Flower / MLX app."""
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import mlx.optimizers as optim
6
+ from flwr.client import NumPyClient, ClientApp
7
+
8
+ from $import_name.task import (
9
+ batch_iterate,
10
+ eval_fn,
11
+ get_params,
12
+ load_data,
13
+ loss_fn,
14
+ set_params,
15
+ MLP,
16
+ )
17
+
18
+
19
+ # Define Flower Client and client_fn
20
+ class FlowerClient(NumPyClient):
21
+ def __init__(self, data):
22
+ num_layers = 2
23
+ hidden_dim = 32
24
+ num_classes = 10
25
+ batch_size = 256
26
+ num_epochs = 1
27
+ learning_rate = 1e-1
28
+
29
+ self.train_images, self.train_labels, self.test_images, self.test_labels = data
30
+ self.model = MLP(num_layers, self.train_images.shape[-1], hidden_dim, num_classes)
31
+ self.optimizer = optim.SGD(learning_rate=learning_rate)
32
+ self.loss_and_grad_fn = nn.value_and_grad(self.model, loss_fn)
33
+ self.num_epochs = num_epochs
34
+ self.batch_size = batch_size
35
+
36
+ def get_parameters(self, config):
37
+ return get_params(self.model)
38
+
39
+ def set_parameters(self, parameters):
40
+ set_params(self.model, parameters)
41
+
42
+ def fit(self, parameters, config):
43
+ self.set_parameters(parameters)
44
+ for _ in range(self.num_epochs):
45
+ for X, y in batch_iterate(
46
+ self.batch_size, self.train_images, self.train_labels
47
+ ):
48
+ _, grads = self.loss_and_grad_fn(self.model, X, y)
49
+ self.optimizer.update(self.model, grads)
50
+ mx.eval(self.model.parameters(), self.optimizer.state)
51
+ return self.get_parameters(config={}), len(self.train_images), {}
52
+
53
+ def evaluate(self, parameters, config):
54
+ self.set_parameters(parameters)
55
+ accuracy = eval_fn(self.model, self.test_images, self.test_labels)
56
+ loss = loss_fn(self.model, self.test_images, self.test_labels)
57
+ return loss.item(), len(self.test_images), {"accuracy": accuracy.item()}
58
+
59
+
60
+ def client_fn(cid):
61
+ data = load_data(int(cid), 2)
62
+
63
+ # Return Client instance
64
+ return FlowerClient(data).to_client()
65
+
66
+
67
+ # Flower ClientApp
68
+ app = ClientApp(
69
+ client_fn,
70
+ )
@@ -0,0 +1,17 @@
1
+ """$project_name: A Flower / HuggingFace Transformers app."""
2
+
3
+ from flwr.server.strategy import FedAvg
4
+ from flwr.server import ServerApp, ServerConfig
5
+
6
+
7
+ # Define strategy
8
+ strategy = FedAvg(
9
+ fraction_fit=1.0,
10
+ fraction_evaluate=1.0,
11
+ )
12
+
13
+ # Start server
14
+ app = ServerApp(
15
+ config=ServerConfig(num_rounds=3),
16
+ strategy=strategy,
17
+ )
@@ -0,0 +1,15 @@
1
+ """$project_name: A Flower / MLX app."""
2
+
3
+ from flwr.server import ServerApp, ServerConfig
4
+ from flwr.server.strategy import FedAvg
5
+
6
+
7
+ # Define strategy
8
+ strategy = FedAvg()
9
+
10
+
11
+ # Create ServerApp
12
+ app = ServerApp(
13
+ config=ServerConfig(num_rounds=3),
14
+ strategy=strategy,
15
+ )
@@ -0,0 +1,87 @@
1
+ """$project_name: A Flower / HuggingFace Transformers app."""
2
+
3
+ import warnings
4
+ from collections import OrderedDict
5
+
6
+ import torch
7
+ from evaluate import load as load_metric
8
+ from torch.optim import AdamW
9
+ from torch.utils.data import DataLoader
10
+ from transformers import AutoTokenizer, DataCollatorWithPadding
11
+
12
+ from flwr_datasets import FederatedDataset
13
+
14
+ warnings.filterwarnings("ignore", category=UserWarning)
15
+ DEVICE = torch.device("cpu")
16
+ CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint
17
+
18
+
19
+ def load_data(partition_id, num_clients):
20
+ """Load IMDB data (training and eval)"""
21
+ fds = FederatedDataset(dataset="imdb", partitioners={"train": num_clients})
22
+ partition = fds.load_partition(partition_id)
23
+ # Divide data: 80% train, 20% test
24
+ partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
25
+
26
+ tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
27
+
28
+ def tokenize_function(examples):
29
+ return tokenizer(examples["text"], truncation=True)
30
+
31
+ partition_train_test = partition_train_test.map(tokenize_function, batched=True)
32
+ partition_train_test = partition_train_test.remove_columns("text")
33
+ partition_train_test = partition_train_test.rename_column("label", "labels")
34
+
35
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
36
+ trainloader = DataLoader(
37
+ partition_train_test["train"],
38
+ shuffle=True,
39
+ batch_size=32,
40
+ collate_fn=data_collator,
41
+ )
42
+
43
+ testloader = DataLoader(
44
+ partition_train_test["test"], batch_size=32, collate_fn=data_collator
45
+ )
46
+
47
+ return trainloader, testloader
48
+
49
+
50
+ def train(net, trainloader, epochs):
51
+ optimizer = AdamW(net.parameters(), lr=5e-5)
52
+ net.train()
53
+ for _ in range(epochs):
54
+ for batch in trainloader:
55
+ batch = {k: v.to(DEVICE) for k, v in batch.items()}
56
+ outputs = net(**batch)
57
+ loss = outputs.loss
58
+ loss.backward()
59
+ optimizer.step()
60
+ optimizer.zero_grad()
61
+
62
+
63
+ def test(net, testloader):
64
+ metric = load_metric("accuracy")
65
+ loss = 0
66
+ net.eval()
67
+ for batch in testloader:
68
+ batch = {k: v.to(DEVICE) for k, v in batch.items()}
69
+ with torch.no_grad():
70
+ outputs = net(**batch)
71
+ logits = outputs.logits
72
+ loss += outputs.loss.item()
73
+ predictions = torch.argmax(logits, dim=-1)
74
+ metric.add_batch(predictions=predictions, references=batch["labels"])
75
+ loss /= len(testloader.dataset)
76
+ accuracy = metric.compute()["accuracy"]
77
+ return loss, accuracy
78
+
79
+
80
+ def get_weights(net):
81
+ return [val.cpu().numpy() for _, val in net.state_dict().items()]
82
+
83
+
84
+ def set_weights(net, parameters):
85
+ params_dict = zip(net.state_dict().keys(), parameters)
86
+ state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
87
+ net.load_state_dict(state_dict, strict=True)
@@ -0,0 +1,89 @@
1
+ """$project_name: A Flower / MLX app."""
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+ from datasets.utils.logging import disable_progress_bar
7
+ from flwr_datasets import FederatedDataset
8
+
9
+
10
+ disable_progress_bar()
11
+
12
+ class MLP(nn.Module):
13
+ """A simple MLP."""
14
+
15
+ def __init__(
16
+ self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
17
+ ):
18
+ super().__init__()
19
+ layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
20
+ self.layers = [
21
+ nn.Linear(idim, odim)
22
+ for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
23
+ ]
24
+
25
+ def __call__(self, x):
26
+ for l in self.layers[:-1]:
27
+ x = mx.maximum(l(x), 0.0)
28
+ return self.layers[-1](x)
29
+
30
+
31
+ def loss_fn(model, X, y):
32
+ return mx.mean(nn.losses.cross_entropy(model(X), y))
33
+
34
+
35
+ def eval_fn(model, X, y):
36
+ return mx.mean(mx.argmax(model(X), axis=1) == y)
37
+
38
+
39
+ def batch_iterate(batch_size, X, y):
40
+ perm = mx.array(np.random.permutation(y.size))
41
+ for s in range(0, y.size, batch_size):
42
+ ids = perm[s : s + batch_size]
43
+ yield X[ids], y[ids]
44
+
45
+
46
+ def load_data(partition_id, num_clients):
47
+ fds = FederatedDataset(dataset="mnist", partitioners={"train": num_clients})
48
+ partition = fds.load_partition(partition_id)
49
+ partition_splits = partition.train_test_split(test_size=0.2, seed=42)
50
+
51
+ partition_splits["train"].set_format("numpy")
52
+ partition_splits["test"].set_format("numpy")
53
+
54
+ train_partition = partition_splits["train"].map(
55
+ lambda img: {
56
+ "img": img.reshape(-1, 28 * 28).squeeze().astype(np.float32) / 255.0
57
+ },
58
+ input_columns="image",
59
+ )
60
+ test_partition = partition_splits["test"].map(
61
+ lambda img: {
62
+ "img": img.reshape(-1, 28 * 28).squeeze().astype(np.float32) / 255.0
63
+ },
64
+ input_columns="image",
65
+ )
66
+
67
+ data = (
68
+ train_partition["img"],
69
+ train_partition["label"].astype(np.uint32),
70
+ test_partition["img"],
71
+ test_partition["label"].astype(np.uint32),
72
+ )
73
+
74
+ train_images, train_labels, test_images, test_labels = map(mx.array, data)
75
+ return train_images, train_labels, test_images, test_labels
76
+
77
+
78
+ def get_params(model):
79
+ layers = model.parameters()["layers"]
80
+ return [np.array(val) for layer in layers for _, val in layer.items()]
81
+
82
+
83
+ def set_params(model, parameters):
84
+ new_params = {}
85
+ new_params["layers"] = [
86
+ {"weight": mx.array(parameters[i]), "bias": mx.array(parameters[i + 1])}
87
+ for i in range(0, len(parameters), 2)
88
+ ]
89
+ model.update(new_params)
@@ -0,0 +1,31 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "$package_name"
7
+ version = "1.0.0"
8
+ description = ""
9
+ authors = [
10
+ { name = "The Flower Authors", email = "hello@flower.ai" },
11
+ ]
12
+ license = { text = "Apache License (2.0)" }
13
+ dependencies = [
14
+ "flwr[simulation]>=1.8.0,<2.0",
15
+ "flwr-datasets>=0.0.2,<1.0.0",
16
+ "torch==2.2.1",
17
+ "transformers>=4.30.0,<5.0"
18
+ "evaluate>=0.4.0,<1.0"
19
+ "datasets>=2.0.0, <3.0"
20
+ "scikit-learn>=1.3.1, <2.0"
21
+ ]
22
+
23
+ [tool.hatch.build.targets.wheel]
24
+ packages = ["."]
25
+
26
+ [flower]
27
+ publisher = "$username"
28
+
29
+ [flower.components]
30
+ serverapp = "$import_name.server:app"
31
+ clientapp = "$import_name.client:app"
@@ -0,0 +1,28 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "$package_name"
7
+ version = "1.0.0"
8
+ description = ""
9
+ authors = [
10
+ { name = "The Flower Authors", email = "hello@flower.ai" },
11
+ ]
12
+ license = { text = "Apache License (2.0)" }
13
+ dependencies = [
14
+ "flwr[simulation]>=1.8.0,<2.0",
15
+ "flwr-datasets[vision]>=0.0.2,<1.0.0",
16
+ "mlx==0.10.0",
17
+ "numpy==1.24.4",
18
+ ]
19
+
20
+ [tool.hatch.build.targets.wheel]
21
+ packages = ["."]
22
+
23
+ [flower]
24
+ publisher = "$username"
25
+
26
+ [flower.components]
27
+ serverapp = "$import_name.server:app"
28
+ clientapp = "$import_name.client:app"
flwr/server/__init__.py CHANGED
@@ -24,7 +24,6 @@ from .app import start_server as start_server
24
24
  from .client_manager import ClientManager as ClientManager
25
25
  from .client_manager import SimpleClientManager as SimpleClientManager
26
26
  from .compat import LegacyContext as LegacyContext
27
- from .compat import start_driver as start_driver
28
27
  from .driver import Driver as Driver
29
28
  from .history import History as History
30
29
  from .run_serverapp import run_server_app as run_server_app
@@ -45,7 +44,6 @@ __all__ = [
45
44
  "ServerApp",
46
45
  "ServerConfig",
47
46
  "SimpleClientManager",
48
- "start_driver",
49
47
  "start_server",
50
48
  "strategy",
51
49
  "workflow",
flwr/server/app.py CHANGED
@@ -41,7 +41,7 @@ from flwr.common.constant import (
41
41
  TRANSPORT_TYPE_VCE,
42
42
  )
43
43
  from flwr.common.exit_handlers import register_exit_handlers
44
- from flwr.common.logger import log
44
+ from flwr.common.logger import log, warn_deprecated_feature
45
45
  from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
46
46
  private_key_to_bytes,
47
47
  public_key_to_bytes,
@@ -196,6 +196,9 @@ def start_server( # pylint: disable=too-many-arguments,too-many-locals
196
196
  def run_driver_api() -> None:
197
197
  """Run Flower server (Driver API)."""
198
198
  log(INFO, "Starting Flower server (Driver API)")
199
+ # Running `flower-driver-api` is deprecated
200
+ warn_deprecated_feature("flower-driver-api")
201
+ log(WARN, "Use `flower-superlink` instead")
199
202
  event(EventType.RUN_DRIVER_API_ENTER)
200
203
  args = _parse_args_run_driver_api().parse_args()
201
204
 
@@ -233,6 +236,9 @@ def run_driver_api() -> None:
233
236
  def run_fleet_api() -> None:
234
237
  """Run Flower server (Fleet API)."""
235
238
  log(INFO, "Starting Flower server (Fleet API)")
239
+ # Running `flower-fleet-api` is deprecated
240
+ warn_deprecated_feature("flower-fleet-api")
241
+ log(WARN, "Use `flower-superlink` instead")
236
242
  event(EventType.RUN_FLEET_API_ENTER)
237
243
  args = _parse_args_run_fleet_api().parse_args()
238
244
 
flwr/server/compat/app.py CHANGED
@@ -15,50 +15,35 @@
15
15
  """Flower driver app."""
16
16
 
17
17
 
18
- import sys
19
18
  from logging import INFO
20
- from pathlib import Path
21
- from typing import Optional, Union
19
+ from typing import Optional
22
20
 
23
21
  from flwr.common import EventType, event
24
- from flwr.common.address import parse_address
25
- from flwr.common.logger import log, warn_deprecated_feature
22
+ from flwr.common.logger import log
26
23
  from flwr.server.client_manager import ClientManager
27
24
  from flwr.server.history import History
28
25
  from flwr.server.server import Server, init_defaults, run_fl
29
26
  from flwr.server.server_config import ServerConfig
30
27
  from flwr.server.strategy import Strategy
31
28
 
32
- from ..driver import Driver, GrpcDriver
29
+ from ..driver import Driver
33
30
  from .app_utils import start_update_client_manager_thread
34
31
 
35
- DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
36
-
37
- ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
38
- [Driver] Error: Not connected.
39
-
40
- Call `connect()` on the `Driver` instance before calling any of the other `Driver`
41
- methods.
42
- """
43
-
44
32
 
45
33
  def start_driver( # pylint: disable=too-many-arguments, too-many-locals
46
34
  *,
47
- server_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
35
+ driver: Driver,
48
36
  server: Optional[Server] = None,
49
37
  config: Optional[ServerConfig] = None,
50
38
  strategy: Optional[Strategy] = None,
51
39
  client_manager: Optional[ClientManager] = None,
52
- root_certificates: Optional[Union[bytes, str]] = None,
53
- driver: Optional[Driver] = None,
54
40
  ) -> History:
55
41
  """Start a Flower Driver API server.
56
42
 
57
43
  Parameters
58
44
  ----------
59
- server_address : Optional[str]
60
- The IPv4 or IPv6 address of the Driver API server.
61
- Defaults to `"[::]:8080"`.
45
+ driver : Driver
46
+ The Driver object to use.
62
47
  server : Optional[flwr.server.Server] (default: None)
63
48
  A server implementation, either `flwr.server.Server` or a subclass
64
49
  thereof. If no instance is provided, then `start_driver` will create
@@ -74,50 +59,14 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
74
59
  An implementation of the class `flwr.server.ClientManager`. If no
75
60
  implementation is provided, then `start_driver` will use
76
61
  `flwr.server.SimpleClientManager`.
77
- root_certificates : Optional[Union[bytes, str]] (default: None)
78
- The PEM-encoded root certificates as a byte string or a path string.
79
- If provided, a secure connection using the certificates will be
80
- established to an SSL-enabled Flower server.
81
- driver : Optional[Driver] (default: None)
82
- The Driver object to use.
83
62
 
84
63
  Returns
85
64
  -------
86
65
  hist : flwr.server.history.History
87
66
  Object containing training and evaluation metrics.
88
-
89
- Examples
90
- --------
91
- Starting a driver that connects to an insecure server:
92
-
93
- >>> start_driver()
94
-
95
- Starting a driver that connects to an SSL-enabled server:
96
-
97
- >>> start_driver(
98
- >>> root_certificates=Path("/crts/root.pem").read_bytes()
99
- >>> )
100
67
  """
101
68
  event(EventType.START_DRIVER_ENTER)
102
69
 
103
- if driver is None:
104
- # Not passing a `Driver` object is deprecated
105
- warn_deprecated_feature("start_driver")
106
-
107
- # Parse IP address
108
- parsed_address = parse_address(server_address)
109
- if not parsed_address:
110
- sys.exit(f"Server IP address ({server_address}) cannot be parsed.")
111
- host, port, is_v6 = parsed_address
112
- address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
113
-
114
- # Create the Driver
115
- if isinstance(root_certificates, str):
116
- root_certificates = Path(root_certificates).read_bytes()
117
- driver = GrpcDriver(
118
- driver_service_address=address, root_certificates=root_certificates
119
- )
120
-
121
70
  # Initialize the Driver API server and config
122
71
  initialized_server, initialized_config = init_defaults(
123
72
  server=server,
flwr/server/history.py CHANGED
@@ -91,32 +91,32 @@ class History:
91
91
  """
92
92
  rep = ""
93
93
  if self.losses_distributed:
94
- rep += "History (loss, distributed):\n" + pprint.pformat(
95
- reduce(
96
- lambda a, b: a + b,
97
- [
98
- f"\tround {server_round}: {loss}\n"
99
- for server_round, loss in self.losses_distributed
100
- ],
101
- )
94
+ rep += "History (loss, distributed):\n" + reduce(
95
+ lambda a, b: a + b,
96
+ [
97
+ f"\tround {server_round}: {loss}\n"
98
+ for server_round, loss in self.losses_distributed
99
+ ],
102
100
  )
103
101
  if self.losses_centralized:
104
- rep += "History (loss, centralized):\n" + pprint.pformat(
105
- reduce(
106
- lambda a, b: a + b,
107
- [
108
- f"\tround {server_round}: {loss}\n"
109
- for server_round, loss in self.losses_centralized
110
- ],
111
- )
102
+ rep += "History (loss, centralized):\n" + reduce(
103
+ lambda a, b: a + b,
104
+ [
105
+ f"\tround {server_round}: {loss}\n"
106
+ for server_round, loss in self.losses_centralized
107
+ ],
112
108
  )
113
109
  if self.metrics_distributed_fit:
114
- rep += "History (metrics, distributed, fit):\n" + pprint.pformat(
115
- self.metrics_distributed_fit
110
+ rep += (
111
+ "History (metrics, distributed, fit):\n"
112
+ + pprint.pformat(self.metrics_distributed_fit)
113
+ + "\n"
116
114
  )
117
115
  if self.metrics_distributed:
118
- rep += "History (metrics, distributed, evaluate):\n" + pprint.pformat(
119
- self.metrics_distributed
116
+ rep += (
117
+ "History (metrics, distributed, evaluate):\n"
118
+ + pprint.pformat(self.metrics_distributed)
119
+ + "\n"
120
120
  )
121
121
  if self.metrics_centralized:
122
122
  rep += "History (metrics, centralized):\n" + pprint.pformat(
flwr/server/server.py CHANGED
@@ -487,11 +487,8 @@ def run_fl(
487
487
  log(INFO, "")
488
488
  log(INFO, "[SUMMARY]")
489
489
  log(INFO, "Run finished %s rounds in %.2fs", config.num_rounds, elapsed_time)
490
- for idx, line in enumerate(io.StringIO(str(hist))):
491
- if idx == 0:
492
- log(INFO, "%s", line.strip("\n"))
493
- else:
494
- log(INFO, "\t%s", line.strip("\n"))
490
+ for line in io.StringIO(str(hist)):
491
+ log(INFO, "\t%s", line.strip("\n"))
495
492
  log(INFO, "")
496
493
 
497
494
  # Graceful shutdown
@@ -15,7 +15,7 @@
15
15
  """Ray backend for the Fleet API using the Simulation Engine."""
16
16
 
17
17
  import pathlib
18
- from logging import ERROR, INFO
18
+ from logging import DEBUG, ERROR, INFO, WARNING
19
19
  from typing import Callable, Dict, List, Tuple, Union
20
20
 
21
21
  import ray
@@ -46,7 +46,7 @@ class RayBackend(Backend):
46
46
  ) -> None:
47
47
  """Prepare RayBackend by initialising Ray and creating the ActorPool."""
48
48
  log(INFO, "Initialising: %s", self.__class__.__name__)
49
- log(INFO, "Backend config: %s", backend_config)
49
+ log(DEBUG, "Backend config: %s", backend_config)
50
50
 
51
51
  if not pathlib.Path(work_dir).exists():
52
52
  raise ValueError(f"Specified work_dir {work_dir} does not exist.")
@@ -55,7 +55,10 @@ class RayBackend(Backend):
55
55
  runtime_env = (
56
56
  self._configure_runtime_env(work_dir=work_dir) if work_dir else None
57
57
  )
58
- init_ray(runtime_env=runtime_env)
58
+ if backend_config.get("silent", False):
59
+ init_ray(logging_level=WARNING, log_to_driver=True, runtime_env=runtime_env)
60
+ else:
61
+ init_ray(runtime_env=runtime_env)
59
62
 
60
63
  # Validate client resources
61
64
  self.client_resources_key = "client_resources"
@@ -109,7 +112,7 @@ class RayBackend(Backend):
109
112
  else:
110
113
  client_resources = {"num_cpus": 2, "num_gpus": 0.0}
111
114
  log(
112
- INFO,
115
+ DEBUG,
113
116
  "`%s` not specified in backend config. Applying default setting: %s",
114
117
  self.client_resources_key,
115
118
  client_resources,
@@ -129,7 +132,7 @@ class RayBackend(Backend):
129
132
  async def build(self) -> None:
130
133
  """Build pool of Ray actors that this backend will submit jobs to."""
131
134
  await self.pool.add_actors_to_pool(self.pool.actors_capacity)
132
- log(INFO, "Constructed ActorPool with: %i actors", self.pool.num_actors)
135
+ log(DEBUG, "Constructed ActorPool with: %i actors", self.pool.num_actors)
133
136
 
134
137
  async def process_message(
135
138
  self,
@@ -173,4 +176,4 @@ class RayBackend(Backend):
173
176
  """Terminate all actors in actor pool."""
174
177
  await self.pool.terminate_all_actors()
175
178
  ray.shutdown()
176
- log(INFO, "Terminated %s", self.__class__.__name__)
179
+ log(DEBUG, "Terminated %s", self.__class__.__name__)
@@ -293,7 +293,7 @@ def start_vce(
293
293
  node_states[node_id] = NodeState()
294
294
 
295
295
  # Load backend config
296
- log(INFO, "Supported backends: %s", list(supported_backends.keys()))
296
+ log(DEBUG, "Supported backends: %s", list(supported_backends.keys()))
297
297
  backend_config = json.loads(backend_config_json_stream)
298
298
 
299
299
  try:
@@ -155,7 +155,7 @@ def run_serverapp_th(
155
155
  # Upon completion, trigger stop event if one was passed
156
156
  if stop_event is not None:
157
157
  stop_event.set()
158
- log(WARNING, "Triggered stop event for Simulation Engine.")
158
+ log(DEBUG, "Triggered stop event for Simulation Engine.")
159
159
 
160
160
  serverapp_th = threading.Thread(
161
161
  target=server_th_with_start_checks,
@@ -249,7 +249,7 @@ def _main_loop(
249
249
  if serverapp_th:
250
250
  serverapp_th.join()
251
251
 
252
- log(INFO, "Stopping Simulation Engine now.")
252
+ log(DEBUG, "Stopping Simulation Engine now.")
253
253
 
254
254
 
255
255
  # pylint: disable=too-many-arguments,too-many-locals
@@ -317,13 +317,15 @@ def _run_simulation(
317
317
  When diabled, only INFO, WARNING and ERROR log messages will be shown. If
318
318
  enabled, DEBUG-level logs will be displayed.
319
319
  """
320
+ if backend_config is None:
321
+ backend_config = {}
322
+
320
323
  # Set logging level
321
324
  logger = logging.getLogger("flwr")
322
325
  if verbose_logging:
323
326
  update_console_handler(level=DEBUG, timestamps=True, colored=True)
324
-
325
- if backend_config is None:
326
- backend_config = {}
327
+ else:
328
+ backend_config["silent"] = True
327
329
 
328
330
  if enable_tf_gpu_growth:
329
331
  # Check that Backend config has also enabled using GPU growth
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flwr-nightly
3
- Version: 1.9.0.dev20240506
3
+ Version: 1.9.0.dev20240508
4
4
  Summary: Flower: A Friendly Federated Learning Framework
5
5
  Home-page: https://flower.ai
6
6
  License: Apache-2.0
@@ -194,7 +194,7 @@ Other [examples](https://github.com/adap/flower/tree/main/examples):
194
194
  - [PyTorch: From Centralized to Federated](https://github.com/adap/flower/tree/main/examples/pytorch-from-centralized-to-federated)
195
195
  - [Vertical FL](https://github.com/adap/flower/tree/main/examples/vertical-fl)
196
196
  - [Federated Finetuning of OpenAI's Whisper](https://github.com/adap/flower/tree/main/examples/whisper-federated-finetuning)
197
- - [Federated Finetuning of Large Language Model](https://github.com/adap/flower/tree/main/examples/fedllm-finetune)
197
+ - [Federated Finetuning of Large Language Model](https://github.com/adap/flower/tree/main/examples/llm-flowertune)
198
198
  - [Federated Finetuning of a Vision Transformer](https://github.com/adap/flower/tree/main/examples/vit-finetune)
199
199
  - [Advanced Flower with TensorFlow/Keras](https://github.com/adap/flower/tree/main/examples/advanced-tensorflow)
200
200
  - [Advanced Flower with PyTorch](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)
@@ -5,23 +5,31 @@ flwr/cli/build.py,sha256=W30wnPSgFuHRnGB9G_vKO14rsaibWk7m-jv9r8rDqo4,5106
5
5
  flwr/cli/config_utils.py,sha256=Hql5A5hbSpJ51hgpwaTkKqfPoaZN4Zq7FZfBuQYLMcQ,4899
6
6
  flwr/cli/example.py,sha256=1bGDYll3BXQY2kRqSN-oICqS5n1b9m0g0RvXTopXHl4,2215
7
7
  flwr/cli/new/__init__.py,sha256=cQzK1WH4JP2awef1t2UQ2xjl1agVEz9rwutV18SWV1k,789
8
- flwr/cli/new/new.py,sha256=x0cYNCYTCwbWiM7K58y4ViJl-Hd_pZ7jUmgaCNSP9v8,6035
8
+ flwr/cli/new/new.py,sha256=whQvNN-r_opeAEpB8i7X21u53FMUKOKWbdY8gJVY-L8,6168
9
9
  flwr/cli/new/templates/__init__.py,sha256=4luU8RL-CK8JJCstQ_ON809W9bNTkY1l9zSaPKBkgwY,725
10
10
  flwr/cli/new/templates/app/.gitignore.tpl,sha256=XixnHdyeMB2vwkGtGnwHqoWpH-9WChdyG0GXe57duhc,3078
11
11
  flwr/cli/new/templates/app/README.md.tpl,sha256=_qGtgpKYKoCJVjQnvlBMKvFs_1gzTcL908I3KJg0oAM,668
12
12
  flwr/cli/new/templates/app/__init__.py,sha256=DU7QMY7IhMQyuwm_tja66xU0KXTWQFqzfTqwg-_NJdE,729
13
13
  flwr/cli/new/templates/app/code/__init__.py,sha256=EM6vfvgAILKPaPn7H1wMV1Wi01WyZCP_Eg6NxD6oWg8,736
14
14
  flwr/cli/new/templates/app/code/__init__.py.tpl,sha256=olwrBeJemHNBWvjc6gJURloFRqW40dAy7FRQA5pDqHU,21
15
+ flwr/cli/new/templates/app/code/client.hf.py.tpl,sha256=RaN89A8HgKp6kjhzH8tgtDSWW8BwwcvJdqRLcvG04zw,1450
16
+ flwr/cli/new/templates/app/code/client.mlx.py.tpl,sha256=53wJy6s3zk4CZwob_qPmMoOqJ-LZNKbdDe_hw5LwOXE,2113
15
17
  flwr/cli/new/templates/app/code/client.numpy.py.tpl,sha256=mTh7Y_jOJrPUvDYHVJy4wJCnjXZV_q-jlDkB07U5GSk,521
16
18
  flwr/cli/new/templates/app/code/client.pytorch.py.tpl,sha256=MgCtMSv1Th16Faod11HubVaARkLYt7vS9RYH962-2pk,1172
17
19
  flwr/cli/new/templates/app/code/client.sklearn.py.tpl,sha256=S71SZiHaRXtKqUk3m5Elc_c6HhKAIKLalrKOQ3p20No,2801
18
20
  flwr/cli/new/templates/app/code/client.tensorflow.py.tpl,sha256=dxrTO9JwYrDBjLsmCiRLetN9KxbnWRTeGA0BQbnOu_A,1280
21
+ flwr/cli/new/templates/app/code/server.hf.py.tpl,sha256=Mld452y3SUkejlFzac5hpCjT7_mbA0ZEEMJIUyHtSTI,338
22
+ flwr/cli/new/templates/app/code/server.mlx.py.tpl,sha256=Cqk3PvM0e7hzohXPqD5hG_cthXoxCfc30bpEThqMy7M,272
19
23
  flwr/cli/new/templates/app/code/server.numpy.py.tpl,sha256=fRxrDXV7pB1aDhQUXMBmrCsC1zp0uKwsBxZBx1JzbHA,248
20
24
  flwr/cli/new/templates/app/code/server.pytorch.py.tpl,sha256=ltdsnFSvFGPcycVmRL4ITlr-TV0CmmXcperZe7Vamow,593
21
25
  flwr/cli/new/templates/app/code/server.sklearn.py.tpl,sha256=cLzOpQzGIUzEazuFsjBpXAQUNPy6in6zR33SCqhix6o,341
22
26
  flwr/cli/new/templates/app/code/server.tensorflow.py.tpl,sha256=gsNrWCKTU77_65_gw9nlp1LSQojgP5QQIWILvqdjx2s,579
27
+ flwr/cli/new/templates/app/code/task.hf.py.tpl,sha256=Rw8cnds4Ym8o8TOq6kMkwlBJfIfvsfnb02jwyulOgF8,2857
28
+ flwr/cli/new/templates/app/code/task.mlx.py.tpl,sha256=y7aVj3F_98-wBnDcbPsCNnFs9BOHTn0y6XIYkByzv7Y,2598
23
29
  flwr/cli/new/templates/app/code/task.pytorch.py.tpl,sha256=NvajdZN-eTyfdqKK0v2MrvWITXw9BjJ3Ri5c1haPJDs,3684
24
30
  flwr/cli/new/templates/app/code/task.tensorflow.py.tpl,sha256=cPOUUS07QbblT9PGFucwu9lY1clRA4-W4DQGA7cpcao,1044
31
+ flwr/cli/new/templates/app/pyproject.hf.toml.tpl,sha256=PNGBNTfWmNJ23aVnW5f1TMMJ0uEwIljevpOsI-mqX08,676
32
+ flwr/cli/new/templates/app/pyproject.mlx.toml.tpl,sha256=JCEsuHZffO1KKkN65rSp6N-A9-OW8-kl6EQp5Z2H3uE,585
25
33
  flwr/cli/new/templates/app/pyproject.numpy.toml.tpl,sha256=m276SKsjOZ4awGdXasUKvLim66agrpAsPNP9-PN6q4I,523
26
34
  flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl,sha256=QikP3u5ht6qr2BkgcnvB3rCYK7jt1cS0nAm7V8g_zFc,592
27
35
  flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl,sha256=IO5iIlyKSBxZCCf48iqEyRWeG1jmVx2tO_s2iE7FpHo,572
@@ -125,12 +133,12 @@ flwr/proto/transport_pb2.pyi,sha256=CZvJRWTU3QWFWLXNFtyLSrSKFatIyMcy-ohzLbQ-G9c,
125
133
  flwr/proto/transport_pb2_grpc.py,sha256=vLN3EHtx2aEEMCO4f1Upu-l27BPzd3-5pV-u8wPcosk,2598
126
134
  flwr/proto/transport_pb2_grpc.pyi,sha256=AGXf8RiIiW2J5IKMlm_3qT3AzcDa4F3P5IqUjve_esA,766
127
135
  flwr/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
128
- flwr/server/__init__.py,sha256=dNLbXIERZ6X9aA_Bit3R9AARwcaZZzEfDuFmEx8VVOE,1785
129
- flwr/server/app.py,sha256=RUSgmhMm-U5FVZo3jk59t4o6v0JU-Y8avs_yVDHKWJM,28600
136
+ flwr/server/__init__.py,sha256=PWyHKu-_KFxGI7oFWSWwqMfTiG_phWECT80iv0saouA,1716
137
+ flwr/server/app.py,sha256=95U1IO07ngLy2lkDTvsoI5XesXBN7mpcT1wNhPSgXTI,28913
130
138
  flwr/server/client_manager.py,sha256=T8UDSRJBVD3fyIDI7NTAA-NA7GPrMNNgH2OAF54RRxE,6127
131
139
  flwr/server/client_proxy.py,sha256=4G-oTwhb45sfWLx2uZdcXD98IZwdTS6F88xe3akCdUg,2399
132
140
  flwr/server/compat/__init__.py,sha256=VxnJtJyOjNFQXMNi9hIuzNlZM5n0Hj1p3aq_Pm2udw4,892
133
- flwr/server/compat/app.py,sha256=BhF3DySbvKkOIyNXnB1rwZhw8cC8yK_w91Fku8HmC_w,5287
141
+ flwr/server/compat/app.py,sha256=0jajWbEiU_B5FGBcoyss_3FTfCmljAhJXM2dGyVrKuI,3421
134
142
  flwr/server/compat/app_utils.py,sha256=06NHrPRPrjMjz5FglSPicJ9lAWZ-rIZ1cKQFs4nD6WI,3468
135
143
  flwr/server/compat/driver_client_proxy.py,sha256=Wc6jyyHY4OrJzeiy8tdXtkF8IdGREdxUPnom7VvvWPI,5444
136
144
  flwr/server/compat/legacy_context.py,sha256=D2s7PvQoDnTexuRmf1uG9Von7GUj4Qqyr7qLklSlKAM,1766
@@ -138,9 +146,9 @@ flwr/server/criterion.py,sha256=ypbAexbztzGUxNen9RCHF91QeqiEQix4t4Ih3E-42MM,1061
138
146
  flwr/server/driver/__init__.py,sha256=bbVL5pyA0Y2HcUK4s5U0B4epI-BuUFyEJbchew_8tJY,862
139
147
  flwr/server/driver/driver.py,sha256=t9SSSDlo9wT_y2Nl7waGYMTm2VlkvK3_bOb7ggPPlho,5090
140
148
  flwr/server/driver/grpc_driver.py,sha256=rdjkcAmtRWKeqJw4xDFqULuwVf0G2nLhfbOTrNUvPeY,11832
141
- flwr/server/history.py,sha256=hDsoBaA4kUa6d1yvDVXuLluBqOBKSm0_fVDtUtYJkmg,5121
149
+ flwr/server/history.py,sha256=bBOHKyX1eQONIsUx4EUU-UnAk1i0EbEl8ioyMq_UWQ8,5063
142
150
  flwr/server/run_serverapp.py,sha256=avLi_yRNE5jD2ql95gzh04BTUbHvzH-N848_mdnnkVk,5972
143
- flwr/server/server.py,sha256=UnBRlI6AGTj0nKeRtEQ3IalM3TJmggMKXhDyn8yKZNk,17664
151
+ flwr/server/server.py,sha256=0QJ0gZ1bjxOpiWQPxXCXVFT5DcGOBc-57Omd8uq4YMM,17563
144
152
  flwr/server/server_app.py,sha256=KgAT_HqsfseTLNnfX2ph42PBbVqQ0lFzvYrT90V34y0,4402
145
153
  flwr/server/server_config.py,sha256=CZaHVAsMvGLjpWVcLPkiYxgJN4xfIyAiUrCI3fETKY4,1349
146
154
  flwr/server/strategy/__init__.py,sha256=7eVZ3hQEg2BgA_usAeL6tsLp9T6XI1VYYoFy08Xn-ew,2836
@@ -187,8 +195,8 @@ flwr/server/superlink/fleet/rest_rere/rest_api.py,sha256=8gNziOjBA8ygTzfVPYiNkg_
187
195
  flwr/server/superlink/fleet/vce/__init__.py,sha256=36MHKiefnJeyjwMQzVUK4m06Ojon3WDcwZGQsAcyVhQ,783
188
196
  flwr/server/superlink/fleet/vce/backend/__init__.py,sha256=oBIzmnrSSRvH_H0vRGEGWhWzQQwqe3zn6e13RsNwlIY,1466
189
197
  flwr/server/superlink/fleet/vce/backend/backend.py,sha256=LJsKl7oixVvptcG98Rd9ejJycNWcEVB0ODvSreLGp-A,2260
190
- flwr/server/superlink/fleet/vce/backend/raybackend.py,sha256=TaT2EpbVEsIY0EDzF8obadyZaSXjD38TFGdDPI-ytD0,6375
191
- flwr/server/superlink/fleet/vce/vce_api.py,sha256=c2J2m6v1jDyuAhiBArdZNIk4cbiZNFJkpKlBJFEQq-c,12454
198
+ flwr/server/superlink/fleet/vce/backend/raybackend.py,sha256=SWygSDQXLL1DsxqF9PqvzVj6t9SA2R3P9jhKYyH_v4I,6550
199
+ flwr/server/superlink/fleet/vce/vce_api.py,sha256=ntIZdIISVdXMOKG8ZNDcstSMaQZ9bRpIPSJVfDFYpP4,12455
192
200
  flwr/server/superlink/state/__init__.py,sha256=ij-7Ms-hyordQdRmGQxY1-nVa4OhixJ0jr7_YDkys0s,1003
193
201
  flwr/server/superlink/state/in_memory_state.py,sha256=WoIOwgayuCu1DLRkkV6KgBsc28SKzSDxtXwO2a9Phuw,12750
194
202
  flwr/server/superlink/state/sqlite_state.py,sha256=8xvJgufEbl_ZRAz9VWXykKP3viUZjQNVS7yDY5dildw,28528
@@ -211,9 +219,9 @@ flwr/simulation/ray_transport/__init__.py,sha256=FsaAnzC4cw4DqoouBCix6496k29jACk
211
219
  flwr/simulation/ray_transport/ray_actor.py,sha256=_wv2eP7qxkCZ-6rMyYWnjLrGPBZRxjvTPjaVk8zIaQ4,19367
212
220
  flwr/simulation/ray_transport/ray_client_proxy.py,sha256=oDu4sEPIOu39vrNi-fqDAe10xtNUXMO49bM2RWfRcyw,6738
213
221
  flwr/simulation/ray_transport/utils.py,sha256=TYdtfg1P9VfTdLMOJlifInGpxWHYs9UfUqIv2wfkRLA,2392
214
- flwr/simulation/run_simulation.py,sha256=LszcnkCLM9YE-kgezB_H7b_NdDrK_Q0yN24mqYtZdfI,15957
215
- flwr_nightly-1.9.0.dev20240506.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
216
- flwr_nightly-1.9.0.dev20240506.dist-info/METADATA,sha256=rNLF7cbRZK0x1EA7vIO_aAAhWWf2L1J8aQKkimtj9EQ,15303
217
- flwr_nightly-1.9.0.dev20240506.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
218
- flwr_nightly-1.9.0.dev20240506.dist-info/entry_points.txt,sha256=8JJPfpqMnXz9c5V_FSt07Xwd-wCWbAO3MFUDXQ5ZGsI,378
219
- flwr_nightly-1.9.0.dev20240506.dist-info/RECORD,,
222
+ flwr/simulation/run_simulation.py,sha256=NdqplXCBRd9_VSqQuFWU1qG6r-KPDIpeHgQbaOCeutQ,16006
223
+ flwr_nightly-1.9.0.dev20240508.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
224
+ flwr_nightly-1.9.0.dev20240508.dist-info/METADATA,sha256=5D-o5asJc3QDNHfwXPjyDLrlJ86P8cSBt_RNFRgZn-A,15302
225
+ flwr_nightly-1.9.0.dev20240508.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
226
+ flwr_nightly-1.9.0.dev20240508.dist-info/entry_points.txt,sha256=8JJPfpqMnXz9c5V_FSt07Xwd-wCWbAO3MFUDXQ5ZGsI,378
227
+ flwr_nightly-1.9.0.dev20240508.dist-info/RECORD,,