pyg-nightly 2.7.0.dev20250701__py3-none-any.whl → 2.7.0.dev20250702__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyg-nightly
3
- Version: 2.7.0.dev20250701
3
+ Version: 2.7.0.dev20250702
4
4
  Summary: Graph Neural Network Library for PyTorch
5
5
  Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
6
6
  Author-email: Matthias Fey <matthias@pyg.org>
@@ -1,4 +1,4 @@
1
- torch_geometric/__init__.py,sha256=DhH0xww6Nw5tVjgkem8O7_mt90VkyoLFJQzTSdB69n8,2250
1
+ torch_geometric/__init__.py,sha256=ap-t4q8f9aTE0oAW_K5390u2Mlk8-S76rdeUEgPzglo,2250
2
2
  torch_geometric/_compile.py,sha256=9yqMTBKatZPr40WavJz9FjNi7pQj8YZAZOyZmmRGXgc,1351
3
3
  torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
4
4
  torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
@@ -295,7 +295,7 @@ torch_geometric/metrics/__init__.py,sha256=3krvDobW6vV5yHTjq2S2pmOXxNfysNG26muq7
295
295
  torch_geometric/metrics/link_pred.py,sha256=dtaI39JB-WqE1B-raiElns6xySRwmkbb9izbcyt6xHI,30886
296
296
  torch_geometric/nn/__init__.py,sha256=kQHHHUxFDht2ztD-XFQuv98TvC8MdodaFsIjAvltJBw,874
297
297
  torch_geometric/nn/data_parallel.py,sha256=YiybTWoSFyfSzlXAamZ_-y1f7B6tvDEFHOuy_AyJz9Q,4761
298
- torch_geometric/nn/encoding.py,sha256=QNjwWczYExZ1wRGBmpuqYbn6tB7NC4BU-DEgzjhcZqw,3115
298
+ torch_geometric/nn/encoding.py,sha256=3DCOCO-XFt-lMb97sHWGN-4KeGUFY5lVo9P00SzrCNk,3559
299
299
  torch_geometric/nn/fx.py,sha256=PDtaHJAgodh4xf8FNl4fVxPGZJDbRaq3Q9z8qb1DNNI,16063
300
300
  torch_geometric/nn/glob.py,sha256=MdHjcUlHmFmTevzwND1_x7dXXJPzIDTBJRGOrGdZ8dQ,1088
301
301
  torch_geometric/nn/inits.py,sha256=_8FqacCLPz5Ft2zB5s6dtKGTKWtfrLyCLLuv1QvyKjk,2457
@@ -325,13 +325,13 @@ torch_geometric/nn/aggr/lcm.py,sha256=TcNqEvHnWpqOc9RFFioBAssQaUhOgMpH1_ovOmgl3w
325
325
  torch_geometric/nn/aggr/lstm.py,sha256=AdLa4rDd8t_X-GADDTOzRFuifSA0tIYVGKfoOckVtUE,2214
326
326
  torch_geometric/nn/aggr/mlp.py,sha256=sHQ4vQcZ-h2aOfFIBiXpAjr2lj7zHT3_TyqQr3WUjxQ,2514
327
327
  torch_geometric/nn/aggr/multi.py,sha256=theSIaDlLjGUyAtqDvOFORRpI9gYoZMXUtypX1PV5NQ,8170
328
- torch_geometric/nn/aggr/patch_transformer.py,sha256=SP--1IaXrHWjjGgH7yIPeO84b5NAwn65zHxaTid119o,5234
328
+ torch_geometric/nn/aggr/patch_transformer.py,sha256=tWWBqBIuIPJfvFhkEs-S8cdEhuU1qxHsxoLh_ZnHznw,5498
329
329
  torch_geometric/nn/aggr/quantile.py,sha256=sRnKyt4CXr9RmjoPyTl4VUvXgSCMl9PG-fhCGsSZ76c,6189
330
330
  torch_geometric/nn/aggr/scaler.py,sha256=GV6gxUFBoKYMQTGybwzoPh708OY6k6chtUYmCIbFGXk,4638
331
331
  torch_geometric/nn/aggr/set2set.py,sha256=4GdmsjbBIrap3CG2naeFNsYe5eE-fhrNQOXM1-TIxyM,2446
332
332
  torch_geometric/nn/aggr/set_transformer.py,sha256=FG7_JizpFX14M6VSCwLSjYXYdJ1ZiQVbvnaYHIraiuM,4213
333
333
  torch_geometric/nn/aggr/sort.py,sha256=bvOOWnFkNOBOZih4rqVZQsjfeDX3vmXo1bpPSFD846w,2507
334
- torch_geometric/nn/aggr/utils.py,sha256=CLJ-ZrVWYIOBpdhQBLAz94dj3cMKKKc3qwGr4DFbiCU,8338
334
+ torch_geometric/nn/aggr/utils.py,sha256=SQvdc0g6p_E2j0prA14MW2ekjEDvV-g545N0Q85uc-o,8625
335
335
  torch_geometric/nn/aggr/variance_preserving.py,sha256=fu-U_aGYpVLpgSFvVg0ONMe6nqoyv8tZ6Y35qMYTf9w,1126
336
336
  torch_geometric/nn/attention/__init__.py,sha256=wLKTmlfP7qL9sZHy4cmDFHEtdwa-MEKE1dT51L1_w10,192
337
337
  torch_geometric/nn/attention/performer.py,sha256=2PCDn4_-oNTao2-DkXIaoi18anP01OxRELF2pvp-jk8,7357
@@ -472,14 +472,14 @@ torch_geometric/nn/nlp/llm.py,sha256=DAv9jOZKXKQNVU2pNMyS1q8gVUtlin_unc6FjLhOYto
472
472
  torch_geometric/nn/nlp/sentence_transformer.py,sha256=q5M7SGtrUzoSiNhKCGFb7JatWiukdhNF6zdq2yiqxwE,4475
473
473
  torch_geometric/nn/nlp/vision_transformer.py,sha256=diVBefjIynzYs8WBlcpTeSVnw1PUecHY--B9Yd-W2hA,863
474
474
  torch_geometric/nn/norm/__init__.py,sha256=u2qIDrkbeuObGVXSAIftAlvSd6ouGTtxznCfD-59UiA,669
475
- torch_geometric/nn/norm/batch_norm.py,sha256=sJKrinHGwA-noIgteg1RD2W06rd0zskD-rXuY-36glY,8283
476
- torch_geometric/nn/norm/diff_group_norm.py,sha256=b57XvNekrUYGDjNJlGeqvaMGNJmHwopSF0_yyBWlLuA,4722
477
- torch_geometric/nn/norm/graph_norm.py,sha256=Tld_9_dzst4yEw58DZo4U--4QryA6pP2bsNfmqEDgrY,2727
475
+ torch_geometric/nn/norm/batch_norm.py,sha256=fzUNmpdCUsMnNcso3PKDUdWc0UQvziK80-w0ZC0Vb8U,8706
476
+ torch_geometric/nn/norm/diff_group_norm.py,sha256=mT0gM5a8txcAFNwZGKFu12qnNF5Pn95zrMx-RisRsh4,4938
477
+ torch_geometric/nn/norm/graph_norm.py,sha256=VRmpi2jNYRQWXzX6Z0FmBxdEiV-EYXNPbGAGC0XNKH8,2964
478
478
  torch_geometric/nn/norm/graph_size_norm.py,sha256=sh5Nue1Ix2jC1T7o7KqOw0_TAOcpZ4VbYzhADWE97-M,1491
479
- torch_geometric/nn/norm/instance_norm.py,sha256=lUCZccuQNY8gfYUz-YRrNeSVckYuIJSFaW_m2HMp3iY,4685
480
- torch_geometric/nn/norm/layer_norm.py,sha256=m7a7Uoyx0zZhAxJ6Kj5N5DOg4zQnojO1FA1i661wW80,7835
479
+ torch_geometric/nn/norm/instance_norm.py,sha256=L8VquSF7Jh5xfxA4YEcSO3IZ4fWR7VIiSukNeRXi4z0,4870
480
+ torch_geometric/nn/norm/layer_norm.py,sha256=XiEyoXdDta6vlInLfwbJVsEgTkBFG6PJm6SpK99-cPE,8243
481
481
  torch_geometric/nn/norm/mean_subtraction_norm.py,sha256=KVHOp413mw7obwAN09Le6XdgobtCXpi4UKpjpG1M550,1322
482
- torch_geometric/nn/norm/msg_norm.py,sha256=zaQtqhs55LU-e6KPC4ylaSdge4KvEoseqOt7pmAzi2s,1662
482
+ torch_geometric/nn/norm/msg_norm.py,sha256=NiV51ce1JgxVY6GbzktoSslDnZKWJrMJYZc_eBxz-pg,1903
483
483
  torch_geometric/nn/norm/pair_norm.py,sha256=IfHMiVYw_xsy035NakbPGdQVaVC-Ue3Oxwo651Vc47I,2824
484
484
  torch_geometric/nn/pool/__init__.py,sha256=VU9cPdLC-MPgt1kfS0ZwehfSD3g0V30VQuR1Wo0mzZE,14250
485
485
  torch_geometric/nn/pool/approx_knn.py,sha256=n7C8Cbar6o5tJcuAbzhM5hqMK26hW8dm5DopuocidO0,3967
@@ -640,7 +640,7 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
640
640
  torch_geometric/visualization/__init__.py,sha256=b-HnVesXjyJ_L1N-DnjiRiRVf7lhwKaBQF_2i5YMVSU,208
641
641
  torch_geometric/visualization/graph.py,sha256=mfZHXYfiU-CWMtfawYc80IxVwVmtK9hbIkSKhM_j7oI,14311
642
642
  torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
643
- pyg_nightly-2.7.0.dev20250701.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
644
- pyg_nightly-2.7.0.dev20250701.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
645
- pyg_nightly-2.7.0.dev20250701.dist-info/METADATA,sha256=a6gNTxOjyu0eVS2qc87xRDkE8TutOw6BvtZGm15nUTI,63005
646
- pyg_nightly-2.7.0.dev20250701.dist-info/RECORD,,
643
+ pyg_nightly-2.7.0.dev20250702.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
644
+ pyg_nightly-2.7.0.dev20250702.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
645
+ pyg_nightly-2.7.0.dev20250702.dist-info/METADATA,sha256=66AyTfnfJvD0er8ePN_vOUgj6tD76JJy4QPaIvkh8bw,63005
646
+ pyg_nightly-2.7.0.dev20250702.dist-info/RECORD,,
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
31
31
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
32
32
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
33
33
 
34
- __version__ = '2.7.0.dev20250701'
34
+ __version__ = '2.7.0.dev20250702'
35
35
 
36
36
  __all__ = [
37
37
  'Index',
@@ -32,6 +32,8 @@ class PatchTransformerAggregation(Aggregation):
32
32
  aggr (str or list[str], optional): The aggregation module, *e.g.*,
33
33
  :obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`,
34
34
  :obj:`"var"`, :obj:`"std"`. (default: :obj:`"mean"`)
35
+ device (torch.device, optional): The device of the module.
36
+ (default: :obj:`None`)
35
37
  """
36
38
  def __init__(
37
39
  self,
@@ -43,6 +45,7 @@ class PatchTransformerAggregation(Aggregation):
43
45
  heads: int = 1,
44
46
  dropout: float = 0.0,
45
47
  aggr: Union[str, List[str]] = 'mean',
48
+ device: Optional[torch.device] = None,
46
49
  ) -> None:
47
50
  super().__init__()
48
51
 
@@ -55,12 +58,13 @@ class PatchTransformerAggregation(Aggregation):
55
58
  for aggr in self.aggrs:
56
59
  assert aggr in ['sum', 'mean', 'min', 'max', 'var', 'std']
57
60
 
58
- self.lin = torch.nn.Linear(in_channels, hidden_channels)
61
+ self.lin = torch.nn.Linear(in_channels, hidden_channels, device=device)
59
62
  self.pad_projector = torch.nn.Linear(
60
63
  patch_size * hidden_channels,
61
64
  hidden_channels,
65
+ device=device,
62
66
  )
63
- self.pe = PositionalEncoding(hidden_channels)
67
+ self.pe = PositionalEncoding(hidden_channels, device=device)
64
68
 
65
69
  self.blocks = torch.nn.ModuleList([
66
70
  MultiheadAttentionBlock(
@@ -68,12 +72,14 @@ class PatchTransformerAggregation(Aggregation):
68
72
  heads=heads,
69
73
  layer_norm=True,
70
74
  dropout=dropout,
75
+ device=device,
71
76
  ) for _ in range(num_transformer_blocks)
72
77
  ])
73
78
 
74
79
  self.fc = torch.nn.Linear(
75
80
  hidden_channels * len(self.aggrs),
76
81
  out_channels,
82
+ device=device,
77
83
  )
78
84
 
79
85
  def reset_parameters(self) -> None:
@@ -26,9 +26,11 @@ class MultiheadAttentionBlock(torch.nn.Module):
26
26
  normalization. (default: :obj:`True`)
27
27
  dropout (float, optional): Dropout probability of attention weights.
28
28
  (default: :obj:`0`)
29
+ device (torch.device, optional): The device of the module.
30
+ (default: :obj:`None`)
29
31
  """
30
32
  def __init__(self, channels: int, heads: int = 1, layer_norm: bool = True,
31
- dropout: float = 0.0):
33
+ dropout: float = 0.0, device: Optional[torch.device] = None):
32
34
  super().__init__()
33
35
 
34
36
  self.channels = channels
@@ -40,10 +42,13 @@ class MultiheadAttentionBlock(torch.nn.Module):
40
42
  heads,
41
43
  batch_first=True,
42
44
  dropout=dropout,
45
+ device=device,
43
46
  )
44
- self.lin = Linear(channels, channels)
45
- self.layer_norm1 = LayerNorm(channels) if layer_norm else None
46
- self.layer_norm2 = LayerNorm(channels) if layer_norm else None
47
+ self.lin = Linear(channels, channels, device=device)
48
+ self.layer_norm1 = LayerNorm(channels,
49
+ device=device) if layer_norm else None
50
+ self.layer_norm2 = LayerNorm(channels,
51
+ device=device) if layer_norm else None
47
52
 
48
53
  def reset_parameters(self):
49
54
  self.attn._reset_parameters()
@@ -1,4 +1,5 @@
1
1
  import math
2
+ from typing import Optional
2
3
 
3
4
  import torch
4
5
  from torch import Tensor
@@ -23,12 +24,15 @@ class PositionalEncoding(torch.nn.Module):
23
24
  granularity (float, optional): The granularity of the positions. If
24
25
  set to smaller value, the encoder will capture more fine-grained
25
26
  changes in positions. (default: :obj:`1.0`)
27
+ device (torch.device, optional): The device of the module.
28
+ (default: :obj:`None`)
26
29
  """
27
30
  def __init__(
28
31
  self,
29
32
  out_channels: int,
30
33
  base_freq: float = 1e-4,
31
34
  granularity: float = 1.0,
35
+ device: Optional[torch.device] = None,
32
36
  ):
33
37
  super().__init__()
34
38
 
@@ -40,7 +44,8 @@ class PositionalEncoding(torch.nn.Module):
40
44
  self.base_freq = base_freq
41
45
  self.granularity = granularity
42
46
 
43
- frequency = torch.logspace(0, 1, out_channels // 2, base_freq)
47
+ frequency = torch.logspace(0, 1, out_channels // 2, base_freq,
48
+ device=device)
44
49
  self.register_buffer('frequency', frequency)
45
50
 
46
51
  self.reset_parameters()
@@ -75,13 +80,17 @@ class TemporalEncoding(torch.nn.Module):
75
80
 
76
81
  Args:
77
82
  out_channels (int): Size :math:`d` of each output sample.
83
+ device (torch.device, optional): The device of the module.
84
+ (default: :obj:`None`)
78
85
  """
79
- def __init__(self, out_channels: int):
86
+ def __init__(self, out_channels: int,
87
+ device: Optional[torch.device] = None):
80
88
  super().__init__()
81
89
  self.out_channels = out_channels
82
90
 
83
91
  sqrt = math.sqrt(out_channels)
84
- weight = 1.0 / sqrt**torch.linspace(0, sqrt, out_channels).view(1, -1)
92
+ weight = 1.0 / sqrt**torch.linspace(0, sqrt, out_channels,
93
+ device=device).view(1, -1)
85
94
  self.register_buffer('weight', weight)
86
95
 
87
96
  self.reset_parameters()
@@ -39,6 +39,8 @@ class BatchNorm(torch.nn.Module):
39
39
  with only a single element will work as during in evaluation.
40
40
  That is the running mean and variance will be used.
41
41
  Requires :obj:`track_running_stats=True`. (default: :obj:`False`)
42
+ device (torch.device, optional): The device to use for the module.
43
+ (default: :obj:`None`)
42
44
  """
43
45
  def __init__(
44
46
  self,
@@ -48,6 +50,7 @@ class BatchNorm(torch.nn.Module):
48
50
  affine: bool = True,
49
51
  track_running_stats: bool = True,
50
52
  allow_single_element: bool = False,
53
+ device: Optional[torch.device] = None,
51
54
  ):
52
55
  super().__init__()
53
56
 
@@ -56,7 +59,7 @@ class BatchNorm(torch.nn.Module):
56
59
  "'track_running_stats' to be set to `True`")
57
60
 
58
61
  self.module = torch.nn.BatchNorm1d(in_channels, eps, momentum, affine,
59
- track_running_stats)
62
+ track_running_stats, device=device)
60
63
  self.in_channels = in_channels
61
64
  self.allow_single_element = allow_single_element
62
65
 
@@ -114,6 +117,8 @@ class HeteroBatchNorm(torch.nn.Module):
114
117
  :obj:`False`, this module does not track such statistics and always
115
118
  uses batch statistics in both training and eval modes.
116
119
  (default: :obj:`True`)
120
+ device (torch.device, optional): The device to use for the module.
121
+ (default: :obj:`None`)
117
122
  """
118
123
  def __init__(
119
124
  self,
@@ -123,6 +128,7 @@ class HeteroBatchNorm(torch.nn.Module):
123
128
  momentum: Optional[float] = 0.1,
124
129
  affine: bool = True,
125
130
  track_running_stats: bool = True,
131
+ device: Optional[torch.device] = None,
126
132
  ):
127
133
  super().__init__()
128
134
 
@@ -134,17 +140,21 @@ class HeteroBatchNorm(torch.nn.Module):
134
140
  self.track_running_stats = track_running_stats
135
141
 
136
142
  if self.affine:
137
- self.weight = Parameter(torch.empty(num_types, in_channels))
138
- self.bias = Parameter(torch.empty(num_types, in_channels))
143
+ self.weight = Parameter(
144
+ torch.empty(num_types, in_channels, device=device))
145
+ self.bias = Parameter(
146
+ torch.empty(num_types, in_channels, device=device))
139
147
  else:
140
148
  self.register_parameter('weight', None)
141
149
  self.register_parameter('bias', None)
142
150
 
143
151
  if self.track_running_stats:
144
- self.register_buffer('running_mean',
145
- torch.empty(num_types, in_channels))
146
- self.register_buffer('running_var',
147
- torch.empty(num_types, in_channels))
152
+ self.register_buffer(
153
+ 'running_mean',
154
+ torch.empty(num_types, in_channels, device=device))
155
+ self.register_buffer(
156
+ 'running_var',
157
+ torch.empty(num_types, in_channels, device=device))
148
158
  self.register_buffer('num_batches_tracked', torch.tensor(0))
149
159
  else:
150
160
  self.register_buffer('running_mean', None)
@@ -1,3 +1,5 @@
1
+ from typing import Optional
2
+
1
3
  import torch
2
4
  from torch import Tensor
3
5
  from torch.nn import BatchNorm1d, Linear
@@ -39,6 +41,8 @@ class DiffGroupNorm(torch.nn.Module):
39
41
  :obj:`False`, this module does not track such statistics and always
40
42
  uses batch statistics in both training and eval modes.
41
43
  (default: :obj:`True`)
44
+ device (torch.device, optional): The device to use for the module.
45
+ (default: :obj:`None`)
42
46
  """
43
47
  def __init__(
44
48
  self,
@@ -49,6 +53,7 @@ class DiffGroupNorm(torch.nn.Module):
49
53
  momentum: float = 0.1,
50
54
  affine: bool = True,
51
55
  track_running_stats: bool = True,
56
+ device: Optional[torch.device] = None,
52
57
  ):
53
58
  super().__init__()
54
59
 
@@ -56,9 +61,9 @@ class DiffGroupNorm(torch.nn.Module):
56
61
  self.groups = groups
57
62
  self.lamda = lamda
58
63
 
59
- self.lin = Linear(in_channels, groups, bias=False)
64
+ self.lin = Linear(in_channels, groups, bias=False, device=device)
60
65
  self.norm = BatchNorm1d(groups * in_channels, eps, momentum, affine,
61
- track_running_stats)
66
+ track_running_stats, device=device)
62
67
 
63
68
  self.reset_parameters()
64
69
 
@@ -26,16 +26,21 @@ class GraphNorm(torch.nn.Module):
26
26
  in_channels (int): Size of each input sample.
27
27
  eps (float, optional): A value added to the denominator for numerical
28
28
  stability. (default: :obj:`1e-5`)
29
+ device (torch.device, optional): The device to use for the module.
30
+ (default: :obj:`None`)
29
31
  """
30
- def __init__(self, in_channels: int, eps: float = 1e-5):
32
+ def __init__(self, in_channels: int, eps: float = 1e-5,
33
+ device: Optional[torch.device] = None):
31
34
  super().__init__()
32
35
 
33
36
  self.in_channels = in_channels
34
37
  self.eps = eps
35
38
 
36
- self.weight = torch.nn.Parameter(torch.empty(in_channels))
37
- self.bias = torch.nn.Parameter(torch.empty(in_channels))
38
- self.mean_scale = torch.nn.Parameter(torch.empty(in_channels))
39
+ self.weight = torch.nn.Parameter(
40
+ torch.empty(in_channels, device=device))
41
+ self.bias = torch.nn.Parameter(torch.empty(in_channels, device=device))
42
+ self.mean_scale = torch.nn.Parameter(
43
+ torch.empty(in_channels, device=device))
39
44
 
40
45
  self.reset_parameters()
41
46
 
@@ -1,5 +1,6 @@
1
1
  from typing import Optional
2
2
 
3
+ import torch
3
4
  import torch.nn.functional as F
4
5
  from torch import Tensor
5
6
  from torch.nn.modules.instancenorm import _InstanceNorm
@@ -36,6 +37,8 @@ class InstanceNorm(_InstanceNorm):
36
37
  :obj:`False`, this module does not track such statistics and always
37
38
  uses instance statistics in both training and eval modes.
38
39
  (default: :obj:`False`)
40
+ device (torch.device, optional): The device to use for the module.
41
+ (default: :obj:`None`)
39
42
  """
40
43
  def __init__(
41
44
  self,
@@ -44,9 +47,10 @@ class InstanceNorm(_InstanceNorm):
44
47
  momentum: float = 0.1,
45
48
  affine: bool = False,
46
49
  track_running_stats: bool = False,
50
+ device: Optional[torch.device] = None,
47
51
  ):
48
52
  super().__init__(in_channels, eps, momentum, affine,
49
- track_running_stats)
53
+ track_running_stats, device=device)
50
54
 
51
55
  def reset_parameters(self):
52
56
  r"""Resets all learnable parameters of the module."""
@@ -35,6 +35,8 @@ class LayerNorm(torch.nn.Module):
35
35
  is used, each graph will be considered as an element to be
36
36
  normalized. If `"node"` is used, each node will be considered as
37
37
  an element to be normalized. (default: :obj:`"graph"`)
38
+ device (torch.device, optional): The device to use for the module.
39
+ (default: :obj:`None`)
38
40
  """
39
41
  def __init__(
40
42
  self,
@@ -42,6 +44,7 @@ class LayerNorm(torch.nn.Module):
42
44
  eps: float = 1e-5,
43
45
  affine: bool = True,
44
46
  mode: str = 'graph',
47
+ device: Optional[torch.device] = None,
45
48
  ):
46
49
  super().__init__()
47
50
 
@@ -51,8 +54,8 @@ class LayerNorm(torch.nn.Module):
51
54
  self.mode = mode
52
55
 
53
56
  if affine:
54
- self.weight = Parameter(torch.empty(in_channels))
55
- self.bias = Parameter(torch.empty(in_channels))
57
+ self.weight = Parameter(torch.empty(in_channels, device=device))
58
+ self.bias = Parameter(torch.empty(in_channels, device=device))
56
59
  else:
57
60
  self.register_parameter('weight', None)
58
61
  self.register_parameter('bias', None)
@@ -134,6 +137,8 @@ class HeteroLayerNorm(torch.nn.Module):
134
137
  normalization (:obj:`"node"`). If `"node"` is used, each node will
135
138
  be considered as an element to be normalized.
136
139
  (default: :obj:`"node"`)
140
+ device (torch.device, optional): The device to use for the module.
141
+ (default: :obj:`None`)
137
142
  """
138
143
  def __init__(
139
144
  self,
@@ -142,6 +147,7 @@ class HeteroLayerNorm(torch.nn.Module):
142
147
  eps: float = 1e-5,
143
148
  affine: bool = True,
144
149
  mode: str = 'node',
150
+ device: Optional[torch.device] = None,
145
151
  ):
146
152
  super().__init__()
147
153
  assert mode == 'node'
@@ -152,8 +158,10 @@ class HeteroLayerNorm(torch.nn.Module):
152
158
  self.affine = affine
153
159
 
154
160
  if affine:
155
- self.weight = Parameter(torch.empty(num_types, in_channels))
156
- self.bias = Parameter(torch.empty(num_types, in_channels))
161
+ self.weight = Parameter(
162
+ torch.empty(num_types, in_channels, device=device))
163
+ self.bias = Parameter(
164
+ torch.empty(num_types, in_channels, device=device))
157
165
  else:
158
166
  self.register_parameter('weight', None)
159
167
  self.register_parameter('bias', None)
@@ -1,3 +1,5 @@
1
+ from typing import Optional
2
+
1
3
  import torch
2
4
  import torch.nn.functional as F
3
5
  from torch import Tensor
@@ -19,10 +21,14 @@ class MessageNorm(torch.nn.Module):
19
21
  learn_scale (bool, optional): If set to :obj:`True`, will learn the
20
22
  scaling factor :math:`s` of message normalization.
21
23
  (default: :obj:`False`)
24
+ device (torch.device, optional): The device to use for the module.
25
+ (default: :obj:`None`)
22
26
  """
23
- def __init__(self, learn_scale: bool = False):
27
+ def __init__(self, learn_scale: bool = False,
28
+ device: Optional[torch.device] = None):
24
29
  super().__init__()
25
- self.scale = Parameter(torch.empty(1), requires_grad=learn_scale)
30
+ self.scale = Parameter(torch.empty(1, device=device),
31
+ requires_grad=learn_scale)
26
32
  self.reset_parameters()
27
33
 
28
34
  def reset_parameters(self):