hqde 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hqde might be problematic. Click here for more details.

@@ -0,0 +1,346 @@
1
+ """
2
+ Fault tolerance module for HQDE framework.
3
+
4
+ This module implements Byzantine fault tolerance, checkpointing, and recovery
5
+ mechanisms for robust distributed ensemble learning.
6
+ """
7
+
8
+ import torch
9
+ import ray
10
+ import numpy as np
11
+ from typing import Dict, List, Optional, Tuple, Any
12
+ import time
13
+ import pickle
14
+ import hashlib
15
+ import logging
16
+ from collections import defaultdict
17
+
18
+
19
+ class ByzantineFaultTolerantAggregator:
20
+ """Byzantine fault-tolerant aggregator for ensemble weights."""
21
+
22
+ def __init__(self,
23
+ byzantine_threshold: float = 0.33,
24
+ outlier_detection_method: str = "median_absolute_deviation",
25
+ min_reliable_sources: int = 3):
26
+ """
27
+ Initialize Byzantine fault-tolerant aggregator.
28
+
29
+ Args:
30
+ byzantine_threshold: Maximum fraction of Byzantine nodes to tolerate
31
+ outlier_detection_method: Method for detecting outliers
32
+ min_reliable_sources: Minimum number of reliable sources required
33
+ """
34
+ self.byzantine_threshold = byzantine_threshold
35
+ self.outlier_detection_method = outlier_detection_method
36
+ self.min_reliable_sources = min_reliable_sources
37
+ self.source_reliability_scores = defaultdict(float)
38
+ self.detection_history = defaultdict(list)
39
+
40
+ def robust_aggregation(self,
41
+ weight_updates: List[Dict[str, torch.Tensor]],
42
+ source_ids: List[str],
43
+ confidence_scores: Optional[List[float]] = None) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
44
+ """
45
+ Perform Byzantine fault-tolerant aggregation.
46
+
47
+ Args:
48
+ weight_updates: List of weight updates from different sources
49
+ source_ids: Identifiers for each source
50
+ confidence_scores: Optional confidence scores for each source
51
+
52
+ Returns:
53
+ Tuple of (aggregated_weights, fault_tolerance_metrics)
54
+ """
55
+ if len(weight_updates) != len(source_ids):
56
+ raise ValueError("Number of weight updates must match number of source IDs")
57
+
58
+ if len(weight_updates) < self.min_reliable_sources:
59
+ raise ValueError(f"Need at least {self.min_reliable_sources} sources for fault tolerance")
60
+
61
+ # Filter out potentially corrupted updates
62
+ reliable_updates, reliable_sources, fault_metrics = self._detect_and_filter_byzantines(
63
+ weight_updates, source_ids, confidence_scores
64
+ )
65
+
66
+ # Perform robust aggregation on reliable updates
67
+ if len(reliable_updates) >= self.min_reliable_sources:
68
+ aggregated_weights = self._geometric_median_aggregation(reliable_updates)
69
+ else:
70
+ # Fallback to simple median if not enough reliable sources
71
+ aggregated_weights = self._median_aggregation(weight_updates)
72
+ fault_metrics['fallback_used'] = True
73
+
74
+ # Update source reliability scores
75
+ self._update_reliability_scores(source_ids, fault_metrics['byzantine_sources'])
76
+
77
+ return aggregated_weights, fault_metrics
78
+
79
+ def _detect_and_filter_byzantines(self,
80
+ weight_updates: List[Dict[str, torch.Tensor]],
81
+ source_ids: List[str],
82
+ confidence_scores: Optional[List[float]]) -> Tuple[List[Dict[str, torch.Tensor]], List[str], Dict[str, Any]]:
83
+ """Detect and filter out Byzantine sources."""
84
+ num_sources = len(weight_updates)
85
+ max_byzantines = int(num_sources * self.byzantine_threshold)
86
+
87
+ byzantine_scores = []
88
+ fault_metrics = {
89
+ 'byzantine_sources': [],
90
+ 'outlier_scores': {},
91
+ 'detection_method': self.outlier_detection_method,
92
+ 'fallback_used': False
93
+ }
94
+
95
+ # Calculate outlier scores for each source
96
+ for i, (update, source_id) in enumerate(zip(weight_updates, source_ids)):
97
+ outlier_score = self._calculate_outlier_score(update, weight_updates, i)
98
+ byzantine_scores.append(outlier_score)
99
+ fault_metrics['outlier_scores'][source_id] = outlier_score
100
+
101
+ # Identify Byzantine sources
102
+ byzantine_indices = []
103
+ if max_byzantines > 0:
104
+ # Sort by outlier score and mark worst ones as Byzantine
105
+ sorted_indices = sorted(range(num_sources), key=lambda i: byzantine_scores[i], reverse=True)
106
+ byzantine_indices = sorted_indices[:max_byzantines]
107
+
108
+ # Additional filtering based on reliability history
109
+ for idx in sorted_indices:
110
+ source_id = source_ids[idx]
111
+ if (self.source_reliability_scores[source_id] < 0.3 and
112
+ byzantine_scores[idx] > np.median(byzantine_scores) + np.std(byzantine_scores)):
113
+ if idx not in byzantine_indices:
114
+ byzantine_indices.append(idx)
115
+
116
+ # Filter out Byzantine sources
117
+ reliable_updates = []
118
+ reliable_sources = []
119
+
120
+ for i, (update, source_id) in enumerate(zip(weight_updates, source_ids)):
121
+ if i not in byzantine_indices:
122
+ reliable_updates.append(update)
123
+ reliable_sources.append(source_id)
124
+ else:
125
+ fault_metrics['byzantine_sources'].append(source_id)
126
+
127
+ return reliable_updates, reliable_sources, fault_metrics
128
+
129
+ def _calculate_outlier_score(self,
130
+ target_update: Dict[str, torch.Tensor],
131
+ all_updates: List[Dict[str, torch.Tensor]],
132
+ target_index: int) -> float:
133
+ """Calculate outlier score for a target update."""
134
+ if self.outlier_detection_method == "median_absolute_deviation":
135
+ return self._mad_outlier_score(target_update, all_updates, target_index)
136
+ elif self.outlier_detection_method == "cosine_similarity":
137
+ return self._cosine_similarity_outlier_score(target_update, all_updates, target_index)
138
+ else:
139
+ return self._euclidean_distance_outlier_score(target_update, all_updates, target_index)
140
+
141
+ def _mad_outlier_score(self,
142
+ target_update: Dict[str, torch.Tensor],
143
+ all_updates: List[Dict[str, torch.Tensor]],
144
+ target_index: int) -> float:
145
+ """Calculate outlier score using Median Absolute Deviation."""
146
+ total_mad_score = 0.0
147
+ param_count = 0
148
+
149
+ for param_name in target_update.keys():
150
+ # Collect parameter values from all updates
151
+ param_values = []
152
+ target_value = target_update[param_name].flatten()
153
+
154
+ for i, update in enumerate(all_updates):
155
+ if param_name in update and i != target_index:
156
+ param_values.append(update[param_name].flatten())
157
+
158
+ if len(param_values) < 2:
159
+ continue
160
+
161
+ # Calculate median and MAD
162
+ stacked_values = torch.stack(param_values)
163
+ median_value = torch.median(stacked_values, dim=0)[0]
164
+
165
+ absolute_deviations = torch.abs(stacked_values - median_value.unsqueeze(0))
166
+ mad = torch.median(absolute_deviations, dim=0)[0]
167
+
168
+ # Calculate MAD score for target
169
+ target_deviation = torch.abs(target_value - median_value)
170
+ mad_score = torch.mean(target_deviation / (mad + 1e-8)).item()
171
+
172
+ total_mad_score += mad_score
173
+ param_count += 1
174
+
175
+ return total_mad_score / max(param_count, 1)
176
+
177
+ def _cosine_similarity_outlier_score(self,
178
+ target_update: Dict[str, torch.Tensor],
179
+ all_updates: List[Dict[str, torch.Tensor]],
180
+ target_index: int) -> float:
181
+ """Calculate outlier score using cosine similarity."""
182
+ similarities = []
183
+
184
+ # Flatten target update
185
+ target_flat = torch.cat([param.flatten() for param in target_update.values()])
186
+
187
+ for i, update in enumerate(all_updates):
188
+ if i != target_index:
189
+ # Flatten comparison update
190
+ try:
191
+ update_flat = torch.cat([update[param_name].flatten()
192
+ for param_name in target_update.keys()
193
+ if param_name in update])
194
+
195
+ if len(update_flat) == len(target_flat):
196
+ similarity = torch.cosine_similarity(target_flat, update_flat, dim=0)
197
+ similarities.append(similarity.item())
198
+ except:
199
+ continue
200
+
201
+ if not similarities:
202
+ return 0.0
203
+
204
+ # Lower similarity means higher outlier score
205
+ avg_similarity = np.mean(similarities)
206
+ return 1.0 - max(0.0, avg_similarity)
207
+
208
+ def _euclidean_distance_outlier_score(self,
209
+ target_update: Dict[str, torch.Tensor],
210
+ all_updates: List[Dict[str, torch.Tensor]],
211
+ target_index: int) -> float:
212
+ """Calculate outlier score using Euclidean distance."""
213
+ distances = []
214
+
215
+ # Flatten target update
216
+ target_flat = torch.cat([param.flatten() for param in target_update.values()])
217
+
218
+ for i, update in enumerate(all_updates):
219
+ if i != target_index:
220
+ try:
221
+ update_flat = torch.cat([update[param_name].flatten()
222
+ for param_name in target_update.keys()
223
+ if param_name in update])
224
+
225
+ if len(update_flat) == len(target_flat):
226
+ distance = torch.norm(target_flat - update_flat).item()
227
+ distances.append(distance)
228
+ except:
229
+ continue
230
+
231
+ if not distances:
232
+ return 0.0
233
+
234
+ # Normalize by median distance
235
+ median_distance = np.median(distances)
236
+ avg_distance = np.mean(distances)
237
+
238
+ return avg_distance / (median_distance + 1e-8)
239
+
240
+ def _geometric_median_aggregation(self, weight_updates: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
241
+ """Aggregate weights using geometric median for robustness."""
242
+ if len(weight_updates) == 1:
243
+ return weight_updates[0].copy()
244
+
245
+ aggregated_weights = {}
246
+
247
+ for param_name in weight_updates[0].keys():
248
+ # Collect parameter tensors
249
+ param_tensors = []
250
+ for update in weight_updates:
251
+ if param_name in update:
252
+ param_tensors.append(update[param_name])
253
+
254
+ if len(param_tensors) < 2:
255
+ aggregated_weights[param_name] = param_tensors[0].clone()
256
+ continue
257
+
258
+ # Calculate geometric median using iterative algorithm
259
+ geometric_median = self._calculate_geometric_median(param_tensors)
260
+ aggregated_weights[param_name] = geometric_median
261
+
262
+ return aggregated_weights
263
+
264
+ def _calculate_geometric_median(self, tensors: List[torch.Tensor], max_iterations: int = 100) -> torch.Tensor:
265
+ """Calculate geometric median of tensor list."""
266
+ if len(tensors) == 1:
267
+ return tensors[0].clone()
268
+
269
+ # Initialize with arithmetic mean
270
+ current_median = torch.stack(tensors).mean(dim=0)
271
+
272
+ for iteration in range(max_iterations):
273
+ # Calculate weights based on distances
274
+ distances = []
275
+ for tensor in tensors:
276
+ dist = torch.norm(tensor - current_median)
277
+ distances.append(max(dist.item(), 1e-8)) # Avoid division by zero
278
+
279
+ # Update median using weighted average
280
+ weights = [1.0 / dist for dist in distances]
281
+ weight_sum = sum(weights)
282
+ weights = [w / weight_sum for w in weights]
283
+
284
+ new_median = torch.zeros_like(current_median)
285
+ for tensor, weight in zip(tensors, weights):
286
+ new_median += weight * tensor
287
+
288
+ # Check convergence
289
+ if torch.norm(new_median - current_median) < 1e-6:
290
+ break
291
+
292
+ current_median = new_median
293
+
294
+ return current_median
295
+
296
+ def _median_aggregation(self, weight_updates: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
297
+ """Simple median aggregation as fallback."""
298
+ aggregated_weights = {}
299
+
300
+ for param_name in weight_updates[0].keys():
301
+ param_tensors = []
302
+ for update in weight_updates:
303
+ if param_name in update:
304
+ param_tensors.append(update[param_name])
305
+
306
+ if param_tensors:
307
+ stacked_tensors = torch.stack(param_tensors)
308
+ aggregated_weights[param_name] = torch.median(stacked_tensors, dim=0)[0]
309
+
310
+ return aggregated_weights
311
+
312
+ def _update_reliability_scores(self, source_ids: List[str], byzantine_sources: List[str]):
313
+ """Update reliability scores for sources."""
314
+ for source_id in source_ids:
315
+ if source_id in byzantine_sources:
316
+ # Decrease reliability for Byzantine sources
317
+ self.source_reliability_scores[source_id] = max(
318
+ 0.0, self.source_reliability_scores[source_id] - 0.1
319
+ )
320
+ self.detection_history[source_id].append(('byzantine', time.time()))
321
+ else:
322
+ # Increase reliability for honest sources
323
+ self.source_reliability_scores[source_id] = min(
324
+ 1.0, self.source_reliability_scores[source_id] + 0.05
325
+ )
326
+ self.detection_history[source_id].append(('honest', time.time()))
327
+
328
+ # Keep only recent history
329
+ if len(self.detection_history[source_id]) > 100:
330
+ self.detection_history[source_id] = self.detection_history[source_id][-100:]
331
+
332
+ def get_reliability_statistics(self) -> Dict[str, Any]:
333
+ """Get reliability statistics for all sources."""
334
+ return {
335
+ 'source_reliability_scores': dict(self.source_reliability_scores),
336
+ 'detection_history_summary': {
337
+ source_id: {
338
+ 'total_detections': len(history),
339
+ 'byzantine_count': sum(1 for event, _ in history if event == 'byzantine'),
340
+ 'honest_count': sum(1 for event, _ in history if event == 'honest')
341
+ }
342
+ for source_id, history in self.detection_history.items()
343
+ },
344
+ 'byzantine_threshold': self.byzantine_threshold,
345
+ 'detection_method': self.outlier_detection_method
346
+ }