aibt-fl 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aibt/models.py ADDED
@@ -0,0 +1,520 @@
1
+ """
2
+ Neural Network Components for AIBT Framework
3
+
4
+ Implements:
5
+ - Gradient Reversal Layer (GRL) for adversarial training
6
+ - Variational Encoder for Information Bottleneck
7
+ - MLP Encoder for tabular data
8
+ - Predictor (Task Head)
9
+ - Adversary (Sensitive Attribute Classifier)
10
+ - Complete AIBTModel
11
+ """
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from typing import Tuple, List, Optional
17
+ import math
18
+
19
+
20
+ # ============================================================================
21
+ # GRADIENT REVERSAL LAYER
22
+ # ============================================================================
23
+
24
+ class GradientReversalFunction(torch.autograd.Function):
25
+ """
26
+ Gradient Reversal Layer for adversarial training.
27
+ Forward pass: identity function
28
+ Backward pass: negate gradients and scale by lambda
29
+
30
+ Reference: Ganin & Lempitsky, "Domain-Adversarial Training of Neural Networks"
31
+ """
32
+
33
+ @staticmethod
34
+ def forward(ctx, x: torch.Tensor, lambda_grl: float) -> torch.Tensor:
35
+ ctx.lambda_grl = lambda_grl
36
+ return x.view_as(x)
37
+
38
+ @staticmethod
39
+ def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:
40
+ return -ctx.lambda_grl * grad_output, None
41
+
42
+
43
+ class GradientReversalLayer(nn.Module):
44
+ """Wrapper module for Gradient Reversal"""
45
+
46
+ def __init__(self, lambda_grl: float = 1.0):
47
+ super().__init__()
48
+ self.lambda_grl = lambda_grl
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ return GradientReversalFunction.apply(x, self.lambda_grl)
52
+
53
+ def set_lambda(self, lambda_grl: float) -> None:
54
+ """Update lambda value (useful for scheduling)"""
55
+ self.lambda_grl = lambda_grl
56
+
57
+
58
+ def grad_reverse(x: torch.Tensor, lambda_grl: float = 1.0) -> torch.Tensor:
59
+ """Functional version of gradient reversal"""
60
+ return GradientReversalFunction.apply(x, lambda_grl)
61
+
62
+
63
+ # ============================================================================
64
+ # VARIATIONAL ENCODER (Information Bottleneck)
65
+ # ============================================================================
66
+
67
+ class VariationalEncoder(nn.Module):
68
+ """
69
+ Variational Encoder for Information Bottleneck.
70
+ Outputs mean (mu) and log-variance (logvar) for reparameterization trick.
71
+
72
+ z = mu + sigma * epsilon, where epsilon ~ N(0, I)
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ input_dim: int,
78
+ hidden_dims: List[int],
79
+ latent_dim: int,
80
+ dropout: float = 0.1
81
+ ):
82
+ super().__init__()
83
+
84
+ self.input_dim = input_dim
85
+ self.latent_dim = latent_dim
86
+
87
+ # Build encoder layers
88
+ layers = []
89
+ prev_dim = input_dim
90
+
91
+ for hidden_dim in hidden_dims:
92
+ layers.extend([
93
+ nn.Linear(prev_dim, hidden_dim),
94
+ nn.BatchNorm1d(hidden_dim),
95
+ nn.ReLU(inplace=True),
96
+ nn.Dropout(dropout)
97
+ ])
98
+ prev_dim = hidden_dim
99
+
100
+ self.encoder = nn.Sequential(*layers)
101
+
102
+ # Mean and log-variance heads
103
+ self.fc_mu = nn.Linear(prev_dim, latent_dim)
104
+ self.fc_logvar = nn.Linear(prev_dim, latent_dim)
105
+
106
+ self._init_weights()
107
+
108
+ def _init_weights(self) -> None:
109
+ for m in self.modules():
110
+ if isinstance(m, nn.Linear):
111
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
112
+ if m.bias is not None:
113
+ nn.init.constant_(m.bias, 0)
114
+
115
+ def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
116
+ """Encode input to mu and logvar"""
117
+ h = self.encoder(x)
118
+ mu = self.fc_mu(h)
119
+ logvar = self.fc_logvar(h)
120
+ return mu, logvar
121
+
122
+ def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
123
+ """
124
+ Reparameterization trick: z = mu + sigma * epsilon
125
+ """
126
+ if self.training:
127
+ std = torch.exp(0.5 * logvar)
128
+ eps = torch.randn_like(std)
129
+ return mu + eps * std
130
+ return mu
131
+
132
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
133
+ """
134
+ Forward pass.
135
+ Returns: z (latent), mu, logvar
136
+ """
137
+ mu, logvar = self.encode(x)
138
+ z = self.reparameterize(mu, logvar)
139
+ return z, mu, logvar
140
+
141
+
142
+ # ============================================================================
143
+ # MLP ENCODER (For Tabular Data)
144
+ # ============================================================================
145
+
146
+ class MLPEncoder(nn.Module):
147
+ """
148
+ MLP Encoder for tabular datasets.
149
+ Can be variational or deterministic.
150
+ """
151
+
152
+ def __init__(
153
+ self,
154
+ input_dim: int,
155
+ hidden_dims: List[int],
156
+ latent_dim: int,
157
+ variational: bool = True,
158
+ dropout: float = 0.1
159
+ ):
160
+ super().__init__()
161
+
162
+ self.input_dim = input_dim
163
+ self.latent_dim = latent_dim
164
+ self.variational = variational
165
+
166
+ # Build encoder layers
167
+ layers = []
168
+ prev_dim = input_dim
169
+
170
+ for hidden_dim in hidden_dims:
171
+ layers.extend([
172
+ nn.Linear(prev_dim, hidden_dim),
173
+ nn.BatchNorm1d(hidden_dim),
174
+ nn.ReLU(inplace=True),
175
+ nn.Dropout(dropout)
176
+ ])
177
+ prev_dim = hidden_dim
178
+
179
+ self.encoder = nn.Sequential(*layers)
180
+
181
+ if variational:
182
+ self.fc_mu = nn.Linear(prev_dim, latent_dim)
183
+ self.fc_logvar = nn.Linear(prev_dim, latent_dim)
184
+ else:
185
+ self.fc_out = nn.Linear(prev_dim, latent_dim)
186
+
187
+ self._init_weights()
188
+
189
+ def _init_weights(self) -> None:
190
+ for m in self.modules():
191
+ if isinstance(m, nn.Linear):
192
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
193
+ if m.bias is not None:
194
+ nn.init.constant_(m.bias, 0)
195
+
196
+ def forward(self, x: torch.Tensor):
197
+ h = self.encoder(x)
198
+
199
+ if self.variational:
200
+ mu = self.fc_mu(h)
201
+ logvar = self.fc_logvar(h)
202
+
203
+ if self.training:
204
+ std = torch.exp(0.5 * logvar)
205
+ eps = torch.randn_like(std)
206
+ z = mu + eps * std
207
+ else:
208
+ z = mu
209
+
210
+ return z, mu, logvar
211
+ else:
212
+ return self.fc_out(h)
213
+
214
+
215
+ # ============================================================================
216
+ # PREDICTOR (Task Head)
217
+ # ============================================================================
218
+
219
+ class Predictor(nn.Module):
220
+ """
221
+ Predictor network for task-specific predictions.
222
+ Takes latent representation z and outputs class probabilities.
223
+ """
224
+
225
+ def __init__(
226
+ self,
227
+ latent_dim: int,
228
+ hidden_dims: List[int],
229
+ num_classes: int,
230
+ dropout: float = 0.1
231
+ ):
232
+ super().__init__()
233
+
234
+ layers = []
235
+ prev_dim = latent_dim
236
+
237
+ for hidden_dim in hidden_dims:
238
+ layers.extend([
239
+ nn.Linear(prev_dim, hidden_dim),
240
+ nn.ReLU(inplace=True),
241
+ nn.Dropout(dropout)
242
+ ])
243
+ prev_dim = hidden_dim
244
+
245
+ layers.append(nn.Linear(prev_dim, num_classes))
246
+
247
+ self.predictor = nn.Sequential(*layers)
248
+ self._init_weights()
249
+
250
+ def _init_weights(self) -> None:
251
+ for m in self.modules():
252
+ if isinstance(m, nn.Linear):
253
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
254
+ if m.bias is not None:
255
+ nn.init.constant_(m.bias, 0)
256
+
257
+ def forward(self, z: torch.Tensor) -> torch.Tensor:
258
+ return self.predictor(z)
259
+
260
+
261
+ # ============================================================================
262
+ # ADVERSARY (Sensitive Attribute Classifier)
263
+ # ============================================================================
264
+
265
+ class Adversary(nn.Module):
266
+ """
267
+ Adversary network for predicting sensitive attributes.
268
+ Used with GRL to enforce privacy in latent representation.
269
+ """
270
+
271
+ def __init__(
272
+ self,
273
+ latent_dim: int,
274
+ hidden_dims: List[int],
275
+ num_sensitive_classes: int,
276
+ dropout: float = 0.1
277
+ ):
278
+ super().__init__()
279
+
280
+ layers = []
281
+ prev_dim = latent_dim
282
+
283
+ for hidden_dim in hidden_dims:
284
+ layers.extend([
285
+ nn.Linear(prev_dim, hidden_dim),
286
+ nn.ReLU(inplace=True),
287
+ nn.Dropout(dropout)
288
+ ])
289
+ prev_dim = hidden_dim
290
+
291
+ layers.append(nn.Linear(prev_dim, num_sensitive_classes))
292
+
293
+ self.adversary = nn.Sequential(*layers)
294
+ self._init_weights()
295
+
296
+ def _init_weights(self) -> None:
297
+ for m in self.modules():
298
+ if isinstance(m, nn.Linear):
299
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
300
+ if m.bias is not None:
301
+ nn.init.constant_(m.bias, 0)
302
+
303
+ def forward(self, z: torch.Tensor) -> torch.Tensor:
304
+ return self.adversary(z)
305
+
306
+
307
+ # ============================================================================
308
+ # COMPLETE AIBT MODEL
309
+ # ============================================================================
310
+
311
+ class AIBTModel(nn.Module):
312
+ """
313
+ Complete AIBT (Adversarial Information Bottleneck Training) Model.
314
+
315
+ Combines:
316
+ - Encoder (Variational for IB)
317
+ - Predictor (Task head)
318
+ - Adversary (Sensitive attribute classifier)
319
+ - Gradient Reversal Layer (GRL)
320
+
321
+ Loss: L = L_task + λ₁ L_KL - λ₂ L_adv
322
+ """
323
+
324
+ def __init__(
325
+ self,
326
+ encoder: nn.Module,
327
+ predictor: nn.Module,
328
+ adversary: nn.Module,
329
+ lambda_kl: float = 0.01,
330
+ lambda_adv: float = 1.0,
331
+ lambda_grl: float = 1.0
332
+ ):
333
+ super().__init__()
334
+
335
+ self.encoder = encoder
336
+ self.predictor = predictor
337
+ self.adversary = adversary
338
+ self.grl = GradientReversalLayer(lambda_grl)
339
+
340
+ self.lambda_kl = lambda_kl
341
+ self.lambda_adv = lambda_adv
342
+
343
+ def forward(
344
+ self,
345
+ x: torch.Tensor,
346
+ return_latent: bool = False
347
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
348
+ """
349
+ Forward pass.
350
+
351
+ Returns:
352
+ y_pred: Task predictions
353
+ s_pred: Sensitive attribute predictions (through GRL)
354
+ mu: Mean (for KL loss)
355
+ logvar: Log variance (for KL loss)
356
+ """
357
+ # Encode
358
+ encoder_out = self.encoder(x)
359
+
360
+ if isinstance(encoder_out, tuple) and len(encoder_out) == 3:
361
+ z, mu, logvar = encoder_out
362
+ else:
363
+ z = encoder_out
364
+ mu = z
365
+ logvar = torch.zeros_like(z)
366
+
367
+ # Task prediction
368
+ y_pred = self.predictor(z)
369
+
370
+ # Adversarial prediction (with gradient reversal)
371
+ z_reversed = self.grl(z)
372
+ s_pred = self.adversary(z_reversed)
373
+
374
+ if return_latent:
375
+ return y_pred, s_pred, mu, logvar, z
376
+
377
+ return y_pred, s_pred, mu, logvar
378
+
379
+ def compute_kl_loss(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
380
+ """
381
+ Compute KL divergence loss for variational bottleneck.
382
+ KL(q(z|x) || p(z)) where p(z) = N(0, I)
383
+
384
+ KL = -0.5 * sum(1 + log(σ²) - μ² - σ²)
385
+ """
386
+ kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
387
+ return kl_loss.mean()
388
+
389
+ def compute_loss(
390
+ self,
391
+ x: torch.Tensor,
392
+ y: torch.Tensor,
393
+ s: Optional[torch.Tensor] = None
394
+ ) -> Tuple[torch.Tensor, dict]:
395
+ """
396
+ Compute complete AIBT loss.
397
+
398
+ L = L_task + λ₁ L_KL - λ₂ L_adv
399
+
400
+ Returns:
401
+ total_loss: Combined loss for backpropagation
402
+ loss_dict: Dictionary of individual loss components
403
+ """
404
+ y_pred, s_pred, mu, logvar = self.forward(x)
405
+
406
+ # Task loss
407
+ task_loss = F.cross_entropy(y_pred, y)
408
+
409
+ # KL divergence loss
410
+ kl_loss = self.compute_kl_loss(mu, logvar)
411
+
412
+ # Adversarial loss (if sensitive attributes provided)
413
+ if s is not None:
414
+ adv_loss = F.cross_entropy(s_pred, s)
415
+ else:
416
+ adv_loss = torch.tensor(0.0, device=x.device)
417
+
418
+ # Combined loss: L = L_task + λ₁ L_KL - λ₂ L_adv
419
+ # Note: The minus sign is already handled by GRL
420
+ total_loss = task_loss + self.lambda_kl * kl_loss + self.lambda_adv * adv_loss
421
+
422
+ loss_dict = {
423
+ "total_loss": total_loss.item(),
424
+ "task_loss": task_loss.item(),
425
+ "kl_loss": kl_loss.item(),
426
+ "adv_loss": adv_loss.item() if isinstance(adv_loss, torch.Tensor) else adv_loss,
427
+ }
428
+
429
+ return total_loss, loss_dict
430
+
431
+ def get_latent(self, x: torch.Tensor) -> torch.Tensor:
432
+ """Get latent representation for privacy analysis"""
433
+ self.eval()
434
+ with torch.no_grad():
435
+ encoder_out = self.encoder(x)
436
+ if isinstance(encoder_out, tuple):
437
+ z = encoder_out[0]
438
+ else:
439
+ z = encoder_out
440
+ return z
441
+
442
+ def predict(self, x: torch.Tensor) -> torch.Tensor:
443
+ """Get task predictions"""
444
+ self.eval()
445
+ with torch.no_grad():
446
+ y_pred, _, _, _ = self.forward(x)
447
+ return y_pred
448
+
449
+
450
+ # ============================================================================
451
+ # MODEL FACTORY
452
+ # ============================================================================
453
+
454
+ def create_aibt_model(
455
+ input_dim: int,
456
+ num_classes: int,
457
+ latent_dim: int = 64,
458
+ hidden_dims: Optional[List[int]] = None,
459
+ num_sensitive_classes: int = 2,
460
+ lambda_kl: float = 0.01,
461
+ lambda_adv: float = 1.0,
462
+ lambda_grl: float = 1.0,
463
+ dropout: float = 0.1
464
+ ) -> AIBTModel:
465
+ """
466
+ Factory function to create an AIBT model.
467
+
468
+ Args:
469
+ input_dim: Input feature dimension
470
+ num_classes: Number of output classes
471
+ latent_dim: Dimension of latent representation
472
+ hidden_dims: Hidden layer dimensions (default: [128, 64])
473
+ num_sensitive_classes: Number of sensitive attribute classes
474
+ lambda_kl: KL divergence weight
475
+ lambda_adv: Adversarial loss weight
476
+ lambda_grl: Gradient reversal strength
477
+ dropout: Dropout rate
478
+
479
+ Returns:
480
+ AIBTModel: Complete AIBT model
481
+ """
482
+ if hidden_dims is None:
483
+ hidden_dims = [128, 64]
484
+
485
+ # Create encoder
486
+ encoder = MLPEncoder(
487
+ input_dim=input_dim,
488
+ hidden_dims=hidden_dims,
489
+ latent_dim=latent_dim,
490
+ variational=True,
491
+ dropout=dropout
492
+ )
493
+
494
+ # Create predictor
495
+ predictor = Predictor(
496
+ latent_dim=latent_dim,
497
+ hidden_dims=[64],
498
+ num_classes=num_classes,
499
+ dropout=dropout
500
+ )
501
+
502
+ # Create adversary
503
+ adversary = Adversary(
504
+ latent_dim=latent_dim,
505
+ hidden_dims=[64],
506
+ num_sensitive_classes=num_sensitive_classes,
507
+ dropout=dropout
508
+ )
509
+
510
+ # Create complete model
511
+ model = AIBTModel(
512
+ encoder=encoder,
513
+ predictor=predictor,
514
+ adversary=adversary,
515
+ lambda_kl=lambda_kl,
516
+ lambda_adv=lambda_adv,
517
+ lambda_grl=lambda_grl
518
+ )
519
+
520
+ return model
aibt/py.typed ADDED
@@ -0,0 +1,2 @@
1
+ # Marker file for PEP 561
2
+ # This file indicates that the package supports type hints
aibt/utils.py ADDED
@@ -0,0 +1,162 @@
1
+ """
2
+ Utility functions for AIBT Framework
3
+ """
4
+
5
+ import torch
6
+ import numpy as np
7
+ from typing import Tuple, List, Optional
8
+
9
+
10
+ def get_device(prefer_gpu: bool = True) -> str:
11
+ """
12
+ Get the best available device.
13
+
14
+ Args:
15
+ prefer_gpu: If True, prefer GPU if available
16
+
17
+ Returns:
18
+ Device string ('cuda' or 'cpu')
19
+ """
20
+ if prefer_gpu and torch.cuda.is_available():
21
+ return "cuda"
22
+ return "cpu"
23
+
24
+
25
+ def set_seed(seed: int = 42) -> None:
26
+ """
27
+ Set random seed for reproducibility.
28
+
29
+ Args:
30
+ seed: Random seed
31
+ """
32
+ np.random.seed(seed)
33
+ torch.manual_seed(seed)
34
+ if torch.cuda.is_available():
35
+ torch.cuda.manual_seed_all(seed)
36
+
37
+
38
+ def split_data_iid(
39
+ X: np.ndarray,
40
+ y: np.ndarray,
41
+ num_clients: int,
42
+ seed: int = 42
43
+ ) -> List[Tuple[np.ndarray, np.ndarray]]:
44
+ """
45
+ Split data IID among clients.
46
+
47
+ Args:
48
+ X: Features
49
+ y: Labels
50
+ num_clients: Number of clients
51
+ seed: Random seed
52
+
53
+ Returns:
54
+ List of (X, y) tuples for each client
55
+ """
56
+ np.random.seed(seed)
57
+ n_samples = len(y)
58
+ indices = np.random.permutation(n_samples)
59
+
60
+ # Split indices among clients
61
+ splits = np.array_split(indices, num_clients)
62
+
63
+ client_data = []
64
+ for split in splits:
65
+ client_data.append((X[split], y[split]))
66
+
67
+ return client_data
68
+
69
+
70
+ def split_data_non_iid(
71
+ X: np.ndarray,
72
+ y: np.ndarray,
73
+ num_clients: int,
74
+ alpha: float = 0.5,
75
+ seed: int = 42
76
+ ) -> List[Tuple[np.ndarray, np.ndarray]]:
77
+ """
78
+ Split data non-IID among clients using Dirichlet distribution.
79
+
80
+ Args:
81
+ X: Features
82
+ y: Labels
83
+ num_clients: Number of clients
84
+ alpha: Dirichlet concentration parameter (smaller = more non-IID)
85
+ seed: Random seed
86
+
87
+ Returns:
88
+ List of (X, y) tuples for each client
89
+ """
90
+ np.random.seed(seed)
91
+ n_samples = len(y)
92
+ num_classes = len(np.unique(y))
93
+
94
+ # Get indices for each class
95
+ class_indices = {c: np.where(y == c)[0] for c in range(num_classes)}
96
+
97
+ # Sample proportions from Dirichlet distribution
98
+ proportions = np.random.dirichlet([alpha] * num_clients, num_classes)
99
+
100
+ # Assign samples to clients
101
+ client_indices = [[] for _ in range(num_clients)]
102
+
103
+ for c in range(num_classes):
104
+ indices = class_indices[c]
105
+ np.random.shuffle(indices)
106
+
107
+ # Split class samples according to proportions
108
+ splits = np.split(
109
+ indices,
110
+ (np.cumsum(proportions[c]) * len(indices)).astype(int)[:-1]
111
+ )
112
+
113
+ for client_id, split in enumerate(splits):
114
+ client_indices[client_id].extend(split)
115
+
116
+ # Create client datasets
117
+ client_data = []
118
+ for indices in client_indices:
119
+ indices = np.array(indices)
120
+ if len(indices) > 0:
121
+ client_data.append((X[indices], y[indices]))
122
+ else:
123
+ # Empty client, add minimal data
124
+ client_data.append((X[:1], y[:1]))
125
+
126
+ return client_data
127
+
128
+
129
+ def normalize_data(
130
+ X_train: np.ndarray,
131
+ X_test: np.ndarray
132
+ ) -> Tuple[np.ndarray, np.ndarray]:
133
+ """
134
+ Normalize data using training statistics.
135
+
136
+ Args:
137
+ X_train: Training features
138
+ X_test: Test features
139
+
140
+ Returns:
141
+ Normalized (X_train, X_test)
142
+ """
143
+ mean = X_train.mean(axis=0)
144
+ std = X_train.std(axis=0) + 1e-8
145
+
146
+ X_train_norm = (X_train - mean) / std
147
+ X_test_norm = (X_test - mean) / std
148
+
149
+ return X_train_norm, X_test_norm
150
+
151
+
152
+ def count_parameters(model: torch.nn.Module) -> int:
153
+ """
154
+ Count trainable parameters in a model.
155
+
156
+ Args:
157
+ model: PyTorch model
158
+
159
+ Returns:
160
+ Number of trainable parameters
161
+ """
162
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)