flwr-nightly 1.22.0.dev20250910__py3-none-any.whl → 1.22.0.dev20250912__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
- flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
- flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
- flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/serverapp/strategy/fedavg.py +66 -62
- {flwr_nightly-1.22.0.dev20250910.dist-info → flwr_nightly-1.22.0.dev20250912.dist-info}/METADATA +1 -1
- {flwr_nightly-1.22.0.dev20250910.dist-info → flwr_nightly-1.22.0.dev20250912.dist-info}/RECORD +23 -23
- {flwr_nightly-1.22.0.dev20250910.dist-info → flwr_nightly-1.22.0.dev20250912.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.22.0.dev20250910.dist-info → flwr_nightly-1.22.0.dev20250912.dist-info}/entry_points.txt +0 -0
@@ -1,57 +1,82 @@
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
2
2
|
|
3
|
-
from flwr.
|
4
|
-
from flwr.
|
3
|
+
from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
|
4
|
+
from flwr.clientapp import ClientApp
|
5
5
|
|
6
6
|
from $import_name.task import load_data, load_model
|
7
7
|
|
8
|
+
# Flower ClientApp
|
9
|
+
app = ClientApp()
|
8
10
|
|
9
|
-
# Define Flower Client and client_fn
|
10
|
-
class FlowerClient(NumPyClient):
|
11
|
-
def __init__(
|
12
|
-
self, model, data, epochs, batch_size, verbose
|
13
|
-
):
|
14
|
-
self.model = model
|
15
|
-
self.x_train, self.y_train, self.x_test, self.y_test = data
|
16
|
-
self.epochs = epochs
|
17
|
-
self.batch_size = batch_size
|
18
|
-
self.verbose = verbose
|
19
|
-
|
20
|
-
def fit(self, parameters, config):
|
21
|
-
self.model.set_weights(parameters)
|
22
|
-
self.model.fit(
|
23
|
-
self.x_train,
|
24
|
-
self.y_train,
|
25
|
-
epochs=self.epochs,
|
26
|
-
batch_size=self.batch_size,
|
27
|
-
verbose=self.verbose,
|
28
|
-
)
|
29
|
-
return self.model.get_weights(), len(self.x_train), {}
|
30
|
-
|
31
|
-
def evaluate(self, parameters, config):
|
32
|
-
self.model.set_weights(parameters)
|
33
|
-
loss, accuracy = self.model.evaluate(self.x_test, self.y_test, verbose=0)
|
34
|
-
return loss, len(self.x_test), {"accuracy": accuracy}
|
35
|
-
|
36
|
-
|
37
|
-
def client_fn(context: Context):
|
38
|
-
# Load model and data
|
39
|
-
net = load_model()
|
40
11
|
|
41
|
-
|
42
|
-
|
43
|
-
|
12
|
+
@app.train()
|
13
|
+
def train(msg: Message, context: Context):
|
14
|
+
"""Train the model on local data."""
|
15
|
+
|
16
|
+
# Load the model and initialize it with the received weights
|
17
|
+
model = load_model()
|
18
|
+
ndarrays = msg.content["arrays"].to_numpy_ndarrays()
|
19
|
+
model.set_weights(ndarrays)
|
20
|
+
|
21
|
+
# Read from config
|
44
22
|
epochs = context.run_config["local-epochs"]
|
45
23
|
batch_size = context.run_config["batch-size"]
|
46
24
|
verbose = context.run_config.get("verbose")
|
47
25
|
|
48
|
-
#
|
49
|
-
|
50
|
-
|
51
|
-
|
26
|
+
# Load the data
|
27
|
+
partition_id = context.node_config["partition-id"]
|
28
|
+
num_partitions = context.node_config["num-partitions"]
|
29
|
+
x_train, y_train, _, _ = load_data(partition_id, num_partitions)
|
52
30
|
|
31
|
+
# Train the model on local data
|
32
|
+
history = model.fit(
|
33
|
+
x_train,
|
34
|
+
y_train,
|
35
|
+
epochs=epochs,
|
36
|
+
batch_size=batch_size,
|
37
|
+
verbose=verbose,
|
38
|
+
)
|
53
39
|
|
54
|
-
#
|
55
|
-
|
56
|
-
|
57
|
-
|
40
|
+
# Get final training loss and accuracy
|
41
|
+
train_loss = history.history["loss"][-1] if "loss" in history.history else None
|
42
|
+
train_acc = history.history.get("accuracy")
|
43
|
+
train_acc = train_acc[-1] if train_acc is not None else None
|
44
|
+
|
45
|
+
# Construct and return reply Message
|
46
|
+
model_record = ArrayRecord(model.get_weights())
|
47
|
+
metrics = {"num-examples": len(x_train)}
|
48
|
+
if train_loss is not None:
|
49
|
+
metrics["train_loss"] = train_loss
|
50
|
+
if train_acc is not None:
|
51
|
+
metrics["train_acc"] = train_acc
|
52
|
+
metric_record = MetricRecord(metrics)
|
53
|
+
content = RecordDict({"arrays": model_record, "metrics": metric_record})
|
54
|
+
return Message(content=content, reply_to=msg)
|
55
|
+
|
56
|
+
|
57
|
+
@app.evaluate()
|
58
|
+
def evaluate(msg: Message, context: Context):
|
59
|
+
"""Evaluate the model on local data."""
|
60
|
+
|
61
|
+
# Load the model and initialize it with the received weights
|
62
|
+
model = load_model()
|
63
|
+
ndarrays = msg.content["arrays"].to_numpy_ndarrays()
|
64
|
+
model.set_weights(ndarrays)
|
65
|
+
|
66
|
+
# Load the data
|
67
|
+
partition_id = context.node_config["partition-id"]
|
68
|
+
num_partitions = context.node_config["num-partitions"]
|
69
|
+
_, _, x_test, y_test = load_data(partition_id, num_partitions)
|
70
|
+
|
71
|
+
# Evaluate the model on local data
|
72
|
+
loss, accuracy = model.evaluate(x_test, y_test, verbose=0)
|
73
|
+
|
74
|
+
# Construct and return reply Message
|
75
|
+
metrics = {
|
76
|
+
"eval_loss": loss,
|
77
|
+
"eval_acc": accuracy,
|
78
|
+
"num-examples": len(x_test),
|
79
|
+
}
|
80
|
+
metric_record = MetricRecord(metrics)
|
81
|
+
content = RecordDict({"metrics": metric_record})
|
82
|
+
return Message(content=content, reply_to=msg)
|
@@ -1,17 +1,22 @@
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
2
2
|
|
3
|
-
|
4
|
-
from flwr.
|
5
|
-
from flwr.
|
3
|
+
import torch
|
4
|
+
from flwr.app import ArrayRecord, Context
|
5
|
+
from flwr.serverapp import Grid, ServerApp
|
6
|
+
from flwr.serverapp.strategy import FedAvg
|
6
7
|
from transformers import AutoModelForSequenceClassification
|
7
8
|
|
8
|
-
|
9
|
+
# Create ServerApp
|
10
|
+
app = ServerApp()
|
11
|
+
|
9
12
|
|
13
|
+
@app.main()
|
14
|
+
def main(grid: Grid, context: Context) -> None:
|
15
|
+
"""Main entry point for the ServerApp."""
|
10
16
|
|
11
|
-
def server_fn(context: Context):
|
12
17
|
# Read from config
|
13
18
|
num_rounds = context.run_config["num-server-rounds"]
|
14
|
-
|
19
|
+
fraction_train = context.run_config["fraction-train"]
|
15
20
|
|
16
21
|
# Initialize global model
|
17
22
|
model_name = context.run_config["model-name"]
|
@@ -19,20 +24,19 @@ def server_fn(context: Context):
|
|
19
24
|
net = AutoModelForSequenceClassification.from_pretrained(
|
20
25
|
model_name, num_labels=num_labels
|
21
26
|
)
|
27
|
+
arrays = ArrayRecord(net.state_dict())
|
22
28
|
|
23
|
-
|
24
|
-
|
29
|
+
# Initialize FedAvg strategy
|
30
|
+
strategy = FedAvg(fraction_train=fraction_train)
|
25
31
|
|
26
|
-
#
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
32
|
+
# Start strategy, run FedAvg for `num_rounds`
|
33
|
+
result = strategy.start(
|
34
|
+
grid=grid,
|
35
|
+
initial_arrays=arrays,
|
36
|
+
num_rounds=num_rounds,
|
31
37
|
)
|
32
|
-
config = ServerConfig(num_rounds=num_rounds)
|
33
|
-
|
34
|
-
return ServerAppComponents(strategy=strategy, config=config)
|
35
38
|
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
+
# Save final model to disk
|
40
|
+
print("\nSaving final model to disk...")
|
41
|
+
state_dict = result.arrays.to_torch_state_dict()
|
42
|
+
torch.save(state_dict, "final_model.pt")
|
@@ -1,26 +1,39 @@
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
2
2
|
|
3
|
-
|
4
|
-
from flwr.
|
5
|
-
from flwr.
|
3
|
+
import numpy as np
|
4
|
+
from flwr.app import ArrayRecord, Context
|
5
|
+
from flwr.serverapp import Grid, ServerApp
|
6
|
+
from flwr.serverapp.strategy import FedAvg
|
7
|
+
|
6
8
|
from $import_name.task import get_params, load_model
|
7
9
|
|
10
|
+
# Create ServerApp
|
11
|
+
app = ServerApp()
|
12
|
+
|
13
|
+
|
14
|
+
@app.main()
|
15
|
+
def main(grid: Grid, context: Context) -> None:
|
16
|
+
"""Main entry point for the ServerApp."""
|
8
17
|
|
9
|
-
def server_fn(context: Context):
|
10
18
|
# Read from config
|
11
19
|
num_rounds = context.run_config["num-server-rounds"]
|
12
20
|
input_dim = context.run_config["input-dim"]
|
13
21
|
|
14
|
-
#
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
# Define strategy
|
19
|
-
strategy = FedAvg(initial_parameters=initial_parameters)
|
20
|
-
config = ServerConfig(num_rounds=num_rounds)
|
22
|
+
# Load global model
|
23
|
+
model = load_model((input_dim,))
|
24
|
+
arrays = ArrayRecord(get_params(model))
|
21
25
|
|
22
|
-
|
26
|
+
# Initialize FedAvg strategy
|
27
|
+
strategy = FedAvg()
|
23
28
|
|
29
|
+
# Start strategy, run FedAvg for `num_rounds`
|
30
|
+
result = strategy.start(
|
31
|
+
grid=grid,
|
32
|
+
initial_arrays=arrays,
|
33
|
+
num_rounds=num_rounds,
|
34
|
+
)
|
24
35
|
|
25
|
-
#
|
26
|
-
|
36
|
+
# Save final model to disk
|
37
|
+
print("\nSaving final model to disk...")
|
38
|
+
ndarrays = result.arrays.to_numpy_ndarrays()
|
39
|
+
np.savez("final_model.npz", *ndarrays)
|
@@ -1,31 +1,41 @@
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
2
2
|
|
3
|
-
from flwr.
|
4
|
-
from flwr.
|
5
|
-
from flwr.
|
6
|
-
from $import_name.task import MLP, get_params
|
3
|
+
from flwr.app import ArrayRecord, Context
|
4
|
+
from flwr.serverapp import Grid, ServerApp
|
5
|
+
from flwr.serverapp.strategy import FedAvg
|
7
6
|
|
7
|
+
from $import_name.task import MLP, get_params, set_params
|
8
8
|
|
9
|
-
|
9
|
+
# Create ServerApp
|
10
|
+
app = ServerApp()
|
11
|
+
|
12
|
+
|
13
|
+
@app.main()
|
14
|
+
def main(grid: Grid, context: Context) -> None:
|
15
|
+
"""Main entry point for the ServerApp."""
|
10
16
|
# Read from config
|
11
17
|
num_rounds = context.run_config["num-server-rounds"]
|
12
|
-
|
13
|
-
num_classes = 10
|
14
18
|
num_layers = context.run_config["num-layers"]
|
15
19
|
input_dim = context.run_config["input-dim"]
|
16
20
|
hidden_dim = context.run_config["hidden-dim"]
|
17
21
|
|
18
22
|
# Initialize global model
|
19
|
-
model = MLP(num_layers, input_dim, hidden_dim,
|
23
|
+
model = MLP(num_layers, input_dim, hidden_dim, output_dim=10)
|
20
24
|
params = get_params(model)
|
21
|
-
|
22
|
-
|
23
|
-
#
|
24
|
-
strategy = FedAvg(
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
25
|
+
arrays = ArrayRecord(params)
|
26
|
+
|
27
|
+
# Initialize FedAvg strategy
|
28
|
+
strategy = FedAvg()
|
29
|
+
|
30
|
+
# Start strategy, run FedAvg for `num_rounds`
|
31
|
+
result = strategy.start(
|
32
|
+
grid=grid,
|
33
|
+
initial_arrays=arrays,
|
34
|
+
num_rounds=num_rounds,
|
35
|
+
)
|
36
|
+
|
37
|
+
# Save final model to disk
|
38
|
+
print("\nSaving final model to disk...")
|
39
|
+
ndarrays = result.arrays.to_numpy_ndarrays()
|
40
|
+
set_params(model, ndarrays)
|
41
|
+
model.save_weights("final_model.npz")
|
@@ -1,25 +1,38 @@
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
2
2
|
|
3
|
-
|
4
|
-
from flwr.
|
5
|
-
from flwr.
|
6
|
-
from
|
7
|
-
|
3
|
+
import numpy as np
|
4
|
+
from flwr.app import ArrayRecord, Context
|
5
|
+
from flwr.serverapp import Grid, ServerApp
|
6
|
+
from flwr.serverapp.strategy import FedAvg
|
8
7
|
|
9
|
-
|
10
|
-
# Read from config
|
11
|
-
num_rounds = context.run_config["num-server-rounds"]
|
8
|
+
from $import_name.task import get_dummy_model
|
12
9
|
|
13
|
-
|
14
|
-
|
15
|
-
dummy_parameters = ndarrays_to_parameters([model])
|
10
|
+
# Create ServerApp
|
11
|
+
app = ServerApp()
|
16
12
|
|
17
|
-
# Define strategy
|
18
|
-
strategy = FedAvg(initial_parameters=dummy_parameters)
|
19
|
-
config = ServerConfig(num_rounds=num_rounds)
|
20
13
|
|
21
|
-
|
14
|
+
@app.main()
|
15
|
+
def main(grid: Grid, context: Context) -> None:
|
16
|
+
"""Main entry point for the ServerApp."""
|
22
17
|
|
18
|
+
# Read run config
|
19
|
+
num_rounds: int = context.run_config["num-server-rounds"]
|
23
20
|
|
24
|
-
#
|
25
|
-
|
21
|
+
# Load global model
|
22
|
+
model = get_dummy_model()
|
23
|
+
arrays = ArrayRecord(model)
|
24
|
+
|
25
|
+
# Initialize FedAvg strategy
|
26
|
+
strategy = FedAvg()
|
27
|
+
|
28
|
+
# Start strategy, run FedAvg for `num_rounds`
|
29
|
+
result = strategy.start(
|
30
|
+
grid=grid,
|
31
|
+
initial_arrays=arrays,
|
32
|
+
num_rounds=num_rounds,
|
33
|
+
)
|
34
|
+
|
35
|
+
# Save final model to disk
|
36
|
+
print("\nSaving final model to disk...")
|
37
|
+
ndarrays = result.arrays.to_numpy_ndarrays()
|
38
|
+
np.savez("final_model", *ndarrays)
|
@@ -1,36 +1,44 @@
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
2
2
|
|
3
|
-
|
4
|
-
from flwr.
|
5
|
-
from flwr.
|
6
|
-
from
|
3
|
+
import joblib
|
4
|
+
from flwr.app import ArrayRecord, Context
|
5
|
+
from flwr.serverapp import Grid, ServerApp
|
6
|
+
from flwr.serverapp.strategy import FedAvg
|
7
7
|
|
8
|
+
from $import_name.task import get_model, get_model_params, set_initial_params, set_model_params
|
8
9
|
|
9
|
-
|
10
|
-
|
11
|
-
|
10
|
+
# Create ServerApp
|
11
|
+
app = ServerApp()
|
12
|
+
|
13
|
+
|
14
|
+
@app.main()
|
15
|
+
def main(grid: Grid, context: Context) -> None:
|
16
|
+
"""Main entry point for the ServerApp."""
|
17
|
+
|
18
|
+
# Read run config
|
19
|
+
num_rounds: int = context.run_config["num-server-rounds"]
|
12
20
|
|
13
21
|
# Create LogisticRegression Model
|
14
22
|
penalty = context.run_config["penalty"]
|
15
23
|
local_epochs = context.run_config["local-epochs"]
|
16
24
|
model = get_model(penalty, local_epochs)
|
17
|
-
|
18
25
|
# Setting initial parameters, akin to model.compile for keras models
|
19
26
|
set_initial_params(model)
|
27
|
+
# Construct ArrayRecord representation
|
28
|
+
arrays = ArrayRecord(get_model_params(model))
|
20
29
|
|
21
|
-
|
30
|
+
# Initialize FedAvg strategy
|
31
|
+
strategy = FedAvg(fraction_train=1.0, fraction_evaluate=1.0)
|
22
32
|
|
23
|
-
#
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
initial_parameters=initial_parameters,
|
33
|
+
# Start strategy, run FedAvg for `num_rounds`
|
34
|
+
result = strategy.start(
|
35
|
+
grid=grid,
|
36
|
+
initial_arrays=arrays,
|
37
|
+
num_rounds=num_rounds,
|
29
38
|
)
|
30
|
-
config = ServerConfig(num_rounds=num_rounds)
|
31
|
-
|
32
|
-
return ServerAppComponents(strategy=strategy, config=config)
|
33
39
|
|
34
|
-
|
35
|
-
|
36
|
-
|
40
|
+
# Save final model parameters
|
41
|
+
print("\nSaving final model to disk...")
|
42
|
+
ndarrays = result.arrays.to_numpy_ndarrays()
|
43
|
+
set_model_params(model, ndarrays)
|
44
|
+
joblib.dump(model, "logreg_model.pkl")
|
@@ -1,29 +1,38 @@
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
2
2
|
|
3
|
-
from flwr.
|
4
|
-
from flwr.
|
5
|
-
from flwr.
|
3
|
+
from flwr.app import ArrayRecord, Context
|
4
|
+
from flwr.serverapp import Grid, ServerApp
|
5
|
+
from flwr.serverapp.strategy import FedAvg
|
6
6
|
|
7
7
|
from $import_name.task import load_model
|
8
8
|
|
9
|
+
# Create ServerApp
|
10
|
+
app = ServerApp()
|
9
11
|
|
10
|
-
def server_fn(context: Context):
|
11
|
-
# Read from config
|
12
|
-
num_rounds = context.run_config["num-server-rounds"]
|
13
12
|
|
14
|
-
|
15
|
-
|
13
|
+
@app.main()
|
14
|
+
def main(grid: Grid, context: Context) -> None:
|
15
|
+
"""Main entry point for the ServerApp."""
|
16
16
|
|
17
|
-
#
|
18
|
-
|
19
|
-
fraction_fit=1.0,
|
20
|
-
fraction_evaluate=1.0,
|
21
|
-
min_available_clients=2,
|
22
|
-
initial_parameters=parameters,
|
23
|
-
)
|
24
|
-
config = ServerConfig(num_rounds=num_rounds)
|
17
|
+
# Read run config
|
18
|
+
num_rounds: int = context.run_config["num-server-rounds"]
|
25
19
|
|
26
|
-
|
20
|
+
# Load global model
|
21
|
+
model = load_model()
|
22
|
+
arrays = ArrayRecord(model.get_weights())
|
27
23
|
|
28
|
-
#
|
29
|
-
|
24
|
+
# Initialize FedAvg strategy
|
25
|
+
strategy = FedAvg(fraction_train=1.0, fraction_evaluate=1.0)
|
26
|
+
|
27
|
+
# Start strategy, run FedAvg for `num_rounds`
|
28
|
+
result = strategy.start(
|
29
|
+
grid=grid,
|
30
|
+
initial_arrays=arrays,
|
31
|
+
num_rounds=num_rounds,
|
32
|
+
)
|
33
|
+
|
34
|
+
# Save final model to disk
|
35
|
+
print("\nSaving final model to disk...")
|
36
|
+
ndarrays = result.arrays.to_numpy_ndarrays()
|
37
|
+
model.set_weights(ndarrays)
|
38
|
+
model.save("final_model.keras")
|
@@ -1,7 +1,6 @@
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
2
2
|
|
3
3
|
import warnings
|
4
|
-
from collections import OrderedDict
|
5
4
|
|
6
5
|
import torch
|
7
6
|
import transformers
|
@@ -62,17 +61,24 @@ def load_data(partition_id: int, num_partitions: int, model_name: str):
|
|
62
61
|
return trainloader, testloader
|
63
62
|
|
64
63
|
|
65
|
-
def train(net, trainloader,
|
64
|
+
def train(net, trainloader, num_steps, device):
|
66
65
|
optimizer = AdamW(net.parameters(), lr=5e-5)
|
67
66
|
net.train()
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
67
|
+
running_loss = 0.0
|
68
|
+
step_cnt = 0
|
69
|
+
for batch in trainloader:
|
70
|
+
batch = {k: v.to(device) for k, v in batch.items()}
|
71
|
+
outputs = net(**batch)
|
72
|
+
loss = outputs.loss
|
73
|
+
loss.backward()
|
74
|
+
optimizer.step()
|
75
|
+
optimizer.zero_grad()
|
76
|
+
running_loss += loss.item()
|
77
|
+
step_cnt += 1
|
78
|
+
if step_cnt >= num_steps:
|
79
|
+
break
|
80
|
+
avg_trainloss = running_loss / step_cnt
|
81
|
+
return avg_trainloss
|
76
82
|
|
77
83
|
|
78
84
|
def test(net, testloader, device):
|
@@ -90,13 +96,3 @@ def test(net, testloader, device):
|
|
90
96
|
loss /= len(testloader.dataset)
|
91
97
|
accuracy = metric.compute()["accuracy"]
|
92
98
|
return loss, accuracy
|
93
|
-
|
94
|
-
|
95
|
-
def get_weights(net):
|
96
|
-
return [val.cpu().numpy() for _, val in net.state_dict().items()]
|
97
|
-
|
98
|
-
|
99
|
-
def set_weights(net, parameters):
|
100
|
-
params_dict = zip(net.state_dict().keys(), parameters)
|
101
|
-
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
|
102
|
-
net.load_state_dict(state_dict, strict=True)
|
@@ -31,7 +31,7 @@ def loss_fn(params, X, y):
|
|
31
31
|
def train(params, grad_fn, X, y):
|
32
32
|
loss = 1_000_000
|
33
33
|
num_examples = X.shape[0]
|
34
|
-
for
|
34
|
+
for _ in range(50):
|
35
35
|
grads = grad_fn(params, X, y)
|
36
36
|
params = jax.tree.map(lambda p, g: p - 0.05 * g, params, grads)
|
37
37
|
loss = loss_fn(params, X, y)
|
@@ -3,10 +3,9 @@
|
|
3
3
|
import os
|
4
4
|
|
5
5
|
import keras
|
6
|
-
from keras import layers
|
7
6
|
from flwr_datasets import FederatedDataset
|
8
7
|
from flwr_datasets.partitioner import IidPartitioner
|
9
|
-
|
8
|
+
from keras import layers
|
10
9
|
|
11
10
|
# Make TensorFlow log less verbose
|
12
11
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
@@ -16,7 +16,7 @@ license = "Apache-2.0"
|
|
16
16
|
dependencies = [
|
17
17
|
"flwr[simulation]>=1.22.0",
|
18
18
|
"flwr-datasets>=0.5.0",
|
19
|
-
"torch
|
19
|
+
"torch>=2.7.1",
|
20
20
|
"transformers>=4.30.0,<5.0",
|
21
21
|
"evaluate>=0.4.0,<1.0",
|
22
22
|
"datasets>=2.0.0, <3.0",
|
@@ -38,8 +38,8 @@ clientapp = "$import_name.client_app:app"
|
|
38
38
|
# Custom config values accessible via `context.run_config`
|
39
39
|
[tool.flwr.app.config]
|
40
40
|
num-server-rounds = 3
|
41
|
-
fraction-
|
42
|
-
local-
|
41
|
+
fraction-train = 0.5
|
42
|
+
local-steps = 5
|
43
43
|
model-name = "prajjwal1/bert-tiny" # Set a larger model if you have access to more GPU resources
|
44
44
|
num-labels = 2
|
45
45
|
|