pyg-nightly 2.7.0.dev20250701__py3-none-any.whl → 2.7.0.dev20250703__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

@@ -0,0 +1,304 @@
1
+ from itertools import product
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from torch_geometric.nn import knn_graph
8
+ from torch_geometric.nn.conv import MessagePassing
9
+ from torch_geometric.utils import to_dense_adj, to_dense_batch
10
+
11
+
12
+ class PositionWiseFeedForward(torch.nn.Module):
13
+ def __init__(self, in_channels: int, hidden_channels: int) -> None:
14
+ super().__init__()
15
+ self.out = torch.nn.Sequential(
16
+ torch.nn.Linear(in_channels, hidden_channels),
17
+ torch.nn.GELU(),
18
+ torch.nn.Linear(hidden_channels, in_channels),
19
+ )
20
+
21
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
22
+ return self.out(x)
23
+
24
+
25
+ class PositionalEncoding(torch.nn.Module):
26
+ def __init__(self, hidden_channels: int,
27
+ max_relative_feature: int = 32) -> None:
28
+ super().__init__()
29
+ self.max_relative_feature = max_relative_feature
30
+ self.emb = torch.nn.Embedding(2 * max_relative_feature + 2,
31
+ hidden_channels)
32
+
33
+ def forward(self, offset, mask) -> torch.Tensor:
34
+ d = torch.clip(offset + self.max_relative_feature, 0,
35
+ 2 * self.max_relative_feature) * mask + (1 - mask) * (
36
+ 2 * self.max_relative_feature + 1) # noqa: E501
37
+ return self.emb(d.long())
38
+
39
+
40
+ class Encoder(MessagePassing):
41
+ def __init__(
42
+ self,
43
+ in_channels: int,
44
+ hidden_channels: int,
45
+ dropout: float = 0.1,
46
+ scale: float = 30,
47
+ ) -> None:
48
+ super().__init__()
49
+ self.out_v = torch.nn.Sequential(
50
+ torch.nn.Linear(in_channels, hidden_channels),
51
+ torch.nn.GELU(),
52
+ torch.nn.Linear(hidden_channels, hidden_channels),
53
+ torch.nn.GELU(),
54
+ torch.nn.Linear(hidden_channels, hidden_channels),
55
+ )
56
+ self.out_e = torch.nn.Sequential(
57
+ torch.nn.Linear(in_channels, hidden_channels),
58
+ torch.nn.GELU(),
59
+ torch.nn.Linear(hidden_channels, hidden_channels),
60
+ torch.nn.GELU(),
61
+ torch.nn.Linear(hidden_channels, hidden_channels),
62
+ )
63
+ self.dropout1 = torch.nn.Dropout(dropout)
64
+ self.dropout2 = torch.nn.Dropout(dropout)
65
+ self.dropout3 = torch.nn.Dropout(dropout)
66
+ self.norm1 = torch.nn.LayerNorm(hidden_channels)
67
+ self.norm2 = torch.nn.LayerNorm(hidden_channels)
68
+ self.norm3 = torch.nn.LayerNorm(hidden_channels)
69
+ self.scale = scale
70
+ self.dense = PositionWiseFeedForward(hidden_channels,
71
+ hidden_channels * 4)
72
+
73
+ def forward(
74
+ self,
75
+ x: torch.Tensor,
76
+ edge_index: torch.Tensor,
77
+ edge_attr: torch.Tensor,
78
+ ) -> torch.Tensor:
79
+ # x: [N, d_v]
80
+ # edge_index: [2, E]
81
+ # edge_attr: [E, d_e]
82
+ # update node features
83
+ h_message = self.propagate(x=x, edge_index=edge_index,
84
+ edge_attr=edge_attr)
85
+ dh = h_message / self.scale
86
+ x = self.norm1(x + self.dropout1(dh))
87
+ dh = self.dense(x)
88
+ x = self.norm2(x + self.dropout2(dh))
89
+ # update edge features
90
+ row, col = edge_index
91
+ x_i, x_j = x[row], x[col]
92
+ h_e = torch.cat([x_i, x_j, edge_attr], dim=-1)
93
+ h_e = self.out_e(h_e)
94
+ edge_attr = self.norm3(edge_attr + self.dropout3(h_e))
95
+ return x, edge_attr
96
+
97
+ def message(self, x_i: torch.Tensor, x_j: torch.Tensor,
98
+ edge_attr: torch.Tensor) -> torch.Tensor:
99
+ h = torch.cat([x_i, x_j, edge_attr], dim=-1) # [E, 2*d_v + d_e]
100
+ h = self.out_e(h) # [E, d_e]
101
+ return h
102
+
103
+
104
+ class Decoder(MessagePassing):
105
+ def __init__(
106
+ self,
107
+ in_channels: int,
108
+ hidden_channels: int,
109
+ dropout: float = 0.1,
110
+ scale: float = 30,
111
+ ) -> None:
112
+ super().__init__()
113
+ self.out_v = torch.nn.Sequential(
114
+ torch.nn.Linear(in_channels, hidden_channels),
115
+ torch.nn.GELU(),
116
+ torch.nn.Linear(hidden_channels, hidden_channels),
117
+ torch.nn.GELU(),
118
+ torch.nn.Linear(hidden_channels, hidden_channels),
119
+ )
120
+ self.dropout1 = torch.nn.Dropout(dropout)
121
+ self.dropout2 = torch.nn.Dropout(dropout)
122
+ self.norm1 = torch.nn.LayerNorm(hidden_channels)
123
+ self.norm2 = torch.nn.LayerNorm(hidden_channels)
124
+ self.scale = scale
125
+ self.dense = PositionWiseFeedForward(hidden_channels,
126
+ hidden_channels * 4)
127
+
128
+ def forward(
129
+ self,
130
+ x: torch.Tensor,
131
+ edge_index: torch.Tensor,
132
+ edge_attr: torch.Tensor,
133
+ x_label: torch.Tensor,
134
+ mask: torch.Tensor,
135
+ ) -> torch.Tensor:
136
+ # x: [N, d_v]
137
+ # edge_index: [2, E]
138
+ # edge_attr: [E, d_e]
139
+ h_message = self.propagate(x=x, x_label=x_label, edge_index=edge_index,
140
+ edge_attr=edge_attr, mask=mask)
141
+ dh = h_message / self.scale
142
+ x = self.norm1(x + self.dropout1(dh))
143
+ dh = self.dense(x)
144
+ x = self.norm2(x + self.dropout2(dh))
145
+ return x
146
+
147
+ def message(self, x_i: torch.Tensor, x_j: torch.Tensor,
148
+ x_label_j: torch.Tensor, edge_attr: torch.Tensor,
149
+ mask: torch.Tensor) -> torch.Tensor:
150
+ h_1 = torch.cat([x_j, edge_attr, x_label_j], dim=-1)
151
+ h_0 = torch.cat([x_j, edge_attr, torch.zeros_like(x_label_j)], dim=-1)
152
+ h = h_1 * mask + h_0 * (1 - mask)
153
+ h = torch.concat([x_i, h], dim=-1)
154
+ h = self.out_v(h)
155
+ return h
156
+
157
+
158
+ class ProteinMPNN(torch.nn.Module):
159
+ def __init__(
160
+ self,
161
+ hidden_dim: int = 128,
162
+ num_encoder_layers: int = 3,
163
+ num_decoder_layers: int = 3,
164
+ num_neighbors: int = 30,
165
+ num_rbf: int = 16,
166
+ dropout: float = 0.1,
167
+ augment_eps: float = 0.2,
168
+ num_positional_embedding: int = 16,
169
+ vocab_size: int = 21,
170
+ ) -> None:
171
+ super().__init__()
172
+ self.augment_eps = augment_eps
173
+ self.hidden_dim = hidden_dim
174
+ self.num_neighbors = num_neighbors
175
+ self.num_rbf = num_rbf
176
+ self.embedding = PositionalEncoding(num_positional_embedding)
177
+ self.edge_mlp = torch.nn.Sequential(
178
+ torch.nn.Linear(num_positional_embedding + 400, hidden_dim),
179
+ torch.nn.LayerNorm(hidden_dim),
180
+ torch.nn.Linear(hidden_dim, hidden_dim),
181
+ )
182
+ self.label_embedding = torch.nn.Embedding(vocab_size, hidden_dim)
183
+ self.encoder_layers = torch.nn.ModuleList([
184
+ Encoder(hidden_dim * 3, hidden_dim, dropout)
185
+ for _ in range(num_encoder_layers)
186
+ ])
187
+
188
+ self.decoder_layers = torch.nn.ModuleList([
189
+ Decoder(hidden_dim * 4, hidden_dim, dropout)
190
+ for _ in range(num_decoder_layers)
191
+ ])
192
+ self.output = torch.nn.Linear(hidden_dim, vocab_size)
193
+
194
+ self.reset_parameters()
195
+
196
+ def reset_parameters(self):
197
+ for p in self.parameters():
198
+ if p.dim() > 1:
199
+ torch.nn.init.xavier_uniform_(p)
200
+
201
+ def _featurize(
202
+ self,
203
+ x: torch.Tensor,
204
+ mask: torch.Tensor,
205
+ batch: torch.Tensor,
206
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
207
+ N, Ca, C, O = (x[:, i, :] for i in range(4)) # noqa: E741
208
+ b = Ca - N
209
+ c = C - Ca
210
+ a = torch.cross(b, c, dim=-1)
211
+ Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + Ca
212
+
213
+ valid_mask = mask.bool()
214
+ valid_Ca = Ca[valid_mask]
215
+ valid_batch = batch[valid_mask]
216
+
217
+ edge_index = knn_graph(valid_Ca, k=self.num_neighbors,
218
+ batch=valid_batch, loop=True)
219
+
220
+ row, col = edge_index
221
+ original_indices = torch.arange(Ca.size(0),
222
+ device=x.device)[valid_mask]
223
+ edge_index_original = torch.stack(
224
+ [original_indices[row], original_indices[col]], dim=0)
225
+ row, col = edge_index_original
226
+
227
+ rbf_all = []
228
+ for A, B in list(product([N, Ca, C, O, Cb], repeat=2)):
229
+ distances = torch.sqrt(torch.sum((A[row] - B[col])**2, 1) + 1e-6)
230
+ rbf = self._rbf(distances)
231
+ rbf_all.append(rbf)
232
+
233
+ return edge_index_original, torch.cat(rbf_all, dim=-1)
234
+
235
+ def _rbf(self, D: torch.Tensor) -> torch.Tensor:
236
+ D_min, D_max, D_count = 2., 22., self.num_rbf
237
+ D_mu = torch.linspace(D_min, D_max, D_count, device=D.device)
238
+ D_mu = D_mu.view([1, -1])
239
+ D_sigma = (D_max - D_min) / D_count
240
+ D_expand = torch.unsqueeze(D, -1)
241
+ RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)
242
+ return RBF
243
+
244
+ def forward(
245
+ self,
246
+ x: torch.Tensor,
247
+ chain_seq_label: torch.Tensor,
248
+ mask: torch.Tensor,
249
+ chain_mask_all: torch.Tensor,
250
+ residue_idx: torch.Tensor,
251
+ chain_encoding_all: torch.Tensor,
252
+ batch: torch.Tensor,
253
+ ) -> torch.Tensor:
254
+ device = x.device
255
+ if self.training and self.augment_eps > 0:
256
+ x = x + self.augment_eps * torch.randn_like(x)
257
+
258
+ edge_index, edge_attr = self._featurize(x, mask, batch)
259
+
260
+ row, col = edge_index
261
+ offset = residue_idx[row] - residue_idx[col]
262
+ # find self vs non-self interaction
263
+ e_chains = ((chain_encoding_all[row] -
264
+ chain_encoding_all[col]) == 0).long()
265
+ e_pos = self.embedding(offset, e_chains)
266
+ h_e = self.edge_mlp(torch.cat([edge_attr, e_pos], dim=-1))
267
+ h_v = torch.zeros(x.size(0), self.hidden_dim, device=x.device)
268
+
269
+ # encoder
270
+ for encoder in self.encoder_layers:
271
+ h_v, h_e = encoder(h_v, edge_index, h_e)
272
+
273
+ # mask
274
+ h_label = self.label_embedding(chain_seq_label)
275
+ batch_chain_mask_all, _ = to_dense_batch(chain_mask_all * mask,
276
+ batch) # [B, N]
277
+ # 0 - visible - encoder, 1 - masked - decoder
278
+ decoding_order = torch.argsort(
279
+ (batch_chain_mask_all + 1e-4) * (torch.abs(
280
+ torch.randn(batch_chain_mask_all.shape, device=device))))
281
+ mask_size = batch_chain_mask_all.size(1)
282
+ permutation_matrix_reverse = F.one_hot(decoding_order,
283
+ num_classes=mask_size).float()
284
+ order_mask_backward = torch.einsum(
285
+ 'ij, biq, bjp->bqp',
286
+ 1 - torch.triu(torch.ones(mask_size, mask_size, device=device)),
287
+ permutation_matrix_reverse,
288
+ permutation_matrix_reverse,
289
+ )
290
+ adj = to_dense_adj(edge_index, batch)
291
+ mask_attend = order_mask_backward[adj.bool()].unsqueeze(-1)
292
+
293
+ # decoder
294
+ for decoder in self.decoder_layers:
295
+ h_v = decoder(
296
+ h_v,
297
+ edge_index,
298
+ h_e,
299
+ h_label,
300
+ mask_attend,
301
+ )
302
+
303
+ logits = self.output(h_v)
304
+ return F.log_softmax(logits, dim=-1)
@@ -39,6 +39,8 @@ class BatchNorm(torch.nn.Module):
39
39
  with only a single element will work as during in evaluation.
40
40
  That is the running mean and variance will be used.
41
41
  Requires :obj:`track_running_stats=True`. (default: :obj:`False`)
42
+ device (torch.device, optional): The device to use for the module.
43
+ (default: :obj:`None`)
42
44
  """
43
45
  def __init__(
44
46
  self,
@@ -48,6 +50,7 @@ class BatchNorm(torch.nn.Module):
48
50
  affine: bool = True,
49
51
  track_running_stats: bool = True,
50
52
  allow_single_element: bool = False,
53
+ device: Optional[torch.device] = None,
51
54
  ):
52
55
  super().__init__()
53
56
 
@@ -56,7 +59,7 @@ class BatchNorm(torch.nn.Module):
56
59
  "'track_running_stats' to be set to `True`")
57
60
 
58
61
  self.module = torch.nn.BatchNorm1d(in_channels, eps, momentum, affine,
59
- track_running_stats)
62
+ track_running_stats, device=device)
60
63
  self.in_channels = in_channels
61
64
  self.allow_single_element = allow_single_element
62
65
 
@@ -114,6 +117,8 @@ class HeteroBatchNorm(torch.nn.Module):
114
117
  :obj:`False`, this module does not track such statistics and always
115
118
  uses batch statistics in both training and eval modes.
116
119
  (default: :obj:`True`)
120
+ device (torch.device, optional): The device to use for the module.
121
+ (default: :obj:`None`)
117
122
  """
118
123
  def __init__(
119
124
  self,
@@ -123,6 +128,7 @@ class HeteroBatchNorm(torch.nn.Module):
123
128
  momentum: Optional[float] = 0.1,
124
129
  affine: bool = True,
125
130
  track_running_stats: bool = True,
131
+ device: Optional[torch.device] = None,
126
132
  ):
127
133
  super().__init__()
128
134
 
@@ -134,17 +140,21 @@ class HeteroBatchNorm(torch.nn.Module):
134
140
  self.track_running_stats = track_running_stats
135
141
 
136
142
  if self.affine:
137
- self.weight = Parameter(torch.empty(num_types, in_channels))
138
- self.bias = Parameter(torch.empty(num_types, in_channels))
143
+ self.weight = Parameter(
144
+ torch.empty(num_types, in_channels, device=device))
145
+ self.bias = Parameter(
146
+ torch.empty(num_types, in_channels, device=device))
139
147
  else:
140
148
  self.register_parameter('weight', None)
141
149
  self.register_parameter('bias', None)
142
150
 
143
151
  if self.track_running_stats:
144
- self.register_buffer('running_mean',
145
- torch.empty(num_types, in_channels))
146
- self.register_buffer('running_var',
147
- torch.empty(num_types, in_channels))
152
+ self.register_buffer(
153
+ 'running_mean',
154
+ torch.empty(num_types, in_channels, device=device))
155
+ self.register_buffer(
156
+ 'running_var',
157
+ torch.empty(num_types, in_channels, device=device))
148
158
  self.register_buffer('num_batches_tracked', torch.tensor(0))
149
159
  else:
150
160
  self.register_buffer('running_mean', None)
@@ -1,3 +1,5 @@
1
+ from typing import Optional
2
+
1
3
  import torch
2
4
  from torch import Tensor
3
5
  from torch.nn import BatchNorm1d, Linear
@@ -39,6 +41,8 @@ class DiffGroupNorm(torch.nn.Module):
39
41
  :obj:`False`, this module does not track such statistics and always
40
42
  uses batch statistics in both training and eval modes.
41
43
  (default: :obj:`True`)
44
+ device (torch.device, optional): The device to use for the module.
45
+ (default: :obj:`None`)
42
46
  """
43
47
  def __init__(
44
48
  self,
@@ -49,6 +53,7 @@ class DiffGroupNorm(torch.nn.Module):
49
53
  momentum: float = 0.1,
50
54
  affine: bool = True,
51
55
  track_running_stats: bool = True,
56
+ device: Optional[torch.device] = None,
52
57
  ):
53
58
  super().__init__()
54
59
 
@@ -56,9 +61,9 @@ class DiffGroupNorm(torch.nn.Module):
56
61
  self.groups = groups
57
62
  self.lamda = lamda
58
63
 
59
- self.lin = Linear(in_channels, groups, bias=False)
64
+ self.lin = Linear(in_channels, groups, bias=False, device=device)
60
65
  self.norm = BatchNorm1d(groups * in_channels, eps, momentum, affine,
61
- track_running_stats)
66
+ track_running_stats, device=device)
62
67
 
63
68
  self.reset_parameters()
64
69
 
@@ -26,16 +26,21 @@ class GraphNorm(torch.nn.Module):
26
26
  in_channels (int): Size of each input sample.
27
27
  eps (float, optional): A value added to the denominator for numerical
28
28
  stability. (default: :obj:`1e-5`)
29
+ device (torch.device, optional): The device to use for the module.
30
+ (default: :obj:`None`)
29
31
  """
30
- def __init__(self, in_channels: int, eps: float = 1e-5):
32
+ def __init__(self, in_channels: int, eps: float = 1e-5,
33
+ device: Optional[torch.device] = None):
31
34
  super().__init__()
32
35
 
33
36
  self.in_channels = in_channels
34
37
  self.eps = eps
35
38
 
36
- self.weight = torch.nn.Parameter(torch.empty(in_channels))
37
- self.bias = torch.nn.Parameter(torch.empty(in_channels))
38
- self.mean_scale = torch.nn.Parameter(torch.empty(in_channels))
39
+ self.weight = torch.nn.Parameter(
40
+ torch.empty(in_channels, device=device))
41
+ self.bias = torch.nn.Parameter(torch.empty(in_channels, device=device))
42
+ self.mean_scale = torch.nn.Parameter(
43
+ torch.empty(in_channels, device=device))
39
44
 
40
45
  self.reset_parameters()
41
46
 
@@ -1,5 +1,6 @@
1
1
  from typing import Optional
2
2
 
3
+ import torch
3
4
  import torch.nn.functional as F
4
5
  from torch import Tensor
5
6
  from torch.nn.modules.instancenorm import _InstanceNorm
@@ -36,6 +37,8 @@ class InstanceNorm(_InstanceNorm):
36
37
  :obj:`False`, this module does not track such statistics and always
37
38
  uses instance statistics in both training and eval modes.
38
39
  (default: :obj:`False`)
40
+ device (torch.device, optional): The device to use for the module.
41
+ (default: :obj:`None`)
39
42
  """
40
43
  def __init__(
41
44
  self,
@@ -44,9 +47,10 @@ class InstanceNorm(_InstanceNorm):
44
47
  momentum: float = 0.1,
45
48
  affine: bool = False,
46
49
  track_running_stats: bool = False,
50
+ device: Optional[torch.device] = None,
47
51
  ):
48
52
  super().__init__(in_channels, eps, momentum, affine,
49
- track_running_stats)
53
+ track_running_stats, device=device)
50
54
 
51
55
  def reset_parameters(self):
52
56
  r"""Resets all learnable parameters of the module."""
@@ -35,6 +35,8 @@ class LayerNorm(torch.nn.Module):
35
35
  is used, each graph will be considered as an element to be
36
36
  normalized. If `"node"` is used, each node will be considered as
37
37
  an element to be normalized. (default: :obj:`"graph"`)
38
+ device (torch.device, optional): The device to use for the module.
39
+ (default: :obj:`None`)
38
40
  """
39
41
  def __init__(
40
42
  self,
@@ -42,6 +44,7 @@ class LayerNorm(torch.nn.Module):
42
44
  eps: float = 1e-5,
43
45
  affine: bool = True,
44
46
  mode: str = 'graph',
47
+ device: Optional[torch.device] = None,
45
48
  ):
46
49
  super().__init__()
47
50
 
@@ -51,8 +54,8 @@ class LayerNorm(torch.nn.Module):
51
54
  self.mode = mode
52
55
 
53
56
  if affine:
54
- self.weight = Parameter(torch.empty(in_channels))
55
- self.bias = Parameter(torch.empty(in_channels))
57
+ self.weight = Parameter(torch.empty(in_channels, device=device))
58
+ self.bias = Parameter(torch.empty(in_channels, device=device))
56
59
  else:
57
60
  self.register_parameter('weight', None)
58
61
  self.register_parameter('bias', None)
@@ -134,6 +137,8 @@ class HeteroLayerNorm(torch.nn.Module):
134
137
  normalization (:obj:`"node"`). If `"node"` is used, each node will
135
138
  be considered as an element to be normalized.
136
139
  (default: :obj:`"node"`)
140
+ device (torch.device, optional): The device to use for the module.
141
+ (default: :obj:`None`)
137
142
  """
138
143
  def __init__(
139
144
  self,
@@ -142,6 +147,7 @@ class HeteroLayerNorm(torch.nn.Module):
142
147
  eps: float = 1e-5,
143
148
  affine: bool = True,
144
149
  mode: str = 'node',
150
+ device: Optional[torch.device] = None,
145
151
  ):
146
152
  super().__init__()
147
153
  assert mode == 'node'
@@ -152,8 +158,10 @@ class HeteroLayerNorm(torch.nn.Module):
152
158
  self.affine = affine
153
159
 
154
160
  if affine:
155
- self.weight = Parameter(torch.empty(num_types, in_channels))
156
- self.bias = Parameter(torch.empty(num_types, in_channels))
161
+ self.weight = Parameter(
162
+ torch.empty(num_types, in_channels, device=device))
163
+ self.bias = Parameter(
164
+ torch.empty(num_types, in_channels, device=device))
157
165
  else:
158
166
  self.register_parameter('weight', None)
159
167
  self.register_parameter('bias', None)
@@ -1,3 +1,5 @@
1
+ from typing import Optional
2
+
1
3
  import torch
2
4
  import torch.nn.functional as F
3
5
  from torch import Tensor
@@ -19,10 +21,14 @@ class MessageNorm(torch.nn.Module):
19
21
  learn_scale (bool, optional): If set to :obj:`True`, will learn the
20
22
  scaling factor :math:`s` of message normalization.
21
23
  (default: :obj:`False`)
24
+ device (torch.device, optional): The device to use for the module.
25
+ (default: :obj:`None`)
22
26
  """
23
- def __init__(self, learn_scale: bool = False):
27
+ def __init__(self, learn_scale: bool = False,
28
+ device: Optional[torch.device] = None):
24
29
  super().__init__()
25
- self.scale = Parameter(torch.empty(1), requires_grad=learn_scale)
30
+ self.scale = Parameter(torch.empty(1, device=device),
31
+ requires_grad=learn_scale)
26
32
  self.reset_parameters()
27
33
 
28
34
  def reset_parameters(self):
@@ -452,15 +452,22 @@ def to_cugraph(
452
452
  g = cugraph.Graph(directed=directed)
453
453
  df = cudf.from_dlpack(to_dlpack(edge_index.t()))
454
454
 
455
+ df = cudf.DataFrame({
456
+ 'source':
457
+ cudf.from_dlpack(to_dlpack(edge_index[0])),
458
+ 'destination':
459
+ cudf.from_dlpack(to_dlpack(edge_index[1])),
460
+ })
461
+
455
462
  if edge_weight is not None:
456
463
  assert edge_weight.dim() == 1
457
- df['2'] = cudf.from_dlpack(to_dlpack(edge_weight))
464
+ df['weight'] = cudf.from_dlpack(to_dlpack(edge_weight))
458
465
 
459
466
  g.from_cudf_edgelist(
460
467
  df,
461
- source=0,
462
- destination=1,
463
- edge_attr='2' if edge_weight is not None else None,
468
+ source='source',
469
+ destination='destination',
470
+ edge_attr='weight' if edge_weight is not None else None,
464
471
  renumber=relabel_nodes,
465
472
  )
466
473
 
@@ -476,13 +483,13 @@ def from_cugraph(g: Any) -> Tuple[Tensor, Optional[Tensor]]:
476
483
  """
477
484
  df = g.view_edge_list()
478
485
 
479
- src = from_dlpack(df[0].to_dlpack()).long()
480
- dst = from_dlpack(df[1].to_dlpack()).long()
486
+ src = from_dlpack(df[g.source_columns].to_dlpack()).long()
487
+ dst = from_dlpack(df[g.destination_columns].to_dlpack()).long()
481
488
  edge_index = torch.stack([src, dst], dim=0)
482
489
 
483
490
  edge_weight = None
484
- if '2' in df:
485
- edge_weight = from_dlpack(df['2'].to_dlpack())
491
+ if g.weight_column is not None:
492
+ edge_weight = from_dlpack(df[g.weight_column].to_dlpack())
486
493
 
487
494
  return edge_index, edge_weight
488
495
 
@@ -148,7 +148,7 @@ def from_smiles(
148
148
  """
149
149
  from rdkit import Chem, RDLogger
150
150
 
151
- RDLogger.DisableLog('rdApp.*') # type: ignore
151
+ RDLogger.DisableLog('rdApp.*')
152
152
 
153
153
  mol = Chem.MolFromSmiles(smiles)
154
154