pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +34 -1
- torch_geometric/_compile.py +11 -3
- torch_geometric/_onnx.py +228 -0
- torch_geometric/config_mixin.py +8 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/__init__.py +19 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +110 -6
- torch_geometric/data/database.py +19 -5
- torch_geometric/data/dataset.py +14 -9
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +20 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +17 -20
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +2 -1
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +19 -5
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcqm4m.py +2 -1
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +3 -2
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +342 -0
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/datasets/wikidata.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +35 -22
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +89 -5
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +16 -7
- torch_geometric/inspector.py +6 -2
- torch_geometric/io/fs.py +27 -0
- torch_geometric/io/tu.py +2 -3
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/llm/models/g_retriever.py +251 -0
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/llm/models/llm.py +470 -0
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +6 -5
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +4 -3
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +23 -2
- torch_geometric/metrics/link_pred.py +755 -85
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/__init__.py +2 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +149 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/collect.jinja +6 -3
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gat_conv.py +33 -4
- torch_geometric/nn/conv/gatv2_conv.py +35 -4
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +3 -2
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +6 -5
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +5 -24
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +17 -15
- torch_geometric/nn/model_hub.py +5 -16
- torch_geometric/nn/models/__init__.py +11 -0
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +15 -9
- torch_geometric/nn/pool/cluster_pool.py +144 -0
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +18 -9
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +337 -8
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +298 -25
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +4 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +10 -8
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +8 -9
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/pad.py +1 -1
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +82 -17
- torch_geometric/utils/__init__.py +6 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +28 -13
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +126 -164
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +1 -1
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +28 -25
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +14 -10
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +32 -24
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/llm.py +0 -283
- torch_geometric/nn/nlp/sentence_transformer.py +0 -94
|
@@ -30,11 +30,13 @@ class LayerNorm(torch.nn.Module):
|
|
|
30
30
|
affine (bool, optional): If set to :obj:`True`, this module has
|
|
31
31
|
learnable affine parameters :math:`\gamma` and :math:`\beta`.
|
|
32
32
|
(default: :obj:`True`)
|
|
33
|
-
mode (str,
|
|
33
|
+
mode (str, optional): The normalization mode to use for layer
|
|
34
34
|
normalization (:obj:`"graph"` or :obj:`"node"`). If :obj:`"graph"`
|
|
35
35
|
is used, each graph will be considered as an element to be
|
|
36
36
|
normalized. If `"node"` is used, each node will be considered as
|
|
37
37
|
an element to be normalized. (default: :obj:`"graph"`)
|
|
38
|
+
device (torch.device, optional): The device to use for the module.
|
|
39
|
+
(default: :obj:`None`)
|
|
38
40
|
"""
|
|
39
41
|
def __init__(
|
|
40
42
|
self,
|
|
@@ -42,6 +44,7 @@ class LayerNorm(torch.nn.Module):
|
|
|
42
44
|
eps: float = 1e-5,
|
|
43
45
|
affine: bool = True,
|
|
44
46
|
mode: str = 'graph',
|
|
47
|
+
device: Optional[torch.device] = None,
|
|
45
48
|
):
|
|
46
49
|
super().__init__()
|
|
47
50
|
|
|
@@ -51,8 +54,8 @@ class LayerNorm(torch.nn.Module):
|
|
|
51
54
|
self.mode = mode
|
|
52
55
|
|
|
53
56
|
if affine:
|
|
54
|
-
self.weight = Parameter(torch.empty(in_channels))
|
|
55
|
-
self.bias = Parameter(torch.empty(in_channels))
|
|
57
|
+
self.weight = Parameter(torch.empty(in_channels, device=device))
|
|
58
|
+
self.bias = Parameter(torch.empty(in_channels, device=device))
|
|
56
59
|
else:
|
|
57
60
|
self.register_parameter('weight', None)
|
|
58
61
|
self.register_parameter('bias', None)
|
|
@@ -108,7 +111,7 @@ class LayerNorm(torch.nn.Module):
|
|
|
108
111
|
return F.layer_norm(x, (self.in_channels, ), self.weight,
|
|
109
112
|
self.bias, self.eps)
|
|
110
113
|
|
|
111
|
-
raise ValueError(f"
|
|
114
|
+
raise ValueError(f"Unknownn normalization mode: {self.mode}")
|
|
112
115
|
|
|
113
116
|
def __repr__(self):
|
|
114
117
|
return (f'{self.__class__.__name__}({self.in_channels}, '
|
|
@@ -130,10 +133,12 @@ class HeteroLayerNorm(torch.nn.Module):
|
|
|
130
133
|
affine (bool, optional): If set to :obj:`True`, this module has
|
|
131
134
|
learnable affine parameters :math:`\gamma` and :math:`\beta`.
|
|
132
135
|
(default: :obj:`True`)
|
|
133
|
-
mode (str,
|
|
136
|
+
mode (str, optional): The normalization mode to use for layer
|
|
134
137
|
normalization (:obj:`"node"`). If `"node"` is used, each node will
|
|
135
138
|
be considered as an element to be normalized.
|
|
136
139
|
(default: :obj:`"node"`)
|
|
140
|
+
device (torch.device, optional): The device to use for the module.
|
|
141
|
+
(default: :obj:`None`)
|
|
137
142
|
"""
|
|
138
143
|
def __init__(
|
|
139
144
|
self,
|
|
@@ -142,6 +147,7 @@ class HeteroLayerNorm(torch.nn.Module):
|
|
|
142
147
|
eps: float = 1e-5,
|
|
143
148
|
affine: bool = True,
|
|
144
149
|
mode: str = 'node',
|
|
150
|
+
device: Optional[torch.device] = None,
|
|
145
151
|
):
|
|
146
152
|
super().__init__()
|
|
147
153
|
assert mode == 'node'
|
|
@@ -152,8 +158,10 @@ class HeteroLayerNorm(torch.nn.Module):
|
|
|
152
158
|
self.affine = affine
|
|
153
159
|
|
|
154
160
|
if affine:
|
|
155
|
-
self.weight = Parameter(
|
|
156
|
-
|
|
161
|
+
self.weight = Parameter(
|
|
162
|
+
torch.empty(num_types, in_channels, device=device))
|
|
163
|
+
self.bias = Parameter(
|
|
164
|
+
torch.empty(num_types, in_channels, device=device))
|
|
157
165
|
else:
|
|
158
166
|
self.register_parameter('weight', None)
|
|
159
167
|
self.register_parameter('bias', None)
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
import torch.nn.functional as F
|
|
3
5
|
from torch import Tensor
|
|
@@ -19,10 +21,14 @@ class MessageNorm(torch.nn.Module):
|
|
|
19
21
|
learn_scale (bool, optional): If set to :obj:`True`, will learn the
|
|
20
22
|
scaling factor :math:`s` of message normalization.
|
|
21
23
|
(default: :obj:`False`)
|
|
24
|
+
device (torch.device, optional): The device to use for the module.
|
|
25
|
+
(default: :obj:`None`)
|
|
22
26
|
"""
|
|
23
|
-
def __init__(self, learn_scale: bool = False
|
|
27
|
+
def __init__(self, learn_scale: bool = False,
|
|
28
|
+
device: Optional[torch.device] = None):
|
|
24
29
|
super().__init__()
|
|
25
|
-
self.scale = Parameter(torch.empty(1
|
|
30
|
+
self.scale = Parameter(torch.empty(1, device=device),
|
|
31
|
+
requires_grad=learn_scale)
|
|
26
32
|
self.reset_parameters()
|
|
27
33
|
|
|
28
34
|
def reset_parameters(self):
|
|
@@ -7,18 +7,19 @@ from torch import Tensor
|
|
|
7
7
|
import torch_geometric.typing
|
|
8
8
|
from torch_geometric.typing import OptTensor, torch_cluster
|
|
9
9
|
|
|
10
|
-
from .asap import ASAPooling
|
|
11
10
|
from .avg_pool import avg_pool, avg_pool_neighbor_x, avg_pool_x
|
|
12
|
-
from .edge_pool import EdgePooling
|
|
13
11
|
from .glob import global_add_pool, global_max_pool, global_mean_pool
|
|
14
12
|
from .knn import (KNNIndex, L2KNNIndex, MIPSKNNIndex, ApproxL2KNNIndex,
|
|
15
13
|
ApproxMIPSKNNIndex)
|
|
16
14
|
from .graclus import graclus
|
|
17
15
|
from .max_pool import max_pool, max_pool_neighbor_x, max_pool_x
|
|
18
|
-
from .mem_pool import MemPooling
|
|
19
|
-
from .pan_pool import PANPooling
|
|
20
|
-
from .sag_pool import SAGPooling
|
|
21
16
|
from .topk_pool import TopKPooling
|
|
17
|
+
from .sag_pool import SAGPooling
|
|
18
|
+
from .edge_pool import EdgePooling
|
|
19
|
+
from .cluster_pool import ClusterPooling
|
|
20
|
+
from .asap import ASAPooling
|
|
21
|
+
from .pan_pool import PANPooling
|
|
22
|
+
from .mem_pool import MemPooling
|
|
22
23
|
from .voxel_grid import voxel_grid
|
|
23
24
|
from .approx_knn import approx_knn, approx_knn_graph
|
|
24
25
|
|
|
@@ -162,8 +163,10 @@ def knn_graph(
|
|
|
162
163
|
:rtype: :class:`torch.Tensor`
|
|
163
164
|
"""
|
|
164
165
|
if batch is not None and x.device != batch.device:
|
|
165
|
-
warnings.warn(
|
|
166
|
-
|
|
166
|
+
warnings.warn(
|
|
167
|
+
"Input tensor 'x' and 'batch' are on different devices "
|
|
168
|
+
"in 'knn_graph'. Performing blocking device transfer",
|
|
169
|
+
stacklevel=2)
|
|
167
170
|
batch = batch.to(x.device)
|
|
168
171
|
|
|
169
172
|
if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE:
|
|
@@ -284,8 +287,10 @@ def radius_graph(
|
|
|
284
287
|
inputs to GPU before proceeding.
|
|
285
288
|
"""
|
|
286
289
|
if batch is not None and x.device != batch.device:
|
|
287
|
-
warnings.warn(
|
|
288
|
-
|
|
290
|
+
warnings.warn(
|
|
291
|
+
"Input tensor 'x' and 'batch' are on different devices "
|
|
292
|
+
"in 'radius_graph'. Performing blocking device transfer",
|
|
293
|
+
stacklevel=2)
|
|
289
294
|
batch = batch.to(x.device)
|
|
290
295
|
|
|
291
296
|
if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE:
|
|
@@ -344,6 +349,7 @@ __all__ = [
|
|
|
344
349
|
'TopKPooling',
|
|
345
350
|
'SAGPooling',
|
|
346
351
|
'EdgePooling',
|
|
352
|
+
'ClusterPooling',
|
|
347
353
|
'ASAPooling',
|
|
348
354
|
'PANPooling',
|
|
349
355
|
'MemPooling',
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
from typing import NamedTuple, Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
from torch_geometric.utils import (
|
|
8
|
+
dense_to_sparse,
|
|
9
|
+
one_hot,
|
|
10
|
+
to_dense_adj,
|
|
11
|
+
to_scipy_sparse_matrix,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class UnpoolInfo(NamedTuple):
|
|
16
|
+
edge_index: Tensor
|
|
17
|
+
cluster: Tensor
|
|
18
|
+
batch: Tensor
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ClusterPooling(torch.nn.Module):
|
|
22
|
+
r"""The cluster pooling operator from the `"Edge-Based Graph Component
|
|
23
|
+
Pooling" <https://arxiv.org/abs/2409.11856>`_ paper.
|
|
24
|
+
:class:`ClusterPooling` computes a score for each edge.
|
|
25
|
+
Based on the selected edges, graph clusters are calculated and compressed
|
|
26
|
+
to one node using the injective :obj:`"sum"` aggregation function.
|
|
27
|
+
Edges are remapped based on the nodes created by each cluster and the
|
|
28
|
+
original edges.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
in_channels (int): Size of each input sample.
|
|
32
|
+
edge_score_method (str, optional): The function to apply
|
|
33
|
+
to compute the edge score from raw edge scores (:obj:`"tanh"`,
|
|
34
|
+
:obj:`"sigmoid"`, :obj:`"log_softmax"`). (default: :obj:`"tanh"`)
|
|
35
|
+
dropout (float, optional): The probability with
|
|
36
|
+
which to drop edge scores during training. (default: :obj:`0.0`)
|
|
37
|
+
threshold (float, optional): The threshold of edge scores. If set to
|
|
38
|
+
:obj:`None`, will be automatically inferred depending on
|
|
39
|
+
:obj:`edge_score_method`. (default: :obj:`None`)
|
|
40
|
+
"""
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
in_channels: int,
|
|
44
|
+
edge_score_method: str = 'tanh',
|
|
45
|
+
dropout: float = 0.0,
|
|
46
|
+
threshold: Optional[float] = None,
|
|
47
|
+
):
|
|
48
|
+
super().__init__()
|
|
49
|
+
assert edge_score_method in ['tanh', 'sigmoid', 'log_softmax']
|
|
50
|
+
|
|
51
|
+
if threshold is None:
|
|
52
|
+
threshold = 0.5 if edge_score_method == 'sigmoid' else 0.0
|
|
53
|
+
|
|
54
|
+
self.in_channels = in_channels
|
|
55
|
+
self.edge_score_method = edge_score_method
|
|
56
|
+
self.dropout = dropout
|
|
57
|
+
self.threshold = threshold
|
|
58
|
+
|
|
59
|
+
self.lin = torch.nn.Linear(2 * in_channels, 1)
|
|
60
|
+
|
|
61
|
+
def reset_parameters(self):
|
|
62
|
+
r"""Resets all learnable parameters of the module."""
|
|
63
|
+
self.lin.reset_parameters()
|
|
64
|
+
|
|
65
|
+
def forward(
|
|
66
|
+
self,
|
|
67
|
+
x: Tensor,
|
|
68
|
+
edge_index: Tensor,
|
|
69
|
+
batch: Tensor,
|
|
70
|
+
) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]:
|
|
71
|
+
r"""Forward pass.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
x (torch.Tensor): The node features.
|
|
75
|
+
edge_index (torch.Tensor): The edge indices.
|
|
76
|
+
batch (torch.Tensor): Batch vector
|
|
77
|
+
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
|
|
78
|
+
each node to a specific example.
|
|
79
|
+
|
|
80
|
+
Return types:
|
|
81
|
+
* **x** *(torch.Tensor)* - The pooled node features.
|
|
82
|
+
* **edge_index** *(torch.Tensor)* - The coarsened edge indices.
|
|
83
|
+
* **batch** *(torch.Tensor)* - The coarsened batch vector.
|
|
84
|
+
* **unpool_info** *(UnpoolInfo)* - Information that can be consumed
|
|
85
|
+
for unpooling.
|
|
86
|
+
"""
|
|
87
|
+
mask = edge_index[0] != edge_index[1]
|
|
88
|
+
edge_index = edge_index[:, mask]
|
|
89
|
+
|
|
90
|
+
edge_attr = torch.cat(
|
|
91
|
+
[x[edge_index[0]], x[edge_index[1]]],
|
|
92
|
+
dim=-1,
|
|
93
|
+
)
|
|
94
|
+
edge_score = self.lin(edge_attr).view(-1)
|
|
95
|
+
edge_score = F.dropout(edge_score, p=self.dropout,
|
|
96
|
+
training=self.training)
|
|
97
|
+
|
|
98
|
+
if self.edge_score_method == 'tanh':
|
|
99
|
+
edge_score = edge_score.tanh()
|
|
100
|
+
elif self.edge_score_method == 'sigmoid':
|
|
101
|
+
edge_score = edge_score.sigmoid()
|
|
102
|
+
else:
|
|
103
|
+
assert self.edge_score_method == 'log_softmax'
|
|
104
|
+
edge_score = F.log_softmax(edge_score, dim=0)
|
|
105
|
+
|
|
106
|
+
return self._merge_edges(x, edge_index, batch, edge_score)
|
|
107
|
+
|
|
108
|
+
def _merge_edges(
|
|
109
|
+
self,
|
|
110
|
+
x: Tensor,
|
|
111
|
+
edge_index: Tensor,
|
|
112
|
+
batch: Tensor,
|
|
113
|
+
edge_score: Tensor,
|
|
114
|
+
) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]:
|
|
115
|
+
|
|
116
|
+
from scipy.sparse.csgraph import connected_components
|
|
117
|
+
|
|
118
|
+
edge_contract = edge_index[:, edge_score > self.threshold]
|
|
119
|
+
|
|
120
|
+
adj = to_scipy_sparse_matrix(edge_contract, num_nodes=x.size(0))
|
|
121
|
+
_, cluster_np = connected_components(adj, directed=True,
|
|
122
|
+
connection="weak")
|
|
123
|
+
|
|
124
|
+
cluster = torch.tensor(cluster_np, dtype=torch.long, device=x.device)
|
|
125
|
+
C = one_hot(cluster)
|
|
126
|
+
A = to_dense_adj(edge_index, max_num_nodes=x.size(0)).squeeze(0)
|
|
127
|
+
S = to_dense_adj(edge_index, edge_attr=edge_score,
|
|
128
|
+
max_num_nodes=x.size(0)).squeeze(0)
|
|
129
|
+
|
|
130
|
+
A_contract = to_dense_adj(edge_contract,
|
|
131
|
+
max_num_nodes=x.size(0)).squeeze(0)
|
|
132
|
+
nodes_single = ((A_contract.sum(dim=-1) +
|
|
133
|
+
A_contract.sum(dim=-2)) == 0).nonzero()
|
|
134
|
+
S[nodes_single, nodes_single] = 1.0
|
|
135
|
+
|
|
136
|
+
x_out = (S @ C).t() @ x
|
|
137
|
+
edge_index_out, _ = dense_to_sparse((C.T @ A @ C).fill_diagonal_(0))
|
|
138
|
+
batch_out = batch.new_empty(x_out.size(0)).scatter_(0, cluster, batch)
|
|
139
|
+
unpool_info = UnpoolInfo(edge_index, cluster, batch)
|
|
140
|
+
|
|
141
|
+
return x_out, edge_index_out, batch_out, unpool_info
|
|
142
|
+
|
|
143
|
+
def __repr__(self) -> str:
|
|
144
|
+
return f'{self.__class__.__name__}({self.in_channels})'
|
|
@@ -4,7 +4,6 @@ from typing import Optional
|
|
|
4
4
|
import torch
|
|
5
5
|
from torch import Tensor
|
|
6
6
|
|
|
7
|
-
import torch_geometric.typing
|
|
8
7
|
from torch_geometric.nn.pool.select import SelectOutput
|
|
9
8
|
|
|
10
9
|
|
|
@@ -49,8 +48,7 @@ class ConnectOutput:
|
|
|
49
48
|
self.batch = batch
|
|
50
49
|
|
|
51
50
|
|
|
52
|
-
|
|
53
|
-
ConnectOutput = torch.jit.script(ConnectOutput)
|
|
51
|
+
ConnectOutput = torch.jit.script(ConnectOutput)
|
|
54
52
|
|
|
55
53
|
|
|
56
54
|
class Connect(torch.nn.Module):
|
torch_geometric/nn/pool/knn.py
CHANGED
|
@@ -91,9 +91,10 @@ class KNNIndex:
|
|
|
91
91
|
if hasattr(self.index, 'reserveMemory'):
|
|
92
92
|
self.index.reserveMemory(self.reserve)
|
|
93
93
|
else:
|
|
94
|
-
warnings.warn(
|
|
95
|
-
|
|
96
|
-
|
|
94
|
+
warnings.warn(
|
|
95
|
+
f"'{self.index.__class__.__name__}' "
|
|
96
|
+
f"does not support pre-allocation of "
|
|
97
|
+
f"memory", stacklevel=2)
|
|
97
98
|
|
|
98
99
|
self.index.train(emb)
|
|
99
100
|
|
|
@@ -135,14 +136,16 @@ class KNNIndex:
|
|
|
135
136
|
query_k = min(query_k, self.numel)
|
|
136
137
|
|
|
137
138
|
if k > 2048: # `faiss` supports up-to `k=2048`:
|
|
138
|
-
warnings.warn(
|
|
139
|
-
|
|
140
|
-
|
|
139
|
+
warnings.warn(
|
|
140
|
+
f"Capping 'k' to faiss' upper limit of 2048 "
|
|
141
|
+
f"(got {k}). This may cause some relevant items to "
|
|
142
|
+
f"not be retrieved.", stacklevel=2)
|
|
141
143
|
elif query_k > 2048:
|
|
142
|
-
warnings.warn(
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
144
|
+
warnings.warn(
|
|
145
|
+
f"Capping 'k' to faiss' upper limit of 2048 "
|
|
146
|
+
f"(got {k} which got extended to {query_k} due to "
|
|
147
|
+
f"the exclusion of existing links). This may cause "
|
|
148
|
+
f"some relevant items to not be retrieved.", stacklevel=2)
|
|
146
149
|
query_k = 2048
|
|
147
150
|
|
|
148
151
|
score, index = self.index.search(emb.detach(), query_k)
|
|
@@ -4,8 +4,6 @@ from typing import Optional
|
|
|
4
4
|
import torch
|
|
5
5
|
from torch import Tensor
|
|
6
6
|
|
|
7
|
-
import torch_geometric.typing
|
|
8
|
-
|
|
9
7
|
|
|
10
8
|
@dataclass(init=False)
|
|
11
9
|
class SelectOutput:
|
|
@@ -64,8 +62,7 @@ class SelectOutput:
|
|
|
64
62
|
self.weight = weight
|
|
65
63
|
|
|
66
64
|
|
|
67
|
-
|
|
68
|
-
SelectOutput = torch.jit.script(SelectOutput)
|
|
65
|
+
SelectOutput = torch.jit.script(SelectOutput)
|
|
69
66
|
|
|
70
67
|
|
|
71
68
|
class Select(torch.nn.Module):
|
torch_geometric/nn/summary.py
CHANGED
|
@@ -141,7 +141,7 @@ def get_shape(inputs: Any) -> str:
|
|
|
141
141
|
def postprocess(info_list: List[dict]) -> List[dict]:
|
|
142
142
|
for idx, info in enumerate(info_list):
|
|
143
143
|
depth = info['depth']
|
|
144
|
-
if idx > 0: # root module (0) is
|
|
144
|
+
if idx > 0: # root module (0) is excluded
|
|
145
145
|
if depth == 1:
|
|
146
146
|
prefix = '├─'
|
|
147
147
|
else:
|
|
@@ -108,9 +108,10 @@ class ToHeteroMessagePassing(torch.nn.Module):
|
|
|
108
108
|
|
|
109
109
|
if (not hasattr(module, 'reset_parameters')
|
|
110
110
|
and sum([p.numel() for p in module.parameters()]) > 0):
|
|
111
|
-
warnings.warn(
|
|
112
|
-
|
|
113
|
-
|
|
111
|
+
warnings.warn(
|
|
112
|
+
f"'{module}' will be duplicated, but its parameters "
|
|
113
|
+
f"cannot be reset. To suppress this warning, add a "
|
|
114
|
+
f"'reset_parameters()' method to '{module}'", stacklevel=2)
|
|
114
115
|
|
|
115
116
|
convs = {edge_type: copy.deepcopy(module) for edge_type in edge_types}
|
|
116
117
|
self.hetero_module = HeteroConv(convs, aggr)
|
|
@@ -157,7 +157,7 @@ class ToHeteroTransformer(Transformer):
|
|
|
157
157
|
f"There exist node types ({unused_node_types}) whose "
|
|
158
158
|
f"representations do not get updated during message passing "
|
|
159
159
|
f"as they do not occur as destination type in any edge type. "
|
|
160
|
-
f"This may lead to unexpected behavior.")
|
|
160
|
+
f"This may lead to unexpected behavior.", stacklevel=2)
|
|
161
161
|
|
|
162
162
|
names = self.metadata[0] + [rel for _, rel, _ in self.metadata[1]]
|
|
163
163
|
for name in names:
|
|
@@ -166,7 +166,7 @@ class ToHeteroTransformer(Transformer):
|
|
|
166
166
|
f"The type '{name}' contains invalid characters which "
|
|
167
167
|
f"may lead to unexpected behavior. To avoid any issues, "
|
|
168
168
|
f"ensure that your types only contain letters, numbers "
|
|
169
|
-
f"and underscores.")
|
|
169
|
+
f"and underscores.", stacklevel=2)
|
|
170
170
|
|
|
171
171
|
def placeholder(self, node: Node, target: Any, name: str):
|
|
172
172
|
# Adds a `get` call to the input dictionary for every node-type or
|
|
@@ -379,7 +379,7 @@ class ToHeteroTransformer(Transformer):
|
|
|
379
379
|
warnings.warn(
|
|
380
380
|
f"'{target}' will be duplicated, but its parameters "
|
|
381
381
|
f"cannot be reset. To suppress this warning, add a "
|
|
382
|
-
f"'reset_parameters()' method to '{target}'")
|
|
382
|
+
f"'reset_parameters()' method to '{target}'", stacklevel=2)
|
|
383
383
|
|
|
384
384
|
return module_dict
|
|
385
385
|
|
|
@@ -165,7 +165,7 @@ class ToHeteroWithBasesTransformer(Transformer):
|
|
|
165
165
|
f"There exist node types ({unused_node_types}) whose "
|
|
166
166
|
f"representations do not get updated during message passing "
|
|
167
167
|
f"as they do not occur as destination type in any edge type. "
|
|
168
|
-
f"This may lead to unexpected behavior.")
|
|
168
|
+
f"This may lead to unexpected behavior.", stacklevel=2)
|
|
169
169
|
|
|
170
170
|
names = self.metadata[0] + [rel for _, rel, _ in self.metadata[1]]
|
|
171
171
|
for name in names:
|
|
@@ -174,7 +174,7 @@ class ToHeteroWithBasesTransformer(Transformer):
|
|
|
174
174
|
f"The type '{name}' contains invalid characters which "
|
|
175
175
|
f"may lead to unexpected behavior. To avoid any issues, "
|
|
176
176
|
f"ensure that your types only contain letters, numbers "
|
|
177
|
-
f"and underscores.")
|
|
177
|
+
f"and underscores.", stacklevel=2)
|
|
178
178
|
|
|
179
179
|
def transform(self) -> GraphModule:
|
|
180
180
|
self._node_offset_dict_initialized = False
|
|
@@ -361,7 +361,7 @@ class HeteroBasisConv(torch.nn.Module):
|
|
|
361
361
|
warnings.warn(
|
|
362
362
|
f"'{conv}' will be duplicated, but its parameters cannot "
|
|
363
363
|
f"be reset. To suppress this warning, add a "
|
|
364
|
-
f"'reset_parameters()' method to '{conv}'")
|
|
364
|
+
f"'reset_parameters()' method to '{conv}'", stacklevel=2)
|
|
365
365
|
torch.nn.init.xavier_uniform_(conv.edge_type_weight)
|
|
366
366
|
|
|
367
367
|
def forward(self, edge_type: Tensor, *args, **kwargs) -> Tensor:
|
|
@@ -380,7 +380,7 @@ class HeteroBasisConv(torch.nn.Module):
|
|
|
380
380
|
|
|
381
381
|
|
|
382
382
|
class LinearAlign(torch.nn.Module):
|
|
383
|
-
# Aligns
|
|
383
|
+
# Aligns representations to the same dimensionality. Note that this will
|
|
384
384
|
# create lazy modules, and as such requires a forward pass in order to
|
|
385
385
|
# initialize parameters.
|
|
386
386
|
def __init__(self, keys: List[Union[NodeType, EdgeType]],
|
|
@@ -468,7 +468,7 @@ def get_edge_type(
|
|
|
468
468
|
###############################################################################
|
|
469
469
|
|
|
470
470
|
# These methods are used to group the individual type-wise components into a
|
|
471
|
-
#
|
|
471
|
+
# unified single representation.
|
|
472
472
|
|
|
473
473
|
|
|
474
474
|
def group_node_placeholder(input_dict: Dict[NodeType, Tensor],
|
|
@@ -20,6 +20,7 @@ from .utils import (
|
|
|
20
20
|
get_gpu_memory_from_nvidia_smi,
|
|
21
21
|
get_model_size,
|
|
22
22
|
)
|
|
23
|
+
from .nvtx import nvtxit
|
|
23
24
|
|
|
24
25
|
__all__ = [
|
|
25
26
|
'profileit',
|
|
@@ -38,6 +39,7 @@ __all__ = [
|
|
|
38
39
|
'get_gpu_memory_from_nvidia_smi',
|
|
39
40
|
'get_gpu_memory_from_ipex',
|
|
40
41
|
'benchmark',
|
|
42
|
+
'nvtxit',
|
|
41
43
|
]
|
|
42
44
|
|
|
43
45
|
classes = __all__
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from functools import wraps
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
CUDA_PROFILE_STARTED = False
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def begin_cuda_profile():
|
|
10
|
+
global CUDA_PROFILE_STARTED
|
|
11
|
+
prev_state = CUDA_PROFILE_STARTED
|
|
12
|
+
if prev_state is False:
|
|
13
|
+
CUDA_PROFILE_STARTED = True
|
|
14
|
+
torch.cuda.cudart().cudaProfilerStart()
|
|
15
|
+
return prev_state
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def end_cuda_profile(prev_state: bool):
|
|
19
|
+
global CUDA_PROFILE_STARTED
|
|
20
|
+
CUDA_PROFILE_STARTED = prev_state
|
|
21
|
+
if prev_state is False:
|
|
22
|
+
torch.cuda.cudart().cudaProfilerStop()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def nvtxit(name: Optional[str] = None, n_warmups: int = 0,
|
|
26
|
+
n_iters: Optional[int] = None):
|
|
27
|
+
"""Enables NVTX profiling for a function.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
name (Optional[str], optional): Name to give the reference frame for
|
|
31
|
+
the function being wrapped. Defaults to the name of the
|
|
32
|
+
function in code.
|
|
33
|
+
n_warmups (int, optional): Number of iters to call that function
|
|
34
|
+
before starting. Defaults to 0.
|
|
35
|
+
n_iters (Optional[int], optional): Number of iters of that function to
|
|
36
|
+
record. Defaults to all of them.
|
|
37
|
+
"""
|
|
38
|
+
def nvtx(func):
|
|
39
|
+
|
|
40
|
+
nonlocal name
|
|
41
|
+
iters_so_far = 0
|
|
42
|
+
if name is None:
|
|
43
|
+
name = func.__name__
|
|
44
|
+
|
|
45
|
+
@wraps(func)
|
|
46
|
+
def wrapper(*args, **kwargs):
|
|
47
|
+
nonlocal iters_so_far
|
|
48
|
+
if not torch.cuda.is_available():
|
|
49
|
+
return func(*args, **kwargs)
|
|
50
|
+
elif iters_so_far < n_warmups:
|
|
51
|
+
iters_so_far += 1
|
|
52
|
+
return func(*args, **kwargs)
|
|
53
|
+
elif n_iters is None or iters_so_far < n_iters + n_warmups:
|
|
54
|
+
prev_state = begin_cuda_profile()
|
|
55
|
+
torch.cuda.nvtx.range_push(f"{name}_{iters_so_far}")
|
|
56
|
+
result = func(*args, **kwargs)
|
|
57
|
+
torch.cuda.nvtx.range_pop()
|
|
58
|
+
end_cuda_profile(prev_state)
|
|
59
|
+
iters_so_far += 1
|
|
60
|
+
return result
|
|
61
|
+
else:
|
|
62
|
+
return func(*args, **kwargs)
|
|
63
|
+
|
|
64
|
+
return wrapper
|
|
65
|
+
|
|
66
|
+
return nvtx
|
|
@@ -5,6 +5,8 @@ from typing import Any, List, NamedTuple, Optional, Tuple
|
|
|
5
5
|
import torch
|
|
6
6
|
import torch.profiler as torch_profiler
|
|
7
7
|
|
|
8
|
+
import torch_geometric.typing
|
|
9
|
+
|
|
8
10
|
# predefined namedtuple for variable setting (global template)
|
|
9
11
|
Trace = namedtuple('Trace', ['path', 'leaf', 'module'])
|
|
10
12
|
|
|
@@ -325,6 +327,8 @@ def _flatten_tree(t, depth=0):
|
|
|
325
327
|
|
|
326
328
|
|
|
327
329
|
def _build_measure_tuple(events: List, occurrences: List) -> NamedTuple:
|
|
330
|
+
device_str = 'device' if torch_geometric.typing.WITH_PT24 else 'cuda'
|
|
331
|
+
|
|
328
332
|
# memory profiling supported in torch >= 1.6
|
|
329
333
|
self_cpu_memory = None
|
|
330
334
|
has_self_cpu_memory = any(
|
|
@@ -339,29 +343,34 @@ def _build_measure_tuple(events: List, occurrences: List) -> NamedTuple:
|
|
|
339
343
|
[getattr(e, "cpu_memory_usage", 0) or 0 for e in events])
|
|
340
344
|
self_cuda_memory = None
|
|
341
345
|
has_self_cuda_memory = any(
|
|
342
|
-
hasattr(e, "
|
|
346
|
+
hasattr(e, f"self_{device_str}_memory_usage") for e in events)
|
|
343
347
|
if has_self_cuda_memory:
|
|
344
|
-
self_cuda_memory = sum(
|
|
345
|
-
|
|
348
|
+
self_cuda_memory = sum([
|
|
349
|
+
getattr(e, f"self_{device_str}_memory_usage", 0) or 0
|
|
350
|
+
for e in events
|
|
351
|
+
])
|
|
346
352
|
cuda_memory = None
|
|
347
|
-
has_cuda_memory = any(
|
|
353
|
+
has_cuda_memory = any(
|
|
354
|
+
hasattr(e, f"{device_str}_memory_usage") for e in events)
|
|
348
355
|
if has_cuda_memory:
|
|
349
356
|
cuda_memory = sum(
|
|
350
|
-
[getattr(e, "
|
|
357
|
+
[getattr(e, f"{device_str}_memory_usage", 0) or 0 for e in events])
|
|
351
358
|
|
|
352
359
|
# self CUDA time supported in torch >= 1.7
|
|
353
360
|
self_cuda_total = None
|
|
354
361
|
has_self_cuda_time = any(
|
|
355
|
-
hasattr(e, "
|
|
362
|
+
hasattr(e, f"self_{device_str}_time_total") for e in events)
|
|
356
363
|
if has_self_cuda_time:
|
|
357
|
-
self_cuda_total = sum(
|
|
358
|
-
|
|
364
|
+
self_cuda_total = sum([
|
|
365
|
+
getattr(e, f"self_{device_str}_time_total", 0) or 0 for e in events
|
|
366
|
+
])
|
|
359
367
|
|
|
360
368
|
return Measure(
|
|
361
369
|
self_cpu_total=sum([e.self_cpu_time_total or 0 for e in events]),
|
|
362
370
|
cpu_total=sum([e.cpu_time_total or 0 for e in events]),
|
|
363
371
|
self_cuda_total=self_cuda_total,
|
|
364
|
-
cuda_total=sum(
|
|
372
|
+
cuda_total=sum(
|
|
373
|
+
[getattr(e, f"{device_str}_time_total") or 0 for e in events]),
|
|
365
374
|
self_cpu_memory=self_cpu_memory,
|
|
366
375
|
cpu_memory=cpu_memory,
|
|
367
376
|
self_cuda_memory=self_cuda_memory,
|