MEDfl 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,6 +7,8 @@ from .model import Model
7
7
  from .utils import params
8
8
  import torch
9
9
 
10
+ import optuna
11
+
10
12
  class FlowerClient(fl.client.NumPyClient):
11
13
  """
12
14
  FlowerClient class for creating MEDfl clients.
@@ -107,6 +109,7 @@ class FlowerClient(fl.client.NumPyClient):
107
109
  Returns:
108
110
  Tuple: Parameters of the local model, number of training examples, and privacy information.
109
111
  """
112
+
110
113
  print('\n -------------------------------- \n this is the config of the client')
111
114
  print(f"[Client {self.cid}] fit, config: {config}")
112
115
  # print(config['epochs'])
@@ -147,4 +150,32 @@ class FlowerClient(fl.client.NumPyClient):
147
150
  self.losses.append(loss)
148
151
  self.accuracies.append(accuracy)
149
152
 
153
+ print(f"[ ============== From evaluate ==== Client {self.cid}] fit, config: {config}")
154
+
155
+
156
+ # if('study' in config):
157
+ # if 0 < config['server_round'] <= config['HPO_factor']*config['server_rounds'] and (config['server_round'] -1) %(config['HPO_RATE'] )==0:
158
+ # print("==================== this is th optimisations info ===================")
159
+ # print(auc)
160
+ # print(config['trail'])
161
+ # print('---------')
162
+ # print(config['study'].trials[0].state)
163
+ # try:
164
+ # # Call tell() method to report the result
165
+ # config['study'].tell(config['trail'], auc)
166
+
167
+ # # Fetch the updated trial from the study
168
+ # updated_trial = config['trail']
169
+
170
+ # # Check the state of the trial
171
+ # if updated_trial.state == optuna.trial.TrialState.COMPLETE:
172
+ # print(f"Trial {updated_trial.number} completed successfully with value {updated_trial.value}")
173
+ # else:
174
+ # print(f"Trial {updated_trial.number} is in state {updated_trial.state}")
175
+
176
+ # except Exception as e:
177
+ # # Handle and log any errors
178
+ # print(f"Error during tell(): {e}")
179
+
180
+
150
181
  return float(loss), len(self.valloader), {"accuracy": float(accuracy)}
@@ -152,6 +152,14 @@ class FlowerServer:
152
152
  self.auc.append(auc)
153
153
  self.losses.append(loss)
154
154
  self.accuracies.append(accuracy)
155
+
156
+ if hasattr(self.strategy, 'study') and 0 < server_round <= self.strategy.HPO_factor*self.num_rounds and (server_round -1) %(self.strategy.hpo_rate )==0:
157
+
158
+ print("================ Weeeee aaaaaareee hre ========================")
159
+
160
+ for trail in self.strategy.trail:
161
+
162
+ self.strategy.study.tell(trail , auc)
155
163
 
156
164
  return loss, {"accuracy": accuracy}
157
165
 
@@ -35,7 +35,8 @@ class Strategy:
35
35
  min_evaluate_clients: int = 2,
36
36
  min_available_clients: int = 2,
37
37
  initial_parameters = [],
38
- evaluation_methode = "centralized"
38
+ evaluation_methode = "centralized" ,
39
+ config = None
39
40
  ) -> None:
40
41
  """
41
42
  Initialize a Strategy object with the specified parameters.
@@ -57,12 +58,62 @@ class Strategy:
57
58
  self.min_available_clients = min_available_clients
58
59
  self.initial_parameters = initial_parameters
59
60
  self.evaluate_fn = None
60
- self.name = name
61
-
62
- def optuna_fed_optimization(self, direction:str , hpo_rate:int , params_config):
63
- self.study = optuna.create_study(direction=direction)
61
+ self.name = name
62
+ self.config = config
63
+ self.server_round = 0
64
+
65
+ def get_trial(self, trial_number):
66
+ # Retrieve the trial from the study
67
+ trial = next((t for t in self.study.trials if t.number == trial_number), None)
68
+ if trial:
69
+ return trial
70
+ else:
71
+ return "Trial not found"
72
+
73
+ def fit_config(self , server_round: int):
74
+ """Return training configuration dict for each round.
75
+
76
+ Perform two rounds of training with one local epoch, increase to two local
77
+ epochs afterwards.
78
+ """
79
+ config = self.config
80
+
81
+ if hasattr(self, 'study'):
82
+ if 0 < server_round <= 0.7*10 and (server_round - 1 ) % self.hpo_rate == 0 :
83
+ if(self.server_round < server_round):
84
+ self.server_round = server_round
85
+ self.trail = []
86
+
87
+
88
+
89
+ print('================= this is the server trails')
90
+ print(self.trail)
91
+
92
+ trail = self.study.ask()
93
+ self.trail.append(trail)
94
+ learning_rate = trail.suggest_float('learning_rate', 1e-5, 1e-1)
95
+ print(self.study.trials)
96
+ print(trail.number)
97
+ config = {
98
+ "trail" : trail ,
99
+ "server_rounds": 5 ,
100
+ "server_round" : server_round ,
101
+ "HPO_factor" : 0.5 ,
102
+ "study" : self.study ,
103
+ "HPO_RATE" : self.hpo_rate ,
104
+ "params" : {
105
+ "learning_rate" : learning_rate
106
+ }
107
+ }
108
+
109
+ return config
110
+
111
+ def optuna_fed_optimization(self, direction:str , hpo_rate:int , hpo_factor , params_config , sampler="TPESampler" , metric='AUC'):
112
+ self.study = optuna.create_study(direction=direction , sampler=self.get_sampler_by_name(sampler)())
64
113
  self.hpo_rate = hpo_rate
65
- self.params_config = params_config
114
+ self.HPO_factor = hpo_factor
115
+ self.config = params_config
116
+ self.opt_metric = metric
66
117
 
67
118
 
68
119
  def create_strategy(self):
@@ -73,10 +124,15 @@ class Strategy:
73
124
  min_evaluate_clients=self.min_evaluate_clients,
74
125
  min_available_clients=self.min_available_clients,
75
126
  initial_parameters=fl.common.ndarrays_to_parameters(self.initial_parameters),
76
- evaluate_fn=self.evaluate_fn
127
+ evaluate_fn=self.evaluate_fn ,
128
+ on_fit_config_fn = self.fit_config ,
129
+ on_evaluate_config_fn = self.fit_config
77
130
  )
78
131
  def get_strategy_by_name(self):
79
132
  return eval(f"fl.server.strategy.{self.name}")
80
133
 
134
+ def get_sampler_by_name(self , name) :
135
+ return eval(f"optuna.samplers.{name}")
136
+
81
137
 
82
138
 
@@ -34,7 +34,7 @@ class DatabaseManager:
34
34
  create_db_script_path = os.path.join(current_directory, '..', 'scripts', 'create_db.py')
35
35
 
36
36
  # Execute the create_db.py script
37
- subprocess.run(['python3', create_db_script_path, path_to_csv], check=True)
37
+ subprocess.run(['python', create_db_script_path, path_to_csv], check=True)
38
38
 
39
39
  return
40
40
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: MEDfl
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: Python Open-source package for simulating federated learning and differential privacy
5
5
  Home-page: https://github.com/MEDomics-UdeS/MEDfl
6
6
  Author: MEDomics consortium
@@ -1,6 +1,6 @@
1
1
  MEDfl/__init__.py,sha256=pMyTkws4slDkLQgPpPdKoWdQf1rZGHXAlNxfPFFZM-I,81
2
2
  MEDfl/LearningManager/__init__.py,sha256=L66e9oq7hTWIo8KpRrG6Yxx1UHmEIOoOfbaBcX9BeOk,354
3
- MEDfl/LearningManager/client.py,sha256=9WyLYCsI9JuHjneLbbzDf7HtzjYINuLfqwkbxOsrBrE,6083
3
+ MEDfl/LearningManager/client.py,sha256=rFAVDJ3KGB61A8ztvYgE0rjY16fo7XpxcbTiSkDPIlo,7540
4
4
  MEDfl/LearningManager/dynamicModal.py,sha256=0mTvDJlss0uSJ3_EXOuL_d-zRmFyXaKB4W4ql-uEX8Y,10821
5
5
  MEDfl/LearningManager/federated_dataset.py,sha256=fQqIbhO6LSk16Ob9z6RohaZ8X71Ff-yueynjulrl4M0,2141
6
6
  MEDfl/LearningManager/flpipeline.py,sha256=M4-OL4nlogv08J_YsyDsGHXR6xe8BWx4HIsuL1QyUvY,7303
@@ -8,11 +8,11 @@ MEDfl/LearningManager/model.py,sha256=DA7HP34Eq1Ra65OlkBmjH9d2MD7OEbsOhfxD48l4QO
8
8
  MEDfl/LearningManager/params.yaml,sha256=6UcmIgYcufbCn_H6vxerlinVlQE8fF6fI6CrRaTSVWE,450
9
9
  MEDfl/LearningManager/params_optimiser.py,sha256=pjhDskhSPuca-jnarYoJcFVBvRkdD9tD3992q_eMPSE,18060
10
10
  MEDfl/LearningManager/plot.py,sha256=iPqMV9rVd7hquoFixDL20OzXI5bMpBW41bkVmTKIWtE,7927
11
- MEDfl/LearningManager/server.py,sha256=7edxPkZ9Ju3Mep_BSHQpUNgW9HKfCui3_l996buJVlU,7258
12
- MEDfl/LearningManager/strategy.py,sha256=n0asQajkHfGLDX3QbbV5qntQA-xuJZU8Z92XccZENsA,3517
11
+ MEDfl/LearningManager/server.py,sha256=scnjiWC6vMIPzAwEvaeQfKfCBqYtKdqZLhOwvcwRo7w,7631
12
+ MEDfl/LearningManager/strategy.py,sha256=ouCdAJA8CdxMmXBVVZ0BapmZIBZfACOxVPQLEXPmHIs,5673
13
13
  MEDfl/LearningManager/utils.py,sha256=vEhkpyC7iLsn4wp1wDh7GzAn5MCJ7T69jkS9lfmKA1Y,9936
14
14
  MEDfl/NetManager/__init__.py,sha256=UaHMFtzo90k7rQW45ZUX7aW0-EG1d3OXDkqc8cVgp6U,283
15
- MEDfl/NetManager/database_connector.py,sha256=vOhbvCbzKVN3vk5HxfcJP7iher7rsGpLtCXB0-4wwRw,1517
15
+ MEDfl/NetManager/database_connector.py,sha256=JKfFLom7I4zuykb8m7aY4cUYZy4j0-i9w3R_jkPzjXY,1516
16
16
  MEDfl/NetManager/dataset.py,sha256=eEuVzCp5dGD4tvDVKq6jlSReecge7T20ByG4d7_cnXU,2869
17
17
  MEDfl/NetManager/flsetup.py,sha256=CS7531I08eLm6txMIDWFMCIrPP-dNpOLBTaR2BR6X0c,11754
18
18
  MEDfl/NetManager/net_helper.py,sha256=T5Y-03SVskK8oXXIpiXzASyEDPQJbcuGVpIS9FnmzI8,7066
@@ -24,8 +24,8 @@ MEDfl/scripts/base.py,sha256=pR7StIt3PpX30aoh53gMkpeNJMHytAPhdc7N09tCITA,781
24
24
  MEDfl/scripts/create_db.py,sha256=PgA6N68iTSfnrt6zy7FDZX2lLjQQ7Ual1Y0efve8gf4,3943
25
25
  alembic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
26
  alembic/env.py,sha256=a4zJAzPNLHnIrUlXCqf_8vuAlFu0pceFJJKM1PQaOI4,1649
27
- MEDfl-0.2.0.dist-info/LICENSE,sha256=IwGE9guuL-ryRPEKi6wFPI_zOhg7zDZbTYuHbSt_SAk,35823
28
- MEDfl-0.2.0.dist-info/METADATA,sha256=TdK27hGked7Mdy6WVnyhoG3uX_ohfh9UDI7-DAHZY7k,4428
29
- MEDfl-0.2.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
30
- MEDfl-0.2.0.dist-info/top_level.txt,sha256=dIL9X8HOFuaVSzpg40DVveDPrymWRoShHtspH7kkjdI,14
31
- MEDfl-0.2.0.dist-info/RECORD,,
27
+ MEDfl-0.2.1.dist-info/LICENSE,sha256=IwGE9guuL-ryRPEKi6wFPI_zOhg7zDZbTYuHbSt_SAk,35823
28
+ MEDfl-0.2.1.dist-info/METADATA,sha256=RIvrAIdTlpIdHPQfGeopzdSKXhaDuhlldOc8sjqF4Uk,4428
29
+ MEDfl-0.2.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
30
+ MEDfl-0.2.1.dist-info/top_level.txt,sha256=dIL9X8HOFuaVSzpg40DVveDPrymWRoShHtspH7kkjdI,14
31
+ MEDfl-0.2.1.dist-info/RECORD,,
File without changes
File without changes