MEDfl 2.0.1__py3-none-any.whl → 2.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- MEDfl/rw/__init__.py +0 -0
- MEDfl/rw/client.py +39 -0
- MEDfl/rw/model.py +75 -0
- MEDfl/rw/rwConfig.py +21 -0
- MEDfl/rw/server.py +42 -0
- MEDfl/rw/strategy.py +122 -0
- {medfl-2.0.1.dist-info → medfl-2.0.3.dist-info}/METADATA +1 -1
- {medfl-2.0.1.dist-info → medfl-2.0.3.dist-info}/RECORD +11 -5
- {medfl-2.0.1.dist-info → medfl-2.0.3.dist-info}/WHEEL +0 -0
- {medfl-2.0.1.dist-info → medfl-2.0.3.dist-info}/licenses/LICENSE +0 -0
- {medfl-2.0.1.dist-info → medfl-2.0.3.dist-info}/top_level.txt +0 -0
MEDfl/rw/__init__.py
ADDED
File without changes
|
MEDfl/rw/client.py
ADDED
@@ -0,0 +1,39 @@
|
|
1
|
+
import flwr as fl
|
2
|
+
import torch
|
3
|
+
import torch.nn as nn
|
4
|
+
import torch.optim as optim
|
5
|
+
from model import Net
|
6
|
+
|
7
|
+
# Dummy training data
|
8
|
+
X_train = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
|
9
|
+
y_train = torch.tensor([[1.0], [0.0]])
|
10
|
+
|
11
|
+
class FlowerClient(fl.client.NumPyClient):
|
12
|
+
def __init__(self):
|
13
|
+
self.model = Net()
|
14
|
+
self.loss_fn = nn.MSELoss()
|
15
|
+
self.optimizer = optim.SGD(self.model.parameters(), lr=0.01)
|
16
|
+
|
17
|
+
def get_parameters(self, config):
|
18
|
+
return [val.cpu().numpy() for val in self.model.state_dict().values()]
|
19
|
+
|
20
|
+
def set_parameters(self, parameters):
|
21
|
+
params_dict = zip(self.model.state_dict().keys(), parameters)
|
22
|
+
state_dict = {k: torch.tensor(v) for k, v in params_dict}
|
23
|
+
self.model.load_state_dict(state_dict, strict=True)
|
24
|
+
|
25
|
+
def fit(self, parameters, config):
|
26
|
+
self.set_parameters(parameters)
|
27
|
+
self.model.train()
|
28
|
+
for _ in range(5):
|
29
|
+
self.optimizer.zero_grad()
|
30
|
+
output = self.model(X_train)
|
31
|
+
loss = self.loss_fn(output, y_train)
|
32
|
+
loss.backward()
|
33
|
+
self.optimizer.step()
|
34
|
+
return self.get_parameters(config), len(X_train), {}
|
35
|
+
|
36
|
+
def evaluate(self, parameters, config):
|
37
|
+
return 0.5, len(X_train), {}
|
38
|
+
|
39
|
+
fl.client.start_numpy_client(server_address="100.65.215.27:8080", client=FlowerClient())
|
MEDfl/rw/model.py
ADDED
@@ -0,0 +1,75 @@
|
|
1
|
+
# client.py
|
2
|
+
import argparse
|
3
|
+
import pandas as pd
|
4
|
+
import flwr as fl
|
5
|
+
import torch
|
6
|
+
import torch.nn as nn
|
7
|
+
import torch.optim as optim
|
8
|
+
from MEDfl.rw.model import Net # your model definition in model.py
|
9
|
+
|
10
|
+
class FlowerClient(fl.client.NumPyClient):
|
11
|
+
def __init__(self, server_address: str, data_path: str = "data/data.csv"):
|
12
|
+
self.server_address = server_address
|
13
|
+
|
14
|
+
# 1. Load model
|
15
|
+
self.model = Net()
|
16
|
+
self.loss_fn = nn.MSELoss()
|
17
|
+
self.optimizer = optim.SGD(self.model.parameters(), lr=0.01)
|
18
|
+
|
19
|
+
# 2. Load data from CSV
|
20
|
+
df = pd.read_csv(data_path)
|
21
|
+
# Assume last column is label
|
22
|
+
X = df.iloc[:, :-1].values
|
23
|
+
y = df.iloc[:, -1].values
|
24
|
+
|
25
|
+
self.X_train = torch.tensor(X, dtype=torch.float32)
|
26
|
+
# If it's regression with single output; remove unsqueeze for multi-class
|
27
|
+
self.y_train = torch.tensor(y, dtype=torch.float32).unsqueeze(1)
|
28
|
+
|
29
|
+
def get_parameters(self, config):
|
30
|
+
return [val.cpu().numpy() for val in self.model.state_dict().values()]
|
31
|
+
|
32
|
+
def set_parameters(self, parameters):
|
33
|
+
params_dict = zip(self.model.state_dict().keys(), parameters)
|
34
|
+
state_dict = {k: torch.tensor(v) for k, v in params_dict}
|
35
|
+
self.model.load_state_dict(state_dict, strict=True)
|
36
|
+
|
37
|
+
def fit(self, parameters, config):
|
38
|
+
self.set_parameters(parameters)
|
39
|
+
self.model.train()
|
40
|
+
for _ in range(5):
|
41
|
+
self.optimizer.zero_grad()
|
42
|
+
preds = self.model(self.X_train)
|
43
|
+
loss = self.loss_fn(preds, self.y_train)
|
44
|
+
loss.backward()
|
45
|
+
self.optimizer.step()
|
46
|
+
# Return updated params, number of examples, and an empty metrics dict
|
47
|
+
return self.get_parameters(config), len(self.X_train), {}
|
48
|
+
|
49
|
+
def evaluate(self, parameters, config):
|
50
|
+
self.set_parameters(parameters)
|
51
|
+
self.model.eval()
|
52
|
+
with torch.no_grad():
|
53
|
+
preds = self.model(self.X_train)
|
54
|
+
loss = self.loss_fn(preds, self.y_train).item()
|
55
|
+
return float(loss), len(self.X_train), {}
|
56
|
+
|
57
|
+
if __name__ == "__main__":
|
58
|
+
parser = argparse.ArgumentParser(description="Flower client")
|
59
|
+
parser.add_argument(
|
60
|
+
"--server_address",
|
61
|
+
type=str,
|
62
|
+
required=True,
|
63
|
+
help="Address of the Flower server (e.g., 127.0.0.1:8080)",
|
64
|
+
)
|
65
|
+
parser.add_argument(
|
66
|
+
"--data_path",
|
67
|
+
type=str,
|
68
|
+
default="data/data.csv",
|
69
|
+
help="Path to your CSV training data",
|
70
|
+
)
|
71
|
+
args = parser.parse_args()
|
72
|
+
|
73
|
+
# Instantiate and start the client
|
74
|
+
client = FlowerClient(server_address=args.server_address, data_path=args.data_path)
|
75
|
+
fl.client.start_numpy_client(server_address=client.server_address, client=client)
|
MEDfl/rw/rwConfig.py
ADDED
@@ -0,0 +1,21 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
|
3
|
+
@dataclass
|
4
|
+
class RealWorldConfig:
|
5
|
+
"""
|
6
|
+
Configuration pour un déploiement fédéré en « real world ».
|
7
|
+
|
8
|
+
Attributes:
|
9
|
+
server_address: Adresse et port du serveur Flower (ex: "0.0.0.0:8080").
|
10
|
+
num_rounds: Nombre total de tours (rounds) de fédération.
|
11
|
+
fraction_fit: Fraction des clients participant à la phase de fit chaque round.
|
12
|
+
fraction_eval: Fraction des clients participant à la phase d'évaluation chaque round.
|
13
|
+
min_fit_clients: Nombre minimum de clients requis pour lancer la phase de fit.
|
14
|
+
min_eval_clients: Nombre minimum de clients requis pour la phase d'évaluation.
|
15
|
+
"""
|
16
|
+
server_address: str
|
17
|
+
num_rounds: int
|
18
|
+
fraction_fit: float
|
19
|
+
fraction_eval: float
|
20
|
+
min_fit_clients: int
|
21
|
+
min_eval_clients: int
|
MEDfl/rw/server.py
ADDED
@@ -0,0 +1,42 @@
|
|
1
|
+
|
2
|
+
import flwr as fl
|
3
|
+
from flwr.server.strategy import FedAvg
|
4
|
+
from flwr.server.server import ServerConfig
|
5
|
+
from typing import Optional, Any
|
6
|
+
from MEDfl.rw.strategy import Strategy
|
7
|
+
|
8
|
+
class FederatedServer:
|
9
|
+
"""
|
10
|
+
Wrapper for launching a Flower federated-learning server,
|
11
|
+
using a Strategy instance as its strategy attribute.
|
12
|
+
"""
|
13
|
+
def __init__(
|
14
|
+
self,
|
15
|
+
host: str = "0.0.0.0",
|
16
|
+
port: int = 8080,
|
17
|
+
num_rounds: int = 3,
|
18
|
+
strategy: Optional[Strategy] = None,
|
19
|
+
certificates: Optional[Any] = None,
|
20
|
+
):
|
21
|
+
self.server_address = f"{host}:{port}"
|
22
|
+
self.server_config = ServerConfig(num_rounds=num_rounds)
|
23
|
+
# If no custom strategy provided, use default
|
24
|
+
self.strategy_wrapper = strategy or Strategy()
|
25
|
+
# Build the actual Flower strategy object
|
26
|
+
self.strategy_wrapper.create_strategy()
|
27
|
+
if self.strategy_wrapper.strategy_object is None:
|
28
|
+
raise ValueError("Strategy object not initialized. Call create_strategy() first.")
|
29
|
+
self.strategy = self.strategy_wrapper.strategy_object
|
30
|
+
self.certificates = certificates
|
31
|
+
|
32
|
+
def start(self) -> None:
|
33
|
+
"""
|
34
|
+
Start the Flower server with the configured strategy.
|
35
|
+
"""
|
36
|
+
print(f"Starting Flower server on {self.server_address} with strategy {self.strategy_wrapper.name}")
|
37
|
+
fl.server.start_server(
|
38
|
+
server_address=self.server_address,
|
39
|
+
config=self.server_config,
|
40
|
+
strategy=self.strategy,
|
41
|
+
certificates=self.certificates,
|
42
|
+
)
|
MEDfl/rw/strategy.py
ADDED
@@ -0,0 +1,122 @@
|
|
1
|
+
import flwr as fl
|
2
|
+
from typing import Callable, Optional, Dict, Any, Tuple, List
|
3
|
+
|
4
|
+
# Custom aggregation for client-returned metrics
|
5
|
+
|
6
|
+
from typing import List, Tuple, Dict
|
7
|
+
|
8
|
+
# Custom aggregation for client-returned metrics
|
9
|
+
def aggregate_fit_metrics(
|
10
|
+
results: List[Tuple[int, Dict[str, float]]]
|
11
|
+
) -> Dict[str, float]:
|
12
|
+
"""
|
13
|
+
Weighted aggregation of training metrics across clients.
|
14
|
+
Expects each tuple (num_examples, metrics_dict) to include
|
15
|
+
'train_loss', 'train_accuracy', 'train_auc'.
|
16
|
+
"""
|
17
|
+
# Sum total examples
|
18
|
+
total_examples = sum(num_examples for num_examples, _ in results)
|
19
|
+
# Weighted averages
|
20
|
+
loss = sum(metrics.get("train_loss", 0.0) * num_examples
|
21
|
+
for num_examples, metrics in results) / total_examples
|
22
|
+
accuracy = sum(metrics.get("train_accuracy", 0.0) * num_examples
|
23
|
+
for num_examples, metrics in results) / total_examples
|
24
|
+
auc = sum(metrics.get("train_auc", 0.0) * num_examples
|
25
|
+
for num_examples, metrics in results) / total_examples
|
26
|
+
return {"train_loss": loss, "train_accuracy": accuracy, "train_auc": auc}
|
27
|
+
|
28
|
+
|
29
|
+
def aggregate_eval_metrics(
|
30
|
+
results: List[Tuple[int, Dict[str, float]]]
|
31
|
+
) -> Dict[str, float]:
|
32
|
+
"""
|
33
|
+
Weighted aggregation of evaluation metrics across clients.
|
34
|
+
Expects each tuple (num_examples, metrics_dict) to include
|
35
|
+
'eval_loss', 'eval_accuracy', 'eval_auc'.
|
36
|
+
"""
|
37
|
+
total_examples = sum(num_examples for num_examples, _ in results)
|
38
|
+
loss = sum(metrics.get("eval_loss", 0.0) * num_examples
|
39
|
+
for num_examples, metrics in results) / total_examples
|
40
|
+
accuracy = sum(metrics.get("eval_accuracy", 0.0) * num_examples
|
41
|
+
for num_examples, metrics in results) / total_examples
|
42
|
+
auc = sum(metrics.get("eval_auc", 0.0) * num_examples
|
43
|
+
for num_examples, metrics in results) / total_examples
|
44
|
+
return {"eval_loss": loss, "eval_accuracy": accuracy, "eval_auc": auc}
|
45
|
+
|
46
|
+
class Strategy:
|
47
|
+
"""
|
48
|
+
A wrapper for Flower server strategies, with custom metric aggregation
|
49
|
+
and console logs on aggregation/evaluation completion.
|
50
|
+
"""
|
51
|
+
def __init__(
|
52
|
+
self,
|
53
|
+
name: str = "FedAvg",
|
54
|
+
fraction_fit: float = 1.0,
|
55
|
+
fraction_evaluate: float = 1.0,
|
56
|
+
min_fit_clients: int = 2,
|
57
|
+
min_evaluate_clients: int = 2,
|
58
|
+
min_available_clients: int = 2,
|
59
|
+
initial_parameters: Optional[List[Any]] = None,
|
60
|
+
evaluate_fn: Optional[Callable[
|
61
|
+
[int, fl.common.Parameters, Dict[str, Any]],
|
62
|
+
Tuple[float, Dict[str, float]]
|
63
|
+
]] = None,
|
64
|
+
fit_metrics_aggregation_fn: Optional[
|
65
|
+
Callable[[List[Tuple[int, fl.common.FitRes]]], Dict[str, float]]
|
66
|
+
] = None,
|
67
|
+
evaluate_metrics_aggregation_fn: Optional[
|
68
|
+
Callable[[List[Tuple[int, fl.common.EvaluateRes]]], Dict[str, float]]
|
69
|
+
] = None,
|
70
|
+
) -> None:
|
71
|
+
self.name = name
|
72
|
+
self.fraction_fit = fraction_fit
|
73
|
+
self.fraction_evaluate = fraction_evaluate
|
74
|
+
self.min_fit_clients = min_fit_clients
|
75
|
+
self.min_evaluate_clients = min_evaluate_clients
|
76
|
+
self.min_available_clients = min_available_clients
|
77
|
+
self.initial_parameters = initial_parameters or []
|
78
|
+
self.evaluate_fn = evaluate_fn
|
79
|
+
# Use custom aggregators if provided, else default to ours
|
80
|
+
self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn or aggregate_fit_metrics
|
81
|
+
self.evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn or aggregate_eval_metrics
|
82
|
+
self.strategy_object: Optional[fl.server.strategy.Strategy] = None
|
83
|
+
|
84
|
+
def create_strategy(self) -> None:
|
85
|
+
# 1) Instantiate the underlying Flower strategy
|
86
|
+
StrategyClass = getattr(fl.server.strategy, self.name)
|
87
|
+
params: Dict[str, Any] = {
|
88
|
+
"fraction_fit": self.fraction_fit,
|
89
|
+
"fraction_evaluate": self.fraction_evaluate,
|
90
|
+
"min_fit_clients": self.min_fit_clients,
|
91
|
+
"min_evaluate_clients": self.min_evaluate_clients,
|
92
|
+
"min_available_clients": self.min_available_clients,
|
93
|
+
"evaluate_fn": self.evaluate_fn,
|
94
|
+
# Plug in our custom aggregators
|
95
|
+
"fit_metrics_aggregation_fn": self.fit_metrics_aggregation_fn,
|
96
|
+
"evaluate_metrics_aggregation_fn": self.evaluate_metrics_aggregation_fn,
|
97
|
+
}
|
98
|
+
if self.initial_parameters:
|
99
|
+
params["initial_parameters"] = fl.common.ndarrays_to_parameters(
|
100
|
+
self.initial_parameters
|
101
|
+
)
|
102
|
+
|
103
|
+
strat = StrategyClass(**params)
|
104
|
+
|
105
|
+
# 2) Wrap aggregate_fit to log
|
106
|
+
original_agg_fit = strat.aggregate_fit
|
107
|
+
def logged_aggregate_fit(rnd, results, failures):
|
108
|
+
aggregated_params, metrics = original_agg_fit(rnd, results, failures)
|
109
|
+
print(f"[Server] ✔ Round {rnd} fit complete → Metrics: {metrics}")
|
110
|
+
return aggregated_params, metrics
|
111
|
+
strat.aggregate_fit = logged_aggregate_fit # type: ignore
|
112
|
+
|
113
|
+
# 3) Wrap aggregate_evaluate to log
|
114
|
+
original_agg_eval = strat.aggregate_evaluate
|
115
|
+
def logged_aggregate_evaluate(rnd, results, failures):
|
116
|
+
loss, metrics = original_agg_eval(rnd, results, failures)
|
117
|
+
print(results)
|
118
|
+
print(f"[Server] ✔ Round {rnd} eval complete → Loss: {loss}, Metrics: {metrics}")
|
119
|
+
return loss, metrics
|
120
|
+
strat.aggregate_evaluate = logged_aggregate_evaluate # type: ignore
|
121
|
+
|
122
|
+
self.strategy_object = strat
|
@@ -19,6 +19,12 @@ MEDfl/NetManager/net_helper.py,sha256=tyfxmpbleSdfPfo2ezKT0VOvZu660v9nhBuHCpl8pG
|
|
19
19
|
MEDfl/NetManager/net_manager_queries.py,sha256=j-CLQPjtTLyZuFPhIcwJStD7L7xtZpkmkhe_h3pDuTs,4086
|
20
20
|
MEDfl/NetManager/network.py,sha256=5t705fzWc-BRg-QPAbAcDv5ckDGzsPwj_Q5V0iTgkx0,6829
|
21
21
|
MEDfl/NetManager/node.py,sha256=t90QuYZ8M1X_AG1bwTta0CnlOuodqkmpVda2K7NOgHc,6542
|
22
|
+
MEDfl/rw/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
23
|
+
MEDfl/rw/client.py,sha256=k8y8Wxh2KNe2oy5gRD-KXpTEGCYzp7X2oF5-Z6Rk1_E,1329
|
24
|
+
MEDfl/rw/model.py,sha256=a3mrACDqg4K1V4Qyhh0PBEmWL5SpUYsiEarWHwsVcGk,2704
|
25
|
+
MEDfl/rw/rwConfig.py,sha256=nK3Inv7v7Dm9gZnUnK5EqA4DmQ7TqiH4UoCZ8MlgFjA,823
|
26
|
+
MEDfl/rw/server.py,sha256=JTexQ5KVOrXWmGOMoLstiVrwUNDHaEhWYnImvAF1Fiw,1557
|
27
|
+
MEDfl/rw/strategy.py,sha256=NvpDpuU5_4Xv9vYyfvGLeaHaYd9h7V-b330G-PgPFCE,5469
|
22
28
|
MEDfl/scripts/__init__.py,sha256=Pq1weevsPaU7MRMHfBYeyT0EOFeWLeVM6Y1DVz6jw1A,48
|
23
29
|
MEDfl/scripts/base.py,sha256=QrmG7gkiPYkAy-5tXxJgJmOSLGAKeIVH6i0jq7G9xnA,752
|
24
30
|
MEDfl/scripts/create_db.py,sha256=MnFtZkTueRZ-3qXPNX4JsXjOKj-4mlkxoRhSFdRcvJw,3817
|
@@ -48,8 +54,8 @@ Medfl/scripts/base.py,sha256=QrmG7gkiPYkAy-5tXxJgJmOSLGAKeIVH6i0jq7G9xnA,752
|
|
48
54
|
Medfl/scripts/create_db.py,sha256=MnFtZkTueRZ-3qXPNX4JsXjOKj-4mlkxoRhSFdRcvJw,3817
|
49
55
|
alembic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
50
56
|
alembic/env.py,sha256=-aSZ6SlJeK1ZeqHgM-54hOi9LhJRFP0SZGjut-JnY-4,1588
|
51
|
-
medfl-2.0.
|
52
|
-
medfl-2.0.
|
53
|
-
medfl-2.0.
|
54
|
-
medfl-2.0.
|
55
|
-
medfl-2.0.
|
57
|
+
medfl-2.0.3.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
58
|
+
medfl-2.0.3.dist-info/METADATA,sha256=U9hqDT2qdSxEyYvtVFKQmjjuHx51XJ-RaKmg8ruzu_M,4579
|
59
|
+
medfl-2.0.3.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
|
60
|
+
medfl-2.0.3.dist-info/top_level.txt,sha256=dIL9X8HOFuaVSzpg40DVveDPrymWRoShHtspH7kkjdI,14
|
61
|
+
medfl-2.0.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|