MEDfl 2.0.4__py3-none-any.whl → 2.0.4.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
MEDfl/rw/model.py CHANGED
@@ -1,75 +1,19 @@
1
- # client.py
2
- import argparse
3
- import pandas as pd
4
- import flwr as fl
5
- import torch
6
1
  import torch.nn as nn
7
- import torch.optim as optim
8
- from MEDfl.rw.model import Net # your model definition in model.py
9
-
10
- class FlowerClient(fl.client.NumPyClient):
11
- def __init__(self, server_address: str, data_path: str = "data/data.csv"):
12
- self.server_address = server_address
13
-
14
- # 1. Load model
15
- self.model = Net()
16
- self.loss_fn = nn.MSELoss()
17
- self.optimizer = optim.SGD(self.model.parameters(), lr=0.01)
18
-
19
- # 2. Load data from CSV
20
- df = pd.read_csv(data_path)
21
- # Assume last column is label
22
- X = df.iloc[:, :-1].values
23
- y = df.iloc[:, -1].values
24
-
25
- self.X_train = torch.tensor(X, dtype=torch.float32)
26
- # If it's regression with single output; remove unsqueeze for multi-class
27
- self.y_train = torch.tensor(y, dtype=torch.float32).unsqueeze(1)
28
-
29
- def get_parameters(self, config):
30
- return [val.cpu().numpy() for val in self.model.state_dict().values()]
31
-
32
- def set_parameters(self, parameters):
33
- params_dict = zip(self.model.state_dict().keys(), parameters)
34
- state_dict = {k: torch.tensor(v) for k, v in params_dict}
35
- self.model.load_state_dict(state_dict, strict=True)
36
-
37
- def fit(self, parameters, config):
38
- self.set_parameters(parameters)
39
- self.model.train()
40
- for _ in range(5):
41
- self.optimizer.zero_grad()
42
- preds = self.model(self.X_train)
43
- loss = self.loss_fn(preds, self.y_train)
44
- loss.backward()
45
- self.optimizer.step()
46
- # Return updated params, number of examples, and an empty metrics dict
47
- return self.get_parameters(config), len(self.X_train), {}
48
-
49
- def evaluate(self, parameters, config):
50
- self.set_parameters(parameters)
51
- self.model.eval()
52
- with torch.no_grad():
53
- preds = self.model(self.X_train)
54
- loss = self.loss_fn(preds, self.y_train).item()
55
- return float(loss), len(self.X_train), {}
56
-
57
- if __name__ == "__main__":
58
- parser = argparse.ArgumentParser(description="Flower client")
59
- parser.add_argument(
60
- "--server_address",
61
- type=str,
62
- required=True,
63
- help="Address of the Flower server (e.g., 127.0.0.1:8080)",
64
- )
65
- parser.add_argument(
66
- "--data_path",
67
- type=str,
68
- default="data/data.csv",
69
- help="Path to your CSV training data",
70
- )
71
- args = parser.parse_args()
72
-
73
- # Instantiate and start the client
74
- client = FlowerClient(server_address=args.server_address, data_path=args.data_path)
75
- fl.client.start_numpy_client(server_address=client.server_address, client=client)
2
+ import torch.nn.functional as F
3
+
4
+ class Net(nn.Module):
5
+ def __init__(self, input_dim):
6
+ super().__init__()
7
+ self.fc1 = nn.Linear(input_dim, 64)
8
+ self.fc2 = nn.Linear(64, 32)
9
+ self.fc3 = nn.Linear(32, 1)
10
+ self.dropout = nn.Dropout(0.3)
11
+ self.batchnorm1 = nn.BatchNorm1d(64)
12
+ self.batchnorm2 = nn.BatchNorm1d(32)
13
+
14
+ def forward(self, x):
15
+ x = F.relu(self.batchnorm1(self.fc1(x)))
16
+ x = self.dropout(x)
17
+ x = F.relu(self.batchnorm2(self.fc2(x)))
18
+ x = self.dropout(x)
19
+ return self.fc3(x) # raw logits for BCEWithLogitsLoss
MEDfl/rw/strategy.py CHANGED
@@ -108,7 +108,7 @@ class Strategy:
108
108
  # Print individual client fit metrics
109
109
  print(f"\n[Server] 🔄 Round {rnd} - Client Training Metrics:")
110
110
  for i, (client_id, fit_res) in enumerate(results):
111
- print(f" Client {client_id.cid}: {fit_res.metrics}")
111
+ print(f" CTM Round {rnd} Client:{client_id.cid}: {fit_res.metrics}")
112
112
 
113
113
  # Call original aggregation function
114
114
  aggregated_params, metrics = original_agg_fit(rnd, results, failures)
@@ -125,7 +125,7 @@ class Strategy:
125
125
  # Print individual client evaluation metrics
126
126
  print(f"\n[Server] 📊 Round {rnd} - Client Evaluation Metrics:")
127
127
  for i, (client_id, eval_res) in enumerate(results):
128
- print(f" Client {client_id.cid}: {eval_res.metrics}")
128
+ print(f" CEM Round {rnd} Client:{client_id.cid}: {eval_res.metrics}")
129
129
 
130
130
  # Call original aggregation function
131
131
  loss, metrics = original_agg_eval(rnd, results, failures)
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.4
1
+ Metadata-Version: 2.1
2
2
  Name: MEDfl
3
- Version: 2.0.4
3
+ Version: 2.0.4.dev1
4
4
  Summary: Python Open-source package for simulating federated learning and differential privacy
5
5
  Home-page: https://github.com/MEDomics-UdeS/MEDfl
6
6
  Author: MEDomics consortium
@@ -35,18 +35,6 @@ Requires-Dist: plotly==5.19.0
35
35
  Requires-Dist: optuna==3.5.0
36
36
  Requires-Dist: mysql-connector-python~=9.3.0
37
37
  Requires-Dist: seaborn~=0.13.2
38
- Dynamic: author
39
- Dynamic: author-email
40
- Dynamic: classifier
41
- Dynamic: description
42
- Dynamic: description-content-type
43
- Dynamic: home-page
44
- Dynamic: keywords
45
- Dynamic: license-file
46
- Dynamic: project-url
47
- Dynamic: requires-dist
48
- Dynamic: requires-python
49
- Dynamic: summary
50
38
 
51
39
  # MEDfl: Federated Learning and Differential Privacy Simulation Tool for Tabular Data
52
40
  ![Python Versions](https://img.shields.io/badge/python-3.9-blue)
@@ -21,10 +21,10 @@ MEDfl/NetManager/network.py,sha256=5t705fzWc-BRg-QPAbAcDv5ckDGzsPwj_Q5V0iTgkx0,6
21
21
  MEDfl/NetManager/node.py,sha256=t90QuYZ8M1X_AG1bwTta0CnlOuodqkmpVda2K7NOgHc,6542
22
22
  MEDfl/rw/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
23
  MEDfl/rw/client.py,sha256=k8y8Wxh2KNe2oy5gRD-KXpTEGCYzp7X2oF5-Z6Rk1_E,1329
24
- MEDfl/rw/model.py,sha256=a3mrACDqg4K1V4Qyhh0PBEmWL5SpUYsiEarWHwsVcGk,2704
24
+ MEDfl/rw/model.py,sha256=TKCfE4nYx75uQdgABwEMkb_ynT-xS_MxNPbGAzJ3EcQ,629
25
25
  MEDfl/rw/rwConfig.py,sha256=nK3Inv7v7Dm9gZnUnK5EqA4DmQ7TqiH4UoCZ8MlgFjA,823
26
26
  MEDfl/rw/server.py,sha256=PiCrUTlnx7rVcO9DcT-vnJF5WkOCe4eEzWXeRSUBh10,3286
27
- MEDfl/rw/strategy.py,sha256=DbgitZzW9Hu7rUnkHdRCq8sKkOGLjMMHPsHRIEVZXwM,6203
27
+ MEDfl/rw/strategy.py,sha256=sUwu0aAq6q3sKnfRimCRfps3be8s2iepGoD9NfcyjXI,6233
28
28
  MEDfl/rw/verbose_server.py,sha256=B_abnpCy43e3YrjotLFOm7cLiuiB5PSTeXD5sMP0CxA,851
29
29
  MEDfl/scripts/__init__.py,sha256=Pq1weevsPaU7MRMHfBYeyT0EOFeWLeVM6Y1DVz6jw1A,48
30
30
  MEDfl/scripts/base.py,sha256=QrmG7gkiPYkAy-5tXxJgJmOSLGAKeIVH6i0jq7G9xnA,752
@@ -55,8 +55,8 @@ Medfl/scripts/base.py,sha256=QrmG7gkiPYkAy-5tXxJgJmOSLGAKeIVH6i0jq7G9xnA,752
55
55
  Medfl/scripts/create_db.py,sha256=MnFtZkTueRZ-3qXPNX4JsXjOKj-4mlkxoRhSFdRcvJw,3817
56
56
  alembic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
57
57
  alembic/env.py,sha256=-aSZ6SlJeK1ZeqHgM-54hOi9LhJRFP0SZGjut-JnY-4,1588
58
- medfl-2.0.4.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
59
- medfl-2.0.4.dist-info/METADATA,sha256=hFoRqPRyKiqNCbBw3RcrYEh9fSRlAZChdEGaoNrLIGQ,4579
60
- medfl-2.0.4.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
61
- medfl-2.0.4.dist-info/top_level.txt,sha256=dIL9X8HOFuaVSzpg40DVveDPrymWRoShHtspH7kkjdI,14
62
- medfl-2.0.4.dist-info/RECORD,,
58
+ MEDfl-2.0.4.dev1.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
59
+ MEDfl-2.0.4.dev1.dist-info/METADATA,sha256=rvSMS_MZCyEDMkL54d4tRwWtSCa3AIQ3wuj6yZbszOc,4326
60
+ MEDfl-2.0.4.dev1.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
61
+ MEDfl-2.0.4.dev1.dist-info/top_level.txt,sha256=dIL9X8HOFuaVSzpg40DVveDPrymWRoShHtspH7kkjdI,14
62
+ MEDfl-2.0.4.dev1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.1)
2
+ Generator: bdist_wheel (0.45.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5