pyg-nightly 2.7.0.dev20250701__py3-none-any.whl → 2.7.0.dev20250702__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.7.0.dev20250701.dist-info → pyg_nightly-2.7.0.dev20250702.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20250701.dist-info → pyg_nightly-2.7.0.dev20250702.dist-info}/RECORD +14 -14
- torch_geometric/__init__.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +8 -2
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/encoding.py +12 -3
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +12 -4
- torch_geometric/nn/norm/msg_norm.py +8 -2
- {pyg_nightly-2.7.0.dev20250701.dist-info → pyg_nightly-2.7.0.dev20250702.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250701.dist-info → pyg_nightly-2.7.0.dev20250702.dist-info}/licenses/LICENSE +0 -0
{pyg_nightly-2.7.0.dev20250701.dist-info → pyg_nightly-2.7.0.dev20250702.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pyg-nightly
|
|
3
|
-
Version: 2.7.0.
|
|
3
|
+
Version: 2.7.0.dev20250702
|
|
4
4
|
Summary: Graph Neural Network Library for PyTorch
|
|
5
5
|
Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
|
|
6
6
|
Author-email: Matthias Fey <matthias@pyg.org>
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
torch_geometric/__init__.py,sha256=
|
|
1
|
+
torch_geometric/__init__.py,sha256=ap-t4q8f9aTE0oAW_K5390u2Mlk8-S76rdeUEgPzglo,2250
|
|
2
2
|
torch_geometric/_compile.py,sha256=9yqMTBKatZPr40WavJz9FjNi7pQj8YZAZOyZmmRGXgc,1351
|
|
3
3
|
torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
|
|
4
4
|
torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
|
|
@@ -295,7 +295,7 @@ torch_geometric/metrics/__init__.py,sha256=3krvDobW6vV5yHTjq2S2pmOXxNfysNG26muq7
|
|
|
295
295
|
torch_geometric/metrics/link_pred.py,sha256=dtaI39JB-WqE1B-raiElns6xySRwmkbb9izbcyt6xHI,30886
|
|
296
296
|
torch_geometric/nn/__init__.py,sha256=kQHHHUxFDht2ztD-XFQuv98TvC8MdodaFsIjAvltJBw,874
|
|
297
297
|
torch_geometric/nn/data_parallel.py,sha256=YiybTWoSFyfSzlXAamZ_-y1f7B6tvDEFHOuy_AyJz9Q,4761
|
|
298
|
-
torch_geometric/nn/encoding.py,sha256=
|
|
298
|
+
torch_geometric/nn/encoding.py,sha256=3DCOCO-XFt-lMb97sHWGN-4KeGUFY5lVo9P00SzrCNk,3559
|
|
299
299
|
torch_geometric/nn/fx.py,sha256=PDtaHJAgodh4xf8FNl4fVxPGZJDbRaq3Q9z8qb1DNNI,16063
|
|
300
300
|
torch_geometric/nn/glob.py,sha256=MdHjcUlHmFmTevzwND1_x7dXXJPzIDTBJRGOrGdZ8dQ,1088
|
|
301
301
|
torch_geometric/nn/inits.py,sha256=_8FqacCLPz5Ft2zB5s6dtKGTKWtfrLyCLLuv1QvyKjk,2457
|
|
@@ -325,13 +325,13 @@ torch_geometric/nn/aggr/lcm.py,sha256=TcNqEvHnWpqOc9RFFioBAssQaUhOgMpH1_ovOmgl3w
|
|
|
325
325
|
torch_geometric/nn/aggr/lstm.py,sha256=AdLa4rDd8t_X-GADDTOzRFuifSA0tIYVGKfoOckVtUE,2214
|
|
326
326
|
torch_geometric/nn/aggr/mlp.py,sha256=sHQ4vQcZ-h2aOfFIBiXpAjr2lj7zHT3_TyqQr3WUjxQ,2514
|
|
327
327
|
torch_geometric/nn/aggr/multi.py,sha256=theSIaDlLjGUyAtqDvOFORRpI9gYoZMXUtypX1PV5NQ,8170
|
|
328
|
-
torch_geometric/nn/aggr/patch_transformer.py,sha256=
|
|
328
|
+
torch_geometric/nn/aggr/patch_transformer.py,sha256=tWWBqBIuIPJfvFhkEs-S8cdEhuU1qxHsxoLh_ZnHznw,5498
|
|
329
329
|
torch_geometric/nn/aggr/quantile.py,sha256=sRnKyt4CXr9RmjoPyTl4VUvXgSCMl9PG-fhCGsSZ76c,6189
|
|
330
330
|
torch_geometric/nn/aggr/scaler.py,sha256=GV6gxUFBoKYMQTGybwzoPh708OY6k6chtUYmCIbFGXk,4638
|
|
331
331
|
torch_geometric/nn/aggr/set2set.py,sha256=4GdmsjbBIrap3CG2naeFNsYe5eE-fhrNQOXM1-TIxyM,2446
|
|
332
332
|
torch_geometric/nn/aggr/set_transformer.py,sha256=FG7_JizpFX14M6VSCwLSjYXYdJ1ZiQVbvnaYHIraiuM,4213
|
|
333
333
|
torch_geometric/nn/aggr/sort.py,sha256=bvOOWnFkNOBOZih4rqVZQsjfeDX3vmXo1bpPSFD846w,2507
|
|
334
|
-
torch_geometric/nn/aggr/utils.py,sha256=
|
|
334
|
+
torch_geometric/nn/aggr/utils.py,sha256=SQvdc0g6p_E2j0prA14MW2ekjEDvV-g545N0Q85uc-o,8625
|
|
335
335
|
torch_geometric/nn/aggr/variance_preserving.py,sha256=fu-U_aGYpVLpgSFvVg0ONMe6nqoyv8tZ6Y35qMYTf9w,1126
|
|
336
336
|
torch_geometric/nn/attention/__init__.py,sha256=wLKTmlfP7qL9sZHy4cmDFHEtdwa-MEKE1dT51L1_w10,192
|
|
337
337
|
torch_geometric/nn/attention/performer.py,sha256=2PCDn4_-oNTao2-DkXIaoi18anP01OxRELF2pvp-jk8,7357
|
|
@@ -472,14 +472,14 @@ torch_geometric/nn/nlp/llm.py,sha256=DAv9jOZKXKQNVU2pNMyS1q8gVUtlin_unc6FjLhOYto
|
|
|
472
472
|
torch_geometric/nn/nlp/sentence_transformer.py,sha256=q5M7SGtrUzoSiNhKCGFb7JatWiukdhNF6zdq2yiqxwE,4475
|
|
473
473
|
torch_geometric/nn/nlp/vision_transformer.py,sha256=diVBefjIynzYs8WBlcpTeSVnw1PUecHY--B9Yd-W2hA,863
|
|
474
474
|
torch_geometric/nn/norm/__init__.py,sha256=u2qIDrkbeuObGVXSAIftAlvSd6ouGTtxznCfD-59UiA,669
|
|
475
|
-
torch_geometric/nn/norm/batch_norm.py,sha256=
|
|
476
|
-
torch_geometric/nn/norm/diff_group_norm.py,sha256=
|
|
477
|
-
torch_geometric/nn/norm/graph_norm.py,sha256=
|
|
475
|
+
torch_geometric/nn/norm/batch_norm.py,sha256=fzUNmpdCUsMnNcso3PKDUdWc0UQvziK80-w0ZC0Vb8U,8706
|
|
476
|
+
torch_geometric/nn/norm/diff_group_norm.py,sha256=mT0gM5a8txcAFNwZGKFu12qnNF5Pn95zrMx-RisRsh4,4938
|
|
477
|
+
torch_geometric/nn/norm/graph_norm.py,sha256=VRmpi2jNYRQWXzX6Z0FmBxdEiV-EYXNPbGAGC0XNKH8,2964
|
|
478
478
|
torch_geometric/nn/norm/graph_size_norm.py,sha256=sh5Nue1Ix2jC1T7o7KqOw0_TAOcpZ4VbYzhADWE97-M,1491
|
|
479
|
-
torch_geometric/nn/norm/instance_norm.py,sha256=
|
|
480
|
-
torch_geometric/nn/norm/layer_norm.py,sha256=
|
|
479
|
+
torch_geometric/nn/norm/instance_norm.py,sha256=L8VquSF7Jh5xfxA4YEcSO3IZ4fWR7VIiSukNeRXi4z0,4870
|
|
480
|
+
torch_geometric/nn/norm/layer_norm.py,sha256=XiEyoXdDta6vlInLfwbJVsEgTkBFG6PJm6SpK99-cPE,8243
|
|
481
481
|
torch_geometric/nn/norm/mean_subtraction_norm.py,sha256=KVHOp413mw7obwAN09Le6XdgobtCXpi4UKpjpG1M550,1322
|
|
482
|
-
torch_geometric/nn/norm/msg_norm.py,sha256=
|
|
482
|
+
torch_geometric/nn/norm/msg_norm.py,sha256=NiV51ce1JgxVY6GbzktoSslDnZKWJrMJYZc_eBxz-pg,1903
|
|
483
483
|
torch_geometric/nn/norm/pair_norm.py,sha256=IfHMiVYw_xsy035NakbPGdQVaVC-Ue3Oxwo651Vc47I,2824
|
|
484
484
|
torch_geometric/nn/pool/__init__.py,sha256=VU9cPdLC-MPgt1kfS0ZwehfSD3g0V30VQuR1Wo0mzZE,14250
|
|
485
485
|
torch_geometric/nn/pool/approx_knn.py,sha256=n7C8Cbar6o5tJcuAbzhM5hqMK26hW8dm5DopuocidO0,3967
|
|
@@ -640,7 +640,7 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
|
|
|
640
640
|
torch_geometric/visualization/__init__.py,sha256=b-HnVesXjyJ_L1N-DnjiRiRVf7lhwKaBQF_2i5YMVSU,208
|
|
641
641
|
torch_geometric/visualization/graph.py,sha256=mfZHXYfiU-CWMtfawYc80IxVwVmtK9hbIkSKhM_j7oI,14311
|
|
642
642
|
torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
|
|
643
|
-
pyg_nightly-2.7.0.
|
|
644
|
-
pyg_nightly-2.7.0.
|
|
645
|
-
pyg_nightly-2.7.0.
|
|
646
|
-
pyg_nightly-2.7.0.
|
|
643
|
+
pyg_nightly-2.7.0.dev20250702.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
|
|
644
|
+
pyg_nightly-2.7.0.dev20250702.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
|
|
645
|
+
pyg_nightly-2.7.0.dev20250702.dist-info/METADATA,sha256=66AyTfnfJvD0er8ePN_vOUgj6tD76JJy4QPaIvkh8bw,63005
|
|
646
|
+
pyg_nightly-2.7.0.dev20250702.dist-info/RECORD,,
|
torch_geometric/__init__.py
CHANGED
|
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
|
|
|
31
31
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
|
32
32
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
|
33
33
|
|
|
34
|
-
__version__ = '2.7.0.
|
|
34
|
+
__version__ = '2.7.0.dev20250702'
|
|
35
35
|
|
|
36
36
|
__all__ = [
|
|
37
37
|
'Index',
|
|
@@ -32,6 +32,8 @@ class PatchTransformerAggregation(Aggregation):
|
|
|
32
32
|
aggr (str or list[str], optional): The aggregation module, *e.g.*,
|
|
33
33
|
:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`,
|
|
34
34
|
:obj:`"var"`, :obj:`"std"`. (default: :obj:`"mean"`)
|
|
35
|
+
device (torch.device, optional): The device of the module.
|
|
36
|
+
(default: :obj:`None`)
|
|
35
37
|
"""
|
|
36
38
|
def __init__(
|
|
37
39
|
self,
|
|
@@ -43,6 +45,7 @@ class PatchTransformerAggregation(Aggregation):
|
|
|
43
45
|
heads: int = 1,
|
|
44
46
|
dropout: float = 0.0,
|
|
45
47
|
aggr: Union[str, List[str]] = 'mean',
|
|
48
|
+
device: Optional[torch.device] = None,
|
|
46
49
|
) -> None:
|
|
47
50
|
super().__init__()
|
|
48
51
|
|
|
@@ -55,12 +58,13 @@ class PatchTransformerAggregation(Aggregation):
|
|
|
55
58
|
for aggr in self.aggrs:
|
|
56
59
|
assert aggr in ['sum', 'mean', 'min', 'max', 'var', 'std']
|
|
57
60
|
|
|
58
|
-
self.lin = torch.nn.Linear(in_channels, hidden_channels)
|
|
61
|
+
self.lin = torch.nn.Linear(in_channels, hidden_channels, device=device)
|
|
59
62
|
self.pad_projector = torch.nn.Linear(
|
|
60
63
|
patch_size * hidden_channels,
|
|
61
64
|
hidden_channels,
|
|
65
|
+
device=device,
|
|
62
66
|
)
|
|
63
|
-
self.pe = PositionalEncoding(hidden_channels)
|
|
67
|
+
self.pe = PositionalEncoding(hidden_channels, device=device)
|
|
64
68
|
|
|
65
69
|
self.blocks = torch.nn.ModuleList([
|
|
66
70
|
MultiheadAttentionBlock(
|
|
@@ -68,12 +72,14 @@ class PatchTransformerAggregation(Aggregation):
|
|
|
68
72
|
heads=heads,
|
|
69
73
|
layer_norm=True,
|
|
70
74
|
dropout=dropout,
|
|
75
|
+
device=device,
|
|
71
76
|
) for _ in range(num_transformer_blocks)
|
|
72
77
|
])
|
|
73
78
|
|
|
74
79
|
self.fc = torch.nn.Linear(
|
|
75
80
|
hidden_channels * len(self.aggrs),
|
|
76
81
|
out_channels,
|
|
82
|
+
device=device,
|
|
77
83
|
)
|
|
78
84
|
|
|
79
85
|
def reset_parameters(self) -> None:
|
torch_geometric/nn/aggr/utils.py
CHANGED
|
@@ -26,9 +26,11 @@ class MultiheadAttentionBlock(torch.nn.Module):
|
|
|
26
26
|
normalization. (default: :obj:`True`)
|
|
27
27
|
dropout (float, optional): Dropout probability of attention weights.
|
|
28
28
|
(default: :obj:`0`)
|
|
29
|
+
device (torch.device, optional): The device of the module.
|
|
30
|
+
(default: :obj:`None`)
|
|
29
31
|
"""
|
|
30
32
|
def __init__(self, channels: int, heads: int = 1, layer_norm: bool = True,
|
|
31
|
-
dropout: float = 0.0):
|
|
33
|
+
dropout: float = 0.0, device: Optional[torch.device] = None):
|
|
32
34
|
super().__init__()
|
|
33
35
|
|
|
34
36
|
self.channels = channels
|
|
@@ -40,10 +42,13 @@ class MultiheadAttentionBlock(torch.nn.Module):
|
|
|
40
42
|
heads,
|
|
41
43
|
batch_first=True,
|
|
42
44
|
dropout=dropout,
|
|
45
|
+
device=device,
|
|
43
46
|
)
|
|
44
|
-
self.lin = Linear(channels, channels)
|
|
45
|
-
self.layer_norm1 = LayerNorm(channels
|
|
46
|
-
|
|
47
|
+
self.lin = Linear(channels, channels, device=device)
|
|
48
|
+
self.layer_norm1 = LayerNorm(channels,
|
|
49
|
+
device=device) if layer_norm else None
|
|
50
|
+
self.layer_norm2 = LayerNorm(channels,
|
|
51
|
+
device=device) if layer_norm else None
|
|
47
52
|
|
|
48
53
|
def reset_parameters(self):
|
|
49
54
|
self.attn._reset_parameters()
|
torch_geometric/nn/encoding.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import math
|
|
2
|
+
from typing import Optional
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
from torch import Tensor
|
|
@@ -23,12 +24,15 @@ class PositionalEncoding(torch.nn.Module):
|
|
|
23
24
|
granularity (float, optional): The granularity of the positions. If
|
|
24
25
|
set to smaller value, the encoder will capture more fine-grained
|
|
25
26
|
changes in positions. (default: :obj:`1.0`)
|
|
27
|
+
device (torch.device, optional): The device of the module.
|
|
28
|
+
(default: :obj:`None`)
|
|
26
29
|
"""
|
|
27
30
|
def __init__(
|
|
28
31
|
self,
|
|
29
32
|
out_channels: int,
|
|
30
33
|
base_freq: float = 1e-4,
|
|
31
34
|
granularity: float = 1.0,
|
|
35
|
+
device: Optional[torch.device] = None,
|
|
32
36
|
):
|
|
33
37
|
super().__init__()
|
|
34
38
|
|
|
@@ -40,7 +44,8 @@ class PositionalEncoding(torch.nn.Module):
|
|
|
40
44
|
self.base_freq = base_freq
|
|
41
45
|
self.granularity = granularity
|
|
42
46
|
|
|
43
|
-
frequency = torch.logspace(0, 1, out_channels // 2, base_freq
|
|
47
|
+
frequency = torch.logspace(0, 1, out_channels // 2, base_freq,
|
|
48
|
+
device=device)
|
|
44
49
|
self.register_buffer('frequency', frequency)
|
|
45
50
|
|
|
46
51
|
self.reset_parameters()
|
|
@@ -75,13 +80,17 @@ class TemporalEncoding(torch.nn.Module):
|
|
|
75
80
|
|
|
76
81
|
Args:
|
|
77
82
|
out_channels (int): Size :math:`d` of each output sample.
|
|
83
|
+
device (torch.device, optional): The device of the module.
|
|
84
|
+
(default: :obj:`None`)
|
|
78
85
|
"""
|
|
79
|
-
def __init__(self, out_channels: int
|
|
86
|
+
def __init__(self, out_channels: int,
|
|
87
|
+
device: Optional[torch.device] = None):
|
|
80
88
|
super().__init__()
|
|
81
89
|
self.out_channels = out_channels
|
|
82
90
|
|
|
83
91
|
sqrt = math.sqrt(out_channels)
|
|
84
|
-
weight = 1.0 / sqrt**torch.linspace(0, sqrt, out_channels
|
|
92
|
+
weight = 1.0 / sqrt**torch.linspace(0, sqrt, out_channels,
|
|
93
|
+
device=device).view(1, -1)
|
|
85
94
|
self.register_buffer('weight', weight)
|
|
86
95
|
|
|
87
96
|
self.reset_parameters()
|
|
@@ -39,6 +39,8 @@ class BatchNorm(torch.nn.Module):
|
|
|
39
39
|
with only a single element will work as during in evaluation.
|
|
40
40
|
That is the running mean and variance will be used.
|
|
41
41
|
Requires :obj:`track_running_stats=True`. (default: :obj:`False`)
|
|
42
|
+
device (torch.device, optional): The device to use for the module.
|
|
43
|
+
(default: :obj:`None`)
|
|
42
44
|
"""
|
|
43
45
|
def __init__(
|
|
44
46
|
self,
|
|
@@ -48,6 +50,7 @@ class BatchNorm(torch.nn.Module):
|
|
|
48
50
|
affine: bool = True,
|
|
49
51
|
track_running_stats: bool = True,
|
|
50
52
|
allow_single_element: bool = False,
|
|
53
|
+
device: Optional[torch.device] = None,
|
|
51
54
|
):
|
|
52
55
|
super().__init__()
|
|
53
56
|
|
|
@@ -56,7 +59,7 @@ class BatchNorm(torch.nn.Module):
|
|
|
56
59
|
"'track_running_stats' to be set to `True`")
|
|
57
60
|
|
|
58
61
|
self.module = torch.nn.BatchNorm1d(in_channels, eps, momentum, affine,
|
|
59
|
-
track_running_stats)
|
|
62
|
+
track_running_stats, device=device)
|
|
60
63
|
self.in_channels = in_channels
|
|
61
64
|
self.allow_single_element = allow_single_element
|
|
62
65
|
|
|
@@ -114,6 +117,8 @@ class HeteroBatchNorm(torch.nn.Module):
|
|
|
114
117
|
:obj:`False`, this module does not track such statistics and always
|
|
115
118
|
uses batch statistics in both training and eval modes.
|
|
116
119
|
(default: :obj:`True`)
|
|
120
|
+
device (torch.device, optional): The device to use for the module.
|
|
121
|
+
(default: :obj:`None`)
|
|
117
122
|
"""
|
|
118
123
|
def __init__(
|
|
119
124
|
self,
|
|
@@ -123,6 +128,7 @@ class HeteroBatchNorm(torch.nn.Module):
|
|
|
123
128
|
momentum: Optional[float] = 0.1,
|
|
124
129
|
affine: bool = True,
|
|
125
130
|
track_running_stats: bool = True,
|
|
131
|
+
device: Optional[torch.device] = None,
|
|
126
132
|
):
|
|
127
133
|
super().__init__()
|
|
128
134
|
|
|
@@ -134,17 +140,21 @@ class HeteroBatchNorm(torch.nn.Module):
|
|
|
134
140
|
self.track_running_stats = track_running_stats
|
|
135
141
|
|
|
136
142
|
if self.affine:
|
|
137
|
-
self.weight = Parameter(
|
|
138
|
-
|
|
143
|
+
self.weight = Parameter(
|
|
144
|
+
torch.empty(num_types, in_channels, device=device))
|
|
145
|
+
self.bias = Parameter(
|
|
146
|
+
torch.empty(num_types, in_channels, device=device))
|
|
139
147
|
else:
|
|
140
148
|
self.register_parameter('weight', None)
|
|
141
149
|
self.register_parameter('bias', None)
|
|
142
150
|
|
|
143
151
|
if self.track_running_stats:
|
|
144
|
-
self.register_buffer(
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
152
|
+
self.register_buffer(
|
|
153
|
+
'running_mean',
|
|
154
|
+
torch.empty(num_types, in_channels, device=device))
|
|
155
|
+
self.register_buffer(
|
|
156
|
+
'running_var',
|
|
157
|
+
torch.empty(num_types, in_channels, device=device))
|
|
148
158
|
self.register_buffer('num_batches_tracked', torch.tensor(0))
|
|
149
159
|
else:
|
|
150
160
|
self.register_buffer('running_mean', None)
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
from torch import Tensor
|
|
3
5
|
from torch.nn import BatchNorm1d, Linear
|
|
@@ -39,6 +41,8 @@ class DiffGroupNorm(torch.nn.Module):
|
|
|
39
41
|
:obj:`False`, this module does not track such statistics and always
|
|
40
42
|
uses batch statistics in both training and eval modes.
|
|
41
43
|
(default: :obj:`True`)
|
|
44
|
+
device (torch.device, optional): The device to use for the module.
|
|
45
|
+
(default: :obj:`None`)
|
|
42
46
|
"""
|
|
43
47
|
def __init__(
|
|
44
48
|
self,
|
|
@@ -49,6 +53,7 @@ class DiffGroupNorm(torch.nn.Module):
|
|
|
49
53
|
momentum: float = 0.1,
|
|
50
54
|
affine: bool = True,
|
|
51
55
|
track_running_stats: bool = True,
|
|
56
|
+
device: Optional[torch.device] = None,
|
|
52
57
|
):
|
|
53
58
|
super().__init__()
|
|
54
59
|
|
|
@@ -56,9 +61,9 @@ class DiffGroupNorm(torch.nn.Module):
|
|
|
56
61
|
self.groups = groups
|
|
57
62
|
self.lamda = lamda
|
|
58
63
|
|
|
59
|
-
self.lin = Linear(in_channels, groups, bias=False)
|
|
64
|
+
self.lin = Linear(in_channels, groups, bias=False, device=device)
|
|
60
65
|
self.norm = BatchNorm1d(groups * in_channels, eps, momentum, affine,
|
|
61
|
-
track_running_stats)
|
|
66
|
+
track_running_stats, device=device)
|
|
62
67
|
|
|
63
68
|
self.reset_parameters()
|
|
64
69
|
|
|
@@ -26,16 +26,21 @@ class GraphNorm(torch.nn.Module):
|
|
|
26
26
|
in_channels (int): Size of each input sample.
|
|
27
27
|
eps (float, optional): A value added to the denominator for numerical
|
|
28
28
|
stability. (default: :obj:`1e-5`)
|
|
29
|
+
device (torch.device, optional): The device to use for the module.
|
|
30
|
+
(default: :obj:`None`)
|
|
29
31
|
"""
|
|
30
|
-
def __init__(self, in_channels: int, eps: float = 1e-5
|
|
32
|
+
def __init__(self, in_channels: int, eps: float = 1e-5,
|
|
33
|
+
device: Optional[torch.device] = None):
|
|
31
34
|
super().__init__()
|
|
32
35
|
|
|
33
36
|
self.in_channels = in_channels
|
|
34
37
|
self.eps = eps
|
|
35
38
|
|
|
36
|
-
self.weight = torch.nn.Parameter(
|
|
37
|
-
|
|
38
|
-
self.
|
|
39
|
+
self.weight = torch.nn.Parameter(
|
|
40
|
+
torch.empty(in_channels, device=device))
|
|
41
|
+
self.bias = torch.nn.Parameter(torch.empty(in_channels, device=device))
|
|
42
|
+
self.mean_scale = torch.nn.Parameter(
|
|
43
|
+
torch.empty(in_channels, device=device))
|
|
39
44
|
|
|
40
45
|
self.reset_parameters()
|
|
41
46
|
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
|
+
import torch
|
|
3
4
|
import torch.nn.functional as F
|
|
4
5
|
from torch import Tensor
|
|
5
6
|
from torch.nn.modules.instancenorm import _InstanceNorm
|
|
@@ -36,6 +37,8 @@ class InstanceNorm(_InstanceNorm):
|
|
|
36
37
|
:obj:`False`, this module does not track such statistics and always
|
|
37
38
|
uses instance statistics in both training and eval modes.
|
|
38
39
|
(default: :obj:`False`)
|
|
40
|
+
device (torch.device, optional): The device to use for the module.
|
|
41
|
+
(default: :obj:`None`)
|
|
39
42
|
"""
|
|
40
43
|
def __init__(
|
|
41
44
|
self,
|
|
@@ -44,9 +47,10 @@ class InstanceNorm(_InstanceNorm):
|
|
|
44
47
|
momentum: float = 0.1,
|
|
45
48
|
affine: bool = False,
|
|
46
49
|
track_running_stats: bool = False,
|
|
50
|
+
device: Optional[torch.device] = None,
|
|
47
51
|
):
|
|
48
52
|
super().__init__(in_channels, eps, momentum, affine,
|
|
49
|
-
track_running_stats)
|
|
53
|
+
track_running_stats, device=device)
|
|
50
54
|
|
|
51
55
|
def reset_parameters(self):
|
|
52
56
|
r"""Resets all learnable parameters of the module."""
|
|
@@ -35,6 +35,8 @@ class LayerNorm(torch.nn.Module):
|
|
|
35
35
|
is used, each graph will be considered as an element to be
|
|
36
36
|
normalized. If `"node"` is used, each node will be considered as
|
|
37
37
|
an element to be normalized. (default: :obj:`"graph"`)
|
|
38
|
+
device (torch.device, optional): The device to use for the module.
|
|
39
|
+
(default: :obj:`None`)
|
|
38
40
|
"""
|
|
39
41
|
def __init__(
|
|
40
42
|
self,
|
|
@@ -42,6 +44,7 @@ class LayerNorm(torch.nn.Module):
|
|
|
42
44
|
eps: float = 1e-5,
|
|
43
45
|
affine: bool = True,
|
|
44
46
|
mode: str = 'graph',
|
|
47
|
+
device: Optional[torch.device] = None,
|
|
45
48
|
):
|
|
46
49
|
super().__init__()
|
|
47
50
|
|
|
@@ -51,8 +54,8 @@ class LayerNorm(torch.nn.Module):
|
|
|
51
54
|
self.mode = mode
|
|
52
55
|
|
|
53
56
|
if affine:
|
|
54
|
-
self.weight = Parameter(torch.empty(in_channels))
|
|
55
|
-
self.bias = Parameter(torch.empty(in_channels))
|
|
57
|
+
self.weight = Parameter(torch.empty(in_channels, device=device))
|
|
58
|
+
self.bias = Parameter(torch.empty(in_channels, device=device))
|
|
56
59
|
else:
|
|
57
60
|
self.register_parameter('weight', None)
|
|
58
61
|
self.register_parameter('bias', None)
|
|
@@ -134,6 +137,8 @@ class HeteroLayerNorm(torch.nn.Module):
|
|
|
134
137
|
normalization (:obj:`"node"`). If `"node"` is used, each node will
|
|
135
138
|
be considered as an element to be normalized.
|
|
136
139
|
(default: :obj:`"node"`)
|
|
140
|
+
device (torch.device, optional): The device to use for the module.
|
|
141
|
+
(default: :obj:`None`)
|
|
137
142
|
"""
|
|
138
143
|
def __init__(
|
|
139
144
|
self,
|
|
@@ -142,6 +147,7 @@ class HeteroLayerNorm(torch.nn.Module):
|
|
|
142
147
|
eps: float = 1e-5,
|
|
143
148
|
affine: bool = True,
|
|
144
149
|
mode: str = 'node',
|
|
150
|
+
device: Optional[torch.device] = None,
|
|
145
151
|
):
|
|
146
152
|
super().__init__()
|
|
147
153
|
assert mode == 'node'
|
|
@@ -152,8 +158,10 @@ class HeteroLayerNorm(torch.nn.Module):
|
|
|
152
158
|
self.affine = affine
|
|
153
159
|
|
|
154
160
|
if affine:
|
|
155
|
-
self.weight = Parameter(
|
|
156
|
-
|
|
161
|
+
self.weight = Parameter(
|
|
162
|
+
torch.empty(num_types, in_channels, device=device))
|
|
163
|
+
self.bias = Parameter(
|
|
164
|
+
torch.empty(num_types, in_channels, device=device))
|
|
157
165
|
else:
|
|
158
166
|
self.register_parameter('weight', None)
|
|
159
167
|
self.register_parameter('bias', None)
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
import torch.nn.functional as F
|
|
3
5
|
from torch import Tensor
|
|
@@ -19,10 +21,14 @@ class MessageNorm(torch.nn.Module):
|
|
|
19
21
|
learn_scale (bool, optional): If set to :obj:`True`, will learn the
|
|
20
22
|
scaling factor :math:`s` of message normalization.
|
|
21
23
|
(default: :obj:`False`)
|
|
24
|
+
device (torch.device, optional): The device to use for the module.
|
|
25
|
+
(default: :obj:`None`)
|
|
22
26
|
"""
|
|
23
|
-
def __init__(self, learn_scale: bool = False
|
|
27
|
+
def __init__(self, learn_scale: bool = False,
|
|
28
|
+
device: Optional[torch.device] = None):
|
|
24
29
|
super().__init__()
|
|
25
|
-
self.scale = Parameter(torch.empty(1
|
|
30
|
+
self.scale = Parameter(torch.empty(1, device=device),
|
|
31
|
+
requires_grad=learn_scale)
|
|
26
32
|
self.reset_parameters()
|
|
27
33
|
|
|
28
34
|
def reset_parameters(self):
|
|
File without changes
|
{pyg_nightly-2.7.0.dev20250701.dist-info → pyg_nightly-2.7.0.dev20250702.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|