FedModelKit 0.5.0__tar.gz → 0.6.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/PKG-INFO +2 -2
  2. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/README.md +1 -1
  3. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/examples/simulation-scikit-model/simulation_example.ipynb +7 -7
  4. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/pyproject.toml +1 -1
  5. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/cli.py +1 -1
  6. fedmodelkit-0.6.0/src/FedModelKit/templates/client_app_template.py +49 -0
  7. fedmodelkit-0.6.0/src/FedModelKit/templates/server_app_template.py +38 -0
  8. fedmodelkit-0.6.0/src/FedModelKit/templates/task_template.py +243 -0
  9. fedmodelkit-0.5.0/src/FedModelKit/templates/client_app_template.py +0 -118
  10. fedmodelkit-0.5.0/src/FedModelKit/templates/server_app_template.py +0 -204
  11. fedmodelkit-0.5.0/src/FedModelKit/templates/task_template.py +0 -140
  12. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/.gitignore +0 -0
  13. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/.python-version +0 -0
  14. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/LICENSE +0 -0
  15. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/examples/simulation-scikit-model/.gitignore +0 -0
  16. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/examples/simulation-scikit-model/AML_preprocessed_dataset.xlsx +0 -0
  17. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/examples/simulation-scikit-model/README.md +0 -0
  18. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/images/aggregator.png +0 -0
  19. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/images/federated_learning_model.png +0 -0
  20. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/images/local_learner.png +0 -0
  21. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/README.md +0 -0
  22. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/__init__.py +0 -0
  23. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/aggregator.py +0 -0
  24. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/default_create_functions.py +0 -0
  25. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/interface.py +0 -0
  26. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/local_learner.py +0 -0
  27. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/py.typed +0 -0
  28. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/src/utils.py +0 -0
  29. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/__init__template.py +0 -0
  30. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/ds_template.ipynb +0 -0
  31. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/extern_pyproject_template.toml +0 -0
  32. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/images/doSendModels.png +0 -0
  33. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/images/doWaitsForJobs.png +0 -0
  34. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/images/dsAggregateModels.png +0 -0
  35. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/images/dsDoneSubmittingJobs.png +0 -0
  36. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/images/dsSendsJobs.png +0 -0
  37. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/images/overview.png +0 -0
  38. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/main_template.py +0 -0
  39. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/pyproject_template.toml +0 -0
  40. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/readme_template.md +0 -0
  41. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates/uv_template.lock +0 -0
  42. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/src/FedModelKit/templates.py +0 -0
  43. {fedmodelkit-0.5.0 → fedmodelkit-0.6.0}/uv.lock +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: FedModelKit
3
- Version: 0.5.0
3
+ Version: 0.6.0
4
4
  Summary: This package contains the core components and protocols for creating, managing, and registering federated learning models using MLflow. It provides utilities for defining local learners, aggregation strategies, and integrating them with MLflow for tracking and deployment.
5
5
  Author-email: ceresale <alessandro.ceresi@upm.es>
6
6
  License-File: LICENSE
@@ -101,7 +101,7 @@ Follow the steps below to install and set up the uploading environment:
101
101
  2. Go to the directory from where you want to upload the model and run this command to initialize it (pay attention to files that might get overwritten):
102
102
 
103
103
  ```bash
104
- msi init
104
+ fmk init
105
105
  ```
106
106
  A `src` folder will be created where the dependencies files must be stored, an `example.py` script will be created in the main directory with an example of how to upload the model and a `README.md` file with a description of how to use the package functionalities. They are provided with clear documentation on how to define your local learner and aggregation strategy, and how to log the model to the Platform Model Catalogue
107
107
 
@@ -81,7 +81,7 @@ Follow the steps below to install and set up the uploading environment:
81
81
  2. Go to the directory from where you want to upload the model and run this command to initialize it (pay attention to files that might get overwritten):
82
82
 
83
83
  ```bash
84
- msi init
84
+ fmk init
85
85
  ```
86
86
  A `src` folder will be created where the dependencies files must be stored, an `example.py` script will be created in the main directory with an example of how to upload the model and a `README.md` file with a description of how to use the package functionalities. They are provided with clear documentation on how to define your local learner and aggregation strategy, and how to log the model to the Platform Model Catalogue
87
87
 
@@ -58,7 +58,7 @@
58
58
  "metadata": {},
59
59
  "outputs": [],
60
60
  "source": [
61
- "!msi init app"
61
+ "!fmk init app"
62
62
  ]
63
63
  },
64
64
  {
@@ -106,7 +106,7 @@
106
106
  "outputs": [],
107
107
  "source": [
108
108
  "%%writefile model_example.py\n",
109
- "import FedModelKit as msi\n",
109
+ "import FedModelKit as fmk\n",
110
110
  "\n",
111
111
  "# Function defining the local learner\n",
112
112
  "def create_local_learner():\n",
@@ -313,11 +313,11 @@
313
313
  "metadata": {},
314
314
  "outputs": [],
315
315
  "source": [
316
- "import FedModelKit as msi\n",
316
+ "import FedModelKit as fmk\n",
317
317
  "from model_example import create_local_learner as create_ll_check # Use alias here to avoid name conflict when uploading the model\n",
318
318
  "from model_example import create_aggregator as create_agg_check # Use alias here to avoid name conflict when uploading the model\n",
319
319
  "\n",
320
- "msi.FederatedModel(create_local_learner=create_ll_check, \n",
320
+ "fmk.FederatedModel(create_local_learner=create_ll_check, \n",
321
321
  " model_name='simple_lr',\n",
322
322
  " create_aggregator=create_agg_check,\n",
323
323
  " aggregator_name='custom_aggregator')"
@@ -574,13 +574,13 @@
574
574
  "outputs": [],
575
575
  "source": [
576
576
  "%run model_example.py # Run the model_example.py file to directly get the functions defining the local learner and aggregator \n",
577
- "import FedModelKit as msi\n",
577
+ "import FedModelKit as fmk\n",
578
578
  "\n",
579
579
  "# Create an instance of the FederatedModel class from the FedModelKit module.\n",
580
580
  "# This class handles the federated learning process, including model creation, training, and aggregation.\n",
581
581
  "# The create_local_learner and create_aggregator functions is passed as an argument to define the local learner (model)\n",
582
582
  "# for each client. Here the functions used are the ones defined in the model_example.py file.\n",
583
- "federated_model = msi.FederatedModel(create_local_learner=create_local_learner, # type: ignore # Make sure you already type-checked the function\n",
583
+ "federated_model = fmk.FederatedModel(create_local_learner=create_local_learner, # type: ignore # Make sure you already type-checked the function\n",
584
584
  " model_name='simple_lr',\n",
585
585
  " create_aggregator=create_aggregator, # type: ignore # Make sure you already type-checked the function\n",
586
586
  " aggregator_name='custom_aggregator')"
@@ -631,7 +631,7 @@
631
631
  "# HERE you should produce the information about the whole dataset needed by you model and store it in\n",
632
632
  "# the src directory, UNLESS you performed the simulation and the clients already stored it.\n",
633
633
  "\n",
634
- "msi.submit_fl_model(model=federated_model,\n",
634
+ "fmk.submit_fl_model(model=federated_model,\n",
635
635
  " platform_url='http://localhost:5000', # URL of the local MLflow server\n",
636
636
  " username='username', # Username for authentication (If no username is required , put a mock username)\n",
637
637
  " password='password', # Password for authentication (if no password is required, put a mock password)\n",
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "FedModelKit"
3
- version = "0.5.0"
3
+ version = "0.6.0"
4
4
  description = "This package contains the core components and protocols for creating, managing, and registering federated learning models using MLflow. It provides utilities for defining local learners, aggregation strategies, and integrating them with MLflow for tracking and deployment."
5
5
  readme = "README.md"
6
6
  authors = [
@@ -62,7 +62,7 @@ def create_structure(exp_name: str = "new_experiment") -> None:
62
62
 
63
63
  '''def main():
64
64
  if len(sys.argv) < 2:
65
- print("Usage: msi <command> [options]")
65
+ print("Usage: fmk <command> [options]")
66
66
  sys.exit(1)
67
67
 
68
68
  command = sys.argv[1]
@@ -0,0 +1,49 @@
1
+ from flwr.client import ClientApp, NumPyClient
2
+ from flwr.common import Context
3
+ from loguru import logger
4
+
5
+ from EXPERIMENT_NAME.task import (
6
+ Net,
7
+ evaluate,
8
+ get_weights,
9
+ load_flwr_data,
10
+ set_weights,
11
+ train,
12
+ )
13
+
14
+
15
+ class FlowerClient(NumPyClient):
16
+ def __init__(self, net, trainloader, testloader):
17
+ self.net = net
18
+ self.trainloader = trainloader
19
+ self.testloader = testloader
20
+
21
+ def fit(self, parameters, config):
22
+ set_weights(self.net, parameters)
23
+ train(self.net, self.trainloader)
24
+ return get_weights(self.net), len(self.trainloader), {}
25
+
26
+ def evaluate(self, parameters, config):
27
+ set_weights(self.net, parameters)
28
+ loss, accuracy = evaluate(self.net, self.testloader)
29
+ return loss, len(self.testloader), {"accuracy": accuracy}
30
+
31
+
32
+ def client_fn(context: Context):
33
+ from EXPERIMENT_NAME.task import load_syftbox_dataset
34
+ from syft_flwr.utils import run_syft_flwr
35
+
36
+ if not run_syft_flwr():
37
+ logger.info("Running flwr locally")
38
+ train_loader, test_loader = load_flwr_data(
39
+ partition_id=context.node_config["partition-id"],
40
+ num_partitions=context.node_config["num-partitions"],
41
+ )
42
+ else:
43
+ logger.info("Running with syft_flwr")
44
+ train_loader, test_loader = load_syftbox_dataset()
45
+ net = Net()
46
+ return FlowerClient(net, train_loader, test_loader).to_client()
47
+
48
+
49
+ app = ClientApp(client_fn=client_fn)
@@ -0,0 +1,38 @@
1
+ """fltabular: Flower Example on Adult Census Income Tabular Dataset."""
2
+
3
+ from flwr.common import Context, ndarrays_to_parameters
4
+ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
+
6
+ from EXPERIMENT_NAME.task import Net, get_weights
7
+
8
+
9
+ def weighted_average(metrics):
10
+ accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
11
+ examples = [num_examples for num_examples, _ in metrics]
12
+
13
+ return {"accuracy": sum(accuracies) / sum(examples)}
14
+
15
+
16
+ def server_fn(context: Context) -> ServerAppComponents:
17
+ net = Net()
18
+ params = ndarrays_to_parameters(get_weights(net))
19
+
20
+ from pathlib import Path
21
+
22
+ from syft_flwr.strategy import FedAvgWithModelSaving
23
+
24
+ strategy = FedAvgWithModelSaving(
25
+ save_path=Path(__file__).parent.parent.parent / "weights",
26
+ fraction_fit=1.0,
27
+ fraction_evaluate=1.0,
28
+ min_available_clients=2,
29
+ initial_parameters=params,
30
+ evaluate_metrics_aggregation_fn=weighted_average,
31
+ )
32
+ num_rounds = context.run_config["num-server-rounds"]
33
+ config = ServerConfig(num_rounds=num_rounds)
34
+
35
+ return ServerAppComponents(config=config, strategy=strategy)
36
+
37
+
38
+ app = ServerApp(server_fn=server_fn)
@@ -0,0 +1,243 @@
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ from flwr_datasets import FederatedDataset
7
+ from flwr_datasets.partitioner import IidPartitioner
8
+ from imblearn.over_sampling import SMOTE
9
+ from loguru import logger
10
+ from pandas import DataFrame
11
+ from sklearn.model_selection import train_test_split
12
+ from sklearn.preprocessing import StandardScaler
13
+ from torch.utils.data import DataLoader, TensorDataset
14
+
15
+
16
+ def get_device():
17
+ if torch.cuda.is_available():
18
+ return torch.device("cuda")
19
+ elif torch.backends.mps.is_available():
20
+ return torch.device("mps")
21
+ elif torch.xpu.is_available():
22
+ return torch.device("xpu")
23
+ else:
24
+ return torch.device("cpu")
25
+
26
+
27
+ DEVICE = get_device()
28
+ logger.info(f"Using device: {DEVICE}")
29
+
30
+
31
+ class Net(nn.Module):
32
+ def __init__(self, input_dim=6):
33
+ super(Net, self).__init__()
34
+ # First layer with more units and batch normalization
35
+ self.layer1 = nn.Sequential(
36
+ nn.Linear(input_dim, 32), # Increased from 20 to 32
37
+ nn.BatchNorm1d(32), # Added batch normalization
38
+ nn.LeakyReLU(0.1), # LeakyReLU instead of ReLU
39
+ nn.Dropout(0.2), # Increased dropout
40
+ )
41
+
42
+ # Second layer with more units
43
+ self.layer2 = nn.Sequential(
44
+ nn.Linear(32, 24), # Increased from 14 to 24
45
+ nn.BatchNorm1d(24), # Added batch normalization
46
+ nn.LeakyReLU(0.1),
47
+ nn.Dropout(0.25),
48
+ )
49
+
50
+ # Third layer
51
+ self.layer3 = nn.Sequential(
52
+ nn.Linear(24, 16), nn.BatchNorm1d(16), nn.LeakyReLU(0.1)
53
+ )
54
+
55
+ # Output layer
56
+ self.output_layer = nn.Sequential(nn.Linear(16, 1), nn.Sigmoid())
57
+
58
+ def forward(self, x):
59
+ x = self.layer1(x)
60
+ x = self.layer2(x)
61
+ x = self.layer3(x)
62
+ x = self.output_layer(x)
63
+ return x
64
+
65
+
66
+ def dataset_processing(
67
+ train_df: DataFrame, test_df: DataFrame
68
+ ) -> tuple[DataLoader, DataLoader]:
69
+ def preprocess_df(df: DataFrame) -> DataFrame:
70
+ columns_to_drop = ["SkinThickness", "Insulin"]
71
+ df_new: DataFrame = df.drop(columns_to_drop, axis=1)
72
+
73
+ # Calculate mean and median (excluding zeros)
74
+ mean_glucose = df_new[df_new["Glucose"] != 0]["Glucose"].mean()
75
+ median_bmi = df_new[df_new["BMI"] != 0]["BMI"].median()
76
+ median_bp = df_new[df_new["BloodPressure"] != 0]["BloodPressure"].median()
77
+
78
+ # Replace zeros values with mean/median
79
+ df_new.replace(
80
+ {
81
+ "Glucose": {0: mean_glucose},
82
+ "BMI": {0: median_bmi},
83
+ "BloodPressure": {0: median_bp},
84
+ },
85
+ inplace=True,
86
+ )
87
+
88
+ return df_new
89
+
90
+ # Preprocess both datasets
91
+ train_processed = preprocess_df(train_df)
92
+ test_processed = preprocess_df(test_df)
93
+
94
+ # Split features and labels for both sets
95
+ X_train = train_processed.values[:, :6]
96
+ y_train = train_processed.values[:, 6:]
97
+ X_test = test_processed.values[:, :6]
98
+ y_test = test_processed.values[:, 6:]
99
+
100
+ from collections import Counter
101
+
102
+ def get_minority_class_count(y):
103
+ return min(Counter(y.flatten()).values())
104
+
105
+ minority_count = get_minority_class_count(y_train)
106
+ k_neighbors = min(5, minority_count - 1) if minority_count > 1 else 1
107
+
108
+ # Resample the training data to fix the class imbalance
109
+ smote = SMOTE(random_state=42, k_neighbors=k_neighbors)
110
+ X_train_resampled, y_train_resampled = smote.fit_resample(X_train, y_train)
111
+
112
+ # Scale the data to have zero mean and unit variance
113
+ scaler = StandardScaler()
114
+ X_train_resampled = scaler.fit_transform(X_train_resampled)
115
+ X_test = scaler.transform(X_test)
116
+
117
+ # Convert numpy arrays to PyTorch tensors
118
+ X_train_tensor = torch.FloatTensor(X_train_resampled)
119
+ y_train_tensor = torch.FloatTensor(y_train_resampled).reshape(
120
+ -1, 1
121
+ ) # Add this reshape
122
+ X_test_tensor = torch.FloatTensor(X_test)
123
+ y_test_tensor = torch.FloatTensor(y_test).reshape(-1, 1)
124
+
125
+ # Create datasets and dataloaders
126
+ train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
127
+ test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
128
+
129
+ train_loader = DataLoader(dataset=train_dataset, batch_size=10, shuffle=True)
130
+ test_loader = DataLoader(
131
+ dataset=test_dataset, batch_size=len(test_dataset), shuffle=False
132
+ )
133
+
134
+ return train_loader, test_loader
135
+
136
+
137
+ def load_syftbox_dataset() -> tuple[DataLoader, DataLoader]:
138
+ import pandas as pd
139
+
140
+ from syft_flwr.utils import get_syftbox_dataset_path
141
+
142
+ data_dir = get_syftbox_dataset_path()
143
+ logger.info(f"Loading dataset from {data_dir}")
144
+
145
+ train_df = pd.read_csv(data_dir / "train.csv")
146
+ test_df = pd.read_csv(data_dir / "test.csv")
147
+
148
+ return dataset_processing(train_df, test_df)
149
+
150
+
151
+ fds = None # Cache FederatedDataset
152
+
153
+
154
+ def load_flwr_data(
155
+ partition_id: int, num_partitions: int
156
+ ) -> tuple[DataLoader, DataLoader]:
157
+ """
158
+ Load the `fl-diabetes-prediction` dataset to memory
159
+ """
160
+ global fds
161
+ if fds is None:
162
+ partitioner = IidPartitioner(num_partitions=num_partitions)
163
+ fds = FederatedDataset(
164
+ dataset="khoaguin/pima-indians-diabetes-database",
165
+ partitioners={"train": partitioner},
166
+ )
167
+
168
+ partition: DataFrame = fds.load_partition(partition_id, "train").with_format(
169
+ "pandas"
170
+ )[:]
171
+ train_df, test_df = train_test_split(partition, test_size=0.2, random_state=95)
172
+
173
+ return dataset_processing(train_df, test_df)
174
+
175
+
176
+ def train(model, train_loader, local_epochs=1):
177
+ criterion = nn.BCELoss()
178
+ optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0005)
179
+ history = {"train_loss": [], "train_acc": []}
180
+ model.to(DEVICE)
181
+
182
+ for epoch in range(local_epochs):
183
+ model.train()
184
+ running_loss = 0.0
185
+ correct = 0
186
+ total = 0
187
+
188
+ for inputs, labels in train_loader:
189
+ inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
190
+
191
+ optimizer.zero_grad()
192
+ outputs = model(inputs)
193
+ loss = criterion(outputs, labels)
194
+ loss.backward()
195
+ optimizer.step()
196
+
197
+ running_loss += loss.item() * inputs.size(0)
198
+ predicted = (outputs > 0.5).float()
199
+ total += labels.size(0)
200
+ correct += (predicted == labels).sum().item()
201
+
202
+ epoch_loss = running_loss / len(train_loader.dataset)
203
+ epoch_acc = correct / total
204
+ history["train_loss"].append(epoch_loss)
205
+ history["train_acc"].append(epoch_acc)
206
+
207
+ return history
208
+
209
+
210
+ def evaluate(model, data_loader):
211
+ model.to(DEVICE)
212
+ model.eval()
213
+ criterion = nn.BCELoss()
214
+
215
+ running_loss = 0.0
216
+ correct = 0
217
+ total = 0
218
+
219
+ with torch.no_grad():
220
+ for inputs, labels in data_loader:
221
+ inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
222
+ outputs = model(inputs)
223
+ loss = criterion(outputs, labels)
224
+ running_loss += loss.item() * inputs.size(0)
225
+ predicted = (outputs > 0.5).float()
226
+ total += labels.size(0)
227
+ correct += (predicted == labels).sum().item()
228
+
229
+ epoch_loss = running_loss / len(data_loader.dataset)
230
+ epoch_acc = correct / total
231
+
232
+ return epoch_loss, epoch_acc
233
+
234
+
235
+ def set_weights(model, parameters):
236
+ params_dict = zip(model.state_dict().keys(), parameters)
237
+ state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
238
+ model.load_state_dict(state_dict, strict=True)
239
+
240
+
241
+ def get_weights(model):
242
+ ndarrays = [val.cpu().numpy() for _, val in model.state_dict().items()]
243
+ return ndarrays
@@ -1,118 +0,0 @@
1
-
2
- import numpy as np
3
- import pandas as pd
4
- import pickle
5
- from pathlib import Path
6
- from flwr.client import ClientApp
7
- from flwr.common import Message, Context
8
- from flwr.common.record import RecordSet, MetricsRecord, ConfigsRecord
9
- from sklearn.preprocessing import OneHotEncoder
10
- import FedModelKit as msi
11
-
12
- from model_example import create_local_learner #type: ignore[import]
13
- from load_data import load_data #type: ignore[import]
14
-
15
-
16
- # Initialize the Flower ClientApp
17
- app = ClientApp()
18
-
19
- @app.query()
20
- def query(msg: Message, ctx: Context) -> Message:
21
- """
22
- Query function to be executed by the Flower client. This function handles the
23
- initial configuration sent by the server.
24
- """
25
-
26
- # Retrieve the configuration sent by the server
27
- fancy_config = msg.content.configs_records['fancy_config']
28
-
29
- # Load the client split data using the load_data function
30
- data = load_data(fancy_config['num_clients'], fancy_config['client_id'])
31
-
32
- # Instantiate the federated model
33
- federated_model = msi.FederatedModel(create_local_learner=create_local_learner, model_name='simple_lr')
34
-
35
- # Store the local learner and the data split in the context
36
- # To store in context other objects, you can use ctx.state.<object_name> = <object>
37
- ctx.state.local_learner = federated_model.create_local_learner()
38
- ctx.state.data = data
39
- ctx.state.undestand = "che succede"
40
-
41
- return msg.create_reply(RecordSet())
42
-
43
- @app.train()
44
- def train(msg: Message, ctx: Context):
45
- """
46
- Train function to be executed by the Flower client.
47
- This function handles the training of the local model using the data provided.
48
- """
49
-
50
- # Retrieve the local learner and the client split from the context
51
- local_learner = ctx.state.local_learner
52
- data = ctx.state.data
53
-
54
- # Retrieve configuration sent by the server - example
55
- #fancy_config = msg.content.configs_records['fancy_config']
56
- #local_epochs = fancy_config['local_epochs']
57
-
58
- # Retrieve the model parameters sent by the server
59
- fancy_parameters = msg.content.parameters_records['fancy_model']
60
- local_learner.set_parameters(fancy_parameters)
61
-
62
- # Prepare the data using the local learner
63
- local_learner.prepare_data(data)
64
-
65
- # Perform local training and obtain training metrics
66
- train_metrics = local_learner.train_round()
67
-
68
- # Retrieve the trained model parameters
69
- new_parameters_records = local_learner.get_parameters()
70
- assert ctx.state.undestand.startswith("che"), "The context state is not being stored correctly"
71
-
72
- # Construct a reply message carrying updated model parameters and generated metrics
73
- reply_content = RecordSet()
74
- reply_content.parameters_records['fancy_model_returned'] = new_parameters_records
75
- reply_content.metrics_records['train_metrics'] = train_metrics
76
-
77
- # Store the metrics and the local learner in the context for future reference
78
- ctx.state.metrics_records['prev'] = train_metrics
79
- ctx.state.local_learner = local_learner
80
-
81
- # Return the reply message to the server
82
- return msg.create_reply(reply_content)
83
-
84
- @app.evaluate()
85
- def eval(msg: Message, ctx: Context):
86
- """
87
- Evaluate function to be executed by the Flower client.
88
- This function handles the evaluation of the local model using the data provided.
89
- """
90
-
91
- # Retrieve the local learner and the client split from the context
92
- local_learner = ctx.state.local_learner
93
- data = ctx.state.data
94
-
95
- # Retrieve configuration sent by the server - example
96
- #fancy_config = msg.content.configs_records['fancy_config']
97
- #local_epochs = fancy_config['local_epochs']
98
-
99
- # Retrieve the model parameters sent by the server
100
- fancy_parameters = msg.content.parameters_records['fancy_model']
101
- local_learner.set_parameters(fancy_parameters)
102
-
103
- # Prepare the data using the local learner
104
- local_learner.prepare_data(data)
105
-
106
- # Evaluate the model and obtain evaluation metrics
107
- eval_metrics = local_learner.evaluate()
108
-
109
- # Construct a reply message with evaluation metrics
110
- reply_content = RecordSet()
111
- reply_content.metrics_records['eval_metrics'] = eval_metrics
112
-
113
- # Store the metrics and the local learner in the context for future reference
114
- ctx.state.metrics_records['prev'] = eval_metrics
115
- ctx.state.local_learner = local_learner
116
-
117
- # Return the reply message to the server
118
- return msg.create_reply(reply_content)
@@ -1,204 +0,0 @@
1
-
2
- from typing import List
3
- import time
4
-
5
- import flwr as fl
6
- from flwr.common import (
7
- Context,
8
- NDArrays,
9
- Message,
10
- MessageType,
11
- Metrics,
12
- RecordSet,
13
- ConfigsRecord,
14
- DEFAULT_TTL,
15
- )
16
- from flwr.server import Driver
17
-
18
- import FedModelKit as msi
19
-
20
- from model_example import create_local_learner #type: ignore[import]
21
-
22
-
23
- # Run via `flower-server-app server:app`
24
- app = fl.server.ServerApp()
25
-
26
-
27
-
28
-
29
- @app.main()
30
- def main(driver: Driver, context: Context) -> None:
31
- """
32
- Main function to run the federated learning server.
33
-
34
- Structure:
35
- - Send a query message to clients for creating the local learner and loading the data
36
- - Start global epochs loop for training and evaluation
37
- - Send training messages to clients
38
- - Aggregate parameters received from clients
39
- - Send evaluation messages to clients
40
- - Aggregate evaluation metrics
41
- """
42
- print("Starting test run")
43
-
44
- # Get node IDs of connected clients
45
- node_ids = driver.get_node_ids()
46
-
47
- # Initialize the federated model
48
- federated_model = msi.FederatedModel(create_local_learner=create_local_learner,
49
- model_name='simple_lr')
50
- global_model = federated_model.create_local_learner()
51
- aggregation_strategy = federated_model.create_aggregator()
52
-
53
- # Send a query message to clients for creating the local learner and loading the data
54
- messages = []
55
- for idx, node_id in enumerate(node_ids):
56
- # Create messages to send to clients
57
- recordset = RecordSet()
58
-
59
- # Add a config with information to send the client for the query
60
- recordset.configs_records["fancy_config"] = ConfigsRecord({"num_clients": len(node_ids), "client_id": idx})
61
-
62
- # Create a query message for each client
63
- message = driver.create_message(
64
- content=recordset,
65
- message_type=MessageType.QUERY,
66
- dst_node_id=node_id,
67
- group_id=str(1),
68
- ttl=DEFAULT_TTL,
69
- )
70
- messages.append(message)
71
-
72
- # Send training messages to clients
73
- message_ids = driver.push_messages(messages)
74
- print(f"Pushed {len(list(message_ids))} messages: {message_ids}")
75
-
76
- # Wait for results from clients
77
- message_ids = [message_id for message_id in message_ids if message_id != ""]
78
- all_replies: List[Message] = []
79
- while True:
80
- replies = driver.pull_messages(message_ids=message_ids)
81
- print(f"Got {len(list(replies))} results")
82
- all_replies += replies
83
- if len(all_replies) == len(message_ids):
84
- break
85
- time.sleep(12)
86
-
87
- # Filter out messages with errors
88
- all_replies = [
89
- msg
90
- for msg in all_replies
91
- if msg.has_content()
92
- ]
93
- print(f"Received {len(all_replies)} answers")
94
-
95
-
96
- # Run federated training and evaluation for a fixed number of rounds
97
- for server_round in range(3):
98
- print(f"Commencing server train and evaluation round {server_round + 1}")
99
-
100
- messages = []
101
- for idx, node_id in enumerate(node_ids):
102
- # Create messages to send to clients
103
- recordset = RecordSet()
104
-
105
- # Add model parameters to record
106
- recordset.parameters_records["fancy_model"] = global_model.get_parameters()
107
- # Add a config with information to send the client for training
108
- recordset.configs_records["fancy_config"] = ConfigsRecord({"local_epochs": 3})
109
-
110
- # Create a training message for each client
111
- message = driver.create_message(
112
- content=recordset,
113
- message_type=MessageType.TRAIN,
114
- dst_node_id=node_id,
115
- group_id=str(server_round),
116
- ttl=DEFAULT_TTL,
117
- )
118
- messages.append(message)
119
-
120
- # Send training messages to clients
121
- message_ids = driver.push_messages(messages)
122
- print(f"Pushed {len(list(message_ids))} messages: {message_ids}")
123
-
124
- # Wait for results from clients
125
- message_ids = [message_id for message_id in message_ids if message_id != ""]
126
- all_replies: List[Message] = []
127
- while True:
128
- replies = driver.pull_messages(message_ids=message_ids)
129
- print(f"Got {len(list(replies))} results")
130
- all_replies += replies
131
- if len(all_replies) == len(message_ids):
132
- break
133
- time.sleep(12)
134
-
135
- # Filter out messages with errors
136
- all_replies = [
137
- msg
138
- for msg in all_replies
139
- if msg.has_content()
140
- ]
141
- print(f"Received {len(all_replies)} results")
142
-
143
- # Print metrics received from clients
144
- for reply in all_replies:
145
- print(reply.content.metrics_records)
146
-
147
- # Aggregate parameters received from clients
148
- parameter_records_list = [reply.content.parameters_records["fancy_model_returned"] for reply in all_replies]
149
- new_parameter_record = aggregation_strategy.aggregate_parameters(parameter_records_list)
150
- global_model.set_parameters(new_parameter_record)
151
-
152
- # Evaluate the updated global model
153
- messages = []
154
- for idx, node_id in enumerate(node_ids):
155
- # Create evaluation messages for clients
156
- recordset = RecordSet()
157
-
158
- # Add updated model parameters to record
159
- recordset.parameters_records["fancy_model"] = new_parameter_record
160
- # Add a config with information to send the client for evaluation
161
- recordset.configs_records["fancy_config"] = ConfigsRecord({"local_epochs": 3})
162
-
163
- # Create an evaluation message for each client
164
- message = driver.create_message(
165
- content=recordset,
166
- message_type=MessageType.EVALUATE,
167
- dst_node_id=node_id,
168
- group_id=str(server_round),
169
- ttl=DEFAULT_TTL,
170
- )
171
- messages.append(message)
172
-
173
- # Send evaluation messages to clients
174
- message_ids = driver.push_messages(messages)
175
- print(f"Pushed {len(list(message_ids))} messages: {message_ids}")
176
-
177
- # Wait for evaluation results from clients
178
- message_ids = [message_id for message_id in message_ids if message_id != ""]
179
- all_replies: List[Message] = []
180
- while True:
181
- replies = driver.pull_messages(message_ids=message_ids)
182
- print(f"Got {len(list(replies))} results")
183
- all_replies += replies
184
- if len(all_replies) == len(message_ids):
185
- break
186
- time.sleep(3)
187
-
188
- # Filter out messages with errors
189
- all_replies = [
190
- msg
191
- for msg in all_replies
192
- if msg.has_content()
193
- ]
194
- print(f"Received {len(all_replies)} results")
195
-
196
- # Print evaluation metrics received from clients
197
- metrics_records_list = [reply.content.metrics_records['eval_metrics'] for reply in all_replies]
198
- for i, reply in enumerate(all_replies):
199
- print(f"Client {i+1} metrics: ", reply.content.metrics_records['eval_metrics'])
200
-
201
- # Aggregate evaluation metrics
202
- print("Aggregated metrics result: ", aggregation_strategy.aggregate_metrics(metrics_records_list))
203
-
204
- print("🎉🎉🎉 Successfully completed federated learning run! 🎉🎉🎉")
@@ -1,140 +0,0 @@
1
- """
2
- Model upload example script.
3
-
4
- To run this code you need to install pytorch library, you can find it at this link: https://pytorch.org/get-started/locally/
5
- """
6
- from FedModelKit import FederatedModel, submit_fl_model
7
-
8
- # Create your custom function defining a PyTorch-based model and returning its instance
9
- def create_local_learner():
10
- import pandas as pd
11
- from torch import nn, optim
12
- from torch.utils.data import DataLoader, Dataset
13
- import torch
14
- import flwr
15
-
16
- from src.utils import Utils # Dependency script stored inside the src directory
17
-
18
- class CustomLocalLearner(nn.Module):
19
- def __init__(self, input_size: int) -> None:
20
- super(CustomLocalLearner, self).__init__()
21
- self.linear = nn.Linear(input_size, 3)
22
- self.softmax = nn.Softmax(dim=1)
23
- self.loss_fn = nn.CrossEntropyLoss()
24
- self.optimizer = optim.Adam(self.parameters(), lr=0.001)
25
-
26
- def _forward(self, x):
27
- x = self.linear(x)
28
- return self.softmax(x)
29
-
30
- def get_parameters(self) -> flwr.common.ParametersRecord:
31
- return Utils.pytorch_to_parameter_record(self.state_dict())
32
-
33
- def set_parameters(self, parameters: flwr.common.ParametersRecord) -> None:
34
- self.load_state_dict(Utils.parameters_to_pytorch_state_dict(parameters))
35
-
36
- def prepare_data(self, data: pd.DataFrame) -> None:
37
- class IrisDataset(Dataset):
38
- def __init__(self, dataframe: pd.DataFrame) -> None:
39
- self.dataframe = dataframe
40
- self.dataframe.loc[:, "class"] = dataframe.loc[:, "class"].replace(
41
- {"Iris-setosa": 0, "Iris-versicolor": 1, "Iris-virginica": 2}
42
- )
43
-
44
- def __getitem__(self, idx: int):
45
- x = self.dataframe.iloc[idx, :-1].to_numpy("float32")
46
- y = torch.tensor(self.dataframe.iloc[idx, -1], dtype=torch.long)
47
- return x, y
48
-
49
- def __len__(self) -> int:
50
- return len(self.dataframe)
51
-
52
- dataset = IrisDataset(data)
53
- dataloader = DataLoader(dataset, batch_size=32)
54
- self.dataloader = dataloader
55
-
56
- def train_round(self) -> flwr.common.MetricsRecord:
57
- loss = torch.tensor(0.0)
58
- for batch in self.dataloader:
59
- x, y = batch
60
- self.optimizer.zero_grad()
61
- y_hat = self._forward(x)
62
- loss = self.loss_fn(y_hat, y)
63
- loss.backward()
64
- self.optimizer.step()
65
-
66
- return flwr.common.MetricsRecord({"loss": loss.item()})
67
-
68
- def evaluate(self) -> flwr.common.MetricsRecord:
69
- correct = 0
70
- total = 0
71
- with torch.no_grad():
72
- for batch in self.dataloader:
73
- x, y = batch
74
- y_hat = self._forward(x)
75
- _, predicted = torch.max(y_hat, 1)
76
- total += y.size(0)
77
- correct += (predicted == y).sum().item()
78
- return flwr.common.MetricsRecord({"accuracy": correct / total})
79
-
80
- return CustomLocalLearner(4)
81
-
82
-
83
- # Create your custom function defining an aggregation strategy and returning its instance
84
- def create_aggregator():
85
- from collections import OrderedDict
86
- import numpy as np
87
- import flwr
88
- from typing import Optional
89
-
90
- from src.utils import Utils # Dependency script stored inside the src directory #type: ignore[import]
91
-
92
- class CustomAggregator:
93
-
94
- def aggregate_parameters(self, results: list[flwr.common.ParametersRecord], config: Optional[flwr.common.ConfigsRecord]=None
95
- ) -> flwr.common.ParametersRecord:
96
- parameters = [Utils.parameters_to_dict(param) for param in results]
97
- keys = parameters[0].keys()
98
- result = OrderedDict()
99
- for key in keys:
100
- # Init array
101
- this_array: np.ndarray = np.zeros_like(parameters[0][key])
102
- for p in parameters:
103
- this_array += p[key]
104
- result[key] = this_array / len(results)
105
- return Utils.dict_to_parameter_record(result)
106
-
107
- def aggregate_metrics(self, results: list[flwr.common.MetricsRecord], config: Optional[flwr.common.ConfigsRecord]=None) -> flwr.common.MetricsRecord:
108
- keys = results[0].keys()
109
- result = OrderedDict()
110
- for key in keys:
111
- # Init array
112
- cumsum = 0.0
113
- for m in results:
114
- if not isinstance(m[key], (int, float)):
115
- raise ValueError(
116
- f"flwr.common.MetricsRecord value type not supported: {type(m[key])}"
117
- )
118
- cumsum += m[key] # type: ignore
119
- result[key] = cumsum / len(results)
120
- return flwr.common.MetricsRecord(result)
121
-
122
- return CustomAggregator()
123
-
124
-
125
- if __name__ == "__main__":
126
- # Create FederatedModel with custom aggregator and local learner
127
- federated_model = FederatedModel(create_local_learner=create_local_learner,
128
- model_name="your_model_name",
129
- create_aggregator=create_aggregator,
130
- aggregator_name="custom_aggregator")
131
-
132
- # Submit the FederatedModel to the Federated Platform Model Catalogue with credentials and experiment name
133
- submit_fl_model(federated_model,
134
- platform_url="your_platform_url",
135
- username="your_username",
136
- password="your_password",
137
- experiment_name="your_experiment_name",
138
- disease="AML",
139
- trained=False
140
- )
File without changes
File without changes
File without changes
File without changes