MEDfl 2.0.4.dev0__py3-none-any.whl → 2.0.4.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- MEDfl/rw/model.py +18 -74
- {medfl-2.0.4.dev0.dist-info → MEDfl-2.0.4.dev1.dist-info}/METADATA +2 -14
- {medfl-2.0.4.dev0.dist-info → MEDfl-2.0.4.dev1.dist-info}/RECORD +6 -6
- {medfl-2.0.4.dev0.dist-info → MEDfl-2.0.4.dev1.dist-info}/WHEEL +1 -1
- {medfl-2.0.4.dev0.dist-info/licenses → MEDfl-2.0.4.dev1.dist-info}/LICENSE +0 -0
- {medfl-2.0.4.dev0.dist-info → MEDfl-2.0.4.dev1.dist-info}/top_level.txt +0 -0
MEDfl/rw/model.py
CHANGED
@@ -1,75 +1,19 @@
|
|
1
|
-
# client.py
|
2
|
-
import argparse
|
3
|
-
import pandas as pd
|
4
|
-
import flwr as fl
|
5
|
-
import torch
|
6
1
|
import torch.nn as nn
|
7
|
-
import torch.
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
self.
|
13
|
-
|
14
|
-
|
15
|
-
self.
|
16
|
-
self.
|
17
|
-
self.
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
self.X_train = torch.tensor(X, dtype=torch.float32)
|
26
|
-
# If it's regression with single output; remove unsqueeze for multi-class
|
27
|
-
self.y_train = torch.tensor(y, dtype=torch.float32).unsqueeze(1)
|
28
|
-
|
29
|
-
def get_parameters(self, config):
|
30
|
-
return [val.cpu().numpy() for val in self.model.state_dict().values()]
|
31
|
-
|
32
|
-
def set_parameters(self, parameters):
|
33
|
-
params_dict = zip(self.model.state_dict().keys(), parameters)
|
34
|
-
state_dict = {k: torch.tensor(v) for k, v in params_dict}
|
35
|
-
self.model.load_state_dict(state_dict, strict=True)
|
36
|
-
|
37
|
-
def fit(self, parameters, config):
|
38
|
-
self.set_parameters(parameters)
|
39
|
-
self.model.train()
|
40
|
-
for _ in range(5):
|
41
|
-
self.optimizer.zero_grad()
|
42
|
-
preds = self.model(self.X_train)
|
43
|
-
loss = self.loss_fn(preds, self.y_train)
|
44
|
-
loss.backward()
|
45
|
-
self.optimizer.step()
|
46
|
-
# Return updated params, number of examples, and an empty metrics dict
|
47
|
-
return self.get_parameters(config), len(self.X_train), {}
|
48
|
-
|
49
|
-
def evaluate(self, parameters, config):
|
50
|
-
self.set_parameters(parameters)
|
51
|
-
self.model.eval()
|
52
|
-
with torch.no_grad():
|
53
|
-
preds = self.model(self.X_train)
|
54
|
-
loss = self.loss_fn(preds, self.y_train).item()
|
55
|
-
return float(loss), len(self.X_train), {}
|
56
|
-
|
57
|
-
if __name__ == "__main__":
|
58
|
-
parser = argparse.ArgumentParser(description="Flower client")
|
59
|
-
parser.add_argument(
|
60
|
-
"--server_address",
|
61
|
-
type=str,
|
62
|
-
required=True,
|
63
|
-
help="Address of the Flower server (e.g., 127.0.0.1:8080)",
|
64
|
-
)
|
65
|
-
parser.add_argument(
|
66
|
-
"--data_path",
|
67
|
-
type=str,
|
68
|
-
default="data/data.csv",
|
69
|
-
help="Path to your CSV training data",
|
70
|
-
)
|
71
|
-
args = parser.parse_args()
|
72
|
-
|
73
|
-
# Instantiate and start the client
|
74
|
-
client = FlowerClient(server_address=args.server_address, data_path=args.data_path)
|
75
|
-
fl.client.start_numpy_client(server_address=client.server_address, client=client)
|
2
|
+
import torch.nn.functional as F
|
3
|
+
|
4
|
+
class Net(nn.Module):
|
5
|
+
def __init__(self, input_dim):
|
6
|
+
super().__init__()
|
7
|
+
self.fc1 = nn.Linear(input_dim, 64)
|
8
|
+
self.fc2 = nn.Linear(64, 32)
|
9
|
+
self.fc3 = nn.Linear(32, 1)
|
10
|
+
self.dropout = nn.Dropout(0.3)
|
11
|
+
self.batchnorm1 = nn.BatchNorm1d(64)
|
12
|
+
self.batchnorm2 = nn.BatchNorm1d(32)
|
13
|
+
|
14
|
+
def forward(self, x):
|
15
|
+
x = F.relu(self.batchnorm1(self.fc1(x)))
|
16
|
+
x = self.dropout(x)
|
17
|
+
x = F.relu(self.batchnorm2(self.fc2(x)))
|
18
|
+
x = self.dropout(x)
|
19
|
+
return self.fc3(x) # raw logits for BCEWithLogitsLoss
|
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.1
|
2
2
|
Name: MEDfl
|
3
|
-
Version: 2.0.4.
|
3
|
+
Version: 2.0.4.dev1
|
4
4
|
Summary: Python Open-source package for simulating federated learning and differential privacy
|
5
5
|
Home-page: https://github.com/MEDomics-UdeS/MEDfl
|
6
6
|
Author: MEDomics consortium
|
@@ -35,18 +35,6 @@ Requires-Dist: plotly==5.19.0
|
|
35
35
|
Requires-Dist: optuna==3.5.0
|
36
36
|
Requires-Dist: mysql-connector-python~=9.3.0
|
37
37
|
Requires-Dist: seaborn~=0.13.2
|
38
|
-
Dynamic: author
|
39
|
-
Dynamic: author-email
|
40
|
-
Dynamic: classifier
|
41
|
-
Dynamic: description
|
42
|
-
Dynamic: description-content-type
|
43
|
-
Dynamic: home-page
|
44
|
-
Dynamic: keywords
|
45
|
-
Dynamic: license-file
|
46
|
-
Dynamic: project-url
|
47
|
-
Dynamic: requires-dist
|
48
|
-
Dynamic: requires-python
|
49
|
-
Dynamic: summary
|
50
38
|
|
51
39
|
# MEDfl: Federated Learning and Differential Privacy Simulation Tool for Tabular Data
|
52
40
|

|
@@ -21,7 +21,7 @@ MEDfl/NetManager/network.py,sha256=5t705fzWc-BRg-QPAbAcDv5ckDGzsPwj_Q5V0iTgkx0,6
|
|
21
21
|
MEDfl/NetManager/node.py,sha256=t90QuYZ8M1X_AG1bwTta0CnlOuodqkmpVda2K7NOgHc,6542
|
22
22
|
MEDfl/rw/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
23
23
|
MEDfl/rw/client.py,sha256=k8y8Wxh2KNe2oy5gRD-KXpTEGCYzp7X2oF5-Z6Rk1_E,1329
|
24
|
-
MEDfl/rw/model.py,sha256=
|
24
|
+
MEDfl/rw/model.py,sha256=TKCfE4nYx75uQdgABwEMkb_ynT-xS_MxNPbGAzJ3EcQ,629
|
25
25
|
MEDfl/rw/rwConfig.py,sha256=nK3Inv7v7Dm9gZnUnK5EqA4DmQ7TqiH4UoCZ8MlgFjA,823
|
26
26
|
MEDfl/rw/server.py,sha256=PiCrUTlnx7rVcO9DcT-vnJF5WkOCe4eEzWXeRSUBh10,3286
|
27
27
|
MEDfl/rw/strategy.py,sha256=sUwu0aAq6q3sKnfRimCRfps3be8s2iepGoD9NfcyjXI,6233
|
@@ -55,8 +55,8 @@ Medfl/scripts/base.py,sha256=QrmG7gkiPYkAy-5tXxJgJmOSLGAKeIVH6i0jq7G9xnA,752
|
|
55
55
|
Medfl/scripts/create_db.py,sha256=MnFtZkTueRZ-3qXPNX4JsXjOKj-4mlkxoRhSFdRcvJw,3817
|
56
56
|
alembic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
57
57
|
alembic/env.py,sha256=-aSZ6SlJeK1ZeqHgM-54hOi9LhJRFP0SZGjut-JnY-4,1588
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
58
|
+
MEDfl-2.0.4.dev1.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
59
|
+
MEDfl-2.0.4.dev1.dist-info/METADATA,sha256=rvSMS_MZCyEDMkL54d4tRwWtSCa3AIQ3wuj6yZbszOc,4326
|
60
|
+
MEDfl-2.0.4.dev1.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
61
|
+
MEDfl-2.0.4.dev1.dist-info/top_level.txt,sha256=dIL9X8HOFuaVSzpg40DVveDPrymWRoShHtspH7kkjdI,14
|
62
|
+
MEDfl-2.0.4.dev1.dist-info/RECORD,,
|
File without changes
|
File without changes
|