MEDfl 2.0.4.dev3__py3-none-any.whl → 2.0.4.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,5 @@
1
1
  import os
2
- import subprocess
2
+ import subprocess , sys
3
3
  from sqlalchemy import create_engine
4
4
  from configparser import ConfigParser
5
5
 
@@ -34,7 +34,9 @@ class DatabaseManager:
34
34
  create_db_script_path = os.path.join(current_directory, '..', 'scripts', 'create_db.py')
35
35
 
36
36
  # Execute the create_db.py script
37
- subprocess.run(['python3', create_db_script_path, path_to_csv], check=True)
37
+ print(sys.executable)
38
+ result = subprocess.run([sys.executable, create_db_script_path, path_to_csv],
39
+ capture_output=True, text=True)
38
40
 
39
41
  return
40
42
 
MEDfl/rw/client.py CHANGED
@@ -1,9 +1,4 @@
1
1
  # client.py
2
- """
3
- Federated Learning Client with Optional Differential Privacy.
4
-
5
- """
6
-
7
2
  import argparse
8
3
  import pandas as pd
9
4
  import flwr as fl
@@ -12,22 +7,21 @@ import torch.nn as nn
12
7
  import torch.optim as optim
13
8
  from torch.utils.data import TensorDataset, DataLoader
14
9
  from sklearn.metrics import accuracy_score, roc_auc_score
15
- from model import Net # Local model definition
10
+ from model import Net # your model definition in model.py
16
11
  import socket
17
- import platform
18
-
12
+ import platform
19
13
 
14
+ # Differential Privacy configuration class
20
15
  class DPConfig:
21
16
  """
22
- Configuration for Differential Privacy (DP) settings.
17
+ Configuration for differential privacy.
23
18
 
24
19
  Attributes:
25
20
  noise_multiplier (float): Noise multiplier for DP.
26
- max_grad_norm (float): Maximum gradient norm (clipping threshold).
21
+ max_grad_norm (float): Maximum gradient norm for clipping.
27
22
  batch_size (int): Batch size for training.
28
- secure_rng (bool): Whether to use a secure RNG for DP noise.
23
+ secure_rng (bool): Whether to use secure random generator.
29
24
  """
30
-
31
25
  def __init__(
32
26
  self,
33
27
  noise_multiplier: float = 1.0,
@@ -42,54 +36,44 @@ class DPConfig:
42
36
 
43
37
 
44
38
  class FlowerClient(fl.client.NumPyClient):
45
- """
46
- FlowerClient: A federated learning client that trains a PyTorch model
47
- and optionally applies differential privacy.
48
-
49
- """
50
-
51
39
  def __init__(
52
40
  self,
53
41
  server_address: str,
54
42
  data_path: str = "data/data.csv",
55
43
  dp_config: DPConfig = None,
56
44
  ):
57
- """
58
- Initialize client by loading data, creating model, optimizer,
59
- and optionally enabling DP.
60
-
61
- Args:
62
- server_address (str): Flower server address.
63
- data_path (str): Path to CSV dataset.
64
- dp_config (DPConfig): Optional DP configuration.
65
- """
66
45
  self.server_address = server_address
67
46
 
68
- # ---------- Load Data ----------
47
+ # 1. Load data from CSV
69
48
  df = pd.read_csv(data_path)
70
49
  X = df.iloc[:, :-1].values
71
50
  y = df.iloc[:, -1].values
72
51
 
52
+ # Convert to tensors and store for metrics
73
53
  self.X_tensor = torch.tensor(X, dtype=torch.float32)
74
54
  self.y_tensor = torch.tensor(y, dtype=torch.float32)
75
55
 
56
+ # Create DataLoader
76
57
  batch_size = dp_config.batch_size if dp_config else 32
77
58
  dataset = TensorDataset(self.X_tensor, self.y_tensor)
78
59
  self.train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
79
60
 
80
- # ---------- Model and Optimizer ----------
61
+ # 2. Initialize model, loss, optimizer
81
62
  input_dim = X.shape[1]
82
63
  self.model = Net(input_dim)
64
+
83
65
  self.criterion = nn.BCEWithLogitsLoss()
84
66
  self.optimizer = optim.SGD(self.model.parameters(), lr=0.01)
85
67
 
86
- # ---------- Differential Privacy ----------
68
+ # 3. Set up differential privacy if requested
87
69
  self.privacy_engine = None
88
70
  if dp_config is not None:
89
71
  try:
90
72
  from opacus import PrivacyEngine
91
73
 
74
+ # Instantiate engine without secure_rng argument
92
75
  self.privacy_engine = PrivacyEngine()
76
+ # Attach privacy engine with secure_rng flag
93
77
  self.model, self.optimizer, self.train_loader = self.privacy_engine.make_private(
94
78
  module=self.model,
95
79
  optimizer=self.optimizer,
@@ -99,10 +83,149 @@ class FlowerClient(fl.client.NumPyClient):
99
83
  secure_rng=dp_config.secure_rng,
100
84
  )
101
85
  except ImportError:
102
- print("⚠️ Opacus not installed, running without DP.")
86
+ print("Opacus not installed, running without DP.")
103
87
 
104
88
  def get_parameters(self, config):
89
+ return [val.cpu().numpy() for val in self.model.state_dict().values()]
90
+
91
+ def set_parameters(self, parameters):
92
+ params_dict = zip(self.model.state_dict().keys(), parameters)
93
+ state_dict = {k: torch.tensor(v) for k, v in params_dict}
94
+ self.model.load_state_dict(state_dict, strict=True)
95
+
96
+ def fit(self, parameters, config):
97
+ # Update model parameters
98
+ self.set_parameters(parameters)
99
+ self.model.train()
100
+
101
+ total_loss = 0.0
102
+ for X_batch, y_batch in self.train_loader:
103
+ self.optimizer.zero_grad()
104
+ outputs = self.model(X_batch)
105
+ loss = self.criterion(outputs.squeeze(), y_batch)
106
+ loss.backward()
107
+ self.optimizer.step()
108
+ total_loss += loss.item() * X_batch.size(0)
109
+
110
+ avg_loss = total_loss / len(self.train_loader.dataset)
111
+ # Compute training metrics (accuracy, AUC) on full dataset
112
+ with torch.no_grad():
113
+ logits = self.model(self.X_tensor).squeeze()
114
+ probs = torch.sigmoid(logits).cpu().numpy()
115
+ y_true = self.y_tensor.cpu().numpy()
116
+ binary_preds = (probs >= 0.5).astype(int)
117
+ acc = accuracy_score(y_true, binary_preds)
118
+ auc = roc_auc_score(y_true, probs)
119
+
120
+ hostname = socket.gethostname()
121
+ os_type = platform.system()
122
+
123
+ metrics = {"hostname": hostname, "train_loss": avg_loss, "train_accuracy": acc, "train_auc": auc , "os_type": os_type}
124
+ return self.get_parameters(config), len(self.train_loader.dataset), metrics
125
+
126
+ def evaluate(self, parameters, config):
127
+ # Update model parameters
128
+ self.set_parameters(parameters)
129
+ self.model.eval()
130
+
131
+ total_loss = 0.0
132
+ all_probs = []
133
+ all_true = []
134
+ with torch.no_grad():
135
+ for X_batch, y_batch in self.train_loader:
136
+ outputs = self.model(X_batch)
137
+ loss = self.criterion(outputs.squeeze(), y_batch)
138
+ total_loss += loss.item() * X_batch.size(0)
139
+ probs = torch.sigmoid(outputs.squeeze()).cpu().numpy()
140
+ all_probs.extend(probs.tolist())
141
+ all_true.extend(y_batch.cpu().numpy().tolist())
142
+
143
+ avg_loss = total_loss / len(self.train_loader.dataset)
144
+ binary_preds = [1 if p >= 0.5 else 0 for p in all_probs]
145
+ acc = accuracy_score(all_true, binary_preds)
146
+ auc = roc_auc_score(all_true, all_probs)
147
+
148
+ metrics = {"eval_loss": avg_loss, "eval_accuracy": acc, "eval_auc": auc}
149
+ print(f"Evaluation metrics: {metrics}")
150
+ return float(avg_loss), len(self.train_loader.dataset), metrics
151
+
152
+ # client.py - Update the get_properties method
153
+ def get_properties(self, config):
105
154
  """
106
- Get model parameters as a list of NumPy arrays.
155
+ Return dataset statistics before training starts.
156
+
157
+ NOTE: Only scalar values (int, float, str, bool, bytes) are allowed.
158
+ Lists will cause the “not a 1:1 mapping” TypeError.
107
159
  """
108
- return [val.cpu().numpy() for val in self.mo]()
160
+ num_samples = len(self.X_tensor)
161
+ num_features = self.X_tensor.shape[1]
162
+ # Convert list to comma‑separated string
163
+ column_names = [f"feature_{i}" for i in range(num_features)]
164
+ columns_str = ",".join(column_names)
165
+
166
+ hostname = socket.gethostname()
167
+ os_type = platform.system()
168
+
169
+ return {
170
+ "hostname": hostname,
171
+ "os_type": os_type,
172
+ "num_samples": num_samples,
173
+ "num_features": num_features,
174
+ "columns": columns_str, # now a single str, not a list
175
+ }
176
+
177
+
178
+
179
+ if __name__ == "__main__":
180
+ parser = argparse.ArgumentParser(description="Flower client with DP support and metrics")
181
+ parser.add_argument(
182
+ "--server_address",
183
+ type=str,
184
+ required=True,
185
+ help="Address of the Flower server (e.g., 127.0.0.1:8080)",
186
+ )
187
+ parser.add_argument(
188
+ "--data_path",
189
+ type=str,
190
+ default="data/data.csv",
191
+ help="Path to your CSV training data",
192
+ )
193
+ parser.add_argument(
194
+ "--dp",
195
+ action="store_true",
196
+ help="Enable differential privacy",
197
+ )
198
+ parser.add_argument(
199
+ "--noise_multiplier",
200
+ type=float,
201
+ default=1.0,
202
+ help="Noise multiplier for DP",
203
+ )
204
+ parser.add_argument(
205
+ "--max_grad_norm",
206
+ type=float,
207
+ default=1.0,
208
+ help="Clip norm for DP",
209
+ )
210
+ parser.add_argument(
211
+ "--batch_size",
212
+ type=int,
213
+ default=32,
214
+ help="Batch size for training",
215
+ )
216
+ args = parser.parse_args()
217
+
218
+ dp_config = None
219
+ if args.dp:
220
+ dp_config = DPConfig(
221
+ noise_multiplier=args.noise_multiplier,
222
+ max_grad_norm=args.max_grad_norm,
223
+ batch_size=args.batch_size,
224
+ )
225
+
226
+ client = FlowerClient(
227
+ server_address=args.server_address,
228
+ data_path=args.data_path,
229
+ dp_config=dp_config,
230
+ )
231
+ fl.client.start_numpy_client(server_address=client.server_address, client=client)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: MEDfl
3
- Version: 2.0.4.dev3
3
+ Version: 2.0.4.dev4
4
4
  Summary: Python Open-source package for simulating federated learning and differential privacy
5
5
  Home-page: https://github.com/MEDomics-UdeS/MEDfl
6
6
  Author: MEDomics consortium
@@ -12,7 +12,7 @@ MEDfl/LearningManager/server.py,sha256=oTgW3K1UT6m4SQBk23FIf23km_BDq9vvjeC6OgY8D
12
12
  MEDfl/LearningManager/strategy.py,sha256=BHXpwmt7jx07y45YLUs8FZry2gYQbpiV4vNbHhsksQ4,3435
13
13
  MEDfl/LearningManager/utils.py,sha256=B4RULJp-puJr724O6teI0PxnUyPV8NG-uPC6jqaiDKI,9605
14
14
  MEDfl/NetManager/__init__.py,sha256=OpgsIiBg7UA6Bfnu_kqGfEPxU8JfpPxSFU98TOeDTP0,273
15
- MEDfl/NetManager/database_connector.py,sha256=G8DAsD_pAIK1U67x3Q8gmSJGW7iJyxQ_NE5lWpT-P0Q,1474
15
+ MEDfl/NetManager/database_connector.py,sha256=L46vgTuQIzuk4PVWB7Mo1THX4aerIlIlAU5ZTIVGLs8,1553
16
16
  MEDfl/NetManager/dataset.py,sha256=HTV0jrJ4Qlhl2aSJzdFU1lkxGBKtmJ390eBpwfKf_4o,2777
17
17
  MEDfl/NetManager/flsetup.py,sha256=CVu_TIU7l3G6DDnwtY6JURbhIZk7gKC3unqWnU-YtlM,11434
18
18
  MEDfl/NetManager/net_helper.py,sha256=tyfxmpbleSdfPfo2ezKT0VOvZu660v9nhBuHCpl8pG4,6812
@@ -20,7 +20,7 @@ MEDfl/NetManager/net_manager_queries.py,sha256=j-CLQPjtTLyZuFPhIcwJStD7L7xtZpkmk
20
20
  MEDfl/NetManager/network.py,sha256=5t705fzWc-BRg-QPAbAcDv5ckDGzsPwj_Q5V0iTgkx0,6829
21
21
  MEDfl/NetManager/node.py,sha256=t90QuYZ8M1X_AG1bwTta0CnlOuodqkmpVda2K7NOgHc,6542
22
22
  MEDfl/rw/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
- MEDfl/rw/client.py,sha256=AQnLvM-pWVA2Md0GiRLB9C4s8S3TjcjjvN3U25ftpX4,3463
23
+ MEDfl/rw/client.py,sha256=a9JpOSH2HItXsJKqLnvmEAbBcCKvgTTU526lff8pf-k,8013
24
24
  MEDfl/rw/model.py,sha256=OAoTmOw4zGWPa_ncDqNanLeucwWHmUydKED6zlB5Hps,1510
25
25
  MEDfl/rw/server.py,sha256=TmcWUwMJ5BG7owqIVTtTC6w8bR02SESRT9lXh7BqlOg,4986
26
26
  MEDfl/rw/strategy.py,sha256=aNlyQhHslmPJdiuJjsK9hu-IUJdrWR1yuGp7pNk4LeA,5974
@@ -29,8 +29,8 @@ MEDfl/scripts/base.py,sha256=QrmG7gkiPYkAy-5tXxJgJmOSLGAKeIVH6i0jq7G9xnA,752
29
29
  MEDfl/scripts/create_db.py,sha256=MnFtZkTueRZ-3qXPNX4JsXjOKj-4mlkxoRhSFdRcvJw,3817
30
30
  alembic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
31
  alembic/env.py,sha256=-aSZ6SlJeK1ZeqHgM-54hOi9LhJRFP0SZGjut-JnY-4,1588
32
- MEDfl-2.0.4.dev3.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
33
- MEDfl-2.0.4.dev3.dist-info/METADATA,sha256=s5S0Ztt85jbo0x9kzkGRQDNrI9cMCed-qBWO4EA2q10,4326
34
- MEDfl-2.0.4.dev3.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
35
- MEDfl-2.0.4.dev3.dist-info/top_level.txt,sha256=dIL9X8HOFuaVSzpg40DVveDPrymWRoShHtspH7kkjdI,14
36
- MEDfl-2.0.4.dev3.dist-info/RECORD,,
32
+ MEDfl-2.0.4.dev4.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
33
+ MEDfl-2.0.4.dev4.dist-info/METADATA,sha256=0nPl8L3hF9N_bUQ7SD8I6lDo0ZJcR1DrYU1RIdFbKOg,4326
34
+ MEDfl-2.0.4.dev4.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
35
+ MEDfl-2.0.4.dev4.dist-info/top_level.txt,sha256=dIL9X8HOFuaVSzpg40DVveDPrymWRoShHtspH7kkjdI,14
36
+ MEDfl-2.0.4.dev4.dist-info/RECORD,,