pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +34 -1
- torch_geometric/_compile.py +11 -3
- torch_geometric/_onnx.py +228 -0
- torch_geometric/config_mixin.py +8 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/__init__.py +19 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +110 -6
- torch_geometric/data/database.py +19 -5
- torch_geometric/data/dataset.py +14 -9
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +20 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +17 -20
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +2 -1
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +19 -5
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcqm4m.py +2 -1
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +3 -2
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +342 -0
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/datasets/wikidata.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +35 -22
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +89 -5
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +16 -7
- torch_geometric/inspector.py +6 -2
- torch_geometric/io/fs.py +27 -0
- torch_geometric/io/tu.py +2 -3
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/llm/models/g_retriever.py +251 -0
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/llm/models/llm.py +470 -0
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +6 -5
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +4 -3
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +23 -2
- torch_geometric/metrics/link_pred.py +755 -85
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/__init__.py +2 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +149 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/collect.jinja +6 -3
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gat_conv.py +33 -4
- torch_geometric/nn/conv/gatv2_conv.py +35 -4
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +3 -2
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +6 -5
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +5 -24
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +17 -15
- torch_geometric/nn/model_hub.py +5 -16
- torch_geometric/nn/models/__init__.py +11 -0
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +15 -9
- torch_geometric/nn/pool/cluster_pool.py +144 -0
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +18 -9
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +337 -8
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +298 -25
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +4 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +10 -8
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +8 -9
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/pad.py +1 -1
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +82 -17
- torch_geometric/utils/__init__.py +6 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +28 -13
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +126 -164
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +1 -1
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +28 -25
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +14 -10
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +32 -24
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/llm.py +0 -283
- torch_geometric/nn/nlp/sentence_transformer.py +0 -94
|
@@ -15,6 +15,7 @@ from torch_geometric.distributed.rpc import (
|
|
|
15
15
|
rpc_async,
|
|
16
16
|
rpc_register,
|
|
17
17
|
)
|
|
18
|
+
from torch_geometric.io import fs
|
|
18
19
|
from torch_geometric.typing import EdgeType, NodeOrEdgeType, NodeType
|
|
19
20
|
|
|
20
21
|
|
|
@@ -415,11 +416,11 @@ class LocalFeatureStore(FeatureStore):
|
|
|
415
416
|
|
|
416
417
|
node_feats: Optional[Dict[str, Any]] = None
|
|
417
418
|
if osp.exists(osp.join(part_dir, 'node_feats.pt')):
|
|
418
|
-
node_feats =
|
|
419
|
+
node_feats = fs.torch_load(osp.join(part_dir, 'node_feats.pt'))
|
|
419
420
|
|
|
420
421
|
edge_feats: Optional[Dict[str, Any]] = None
|
|
421
422
|
if osp.exists(osp.join(part_dir, 'edge_feats.pt')):
|
|
422
|
-
edge_feats =
|
|
423
|
+
edge_feats = fs.torch_load(osp.join(part_dir, 'edge_feats.pt'))
|
|
423
424
|
|
|
424
425
|
if not meta['is_hetero'] and node_feats is not None:
|
|
425
426
|
feat_store.put_global_id(node_feats['global_id'], group_name=None)
|
|
@@ -6,6 +6,7 @@ from torch import Tensor
|
|
|
6
6
|
|
|
7
7
|
from torch_geometric.data import EdgeAttr, GraphStore
|
|
8
8
|
from torch_geometric.distributed.partition import load_partition_info
|
|
9
|
+
from torch_geometric.io import fs
|
|
9
10
|
from torch_geometric.typing import EdgeTensorType, EdgeType, NodeType
|
|
10
11
|
from torch_geometric.utils import sort_edge_index
|
|
11
12
|
|
|
@@ -185,7 +186,7 @@ class LocalGraphStore(GraphStore):
|
|
|
185
186
|
graph_store.edge_pb = edge_pb
|
|
186
187
|
graph_store.meta = meta
|
|
187
188
|
|
|
188
|
-
graph_data =
|
|
189
|
+
graph_data = fs.torch_load(osp.join(part_dir, 'graph.pt'))
|
|
189
190
|
graph_store.is_sorted = meta['is_sorted']
|
|
190
191
|
|
|
191
192
|
if not meta['is_hetero']:
|
|
@@ -3,15 +3,16 @@ import logging
|
|
|
3
3
|
import os
|
|
4
4
|
import os.path as osp
|
|
5
5
|
from collections import defaultdict
|
|
6
|
-
from typing import List, Optional, Union
|
|
6
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
10
10
|
import torch_geometric.distributed as pyg_dist
|
|
11
11
|
from torch_geometric.data import Data, HeteroData
|
|
12
|
+
from torch_geometric.io import fs
|
|
12
13
|
from torch_geometric.loader.cluster import ClusterData
|
|
13
14
|
from torch_geometric.sampler.utils import sort_csc
|
|
14
|
-
from torch_geometric.typing import
|
|
15
|
+
from torch_geometric.typing import EdgeType, EdgeTypeStr, NodeType
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
class Partitioner:
|
|
@@ -303,7 +304,7 @@ class Partitioner:
|
|
|
303
304
|
elif self.is_node_level_time:
|
|
304
305
|
node_time = data.time
|
|
305
306
|
|
|
306
|
-
# Sort by column to avoid keeping track of
|
|
307
|
+
# Sort by column to avoid keeping track of permutations in
|
|
307
308
|
# `NeighborSampler` when converting to CSC format:
|
|
308
309
|
global_row, global_col, perm = sort_csc(
|
|
309
310
|
global_row, global_col, node_time, edge_time)
|
|
@@ -360,7 +361,7 @@ class Partitioner:
|
|
|
360
361
|
'edge_types': self.edge_types,
|
|
361
362
|
'node_offset': list(node_offset.values()) if node_offset else None,
|
|
362
363
|
'is_hetero': self.is_hetero,
|
|
363
|
-
'is_sorted': True, # Based on
|
|
364
|
+
'is_sorted': True, # Based on column/destination.
|
|
364
365
|
}
|
|
365
366
|
with open(osp.join(self.root, 'META.json'), 'w') as f:
|
|
366
367
|
json.dump(meta, f)
|
|
@@ -380,21 +381,21 @@ def load_partition_info(
|
|
|
380
381
|
assert osp.exists(partition_dir)
|
|
381
382
|
|
|
382
383
|
if meta['is_hetero'] is False:
|
|
383
|
-
node_pb =
|
|
384
|
-
edge_pb =
|
|
384
|
+
node_pb = fs.torch_load(osp.join(root_dir, 'node_map.pt'))
|
|
385
|
+
edge_pb = fs.torch_load(osp.join(root_dir, 'edge_map.pt'))
|
|
385
386
|
|
|
386
387
|
return (meta, num_partitions, partition_idx, node_pb, edge_pb)
|
|
387
388
|
else:
|
|
388
389
|
node_pb_dict = {}
|
|
389
390
|
node_pb_dir = osp.join(root_dir, 'node_map')
|
|
390
391
|
for ntype in meta['node_types']:
|
|
391
|
-
node_pb_dict[ntype] =
|
|
392
|
+
node_pb_dict[ntype] = fs.torch_load(
|
|
392
393
|
osp.join(node_pb_dir, f'{pyg_dist.utils.as_str(ntype)}.pt'))
|
|
393
394
|
|
|
394
395
|
edge_pb_dict = {}
|
|
395
396
|
edge_pb_dir = osp.join(root_dir, 'edge_map')
|
|
396
397
|
for etype in meta['edge_types']:
|
|
397
|
-
edge_pb_dict[tuple(etype)] =
|
|
398
|
+
edge_pb_dict[tuple(etype)] = fs.torch_load(
|
|
398
399
|
osp.join(edge_pb_dir, f'{pyg_dist.utils.as_str(etype)}.pt'))
|
|
399
400
|
|
|
400
401
|
return (meta, num_partitions, partition_idx, node_pb_dict,
|
|
@@ -92,7 +92,7 @@ def shutdown_rpc(id: str = None, graceful: bool = True,
|
|
|
92
92
|
class RPCRouter:
|
|
93
93
|
r"""A router to get the worker based on the partition ID."""
|
|
94
94
|
def __init__(self, partition_to_workers: List[List[str]]):
|
|
95
|
-
for
|
|
95
|
+
for rpc_worker_list in partition_to_workers:
|
|
96
96
|
if len(rpc_worker_list) == 0:
|
|
97
97
|
raise ValueError('No RPC worker is in worker list')
|
|
98
98
|
self.partition_to_workers = partition_to_workers
|
|
@@ -120,7 +120,7 @@ def rpc_partition_to_workers(
|
|
|
120
120
|
partition_to_workers = [[] for _ in range(num_partitions)]
|
|
121
121
|
gathered_results = global_all_gather(
|
|
122
122
|
(ctx.role, num_partitions, current_partition_idx))
|
|
123
|
-
for worker_name, (
|
|
123
|
+
for worker_name, (_, _, idx) in gathered_results.items():
|
|
124
124
|
partition_to_workers[idx].append(worker_name)
|
|
125
125
|
return partition_to_workers
|
|
126
126
|
|
|
@@ -144,7 +144,7 @@ _rpc_call_pool: Dict[int, RPCCallBase] = {}
|
|
|
144
144
|
@rpc_require_initialized
|
|
145
145
|
def rpc_register(call: RPCCallBase) -> int:
|
|
146
146
|
r"""Registers a call for RPC requests."""
|
|
147
|
-
global _rpc_call_id
|
|
147
|
+
global _rpc_call_id
|
|
148
148
|
|
|
149
149
|
with _rpc_call_lock:
|
|
150
150
|
call_id = _rpc_call_id
|
torch_geometric/edge_index.py
CHANGED
|
@@ -17,6 +17,7 @@ from typing import (
|
|
|
17
17
|
overload,
|
|
18
18
|
)
|
|
19
19
|
|
|
20
|
+
import numpy as np
|
|
20
21
|
import torch
|
|
21
22
|
import torch.utils._pytree as pytree
|
|
22
23
|
from torch import Tensor
|
|
@@ -173,7 +174,7 @@ class EdgeIndex(Tensor):
|
|
|
173
174
|
:meth:`EdgeIndex.fill_cache_`, and are maintained and adjusted over its
|
|
174
175
|
lifespan (*e.g.*, when calling :meth:`EdgeIndex.flip`).
|
|
175
176
|
|
|
176
|
-
This representation ensures
|
|
177
|
+
This representation ensures optimal computation in GNN message passing
|
|
177
178
|
schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG`
|
|
178
179
|
workflows.
|
|
179
180
|
|
|
@@ -183,7 +184,7 @@ class EdgeIndex(Tensor):
|
|
|
183
184
|
|
|
184
185
|
edge_index = EdgeIndex(
|
|
185
186
|
[[0, 1, 1, 2],
|
|
186
|
-
[1, 0, 2, 1]]
|
|
187
|
+
[1, 0, 2, 1]],
|
|
187
188
|
sparse_size=(3, 3),
|
|
188
189
|
sort_order='row',
|
|
189
190
|
is_undirected=True,
|
|
@@ -210,7 +211,7 @@ class EdgeIndex(Tensor):
|
|
|
210
211
|
assert not edge_index.is_undirected
|
|
211
212
|
|
|
212
213
|
# Sparse-Dense Matrix Multiplication:
|
|
213
|
-
out = edge_index.flip(0) @
|
|
214
|
+
out = edge_index.flip(0) @ torch.randn(3, 16)
|
|
214
215
|
assert out.size() == (3, 16)
|
|
215
216
|
"""
|
|
216
217
|
# See "https://pytorch.org/docs/stable/notes/extending.html"
|
|
@@ -297,8 +298,7 @@ class EdgeIndex(Tensor):
|
|
|
297
298
|
indptr = None
|
|
298
299
|
data = torch.stack([row, col], dim=0)
|
|
299
300
|
|
|
300
|
-
if
|
|
301
|
-
and data.layout == torch.sparse_csc):
|
|
301
|
+
if data.layout == torch.sparse_csc:
|
|
302
302
|
row = data.row_indices()
|
|
303
303
|
indptr = data.ccol_indices()
|
|
304
304
|
|
|
@@ -325,7 +325,7 @@ class EdgeIndex(Tensor):
|
|
|
325
325
|
elif sparse_size[0] is None and sparse_size[1] is not None:
|
|
326
326
|
sparse_size = (sparse_size[1], sparse_size[1])
|
|
327
327
|
|
|
328
|
-
out = Tensor._make_wrapper_subclass(
|
|
328
|
+
out = Tensor._make_wrapper_subclass(
|
|
329
329
|
cls,
|
|
330
330
|
size=data.size(),
|
|
331
331
|
strides=data.stride(),
|
|
@@ -803,7 +803,7 @@ class EdgeIndex(Tensor):
|
|
|
803
803
|
|
|
804
804
|
size = self.get_sparse_size()
|
|
805
805
|
if value is not None and value.dim() > 1:
|
|
806
|
-
size = size + value.size()[1:]
|
|
806
|
+
size = size + value.size()[1:]
|
|
807
807
|
|
|
808
808
|
out = torch.full(size, fill_value, dtype=dtype, device=self.device)
|
|
809
809
|
out[self._data[0], self._data[1]] = value if value is not None else 1
|
|
@@ -820,19 +820,28 @@ class EdgeIndex(Tensor):
|
|
|
820
820
|
:obj:`1.0`. (default: :obj:`None`)
|
|
821
821
|
"""
|
|
822
822
|
value = self._get_value() if value is None else value
|
|
823
|
-
|
|
823
|
+
|
|
824
|
+
if not torch_geometric.typing.WITH_PT21:
|
|
825
|
+
out = torch.sparse_coo_tensor(
|
|
826
|
+
indices=self._data,
|
|
827
|
+
values=value,
|
|
828
|
+
size=self.get_sparse_size(),
|
|
829
|
+
device=self.device,
|
|
830
|
+
requires_grad=value.requires_grad,
|
|
831
|
+
)
|
|
832
|
+
if self.is_sorted_by_row:
|
|
833
|
+
out = out._coalesced_(True)
|
|
834
|
+
return out
|
|
835
|
+
|
|
836
|
+
return torch.sparse_coo_tensor(
|
|
824
837
|
indices=self._data,
|
|
825
838
|
values=value,
|
|
826
839
|
size=self.get_sparse_size(),
|
|
827
840
|
device=self.device,
|
|
828
841
|
requires_grad=value.requires_grad,
|
|
842
|
+
is_coalesced=True if self.is_sorted_by_row else None,
|
|
829
843
|
)
|
|
830
844
|
|
|
831
|
-
if self.is_sorted_by_row:
|
|
832
|
-
out = out._coalesced_(True)
|
|
833
|
-
|
|
834
|
-
return out
|
|
835
|
-
|
|
836
845
|
def to_sparse_csr( # type: ignore
|
|
837
846
|
self,
|
|
838
847
|
value: Optional[Tensor] = None,
|
|
@@ -872,10 +881,6 @@ class EdgeIndex(Tensor):
|
|
|
872
881
|
If not specified, non-zero elements will be assigned a value of
|
|
873
882
|
:obj:`1.0`. (default: :obj:`None`)
|
|
874
883
|
"""
|
|
875
|
-
if not torch_geometric.typing.WITH_PT112:
|
|
876
|
-
raise NotImplementedError(
|
|
877
|
-
"'to_sparse_csc' not supported for PyTorch < 1.12")
|
|
878
|
-
|
|
879
884
|
(colptr, row), perm = self.get_csc()
|
|
880
885
|
if value is not None and perm is not None:
|
|
881
886
|
value = value[perm]
|
|
@@ -912,7 +917,7 @@ class EdgeIndex(Tensor):
|
|
|
912
917
|
return self.to_sparse_coo(value)
|
|
913
918
|
if layout == torch.sparse_csr:
|
|
914
919
|
return self.to_sparse_csr(value)
|
|
915
|
-
if
|
|
920
|
+
if layout == torch.sparse_csc:
|
|
916
921
|
return self.to_sparse_csc(value)
|
|
917
922
|
|
|
918
923
|
raise ValueError(f"Unexpected tensor layout (got '{layout}')")
|
|
@@ -1181,10 +1186,10 @@ class EdgeIndex(Tensor):
|
|
|
1181
1186
|
return edge_index
|
|
1182
1187
|
|
|
1183
1188
|
# Prevent auto-wrapping outputs back into the proper subclass type:
|
|
1184
|
-
__torch_function__ = torch._C._disabled_torch_function_impl
|
|
1189
|
+
__torch_function__ = torch._C._disabled_torch_function_impl # type: ignore
|
|
1185
1190
|
|
|
1186
1191
|
@classmethod
|
|
1187
|
-
def __torch_dispatch__(
|
|
1192
|
+
def __torch_dispatch__( # type: ignore
|
|
1188
1193
|
cls: Type,
|
|
1189
1194
|
func: Callable[..., Any],
|
|
1190
1195
|
types: Iterable[Type[Any]],
|
|
@@ -1237,6 +1242,14 @@ class EdgeIndex(Tensor):
|
|
|
1237
1242
|
return torch._tensor_str._add_suffixes(prefix + tensor_str, suffixes,
|
|
1238
1243
|
indent, force_newline=False)
|
|
1239
1244
|
|
|
1245
|
+
def tolist(self) -> List[Any]:
|
|
1246
|
+
"""""" # noqa: D419
|
|
1247
|
+
return self._data.tolist()
|
|
1248
|
+
|
|
1249
|
+
def numpy(self, *, force: bool = False) -> np.ndarray:
|
|
1250
|
+
"""""" # noqa: D419
|
|
1251
|
+
return self._data.numpy(force=force)
|
|
1252
|
+
|
|
1240
1253
|
# Helpers #################################################################
|
|
1241
1254
|
|
|
1242
1255
|
def _shallow_copy(self) -> 'EdgeIndex':
|
|
@@ -1469,7 +1482,7 @@ def _slice(
|
|
|
1469
1482
|
step: int = 1,
|
|
1470
1483
|
) -> Union[EdgeIndex, Tensor]:
|
|
1471
1484
|
|
|
1472
|
-
if ((start is None or start <=
|
|
1485
|
+
if ((start is None or start == 0 or start <= -input.size(dim))
|
|
1473
1486
|
and (end is None or end > input.size(dim)) and step == 1):
|
|
1474
1487
|
return input._shallow_copy() # No-op.
|
|
1475
1488
|
|
|
@@ -1928,7 +1941,7 @@ def _spmm(
|
|
|
1928
1941
|
if transpose and not input.is_sorted_by_col:
|
|
1929
1942
|
cls_name = input.__class__.__name__
|
|
1930
1943
|
raise ValueError(f"'matmul(..., transpose=True)' requires "
|
|
1931
|
-
f"'{cls_name}' to be sorted by
|
|
1944
|
+
f"'{cls_name}' to be sorted by columns")
|
|
1932
1945
|
|
|
1933
1946
|
if (torch_geometric.typing.WITH_TORCH_SPARSE and not is_compiling()
|
|
1934
1947
|
and other.is_cuda): # pragma: no cover
|
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from typing import List, Optional, Union
|
|
2
|
+
from typing import Dict, List, Optional, Union, overload
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
from torch import Tensor
|
|
6
6
|
|
|
7
|
-
from torch_geometric.explain import Explanation
|
|
7
|
+
from torch_geometric.explain import Explanation, HeteroExplanation
|
|
8
8
|
from torch_geometric.explain.algorithm import ExplainerAlgorithm
|
|
9
9
|
from torch_geometric.explain.config import ExplanationType, ModelTaskLevel
|
|
10
10
|
from torch_geometric.nn.conv.message_passing import MessagePassing
|
|
11
|
+
from torch_geometric.typing import EdgeType, NodeType
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
class AttentionExplainer(ExplainerAlgorithm):
|
|
@@ -26,7 +27,9 @@ class AttentionExplainer(ExplainerAlgorithm):
|
|
|
26
27
|
def __init__(self, reduce: str = 'max'):
|
|
27
28
|
super().__init__()
|
|
28
29
|
self.reduce = reduce
|
|
30
|
+
self.is_hetero = False
|
|
29
31
|
|
|
32
|
+
@overload
|
|
30
33
|
def forward(
|
|
31
34
|
self,
|
|
32
35
|
model: torch.nn.Module,
|
|
@@ -37,65 +40,252 @@ class AttentionExplainer(ExplainerAlgorithm):
|
|
|
37
40
|
index: Optional[Union[int, Tensor]] = None,
|
|
38
41
|
**kwargs,
|
|
39
42
|
) -> Explanation:
|
|
40
|
-
|
|
41
|
-
raise ValueError(f"Heterogeneous graphs not yet supported in "
|
|
42
|
-
f"'{self.__class__.__name__}'")
|
|
43
|
+
...
|
|
43
44
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
45
|
+
@overload
|
|
46
|
+
def forward(
|
|
47
|
+
self,
|
|
48
|
+
model: torch.nn.Module,
|
|
49
|
+
x: Dict[NodeType, Tensor],
|
|
50
|
+
edge_index: Dict[EdgeType, Tensor],
|
|
51
|
+
*,
|
|
52
|
+
target: Tensor,
|
|
53
|
+
index: Optional[Union[int, Tensor]] = None,
|
|
54
|
+
**kwargs,
|
|
55
|
+
) -> HeteroExplanation:
|
|
56
|
+
...
|
|
57
|
+
|
|
58
|
+
def forward(
|
|
59
|
+
self,
|
|
60
|
+
model: torch.nn.Module,
|
|
61
|
+
x: Union[Tensor, Dict[NodeType, Tensor]],
|
|
62
|
+
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
|
|
63
|
+
*,
|
|
64
|
+
target: Tensor,
|
|
65
|
+
index: Optional[Union[int, Tensor]] = None,
|
|
66
|
+
**kwargs,
|
|
67
|
+
) -> Union[Explanation, HeteroExplanation]:
|
|
68
|
+
"""Generate explanations based on attention coefficients."""
|
|
69
|
+
self.is_hetero = isinstance(x, dict)
|
|
70
|
+
|
|
71
|
+
# Collect attention coefficients
|
|
72
|
+
alphas_dict = self._collect_attention_coefficients(
|
|
73
|
+
model, x, edge_index, **kwargs)
|
|
74
|
+
|
|
75
|
+
# Process attention coefficients
|
|
76
|
+
if self.is_hetero:
|
|
77
|
+
return self._create_hetero_explanation(model, alphas_dict,
|
|
78
|
+
edge_index, index, x)
|
|
79
|
+
else:
|
|
80
|
+
return self._create_homo_explanation(model, alphas_dict,
|
|
81
|
+
edge_index, index, x)
|
|
82
|
+
|
|
83
|
+
@overload
|
|
84
|
+
def _collect_attention_coefficients(
|
|
85
|
+
self,
|
|
86
|
+
model: torch.nn.Module,
|
|
87
|
+
x: Tensor,
|
|
88
|
+
edge_index: Tensor,
|
|
89
|
+
**kwargs,
|
|
90
|
+
) -> List[Tensor]:
|
|
91
|
+
...
|
|
92
|
+
|
|
93
|
+
@overload
|
|
94
|
+
def _collect_attention_coefficients(
|
|
95
|
+
self,
|
|
96
|
+
model: torch.nn.Module,
|
|
97
|
+
x: Dict[NodeType, Tensor],
|
|
98
|
+
edge_index: Dict[EdgeType, Tensor],
|
|
99
|
+
**kwargs,
|
|
100
|
+
) -> Dict[EdgeType, List[Tensor]]:
|
|
101
|
+
...
|
|
102
|
+
|
|
103
|
+
def _collect_attention_coefficients(
|
|
104
|
+
self,
|
|
105
|
+
model: torch.nn.Module,
|
|
106
|
+
x: Union[Tensor, Dict[NodeType, Tensor]],
|
|
107
|
+
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
|
|
108
|
+
**kwargs,
|
|
109
|
+
) -> Union[List[Tensor], Dict[EdgeType, List[Tensor]]]:
|
|
110
|
+
"""Collect attention coefficients from model layers."""
|
|
111
|
+
if self.is_hetero:
|
|
112
|
+
# For heterogeneous graphs, store alphas by edge type
|
|
113
|
+
alphas_dict: Dict[EdgeType, List[Tensor]] = {}
|
|
114
|
+
|
|
115
|
+
# Get list of edge types
|
|
116
|
+
edge_types = list(edge_index.keys())
|
|
117
|
+
|
|
118
|
+
# Hook function to capture attention coefficients by edge type
|
|
119
|
+
def hook(module, msg_kwargs, out):
|
|
120
|
+
# Find edge type from the module's full name
|
|
121
|
+
module_name = getattr(module, '_name', None)
|
|
122
|
+
if module_name is None:
|
|
123
|
+
return
|
|
50
124
|
|
|
51
|
-
|
|
125
|
+
edge_type = None
|
|
126
|
+
for edge_tuple in edge_types:
|
|
127
|
+
src_type, edge_name, dst_type = edge_tuple
|
|
128
|
+
# Check if all components appear in the module name in
|
|
129
|
+
# order
|
|
130
|
+
try:
|
|
131
|
+
src_idx = module_name.index(src_type)
|
|
132
|
+
edge_idx = module_name.index(edge_name, src_idx)
|
|
133
|
+
dst_idx = module_name.index(dst_type, edge_idx)
|
|
134
|
+
if src_idx < edge_idx < dst_idx:
|
|
135
|
+
edge_type = edge_tuple
|
|
136
|
+
break
|
|
137
|
+
except ValueError: # Component not found
|
|
138
|
+
continue
|
|
139
|
+
|
|
140
|
+
if edge_type is None:
|
|
141
|
+
return
|
|
142
|
+
|
|
143
|
+
if edge_type not in alphas_dict:
|
|
144
|
+
alphas_dict[edge_type] = []
|
|
145
|
+
|
|
146
|
+
# Extract alpha from message kwargs or module
|
|
147
|
+
if 'alpha' in msg_kwargs[0]:
|
|
148
|
+
alphas_dict[edge_type].append(
|
|
149
|
+
msg_kwargs[0]['alpha'].detach())
|
|
150
|
+
elif getattr(module, '_alpha', None) is not None:
|
|
151
|
+
alphas_dict[edge_type].append(module._alpha.detach())
|
|
152
|
+
else:
|
|
153
|
+
# For homogeneous graphs, store all alphas in a list
|
|
154
|
+
alphas: List[Tensor] = []
|
|
52
155
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
156
|
+
def hook(module, msg_kwargs, out):
|
|
157
|
+
if 'alpha' in msg_kwargs[0]:
|
|
158
|
+
alphas.append(msg_kwargs[0]['alpha'].detach())
|
|
159
|
+
elif getattr(module, '_alpha', None) is not None:
|
|
160
|
+
alphas.append(module._alpha.detach())
|
|
58
161
|
|
|
162
|
+
# Register hooks for all message passing modules
|
|
59
163
|
hook_handles = []
|
|
60
|
-
for module in model.
|
|
61
|
-
if
|
|
62
|
-
|
|
164
|
+
for name, module in model.named_modules():
|
|
165
|
+
if isinstance(module,
|
|
166
|
+
MessagePassing) and module.explain is not False:
|
|
167
|
+
# Store name for hetero graph lookup in the hook
|
|
168
|
+
if self.is_hetero:
|
|
169
|
+
module._name = name
|
|
170
|
+
|
|
63
171
|
hook_handles.append(module.register_message_forward_hook(hook))
|
|
64
172
|
|
|
173
|
+
# Forward pass to collect attention coefficients.
|
|
65
174
|
model(x, edge_index, **kwargs)
|
|
66
175
|
|
|
67
|
-
|
|
176
|
+
# Remove hooks
|
|
177
|
+
for handle in hook_handles:
|
|
68
178
|
handle.remove()
|
|
69
179
|
|
|
70
|
-
if
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
180
|
+
# Check if we collected any attention coefficients.
|
|
181
|
+
if self.is_hetero:
|
|
182
|
+
if not alphas_dict:
|
|
183
|
+
raise ValueError(
|
|
184
|
+
"Could not collect any attention coefficients. "
|
|
185
|
+
"Please ensure that your model is using "
|
|
186
|
+
"attention-based GNN layers.")
|
|
187
|
+
return alphas_dict
|
|
188
|
+
else:
|
|
189
|
+
if not alphas:
|
|
190
|
+
raise ValueError(
|
|
191
|
+
"Could not collect any attention coefficients. "
|
|
192
|
+
"Please ensure that your model is using "
|
|
193
|
+
"attention-based GNN layers.")
|
|
194
|
+
return alphas
|
|
74
195
|
|
|
196
|
+
def _process_attention_coefficients(
|
|
197
|
+
self,
|
|
198
|
+
alphas: List[Tensor],
|
|
199
|
+
edge_index_size: int,
|
|
200
|
+
) -> Tensor:
|
|
201
|
+
"""Process collected attention coefficients into a single mask."""
|
|
75
202
|
for i, alpha in enumerate(alphas):
|
|
76
|
-
|
|
203
|
+
# Ensure alpha doesn't exceed edge_index size
|
|
204
|
+
alpha = alpha[:edge_index_size]
|
|
205
|
+
|
|
206
|
+
# Reduce multi-head attention
|
|
77
207
|
if alpha.dim() == 2:
|
|
78
208
|
alpha = getattr(torch, self.reduce)(alpha, dim=-1)
|
|
79
|
-
if isinstance(alpha, tuple): #
|
|
209
|
+
if isinstance(alpha, tuple): # Handle torch.max output
|
|
80
210
|
alpha = alpha[0]
|
|
81
211
|
elif alpha.dim() > 2:
|
|
82
|
-
raise ValueError(f"
|
|
212
|
+
raise ValueError(f"Cannot reduce attention coefficients of "
|
|
83
213
|
f"shape {list(alpha.size())}")
|
|
84
214
|
alphas[i] = alpha
|
|
85
215
|
|
|
216
|
+
# Combine attention coefficients across layers
|
|
86
217
|
if len(alphas) > 1:
|
|
87
218
|
alpha = torch.stack(alphas, dim=-1)
|
|
88
219
|
alpha = getattr(torch, self.reduce)(alpha, dim=-1)
|
|
89
|
-
if isinstance(alpha, tuple): #
|
|
220
|
+
if isinstance(alpha, tuple): # Handle torch.max output
|
|
90
221
|
alpha = alpha[0]
|
|
91
222
|
else:
|
|
92
223
|
alpha = alphas[0]
|
|
93
224
|
|
|
225
|
+
return alpha
|
|
226
|
+
|
|
227
|
+
def _create_homo_explanation(
|
|
228
|
+
self,
|
|
229
|
+
model: torch.nn.Module,
|
|
230
|
+
alphas: List[Tensor],
|
|
231
|
+
edge_index: Tensor,
|
|
232
|
+
index: Optional[Union[int, Tensor]],
|
|
233
|
+
x: Tensor,
|
|
234
|
+
) -> Explanation:
|
|
235
|
+
"""Create explanation for homogeneous graph."""
|
|
236
|
+
# Get hard edge mask for node-level tasks
|
|
237
|
+
hard_edge_mask = None
|
|
238
|
+
if self.model_config.task_level == ModelTaskLevel.node:
|
|
239
|
+
_, hard_edge_mask = self._get_hard_masks(model, index, edge_index,
|
|
240
|
+
num_nodes=x.size(0))
|
|
241
|
+
|
|
242
|
+
# Process attention coefficients
|
|
243
|
+
alpha = self._process_attention_coefficients(alphas,
|
|
244
|
+
edge_index.size(1))
|
|
245
|
+
|
|
246
|
+
# Post-process mask with hard edge mask if needed
|
|
94
247
|
alpha = self._post_process_mask(alpha, hard_edge_mask,
|
|
95
248
|
apply_sigmoid=False)
|
|
96
249
|
|
|
97
250
|
return Explanation(edge_mask=alpha)
|
|
98
251
|
|
|
252
|
+
def _create_hetero_explanation(
|
|
253
|
+
self,
|
|
254
|
+
model: torch.nn.Module,
|
|
255
|
+
alphas_dict: Dict[EdgeType, List[Tensor]],
|
|
256
|
+
edge_index: Dict[EdgeType, Tensor],
|
|
257
|
+
index: Optional[Union[int, Tensor]],
|
|
258
|
+
x: Dict[NodeType, Tensor],
|
|
259
|
+
) -> HeteroExplanation:
|
|
260
|
+
"""Create explanation for heterogeneous graph."""
|
|
261
|
+
edge_masks_dict = {}
|
|
262
|
+
|
|
263
|
+
# Process each edge type separately
|
|
264
|
+
for edge_type, alphas in alphas_dict.items():
|
|
265
|
+
if not alphas:
|
|
266
|
+
continue
|
|
267
|
+
|
|
268
|
+
# Get hard edge mask for node-level tasks
|
|
269
|
+
hard_edge_mask = None
|
|
270
|
+
if self.model_config.task_level == ModelTaskLevel.node:
|
|
271
|
+
src_type, _, dst_type = edge_type
|
|
272
|
+
_, hard_edge_mask = self._get_hard_masks(
|
|
273
|
+
model, index, edge_index[edge_type],
|
|
274
|
+
num_nodes=max(x[src_type].size(0), x[dst_type].size(0)))
|
|
275
|
+
|
|
276
|
+
# Process attention coefficients for this edge type
|
|
277
|
+
alpha = self._process_attention_coefficients(
|
|
278
|
+
alphas, edge_index[edge_type].size(1))
|
|
279
|
+
|
|
280
|
+
# Apply hard mask if available
|
|
281
|
+
edge_masks_dict[edge_type] = self._post_process_mask(
|
|
282
|
+
alpha, hard_edge_mask, apply_sigmoid=False)
|
|
283
|
+
|
|
284
|
+
# Create heterogeneous explanation
|
|
285
|
+
explanation = HeteroExplanation()
|
|
286
|
+
explanation.set_value_dict('edge_mask', edge_masks_dict)
|
|
287
|
+
return explanation
|
|
288
|
+
|
|
99
289
|
def supports(self) -> bool:
|
|
100
290
|
explanation_type = self.explainer_config.explanation_type
|
|
101
291
|
if explanation_type != ExplanationType.model:
|
|
@@ -166,7 +166,7 @@ class ExplainerAlgorithm(torch.nn.Module):
|
|
|
166
166
|
elif self.model_config.return_type == ModelReturnType.probs:
|
|
167
167
|
loss_fn = F.binary_cross_entropy
|
|
168
168
|
else:
|
|
169
|
-
|
|
169
|
+
raise AssertionError()
|
|
170
170
|
|
|
171
171
|
return loss_fn(y_hat.view_as(y), y.float())
|
|
172
172
|
|
|
@@ -183,7 +183,7 @@ class ExplainerAlgorithm(torch.nn.Module):
|
|
|
183
183
|
elif self.model_config.return_type == ModelReturnType.log_probs:
|
|
184
184
|
loss_fn = F.nll_loss
|
|
185
185
|
else:
|
|
186
|
-
|
|
186
|
+
raise AssertionError()
|
|
187
187
|
|
|
188
188
|
return loss_fn(y_hat, y)
|
|
189
189
|
|
|
@@ -190,7 +190,7 @@ def to_captum_input(
|
|
|
190
190
|
|
|
191
191
|
Args:
|
|
192
192
|
x (torch.Tensor or Dict[NodeType, torch.Tensor]): The node features.
|
|
193
|
-
For heterogeneous graphs this is a dictionary holding node
|
|
193
|
+
For heterogeneous graphs this is a dictionary holding node features
|
|
194
194
|
for each node type.
|
|
195
195
|
edge_index(torch.Tensor or Dict[EdgeType, torch.Tensor]): The edge
|
|
196
196
|
indices. For heterogeneous graphs this is a dictionary holding the
|
|
@@ -73,7 +73,8 @@ class CaptumExplainer(ExplainerAlgorithm):
|
|
|
73
73
|
f"{self.attribution_method_class.__name__}")
|
|
74
74
|
|
|
75
75
|
if kwargs.get('internal_batch_size', 1) != 1:
|
|
76
|
-
warnings.warn("Overriding 'internal_batch_size' to 1"
|
|
76
|
+
warnings.warn("Overriding 'internal_batch_size' to 1",
|
|
77
|
+
stacklevel=2)
|
|
77
78
|
|
|
78
79
|
if 'internal_batch_size' in self._get_attribute_parameters():
|
|
79
80
|
kwargs['internal_batch_size'] = 1
|