MEDfl 2.0.4.dev1__py3-none-any.whl → 2.0.4.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- MEDfl/rw/client.py +98 -29
- MEDfl/rw/model.py +28 -0
- MEDfl/rw/server.py +71 -18
- MEDfl/rw/strategy.py +72 -78
- {MEDfl-2.0.4.dev1.dist-info → MEDfl-2.0.4.dev3.dist-info}/METADATA +1 -1
- MEDfl-2.0.4.dev3.dist-info/RECORD +36 -0
- MEDfl/rw/rwConfig.py +0 -21
- MEDfl/rw/verbose_server.py +0 -21
- MEDfl-2.0.4.dev1.dist-info/RECORD +0 -62
- Medfl/LearningManager/__init__.py +0 -13
- Medfl/LearningManager/client.py +0 -150
- Medfl/LearningManager/dynamicModal.py +0 -287
- Medfl/LearningManager/federated_dataset.py +0 -60
- Medfl/LearningManager/flpipeline.py +0 -192
- Medfl/LearningManager/model.py +0 -223
- Medfl/LearningManager/params.yaml +0 -14
- Medfl/LearningManager/params_optimiser.py +0 -442
- Medfl/LearningManager/plot.py +0 -229
- Medfl/LearningManager/server.py +0 -181
- Medfl/LearningManager/strategy.py +0 -82
- Medfl/LearningManager/utils.py +0 -331
- Medfl/NetManager/__init__.py +0 -10
- Medfl/NetManager/database_connector.py +0 -43
- Medfl/NetManager/dataset.py +0 -92
- Medfl/NetManager/flsetup.py +0 -320
- Medfl/NetManager/net_helper.py +0 -254
- Medfl/NetManager/net_manager_queries.py +0 -142
- Medfl/NetManager/network.py +0 -194
- Medfl/NetManager/node.py +0 -184
- Medfl/__init__.py +0 -3
- Medfl/scripts/__init__.py +0 -2
- Medfl/scripts/base.py +0 -30
- Medfl/scripts/create_db.py +0 -126
- {MEDfl-2.0.4.dev1.dist-info → MEDfl-2.0.4.dev3.dist-info}/LICENSE +0 -0
- {MEDfl-2.0.4.dev1.dist-info → MEDfl-2.0.4.dev3.dist-info}/WHEEL +0 -0
- {MEDfl-2.0.4.dev1.dist-info → MEDfl-2.0.4.dev3.dist-info}/top_level.txt +0 -0
MEDfl/rw/client.py
CHANGED
@@ -1,39 +1,108 @@
|
|
1
|
+
# client.py
|
2
|
+
"""
|
3
|
+
Federated Learning Client with Optional Differential Privacy.
|
4
|
+
|
5
|
+
"""
|
6
|
+
|
7
|
+
import argparse
|
8
|
+
import pandas as pd
|
1
9
|
import flwr as fl
|
2
10
|
import torch
|
3
11
|
import torch.nn as nn
|
4
12
|
import torch.optim as optim
|
5
|
-
from
|
13
|
+
from torch.utils.data import TensorDataset, DataLoader
|
14
|
+
from sklearn.metrics import accuracy_score, roc_auc_score
|
15
|
+
from model import Net # Local model definition
|
16
|
+
import socket
|
17
|
+
import platform
|
18
|
+
|
19
|
+
|
20
|
+
class DPConfig:
|
21
|
+
"""
|
22
|
+
Configuration for Differential Privacy (DP) settings.
|
23
|
+
|
24
|
+
Attributes:
|
25
|
+
noise_multiplier (float): Noise multiplier for DP.
|
26
|
+
max_grad_norm (float): Maximum gradient norm (clipping threshold).
|
27
|
+
batch_size (int): Batch size for training.
|
28
|
+
secure_rng (bool): Whether to use a secure RNG for DP noise.
|
29
|
+
"""
|
30
|
+
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
noise_multiplier: float = 1.0,
|
34
|
+
max_grad_norm: float = 1.0,
|
35
|
+
batch_size: int = 32,
|
36
|
+
secure_rng: bool = False,
|
37
|
+
):
|
38
|
+
self.noise_multiplier = noise_multiplier
|
39
|
+
self.max_grad_norm = max_grad_norm
|
40
|
+
self.batch_size = batch_size
|
41
|
+
self.secure_rng = secure_rng
|
6
42
|
|
7
|
-
# Dummy training data
|
8
|
-
X_train = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
|
9
|
-
y_train = torch.tensor([[1.0], [0.0]])
|
10
43
|
|
11
44
|
class FlowerClient(fl.client.NumPyClient):
|
12
|
-
|
13
|
-
|
14
|
-
|
45
|
+
"""
|
46
|
+
FlowerClient: A federated learning client that trains a PyTorch model
|
47
|
+
and optionally applies differential privacy.
|
48
|
+
|
49
|
+
"""
|
50
|
+
|
51
|
+
def __init__(
|
52
|
+
self,
|
53
|
+
server_address: str,
|
54
|
+
data_path: str = "data/data.csv",
|
55
|
+
dp_config: DPConfig = None,
|
56
|
+
):
|
57
|
+
"""
|
58
|
+
Initialize client by loading data, creating model, optimizer,
|
59
|
+
and optionally enabling DP.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
server_address (str): Flower server address.
|
63
|
+
data_path (str): Path to CSV dataset.
|
64
|
+
dp_config (DPConfig): Optional DP configuration.
|
65
|
+
"""
|
66
|
+
self.server_address = server_address
|
67
|
+
|
68
|
+
# ---------- Load Data ----------
|
69
|
+
df = pd.read_csv(data_path)
|
70
|
+
X = df.iloc[:, :-1].values
|
71
|
+
y = df.iloc[:, -1].values
|
72
|
+
|
73
|
+
self.X_tensor = torch.tensor(X, dtype=torch.float32)
|
74
|
+
self.y_tensor = torch.tensor(y, dtype=torch.float32)
|
75
|
+
|
76
|
+
batch_size = dp_config.batch_size if dp_config else 32
|
77
|
+
dataset = TensorDataset(self.X_tensor, self.y_tensor)
|
78
|
+
self.train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
79
|
+
|
80
|
+
# ---------- Model and Optimizer ----------
|
81
|
+
input_dim = X.shape[1]
|
82
|
+
self.model = Net(input_dim)
|
83
|
+
self.criterion = nn.BCEWithLogitsLoss()
|
15
84
|
self.optimizer = optim.SGD(self.model.parameters(), lr=0.01)
|
16
85
|
|
86
|
+
# ---------- Differential Privacy ----------
|
87
|
+
self.privacy_engine = None
|
88
|
+
if dp_config is not None:
|
89
|
+
try:
|
90
|
+
from opacus import PrivacyEngine
|
91
|
+
|
92
|
+
self.privacy_engine = PrivacyEngine()
|
93
|
+
self.model, self.optimizer, self.train_loader = self.privacy_engine.make_private(
|
94
|
+
module=self.model,
|
95
|
+
optimizer=self.optimizer,
|
96
|
+
data_loader=self.train_loader,
|
97
|
+
noise_multiplier=dp_config.noise_multiplier,
|
98
|
+
max_grad_norm=dp_config.max_grad_norm,
|
99
|
+
secure_rng=dp_config.secure_rng,
|
100
|
+
)
|
101
|
+
except ImportError:
|
102
|
+
print("⚠️ Opacus not installed, running without DP.")
|
103
|
+
|
17
104
|
def get_parameters(self, config):
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
state_dict = {k: torch.tensor(v) for k, v in params_dict}
|
23
|
-
self.model.load_state_dict(state_dict, strict=True)
|
24
|
-
|
25
|
-
def fit(self, parameters, config):
|
26
|
-
self.set_parameters(parameters)
|
27
|
-
self.model.train()
|
28
|
-
for _ in range(5):
|
29
|
-
self.optimizer.zero_grad()
|
30
|
-
output = self.model(X_train)
|
31
|
-
loss = self.loss_fn(output, y_train)
|
32
|
-
loss.backward()
|
33
|
-
self.optimizer.step()
|
34
|
-
return self.get_parameters(config), len(X_train), {}
|
35
|
-
|
36
|
-
def evaluate(self, parameters, config):
|
37
|
-
return 0.5, len(X_train), {}
|
38
|
-
|
39
|
-
fl.client.start_numpy_client(server_address="100.65.215.27:8080", client=FlowerClient())
|
105
|
+
"""
|
106
|
+
Get model parameters as a list of NumPy arrays.
|
107
|
+
"""
|
108
|
+
return [val.cpu().numpy() for val in self.mo]()
|
MEDfl/rw/model.py
CHANGED
@@ -2,18 +2,46 @@ import torch.nn as nn
|
|
2
2
|
import torch.nn.functional as F
|
3
3
|
|
4
4
|
class Net(nn.Module):
|
5
|
+
"""
|
6
|
+
Net defines a simple feedforward neural network with two hidden layers
|
7
|
+
|
8
|
+
"""
|
9
|
+
|
5
10
|
def __init__(self, input_dim):
|
11
|
+
"""
|
12
|
+
Initialize the layers of the network.
|
13
|
+
|
14
|
+
Args:
|
15
|
+
input_dim (int): Number of input features.
|
16
|
+
"""
|
6
17
|
super().__init__()
|
18
|
+
# First fully connected layer: input_dim → 64
|
7
19
|
self.fc1 = nn.Linear(input_dim, 64)
|
20
|
+
# Second fully connected layer: 64 → 32
|
8
21
|
self.fc2 = nn.Linear(64, 32)
|
22
|
+
# Output layer: 32 → 1
|
9
23
|
self.fc3 = nn.Linear(32, 1)
|
24
|
+
# Dropout with 30% probability
|
10
25
|
self.dropout = nn.Dropout(0.3)
|
26
|
+
# Batch normalization layers for hidden layers
|
11
27
|
self.batchnorm1 = nn.BatchNorm1d(64)
|
12
28
|
self.batchnorm2 = nn.BatchNorm1d(32)
|
13
29
|
|
14
30
|
def forward(self, x):
|
31
|
+
"""
|
32
|
+
Define the forward pass of the network.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
x (torch.Tensor): Input tensor of shape (batch_size, input_dim).
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
torch.Tensor: Output logits of shape (batch_size, 1).
|
39
|
+
"""
|
40
|
+
# Hidden layer 1: linear → batchnorm → ReLU → dropout
|
15
41
|
x = F.relu(self.batchnorm1(self.fc1(x)))
|
16
42
|
x = self.dropout(x)
|
43
|
+
# Hidden layer 2: linear → batchnorm → ReLU → dropout
|
17
44
|
x = F.relu(self.batchnorm2(self.fc2(x)))
|
18
45
|
x = self.dropout(x)
|
46
|
+
# Output layer: linear
|
19
47
|
return self.fc3(x) # raw logits for BCEWithLogitsLoss
|
MEDfl/rw/server.py
CHANGED
@@ -1,21 +1,32 @@
|
|
1
1
|
import flwr as fl
|
2
2
|
from flwr.server.strategy import FedAvg
|
3
3
|
from flwr.server.server import ServerConfig
|
4
|
-
from typing import Optional, Any
|
4
|
+
from typing import Optional, Any
|
5
5
|
from MEDfl.rw.strategy import Strategy
|
6
|
-
|
7
|
-
import time
|
6
|
+
import asyncio
|
8
7
|
from flwr.server.client_manager import ClientManager
|
9
8
|
from flwr.server.client_proxy import ClientProxy
|
10
9
|
from flwr.common import GetPropertiesIns
|
11
|
-
import
|
10
|
+
from flwr.common import GetPropertiesIns
|
11
|
+
|
12
12
|
|
13
13
|
class FederatedServer:
|
14
14
|
"""
|
15
|
-
|
16
|
-
|
17
|
-
|
15
|
+
FederatedServer wraps the launch and configuration of a Flower federated learning server.
|
16
|
+
|
17
|
+
Attributes:
|
18
|
+
server_address (str): Server host and port in the format "host:port".
|
19
|
+
server_config (ServerConfig): Configuration for the Flower server.
|
20
|
+
strategy_wrapper (Strategy): Wrapper around the actual Flower strategy.
|
21
|
+
strategy (flwr.server.Strategy): Actual Flower strategy instance.
|
22
|
+
certificates (Any): Optional TLS certificates.
|
23
|
+
connected_clients (list): List of connected client IDs.
|
24
|
+
|
25
|
+
Methods:
|
26
|
+
start():
|
27
|
+
Launch the Flower server with the specified strategy and log client connections.
|
18
28
|
"""
|
29
|
+
|
19
30
|
def __init__(
|
20
31
|
self,
|
21
32
|
host: str = "0.0.0.0",
|
@@ -24,29 +35,42 @@ class FederatedServer:
|
|
24
35
|
strategy: Optional[Strategy] = None,
|
25
36
|
certificates: Optional[Any] = None,
|
26
37
|
):
|
38
|
+
"""
|
39
|
+
Initialize the FederatedServer.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
host (str): Hostname or IP to bind the server to.
|
43
|
+
port (int): Port to listen on.
|
44
|
+
num_rounds (int): Number of federated learning rounds to execute.
|
45
|
+
strategy (Optional[Strategy]): Optional custom strategy wrapper.
|
46
|
+
certificates (Optional[Any]): Optional TLS certificates.
|
47
|
+
"""
|
48
|
+
# Server address and configuration
|
27
49
|
self.server_address = f"{host}:{port}"
|
28
50
|
self.server_config = ServerConfig(num_rounds=num_rounds)
|
29
|
-
|
51
|
+
|
52
|
+
# Use custom or default strategy
|
30
53
|
self.strategy_wrapper = strategy or Strategy()
|
31
|
-
# Build the actual Flower strategy object
|
32
54
|
self.strategy_wrapper.create_strategy()
|
33
55
|
if self.strategy_wrapper.strategy_object is None:
|
34
56
|
raise ValueError("Strategy object not initialized. Call create_strategy() first.")
|
35
57
|
self.strategy = self.strategy_wrapper.strategy_object
|
58
|
+
|
36
59
|
self.certificates = certificates
|
37
|
-
self.connected_clients = []
|
60
|
+
self.connected_clients = [] # Track connected client IDs
|
61
|
+
|
38
62
|
|
39
63
|
def start(self) -> None:
|
40
64
|
"""
|
41
|
-
Start the Flower server with the configured strategy.
|
42
|
-
Now tracks and logs client connections before starting.
|
65
|
+
Start the Flower server with the configured strategy and track client connections.
|
43
66
|
"""
|
44
67
|
print(f"Using strategy: {self.strategy_wrapper.name}")
|
45
68
|
print(f"Starting Flower server on {self.server_address} with strategy {self.strategy_wrapper.name}")
|
46
|
-
|
47
|
-
#
|
69
|
+
|
70
|
+
# Use a custom client manager that logs client connections
|
48
71
|
client_manager = TrackingClientManager(self)
|
49
|
-
|
72
|
+
|
73
|
+
# Launch the Flower server
|
50
74
|
fl.server.start_server(
|
51
75
|
server_address=self.server_address,
|
52
76
|
config=self.server_config,
|
@@ -55,28 +79,57 @@ class FederatedServer:
|
|
55
79
|
client_manager=client_manager,
|
56
80
|
)
|
57
81
|
|
82
|
+
|
58
83
|
class TrackingClientManager(fl.server.client_manager.SimpleClientManager):
|
59
84
|
"""
|
60
|
-
|
85
|
+
TrackingClientManager extends the default SimpleClientManager to log client connections.
|
86
|
+
|
87
|
+
Attributes:
|
88
|
+
server (FederatedServer): The FederatedServer instance this manager belongs to.
|
89
|
+
client_properties (dict): Placeholder for storing client-specific properties.
|
61
90
|
"""
|
91
|
+
|
62
92
|
def __init__(self, server: FederatedServer):
|
93
|
+
"""
|
94
|
+
Initialize the TrackingClientManager.
|
95
|
+
|
96
|
+
Args:
|
97
|
+
server (FederatedServer): Reference to the FederatedServer.
|
98
|
+
"""
|
63
99
|
super().__init__()
|
64
100
|
self.server = server
|
65
|
-
self.client_properties = {}
|
101
|
+
self.client_properties = {}
|
66
102
|
|
67
103
|
def register(self, client: ClientProxy) -> bool:
|
104
|
+
"""
|
105
|
+
Register a client and log its connection.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
client (ClientProxy): The client proxy being registered.
|
109
|
+
|
110
|
+
Returns:
|
111
|
+
bool: True if the client was registered successfully.
|
112
|
+
"""
|
68
113
|
success = super().register(client)
|
69
114
|
if success and client.cid not in self.server.connected_clients:
|
70
|
-
# Run the
|
115
|
+
# Run the asynchronous hostname fetch synchronously
|
71
116
|
asyncio.run(self._fetch_and_log_hostname(client))
|
72
117
|
return success
|
73
118
|
|
74
119
|
async def _fetch_and_log_hostname(self, client: ClientProxy):
|
120
|
+
"""
|
121
|
+
Asynchronously fetch and log the client's hostname or CID.
|
122
|
+
|
123
|
+
Args:
|
124
|
+
client (ClientProxy): The client proxy.
|
125
|
+
"""
|
126
|
+
# Optional: uncomment to fetch hostname from client properties
|
75
127
|
# try:
|
76
128
|
# ins = GetPropertiesIns(config={})
|
77
129
|
# props = await client.get_properties(ins=ins, timeout=10.0, group_id=0)
|
78
130
|
# hostname = props.properties.get("hostname", "unknown")
|
79
131
|
# except Exception as e:
|
80
132
|
# hostname = f"Error: {e}"
|
133
|
+
|
81
134
|
print(f"✅ Client connected - CID: {client.cid}")
|
82
135
|
self.server.connected_clients.append(client.cid)
|
MEDfl/rw/strategy.py
CHANGED
@@ -1,53 +1,43 @@
|
|
1
1
|
import flwr as fl
|
2
2
|
from typing import Callable, Optional, Dict, Any, Tuple, List
|
3
|
+
from flwr.common import GetPropertiesIns
|
4
|
+
from flwr.server.client_manager import ClientManager
|
5
|
+
from flwr.server.client_proxy import ClientProxy
|
3
6
|
|
4
|
-
#
|
7
|
+
# ===================================================
|
8
|
+
# Custom metric aggregation functions
|
9
|
+
# ===================================================
|
5
10
|
|
6
|
-
from typing import List, Tuple, Dict
|
7
|
-
|
8
|
-
# Custom aggregation for client-returned metrics
|
9
11
|
def aggregate_fit_metrics(
|
10
12
|
results: List[Tuple[int, Dict[str, float]]]
|
11
13
|
) -> Dict[str, float]:
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
"""
|
17
|
-
# Sum total examples
|
18
|
-
total_examples = sum(num_examples for num_examples, _ in results)
|
19
|
-
# Weighted averages
|
20
|
-
loss = sum(metrics.get("train_loss", 0.0) * num_examples
|
21
|
-
for num_examples, metrics in results) / total_examples
|
22
|
-
accuracy = sum(metrics.get("train_accuracy", 0.0) * num_examples
|
23
|
-
for num_examples, metrics in results) / total_examples
|
24
|
-
auc = sum(metrics.get("train_auc", 0.0) * num_examples
|
25
|
-
for num_examples, metrics in results) / total_examples
|
26
|
-
return {"train_loss": loss, "train_accuracy": accuracy, "train_auc": auc}
|
27
|
-
|
14
|
+
total = sum(n for n, _ in results)
|
15
|
+
loss = sum(m.get("train_loss", 0.0) * n for n, m in results) / total
|
16
|
+
acc = sum(m.get("train_accuracy", 0.0) * n for n, m in results) / total
|
17
|
+
auc = sum(m.get("train_auc", 0.0) * n for n, m in results) / total
|
18
|
+
return {"train_loss": loss, "train_accuracy": acc, "train_auc": auc}
|
28
19
|
|
29
20
|
def aggregate_eval_metrics(
|
30
21
|
results: List[Tuple[int, Dict[str, float]]]
|
31
22
|
) -> Dict[str, float]:
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
"""
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
for num_examples, metrics in results) / total_examples
|
42
|
-
auc = sum(metrics.get("eval_auc", 0.0) * num_examples
|
43
|
-
for num_examples, metrics in results) / total_examples
|
44
|
-
return {"eval_loss": loss, "eval_accuracy": accuracy, "eval_auc": auc}
|
23
|
+
total = sum(n for n, _ in results)
|
24
|
+
loss = sum(m.get("eval_loss", 0.0) * n for n, m in results) / total
|
25
|
+
acc = sum(m.get("eval_accuracy", 0.0) * n for n, m in results) / total
|
26
|
+
auc = sum(m.get("eval_auc", 0.0) * n for n, m in results) / total
|
27
|
+
return {"eval_loss": loss, "eval_accuracy": acc, "eval_auc": auc}
|
28
|
+
|
29
|
+
# ===================================================
|
30
|
+
# Strategy Wrapper
|
31
|
+
# ===================================================
|
45
32
|
|
46
33
|
class Strategy:
|
47
34
|
"""
|
48
|
-
|
49
|
-
|
35
|
+
Flower Strategy wrapper:
|
36
|
+
- Custom metric aggregation
|
37
|
+
- Per-client & aggregated metric logging
|
38
|
+
- Synchronous get_properties() inspection in configure_fit()
|
50
39
|
"""
|
40
|
+
|
51
41
|
def __init__(
|
52
42
|
self,
|
53
43
|
name: str = "FedAvg",
|
@@ -57,16 +47,9 @@ class Strategy:
|
|
57
47
|
min_evaluate_clients: int = 2,
|
58
48
|
min_available_clients: int = 2,
|
59
49
|
initial_parameters: Optional[List[Any]] = None,
|
60
|
-
evaluate_fn: Optional[Callable
|
61
|
-
|
62
|
-
|
63
|
-
]] = None,
|
64
|
-
fit_metrics_aggregation_fn: Optional[
|
65
|
-
Callable[[List[Tuple[int, fl.common.FitRes]]], Dict[str, float]]
|
66
|
-
] = None,
|
67
|
-
evaluate_metrics_aggregation_fn: Optional[
|
68
|
-
Callable[[List[Tuple[int, fl.common.EvaluateRes]]], Dict[str, float]]
|
69
|
-
] = None,
|
50
|
+
evaluate_fn: Optional[Callable] = None,
|
51
|
+
fit_metrics_aggregation_fn: Optional[Callable] = None,
|
52
|
+
evaluate_metrics_aggregation_fn: Optional[Callable] = None,
|
70
53
|
) -> None:
|
71
54
|
self.name = name
|
72
55
|
self.fraction_fit = fraction_fit
|
@@ -76,13 +59,13 @@ class Strategy:
|
|
76
59
|
self.min_available_clients = min_available_clients
|
77
60
|
self.initial_parameters = initial_parameters or []
|
78
61
|
self.evaluate_fn = evaluate_fn
|
79
|
-
|
62
|
+
|
80
63
|
self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn or aggregate_fit_metrics
|
81
64
|
self.evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn or aggregate_eval_metrics
|
65
|
+
|
82
66
|
self.strategy_object: Optional[fl.server.strategy.Strategy] = None
|
83
67
|
|
84
68
|
def create_strategy(self) -> None:
|
85
|
-
# 1) Instantiate the underlying Flower strategy
|
86
69
|
StrategyClass = getattr(fl.server.strategy, self.name)
|
87
70
|
params: Dict[str, Any] = {
|
88
71
|
"fraction_fit": self.fraction_fit,
|
@@ -91,50 +74,61 @@ class Strategy:
|
|
91
74
|
"min_evaluate_clients": self.min_evaluate_clients,
|
92
75
|
"min_available_clients": self.min_available_clients,
|
93
76
|
"evaluate_fn": self.evaluate_fn,
|
94
|
-
# Plug in our custom aggregators
|
95
77
|
"fit_metrics_aggregation_fn": self.fit_metrics_aggregation_fn,
|
96
78
|
"evaluate_metrics_aggregation_fn": self.evaluate_metrics_aggregation_fn,
|
97
79
|
}
|
98
80
|
if self.initial_parameters:
|
99
|
-
params["initial_parameters"] = fl.common.ndarrays_to_parameters(
|
100
|
-
self.initial_parameters
|
101
|
-
)
|
81
|
+
params["initial_parameters"] = fl.common.ndarrays_to_parameters(self.initial_parameters)
|
102
82
|
|
103
83
|
strat = StrategyClass(**params)
|
104
84
|
|
105
|
-
#
|
85
|
+
# Wrap aggregate_fit to log metrics
|
106
86
|
original_agg_fit = strat.aggregate_fit
|
107
|
-
def
|
108
|
-
|
109
|
-
print(f"\n[Server] 🔄 Round {rnd} - Client Training Metrics:")
|
87
|
+
def logged_agg_fit(server_round, results, failures):
|
88
|
+
print(f"\n[Server] 🔄 Round {server_round} - Client Training Metrics:")
|
110
89
|
for i, (client_id, fit_res) in enumerate(results):
|
111
|
-
print(f" CTM Round {
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
# Print aggregated metrics
|
117
|
-
print(f"[Server] ✅ Round {rnd} - Aggregated Training Metrics: {metrics}\n")
|
118
|
-
return aggregated_params, metrics
|
119
|
-
|
120
|
-
strat.aggregate_fit = logged_aggregate_fit # type: ignore
|
90
|
+
print(f" CTM Round {server_round} Client:{client_id.cid}: {fit_res.metrics}")
|
91
|
+
agg_params, metrics = original_agg_fit(server_round, results, failures)
|
92
|
+
print(f"[Server] ✅ Round {server_round} - Aggregated Training Metrics: {metrics}\n")
|
93
|
+
return agg_params, metrics
|
94
|
+
strat.aggregate_fit = logged_agg_fit
|
121
95
|
|
122
|
-
#
|
96
|
+
# Wrap aggregate_evaluate to log metrics
|
123
97
|
original_agg_eval = strat.aggregate_evaluate
|
124
|
-
def
|
125
|
-
|
126
|
-
print(f"\n[Server] 📊 Round {rnd} - Client Evaluation Metrics:")
|
98
|
+
def logged_agg_eval(server_round, results, failures):
|
99
|
+
print(f"\n[Server] 📊 Round {server_round} - Client Evaluation Metrics:")
|
127
100
|
for i, (client_id, eval_res) in enumerate(results):
|
128
|
-
print(f" CEM Round {
|
129
|
-
|
130
|
-
|
131
|
-
loss, metrics
|
132
|
-
|
133
|
-
# Print aggregated metrics
|
134
|
-
print(f"[Server] ✅ Round {rnd} - Aggregated Evaluation Metrics:")
|
135
|
-
print(f" Loss: {loss}, Metrics: {metrics}\n")
|
101
|
+
print(f" CEM Round {server_round} Client:{client_id.cid}: {eval_res.metrics}")
|
102
|
+
loss, metrics = original_agg_eval(server_round, results, failures)
|
103
|
+
print(f"[Server] ✅ Round {server_round} - Aggregated Evaluation Metrics:")
|
104
|
+
print(f" Loss: {loss}, Metrics: {metrics}\n")
|
136
105
|
return loss, metrics
|
106
|
+
strat.aggregate_evaluate = logged_agg_eval
|
107
|
+
|
108
|
+
# Wrap configure_fit to inspect client properties synchronously
|
109
|
+
original_conf_fit = strat.configure_fit
|
110
|
+
def wrapped_conf_fit(
|
111
|
+
server_round: int,
|
112
|
+
parameters,
|
113
|
+
client_manager: ClientManager
|
114
|
+
):
|
115
|
+
selected = original_conf_fit(
|
116
|
+
server_round=server_round,
|
117
|
+
parameters=parameters,
|
118
|
+
client_manager=client_manager
|
119
|
+
)
|
120
|
+
|
121
|
+
# Synchronously fetch & log properties
|
122
|
+
ins = GetPropertiesIns(config={})
|
123
|
+
for client, _ in selected:
|
124
|
+
try:
|
125
|
+
props = client.get_properties(ins=ins, timeout=10.0, group_id=0)
|
126
|
+
print(f"\n📋 [Round {server_round}] Client {client.cid} Properties: {props.properties}")
|
127
|
+
|
128
|
+
except Exception as e:
|
129
|
+
print(f"⚠️ Failed to get properties from {client.cid}: {e}")
|
137
130
|
|
138
|
-
|
131
|
+
return selected
|
139
132
|
|
133
|
+
strat.configure_fit = wrapped_conf_fit
|
140
134
|
self.strategy_object = strat
|
@@ -0,0 +1,36 @@
|
|
1
|
+
MEDfl/__init__.py,sha256=70DmtU4C3A-1XYoaYm0moXBe-YGJ2FhEe3ga5SQVTts,97
|
2
|
+
MEDfl/LearningManager/__init__.py,sha256=IMHJVeyx5ew0U_90LNMNCd4QISzWv3XCCri7fQRvcsM,341
|
3
|
+
MEDfl/LearningManager/client.py,sha256=9Y_Zb0yxvCxx3dVCPQ1bXS5mCKasylSBnoVj-RDN270,5933
|
4
|
+
MEDfl/LearningManager/dynamicModal.py,sha256=q8u7xPpj_TdZnSr8kYj0Xx7Sdz-diXsKBAfVce8-qSU,10534
|
5
|
+
MEDfl/LearningManager/federated_dataset.py,sha256=InsZ5Rys2dgqaPxVyP5G3TrJMwiCNHOoTd3tCpUwUVM,2081
|
6
|
+
MEDfl/LearningManager/flpipeline.py,sha256=5lT2uod5EqnkRQ04cgm0gYyZz0djumfIYipCrzX1fdo,7111
|
7
|
+
MEDfl/LearningManager/model.py,sha256=vp8FIMxBdz3FTF5wJaea2IO_WGeANLZgBxTKVe3gW3Q,7456
|
8
|
+
MEDfl/LearningManager/params.yaml,sha256=Ix1cNtlWr3vDC0te6pipl5w8iLADO6dZvwm633-VaIA,436
|
9
|
+
MEDfl/LearningManager/params_optimiser.py,sha256=8e0gCt4imwQHlNSJ3A2EAuc3wSr6yfSI6JDghohfmZQ,17618
|
10
|
+
MEDfl/LearningManager/plot.py,sha256=A6Z8wC8J-H-OmWBPKqwK5eiTB9vzOBGMaFv1SaNA9Js,7698
|
11
|
+
MEDfl/LearningManager/server.py,sha256=oTgW3K1UT6m4SQBk23FIf23km_BDq9vvjeC6OgY8DNw,7077
|
12
|
+
MEDfl/LearningManager/strategy.py,sha256=BHXpwmt7jx07y45YLUs8FZry2gYQbpiV4vNbHhsksQ4,3435
|
13
|
+
MEDfl/LearningManager/utils.py,sha256=B4RULJp-puJr724O6teI0PxnUyPV8NG-uPC6jqaiDKI,9605
|
14
|
+
MEDfl/NetManager/__init__.py,sha256=OpgsIiBg7UA6Bfnu_kqGfEPxU8JfpPxSFU98TOeDTP0,273
|
15
|
+
MEDfl/NetManager/database_connector.py,sha256=G8DAsD_pAIK1U67x3Q8gmSJGW7iJyxQ_NE5lWpT-P0Q,1474
|
16
|
+
MEDfl/NetManager/dataset.py,sha256=HTV0jrJ4Qlhl2aSJzdFU1lkxGBKtmJ390eBpwfKf_4o,2777
|
17
|
+
MEDfl/NetManager/flsetup.py,sha256=CVu_TIU7l3G6DDnwtY6JURbhIZk7gKC3unqWnU-YtlM,11434
|
18
|
+
MEDfl/NetManager/net_helper.py,sha256=tyfxmpbleSdfPfo2ezKT0VOvZu660v9nhBuHCpl8pG4,6812
|
19
|
+
MEDfl/NetManager/net_manager_queries.py,sha256=j-CLQPjtTLyZuFPhIcwJStD7L7xtZpkmkhe_h3pDuTs,4086
|
20
|
+
MEDfl/NetManager/network.py,sha256=5t705fzWc-BRg-QPAbAcDv5ckDGzsPwj_Q5V0iTgkx0,6829
|
21
|
+
MEDfl/NetManager/node.py,sha256=t90QuYZ8M1X_AG1bwTta0CnlOuodqkmpVda2K7NOgHc,6542
|
22
|
+
MEDfl/rw/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
23
|
+
MEDfl/rw/client.py,sha256=AQnLvM-pWVA2Md0GiRLB9C4s8S3TjcjjvN3U25ftpX4,3463
|
24
|
+
MEDfl/rw/model.py,sha256=OAoTmOw4zGWPa_ncDqNanLeucwWHmUydKED6zlB5Hps,1510
|
25
|
+
MEDfl/rw/server.py,sha256=TmcWUwMJ5BG7owqIVTtTC6w8bR02SESRT9lXh7BqlOg,4986
|
26
|
+
MEDfl/rw/strategy.py,sha256=aNlyQhHslmPJdiuJjsK9hu-IUJdrWR1yuGp7pNk4LeA,5974
|
27
|
+
MEDfl/scripts/__init__.py,sha256=Pq1weevsPaU7MRMHfBYeyT0EOFeWLeVM6Y1DVz6jw1A,48
|
28
|
+
MEDfl/scripts/base.py,sha256=QrmG7gkiPYkAy-5tXxJgJmOSLGAKeIVH6i0jq7G9xnA,752
|
29
|
+
MEDfl/scripts/create_db.py,sha256=MnFtZkTueRZ-3qXPNX4JsXjOKj-4mlkxoRhSFdRcvJw,3817
|
30
|
+
alembic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
31
|
+
alembic/env.py,sha256=-aSZ6SlJeK1ZeqHgM-54hOi9LhJRFP0SZGjut-JnY-4,1588
|
32
|
+
MEDfl-2.0.4.dev3.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
33
|
+
MEDfl-2.0.4.dev3.dist-info/METADATA,sha256=s5S0Ztt85jbo0x9kzkGRQDNrI9cMCed-qBWO4EA2q10,4326
|
34
|
+
MEDfl-2.0.4.dev3.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
35
|
+
MEDfl-2.0.4.dev3.dist-info/top_level.txt,sha256=dIL9X8HOFuaVSzpg40DVveDPrymWRoShHtspH7kkjdI,14
|
36
|
+
MEDfl-2.0.4.dev3.dist-info/RECORD,,
|
MEDfl/rw/rwConfig.py
DELETED
@@ -1,21 +0,0 @@
|
|
1
|
-
from dataclasses import dataclass
|
2
|
-
|
3
|
-
@dataclass
|
4
|
-
class RealWorldConfig:
|
5
|
-
"""
|
6
|
-
Configuration pour un déploiement fédéré en « real world ».
|
7
|
-
|
8
|
-
Attributes:
|
9
|
-
server_address: Adresse et port du serveur Flower (ex: "0.0.0.0:8080").
|
10
|
-
num_rounds: Nombre total de tours (rounds) de fédération.
|
11
|
-
fraction_fit: Fraction des clients participant à la phase de fit chaque round.
|
12
|
-
fraction_eval: Fraction des clients participant à la phase d'évaluation chaque round.
|
13
|
-
min_fit_clients: Nombre minimum de clients requis pour lancer la phase de fit.
|
14
|
-
min_eval_clients: Nombre minimum de clients requis pour la phase d'évaluation.
|
15
|
-
"""
|
16
|
-
server_address: str
|
17
|
-
num_rounds: int
|
18
|
-
fraction_fit: float
|
19
|
-
fraction_eval: float
|
20
|
-
min_fit_clients: int
|
21
|
-
min_eval_clients: int
|
MEDfl/rw/verbose_server.py
DELETED
@@ -1,21 +0,0 @@
|
|
1
|
-
from flwr.server import Server
|
2
|
-
from flwr.server.client_manager import SimpleClientManager
|
3
|
-
from flwr.server.client_proxy import ClientProxy
|
4
|
-
from flwr.common.logger import log
|
5
|
-
from logging import INFO
|
6
|
-
|
7
|
-
class VerboseServer(Server):
|
8
|
-
def __init__(self, strategy):
|
9
|
-
super().__init__(client_manager=SimpleClientManager(), strategy=strategy)
|
10
|
-
|
11
|
-
def client_manager_fn(self):
|
12
|
-
return self.client_manager
|
13
|
-
|
14
|
-
def on_client_connect(self, client: ClientProxy):
|
15
|
-
super().on_client_connect(client)
|
16
|
-
log(INFO, f"[Server] ➕ Client connected: {client.cid}")
|
17
|
-
log(INFO, f"[Server] Currently connected: {len(self.client_manager.all().values())} clients")
|
18
|
-
|
19
|
-
def on_client_disconnect(self, client: ClientProxy):
|
20
|
-
super().on_client_disconnect(client)
|
21
|
-
log(INFO, f"[Server] ❌ Client disconnected: {client.cid}")
|