MEDfl 0.2.1__py3-none-any.whl → 2.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- MEDfl/LearningManager/__init__.py +13 -13
- MEDfl/LearningManager/client.py +150 -181
- MEDfl/LearningManager/dynamicModal.py +287 -287
- MEDfl/LearningManager/federated_dataset.py +60 -60
- MEDfl/LearningManager/flpipeline.py +192 -192
- MEDfl/LearningManager/model.py +223 -223
- MEDfl/LearningManager/params.yaml +14 -14
- MEDfl/LearningManager/params_optimiser.py +442 -442
- MEDfl/LearningManager/plot.py +229 -229
- MEDfl/LearningManager/server.py +181 -189
- MEDfl/LearningManager/strategy.py +82 -138
- MEDfl/LearningManager/utils.py +331 -331
- MEDfl/NetManager/__init__.py +10 -10
- MEDfl/NetManager/database_connector.py +43 -43
- MEDfl/NetManager/dataset.py +92 -92
- MEDfl/NetManager/flsetup.py +320 -320
- MEDfl/NetManager/net_helper.py +254 -254
- MEDfl/NetManager/net_manager_queries.py +142 -142
- MEDfl/NetManager/network.py +194 -194
- MEDfl/NetManager/node.py +184 -184
- MEDfl/__init__.py +4 -3
- MEDfl/scripts/__init__.py +1 -1
- MEDfl/scripts/base.py +29 -29
- MEDfl/scripts/create_db.py +126 -126
- Medfl/LearningManager/__init__.py +13 -0
- Medfl/LearningManager/client.py +150 -0
- Medfl/LearningManager/dynamicModal.py +287 -0
- Medfl/LearningManager/federated_dataset.py +60 -0
- Medfl/LearningManager/flpipeline.py +192 -0
- Medfl/LearningManager/model.py +223 -0
- Medfl/LearningManager/params.yaml +14 -0
- Medfl/LearningManager/params_optimiser.py +442 -0
- Medfl/LearningManager/plot.py +229 -0
- Medfl/LearningManager/server.py +181 -0
- Medfl/LearningManager/strategy.py +82 -0
- Medfl/LearningManager/utils.py +331 -0
- Medfl/NetManager/__init__.py +10 -0
- Medfl/NetManager/database_connector.py +43 -0
- Medfl/NetManager/dataset.py +92 -0
- Medfl/NetManager/flsetup.py +320 -0
- Medfl/NetManager/net_helper.py +254 -0
- Medfl/NetManager/net_manager_queries.py +142 -0
- Medfl/NetManager/network.py +194 -0
- Medfl/NetManager/node.py +184 -0
- Medfl/__init__.py +3 -0
- Medfl/scripts/__init__.py +2 -0
- Medfl/scripts/base.py +30 -0
- Medfl/scripts/create_db.py +126 -0
- alembic/env.py +61 -61
- {MEDfl-0.2.1.dist-info → medfl-2.0.1.dist-info}/METADATA +120 -108
- medfl-2.0.1.dist-info/RECORD +55 -0
- {MEDfl-0.2.1.dist-info → medfl-2.0.1.dist-info}/WHEEL +1 -1
- {MEDfl-0.2.1.dist-info → medfl-2.0.1.dist-info/licenses}/LICENSE +674 -674
- MEDfl-0.2.1.dist-info/RECORD +0 -31
- {MEDfl-0.2.1.dist-info → medfl-2.0.1.dist-info}/top_level.txt +0 -0
MEDfl/LearningManager/server.py
CHANGED
@@ -1,189 +1,181 @@
|
|
1
|
-
#!/usr/bin/env python3
|
2
|
-
|
3
|
-
import copy
|
4
|
-
from typing import Dict, Optional, Tuple
|
5
|
-
|
6
|
-
import flwr as fl
|
7
|
-
import torch
|
8
|
-
|
9
|
-
from .client import FlowerClient
|
10
|
-
from .federated_dataset import FederatedDataset
|
11
|
-
from .model import Model
|
12
|
-
from .strategy import Strategy
|
13
|
-
|
14
|
-
|
15
|
-
class FlowerServer:
|
16
|
-
"""
|
17
|
-
A class representing the central server for Federated Learning using Flower.
|
18
|
-
|
19
|
-
Attributes:
|
20
|
-
global_model (Model): The global model that will be federated among clients.
|
21
|
-
strategy (Strategy): The strategy used for federated learning, specifying communication and aggregation methods.
|
22
|
-
num_rounds (int): The number of federated learning rounds to perform.
|
23
|
-
num_clients (int): The number of clients participating in the federated learning process.
|
24
|
-
fed_dataset (FederatedDataset): The federated dataset used for training and evaluation.
|
25
|
-
diff_priv (bool): Whether differential privacy is used during the federated learning process.
|
26
|
-
accuracies (List[float]): A list to store the accuracy of the global model during each round.
|
27
|
-
losses (List[float]): A list to store the loss of the global model during each round.
|
28
|
-
flower_clients (List[FlowerClient]): A list to store the FlowerClient objects representing individual clients.
|
29
|
-
|
30
|
-
"""
|
31
|
-
|
32
|
-
def __init__(
|
33
|
-
self,
|
34
|
-
global_model: Model,
|
35
|
-
strategy: Strategy,
|
36
|
-
num_rounds: int,
|
37
|
-
num_clients: int,
|
38
|
-
fed_dataset: FederatedDataset,
|
39
|
-
diff_privacy: bool = False,
|
40
|
-
client_resources: Optional[Dict[str, float]] = {'num_cpus': 1, 'num_gpus': 0.0}
|
41
|
-
) -> None:
|
42
|
-
"""
|
43
|
-
Initialize a FlowerServer object with the specified parameters.
|
44
|
-
|
45
|
-
Args:
|
46
|
-
global_model (Model): The global model that will be federated among clients.
|
47
|
-
strategy (Strategy): The strategy used for federated learning, specifying communication and aggregation methods.
|
48
|
-
num_rounds (int): The number of federated learning rounds to perform.
|
49
|
-
num_clients (int): The number of clients participating in the federated learning process.
|
50
|
-
fed_dataset (FederatedDataset): The federated dataset used for training and evaluation.
|
51
|
-
diff_privacy (bool, optional): Whether differential privacy is used during the federated learning process.
|
52
|
-
Default is False.
|
53
|
-
"""
|
54
|
-
self.device = torch.device(
|
55
|
-
f"cuda" if torch.cuda.is_available() else "cpu"
|
56
|
-
)
|
57
|
-
self.global_model = global_model
|
58
|
-
self.params = global_model.get_parameters()
|
59
|
-
self.global_model.model = global_model.model.to(self.device)
|
60
|
-
self.num_rounds = num_rounds
|
61
|
-
self.num_clients = num_clients
|
62
|
-
self.fed_dataset = fed_dataset
|
63
|
-
self.strategy = strategy
|
64
|
-
self.client_resources = client_resources
|
65
|
-
setattr(
|
66
|
-
self.strategy.strategy_object,
|
67
|
-
"min_available_clients",
|
68
|
-
self.num_clients,
|
69
|
-
)
|
70
|
-
setattr(
|
71
|
-
self.strategy.strategy_object,
|
72
|
-
"initial_parameters",
|
73
|
-
fl.common.ndarrays_to_parameters(self.params),
|
74
|
-
)
|
75
|
-
setattr(self.strategy.strategy_object, "evaluate_fn", self.evaluate)
|
76
|
-
self.fed_dataset = fed_dataset
|
77
|
-
self.diff_priv = diff_privacy
|
78
|
-
self.accuracies = []
|
79
|
-
self.losses = []
|
80
|
-
self.auc = []
|
81
|
-
self.flower_clients = []
|
82
|
-
self.validate()
|
83
|
-
|
84
|
-
def validate(self) -> None:
|
85
|
-
"""Validate global_model, strategy, num_clients, num_rounds, fed_dataset, diff_privacy"""
|
86
|
-
if not isinstance(self.global_model, Model):
|
87
|
-
raise TypeError("global_model argument must be a Model instance")
|
88
|
-
|
89
|
-
# if not isinstance(self.strategy, Strategy):
|
90
|
-
# print(self.strategy)
|
91
|
-
# print(isinstance(self.strategy, Strategy))
|
92
|
-
# raise TypeError("strategy argument must be a Strategy instance")
|
93
|
-
|
94
|
-
if not isinstance(self.num_clients, int):
|
95
|
-
raise TypeError("num_clients argument must be an int")
|
96
|
-
|
97
|
-
if not isinstance(self.num_rounds, int):
|
98
|
-
raise TypeError("num_rounds argument must be an int")
|
99
|
-
|
100
|
-
if not isinstance(self.diff_priv, bool):
|
101
|
-
raise TypeError("diff_priv argument must be a bool")
|
102
|
-
|
103
|
-
def client_fn(self, cid) -> FlowerClient:
|
104
|
-
"""
|
105
|
-
Return a FlowerClient object for a specific client ID.
|
106
|
-
|
107
|
-
Args:
|
108
|
-
cid: The client ID.
|
109
|
-
|
110
|
-
Returns:
|
111
|
-
FlowerClient: A FlowerClient object representing the individual client.
|
112
|
-
"""
|
113
|
-
|
114
|
-
device = torch.device(
|
115
|
-
f"cuda:{int(cid) % 4}" if torch.cuda.is_available() else "cpu"
|
116
|
-
)
|
117
|
-
client_model = copy.deepcopy(self.global_model)
|
118
|
-
|
119
|
-
trainloader = self.fed_dataset.trainloaders[int(cid)]
|
120
|
-
valloader = self.fed_dataset.valloaders[int(cid)]
|
121
|
-
# this helps in making plots
|
122
|
-
|
123
|
-
client = FlowerClient(
|
124
|
-
cid, client_model, trainloader, valloader, self.diff_priv
|
125
|
-
)
|
126
|
-
self.flower_clients.append(client)
|
127
|
-
return client
|
128
|
-
|
129
|
-
def evaluate(
|
130
|
-
self,
|
131
|
-
server_round: int,
|
132
|
-
parameters: fl.common.NDArrays,
|
133
|
-
config: Dict[str, fl.common.Scalar],
|
134
|
-
) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
|
135
|
-
"""
|
136
|
-
Evaluate the global model on the validation dataset and update the accuracies and losses.
|
137
|
-
|
138
|
-
Args:
|
139
|
-
server_round (int): The current round of the federated learning process.
|
140
|
-
parameters (fl.common.NDArrays): The global model parameters.
|
141
|
-
config (Dict[str, fl.common.Scalar]): Configuration dictionary.
|
142
|
-
|
143
|
-
Returns:
|
144
|
-
Optional[Tuple[float, Dict[str, fl.common.Scalar]]]: The evaluation loss and accuracy.
|
145
|
-
"""
|
146
|
-
testloader = self.fed_dataset.valloaders[0]
|
147
|
-
|
148
|
-
self.global_model.set_parameters(
|
149
|
-
parameters
|
150
|
-
) # Update model with the latest parameters
|
151
|
-
loss, accuracy ,auc = self.global_model.evaluate(testloader, self.device)
|
152
|
-
self.auc.append(auc)
|
153
|
-
self.losses.append(loss)
|
154
|
-
self.accuracies.append(accuracy)
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
config=fl.server.ServerConfig(self.num_rounds),
|
183
|
-
strategy=self.strategy.strategy_object,
|
184
|
-
ray_init_args=ray_init_args,
|
185
|
-
client_resources = self.client_resources
|
186
|
-
)
|
187
|
-
|
188
|
-
return history
|
189
|
-
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
|
3
|
+
import copy
|
4
|
+
from typing import Dict, Optional, Tuple
|
5
|
+
|
6
|
+
import flwr as fl
|
7
|
+
import torch
|
8
|
+
|
9
|
+
from .client import FlowerClient
|
10
|
+
from .federated_dataset import FederatedDataset
|
11
|
+
from .model import Model
|
12
|
+
from .strategy import Strategy
|
13
|
+
|
14
|
+
|
15
|
+
class FlowerServer:
|
16
|
+
"""
|
17
|
+
A class representing the central server for Federated Learning using Flower.
|
18
|
+
|
19
|
+
Attributes:
|
20
|
+
global_model (Model): The global model that will be federated among clients.
|
21
|
+
strategy (Strategy): The strategy used for federated learning, specifying communication and aggregation methods.
|
22
|
+
num_rounds (int): The number of federated learning rounds to perform.
|
23
|
+
num_clients (int): The number of clients participating in the federated learning process.
|
24
|
+
fed_dataset (FederatedDataset): The federated dataset used for training and evaluation.
|
25
|
+
diff_priv (bool): Whether differential privacy is used during the federated learning process.
|
26
|
+
accuracies (List[float]): A list to store the accuracy of the global model during each round.
|
27
|
+
losses (List[float]): A list to store the loss of the global model during each round.
|
28
|
+
flower_clients (List[FlowerClient]): A list to store the FlowerClient objects representing individual clients.
|
29
|
+
|
30
|
+
"""
|
31
|
+
|
32
|
+
def __init__(
|
33
|
+
self,
|
34
|
+
global_model: Model,
|
35
|
+
strategy: Strategy,
|
36
|
+
num_rounds: int,
|
37
|
+
num_clients: int,
|
38
|
+
fed_dataset: FederatedDataset,
|
39
|
+
diff_privacy: bool = False,
|
40
|
+
client_resources: Optional[Dict[str, float]] = {'num_cpus': 1, 'num_gpus': 0.0}
|
41
|
+
) -> None:
|
42
|
+
"""
|
43
|
+
Initialize a FlowerServer object with the specified parameters.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
global_model (Model): The global model that will be federated among clients.
|
47
|
+
strategy (Strategy): The strategy used for federated learning, specifying communication and aggregation methods.
|
48
|
+
num_rounds (int): The number of federated learning rounds to perform.
|
49
|
+
num_clients (int): The number of clients participating in the federated learning process.
|
50
|
+
fed_dataset (FederatedDataset): The federated dataset used for training and evaluation.
|
51
|
+
diff_privacy (bool, optional): Whether differential privacy is used during the federated learning process.
|
52
|
+
Default is False.
|
53
|
+
"""
|
54
|
+
self.device = torch.device(
|
55
|
+
f"cuda" if torch.cuda.is_available() else "cpu"
|
56
|
+
)
|
57
|
+
self.global_model = global_model
|
58
|
+
self.params = global_model.get_parameters()
|
59
|
+
self.global_model.model = global_model.model.to(self.device)
|
60
|
+
self.num_rounds = num_rounds
|
61
|
+
self.num_clients = num_clients
|
62
|
+
self.fed_dataset = fed_dataset
|
63
|
+
self.strategy = strategy
|
64
|
+
self.client_resources = client_resources
|
65
|
+
setattr(
|
66
|
+
self.strategy.strategy_object,
|
67
|
+
"min_available_clients",
|
68
|
+
self.num_clients,
|
69
|
+
)
|
70
|
+
setattr(
|
71
|
+
self.strategy.strategy_object,
|
72
|
+
"initial_parameters",
|
73
|
+
fl.common.ndarrays_to_parameters(self.params),
|
74
|
+
)
|
75
|
+
setattr(self.strategy.strategy_object, "evaluate_fn", self.evaluate)
|
76
|
+
self.fed_dataset = fed_dataset
|
77
|
+
self.diff_priv = diff_privacy
|
78
|
+
self.accuracies = []
|
79
|
+
self.losses = []
|
80
|
+
self.auc = []
|
81
|
+
self.flower_clients = []
|
82
|
+
self.validate()
|
83
|
+
|
84
|
+
def validate(self) -> None:
|
85
|
+
"""Validate global_model, strategy, num_clients, num_rounds, fed_dataset, diff_privacy"""
|
86
|
+
if not isinstance(self.global_model, Model):
|
87
|
+
raise TypeError("global_model argument must be a Model instance")
|
88
|
+
|
89
|
+
# if not isinstance(self.strategy, Strategy):
|
90
|
+
# print(self.strategy)
|
91
|
+
# print(isinstance(self.strategy, Strategy))
|
92
|
+
# raise TypeError("strategy argument must be a Strategy instance")
|
93
|
+
|
94
|
+
if not isinstance(self.num_clients, int):
|
95
|
+
raise TypeError("num_clients argument must be an int")
|
96
|
+
|
97
|
+
if not isinstance(self.num_rounds, int):
|
98
|
+
raise TypeError("num_rounds argument must be an int")
|
99
|
+
|
100
|
+
if not isinstance(self.diff_priv, bool):
|
101
|
+
raise TypeError("diff_priv argument must be a bool")
|
102
|
+
|
103
|
+
def client_fn(self, cid) -> FlowerClient:
|
104
|
+
"""
|
105
|
+
Return a FlowerClient object for a specific client ID.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
cid: The client ID.
|
109
|
+
|
110
|
+
Returns:
|
111
|
+
FlowerClient: A FlowerClient object representing the individual client.
|
112
|
+
"""
|
113
|
+
|
114
|
+
device = torch.device(
|
115
|
+
f"cuda:{int(cid) % 4}" if torch.cuda.is_available() else "cpu"
|
116
|
+
)
|
117
|
+
client_model = copy.deepcopy(self.global_model)
|
118
|
+
|
119
|
+
trainloader = self.fed_dataset.trainloaders[int(cid)]
|
120
|
+
valloader = self.fed_dataset.valloaders[int(cid)]
|
121
|
+
# this helps in making plots
|
122
|
+
|
123
|
+
client = FlowerClient(
|
124
|
+
cid, client_model, trainloader, valloader, self.diff_priv
|
125
|
+
)
|
126
|
+
self.flower_clients.append(client)
|
127
|
+
return client
|
128
|
+
|
129
|
+
def evaluate(
|
130
|
+
self,
|
131
|
+
server_round: int,
|
132
|
+
parameters: fl.common.NDArrays,
|
133
|
+
config: Dict[str, fl.common.Scalar],
|
134
|
+
) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
|
135
|
+
"""
|
136
|
+
Evaluate the global model on the validation dataset and update the accuracies and losses.
|
137
|
+
|
138
|
+
Args:
|
139
|
+
server_round (int): The current round of the federated learning process.
|
140
|
+
parameters (fl.common.NDArrays): The global model parameters.
|
141
|
+
config (Dict[str, fl.common.Scalar]): Configuration dictionary.
|
142
|
+
|
143
|
+
Returns:
|
144
|
+
Optional[Tuple[float, Dict[str, fl.common.Scalar]]]: The evaluation loss and accuracy.
|
145
|
+
"""
|
146
|
+
testloader = self.fed_dataset.valloaders[0]
|
147
|
+
|
148
|
+
self.global_model.set_parameters(
|
149
|
+
parameters
|
150
|
+
) # Update model with the latest parameters
|
151
|
+
loss, accuracy ,auc = self.global_model.evaluate(testloader, self.device)
|
152
|
+
self.auc.append(auc)
|
153
|
+
self.losses.append(loss)
|
154
|
+
self.accuracies.append(accuracy)
|
155
|
+
|
156
|
+
return loss, {"accuracy": accuracy}
|
157
|
+
|
158
|
+
def run(self) -> None:
|
159
|
+
"""
|
160
|
+
Run the federated learning process using Flower simulation.
|
161
|
+
|
162
|
+
Returns:
|
163
|
+
History: The history of the accuracies and losses during the training of each node
|
164
|
+
"""
|
165
|
+
# Increase the object store memory to the minimum allowed value or higher
|
166
|
+
ray_init_args = {"include_dashboard": False
|
167
|
+
, "object_store_memory": 78643200
|
168
|
+
}
|
169
|
+
self.fed_dataset.eng = None
|
170
|
+
|
171
|
+
history = fl.simulation.start_simulation(
|
172
|
+
client_fn=self.client_fn,
|
173
|
+
num_clients=self.num_clients,
|
174
|
+
config=fl.server.ServerConfig(self.num_rounds),
|
175
|
+
strategy=self.strategy.strategy_object,
|
176
|
+
ray_init_args=ray_init_args,
|
177
|
+
client_resources = self.client_resources
|
178
|
+
)
|
179
|
+
|
180
|
+
return history
|
181
|
+
|
@@ -1,138 +1,82 @@
|
|
1
|
-
|
2
|
-
from collections import OrderedDict
|
3
|
-
from typing import Dict, List, Optional, Tuple
|
4
|
-
|
5
|
-
import flwr as fl
|
6
|
-
import numpy as np
|
7
|
-
|
8
|
-
import optuna
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
class Strategy:
|
14
|
-
"""
|
15
|
-
A class representing a strategy for Federated Learning.
|
16
|
-
|
17
|
-
Attributes:
|
18
|
-
name (str): The name of the strategy. Default is "FedAvg".
|
19
|
-
fraction_fit (float): Fraction of clients to use for training during each round. Default is 1.0.
|
20
|
-
fraction_evaluate (float): Fraction of clients to use for evaluation during each round. Default is 1.0.
|
21
|
-
min_fit_clients (int): Minimum number of clients to use for training during each round. Default is 2.
|
22
|
-
min_evaluate_clients (int): Minimum number of clients to use for evaluation during each round. Default is 2.
|
23
|
-
min_available_clients (int): Minimum number of available clients required to start a round. Default is 2.
|
24
|
-
initial_parameters (Optional[]): The initial parameters of the server model
|
25
|
-
Methods:
|
26
|
-
|
27
|
-
"""
|
28
|
-
|
29
|
-
def __init__(
|
30
|
-
self,
|
31
|
-
name: str = "FedAvg",
|
32
|
-
fraction_fit: float = 1.0,
|
33
|
-
fraction_evaluate: float = 1.0,
|
34
|
-
min_fit_clients: int = 2,
|
35
|
-
min_evaluate_clients: int = 2,
|
36
|
-
min_available_clients: int = 2,
|
37
|
-
initial_parameters = [],
|
38
|
-
evaluation_methode = "centralized"
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
self.
|
55
|
-
self.
|
56
|
-
self.
|
57
|
-
self.
|
58
|
-
self.
|
59
|
-
self.
|
60
|
-
self.
|
61
|
-
|
62
|
-
|
63
|
-
self.
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
if(self.server_round < server_round):
|
84
|
-
self.server_round = server_round
|
85
|
-
self.trail = []
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
print('================= this is the server trails')
|
90
|
-
print(self.trail)
|
91
|
-
|
92
|
-
trail = self.study.ask()
|
93
|
-
self.trail.append(trail)
|
94
|
-
learning_rate = trail.suggest_float('learning_rate', 1e-5, 1e-1)
|
95
|
-
print(self.study.trials)
|
96
|
-
print(trail.number)
|
97
|
-
config = {
|
98
|
-
"trail" : trail ,
|
99
|
-
"server_rounds": 5 ,
|
100
|
-
"server_round" : server_round ,
|
101
|
-
"HPO_factor" : 0.5 ,
|
102
|
-
"study" : self.study ,
|
103
|
-
"HPO_RATE" : self.hpo_rate ,
|
104
|
-
"params" : {
|
105
|
-
"learning_rate" : learning_rate
|
106
|
-
}
|
107
|
-
}
|
108
|
-
|
109
|
-
return config
|
110
|
-
|
111
|
-
def optuna_fed_optimization(self, direction:str , hpo_rate:int , hpo_factor , params_config , sampler="TPESampler" , metric='AUC'):
|
112
|
-
self.study = optuna.create_study(direction=direction , sampler=self.get_sampler_by_name(sampler)())
|
113
|
-
self.hpo_rate = hpo_rate
|
114
|
-
self.HPO_factor = hpo_factor
|
115
|
-
self.config = params_config
|
116
|
-
self.opt_metric = metric
|
117
|
-
|
118
|
-
|
119
|
-
def create_strategy(self):
|
120
|
-
self.strategy_object = self.get_strategy_by_name()(
|
121
|
-
fraction_fit=self.fraction_fit,
|
122
|
-
fraction_evaluate=self.fraction_evaluate,
|
123
|
-
min_fit_clients=self.min_fit_clients,
|
124
|
-
min_evaluate_clients=self.min_evaluate_clients,
|
125
|
-
min_available_clients=self.min_available_clients,
|
126
|
-
initial_parameters=fl.common.ndarrays_to_parameters(self.initial_parameters),
|
127
|
-
evaluate_fn=self.evaluate_fn ,
|
128
|
-
on_fit_config_fn = self.fit_config ,
|
129
|
-
on_evaluate_config_fn = self.fit_config
|
130
|
-
)
|
131
|
-
def get_strategy_by_name(self):
|
132
|
-
return eval(f"fl.server.strategy.{self.name}")
|
133
|
-
|
134
|
-
def get_sampler_by_name(self , name) :
|
135
|
-
return eval(f"optuna.samplers.{name}")
|
136
|
-
|
137
|
-
|
138
|
-
|
1
|
+
|
2
|
+
from collections import OrderedDict
|
3
|
+
from typing import Dict, List, Optional, Tuple
|
4
|
+
|
5
|
+
import flwr as fl
|
6
|
+
import numpy as np
|
7
|
+
|
8
|
+
import optuna
|
9
|
+
|
10
|
+
|
11
|
+
|
12
|
+
|
13
|
+
class Strategy:
|
14
|
+
"""
|
15
|
+
A class representing a strategy for Federated Learning.
|
16
|
+
|
17
|
+
Attributes:
|
18
|
+
name (str): The name of the strategy. Default is "FedAvg".
|
19
|
+
fraction_fit (float): Fraction of clients to use for training during each round. Default is 1.0.
|
20
|
+
fraction_evaluate (float): Fraction of clients to use for evaluation during each round. Default is 1.0.
|
21
|
+
min_fit_clients (int): Minimum number of clients to use for training during each round. Default is 2.
|
22
|
+
min_evaluate_clients (int): Minimum number of clients to use for evaluation during each round. Default is 2.
|
23
|
+
min_available_clients (int): Minimum number of available clients required to start a round. Default is 2.
|
24
|
+
initial_parameters (Optional[]): The initial parameters of the server model
|
25
|
+
Methods:
|
26
|
+
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
name: str = "FedAvg",
|
32
|
+
fraction_fit: float = 1.0,
|
33
|
+
fraction_evaluate: float = 1.0,
|
34
|
+
min_fit_clients: int = 2,
|
35
|
+
min_evaluate_clients: int = 2,
|
36
|
+
min_available_clients: int = 2,
|
37
|
+
initial_parameters = [],
|
38
|
+
evaluation_methode = "centralized"
|
39
|
+
) -> None:
|
40
|
+
"""
|
41
|
+
Initialize a Strategy object with the specified parameters.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
name (str): The name of the strategy. Default is "FedAvg".
|
45
|
+
fraction_fit (float): Fraction of clients to use for training during each round. Default is 1.0.
|
46
|
+
fraction_evaluate (float): Fraction of clients to use for evaluation during each round. Default is 1.0.
|
47
|
+
min_fit_clients (int): Minimum number of clients to use for training during each round. Default is 2.
|
48
|
+
min_evaluate_clients (int): Minimum number of clients to use for evaluation during each round. Default is 2.
|
49
|
+
min_available_clients (int): Minimum number of available clients required to start a round. Default is 2.
|
50
|
+
initial_parameters (Optional[]): The initial parametres of the server model
|
51
|
+
evaluation_methode ( "centralized" | "distributed")
|
52
|
+
"""
|
53
|
+
self.fraction_fit = fraction_fit
|
54
|
+
self.fraction_evaluate = fraction_evaluate
|
55
|
+
self.min_fit_clients = min_fit_clients
|
56
|
+
self.min_evaluate_clients = min_evaluate_clients
|
57
|
+
self.min_available_clients = min_available_clients
|
58
|
+
self.initial_parameters = initial_parameters
|
59
|
+
self.evaluate_fn = None
|
60
|
+
self.name = name
|
61
|
+
|
62
|
+
def optuna_fed_optimization(self, direction:str , hpo_rate:int , params_config):
|
63
|
+
self.study = optuna.create_study(direction=direction)
|
64
|
+
self.hpo_rate = hpo_rate
|
65
|
+
self.params_config = params_config
|
66
|
+
|
67
|
+
|
68
|
+
def create_strategy(self):
|
69
|
+
self.strategy_object = self.get_strategy_by_name()(
|
70
|
+
fraction_fit=self.fraction_fit,
|
71
|
+
fraction_evaluate=self.fraction_evaluate,
|
72
|
+
min_fit_clients=self.min_fit_clients,
|
73
|
+
min_evaluate_clients=self.min_evaluate_clients,
|
74
|
+
min_available_clients=self.min_available_clients,
|
75
|
+
initial_parameters=fl.common.ndarrays_to_parameters(self.initial_parameters),
|
76
|
+
evaluate_fn=self.evaluate_fn
|
77
|
+
)
|
78
|
+
def get_strategy_by_name(self):
|
79
|
+
return eval(f"fl.server.strategy.{self.name}")
|
80
|
+
|
81
|
+
|
82
|
+
|