pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +77 -53
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +226 -189
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +14 -2
- torch_geometric/_compile.py +9 -3
- torch_geometric/_onnx.py +214 -0
- torch_geometric/config_mixin.py +5 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +109 -5
- torch_geometric/data/database.py +4 -0
- torch_geometric/data/dataset.py +14 -11
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +18 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +15 -17
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +1 -1
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +310 -209
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/partition.py +2 -2
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +18 -14
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +87 -3
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +14 -5
- torch_geometric/inspector.py +4 -0
- torch_geometric/io/fs.py +5 -4
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/{nn/nlp → llm/models}/llm.py +179 -31
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +4 -4
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +3 -2
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +14 -0
- torch_geometric/metrics/link_pred.py +745 -92
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +8 -2
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +2 -1
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +5 -4
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +0 -20
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +14 -12
- torch_geometric/nn/model_hub.py +2 -15
- torch_geometric/nn/models/__init__.py +11 -2
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +8 -4
- torch_geometric/nn/pool/cluster_pool.py +3 -4
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +336 -7
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +296 -23
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +8 -6
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -8
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/typing.py +70 -17
- torch_geometric/utils/__init__.py +4 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +27 -12
- torch_geometric/utils/_scatter.py +132 -195
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +4 -3
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +13 -9
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +7 -14
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/sentence_transformer.py +0 -101
|
@@ -233,7 +233,7 @@ class MetaPath2Vec(torch.nn.Module):
|
|
|
233
233
|
"""
|
|
234
234
|
from sklearn.linear_model import LogisticRegression
|
|
235
235
|
|
|
236
|
-
clf = LogisticRegression(solver=solver,
|
|
236
|
+
clf = LogisticRegression(*args, solver=solver,
|
|
237
237
|
**kwargs).fit(train_z.detach().cpu().numpy(),
|
|
238
238
|
train_y.detach().cpu().numpy())
|
|
239
239
|
return clf.score(test_z.detach().cpu().numpy(),
|
torch_geometric/nn/models/mlp.py
CHANGED
|
@@ -99,8 +99,10 @@ class MLP(torch.nn.Module):
|
|
|
99
99
|
act_first = act_first or kwargs.get("relu_first", False)
|
|
100
100
|
batch_norm = kwargs.get("batch_norm", None)
|
|
101
101
|
if batch_norm is not None and isinstance(batch_norm, bool):
|
|
102
|
-
warnings.warn(
|
|
103
|
-
|
|
102
|
+
warnings.warn(
|
|
103
|
+
"Argument `batch_norm` is deprecated, "
|
|
104
|
+
"please use `norm` to specify normalization layer.",
|
|
105
|
+
stacklevel=2)
|
|
104
106
|
norm = 'batch_norm' if batch_norm else None
|
|
105
107
|
batch_norm_kwargs = kwargs.get("batch_norm_kwargs", None)
|
|
106
108
|
norm_kwargs = batch_norm_kwargs or {}
|
|
@@ -181,7 +181,7 @@ class Node2Vec(torch.nn.Module):
|
|
|
181
181
|
"""
|
|
182
182
|
from sklearn.linear_model import LogisticRegression
|
|
183
183
|
|
|
184
|
-
clf = LogisticRegression(solver=solver,
|
|
184
|
+
clf = LogisticRegression(*args, solver=solver,
|
|
185
185
|
**kwargs).fit(train_z.detach().cpu().numpy(),
|
|
186
186
|
train_y.detach().cpu().numpy())
|
|
187
187
|
return clf.score(test_z.detach().cpu().numpy(),
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
from torch_geometric.nn import GATConv, GCNConv
|
|
8
|
+
from torch_geometric.nn.attention import PolynormerAttention
|
|
9
|
+
from torch_geometric.utils import to_dense_batch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Polynormer(torch.nn.Module):
|
|
13
|
+
r"""The polynormer module from the
|
|
14
|
+
`"Polynormer: polynomial-expressive graph
|
|
15
|
+
transformer in linear time"
|
|
16
|
+
<https://arxiv.org/abs/2403.01232>`_ paper.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
in_channels (int): Input channels.
|
|
20
|
+
hidden_channels (int): Hidden channels.
|
|
21
|
+
out_channels (int): Output channels.
|
|
22
|
+
local_layers (int): The number of local attention layers.
|
|
23
|
+
(default: :obj:`7`)
|
|
24
|
+
global_layers (int): The number of global attention layers.
|
|
25
|
+
(default: :obj:`2`)
|
|
26
|
+
in_dropout (float): Input dropout rate.
|
|
27
|
+
(default: :obj:`0.15`)
|
|
28
|
+
dropout (float): Dropout rate.
|
|
29
|
+
(default: :obj:`0.5`)
|
|
30
|
+
global_dropout (float): Global dropout rate.
|
|
31
|
+
(default: :obj:`0.5`)
|
|
32
|
+
heads (int): The number of heads.
|
|
33
|
+
(default: :obj:`1`)
|
|
34
|
+
beta (float): Aggregate type.
|
|
35
|
+
(default: :obj:`0.9`)
|
|
36
|
+
qk_shared (bool optional): Whether weight of query and key are shared.
|
|
37
|
+
(default: :obj:`True`)
|
|
38
|
+
pre_ln (bool): Pre layer normalization.
|
|
39
|
+
(default: :obj:`False`)
|
|
40
|
+
post_bn (bool): Post batch normalization.
|
|
41
|
+
(default: :obj:`True`)
|
|
42
|
+
local_attn (bool): Whether use local attention.
|
|
43
|
+
(default: :obj:`False`)
|
|
44
|
+
"""
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
in_channels: int,
|
|
48
|
+
hidden_channels: int,
|
|
49
|
+
out_channels: int,
|
|
50
|
+
local_layers: int = 7,
|
|
51
|
+
global_layers: int = 2,
|
|
52
|
+
in_dropout: float = 0.15,
|
|
53
|
+
dropout: float = 0.5,
|
|
54
|
+
global_dropout: float = 0.5,
|
|
55
|
+
heads: int = 1,
|
|
56
|
+
beta: float = 0.9,
|
|
57
|
+
qk_shared: bool = False,
|
|
58
|
+
pre_ln: bool = False,
|
|
59
|
+
post_bn: bool = True,
|
|
60
|
+
local_attn: bool = False,
|
|
61
|
+
) -> None:
|
|
62
|
+
super().__init__()
|
|
63
|
+
self._global = False
|
|
64
|
+
self.in_drop = in_dropout
|
|
65
|
+
self.dropout = dropout
|
|
66
|
+
self.pre_ln = pre_ln
|
|
67
|
+
self.post_bn = post_bn
|
|
68
|
+
|
|
69
|
+
self.beta = beta
|
|
70
|
+
|
|
71
|
+
self.h_lins = torch.nn.ModuleList()
|
|
72
|
+
self.local_convs = torch.nn.ModuleList()
|
|
73
|
+
self.lins = torch.nn.ModuleList()
|
|
74
|
+
self.lns = torch.nn.ModuleList()
|
|
75
|
+
if self.pre_ln:
|
|
76
|
+
self.pre_lns = torch.nn.ModuleList()
|
|
77
|
+
if self.post_bn:
|
|
78
|
+
self.post_bns = torch.nn.ModuleList()
|
|
79
|
+
|
|
80
|
+
# first layer
|
|
81
|
+
inner_channels = heads * hidden_channels
|
|
82
|
+
self.h_lins.append(torch.nn.Linear(in_channels, inner_channels))
|
|
83
|
+
if local_attn:
|
|
84
|
+
self.local_convs.append(
|
|
85
|
+
GATConv(in_channels, hidden_channels, heads=heads, concat=True,
|
|
86
|
+
add_self_loops=False, bias=False))
|
|
87
|
+
else:
|
|
88
|
+
self.local_convs.append(
|
|
89
|
+
GCNConv(in_channels, inner_channels, cached=False,
|
|
90
|
+
normalize=True))
|
|
91
|
+
|
|
92
|
+
self.lins.append(torch.nn.Linear(in_channels, inner_channels))
|
|
93
|
+
self.lns.append(torch.nn.LayerNorm(inner_channels))
|
|
94
|
+
if self.pre_ln:
|
|
95
|
+
self.pre_lns.append(torch.nn.LayerNorm(in_channels))
|
|
96
|
+
if self.post_bn:
|
|
97
|
+
self.post_bns.append(torch.nn.BatchNorm1d(inner_channels))
|
|
98
|
+
|
|
99
|
+
# following layers
|
|
100
|
+
for _ in range(local_layers - 1):
|
|
101
|
+
self.h_lins.append(torch.nn.Linear(inner_channels, inner_channels))
|
|
102
|
+
if local_attn:
|
|
103
|
+
self.local_convs.append(
|
|
104
|
+
GATConv(inner_channels, hidden_channels, heads=heads,
|
|
105
|
+
concat=True, add_self_loops=False, bias=False))
|
|
106
|
+
else:
|
|
107
|
+
self.local_convs.append(
|
|
108
|
+
GCNConv(inner_channels, inner_channels, cached=False,
|
|
109
|
+
normalize=True))
|
|
110
|
+
|
|
111
|
+
self.lins.append(torch.nn.Linear(inner_channels, inner_channels))
|
|
112
|
+
self.lns.append(torch.nn.LayerNorm(inner_channels))
|
|
113
|
+
if self.pre_ln:
|
|
114
|
+
self.pre_lns.append(torch.nn.LayerNorm(heads *
|
|
115
|
+
hidden_channels))
|
|
116
|
+
if self.post_bn:
|
|
117
|
+
self.post_bns.append(torch.nn.BatchNorm1d(inner_channels))
|
|
118
|
+
|
|
119
|
+
self.lin_in = torch.nn.Linear(in_channels, inner_channels)
|
|
120
|
+
self.ln = torch.nn.LayerNorm(inner_channels)
|
|
121
|
+
|
|
122
|
+
self.global_attn = torch.nn.ModuleList()
|
|
123
|
+
for _ in range(global_layers):
|
|
124
|
+
self.global_attn.append(
|
|
125
|
+
PolynormerAttention(
|
|
126
|
+
channels=hidden_channels,
|
|
127
|
+
heads=heads,
|
|
128
|
+
head_channels=hidden_channels,
|
|
129
|
+
beta=beta,
|
|
130
|
+
dropout=global_dropout,
|
|
131
|
+
qk_shared=qk_shared,
|
|
132
|
+
))
|
|
133
|
+
self.pred_local = torch.nn.Linear(inner_channels, out_channels)
|
|
134
|
+
self.pred_global = torch.nn.Linear(inner_channels, out_channels)
|
|
135
|
+
self.reset_parameters()
|
|
136
|
+
|
|
137
|
+
def reset_parameters(self) -> None:
|
|
138
|
+
for local_conv in self.local_convs:
|
|
139
|
+
local_conv.reset_parameters()
|
|
140
|
+
for attn in self.global_attn:
|
|
141
|
+
attn.reset_parameters()
|
|
142
|
+
for lin in self.lins:
|
|
143
|
+
lin.reset_parameters()
|
|
144
|
+
for h_lin in self.h_lins:
|
|
145
|
+
h_lin.reset_parameters()
|
|
146
|
+
for ln in self.lns:
|
|
147
|
+
ln.reset_parameters()
|
|
148
|
+
if self.pre_ln:
|
|
149
|
+
for p_ln in self.pre_lns:
|
|
150
|
+
p_ln.reset_parameters()
|
|
151
|
+
if self.post_bn:
|
|
152
|
+
for p_bn in self.post_bns:
|
|
153
|
+
p_bn.reset_parameters()
|
|
154
|
+
self.lin_in.reset_parameters()
|
|
155
|
+
self.ln.reset_parameters()
|
|
156
|
+
self.pred_local.reset_parameters()
|
|
157
|
+
self.pred_global.reset_parameters()
|
|
158
|
+
|
|
159
|
+
def forward(
|
|
160
|
+
self,
|
|
161
|
+
x: Tensor,
|
|
162
|
+
edge_index: Tensor,
|
|
163
|
+
batch: Optional[Tensor],
|
|
164
|
+
) -> Tensor:
|
|
165
|
+
r"""Forward pass.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
x (torch.Tensor): The input node features.
|
|
169
|
+
edge_index (torch.Tensor or SparseTensor): The edge indices.
|
|
170
|
+
batch (torch.Tensor, optional): The batch vector
|
|
171
|
+
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
|
|
172
|
+
each element to a specific example.
|
|
173
|
+
"""
|
|
174
|
+
x = F.dropout(x, p=self.in_drop, training=self.training)
|
|
175
|
+
|
|
176
|
+
# equivariant local attention
|
|
177
|
+
x_local = 0
|
|
178
|
+
for i, local_conv in enumerate(self.local_convs):
|
|
179
|
+
if self.pre_ln:
|
|
180
|
+
x = self.pre_lns[i](x)
|
|
181
|
+
h = self.h_lins[i](x)
|
|
182
|
+
h = F.relu(h)
|
|
183
|
+
x = local_conv(x, edge_index) + self.lins[i](x)
|
|
184
|
+
if self.post_bn:
|
|
185
|
+
x = self.post_bns[i](x)
|
|
186
|
+
x = F.relu(x)
|
|
187
|
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
188
|
+
x = (1 - self.beta) * self.lns[i](h * x) + self.beta * x
|
|
189
|
+
x_local = x_local + x
|
|
190
|
+
|
|
191
|
+
# equivariant global attention
|
|
192
|
+
if self._global:
|
|
193
|
+
batch, indices = batch.sort()
|
|
194
|
+
rev_perm = torch.empty_like(indices)
|
|
195
|
+
rev_perm[indices] = torch.arange(len(indices),
|
|
196
|
+
device=indices.device)
|
|
197
|
+
x_local = self.ln(x_local[indices])
|
|
198
|
+
x_global, mask = to_dense_batch(x_local, batch)
|
|
199
|
+
for attn in self.global_attn:
|
|
200
|
+
x_global = attn(x_global, mask)
|
|
201
|
+
x = x_global[mask][rev_perm]
|
|
202
|
+
x = self.pred_global(x)
|
|
203
|
+
else:
|
|
204
|
+
x = self.pred_local(x_local)
|
|
205
|
+
|
|
206
|
+
return F.log_softmax(x, dim=-1)
|
|
@@ -196,8 +196,8 @@ class InvertibleModule(torch.nn.Module):
|
|
|
196
196
|
class GroupAddRev(InvertibleModule):
|
|
197
197
|
r"""The Grouped Reversible GNN module from the `"Graph Neural Networks with
|
|
198
198
|
1000 Layers" <https://arxiv.org/abs/2106.07476>`_ paper.
|
|
199
|
-
This module enables training of
|
|
200
|
-
independent of the number of layers.
|
|
199
|
+
This module enables training of arbitrary deep GNNs with a memory
|
|
200
|
+
complexity independent of the number of layers.
|
|
201
201
|
|
|
202
202
|
It does so by partitioning input node features :math:`\mathbf{X}` into
|
|
203
203
|
:math:`C` groups across the feature dimension. Then, a grouped reversible
|
|
@@ -249,7 +249,7 @@ class GroupAddRev(InvertibleModule):
|
|
|
249
249
|
else:
|
|
250
250
|
assert num_groups is not None, "Please specific 'num_groups'"
|
|
251
251
|
self.convs = torch.nn.ModuleList([conv])
|
|
252
|
-
for
|
|
252
|
+
for _ in range(num_groups - 1):
|
|
253
253
|
conv = copy.deepcopy(self.convs[0])
|
|
254
254
|
if hasattr(conv, 'reset_parameters'):
|
|
255
255
|
conv.reset_parameters()
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
from torch_geometric.nn.attention import SGFormerAttention
|
|
8
|
+
from torch_geometric.nn.conv import GCNConv
|
|
9
|
+
from torch_geometric.utils import to_dense_batch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class GraphModule(torch.nn.Module):
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
in_channels,
|
|
16
|
+
hidden_channels,
|
|
17
|
+
num_layers=2,
|
|
18
|
+
dropout=0.5,
|
|
19
|
+
):
|
|
20
|
+
super().__init__()
|
|
21
|
+
|
|
22
|
+
self.convs = torch.nn.ModuleList()
|
|
23
|
+
self.fcs = torch.nn.ModuleList()
|
|
24
|
+
self.fcs.append(torch.nn.Linear(in_channels, hidden_channels))
|
|
25
|
+
|
|
26
|
+
self.bns = torch.nn.ModuleList()
|
|
27
|
+
self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
|
|
28
|
+
for _ in range(num_layers):
|
|
29
|
+
self.convs.append(GCNConv(hidden_channels, hidden_channels))
|
|
30
|
+
self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
|
|
31
|
+
|
|
32
|
+
self.dropout = dropout
|
|
33
|
+
self.activation = F.relu
|
|
34
|
+
|
|
35
|
+
def reset_parameters(self):
|
|
36
|
+
for conv in self.convs:
|
|
37
|
+
conv.reset_parameters()
|
|
38
|
+
for bn in self.bns:
|
|
39
|
+
bn.reset_parameters()
|
|
40
|
+
for fc in self.fcs:
|
|
41
|
+
fc.reset_parameters()
|
|
42
|
+
|
|
43
|
+
def forward(self, x, edge_index):
|
|
44
|
+
x = self.fcs[0](x)
|
|
45
|
+
x = self.bns[0](x)
|
|
46
|
+
x = self.activation(x)
|
|
47
|
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
48
|
+
last_x = x
|
|
49
|
+
|
|
50
|
+
for i, conv in enumerate(self.convs):
|
|
51
|
+
x = conv(x, edge_index)
|
|
52
|
+
x = self.bns[i + 1](x)
|
|
53
|
+
x = self.activation(x)
|
|
54
|
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
55
|
+
x = x + last_x
|
|
56
|
+
return x
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class SGModule(torch.nn.Module):
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
in_channels,
|
|
63
|
+
hidden_channels,
|
|
64
|
+
num_layers=2,
|
|
65
|
+
num_heads=1,
|
|
66
|
+
dropout=0.5,
|
|
67
|
+
):
|
|
68
|
+
super().__init__()
|
|
69
|
+
|
|
70
|
+
self.attns = torch.nn.ModuleList()
|
|
71
|
+
self.fcs = torch.nn.ModuleList()
|
|
72
|
+
self.fcs.append(torch.nn.Linear(in_channels, hidden_channels))
|
|
73
|
+
self.bns = torch.nn.ModuleList()
|
|
74
|
+
self.bns.append(torch.nn.LayerNorm(hidden_channels))
|
|
75
|
+
for _ in range(num_layers):
|
|
76
|
+
self.attns.append(
|
|
77
|
+
SGFormerAttention(hidden_channels, num_heads, hidden_channels))
|
|
78
|
+
self.bns.append(torch.nn.LayerNorm(hidden_channels))
|
|
79
|
+
|
|
80
|
+
self.dropout = dropout
|
|
81
|
+
self.activation = F.relu
|
|
82
|
+
|
|
83
|
+
def reset_parameters(self):
|
|
84
|
+
for attn in self.attns:
|
|
85
|
+
attn.reset_parameters()
|
|
86
|
+
for bn in self.bns:
|
|
87
|
+
bn.reset_parameters()
|
|
88
|
+
for fc in self.fcs:
|
|
89
|
+
fc.reset_parameters()
|
|
90
|
+
|
|
91
|
+
def forward(self, x: Tensor, batch: Tensor):
|
|
92
|
+
# to dense batch expects sorted batch
|
|
93
|
+
batch, indices = batch.sort(stable=True)
|
|
94
|
+
rev_perm = torch.empty_like(indices)
|
|
95
|
+
rev_perm[indices] = torch.arange(len(indices), device=indices.device)
|
|
96
|
+
x = x[indices]
|
|
97
|
+
x, mask = to_dense_batch(x, batch)
|
|
98
|
+
layer_ = []
|
|
99
|
+
|
|
100
|
+
# input MLP layer
|
|
101
|
+
x = self.fcs[0](x)
|
|
102
|
+
x = self.bns[0](x)
|
|
103
|
+
x = self.activation(x)
|
|
104
|
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
105
|
+
|
|
106
|
+
# store as residual link
|
|
107
|
+
layer_.append(x)
|
|
108
|
+
|
|
109
|
+
for i, attn in enumerate(self.attns):
|
|
110
|
+
x = attn(x, mask)
|
|
111
|
+
x = (x + layer_[i]) / 2.
|
|
112
|
+
x = self.bns[i + 1](x)
|
|
113
|
+
x = self.activation(x)
|
|
114
|
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
115
|
+
layer_.append(x)
|
|
116
|
+
|
|
117
|
+
x_mask = x[mask]
|
|
118
|
+
# reverse the sorting
|
|
119
|
+
unsorted_x_mask = x_mask[rev_perm]
|
|
120
|
+
return unsorted_x_mask
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class SGFormer(torch.nn.Module):
|
|
124
|
+
r"""The sgformer module from the
|
|
125
|
+
`"SGFormer: Simplifying and Empowering Transformers for
|
|
126
|
+
Large-Graph Representations"
|
|
127
|
+
<https://arxiv.org/abs/2306.10759>`_ paper.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
in_channels (int): Input channels.
|
|
131
|
+
hidden_channels (int): Hidden channels.
|
|
132
|
+
out_channels (int): Output channels.
|
|
133
|
+
trans_num_layers (int): The number of layers for all-pair attention.
|
|
134
|
+
(default: :obj:`2`)
|
|
135
|
+
trans_num_heads (int): The number of heads for attention.
|
|
136
|
+
(default: :obj:`1`)
|
|
137
|
+
trans_dropout (float): Global dropout rate.
|
|
138
|
+
(default: :obj:`0.5`)
|
|
139
|
+
gnn_num_layers (int): The number of layers for GNN.
|
|
140
|
+
(default: :obj:`3`)
|
|
141
|
+
gnn_dropout (float): GNN dropout rate.
|
|
142
|
+
(default: :obj:`0.5`)
|
|
143
|
+
graph_weight (float): The weight balance global and gnn module.
|
|
144
|
+
(default: :obj:`0.5`)
|
|
145
|
+
aggregate (str): Aggregate type.
|
|
146
|
+
(default: :obj:`add`)
|
|
147
|
+
"""
|
|
148
|
+
def __init__(
|
|
149
|
+
self,
|
|
150
|
+
in_channels: int,
|
|
151
|
+
hidden_channels: int,
|
|
152
|
+
out_channels: int,
|
|
153
|
+
trans_num_layers: int = 2,
|
|
154
|
+
trans_num_heads: int = 1,
|
|
155
|
+
trans_dropout: float = 0.5,
|
|
156
|
+
gnn_num_layers: int = 3,
|
|
157
|
+
gnn_dropout: float = 0.5,
|
|
158
|
+
graph_weight: float = 0.5,
|
|
159
|
+
aggregate: str = 'add',
|
|
160
|
+
):
|
|
161
|
+
super().__init__()
|
|
162
|
+
self.trans_conv = SGModule(
|
|
163
|
+
in_channels,
|
|
164
|
+
hidden_channels,
|
|
165
|
+
trans_num_layers,
|
|
166
|
+
trans_num_heads,
|
|
167
|
+
trans_dropout,
|
|
168
|
+
)
|
|
169
|
+
self.graph_conv = GraphModule(
|
|
170
|
+
in_channels,
|
|
171
|
+
hidden_channels,
|
|
172
|
+
gnn_num_layers,
|
|
173
|
+
gnn_dropout,
|
|
174
|
+
)
|
|
175
|
+
self.graph_weight = graph_weight
|
|
176
|
+
|
|
177
|
+
self.aggregate = aggregate
|
|
178
|
+
|
|
179
|
+
if aggregate == 'add':
|
|
180
|
+
self.fc = torch.nn.Linear(hidden_channels, out_channels)
|
|
181
|
+
elif aggregate == 'cat':
|
|
182
|
+
self.fc = torch.nn.Linear(2 * hidden_channels, out_channels)
|
|
183
|
+
else:
|
|
184
|
+
raise ValueError(f'Invalid aggregate type:{aggregate}')
|
|
185
|
+
|
|
186
|
+
self.params1 = list(self.trans_conv.parameters())
|
|
187
|
+
self.params2 = list(self.graph_conv.parameters())
|
|
188
|
+
self.params2.extend(list(self.fc.parameters()))
|
|
189
|
+
|
|
190
|
+
self.out_channels = out_channels
|
|
191
|
+
|
|
192
|
+
def reset_parameters(self) -> None:
|
|
193
|
+
self.trans_conv.reset_parameters()
|
|
194
|
+
self.graph_conv.reset_parameters()
|
|
195
|
+
self.fc.reset_parameters()
|
|
196
|
+
|
|
197
|
+
def forward(
|
|
198
|
+
self,
|
|
199
|
+
x: Tensor,
|
|
200
|
+
edge_index: Tensor,
|
|
201
|
+
batch: Optional[Tensor],
|
|
202
|
+
) -> Tensor:
|
|
203
|
+
r"""Forward pass.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
x (torch.Tensor): The input node features.
|
|
207
|
+
edge_index (torch.Tensor or SparseTensor): The edge indices.
|
|
208
|
+
batch (torch.Tensor, optional): The batch vector
|
|
209
|
+
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
|
|
210
|
+
each element to a specific example.
|
|
211
|
+
"""
|
|
212
|
+
x1 = self.trans_conv(x, batch)
|
|
213
|
+
x2 = self.graph_conv(x, edge_index)
|
|
214
|
+
if self.aggregate == 'add':
|
|
215
|
+
x = self.graph_weight * x2 + (1 - self.graph_weight) * x1
|
|
216
|
+
else:
|
|
217
|
+
x = torch.cat((x1, x2), dim=1)
|
|
218
|
+
x = self.fc(x)
|
|
219
|
+
return F.log_softmax(x, dim=-1)
|
|
@@ -45,7 +45,7 @@ class SignedGCN(torch.nn.Module):
|
|
|
45
45
|
self.conv1 = SignedConv(in_channels, hidden_channels // 2,
|
|
46
46
|
first_aggr=True)
|
|
47
47
|
self.convs = torch.nn.ModuleList()
|
|
48
|
-
for
|
|
48
|
+
for _ in range(num_layers - 1):
|
|
49
49
|
self.convs.append(
|
|
50
50
|
SignedConv(hidden_channels // 2, hidden_channels // 2,
|
|
51
51
|
first_aggr=False))
|
|
@@ -11,7 +11,7 @@ from torch_geometric.utils import scatter
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class CosineCutoff(torch.nn.Module):
|
|
14
|
-
r"""
|
|
14
|
+
r"""Applies a cosine cutoff to the input distances.
|
|
15
15
|
|
|
16
16
|
.. math::
|
|
17
17
|
\text{cutoffs} =
|
|
@@ -572,7 +572,7 @@ class ViS_MP(MessagePassing):
|
|
|
572
572
|
d_ij: Tensor,
|
|
573
573
|
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
|
|
574
574
|
r"""Computes the residual scalar and vector features of the nodes and
|
|
575
|
-
scalar
|
|
575
|
+
scalar features of the edges.
|
|
576
576
|
|
|
577
577
|
Args:
|
|
578
578
|
x (torch.Tensor): The scalar features of the nodes.
|
|
@@ -39,6 +39,8 @@ class BatchNorm(torch.nn.Module):
|
|
|
39
39
|
with only a single element will work as during in evaluation.
|
|
40
40
|
That is the running mean and variance will be used.
|
|
41
41
|
Requires :obj:`track_running_stats=True`. (default: :obj:`False`)
|
|
42
|
+
device (torch.device, optional): The device to use for the module.
|
|
43
|
+
(default: :obj:`None`)
|
|
42
44
|
"""
|
|
43
45
|
def __init__(
|
|
44
46
|
self,
|
|
@@ -48,6 +50,7 @@ class BatchNorm(torch.nn.Module):
|
|
|
48
50
|
affine: bool = True,
|
|
49
51
|
track_running_stats: bool = True,
|
|
50
52
|
allow_single_element: bool = False,
|
|
53
|
+
device: Optional[torch.device] = None,
|
|
51
54
|
):
|
|
52
55
|
super().__init__()
|
|
53
56
|
|
|
@@ -56,7 +59,7 @@ class BatchNorm(torch.nn.Module):
|
|
|
56
59
|
"'track_running_stats' to be set to `True`")
|
|
57
60
|
|
|
58
61
|
self.module = torch.nn.BatchNorm1d(in_channels, eps, momentum, affine,
|
|
59
|
-
track_running_stats)
|
|
62
|
+
track_running_stats, device=device)
|
|
60
63
|
self.in_channels = in_channels
|
|
61
64
|
self.allow_single_element = allow_single_element
|
|
62
65
|
|
|
@@ -114,6 +117,8 @@ class HeteroBatchNorm(torch.nn.Module):
|
|
|
114
117
|
:obj:`False`, this module does not track such statistics and always
|
|
115
118
|
uses batch statistics in both training and eval modes.
|
|
116
119
|
(default: :obj:`True`)
|
|
120
|
+
device (torch.device, optional): The device to use for the module.
|
|
121
|
+
(default: :obj:`None`)
|
|
117
122
|
"""
|
|
118
123
|
def __init__(
|
|
119
124
|
self,
|
|
@@ -123,6 +128,7 @@ class HeteroBatchNorm(torch.nn.Module):
|
|
|
123
128
|
momentum: Optional[float] = 0.1,
|
|
124
129
|
affine: bool = True,
|
|
125
130
|
track_running_stats: bool = True,
|
|
131
|
+
device: Optional[torch.device] = None,
|
|
126
132
|
):
|
|
127
133
|
super().__init__()
|
|
128
134
|
|
|
@@ -134,17 +140,21 @@ class HeteroBatchNorm(torch.nn.Module):
|
|
|
134
140
|
self.track_running_stats = track_running_stats
|
|
135
141
|
|
|
136
142
|
if self.affine:
|
|
137
|
-
self.weight = Parameter(
|
|
138
|
-
|
|
143
|
+
self.weight = Parameter(
|
|
144
|
+
torch.empty(num_types, in_channels, device=device))
|
|
145
|
+
self.bias = Parameter(
|
|
146
|
+
torch.empty(num_types, in_channels, device=device))
|
|
139
147
|
else:
|
|
140
148
|
self.register_parameter('weight', None)
|
|
141
149
|
self.register_parameter('bias', None)
|
|
142
150
|
|
|
143
151
|
if self.track_running_stats:
|
|
144
|
-
self.register_buffer(
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
152
|
+
self.register_buffer(
|
|
153
|
+
'running_mean',
|
|
154
|
+
torch.empty(num_types, in_channels, device=device))
|
|
155
|
+
self.register_buffer(
|
|
156
|
+
'running_var',
|
|
157
|
+
torch.empty(num_types, in_channels, device=device))
|
|
148
158
|
self.register_buffer('num_batches_tracked', torch.tensor(0))
|
|
149
159
|
else:
|
|
150
160
|
self.register_buffer('running_mean', None)
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
from torch import Tensor
|
|
3
5
|
from torch.nn import BatchNorm1d, Linear
|
|
@@ -39,6 +41,8 @@ class DiffGroupNorm(torch.nn.Module):
|
|
|
39
41
|
:obj:`False`, this module does not track such statistics and always
|
|
40
42
|
uses batch statistics in both training and eval modes.
|
|
41
43
|
(default: :obj:`True`)
|
|
44
|
+
device (torch.device, optional): The device to use for the module.
|
|
45
|
+
(default: :obj:`None`)
|
|
42
46
|
"""
|
|
43
47
|
def __init__(
|
|
44
48
|
self,
|
|
@@ -49,6 +53,7 @@ class DiffGroupNorm(torch.nn.Module):
|
|
|
49
53
|
momentum: float = 0.1,
|
|
50
54
|
affine: bool = True,
|
|
51
55
|
track_running_stats: bool = True,
|
|
56
|
+
device: Optional[torch.device] = None,
|
|
52
57
|
):
|
|
53
58
|
super().__init__()
|
|
54
59
|
|
|
@@ -56,9 +61,9 @@ class DiffGroupNorm(torch.nn.Module):
|
|
|
56
61
|
self.groups = groups
|
|
57
62
|
self.lamda = lamda
|
|
58
63
|
|
|
59
|
-
self.lin = Linear(in_channels, groups, bias=False)
|
|
64
|
+
self.lin = Linear(in_channels, groups, bias=False, device=device)
|
|
60
65
|
self.norm = BatchNorm1d(groups * in_channels, eps, momentum, affine,
|
|
61
|
-
track_running_stats)
|
|
66
|
+
track_running_stats, device=device)
|
|
62
67
|
|
|
63
68
|
self.reset_parameters()
|
|
64
69
|
|
|
@@ -26,16 +26,21 @@ class GraphNorm(torch.nn.Module):
|
|
|
26
26
|
in_channels (int): Size of each input sample.
|
|
27
27
|
eps (float, optional): A value added to the denominator for numerical
|
|
28
28
|
stability. (default: :obj:`1e-5`)
|
|
29
|
+
device (torch.device, optional): The device to use for the module.
|
|
30
|
+
(default: :obj:`None`)
|
|
29
31
|
"""
|
|
30
|
-
def __init__(self, in_channels: int, eps: float = 1e-5
|
|
32
|
+
def __init__(self, in_channels: int, eps: float = 1e-5,
|
|
33
|
+
device: Optional[torch.device] = None):
|
|
31
34
|
super().__init__()
|
|
32
35
|
|
|
33
36
|
self.in_channels = in_channels
|
|
34
37
|
self.eps = eps
|
|
35
38
|
|
|
36
|
-
self.weight = torch.nn.Parameter(
|
|
37
|
-
|
|
38
|
-
self.
|
|
39
|
+
self.weight = torch.nn.Parameter(
|
|
40
|
+
torch.empty(in_channels, device=device))
|
|
41
|
+
self.bias = torch.nn.Parameter(torch.empty(in_channels, device=device))
|
|
42
|
+
self.mean_scale = torch.nn.Parameter(
|
|
43
|
+
torch.empty(in_channels, device=device))
|
|
39
44
|
|
|
40
45
|
self.reset_parameters()
|
|
41
46
|
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
|
+
import torch
|
|
3
4
|
import torch.nn.functional as F
|
|
4
5
|
from torch import Tensor
|
|
5
6
|
from torch.nn.modules.instancenorm import _InstanceNorm
|
|
@@ -36,6 +37,8 @@ class InstanceNorm(_InstanceNorm):
|
|
|
36
37
|
:obj:`False`, this module does not track such statistics and always
|
|
37
38
|
uses instance statistics in both training and eval modes.
|
|
38
39
|
(default: :obj:`False`)
|
|
40
|
+
device (torch.device, optional): The device to use for the module.
|
|
41
|
+
(default: :obj:`None`)
|
|
39
42
|
"""
|
|
40
43
|
def __init__(
|
|
41
44
|
self,
|
|
@@ -44,9 +47,10 @@ class InstanceNorm(_InstanceNorm):
|
|
|
44
47
|
momentum: float = 0.1,
|
|
45
48
|
affine: bool = False,
|
|
46
49
|
track_running_stats: bool = False,
|
|
50
|
+
device: Optional[torch.device] = None,
|
|
47
51
|
):
|
|
48
52
|
super().__init__(in_channels, eps, momentum, affine,
|
|
49
|
-
track_running_stats)
|
|
53
|
+
track_running_stats, device=device)
|
|
50
54
|
|
|
51
55
|
def reset_parameters(self):
|
|
52
56
|
r"""Resets all learnable parameters of the module."""
|