MEDfl 2.0.1__py3-none-any.whl → 2.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
MEDfl/rw/__init__.py ADDED
File without changes
MEDfl/rw/client.py ADDED
@@ -0,0 +1,39 @@
1
+ import flwr as fl
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from model import Net
6
+
7
+ # Dummy training data
8
+ X_train = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
9
+ y_train = torch.tensor([[1.0], [0.0]])
10
+
11
+ class FlowerClient(fl.client.NumPyClient):
12
+ def __init__(self):
13
+ self.model = Net()
14
+ self.loss_fn = nn.MSELoss()
15
+ self.optimizer = optim.SGD(self.model.parameters(), lr=0.01)
16
+
17
+ def get_parameters(self, config):
18
+ return [val.cpu().numpy() for val in self.model.state_dict().values()]
19
+
20
+ def set_parameters(self, parameters):
21
+ params_dict = zip(self.model.state_dict().keys(), parameters)
22
+ state_dict = {k: torch.tensor(v) for k, v in params_dict}
23
+ self.model.load_state_dict(state_dict, strict=True)
24
+
25
+ def fit(self, parameters, config):
26
+ self.set_parameters(parameters)
27
+ self.model.train()
28
+ for _ in range(5):
29
+ self.optimizer.zero_grad()
30
+ output = self.model(X_train)
31
+ loss = self.loss_fn(output, y_train)
32
+ loss.backward()
33
+ self.optimizer.step()
34
+ return self.get_parameters(config), len(X_train), {}
35
+
36
+ def evaluate(self, parameters, config):
37
+ return 0.5, len(X_train), {}
38
+
39
+ fl.client.start_numpy_client(server_address="100.65.215.27:8080", client=FlowerClient())
MEDfl/rw/model.py ADDED
@@ -0,0 +1,75 @@
1
+ # client.py
2
+ import argparse
3
+ import pandas as pd
4
+ import flwr as fl
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.optim as optim
8
+ from MEDfl.rw.model import Net # your model definition in model.py
9
+
10
+ class FlowerClient(fl.client.NumPyClient):
11
+ def __init__(self, server_address: str, data_path: str = "data/data.csv"):
12
+ self.server_address = server_address
13
+
14
+ # 1. Load model
15
+ self.model = Net()
16
+ self.loss_fn = nn.MSELoss()
17
+ self.optimizer = optim.SGD(self.model.parameters(), lr=0.01)
18
+
19
+ # 2. Load data from CSV
20
+ df = pd.read_csv(data_path)
21
+ # Assume last column is label
22
+ X = df.iloc[:, :-1].values
23
+ y = df.iloc[:, -1].values
24
+
25
+ self.X_train = torch.tensor(X, dtype=torch.float32)
26
+ # If it's regression with single output; remove unsqueeze for multi-class
27
+ self.y_train = torch.tensor(y, dtype=torch.float32).unsqueeze(1)
28
+
29
+ def get_parameters(self, config):
30
+ return [val.cpu().numpy() for val in self.model.state_dict().values()]
31
+
32
+ def set_parameters(self, parameters):
33
+ params_dict = zip(self.model.state_dict().keys(), parameters)
34
+ state_dict = {k: torch.tensor(v) for k, v in params_dict}
35
+ self.model.load_state_dict(state_dict, strict=True)
36
+
37
+ def fit(self, parameters, config):
38
+ self.set_parameters(parameters)
39
+ self.model.train()
40
+ for _ in range(5):
41
+ self.optimizer.zero_grad()
42
+ preds = self.model(self.X_train)
43
+ loss = self.loss_fn(preds, self.y_train)
44
+ loss.backward()
45
+ self.optimizer.step()
46
+ # Return updated params, number of examples, and an empty metrics dict
47
+ return self.get_parameters(config), len(self.X_train), {}
48
+
49
+ def evaluate(self, parameters, config):
50
+ self.set_parameters(parameters)
51
+ self.model.eval()
52
+ with torch.no_grad():
53
+ preds = self.model(self.X_train)
54
+ loss = self.loss_fn(preds, self.y_train).item()
55
+ return float(loss), len(self.X_train), {}
56
+
57
+ if __name__ == "__main__":
58
+ parser = argparse.ArgumentParser(description="Flower client")
59
+ parser.add_argument(
60
+ "--server_address",
61
+ type=str,
62
+ required=True,
63
+ help="Address of the Flower server (e.g., 127.0.0.1:8080)",
64
+ )
65
+ parser.add_argument(
66
+ "--data_path",
67
+ type=str,
68
+ default="data/data.csv",
69
+ help="Path to your CSV training data",
70
+ )
71
+ args = parser.parse_args()
72
+
73
+ # Instantiate and start the client
74
+ client = FlowerClient(server_address=args.server_address, data_path=args.data_path)
75
+ fl.client.start_numpy_client(server_address=client.server_address, client=client)
MEDfl/rw/rwConfig.py ADDED
@@ -0,0 +1,21 @@
1
+ from dataclasses import dataclass
2
+
3
+ @dataclass
4
+ class RealWorldConfig:
5
+ """
6
+ Configuration pour un déploiement fédéré en « real world ».
7
+
8
+ Attributes:
9
+ server_address: Adresse et port du serveur Flower (ex: "0.0.0.0:8080").
10
+ num_rounds: Nombre total de tours (rounds) de fédération.
11
+ fraction_fit: Fraction des clients participant à la phase de fit chaque round.
12
+ fraction_eval: Fraction des clients participant à la phase d'évaluation chaque round.
13
+ min_fit_clients: Nombre minimum de clients requis pour lancer la phase de fit.
14
+ min_eval_clients: Nombre minimum de clients requis pour la phase d'évaluation.
15
+ """
16
+ server_address: str
17
+ num_rounds: int
18
+ fraction_fit: float
19
+ fraction_eval: float
20
+ min_fit_clients: int
21
+ min_eval_clients: int
MEDfl/rw/server.py ADDED
@@ -0,0 +1,42 @@
1
+
2
+ import flwr as fl
3
+ from flwr.server.strategy import FedAvg
4
+ from flwr.server.server import ServerConfig
5
+ from typing import Optional, Any
6
+ from MEDfl.rw.strategy import Strategy
7
+
8
+ class FederatedServer:
9
+ """
10
+ Wrapper for launching a Flower federated-learning server,
11
+ using a Strategy instance as its strategy attribute.
12
+ """
13
+ def __init__(
14
+ self,
15
+ host: str = "0.0.0.0",
16
+ port: int = 8080,
17
+ num_rounds: int = 3,
18
+ strategy: Optional[Strategy] = None,
19
+ certificates: Optional[Any] = None,
20
+ ):
21
+ self.server_address = f"{host}:{port}"
22
+ self.server_config = ServerConfig(num_rounds=num_rounds)
23
+ # If no custom strategy provided, use default
24
+ self.strategy_wrapper = strategy or Strategy()
25
+ # Build the actual Flower strategy object
26
+ self.strategy_wrapper.create_strategy()
27
+ if self.strategy_wrapper.strategy_object is None:
28
+ raise ValueError("Strategy object not initialized. Call create_strategy() first.")
29
+ self.strategy = self.strategy_wrapper.strategy_object
30
+ self.certificates = certificates
31
+
32
+ def start(self) -> None:
33
+ """
34
+ Start the Flower server with the configured strategy.
35
+ """
36
+ print(f"Starting Flower server on {self.server_address} with strategy {self.strategy_wrapper.name}")
37
+ fl.server.start_server(
38
+ server_address=self.server_address,
39
+ config=self.server_config,
40
+ strategy=self.strategy,
41
+ certificates=self.certificates,
42
+ )
MEDfl/rw/strategy.py ADDED
@@ -0,0 +1,122 @@
1
+ import flwr as fl
2
+ from typing import Callable, Optional, Dict, Any, Tuple, List
3
+
4
+ # Custom aggregation for client-returned metrics
5
+
6
+ from typing import List, Tuple, Dict
7
+
8
+ # Custom aggregation for client-returned metrics
9
+ def aggregate_fit_metrics(
10
+ results: List[Tuple[int, Dict[str, float]]]
11
+ ) -> Dict[str, float]:
12
+ """
13
+ Weighted aggregation of training metrics across clients.
14
+ Expects each tuple (num_examples, metrics_dict) to include
15
+ 'train_loss', 'train_accuracy', 'train_auc'.
16
+ """
17
+ # Sum total examples
18
+ total_examples = sum(num_examples for num_examples, _ in results)
19
+ # Weighted averages
20
+ loss = sum(metrics.get("train_loss", 0.0) * num_examples
21
+ for num_examples, metrics in results) / total_examples
22
+ accuracy = sum(metrics.get("train_accuracy", 0.0) * num_examples
23
+ for num_examples, metrics in results) / total_examples
24
+ auc = sum(metrics.get("train_auc", 0.0) * num_examples
25
+ for num_examples, metrics in results) / total_examples
26
+ return {"train_loss": loss, "train_accuracy": accuracy, "train_auc": auc}
27
+
28
+
29
+ def aggregate_eval_metrics(
30
+ results: List[Tuple[int, Dict[str, float]]]
31
+ ) -> Dict[str, float]:
32
+ """
33
+ Weighted aggregation of evaluation metrics across clients.
34
+ Expects each tuple (num_examples, metrics_dict) to include
35
+ 'eval_loss', 'eval_accuracy', 'eval_auc'.
36
+ """
37
+ total_examples = sum(num_examples for num_examples, _ in results)
38
+ loss = sum(metrics.get("eval_loss", 0.0) * num_examples
39
+ for num_examples, metrics in results) / total_examples
40
+ accuracy = sum(metrics.get("eval_accuracy", 0.0) * num_examples
41
+ for num_examples, metrics in results) / total_examples
42
+ auc = sum(metrics.get("eval_auc", 0.0) * num_examples
43
+ for num_examples, metrics in results) / total_examples
44
+ return {"eval_loss": loss, "eval_accuracy": accuracy, "eval_auc": auc}
45
+
46
+ class Strategy:
47
+ """
48
+ A wrapper for Flower server strategies, with custom metric aggregation
49
+ and console logs on aggregation/evaluation completion.
50
+ """
51
+ def __init__(
52
+ self,
53
+ name: str = "FedAvg",
54
+ fraction_fit: float = 1.0,
55
+ fraction_evaluate: float = 1.0,
56
+ min_fit_clients: int = 2,
57
+ min_evaluate_clients: int = 2,
58
+ min_available_clients: int = 2,
59
+ initial_parameters: Optional[List[Any]] = None,
60
+ evaluate_fn: Optional[Callable[
61
+ [int, fl.common.Parameters, Dict[str, Any]],
62
+ Tuple[float, Dict[str, float]]
63
+ ]] = None,
64
+ fit_metrics_aggregation_fn: Optional[
65
+ Callable[[List[Tuple[int, fl.common.FitRes]]], Dict[str, float]]
66
+ ] = None,
67
+ evaluate_metrics_aggregation_fn: Optional[
68
+ Callable[[List[Tuple[int, fl.common.EvaluateRes]]], Dict[str, float]]
69
+ ] = None,
70
+ ) -> None:
71
+ self.name = name
72
+ self.fraction_fit = fraction_fit
73
+ self.fraction_evaluate = fraction_evaluate
74
+ self.min_fit_clients = min_fit_clients
75
+ self.min_evaluate_clients = min_evaluate_clients
76
+ self.min_available_clients = min_available_clients
77
+ self.initial_parameters = initial_parameters or []
78
+ self.evaluate_fn = evaluate_fn
79
+ # Use custom aggregators if provided, else default to ours
80
+ self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn or aggregate_fit_metrics
81
+ self.evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn or aggregate_eval_metrics
82
+ self.strategy_object: Optional[fl.server.strategy.Strategy] = None
83
+
84
+ def create_strategy(self) -> None:
85
+ # 1) Instantiate the underlying Flower strategy
86
+ StrategyClass = getattr(fl.server.strategy, self.name)
87
+ params: Dict[str, Any] = {
88
+ "fraction_fit": self.fraction_fit,
89
+ "fraction_evaluate": self.fraction_evaluate,
90
+ "min_fit_clients": self.min_fit_clients,
91
+ "min_evaluate_clients": self.min_evaluate_clients,
92
+ "min_available_clients": self.min_available_clients,
93
+ "evaluate_fn": self.evaluate_fn,
94
+ # Plug in our custom aggregators
95
+ "fit_metrics_aggregation_fn": self.fit_metrics_aggregation_fn,
96
+ "evaluate_metrics_aggregation_fn": self.evaluate_metrics_aggregation_fn,
97
+ }
98
+ if self.initial_parameters:
99
+ params["initial_parameters"] = fl.common.ndarrays_to_parameters(
100
+ self.initial_parameters
101
+ )
102
+
103
+ strat = StrategyClass(**params)
104
+
105
+ # 2) Wrap aggregate_fit to log
106
+ original_agg_fit = strat.aggregate_fit
107
+ def logged_aggregate_fit(rnd, results, failures):
108
+ aggregated_params, metrics = original_agg_fit(rnd, results, failures)
109
+ print(f"[Server] ✔ Round {rnd} fit complete → Metrics: {metrics}")
110
+ return aggregated_params, metrics
111
+ strat.aggregate_fit = logged_aggregate_fit # type: ignore
112
+
113
+ # 3) Wrap aggregate_evaluate to log
114
+ original_agg_eval = strat.aggregate_evaluate
115
+ def logged_aggregate_evaluate(rnd, results, failures):
116
+ loss, metrics = original_agg_eval(rnd, results, failures)
117
+ print(results)
118
+ print(f"[Server] ✔ Round {rnd} eval complete → Loss: {loss}, Metrics: {metrics}")
119
+ return loss, metrics
120
+ strat.aggregate_evaluate = logged_aggregate_evaluate # type: ignore
121
+
122
+ self.strategy_object = strat
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: MEDfl
3
- Version: 2.0.1
3
+ Version: 2.0.3
4
4
  Summary: Python Open-source package for simulating federated learning and differential privacy
5
5
  Home-page: https://github.com/MEDomics-UdeS/MEDfl
6
6
  Author: MEDomics consortium
@@ -19,6 +19,12 @@ MEDfl/NetManager/net_helper.py,sha256=tyfxmpbleSdfPfo2ezKT0VOvZu660v9nhBuHCpl8pG
19
19
  MEDfl/NetManager/net_manager_queries.py,sha256=j-CLQPjtTLyZuFPhIcwJStD7L7xtZpkmkhe_h3pDuTs,4086
20
20
  MEDfl/NetManager/network.py,sha256=5t705fzWc-BRg-QPAbAcDv5ckDGzsPwj_Q5V0iTgkx0,6829
21
21
  MEDfl/NetManager/node.py,sha256=t90QuYZ8M1X_AG1bwTta0CnlOuodqkmpVda2K7NOgHc,6542
22
+ MEDfl/rw/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
+ MEDfl/rw/client.py,sha256=k8y8Wxh2KNe2oy5gRD-KXpTEGCYzp7X2oF5-Z6Rk1_E,1329
24
+ MEDfl/rw/model.py,sha256=a3mrACDqg4K1V4Qyhh0PBEmWL5SpUYsiEarWHwsVcGk,2704
25
+ MEDfl/rw/rwConfig.py,sha256=nK3Inv7v7Dm9gZnUnK5EqA4DmQ7TqiH4UoCZ8MlgFjA,823
26
+ MEDfl/rw/server.py,sha256=JTexQ5KVOrXWmGOMoLstiVrwUNDHaEhWYnImvAF1Fiw,1557
27
+ MEDfl/rw/strategy.py,sha256=NvpDpuU5_4Xv9vYyfvGLeaHaYd9h7V-b330G-PgPFCE,5469
22
28
  MEDfl/scripts/__init__.py,sha256=Pq1weevsPaU7MRMHfBYeyT0EOFeWLeVM6Y1DVz6jw1A,48
23
29
  MEDfl/scripts/base.py,sha256=QrmG7gkiPYkAy-5tXxJgJmOSLGAKeIVH6i0jq7G9xnA,752
24
30
  MEDfl/scripts/create_db.py,sha256=MnFtZkTueRZ-3qXPNX4JsXjOKj-4mlkxoRhSFdRcvJw,3817
@@ -48,8 +54,8 @@ Medfl/scripts/base.py,sha256=QrmG7gkiPYkAy-5tXxJgJmOSLGAKeIVH6i0jq7G9xnA,752
48
54
  Medfl/scripts/create_db.py,sha256=MnFtZkTueRZ-3qXPNX4JsXjOKj-4mlkxoRhSFdRcvJw,3817
49
55
  alembic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
50
56
  alembic/env.py,sha256=-aSZ6SlJeK1ZeqHgM-54hOi9LhJRFP0SZGjut-JnY-4,1588
51
- medfl-2.0.1.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
52
- medfl-2.0.1.dist-info/METADATA,sha256=zHjfr88Etr5-fG8de55QinubpfHZBoSl15_iF-peN64,4579
53
- medfl-2.0.1.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
54
- medfl-2.0.1.dist-info/top_level.txt,sha256=dIL9X8HOFuaVSzpg40DVveDPrymWRoShHtspH7kkjdI,14
55
- medfl-2.0.1.dist-info/RECORD,,
57
+ medfl-2.0.3.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
58
+ medfl-2.0.3.dist-info/METADATA,sha256=U9hqDT2qdSxEyYvtVFKQmjjuHx51XJ-RaKmg8ruzu_M,4579
59
+ medfl-2.0.3.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
60
+ medfl-2.0.3.dist-info/top_level.txt,sha256=dIL9X8HOFuaVSzpg40DVveDPrymWRoShHtspH7kkjdI,14
61
+ medfl-2.0.3.dist-info/RECORD,,
File without changes