flwr 1.21.0__py3-none-any.whl → 1.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/app.py +2 -0
- flwr/cli/new/new.py +9 -7
- flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
- flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
- flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
- flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
- flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
- flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
- flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
- flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
- flwr/cli/new/templates/app/code/{task.pytorch_msg_api.py.tpl → task.pytorch_legacy_api.py.tpl} +27 -14
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
- flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/{pyproject.pytorch_msg_api.toml.tpl → pyproject.pytorch_legacy_api.toml.tpl} +3 -3
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
- flwr/cli/pull.py +100 -0
- flwr/cli/utils.py +17 -0
- flwr/clientapp/mod/__init__.py +4 -1
- flwr/clientapp/mod/centraldp_mods.py +156 -40
- flwr/clientapp/mod/localdp_mod.py +169 -0
- flwr/clientapp/typing.py +22 -0
- flwr/common/constant.py +3 -0
- flwr/common/exit/exit_code.py +4 -0
- flwr/common/record/typeddict.py +12 -0
- flwr/proto/control_pb2.py +7 -3
- flwr/proto/control_pb2.pyi +24 -0
- flwr/proto/control_pb2_grpc.py +34 -0
- flwr/proto/control_pb2_grpc.pyi +13 -0
- flwr/server/app.py +13 -0
- flwr/serverapp/strategy/__init__.py +26 -0
- flwr/serverapp/strategy/bulyan.py +238 -0
- flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
- flwr/serverapp/strategy/dp_fixed_clipping.py +71 -49
- flwr/serverapp/strategy/fedadagrad.py +0 -3
- flwr/serverapp/strategy/fedadam.py +0 -3
- flwr/serverapp/strategy/fedavg.py +89 -64
- flwr/serverapp/strategy/fedavgm.py +198 -0
- flwr/serverapp/strategy/fedmedian.py +105 -0
- flwr/serverapp/strategy/fedprox.py +174 -0
- flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
- flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
- flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
- flwr/serverapp/strategy/fedyogi.py +0 -3
- flwr/serverapp/strategy/krum.py +112 -0
- flwr/serverapp/strategy/multikrum.py +247 -0
- flwr/serverapp/strategy/qfedavg.py +252 -0
- flwr/serverapp/strategy/strategy_utils.py +48 -0
- flwr/simulation/app.py +1 -1
- flwr/simulation/run_simulation.py +25 -30
- flwr/supercore/cli/flower_superexec.py +26 -1
- flwr/supercore/constant.py +19 -0
- flwr/supercore/superexec/plugin/exec_plugin.py +11 -1
- flwr/supercore/superexec/run_superexec.py +16 -2
- flwr/superlink/artifact_provider/__init__.py +22 -0
- flwr/superlink/artifact_provider/artifact_provider.py +37 -0
- flwr/superlink/servicer/control/control_grpc.py +3 -0
- flwr/superlink/servicer/control/control_servicer.py +59 -2
- {flwr-1.21.0.dist-info → flwr-1.22.0.dist-info}/METADATA +6 -16
- {flwr-1.21.0.dist-info → flwr-1.22.0.dist-info}/RECORD +93 -74
- flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +0 -80
- flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +0 -41
- flwr/serverapp/dp_fixed_clipping.py +0 -352
- flwr/serverapp/strategy/strategy_utils_tests.py +0 -304
- /flwr/cli/new/templates/app/code/{__init__.pytorch_msg_api.py.tpl → __init__.pytorch_legacy_api.py.tpl} +0 -0
- {flwr-1.21.0.dist-info → flwr-1.22.0.dist-info}/WHEEL +0 -0
- {flwr-1.21.0.dist-info → flwr-1.22.0.dist-info}/entry_points.txt +0 -0
|
@@ -1,55 +1,80 @@
|
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
+
from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
|
|
5
|
+
from flwr.clientapp import ClientApp
|
|
4
6
|
|
|
5
|
-
from
|
|
6
|
-
from
|
|
7
|
-
from $import_name.task import
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
self.trainloader,
|
|
25
|
-
self.local_epochs,
|
|
26
|
-
self.device,
|
|
27
|
-
)
|
|
28
|
-
return (
|
|
29
|
-
get_weights(self.net),
|
|
30
|
-
len(self.trainloader.dataset),
|
|
31
|
-
{"train_loss": train_loss},
|
|
32
|
-
)
|
|
33
|
-
|
|
34
|
-
def evaluate(self, parameters, config):
|
|
35
|
-
set_weights(self.net, parameters)
|
|
36
|
-
loss, accuracy = test(self.net, self.valloader, self.device)
|
|
37
|
-
return loss, len(self.valloader.dataset), {"accuracy": accuracy}
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
def client_fn(context: Context):
|
|
41
|
-
# Load model and data
|
|
42
|
-
net = Net()
|
|
7
|
+
from $import_name.task import Net, load_data
|
|
8
|
+
from $import_name.task import test as test_fn
|
|
9
|
+
from $import_name.task import train as train_fn
|
|
10
|
+
|
|
11
|
+
# Flower ClientApp
|
|
12
|
+
app = ClientApp()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@app.train()
|
|
16
|
+
def train(msg: Message, context: Context):
|
|
17
|
+
"""Train the model on local data."""
|
|
18
|
+
|
|
19
|
+
# Load the model and initialize it with the received weights
|
|
20
|
+
model = Net()
|
|
21
|
+
model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
|
|
22
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
23
|
+
model.to(device)
|
|
24
|
+
|
|
25
|
+
# Load the data
|
|
43
26
|
partition_id = context.node_config["partition-id"]
|
|
44
27
|
num_partitions = context.node_config["num-partitions"]
|
|
45
|
-
trainloader,
|
|
46
|
-
local_epochs = context.run_config["local-epochs"]
|
|
28
|
+
trainloader, _ = load_data(partition_id, num_partitions)
|
|
47
29
|
|
|
48
|
-
#
|
|
49
|
-
|
|
30
|
+
# Call the training function
|
|
31
|
+
train_loss = train_fn(
|
|
32
|
+
model,
|
|
33
|
+
trainloader,
|
|
34
|
+
context.run_config["local-epochs"],
|
|
35
|
+
msg.content["config"]["lr"],
|
|
36
|
+
device,
|
|
37
|
+
)
|
|
50
38
|
|
|
39
|
+
# Construct and return reply Message
|
|
40
|
+
model_record = ArrayRecord(model.state_dict())
|
|
41
|
+
metrics = {
|
|
42
|
+
"train_loss": train_loss,
|
|
43
|
+
"num-examples": len(trainloader.dataset),
|
|
44
|
+
}
|
|
45
|
+
metric_record = MetricRecord(metrics)
|
|
46
|
+
content = RecordDict({"arrays": model_record, "metrics": metric_record})
|
|
47
|
+
return Message(content=content, reply_to=msg)
|
|
51
48
|
|
|
52
|
-
|
|
53
|
-
app
|
|
54
|
-
|
|
55
|
-
|
|
49
|
+
|
|
50
|
+
@app.evaluate()
|
|
51
|
+
def evaluate(msg: Message, context: Context):
|
|
52
|
+
"""Evaluate the model on local data."""
|
|
53
|
+
|
|
54
|
+
# Load the model and initialize it with the received weights
|
|
55
|
+
model = Net()
|
|
56
|
+
model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
|
|
57
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
58
|
+
model.to(device)
|
|
59
|
+
|
|
60
|
+
# Load the data
|
|
61
|
+
partition_id = context.node_config["partition-id"]
|
|
62
|
+
num_partitions = context.node_config["num-partitions"]
|
|
63
|
+
_, valloader = load_data(partition_id, num_partitions)
|
|
64
|
+
|
|
65
|
+
# Call the evaluation function
|
|
66
|
+
eval_loss, eval_acc = test_fn(
|
|
67
|
+
model,
|
|
68
|
+
valloader,
|
|
69
|
+
device,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# Construct and return reply Message
|
|
73
|
+
metrics = {
|
|
74
|
+
"eval_loss": eval_loss,
|
|
75
|
+
"eval_acc": eval_acc,
|
|
76
|
+
"num-examples": len(valloader.dataset),
|
|
77
|
+
}
|
|
78
|
+
metric_record = MetricRecord(metrics)
|
|
79
|
+
content = RecordDict({"metrics": metric_record})
|
|
80
|
+
return Message(content=content, reply_to=msg)
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from flwr.client import ClientApp, NumPyClient
|
|
6
|
+
from flwr.common import Context
|
|
7
|
+
from $import_name.task import Net, get_weights, load_data, set_weights, test, train
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# Define Flower Client and client_fn
|
|
11
|
+
class FlowerClient(NumPyClient):
|
|
12
|
+
def __init__(self, net, trainloader, valloader, local_epochs):
|
|
13
|
+
self.net = net
|
|
14
|
+
self.trainloader = trainloader
|
|
15
|
+
self.valloader = valloader
|
|
16
|
+
self.local_epochs = local_epochs
|
|
17
|
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
18
|
+
self.net.to(self.device)
|
|
19
|
+
|
|
20
|
+
def fit(self, parameters, config):
|
|
21
|
+
set_weights(self.net, parameters)
|
|
22
|
+
train_loss = train(
|
|
23
|
+
self.net,
|
|
24
|
+
self.trainloader,
|
|
25
|
+
self.local_epochs,
|
|
26
|
+
self.device,
|
|
27
|
+
)
|
|
28
|
+
return (
|
|
29
|
+
get_weights(self.net),
|
|
30
|
+
len(self.trainloader.dataset),
|
|
31
|
+
{"train_loss": train_loss},
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def evaluate(self, parameters, config):
|
|
35
|
+
set_weights(self.net, parameters)
|
|
36
|
+
loss, accuracy = test(self.net, self.valloader, self.device)
|
|
37
|
+
return loss, len(self.valloader.dataset), {"accuracy": accuracy}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def client_fn(context: Context):
|
|
41
|
+
# Load model and data
|
|
42
|
+
net = Net()
|
|
43
|
+
partition_id = context.node_config["partition-id"]
|
|
44
|
+
num_partitions = context.node_config["num-partitions"]
|
|
45
|
+
trainloader, valloader = load_data(partition_id, num_partitions)
|
|
46
|
+
local_epochs = context.run_config["local-epochs"]
|
|
47
|
+
|
|
48
|
+
# Return Client instance
|
|
49
|
+
return FlowerClient(net, trainloader, valloader, local_epochs).to_client()
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# Flower ClientApp
|
|
53
|
+
app = ClientApp(
|
|
54
|
+
client_fn,
|
|
55
|
+
)
|
|
@@ -2,10 +2,16 @@
|
|
|
2
2
|
|
|
3
3
|
import warnings
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
|
|
6
|
+
from flwr.clientapp import ClientApp
|
|
7
|
+
from sklearn.metrics import (
|
|
8
|
+
accuracy_score,
|
|
9
|
+
f1_score,
|
|
10
|
+
log_loss,
|
|
11
|
+
precision_score,
|
|
12
|
+
recall_score,
|
|
13
|
+
)
|
|
6
14
|
|
|
7
|
-
from flwr.client import ClientApp, NumPyClient
|
|
8
|
-
from flwr.common import Context
|
|
9
15
|
from $import_name.task import (
|
|
10
16
|
get_model,
|
|
11
17
|
get_model_params,
|
|
@@ -14,39 +20,52 @@ from $import_name.task import (
|
|
|
14
20
|
set_model_params,
|
|
15
21
|
)
|
|
16
22
|
|
|
23
|
+
# Flower ClientApp
|
|
24
|
+
app = ClientApp()
|
|
17
25
|
|
|
18
|
-
class FlowerClient(NumPyClient):
|
|
19
|
-
def __init__(self, model, X_train, X_test, y_train, y_test):
|
|
20
|
-
self.model = model
|
|
21
|
-
self.X_train = X_train
|
|
22
|
-
self.X_test = X_test
|
|
23
|
-
self.y_train = y_train
|
|
24
|
-
self.y_test = y_test
|
|
25
26
|
|
|
26
|
-
|
|
27
|
-
|
|
27
|
+
@app.train()
|
|
28
|
+
def train(msg: Message, context: Context):
|
|
29
|
+
"""Train the model on local data."""
|
|
28
30
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
31
|
+
# Create LogisticRegression Model
|
|
32
|
+
penalty = context.run_config["penalty"]
|
|
33
|
+
local_epochs = context.run_config["local-epochs"]
|
|
34
|
+
model = get_model(penalty, local_epochs)
|
|
35
|
+
# Setting initial parameters, akin to model.compile for keras models
|
|
36
|
+
set_initial_params(model)
|
|
33
37
|
|
|
34
|
-
|
|
38
|
+
# Apply received pararameters
|
|
39
|
+
ndarrays = msg.content["arrays"].to_numpy_ndarrays()
|
|
40
|
+
set_model_params(model, ndarrays)
|
|
35
41
|
|
|
36
|
-
|
|
37
|
-
|
|
42
|
+
# Load the data
|
|
43
|
+
partition_id = context.node_config["partition-id"]
|
|
44
|
+
num_partitions = context.node_config["num-partitions"]
|
|
45
|
+
X_train, _, y_train, _ = load_data(partition_id, num_partitions)
|
|
38
46
|
|
|
39
|
-
|
|
40
|
-
|
|
47
|
+
# Ignore convergence failure due to low local epochs
|
|
48
|
+
with warnings.catch_warnings():
|
|
49
|
+
warnings.simplefilter("ignore")
|
|
50
|
+
# Train the model on local data
|
|
51
|
+
model.fit(X_train, y_train)
|
|
41
52
|
|
|
42
|
-
|
|
53
|
+
# Let's compute train loss
|
|
54
|
+
y_train_pred_proba = model.predict_proba(X_train)
|
|
55
|
+
train_logloss = log_loss(y_train, y_train_pred_proba)
|
|
43
56
|
|
|
57
|
+
# Construct and return reply Message
|
|
58
|
+
ndarrays = get_model_params(model)
|
|
59
|
+
model_record = ArrayRecord(ndarrays)
|
|
60
|
+
metrics = {"num-examples": len(X_train), "train_logloss": train_logloss}
|
|
61
|
+
metric_record = MetricRecord(metrics)
|
|
62
|
+
content = RecordDict({"arrays": model_record, "metrics": metric_record})
|
|
63
|
+
return Message(content=content, reply_to=msg)
|
|
44
64
|
|
|
45
|
-
def client_fn(context: Context):
|
|
46
|
-
partition_id = context.node_config["partition-id"]
|
|
47
|
-
num_partitions = context.node_config["num-partitions"]
|
|
48
65
|
|
|
49
|
-
|
|
66
|
+
@app.evaluate()
|
|
67
|
+
def evaluate(msg: Message, context: Context):
|
|
68
|
+
"""Evaluate the model on test data."""
|
|
50
69
|
|
|
51
70
|
# Create LogisticRegression Model
|
|
52
71
|
penalty = context.run_config["penalty"]
|
|
@@ -56,8 +75,34 @@ def client_fn(context: Context):
|
|
|
56
75
|
# Setting initial parameters, akin to model.compile for keras models
|
|
57
76
|
set_initial_params(model)
|
|
58
77
|
|
|
59
|
-
|
|
78
|
+
# Apply received pararameters
|
|
79
|
+
ndarrays = msg.content["arrays"].to_numpy_ndarrays()
|
|
80
|
+
set_model_params(model, ndarrays)
|
|
60
81
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
82
|
+
# Load the data
|
|
83
|
+
partition_id = context.node_config["partition-id"]
|
|
84
|
+
num_partitions = context.node_config["num-partitions"]
|
|
85
|
+
_, X_test, _, y_test = load_data(partition_id, num_partitions)
|
|
86
|
+
|
|
87
|
+
# Evaluate the model on local data
|
|
88
|
+
y_train_pred = model.predict(X_test)
|
|
89
|
+
y_train_pred_proba = model.predict_proba(X_test)
|
|
90
|
+
|
|
91
|
+
accuracy = accuracy_score(y_test, y_train_pred)
|
|
92
|
+
loss = log_loss(y_test, y_train_pred_proba)
|
|
93
|
+
precision = precision_score(y_test, y_train_pred, average="macro", zero_division=0)
|
|
94
|
+
recall = recall_score(y_test, y_train_pred, average="macro", zero_division=0)
|
|
95
|
+
f1 = f1_score(y_test, y_train_pred, average="macro", zero_division=0)
|
|
96
|
+
|
|
97
|
+
# Construct and return reply Message
|
|
98
|
+
metrics = {
|
|
99
|
+
"num-examples": len(X_test),
|
|
100
|
+
"test_logloss": loss,
|
|
101
|
+
"accuracy": accuracy,
|
|
102
|
+
"precision": precision,
|
|
103
|
+
"recall": recall,
|
|
104
|
+
"f1": f1,
|
|
105
|
+
}
|
|
106
|
+
metric_record = MetricRecord(metrics)
|
|
107
|
+
content = RecordDict({"metrics": metric_record})
|
|
108
|
+
return Message(content=content, reply_to=msg)
|
|
@@ -1,57 +1,82 @@
|
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
|
-
from flwr.
|
|
4
|
-
from flwr.
|
|
3
|
+
from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
|
|
4
|
+
from flwr.clientapp import ClientApp
|
|
5
5
|
|
|
6
6
|
from $import_name.task import load_data, load_model
|
|
7
7
|
|
|
8
|
+
# Flower ClientApp
|
|
9
|
+
app = ClientApp()
|
|
8
10
|
|
|
9
|
-
# Define Flower Client and client_fn
|
|
10
|
-
class FlowerClient(NumPyClient):
|
|
11
|
-
def __init__(
|
|
12
|
-
self, model, data, epochs, batch_size, verbose
|
|
13
|
-
):
|
|
14
|
-
self.model = model
|
|
15
|
-
self.x_train, self.y_train, self.x_test, self.y_test = data
|
|
16
|
-
self.epochs = epochs
|
|
17
|
-
self.batch_size = batch_size
|
|
18
|
-
self.verbose = verbose
|
|
19
|
-
|
|
20
|
-
def fit(self, parameters, config):
|
|
21
|
-
self.model.set_weights(parameters)
|
|
22
|
-
self.model.fit(
|
|
23
|
-
self.x_train,
|
|
24
|
-
self.y_train,
|
|
25
|
-
epochs=self.epochs,
|
|
26
|
-
batch_size=self.batch_size,
|
|
27
|
-
verbose=self.verbose,
|
|
28
|
-
)
|
|
29
|
-
return self.model.get_weights(), len(self.x_train), {}
|
|
30
|
-
|
|
31
|
-
def evaluate(self, parameters, config):
|
|
32
|
-
self.model.set_weights(parameters)
|
|
33
|
-
loss, accuracy = self.model.evaluate(self.x_test, self.y_test, verbose=0)
|
|
34
|
-
return loss, len(self.x_test), {"accuracy": accuracy}
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def client_fn(context: Context):
|
|
38
|
-
# Load model and data
|
|
39
|
-
net = load_model()
|
|
40
11
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
12
|
+
@app.train()
|
|
13
|
+
def train(msg: Message, context: Context):
|
|
14
|
+
"""Train the model on local data."""
|
|
15
|
+
|
|
16
|
+
# Load the model and initialize it with the received weights
|
|
17
|
+
model = load_model()
|
|
18
|
+
ndarrays = msg.content["arrays"].to_numpy_ndarrays()
|
|
19
|
+
model.set_weights(ndarrays)
|
|
20
|
+
|
|
21
|
+
# Read from config
|
|
44
22
|
epochs = context.run_config["local-epochs"]
|
|
45
23
|
batch_size = context.run_config["batch-size"]
|
|
46
24
|
verbose = context.run_config.get("verbose")
|
|
47
25
|
|
|
48
|
-
#
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
26
|
+
# Load the data
|
|
27
|
+
partition_id = context.node_config["partition-id"]
|
|
28
|
+
num_partitions = context.node_config["num-partitions"]
|
|
29
|
+
x_train, y_train, _, _ = load_data(partition_id, num_partitions)
|
|
52
30
|
|
|
31
|
+
# Train the model on local data
|
|
32
|
+
history = model.fit(
|
|
33
|
+
x_train,
|
|
34
|
+
y_train,
|
|
35
|
+
epochs=epochs,
|
|
36
|
+
batch_size=batch_size,
|
|
37
|
+
verbose=verbose,
|
|
38
|
+
)
|
|
53
39
|
|
|
54
|
-
#
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
40
|
+
# Get final training loss and accuracy
|
|
41
|
+
train_loss = history.history["loss"][-1] if "loss" in history.history else None
|
|
42
|
+
train_acc = history.history.get("accuracy")
|
|
43
|
+
train_acc = train_acc[-1] if train_acc is not None else None
|
|
44
|
+
|
|
45
|
+
# Construct and return reply Message
|
|
46
|
+
model_record = ArrayRecord(model.get_weights())
|
|
47
|
+
metrics = {"num-examples": len(x_train)}
|
|
48
|
+
if train_loss is not None:
|
|
49
|
+
metrics["train_loss"] = train_loss
|
|
50
|
+
if train_acc is not None:
|
|
51
|
+
metrics["train_acc"] = train_acc
|
|
52
|
+
metric_record = MetricRecord(metrics)
|
|
53
|
+
content = RecordDict({"arrays": model_record, "metrics": metric_record})
|
|
54
|
+
return Message(content=content, reply_to=msg)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@app.evaluate()
|
|
58
|
+
def evaluate(msg: Message, context: Context):
|
|
59
|
+
"""Evaluate the model on local data."""
|
|
60
|
+
|
|
61
|
+
# Load the model and initialize it with the received weights
|
|
62
|
+
model = load_model()
|
|
63
|
+
ndarrays = msg.content["arrays"].to_numpy_ndarrays()
|
|
64
|
+
model.set_weights(ndarrays)
|
|
65
|
+
|
|
66
|
+
# Load the data
|
|
67
|
+
partition_id = context.node_config["partition-id"]
|
|
68
|
+
num_partitions = context.node_config["num-partitions"]
|
|
69
|
+
_, _, x_test, y_test = load_data(partition_id, num_partitions)
|
|
70
|
+
|
|
71
|
+
# Evaluate the model on local data
|
|
72
|
+
loss, accuracy = model.evaluate(x_test, y_test, verbose=0)
|
|
73
|
+
|
|
74
|
+
# Construct and return reply Message
|
|
75
|
+
metrics = {
|
|
76
|
+
"eval_loss": loss,
|
|
77
|
+
"eval_acc": accuracy,
|
|
78
|
+
"num-examples": len(x_test),
|
|
79
|
+
}
|
|
80
|
+
metric_record = MetricRecord(metrics)
|
|
81
|
+
content = RecordDict({"metrics": metric_record})
|
|
82
|
+
return Message(content=content, reply_to=msg)
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import xgboost as xgb
|
|
7
|
+
from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
|
|
8
|
+
from flwr.clientapp import ClientApp
|
|
9
|
+
from flwr.common.config import unflatten_dict
|
|
10
|
+
|
|
11
|
+
from $import_name.task import load_data, replace_keys
|
|
12
|
+
|
|
13
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# Flower ClientApp
|
|
17
|
+
app = ClientApp()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _local_boost(bst_input, num_local_round, train_dmatrix):
|
|
21
|
+
# Update trees based on local training data.
|
|
22
|
+
for i in range(num_local_round):
|
|
23
|
+
bst_input.update(train_dmatrix, bst_input.num_boosted_rounds())
|
|
24
|
+
|
|
25
|
+
# Bagging: extract the last N=num_local_round trees for sever aggregation
|
|
26
|
+
bst = bst_input[
|
|
27
|
+
bst_input.num_boosted_rounds()
|
|
28
|
+
- num_local_round : bst_input.num_boosted_rounds()
|
|
29
|
+
]
|
|
30
|
+
return bst
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@app.train()
|
|
34
|
+
def train(msg: Message, context: Context) -> Message:
|
|
35
|
+
# Load model and data
|
|
36
|
+
partition_id = context.node_config["partition-id"]
|
|
37
|
+
num_partitions = context.node_config["num-partitions"]
|
|
38
|
+
train_dmatrix, _, num_train, _ = load_data(partition_id, num_partitions)
|
|
39
|
+
|
|
40
|
+
# Read from run config
|
|
41
|
+
num_local_round = context.run_config["local-epochs"]
|
|
42
|
+
# Flatted config dict and replace "-" with "_"
|
|
43
|
+
cfg = replace_keys(unflatten_dict(context.run_config))
|
|
44
|
+
params = cfg["params"]
|
|
45
|
+
|
|
46
|
+
global_round = msg.content["config"]["server-round"]
|
|
47
|
+
if global_round == 1:
|
|
48
|
+
# First round local training
|
|
49
|
+
bst = xgb.train(
|
|
50
|
+
params,
|
|
51
|
+
train_dmatrix,
|
|
52
|
+
num_boost_round=num_local_round,
|
|
53
|
+
)
|
|
54
|
+
else:
|
|
55
|
+
bst = xgb.Booster(params=params)
|
|
56
|
+
global_model = bytearray(msg.content["arrays"]["0"].numpy().tobytes())
|
|
57
|
+
|
|
58
|
+
# Load global model into booster
|
|
59
|
+
bst.load_model(global_model)
|
|
60
|
+
|
|
61
|
+
# Local training
|
|
62
|
+
bst = _local_boost(bst, num_local_round, train_dmatrix)
|
|
63
|
+
|
|
64
|
+
# Save model
|
|
65
|
+
local_model = bst.save_raw("json")
|
|
66
|
+
model_np = np.frombuffer(local_model, dtype=np.uint8)
|
|
67
|
+
|
|
68
|
+
# Construct reply message
|
|
69
|
+
# Note: we store the model as the first item in a list into ArrayRecord,
|
|
70
|
+
# which can be accessed using index ["0"].
|
|
71
|
+
model_record = ArrayRecord([model_np])
|
|
72
|
+
metrics = {
|
|
73
|
+
"num-examples": num_train,
|
|
74
|
+
}
|
|
75
|
+
metric_record = MetricRecord(metrics)
|
|
76
|
+
content = RecordDict({"arrays": model_record, "metrics": metric_record})
|
|
77
|
+
return Message(content=content, reply_to=msg)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@app.evaluate()
|
|
81
|
+
def evaluate(msg: Message, context: Context) -> Message:
|
|
82
|
+
# Load model and data
|
|
83
|
+
partition_id = context.node_config["partition-id"]
|
|
84
|
+
num_partitions = context.node_config["num-partitions"]
|
|
85
|
+
_, valid_dmatrix, _, num_val = load_data(partition_id, num_partitions)
|
|
86
|
+
|
|
87
|
+
# Load config
|
|
88
|
+
cfg = replace_keys(unflatten_dict(context.run_config))
|
|
89
|
+
params = cfg["params"]
|
|
90
|
+
|
|
91
|
+
# Load global model
|
|
92
|
+
bst = xgb.Booster(params=params)
|
|
93
|
+
global_model = bytearray(msg.content["arrays"]["0"].numpy().tobytes())
|
|
94
|
+
bst.load_model(global_model)
|
|
95
|
+
|
|
96
|
+
# Run evaluation
|
|
97
|
+
eval_results = bst.eval_set(
|
|
98
|
+
evals=[(valid_dmatrix, "valid")],
|
|
99
|
+
iteration=bst.num_boosted_rounds() - 1,
|
|
100
|
+
)
|
|
101
|
+
auc = float(eval_results.split("\t")[1].split(":")[1])
|
|
102
|
+
|
|
103
|
+
# Construct and return reply Message
|
|
104
|
+
metrics = {
|
|
105
|
+
"auc": auc,
|
|
106
|
+
"num-examples": num_val,
|
|
107
|
+
}
|
|
108
|
+
metric_record = MetricRecord(metrics)
|
|
109
|
+
content = RecordDict({"metrics": metric_record})
|
|
110
|
+
return Message(content=content, reply_to=msg)
|