aibt-fl 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aibt/__init__.py +77 -0
- aibt/aggregation.py +68 -0
- aibt/client.py +259 -0
- aibt/core.py +287 -0
- aibt/metrics.py +383 -0
- aibt/models.py +520 -0
- aibt/py.typed +2 -0
- aibt/utils.py +162 -0
- aibt_fl-1.0.0.dist-info/METADATA +247 -0
- aibt_fl-1.0.0.dist-info/RECORD +13 -0
- aibt_fl-1.0.0.dist-info/WHEEL +5 -0
- aibt_fl-1.0.0.dist-info/licenses/LICENSE +21 -0
- aibt_fl-1.0.0.dist-info/top_level.txt +1 -0
aibt/__init__.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""
|
|
2
|
+
AIBT: Adversarial Information Bottleneck Training for Privacy-Preserving Federated Learning
|
|
3
|
+
|
|
4
|
+
This package provides a complete implementation of the AIBT framework for
|
|
5
|
+
privacy-preserving federated learning, combining:
|
|
6
|
+
- Information Bottleneck (IB) for feature compression
|
|
7
|
+
- Adversarial Training for sensitive attribute suppression
|
|
8
|
+
- Federated Learning for distributed training
|
|
9
|
+
|
|
10
|
+
Example:
|
|
11
|
+
>>> from aibt import AIBTFL, create_aibt_model, evaluate_privacy
|
|
12
|
+
>>> model = create_aibt_model(input_dim=13, num_classes=2)
|
|
13
|
+
>>> aibt = AIBTFL(model=model, num_clients=10)
|
|
14
|
+
>>> history = aibt.train(num_rounds=100, local_epochs=5)
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
__version__ = "1.0.0"
|
|
18
|
+
__author__ = "AIBT Research Team"
|
|
19
|
+
|
|
20
|
+
# Core FL implementation
|
|
21
|
+
from aibt.core import AIBTFL
|
|
22
|
+
|
|
23
|
+
# Neural network models
|
|
24
|
+
from aibt.models import (
|
|
25
|
+
GradientReversalLayer,
|
|
26
|
+
GradientReversalFunction,
|
|
27
|
+
VariationalEncoder,
|
|
28
|
+
MLPEncoder,
|
|
29
|
+
Predictor,
|
|
30
|
+
Adversary,
|
|
31
|
+
AIBTModel,
|
|
32
|
+
create_aibt_model,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# Client implementation
|
|
36
|
+
from aibt.client import FLClient, AIBTClient
|
|
37
|
+
|
|
38
|
+
# Privacy metrics
|
|
39
|
+
from aibt.metrics import (
|
|
40
|
+
MembershipInferenceAttack,
|
|
41
|
+
AttributeInferenceAttack,
|
|
42
|
+
evaluate_membership_inference,
|
|
43
|
+
evaluate_attribute_inference,
|
|
44
|
+
evaluate_privacy,
|
|
45
|
+
evaluate_performance,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
# Aggregation utilities
|
|
49
|
+
from aibt.aggregation import fedavg_aggregate
|
|
50
|
+
|
|
51
|
+
__all__ = [
|
|
52
|
+
# Version
|
|
53
|
+
"__version__",
|
|
54
|
+
# Core
|
|
55
|
+
"AIBTFL",
|
|
56
|
+
# Models
|
|
57
|
+
"GradientReversalLayer",
|
|
58
|
+
"GradientReversalFunction",
|
|
59
|
+
"VariationalEncoder",
|
|
60
|
+
"MLPEncoder",
|
|
61
|
+
"Predictor",
|
|
62
|
+
"Adversary",
|
|
63
|
+
"AIBTModel",
|
|
64
|
+
"create_aibt_model",
|
|
65
|
+
# Clients
|
|
66
|
+
"FLClient",
|
|
67
|
+
"AIBTClient",
|
|
68
|
+
# Metrics
|
|
69
|
+
"MembershipInferenceAttack",
|
|
70
|
+
"AttributeInferenceAttack",
|
|
71
|
+
"evaluate_membership_inference",
|
|
72
|
+
"evaluate_attribute_inference",
|
|
73
|
+
"evaluate_privacy",
|
|
74
|
+
"evaluate_performance",
|
|
75
|
+
# Aggregation
|
|
76
|
+
"fedavg_aggregate",
|
|
77
|
+
]
|
aibt/aggregation.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""
|
|
2
|
+
FedAvg Aggregation utilities for AIBT.
|
|
3
|
+
|
|
4
|
+
Implements weighted averaging of model parameters for federated learning.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from typing import Dict, List
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def fedavg_aggregate(
|
|
12
|
+
client_params: List[Dict[str, torch.Tensor]],
|
|
13
|
+
client_sizes: List[int]
|
|
14
|
+
) -> Dict[str, torch.Tensor]:
|
|
15
|
+
"""
|
|
16
|
+
FedAvg aggregation for federated learning.
|
|
17
|
+
|
|
18
|
+
Computes weighted average of client model parameters:
|
|
19
|
+
w_global = Σ(n_k/n) * w_k
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
client_params: List of model state dicts from clients
|
|
23
|
+
client_sizes: List of dataset sizes per client
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Aggregated model state dict
|
|
27
|
+
"""
|
|
28
|
+
if not client_params:
|
|
29
|
+
raise ValueError("No client parameters provided")
|
|
30
|
+
|
|
31
|
+
total_size = sum(client_sizes)
|
|
32
|
+
aggregated = {}
|
|
33
|
+
|
|
34
|
+
for key in client_params[0].keys():
|
|
35
|
+
weighted_sum = torch.zeros_like(client_params[0][key], dtype=torch.float32)
|
|
36
|
+
|
|
37
|
+
for i, params in enumerate(client_params):
|
|
38
|
+
weight = client_sizes[i] / total_size
|
|
39
|
+
weighted_sum += weight * params[key].float()
|
|
40
|
+
|
|
41
|
+
aggregated[key] = weighted_sum
|
|
42
|
+
|
|
43
|
+
return aggregated
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def simple_average(
|
|
47
|
+
client_params: List[Dict[str, torch.Tensor]]
|
|
48
|
+
) -> Dict[str, torch.Tensor]:
|
|
49
|
+
"""
|
|
50
|
+
Simple averaging of client parameters (unweighted).
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
client_params: List of model state dicts from clients
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Aggregated model state dict
|
|
57
|
+
"""
|
|
58
|
+
if not client_params:
|
|
59
|
+
raise ValueError("No client parameters provided")
|
|
60
|
+
|
|
61
|
+
num_clients = len(client_params)
|
|
62
|
+
aggregated = {}
|
|
63
|
+
|
|
64
|
+
for key in client_params[0].keys():
|
|
65
|
+
stacked = torch.stack([params[key].float() for params in client_params])
|
|
66
|
+
aggregated[key] = stacked.mean(dim=0)
|
|
67
|
+
|
|
68
|
+
return aggregated
|
aibt/client.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Federated Learning Client implementations for AIBT.
|
|
3
|
+
|
|
4
|
+
Implements:
|
|
5
|
+
- FLClient: Base federated learning client
|
|
6
|
+
- AIBTClient: AIBT-specific client with IB and adversarial training
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
from torch.utils.data import DataLoader, TensorDataset
|
|
13
|
+
from typing import Dict, Tuple, Optional
|
|
14
|
+
import copy
|
|
15
|
+
import numpy as np
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class FLClient:
|
|
19
|
+
"""
|
|
20
|
+
Federated Learning Client.
|
|
21
|
+
Each client has local data and performs local training.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
client_id: int,
|
|
27
|
+
model: nn.Module,
|
|
28
|
+
train_data: Tuple[np.ndarray, np.ndarray],
|
|
29
|
+
sensitive_data: Optional[np.ndarray] = None,
|
|
30
|
+
batch_size: int = 32,
|
|
31
|
+
learning_rate: float = 0.001,
|
|
32
|
+
device: str = "cpu"
|
|
33
|
+
):
|
|
34
|
+
self.client_id = client_id
|
|
35
|
+
self.device = device
|
|
36
|
+
self.batch_size = batch_size
|
|
37
|
+
self.learning_rate = learning_rate
|
|
38
|
+
|
|
39
|
+
# Create local model (deep copy)
|
|
40
|
+
self.model = copy.deepcopy(model).to(device)
|
|
41
|
+
|
|
42
|
+
# Prepare data
|
|
43
|
+
X, y = train_data
|
|
44
|
+
self.n_samples = len(y)
|
|
45
|
+
|
|
46
|
+
# Check for sensitive attributes
|
|
47
|
+
self.has_sensitive = sensitive_data is not None
|
|
48
|
+
|
|
49
|
+
if self.has_sensitive:
|
|
50
|
+
dataset = TensorDataset(
|
|
51
|
+
torch.FloatTensor(X),
|
|
52
|
+
torch.LongTensor(y),
|
|
53
|
+
torch.LongTensor(sensitive_data)
|
|
54
|
+
)
|
|
55
|
+
else:
|
|
56
|
+
dataset = TensorDataset(
|
|
57
|
+
torch.FloatTensor(X),
|
|
58
|
+
torch.LongTensor(y)
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
self.dataloader = DataLoader(
|
|
62
|
+
dataset,
|
|
63
|
+
batch_size=batch_size,
|
|
64
|
+
shuffle=True,
|
|
65
|
+
drop_last=len(dataset) > batch_size
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# Optimizer
|
|
69
|
+
self.optimizer = torch.optim.Adam(
|
|
70
|
+
self.model.parameters(),
|
|
71
|
+
lr=learning_rate
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def set_model_params(self, params: Dict[str, torch.Tensor]) -> None:
|
|
75
|
+
"""Load global model parameters"""
|
|
76
|
+
self.model.load_state_dict(params)
|
|
77
|
+
|
|
78
|
+
def get_model_params(self) -> Dict[str, torch.Tensor]:
|
|
79
|
+
"""Get local model parameters"""
|
|
80
|
+
return copy.deepcopy(self.model.state_dict())
|
|
81
|
+
|
|
82
|
+
def train_epoch(self) -> Dict[str, float]:
|
|
83
|
+
"""Train for one epoch"""
|
|
84
|
+
self.model.train()
|
|
85
|
+
total_loss = 0.0
|
|
86
|
+
correct = 0
|
|
87
|
+
total = 0
|
|
88
|
+
|
|
89
|
+
for batch in self.dataloader:
|
|
90
|
+
if self.has_sensitive:
|
|
91
|
+
X, y, s = batch
|
|
92
|
+
s = s.to(self.device)
|
|
93
|
+
else:
|
|
94
|
+
X, y = batch
|
|
95
|
+
s = None
|
|
96
|
+
|
|
97
|
+
X, y = X.to(self.device), y.to(self.device)
|
|
98
|
+
|
|
99
|
+
self.optimizer.zero_grad()
|
|
100
|
+
|
|
101
|
+
# Forward pass
|
|
102
|
+
output = self.model(X)
|
|
103
|
+
if isinstance(output, tuple):
|
|
104
|
+
y_pred = output[0]
|
|
105
|
+
else:
|
|
106
|
+
y_pred = output
|
|
107
|
+
|
|
108
|
+
loss = F.cross_entropy(y_pred, y)
|
|
109
|
+
loss.backward()
|
|
110
|
+
self.optimizer.step()
|
|
111
|
+
|
|
112
|
+
total_loss += loss.item() * len(y)
|
|
113
|
+
pred = y_pred.argmax(dim=1)
|
|
114
|
+
correct += (pred == y).sum().item()
|
|
115
|
+
total += len(y)
|
|
116
|
+
|
|
117
|
+
return {
|
|
118
|
+
"loss": total_loss / total,
|
|
119
|
+
"accuracy": correct / total
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
def train(self, local_epochs: int) -> Dict[str, float]:
|
|
123
|
+
"""Train for multiple epochs"""
|
|
124
|
+
metrics = {}
|
|
125
|
+
for epoch in range(local_epochs):
|
|
126
|
+
metrics = self.train_epoch()
|
|
127
|
+
return metrics
|
|
128
|
+
|
|
129
|
+
def evaluate(self, test_data: Tuple[np.ndarray, np.ndarray]) -> Dict[str, float]:
|
|
130
|
+
"""Evaluate model on test data"""
|
|
131
|
+
self.model.eval()
|
|
132
|
+
X, y = test_data
|
|
133
|
+
X_tensor = torch.FloatTensor(X).to(self.device)
|
|
134
|
+
y_tensor = torch.LongTensor(y).to(self.device)
|
|
135
|
+
|
|
136
|
+
with torch.no_grad():
|
|
137
|
+
output = self.model(X_tensor)
|
|
138
|
+
if isinstance(output, tuple):
|
|
139
|
+
logits = output[0]
|
|
140
|
+
else:
|
|
141
|
+
logits = output
|
|
142
|
+
|
|
143
|
+
loss = F.cross_entropy(logits, y_tensor).item()
|
|
144
|
+
pred = logits.argmax(dim=1)
|
|
145
|
+
accuracy = (pred == y_tensor).float().mean().item()
|
|
146
|
+
|
|
147
|
+
return {"test_loss": loss, "test_accuracy": accuracy}
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class AIBTClient(FLClient):
|
|
151
|
+
"""
|
|
152
|
+
AIBT Client with Information Bottleneck and Adversarial Training.
|
|
153
|
+
|
|
154
|
+
Loss = L_task + λ₁ L_KL - λ₂ L_adv
|
|
155
|
+
|
|
156
|
+
Training procedure:
|
|
157
|
+
1. Forward pass through encoder → z (with reparameterization)
|
|
158
|
+
2. Task prediction via predictor
|
|
159
|
+
3. Adversarial prediction via GRL + adversary
|
|
160
|
+
4. Compute combined loss
|
|
161
|
+
5. Update encoder and predictor
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
def __init__(
|
|
165
|
+
self,
|
|
166
|
+
*args,
|
|
167
|
+
lambda_kl: float = 0.01,
|
|
168
|
+
lambda_adv: float = 1.0,
|
|
169
|
+
**kwargs
|
|
170
|
+
):
|
|
171
|
+
super().__init__(*args, **kwargs)
|
|
172
|
+
self.lambda_kl = lambda_kl
|
|
173
|
+
self.lambda_adv = lambda_adv
|
|
174
|
+
|
|
175
|
+
def train_epoch(self) -> Dict[str, float]:
|
|
176
|
+
"""
|
|
177
|
+
Train for one epoch using AIBT algorithm.
|
|
178
|
+
|
|
179
|
+
Algorithm 1/2 from the paper:
|
|
180
|
+
For each batch:
|
|
181
|
+
1. z = Encoder(x) with reparameterization
|
|
182
|
+
2. y_hat = Predictor(z)
|
|
183
|
+
3. s_hat = Adversary(GRL(z))
|
|
184
|
+
4. L = L_task + λ₁ L_KL - λ₂ L_adv
|
|
185
|
+
5. Update encoder and predictor
|
|
186
|
+
"""
|
|
187
|
+
self.model.train()
|
|
188
|
+
total_loss = 0.0
|
|
189
|
+
total_task_loss = 0.0
|
|
190
|
+
total_kl_loss = 0.0
|
|
191
|
+
total_adv_loss = 0.0
|
|
192
|
+
correct = 0
|
|
193
|
+
total = 0
|
|
194
|
+
|
|
195
|
+
for batch in self.dataloader:
|
|
196
|
+
if self.has_sensitive:
|
|
197
|
+
X, y, s = batch
|
|
198
|
+
s = s.to(self.device)
|
|
199
|
+
else:
|
|
200
|
+
X, y = batch
|
|
201
|
+
s = None
|
|
202
|
+
|
|
203
|
+
X, y = X.to(self.device), y.to(self.device)
|
|
204
|
+
|
|
205
|
+
self.optimizer.zero_grad()
|
|
206
|
+
|
|
207
|
+
# Forward pass through AIBT model
|
|
208
|
+
if hasattr(self.model, 'compute_loss'):
|
|
209
|
+
# Using AIBTModel with built-in loss computation
|
|
210
|
+
loss, loss_dict = self.model.compute_loss(X, y, s)
|
|
211
|
+
|
|
212
|
+
# Get predictions for accuracy
|
|
213
|
+
with torch.no_grad():
|
|
214
|
+
y_pred, _, _, _ = self.model(X)
|
|
215
|
+
pred = y_pred.argmax(dim=1)
|
|
216
|
+
|
|
217
|
+
total_task_loss += loss_dict["task_loss"] * len(y)
|
|
218
|
+
total_kl_loss += loss_dict["kl_loss"] * len(y)
|
|
219
|
+
total_adv_loss += loss_dict["adv_loss"] * len(y)
|
|
220
|
+
else:
|
|
221
|
+
# Manual computation for non-AIBT models
|
|
222
|
+
output = self.model(X)
|
|
223
|
+
if isinstance(output, tuple):
|
|
224
|
+
y_pred = output[0]
|
|
225
|
+
else:
|
|
226
|
+
y_pred = output
|
|
227
|
+
|
|
228
|
+
loss = F.cross_entropy(y_pred, y)
|
|
229
|
+
pred = y_pred.argmax(dim=1)
|
|
230
|
+
|
|
231
|
+
loss.backward()
|
|
232
|
+
self.optimizer.step()
|
|
233
|
+
|
|
234
|
+
total_loss += loss.item() * len(y)
|
|
235
|
+
correct += (pred == y).sum().item()
|
|
236
|
+
total += len(y)
|
|
237
|
+
|
|
238
|
+
return {
|
|
239
|
+
"loss": total_loss / total,
|
|
240
|
+
"task_loss": total_task_loss / total if total_task_loss > 0 else 0,
|
|
241
|
+
"kl_loss": total_kl_loss / total if total_kl_loss > 0 else 0,
|
|
242
|
+
"adv_loss": total_adv_loss / total if total_adv_loss > 0 else 0,
|
|
243
|
+
"accuracy": correct / total
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
def get_latent_representations(self, X: np.ndarray) -> np.ndarray:
|
|
247
|
+
"""Get latent representations for privacy analysis"""
|
|
248
|
+
self.model.eval()
|
|
249
|
+
X_tensor = torch.FloatTensor(X).to(self.device)
|
|
250
|
+
|
|
251
|
+
with torch.no_grad():
|
|
252
|
+
if hasattr(self.model, 'get_latent'):
|
|
253
|
+
z = self.model.get_latent(X_tensor)
|
|
254
|
+
else:
|
|
255
|
+
z = self.model.encoder(X_tensor)
|
|
256
|
+
if isinstance(z, tuple):
|
|
257
|
+
z = z[0]
|
|
258
|
+
|
|
259
|
+
return z.cpu().numpy()
|
aibt/core.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
1
|
+
"""
|
|
2
|
+
AIBT FL: Adversarial Information Bottleneck Training for Federated Learning
|
|
3
|
+
|
|
4
|
+
Main federated learning implementation combining:
|
|
5
|
+
- Information Bottleneck (IB) for feature compression
|
|
6
|
+
- Adversarial Training for sensitive attribute suppression
|
|
7
|
+
- Federated Averaging for model aggregation
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn as nn
|
|
12
|
+
import torch.nn.functional as F
|
|
13
|
+
from typing import Dict, List, Tuple, Optional, Any
|
|
14
|
+
import copy
|
|
15
|
+
import numpy as np
|
|
16
|
+
|
|
17
|
+
from aibt.client import AIBTClient, FLClient
|
|
18
|
+
from aibt.aggregation import fedavg_aggregate
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class AIBTFL:
|
|
22
|
+
"""
|
|
23
|
+
AIBT FL: Adversarial Information Bottleneck Training for Federated Learning.
|
|
24
|
+
|
|
25
|
+
Combines:
|
|
26
|
+
- Information Bottleneck (IB) for feature compression
|
|
27
|
+
- Adversarial Training for sensitive attribute suppression
|
|
28
|
+
- Federated Learning for distributed training
|
|
29
|
+
|
|
30
|
+
Loss: L = L_task + λ₁ L_KL - λ₂ L_adv
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
model: Neural network model (preferably AIBTModel)
|
|
34
|
+
num_clients: Number of federated learning clients
|
|
35
|
+
device: Device for computation ('cpu' or 'cuda')
|
|
36
|
+
learning_rate: Learning rate for local training
|
|
37
|
+
batch_size: Batch size for local training
|
|
38
|
+
lambda_kl: Weight for KL divergence loss (Information Bottleneck)
|
|
39
|
+
lambda_adv: Weight for adversarial loss
|
|
40
|
+
lambda_grl: Strength of gradient reversal
|
|
41
|
+
latent_dim: Dimension of latent representation
|
|
42
|
+
num_sensitive_classes: Number of sensitive attribute classes
|
|
43
|
+
|
|
44
|
+
Example:
|
|
45
|
+
>>> from aibt import AIBTFL, create_aibt_model
|
|
46
|
+
>>> model = create_aibt_model(input_dim=13, num_classes=2)
|
|
47
|
+
>>> aibt = AIBTFL(model=model, num_clients=10)
|
|
48
|
+
>>> aibt.setup_clients(client_datasets)
|
|
49
|
+
>>> history = aibt.train(num_rounds=100, local_epochs=5)
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
model: nn.Module,
|
|
55
|
+
num_clients: int,
|
|
56
|
+
device: str = "cpu",
|
|
57
|
+
learning_rate: float = 0.001,
|
|
58
|
+
batch_size: int = 32,
|
|
59
|
+
lambda_kl: float = 0.01,
|
|
60
|
+
lambda_adv: float = 1.0,
|
|
61
|
+
lambda_grl: float = 1.0,
|
|
62
|
+
latent_dim: int = 128,
|
|
63
|
+
num_sensitive_classes: int = 2
|
|
64
|
+
):
|
|
65
|
+
self.device = device
|
|
66
|
+
self.num_clients = num_clients
|
|
67
|
+
self.learning_rate = learning_rate
|
|
68
|
+
self.batch_size = batch_size
|
|
69
|
+
|
|
70
|
+
self.lambda_kl = lambda_kl
|
|
71
|
+
self.lambda_adv = lambda_adv
|
|
72
|
+
self.lambda_grl = lambda_grl
|
|
73
|
+
self.latent_dim = latent_dim
|
|
74
|
+
self.num_sensitive_classes = num_sensitive_classes
|
|
75
|
+
self.method_name = "AIBT"
|
|
76
|
+
|
|
77
|
+
# Global model
|
|
78
|
+
self.global_model = copy.deepcopy(model).to(device)
|
|
79
|
+
|
|
80
|
+
# Update model parameters if it's an AIBT model
|
|
81
|
+
if hasattr(self.global_model, 'lambda_kl'):
|
|
82
|
+
self.global_model.lambda_kl = lambda_kl
|
|
83
|
+
self.global_model.lambda_adv = lambda_adv
|
|
84
|
+
if hasattr(self.global_model, 'grl'):
|
|
85
|
+
self.global_model.grl.set_lambda(lambda_grl)
|
|
86
|
+
|
|
87
|
+
# Clients (initialized later)
|
|
88
|
+
self.clients: List[AIBTClient] = []
|
|
89
|
+
|
|
90
|
+
# Metrics tracking
|
|
91
|
+
self.current_round = 0
|
|
92
|
+
self.round_metrics: List[Dict[str, float]] = []
|
|
93
|
+
|
|
94
|
+
def setup_clients(
|
|
95
|
+
self,
|
|
96
|
+
client_datasets: List[Tuple[np.ndarray, np.ndarray]],
|
|
97
|
+
sensitive_data: Optional[List[np.ndarray]] = None
|
|
98
|
+
) -> None:
|
|
99
|
+
"""
|
|
100
|
+
Initialize clients with their local data.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
client_datasets: List of (X, y) tuples, one per client
|
|
104
|
+
sensitive_data: Optional list of sensitive attribute arrays
|
|
105
|
+
"""
|
|
106
|
+
self.clients = []
|
|
107
|
+
|
|
108
|
+
for i, data in enumerate(client_datasets):
|
|
109
|
+
s_data = sensitive_data[i] if sensitive_data else None
|
|
110
|
+
client = self._create_client(i, data, s_data)
|
|
111
|
+
self.clients.append(client)
|
|
112
|
+
|
|
113
|
+
def _create_client(
|
|
114
|
+
self,
|
|
115
|
+
client_id: int,
|
|
116
|
+
train_data: Tuple[np.ndarray, np.ndarray],
|
|
117
|
+
sensitive_data: Optional[np.ndarray] = None
|
|
118
|
+
) -> AIBTClient:
|
|
119
|
+
"""Create an AIBT client"""
|
|
120
|
+
return AIBTClient(
|
|
121
|
+
client_id=client_id,
|
|
122
|
+
model=self.global_model,
|
|
123
|
+
train_data=train_data,
|
|
124
|
+
sensitive_data=sensitive_data,
|
|
125
|
+
batch_size=self.batch_size,
|
|
126
|
+
learning_rate=self.learning_rate,
|
|
127
|
+
device=self.device,
|
|
128
|
+
lambda_kl=self.lambda_kl,
|
|
129
|
+
lambda_adv=self.lambda_adv
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
def distribute_model(self) -> None:
|
|
133
|
+
"""Send global model to all clients"""
|
|
134
|
+
global_params = self.global_model.state_dict()
|
|
135
|
+
for client in self.clients:
|
|
136
|
+
client.set_model_params(copy.deepcopy(global_params))
|
|
137
|
+
|
|
138
|
+
def train_round(self, local_epochs: int) -> Dict[str, float]:
|
|
139
|
+
"""
|
|
140
|
+
Execute one round of AIBT federated training.
|
|
141
|
+
|
|
142
|
+
Following Algorithm 3 from the paper:
|
|
143
|
+
1. Distribute global model to clients
|
|
144
|
+
2. Each client trains using Algorithm 1/2
|
|
145
|
+
3. Aggregate updates using FedAvg
|
|
146
|
+
4. Update global model
|
|
147
|
+
"""
|
|
148
|
+
self.current_round += 1
|
|
149
|
+
|
|
150
|
+
# Distribute global model
|
|
151
|
+
self.distribute_model()
|
|
152
|
+
|
|
153
|
+
# Local training
|
|
154
|
+
client_params = []
|
|
155
|
+
client_sizes = []
|
|
156
|
+
client_metrics = []
|
|
157
|
+
|
|
158
|
+
for client in self.clients:
|
|
159
|
+
# Local AIBT training
|
|
160
|
+
metrics = client.train(local_epochs)
|
|
161
|
+
client_metrics.append(metrics)
|
|
162
|
+
|
|
163
|
+
# Collect parameters
|
|
164
|
+
client_params.append(client.get_model_params())
|
|
165
|
+
client_sizes.append(client.n_samples)
|
|
166
|
+
|
|
167
|
+
# Aggregate using FedAvg
|
|
168
|
+
aggregated_params = fedavg_aggregate(client_params, client_sizes)
|
|
169
|
+
self.global_model.load_state_dict(aggregated_params)
|
|
170
|
+
|
|
171
|
+
# Compute round metrics
|
|
172
|
+
avg_loss = np.mean([m["loss"] for m in client_metrics])
|
|
173
|
+
avg_acc = np.mean([m["accuracy"] for m in client_metrics])
|
|
174
|
+
avg_task_loss = np.mean([m.get("task_loss", 0) for m in client_metrics])
|
|
175
|
+
avg_kl_loss = np.mean([m.get("kl_loss", 0) for m in client_metrics])
|
|
176
|
+
avg_adv_loss = np.mean([m.get("adv_loss", 0) for m in client_metrics])
|
|
177
|
+
|
|
178
|
+
round_metrics = {
|
|
179
|
+
"round": self.current_round,
|
|
180
|
+
"train_loss": avg_loss,
|
|
181
|
+
"train_accuracy": avg_acc,
|
|
182
|
+
"task_loss": avg_task_loss,
|
|
183
|
+
"kl_loss": avg_kl_loss,
|
|
184
|
+
"adv_loss": avg_adv_loss,
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
self.round_metrics.append(round_metrics)
|
|
188
|
+
return round_metrics
|
|
189
|
+
|
|
190
|
+
def evaluate(self, test_data: Tuple[np.ndarray, np.ndarray]) -> Dict[str, float]:
|
|
191
|
+
"""Evaluate global model on test data"""
|
|
192
|
+
self.global_model.eval()
|
|
193
|
+
|
|
194
|
+
X, y = test_data
|
|
195
|
+
X_tensor = torch.FloatTensor(X).to(self.device)
|
|
196
|
+
y_tensor = torch.LongTensor(y).to(self.device)
|
|
197
|
+
|
|
198
|
+
with torch.no_grad():
|
|
199
|
+
output = self.global_model(X_tensor)
|
|
200
|
+
if isinstance(output, tuple):
|
|
201
|
+
logits = output[0]
|
|
202
|
+
else:
|
|
203
|
+
logits = output
|
|
204
|
+
|
|
205
|
+
loss = F.cross_entropy(logits, y_tensor).item()
|
|
206
|
+
pred = logits.argmax(dim=1)
|
|
207
|
+
accuracy = (pred == y_tensor).float().mean().item()
|
|
208
|
+
|
|
209
|
+
return {"test_loss": loss, "test_accuracy": accuracy}
|
|
210
|
+
|
|
211
|
+
def get_latent_representations(self, X: np.ndarray) -> np.ndarray:
|
|
212
|
+
"""Get latent representations from global model"""
|
|
213
|
+
self.global_model.eval()
|
|
214
|
+
X_tensor = torch.FloatTensor(X).to(self.device)
|
|
215
|
+
|
|
216
|
+
with torch.no_grad():
|
|
217
|
+
if hasattr(self.global_model, 'get_latent'):
|
|
218
|
+
z = self.global_model.get_latent(X_tensor)
|
|
219
|
+
elif hasattr(self.global_model, 'encoder'):
|
|
220
|
+
z = self.global_model.encoder(X_tensor)
|
|
221
|
+
if isinstance(z, tuple):
|
|
222
|
+
z = z[0]
|
|
223
|
+
else:
|
|
224
|
+
raise AttributeError("Model does not have encoder")
|
|
225
|
+
|
|
226
|
+
return z.cpu().numpy()
|
|
227
|
+
|
|
228
|
+
def train(
|
|
229
|
+
self,
|
|
230
|
+
num_rounds: int,
|
|
231
|
+
local_epochs: int,
|
|
232
|
+
test_data: Optional[Tuple[np.ndarray, np.ndarray]] = None,
|
|
233
|
+
verbose: bool = True
|
|
234
|
+
) -> Dict[str, Any]:
|
|
235
|
+
"""
|
|
236
|
+
Complete AIBT federated training loop.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
num_rounds: Number of federated learning rounds
|
|
240
|
+
local_epochs: Number of local training epochs per round
|
|
241
|
+
test_data: Optional (X_test, y_test) for evaluation
|
|
242
|
+
verbose: Print progress
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
Dictionary with training history and final metrics
|
|
246
|
+
"""
|
|
247
|
+
history = {
|
|
248
|
+
"round_metrics": [],
|
|
249
|
+
"test_metrics": []
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
for round_idx in range(num_rounds):
|
|
253
|
+
# Training round
|
|
254
|
+
round_metrics = self.train_round(local_epochs)
|
|
255
|
+
|
|
256
|
+
# Evaluation
|
|
257
|
+
if test_data is not None:
|
|
258
|
+
test_metrics = self.evaluate(test_data)
|
|
259
|
+
round_metrics.update(test_metrics)
|
|
260
|
+
|
|
261
|
+
history["round_metrics"].append(round_metrics)
|
|
262
|
+
|
|
263
|
+
if verbose:
|
|
264
|
+
print(f"Round {round_idx + 1}/{num_rounds} | "
|
|
265
|
+
f"Loss: {round_metrics['train_loss']:.4f} | "
|
|
266
|
+
f"Acc: {round_metrics['train_accuracy']:.4f}", end="")
|
|
267
|
+
if test_data is not None:
|
|
268
|
+
print(f" | Test Acc: {round_metrics.get('test_accuracy', 0):.4f}")
|
|
269
|
+
else:
|
|
270
|
+
print()
|
|
271
|
+
|
|
272
|
+
# Final evaluation
|
|
273
|
+
if test_data is not None:
|
|
274
|
+
history["final_test_metrics"] = self.evaluate(test_data)
|
|
275
|
+
|
|
276
|
+
return history
|
|
277
|
+
|
|
278
|
+
def get_model(self) -> nn.Module:
|
|
279
|
+
"""Get the global model"""
|
|
280
|
+
return self.global_model
|
|
281
|
+
|
|
282
|
+
def get_communication_cost(self) -> int:
|
|
283
|
+
"""Calculate communication cost in bytes"""
|
|
284
|
+
total_bytes = 0
|
|
285
|
+
for param in self.global_model.parameters():
|
|
286
|
+
total_bytes += param.numel() * param.element_size()
|
|
287
|
+
return total_bytes
|