MEDfl 2.0.4.dev3__py3-none-any.whl → 2.0.4.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- MEDfl/NetManager/database_connector.py +4 -2
- MEDfl/rw/client.py +156 -33
- {MEDfl-2.0.4.dev3.dist-info → MEDfl-2.0.4.dev4.dist-info}/METADATA +1 -1
- {MEDfl-2.0.4.dev3.dist-info → MEDfl-2.0.4.dev4.dist-info}/RECORD +7 -7
- {MEDfl-2.0.4.dev3.dist-info → MEDfl-2.0.4.dev4.dist-info}/LICENSE +0 -0
- {MEDfl-2.0.4.dev3.dist-info → MEDfl-2.0.4.dev4.dist-info}/WHEEL +0 -0
- {MEDfl-2.0.4.dev3.dist-info → MEDfl-2.0.4.dev4.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
import subprocess
|
2
|
+
import subprocess , sys
|
3
3
|
from sqlalchemy import create_engine
|
4
4
|
from configparser import ConfigParser
|
5
5
|
|
@@ -34,7 +34,9 @@ class DatabaseManager:
|
|
34
34
|
create_db_script_path = os.path.join(current_directory, '..', 'scripts', 'create_db.py')
|
35
35
|
|
36
36
|
# Execute the create_db.py script
|
37
|
-
|
37
|
+
print(sys.executable)
|
38
|
+
result = subprocess.run([sys.executable, create_db_script_path, path_to_csv],
|
39
|
+
capture_output=True, text=True)
|
38
40
|
|
39
41
|
return
|
40
42
|
|
MEDfl/rw/client.py
CHANGED
@@ -1,9 +1,4 @@
|
|
1
1
|
# client.py
|
2
|
-
"""
|
3
|
-
Federated Learning Client with Optional Differential Privacy.
|
4
|
-
|
5
|
-
"""
|
6
|
-
|
7
2
|
import argparse
|
8
3
|
import pandas as pd
|
9
4
|
import flwr as fl
|
@@ -12,22 +7,21 @@ import torch.nn as nn
|
|
12
7
|
import torch.optim as optim
|
13
8
|
from torch.utils.data import TensorDataset, DataLoader
|
14
9
|
from sklearn.metrics import accuracy_score, roc_auc_score
|
15
|
-
from model import Net #
|
10
|
+
from model import Net # your model definition in model.py
|
16
11
|
import socket
|
17
|
-
import platform
|
18
|
-
|
12
|
+
import platform
|
19
13
|
|
14
|
+
# Differential Privacy configuration class
|
20
15
|
class DPConfig:
|
21
16
|
"""
|
22
|
-
Configuration for
|
17
|
+
Configuration for differential privacy.
|
23
18
|
|
24
19
|
Attributes:
|
25
20
|
noise_multiplier (float): Noise multiplier for DP.
|
26
|
-
max_grad_norm (float): Maximum gradient norm
|
21
|
+
max_grad_norm (float): Maximum gradient norm for clipping.
|
27
22
|
batch_size (int): Batch size for training.
|
28
|
-
secure_rng (bool): Whether to use
|
23
|
+
secure_rng (bool): Whether to use secure random generator.
|
29
24
|
"""
|
30
|
-
|
31
25
|
def __init__(
|
32
26
|
self,
|
33
27
|
noise_multiplier: float = 1.0,
|
@@ -42,54 +36,44 @@ class DPConfig:
|
|
42
36
|
|
43
37
|
|
44
38
|
class FlowerClient(fl.client.NumPyClient):
|
45
|
-
"""
|
46
|
-
FlowerClient: A federated learning client that trains a PyTorch model
|
47
|
-
and optionally applies differential privacy.
|
48
|
-
|
49
|
-
"""
|
50
|
-
|
51
39
|
def __init__(
|
52
40
|
self,
|
53
41
|
server_address: str,
|
54
42
|
data_path: str = "data/data.csv",
|
55
43
|
dp_config: DPConfig = None,
|
56
44
|
):
|
57
|
-
"""
|
58
|
-
Initialize client by loading data, creating model, optimizer,
|
59
|
-
and optionally enabling DP.
|
60
|
-
|
61
|
-
Args:
|
62
|
-
server_address (str): Flower server address.
|
63
|
-
data_path (str): Path to CSV dataset.
|
64
|
-
dp_config (DPConfig): Optional DP configuration.
|
65
|
-
"""
|
66
45
|
self.server_address = server_address
|
67
46
|
|
68
|
-
#
|
47
|
+
# 1. Load data from CSV
|
69
48
|
df = pd.read_csv(data_path)
|
70
49
|
X = df.iloc[:, :-1].values
|
71
50
|
y = df.iloc[:, -1].values
|
72
51
|
|
52
|
+
# Convert to tensors and store for metrics
|
73
53
|
self.X_tensor = torch.tensor(X, dtype=torch.float32)
|
74
54
|
self.y_tensor = torch.tensor(y, dtype=torch.float32)
|
75
55
|
|
56
|
+
# Create DataLoader
|
76
57
|
batch_size = dp_config.batch_size if dp_config else 32
|
77
58
|
dataset = TensorDataset(self.X_tensor, self.y_tensor)
|
78
59
|
self.train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
79
60
|
|
80
|
-
#
|
61
|
+
# 2. Initialize model, loss, optimizer
|
81
62
|
input_dim = X.shape[1]
|
82
63
|
self.model = Net(input_dim)
|
64
|
+
|
83
65
|
self.criterion = nn.BCEWithLogitsLoss()
|
84
66
|
self.optimizer = optim.SGD(self.model.parameters(), lr=0.01)
|
85
67
|
|
86
|
-
#
|
68
|
+
# 3. Set up differential privacy if requested
|
87
69
|
self.privacy_engine = None
|
88
70
|
if dp_config is not None:
|
89
71
|
try:
|
90
72
|
from opacus import PrivacyEngine
|
91
73
|
|
74
|
+
# Instantiate engine without secure_rng argument
|
92
75
|
self.privacy_engine = PrivacyEngine()
|
76
|
+
# Attach privacy engine with secure_rng flag
|
93
77
|
self.model, self.optimizer, self.train_loader = self.privacy_engine.make_private(
|
94
78
|
module=self.model,
|
95
79
|
optimizer=self.optimizer,
|
@@ -99,10 +83,149 @@ class FlowerClient(fl.client.NumPyClient):
|
|
99
83
|
secure_rng=dp_config.secure_rng,
|
100
84
|
)
|
101
85
|
except ImportError:
|
102
|
-
print("
|
86
|
+
print("Opacus not installed, running without DP.")
|
103
87
|
|
104
88
|
def get_parameters(self, config):
|
89
|
+
return [val.cpu().numpy() for val in self.model.state_dict().values()]
|
90
|
+
|
91
|
+
def set_parameters(self, parameters):
|
92
|
+
params_dict = zip(self.model.state_dict().keys(), parameters)
|
93
|
+
state_dict = {k: torch.tensor(v) for k, v in params_dict}
|
94
|
+
self.model.load_state_dict(state_dict, strict=True)
|
95
|
+
|
96
|
+
def fit(self, parameters, config):
|
97
|
+
# Update model parameters
|
98
|
+
self.set_parameters(parameters)
|
99
|
+
self.model.train()
|
100
|
+
|
101
|
+
total_loss = 0.0
|
102
|
+
for X_batch, y_batch in self.train_loader:
|
103
|
+
self.optimizer.zero_grad()
|
104
|
+
outputs = self.model(X_batch)
|
105
|
+
loss = self.criterion(outputs.squeeze(), y_batch)
|
106
|
+
loss.backward()
|
107
|
+
self.optimizer.step()
|
108
|
+
total_loss += loss.item() * X_batch.size(0)
|
109
|
+
|
110
|
+
avg_loss = total_loss / len(self.train_loader.dataset)
|
111
|
+
# Compute training metrics (accuracy, AUC) on full dataset
|
112
|
+
with torch.no_grad():
|
113
|
+
logits = self.model(self.X_tensor).squeeze()
|
114
|
+
probs = torch.sigmoid(logits).cpu().numpy()
|
115
|
+
y_true = self.y_tensor.cpu().numpy()
|
116
|
+
binary_preds = (probs >= 0.5).astype(int)
|
117
|
+
acc = accuracy_score(y_true, binary_preds)
|
118
|
+
auc = roc_auc_score(y_true, probs)
|
119
|
+
|
120
|
+
hostname = socket.gethostname()
|
121
|
+
os_type = platform.system()
|
122
|
+
|
123
|
+
metrics = {"hostname": hostname, "train_loss": avg_loss, "train_accuracy": acc, "train_auc": auc , "os_type": os_type}
|
124
|
+
return self.get_parameters(config), len(self.train_loader.dataset), metrics
|
125
|
+
|
126
|
+
def evaluate(self, parameters, config):
|
127
|
+
# Update model parameters
|
128
|
+
self.set_parameters(parameters)
|
129
|
+
self.model.eval()
|
130
|
+
|
131
|
+
total_loss = 0.0
|
132
|
+
all_probs = []
|
133
|
+
all_true = []
|
134
|
+
with torch.no_grad():
|
135
|
+
for X_batch, y_batch in self.train_loader:
|
136
|
+
outputs = self.model(X_batch)
|
137
|
+
loss = self.criterion(outputs.squeeze(), y_batch)
|
138
|
+
total_loss += loss.item() * X_batch.size(0)
|
139
|
+
probs = torch.sigmoid(outputs.squeeze()).cpu().numpy()
|
140
|
+
all_probs.extend(probs.tolist())
|
141
|
+
all_true.extend(y_batch.cpu().numpy().tolist())
|
142
|
+
|
143
|
+
avg_loss = total_loss / len(self.train_loader.dataset)
|
144
|
+
binary_preds = [1 if p >= 0.5 else 0 for p in all_probs]
|
145
|
+
acc = accuracy_score(all_true, binary_preds)
|
146
|
+
auc = roc_auc_score(all_true, all_probs)
|
147
|
+
|
148
|
+
metrics = {"eval_loss": avg_loss, "eval_accuracy": acc, "eval_auc": auc}
|
149
|
+
print(f"Evaluation metrics: {metrics}")
|
150
|
+
return float(avg_loss), len(self.train_loader.dataset), metrics
|
151
|
+
|
152
|
+
# client.py - Update the get_properties method
|
153
|
+
def get_properties(self, config):
|
105
154
|
"""
|
106
|
-
|
155
|
+
Return dataset statistics before training starts.
|
156
|
+
|
157
|
+
NOTE: Only scalar values (int, float, str, bool, bytes) are allowed.
|
158
|
+
Lists will cause the “not a 1:1 mapping” TypeError.
|
107
159
|
"""
|
108
|
-
|
160
|
+
num_samples = len(self.X_tensor)
|
161
|
+
num_features = self.X_tensor.shape[1]
|
162
|
+
# Convert list to comma‑separated string
|
163
|
+
column_names = [f"feature_{i}" for i in range(num_features)]
|
164
|
+
columns_str = ",".join(column_names)
|
165
|
+
|
166
|
+
hostname = socket.gethostname()
|
167
|
+
os_type = platform.system()
|
168
|
+
|
169
|
+
return {
|
170
|
+
"hostname": hostname,
|
171
|
+
"os_type": os_type,
|
172
|
+
"num_samples": num_samples,
|
173
|
+
"num_features": num_features,
|
174
|
+
"columns": columns_str, # now a single str, not a list
|
175
|
+
}
|
176
|
+
|
177
|
+
|
178
|
+
|
179
|
+
if __name__ == "__main__":
|
180
|
+
parser = argparse.ArgumentParser(description="Flower client with DP support and metrics")
|
181
|
+
parser.add_argument(
|
182
|
+
"--server_address",
|
183
|
+
type=str,
|
184
|
+
required=True,
|
185
|
+
help="Address of the Flower server (e.g., 127.0.0.1:8080)",
|
186
|
+
)
|
187
|
+
parser.add_argument(
|
188
|
+
"--data_path",
|
189
|
+
type=str,
|
190
|
+
default="data/data.csv",
|
191
|
+
help="Path to your CSV training data",
|
192
|
+
)
|
193
|
+
parser.add_argument(
|
194
|
+
"--dp",
|
195
|
+
action="store_true",
|
196
|
+
help="Enable differential privacy",
|
197
|
+
)
|
198
|
+
parser.add_argument(
|
199
|
+
"--noise_multiplier",
|
200
|
+
type=float,
|
201
|
+
default=1.0,
|
202
|
+
help="Noise multiplier for DP",
|
203
|
+
)
|
204
|
+
parser.add_argument(
|
205
|
+
"--max_grad_norm",
|
206
|
+
type=float,
|
207
|
+
default=1.0,
|
208
|
+
help="Clip norm for DP",
|
209
|
+
)
|
210
|
+
parser.add_argument(
|
211
|
+
"--batch_size",
|
212
|
+
type=int,
|
213
|
+
default=32,
|
214
|
+
help="Batch size for training",
|
215
|
+
)
|
216
|
+
args = parser.parse_args()
|
217
|
+
|
218
|
+
dp_config = None
|
219
|
+
if args.dp:
|
220
|
+
dp_config = DPConfig(
|
221
|
+
noise_multiplier=args.noise_multiplier,
|
222
|
+
max_grad_norm=args.max_grad_norm,
|
223
|
+
batch_size=args.batch_size,
|
224
|
+
)
|
225
|
+
|
226
|
+
client = FlowerClient(
|
227
|
+
server_address=args.server_address,
|
228
|
+
data_path=args.data_path,
|
229
|
+
dp_config=dp_config,
|
230
|
+
)
|
231
|
+
fl.client.start_numpy_client(server_address=client.server_address, client=client)
|
@@ -12,7 +12,7 @@ MEDfl/LearningManager/server.py,sha256=oTgW3K1UT6m4SQBk23FIf23km_BDq9vvjeC6OgY8D
|
|
12
12
|
MEDfl/LearningManager/strategy.py,sha256=BHXpwmt7jx07y45YLUs8FZry2gYQbpiV4vNbHhsksQ4,3435
|
13
13
|
MEDfl/LearningManager/utils.py,sha256=B4RULJp-puJr724O6teI0PxnUyPV8NG-uPC6jqaiDKI,9605
|
14
14
|
MEDfl/NetManager/__init__.py,sha256=OpgsIiBg7UA6Bfnu_kqGfEPxU8JfpPxSFU98TOeDTP0,273
|
15
|
-
MEDfl/NetManager/database_connector.py,sha256=
|
15
|
+
MEDfl/NetManager/database_connector.py,sha256=L46vgTuQIzuk4PVWB7Mo1THX4aerIlIlAU5ZTIVGLs8,1553
|
16
16
|
MEDfl/NetManager/dataset.py,sha256=HTV0jrJ4Qlhl2aSJzdFU1lkxGBKtmJ390eBpwfKf_4o,2777
|
17
17
|
MEDfl/NetManager/flsetup.py,sha256=CVu_TIU7l3G6DDnwtY6JURbhIZk7gKC3unqWnU-YtlM,11434
|
18
18
|
MEDfl/NetManager/net_helper.py,sha256=tyfxmpbleSdfPfo2ezKT0VOvZu660v9nhBuHCpl8pG4,6812
|
@@ -20,7 +20,7 @@ MEDfl/NetManager/net_manager_queries.py,sha256=j-CLQPjtTLyZuFPhIcwJStD7L7xtZpkmk
|
|
20
20
|
MEDfl/NetManager/network.py,sha256=5t705fzWc-BRg-QPAbAcDv5ckDGzsPwj_Q5V0iTgkx0,6829
|
21
21
|
MEDfl/NetManager/node.py,sha256=t90QuYZ8M1X_AG1bwTta0CnlOuodqkmpVda2K7NOgHc,6542
|
22
22
|
MEDfl/rw/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
23
|
-
MEDfl/rw/client.py,sha256=
|
23
|
+
MEDfl/rw/client.py,sha256=a9JpOSH2HItXsJKqLnvmEAbBcCKvgTTU526lff8pf-k,8013
|
24
24
|
MEDfl/rw/model.py,sha256=OAoTmOw4zGWPa_ncDqNanLeucwWHmUydKED6zlB5Hps,1510
|
25
25
|
MEDfl/rw/server.py,sha256=TmcWUwMJ5BG7owqIVTtTC6w8bR02SESRT9lXh7BqlOg,4986
|
26
26
|
MEDfl/rw/strategy.py,sha256=aNlyQhHslmPJdiuJjsK9hu-IUJdrWR1yuGp7pNk4LeA,5974
|
@@ -29,8 +29,8 @@ MEDfl/scripts/base.py,sha256=QrmG7gkiPYkAy-5tXxJgJmOSLGAKeIVH6i0jq7G9xnA,752
|
|
29
29
|
MEDfl/scripts/create_db.py,sha256=MnFtZkTueRZ-3qXPNX4JsXjOKj-4mlkxoRhSFdRcvJw,3817
|
30
30
|
alembic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
31
31
|
alembic/env.py,sha256=-aSZ6SlJeK1ZeqHgM-54hOi9LhJRFP0SZGjut-JnY-4,1588
|
32
|
-
MEDfl-2.0.4.
|
33
|
-
MEDfl-2.0.4.
|
34
|
-
MEDfl-2.0.4.
|
35
|
-
MEDfl-2.0.4.
|
36
|
-
MEDfl-2.0.4.
|
32
|
+
MEDfl-2.0.4.dev4.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
33
|
+
MEDfl-2.0.4.dev4.dist-info/METADATA,sha256=0nPl8L3hF9N_bUQ7SD8I6lDo0ZJcR1DrYU1RIdFbKOg,4326
|
34
|
+
MEDfl-2.0.4.dev4.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
35
|
+
MEDfl-2.0.4.dev4.dist-info/top_level.txt,sha256=dIL9X8HOFuaVSzpg40DVveDPrymWRoShHtspH7kkjdI,14
|
36
|
+
MEDfl-2.0.4.dev4.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|