flwr-nightly 1.22.0.dev20250910__py3-none-any.whl → 1.22.0.dev20250911__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
- flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
- flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
- flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/serverapp/strategy/fedavg.py +66 -62
- {flwr_nightly-1.22.0.dev20250910.dist-info → flwr_nightly-1.22.0.dev20250911.dist-info}/METADATA +1 -1
- {flwr_nightly-1.22.0.dev20250910.dist-info → flwr_nightly-1.22.0.dev20250911.dist-info}/RECORD +23 -23
- {flwr_nightly-1.22.0.dev20250910.dist-info → flwr_nightly-1.22.0.dev20250911.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.22.0.dev20250910.dist-info → flwr_nightly-1.22.0.dev20250911.dist-info}/entry_points.txt +0 -0
@@ -1,41 +1,67 @@
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
2
2
|
|
3
3
|
import torch
|
4
|
-
from flwr.
|
5
|
-
from flwr.
|
4
|
+
from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
|
5
|
+
from flwr.clientapp import ClientApp
|
6
6
|
from transformers import AutoModelForSequenceClassification
|
7
7
|
|
8
|
-
from $import_name.task import
|
8
|
+
from $import_name.task import load_data
|
9
|
+
from $import_name.task import test as test_fn
|
10
|
+
from $import_name.task import train as train_fn
|
9
11
|
|
12
|
+
# Flower ClientApp
|
13
|
+
app = ClientApp()
|
14
|
+
|
15
|
+
|
16
|
+
@app.train()
|
17
|
+
def train(msg: Message, context: Context):
|
18
|
+
"""Train the model on local data."""
|
19
|
+
|
20
|
+
# Get this client's dataset partition
|
21
|
+
partition_id = context.node_config["partition-id"]
|
22
|
+
num_partitions = context.node_config["num-partitions"]
|
23
|
+
model_name = context.run_config["model-name"]
|
24
|
+
trainloader, _ = load_data(partition_id, num_partitions, model_name)
|
25
|
+
|
26
|
+
# Load model
|
27
|
+
num_labels = context.run_config["num-labels"]
|
28
|
+
net = AutoModelForSequenceClassification.from_pretrained(
|
29
|
+
model_name, num_labels=num_labels
|
30
|
+
)
|
10
31
|
|
11
|
-
#
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
self.trainloader = trainloader
|
16
|
-
self.testloader = testloader
|
17
|
-
self.local_epochs = local_epochs
|
18
|
-
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
19
|
-
self.net.to(self.device)
|
32
|
+
# Initialize it with the received weights
|
33
|
+
net.load_state_dict(msg.content["arrays"].to_torch_state_dict())
|
34
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
35
|
+
net.to(device)
|
20
36
|
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
37
|
+
# Train the model on local data
|
38
|
+
train_loss = train_fn(
|
39
|
+
net,
|
40
|
+
trainloader,
|
41
|
+
context.run_config["local-steps"],
|
42
|
+
device,
|
43
|
+
)
|
25
44
|
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
45
|
+
# Construct and return reply Message
|
46
|
+
model_record = ArrayRecord(net.state_dict())
|
47
|
+
metrics = {
|
48
|
+
"train_loss": train_loss,
|
49
|
+
"num-examples": len(trainloader.dataset),
|
50
|
+
}
|
51
|
+
metric_record = MetricRecord(metrics)
|
52
|
+
content = RecordDict({"arrays": model_record, "metrics": metric_record})
|
53
|
+
return Message(content=content, reply_to=msg)
|
30
54
|
|
31
55
|
|
32
|
-
|
56
|
+
@app.evaluate()
|
57
|
+
def evaluate(msg: Message, context: Context):
|
58
|
+
"""Evaluate the model on local data."""
|
33
59
|
|
34
60
|
# Get this client's dataset partition
|
35
61
|
partition_id = context.node_config["partition-id"]
|
36
62
|
num_partitions = context.node_config["num-partitions"]
|
37
63
|
model_name = context.run_config["model-name"]
|
38
|
-
|
64
|
+
_, valloader = load_data(partition_id, num_partitions, model_name)
|
39
65
|
|
40
66
|
# Load model
|
41
67
|
num_labels = context.run_config["num-labels"]
|
@@ -43,13 +69,25 @@ def client_fn(context: Context):
|
|
43
69
|
model_name, num_labels=num_labels
|
44
70
|
)
|
45
71
|
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
72
|
+
# Initialize it with the received weights
|
73
|
+
net.load_state_dict(msg.content["arrays"].to_torch_state_dict())
|
74
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
75
|
+
net.to(device)
|
50
76
|
|
77
|
+
# Evaluate the model on local data
|
78
|
+
val_loss, val_accuracy = test_fn(
|
79
|
+
net,
|
80
|
+
valloader,
|
81
|
+
device,
|
82
|
+
)
|
51
83
|
|
52
|
-
#
|
53
|
-
|
54
|
-
|
55
|
-
|
84
|
+
# Construct and return reply Message
|
85
|
+
model_record = ArrayRecord(net.state_dict())
|
86
|
+
metrics = {
|
87
|
+
"val_loss": val_loss,
|
88
|
+
"val_accuracy": val_accuracy,
|
89
|
+
"num-examples": len(valloader.dataset),
|
90
|
+
}
|
91
|
+
metric_record = MetricRecord(metrics)
|
92
|
+
content = RecordDict({"arrays": model_record, "metrics": metric_record})
|
93
|
+
return Message(content=content, reply_to=msg)
|
@@ -1,50 +1,71 @@
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
2
2
|
|
3
3
|
import jax
|
4
|
+
from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
|
5
|
+
from flwr.clientapp import ClientApp
|
4
6
|
|
5
|
-
from
|
6
|
-
from
|
7
|
-
from $import_name.task import
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
# Define Flower Client and client_fn
|
19
|
-
class FlowerClient(NumPyClient):
|
20
|
-
def __init__(self, input_dim):
|
21
|
-
self.train_x, self.train_y, self.test_x, self.test_y = load_data()
|
22
|
-
self.grad_fn = jax.grad(loss_fn)
|
23
|
-
self.params = load_model((input_dim,))
|
24
|
-
|
25
|
-
def fit(self, parameters, config):
|
26
|
-
set_params(self.params, parameters)
|
27
|
-
self.params, loss, num_examples = train(
|
28
|
-
self.params, self.grad_fn, self.train_x, self.train_y
|
29
|
-
)
|
30
|
-
return get_params(self.params), num_examples, {"loss": float(loss)}
|
31
|
-
|
32
|
-
def evaluate(self, parameters, config):
|
33
|
-
set_params(self.params, parameters)
|
34
|
-
loss, num_examples = evaluation(
|
35
|
-
self.params, self.grad_fn, self.test_x, self.test_y
|
36
|
-
)
|
37
|
-
return float(loss), num_examples, {"loss": float(loss)}
|
38
|
-
|
39
|
-
|
40
|
-
def client_fn(context: Context):
|
7
|
+
from $import_name.task import evaluation as evaluation_fn
|
8
|
+
from $import_name.task import get_params, load_data, load_model, loss_fn, set_params
|
9
|
+
from $import_name.task import train as train_fn
|
10
|
+
|
11
|
+
# Flower ClientApp
|
12
|
+
app = ClientApp()
|
13
|
+
|
14
|
+
|
15
|
+
@app.train()
|
16
|
+
def train(msg: Message, context: Context):
|
17
|
+
"""Train the model on local data."""
|
18
|
+
|
19
|
+
# Read from config
|
41
20
|
input_dim = context.run_config["input-dim"]
|
42
21
|
|
43
|
-
#
|
44
|
-
|
22
|
+
# Load data and model
|
23
|
+
train_x, train_y, _, _ = load_data()
|
24
|
+
model = load_model((input_dim,))
|
25
|
+
grad_fn = jax.grad(loss_fn)
|
45
26
|
|
27
|
+
# Set model parameters
|
28
|
+
ndarrays = msg.content["arrays"].to_numpy_ndarrays()
|
29
|
+
set_params(model, ndarrays)
|
46
30
|
|
47
|
-
#
|
48
|
-
|
49
|
-
|
50
|
-
|
31
|
+
# Train the model on local data
|
32
|
+
model, loss, num_examples = train_fn(model, grad_fn, train_x, train_y)
|
33
|
+
|
34
|
+
# Construct and return reply Message
|
35
|
+
model_record = ArrayRecord(get_params(model))
|
36
|
+
metrics = {
|
37
|
+
"train_loss": float(loss),
|
38
|
+
"num-examples": num_examples,
|
39
|
+
}
|
40
|
+
metric_record = MetricRecord(metrics)
|
41
|
+
content = RecordDict({"arrays": model_record, "metrics": metric_record})
|
42
|
+
return Message(content=content, reply_to=msg)
|
43
|
+
|
44
|
+
|
45
|
+
@app.evaluate()
|
46
|
+
def evaluate(msg: Message, context: Context):
|
47
|
+
"""Evaluate the model on local data."""
|
48
|
+
|
49
|
+
# Read from config
|
50
|
+
input_dim = context.run_config["input-dim"]
|
51
|
+
|
52
|
+
# Load data and model
|
53
|
+
_, _, test_x, test_y = load_data()
|
54
|
+
model = load_model((input_dim,))
|
55
|
+
grad_fn = jax.grad(loss_fn)
|
56
|
+
|
57
|
+
# Set model parameters
|
58
|
+
ndarrays = msg.content["arrays"].to_numpy_ndarrays()
|
59
|
+
set_params(model, ndarrays)
|
60
|
+
|
61
|
+
# Evaluate the model on local data
|
62
|
+
loss, num_examples = evaluation_fn(model, grad_fn, test_x, test_y)
|
63
|
+
|
64
|
+
# Construct and return reply Message
|
65
|
+
metrics = {
|
66
|
+
"test_loss": float(loss),
|
67
|
+
"num-examples": num_examples,
|
68
|
+
}
|
69
|
+
metric_record = MetricRecord(metrics)
|
70
|
+
content = RecordDict({"metrics": metric_record})
|
71
|
+
return Message(content=content, reply_to=msg)
|
@@ -3,10 +3,9 @@
|
|
3
3
|
import mlx.core as mx
|
4
4
|
import mlx.nn as nn
|
5
5
|
import mlx.optimizers as optim
|
6
|
+
from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
|
7
|
+
from flwr.clientapp import ClientApp
|
6
8
|
|
7
|
-
from flwr.client import ClientApp, NumPyClient
|
8
|
-
from flwr.common import Context
|
9
|
-
from flwr.common.config import UserConfig
|
10
9
|
from $import_name.task import (
|
11
10
|
MLP,
|
12
11
|
batch_iterate,
|
@@ -17,57 +16,87 @@ from $import_name.task import (
|
|
17
16
|
set_params,
|
18
17
|
)
|
19
18
|
|
19
|
+
# Flower ClientApp
|
20
|
+
app = ClientApp()
|
21
|
+
|
22
|
+
|
23
|
+
@app.train()
|
24
|
+
def train(msg: Message, context: Context):
|
25
|
+
"""Train the model on local data."""
|
26
|
+
|
27
|
+
# Read config
|
28
|
+
num_layers = context.run_config["num-layers"]
|
29
|
+
input_dim = context.run_config["input-dim"]
|
30
|
+
hidden_dim = context.run_config["hidden-dim"]
|
31
|
+
batch_size = context.run_config["batch-size"]
|
32
|
+
learning_rate = context.run_config["lr"]
|
33
|
+
num_epochs = context.run_config["local-epochs"]
|
20
34
|
|
21
|
-
#
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
)
|
29
|
-
|
30
|
-
|
31
|
-
input_dim = run_config["input-dim"]
|
32
|
-
batch_size = run_config["batch-size"]
|
33
|
-
learning_rate = run_config["lr"]
|
34
|
-
self.num_epochs = run_config["local-epochs"]
|
35
|
-
|
36
|
-
self.train_images, self.train_labels, self.test_images, self.test_labels = data
|
37
|
-
self.model = MLP(num_layers, input_dim, hidden_dim, num_classes)
|
38
|
-
self.optimizer = optim.SGD(learning_rate=learning_rate)
|
39
|
-
self.loss_and_grad_fn = nn.value_and_grad(self.model, loss_fn)
|
40
|
-
self.batch_size = batch_size
|
41
|
-
|
42
|
-
def fit(self, parameters, config):
|
43
|
-
set_params(self.model, parameters)
|
44
|
-
for _ in range(self.num_epochs):
|
45
|
-
for X, y in batch_iterate(
|
46
|
-
self.batch_size, self.train_images, self.train_labels
|
47
|
-
):
|
48
|
-
_, grads = self.loss_and_grad_fn(self.model, X, y)
|
49
|
-
self.optimizer.update(self.model, grads)
|
50
|
-
mx.eval(self.model.parameters(), self.optimizer.state)
|
51
|
-
return get_params(self.model), len(self.train_images), {}
|
52
|
-
|
53
|
-
def evaluate(self, parameters, config):
|
54
|
-
set_params(self.model, parameters)
|
55
|
-
accuracy = eval_fn(self.model, self.test_images, self.test_labels)
|
56
|
-
loss = loss_fn(self.model, self.test_images, self.test_labels)
|
57
|
-
return loss.item(), len(self.test_images), {"accuracy": accuracy.item()}
|
58
|
-
|
59
|
-
|
60
|
-
def client_fn(context: Context):
|
35
|
+
# Instantiate model and apply global parameters
|
36
|
+
model = MLP(num_layers, input_dim, hidden_dim, output_dim=10)
|
37
|
+
ndarrays = msg.content["arrays"].to_numpy_ndarrays()
|
38
|
+
set_params(model, ndarrays)
|
39
|
+
|
40
|
+
# Define optimizer and loss function
|
41
|
+
optimizer = optim.SGD(learning_rate=learning_rate)
|
42
|
+
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
43
|
+
|
44
|
+
# Load data
|
61
45
|
partition_id = context.node_config["partition-id"]
|
62
46
|
num_partitions = context.node_config["num-partitions"]
|
63
|
-
|
64
|
-
num_classes = 10
|
47
|
+
train_images, train_labels, _, _ = load_data(partition_id, num_partitions)
|
65
48
|
|
66
|
-
#
|
67
|
-
|
49
|
+
# Train the model on local data
|
50
|
+
for _ in range(num_epochs):
|
51
|
+
for X, y in batch_iterate(batch_size, train_images, train_labels):
|
52
|
+
_, grads = loss_and_grad_fn(model, X, y)
|
53
|
+
optimizer.update(model, grads)
|
54
|
+
mx.eval(model.parameters(), optimizer.state)
|
68
55
|
|
56
|
+
# Compute train accuracy and loss
|
57
|
+
accuracy = eval_fn(model, train_images, train_labels)
|
58
|
+
loss = loss_fn(model, train_images, train_labels)
|
59
|
+
# Construct and return reply Message
|
60
|
+
model_record = ArrayRecord(get_params(model))
|
61
|
+
metrics = {
|
62
|
+
"num-examples": len(train_images),
|
63
|
+
"accuracy": float(accuracy.item()),
|
64
|
+
"loss": float(loss.item()),
|
65
|
+
}
|
66
|
+
metric_record = MetricRecord(metrics)
|
67
|
+
content = RecordDict({"arrays": model_record, "metrics": metric_record})
|
68
|
+
return Message(content=content, reply_to=msg)
|
69
69
|
|
70
|
-
|
71
|
-
app
|
72
|
-
|
73
|
-
|
70
|
+
|
71
|
+
@app.evaluate()
|
72
|
+
def evaluate(msg: Message, context: Context):
|
73
|
+
"""Evaluate the model on local data."""
|
74
|
+
|
75
|
+
# Read config
|
76
|
+
num_layers = context.run_config["num-layers"]
|
77
|
+
input_dim = context.run_config["input-dim"]
|
78
|
+
hidden_dim = context.run_config["hidden-dim"]
|
79
|
+
|
80
|
+
# Instantiate model and apply global parameters
|
81
|
+
model = MLP(num_layers, input_dim, hidden_dim, output_dim=10)
|
82
|
+
ndarrays = msg.content["arrays"].to_numpy_ndarrays()
|
83
|
+
set_params(model, ndarrays)
|
84
|
+
|
85
|
+
# Load data
|
86
|
+
partition_id = context.node_config["partition-id"]
|
87
|
+
num_partitions = context.node_config["num-partitions"]
|
88
|
+
_, _, test_images, test_labels = load_data(partition_id, num_partitions)
|
89
|
+
|
90
|
+
# Evaluate the model on local data
|
91
|
+
accuracy = eval_fn(model, test_images, test_labels)
|
92
|
+
loss = loss_fn(model, test_images, test_labels)
|
93
|
+
|
94
|
+
# Construct and return reply Message
|
95
|
+
metrics = {
|
96
|
+
"num-examples": len(test_images),
|
97
|
+
"accuracy": float(accuracy.item()),
|
98
|
+
"loss": float(loss.item()),
|
99
|
+
}
|
100
|
+
metric_record = MetricRecord(metrics)
|
101
|
+
content = RecordDict({"metrics": metric_record})
|
102
|
+
return Message(content=content, reply_to=msg)
|
@@ -1,23 +1,46 @@
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
2
2
|
|
3
|
-
|
4
|
-
from flwr.
|
5
|
-
from
|
3
|
+
import numpy as np
|
4
|
+
from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
|
5
|
+
from flwr.clientapp import ClientApp
|
6
6
|
|
7
|
+
# Flower ClientApp
|
8
|
+
app = ClientApp()
|
7
9
|
|
8
|
-
class FlowerClient(NumPyClient):
|
9
10
|
|
10
|
-
|
11
|
-
|
12
|
-
|
11
|
+
@app.train()
|
12
|
+
def train(msg: Message, context: Context):
|
13
|
+
"""Train the model on local data."""
|
13
14
|
|
14
|
-
|
15
|
-
|
15
|
+
# The model is the global arrays
|
16
|
+
ndarrays = msg.content["arrays"].to_numpy_ndarrays()
|
16
17
|
|
18
|
+
# Simulate local training (here we just add random noise to model parameters)
|
19
|
+
model = [m + np.random.rand(*m.shape) for m in ndarrays]
|
17
20
|
|
18
|
-
|
19
|
-
|
21
|
+
# Construct and return reply Message
|
22
|
+
model_record = ArrayRecord(model)
|
23
|
+
metrics = {
|
24
|
+
"random_metric": np.random.rand(),
|
25
|
+
"num-examples": 1,
|
26
|
+
}
|
27
|
+
metric_record = MetricRecord(metrics)
|
28
|
+
content = RecordDict({"arrays": model_record, "metrics": metric_record})
|
29
|
+
return Message(content=content, reply_to=msg)
|
20
30
|
|
21
31
|
|
22
|
-
|
23
|
-
|
32
|
+
@app.evaluate()
|
33
|
+
def evaluate(msg: Message, context: Context):
|
34
|
+
"""Evaluate the model on local data."""
|
35
|
+
|
36
|
+
# The model is the global arrays
|
37
|
+
ndarrays = msg.content["arrays"].to_numpy_ndarrays()
|
38
|
+
|
39
|
+
# Return reply Message
|
40
|
+
metrics = {
|
41
|
+
"random_metric": np.random.rand(3).tolist(),
|
42
|
+
"num-examples": 1,
|
43
|
+
}
|
44
|
+
metric_record = MetricRecord(metrics)
|
45
|
+
content = RecordDict({"metrics": metric_record})
|
46
|
+
return Message(content=content, reply_to=msg)
|
@@ -2,10 +2,16 @@
|
|
2
2
|
|
3
3
|
import warnings
|
4
4
|
|
5
|
-
from
|
5
|
+
from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
|
6
|
+
from flwr.clientapp import ClientApp
|
7
|
+
from sklearn.metrics import (
|
8
|
+
accuracy_score,
|
9
|
+
f1_score,
|
10
|
+
log_loss,
|
11
|
+
precision_score,
|
12
|
+
recall_score,
|
13
|
+
)
|
6
14
|
|
7
|
-
from flwr.client import ClientApp, NumPyClient
|
8
|
-
from flwr.common import Context
|
9
15
|
from $import_name.task import (
|
10
16
|
get_model,
|
11
17
|
get_model_params,
|
@@ -14,39 +20,52 @@ from $import_name.task import (
|
|
14
20
|
set_model_params,
|
15
21
|
)
|
16
22
|
|
23
|
+
# Flower ClientApp
|
24
|
+
app = ClientApp()
|
17
25
|
|
18
|
-
class FlowerClient(NumPyClient):
|
19
|
-
def __init__(self, model, X_train, X_test, y_train, y_test):
|
20
|
-
self.model = model
|
21
|
-
self.X_train = X_train
|
22
|
-
self.X_test = X_test
|
23
|
-
self.y_train = y_train
|
24
|
-
self.y_test = y_test
|
25
26
|
|
26
|
-
|
27
|
-
|
27
|
+
@app.train()
|
28
|
+
def train(msg: Message, context: Context):
|
29
|
+
"""Train the model on local data."""
|
28
30
|
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
31
|
+
# Create LogisticRegression Model
|
32
|
+
penalty = context.run_config["penalty"]
|
33
|
+
local_epochs = context.run_config["local-epochs"]
|
34
|
+
model = get_model(penalty, local_epochs)
|
35
|
+
# Setting initial parameters, akin to model.compile for keras models
|
36
|
+
set_initial_params(model)
|
33
37
|
|
34
|
-
|
38
|
+
# Apply received pararameters
|
39
|
+
ndarrays = msg.content["arrays"].to_numpy_ndarrays()
|
40
|
+
set_model_params(model, ndarrays)
|
35
41
|
|
36
|
-
|
37
|
-
|
42
|
+
# Load the data
|
43
|
+
partition_id = context.node_config["partition-id"]
|
44
|
+
num_partitions = context.node_config["num-partitions"]
|
45
|
+
X_train, _, y_train, _ = load_data(partition_id, num_partitions)
|
38
46
|
|
39
|
-
|
40
|
-
|
47
|
+
# Ignore convergence failure due to low local epochs
|
48
|
+
with warnings.catch_warnings():
|
49
|
+
warnings.simplefilter("ignore")
|
50
|
+
# Train the model on local data
|
51
|
+
model.fit(X_train, y_train)
|
41
52
|
|
42
|
-
|
53
|
+
# Let's compute train loss
|
54
|
+
y_train_pred_proba = model.predict_proba(X_train)
|
55
|
+
train_logloss = log_loss(y_train, y_train_pred_proba)
|
43
56
|
|
57
|
+
# Construct and return reply Message
|
58
|
+
ndarrays = get_model_params(model)
|
59
|
+
model_record = ArrayRecord(ndarrays)
|
60
|
+
metrics = {"num-examples": len(X_train), "train_logloss": train_logloss}
|
61
|
+
metric_record = MetricRecord(metrics)
|
62
|
+
content = RecordDict({"arrays": model_record, "metrics": metric_record})
|
63
|
+
return Message(content=content, reply_to=msg)
|
44
64
|
|
45
|
-
def client_fn(context: Context):
|
46
|
-
partition_id = context.node_config["partition-id"]
|
47
|
-
num_partitions = context.node_config["num-partitions"]
|
48
65
|
|
49
|
-
|
66
|
+
@app.evaluate()
|
67
|
+
def evaluate(msg: Message, context: Context):
|
68
|
+
"""Evaluate the model on test data."""
|
50
69
|
|
51
70
|
# Create LogisticRegression Model
|
52
71
|
penalty = context.run_config["penalty"]
|
@@ -56,8 +75,34 @@ def client_fn(context: Context):
|
|
56
75
|
# Setting initial parameters, akin to model.compile for keras models
|
57
76
|
set_initial_params(model)
|
58
77
|
|
59
|
-
|
78
|
+
# Apply received pararameters
|
79
|
+
ndarrays = msg.content["arrays"].to_numpy_ndarrays()
|
80
|
+
set_model_params(model, ndarrays)
|
60
81
|
|
61
|
-
|
62
|
-
|
63
|
-
|
82
|
+
# Load the data
|
83
|
+
partition_id = context.node_config["partition-id"]
|
84
|
+
num_partitions = context.node_config["num-partitions"]
|
85
|
+
_, X_test, _, y_test = load_data(partition_id, num_partitions)
|
86
|
+
|
87
|
+
# Evaluate the model on local data
|
88
|
+
y_train_pred = model.predict(X_test)
|
89
|
+
y_train_pred_proba = model.predict_proba(X_test)
|
90
|
+
|
91
|
+
accuracy = accuracy_score(y_test, y_train_pred)
|
92
|
+
loss = log_loss(y_test, y_train_pred_proba)
|
93
|
+
precision = precision_score(y_test, y_train_pred, average="macro", zero_division=0)
|
94
|
+
recall = recall_score(y_test, y_train_pred, average="macro", zero_division=0)
|
95
|
+
f1 = f1_score(y_test, y_train_pred, average="macro", zero_division=0)
|
96
|
+
|
97
|
+
# Construct and return reply Message
|
98
|
+
metrics = {
|
99
|
+
"num-examples": len(X_test),
|
100
|
+
"test_logloss": loss,
|
101
|
+
"accuracy": accuracy,
|
102
|
+
"precision": precision,
|
103
|
+
"recall": recall,
|
104
|
+
"f1": f1,
|
105
|
+
}
|
106
|
+
metric_record = MetricRecord(metrics)
|
107
|
+
content = RecordDict({"metrics": metric_record})
|
108
|
+
return Message(content=content, reply_to=msg)
|