MEDfl 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- MEDfl/LearningManager/client.py +31 -0
- MEDfl/LearningManager/server.py +8 -0
- MEDfl/LearningManager/strategy.py +63 -7
- MEDfl/NetManager/database_connector.py +1 -1
- {MEDfl-0.2.0.dist-info → MEDfl-0.2.1.dist-info}/METADATA +1 -1
- {MEDfl-0.2.0.dist-info → MEDfl-0.2.1.dist-info}/RECORD +9 -9
- {MEDfl-0.2.0.dist-info → MEDfl-0.2.1.dist-info}/LICENSE +0 -0
- {MEDfl-0.2.0.dist-info → MEDfl-0.2.1.dist-info}/WHEEL +0 -0
- {MEDfl-0.2.0.dist-info → MEDfl-0.2.1.dist-info}/top_level.txt +0 -0
MEDfl/LearningManager/client.py
CHANGED
@@ -7,6 +7,8 @@ from .model import Model
|
|
7
7
|
from .utils import params
|
8
8
|
import torch
|
9
9
|
|
10
|
+
import optuna
|
11
|
+
|
10
12
|
class FlowerClient(fl.client.NumPyClient):
|
11
13
|
"""
|
12
14
|
FlowerClient class for creating MEDfl clients.
|
@@ -107,6 +109,7 @@ class FlowerClient(fl.client.NumPyClient):
|
|
107
109
|
Returns:
|
108
110
|
Tuple: Parameters of the local model, number of training examples, and privacy information.
|
109
111
|
"""
|
112
|
+
|
110
113
|
print('\n -------------------------------- \n this is the config of the client')
|
111
114
|
print(f"[Client {self.cid}] fit, config: {config}")
|
112
115
|
# print(config['epochs'])
|
@@ -147,4 +150,32 @@ class FlowerClient(fl.client.NumPyClient):
|
|
147
150
|
self.losses.append(loss)
|
148
151
|
self.accuracies.append(accuracy)
|
149
152
|
|
153
|
+
print(f"[ ============== From evaluate ==== Client {self.cid}] fit, config: {config}")
|
154
|
+
|
155
|
+
|
156
|
+
# if('study' in config):
|
157
|
+
# if 0 < config['server_round'] <= config['HPO_factor']*config['server_rounds'] and (config['server_round'] -1) %(config['HPO_RATE'] )==0:
|
158
|
+
# print("==================== this is th optimisations info ===================")
|
159
|
+
# print(auc)
|
160
|
+
# print(config['trail'])
|
161
|
+
# print('---------')
|
162
|
+
# print(config['study'].trials[0].state)
|
163
|
+
# try:
|
164
|
+
# # Call tell() method to report the result
|
165
|
+
# config['study'].tell(config['trail'], auc)
|
166
|
+
|
167
|
+
# # Fetch the updated trial from the study
|
168
|
+
# updated_trial = config['trail']
|
169
|
+
|
170
|
+
# # Check the state of the trial
|
171
|
+
# if updated_trial.state == optuna.trial.TrialState.COMPLETE:
|
172
|
+
# print(f"Trial {updated_trial.number} completed successfully with value {updated_trial.value}")
|
173
|
+
# else:
|
174
|
+
# print(f"Trial {updated_trial.number} is in state {updated_trial.state}")
|
175
|
+
|
176
|
+
# except Exception as e:
|
177
|
+
# # Handle and log any errors
|
178
|
+
# print(f"Error during tell(): {e}")
|
179
|
+
|
180
|
+
|
150
181
|
return float(loss), len(self.valloader), {"accuracy": float(accuracy)}
|
MEDfl/LearningManager/server.py
CHANGED
@@ -152,6 +152,14 @@ class FlowerServer:
|
|
152
152
|
self.auc.append(auc)
|
153
153
|
self.losses.append(loss)
|
154
154
|
self.accuracies.append(accuracy)
|
155
|
+
|
156
|
+
if hasattr(self.strategy, 'study') and 0 < server_round <= self.strategy.HPO_factor*self.num_rounds and (server_round -1) %(self.strategy.hpo_rate )==0:
|
157
|
+
|
158
|
+
print("================ Weeeee aaaaaareee hre ========================")
|
159
|
+
|
160
|
+
for trail in self.strategy.trail:
|
161
|
+
|
162
|
+
self.strategy.study.tell(trail , auc)
|
155
163
|
|
156
164
|
return loss, {"accuracy": accuracy}
|
157
165
|
|
@@ -35,7 +35,8 @@ class Strategy:
|
|
35
35
|
min_evaluate_clients: int = 2,
|
36
36
|
min_available_clients: int = 2,
|
37
37
|
initial_parameters = [],
|
38
|
-
evaluation_methode = "centralized"
|
38
|
+
evaluation_methode = "centralized" ,
|
39
|
+
config = None
|
39
40
|
) -> None:
|
40
41
|
"""
|
41
42
|
Initialize a Strategy object with the specified parameters.
|
@@ -57,12 +58,62 @@ class Strategy:
|
|
57
58
|
self.min_available_clients = min_available_clients
|
58
59
|
self.initial_parameters = initial_parameters
|
59
60
|
self.evaluate_fn = None
|
60
|
-
self.name = name
|
61
|
-
|
62
|
-
|
63
|
-
|
61
|
+
self.name = name
|
62
|
+
self.config = config
|
63
|
+
self.server_round = 0
|
64
|
+
|
65
|
+
def get_trial(self, trial_number):
|
66
|
+
# Retrieve the trial from the study
|
67
|
+
trial = next((t for t in self.study.trials if t.number == trial_number), None)
|
68
|
+
if trial:
|
69
|
+
return trial
|
70
|
+
else:
|
71
|
+
return "Trial not found"
|
72
|
+
|
73
|
+
def fit_config(self , server_round: int):
|
74
|
+
"""Return training configuration dict for each round.
|
75
|
+
|
76
|
+
Perform two rounds of training with one local epoch, increase to two local
|
77
|
+
epochs afterwards.
|
78
|
+
"""
|
79
|
+
config = self.config
|
80
|
+
|
81
|
+
if hasattr(self, 'study'):
|
82
|
+
if 0 < server_round <= 0.7*10 and (server_round - 1 ) % self.hpo_rate == 0 :
|
83
|
+
if(self.server_round < server_round):
|
84
|
+
self.server_round = server_round
|
85
|
+
self.trail = []
|
86
|
+
|
87
|
+
|
88
|
+
|
89
|
+
print('================= this is the server trails')
|
90
|
+
print(self.trail)
|
91
|
+
|
92
|
+
trail = self.study.ask()
|
93
|
+
self.trail.append(trail)
|
94
|
+
learning_rate = trail.suggest_float('learning_rate', 1e-5, 1e-1)
|
95
|
+
print(self.study.trials)
|
96
|
+
print(trail.number)
|
97
|
+
config = {
|
98
|
+
"trail" : trail ,
|
99
|
+
"server_rounds": 5 ,
|
100
|
+
"server_round" : server_round ,
|
101
|
+
"HPO_factor" : 0.5 ,
|
102
|
+
"study" : self.study ,
|
103
|
+
"HPO_RATE" : self.hpo_rate ,
|
104
|
+
"params" : {
|
105
|
+
"learning_rate" : learning_rate
|
106
|
+
}
|
107
|
+
}
|
108
|
+
|
109
|
+
return config
|
110
|
+
|
111
|
+
def optuna_fed_optimization(self, direction:str , hpo_rate:int , hpo_factor , params_config , sampler="TPESampler" , metric='AUC'):
|
112
|
+
self.study = optuna.create_study(direction=direction , sampler=self.get_sampler_by_name(sampler)())
|
64
113
|
self.hpo_rate = hpo_rate
|
65
|
-
self.
|
114
|
+
self.HPO_factor = hpo_factor
|
115
|
+
self.config = params_config
|
116
|
+
self.opt_metric = metric
|
66
117
|
|
67
118
|
|
68
119
|
def create_strategy(self):
|
@@ -73,10 +124,15 @@ class Strategy:
|
|
73
124
|
min_evaluate_clients=self.min_evaluate_clients,
|
74
125
|
min_available_clients=self.min_available_clients,
|
75
126
|
initial_parameters=fl.common.ndarrays_to_parameters(self.initial_parameters),
|
76
|
-
evaluate_fn=self.evaluate_fn
|
127
|
+
evaluate_fn=self.evaluate_fn ,
|
128
|
+
on_fit_config_fn = self.fit_config ,
|
129
|
+
on_evaluate_config_fn = self.fit_config
|
77
130
|
)
|
78
131
|
def get_strategy_by_name(self):
|
79
132
|
return eval(f"fl.server.strategy.{self.name}")
|
80
133
|
|
134
|
+
def get_sampler_by_name(self , name) :
|
135
|
+
return eval(f"optuna.samplers.{name}")
|
136
|
+
|
81
137
|
|
82
138
|
|
@@ -34,7 +34,7 @@ class DatabaseManager:
|
|
34
34
|
create_db_script_path = os.path.join(current_directory, '..', 'scripts', 'create_db.py')
|
35
35
|
|
36
36
|
# Execute the create_db.py script
|
37
|
-
subprocess.run(['
|
37
|
+
subprocess.run(['python', create_db_script_path, path_to_csv], check=True)
|
38
38
|
|
39
39
|
return
|
40
40
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
MEDfl/__init__.py,sha256=pMyTkws4slDkLQgPpPdKoWdQf1rZGHXAlNxfPFFZM-I,81
|
2
2
|
MEDfl/LearningManager/__init__.py,sha256=L66e9oq7hTWIo8KpRrG6Yxx1UHmEIOoOfbaBcX9BeOk,354
|
3
|
-
MEDfl/LearningManager/client.py,sha256=
|
3
|
+
MEDfl/LearningManager/client.py,sha256=rFAVDJ3KGB61A8ztvYgE0rjY16fo7XpxcbTiSkDPIlo,7540
|
4
4
|
MEDfl/LearningManager/dynamicModal.py,sha256=0mTvDJlss0uSJ3_EXOuL_d-zRmFyXaKB4W4ql-uEX8Y,10821
|
5
5
|
MEDfl/LearningManager/federated_dataset.py,sha256=fQqIbhO6LSk16Ob9z6RohaZ8X71Ff-yueynjulrl4M0,2141
|
6
6
|
MEDfl/LearningManager/flpipeline.py,sha256=M4-OL4nlogv08J_YsyDsGHXR6xe8BWx4HIsuL1QyUvY,7303
|
@@ -8,11 +8,11 @@ MEDfl/LearningManager/model.py,sha256=DA7HP34Eq1Ra65OlkBmjH9d2MD7OEbsOhfxD48l4QO
|
|
8
8
|
MEDfl/LearningManager/params.yaml,sha256=6UcmIgYcufbCn_H6vxerlinVlQE8fF6fI6CrRaTSVWE,450
|
9
9
|
MEDfl/LearningManager/params_optimiser.py,sha256=pjhDskhSPuca-jnarYoJcFVBvRkdD9tD3992q_eMPSE,18060
|
10
10
|
MEDfl/LearningManager/plot.py,sha256=iPqMV9rVd7hquoFixDL20OzXI5bMpBW41bkVmTKIWtE,7927
|
11
|
-
MEDfl/LearningManager/server.py,sha256=
|
12
|
-
MEDfl/LearningManager/strategy.py,sha256=
|
11
|
+
MEDfl/LearningManager/server.py,sha256=scnjiWC6vMIPzAwEvaeQfKfCBqYtKdqZLhOwvcwRo7w,7631
|
12
|
+
MEDfl/LearningManager/strategy.py,sha256=ouCdAJA8CdxMmXBVVZ0BapmZIBZfACOxVPQLEXPmHIs,5673
|
13
13
|
MEDfl/LearningManager/utils.py,sha256=vEhkpyC7iLsn4wp1wDh7GzAn5MCJ7T69jkS9lfmKA1Y,9936
|
14
14
|
MEDfl/NetManager/__init__.py,sha256=UaHMFtzo90k7rQW45ZUX7aW0-EG1d3OXDkqc8cVgp6U,283
|
15
|
-
MEDfl/NetManager/database_connector.py,sha256=
|
15
|
+
MEDfl/NetManager/database_connector.py,sha256=JKfFLom7I4zuykb8m7aY4cUYZy4j0-i9w3R_jkPzjXY,1516
|
16
16
|
MEDfl/NetManager/dataset.py,sha256=eEuVzCp5dGD4tvDVKq6jlSReecge7T20ByG4d7_cnXU,2869
|
17
17
|
MEDfl/NetManager/flsetup.py,sha256=CS7531I08eLm6txMIDWFMCIrPP-dNpOLBTaR2BR6X0c,11754
|
18
18
|
MEDfl/NetManager/net_helper.py,sha256=T5Y-03SVskK8oXXIpiXzASyEDPQJbcuGVpIS9FnmzI8,7066
|
@@ -24,8 +24,8 @@ MEDfl/scripts/base.py,sha256=pR7StIt3PpX30aoh53gMkpeNJMHytAPhdc7N09tCITA,781
|
|
24
24
|
MEDfl/scripts/create_db.py,sha256=PgA6N68iTSfnrt6zy7FDZX2lLjQQ7Ual1Y0efve8gf4,3943
|
25
25
|
alembic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
26
26
|
alembic/env.py,sha256=a4zJAzPNLHnIrUlXCqf_8vuAlFu0pceFJJKM1PQaOI4,1649
|
27
|
-
MEDfl-0.2.
|
28
|
-
MEDfl-0.2.
|
29
|
-
MEDfl-0.2.
|
30
|
-
MEDfl-0.2.
|
31
|
-
MEDfl-0.2.
|
27
|
+
MEDfl-0.2.1.dist-info/LICENSE,sha256=IwGE9guuL-ryRPEKi6wFPI_zOhg7zDZbTYuHbSt_SAk,35823
|
28
|
+
MEDfl-0.2.1.dist-info/METADATA,sha256=RIvrAIdTlpIdHPQfGeopzdSKXhaDuhlldOc8sjqF4Uk,4428
|
29
|
+
MEDfl-0.2.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
30
|
+
MEDfl-0.2.1.dist-info/top_level.txt,sha256=dIL9X8HOFuaVSzpg40DVveDPrymWRoShHtspH7kkjdI,14
|
31
|
+
MEDfl-0.2.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|