hqde 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hqde might be problematic. Click here for more details.
- hqde/__init__.py +62 -0
- hqde/__main__.py +0 -0
- hqde/core/__init__.py +23 -0
- hqde/core/hqde_system.py +380 -0
- hqde/distributed/__init__.py +18 -0
- hqde/distributed/fault_tolerance.py +346 -0
- hqde/distributed/hierarchical_aggregator.py +399 -0
- hqde/distributed/load_balancer.py +498 -0
- hqde/distributed/mapreduce_ensemble.py +394 -0
- hqde/py.typed +0 -0
- hqde/quantum/__init__.py +17 -0
- hqde/quantum/quantum_aggregator.py +291 -0
- hqde/quantum/quantum_noise.py +284 -0
- hqde/quantum/quantum_optimization.py +336 -0
- hqde/utils/__init__.py +20 -0
- hqde/utils/config_manager.py +9 -0
- hqde/utils/data_utils.py +13 -0
- hqde/utils/performance_monitor.py +465 -0
- hqde/utils/visualization.py +9 -0
- hqde-0.1.0.dist-info/METADATA +237 -0
- hqde-0.1.0.dist-info/RECORD +24 -0
- hqde-0.1.0.dist-info/WHEEL +5 -0
- hqde-0.1.0.dist-info/licenses/LICENSE +21 -0
- hqde-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,346 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Fault tolerance module for HQDE framework.
|
|
3
|
+
|
|
4
|
+
This module implements Byzantine fault tolerance, checkpointing, and recovery
|
|
5
|
+
mechanisms for robust distributed ensemble learning.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import ray
|
|
10
|
+
import numpy as np
|
|
11
|
+
from typing import Dict, List, Optional, Tuple, Any
|
|
12
|
+
import time
|
|
13
|
+
import pickle
|
|
14
|
+
import hashlib
|
|
15
|
+
import logging
|
|
16
|
+
from collections import defaultdict
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ByzantineFaultTolerantAggregator:
|
|
20
|
+
"""Byzantine fault-tolerant aggregator for ensemble weights."""
|
|
21
|
+
|
|
22
|
+
def __init__(self,
|
|
23
|
+
byzantine_threshold: float = 0.33,
|
|
24
|
+
outlier_detection_method: str = "median_absolute_deviation",
|
|
25
|
+
min_reliable_sources: int = 3):
|
|
26
|
+
"""
|
|
27
|
+
Initialize Byzantine fault-tolerant aggregator.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
byzantine_threshold: Maximum fraction of Byzantine nodes to tolerate
|
|
31
|
+
outlier_detection_method: Method for detecting outliers
|
|
32
|
+
min_reliable_sources: Minimum number of reliable sources required
|
|
33
|
+
"""
|
|
34
|
+
self.byzantine_threshold = byzantine_threshold
|
|
35
|
+
self.outlier_detection_method = outlier_detection_method
|
|
36
|
+
self.min_reliable_sources = min_reliable_sources
|
|
37
|
+
self.source_reliability_scores = defaultdict(float)
|
|
38
|
+
self.detection_history = defaultdict(list)
|
|
39
|
+
|
|
40
|
+
def robust_aggregation(self,
|
|
41
|
+
weight_updates: List[Dict[str, torch.Tensor]],
|
|
42
|
+
source_ids: List[str],
|
|
43
|
+
confidence_scores: Optional[List[float]] = None) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
|
44
|
+
"""
|
|
45
|
+
Perform Byzantine fault-tolerant aggregation.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
weight_updates: List of weight updates from different sources
|
|
49
|
+
source_ids: Identifiers for each source
|
|
50
|
+
confidence_scores: Optional confidence scores for each source
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Tuple of (aggregated_weights, fault_tolerance_metrics)
|
|
54
|
+
"""
|
|
55
|
+
if len(weight_updates) != len(source_ids):
|
|
56
|
+
raise ValueError("Number of weight updates must match number of source IDs")
|
|
57
|
+
|
|
58
|
+
if len(weight_updates) < self.min_reliable_sources:
|
|
59
|
+
raise ValueError(f"Need at least {self.min_reliable_sources} sources for fault tolerance")
|
|
60
|
+
|
|
61
|
+
# Filter out potentially corrupted updates
|
|
62
|
+
reliable_updates, reliable_sources, fault_metrics = self._detect_and_filter_byzantines(
|
|
63
|
+
weight_updates, source_ids, confidence_scores
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Perform robust aggregation on reliable updates
|
|
67
|
+
if len(reliable_updates) >= self.min_reliable_sources:
|
|
68
|
+
aggregated_weights = self._geometric_median_aggregation(reliable_updates)
|
|
69
|
+
else:
|
|
70
|
+
# Fallback to simple median if not enough reliable sources
|
|
71
|
+
aggregated_weights = self._median_aggregation(weight_updates)
|
|
72
|
+
fault_metrics['fallback_used'] = True
|
|
73
|
+
|
|
74
|
+
# Update source reliability scores
|
|
75
|
+
self._update_reliability_scores(source_ids, fault_metrics['byzantine_sources'])
|
|
76
|
+
|
|
77
|
+
return aggregated_weights, fault_metrics
|
|
78
|
+
|
|
79
|
+
def _detect_and_filter_byzantines(self,
|
|
80
|
+
weight_updates: List[Dict[str, torch.Tensor]],
|
|
81
|
+
source_ids: List[str],
|
|
82
|
+
confidence_scores: Optional[List[float]]) -> Tuple[List[Dict[str, torch.Tensor]], List[str], Dict[str, Any]]:
|
|
83
|
+
"""Detect and filter out Byzantine sources."""
|
|
84
|
+
num_sources = len(weight_updates)
|
|
85
|
+
max_byzantines = int(num_sources * self.byzantine_threshold)
|
|
86
|
+
|
|
87
|
+
byzantine_scores = []
|
|
88
|
+
fault_metrics = {
|
|
89
|
+
'byzantine_sources': [],
|
|
90
|
+
'outlier_scores': {},
|
|
91
|
+
'detection_method': self.outlier_detection_method,
|
|
92
|
+
'fallback_used': False
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
# Calculate outlier scores for each source
|
|
96
|
+
for i, (update, source_id) in enumerate(zip(weight_updates, source_ids)):
|
|
97
|
+
outlier_score = self._calculate_outlier_score(update, weight_updates, i)
|
|
98
|
+
byzantine_scores.append(outlier_score)
|
|
99
|
+
fault_metrics['outlier_scores'][source_id] = outlier_score
|
|
100
|
+
|
|
101
|
+
# Identify Byzantine sources
|
|
102
|
+
byzantine_indices = []
|
|
103
|
+
if max_byzantines > 0:
|
|
104
|
+
# Sort by outlier score and mark worst ones as Byzantine
|
|
105
|
+
sorted_indices = sorted(range(num_sources), key=lambda i: byzantine_scores[i], reverse=True)
|
|
106
|
+
byzantine_indices = sorted_indices[:max_byzantines]
|
|
107
|
+
|
|
108
|
+
# Additional filtering based on reliability history
|
|
109
|
+
for idx in sorted_indices:
|
|
110
|
+
source_id = source_ids[idx]
|
|
111
|
+
if (self.source_reliability_scores[source_id] < 0.3 and
|
|
112
|
+
byzantine_scores[idx] > np.median(byzantine_scores) + np.std(byzantine_scores)):
|
|
113
|
+
if idx not in byzantine_indices:
|
|
114
|
+
byzantine_indices.append(idx)
|
|
115
|
+
|
|
116
|
+
# Filter out Byzantine sources
|
|
117
|
+
reliable_updates = []
|
|
118
|
+
reliable_sources = []
|
|
119
|
+
|
|
120
|
+
for i, (update, source_id) in enumerate(zip(weight_updates, source_ids)):
|
|
121
|
+
if i not in byzantine_indices:
|
|
122
|
+
reliable_updates.append(update)
|
|
123
|
+
reliable_sources.append(source_id)
|
|
124
|
+
else:
|
|
125
|
+
fault_metrics['byzantine_sources'].append(source_id)
|
|
126
|
+
|
|
127
|
+
return reliable_updates, reliable_sources, fault_metrics
|
|
128
|
+
|
|
129
|
+
def _calculate_outlier_score(self,
|
|
130
|
+
target_update: Dict[str, torch.Tensor],
|
|
131
|
+
all_updates: List[Dict[str, torch.Tensor]],
|
|
132
|
+
target_index: int) -> float:
|
|
133
|
+
"""Calculate outlier score for a target update."""
|
|
134
|
+
if self.outlier_detection_method == "median_absolute_deviation":
|
|
135
|
+
return self._mad_outlier_score(target_update, all_updates, target_index)
|
|
136
|
+
elif self.outlier_detection_method == "cosine_similarity":
|
|
137
|
+
return self._cosine_similarity_outlier_score(target_update, all_updates, target_index)
|
|
138
|
+
else:
|
|
139
|
+
return self._euclidean_distance_outlier_score(target_update, all_updates, target_index)
|
|
140
|
+
|
|
141
|
+
def _mad_outlier_score(self,
|
|
142
|
+
target_update: Dict[str, torch.Tensor],
|
|
143
|
+
all_updates: List[Dict[str, torch.Tensor]],
|
|
144
|
+
target_index: int) -> float:
|
|
145
|
+
"""Calculate outlier score using Median Absolute Deviation."""
|
|
146
|
+
total_mad_score = 0.0
|
|
147
|
+
param_count = 0
|
|
148
|
+
|
|
149
|
+
for param_name in target_update.keys():
|
|
150
|
+
# Collect parameter values from all updates
|
|
151
|
+
param_values = []
|
|
152
|
+
target_value = target_update[param_name].flatten()
|
|
153
|
+
|
|
154
|
+
for i, update in enumerate(all_updates):
|
|
155
|
+
if param_name in update and i != target_index:
|
|
156
|
+
param_values.append(update[param_name].flatten())
|
|
157
|
+
|
|
158
|
+
if len(param_values) < 2:
|
|
159
|
+
continue
|
|
160
|
+
|
|
161
|
+
# Calculate median and MAD
|
|
162
|
+
stacked_values = torch.stack(param_values)
|
|
163
|
+
median_value = torch.median(stacked_values, dim=0)[0]
|
|
164
|
+
|
|
165
|
+
absolute_deviations = torch.abs(stacked_values - median_value.unsqueeze(0))
|
|
166
|
+
mad = torch.median(absolute_deviations, dim=0)[0]
|
|
167
|
+
|
|
168
|
+
# Calculate MAD score for target
|
|
169
|
+
target_deviation = torch.abs(target_value - median_value)
|
|
170
|
+
mad_score = torch.mean(target_deviation / (mad + 1e-8)).item()
|
|
171
|
+
|
|
172
|
+
total_mad_score += mad_score
|
|
173
|
+
param_count += 1
|
|
174
|
+
|
|
175
|
+
return total_mad_score / max(param_count, 1)
|
|
176
|
+
|
|
177
|
+
def _cosine_similarity_outlier_score(self,
|
|
178
|
+
target_update: Dict[str, torch.Tensor],
|
|
179
|
+
all_updates: List[Dict[str, torch.Tensor]],
|
|
180
|
+
target_index: int) -> float:
|
|
181
|
+
"""Calculate outlier score using cosine similarity."""
|
|
182
|
+
similarities = []
|
|
183
|
+
|
|
184
|
+
# Flatten target update
|
|
185
|
+
target_flat = torch.cat([param.flatten() for param in target_update.values()])
|
|
186
|
+
|
|
187
|
+
for i, update in enumerate(all_updates):
|
|
188
|
+
if i != target_index:
|
|
189
|
+
# Flatten comparison update
|
|
190
|
+
try:
|
|
191
|
+
update_flat = torch.cat([update[param_name].flatten()
|
|
192
|
+
for param_name in target_update.keys()
|
|
193
|
+
if param_name in update])
|
|
194
|
+
|
|
195
|
+
if len(update_flat) == len(target_flat):
|
|
196
|
+
similarity = torch.cosine_similarity(target_flat, update_flat, dim=0)
|
|
197
|
+
similarities.append(similarity.item())
|
|
198
|
+
except:
|
|
199
|
+
continue
|
|
200
|
+
|
|
201
|
+
if not similarities:
|
|
202
|
+
return 0.0
|
|
203
|
+
|
|
204
|
+
# Lower similarity means higher outlier score
|
|
205
|
+
avg_similarity = np.mean(similarities)
|
|
206
|
+
return 1.0 - max(0.0, avg_similarity)
|
|
207
|
+
|
|
208
|
+
def _euclidean_distance_outlier_score(self,
|
|
209
|
+
target_update: Dict[str, torch.Tensor],
|
|
210
|
+
all_updates: List[Dict[str, torch.Tensor]],
|
|
211
|
+
target_index: int) -> float:
|
|
212
|
+
"""Calculate outlier score using Euclidean distance."""
|
|
213
|
+
distances = []
|
|
214
|
+
|
|
215
|
+
# Flatten target update
|
|
216
|
+
target_flat = torch.cat([param.flatten() for param in target_update.values()])
|
|
217
|
+
|
|
218
|
+
for i, update in enumerate(all_updates):
|
|
219
|
+
if i != target_index:
|
|
220
|
+
try:
|
|
221
|
+
update_flat = torch.cat([update[param_name].flatten()
|
|
222
|
+
for param_name in target_update.keys()
|
|
223
|
+
if param_name in update])
|
|
224
|
+
|
|
225
|
+
if len(update_flat) == len(target_flat):
|
|
226
|
+
distance = torch.norm(target_flat - update_flat).item()
|
|
227
|
+
distances.append(distance)
|
|
228
|
+
except:
|
|
229
|
+
continue
|
|
230
|
+
|
|
231
|
+
if not distances:
|
|
232
|
+
return 0.0
|
|
233
|
+
|
|
234
|
+
# Normalize by median distance
|
|
235
|
+
median_distance = np.median(distances)
|
|
236
|
+
avg_distance = np.mean(distances)
|
|
237
|
+
|
|
238
|
+
return avg_distance / (median_distance + 1e-8)
|
|
239
|
+
|
|
240
|
+
def _geometric_median_aggregation(self, weight_updates: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
|
241
|
+
"""Aggregate weights using geometric median for robustness."""
|
|
242
|
+
if len(weight_updates) == 1:
|
|
243
|
+
return weight_updates[0].copy()
|
|
244
|
+
|
|
245
|
+
aggregated_weights = {}
|
|
246
|
+
|
|
247
|
+
for param_name in weight_updates[0].keys():
|
|
248
|
+
# Collect parameter tensors
|
|
249
|
+
param_tensors = []
|
|
250
|
+
for update in weight_updates:
|
|
251
|
+
if param_name in update:
|
|
252
|
+
param_tensors.append(update[param_name])
|
|
253
|
+
|
|
254
|
+
if len(param_tensors) < 2:
|
|
255
|
+
aggregated_weights[param_name] = param_tensors[0].clone()
|
|
256
|
+
continue
|
|
257
|
+
|
|
258
|
+
# Calculate geometric median using iterative algorithm
|
|
259
|
+
geometric_median = self._calculate_geometric_median(param_tensors)
|
|
260
|
+
aggregated_weights[param_name] = geometric_median
|
|
261
|
+
|
|
262
|
+
return aggregated_weights
|
|
263
|
+
|
|
264
|
+
def _calculate_geometric_median(self, tensors: List[torch.Tensor], max_iterations: int = 100) -> torch.Tensor:
|
|
265
|
+
"""Calculate geometric median of tensor list."""
|
|
266
|
+
if len(tensors) == 1:
|
|
267
|
+
return tensors[0].clone()
|
|
268
|
+
|
|
269
|
+
# Initialize with arithmetic mean
|
|
270
|
+
current_median = torch.stack(tensors).mean(dim=0)
|
|
271
|
+
|
|
272
|
+
for iteration in range(max_iterations):
|
|
273
|
+
# Calculate weights based on distances
|
|
274
|
+
distances = []
|
|
275
|
+
for tensor in tensors:
|
|
276
|
+
dist = torch.norm(tensor - current_median)
|
|
277
|
+
distances.append(max(dist.item(), 1e-8)) # Avoid division by zero
|
|
278
|
+
|
|
279
|
+
# Update median using weighted average
|
|
280
|
+
weights = [1.0 / dist for dist in distances]
|
|
281
|
+
weight_sum = sum(weights)
|
|
282
|
+
weights = [w / weight_sum for w in weights]
|
|
283
|
+
|
|
284
|
+
new_median = torch.zeros_like(current_median)
|
|
285
|
+
for tensor, weight in zip(tensors, weights):
|
|
286
|
+
new_median += weight * tensor
|
|
287
|
+
|
|
288
|
+
# Check convergence
|
|
289
|
+
if torch.norm(new_median - current_median) < 1e-6:
|
|
290
|
+
break
|
|
291
|
+
|
|
292
|
+
current_median = new_median
|
|
293
|
+
|
|
294
|
+
return current_median
|
|
295
|
+
|
|
296
|
+
def _median_aggregation(self, weight_updates: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
|
297
|
+
"""Simple median aggregation as fallback."""
|
|
298
|
+
aggregated_weights = {}
|
|
299
|
+
|
|
300
|
+
for param_name in weight_updates[0].keys():
|
|
301
|
+
param_tensors = []
|
|
302
|
+
for update in weight_updates:
|
|
303
|
+
if param_name in update:
|
|
304
|
+
param_tensors.append(update[param_name])
|
|
305
|
+
|
|
306
|
+
if param_tensors:
|
|
307
|
+
stacked_tensors = torch.stack(param_tensors)
|
|
308
|
+
aggregated_weights[param_name] = torch.median(stacked_tensors, dim=0)[0]
|
|
309
|
+
|
|
310
|
+
return aggregated_weights
|
|
311
|
+
|
|
312
|
+
def _update_reliability_scores(self, source_ids: List[str], byzantine_sources: List[str]):
|
|
313
|
+
"""Update reliability scores for sources."""
|
|
314
|
+
for source_id in source_ids:
|
|
315
|
+
if source_id in byzantine_sources:
|
|
316
|
+
# Decrease reliability for Byzantine sources
|
|
317
|
+
self.source_reliability_scores[source_id] = max(
|
|
318
|
+
0.0, self.source_reliability_scores[source_id] - 0.1
|
|
319
|
+
)
|
|
320
|
+
self.detection_history[source_id].append(('byzantine', time.time()))
|
|
321
|
+
else:
|
|
322
|
+
# Increase reliability for honest sources
|
|
323
|
+
self.source_reliability_scores[source_id] = min(
|
|
324
|
+
1.0, self.source_reliability_scores[source_id] + 0.05
|
|
325
|
+
)
|
|
326
|
+
self.detection_history[source_id].append(('honest', time.time()))
|
|
327
|
+
|
|
328
|
+
# Keep only recent history
|
|
329
|
+
if len(self.detection_history[source_id]) > 100:
|
|
330
|
+
self.detection_history[source_id] = self.detection_history[source_id][-100:]
|
|
331
|
+
|
|
332
|
+
def get_reliability_statistics(self) -> Dict[str, Any]:
|
|
333
|
+
"""Get reliability statistics for all sources."""
|
|
334
|
+
return {
|
|
335
|
+
'source_reliability_scores': dict(self.source_reliability_scores),
|
|
336
|
+
'detection_history_summary': {
|
|
337
|
+
source_id: {
|
|
338
|
+
'total_detections': len(history),
|
|
339
|
+
'byzantine_count': sum(1 for event, _ in history if event == 'byzantine'),
|
|
340
|
+
'honest_count': sum(1 for event, _ in history if event == 'honest')
|
|
341
|
+
}
|
|
342
|
+
for source_id, history in self.detection_history.items()
|
|
343
|
+
},
|
|
344
|
+
'byzantine_threshold': self.byzantine_threshold,
|
|
345
|
+
'detection_method': self.outlier_detection_method
|
|
346
|
+
}
|