aibt-fl 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aibt/__init__.py ADDED
@@ -0,0 +1,77 @@
1
+ """
2
+ AIBT: Adversarial Information Bottleneck Training for Privacy-Preserving Federated Learning
3
+
4
+ This package provides a complete implementation of the AIBT framework for
5
+ privacy-preserving federated learning, combining:
6
+ - Information Bottleneck (IB) for feature compression
7
+ - Adversarial Training for sensitive attribute suppression
8
+ - Federated Learning for distributed training
9
+
10
+ Example:
11
+ >>> from aibt import AIBTFL, create_aibt_model, evaluate_privacy
12
+ >>> model = create_aibt_model(input_dim=13, num_classes=2)
13
+ >>> aibt = AIBTFL(model=model, num_clients=10)
14
+ >>> history = aibt.train(num_rounds=100, local_epochs=5)
15
+ """
16
+
17
+ __version__ = "1.0.0"
18
+ __author__ = "AIBT Research Team"
19
+
20
+ # Core FL implementation
21
+ from aibt.core import AIBTFL
22
+
23
+ # Neural network models
24
+ from aibt.models import (
25
+ GradientReversalLayer,
26
+ GradientReversalFunction,
27
+ VariationalEncoder,
28
+ MLPEncoder,
29
+ Predictor,
30
+ Adversary,
31
+ AIBTModel,
32
+ create_aibt_model,
33
+ )
34
+
35
+ # Client implementation
36
+ from aibt.client import FLClient, AIBTClient
37
+
38
+ # Privacy metrics
39
+ from aibt.metrics import (
40
+ MembershipInferenceAttack,
41
+ AttributeInferenceAttack,
42
+ evaluate_membership_inference,
43
+ evaluate_attribute_inference,
44
+ evaluate_privacy,
45
+ evaluate_performance,
46
+ )
47
+
48
+ # Aggregation utilities
49
+ from aibt.aggregation import fedavg_aggregate
50
+
51
+ __all__ = [
52
+ # Version
53
+ "__version__",
54
+ # Core
55
+ "AIBTFL",
56
+ # Models
57
+ "GradientReversalLayer",
58
+ "GradientReversalFunction",
59
+ "VariationalEncoder",
60
+ "MLPEncoder",
61
+ "Predictor",
62
+ "Adversary",
63
+ "AIBTModel",
64
+ "create_aibt_model",
65
+ # Clients
66
+ "FLClient",
67
+ "AIBTClient",
68
+ # Metrics
69
+ "MembershipInferenceAttack",
70
+ "AttributeInferenceAttack",
71
+ "evaluate_membership_inference",
72
+ "evaluate_attribute_inference",
73
+ "evaluate_privacy",
74
+ "evaluate_performance",
75
+ # Aggregation
76
+ "fedavg_aggregate",
77
+ ]
aibt/aggregation.py ADDED
@@ -0,0 +1,68 @@
1
+ """
2
+ FedAvg Aggregation utilities for AIBT.
3
+
4
+ Implements weighted averaging of model parameters for federated learning.
5
+ """
6
+
7
+ import torch
8
+ from typing import Dict, List
9
+
10
+
11
+ def fedavg_aggregate(
12
+ client_params: List[Dict[str, torch.Tensor]],
13
+ client_sizes: List[int]
14
+ ) -> Dict[str, torch.Tensor]:
15
+ """
16
+ FedAvg aggregation for federated learning.
17
+
18
+ Computes weighted average of client model parameters:
19
+ w_global = Σ(n_k/n) * w_k
20
+
21
+ Args:
22
+ client_params: List of model state dicts from clients
23
+ client_sizes: List of dataset sizes per client
24
+
25
+ Returns:
26
+ Aggregated model state dict
27
+ """
28
+ if not client_params:
29
+ raise ValueError("No client parameters provided")
30
+
31
+ total_size = sum(client_sizes)
32
+ aggregated = {}
33
+
34
+ for key in client_params[0].keys():
35
+ weighted_sum = torch.zeros_like(client_params[0][key], dtype=torch.float32)
36
+
37
+ for i, params in enumerate(client_params):
38
+ weight = client_sizes[i] / total_size
39
+ weighted_sum += weight * params[key].float()
40
+
41
+ aggregated[key] = weighted_sum
42
+
43
+ return aggregated
44
+
45
+
46
+ def simple_average(
47
+ client_params: List[Dict[str, torch.Tensor]]
48
+ ) -> Dict[str, torch.Tensor]:
49
+ """
50
+ Simple averaging of client parameters (unweighted).
51
+
52
+ Args:
53
+ client_params: List of model state dicts from clients
54
+
55
+ Returns:
56
+ Aggregated model state dict
57
+ """
58
+ if not client_params:
59
+ raise ValueError("No client parameters provided")
60
+
61
+ num_clients = len(client_params)
62
+ aggregated = {}
63
+
64
+ for key in client_params[0].keys():
65
+ stacked = torch.stack([params[key].float() for params in client_params])
66
+ aggregated[key] = stacked.mean(dim=0)
67
+
68
+ return aggregated
aibt/client.py ADDED
@@ -0,0 +1,259 @@
1
+ """
2
+ Federated Learning Client implementations for AIBT.
3
+
4
+ Implements:
5
+ - FLClient: Base federated learning client
6
+ - AIBTClient: AIBT-specific client with IB and adversarial training
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torch.utils.data import DataLoader, TensorDataset
13
+ from typing import Dict, Tuple, Optional
14
+ import copy
15
+ import numpy as np
16
+
17
+
18
+ class FLClient:
19
+ """
20
+ Federated Learning Client.
21
+ Each client has local data and performs local training.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ client_id: int,
27
+ model: nn.Module,
28
+ train_data: Tuple[np.ndarray, np.ndarray],
29
+ sensitive_data: Optional[np.ndarray] = None,
30
+ batch_size: int = 32,
31
+ learning_rate: float = 0.001,
32
+ device: str = "cpu"
33
+ ):
34
+ self.client_id = client_id
35
+ self.device = device
36
+ self.batch_size = batch_size
37
+ self.learning_rate = learning_rate
38
+
39
+ # Create local model (deep copy)
40
+ self.model = copy.deepcopy(model).to(device)
41
+
42
+ # Prepare data
43
+ X, y = train_data
44
+ self.n_samples = len(y)
45
+
46
+ # Check for sensitive attributes
47
+ self.has_sensitive = sensitive_data is not None
48
+
49
+ if self.has_sensitive:
50
+ dataset = TensorDataset(
51
+ torch.FloatTensor(X),
52
+ torch.LongTensor(y),
53
+ torch.LongTensor(sensitive_data)
54
+ )
55
+ else:
56
+ dataset = TensorDataset(
57
+ torch.FloatTensor(X),
58
+ torch.LongTensor(y)
59
+ )
60
+
61
+ self.dataloader = DataLoader(
62
+ dataset,
63
+ batch_size=batch_size,
64
+ shuffle=True,
65
+ drop_last=len(dataset) > batch_size
66
+ )
67
+
68
+ # Optimizer
69
+ self.optimizer = torch.optim.Adam(
70
+ self.model.parameters(),
71
+ lr=learning_rate
72
+ )
73
+
74
+ def set_model_params(self, params: Dict[str, torch.Tensor]) -> None:
75
+ """Load global model parameters"""
76
+ self.model.load_state_dict(params)
77
+
78
+ def get_model_params(self) -> Dict[str, torch.Tensor]:
79
+ """Get local model parameters"""
80
+ return copy.deepcopy(self.model.state_dict())
81
+
82
+ def train_epoch(self) -> Dict[str, float]:
83
+ """Train for one epoch"""
84
+ self.model.train()
85
+ total_loss = 0.0
86
+ correct = 0
87
+ total = 0
88
+
89
+ for batch in self.dataloader:
90
+ if self.has_sensitive:
91
+ X, y, s = batch
92
+ s = s.to(self.device)
93
+ else:
94
+ X, y = batch
95
+ s = None
96
+
97
+ X, y = X.to(self.device), y.to(self.device)
98
+
99
+ self.optimizer.zero_grad()
100
+
101
+ # Forward pass
102
+ output = self.model(X)
103
+ if isinstance(output, tuple):
104
+ y_pred = output[0]
105
+ else:
106
+ y_pred = output
107
+
108
+ loss = F.cross_entropy(y_pred, y)
109
+ loss.backward()
110
+ self.optimizer.step()
111
+
112
+ total_loss += loss.item() * len(y)
113
+ pred = y_pred.argmax(dim=1)
114
+ correct += (pred == y).sum().item()
115
+ total += len(y)
116
+
117
+ return {
118
+ "loss": total_loss / total,
119
+ "accuracy": correct / total
120
+ }
121
+
122
+ def train(self, local_epochs: int) -> Dict[str, float]:
123
+ """Train for multiple epochs"""
124
+ metrics = {}
125
+ for epoch in range(local_epochs):
126
+ metrics = self.train_epoch()
127
+ return metrics
128
+
129
+ def evaluate(self, test_data: Tuple[np.ndarray, np.ndarray]) -> Dict[str, float]:
130
+ """Evaluate model on test data"""
131
+ self.model.eval()
132
+ X, y = test_data
133
+ X_tensor = torch.FloatTensor(X).to(self.device)
134
+ y_tensor = torch.LongTensor(y).to(self.device)
135
+
136
+ with torch.no_grad():
137
+ output = self.model(X_tensor)
138
+ if isinstance(output, tuple):
139
+ logits = output[0]
140
+ else:
141
+ logits = output
142
+
143
+ loss = F.cross_entropy(logits, y_tensor).item()
144
+ pred = logits.argmax(dim=1)
145
+ accuracy = (pred == y_tensor).float().mean().item()
146
+
147
+ return {"test_loss": loss, "test_accuracy": accuracy}
148
+
149
+
150
+ class AIBTClient(FLClient):
151
+ """
152
+ AIBT Client with Information Bottleneck and Adversarial Training.
153
+
154
+ Loss = L_task + λ₁ L_KL - λ₂ L_adv
155
+
156
+ Training procedure:
157
+ 1. Forward pass through encoder → z (with reparameterization)
158
+ 2. Task prediction via predictor
159
+ 3. Adversarial prediction via GRL + adversary
160
+ 4. Compute combined loss
161
+ 5. Update encoder and predictor
162
+ """
163
+
164
+ def __init__(
165
+ self,
166
+ *args,
167
+ lambda_kl: float = 0.01,
168
+ lambda_adv: float = 1.0,
169
+ **kwargs
170
+ ):
171
+ super().__init__(*args, **kwargs)
172
+ self.lambda_kl = lambda_kl
173
+ self.lambda_adv = lambda_adv
174
+
175
+ def train_epoch(self) -> Dict[str, float]:
176
+ """
177
+ Train for one epoch using AIBT algorithm.
178
+
179
+ Algorithm 1/2 from the paper:
180
+ For each batch:
181
+ 1. z = Encoder(x) with reparameterization
182
+ 2. y_hat = Predictor(z)
183
+ 3. s_hat = Adversary(GRL(z))
184
+ 4. L = L_task + λ₁ L_KL - λ₂ L_adv
185
+ 5. Update encoder and predictor
186
+ """
187
+ self.model.train()
188
+ total_loss = 0.0
189
+ total_task_loss = 0.0
190
+ total_kl_loss = 0.0
191
+ total_adv_loss = 0.0
192
+ correct = 0
193
+ total = 0
194
+
195
+ for batch in self.dataloader:
196
+ if self.has_sensitive:
197
+ X, y, s = batch
198
+ s = s.to(self.device)
199
+ else:
200
+ X, y = batch
201
+ s = None
202
+
203
+ X, y = X.to(self.device), y.to(self.device)
204
+
205
+ self.optimizer.zero_grad()
206
+
207
+ # Forward pass through AIBT model
208
+ if hasattr(self.model, 'compute_loss'):
209
+ # Using AIBTModel with built-in loss computation
210
+ loss, loss_dict = self.model.compute_loss(X, y, s)
211
+
212
+ # Get predictions for accuracy
213
+ with torch.no_grad():
214
+ y_pred, _, _, _ = self.model(X)
215
+ pred = y_pred.argmax(dim=1)
216
+
217
+ total_task_loss += loss_dict["task_loss"] * len(y)
218
+ total_kl_loss += loss_dict["kl_loss"] * len(y)
219
+ total_adv_loss += loss_dict["adv_loss"] * len(y)
220
+ else:
221
+ # Manual computation for non-AIBT models
222
+ output = self.model(X)
223
+ if isinstance(output, tuple):
224
+ y_pred = output[0]
225
+ else:
226
+ y_pred = output
227
+
228
+ loss = F.cross_entropy(y_pred, y)
229
+ pred = y_pred.argmax(dim=1)
230
+
231
+ loss.backward()
232
+ self.optimizer.step()
233
+
234
+ total_loss += loss.item() * len(y)
235
+ correct += (pred == y).sum().item()
236
+ total += len(y)
237
+
238
+ return {
239
+ "loss": total_loss / total,
240
+ "task_loss": total_task_loss / total if total_task_loss > 0 else 0,
241
+ "kl_loss": total_kl_loss / total if total_kl_loss > 0 else 0,
242
+ "adv_loss": total_adv_loss / total if total_adv_loss > 0 else 0,
243
+ "accuracy": correct / total
244
+ }
245
+
246
+ def get_latent_representations(self, X: np.ndarray) -> np.ndarray:
247
+ """Get latent representations for privacy analysis"""
248
+ self.model.eval()
249
+ X_tensor = torch.FloatTensor(X).to(self.device)
250
+
251
+ with torch.no_grad():
252
+ if hasattr(self.model, 'get_latent'):
253
+ z = self.model.get_latent(X_tensor)
254
+ else:
255
+ z = self.model.encoder(X_tensor)
256
+ if isinstance(z, tuple):
257
+ z = z[0]
258
+
259
+ return z.cpu().numpy()
aibt/core.py ADDED
@@ -0,0 +1,287 @@
1
+ """
2
+ AIBT FL: Adversarial Information Bottleneck Training for Federated Learning
3
+
4
+ Main federated learning implementation combining:
5
+ - Information Bottleneck (IB) for feature compression
6
+ - Adversarial Training for sensitive attribute suppression
7
+ - Federated Averaging for model aggregation
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from typing import Dict, List, Tuple, Optional, Any
14
+ import copy
15
+ import numpy as np
16
+
17
+ from aibt.client import AIBTClient, FLClient
18
+ from aibt.aggregation import fedavg_aggregate
19
+
20
+
21
+ class AIBTFL:
22
+ """
23
+ AIBT FL: Adversarial Information Bottleneck Training for Federated Learning.
24
+
25
+ Combines:
26
+ - Information Bottleneck (IB) for feature compression
27
+ - Adversarial Training for sensitive attribute suppression
28
+ - Federated Learning for distributed training
29
+
30
+ Loss: L = L_task + λ₁ L_KL - λ₂ L_adv
31
+
32
+ Args:
33
+ model: Neural network model (preferably AIBTModel)
34
+ num_clients: Number of federated learning clients
35
+ device: Device for computation ('cpu' or 'cuda')
36
+ learning_rate: Learning rate for local training
37
+ batch_size: Batch size for local training
38
+ lambda_kl: Weight for KL divergence loss (Information Bottleneck)
39
+ lambda_adv: Weight for adversarial loss
40
+ lambda_grl: Strength of gradient reversal
41
+ latent_dim: Dimension of latent representation
42
+ num_sensitive_classes: Number of sensitive attribute classes
43
+
44
+ Example:
45
+ >>> from aibt import AIBTFL, create_aibt_model
46
+ >>> model = create_aibt_model(input_dim=13, num_classes=2)
47
+ >>> aibt = AIBTFL(model=model, num_clients=10)
48
+ >>> aibt.setup_clients(client_datasets)
49
+ >>> history = aibt.train(num_rounds=100, local_epochs=5)
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ model: nn.Module,
55
+ num_clients: int,
56
+ device: str = "cpu",
57
+ learning_rate: float = 0.001,
58
+ batch_size: int = 32,
59
+ lambda_kl: float = 0.01,
60
+ lambda_adv: float = 1.0,
61
+ lambda_grl: float = 1.0,
62
+ latent_dim: int = 128,
63
+ num_sensitive_classes: int = 2
64
+ ):
65
+ self.device = device
66
+ self.num_clients = num_clients
67
+ self.learning_rate = learning_rate
68
+ self.batch_size = batch_size
69
+
70
+ self.lambda_kl = lambda_kl
71
+ self.lambda_adv = lambda_adv
72
+ self.lambda_grl = lambda_grl
73
+ self.latent_dim = latent_dim
74
+ self.num_sensitive_classes = num_sensitive_classes
75
+ self.method_name = "AIBT"
76
+
77
+ # Global model
78
+ self.global_model = copy.deepcopy(model).to(device)
79
+
80
+ # Update model parameters if it's an AIBT model
81
+ if hasattr(self.global_model, 'lambda_kl'):
82
+ self.global_model.lambda_kl = lambda_kl
83
+ self.global_model.lambda_adv = lambda_adv
84
+ if hasattr(self.global_model, 'grl'):
85
+ self.global_model.grl.set_lambda(lambda_grl)
86
+
87
+ # Clients (initialized later)
88
+ self.clients: List[AIBTClient] = []
89
+
90
+ # Metrics tracking
91
+ self.current_round = 0
92
+ self.round_metrics: List[Dict[str, float]] = []
93
+
94
+ def setup_clients(
95
+ self,
96
+ client_datasets: List[Tuple[np.ndarray, np.ndarray]],
97
+ sensitive_data: Optional[List[np.ndarray]] = None
98
+ ) -> None:
99
+ """
100
+ Initialize clients with their local data.
101
+
102
+ Args:
103
+ client_datasets: List of (X, y) tuples, one per client
104
+ sensitive_data: Optional list of sensitive attribute arrays
105
+ """
106
+ self.clients = []
107
+
108
+ for i, data in enumerate(client_datasets):
109
+ s_data = sensitive_data[i] if sensitive_data else None
110
+ client = self._create_client(i, data, s_data)
111
+ self.clients.append(client)
112
+
113
+ def _create_client(
114
+ self,
115
+ client_id: int,
116
+ train_data: Tuple[np.ndarray, np.ndarray],
117
+ sensitive_data: Optional[np.ndarray] = None
118
+ ) -> AIBTClient:
119
+ """Create an AIBT client"""
120
+ return AIBTClient(
121
+ client_id=client_id,
122
+ model=self.global_model,
123
+ train_data=train_data,
124
+ sensitive_data=sensitive_data,
125
+ batch_size=self.batch_size,
126
+ learning_rate=self.learning_rate,
127
+ device=self.device,
128
+ lambda_kl=self.lambda_kl,
129
+ lambda_adv=self.lambda_adv
130
+ )
131
+
132
+ def distribute_model(self) -> None:
133
+ """Send global model to all clients"""
134
+ global_params = self.global_model.state_dict()
135
+ for client in self.clients:
136
+ client.set_model_params(copy.deepcopy(global_params))
137
+
138
+ def train_round(self, local_epochs: int) -> Dict[str, float]:
139
+ """
140
+ Execute one round of AIBT federated training.
141
+
142
+ Following Algorithm 3 from the paper:
143
+ 1. Distribute global model to clients
144
+ 2. Each client trains using Algorithm 1/2
145
+ 3. Aggregate updates using FedAvg
146
+ 4. Update global model
147
+ """
148
+ self.current_round += 1
149
+
150
+ # Distribute global model
151
+ self.distribute_model()
152
+
153
+ # Local training
154
+ client_params = []
155
+ client_sizes = []
156
+ client_metrics = []
157
+
158
+ for client in self.clients:
159
+ # Local AIBT training
160
+ metrics = client.train(local_epochs)
161
+ client_metrics.append(metrics)
162
+
163
+ # Collect parameters
164
+ client_params.append(client.get_model_params())
165
+ client_sizes.append(client.n_samples)
166
+
167
+ # Aggregate using FedAvg
168
+ aggregated_params = fedavg_aggregate(client_params, client_sizes)
169
+ self.global_model.load_state_dict(aggregated_params)
170
+
171
+ # Compute round metrics
172
+ avg_loss = np.mean([m["loss"] for m in client_metrics])
173
+ avg_acc = np.mean([m["accuracy"] for m in client_metrics])
174
+ avg_task_loss = np.mean([m.get("task_loss", 0) for m in client_metrics])
175
+ avg_kl_loss = np.mean([m.get("kl_loss", 0) for m in client_metrics])
176
+ avg_adv_loss = np.mean([m.get("adv_loss", 0) for m in client_metrics])
177
+
178
+ round_metrics = {
179
+ "round": self.current_round,
180
+ "train_loss": avg_loss,
181
+ "train_accuracy": avg_acc,
182
+ "task_loss": avg_task_loss,
183
+ "kl_loss": avg_kl_loss,
184
+ "adv_loss": avg_adv_loss,
185
+ }
186
+
187
+ self.round_metrics.append(round_metrics)
188
+ return round_metrics
189
+
190
+ def evaluate(self, test_data: Tuple[np.ndarray, np.ndarray]) -> Dict[str, float]:
191
+ """Evaluate global model on test data"""
192
+ self.global_model.eval()
193
+
194
+ X, y = test_data
195
+ X_tensor = torch.FloatTensor(X).to(self.device)
196
+ y_tensor = torch.LongTensor(y).to(self.device)
197
+
198
+ with torch.no_grad():
199
+ output = self.global_model(X_tensor)
200
+ if isinstance(output, tuple):
201
+ logits = output[0]
202
+ else:
203
+ logits = output
204
+
205
+ loss = F.cross_entropy(logits, y_tensor).item()
206
+ pred = logits.argmax(dim=1)
207
+ accuracy = (pred == y_tensor).float().mean().item()
208
+
209
+ return {"test_loss": loss, "test_accuracy": accuracy}
210
+
211
+ def get_latent_representations(self, X: np.ndarray) -> np.ndarray:
212
+ """Get latent representations from global model"""
213
+ self.global_model.eval()
214
+ X_tensor = torch.FloatTensor(X).to(self.device)
215
+
216
+ with torch.no_grad():
217
+ if hasattr(self.global_model, 'get_latent'):
218
+ z = self.global_model.get_latent(X_tensor)
219
+ elif hasattr(self.global_model, 'encoder'):
220
+ z = self.global_model.encoder(X_tensor)
221
+ if isinstance(z, tuple):
222
+ z = z[0]
223
+ else:
224
+ raise AttributeError("Model does not have encoder")
225
+
226
+ return z.cpu().numpy()
227
+
228
+ def train(
229
+ self,
230
+ num_rounds: int,
231
+ local_epochs: int,
232
+ test_data: Optional[Tuple[np.ndarray, np.ndarray]] = None,
233
+ verbose: bool = True
234
+ ) -> Dict[str, Any]:
235
+ """
236
+ Complete AIBT federated training loop.
237
+
238
+ Args:
239
+ num_rounds: Number of federated learning rounds
240
+ local_epochs: Number of local training epochs per round
241
+ test_data: Optional (X_test, y_test) for evaluation
242
+ verbose: Print progress
243
+
244
+ Returns:
245
+ Dictionary with training history and final metrics
246
+ """
247
+ history = {
248
+ "round_metrics": [],
249
+ "test_metrics": []
250
+ }
251
+
252
+ for round_idx in range(num_rounds):
253
+ # Training round
254
+ round_metrics = self.train_round(local_epochs)
255
+
256
+ # Evaluation
257
+ if test_data is not None:
258
+ test_metrics = self.evaluate(test_data)
259
+ round_metrics.update(test_metrics)
260
+
261
+ history["round_metrics"].append(round_metrics)
262
+
263
+ if verbose:
264
+ print(f"Round {round_idx + 1}/{num_rounds} | "
265
+ f"Loss: {round_metrics['train_loss']:.4f} | "
266
+ f"Acc: {round_metrics['train_accuracy']:.4f}", end="")
267
+ if test_data is not None:
268
+ print(f" | Test Acc: {round_metrics.get('test_accuracy', 0):.4f}")
269
+ else:
270
+ print()
271
+
272
+ # Final evaluation
273
+ if test_data is not None:
274
+ history["final_test_metrics"] = self.evaluate(test_data)
275
+
276
+ return history
277
+
278
+ def get_model(self) -> nn.Module:
279
+ """Get the global model"""
280
+ return self.global_model
281
+
282
+ def get_communication_cost(self) -> int:
283
+ """Calculate communication cost in bytes"""
284
+ total_bytes = 0
285
+ for param in self.global_model.parameters():
286
+ total_bytes += param.numel() * param.element_size()
287
+ return total_bytes