MEDfl 2.0.4.dev0__py3-none-any.whl → 2.0.4.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
MEDfl/rw/model.py CHANGED
@@ -1,75 +1,19 @@
1
- # client.py
2
- import argparse
3
- import pandas as pd
4
- import flwr as fl
5
- import torch
6
1
  import torch.nn as nn
7
- import torch.optim as optim
8
- from MEDfl.rw.model import Net # your model definition in model.py
9
-
10
- class FlowerClient(fl.client.NumPyClient):
11
- def __init__(self, server_address: str, data_path: str = "data/data.csv"):
12
- self.server_address = server_address
13
-
14
- # 1. Load model
15
- self.model = Net()
16
- self.loss_fn = nn.MSELoss()
17
- self.optimizer = optim.SGD(self.model.parameters(), lr=0.01)
18
-
19
- # 2. Load data from CSV
20
- df = pd.read_csv(data_path)
21
- # Assume last column is label
22
- X = df.iloc[:, :-1].values
23
- y = df.iloc[:, -1].values
24
-
25
- self.X_train = torch.tensor(X, dtype=torch.float32)
26
- # If it's regression with single output; remove unsqueeze for multi-class
27
- self.y_train = torch.tensor(y, dtype=torch.float32).unsqueeze(1)
28
-
29
- def get_parameters(self, config):
30
- return [val.cpu().numpy() for val in self.model.state_dict().values()]
31
-
32
- def set_parameters(self, parameters):
33
- params_dict = zip(self.model.state_dict().keys(), parameters)
34
- state_dict = {k: torch.tensor(v) for k, v in params_dict}
35
- self.model.load_state_dict(state_dict, strict=True)
36
-
37
- def fit(self, parameters, config):
38
- self.set_parameters(parameters)
39
- self.model.train()
40
- for _ in range(5):
41
- self.optimizer.zero_grad()
42
- preds = self.model(self.X_train)
43
- loss = self.loss_fn(preds, self.y_train)
44
- loss.backward()
45
- self.optimizer.step()
46
- # Return updated params, number of examples, and an empty metrics dict
47
- return self.get_parameters(config), len(self.X_train), {}
48
-
49
- def evaluate(self, parameters, config):
50
- self.set_parameters(parameters)
51
- self.model.eval()
52
- with torch.no_grad():
53
- preds = self.model(self.X_train)
54
- loss = self.loss_fn(preds, self.y_train).item()
55
- return float(loss), len(self.X_train), {}
56
-
57
- if __name__ == "__main__":
58
- parser = argparse.ArgumentParser(description="Flower client")
59
- parser.add_argument(
60
- "--server_address",
61
- type=str,
62
- required=True,
63
- help="Address of the Flower server (e.g., 127.0.0.1:8080)",
64
- )
65
- parser.add_argument(
66
- "--data_path",
67
- type=str,
68
- default="data/data.csv",
69
- help="Path to your CSV training data",
70
- )
71
- args = parser.parse_args()
72
-
73
- # Instantiate and start the client
74
- client = FlowerClient(server_address=args.server_address, data_path=args.data_path)
75
- fl.client.start_numpy_client(server_address=client.server_address, client=client)
2
+ import torch.nn.functional as F
3
+
4
+ class Net(nn.Module):
5
+ def __init__(self, input_dim):
6
+ super().__init__()
7
+ self.fc1 = nn.Linear(input_dim, 64)
8
+ self.fc2 = nn.Linear(64, 32)
9
+ self.fc3 = nn.Linear(32, 1)
10
+ self.dropout = nn.Dropout(0.3)
11
+ self.batchnorm1 = nn.BatchNorm1d(64)
12
+ self.batchnorm2 = nn.BatchNorm1d(32)
13
+
14
+ def forward(self, x):
15
+ x = F.relu(self.batchnorm1(self.fc1(x)))
16
+ x = self.dropout(x)
17
+ x = F.relu(self.batchnorm2(self.fc2(x)))
18
+ x = self.dropout(x)
19
+ return self.fc3(x) # raw logits for BCEWithLogitsLoss
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.4
1
+ Metadata-Version: 2.1
2
2
  Name: MEDfl
3
- Version: 2.0.4.dev0
3
+ Version: 2.0.4.dev1
4
4
  Summary: Python Open-source package for simulating federated learning and differential privacy
5
5
  Home-page: https://github.com/MEDomics-UdeS/MEDfl
6
6
  Author: MEDomics consortium
@@ -35,18 +35,6 @@ Requires-Dist: plotly==5.19.0
35
35
  Requires-Dist: optuna==3.5.0
36
36
  Requires-Dist: mysql-connector-python~=9.3.0
37
37
  Requires-Dist: seaborn~=0.13.2
38
- Dynamic: author
39
- Dynamic: author-email
40
- Dynamic: classifier
41
- Dynamic: description
42
- Dynamic: description-content-type
43
- Dynamic: home-page
44
- Dynamic: keywords
45
- Dynamic: license-file
46
- Dynamic: project-url
47
- Dynamic: requires-dist
48
- Dynamic: requires-python
49
- Dynamic: summary
50
38
 
51
39
  # MEDfl: Federated Learning and Differential Privacy Simulation Tool for Tabular Data
52
40
  ![Python Versions](https://img.shields.io/badge/python-3.9-blue)
@@ -21,7 +21,7 @@ MEDfl/NetManager/network.py,sha256=5t705fzWc-BRg-QPAbAcDv5ckDGzsPwj_Q5V0iTgkx0,6
21
21
  MEDfl/NetManager/node.py,sha256=t90QuYZ8M1X_AG1bwTta0CnlOuodqkmpVda2K7NOgHc,6542
22
22
  MEDfl/rw/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
23
  MEDfl/rw/client.py,sha256=k8y8Wxh2KNe2oy5gRD-KXpTEGCYzp7X2oF5-Z6Rk1_E,1329
24
- MEDfl/rw/model.py,sha256=a3mrACDqg4K1V4Qyhh0PBEmWL5SpUYsiEarWHwsVcGk,2704
24
+ MEDfl/rw/model.py,sha256=TKCfE4nYx75uQdgABwEMkb_ynT-xS_MxNPbGAzJ3EcQ,629
25
25
  MEDfl/rw/rwConfig.py,sha256=nK3Inv7v7Dm9gZnUnK5EqA4DmQ7TqiH4UoCZ8MlgFjA,823
26
26
  MEDfl/rw/server.py,sha256=PiCrUTlnx7rVcO9DcT-vnJF5WkOCe4eEzWXeRSUBh10,3286
27
27
  MEDfl/rw/strategy.py,sha256=sUwu0aAq6q3sKnfRimCRfps3be8s2iepGoD9NfcyjXI,6233
@@ -55,8 +55,8 @@ Medfl/scripts/base.py,sha256=QrmG7gkiPYkAy-5tXxJgJmOSLGAKeIVH6i0jq7G9xnA,752
55
55
  Medfl/scripts/create_db.py,sha256=MnFtZkTueRZ-3qXPNX4JsXjOKj-4mlkxoRhSFdRcvJw,3817
56
56
  alembic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
57
57
  alembic/env.py,sha256=-aSZ6SlJeK1ZeqHgM-54hOi9LhJRFP0SZGjut-JnY-4,1588
58
- medfl-2.0.4.dev0.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
59
- medfl-2.0.4.dev0.dist-info/METADATA,sha256=JjQfuGkhohyJ3a1WkkKtn-GicXO4tSbCQSqxK2AaTQ4,4584
60
- medfl-2.0.4.dev0.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
61
- medfl-2.0.4.dev0.dist-info/top_level.txt,sha256=dIL9X8HOFuaVSzpg40DVveDPrymWRoShHtspH7kkjdI,14
62
- medfl-2.0.4.dev0.dist-info/RECORD,,
58
+ MEDfl-2.0.4.dev1.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
59
+ MEDfl-2.0.4.dev1.dist-info/METADATA,sha256=rvSMS_MZCyEDMkL54d4tRwWtSCa3AIQ3wuj6yZbszOc,4326
60
+ MEDfl-2.0.4.dev1.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
61
+ MEDfl-2.0.4.dev1.dist-info/top_level.txt,sha256=dIL9X8HOFuaVSzpg40DVveDPrymWRoShHtspH7kkjdI,14
62
+ MEDfl-2.0.4.dev1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.1)
2
+ Generator: bdist_wheel (0.45.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5