FedModelKit 0.5.0__tar.gz → 0.6.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/PKG-INFO +2 -2
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/README.md +1 -1
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/examples/simulation-scikit-model/simulation_example.ipynb +7 -7
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/pyproject.toml +1 -1
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/cli.py +1 -1
- fedmodelkit-0.6.0/src/FedModelKit/templates/client_app_template.py +49 -0
- fedmodelkit-0.6.0/src/FedModelKit/templates/server_app_template.py +38 -0
- fedmodelkit-0.6.0/src/FedModelKit/templates/task_template.py +243 -0
- fedmodelkit-0.5.0/src/FedModelKit/templates/client_app_template.py +0 -118
- fedmodelkit-0.5.0/src/FedModelKit/templates/server_app_template.py +0 -204
- fedmodelkit-0.5.0/src/FedModelKit/templates/task_template.py +0 -140
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/.gitignore +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/.python-version +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/LICENSE +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/examples/simulation-scikit-model/.gitignore +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/examples/simulation-scikit-model/AML_preprocessed_dataset.xlsx +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/examples/simulation-scikit-model/README.md +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/images/aggregator.png +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/images/federated_learning_model.png +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/images/local_learner.png +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/README.md +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/__init__.py +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/aggregator.py +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/default_create_functions.py +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/interface.py +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/local_learner.py +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/py.typed +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/src/utils.py +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/__init__template.py +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/ds_template.ipynb +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/extern_pyproject_template.toml +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/images/doSendModels.png +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/images/doWaitsForJobs.png +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/images/dsAggregateModels.png +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/images/dsDoneSubmittingJobs.png +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/images/dsSendsJobs.png +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/images/overview.png +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/main_template.py +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/pyproject_template.toml +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/readme_template.md +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/uv_template.lock +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates.py +0 -0
- {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/uv.lock +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: FedModelKit
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.6.0
|
|
4
4
|
Summary: This package contains the core components and protocols for creating, managing, and registering federated learning models using MLflow. It provides utilities for defining local learners, aggregation strategies, and integrating them with MLflow for tracking and deployment.
|
|
5
5
|
Author-email: ceresale <alessandro.ceresi@upm.es>
|
|
6
6
|
License-File: LICENSE
|
|
@@ -101,7 +101,7 @@ Follow the steps below to install and set up the uploading environment:
|
|
|
101
101
|
2. Go to the directory from where you want to upload the model and run this command to initialize it (pay attention to files that might get overwritten):
|
|
102
102
|
|
|
103
103
|
```bash
|
|
104
|
-
|
|
104
|
+
fmk init
|
|
105
105
|
```
|
|
106
106
|
A `src` folder will be created where the dependencies files must be stored, an `example.py` script will be created in the main directory with an example of how to upload the model and a `README.md` file with a description of how to use the package functionalities. They are provided with clear documentation on how to define your local learner and aggregation strategy, and how to log the model to the Platform Model Catalogue
|
|
107
107
|
|
|
@@ -81,7 +81,7 @@ Follow the steps below to install and set up the uploading environment:
|
|
|
81
81
|
2. Go to the directory from where you want to upload the model and run this command to initialize it (pay attention to files that might get overwritten):
|
|
82
82
|
|
|
83
83
|
```bash
|
|
84
|
-
|
|
84
|
+
fmk init
|
|
85
85
|
```
|
|
86
86
|
A `src` folder will be created where the dependencies files must be stored, an `example.py` script will be created in the main directory with an example of how to upload the model and a `README.md` file with a description of how to use the package functionalities. They are provided with clear documentation on how to define your local learner and aggregation strategy, and how to log the model to the Platform Model Catalogue
|
|
87
87
|
|
{fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/examples/simulation-scikit-model/simulation_example.ipynb
RENAMED
|
@@ -58,7 +58,7 @@
|
|
|
58
58
|
"metadata": {},
|
|
59
59
|
"outputs": [],
|
|
60
60
|
"source": [
|
|
61
|
-
"!
|
|
61
|
+
"!fmk init app"
|
|
62
62
|
]
|
|
63
63
|
},
|
|
64
64
|
{
|
|
@@ -106,7 +106,7 @@
|
|
|
106
106
|
"outputs": [],
|
|
107
107
|
"source": [
|
|
108
108
|
"%%writefile model_example.py\n",
|
|
109
|
-
"import FedModelKit as
|
|
109
|
+
"import FedModelKit as fmk\n",
|
|
110
110
|
"\n",
|
|
111
111
|
"# Function defining the local learner\n",
|
|
112
112
|
"def create_local_learner():\n",
|
|
@@ -313,11 +313,11 @@
|
|
|
313
313
|
"metadata": {},
|
|
314
314
|
"outputs": [],
|
|
315
315
|
"source": [
|
|
316
|
-
"import FedModelKit as
|
|
316
|
+
"import FedModelKit as fmk\n",
|
|
317
317
|
"from model_example import create_local_learner as create_ll_check # Use alias here to avoid name conflict when uploading the model\n",
|
|
318
318
|
"from model_example import create_aggregator as create_agg_check # Use alias here to avoid name conflict when uploading the model\n",
|
|
319
319
|
"\n",
|
|
320
|
-
"
|
|
320
|
+
"fmk.FederatedModel(create_local_learner=create_ll_check, \n",
|
|
321
321
|
" model_name='simple_lr',\n",
|
|
322
322
|
" create_aggregator=create_agg_check,\n",
|
|
323
323
|
" aggregator_name='custom_aggregator')"
|
|
@@ -574,13 +574,13 @@
|
|
|
574
574
|
"outputs": [],
|
|
575
575
|
"source": [
|
|
576
576
|
"%run model_example.py # Run the model_example.py file to directly get the functions defining the local learner and aggregator \n",
|
|
577
|
-
"import FedModelKit as
|
|
577
|
+
"import FedModelKit as fmk\n",
|
|
578
578
|
"\n",
|
|
579
579
|
"# Create an instance of the FederatedModel class from the FedModelKit module.\n",
|
|
580
580
|
"# This class handles the federated learning process, including model creation, training, and aggregation.\n",
|
|
581
581
|
"# The create_local_learner and create_aggregator functions is passed as an argument to define the local learner (model)\n",
|
|
582
582
|
"# for each client. Here the functions used are the ones defined in the model_example.py file.\n",
|
|
583
|
-
"federated_model =
|
|
583
|
+
"federated_model = fmk.FederatedModel(create_local_learner=create_local_learner, # type: ignore # Make sure you already type-checked the function\n",
|
|
584
584
|
" model_name='simple_lr',\n",
|
|
585
585
|
" create_aggregator=create_aggregator, # type: ignore # Make sure you already type-checked the function\n",
|
|
586
586
|
" aggregator_name='custom_aggregator')"
|
|
@@ -631,7 +631,7 @@
|
|
|
631
631
|
"# HERE you should produce the information about the whole dataset needed by you model and store it in\n",
|
|
632
632
|
"# the src directory, UNLESS you performed the simulation and the clients already stored it.\n",
|
|
633
633
|
"\n",
|
|
634
|
-
"
|
|
634
|
+
"fmk.submit_fl_model(model=federated_model,\n",
|
|
635
635
|
" platform_url='http://localhost:5000', # URL of the local MLflow server\n",
|
|
636
636
|
" username='username', # Username for authentication (If no username is required , put a mock username)\n",
|
|
637
637
|
" password='password', # Password for authentication (if no password is required, put a mock password)\n",
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "FedModelKit"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.6.0"
|
|
4
4
|
description = "This package contains the core components and protocols for creating, managing, and registering federated learning models using MLflow. It provides utilities for defining local learners, aggregation strategies, and integrating them with MLflow for tracking and deployment."
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
authors = [
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from flwr.client import ClientApp, NumPyClient
|
|
2
|
+
from flwr.common import Context
|
|
3
|
+
from loguru import logger
|
|
4
|
+
|
|
5
|
+
from EXPERIMENT_NAME.task import (
|
|
6
|
+
Net,
|
|
7
|
+
evaluate,
|
|
8
|
+
get_weights,
|
|
9
|
+
load_flwr_data,
|
|
10
|
+
set_weights,
|
|
11
|
+
train,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FlowerClient(NumPyClient):
|
|
16
|
+
def __init__(self, net, trainloader, testloader):
|
|
17
|
+
self.net = net
|
|
18
|
+
self.trainloader = trainloader
|
|
19
|
+
self.testloader = testloader
|
|
20
|
+
|
|
21
|
+
def fit(self, parameters, config):
|
|
22
|
+
set_weights(self.net, parameters)
|
|
23
|
+
train(self.net, self.trainloader)
|
|
24
|
+
return get_weights(self.net), len(self.trainloader), {}
|
|
25
|
+
|
|
26
|
+
def evaluate(self, parameters, config):
|
|
27
|
+
set_weights(self.net, parameters)
|
|
28
|
+
loss, accuracy = evaluate(self.net, self.testloader)
|
|
29
|
+
return loss, len(self.testloader), {"accuracy": accuracy}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def client_fn(context: Context):
|
|
33
|
+
from EXPERIMENT_NAME.task import load_syftbox_dataset
|
|
34
|
+
from syft_flwr.utils import run_syft_flwr
|
|
35
|
+
|
|
36
|
+
if not run_syft_flwr():
|
|
37
|
+
logger.info("Running flwr locally")
|
|
38
|
+
train_loader, test_loader = load_flwr_data(
|
|
39
|
+
partition_id=context.node_config["partition-id"],
|
|
40
|
+
num_partitions=context.node_config["num-partitions"],
|
|
41
|
+
)
|
|
42
|
+
else:
|
|
43
|
+
logger.info("Running with syft_flwr")
|
|
44
|
+
train_loader, test_loader = load_syftbox_dataset()
|
|
45
|
+
net = Net()
|
|
46
|
+
return FlowerClient(net, train_loader, test_loader).to_client()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
app = ClientApp(client_fn=client_fn)
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""fltabular: Flower Example on Adult Census Income Tabular Dataset."""
|
|
2
|
+
|
|
3
|
+
from flwr.common import Context, ndarrays_to_parameters
|
|
4
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
5
|
+
|
|
6
|
+
from EXPERIMENT_NAME.task import Net, get_weights
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def weighted_average(metrics):
|
|
10
|
+
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
|
|
11
|
+
examples = [num_examples for num_examples, _ in metrics]
|
|
12
|
+
|
|
13
|
+
return {"accuracy": sum(accuracies) / sum(examples)}
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def server_fn(context: Context) -> ServerAppComponents:
|
|
17
|
+
net = Net()
|
|
18
|
+
params = ndarrays_to_parameters(get_weights(net))
|
|
19
|
+
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
|
|
22
|
+
from syft_flwr.strategy import FedAvgWithModelSaving
|
|
23
|
+
|
|
24
|
+
strategy = FedAvgWithModelSaving(
|
|
25
|
+
save_path=Path(__file__).parent.parent.parent / "weights",
|
|
26
|
+
fraction_fit=1.0,
|
|
27
|
+
fraction_evaluate=1.0,
|
|
28
|
+
min_available_clients=2,
|
|
29
|
+
initial_parameters=params,
|
|
30
|
+
evaluate_metrics_aggregation_fn=weighted_average,
|
|
31
|
+
)
|
|
32
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
33
|
+
config = ServerConfig(num_rounds=num_rounds)
|
|
34
|
+
|
|
35
|
+
return ServerAppComponents(config=config, strategy=strategy)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
app = ServerApp(server_fn=server_fn)
|
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
from collections import OrderedDict
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch.optim as optim
|
|
6
|
+
from flwr_datasets import FederatedDataset
|
|
7
|
+
from flwr_datasets.partitioner import IidPartitioner
|
|
8
|
+
from imblearn.over_sampling import SMOTE
|
|
9
|
+
from loguru import logger
|
|
10
|
+
from pandas import DataFrame
|
|
11
|
+
from sklearn.model_selection import train_test_split
|
|
12
|
+
from sklearn.preprocessing import StandardScaler
|
|
13
|
+
from torch.utils.data import DataLoader, TensorDataset
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_device():
|
|
17
|
+
if torch.cuda.is_available():
|
|
18
|
+
return torch.device("cuda")
|
|
19
|
+
elif torch.backends.mps.is_available():
|
|
20
|
+
return torch.device("mps")
|
|
21
|
+
elif torch.xpu.is_available():
|
|
22
|
+
return torch.device("xpu")
|
|
23
|
+
else:
|
|
24
|
+
return torch.device("cpu")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
DEVICE = get_device()
|
|
28
|
+
logger.info(f"Using device: {DEVICE}")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class Net(nn.Module):
|
|
32
|
+
def __init__(self, input_dim=6):
|
|
33
|
+
super(Net, self).__init__()
|
|
34
|
+
# First layer with more units and batch normalization
|
|
35
|
+
self.layer1 = nn.Sequential(
|
|
36
|
+
nn.Linear(input_dim, 32), # Increased from 20 to 32
|
|
37
|
+
nn.BatchNorm1d(32), # Added batch normalization
|
|
38
|
+
nn.LeakyReLU(0.1), # LeakyReLU instead of ReLU
|
|
39
|
+
nn.Dropout(0.2), # Increased dropout
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
# Second layer with more units
|
|
43
|
+
self.layer2 = nn.Sequential(
|
|
44
|
+
nn.Linear(32, 24), # Increased from 14 to 24
|
|
45
|
+
nn.BatchNorm1d(24), # Added batch normalization
|
|
46
|
+
nn.LeakyReLU(0.1),
|
|
47
|
+
nn.Dropout(0.25),
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# Third layer
|
|
51
|
+
self.layer3 = nn.Sequential(
|
|
52
|
+
nn.Linear(24, 16), nn.BatchNorm1d(16), nn.LeakyReLU(0.1)
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# Output layer
|
|
56
|
+
self.output_layer = nn.Sequential(nn.Linear(16, 1), nn.Sigmoid())
|
|
57
|
+
|
|
58
|
+
def forward(self, x):
|
|
59
|
+
x = self.layer1(x)
|
|
60
|
+
x = self.layer2(x)
|
|
61
|
+
x = self.layer3(x)
|
|
62
|
+
x = self.output_layer(x)
|
|
63
|
+
return x
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def dataset_processing(
|
|
67
|
+
train_df: DataFrame, test_df: DataFrame
|
|
68
|
+
) -> tuple[DataLoader, DataLoader]:
|
|
69
|
+
def preprocess_df(df: DataFrame) -> DataFrame:
|
|
70
|
+
columns_to_drop = ["SkinThickness", "Insulin"]
|
|
71
|
+
df_new: DataFrame = df.drop(columns_to_drop, axis=1)
|
|
72
|
+
|
|
73
|
+
# Calculate mean and median (excluding zeros)
|
|
74
|
+
mean_glucose = df_new[df_new["Glucose"] != 0]["Glucose"].mean()
|
|
75
|
+
median_bmi = df_new[df_new["BMI"] != 0]["BMI"].median()
|
|
76
|
+
median_bp = df_new[df_new["BloodPressure"] != 0]["BloodPressure"].median()
|
|
77
|
+
|
|
78
|
+
# Replace zeros values with mean/median
|
|
79
|
+
df_new.replace(
|
|
80
|
+
{
|
|
81
|
+
"Glucose": {0: mean_glucose},
|
|
82
|
+
"BMI": {0: median_bmi},
|
|
83
|
+
"BloodPressure": {0: median_bp},
|
|
84
|
+
},
|
|
85
|
+
inplace=True,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return df_new
|
|
89
|
+
|
|
90
|
+
# Preprocess both datasets
|
|
91
|
+
train_processed = preprocess_df(train_df)
|
|
92
|
+
test_processed = preprocess_df(test_df)
|
|
93
|
+
|
|
94
|
+
# Split features and labels for both sets
|
|
95
|
+
X_train = train_processed.values[:, :6]
|
|
96
|
+
y_train = train_processed.values[:, 6:]
|
|
97
|
+
X_test = test_processed.values[:, :6]
|
|
98
|
+
y_test = test_processed.values[:, 6:]
|
|
99
|
+
|
|
100
|
+
from collections import Counter
|
|
101
|
+
|
|
102
|
+
def get_minority_class_count(y):
|
|
103
|
+
return min(Counter(y.flatten()).values())
|
|
104
|
+
|
|
105
|
+
minority_count = get_minority_class_count(y_train)
|
|
106
|
+
k_neighbors = min(5, minority_count - 1) if minority_count > 1 else 1
|
|
107
|
+
|
|
108
|
+
# Resample the training data to fix the class imbalance
|
|
109
|
+
smote = SMOTE(random_state=42, k_neighbors=k_neighbors)
|
|
110
|
+
X_train_resampled, y_train_resampled = smote.fit_resample(X_train, y_train)
|
|
111
|
+
|
|
112
|
+
# Scale the data to have zero mean and unit variance
|
|
113
|
+
scaler = StandardScaler()
|
|
114
|
+
X_train_resampled = scaler.fit_transform(X_train_resampled)
|
|
115
|
+
X_test = scaler.transform(X_test)
|
|
116
|
+
|
|
117
|
+
# Convert numpy arrays to PyTorch tensors
|
|
118
|
+
X_train_tensor = torch.FloatTensor(X_train_resampled)
|
|
119
|
+
y_train_tensor = torch.FloatTensor(y_train_resampled).reshape(
|
|
120
|
+
-1, 1
|
|
121
|
+
) # Add this reshape
|
|
122
|
+
X_test_tensor = torch.FloatTensor(X_test)
|
|
123
|
+
y_test_tensor = torch.FloatTensor(y_test).reshape(-1, 1)
|
|
124
|
+
|
|
125
|
+
# Create datasets and dataloaders
|
|
126
|
+
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
|
|
127
|
+
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
|
|
128
|
+
|
|
129
|
+
train_loader = DataLoader(dataset=train_dataset, batch_size=10, shuffle=True)
|
|
130
|
+
test_loader = DataLoader(
|
|
131
|
+
dataset=test_dataset, batch_size=len(test_dataset), shuffle=False
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
return train_loader, test_loader
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def load_syftbox_dataset() -> tuple[DataLoader, DataLoader]:
|
|
138
|
+
import pandas as pd
|
|
139
|
+
|
|
140
|
+
from syft_flwr.utils import get_syftbox_dataset_path
|
|
141
|
+
|
|
142
|
+
data_dir = get_syftbox_dataset_path()
|
|
143
|
+
logger.info(f"Loading dataset from {data_dir}")
|
|
144
|
+
|
|
145
|
+
train_df = pd.read_csv(data_dir / "train.csv")
|
|
146
|
+
test_df = pd.read_csv(data_dir / "test.csv")
|
|
147
|
+
|
|
148
|
+
return dataset_processing(train_df, test_df)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
fds = None # Cache FederatedDataset
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def load_flwr_data(
|
|
155
|
+
partition_id: int, num_partitions: int
|
|
156
|
+
) -> tuple[DataLoader, DataLoader]:
|
|
157
|
+
"""
|
|
158
|
+
Load the `fl-diabetes-prediction` dataset to memory
|
|
159
|
+
"""
|
|
160
|
+
global fds
|
|
161
|
+
if fds is None:
|
|
162
|
+
partitioner = IidPartitioner(num_partitions=num_partitions)
|
|
163
|
+
fds = FederatedDataset(
|
|
164
|
+
dataset="khoaguin/pima-indians-diabetes-database",
|
|
165
|
+
partitioners={"train": partitioner},
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
partition: DataFrame = fds.load_partition(partition_id, "train").with_format(
|
|
169
|
+
"pandas"
|
|
170
|
+
)[:]
|
|
171
|
+
train_df, test_df = train_test_split(partition, test_size=0.2, random_state=95)
|
|
172
|
+
|
|
173
|
+
return dataset_processing(train_df, test_df)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def train(model, train_loader, local_epochs=1):
|
|
177
|
+
criterion = nn.BCELoss()
|
|
178
|
+
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0005)
|
|
179
|
+
history = {"train_loss": [], "train_acc": []}
|
|
180
|
+
model.to(DEVICE)
|
|
181
|
+
|
|
182
|
+
for epoch in range(local_epochs):
|
|
183
|
+
model.train()
|
|
184
|
+
running_loss = 0.0
|
|
185
|
+
correct = 0
|
|
186
|
+
total = 0
|
|
187
|
+
|
|
188
|
+
for inputs, labels in train_loader:
|
|
189
|
+
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
|
190
|
+
|
|
191
|
+
optimizer.zero_grad()
|
|
192
|
+
outputs = model(inputs)
|
|
193
|
+
loss = criterion(outputs, labels)
|
|
194
|
+
loss.backward()
|
|
195
|
+
optimizer.step()
|
|
196
|
+
|
|
197
|
+
running_loss += loss.item() * inputs.size(0)
|
|
198
|
+
predicted = (outputs > 0.5).float()
|
|
199
|
+
total += labels.size(0)
|
|
200
|
+
correct += (predicted == labels).sum().item()
|
|
201
|
+
|
|
202
|
+
epoch_loss = running_loss / len(train_loader.dataset)
|
|
203
|
+
epoch_acc = correct / total
|
|
204
|
+
history["train_loss"].append(epoch_loss)
|
|
205
|
+
history["train_acc"].append(epoch_acc)
|
|
206
|
+
|
|
207
|
+
return history
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def evaluate(model, data_loader):
|
|
211
|
+
model.to(DEVICE)
|
|
212
|
+
model.eval()
|
|
213
|
+
criterion = nn.BCELoss()
|
|
214
|
+
|
|
215
|
+
running_loss = 0.0
|
|
216
|
+
correct = 0
|
|
217
|
+
total = 0
|
|
218
|
+
|
|
219
|
+
with torch.no_grad():
|
|
220
|
+
for inputs, labels in data_loader:
|
|
221
|
+
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
|
222
|
+
outputs = model(inputs)
|
|
223
|
+
loss = criterion(outputs, labels)
|
|
224
|
+
running_loss += loss.item() * inputs.size(0)
|
|
225
|
+
predicted = (outputs > 0.5).float()
|
|
226
|
+
total += labels.size(0)
|
|
227
|
+
correct += (predicted == labels).sum().item()
|
|
228
|
+
|
|
229
|
+
epoch_loss = running_loss / len(data_loader.dataset)
|
|
230
|
+
epoch_acc = correct / total
|
|
231
|
+
|
|
232
|
+
return epoch_loss, epoch_acc
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def set_weights(model, parameters):
|
|
236
|
+
params_dict = zip(model.state_dict().keys(), parameters)
|
|
237
|
+
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
|
|
238
|
+
model.load_state_dict(state_dict, strict=True)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def get_weights(model):
|
|
242
|
+
ndarrays = [val.cpu().numpy() for _, val in model.state_dict().items()]
|
|
243
|
+
return ndarrays
|
|
@@ -1,118 +0,0 @@
|
|
|
1
|
-
|
|
2
|
-
import numpy as np
|
|
3
|
-
import pandas as pd
|
|
4
|
-
import pickle
|
|
5
|
-
from pathlib import Path
|
|
6
|
-
from flwr.client import ClientApp
|
|
7
|
-
from flwr.common import Message, Context
|
|
8
|
-
from flwr.common.record import RecordSet, MetricsRecord, ConfigsRecord
|
|
9
|
-
from sklearn.preprocessing import OneHotEncoder
|
|
10
|
-
import FedModelKit as msi
|
|
11
|
-
|
|
12
|
-
from model_example import create_local_learner #type: ignore[import]
|
|
13
|
-
from load_data import load_data #type: ignore[import]
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
# Initialize the Flower ClientApp
|
|
17
|
-
app = ClientApp()
|
|
18
|
-
|
|
19
|
-
@app.query()
|
|
20
|
-
def query(msg: Message, ctx: Context) -> Message:
|
|
21
|
-
"""
|
|
22
|
-
Query function to be executed by the Flower client. This function handles the
|
|
23
|
-
initial configuration sent by the server.
|
|
24
|
-
"""
|
|
25
|
-
|
|
26
|
-
# Retrieve the configuration sent by the server
|
|
27
|
-
fancy_config = msg.content.configs_records['fancy_config']
|
|
28
|
-
|
|
29
|
-
# Load the client split data using the load_data function
|
|
30
|
-
data = load_data(fancy_config['num_clients'], fancy_config['client_id'])
|
|
31
|
-
|
|
32
|
-
# Instantiate the federated model
|
|
33
|
-
federated_model = msi.FederatedModel(create_local_learner=create_local_learner, model_name='simple_lr')
|
|
34
|
-
|
|
35
|
-
# Store the local learner and the data split in the context
|
|
36
|
-
# To store in context other objects, you can use ctx.state.<object_name> = <object>
|
|
37
|
-
ctx.state.local_learner = federated_model.create_local_learner()
|
|
38
|
-
ctx.state.data = data
|
|
39
|
-
ctx.state.undestand = "che succede"
|
|
40
|
-
|
|
41
|
-
return msg.create_reply(RecordSet())
|
|
42
|
-
|
|
43
|
-
@app.train()
|
|
44
|
-
def train(msg: Message, ctx: Context):
|
|
45
|
-
"""
|
|
46
|
-
Train function to be executed by the Flower client.
|
|
47
|
-
This function handles the training of the local model using the data provided.
|
|
48
|
-
"""
|
|
49
|
-
|
|
50
|
-
# Retrieve the local learner and the client split from the context
|
|
51
|
-
local_learner = ctx.state.local_learner
|
|
52
|
-
data = ctx.state.data
|
|
53
|
-
|
|
54
|
-
# Retrieve configuration sent by the server - example
|
|
55
|
-
#fancy_config = msg.content.configs_records['fancy_config']
|
|
56
|
-
#local_epochs = fancy_config['local_epochs']
|
|
57
|
-
|
|
58
|
-
# Retrieve the model parameters sent by the server
|
|
59
|
-
fancy_parameters = msg.content.parameters_records['fancy_model']
|
|
60
|
-
local_learner.set_parameters(fancy_parameters)
|
|
61
|
-
|
|
62
|
-
# Prepare the data using the local learner
|
|
63
|
-
local_learner.prepare_data(data)
|
|
64
|
-
|
|
65
|
-
# Perform local training and obtain training metrics
|
|
66
|
-
train_metrics = local_learner.train_round()
|
|
67
|
-
|
|
68
|
-
# Retrieve the trained model parameters
|
|
69
|
-
new_parameters_records = local_learner.get_parameters()
|
|
70
|
-
assert ctx.state.undestand.startswith("che"), "The context state is not being stored correctly"
|
|
71
|
-
|
|
72
|
-
# Construct a reply message carrying updated model parameters and generated metrics
|
|
73
|
-
reply_content = RecordSet()
|
|
74
|
-
reply_content.parameters_records['fancy_model_returned'] = new_parameters_records
|
|
75
|
-
reply_content.metrics_records['train_metrics'] = train_metrics
|
|
76
|
-
|
|
77
|
-
# Store the metrics and the local learner in the context for future reference
|
|
78
|
-
ctx.state.metrics_records['prev'] = train_metrics
|
|
79
|
-
ctx.state.local_learner = local_learner
|
|
80
|
-
|
|
81
|
-
# Return the reply message to the server
|
|
82
|
-
return msg.create_reply(reply_content)
|
|
83
|
-
|
|
84
|
-
@app.evaluate()
|
|
85
|
-
def eval(msg: Message, ctx: Context):
|
|
86
|
-
"""
|
|
87
|
-
Evaluate function to be executed by the Flower client.
|
|
88
|
-
This function handles the evaluation of the local model using the data provided.
|
|
89
|
-
"""
|
|
90
|
-
|
|
91
|
-
# Retrieve the local learner and the client split from the context
|
|
92
|
-
local_learner = ctx.state.local_learner
|
|
93
|
-
data = ctx.state.data
|
|
94
|
-
|
|
95
|
-
# Retrieve configuration sent by the server - example
|
|
96
|
-
#fancy_config = msg.content.configs_records['fancy_config']
|
|
97
|
-
#local_epochs = fancy_config['local_epochs']
|
|
98
|
-
|
|
99
|
-
# Retrieve the model parameters sent by the server
|
|
100
|
-
fancy_parameters = msg.content.parameters_records['fancy_model']
|
|
101
|
-
local_learner.set_parameters(fancy_parameters)
|
|
102
|
-
|
|
103
|
-
# Prepare the data using the local learner
|
|
104
|
-
local_learner.prepare_data(data)
|
|
105
|
-
|
|
106
|
-
# Evaluate the model and obtain evaluation metrics
|
|
107
|
-
eval_metrics = local_learner.evaluate()
|
|
108
|
-
|
|
109
|
-
# Construct a reply message with evaluation metrics
|
|
110
|
-
reply_content = RecordSet()
|
|
111
|
-
reply_content.metrics_records['eval_metrics'] = eval_metrics
|
|
112
|
-
|
|
113
|
-
# Store the metrics and the local learner in the context for future reference
|
|
114
|
-
ctx.state.metrics_records['prev'] = eval_metrics
|
|
115
|
-
ctx.state.local_learner = local_learner
|
|
116
|
-
|
|
117
|
-
# Return the reply message to the server
|
|
118
|
-
return msg.create_reply(reply_content)
|
|
@@ -1,204 +0,0 @@
|
|
|
1
|
-
|
|
2
|
-
from typing import List
|
|
3
|
-
import time
|
|
4
|
-
|
|
5
|
-
import flwr as fl
|
|
6
|
-
from flwr.common import (
|
|
7
|
-
Context,
|
|
8
|
-
NDArrays,
|
|
9
|
-
Message,
|
|
10
|
-
MessageType,
|
|
11
|
-
Metrics,
|
|
12
|
-
RecordSet,
|
|
13
|
-
ConfigsRecord,
|
|
14
|
-
DEFAULT_TTL,
|
|
15
|
-
)
|
|
16
|
-
from flwr.server import Driver
|
|
17
|
-
|
|
18
|
-
import FedModelKit as msi
|
|
19
|
-
|
|
20
|
-
from model_example import create_local_learner #type: ignore[import]
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
# Run via `flower-server-app server:app`
|
|
24
|
-
app = fl.server.ServerApp()
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
@app.main()
|
|
30
|
-
def main(driver: Driver, context: Context) -> None:
|
|
31
|
-
"""
|
|
32
|
-
Main function to run the federated learning server.
|
|
33
|
-
|
|
34
|
-
Structure:
|
|
35
|
-
- Send a query message to clients for creating the local learner and loading the data
|
|
36
|
-
- Start global epochs loop for training and evaluation
|
|
37
|
-
- Send training messages to clients
|
|
38
|
-
- Aggregate parameters received from clients
|
|
39
|
-
- Send evaluation messages to clients
|
|
40
|
-
- Aggregate evaluation metrics
|
|
41
|
-
"""
|
|
42
|
-
print("Starting test run")
|
|
43
|
-
|
|
44
|
-
# Get node IDs of connected clients
|
|
45
|
-
node_ids = driver.get_node_ids()
|
|
46
|
-
|
|
47
|
-
# Initialize the federated model
|
|
48
|
-
federated_model = msi.FederatedModel(create_local_learner=create_local_learner,
|
|
49
|
-
model_name='simple_lr')
|
|
50
|
-
global_model = federated_model.create_local_learner()
|
|
51
|
-
aggregation_strategy = federated_model.create_aggregator()
|
|
52
|
-
|
|
53
|
-
# Send a query message to clients for creating the local learner and loading the data
|
|
54
|
-
messages = []
|
|
55
|
-
for idx, node_id in enumerate(node_ids):
|
|
56
|
-
# Create messages to send to clients
|
|
57
|
-
recordset = RecordSet()
|
|
58
|
-
|
|
59
|
-
# Add a config with information to send the client for the query
|
|
60
|
-
recordset.configs_records["fancy_config"] = ConfigsRecord({"num_clients": len(node_ids), "client_id": idx})
|
|
61
|
-
|
|
62
|
-
# Create a query message for each client
|
|
63
|
-
message = driver.create_message(
|
|
64
|
-
content=recordset,
|
|
65
|
-
message_type=MessageType.QUERY,
|
|
66
|
-
dst_node_id=node_id,
|
|
67
|
-
group_id=str(1),
|
|
68
|
-
ttl=DEFAULT_TTL,
|
|
69
|
-
)
|
|
70
|
-
messages.append(message)
|
|
71
|
-
|
|
72
|
-
# Send training messages to clients
|
|
73
|
-
message_ids = driver.push_messages(messages)
|
|
74
|
-
print(f"Pushed {len(list(message_ids))} messages: {message_ids}")
|
|
75
|
-
|
|
76
|
-
# Wait for results from clients
|
|
77
|
-
message_ids = [message_id for message_id in message_ids if message_id != ""]
|
|
78
|
-
all_replies: List[Message] = []
|
|
79
|
-
while True:
|
|
80
|
-
replies = driver.pull_messages(message_ids=message_ids)
|
|
81
|
-
print(f"Got {len(list(replies))} results")
|
|
82
|
-
all_replies += replies
|
|
83
|
-
if len(all_replies) == len(message_ids):
|
|
84
|
-
break
|
|
85
|
-
time.sleep(12)
|
|
86
|
-
|
|
87
|
-
# Filter out messages with errors
|
|
88
|
-
all_replies = [
|
|
89
|
-
msg
|
|
90
|
-
for msg in all_replies
|
|
91
|
-
if msg.has_content()
|
|
92
|
-
]
|
|
93
|
-
print(f"Received {len(all_replies)} answers")
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
# Run federated training and evaluation for a fixed number of rounds
|
|
97
|
-
for server_round in range(3):
|
|
98
|
-
print(f"Commencing server train and evaluation round {server_round + 1}")
|
|
99
|
-
|
|
100
|
-
messages = []
|
|
101
|
-
for idx, node_id in enumerate(node_ids):
|
|
102
|
-
# Create messages to send to clients
|
|
103
|
-
recordset = RecordSet()
|
|
104
|
-
|
|
105
|
-
# Add model parameters to record
|
|
106
|
-
recordset.parameters_records["fancy_model"] = global_model.get_parameters()
|
|
107
|
-
# Add a config with information to send the client for training
|
|
108
|
-
recordset.configs_records["fancy_config"] = ConfigsRecord({"local_epochs": 3})
|
|
109
|
-
|
|
110
|
-
# Create a training message for each client
|
|
111
|
-
message = driver.create_message(
|
|
112
|
-
content=recordset,
|
|
113
|
-
message_type=MessageType.TRAIN,
|
|
114
|
-
dst_node_id=node_id,
|
|
115
|
-
group_id=str(server_round),
|
|
116
|
-
ttl=DEFAULT_TTL,
|
|
117
|
-
)
|
|
118
|
-
messages.append(message)
|
|
119
|
-
|
|
120
|
-
# Send training messages to clients
|
|
121
|
-
message_ids = driver.push_messages(messages)
|
|
122
|
-
print(f"Pushed {len(list(message_ids))} messages: {message_ids}")
|
|
123
|
-
|
|
124
|
-
# Wait for results from clients
|
|
125
|
-
message_ids = [message_id for message_id in message_ids if message_id != ""]
|
|
126
|
-
all_replies: List[Message] = []
|
|
127
|
-
while True:
|
|
128
|
-
replies = driver.pull_messages(message_ids=message_ids)
|
|
129
|
-
print(f"Got {len(list(replies))} results")
|
|
130
|
-
all_replies += replies
|
|
131
|
-
if len(all_replies) == len(message_ids):
|
|
132
|
-
break
|
|
133
|
-
time.sleep(12)
|
|
134
|
-
|
|
135
|
-
# Filter out messages with errors
|
|
136
|
-
all_replies = [
|
|
137
|
-
msg
|
|
138
|
-
for msg in all_replies
|
|
139
|
-
if msg.has_content()
|
|
140
|
-
]
|
|
141
|
-
print(f"Received {len(all_replies)} results")
|
|
142
|
-
|
|
143
|
-
# Print metrics received from clients
|
|
144
|
-
for reply in all_replies:
|
|
145
|
-
print(reply.content.metrics_records)
|
|
146
|
-
|
|
147
|
-
# Aggregate parameters received from clients
|
|
148
|
-
parameter_records_list = [reply.content.parameters_records["fancy_model_returned"] for reply in all_replies]
|
|
149
|
-
new_parameter_record = aggregation_strategy.aggregate_parameters(parameter_records_list)
|
|
150
|
-
global_model.set_parameters(new_parameter_record)
|
|
151
|
-
|
|
152
|
-
# Evaluate the updated global model
|
|
153
|
-
messages = []
|
|
154
|
-
for idx, node_id in enumerate(node_ids):
|
|
155
|
-
# Create evaluation messages for clients
|
|
156
|
-
recordset = RecordSet()
|
|
157
|
-
|
|
158
|
-
# Add updated model parameters to record
|
|
159
|
-
recordset.parameters_records["fancy_model"] = new_parameter_record
|
|
160
|
-
# Add a config with information to send the client for evaluation
|
|
161
|
-
recordset.configs_records["fancy_config"] = ConfigsRecord({"local_epochs": 3})
|
|
162
|
-
|
|
163
|
-
# Create an evaluation message for each client
|
|
164
|
-
message = driver.create_message(
|
|
165
|
-
content=recordset,
|
|
166
|
-
message_type=MessageType.EVALUATE,
|
|
167
|
-
dst_node_id=node_id,
|
|
168
|
-
group_id=str(server_round),
|
|
169
|
-
ttl=DEFAULT_TTL,
|
|
170
|
-
)
|
|
171
|
-
messages.append(message)
|
|
172
|
-
|
|
173
|
-
# Send evaluation messages to clients
|
|
174
|
-
message_ids = driver.push_messages(messages)
|
|
175
|
-
print(f"Pushed {len(list(message_ids))} messages: {message_ids}")
|
|
176
|
-
|
|
177
|
-
# Wait for evaluation results from clients
|
|
178
|
-
message_ids = [message_id for message_id in message_ids if message_id != ""]
|
|
179
|
-
all_replies: List[Message] = []
|
|
180
|
-
while True:
|
|
181
|
-
replies = driver.pull_messages(message_ids=message_ids)
|
|
182
|
-
print(f"Got {len(list(replies))} results")
|
|
183
|
-
all_replies += replies
|
|
184
|
-
if len(all_replies) == len(message_ids):
|
|
185
|
-
break
|
|
186
|
-
time.sleep(3)
|
|
187
|
-
|
|
188
|
-
# Filter out messages with errors
|
|
189
|
-
all_replies = [
|
|
190
|
-
msg
|
|
191
|
-
for msg in all_replies
|
|
192
|
-
if msg.has_content()
|
|
193
|
-
]
|
|
194
|
-
print(f"Received {len(all_replies)} results")
|
|
195
|
-
|
|
196
|
-
# Print evaluation metrics received from clients
|
|
197
|
-
metrics_records_list = [reply.content.metrics_records['eval_metrics'] for reply in all_replies]
|
|
198
|
-
for i, reply in enumerate(all_replies):
|
|
199
|
-
print(f"Client {i+1} metrics: ", reply.content.metrics_records['eval_metrics'])
|
|
200
|
-
|
|
201
|
-
# Aggregate evaluation metrics
|
|
202
|
-
print("Aggregated metrics result: ", aggregation_strategy.aggregate_metrics(metrics_records_list))
|
|
203
|
-
|
|
204
|
-
print("🎉🎉🎉 Successfully completed federated learning run! 🎉🎉🎉")
|
|
@@ -1,140 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Model upload example script.
|
|
3
|
-
|
|
4
|
-
To run this code you need to install pytorch library, you can find it at this link: https://pytorch.org/get-started/locally/
|
|
5
|
-
"""
|
|
6
|
-
from FedModelKit import FederatedModel, submit_fl_model
|
|
7
|
-
|
|
8
|
-
# Create your custom function defining a PyTorch-based model and returning its instance
|
|
9
|
-
def create_local_learner():
|
|
10
|
-
import pandas as pd
|
|
11
|
-
from torch import nn, optim
|
|
12
|
-
from torch.utils.data import DataLoader, Dataset
|
|
13
|
-
import torch
|
|
14
|
-
import flwr
|
|
15
|
-
|
|
16
|
-
from src.utils import Utils # Dependency script stored inside the src directory
|
|
17
|
-
|
|
18
|
-
class CustomLocalLearner(nn.Module):
|
|
19
|
-
def __init__(self, input_size: int) -> None:
|
|
20
|
-
super(CustomLocalLearner, self).__init__()
|
|
21
|
-
self.linear = nn.Linear(input_size, 3)
|
|
22
|
-
self.softmax = nn.Softmax(dim=1)
|
|
23
|
-
self.loss_fn = nn.CrossEntropyLoss()
|
|
24
|
-
self.optimizer = optim.Adam(self.parameters(), lr=0.001)
|
|
25
|
-
|
|
26
|
-
def _forward(self, x):
|
|
27
|
-
x = self.linear(x)
|
|
28
|
-
return self.softmax(x)
|
|
29
|
-
|
|
30
|
-
def get_parameters(self) -> flwr.common.ParametersRecord:
|
|
31
|
-
return Utils.pytorch_to_parameter_record(self.state_dict())
|
|
32
|
-
|
|
33
|
-
def set_parameters(self, parameters: flwr.common.ParametersRecord) -> None:
|
|
34
|
-
self.load_state_dict(Utils.parameters_to_pytorch_state_dict(parameters))
|
|
35
|
-
|
|
36
|
-
def prepare_data(self, data: pd.DataFrame) -> None:
|
|
37
|
-
class IrisDataset(Dataset):
|
|
38
|
-
def __init__(self, dataframe: pd.DataFrame) -> None:
|
|
39
|
-
self.dataframe = dataframe
|
|
40
|
-
self.dataframe.loc[:, "class"] = dataframe.loc[:, "class"].replace(
|
|
41
|
-
{"Iris-setosa": 0, "Iris-versicolor": 1, "Iris-virginica": 2}
|
|
42
|
-
)
|
|
43
|
-
|
|
44
|
-
def __getitem__(self, idx: int):
|
|
45
|
-
x = self.dataframe.iloc[idx, :-1].to_numpy("float32")
|
|
46
|
-
y = torch.tensor(self.dataframe.iloc[idx, -1], dtype=torch.long)
|
|
47
|
-
return x, y
|
|
48
|
-
|
|
49
|
-
def __len__(self) -> int:
|
|
50
|
-
return len(self.dataframe)
|
|
51
|
-
|
|
52
|
-
dataset = IrisDataset(data)
|
|
53
|
-
dataloader = DataLoader(dataset, batch_size=32)
|
|
54
|
-
self.dataloader = dataloader
|
|
55
|
-
|
|
56
|
-
def train_round(self) -> flwr.common.MetricsRecord:
|
|
57
|
-
loss = torch.tensor(0.0)
|
|
58
|
-
for batch in self.dataloader:
|
|
59
|
-
x, y = batch
|
|
60
|
-
self.optimizer.zero_grad()
|
|
61
|
-
y_hat = self._forward(x)
|
|
62
|
-
loss = self.loss_fn(y_hat, y)
|
|
63
|
-
loss.backward()
|
|
64
|
-
self.optimizer.step()
|
|
65
|
-
|
|
66
|
-
return flwr.common.MetricsRecord({"loss": loss.item()})
|
|
67
|
-
|
|
68
|
-
def evaluate(self) -> flwr.common.MetricsRecord:
|
|
69
|
-
correct = 0
|
|
70
|
-
total = 0
|
|
71
|
-
with torch.no_grad():
|
|
72
|
-
for batch in self.dataloader:
|
|
73
|
-
x, y = batch
|
|
74
|
-
y_hat = self._forward(x)
|
|
75
|
-
_, predicted = torch.max(y_hat, 1)
|
|
76
|
-
total += y.size(0)
|
|
77
|
-
correct += (predicted == y).sum().item()
|
|
78
|
-
return flwr.common.MetricsRecord({"accuracy": correct / total})
|
|
79
|
-
|
|
80
|
-
return CustomLocalLearner(4)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
# Create your custom function defining an aggregation strategy and returning its instance
|
|
84
|
-
def create_aggregator():
|
|
85
|
-
from collections import OrderedDict
|
|
86
|
-
import numpy as np
|
|
87
|
-
import flwr
|
|
88
|
-
from typing import Optional
|
|
89
|
-
|
|
90
|
-
from src.utils import Utils # Dependency script stored inside the src directory #type: ignore[import]
|
|
91
|
-
|
|
92
|
-
class CustomAggregator:
|
|
93
|
-
|
|
94
|
-
def aggregate_parameters(self, results: list[flwr.common.ParametersRecord], config: Optional[flwr.common.ConfigsRecord]=None
|
|
95
|
-
) -> flwr.common.ParametersRecord:
|
|
96
|
-
parameters = [Utils.parameters_to_dict(param) for param in results]
|
|
97
|
-
keys = parameters[0].keys()
|
|
98
|
-
result = OrderedDict()
|
|
99
|
-
for key in keys:
|
|
100
|
-
# Init array
|
|
101
|
-
this_array: np.ndarray = np.zeros_like(parameters[0][key])
|
|
102
|
-
for p in parameters:
|
|
103
|
-
this_array += p[key]
|
|
104
|
-
result[key] = this_array / len(results)
|
|
105
|
-
return Utils.dict_to_parameter_record(result)
|
|
106
|
-
|
|
107
|
-
def aggregate_metrics(self, results: list[flwr.common.MetricsRecord], config: Optional[flwr.common.ConfigsRecord]=None) -> flwr.common.MetricsRecord:
|
|
108
|
-
keys = results[0].keys()
|
|
109
|
-
result = OrderedDict()
|
|
110
|
-
for key in keys:
|
|
111
|
-
# Init array
|
|
112
|
-
cumsum = 0.0
|
|
113
|
-
for m in results:
|
|
114
|
-
if not isinstance(m[key], (int, float)):
|
|
115
|
-
raise ValueError(
|
|
116
|
-
f"flwr.common.MetricsRecord value type not supported: {type(m[key])}"
|
|
117
|
-
)
|
|
118
|
-
cumsum += m[key] # type: ignore
|
|
119
|
-
result[key] = cumsum / len(results)
|
|
120
|
-
return flwr.common.MetricsRecord(result)
|
|
121
|
-
|
|
122
|
-
return CustomAggregator()
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
if __name__ == "__main__":
|
|
126
|
-
# Create FederatedModel with custom aggregator and local learner
|
|
127
|
-
federated_model = FederatedModel(create_local_learner=create_local_learner,
|
|
128
|
-
model_name="your_model_name",
|
|
129
|
-
create_aggregator=create_aggregator,
|
|
130
|
-
aggregator_name="custom_aggregator")
|
|
131
|
-
|
|
132
|
-
# Submit the FederatedModel to the Federated Platform Model Catalogue with credentials and experiment name
|
|
133
|
-
submit_fl_model(federated_model,
|
|
134
|
-
platform_url="your_platform_url",
|
|
135
|
-
username="your_username",
|
|
136
|
-
password="your_password",
|
|
137
|
-
experiment_name="your_experiment_name",
|
|
138
|
-
disease="AML",
|
|
139
|
-
trained=False
|
|
140
|
-
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/extern_pyproject_template.toml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/images/dsAggregateModels.png
RENAMED
|
File without changes
|
{fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/images/dsDoneSubmittingJobs.png
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|