MEDfl 0.2.1__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. MEDfl/LearningManager/__init__.py +13 -13
  2. MEDfl/LearningManager/client.py +150 -181
  3. MEDfl/LearningManager/dynamicModal.py +287 -287
  4. MEDfl/LearningManager/federated_dataset.py +60 -60
  5. MEDfl/LearningManager/flpipeline.py +192 -192
  6. MEDfl/LearningManager/model.py +223 -223
  7. MEDfl/LearningManager/params.yaml +14 -14
  8. MEDfl/LearningManager/params_optimiser.py +442 -442
  9. MEDfl/LearningManager/plot.py +229 -229
  10. MEDfl/LearningManager/server.py +181 -189
  11. MEDfl/LearningManager/strategy.py +82 -138
  12. MEDfl/LearningManager/utils.py +331 -331
  13. MEDfl/NetManager/__init__.py +10 -10
  14. MEDfl/NetManager/database_connector.py +43 -43
  15. MEDfl/NetManager/dataset.py +92 -92
  16. MEDfl/NetManager/flsetup.py +320 -320
  17. MEDfl/NetManager/net_helper.py +254 -254
  18. MEDfl/NetManager/net_manager_queries.py +142 -142
  19. MEDfl/NetManager/network.py +194 -194
  20. MEDfl/NetManager/node.py +184 -184
  21. MEDfl/__init__.py +2 -2
  22. MEDfl/scripts/__init__.py +1 -1
  23. MEDfl/scripts/base.py +29 -29
  24. MEDfl/scripts/create_db.py +126 -126
  25. Medfl/LearningManager/__init__.py +13 -0
  26. Medfl/LearningManager/client.py +150 -0
  27. Medfl/LearningManager/dynamicModal.py +287 -0
  28. Medfl/LearningManager/federated_dataset.py +60 -0
  29. Medfl/LearningManager/flpipeline.py +192 -0
  30. Medfl/LearningManager/model.py +223 -0
  31. Medfl/LearningManager/params.yaml +14 -0
  32. Medfl/LearningManager/params_optimiser.py +442 -0
  33. Medfl/LearningManager/plot.py +229 -0
  34. Medfl/LearningManager/server.py +181 -0
  35. Medfl/LearningManager/strategy.py +82 -0
  36. Medfl/LearningManager/utils.py +331 -0
  37. Medfl/NetManager/__init__.py +10 -0
  38. Medfl/NetManager/database_connector.py +43 -0
  39. Medfl/NetManager/dataset.py +92 -0
  40. Medfl/NetManager/flsetup.py +320 -0
  41. Medfl/NetManager/net_helper.py +254 -0
  42. Medfl/NetManager/net_manager_queries.py +142 -0
  43. Medfl/NetManager/network.py +194 -0
  44. Medfl/NetManager/node.py +184 -0
  45. Medfl/__init__.py +3 -0
  46. Medfl/scripts/__init__.py +2 -0
  47. Medfl/scripts/base.py +30 -0
  48. Medfl/scripts/create_db.py +126 -0
  49. alembic/env.py +61 -61
  50. {MEDfl-0.2.1.dist-info → medfl-2.0.0.dist-info}/METADATA +120 -108
  51. medfl-2.0.0.dist-info/RECORD +55 -0
  52. {MEDfl-0.2.1.dist-info → medfl-2.0.0.dist-info}/WHEEL +1 -1
  53. {MEDfl-0.2.1.dist-info → medfl-2.0.0.dist-info/licenses}/LICENSE +674 -674
  54. MEDfl-0.2.1.dist-info/RECORD +0 -31
  55. {MEDfl-0.2.1.dist-info → medfl-2.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,229 @@
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import seaborn as sns
4
+
5
+ from .utils import *
6
+
7
+ # Replace this with your actual code for data collection
8
+ results_dict = {
9
+ ("LR: 0.001, Optimizer: Adam", "accuracy"): [0.85, 0.89, 0.92, 0.94, ...],
10
+ ("LR: 0.001, Optimizer: Adam", "loss"): [0.2, 0.15, 0.1, 0.08, ...],
11
+ ("LR: 0.01, Optimizer: SGD", "accuracy"): [0.88, 0.91, 0.93, 0.95, ...],
12
+ ("LR: 0.01, Optimizer: SGD", "loss"): [0.18, 0.13, 0.09, 0.07, ...],
13
+ ("LR: 0.1, Optimizer: Adam", "accuracy"): [0.82, 0.87, 0.91, 0.93, ...],
14
+ ("LR: 0.1, Optimizer: Adam", "loss"): [0.25, 0.2, 0.15, 0.12, ...],
15
+ }
16
+ """
17
+ server should have:
18
+ #len = num of rounds
19
+ self.accuracies
20
+ self.losses
21
+
22
+ Client should have
23
+ # len = num of epochs
24
+ self.accuracies
25
+ self.losses
26
+ self.epsilons
27
+ self.deltas
28
+
29
+ #common things : LR,SGD, Aggregation
30
+
31
+ """
32
+
33
+
34
+ class AccuracyLossPlotter:
35
+ """
36
+ A utility class for plotting accuracy and loss metrics based on experiment results.
37
+
38
+ Args:
39
+ results_dict (dict): Dictionary containing experiment results organized by parameters and metrics.
40
+
41
+ Attributes:
42
+ results_dict (dict): Dictionary containing experiment results organized by parameters and metrics.
43
+ parameters (list): List of unique parameters in the experiment results.
44
+ metrics (list): List of unique metrics in the experiment results.
45
+ iterations (range): Range of iterations (rounds or epochs) in the experiment.
46
+ """
47
+
48
+ def __init__(self, results_dict):
49
+ """
50
+ Initialize the AccuracyLossPlotter with experiment results.
51
+
52
+ Args:
53
+ results_dict (dict): Dictionary containing experiment results organized by parameters and metrics.
54
+ """
55
+ self.results_dict = results_dict
56
+ self.parameters = list(
57
+ set([param[0] for param in results_dict.keys()])
58
+ )
59
+ self.metrics = list(set([param[1] for param in results_dict.keys()]))
60
+ self.iterations = range(1, len(list(results_dict.values())[0]) + 1)
61
+
62
+ def plot_accuracy_loss(self):
63
+ """
64
+ Plot accuracy and loss metrics for different parameters.
65
+ """
66
+
67
+ plt.figure(figsize=(8, 6))
68
+
69
+ for param in self.parameters:
70
+ for metric in self.metrics:
71
+ key = (param, metric)
72
+ values = self.results_dict[key]
73
+ plt.plot(
74
+ self.iterations,
75
+ values,
76
+ label=f"{param} ({metric})",
77
+ marker="o",
78
+ linestyle="-",
79
+ )
80
+
81
+ plt.xlabel("Rounds")
82
+ plt.ylabel("Accuracy / Loss")
83
+ plt.title("Accuracy and Loss by Parameters")
84
+ plt.legend()
85
+ plt.grid(True)
86
+ plt.show()
87
+
88
+ @staticmethod
89
+ def plot_global_confusion_matrix(pipeline_name: str):
90
+ """
91
+ Plot a global confusion matrix based on pipeline results.
92
+
93
+ Args:
94
+ pipeline_name (str): Name of the pipeline.
95
+
96
+ Returns:
97
+ None
98
+ """
99
+ # Get the id of the pipeline by name
100
+ pipeline_id = get_pipeline_from_name(pipeline_name)
101
+ # get the confusion matrix pf the pipeline
102
+ confusion_matrix = get_pipeline_confusion_matrix(pipeline_id)
103
+
104
+ # Extracting confusion matrix values
105
+ TP = confusion_matrix['TP']
106
+ FP = confusion_matrix['FP']
107
+ FN = confusion_matrix['FN']
108
+ TN = confusion_matrix['TN']
109
+
110
+ # Creating a matrix for visualization
111
+ matrix = [[TN, FP],
112
+ [FN, TP]]
113
+
114
+ # Plotting the confusion matrix as a heatmap
115
+ plt.figure(figsize=(6, 4))
116
+ sns.heatmap(matrix, annot=True, fmt='d', cmap='Blues',
117
+ xticklabels=['Predicted Negative', 'Predicted Positive'],
118
+ yticklabels=['Actual Negative', 'Actual Positive'])
119
+ plt.title('Global Confusion Matrix')
120
+ plt.xlabel('Predicted label')
121
+ plt.ylabel('True label')
122
+ plt.tight_layout()
123
+
124
+ # Display the confusion matrix heatmap
125
+ plt.show()
126
+
127
+ @staticmethod
128
+ def plot_confusion_Matrix_by_node(node_name: str, pipeline_name: str):
129
+ """
130
+ Plot a confusion matrix for a specific node in the pipeline.
131
+
132
+ Args:
133
+ node_name (str): Name of the node.
134
+ pipeline_name (str): Name of the pipeline.
135
+
136
+ Returns:
137
+ None
138
+ """
139
+
140
+ # Get the id of the pipeline by name
141
+ pipeline_id = get_pipeline_from_name(pipeline_name)
142
+ # get the confusion matrix pf the pipeline
143
+ confusion_matrix = get_node_confusion_matrix(
144
+ pipeline_id, node_name=node_name)
145
+
146
+ # Extracting confusion matrix values
147
+ TP = confusion_matrix['TP']
148
+ FP = confusion_matrix['FP']
149
+ FN = confusion_matrix['FN']
150
+ TN = confusion_matrix['TN']
151
+
152
+ # Creating a matrix for visualization
153
+ matrix = [[TN, FP],
154
+ [FN, TP]]
155
+
156
+ # Plotting the confusion matrix as a heatmap
157
+ plt.figure(figsize=(6, 4))
158
+ sns.heatmap(matrix, annot=True, fmt='d', cmap='Blues',
159
+ xticklabels=['Predicted Negative', 'Predicted Positive'],
160
+ yticklabels=['Actual Negative', 'Actual Positive'])
161
+ plt.title('Confusion Matrix of node: '+node_name)
162
+ plt.xlabel('Predicted label')
163
+ plt.ylabel('True label')
164
+ plt.tight_layout()
165
+
166
+ # Display the confusion matrix heatmap
167
+ plt.show()
168
+ return
169
+
170
+ @staticmethod
171
+ def plot_classification_report(pipeline_name: str):
172
+ """
173
+ Plot a comparison of classification report metrics between nodes.
174
+
175
+ Args:
176
+ pipeline_name (str): Name of the pipeline.
177
+
178
+ Returns:
179
+ None
180
+ """
181
+
182
+ colors = ['#FF5733', '#6A5ACD', '#3CB371', '#FFD700', '#FFA500', '#8A2BE2', '#00FFFF', '#FF00FF', '#A52A2A', '#00FF00']
183
+
184
+ # Get the id of the pipeline by name
185
+ pipeline_id = get_pipeline_from_name(pipeline_name)
186
+
187
+ pipeline_results = get_pipeline_result(pipeline_id)
188
+
189
+ nodesList = pipeline_results['nodename']
190
+ classificationReports = []
191
+
192
+ for index, node in enumerate(nodesList):
193
+ classificationReports.append({
194
+ 'Accuracy': pipeline_results['accuracy'][index],
195
+ 'Sensitivity/Recall': pipeline_results['sensivity'][index],
196
+ 'PPV/Precision': pipeline_results['ppv'][index],
197
+ 'NPV': pipeline_results['npv'][index],
198
+ 'F1-score': pipeline_results['f1score'][index],
199
+ 'False positive rate': pipeline_results['fpr'][index],
200
+ 'True positive rate': pipeline_results['tpr'][index]
201
+ })
202
+
203
+ metric_labels = list(classificationReports[0].keys()) # Assuming both reports have the same keys
204
+
205
+ # Set the positions of the bars on the x-axis
206
+ x = np.arange(len(metric_labels))
207
+
208
+ # Set the width of the bars
209
+ width = 0.35
210
+
211
+ plt.figure(figsize=(12, 6))
212
+
213
+ for index, report in enumerate(classificationReports):
214
+ metric = list(report.values())
215
+ plt.bar(x + (index - len(nodesList) / 2) * width / len(nodesList), metric, width / len(nodesList),
216
+ label=nodesList[index], color=colors[index % len(colors)])
217
+
218
+ # Adding labels, title, and legend
219
+ plt.xlabel('Metrics')
220
+ plt.ylabel('Values')
221
+ plt.title('Comparison of Classification Report Metrics between Nodes')
222
+ plt.xticks(ticks=x, labels=metric_labels, rotation=45)
223
+ plt.legend()
224
+
225
+ # Show plot
226
+ plt.tight_layout()
227
+ plt.show()
228
+
229
+ return
@@ -0,0 +1,181 @@
1
+ #!/usr/bin/env python3
2
+
3
+ import copy
4
+ from typing import Dict, Optional, Tuple
5
+
6
+ import flwr as fl
7
+ import torch
8
+
9
+ from .client import FlowerClient
10
+ from .federated_dataset import FederatedDataset
11
+ from .model import Model
12
+ from .strategy import Strategy
13
+
14
+
15
+ class FlowerServer:
16
+ """
17
+ A class representing the central server for Federated Learning using Flower.
18
+
19
+ Attributes:
20
+ global_model (Model): The global model that will be federated among clients.
21
+ strategy (Strategy): The strategy used for federated learning, specifying communication and aggregation methods.
22
+ num_rounds (int): The number of federated learning rounds to perform.
23
+ num_clients (int): The number of clients participating in the federated learning process.
24
+ fed_dataset (FederatedDataset): The federated dataset used for training and evaluation.
25
+ diff_priv (bool): Whether differential privacy is used during the federated learning process.
26
+ accuracies (List[float]): A list to store the accuracy of the global model during each round.
27
+ losses (List[float]): A list to store the loss of the global model during each round.
28
+ flower_clients (List[FlowerClient]): A list to store the FlowerClient objects representing individual clients.
29
+
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ global_model: Model,
35
+ strategy: Strategy,
36
+ num_rounds: int,
37
+ num_clients: int,
38
+ fed_dataset: FederatedDataset,
39
+ diff_privacy: bool = False,
40
+ client_resources: Optional[Dict[str, float]] = {'num_cpus': 1, 'num_gpus': 0.0}
41
+ ) -> None:
42
+ """
43
+ Initialize a FlowerServer object with the specified parameters.
44
+
45
+ Args:
46
+ global_model (Model): The global model that will be federated among clients.
47
+ strategy (Strategy): The strategy used for federated learning, specifying communication and aggregation methods.
48
+ num_rounds (int): The number of federated learning rounds to perform.
49
+ num_clients (int): The number of clients participating in the federated learning process.
50
+ fed_dataset (FederatedDataset): The federated dataset used for training and evaluation.
51
+ diff_privacy (bool, optional): Whether differential privacy is used during the federated learning process.
52
+ Default is False.
53
+ """
54
+ self.device = torch.device(
55
+ f"cuda" if torch.cuda.is_available() else "cpu"
56
+ )
57
+ self.global_model = global_model
58
+ self.params = global_model.get_parameters()
59
+ self.global_model.model = global_model.model.to(self.device)
60
+ self.num_rounds = num_rounds
61
+ self.num_clients = num_clients
62
+ self.fed_dataset = fed_dataset
63
+ self.strategy = strategy
64
+ self.client_resources = client_resources
65
+ setattr(
66
+ self.strategy.strategy_object,
67
+ "min_available_clients",
68
+ self.num_clients,
69
+ )
70
+ setattr(
71
+ self.strategy.strategy_object,
72
+ "initial_parameters",
73
+ fl.common.ndarrays_to_parameters(self.params),
74
+ )
75
+ setattr(self.strategy.strategy_object, "evaluate_fn", self.evaluate)
76
+ self.fed_dataset = fed_dataset
77
+ self.diff_priv = diff_privacy
78
+ self.accuracies = []
79
+ self.losses = []
80
+ self.auc = []
81
+ self.flower_clients = []
82
+ self.validate()
83
+
84
+ def validate(self) -> None:
85
+ """Validate global_model, strategy, num_clients, num_rounds, fed_dataset, diff_privacy"""
86
+ if not isinstance(self.global_model, Model):
87
+ raise TypeError("global_model argument must be a Model instance")
88
+
89
+ # if not isinstance(self.strategy, Strategy):
90
+ # print(self.strategy)
91
+ # print(isinstance(self.strategy, Strategy))
92
+ # raise TypeError("strategy argument must be a Strategy instance")
93
+
94
+ if not isinstance(self.num_clients, int):
95
+ raise TypeError("num_clients argument must be an int")
96
+
97
+ if not isinstance(self.num_rounds, int):
98
+ raise TypeError("num_rounds argument must be an int")
99
+
100
+ if not isinstance(self.diff_priv, bool):
101
+ raise TypeError("diff_priv argument must be a bool")
102
+
103
+ def client_fn(self, cid) -> FlowerClient:
104
+ """
105
+ Return a FlowerClient object for a specific client ID.
106
+
107
+ Args:
108
+ cid: The client ID.
109
+
110
+ Returns:
111
+ FlowerClient: A FlowerClient object representing the individual client.
112
+ """
113
+
114
+ device = torch.device(
115
+ f"cuda:{int(cid) % 4}" if torch.cuda.is_available() else "cpu"
116
+ )
117
+ client_model = copy.deepcopy(self.global_model)
118
+
119
+ trainloader = self.fed_dataset.trainloaders[int(cid)]
120
+ valloader = self.fed_dataset.valloaders[int(cid)]
121
+ # this helps in making plots
122
+
123
+ client = FlowerClient(
124
+ cid, client_model, trainloader, valloader, self.diff_priv
125
+ )
126
+ self.flower_clients.append(client)
127
+ return client
128
+
129
+ def evaluate(
130
+ self,
131
+ server_round: int,
132
+ parameters: fl.common.NDArrays,
133
+ config: Dict[str, fl.common.Scalar],
134
+ ) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
135
+ """
136
+ Evaluate the global model on the validation dataset and update the accuracies and losses.
137
+
138
+ Args:
139
+ server_round (int): The current round of the federated learning process.
140
+ parameters (fl.common.NDArrays): The global model parameters.
141
+ config (Dict[str, fl.common.Scalar]): Configuration dictionary.
142
+
143
+ Returns:
144
+ Optional[Tuple[float, Dict[str, fl.common.Scalar]]]: The evaluation loss and accuracy.
145
+ """
146
+ testloader = self.fed_dataset.valloaders[0]
147
+
148
+ self.global_model.set_parameters(
149
+ parameters
150
+ ) # Update model with the latest parameters
151
+ loss, accuracy ,auc = self.global_model.evaluate(testloader, self.device)
152
+ self.auc.append(auc)
153
+ self.losses.append(loss)
154
+ self.accuracies.append(accuracy)
155
+
156
+ return loss, {"accuracy": accuracy}
157
+
158
+ def run(self) -> None:
159
+ """
160
+ Run the federated learning process using Flower simulation.
161
+
162
+ Returns:
163
+ History: The history of the accuracies and losses during the training of each node
164
+ """
165
+ # Increase the object store memory to the minimum allowed value or higher
166
+ ray_init_args = {"include_dashboard": False
167
+ , "object_store_memory": 78643200
168
+ }
169
+ self.fed_dataset.eng = None
170
+
171
+ history = fl.simulation.start_simulation(
172
+ client_fn=self.client_fn,
173
+ num_clients=self.num_clients,
174
+ config=fl.server.ServerConfig(self.num_rounds),
175
+ strategy=self.strategy.strategy_object,
176
+ ray_init_args=ray_init_args,
177
+ client_resources = self.client_resources
178
+ )
179
+
180
+ return history
181
+
@@ -0,0 +1,82 @@
1
+
2
+ from collections import OrderedDict
3
+ from typing import Dict, List, Optional, Tuple
4
+
5
+ import flwr as fl
6
+ import numpy as np
7
+
8
+ import optuna
9
+
10
+
11
+
12
+
13
+ class Strategy:
14
+ """
15
+ A class representing a strategy for Federated Learning.
16
+
17
+ Attributes:
18
+ name (str): The name of the strategy. Default is "FedAvg".
19
+ fraction_fit (float): Fraction of clients to use for training during each round. Default is 1.0.
20
+ fraction_evaluate (float): Fraction of clients to use for evaluation during each round. Default is 1.0.
21
+ min_fit_clients (int): Minimum number of clients to use for training during each round. Default is 2.
22
+ min_evaluate_clients (int): Minimum number of clients to use for evaluation during each round. Default is 2.
23
+ min_available_clients (int): Minimum number of available clients required to start a round. Default is 2.
24
+ initial_parameters (Optional[]): The initial parameters of the server model
25
+ Methods:
26
+
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ name: str = "FedAvg",
32
+ fraction_fit: float = 1.0,
33
+ fraction_evaluate: float = 1.0,
34
+ min_fit_clients: int = 2,
35
+ min_evaluate_clients: int = 2,
36
+ min_available_clients: int = 2,
37
+ initial_parameters = [],
38
+ evaluation_methode = "centralized"
39
+ ) -> None:
40
+ """
41
+ Initialize a Strategy object with the specified parameters.
42
+
43
+ Args:
44
+ name (str): The name of the strategy. Default is "FedAvg".
45
+ fraction_fit (float): Fraction of clients to use for training during each round. Default is 1.0.
46
+ fraction_evaluate (float): Fraction of clients to use for evaluation during each round. Default is 1.0.
47
+ min_fit_clients (int): Minimum number of clients to use for training during each round. Default is 2.
48
+ min_evaluate_clients (int): Minimum number of clients to use for evaluation during each round. Default is 2.
49
+ min_available_clients (int): Minimum number of available clients required to start a round. Default is 2.
50
+ initial_parameters (Optional[]): The initial parametres of the server model
51
+ evaluation_methode ( "centralized" | "distributed")
52
+ """
53
+ self.fraction_fit = fraction_fit
54
+ self.fraction_evaluate = fraction_evaluate
55
+ self.min_fit_clients = min_fit_clients
56
+ self.min_evaluate_clients = min_evaluate_clients
57
+ self.min_available_clients = min_available_clients
58
+ self.initial_parameters = initial_parameters
59
+ self.evaluate_fn = None
60
+ self.name = name
61
+
62
+ def optuna_fed_optimization(self, direction:str , hpo_rate:int , params_config):
63
+ self.study = optuna.create_study(direction=direction)
64
+ self.hpo_rate = hpo_rate
65
+ self.params_config = params_config
66
+
67
+
68
+ def create_strategy(self):
69
+ self.strategy_object = self.get_strategy_by_name()(
70
+ fraction_fit=self.fraction_fit,
71
+ fraction_evaluate=self.fraction_evaluate,
72
+ min_fit_clients=self.min_fit_clients,
73
+ min_evaluate_clients=self.min_evaluate_clients,
74
+ min_available_clients=self.min_available_clients,
75
+ initial_parameters=fl.common.ndarrays_to_parameters(self.initial_parameters),
76
+ evaluate_fn=self.evaluate_fn
77
+ )
78
+ def get_strategy_by_name(self):
79
+ return eval(f"fl.server.strategy.{self.name}")
80
+
81
+
82
+