pyg-nightly 2.6.0.dev20240318__py3-none-any.whl → 2.7.0.dev20250115__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/METADATA +31 -47
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/RECORD +226 -199
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/WHEEL +1 -1
- torch_geometric/__init__.py +28 -1
- torch_geometric/_compile.py +8 -1
- torch_geometric/_onnx.py +14 -0
- torch_geometric/config_mixin.py +113 -0
- torch_geometric/config_store.py +28 -19
- torch_geometric/data/__init__.py +24 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +8 -2
- torch_geometric/data/data.py +16 -8
- torch_geometric/data/database.py +61 -15
- torch_geometric/data/dataset.py +14 -6
- torch_geometric/data/feature_store.py +25 -42
- torch_geometric/data/graph_store.py +1 -5
- torch_geometric/data/hetero_data.py +18 -9
- torch_geometric/data/in_memory_dataset.py +2 -4
- torch_geometric/data/large_graph_indexer.py +677 -0
- torch_geometric/data/lightning/datamodule.py +4 -4
- torch_geometric/data/separate.py +6 -1
- torch_geometric/data/storage.py +17 -7
- torch_geometric/data/summary.py +14 -4
- torch_geometric/data/temporal.py +1 -2
- torch_geometric/datasets/__init__.py +17 -2
- torch_geometric/datasets/actor.py +9 -11
- torch_geometric/datasets/airfrans.py +15 -18
- torch_geometric/datasets/airports.py +10 -12
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +9 -10
- torch_geometric/datasets/amazon_products.py +9 -10
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +10 -12
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/bitcoin_otc.py +1 -1
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/cornell.py +145 -0
- torch_geometric/datasets/dblp.py +2 -1
- torch_geometric/datasets/dbp15k.py +2 -2
- torch_geometric/datasets/fake.py +1 -3
- torch_geometric/datasets/flickr.py +2 -1
- torch_geometric/datasets/freebase.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
- torch_geometric/datasets/hgb_dataset.py +8 -8
- torch_geometric/datasets/imdb.py +2 -1
- torch_geometric/datasets/karate.py +3 -2
- torch_geometric/datasets/last_fm.py +2 -1
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +4 -3
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
- torch_geometric/datasets/molecule_net.py +15 -3
- torch_geometric/datasets/motif_generator/base.py +0 -1
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +239 -0
- torch_geometric/datasets/ose_gvcs.py +1 -1
- torch_geometric/datasets/pascal.py +11 -9
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcpnet_dataset.py +1 -1
- torch_geometric/datasets/pcqm4m.py +10 -3
- torch_geometric/datasets/ppi.py +1 -1
- torch_geometric/datasets/qm9.py +8 -7
- torch_geometric/datasets/rcdd.py +4 -4
- torch_geometric/datasets/reddit.py +2 -1
- torch_geometric/datasets/reddit2.py +2 -1
- torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
- torch_geometric/datasets/s3dis.py +5 -3
- torch_geometric/datasets/shapenet.py +3 -3
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +7 -1
- torch_geometric/datasets/tag_dataset.py +350 -0
- torch_geometric/datasets/upfd.py +2 -1
- torch_geometric/datasets/web_qsp_dataset.py +246 -0
- torch_geometric/datasets/webkb.py +2 -2
- torch_geometric/datasets/wikics.py +1 -1
- torch_geometric/datasets/wikidata.py +3 -2
- torch_geometric/datasets/wikipedia_network.py +2 -2
- torch_geometric/datasets/willow_object_class.py +1 -1
- torch_geometric/datasets/word_net.py +2 -2
- torch_geometric/datasets/yelp.py +2 -1
- torch_geometric/datasets/zinc.py +1 -1
- torch_geometric/device.py +42 -0
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/edge_index.py +616 -438
- torch_geometric/explain/algorithm/base.py +0 -1
- torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
- torch_geometric/explain/algorithm/pg_explainer.py +1 -1
- torch_geometric/explain/explanation.py +2 -2
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/logger.py +4 -4
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/utils/agg_runs.py +6 -6
- torch_geometric/index.py +826 -0
- torch_geometric/inspector.py +13 -7
- torch_geometric/io/fs.py +28 -2
- torch_geometric/io/npz.py +2 -1
- torch_geometric/io/off.py +2 -2
- torch_geometric/io/sdf.py +2 -2
- torch_geometric/io/tu.py +4 -5
- torch_geometric/loader/__init__.py +4 -0
- torch_geometric/loader/cluster.py +10 -4
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +12 -4
- torch_geometric/loader/mixin.py +1 -1
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +1 -1
- torch_geometric/loader/rag_loader.py +107 -0
- torch_geometric/loader/utils.py +8 -7
- torch_geometric/loader/zip_loader.py +10 -0
- torch_geometric/metrics/__init__.py +11 -2
- torch_geometric/metrics/link_pred.py +317 -65
- torch_geometric/nn/aggr/__init__.py +4 -0
- torch_geometric/nn/aggr/attention.py +0 -2
- torch_geometric/nn/aggr/base.py +3 -5
- torch_geometric/nn/aggr/patch_transformer.py +143 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/variance_preserving.py +33 -0
- torch_geometric/nn/attention/__init__.py +5 -1
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/conv/collect.jinja +7 -4
- torch_geometric/nn/conv/cugraph/base.py +8 -12
- torch_geometric/nn/conv/edge_conv.py +3 -2
- torch_geometric/nn/conv/fused_gat_conv.py +1 -1
- torch_geometric/nn/conv/gat_conv.py +35 -7
- torch_geometric/nn/conv/gatv2_conv.py +36 -6
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/graph_conv.py +21 -3
- torch_geometric/nn/conv/gravnet_conv.py +3 -2
- torch_geometric/nn/conv/hetero_conv.py +3 -3
- torch_geometric/nn/conv/hgt_conv.py +1 -1
- torch_geometric/nn/conv/message_passing.py +138 -87
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/propagate.jinja +9 -1
- torch_geometric/nn/conv/rgcn_conv.py +5 -5
- torch_geometric/nn/conv/spline_conv.py +4 -4
- torch_geometric/nn/conv/x_conv.py +3 -2
- torch_geometric/nn/dense/linear.py +11 -6
- torch_geometric/nn/fx.py +3 -3
- torch_geometric/nn/model_hub.py +3 -1
- torch_geometric/nn/models/__init__.py +10 -2
- torch_geometric/nn/models/deep_graph_infomax.py +1 -2
- torch_geometric/nn/models/dimenet_utils.py +5 -7
- torch_geometric/nn/models/g_retriever.py +230 -0
- torch_geometric/nn/models/git_mol.py +336 -0
- torch_geometric/nn/models/glem.py +385 -0
- torch_geometric/nn/models/gnnff.py +0 -1
- torch_geometric/nn/models/graph_unet.py +12 -3
- torch_geometric/nn/models/jumping_knowledge.py +63 -4
- torch_geometric/nn/models/lightgcn.py +1 -1
- torch_geometric/nn/models/metapath2vec.py +5 -5
- torch_geometric/nn/models/molecule_gpt.py +222 -0
- torch_geometric/nn/models/node2vec.py +2 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/signed_gcn.py +3 -3
- torch_geometric/nn/module_dict.py +2 -2
- torch_geometric/nn/nlp/__init__.py +9 -0
- torch_geometric/nn/nlp/llm.py +329 -0
- torch_geometric/nn/nlp/sentence_transformer.py +134 -0
- torch_geometric/nn/nlp/vision_transformer.py +33 -0
- torch_geometric/nn/norm/batch_norm.py +1 -1
- torch_geometric/nn/parameter_dict.py +2 -2
- torch_geometric/nn/pool/__init__.py +21 -5
- torch_geometric/nn/pool/cluster_pool.py +145 -0
- torch_geometric/nn/pool/connect/base.py +0 -1
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/graclus.py +4 -2
- torch_geometric/nn/pool/pool.py +8 -2
- torch_geometric/nn/pool/select/base.py +0 -1
- torch_geometric/nn/pool/voxel_grid.py +3 -2
- torch_geometric/nn/resolver.py +1 -1
- torch_geometric/nn/sequential.jinja +10 -23
- torch_geometric/nn/sequential.py +204 -78
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +30 -19
- torch_geometric/resolver.py +1 -1
- torch_geometric/sampler/base.py +34 -13
- torch_geometric/sampler/neighbor_sampler.py +11 -10
- torch_geometric/sampler/utils.py +1 -1
- torch_geometric/template.py +1 -0
- torch_geometric/testing/__init__.py +6 -2
- torch_geometric/testing/decorators.py +56 -22
- torch_geometric/testing/feature_store.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_metapaths.py +5 -5
- torch_geometric/transforms/add_positional_encoding.py +1 -1
- torch_geometric/transforms/delaunay.py +65 -14
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -6
- torch_geometric/transforms/laplacian_lambda_max.py +3 -3
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -2
- torch_geometric/transforms/pad.py +7 -6
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/to_sparse_tensor.py +1 -1
- torch_geometric/transforms/two_hop.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +43 -6
- torch_geometric/utils/__init__.py +5 -1
- torch_geometric/utils/_negative_sampling.py +1 -1
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +38 -12
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +12 -8
- torch_geometric/utils/geodesic.py +24 -22
- torch_geometric/utils/hetero.py +1 -1
- torch_geometric/utils/map.py +8 -2
- torch_geometric/utils/smiles.py +65 -27
- torch_geometric/utils/sparse.py +39 -25
- torch_geometric/visualization/graph.py +3 -4
@@ -70,8 +70,8 @@ class HeteroConv(torch.nn.Module):
|
|
70
70
|
for edge_type, module in convs.items():
|
71
71
|
check_add_self_loops(module, [edge_type])
|
72
72
|
|
73
|
-
src_node_types =
|
74
|
-
dst_node_types =
|
73
|
+
src_node_types = {key[0] for key in convs.keys()}
|
74
|
+
dst_node_types = {key[-1] for key in convs.keys()}
|
75
75
|
if len(src_node_types - dst_node_types) > 0:
|
76
76
|
warnings.warn(
|
77
77
|
f"There exist node types ({src_node_types - dst_node_types}) "
|
@@ -102,7 +102,7 @@ class HeteroConv(torch.nn.Module):
|
|
102
102
|
individual edge type, either as a :class:`torch.Tensor` of
|
103
103
|
shape :obj:`[2, num_edges]` or a
|
104
104
|
:class:`torch_sparse.SparseTensor`.
|
105
|
-
*args_dict (optional): Additional forward arguments of
|
105
|
+
*args_dict (optional): Additional forward arguments of individual
|
106
106
|
:class:`torch_geometric.nn.conv.MessagePassing` layers.
|
107
107
|
**kwargs_dict (optional): Additional forward arguments of
|
108
108
|
individual :class:`torch_geometric.nn.conv.MessagePassing`
|
@@ -67,7 +67,7 @@ class HGTConv(MessagePassing):
|
|
67
67
|
for i, edge_type in enumerate(metadata[1])
|
68
68
|
}
|
69
69
|
|
70
|
-
self.dst_node_types =
|
70
|
+
self.dst_node_types = {key[-1] for key in self.edge_types}
|
71
71
|
|
72
72
|
self.kqv_lin = HeteroDictLinear(self.in_channels,
|
73
73
|
self.out_channels * 3)
|
@@ -6,6 +6,7 @@ from typing import (
|
|
6
6
|
Any,
|
7
7
|
Callable,
|
8
8
|
Dict,
|
9
|
+
Final,
|
9
10
|
List,
|
10
11
|
Optional,
|
11
12
|
OrderedDict,
|
@@ -19,6 +20,7 @@ from torch import Tensor
|
|
19
20
|
from torch.utils.hooks import RemovableHandle
|
20
21
|
|
21
22
|
from torch_geometric import EdgeIndex, is_compiling
|
23
|
+
from torch_geometric.index import ptr2index
|
22
24
|
from torch_geometric.inspector import Inspector, Signature
|
23
25
|
from torch_geometric.nn.aggr import Aggregation
|
24
26
|
from torch_geometric.nn.resolver import aggregation_resolver as aggr_resolver
|
@@ -29,7 +31,6 @@ from torch_geometric.utils import (
|
|
29
31
|
is_torch_sparse_tensor,
|
30
32
|
to_edge_index,
|
31
33
|
)
|
32
|
-
from torch_geometric.utils.sparse import ptr2index
|
33
34
|
|
34
35
|
FUSE_AGGRS = {'add', 'sum', 'mean', 'min', 'max'}
|
35
36
|
HookDict = OrderedDict[int, Callable]
|
@@ -102,6 +103,10 @@ class MessagePassing(torch.nn.Module):
|
|
102
103
|
'size_i', 'size_j', 'ptr', 'index', 'dim_size'
|
103
104
|
}
|
104
105
|
|
106
|
+
# Supports `message_and_aggregate` via `EdgeIndex`.
|
107
|
+
# TODO Remove once migration is finished.
|
108
|
+
SUPPORTS_FUSED_EDGE_INDEX: Final[bool] = False
|
109
|
+
|
105
110
|
def __init__(
|
106
111
|
self,
|
107
112
|
aggr: Optional[Union[str, List[str], Aggregation]] = 'sum',
|
@@ -162,63 +167,8 @@ class MessagePassing(torch.nn.Module):
|
|
162
167
|
self._edge_update_forward_pre_hooks: HookDict = OrderedDict()
|
163
168
|
self._edge_update_forward_hooks: HookDict = OrderedDict()
|
164
169
|
|
165
|
-
|
166
|
-
|
167
|
-
# Optimize `propagate()` via `*.jinja` templates:
|
168
|
-
if not self.propagate.__module__.startswith(jinja_prefix):
|
169
|
-
try:
|
170
|
-
module = module_from_template(
|
171
|
-
module_name=f'{jinja_prefix}_propagate',
|
172
|
-
template_path=osp.join(root_dir, 'propagate.jinja'),
|
173
|
-
tmp_dirname='message_passing',
|
174
|
-
# Keyword arguments:
|
175
|
-
modules=self.inspector._modules,
|
176
|
-
collect_name='collect',
|
177
|
-
signature=self._get_propagate_signature(),
|
178
|
-
collect_param_dict=self.inspector.get_flat_param_dict(
|
179
|
-
['message', 'aggregate', 'update']),
|
180
|
-
message_args=self.inspector.get_param_names('message'),
|
181
|
-
aggregate_args=self.inspector.get_param_names('aggregate'),
|
182
|
-
message_and_aggregate_args=self.inspector.get_param_names(
|
183
|
-
'message_and_aggregate'),
|
184
|
-
update_args=self.inspector.get_param_names('update'),
|
185
|
-
fuse=self.fuse,
|
186
|
-
)
|
187
|
-
|
188
|
-
self.__class__._orig_propagate = self.__class__.propagate
|
189
|
-
self.__class__._jinja_propagate = module.propagate
|
190
|
-
|
191
|
-
self.__class__.propagate = module.propagate
|
192
|
-
self.__class__.collect = module.collect
|
193
|
-
except Exception: # pragma: no cover
|
194
|
-
self.__class__._orig_propagate = self.__class__.propagate
|
195
|
-
self.__class__._jinja_propagate = self.__class__.propagate
|
196
|
-
|
197
|
-
# Optimize `edge_updater()` via `*.jinja` templates (if implemented):
|
198
|
-
if (self.inspector.implements('edge_update')
|
199
|
-
and not self.edge_updater.__module__.startswith(jinja_prefix)):
|
200
|
-
try:
|
201
|
-
module = module_from_template(
|
202
|
-
module_name=f'{jinja_prefix}_edge_updater',
|
203
|
-
template_path=osp.join(root_dir, 'edge_updater.jinja'),
|
204
|
-
tmp_dirname='message_passing',
|
205
|
-
# Keyword arguments:
|
206
|
-
modules=self.inspector._modules,
|
207
|
-
collect_name='edge_collect',
|
208
|
-
signature=self._get_edge_updater_signature(),
|
209
|
-
collect_param_dict=self.inspector.get_param_dict(
|
210
|
-
'edge_update'),
|
211
|
-
)
|
212
|
-
|
213
|
-
self.__class__._orig_edge_updater = self.__class__.edge_updater
|
214
|
-
self.__class__._jinja_edge_updater = module.edge_updater
|
215
|
-
|
216
|
-
self.__class__.edge_updater = module.edge_updater
|
217
|
-
self.__class__.edge_collect = module.edge_collect
|
218
|
-
except Exception: # pragma: no cover
|
219
|
-
self.__class__._orig_edge_updater = self.__class__.edge_updater
|
220
|
-
self.__class__._jinja_edge_updater = (
|
221
|
-
self.__class__.edge_updater)
|
170
|
+
# Set jittable `propagate` and `edge_updater` function templates:
|
171
|
+
self._set_jittable_templates()
|
222
172
|
|
223
173
|
# Explainability:
|
224
174
|
self._explain: Optional[bool] = None
|
@@ -227,6 +177,7 @@ class MessagePassing(torch.nn.Module):
|
|
227
177
|
self._apply_sigmoid: bool = True
|
228
178
|
|
229
179
|
# Inference Decomposition:
|
180
|
+
self._decomposed_layers = 1
|
230
181
|
self.decomposed_layers = decomposed_layers
|
231
182
|
|
232
183
|
def reset_parameters(self) -> None:
|
@@ -234,6 +185,12 @@ class MessagePassing(torch.nn.Module):
|
|
234
185
|
if self.aggr_module is not None:
|
235
186
|
self.aggr_module.reset_parameters()
|
236
187
|
|
188
|
+
def __setstate__(self, data: Dict[str, Any]) -> None:
|
189
|
+
self.inspector = data['inspector']
|
190
|
+
self.fuse = data['fuse']
|
191
|
+
self._set_jittable_templates()
|
192
|
+
super().__setstate__(data)
|
193
|
+
|
237
194
|
def __repr__(self) -> str:
|
238
195
|
channels_repr = ''
|
239
196
|
if hasattr(self, 'in_channels') and hasattr(self, 'out_channels'):
|
@@ -247,7 +204,7 @@ class MessagePassing(torch.nn.Module):
|
|
247
204
|
def _check_input(
|
248
205
|
self,
|
249
206
|
edge_index: Union[Tensor, SparseTensor],
|
250
|
-
size: Optional[Tuple[int, int]],
|
207
|
+
size: Optional[Tuple[Optional[int], Optional[int]]],
|
251
208
|
) -> List[Optional[int]]:
|
252
209
|
|
253
210
|
if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):
|
@@ -256,19 +213,20 @@ class MessagePassing(torch.nn.Module):
|
|
256
213
|
if is_sparse(edge_index):
|
257
214
|
if self.flow == 'target_to_source':
|
258
215
|
raise ValueError(
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
216
|
+
'Flow direction "target_to_source" is invalid for '
|
217
|
+
'message propagation via `torch_sparse.SparseTensor` '
|
218
|
+
'or `torch.sparse.Tensor`. If you really want to make '
|
219
|
+
'use of a reverse message passing flow, pass in the '
|
220
|
+
'transposed sparse tensor to the message passing module, '
|
221
|
+
'e.g., `adj_t.t()`.')
|
265
222
|
|
266
223
|
if isinstance(edge_index, SparseTensor):
|
267
224
|
return [edge_index.size(1), edge_index.size(0)]
|
268
225
|
return [edge_index.size(1), edge_index.size(0)]
|
269
226
|
|
270
227
|
elif isinstance(edge_index, Tensor):
|
271
|
-
int_dtypes = (torch.uint8, torch.int8, torch.
|
228
|
+
int_dtypes = (torch.uint8, torch.int8, torch.int16, torch.int32,
|
229
|
+
torch.int64)
|
272
230
|
|
273
231
|
if edge_index.dtype not in int_dtypes:
|
274
232
|
raise ValueError(f"Expected 'edge_index' to be of integer "
|
@@ -284,9 +242,9 @@ class MessagePassing(torch.nn.Module):
|
|
284
242
|
return list(size) if size is not None else [None, None]
|
285
243
|
|
286
244
|
raise ValueError(
|
287
|
-
|
288
|
-
|
289
|
-
|
245
|
+
'`MessagePassing.propagate` only supports integer tensors of '
|
246
|
+
'shape `[2, num_messages]`, `torch_sparse.SparseTensor` or '
|
247
|
+
'`torch.sparse.Tensor` for argument `edge_index`.')
|
290
248
|
|
291
249
|
def _set_size(
|
292
250
|
self,
|
@@ -299,8 +257,8 @@ class MessagePassing(torch.nn.Module):
|
|
299
257
|
size[dim] = src.size(self.node_dim)
|
300
258
|
elif the_size != src.size(self.node_dim):
|
301
259
|
raise ValueError(
|
302
|
-
|
303
|
-
|
260
|
+
f'Encountered tensor with size {src.size(self.node_dim)} in '
|
261
|
+
f'dimension {self.node_dim}, but expected size {the_size}.')
|
304
262
|
|
305
263
|
def _index_select(self, src: Tensor, index) -> Tensor:
|
306
264
|
if torch.jit.is_scripting() or is_compiling():
|
@@ -370,9 +328,9 @@ class MessagePassing(torch.nn.Module):
|
|
370
328
|
return src.index_select(self.node_dim, row)
|
371
329
|
|
372
330
|
raise ValueError(
|
373
|
-
|
374
|
-
|
375
|
-
|
331
|
+
'`MessagePassing.propagate` only supports integer tensors of '
|
332
|
+
'shape `[2, num_messages]`, `torch_sparse.SparseTensor` '
|
333
|
+
'or `torch.sparse.Tensor` for argument `edge_index`.')
|
376
334
|
|
377
335
|
def _collect(
|
378
336
|
self,
|
@@ -459,7 +417,6 @@ class MessagePassing(torch.nn.Module):
|
|
459
417
|
|
460
418
|
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
461
419
|
r"""Runs the forward pass of the module."""
|
462
|
-
pass
|
463
420
|
|
464
421
|
def propagate(
|
465
422
|
self,
|
@@ -509,7 +466,17 @@ class MessagePassing(torch.nn.Module):
|
|
509
466
|
mutable_size = self._check_input(edge_index, size)
|
510
467
|
|
511
468
|
# Run "fused" message and aggregation (if applicable).
|
512
|
-
|
469
|
+
fuse = False
|
470
|
+
if self.fuse and not self.explain:
|
471
|
+
if is_sparse(edge_index):
|
472
|
+
fuse = True
|
473
|
+
elif (not torch.jit.is_scripting()
|
474
|
+
and isinstance(edge_index, EdgeIndex)):
|
475
|
+
if (self.SUPPORTS_FUSED_EDGE_INDEX
|
476
|
+
and edge_index.is_sorted_by_col):
|
477
|
+
fuse = True
|
478
|
+
|
479
|
+
if fuse:
|
513
480
|
coll_dict = self._collect(self._fused_user_args, edge_index,
|
514
481
|
mutable_size, kwargs)
|
515
482
|
|
@@ -628,7 +595,7 @@ class MessagePassing(torch.nn.Module):
|
|
628
595
|
dim=self.node_dim)
|
629
596
|
|
630
597
|
@abstractmethod
|
631
|
-
def message_and_aggregate(self,
|
598
|
+
def message_and_aggregate(self, edge_index: Adj) -> Tensor:
|
632
599
|
r"""Fuses computations of :func:`message` and :func:`aggregate` into a
|
633
600
|
single function.
|
634
601
|
If applicable, this saves both time and memory since messages do not
|
@@ -720,16 +687,20 @@ class MessagePassing(torch.nn.Module):
|
|
720
687
|
raise ValueError("Inference decomposition of message passing "
|
721
688
|
"modules is only supported on the Python module")
|
722
689
|
|
690
|
+
if decomposed_layers == self._decomposed_layers:
|
691
|
+
return # Abort early if nothing to do.
|
692
|
+
|
723
693
|
self._decomposed_layers = decomposed_layers
|
724
694
|
|
725
695
|
if decomposed_layers != 1:
|
726
|
-
|
727
|
-
self
|
696
|
+
if hasattr(self.__class__, '_orig_propagate'):
|
697
|
+
self.propagate = self.__class__._orig_propagate.__get__(
|
698
|
+
self, MessagePassing)
|
728
699
|
|
729
|
-
elif
|
730
|
-
|
731
|
-
|
732
|
-
|
700
|
+
elif self.explain is None or self.explain is False:
|
701
|
+
if hasattr(self.__class__, '_jinja_propagate'):
|
702
|
+
self.propagate = self.__class__._jinja_propagate.__get__(
|
703
|
+
self, MessagePassing)
|
733
704
|
|
734
705
|
# Explainability ##########################################################
|
735
706
|
|
@@ -743,6 +714,9 @@ class MessagePassing(torch.nn.Module):
|
|
743
714
|
raise ValueError("Explainability of message passing modules "
|
744
715
|
"is only supported on the Python module")
|
745
716
|
|
717
|
+
if explain == self._explain:
|
718
|
+
return # Abort early if nothing to do.
|
719
|
+
|
746
720
|
self._explain = explain
|
747
721
|
|
748
722
|
if explain is True:
|
@@ -753,16 +727,18 @@ class MessagePassing(torch.nn.Module):
|
|
753
727
|
funcs=['message', 'explain_message', 'aggregate', 'update'],
|
754
728
|
exclude=self.special_args,
|
755
729
|
)
|
756
|
-
|
757
|
-
self
|
730
|
+
if hasattr(self.__class__, '_orig_propagate'):
|
731
|
+
self.propagate = self.__class__._orig_propagate.__get__(
|
732
|
+
self, MessagePassing)
|
758
733
|
else:
|
759
734
|
self._user_args = self.inspector.get_flat_param_names(
|
760
735
|
funcs=['message', 'aggregate', 'update'],
|
761
736
|
exclude=self.special_args,
|
762
737
|
)
|
763
738
|
if self.decomposed_layers == 1:
|
764
|
-
|
765
|
-
self
|
739
|
+
if hasattr(self.__class__, '_jinja_propagate'):
|
740
|
+
self.propagate = self.__class__._jinja_propagate.__get__(
|
741
|
+
self, MessagePassing)
|
766
742
|
|
767
743
|
def explain_message(
|
768
744
|
self,
|
@@ -947,6 +923,81 @@ class MessagePassing(torch.nn.Module):
|
|
947
923
|
|
948
924
|
# TorchScript Support #####################################################
|
949
925
|
|
926
|
+
def _set_jittable_templates(self, raise_on_error: bool = False) -> None:
|
927
|
+
root_dir = osp.dirname(osp.realpath(__file__))
|
928
|
+
jinja_prefix = f'{self.__module__}_{self.__class__.__name__}'
|
929
|
+
# Optimize `propagate()` via `*.jinja` templates:
|
930
|
+
if not self.propagate.__module__.startswith(jinja_prefix):
|
931
|
+
try:
|
932
|
+
if ('propagate' in self.__class__.__dict__
|
933
|
+
and self.__class__.__dict__['propagate']
|
934
|
+
!= MessagePassing.propagate):
|
935
|
+
raise ValueError("Cannot compile custom 'propagate' "
|
936
|
+
"method")
|
937
|
+
|
938
|
+
module = module_from_template(
|
939
|
+
module_name=f'{jinja_prefix}_propagate',
|
940
|
+
template_path=osp.join(root_dir, 'propagate.jinja'),
|
941
|
+
tmp_dirname='message_passing',
|
942
|
+
# Keyword arguments:
|
943
|
+
modules=self.inspector._modules,
|
944
|
+
collect_name='collect',
|
945
|
+
signature=self._get_propagate_signature(),
|
946
|
+
collect_param_dict=self.inspector.get_flat_param_dict(
|
947
|
+
['message', 'aggregate', 'update']),
|
948
|
+
message_args=self.inspector.get_param_names('message'),
|
949
|
+
aggregate_args=self.inspector.get_param_names('aggregate'),
|
950
|
+
message_and_aggregate_args=self.inspector.get_param_names(
|
951
|
+
'message_and_aggregate'),
|
952
|
+
update_args=self.inspector.get_param_names('update'),
|
953
|
+
fuse=self.fuse,
|
954
|
+
)
|
955
|
+
|
956
|
+
self.__class__._orig_propagate = self.__class__.propagate
|
957
|
+
self.__class__._jinja_propagate = module.propagate
|
958
|
+
|
959
|
+
self.__class__.propagate = module.propagate
|
960
|
+
self.__class__.collect = module.collect
|
961
|
+
except Exception as e: # pragma: no cover
|
962
|
+
if raise_on_error:
|
963
|
+
raise e
|
964
|
+
self.__class__._orig_propagate = self.__class__.propagate
|
965
|
+
self.__class__._jinja_propagate = self.__class__.propagate
|
966
|
+
|
967
|
+
# Optimize `edge_updater()` via `*.jinja` templates (if implemented):
|
968
|
+
if (self.inspector.implements('edge_update')
|
969
|
+
and not self.edge_updater.__module__.startswith(jinja_prefix)):
|
970
|
+
try:
|
971
|
+
if ('edge_updater' in self.__class__.__dict__
|
972
|
+
and self.__class__.__dict__['edge_updater']
|
973
|
+
!= MessagePassing.edge_updater):
|
974
|
+
raise ValueError("Cannot compile custom 'edge_updater' "
|
975
|
+
"method")
|
976
|
+
|
977
|
+
module = module_from_template(
|
978
|
+
module_name=f'{jinja_prefix}_edge_updater',
|
979
|
+
template_path=osp.join(root_dir, 'edge_updater.jinja'),
|
980
|
+
tmp_dirname='message_passing',
|
981
|
+
# Keyword arguments:
|
982
|
+
modules=self.inspector._modules,
|
983
|
+
collect_name='edge_collect',
|
984
|
+
signature=self._get_edge_updater_signature(),
|
985
|
+
collect_param_dict=self.inspector.get_param_dict(
|
986
|
+
'edge_update'),
|
987
|
+
)
|
988
|
+
|
989
|
+
self.__class__._orig_edge_updater = self.__class__.edge_updater
|
990
|
+
self.__class__._jinja_edge_updater = module.edge_updater
|
991
|
+
|
992
|
+
self.__class__.edge_updater = module.edge_updater
|
993
|
+
self.__class__.edge_collect = module.edge_collect
|
994
|
+
except Exception as e: # pragma: no cover
|
995
|
+
if raise_on_error:
|
996
|
+
raise e
|
997
|
+
self.__class__._orig_edge_updater = self.__class__.edge_updater
|
998
|
+
self.__class__._jinja_edge_updater = (
|
999
|
+
self.__class__.edge_updater)
|
1000
|
+
|
950
1001
|
def _get_propagate_signature(self) -> Signature:
|
951
1002
|
param_dict = self.inspector.get_params_from_method_call(
|
952
1003
|
'propagate', exclude=[0, 'edge_index', 'size'])
|
@@ -14,7 +14,7 @@ from torch_geometric.utils import spmm
|
|
14
14
|
|
15
15
|
class MixHopConv(MessagePassing):
|
16
16
|
r"""The Mix-Hop graph convolutional operator from the
|
17
|
-
`"MixHop: Higher-Order Graph Convolutional
|
17
|
+
`"MixHop: Higher-Order Graph Convolutional Architectures via Sparsified
|
18
18
|
Neighborhood Mixing" <https://arxiv.org/abs/1905.00067>`_ paper.
|
19
19
|
|
20
20
|
.. math::
|
@@ -42,7 +42,15 @@ def propagate(
|
|
42
42
|
# End Propagate Forward Pre Hook ###########################################
|
43
43
|
|
44
44
|
mutable_size = self._check_input(edge_index, size)
|
45
|
-
|
45
|
+
|
46
|
+
# Run "fused" message and aggregation (if applicable).
|
47
|
+
fuse = False
|
48
|
+
if self.fuse:
|
49
|
+
if is_sparse(edge_index):
|
50
|
+
fuse = True
|
51
|
+
elif not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):
|
52
|
+
if self.SUPPORTS_FUSED_EDGE_INDEX and edge_index.is_sorted_by_col:
|
53
|
+
fuse = True
|
46
54
|
|
47
55
|
if fuse:
|
48
56
|
|
@@ -3,11 +3,11 @@ from typing import Optional, Tuple, Union
|
|
3
3
|
import torch
|
4
4
|
from torch import Tensor
|
5
5
|
from torch.nn import Parameter
|
6
|
-
from torch.nn import Parameter as Param
|
7
6
|
|
8
7
|
import torch_geometric.backend
|
9
8
|
import torch_geometric.typing
|
10
9
|
from torch_geometric import is_compiling
|
10
|
+
from torch_geometric.index import index2ptr
|
11
11
|
from torch_geometric.nn.conv import MessagePassing
|
12
12
|
from torch_geometric.nn.inits import glorot, zeros
|
13
13
|
from torch_geometric.typing import (
|
@@ -18,7 +18,6 @@ from torch_geometric.typing import (
|
|
18
18
|
torch_sparse,
|
19
19
|
)
|
20
20
|
from torch_geometric.utils import index_sort, one_hot, scatter, spmm
|
21
|
-
from torch_geometric.utils.sparse import index2ptr
|
22
21
|
|
23
22
|
|
24
23
|
def masked_edge_index(edge_index: Adj, edge_mask: Tensor) -> Adj:
|
@@ -121,7 +120,8 @@ class RGCNConv(MessagePassing):
|
|
121
120
|
in_channels = (in_channels, in_channels)
|
122
121
|
self.in_channels_l = in_channels[0]
|
123
122
|
|
124
|
-
self._use_segment_matmul_heuristic_output:
|
123
|
+
self._use_segment_matmul_heuristic_output: torch.jit.Attribute(
|
124
|
+
None, Optional[float])
|
125
125
|
|
126
126
|
if num_bases is not None:
|
127
127
|
self.weight = Parameter(
|
@@ -143,12 +143,12 @@ class RGCNConv(MessagePassing):
|
|
143
143
|
self.register_parameter('comp', None)
|
144
144
|
|
145
145
|
if root_weight:
|
146
|
-
self.root =
|
146
|
+
self.root = Parameter(torch.empty(in_channels[1], out_channels))
|
147
147
|
else:
|
148
148
|
self.register_parameter('root', None)
|
149
149
|
|
150
150
|
if bias:
|
151
|
-
self.bias =
|
151
|
+
self.bias = Parameter(torch.empty(out_channels))
|
152
152
|
else:
|
153
153
|
self.register_parameter('bias', None)
|
154
154
|
|
@@ -5,17 +5,17 @@ import torch
|
|
5
5
|
from torch import Tensor, nn
|
6
6
|
from torch.nn import Parameter
|
7
7
|
|
8
|
+
import torch_geometric.typing
|
8
9
|
from torch_geometric.nn.conv import MessagePassing
|
9
10
|
from torch_geometric.nn.dense.linear import Linear
|
10
11
|
from torch_geometric.nn.inits import uniform, zeros
|
11
12
|
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size
|
12
13
|
from torch_geometric.utils.repeat import repeat
|
13
14
|
|
14
|
-
|
15
|
+
if torch_geometric.typing.WITH_TORCH_SPLINE_CONV:
|
15
16
|
from torch_spline_conv import spline_basis, spline_weighting
|
16
|
-
|
17
|
-
spline_basis = None
|
18
|
-
spline_weighting = None
|
17
|
+
else:
|
18
|
+
spline_basis = spline_weighting = None
|
19
19
|
|
20
20
|
|
21
21
|
class SplineConv(MessagePassing):
|
@@ -9,12 +9,13 @@ from torch.nn import Conv1d
|
|
9
9
|
from torch.nn import Linear as L
|
10
10
|
from torch.nn import Sequential as S
|
11
11
|
|
12
|
+
import torch_geometric.typing
|
12
13
|
from torch_geometric.nn import Reshape
|
13
14
|
from torch_geometric.nn.inits import reset
|
14
15
|
|
15
|
-
|
16
|
+
if torch_geometric.typing.WITH_TORCH_CLUSTER:
|
16
17
|
from torch_cluster import knn_graph
|
17
|
-
|
18
|
+
else:
|
18
19
|
knn_graph = None
|
19
20
|
|
20
21
|
|
@@ -12,10 +12,10 @@ from torch.nn.parameter import Parameter
|
|
12
12
|
import torch_geometric.backend
|
13
13
|
import torch_geometric.typing
|
14
14
|
from torch_geometric import is_compiling
|
15
|
+
from torch_geometric.index import index2ptr
|
15
16
|
from torch_geometric.nn import inits
|
16
17
|
from torch_geometric.typing import pyg_lib
|
17
18
|
from torch_geometric.utils import index_sort
|
18
|
-
from torch_geometric.utils.sparse import index2ptr
|
19
19
|
|
20
20
|
|
21
21
|
def is_uninitialized_parameter(x: Any) -> bool:
|
@@ -58,7 +58,7 @@ def reset_bias_(bias: Optional[Tensor], in_channels: int,
|
|
58
58
|
|
59
59
|
|
60
60
|
class Linear(torch.nn.Module):
|
61
|
-
r"""Applies a linear
|
61
|
+
r"""Applies a linear transformation to the incoming data.
|
62
62
|
|
63
63
|
.. math::
|
64
64
|
\mathbf{x}^{\prime} = \mathbf{x} \mathbf{W}^{\top} + \mathbf{b}
|
@@ -192,7 +192,7 @@ class Linear(torch.nn.Module):
|
|
192
192
|
|
193
193
|
|
194
194
|
class HeteroLinear(torch.nn.Module):
|
195
|
-
r"""Applies separate linear
|
195
|
+
r"""Applies separate linear transformations to the incoming data according
|
196
196
|
to types.
|
197
197
|
|
198
198
|
For type :math:`\kappa`, it computes
|
@@ -222,6 +222,8 @@ class HeteroLinear(torch.nn.Module):
|
|
222
222
|
type vector :math:`(*)`
|
223
223
|
- **output:** features :math:`(*, F_{out})`
|
224
224
|
"""
|
225
|
+
_timing_cache: Dict[int, Tuple[float, float]]
|
226
|
+
|
225
227
|
def __init__(
|
226
228
|
self,
|
227
229
|
in_channels: int,
|
@@ -245,15 +247,17 @@ class HeteroLinear(torch.nn.Module):
|
|
245
247
|
else:
|
246
248
|
self.weight = torch.nn.Parameter(
|
247
249
|
torch.empty(num_types, in_channels, out_channels))
|
250
|
+
|
248
251
|
if kwargs.get('bias', True):
|
249
252
|
self.bias = Parameter(torch.empty(num_types, out_channels))
|
250
253
|
else:
|
251
254
|
self.register_parameter('bias', None)
|
252
|
-
self.reset_parameters()
|
253
255
|
|
254
256
|
# Timing cache for benchmarking naive vs. segment matmul usage:
|
255
257
|
self._timing_cache: Dict[int, Tuple[float, float]] = {}
|
256
258
|
|
259
|
+
self.reset_parameters()
|
260
|
+
|
257
261
|
def reset_parameters(self):
|
258
262
|
r"""Resets all learnable parameters of the module."""
|
259
263
|
reset_weight_(self.weight, self.in_channels,
|
@@ -361,7 +365,8 @@ class HeteroLinear(torch.nn.Module):
|
|
361
365
|
|
362
366
|
|
363
367
|
class HeteroDictLinear(torch.nn.Module):
|
364
|
-
r"""Applies separate linear
|
368
|
+
r"""Applies separate linear transformations to the incoming data
|
369
|
+
dictionary.
|
365
370
|
|
366
371
|
For key :math:`\kappa`, it computes
|
367
372
|
|
@@ -475,7 +480,7 @@ class HeteroDictLinear(torch.nn.Module):
|
|
475
480
|
lin = self.lins[key]
|
476
481
|
if is_uninitialized_parameter(lin.weight):
|
477
482
|
self.lins[key].initialize_parameters(None, x)
|
478
|
-
|
483
|
+
self.lins[key].reset_parameters()
|
479
484
|
self._hook.remove()
|
480
485
|
self.in_channels = {key: x.size(-1) for key, x in input[0].items()}
|
481
486
|
delattr(self, '_hook')
|
torch_geometric/nn/fx.py
CHANGED
@@ -18,8 +18,8 @@ class Transformer:
|
|
18
18
|
:class:`~torch.nn.Module`.
|
19
19
|
:class:`Transformer` works entirely symbolically.
|
20
20
|
|
21
|
-
Methods in the :class:`Transformer` class can be
|
22
|
-
behavior of transformation.
|
21
|
+
Methods in the :class:`Transformer` class can be overridden to customize
|
22
|
+
the behavior of transformation.
|
23
23
|
|
24
24
|
.. code-block:: none
|
25
25
|
|
@@ -283,7 +283,7 @@ def symbolic_trace(
|
|
283
283
|
# TODO We currently only trace top-level modules.
|
284
284
|
return not isinstance(module, torch.nn.Sequential)
|
285
285
|
|
286
|
-
# Note: This is a hack around the fact that `
|
286
|
+
# Note: This is a hack around the fact that `Aggregation.__call__`
|
287
287
|
# is not patched by the base implementation of `trace`.
|
288
288
|
# see https://github.com/pyg-team/pytorch_geometric/pull/5021 for
|
289
289
|
# details on the rationale
|
torch_geometric/nn/model_hub.py
CHANGED
@@ -4,6 +4,8 @@ from typing import Any, Dict, Optional, Union
|
|
4
4
|
|
5
5
|
import torch
|
6
6
|
|
7
|
+
from torch_geometric.io import fs
|
8
|
+
|
7
9
|
try:
|
8
10
|
from huggingface_hub import ModelHubMixin, hf_hub_download
|
9
11
|
except ImportError:
|
@@ -175,7 +177,7 @@ class PyGModelHubMixin(ModelHubMixin):
|
|
175
177
|
|
176
178
|
model = cls(dataset_name, model_name, model_kwargs)
|
177
179
|
|
178
|
-
state_dict =
|
180
|
+
state_dict = fs.torch_load(model_file, map_location=map_location)
|
179
181
|
model.load_state_dict(state_dict, strict=strict)
|
180
182
|
model.eval()
|
181
183
|
|
@@ -2,7 +2,7 @@ r"""Model package."""
|
|
2
2
|
|
3
3
|
from .mlp import MLP
|
4
4
|
from .basic_gnn import GCN, GraphSAGE, GIN, GAT, PNA, EdgeCNN
|
5
|
-
from .jumping_knowledge import JumpingKnowledge
|
5
|
+
from .jumping_knowledge import JumpingKnowledge, HeteroJumpingKnowledge
|
6
6
|
from .meta import MetaLayer
|
7
7
|
from .node2vec import Node2Vec
|
8
8
|
from .deep_graph_infomax import DeepGraphInfomax
|
@@ -28,7 +28,10 @@ from .gnnff import GNNFF
|
|
28
28
|
from .pmlp import PMLP
|
29
29
|
from .neural_fingerprint import NeuralFingerprint
|
30
30
|
from .visnet import ViSNet
|
31
|
-
|
31
|
+
from .g_retriever import GRetriever
|
32
|
+
from .git_mol import GITMol
|
33
|
+
from .molecule_gpt import MoleculeGPT
|
34
|
+
from .glem import GLEM
|
32
35
|
# Deprecated:
|
33
36
|
from torch_geometric.explain.algorithm.captum import (to_captum_input,
|
34
37
|
captum_output_to_dicts)
|
@@ -42,6 +45,7 @@ __all__ = classes = [
|
|
42
45
|
'PNA',
|
43
46
|
'EdgeCNN',
|
44
47
|
'JumpingKnowledge',
|
48
|
+
'HeteroJumpingKnowledge',
|
45
49
|
'MetaLayer',
|
46
50
|
'Node2Vec',
|
47
51
|
'DeepGraphInfomax',
|
@@ -74,4 +78,8 @@ __all__ = classes = [
|
|
74
78
|
'PMLP',
|
75
79
|
'NeuralFingerprint',
|
76
80
|
'ViSNet',
|
81
|
+
'GRetriever',
|
82
|
+
'GITMol',
|
83
|
+
'MoleculeGPT',
|
84
|
+
'GLEM',
|
77
85
|
]
|