hqde 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hqde might be problematic. Click here for more details.

@@ -0,0 +1,399 @@
1
+ """
2
+ Hierarchical aggregation module for HQDE framework.
3
+
4
+ This module implements tree-structured aggregation with adaptive topology
5
+ optimization and communication-efficient ensemble weight combination.
6
+ """
7
+
8
+ import torch
9
+ import ray
10
+ import numpy as np
11
+ from typing import Dict, List, Optional, Tuple, Any
12
+ import math
13
+ import time
14
+ import logging
15
+ from collections import defaultdict
16
+
17
+
18
+ @ray.remote
19
+ class AggregationNode:
20
+ """Individual node in the hierarchical aggregation tree."""
21
+
22
+ def __init__(self, node_id: str, level: int, max_children: int = 4):
23
+ self.node_id = node_id
24
+ self.level = level
25
+ self.max_children = max_children
26
+ self.children = []
27
+ self.parent = None
28
+ self.local_weights = {}
29
+ self.aggregated_weights = {}
30
+ self.performance_metrics = {
31
+ 'processing_time': 0.0,
32
+ 'communication_time': 0.0,
33
+ 'data_size_processed': 0
34
+ }
35
+
36
+ def add_child(self, child_node_id: str):
37
+ """Add a child node to this aggregation node."""
38
+ if len(self.children) < self.max_children:
39
+ self.children.append(child_node_id)
40
+ return True
41
+ return False
42
+
43
+ def set_parent(self, parent_node_id: str):
44
+ """Set the parent node for this aggregation node."""
45
+ self.parent = parent_node_id
46
+
47
+ def receive_weights(self, source_id: str, weights: Dict[str, torch.Tensor], metadata: Dict[str, Any] = None):
48
+ """Receive weights from a child node or ensemble member."""
49
+ self.local_weights[source_id] = {
50
+ 'weights': weights,
51
+ 'metadata': metadata or {},
52
+ 'timestamp': time.time()
53
+ }
54
+
55
+ def aggregate_local_weights(self, aggregation_method: str = "weighted_mean") -> Dict[str, torch.Tensor]:
56
+ """Aggregate weights from all sources at this node."""
57
+ if not self.local_weights:
58
+ return {}
59
+
60
+ start_time = time.time()
61
+
62
+ # Collect all weights by parameter name
63
+ param_weights = defaultdict(list)
64
+ param_metadata = defaultdict(list)
65
+
66
+ for source_id, data in self.local_weights.items():
67
+ weights = data['weights']
68
+ metadata = data['metadata']
69
+
70
+ for param_name, weight_tensor in weights.items():
71
+ param_weights[param_name].append(weight_tensor)
72
+ param_metadata[param_name].append(metadata.get('importance_score', 1.0))
73
+
74
+ # Aggregate each parameter
75
+ aggregated = {}
76
+ for param_name, weight_list in param_weights.items():
77
+ if aggregation_method == "weighted_mean":
78
+ importance_scores = param_metadata[param_name]
79
+ weights_tensor = torch.tensor(importance_scores, dtype=torch.float32)
80
+ weights_normalized = torch.softmax(weights_tensor, dim=0)
81
+
82
+ aggregated_param = torch.zeros_like(weight_list[0])
83
+ for weight, norm_weight in zip(weight_list, weights_normalized):
84
+ aggregated_param += norm_weight * weight
85
+
86
+ aggregated[param_name] = aggregated_param
87
+
88
+ elif aggregation_method == "median":
89
+ stacked_weights = torch.stack(weight_list)
90
+ aggregated[param_name] = torch.median(stacked_weights, dim=0)[0]
91
+
92
+ else: # default to mean
93
+ stacked_weights = torch.stack(weight_list)
94
+ aggregated[param_name] = torch.mean(stacked_weights, dim=0)
95
+
96
+ self.aggregated_weights = aggregated
97
+ self.performance_metrics['processing_time'] = time.time() - start_time
98
+
99
+ return aggregated
100
+
101
+ def get_aggregated_weights(self) -> Dict[str, torch.Tensor]:
102
+ """Get the aggregated weights from this node."""
103
+ return self.aggregated_weights.copy()
104
+
105
+ def clear_local_weights(self):
106
+ """Clear local weights to free memory."""
107
+ self.local_weights.clear()
108
+
109
+ def get_node_info(self) -> Dict[str, Any]:
110
+ """Get information about this node."""
111
+ return {
112
+ 'node_id': self.node_id,
113
+ 'level': self.level,
114
+ 'num_children': len(self.children),
115
+ 'children': self.children,
116
+ 'parent': self.parent,
117
+ 'num_local_weights': len(self.local_weights),
118
+ 'performance_metrics': self.performance_metrics
119
+ }
120
+
121
+
122
+ class HierarchicalAggregator:
123
+ """Hierarchical aggregation system for distributed ensemble learning."""
124
+
125
+ def __init__(self,
126
+ num_ensemble_members: int,
127
+ tree_branching_factor: int = 4,
128
+ adaptive_topology: bool = True):
129
+ """
130
+ Initialize hierarchical aggregator.
131
+
132
+ Args:
133
+ num_ensemble_members: Number of ensemble members
134
+ tree_branching_factor: Maximum children per node in aggregation tree
135
+ adaptive_topology: Whether to use adaptive topology optimization
136
+ """
137
+ self.num_ensemble_members = num_ensemble_members
138
+ self.tree_branching_factor = tree_branching_factor
139
+ self.adaptive_topology = adaptive_topology
140
+
141
+ self.nodes = {}
142
+ self.tree_structure = {}
143
+ self.root_node_id = None
144
+
145
+ # Performance monitoring
146
+ self.aggregation_metrics = {
147
+ 'total_aggregation_time': 0.0,
148
+ 'communication_overhead': 0.0,
149
+ 'tree_depth': 0,
150
+ 'total_nodes': 0
151
+ }
152
+
153
+ self._build_aggregation_tree()
154
+
155
+ def _build_aggregation_tree(self):
156
+ """Build the hierarchical aggregation tree."""
157
+ if self.num_ensemble_members <= 0:
158
+ return
159
+
160
+ # Calculate tree structure
161
+ num_leaves = self.num_ensemble_members
162
+ tree_levels = []
163
+ current_level_size = num_leaves
164
+
165
+ # Build tree bottom-up
166
+ level = 0
167
+ while current_level_size > 1:
168
+ tree_levels.append(current_level_size)
169
+ current_level_size = math.ceil(current_level_size / self.tree_branching_factor)
170
+ level += 1
171
+
172
+ tree_levels.append(1) # Root node
173
+ tree_levels.reverse() # Reverse to get top-down structure
174
+
175
+ self.aggregation_metrics['tree_depth'] = len(tree_levels) - 1
176
+ self.aggregation_metrics['total_nodes'] = sum(tree_levels)
177
+
178
+ # Create nodes for each level
179
+ node_counter = 0
180
+ level_nodes = {}
181
+
182
+ for level_idx, num_nodes in enumerate(tree_levels):
183
+ level_nodes[level_idx] = []
184
+
185
+ for node_idx in range(num_nodes):
186
+ node_id = f"agg_node_{level_idx}_{node_idx}"
187
+ node = AggregationNode.remote(node_id, level_idx, self.tree_branching_factor)
188
+ self.nodes[node_id] = node
189
+ level_nodes[level_idx].append(node_id)
190
+ node_counter += 1
191
+
192
+ # Set up parent-child relationships
193
+ for level_idx in range(len(tree_levels) - 1):
194
+ parent_level = level_idx
195
+ child_level = level_idx + 1
196
+
197
+ parent_nodes = level_nodes[parent_level]
198
+ child_nodes = level_nodes[child_level]
199
+
200
+ for child_idx, child_node_id in enumerate(child_nodes):
201
+ parent_idx = child_idx // self.tree_branching_factor
202
+ if parent_idx < len(parent_nodes):
203
+ parent_node_id = parent_nodes[parent_idx]
204
+
205
+ # Set parent-child relationship
206
+ ray.get(self.nodes[parent_node_id].add_child.remote(child_node_id))
207
+ ray.get(self.nodes[child_node_id].set_parent.remote(parent_node_id))
208
+
209
+ # Set root node
210
+ if tree_levels:
211
+ self.root_node_id = level_nodes[0][0]
212
+
213
+ self.tree_structure = level_nodes
214
+
215
+ def aggregate_ensemble_weights(self,
216
+ ensemble_weights: List[Dict[str, torch.Tensor]],
217
+ ensemble_metadata: Optional[List[Dict[str, Any]]] = None) -> Dict[str, torch.Tensor]:
218
+ """
219
+ Perform hierarchical aggregation of ensemble weights.
220
+
221
+ Args:
222
+ ensemble_weights: List of weight dictionaries from ensemble members
223
+ ensemble_metadata: Optional metadata for each ensemble member
224
+
225
+ Returns:
226
+ Hierarchically aggregated weights
227
+ """
228
+ if len(ensemble_weights) != self.num_ensemble_members:
229
+ raise ValueError(f"Expected {self.num_ensemble_members} ensemble members, got {len(ensemble_weights)}")
230
+
231
+ start_time = time.time()
232
+
233
+ # Distribute weights to leaf nodes
234
+ self._distribute_weights_to_leaves(ensemble_weights, ensemble_metadata)
235
+
236
+ # Perform bottom-up aggregation
237
+ aggregated_weights = self._perform_bottom_up_aggregation()
238
+
239
+ # Update performance metrics
240
+ self.aggregation_metrics['total_aggregation_time'] = time.time() - start_time
241
+
242
+ return aggregated_weights
243
+
244
+ def _distribute_weights_to_leaves(self,
245
+ ensemble_weights: List[Dict[str, torch.Tensor]],
246
+ ensemble_metadata: Optional[List[Dict[str, Any]]]):
247
+ """Distribute ensemble weights to leaf nodes."""
248
+ if not ensemble_metadata:
249
+ ensemble_metadata = [{}] * len(ensemble_weights)
250
+
251
+ # Get leaf nodes (highest level in tree_structure)
252
+ max_level = max(self.tree_structure.keys())
253
+ leaf_nodes = self.tree_structure[max_level]
254
+
255
+ # Distribute weights to leaf nodes
256
+ distribution_futures = []
257
+ for i, (weights, metadata) in enumerate(zip(ensemble_weights, ensemble_metadata)):
258
+ leaf_node_idx = i % len(leaf_nodes)
259
+ leaf_node_id = leaf_nodes[leaf_node_idx]
260
+ leaf_node = self.nodes[leaf_node_id]
261
+
262
+ source_id = f"ensemble_member_{i}"
263
+ future = leaf_node.receive_weights.remote(source_id, weights, metadata)
264
+ distribution_futures.append(future)
265
+
266
+ # Wait for all distributions to complete
267
+ ray.get(distribution_futures)
268
+
269
+ def _perform_bottom_up_aggregation(self) -> Dict[str, torch.Tensor]:
270
+ """Perform bottom-up aggregation through the tree."""
271
+ # Process levels from bottom to top
272
+ for level in sorted(self.tree_structure.keys(), reverse=True):
273
+ level_nodes = self.tree_structure[level]
274
+
275
+ # Aggregate at each node in this level
276
+ aggregation_futures = []
277
+ for node_id in level_nodes:
278
+ node = self.nodes[node_id]
279
+ future = node.aggregate_local_weights.remote("weighted_mean")
280
+ aggregation_futures.append((node_id, future))
281
+
282
+ # Wait for aggregations to complete
283
+ level_results = {}
284
+ for node_id, future in aggregation_futures:
285
+ aggregated_weights = ray.get(future)
286
+ level_results[node_id] = aggregated_weights
287
+
288
+ # If not at root level, send results to parent nodes
289
+ if level > 0:
290
+ parent_transmission_futures = []
291
+ for node_id in level_nodes:
292
+ if node_id in level_results:
293
+ node = self.nodes[node_id]
294
+ parent_info = ray.get(node.get_node_info.remote())
295
+ parent_id = parent_info['parent']
296
+
297
+ if parent_id and parent_id in self.nodes:
298
+ parent_node = self.nodes[parent_id]
299
+ weights = level_results[node_id]
300
+ metadata = {'source_level': level, 'source_node': node_id}
301
+
302
+ future = parent_node.receive_weights.remote(node_id, weights, metadata)
303
+ parent_transmission_futures.append(future)
304
+
305
+ # Wait for all transmissions to complete
306
+ ray.get(parent_transmission_futures)
307
+
308
+ # Clear local weights to free memory (except for root)
309
+ if level > 0:
310
+ clear_futures = []
311
+ for node_id in level_nodes:
312
+ node = self.nodes[node_id]
313
+ future = node.clear_local_weights.remote()
314
+ clear_futures.append(future)
315
+ ray.get(clear_futures)
316
+
317
+ # Get final result from root node
318
+ if self.root_node_id:
319
+ root_node = self.nodes[self.root_node_id]
320
+ final_weights = ray.get(root_node.get_aggregated_weights.remote())
321
+ return final_weights
322
+ else:
323
+ return {}
324
+
325
+ def optimize_tree_topology(self,
326
+ node_performance_data: Dict[str, Dict[str, float]],
327
+ network_topology: Optional[Dict[str, Any]] = None):
328
+ """
329
+ Optimize tree topology based on node performance and network characteristics.
330
+
331
+ Args:
332
+ node_performance_data: Performance metrics for each node
333
+ network_topology: Network topology information
334
+ """
335
+ if not self.adaptive_topology:
336
+ return
337
+
338
+ # Analyze current performance
339
+ bottleneck_nodes = []
340
+ for node_id, performance in node_performance_data.items():
341
+ processing_time = performance.get('processing_time', 0)
342
+ communication_time = performance.get('communication_time', 0)
343
+
344
+ # Identify bottlenecks
345
+ if processing_time > 1.0 or communication_time > 0.5: # Thresholds
346
+ bottleneck_nodes.append(node_id)
347
+
348
+ # Rebalance tree if bottlenecks detected
349
+ if bottleneck_nodes:
350
+ self._rebalance_tree(bottleneck_nodes, node_performance_data)
351
+
352
+ def _rebalance_tree(self,
353
+ bottleneck_nodes: List[str],
354
+ performance_data: Dict[str, Dict[str, float]]):
355
+ """Rebalance tree structure to address bottlenecks."""
356
+ # This is a simplified rebalancing strategy
357
+ # In practice, this would involve more sophisticated optimization
358
+
359
+ for bottleneck_node_id in bottleneck_nodes:
360
+ if bottleneck_node_id in self.nodes:
361
+ node_info = ray.get(self.nodes[bottleneck_node_id].get_node_info.remote())
362
+
363
+ # If node has too many children, redistribute them
364
+ if len(node_info['children']) > self.tree_branching_factor:
365
+ # Create additional intermediate nodes
366
+ self._split_overloaded_node(bottleneck_node_id, node_info)
367
+
368
+ def _split_overloaded_node(self, node_id: str, node_info: Dict[str, Any]):
369
+ """Split an overloaded node by creating intermediate nodes."""
370
+ # This would involve creating new intermediate nodes and
371
+ # redistributing children - simplified implementation
372
+ logging.info(f"Would split overloaded node {node_id} with {len(node_info['children'])} children")
373
+
374
+ def get_aggregation_statistics(self) -> Dict[str, Any]:
375
+ """Get statistics about the hierarchical aggregation process."""
376
+ # Collect node statistics
377
+ node_stat_futures = []
378
+ for node_id, node in self.nodes.items():
379
+ future = node.get_node_info.remote()
380
+ node_stat_futures.append((node_id, future))
381
+
382
+ node_statistics = {}
383
+ for node_id, future in node_stat_futures:
384
+ node_info = ray.get(future)
385
+ node_statistics[node_id] = node_info
386
+
387
+ return {
388
+ 'aggregation_metrics': self.aggregation_metrics,
389
+ 'tree_structure': self.tree_structure,
390
+ 'node_statistics': node_statistics,
391
+ 'num_ensemble_members': self.num_ensemble_members,
392
+ 'tree_branching_factor': self.tree_branching_factor
393
+ }
394
+
395
+ def cleanup(self):
396
+ """Cleanup hierarchical aggregation resources."""
397
+ # Ray will automatically clean up remote actors
398
+ self.nodes.clear()
399
+ self.tree_structure.clear()