pyg-nightly 2.6.0.dev20240511__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {pyg_nightly-2.6.0.dev20240511.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +30 -31
- {pyg_nightly-2.6.0.dev20240511.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +205 -181
- {pyg_nightly-2.6.0.dev20240511.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
- torch_geometric/__init__.py +26 -1
- torch_geometric/_compile.py +8 -1
- torch_geometric/_onnx.py +14 -0
- torch_geometric/config_mixin.py +113 -0
- torch_geometric/config_store.py +16 -14
- torch_geometric/data/__init__.py +24 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/data.py +13 -8
- torch_geometric/data/database.py +15 -7
- torch_geometric/data/dataset.py +14 -6
- torch_geometric/data/feature_store.py +13 -22
- torch_geometric/data/graph_store.py +0 -4
- torch_geometric/data/hetero_data.py +4 -4
- torch_geometric/data/in_memory_dataset.py +2 -4
- torch_geometric/data/large_graph_indexer.py +677 -0
- torch_geometric/data/lightning/datamodule.py +4 -4
- torch_geometric/data/storage.py +15 -5
- torch_geometric/data/summary.py +14 -4
- torch_geometric/data/temporal.py +1 -2
- torch_geometric/datasets/__init__.py +11 -1
- torch_geometric/datasets/actor.py +9 -11
- torch_geometric/datasets/airfrans.py +15 -18
- torch_geometric/datasets/airports.py +10 -12
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +9 -10
- torch_geometric/datasets/amazon_products.py +9 -10
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +10 -12
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/bitcoin_otc.py +1 -1
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/dblp.py +2 -1
- torch_geometric/datasets/dbp15k.py +2 -2
- torch_geometric/datasets/fake.py +1 -3
- torch_geometric/datasets/flickr.py +2 -1
- torch_geometric/datasets/freebase.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +6 -5
- torch_geometric/datasets/hgb_dataset.py +8 -8
- torch_geometric/datasets/imdb.py +2 -1
- torch_geometric/datasets/last_fm.py +2 -1
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +4 -3
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
- torch_geometric/datasets/molecule_net.py +7 -1
- torch_geometric/datasets/motif_generator/base.py +0 -1
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +239 -0
- torch_geometric/datasets/ose_gvcs.py +1 -1
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcpnet_dataset.py +1 -1
- torch_geometric/datasets/pcqm4m.py +2 -1
- torch_geometric/datasets/ppi.py +1 -1
- torch_geometric/datasets/qm9.py +4 -3
- torch_geometric/datasets/reddit.py +2 -1
- torch_geometric/datasets/reddit2.py +2 -1
- torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
- torch_geometric/datasets/s3dis.py +2 -2
- torch_geometric/datasets/shapenet.py +3 -3
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/tag_dataset.py +350 -0
- torch_geometric/datasets/upfd.py +2 -1
- torch_geometric/datasets/web_qsp_dataset.py +246 -0
- torch_geometric/datasets/webkb.py +2 -2
- torch_geometric/datasets/wikics.py +1 -1
- torch_geometric/datasets/wikidata.py +3 -2
- torch_geometric/datasets/wikipedia_network.py +2 -2
- torch_geometric/datasets/word_net.py +2 -2
- torch_geometric/datasets/yelp.py +2 -1
- torch_geometric/datasets/zinc.py +1 -1
- torch_geometric/device.py +42 -0
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/edge_index.py +17 -8
- torch_geometric/explain/algorithm/base.py +0 -1
- torch_geometric/explain/algorithm/pg_explainer.py +1 -1
- torch_geometric/explain/explanation.py +2 -2
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/logger.py +4 -4
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/utils/agg_runs.py +6 -6
- torch_geometric/index.py +20 -7
- torch_geometric/inspector.py +6 -2
- torch_geometric/io/fs.py +28 -2
- torch_geometric/io/npz.py +2 -1
- torch_geometric/io/off.py +2 -2
- torch_geometric/io/sdf.py +2 -2
- torch_geometric/io/tu.py +2 -3
- torch_geometric/loader/__init__.py +4 -0
- torch_geometric/loader/cluster.py +9 -3
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +12 -4
- torch_geometric/loader/mixin.py +1 -1
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +1 -1
- torch_geometric/loader/rag_loader.py +107 -0
- torch_geometric/loader/zip_loader.py +10 -0
- torch_geometric/metrics/__init__.py +11 -2
- torch_geometric/metrics/link_pred.py +159 -34
- torch_geometric/nn/aggr/__init__.py +2 -0
- torch_geometric/nn/aggr/attention.py +0 -2
- torch_geometric/nn/aggr/base.py +2 -4
- torch_geometric/nn/aggr/patch_transformer.py +143 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/attention/__init__.py +5 -1
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/conv/collect.jinja +6 -3
- torch_geometric/nn/conv/cugraph/base.py +0 -1
- torch_geometric/nn/conv/edge_conv.py +3 -2
- torch_geometric/nn/conv/gat_conv.py +35 -7
- torch_geometric/nn/conv/gatv2_conv.py +36 -6
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +3 -2
- torch_geometric/nn/conv/hetero_conv.py +3 -3
- torch_geometric/nn/conv/hgt_conv.py +1 -1
- torch_geometric/nn/conv/message_passing.py +100 -82
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/spline_conv.py +4 -4
- torch_geometric/nn/conv/x_conv.py +3 -2
- torch_geometric/nn/dense/linear.py +5 -4
- torch_geometric/nn/fx.py +3 -3
- torch_geometric/nn/model_hub.py +3 -1
- torch_geometric/nn/models/__init__.py +10 -2
- torch_geometric/nn/models/deep_graph_infomax.py +1 -2
- torch_geometric/nn/models/dimenet_utils.py +5 -7
- torch_geometric/nn/models/g_retriever.py +230 -0
- torch_geometric/nn/models/git_mol.py +336 -0
- torch_geometric/nn/models/glem.py +385 -0
- torch_geometric/nn/models/gnnff.py +0 -1
- torch_geometric/nn/models/graph_unet.py +12 -3
- torch_geometric/nn/models/jumping_knowledge.py +63 -4
- torch_geometric/nn/models/lightgcn.py +1 -1
- torch_geometric/nn/models/metapath2vec.py +3 -4
- torch_geometric/nn/models/molecule_gpt.py +222 -0
- torch_geometric/nn/models/node2vec.py +1 -2
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/signed_gcn.py +3 -3
- torch_geometric/nn/module_dict.py +2 -2
- torch_geometric/nn/nlp/__init__.py +9 -0
- torch_geometric/nn/nlp/llm.py +322 -0
- torch_geometric/nn/nlp/sentence_transformer.py +134 -0
- torch_geometric/nn/nlp/vision_transformer.py +33 -0
- torch_geometric/nn/norm/batch_norm.py +1 -1
- torch_geometric/nn/parameter_dict.py +2 -2
- torch_geometric/nn/pool/__init__.py +7 -5
- torch_geometric/nn/pool/cluster_pool.py +145 -0
- torch_geometric/nn/pool/connect/base.py +0 -1
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/graclus.py +4 -2
- torch_geometric/nn/pool/select/base.py +0 -1
- torch_geometric/nn/pool/voxel_grid.py +3 -2
- torch_geometric/nn/resolver.py +1 -1
- torch_geometric/nn/sequential.jinja +10 -23
- torch_geometric/nn/sequential.py +203 -77
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +24 -15
- torch_geometric/resolver.py +1 -1
- torch_geometric/sampler/base.py +34 -13
- torch_geometric/sampler/neighbor_sampler.py +11 -10
- torch_geometric/testing/decorators.py +17 -22
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_metapaths.py +4 -4
- torch_geometric/transforms/add_positional_encoding.py +1 -1
- torch_geometric/transforms/delaunay.py +65 -14
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -6
- torch_geometric/transforms/laplacian_lambda_max.py +2 -2
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -2
- torch_geometric/transforms/pad.py +7 -6
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +31 -5
- torch_geometric/utils/__init__.py +5 -1
- torch_geometric/utils/_negative_sampling.py +1 -1
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +37 -12
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +5 -5
- torch_geometric/utils/geodesic.py +24 -22
- torch_geometric/utils/hetero.py +1 -1
- torch_geometric/utils/map.py +1 -1
- torch_geometric/utils/smiles.py +66 -28
- torch_geometric/utils/sparse.py +25 -10
- torch_geometric/visualization/graph.py +3 -4
torch_geometric/nn/sequential.py
CHANGED
@@ -1,25 +1,33 @@
|
|
1
|
+
import copy
|
2
|
+
import inspect
|
1
3
|
import os.path as osp
|
2
4
|
import random
|
3
|
-
|
5
|
+
import sys
|
6
|
+
from typing import (
|
7
|
+
Any,
|
8
|
+
Callable,
|
9
|
+
Dict,
|
10
|
+
List,
|
11
|
+
NamedTuple,
|
12
|
+
Optional,
|
13
|
+
Tuple,
|
14
|
+
Union,
|
15
|
+
)
|
4
16
|
|
5
17
|
import torch
|
6
18
|
from torch import Tensor
|
7
19
|
|
8
|
-
from torch_geometric.inspector import
|
20
|
+
from torch_geometric.inspector import Parameter, Signature, eval_type, split
|
9
21
|
from torch_geometric.template import module_from_template
|
10
22
|
|
11
23
|
|
12
24
|
class Child(NamedTuple):
|
13
25
|
name: str
|
14
|
-
module: Callable
|
15
26
|
param_names: List[str]
|
16
27
|
return_names: List[str]
|
17
28
|
|
18
29
|
|
19
|
-
|
20
|
-
input_args: str,
|
21
|
-
modules: List[Union[Tuple[Callable, str], Callable]],
|
22
|
-
) -> torch.nn.Module:
|
30
|
+
class Sequential(torch.nn.Module):
|
23
31
|
r"""An extension of the :class:`torch.nn.Sequential` container in order to
|
24
32
|
define a sequential GNN model.
|
25
33
|
|
@@ -74,74 +82,192 @@ def Sequential(
|
|
74
82
|
:obj:`OrderedDict` of modules (and function header definitions) can
|
75
83
|
be passed.
|
76
84
|
"""
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
else:
|
104
|
-
raise ValueError(f"Expected tuple of length 2 (got {module})")
|
105
|
-
|
106
|
-
if i == 0 and desc is None:
|
107
|
-
raise ValueError("Requires signature for first module")
|
108
|
-
if not callable(module):
|
109
|
-
raise ValueError(f"Expected callable module (got {module})")
|
110
|
-
if desc is not None and not isinstance(desc, str):
|
111
|
-
raise ValueError(f"Expected type hint representation (got {desc})")
|
112
|
-
|
113
|
-
if desc is not None:
|
114
|
-
signature = desc.split('->')
|
115
|
-
if len(signature) != 2:
|
116
|
-
raise ValueError(f"Failed to parse arguments (got '{desc}')")
|
117
|
-
param_names = [v.strip() for v in signature[0].split(',')]
|
118
|
-
return_names = [v.strip() for v in signature[1].split(',')]
|
119
|
-
child = Child(name, module, param_names, return_names)
|
85
|
+
_children: List[Child]
|
86
|
+
|
87
|
+
def __init__(
|
88
|
+
self,
|
89
|
+
input_args: str,
|
90
|
+
modules: List[Union[Tuple[Callable, str], Callable]],
|
91
|
+
) -> None:
|
92
|
+
super().__init__()
|
93
|
+
|
94
|
+
caller_path = inspect.stack()[1].filename
|
95
|
+
self._caller_module = osp.splitext(osp.basename(caller_path))[0]
|
96
|
+
|
97
|
+
_globals = copy.copy(globals())
|
98
|
+
_globals.update(sys.modules['__main__'].__dict__)
|
99
|
+
if self._caller_module in sys.modules:
|
100
|
+
_globals.update(sys.modules[self._caller_module].__dict__)
|
101
|
+
|
102
|
+
signature = input_args.split('->')
|
103
|
+
if len(signature) == 1:
|
104
|
+
args_repr = signature[0]
|
105
|
+
return_type_repr = 'Tensor'
|
106
|
+
return_type = Tensor
|
107
|
+
elif len(signature) == 2:
|
108
|
+
args_repr = signature[0]
|
109
|
+
return_type_repr = signature[1].strip()
|
110
|
+
return_type = eval_type(return_type_repr, _globals)
|
120
111
|
else:
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
112
|
+
raise ValueError(f"Failed to parse arguments (got '{input_args}')")
|
113
|
+
|
114
|
+
param_dict: Dict[str, Parameter] = {}
|
115
|
+
for arg in split(args_repr, sep=','):
|
116
|
+
signature = arg.split(':')
|
117
|
+
if len(signature) == 1:
|
118
|
+
name = signature[0].strip()
|
119
|
+
param_dict[name] = Parameter(
|
120
|
+
name=name,
|
121
|
+
type=Tensor,
|
122
|
+
type_repr='Tensor',
|
123
|
+
default=inspect._empty,
|
124
|
+
)
|
125
|
+
elif len(signature) == 2:
|
126
|
+
name = signature[0].strip()
|
127
|
+
param_dict[name] = Parameter(
|
128
|
+
name=name,
|
129
|
+
type=eval_type(signature[1].strip(), _globals),
|
130
|
+
type_repr=signature[1].strip(),
|
131
|
+
default=inspect._empty,
|
132
|
+
)
|
133
|
+
else:
|
134
|
+
raise ValueError(f"Failed to parse argument "
|
135
|
+
f"(got '{arg.strip()}')")
|
136
|
+
|
137
|
+
self.signature = Signature(param_dict, return_type, return_type_repr)
|
138
|
+
|
139
|
+
if not isinstance(modules, dict):
|
140
|
+
modules = {
|
141
|
+
f'module_{i}': module
|
142
|
+
for i, module in enumerate(modules)
|
143
|
+
}
|
144
|
+
if len(modules) == 0:
|
145
|
+
raise ValueError(f"'{self.__class__.__name__}' expects a "
|
146
|
+
f"non-empty list of modules")
|
147
|
+
|
148
|
+
self._children: List[Child] = []
|
149
|
+
for i, (name, module) in enumerate(modules.items()):
|
150
|
+
desc: Optional[str] = None
|
151
|
+
if isinstance(module, (tuple, list)):
|
152
|
+
if len(module) == 1:
|
153
|
+
module = module[0]
|
154
|
+
elif len(module) == 2:
|
155
|
+
module, desc = module
|
156
|
+
else:
|
157
|
+
raise ValueError(f"Expected tuple of length 2 "
|
158
|
+
f"(got {module})")
|
159
|
+
|
160
|
+
if i == 0 and desc is None:
|
161
|
+
raise ValueError("Signature for first module required")
|
162
|
+
if not callable(module):
|
163
|
+
raise ValueError(f"Expected callable module (got {module})")
|
164
|
+
if desc is not None and not isinstance(desc, str):
|
165
|
+
raise ValueError(f"Expected type hint representation "
|
166
|
+
f"(got {desc})")
|
167
|
+
|
168
|
+
if desc is not None:
|
169
|
+
signature = desc.split('->')
|
170
|
+
if len(signature) != 2:
|
171
|
+
raise ValueError(
|
172
|
+
f"Failed to parse arguments (got '{desc}')")
|
173
|
+
param_names = [v.strip() for v in signature[0].split(',')]
|
174
|
+
return_names = [v.strip() for v in signature[1].split(',')]
|
175
|
+
child = Child(name, param_names, return_names)
|
176
|
+
else:
|
177
|
+
param_names = self._children[-1].return_names
|
178
|
+
child = Child(name, param_names, param_names)
|
179
|
+
|
180
|
+
setattr(self, name, module)
|
181
|
+
self._children.append(child)
|
182
|
+
|
183
|
+
self._set_jittable_template()
|
184
|
+
|
185
|
+
def reset_parameters(self) -> None:
|
186
|
+
r"""Resets all learnable parameters of the module."""
|
187
|
+
for child in self._children:
|
188
|
+
module = getattr(self, child.name)
|
189
|
+
if hasattr(module, 'reset_parameters'):
|
190
|
+
module.reset_parameters()
|
191
|
+
|
192
|
+
def __len__(self) -> int:
|
193
|
+
return len(self._children)
|
194
|
+
|
195
|
+
def __getitem__(self, idx: int) -> torch.nn.Module:
|
196
|
+
return getattr(self, self._children[idx].name)
|
197
|
+
|
198
|
+
def __setstate__(self, data: Dict[str, Any]) -> None:
|
199
|
+
super().__setstate__(data)
|
200
|
+
self._set_jittable_template()
|
201
|
+
|
202
|
+
def __repr__(self) -> str:
|
203
|
+
module_descs = [
|
204
|
+
f"{', '.join(c.param_names)} -> {', '.join(c.return_names)}"
|
205
|
+
for c in self._children
|
206
|
+
]
|
207
|
+
module_reprs = [
|
208
|
+
f' ({i}) - {self[i]}: {module_descs[i]}' for i in range(len(self))
|
209
|
+
]
|
210
|
+
return '{}(\n{}\n)'.format(
|
211
|
+
self.__class__.__name__,
|
212
|
+
'\n'.join(module_reprs),
|
213
|
+
)
|
214
|
+
|
215
|
+
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
216
|
+
"""""" # noqa: D419
|
217
|
+
value_dict = {
|
218
|
+
name: arg
|
219
|
+
for name, arg in zip(self.signature.param_dict.keys(), args)
|
220
|
+
}
|
221
|
+
for key, arg in kwargs.items():
|
222
|
+
if key in value_dict:
|
223
|
+
raise TypeError(f"'{self.__class__.__name__}' got multiple "
|
224
|
+
f"values for argument '{key}'")
|
225
|
+
value_dict[key] = arg
|
226
|
+
|
227
|
+
for child in self._children:
|
228
|
+
args = [value_dict[name] for name in child.param_names]
|
229
|
+
outs = getattr(self, child.name)(*args)
|
230
|
+
if len(child.return_names) == 1:
|
231
|
+
value_dict[child.return_names[0]] = outs
|
232
|
+
else:
|
233
|
+
for name, out in zip(child.return_names, outs):
|
234
|
+
value_dict[name] = out
|
235
|
+
|
236
|
+
return outs
|
237
|
+
|
238
|
+
# TorchScript Support #####################################################
|
239
|
+
|
240
|
+
def _set_jittable_template(self, raise_on_error: bool = False) -> None:
|
241
|
+
try: # Optimize `forward()` via `*.jinja` templates:
|
242
|
+
if ('forward' in self.__class__.__dict__ and
|
243
|
+
self.__class__.__dict__['forward'] != Sequential.forward):
|
244
|
+
raise ValueError("Cannot compile custom 'forward' method")
|
245
|
+
|
246
|
+
root_dir = osp.dirname(osp.realpath(__file__))
|
247
|
+
uid = '%06x' % random.randrange(16**6)
|
248
|
+
jinja_prefix = f'{self.__module__}_{self.__class__.__name__}_{uid}'
|
249
|
+
module = module_from_template(
|
250
|
+
module_name=jinja_prefix,
|
251
|
+
template_path=osp.join(root_dir, 'sequential.jinja'),
|
252
|
+
tmp_dirname='sequential',
|
253
|
+
# Keyword arguments:
|
254
|
+
modules=[self._caller_module],
|
255
|
+
signature=self.signature,
|
256
|
+
children=self._children,
|
257
|
+
)
|
258
|
+
|
259
|
+
self.forward = module.forward.__get__(self)
|
260
|
+
|
261
|
+
# NOTE We override `forward` on the class level here in order to
|
262
|
+
# support `torch.jit.trace` - this is generally dangerous to do,
|
263
|
+
# and limits `torch.jit.trace` to a single `Sequential` module:
|
264
|
+
self.__class__.forward = module.forward
|
265
|
+
except Exception as e: # pragma: no cover
|
266
|
+
if raise_on_error:
|
267
|
+
raise e
|
268
|
+
|
269
|
+
def __prepare_scriptable__(self) -> 'Sequential':
|
270
|
+
# Prevent type sharing when scripting `Sequential` modules:
|
271
|
+
type_store = torch.jit._recursive.concrete_type_store.type_store
|
272
|
+
type_store.pop(self.__class__, None)
|
273
|
+
return self
|
torch_geometric/nn/summary.py
CHANGED
@@ -141,7 +141,7 @@ def get_shape(inputs: Any) -> str:
|
|
141
141
|
def postprocess(info_list: List[dict]) -> List[dict]:
|
142
142
|
for idx, info in enumerate(info_list):
|
143
143
|
depth = info['depth']
|
144
|
-
if idx > 0: # root module (0) is
|
144
|
+
if idx > 0: # root module (0) is excluded
|
145
145
|
if depth == 1:
|
146
146
|
prefix = '├─'
|
147
147
|
else:
|
@@ -272,7 +272,6 @@ class ToHeteroWithBasesTransformer(Transformer):
|
|
272
272
|
args=(value, self.find_by_name('edge_offset_dict')),
|
273
273
|
name=f'{value.name}__split')
|
274
274
|
|
275
|
-
pass
|
276
275
|
elif isinstance(value, Node):
|
277
276
|
self.graph.inserting_before(node)
|
278
277
|
return self.graph.create_node(
|
@@ -309,6 +308,24 @@ class ToHeteroWithBasesTransformer(Transformer):
|
|
309
308
|
###############################################################################
|
310
309
|
|
311
310
|
|
311
|
+
# We make use of a post-message computation hook to inject the
|
312
|
+
# basis re-weighting for each individual edge type.
|
313
|
+
# This currently requires us to set `conv.fuse = False`, which leads
|
314
|
+
# to a materialization of messages.
|
315
|
+
def hook(module, inputs, output):
|
316
|
+
assert isinstance(module._edge_type, Tensor)
|
317
|
+
if module._edge_type.size(0) != output.size(-2):
|
318
|
+
raise ValueError(
|
319
|
+
f"Number of messages ({output.size(0)}) does not match "
|
320
|
+
f"with the number of original edges "
|
321
|
+
f"({module._edge_type.size(0)}). Does your message "
|
322
|
+
f"passing layer create additional self-loops? Try to "
|
323
|
+
f"remove them via 'add_self_loops=False'")
|
324
|
+
weight = module.edge_type_weight.view(-1)[module._edge_type]
|
325
|
+
weight = weight.view([1] * (output.dim() - 2) + [-1, 1])
|
326
|
+
return weight * output
|
327
|
+
|
328
|
+
|
312
329
|
class HeteroBasisConv(torch.nn.Module):
|
313
330
|
# A wrapper layer that applies the basis-decomposition technique to a
|
314
331
|
# heterogeneous graph.
|
@@ -319,23 +336,6 @@ class HeteroBasisConv(torch.nn.Module):
|
|
319
336
|
self.num_relations = num_relations
|
320
337
|
self.num_bases = num_bases
|
321
338
|
|
322
|
-
# We make use of a post-message computation hook to inject the
|
323
|
-
# basis re-weighting for each individual edge type.
|
324
|
-
# This currently requires us to set `conv.fuse = False`, which leads
|
325
|
-
# to a materialization of messages.
|
326
|
-
def hook(module, inputs, output):
|
327
|
-
assert isinstance(module._edge_type, Tensor)
|
328
|
-
if module._edge_type.size(0) != output.size(-2):
|
329
|
-
raise ValueError(
|
330
|
-
f"Number of messages ({output.size(0)}) does not match "
|
331
|
-
f"with the number of original edges "
|
332
|
-
f"({module._edge_type.size(0)}). Does your message "
|
333
|
-
f"passing layer create additional self-loops? Try to "
|
334
|
-
f"remove them via 'add_self_loops=False'")
|
335
|
-
weight = module.edge_type_weight.view(-1)[module._edge_type]
|
336
|
-
weight = weight.view([1] * (output.dim() - 2) + [-1, 1])
|
337
|
-
return weight * output
|
338
|
-
|
339
339
|
params = list(module.parameters())
|
340
340
|
device = params[0].device if len(params) > 0 else 'cpu'
|
341
341
|
|
@@ -468,7 +468,7 @@ def get_edge_type(
|
|
468
468
|
###############################################################################
|
469
469
|
|
470
470
|
# These methods are used to group the individual type-wise components into a
|
471
|
-
#
|
471
|
+
# unified single representation.
|
472
472
|
|
473
473
|
|
474
474
|
def group_node_placeholder(input_dict: Dict[NodeType, Tensor],
|
@@ -20,6 +20,7 @@ from .utils import (
|
|
20
20
|
get_gpu_memory_from_nvidia_smi,
|
21
21
|
get_model_size,
|
22
22
|
)
|
23
|
+
from .nvtx import nvtxit
|
23
24
|
|
24
25
|
__all__ = [
|
25
26
|
'profileit',
|
@@ -38,6 +39,7 @@ __all__ = [
|
|
38
39
|
'get_gpu_memory_from_nvidia_smi',
|
39
40
|
'get_gpu_memory_from_ipex',
|
40
41
|
'benchmark',
|
42
|
+
'nvtxit',
|
41
43
|
]
|
42
44
|
|
43
45
|
classes = __all__
|
@@ -0,0 +1,66 @@
|
|
1
|
+
from functools import wraps
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import torch
|
5
|
+
|
6
|
+
CUDA_PROFILE_STARTED = False
|
7
|
+
|
8
|
+
|
9
|
+
def begin_cuda_profile():
|
10
|
+
global CUDA_PROFILE_STARTED
|
11
|
+
prev_state = CUDA_PROFILE_STARTED
|
12
|
+
if prev_state is False:
|
13
|
+
CUDA_PROFILE_STARTED = True
|
14
|
+
torch.cuda.cudart().cudaProfilerStart()
|
15
|
+
return prev_state
|
16
|
+
|
17
|
+
|
18
|
+
def end_cuda_profile(prev_state: bool):
|
19
|
+
global CUDA_PROFILE_STARTED
|
20
|
+
CUDA_PROFILE_STARTED = prev_state
|
21
|
+
if prev_state is False:
|
22
|
+
torch.cuda.cudart().cudaProfilerStop()
|
23
|
+
|
24
|
+
|
25
|
+
def nvtxit(name: Optional[str] = None, n_warmups: int = 0,
|
26
|
+
n_iters: Optional[int] = None):
|
27
|
+
"""Enables NVTX profiling for a function.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
name (Optional[str], optional): Name to give the reference frame for
|
31
|
+
the function being wrapped. Defaults to the name of the
|
32
|
+
function in code.
|
33
|
+
n_warmups (int, optional): Number of iters to call that function
|
34
|
+
before starting. Defaults to 0.
|
35
|
+
n_iters (Optional[int], optional): Number of iters of that function to
|
36
|
+
record. Defaults to all of them.
|
37
|
+
"""
|
38
|
+
def nvtx(func):
|
39
|
+
|
40
|
+
nonlocal name
|
41
|
+
iters_so_far = 0
|
42
|
+
if name is None:
|
43
|
+
name = func.__name__
|
44
|
+
|
45
|
+
@wraps(func)
|
46
|
+
def wrapper(*args, **kwargs):
|
47
|
+
nonlocal iters_so_far
|
48
|
+
if not torch.cuda.is_available():
|
49
|
+
return func(*args, **kwargs)
|
50
|
+
elif iters_so_far < n_warmups:
|
51
|
+
iters_so_far += 1
|
52
|
+
return func(*args, **kwargs)
|
53
|
+
elif n_iters is None or iters_so_far < n_iters + n_warmups:
|
54
|
+
prev_state = begin_cuda_profile()
|
55
|
+
torch.cuda.nvtx.range_push(f"{name}_{iters_so_far}")
|
56
|
+
result = func(*args, **kwargs)
|
57
|
+
torch.cuda.nvtx.range_pop()
|
58
|
+
end_cuda_profile(prev_state)
|
59
|
+
iters_so_far += 1
|
60
|
+
return result
|
61
|
+
else:
|
62
|
+
return func(*args, **kwargs)
|
63
|
+
|
64
|
+
return wrapper
|
65
|
+
|
66
|
+
return nvtx
|
@@ -5,6 +5,8 @@ from typing import Any, List, NamedTuple, Optional, Tuple
|
|
5
5
|
import torch
|
6
6
|
import torch.profiler as torch_profiler
|
7
7
|
|
8
|
+
import torch_geometric.typing
|
9
|
+
|
8
10
|
# predefined namedtuple for variable setting (global template)
|
9
11
|
Trace = namedtuple('Trace', ['path', 'leaf', 'module'])
|
10
12
|
|
@@ -325,6 +327,8 @@ def _flatten_tree(t, depth=0):
|
|
325
327
|
|
326
328
|
|
327
329
|
def _build_measure_tuple(events: List, occurrences: List) -> NamedTuple:
|
330
|
+
device_str = 'device' if torch_geometric.typing.WITH_PT24 else 'cuda'
|
331
|
+
|
328
332
|
# memory profiling supported in torch >= 1.6
|
329
333
|
self_cpu_memory = None
|
330
334
|
has_self_cpu_memory = any(
|
@@ -339,29 +343,34 @@ def _build_measure_tuple(events: List, occurrences: List) -> NamedTuple:
|
|
339
343
|
[getattr(e, "cpu_memory_usage", 0) or 0 for e in events])
|
340
344
|
self_cuda_memory = None
|
341
345
|
has_self_cuda_memory = any(
|
342
|
-
hasattr(e, "
|
346
|
+
hasattr(e, f"self_{device_str}_memory_usage") for e in events)
|
343
347
|
if has_self_cuda_memory:
|
344
|
-
self_cuda_memory = sum(
|
345
|
-
|
348
|
+
self_cuda_memory = sum([
|
349
|
+
getattr(e, f"self_{device_str}_memory_usage", 0) or 0
|
350
|
+
for e in events
|
351
|
+
])
|
346
352
|
cuda_memory = None
|
347
|
-
has_cuda_memory = any(
|
353
|
+
has_cuda_memory = any(
|
354
|
+
hasattr(e, f"{device_str}_memory_usage") for e in events)
|
348
355
|
if has_cuda_memory:
|
349
356
|
cuda_memory = sum(
|
350
|
-
[getattr(e, "
|
357
|
+
[getattr(e, f"{device_str}_memory_usage", 0) or 0 for e in events])
|
351
358
|
|
352
359
|
# self CUDA time supported in torch >= 1.7
|
353
360
|
self_cuda_total = None
|
354
361
|
has_self_cuda_time = any(
|
355
|
-
hasattr(e, "
|
362
|
+
hasattr(e, f"self_{device_str}_time_total") for e in events)
|
356
363
|
if has_self_cuda_time:
|
357
|
-
self_cuda_total = sum(
|
358
|
-
|
364
|
+
self_cuda_total = sum([
|
365
|
+
getattr(e, f"self_{device_str}_time_total", 0) or 0 for e in events
|
366
|
+
])
|
359
367
|
|
360
368
|
return Measure(
|
361
369
|
self_cpu_total=sum([e.self_cpu_time_total or 0 for e in events]),
|
362
370
|
cpu_total=sum([e.cpu_time_total or 0 for e in events]),
|
363
371
|
self_cuda_total=self_cuda_total,
|
364
|
-
cuda_total=sum(
|
372
|
+
cuda_total=sum(
|
373
|
+
[getattr(e, f"{device_str}_time_total") or 0 for e in events]),
|
365
374
|
self_cpu_memory=self_cpu_memory,
|
366
375
|
cpu_memory=cpu_memory,
|
367
376
|
self_cuda_memory=self_cuda_memory,
|
@@ -436,10 +445,10 @@ def format_time(time_us: int) -> str:
|
|
436
445
|
US_IN_SECOND = 1000.0 * 1000.0
|
437
446
|
US_IN_MS = 1000.0
|
438
447
|
if time_us >= US_IN_SECOND:
|
439
|
-
return '{:.3f}s'
|
448
|
+
return f'{time_us / US_IN_SECOND:.3f}s'
|
440
449
|
if time_us >= US_IN_MS:
|
441
|
-
return '{:.3f}ms'
|
442
|
-
return '{:.3f}us'
|
450
|
+
return f'{time_us / US_IN_MS:.3f}ms'
|
451
|
+
return f'{time_us:.3f}us'
|
443
452
|
|
444
453
|
|
445
454
|
def format_memory(nbytes: int) -> str:
|
@@ -448,10 +457,10 @@ def format_memory(nbytes: int) -> str:
|
|
448
457
|
MB = 1024 * KB
|
449
458
|
GB = 1024 * MB
|
450
459
|
if (abs(nbytes) >= GB):
|
451
|
-
return '{
|
460
|
+
return f'{nbytes * 1.0 / GB:.2f} Gb'
|
452
461
|
elif (abs(nbytes) >= MB):
|
453
|
-
return '{
|
462
|
+
return f'{nbytes * 1.0 / MB:.2f} Mb'
|
454
463
|
elif (abs(nbytes) >= KB):
|
455
|
-
return '{
|
464
|
+
return f'{nbytes * 1.0 / KB:.2f} Kb'
|
456
465
|
else:
|
457
466
|
return str(nbytes) + ' b'
|
torch_geometric/resolver.py
CHANGED
@@ -39,5 +39,5 @@ def resolver(
|
|
39
39
|
return obj
|
40
40
|
return cls
|
41
41
|
|
42
|
-
choices =
|
42
|
+
choices = {cls.__name__ for cls in classes} | set(class_dict.keys())
|
43
43
|
raise ValueError(f"Could not resolve '{query}' among choices {choices}")
|
torch_geometric/sampler/base.py
CHANGED
@@ -5,7 +5,7 @@ from abc import ABC
|
|
5
5
|
from collections import defaultdict
|
6
6
|
from dataclasses import dataclass
|
7
7
|
from enum import Enum
|
8
|
-
from typing import Any, Dict, List, Optional, Union
|
8
|
+
from typing import Any, Dict, List, Literal, Optional, Union
|
9
9
|
|
10
10
|
import torch
|
11
11
|
from torch import Tensor
|
@@ -425,6 +425,14 @@ class NumNeighbors:
|
|
425
425
|
else:
|
426
426
|
assert False
|
427
427
|
|
428
|
+
# Confirm that `values` only hold valid edge types:
|
429
|
+
if isinstance(self.values, dict):
|
430
|
+
edge_types_str = {EdgeTypeStr(key) for key in edge_types}
|
431
|
+
invalid_edge_types = set(self.values.keys()) - edge_types_str
|
432
|
+
if len(invalid_edge_types) > 0:
|
433
|
+
raise ValueError("Not all edge types specified in "
|
434
|
+
"'num_neighbors' exist in the graph")
|
435
|
+
|
428
436
|
out = {}
|
429
437
|
for edge_type in edge_types:
|
430
438
|
edge_type_str = EdgeTypeStr(edge_type)
|
@@ -444,7 +452,7 @@ class NumNeighbors:
|
|
444
452
|
out = copy.copy(self.values)
|
445
453
|
|
446
454
|
if isinstance(out, dict):
|
447
|
-
num_hops =
|
455
|
+
num_hops = {len(v) for v in out.values()}
|
448
456
|
if len(num_hops) > 1:
|
449
457
|
raise ValueError(f"Number of hops must be the same across all "
|
450
458
|
f"edge types (got {len(num_hops)} different "
|
@@ -533,24 +541,31 @@ class NegativeSampling(CastMixin):
|
|
533
541
|
destination nodes for each positive source node.
|
534
542
|
amount (int or float, optional): The ratio of sampled negative edges to
|
535
543
|
the number of positive edges. (default: :obj:`1`)
|
536
|
-
|
537
|
-
sampling of nodes. Does not
|
538
|
-
If not given, negative nodes will be sampled uniformly.
|
544
|
+
src_weight (torch.Tensor, optional): A node-level vector determining
|
545
|
+
the sampling of source nodes. Does not necessarily need to sum up
|
546
|
+
to one. If not given, negative nodes will be sampled uniformly.
|
547
|
+
(default: :obj:`None`)
|
548
|
+
dst_weight (torch.Tensor, optional): A node-level vector determining
|
549
|
+
the sampling of destination nodes. Does not necessarily need to sum
|
550
|
+
up to one. If not given, negative nodes will be sampled uniformly.
|
539
551
|
(default: :obj:`None`)
|
540
552
|
"""
|
541
553
|
mode: NegativeSamplingMode
|
542
554
|
amount: Union[int, float] = 1
|
543
|
-
|
555
|
+
src_weight: Optional[Tensor] = None
|
556
|
+
dst_weight: Optional[Tensor] = None
|
544
557
|
|
545
558
|
def __init__(
|
546
559
|
self,
|
547
560
|
mode: Union[NegativeSamplingMode, str],
|
548
561
|
amount: Union[int, float] = 1,
|
549
|
-
|
562
|
+
src_weight: Optional[Tensor] = None,
|
563
|
+
dst_weight: Optional[Tensor] = None,
|
550
564
|
):
|
551
565
|
self.mode = NegativeSamplingMode(mode)
|
552
566
|
self.amount = amount
|
553
|
-
self.
|
567
|
+
self.src_weight = src_weight
|
568
|
+
self.dst_weight = dst_weight
|
554
569
|
|
555
570
|
if self.amount <= 0:
|
556
571
|
raise ValueError(f"The attribute 'amount' needs to be positive "
|
@@ -571,22 +586,28 @@ class NegativeSampling(CastMixin):
|
|
571
586
|
def is_triplet(self) -> bool:
|
572
587
|
return self.mode == NegativeSamplingMode.triplet
|
573
588
|
|
574
|
-
def sample(
|
575
|
-
|
589
|
+
def sample(
|
590
|
+
self,
|
591
|
+
num_samples: int,
|
592
|
+
endpoint: Literal['src', 'dst'],
|
593
|
+
num_nodes: Optional[int] = None,
|
594
|
+
) -> Tensor:
|
576
595
|
r"""Generates :obj:`num_samples` negative samples."""
|
577
|
-
|
596
|
+
weight = self.src_weight if endpoint == 'src' else self.dst_weight
|
597
|
+
|
598
|
+
if weight is None:
|
578
599
|
if num_nodes is None:
|
579
600
|
raise ValueError(
|
580
601
|
f"Cannot sample negatives in '{self.__class__.__name__}' "
|
581
602
|
f"without passing the 'num_nodes' argument")
|
582
603
|
return torch.randint(num_nodes, (num_samples, ))
|
583
604
|
|
584
|
-
if num_nodes is not None and
|
605
|
+
if num_nodes is not None and weight.numel() != num_nodes:
|
585
606
|
raise ValueError(
|
586
607
|
f"The 'weight' attribute in '{self.__class__.__name__}' "
|
587
608
|
f"needs to match the number of nodes {num_nodes} "
|
588
609
|
f"(got {self.weight.numel()})")
|
589
|
-
return torch.multinomial(
|
610
|
+
return torch.multinomial(weight, num_samples, replacement=True)
|
590
611
|
|
591
612
|
|
592
613
|
class BaseSampler(ABC):
|