aibt-fl 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aibt/__init__.py +77 -0
- aibt/aggregation.py +68 -0
- aibt/client.py +259 -0
- aibt/core.py +287 -0
- aibt/metrics.py +383 -0
- aibt/models.py +520 -0
- aibt/py.typed +2 -0
- aibt/utils.py +162 -0
- aibt_fl-1.0.0.dist-info/METADATA +247 -0
- aibt_fl-1.0.0.dist-info/RECORD +13 -0
- aibt_fl-1.0.0.dist-info/WHEEL +5 -0
- aibt_fl-1.0.0.dist-info/licenses/LICENSE +21 -0
- aibt_fl-1.0.0.dist-info/top_level.txt +1 -0
aibt/models.py
ADDED
|
@@ -0,0 +1,520 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Neural Network Components for AIBT Framework
|
|
3
|
+
|
|
4
|
+
Implements:
|
|
5
|
+
- Gradient Reversal Layer (GRL) for adversarial training
|
|
6
|
+
- Variational Encoder for Information Bottleneck
|
|
7
|
+
- MLP Encoder for tabular data
|
|
8
|
+
- Predictor (Task Head)
|
|
9
|
+
- Adversary (Sensitive Attribute Classifier)
|
|
10
|
+
- Complete AIBTModel
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
import torch.nn as nn
|
|
15
|
+
import torch.nn.functional as F
|
|
16
|
+
from typing import Tuple, List, Optional
|
|
17
|
+
import math
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# ============================================================================
|
|
21
|
+
# GRADIENT REVERSAL LAYER
|
|
22
|
+
# ============================================================================
|
|
23
|
+
|
|
24
|
+
class GradientReversalFunction(torch.autograd.Function):
|
|
25
|
+
"""
|
|
26
|
+
Gradient Reversal Layer for adversarial training.
|
|
27
|
+
Forward pass: identity function
|
|
28
|
+
Backward pass: negate gradients and scale by lambda
|
|
29
|
+
|
|
30
|
+
Reference: Ganin & Lempitsky, "Domain-Adversarial Training of Neural Networks"
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
@staticmethod
|
|
34
|
+
def forward(ctx, x: torch.Tensor, lambda_grl: float) -> torch.Tensor:
|
|
35
|
+
ctx.lambda_grl = lambda_grl
|
|
36
|
+
return x.view_as(x)
|
|
37
|
+
|
|
38
|
+
@staticmethod
|
|
39
|
+
def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:
|
|
40
|
+
return -ctx.lambda_grl * grad_output, None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class GradientReversalLayer(nn.Module):
|
|
44
|
+
"""Wrapper module for Gradient Reversal"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, lambda_grl: float = 1.0):
|
|
47
|
+
super().__init__()
|
|
48
|
+
self.lambda_grl = lambda_grl
|
|
49
|
+
|
|
50
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
51
|
+
return GradientReversalFunction.apply(x, self.lambda_grl)
|
|
52
|
+
|
|
53
|
+
def set_lambda(self, lambda_grl: float) -> None:
|
|
54
|
+
"""Update lambda value (useful for scheduling)"""
|
|
55
|
+
self.lambda_grl = lambda_grl
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def grad_reverse(x: torch.Tensor, lambda_grl: float = 1.0) -> torch.Tensor:
|
|
59
|
+
"""Functional version of gradient reversal"""
|
|
60
|
+
return GradientReversalFunction.apply(x, lambda_grl)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
# ============================================================================
|
|
64
|
+
# VARIATIONAL ENCODER (Information Bottleneck)
|
|
65
|
+
# ============================================================================
|
|
66
|
+
|
|
67
|
+
class VariationalEncoder(nn.Module):
|
|
68
|
+
"""
|
|
69
|
+
Variational Encoder for Information Bottleneck.
|
|
70
|
+
Outputs mean (mu) and log-variance (logvar) for reparameterization trick.
|
|
71
|
+
|
|
72
|
+
z = mu + sigma * epsilon, where epsilon ~ N(0, I)
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
input_dim: int,
|
|
78
|
+
hidden_dims: List[int],
|
|
79
|
+
latent_dim: int,
|
|
80
|
+
dropout: float = 0.1
|
|
81
|
+
):
|
|
82
|
+
super().__init__()
|
|
83
|
+
|
|
84
|
+
self.input_dim = input_dim
|
|
85
|
+
self.latent_dim = latent_dim
|
|
86
|
+
|
|
87
|
+
# Build encoder layers
|
|
88
|
+
layers = []
|
|
89
|
+
prev_dim = input_dim
|
|
90
|
+
|
|
91
|
+
for hidden_dim in hidden_dims:
|
|
92
|
+
layers.extend([
|
|
93
|
+
nn.Linear(prev_dim, hidden_dim),
|
|
94
|
+
nn.BatchNorm1d(hidden_dim),
|
|
95
|
+
nn.ReLU(inplace=True),
|
|
96
|
+
nn.Dropout(dropout)
|
|
97
|
+
])
|
|
98
|
+
prev_dim = hidden_dim
|
|
99
|
+
|
|
100
|
+
self.encoder = nn.Sequential(*layers)
|
|
101
|
+
|
|
102
|
+
# Mean and log-variance heads
|
|
103
|
+
self.fc_mu = nn.Linear(prev_dim, latent_dim)
|
|
104
|
+
self.fc_logvar = nn.Linear(prev_dim, latent_dim)
|
|
105
|
+
|
|
106
|
+
self._init_weights()
|
|
107
|
+
|
|
108
|
+
def _init_weights(self) -> None:
|
|
109
|
+
for m in self.modules():
|
|
110
|
+
if isinstance(m, nn.Linear):
|
|
111
|
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
112
|
+
if m.bias is not None:
|
|
113
|
+
nn.init.constant_(m.bias, 0)
|
|
114
|
+
|
|
115
|
+
def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
116
|
+
"""Encode input to mu and logvar"""
|
|
117
|
+
h = self.encoder(x)
|
|
118
|
+
mu = self.fc_mu(h)
|
|
119
|
+
logvar = self.fc_logvar(h)
|
|
120
|
+
return mu, logvar
|
|
121
|
+
|
|
122
|
+
def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
|
|
123
|
+
"""
|
|
124
|
+
Reparameterization trick: z = mu + sigma * epsilon
|
|
125
|
+
"""
|
|
126
|
+
if self.training:
|
|
127
|
+
std = torch.exp(0.5 * logvar)
|
|
128
|
+
eps = torch.randn_like(std)
|
|
129
|
+
return mu + eps * std
|
|
130
|
+
return mu
|
|
131
|
+
|
|
132
|
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
133
|
+
"""
|
|
134
|
+
Forward pass.
|
|
135
|
+
Returns: z (latent), mu, logvar
|
|
136
|
+
"""
|
|
137
|
+
mu, logvar = self.encode(x)
|
|
138
|
+
z = self.reparameterize(mu, logvar)
|
|
139
|
+
return z, mu, logvar
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
# ============================================================================
|
|
143
|
+
# MLP ENCODER (For Tabular Data)
|
|
144
|
+
# ============================================================================
|
|
145
|
+
|
|
146
|
+
class MLPEncoder(nn.Module):
|
|
147
|
+
"""
|
|
148
|
+
MLP Encoder for tabular datasets.
|
|
149
|
+
Can be variational or deterministic.
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
def __init__(
|
|
153
|
+
self,
|
|
154
|
+
input_dim: int,
|
|
155
|
+
hidden_dims: List[int],
|
|
156
|
+
latent_dim: int,
|
|
157
|
+
variational: bool = True,
|
|
158
|
+
dropout: float = 0.1
|
|
159
|
+
):
|
|
160
|
+
super().__init__()
|
|
161
|
+
|
|
162
|
+
self.input_dim = input_dim
|
|
163
|
+
self.latent_dim = latent_dim
|
|
164
|
+
self.variational = variational
|
|
165
|
+
|
|
166
|
+
# Build encoder layers
|
|
167
|
+
layers = []
|
|
168
|
+
prev_dim = input_dim
|
|
169
|
+
|
|
170
|
+
for hidden_dim in hidden_dims:
|
|
171
|
+
layers.extend([
|
|
172
|
+
nn.Linear(prev_dim, hidden_dim),
|
|
173
|
+
nn.BatchNorm1d(hidden_dim),
|
|
174
|
+
nn.ReLU(inplace=True),
|
|
175
|
+
nn.Dropout(dropout)
|
|
176
|
+
])
|
|
177
|
+
prev_dim = hidden_dim
|
|
178
|
+
|
|
179
|
+
self.encoder = nn.Sequential(*layers)
|
|
180
|
+
|
|
181
|
+
if variational:
|
|
182
|
+
self.fc_mu = nn.Linear(prev_dim, latent_dim)
|
|
183
|
+
self.fc_logvar = nn.Linear(prev_dim, latent_dim)
|
|
184
|
+
else:
|
|
185
|
+
self.fc_out = nn.Linear(prev_dim, latent_dim)
|
|
186
|
+
|
|
187
|
+
self._init_weights()
|
|
188
|
+
|
|
189
|
+
def _init_weights(self) -> None:
|
|
190
|
+
for m in self.modules():
|
|
191
|
+
if isinstance(m, nn.Linear):
|
|
192
|
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
193
|
+
if m.bias is not None:
|
|
194
|
+
nn.init.constant_(m.bias, 0)
|
|
195
|
+
|
|
196
|
+
def forward(self, x: torch.Tensor):
|
|
197
|
+
h = self.encoder(x)
|
|
198
|
+
|
|
199
|
+
if self.variational:
|
|
200
|
+
mu = self.fc_mu(h)
|
|
201
|
+
logvar = self.fc_logvar(h)
|
|
202
|
+
|
|
203
|
+
if self.training:
|
|
204
|
+
std = torch.exp(0.5 * logvar)
|
|
205
|
+
eps = torch.randn_like(std)
|
|
206
|
+
z = mu + eps * std
|
|
207
|
+
else:
|
|
208
|
+
z = mu
|
|
209
|
+
|
|
210
|
+
return z, mu, logvar
|
|
211
|
+
else:
|
|
212
|
+
return self.fc_out(h)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
# ============================================================================
|
|
216
|
+
# PREDICTOR (Task Head)
|
|
217
|
+
# ============================================================================
|
|
218
|
+
|
|
219
|
+
class Predictor(nn.Module):
|
|
220
|
+
"""
|
|
221
|
+
Predictor network for task-specific predictions.
|
|
222
|
+
Takes latent representation z and outputs class probabilities.
|
|
223
|
+
"""
|
|
224
|
+
|
|
225
|
+
def __init__(
|
|
226
|
+
self,
|
|
227
|
+
latent_dim: int,
|
|
228
|
+
hidden_dims: List[int],
|
|
229
|
+
num_classes: int,
|
|
230
|
+
dropout: float = 0.1
|
|
231
|
+
):
|
|
232
|
+
super().__init__()
|
|
233
|
+
|
|
234
|
+
layers = []
|
|
235
|
+
prev_dim = latent_dim
|
|
236
|
+
|
|
237
|
+
for hidden_dim in hidden_dims:
|
|
238
|
+
layers.extend([
|
|
239
|
+
nn.Linear(prev_dim, hidden_dim),
|
|
240
|
+
nn.ReLU(inplace=True),
|
|
241
|
+
nn.Dropout(dropout)
|
|
242
|
+
])
|
|
243
|
+
prev_dim = hidden_dim
|
|
244
|
+
|
|
245
|
+
layers.append(nn.Linear(prev_dim, num_classes))
|
|
246
|
+
|
|
247
|
+
self.predictor = nn.Sequential(*layers)
|
|
248
|
+
self._init_weights()
|
|
249
|
+
|
|
250
|
+
def _init_weights(self) -> None:
|
|
251
|
+
for m in self.modules():
|
|
252
|
+
if isinstance(m, nn.Linear):
|
|
253
|
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
254
|
+
if m.bias is not None:
|
|
255
|
+
nn.init.constant_(m.bias, 0)
|
|
256
|
+
|
|
257
|
+
def forward(self, z: torch.Tensor) -> torch.Tensor:
|
|
258
|
+
return self.predictor(z)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
# ============================================================================
|
|
262
|
+
# ADVERSARY (Sensitive Attribute Classifier)
|
|
263
|
+
# ============================================================================
|
|
264
|
+
|
|
265
|
+
class Adversary(nn.Module):
|
|
266
|
+
"""
|
|
267
|
+
Adversary network for predicting sensitive attributes.
|
|
268
|
+
Used with GRL to enforce privacy in latent representation.
|
|
269
|
+
"""
|
|
270
|
+
|
|
271
|
+
def __init__(
|
|
272
|
+
self,
|
|
273
|
+
latent_dim: int,
|
|
274
|
+
hidden_dims: List[int],
|
|
275
|
+
num_sensitive_classes: int,
|
|
276
|
+
dropout: float = 0.1
|
|
277
|
+
):
|
|
278
|
+
super().__init__()
|
|
279
|
+
|
|
280
|
+
layers = []
|
|
281
|
+
prev_dim = latent_dim
|
|
282
|
+
|
|
283
|
+
for hidden_dim in hidden_dims:
|
|
284
|
+
layers.extend([
|
|
285
|
+
nn.Linear(prev_dim, hidden_dim),
|
|
286
|
+
nn.ReLU(inplace=True),
|
|
287
|
+
nn.Dropout(dropout)
|
|
288
|
+
])
|
|
289
|
+
prev_dim = hidden_dim
|
|
290
|
+
|
|
291
|
+
layers.append(nn.Linear(prev_dim, num_sensitive_classes))
|
|
292
|
+
|
|
293
|
+
self.adversary = nn.Sequential(*layers)
|
|
294
|
+
self._init_weights()
|
|
295
|
+
|
|
296
|
+
def _init_weights(self) -> None:
|
|
297
|
+
for m in self.modules():
|
|
298
|
+
if isinstance(m, nn.Linear):
|
|
299
|
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
300
|
+
if m.bias is not None:
|
|
301
|
+
nn.init.constant_(m.bias, 0)
|
|
302
|
+
|
|
303
|
+
def forward(self, z: torch.Tensor) -> torch.Tensor:
|
|
304
|
+
return self.adversary(z)
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
# ============================================================================
|
|
308
|
+
# COMPLETE AIBT MODEL
|
|
309
|
+
# ============================================================================
|
|
310
|
+
|
|
311
|
+
class AIBTModel(nn.Module):
|
|
312
|
+
"""
|
|
313
|
+
Complete AIBT (Adversarial Information Bottleneck Training) Model.
|
|
314
|
+
|
|
315
|
+
Combines:
|
|
316
|
+
- Encoder (Variational for IB)
|
|
317
|
+
- Predictor (Task head)
|
|
318
|
+
- Adversary (Sensitive attribute classifier)
|
|
319
|
+
- Gradient Reversal Layer (GRL)
|
|
320
|
+
|
|
321
|
+
Loss: L = L_task + λ₁ L_KL - λ₂ L_adv
|
|
322
|
+
"""
|
|
323
|
+
|
|
324
|
+
def __init__(
|
|
325
|
+
self,
|
|
326
|
+
encoder: nn.Module,
|
|
327
|
+
predictor: nn.Module,
|
|
328
|
+
adversary: nn.Module,
|
|
329
|
+
lambda_kl: float = 0.01,
|
|
330
|
+
lambda_adv: float = 1.0,
|
|
331
|
+
lambda_grl: float = 1.0
|
|
332
|
+
):
|
|
333
|
+
super().__init__()
|
|
334
|
+
|
|
335
|
+
self.encoder = encoder
|
|
336
|
+
self.predictor = predictor
|
|
337
|
+
self.adversary = adversary
|
|
338
|
+
self.grl = GradientReversalLayer(lambda_grl)
|
|
339
|
+
|
|
340
|
+
self.lambda_kl = lambda_kl
|
|
341
|
+
self.lambda_adv = lambda_adv
|
|
342
|
+
|
|
343
|
+
def forward(
|
|
344
|
+
self,
|
|
345
|
+
x: torch.Tensor,
|
|
346
|
+
return_latent: bool = False
|
|
347
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
348
|
+
"""
|
|
349
|
+
Forward pass.
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
y_pred: Task predictions
|
|
353
|
+
s_pred: Sensitive attribute predictions (through GRL)
|
|
354
|
+
mu: Mean (for KL loss)
|
|
355
|
+
logvar: Log variance (for KL loss)
|
|
356
|
+
"""
|
|
357
|
+
# Encode
|
|
358
|
+
encoder_out = self.encoder(x)
|
|
359
|
+
|
|
360
|
+
if isinstance(encoder_out, tuple) and len(encoder_out) == 3:
|
|
361
|
+
z, mu, logvar = encoder_out
|
|
362
|
+
else:
|
|
363
|
+
z = encoder_out
|
|
364
|
+
mu = z
|
|
365
|
+
logvar = torch.zeros_like(z)
|
|
366
|
+
|
|
367
|
+
# Task prediction
|
|
368
|
+
y_pred = self.predictor(z)
|
|
369
|
+
|
|
370
|
+
# Adversarial prediction (with gradient reversal)
|
|
371
|
+
z_reversed = self.grl(z)
|
|
372
|
+
s_pred = self.adversary(z_reversed)
|
|
373
|
+
|
|
374
|
+
if return_latent:
|
|
375
|
+
return y_pred, s_pred, mu, logvar, z
|
|
376
|
+
|
|
377
|
+
return y_pred, s_pred, mu, logvar
|
|
378
|
+
|
|
379
|
+
def compute_kl_loss(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
|
|
380
|
+
"""
|
|
381
|
+
Compute KL divergence loss for variational bottleneck.
|
|
382
|
+
KL(q(z|x) || p(z)) where p(z) = N(0, I)
|
|
383
|
+
|
|
384
|
+
KL = -0.5 * sum(1 + log(σ²) - μ² - σ²)
|
|
385
|
+
"""
|
|
386
|
+
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
|
|
387
|
+
return kl_loss.mean()
|
|
388
|
+
|
|
389
|
+
def compute_loss(
|
|
390
|
+
self,
|
|
391
|
+
x: torch.Tensor,
|
|
392
|
+
y: torch.Tensor,
|
|
393
|
+
s: Optional[torch.Tensor] = None
|
|
394
|
+
) -> Tuple[torch.Tensor, dict]:
|
|
395
|
+
"""
|
|
396
|
+
Compute complete AIBT loss.
|
|
397
|
+
|
|
398
|
+
L = L_task + λ₁ L_KL - λ₂ L_adv
|
|
399
|
+
|
|
400
|
+
Returns:
|
|
401
|
+
total_loss: Combined loss for backpropagation
|
|
402
|
+
loss_dict: Dictionary of individual loss components
|
|
403
|
+
"""
|
|
404
|
+
y_pred, s_pred, mu, logvar = self.forward(x)
|
|
405
|
+
|
|
406
|
+
# Task loss
|
|
407
|
+
task_loss = F.cross_entropy(y_pred, y)
|
|
408
|
+
|
|
409
|
+
# KL divergence loss
|
|
410
|
+
kl_loss = self.compute_kl_loss(mu, logvar)
|
|
411
|
+
|
|
412
|
+
# Adversarial loss (if sensitive attributes provided)
|
|
413
|
+
if s is not None:
|
|
414
|
+
adv_loss = F.cross_entropy(s_pred, s)
|
|
415
|
+
else:
|
|
416
|
+
adv_loss = torch.tensor(0.0, device=x.device)
|
|
417
|
+
|
|
418
|
+
# Combined loss: L = L_task + λ₁ L_KL - λ₂ L_adv
|
|
419
|
+
# Note: The minus sign is already handled by GRL
|
|
420
|
+
total_loss = task_loss + self.lambda_kl * kl_loss + self.lambda_adv * adv_loss
|
|
421
|
+
|
|
422
|
+
loss_dict = {
|
|
423
|
+
"total_loss": total_loss.item(),
|
|
424
|
+
"task_loss": task_loss.item(),
|
|
425
|
+
"kl_loss": kl_loss.item(),
|
|
426
|
+
"adv_loss": adv_loss.item() if isinstance(adv_loss, torch.Tensor) else adv_loss,
|
|
427
|
+
}
|
|
428
|
+
|
|
429
|
+
return total_loss, loss_dict
|
|
430
|
+
|
|
431
|
+
def get_latent(self, x: torch.Tensor) -> torch.Tensor:
|
|
432
|
+
"""Get latent representation for privacy analysis"""
|
|
433
|
+
self.eval()
|
|
434
|
+
with torch.no_grad():
|
|
435
|
+
encoder_out = self.encoder(x)
|
|
436
|
+
if isinstance(encoder_out, tuple):
|
|
437
|
+
z = encoder_out[0]
|
|
438
|
+
else:
|
|
439
|
+
z = encoder_out
|
|
440
|
+
return z
|
|
441
|
+
|
|
442
|
+
def predict(self, x: torch.Tensor) -> torch.Tensor:
|
|
443
|
+
"""Get task predictions"""
|
|
444
|
+
self.eval()
|
|
445
|
+
with torch.no_grad():
|
|
446
|
+
y_pred, _, _, _ = self.forward(x)
|
|
447
|
+
return y_pred
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
# ============================================================================
|
|
451
|
+
# MODEL FACTORY
|
|
452
|
+
# ============================================================================
|
|
453
|
+
|
|
454
|
+
def create_aibt_model(
|
|
455
|
+
input_dim: int,
|
|
456
|
+
num_classes: int,
|
|
457
|
+
latent_dim: int = 64,
|
|
458
|
+
hidden_dims: Optional[List[int]] = None,
|
|
459
|
+
num_sensitive_classes: int = 2,
|
|
460
|
+
lambda_kl: float = 0.01,
|
|
461
|
+
lambda_adv: float = 1.0,
|
|
462
|
+
lambda_grl: float = 1.0,
|
|
463
|
+
dropout: float = 0.1
|
|
464
|
+
) -> AIBTModel:
|
|
465
|
+
"""
|
|
466
|
+
Factory function to create an AIBT model.
|
|
467
|
+
|
|
468
|
+
Args:
|
|
469
|
+
input_dim: Input feature dimension
|
|
470
|
+
num_classes: Number of output classes
|
|
471
|
+
latent_dim: Dimension of latent representation
|
|
472
|
+
hidden_dims: Hidden layer dimensions (default: [128, 64])
|
|
473
|
+
num_sensitive_classes: Number of sensitive attribute classes
|
|
474
|
+
lambda_kl: KL divergence weight
|
|
475
|
+
lambda_adv: Adversarial loss weight
|
|
476
|
+
lambda_grl: Gradient reversal strength
|
|
477
|
+
dropout: Dropout rate
|
|
478
|
+
|
|
479
|
+
Returns:
|
|
480
|
+
AIBTModel: Complete AIBT model
|
|
481
|
+
"""
|
|
482
|
+
if hidden_dims is None:
|
|
483
|
+
hidden_dims = [128, 64]
|
|
484
|
+
|
|
485
|
+
# Create encoder
|
|
486
|
+
encoder = MLPEncoder(
|
|
487
|
+
input_dim=input_dim,
|
|
488
|
+
hidden_dims=hidden_dims,
|
|
489
|
+
latent_dim=latent_dim,
|
|
490
|
+
variational=True,
|
|
491
|
+
dropout=dropout
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
# Create predictor
|
|
495
|
+
predictor = Predictor(
|
|
496
|
+
latent_dim=latent_dim,
|
|
497
|
+
hidden_dims=[64],
|
|
498
|
+
num_classes=num_classes,
|
|
499
|
+
dropout=dropout
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
# Create adversary
|
|
503
|
+
adversary = Adversary(
|
|
504
|
+
latent_dim=latent_dim,
|
|
505
|
+
hidden_dims=[64],
|
|
506
|
+
num_sensitive_classes=num_sensitive_classes,
|
|
507
|
+
dropout=dropout
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
# Create complete model
|
|
511
|
+
model = AIBTModel(
|
|
512
|
+
encoder=encoder,
|
|
513
|
+
predictor=predictor,
|
|
514
|
+
adversary=adversary,
|
|
515
|
+
lambda_kl=lambda_kl,
|
|
516
|
+
lambda_adv=lambda_adv,
|
|
517
|
+
lambda_grl=lambda_grl
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
return model
|
aibt/py.typed
ADDED
aibt/utils.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utility functions for AIBT Framework
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import numpy as np
|
|
7
|
+
from typing import Tuple, List, Optional
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def get_device(prefer_gpu: bool = True) -> str:
|
|
11
|
+
"""
|
|
12
|
+
Get the best available device.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
prefer_gpu: If True, prefer GPU if available
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
Device string ('cuda' or 'cpu')
|
|
19
|
+
"""
|
|
20
|
+
if prefer_gpu and torch.cuda.is_available():
|
|
21
|
+
return "cuda"
|
|
22
|
+
return "cpu"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def set_seed(seed: int = 42) -> None:
|
|
26
|
+
"""
|
|
27
|
+
Set random seed for reproducibility.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
seed: Random seed
|
|
31
|
+
"""
|
|
32
|
+
np.random.seed(seed)
|
|
33
|
+
torch.manual_seed(seed)
|
|
34
|
+
if torch.cuda.is_available():
|
|
35
|
+
torch.cuda.manual_seed_all(seed)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def split_data_iid(
|
|
39
|
+
X: np.ndarray,
|
|
40
|
+
y: np.ndarray,
|
|
41
|
+
num_clients: int,
|
|
42
|
+
seed: int = 42
|
|
43
|
+
) -> List[Tuple[np.ndarray, np.ndarray]]:
|
|
44
|
+
"""
|
|
45
|
+
Split data IID among clients.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
X: Features
|
|
49
|
+
y: Labels
|
|
50
|
+
num_clients: Number of clients
|
|
51
|
+
seed: Random seed
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
List of (X, y) tuples for each client
|
|
55
|
+
"""
|
|
56
|
+
np.random.seed(seed)
|
|
57
|
+
n_samples = len(y)
|
|
58
|
+
indices = np.random.permutation(n_samples)
|
|
59
|
+
|
|
60
|
+
# Split indices among clients
|
|
61
|
+
splits = np.array_split(indices, num_clients)
|
|
62
|
+
|
|
63
|
+
client_data = []
|
|
64
|
+
for split in splits:
|
|
65
|
+
client_data.append((X[split], y[split]))
|
|
66
|
+
|
|
67
|
+
return client_data
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def split_data_non_iid(
|
|
71
|
+
X: np.ndarray,
|
|
72
|
+
y: np.ndarray,
|
|
73
|
+
num_clients: int,
|
|
74
|
+
alpha: float = 0.5,
|
|
75
|
+
seed: int = 42
|
|
76
|
+
) -> List[Tuple[np.ndarray, np.ndarray]]:
|
|
77
|
+
"""
|
|
78
|
+
Split data non-IID among clients using Dirichlet distribution.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
X: Features
|
|
82
|
+
y: Labels
|
|
83
|
+
num_clients: Number of clients
|
|
84
|
+
alpha: Dirichlet concentration parameter (smaller = more non-IID)
|
|
85
|
+
seed: Random seed
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
List of (X, y) tuples for each client
|
|
89
|
+
"""
|
|
90
|
+
np.random.seed(seed)
|
|
91
|
+
n_samples = len(y)
|
|
92
|
+
num_classes = len(np.unique(y))
|
|
93
|
+
|
|
94
|
+
# Get indices for each class
|
|
95
|
+
class_indices = {c: np.where(y == c)[0] for c in range(num_classes)}
|
|
96
|
+
|
|
97
|
+
# Sample proportions from Dirichlet distribution
|
|
98
|
+
proportions = np.random.dirichlet([alpha] * num_clients, num_classes)
|
|
99
|
+
|
|
100
|
+
# Assign samples to clients
|
|
101
|
+
client_indices = [[] for _ in range(num_clients)]
|
|
102
|
+
|
|
103
|
+
for c in range(num_classes):
|
|
104
|
+
indices = class_indices[c]
|
|
105
|
+
np.random.shuffle(indices)
|
|
106
|
+
|
|
107
|
+
# Split class samples according to proportions
|
|
108
|
+
splits = np.split(
|
|
109
|
+
indices,
|
|
110
|
+
(np.cumsum(proportions[c]) * len(indices)).astype(int)[:-1]
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
for client_id, split in enumerate(splits):
|
|
114
|
+
client_indices[client_id].extend(split)
|
|
115
|
+
|
|
116
|
+
# Create client datasets
|
|
117
|
+
client_data = []
|
|
118
|
+
for indices in client_indices:
|
|
119
|
+
indices = np.array(indices)
|
|
120
|
+
if len(indices) > 0:
|
|
121
|
+
client_data.append((X[indices], y[indices]))
|
|
122
|
+
else:
|
|
123
|
+
# Empty client, add minimal data
|
|
124
|
+
client_data.append((X[:1], y[:1]))
|
|
125
|
+
|
|
126
|
+
return client_data
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def normalize_data(
|
|
130
|
+
X_train: np.ndarray,
|
|
131
|
+
X_test: np.ndarray
|
|
132
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
133
|
+
"""
|
|
134
|
+
Normalize data using training statistics.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
X_train: Training features
|
|
138
|
+
X_test: Test features
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Normalized (X_train, X_test)
|
|
142
|
+
"""
|
|
143
|
+
mean = X_train.mean(axis=0)
|
|
144
|
+
std = X_train.std(axis=0) + 1e-8
|
|
145
|
+
|
|
146
|
+
X_train_norm = (X_train - mean) / std
|
|
147
|
+
X_test_norm = (X_test - mean) / std
|
|
148
|
+
|
|
149
|
+
return X_train_norm, X_test_norm
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def count_parameters(model: torch.nn.Module) -> int:
|
|
153
|
+
"""
|
|
154
|
+
Count trainable parameters in a model.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
model: PyTorch model
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
Number of trainable parameters
|
|
161
|
+
"""
|
|
162
|
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|