MEDfl 2.0.4.dev3__py3-none-any.whl → 2.0.4.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -214,10 +214,11 @@ class Model:
214
214
  torch.nn.Module: Loaded PyTorch model.
215
215
  """
216
216
  # Ensure models are loaded onto the CPU when CUDA is not available
217
+ torch_kwargs = {"weights_only": False}
217
218
  if torch.cuda.is_available():
218
- loaded_model = torch.load(model_path)
219
+ loaded_model = torch.load(model_path , **torch_kwargs)
219
220
  else:
220
- loaded_model = torch.load(model_path, map_location=torch.device('cpu'))
221
+ loaded_model = torch.load(model_path, map_location=torch.device('cpu') , **torch_kwargs)
221
222
  return loaded_model
222
223
 
223
224
 
@@ -113,7 +113,7 @@ def custom_classification_report(y_true, y_pred_prob):
113
113
 
114
114
  auc = roc_auc_score(y_true, y_pred_prob) # Calculate AUC
115
115
 
116
- tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
116
+ tn, fp, fn, tp = confusion_matrix(y_true, y_pred , labels=[0, 1]).ravel()
117
117
 
118
118
  # Accuracy
119
119
  denominator_acc = tp + tn + fp + fn
@@ -1,5 +1,5 @@
1
1
  import os
2
- import subprocess
2
+ import subprocess , sys
3
3
  from sqlalchemy import create_engine
4
4
  from configparser import ConfigParser
5
5
 
@@ -34,7 +34,9 @@ class DatabaseManager:
34
34
  create_db_script_path = os.path.join(current_directory, '..', 'scripts', 'create_db.py')
35
35
 
36
36
  # Execute the create_db.py script
37
- subprocess.run(['python3', create_db_script_path, path_to_csv], check=True)
37
+ print(sys.executable)
38
+ result = subprocess.run([sys.executable, create_db_script_path, path_to_csv],
39
+ capture_output=True, text=True)
38
40
 
39
41
  return
40
42
 
MEDfl/rw/client.py CHANGED
@@ -1,9 +1,4 @@
1
1
  # client.py
2
- """
3
- Federated Learning Client with Optional Differential Privacy.
4
-
5
- """
6
-
7
2
  import argparse
8
3
  import pandas as pd
9
4
  import flwr as fl
@@ -12,22 +7,21 @@ import torch.nn as nn
12
7
  import torch.optim as optim
13
8
  from torch.utils.data import TensorDataset, DataLoader
14
9
  from sklearn.metrics import accuracy_score, roc_auc_score
15
- from model import Net # Local model definition
10
+ from model import Net # your model definition in model.py
16
11
  import socket
17
- import platform
18
-
12
+ import platform
19
13
 
14
+ # Differential Privacy configuration class
20
15
  class DPConfig:
21
16
  """
22
- Configuration for Differential Privacy (DP) settings.
17
+ Configuration for differential privacy.
23
18
 
24
19
  Attributes:
25
20
  noise_multiplier (float): Noise multiplier for DP.
26
- max_grad_norm (float): Maximum gradient norm (clipping threshold).
21
+ max_grad_norm (float): Maximum gradient norm for clipping.
27
22
  batch_size (int): Batch size for training.
28
- secure_rng (bool): Whether to use a secure RNG for DP noise.
23
+ secure_rng (bool): Whether to use secure random generator.
29
24
  """
30
-
31
25
  def __init__(
32
26
  self,
33
27
  noise_multiplier: float = 1.0,
@@ -42,54 +36,44 @@ class DPConfig:
42
36
 
43
37
 
44
38
  class FlowerClient(fl.client.NumPyClient):
45
- """
46
- FlowerClient: A federated learning client that trains a PyTorch model
47
- and optionally applies differential privacy.
48
-
49
- """
50
-
51
39
  def __init__(
52
40
  self,
53
41
  server_address: str,
54
42
  data_path: str = "data/data.csv",
55
43
  dp_config: DPConfig = None,
56
44
  ):
57
- """
58
- Initialize client by loading data, creating model, optimizer,
59
- and optionally enabling DP.
60
-
61
- Args:
62
- server_address (str): Flower server address.
63
- data_path (str): Path to CSV dataset.
64
- dp_config (DPConfig): Optional DP configuration.
65
- """
66
45
  self.server_address = server_address
67
46
 
68
- # ---------- Load Data ----------
47
+ # 1. Load data from CSV
69
48
  df = pd.read_csv(data_path)
70
49
  X = df.iloc[:, :-1].values
71
50
  y = df.iloc[:, -1].values
72
51
 
52
+ # Convert to tensors and store for metrics
73
53
  self.X_tensor = torch.tensor(X, dtype=torch.float32)
74
54
  self.y_tensor = torch.tensor(y, dtype=torch.float32)
75
55
 
56
+ # Create DataLoader
76
57
  batch_size = dp_config.batch_size if dp_config else 32
77
58
  dataset = TensorDataset(self.X_tensor, self.y_tensor)
78
59
  self.train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
79
60
 
80
- # ---------- Model and Optimizer ----------
61
+ # 2. Initialize model, loss, optimizer
81
62
  input_dim = X.shape[1]
82
63
  self.model = Net(input_dim)
64
+
83
65
  self.criterion = nn.BCEWithLogitsLoss()
84
66
  self.optimizer = optim.SGD(self.model.parameters(), lr=0.01)
85
67
 
86
- # ---------- Differential Privacy ----------
68
+ # 3. Set up differential privacy if requested
87
69
  self.privacy_engine = None
88
70
  if dp_config is not None:
89
71
  try:
90
72
  from opacus import PrivacyEngine
91
73
 
74
+ # Instantiate engine without secure_rng argument
92
75
  self.privacy_engine = PrivacyEngine()
76
+ # Attach privacy engine with secure_rng flag
93
77
  self.model, self.optimizer, self.train_loader = self.privacy_engine.make_private(
94
78
  module=self.model,
95
79
  optimizer=self.optimizer,
@@ -99,10 +83,149 @@ class FlowerClient(fl.client.NumPyClient):
99
83
  secure_rng=dp_config.secure_rng,
100
84
  )
101
85
  except ImportError:
102
- print("⚠️ Opacus not installed, running without DP.")
86
+ print("Opacus not installed, running without DP.")
103
87
 
104
88
  def get_parameters(self, config):
89
+ return [val.cpu().numpy() for val in self.model.state_dict().values()]
90
+
91
+ def set_parameters(self, parameters):
92
+ params_dict = zip(self.model.state_dict().keys(), parameters)
93
+ state_dict = {k: torch.tensor(v) for k, v in params_dict}
94
+ self.model.load_state_dict(state_dict, strict=True)
95
+
96
+ def fit(self, parameters, config):
97
+ # Update model parameters
98
+ self.set_parameters(parameters)
99
+ self.model.train()
100
+
101
+ total_loss = 0.0
102
+ for X_batch, y_batch in self.train_loader:
103
+ self.optimizer.zero_grad()
104
+ outputs = self.model(X_batch)
105
+ loss = self.criterion(outputs.squeeze(), y_batch)
106
+ loss.backward()
107
+ self.optimizer.step()
108
+ total_loss += loss.item() * X_batch.size(0)
109
+
110
+ avg_loss = total_loss / len(self.train_loader.dataset)
111
+ # Compute training metrics (accuracy, AUC) on full dataset
112
+ with torch.no_grad():
113
+ logits = self.model(self.X_tensor).squeeze()
114
+ probs = torch.sigmoid(logits).cpu().numpy()
115
+ y_true = self.y_tensor.cpu().numpy()
116
+ binary_preds = (probs >= 0.5).astype(int)
117
+ acc = accuracy_score(y_true, binary_preds)
118
+ auc = roc_auc_score(y_true, probs)
119
+
120
+ hostname = socket.gethostname()
121
+ os_type = platform.system()
122
+
123
+ metrics = {"hostname": hostname, "train_loss": avg_loss, "train_accuracy": acc, "train_auc": auc , "os_type": os_type}
124
+ return self.get_parameters(config), len(self.train_loader.dataset), metrics
125
+
126
+ def evaluate(self, parameters, config):
127
+ # Update model parameters
128
+ self.set_parameters(parameters)
129
+ self.model.eval()
130
+
131
+ total_loss = 0.0
132
+ all_probs = []
133
+ all_true = []
134
+ with torch.no_grad():
135
+ for X_batch, y_batch in self.train_loader:
136
+ outputs = self.model(X_batch)
137
+ loss = self.criterion(outputs.squeeze(), y_batch)
138
+ total_loss += loss.item() * X_batch.size(0)
139
+ probs = torch.sigmoid(outputs.squeeze()).cpu().numpy()
140
+ all_probs.extend(probs.tolist())
141
+ all_true.extend(y_batch.cpu().numpy().tolist())
142
+
143
+ avg_loss = total_loss / len(self.train_loader.dataset)
144
+ binary_preds = [1 if p >= 0.5 else 0 for p in all_probs]
145
+ acc = accuracy_score(all_true, binary_preds)
146
+ auc = roc_auc_score(all_true, all_probs)
147
+
148
+ metrics = {"eval_loss": avg_loss, "eval_accuracy": acc, "eval_auc": auc}
149
+ print(f"Evaluation metrics: {metrics}")
150
+ return float(avg_loss), len(self.train_loader.dataset), metrics
151
+
152
+ # client.py - Update the get_properties method
153
+ def get_properties(self, config):
105
154
  """
106
- Get model parameters as a list of NumPy arrays.
155
+ Return dataset statistics before training starts.
156
+
157
+ NOTE: Only scalar values (int, float, str, bool, bytes) are allowed.
158
+ Lists will cause the “not a 1:1 mapping” TypeError.
107
159
  """
108
- return [val.cpu().numpy() for val in self.mo]()
160
+ num_samples = len(self.X_tensor)
161
+ num_features = self.X_tensor.shape[1]
162
+ # Convert list to comma‑separated string
163
+ column_names = [f"feature_{i}" for i in range(num_features)]
164
+ columns_str = ",".join(column_names)
165
+
166
+ hostname = socket.gethostname()
167
+ os_type = platform.system()
168
+
169
+ return {
170
+ "hostname": hostname,
171
+ "os_type": os_type,
172
+ "num_samples": num_samples,
173
+ "num_features": num_features,
174
+ "columns": columns_str, # now a single str, not a list
175
+ }
176
+
177
+
178
+
179
+ if __name__ == "__main__":
180
+ parser = argparse.ArgumentParser(description="Flower client with DP support and metrics")
181
+ parser.add_argument(
182
+ "--server_address",
183
+ type=str,
184
+ required=True,
185
+ help="Address of the Flower server (e.g., 127.0.0.1:8080)",
186
+ )
187
+ parser.add_argument(
188
+ "--data_path",
189
+ type=str,
190
+ default="data/data.csv",
191
+ help="Path to your CSV training data",
192
+ )
193
+ parser.add_argument(
194
+ "--dp",
195
+ action="store_true",
196
+ help="Enable differential privacy",
197
+ )
198
+ parser.add_argument(
199
+ "--noise_multiplier",
200
+ type=float,
201
+ default=1.0,
202
+ help="Noise multiplier for DP",
203
+ )
204
+ parser.add_argument(
205
+ "--max_grad_norm",
206
+ type=float,
207
+ default=1.0,
208
+ help="Clip norm for DP",
209
+ )
210
+ parser.add_argument(
211
+ "--batch_size",
212
+ type=int,
213
+ default=32,
214
+ help="Batch size for training",
215
+ )
216
+ args = parser.parse_args()
217
+
218
+ dp_config = None
219
+ if args.dp:
220
+ dp_config = DPConfig(
221
+ noise_multiplier=args.noise_multiplier,
222
+ max_grad_norm=args.max_grad_norm,
223
+ batch_size=args.batch_size,
224
+ )
225
+
226
+ client = FlowerClient(
227
+ server_address=args.server_address,
228
+ data_path=args.data_path,
229
+ dp_config=dp_config,
230
+ )
231
+ fl.client.start_numpy_client(server_address=client.server_address, client=client)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: MEDfl
3
- Version: 2.0.4.dev3
3
+ Version: 2.0.4.dev5
4
4
  Summary: Python Open-source package for simulating federated learning and differential privacy
5
5
  Home-page: https://github.com/MEDomics-UdeS/MEDfl
6
6
  Author: MEDomics consortium
@@ -28,13 +28,14 @@ Requires-Dist: Sphinx~=5.3.0
28
28
  Requires-Dist: SQLAlchemy~=1.4.47
29
29
  Requires-Dist: torch>=2.0.0
30
30
  Requires-Dist: datetime~=5.1
31
- Requires-Dist: scikit-learn~=1.6.1
31
+ Requires-Dist: scikit-learn~=1.2.2
32
32
  Requires-Dist: sphinx-jsonschema==1.19.1
33
33
  Requires-Dist: sphinx-rtd-dark-mode==1.2.4
34
34
  Requires-Dist: plotly==5.19.0
35
35
  Requires-Dist: optuna==3.5.0
36
36
  Requires-Dist: mysql-connector-python~=9.3.0
37
37
  Requires-Dist: seaborn~=0.13.2
38
+ Requires-Dist: flwr[simulation]
38
39
 
39
40
  # MEDfl: Federated Learning and Differential Privacy Simulation Tool for Tabular Data
40
41
  ![Python Versions](https://img.shields.io/badge/python-3.9-blue)
@@ -4,15 +4,15 @@ MEDfl/LearningManager/client.py,sha256=9Y_Zb0yxvCxx3dVCPQ1bXS5mCKasylSBnoVj-RDN2
4
4
  MEDfl/LearningManager/dynamicModal.py,sha256=q8u7xPpj_TdZnSr8kYj0Xx7Sdz-diXsKBAfVce8-qSU,10534
5
5
  MEDfl/LearningManager/federated_dataset.py,sha256=InsZ5Rys2dgqaPxVyP5G3TrJMwiCNHOoTd3tCpUwUVM,2081
6
6
  MEDfl/LearningManager/flpipeline.py,sha256=5lT2uod5EqnkRQ04cgm0gYyZz0djumfIYipCrzX1fdo,7111
7
- MEDfl/LearningManager/model.py,sha256=vp8FIMxBdz3FTF5wJaea2IO_WGeANLZgBxTKVe3gW3Q,7456
7
+ MEDfl/LearningManager/model.py,sha256=OZcVzIJ_Q5JMF8TfvqtY34FgeNZOqTocENmDKppQEmA,7537
8
8
  MEDfl/LearningManager/params.yaml,sha256=Ix1cNtlWr3vDC0te6pipl5w8iLADO6dZvwm633-VaIA,436
9
9
  MEDfl/LearningManager/params_optimiser.py,sha256=8e0gCt4imwQHlNSJ3A2EAuc3wSr6yfSI6JDghohfmZQ,17618
10
10
  MEDfl/LearningManager/plot.py,sha256=A6Z8wC8J-H-OmWBPKqwK5eiTB9vzOBGMaFv1SaNA9Js,7698
11
11
  MEDfl/LearningManager/server.py,sha256=oTgW3K1UT6m4SQBk23FIf23km_BDq9vvjeC6OgY8DNw,7077
12
12
  MEDfl/LearningManager/strategy.py,sha256=BHXpwmt7jx07y45YLUs8FZry2gYQbpiV4vNbHhsksQ4,3435
13
- MEDfl/LearningManager/utils.py,sha256=B4RULJp-puJr724O6teI0PxnUyPV8NG-uPC6jqaiDKI,9605
13
+ MEDfl/LearningManager/utils.py,sha256=zk84uKR_iSsoz1uvKD-l4_jJWV3_5pfUWMIeLNcT1mU,9621
14
14
  MEDfl/NetManager/__init__.py,sha256=OpgsIiBg7UA6Bfnu_kqGfEPxU8JfpPxSFU98TOeDTP0,273
15
- MEDfl/NetManager/database_connector.py,sha256=G8DAsD_pAIK1U67x3Q8gmSJGW7iJyxQ_NE5lWpT-P0Q,1474
15
+ MEDfl/NetManager/database_connector.py,sha256=L46vgTuQIzuk4PVWB7Mo1THX4aerIlIlAU5ZTIVGLs8,1553
16
16
  MEDfl/NetManager/dataset.py,sha256=HTV0jrJ4Qlhl2aSJzdFU1lkxGBKtmJ390eBpwfKf_4o,2777
17
17
  MEDfl/NetManager/flsetup.py,sha256=CVu_TIU7l3G6DDnwtY6JURbhIZk7gKC3unqWnU-YtlM,11434
18
18
  MEDfl/NetManager/net_helper.py,sha256=tyfxmpbleSdfPfo2ezKT0VOvZu660v9nhBuHCpl8pG4,6812
@@ -20,7 +20,7 @@ MEDfl/NetManager/net_manager_queries.py,sha256=j-CLQPjtTLyZuFPhIcwJStD7L7xtZpkmk
20
20
  MEDfl/NetManager/network.py,sha256=5t705fzWc-BRg-QPAbAcDv5ckDGzsPwj_Q5V0iTgkx0,6829
21
21
  MEDfl/NetManager/node.py,sha256=t90QuYZ8M1X_AG1bwTta0CnlOuodqkmpVda2K7NOgHc,6542
22
22
  MEDfl/rw/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
- MEDfl/rw/client.py,sha256=AQnLvM-pWVA2Md0GiRLB9C4s8S3TjcjjvN3U25ftpX4,3463
23
+ MEDfl/rw/client.py,sha256=a9JpOSH2HItXsJKqLnvmEAbBcCKvgTTU526lff8pf-k,8013
24
24
  MEDfl/rw/model.py,sha256=OAoTmOw4zGWPa_ncDqNanLeucwWHmUydKED6zlB5Hps,1510
25
25
  MEDfl/rw/server.py,sha256=TmcWUwMJ5BG7owqIVTtTC6w8bR02SESRT9lXh7BqlOg,4986
26
26
  MEDfl/rw/strategy.py,sha256=aNlyQhHslmPJdiuJjsK9hu-IUJdrWR1yuGp7pNk4LeA,5974
@@ -29,8 +29,8 @@ MEDfl/scripts/base.py,sha256=QrmG7gkiPYkAy-5tXxJgJmOSLGAKeIVH6i0jq7G9xnA,752
29
29
  MEDfl/scripts/create_db.py,sha256=MnFtZkTueRZ-3qXPNX4JsXjOKj-4mlkxoRhSFdRcvJw,3817
30
30
  alembic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
31
  alembic/env.py,sha256=-aSZ6SlJeK1ZeqHgM-54hOi9LhJRFP0SZGjut-JnY-4,1588
32
- MEDfl-2.0.4.dev3.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
33
- MEDfl-2.0.4.dev3.dist-info/METADATA,sha256=s5S0Ztt85jbo0x9kzkGRQDNrI9cMCed-qBWO4EA2q10,4326
34
- MEDfl-2.0.4.dev3.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
35
- MEDfl-2.0.4.dev3.dist-info/top_level.txt,sha256=dIL9X8HOFuaVSzpg40DVveDPrymWRoShHtspH7kkjdI,14
36
- MEDfl-2.0.4.dev3.dist-info/RECORD,,
32
+ MEDfl-2.0.4.dev5.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
33
+ MEDfl-2.0.4.dev5.dist-info/METADATA,sha256=rYlywNl8wr5CZiebY_qsA5iDemHvVyH6WOgou_-KAuw,4358
34
+ MEDfl-2.0.4.dev5.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
35
+ MEDfl-2.0.4.dev5.dist-info/top_level.txt,sha256=dIL9X8HOFuaVSzpg40DVveDPrymWRoShHtspH7kkjdI,14
36
+ MEDfl-2.0.4.dev5.dist-info/RECORD,,