pyg-nightly 2.7.0.dev20250428__py3-none-any.whl → 2.7.0.dev20250430__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyg-nightly
3
- Version: 2.7.0.dev20250428
3
+ Version: 2.7.0.dev20250430
4
4
  Summary: Graph Neural Network Library for PyTorch
5
5
  Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
6
6
  Author-email: Matthias Fey <matthias@pyg.org>
@@ -1,4 +1,4 @@
1
- torch_geometric/__init__.py,sha256=Gbg4XAZMSi28NjwxD1wMjvY-M4y1bxpYHz0c6bjkLbs,1978
1
+ torch_geometric/__init__.py,sha256=27uNYndI_FSG4HKFDUYSjNUa4aLud2vk66mSkuYtSAA,1978
2
2
  torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
3
3
  torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
4
4
  torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
@@ -193,7 +193,7 @@ torch_geometric/distributed/utils.py,sha256=FGrr3qw7hx7EQaIjjqasurloCFJ9q_0jt8jd
193
193
  torch_geometric/explain/__init__.py,sha256=pRxVB33zsxhED1StRWdHboQWh3e06__g9N298Hzi42Y,359
194
194
  torch_geometric/explain/config.py,sha256=_0j67NAwPwjrWHPncNywCT-oKyMiryJNxufxVN1BFlM,7834
195
195
  torch_geometric/explain/explainer.py,sha256=8_NZTmlT4WO9RgKxpSUQRt3rbVwFURF5bSWOPlfOLjA,10667
196
- torch_geometric/explain/explanation.py,sha256=Z2NlgavEnq0QadEr6p6pxAhV6lU7WrlcJLFWbTdsmvg,14903
196
+ torch_geometric/explain/explanation.py,sha256=Bt8THLn-CSrvEFisdT9DX9fnOMaqficsChSCI9uhyQw,18873
197
197
  torch_geometric/explain/algorithm/__init__.py,sha256=fE29xbd0bPxg-EfrB2BDmmY9QnyO-7TgvYduGHofm5o,496
198
198
  torch_geometric/explain/algorithm/attention_explainer.py,sha256=65iGLmOt00ERtBDVxAoydIchykdWZU24aXzSzUGzQEI,11304
199
199
  torch_geometric/explain/algorithm/base.py,sha256=wwJcREUFKDLFUDjRa9o4X3DWqQgMvhS3Iciwb6Evtjc,6922
@@ -633,10 +633,10 @@ torch_geometric/utils/repeat.py,sha256=RxCoRoEisaP6NouXPPW5tY1Rn-tIfrmpJPm0qGP6W
633
633
  torch_geometric/utils/smiles.py,sha256=lGQ2BwJ49uBrQfIxxPz8ceTO9Jo-XCjlLxs1ql3xrsA,7130
634
634
  torch_geometric/utils/sparse.py,sha256=uYd0oPrp5XN0c2Zc15f-00rhhVMfLnRMqNcqcmILNKQ,25519
635
635
  torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5nUAUjw,6222
636
- torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
637
- torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
636
+ torch_geometric/visualization/__init__.py,sha256=b-HnVesXjyJ_L1N-DnjiRiRVf7lhwKaBQF_2i5YMVSU,208
637
+ torch_geometric/visualization/graph.py,sha256=PoI9tjbEXZVkMUg4CvTLbzqtEfzUwMUcsw57DNBEU0s,14311
638
638
  torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
639
- pyg_nightly-2.7.0.dev20250428.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
640
- pyg_nightly-2.7.0.dev20250428.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
641
- pyg_nightly-2.7.0.dev20250428.dist-info/METADATA,sha256=FIDRBCg3wfpTqyNTcBx3DGEuYtTJ3pvyK7aoAnHFVIs,62979
642
- pyg_nightly-2.7.0.dev20250428.dist-info/RECORD,,
639
+ pyg_nightly-2.7.0.dev20250430.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
640
+ pyg_nightly-2.7.0.dev20250430.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
641
+ pyg_nightly-2.7.0.dev20250430.dist-info/METADATA,sha256=7daZJ-7pS7DuGftzAxPDtQv8_RfVOflLc_-ov66cGDk,62979
642
+ pyg_nightly-2.7.0.dev20250430.dist-info/RECORD,,
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
31
31
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
32
32
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
33
33
 
34
- __version__ = '2.7.0.dev20250428'
34
+ __version__ = '2.7.0.dev20250430'
35
35
 
36
36
  __all__ = [
37
37
  'Index',
@@ -1,5 +1,5 @@
1
1
  import copy
2
- from typing import Dict, List, Optional, Union
2
+ from typing import Dict, List, Optional, Tuple, Union
3
3
 
4
4
  import torch
5
5
  from torch import Tensor
@@ -8,7 +8,10 @@ from torch_geometric.data.data import Data, warn_or_raise
8
8
  from torch_geometric.data.hetero_data import HeteroData
9
9
  from torch_geometric.explain.config import ThresholdConfig, ThresholdType
10
10
  from torch_geometric.typing import EdgeType, NodeType
11
- from torch_geometric.visualization import visualize_graph
11
+ from torch_geometric.visualization import (
12
+ visualize_graph,
13
+ visualize_hetero_graph,
14
+ )
12
15
 
13
16
 
14
17
  class ExplanationMixin:
@@ -362,6 +365,87 @@ class HeteroExplanation(HeteroData, ExplanationMixin):
362
365
 
363
366
  return _visualize_score(score, all_feat_labels, path, top_k)
364
367
 
368
+ def visualize_graph(
369
+ self,
370
+ path: Optional[str] = None,
371
+ node_labels: Optional[Dict[NodeType, List[str]]] = None,
372
+ node_size_range: Tuple[float, float] = (50, 500),
373
+ node_opacity_range: Tuple[float, float] = (0.2, 1.0),
374
+ edge_width_range: Tuple[float, float] = (0.1, 2.0),
375
+ edge_opacity_range: Tuple[float, float] = (0.2, 1.0),
376
+ ) -> None:
377
+ r"""Visualizes the explanation subgraph using networkx, with edge
378
+ opacity corresponding to edge importance and node colors
379
+ corresponding to node types.
380
+
381
+ Args:
382
+ path (str, optional): The path to where the plot is saved.
383
+ If set to :obj:`None`, will visualize the plot on-the-fly.
384
+ (default: :obj:`None`)
385
+ node_labels (Dict[NodeType, List[str]], optional): The display
386
+ names of nodes for each node type that will be shown in the
387
+ visualization. (default: :obj:`None`)
388
+ node_size_range (Tuple[float, float], optional): The minimum and
389
+ maximum node size in the visualization.
390
+ (default: :obj:`(50, 500)`)
391
+ node_opacity_range (Tuple[float, float], optional): The minimum and
392
+ maximum node opacity in the visualization.
393
+ (default: :obj:`(0.2, 1.0)`)
394
+ edge_width_range (Tuple[float, float], optional): The minimum and
395
+ maximum edge width in the visualization.
396
+ (default: :obj:`(0.1, 2.0)`)
397
+ edge_opacity_range (Tuple[float, float], optional): The minimum and
398
+ maximum edge opacity in the visualization.
399
+ (default: :obj:`(0.2, 1.0)`)
400
+ """
401
+ # Validate node labels if provided
402
+ if node_labels is not None:
403
+ for node_type, labels in node_labels.items():
404
+ if node_type not in self.node_types:
405
+ raise ValueError(
406
+ f"Node type '{node_type}' in node_labels "
407
+ f"does not exist in the explanation graph")
408
+ if len(labels) != self[node_type].num_nodes:
409
+ raise ValueError(f"Number of labels for node type "
410
+ f"'{node_type}' (got {len(labels)}) does "
411
+ f"not match the number of nodes "
412
+ f"(got {self[node_type].num_nodes})")
413
+ # Get the explanation subgraph
414
+ subgraph = self.get_explanation_subgraph()
415
+
416
+ # Prepare edge indices and weights for each edge type
417
+ edge_index_dict = {}
418
+ edge_weight_dict = {}
419
+ for edge_type in subgraph.edge_types:
420
+ if edge_type[0] == 'x' or edge_type[-1] == 'x': # Skip edges
421
+ continue
422
+ edge_index_dict[edge_type] = subgraph[edge_type].edge_index
423
+ edge_weight_dict[edge_type] = subgraph[edge_type].get(
424
+ 'edge_mask',
425
+ torch.ones(subgraph[edge_type].edge_index.size(1)))
426
+
427
+ # Prepare node weights for each node type
428
+ node_weight_dict = {}
429
+ for node_type in subgraph.node_types:
430
+ if node_type == 'x': # Skip the global store
431
+ continue
432
+ node_weight_dict[node_type] = subgraph[node_type] \
433
+ .get('node_mask',
434
+ torch.ones(subgraph[node_type].num_nodes)).squeeze(-1)
435
+
436
+ # Call the visualization function
437
+ visualize_hetero_graph(
438
+ edge_index_dict=edge_index_dict,
439
+ edge_weight_dict=edge_weight_dict,
440
+ path=path,
441
+ node_labels_dict=node_labels,
442
+ node_weight_dict=node_weight_dict,
443
+ node_size_range=node_size_range,
444
+ node_opacity_range=node_opacity_range,
445
+ edge_width_range=edge_width_range,
446
+ edge_opacity_range=edge_opacity_range,
447
+ )
448
+
365
449
 
366
450
  def _visualize_score(
367
451
  score: torch.Tensor,
@@ -1,9 +1,10 @@
1
1
  r"""Visualization package."""
2
2
 
3
- from .graph import visualize_graph
3
+ from .graph import visualize_graph, visualize_hetero_graph
4
4
  from .influence import influence
5
5
 
6
6
  __all__ = [
7
7
  'visualize_graph',
8
+ 'visualize_hetero_graph',
8
9
  'influence',
9
10
  ]
@@ -1,5 +1,5 @@
1
1
  from math import sqrt
2
- from typing import Any, List, Optional
2
+ from typing import Any, Dict, List, Optional, Set, Tuple
3
3
 
4
4
  import torch
5
5
  from torch import Tensor
@@ -150,3 +150,249 @@ def _visualize_graph_via_networkx(
150
150
  plt.show()
151
151
 
152
152
  plt.close()
153
+
154
+
155
+ def visualize_hetero_graph(
156
+ edge_index_dict: Dict[Tuple[str, str, str], Tensor],
157
+ edge_weight_dict: Dict[Tuple[str, str, str], Tensor],
158
+ path: Optional[str] = None,
159
+ backend: Optional[str] = None,
160
+ node_labels_dict: Optional[Dict[str, List[str]]] = None,
161
+ node_weight_dict: Optional[Dict[str, Tensor]] = None,
162
+ node_size_range: Tuple[float, float] = (50, 500),
163
+ node_opacity_range: Tuple[float, float] = (0.2, 1.0),
164
+ edge_width_range: Tuple[float, float] = (0.1, 2.0),
165
+ edge_opacity_range: Tuple[float, float] = (0.2, 1.0),
166
+ ) -> Any:
167
+ """Visualizes a heterogeneous graph using networkx."""
168
+ if backend is not None and backend != "networkx":
169
+ raise ValueError("Only 'networkx' backend is supported")
170
+
171
+ # Filter out edges with 0 weight
172
+ filtered_edge_index_dict = {}
173
+ filtered_edge_weight_dict = {}
174
+ for edge_type in edge_index_dict.keys():
175
+ mask = edge_weight_dict[edge_type] > 0
176
+ if mask.sum() > 0:
177
+ filtered_edge_index_dict[edge_type] = edge_index_dict[
178
+ edge_type][:, mask]
179
+ filtered_edge_weight_dict[edge_type] = edge_weight_dict[edge_type][
180
+ mask]
181
+
182
+ # Get all unique nodes that are still in the filtered edges
183
+ remaining_nodes: Dict[str, Set[int]] = {}
184
+ for edge_type, edge_index in filtered_edge_index_dict.items():
185
+ src_type, _, dst_type = edge_type
186
+ if src_type not in remaining_nodes:
187
+ remaining_nodes[src_type] = set()
188
+ if dst_type not in remaining_nodes:
189
+ remaining_nodes[dst_type] = set()
190
+ remaining_nodes[src_type].update(edge_index[0].tolist())
191
+ remaining_nodes[dst_type].update(edge_index[1].tolist())
192
+
193
+ # Filter node weights to only include remaining nodes
194
+ if node_weight_dict is not None:
195
+ filtered_node_weight_dict = {}
196
+ for node_type, weights in node_weight_dict.items():
197
+ if node_type in remaining_nodes:
198
+ mask = torch.zeros(len(weights), dtype=torch.bool)
199
+ mask[list(remaining_nodes[node_type])] = True
200
+ filtered_node_weight_dict[node_type] = weights[mask]
201
+ node_weight_dict = filtered_node_weight_dict
202
+
203
+ # Filter node labels to only include remaining nodes
204
+ if node_labels_dict is not None:
205
+ filtered_node_labels_dict = {}
206
+ for node_type, labels in node_labels_dict.items():
207
+ if node_type in remaining_nodes:
208
+ filtered_node_labels_dict[node_type] = [
209
+ label for i, label in enumerate(labels)
210
+ if i in remaining_nodes[node_type]
211
+ ]
212
+ node_labels_dict = filtered_node_labels_dict
213
+
214
+ return _visualize_hetero_graph_via_networkx(
215
+ filtered_edge_index_dict,
216
+ filtered_edge_weight_dict,
217
+ path,
218
+ node_labels_dict,
219
+ node_weight_dict,
220
+ node_size_range,
221
+ node_opacity_range,
222
+ edge_width_range,
223
+ edge_opacity_range,
224
+ )
225
+
226
+
227
+ def _visualize_hetero_graph_via_networkx(
228
+ edge_index_dict: Dict[Tuple[str, str, str], Tensor],
229
+ edge_weight_dict: Dict[Tuple[str, str, str], Tensor],
230
+ path: Optional[str] = None,
231
+ node_labels_dict: Optional[Dict[str, List[str]]] = None,
232
+ node_weight_dict: Optional[Dict[str, Tensor]] = None,
233
+ node_size_range: Tuple[float, float] = (50, 500),
234
+ node_opacity_range: Tuple[float, float] = (0.2, 1.0),
235
+ edge_width_range: Tuple[float, float] = (0.1, 2.0),
236
+ edge_opacity_range: Tuple[float, float] = (0.2, 1.0),
237
+ ) -> Any:
238
+ import matplotlib.pyplot as plt
239
+ import networkx as nx
240
+
241
+ g = nx.DiGraph()
242
+ node_offsets: Dict[str, int] = {}
243
+ current_offset = 0
244
+
245
+ # First, collect all unique node types and their counts
246
+ node_types = set()
247
+ node_counts: Dict[str, int] = {}
248
+ remaining_nodes: Dict[str, Set[int]] = {
249
+ } # Track which nodes are actually present in edges
250
+
251
+ # Get all unique nodes that are in the edges
252
+ for edge_type in edge_index_dict.keys():
253
+ src_type, _, dst_type = edge_type
254
+ node_types.add(src_type)
255
+ node_types.add(dst_type)
256
+
257
+ if src_type not in remaining_nodes:
258
+ remaining_nodes[src_type] = set()
259
+ if dst_type not in remaining_nodes:
260
+ remaining_nodes[dst_type] = set()
261
+
262
+ remaining_nodes[src_type].update(
263
+ edge_index_dict[edge_type][0].tolist())
264
+ remaining_nodes[dst_type].update(
265
+ edge_index_dict[edge_type][1].tolist())
266
+
267
+ # Set node counts based on remaining nodes
268
+ for node_type in node_types:
269
+ node_counts[node_type] = len(remaining_nodes[node_type])
270
+
271
+ # Add nodes for each node type
272
+ for node_type in node_types:
273
+ num_nodes = node_counts[node_type]
274
+ node_offsets[node_type] = current_offset
275
+
276
+ # Get node weights if provided
277
+ weights = None
278
+ if node_weight_dict is not None and node_type in node_weight_dict:
279
+ weights = node_weight_dict[node_type]
280
+ if len(weights) != num_nodes:
281
+ raise ValueError(f"Number of weights for node type "
282
+ f"{node_type} ({len(weights)}) does not "
283
+ f"match number of nodes ({num_nodes})")
284
+
285
+ for i in range(num_nodes):
286
+ node_id = current_offset + i
287
+ label = (node_labels_dict[node_type][i]
288
+ if node_labels_dict is not None
289
+ and node_type in node_labels_dict else "")
290
+
291
+ # Calculate node size and opacity if weights provided
292
+ size = node_size_range[1]
293
+ opacity = node_opacity_range[1]
294
+ if weights is not None:
295
+ w = weights[i].item()
296
+ size = node_size_range[0] + w * \
297
+ (node_size_range[1] - node_size_range[0])
298
+ opacity = node_opacity_range[0] + w * \
299
+ (node_opacity_range[1] - node_opacity_range[0])
300
+
301
+ g.add_node(node_id, label=label, type=node_type, size=size,
302
+ alpha=opacity)
303
+
304
+ current_offset += num_nodes
305
+
306
+ # Add edges with remapped node indices
307
+ for edge_type, edge_index in edge_index_dict.items():
308
+ src_type, _, dst_type = edge_type
309
+ edge_weight = edge_weight_dict[edge_type]
310
+ src_offset = node_offsets[src_type]
311
+ dst_offset = node_offsets[dst_type]
312
+
313
+ # Create mappings for source and target nodes
314
+ src_mapping = {
315
+ old_idx: new_idx
316
+ for new_idx, old_idx in enumerate(sorted(
317
+ remaining_nodes[src_type]))
318
+ }
319
+ dst_mapping = {
320
+ old_idx: new_idx
321
+ for new_idx, old_idx in enumerate(sorted(
322
+ remaining_nodes[dst_type]))
323
+ }
324
+
325
+ for (src, dst), w in zip(edge_index.t().tolist(),
326
+ edge_weight.tolist()):
327
+ # Remap node indices
328
+ new_src = src_mapping[src] + src_offset
329
+ new_dst = dst_mapping[dst] + dst_offset
330
+
331
+ # Calculate edge width and opacity based on weight
332
+ width = edge_width_range[0] + w * \
333
+ (edge_width_range[1] - edge_width_range[0])
334
+ opacity = edge_opacity_range[0] + w * \
335
+ (edge_opacity_range[1] - edge_opacity_range[0])
336
+ g.add_edge(new_src, new_dst, width=width, alpha=opacity)
337
+
338
+ # Draw the graph
339
+ ax = plt.gca()
340
+ pos = nx.arf_layout(g)
341
+
342
+ # Draw edges with arrows
343
+ for src, dst, data in g.edges(data=True):
344
+ ax.annotate(
345
+ '',
346
+ xy=pos[src],
347
+ xytext=pos[dst],
348
+ arrowprops=dict(
349
+ arrowstyle="<-",
350
+ alpha=data['alpha'],
351
+ linewidth=data['width'],
352
+ shrinkA=sqrt(g.nodes[src]['size']) / 2.0,
353
+ shrinkB=sqrt(g.nodes[dst]['size']) / 2.0,
354
+ connectionstyle="arc3,rad=0.1",
355
+ ),
356
+ )
357
+
358
+ # Draw nodes colored by type
359
+ node_colors = []
360
+ node_sizes = []
361
+ node_alphas = []
362
+
363
+ # Use matplotlib tab20 colormap for consistent coloring
364
+ tab10_cmap = plt.cm.tab10 # type: ignore[attr-defined]
365
+ node_type_colors: Dict[str, Any] = {} # Store color for each node type
366
+ for node in g.nodes():
367
+ node_type = g.nodes[node]['type']
368
+ # Assign a consistent color for each node type
369
+ if node_type not in node_type_colors:
370
+ color_idx = len(node_type_colors) % 10 # Cycle through colors
371
+ node_type_colors[node_type] = tab10_cmap(color_idx)
372
+ node_colors.append(node_type_colors[node_type])
373
+ node_sizes.append(g.nodes[node]['size'])
374
+ node_alphas.append(g.nodes[node]['alpha'])
375
+
376
+ nx.draw_networkx_nodes(g, pos, node_size=node_sizes,
377
+ node_color=node_colors, margins=0.1,
378
+ alpha=node_alphas)
379
+
380
+ # Draw labels
381
+ labels = nx.get_node_attributes(g, 'label')
382
+ nx.draw_networkx_labels(g, pos, labels, font_size=10)
383
+
384
+ # Add legend
385
+ legend_elements = []
386
+ for node_type, color in node_type_colors.items():
387
+ legend_elements.append(
388
+ plt.Line2D([0], [0], marker='o', color='w', label=node_type,
389
+ markerfacecolor=color, markersize=10))
390
+ ax.legend(handles=legend_elements, loc='upper right',
391
+ bbox_to_anchor=(0.9, 1))
392
+
393
+ if path is not None:
394
+ plt.savefig(path, bbox_inches='tight')
395
+ else:
396
+ plt.show()
397
+
398
+ plt.close()