pyg-nightly 2.7.0.dev20241124__py3-none-any.whl → 2.7.0.dev20241129__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20241124.dist-info → pyg_nightly-2.7.0.dev20241129.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20241124.dist-info → pyg_nightly-2.7.0.dev20241129.dist-info}/RECORD +20 -14
- torch_geometric/__init__.py +1 -1
- torch_geometric/data/__init__.py +5 -0
- torch_geometric/data/dataset.py +1 -1
- torch_geometric/data/large_graph_indexer.py +677 -0
- torch_geometric/datasets/__init__.py +2 -0
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/loader/__init__.py +2 -0
- torch_geometric/loader/rag_loader.py +106 -0
- torch_geometric/nn/models/__init__.py +2 -0
- torch_geometric/nn/models/g_retriever.py +12 -1
- torch_geometric/nn/models/git_mol.py +336 -0
- torch_geometric/nn/nlp/__init__.py +2 -0
- torch_geometric/nn/nlp/sentence_transformer.py +30 -0
- torch_geometric/nn/nlp/vision_transformer.py +33 -0
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/sampler/base.py +8 -0
- {pyg_nightly-2.7.0.dev20241124.dist-info → pyg_nightly-2.7.0.dev20241129.dist-info}/WHEEL +0 -0
@@ -0,0 +1,336 @@
|
|
1
|
+
from typing import List, Optional
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.nn.functional as F
|
5
|
+
from torch import Tensor
|
6
|
+
from torch.nn import BatchNorm1d, LayerNorm, Linear, ReLU, Sequential
|
7
|
+
|
8
|
+
from torch_geometric.nn import GINEConv
|
9
|
+
from torch_geometric.nn.nlp import SentenceTransformer, VisionTransformer
|
10
|
+
from torch_geometric.utils import add_self_loops, to_dense_batch
|
11
|
+
|
12
|
+
|
13
|
+
class GraphEncoder(torch.nn.Module):
|
14
|
+
def __init__(
|
15
|
+
self,
|
16
|
+
num_layers: int,
|
17
|
+
in_channels: int,
|
18
|
+
dropout: float = 0.,
|
19
|
+
num_atom_type: int = 120,
|
20
|
+
num_chirality_tag: int = 3,
|
21
|
+
num_bond_type: int = 6,
|
22
|
+
num_bond_direction: int = 3,
|
23
|
+
) -> None:
|
24
|
+
super().__init__()
|
25
|
+
|
26
|
+
self.num_layers = num_layers
|
27
|
+
self.dropout = dropout
|
28
|
+
|
29
|
+
self.x_embed1 = torch.nn.Embedding(num_atom_type, in_channels)
|
30
|
+
self.x_embed2 = torch.nn.Embedding(num_chirality_tag, in_channels)
|
31
|
+
self.edge_embed1 = torch.nn.Embedding(num_bond_type, in_channels)
|
32
|
+
self.edge_embed2 = torch.nn.Embedding(num_bond_direction, in_channels)
|
33
|
+
|
34
|
+
self.gnns = torch.nn.ModuleList()
|
35
|
+
self.batch_norms = torch.nn.ModuleList()
|
36
|
+
for _ in range(num_layers):
|
37
|
+
self.gnns.append(
|
38
|
+
GINEConv(
|
39
|
+
nn=Sequential(
|
40
|
+
Linear(in_channels, in_channels * 2),
|
41
|
+
ReLU(),
|
42
|
+
Linear(in_channels * 2, in_channels),
|
43
|
+
),
|
44
|
+
train_eps=True,
|
45
|
+
edge_dim=in_channels,
|
46
|
+
))
|
47
|
+
self.batch_norms.append(BatchNorm1d(in_channels))
|
48
|
+
self.reset_parameters()
|
49
|
+
|
50
|
+
def reset_parameters(self):
|
51
|
+
torch.nn.init.xavier_uniform_(self.x_embed1.weight.data)
|
52
|
+
torch.nn.init.xavier_uniform_(self.x_embed2.weight.data)
|
53
|
+
torch.nn.init.xavier_uniform_(self.edge_embed1.weight.data)
|
54
|
+
torch.nn.init.xavier_uniform_(self.edge_embed2.weight.data)
|
55
|
+
|
56
|
+
def forward(
|
57
|
+
self,
|
58
|
+
x: Tensor,
|
59
|
+
edge_index: Tensor,
|
60
|
+
batch: Tensor,
|
61
|
+
edge_attr: Tensor,
|
62
|
+
) -> Tensor:
|
63
|
+
x = self.x_embed1(x[:, 0].long()) + self.x_embed2(x[:, 1].long())
|
64
|
+
edge_index, edge_attr = add_self_loops(
|
65
|
+
edge_index,
|
66
|
+
edge_attr,
|
67
|
+
fill_value=0,
|
68
|
+
num_nodes=x.size(0),
|
69
|
+
)
|
70
|
+
edge_attr = self.edge_embed1(edge_attr[:, 0]) + self.edge_embed2(
|
71
|
+
edge_attr[:, 1])
|
72
|
+
for i, (gnn, bn) in enumerate(zip(self.gnns, self.batch_norms)):
|
73
|
+
x = gnn(x, edge_index, edge_attr)
|
74
|
+
x = bn(x)
|
75
|
+
if i < self.num_layers - 1:
|
76
|
+
x = F.relu(x)
|
77
|
+
x = F.dropout(x, self.dropout, training=self.training)
|
78
|
+
|
79
|
+
x, mask = to_dense_batch(x, batch)
|
80
|
+
return x, mask
|
81
|
+
|
82
|
+
|
83
|
+
class GITFormer(torch.nn.Module):
|
84
|
+
def __init__(
|
85
|
+
self,
|
86
|
+
num_query_token: int,
|
87
|
+
vision_graph_width: int,
|
88
|
+
cross_attention_freq: int = 2,
|
89
|
+
):
|
90
|
+
super().__init__()
|
91
|
+
from transformers import AutoConfig, AutoModel
|
92
|
+
|
93
|
+
config = AutoConfig.from_pretrained("allenai/scibert_scivocab_uncased")
|
94
|
+
config.encoder_width = vision_graph_width
|
95
|
+
# insert cross-attention layer every other block
|
96
|
+
config.add_cross_attention = True
|
97
|
+
config.is_decoder = True
|
98
|
+
config.cross_attention_freq = cross_attention_freq
|
99
|
+
config.query_length = num_query_token
|
100
|
+
self.Qformer = AutoModel.from_pretrained(
|
101
|
+
"allenai/scibert_scivocab_uncased", config=config)
|
102
|
+
self.query_tokens = torch.nn.Parameter(
|
103
|
+
torch.zeros(1, num_query_token, config.hidden_size))
|
104
|
+
self.query_tokens.data.normal_(mean=0.0, std=config.initializer_range)
|
105
|
+
|
106
|
+
|
107
|
+
class GITMol(torch.nn.Module):
|
108
|
+
r"""The GITMol model from the `"GIT-Mol: A Multi-modal Large Language
|
109
|
+
Model for Molecular Science with Graph, Image, and Text"
|
110
|
+
<https://arxiv.org/pdf/2308.06911>`_ paper.
|
111
|
+
|
112
|
+
.. note::
|
113
|
+
For an example of using :class:`GITMol`, see
|
114
|
+
`examples/llm/git_mol.py <https://github.com/pyg-team/
|
115
|
+
pytorch_geometric/blob/master/examples/llm/git_mol.py>`_.
|
116
|
+
"""
|
117
|
+
def __init__(self) -> None:
|
118
|
+
super().__init__()
|
119
|
+
# graph
|
120
|
+
self.graph_encoder = GraphEncoder(num_layers=2, in_channels=16)
|
121
|
+
self.graph_proj = Linear(16, 768)
|
122
|
+
self.ln_graph = LayerNorm(768)
|
123
|
+
# text
|
124
|
+
self.text_encoder = SentenceTransformer(
|
125
|
+
model_name='allenai/scibert_scivocab_uncased',
|
126
|
+
pooling_strategy='last_hidden_state',
|
127
|
+
)
|
128
|
+
self.text_proj = Linear(768, 768)
|
129
|
+
self.ln_text = LayerNorm(768)
|
130
|
+
# vision
|
131
|
+
self.vision_encoder = VisionTransformer(
|
132
|
+
model_name='microsoft/swin-base-patch4-window7-224', )
|
133
|
+
self.vision_proj = Linear(1024, 768)
|
134
|
+
self.ln_vision = LayerNorm(768)
|
135
|
+
# cross-attention
|
136
|
+
self.gitformer = GITFormer(384, 768)
|
137
|
+
|
138
|
+
self.xtm_head = torch.nn.ModuleDict({
|
139
|
+
'image':
|
140
|
+
Linear(self.gitformer.Qformer.config.hidden_size, 2),
|
141
|
+
'graph':
|
142
|
+
Linear(self.gitformer.Qformer.config.hidden_size, 2),
|
143
|
+
'cs_text':
|
144
|
+
Linear(self.gitformer.Qformer.config.hidden_size, 2),
|
145
|
+
})
|
146
|
+
|
147
|
+
self.xtc_proj = torch.nn.ModuleDict({
|
148
|
+
'image':
|
149
|
+
Linear(self.gitformer.Qformer.config.hidden_size, 768),
|
150
|
+
'graph':
|
151
|
+
Linear(self.gitformer.Qformer.config.hidden_size, 768),
|
152
|
+
'cs_text':
|
153
|
+
Linear(self.gitformer.Qformer.config.hidden_size, 768),
|
154
|
+
})
|
155
|
+
self.temp = torch.nn.Parameter(0.07 * torch.ones([]))
|
156
|
+
self.model_freeze()
|
157
|
+
|
158
|
+
def model_freeze(self) -> None:
|
159
|
+
for param in self.graph_encoder.parameters():
|
160
|
+
param.requires_grad = False
|
161
|
+
|
162
|
+
for param in self.vision_encoder.parameters():
|
163
|
+
param.requires_grad = False
|
164
|
+
|
165
|
+
def forward(
|
166
|
+
self,
|
167
|
+
x: Tensor,
|
168
|
+
edge_index: Tensor,
|
169
|
+
batch: Tensor,
|
170
|
+
edge_attr: Optional[Tensor],
|
171
|
+
smiles: List[str],
|
172
|
+
images: Tensor,
|
173
|
+
captions: List[str],
|
174
|
+
) -> Tensor:
|
175
|
+
batch_size = len(smiles)
|
176
|
+
|
177
|
+
x_vision = self.vision_encoder(images)
|
178
|
+
x_vision = self.vision_proj(x_vision)
|
179
|
+
x_vision = self.ln_vision(x_vision) # [bs, patch_len, d]
|
180
|
+
vision_atts = torch.ones(x_vision.size()[:-1],
|
181
|
+
dtype=torch.long).to(x_vision.device)
|
182
|
+
vision_targets = torch.arange(batch_size).to(x_vision.device)
|
183
|
+
|
184
|
+
x_graph, graph_atts = self.graph_encoder(x, edge_index, batch,
|
185
|
+
edge_attr)
|
186
|
+
x_graph = self.graph_proj(x_graph)
|
187
|
+
x_graph = self.ln_graph(x_graph) # [bs, node_len, d]
|
188
|
+
graph_targets = torch.arange(batch_size).to(x_graph.device)
|
189
|
+
|
190
|
+
x_smiles = self.text_encoder.encode(smiles) # [bs, seq_len, d]
|
191
|
+
smiles_atts = torch.ones(x_smiles.size()[:-1],
|
192
|
+
dtype=torch.long).to(x_smiles.device)
|
193
|
+
smiles_targets = torch.arange(batch_size).to(x_smiles.device)
|
194
|
+
|
195
|
+
caption_input_ids, caption_attention_masks = self.text_encoder.get_input_ids( # noqa: E501
|
196
|
+
captions)
|
197
|
+
|
198
|
+
text_output = self.gitformer.Qformer(
|
199
|
+
caption_input_ids,
|
200
|
+
attention_mask=caption_attention_masks,
|
201
|
+
return_dict=True,
|
202
|
+
)
|
203
|
+
text_feat = F.normalize(
|
204
|
+
self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1)
|
205
|
+
|
206
|
+
loss = 0
|
207
|
+
for x_embed, x_atts, x_targets, modal in zip(
|
208
|
+
[x_graph, x_smiles, x_vision],
|
209
|
+
[graph_atts, smiles_atts, vision_atts],
|
210
|
+
[graph_targets, smiles_targets, vision_targets],
|
211
|
+
['graph', 'cs_text', 'image'],
|
212
|
+
):
|
213
|
+
loss += self._calc_xtc_loss(x_embed, x_atts, x_targets, text_feat,
|
214
|
+
modal)
|
215
|
+
loss += self._calc_xtm_loss(x_embed, caption_input_ids,
|
216
|
+
caption_attention_masks, modal)
|
217
|
+
|
218
|
+
return loss / 6
|
219
|
+
|
220
|
+
def _calc_xtm_loss(
|
221
|
+
self,
|
222
|
+
x_embeds: Tensor,
|
223
|
+
input_ids: Tensor,
|
224
|
+
attention_mask: Tensor,
|
225
|
+
modal: str,
|
226
|
+
) -> Tensor:
|
227
|
+
# Initializing lists to hold the original and negative samples
|
228
|
+
x_embeds_list = []
|
229
|
+
text_input_ids_list = []
|
230
|
+
text_attention_mask_list = []
|
231
|
+
|
232
|
+
batch_size = x_embeds.size(0)
|
233
|
+
for i in range(batch_size):
|
234
|
+
# Original samples
|
235
|
+
x_embeds_list.append(x_embeds[i])
|
236
|
+
text_input_ids_list.append(input_ids[i, :])
|
237
|
+
text_attention_mask_list.append(attention_mask[i, :])
|
238
|
+
|
239
|
+
if batch_size > 1:
|
240
|
+
# Negative samples (neg_text_input_ids corresponds to x_embeds)
|
241
|
+
neg_text_input_ids = input_ids[i - 1 if i == batch_size -
|
242
|
+
1 else i + 1, :]
|
243
|
+
neg_text_attention_mask = attention_mask[i -
|
244
|
+
1 if i == batch_size -
|
245
|
+
1 else i + 1, :]
|
246
|
+
text_input_ids_list.append(neg_text_input_ids)
|
247
|
+
text_attention_mask_list.append(neg_text_attention_mask)
|
248
|
+
x_embeds_list.append(x_embeds[i, :])
|
249
|
+
|
250
|
+
# Negative samples (text_input_ids corresponds to neg_x_embeds)
|
251
|
+
neg_x_embeds = x_embeds[i - 1 if i == batch_size - 1 else i +
|
252
|
+
1, :]
|
253
|
+
x_embeds_list.append(neg_x_embeds)
|
254
|
+
text_input_ids_list.append(input_ids[i, :])
|
255
|
+
text_attention_mask_list.append(attention_mask[i, :])
|
256
|
+
|
257
|
+
# Stack all samples into two large tensors
|
258
|
+
x_embeds_all = torch.stack(x_embeds_list, dim=1) \
|
259
|
+
.reshape(-1, x_embeds.size(1), x_embeds.size(2))
|
260
|
+
text_input_ids_all = torch.stack(text_input_ids_list, dim=1) \
|
261
|
+
.reshape(-1, input_ids.size(1))
|
262
|
+
# Create image attention masks for the concatenated tensor
|
263
|
+
image_attns_all = torch.ones(x_embeds_all.size()[:-1],
|
264
|
+
dtype=torch.long).to(x_embeds_all.device)
|
265
|
+
query_tokens_xtm = self.gitformer.query_tokens.expand(
|
266
|
+
text_input_ids_all.shape[0], -1, -1)
|
267
|
+
query_attns_xtm = torch.ones(query_tokens_xtm.size()[:-1],
|
268
|
+
dtype=torch.long).to(x_embeds_all.device)
|
269
|
+
|
270
|
+
output_xtm = self.gitformer.Qformer(
|
271
|
+
inputs_embeds=query_tokens_xtm,
|
272
|
+
attention_mask=query_attns_xtm,
|
273
|
+
encoder_hidden_states=x_embeds_all,
|
274
|
+
encoder_attention_mask=image_attns_all,
|
275
|
+
return_dict=True,
|
276
|
+
).last_hidden_state
|
277
|
+
|
278
|
+
xtm_embeddings = output_xtm[:, :query_tokens_xtm.size(1), :]
|
279
|
+
|
280
|
+
xtm_logit = self.xtm_head[modal](xtm_embeddings).mean(dim=1)
|
281
|
+
# Create labels: 1 for the original samples, 0 for the negative samples
|
282
|
+
if batch_size > 1:
|
283
|
+
labels = torch.cat(
|
284
|
+
[torch.ones(batch_size),
|
285
|
+
torch.zeros(batch_size * 2)], dim=0)
|
286
|
+
else:
|
287
|
+
labels = torch.ones(batch_size)
|
288
|
+
labels = labels.long().to(xtm_logit.device)
|
289
|
+
|
290
|
+
# Calculate cross entropy loss
|
291
|
+
return F.cross_entropy(xtm_logit, labels)
|
292
|
+
|
293
|
+
def _calc_xtc_loss(
|
294
|
+
self,
|
295
|
+
x_embeds: Tensor,
|
296
|
+
x_atts: Tensor,
|
297
|
+
x_targets: Tensor,
|
298
|
+
text_feat: Tensor,
|
299
|
+
modal: str,
|
300
|
+
) -> Tensor:
|
301
|
+
query_tokens = self.gitformer.query_tokens.expand(
|
302
|
+
x_embeds.shape[0], -1, -1)
|
303
|
+
|
304
|
+
query_output = self.gitformer.Qformer(
|
305
|
+
inputs_embeds=query_tokens,
|
306
|
+
encoder_hidden_states=x_embeds,
|
307
|
+
encoder_attention_mask=x_atts,
|
308
|
+
return_dict=True,
|
309
|
+
).last_hidden_state
|
310
|
+
|
311
|
+
x_feats = F.normalize(self.xtc_proj[modal](query_output), dim=-1)
|
312
|
+
|
313
|
+
sim_q2t = torch.matmul(
|
314
|
+
x_feats.unsqueeze(1),
|
315
|
+
text_feat.unsqueeze(-1),
|
316
|
+
).squeeze(-1)
|
317
|
+
|
318
|
+
# modal-text similarity: aggregate across all query tokens
|
319
|
+
sim_x2t, _ = sim_q2t.max(-1)
|
320
|
+
sim_x2t = sim_x2t / self.temp
|
321
|
+
|
322
|
+
# text-query similarity
|
323
|
+
sim_t2q = torch.matmul(
|
324
|
+
text_feat.unsqueeze(1).unsqueeze(1),
|
325
|
+
x_feats.permute(0, 2, 1),
|
326
|
+
).squeeze(-2)
|
327
|
+
|
328
|
+
# text-modal similarity: aggregate across all query tokens
|
329
|
+
sim_t2x, _ = sim_t2q.max(-1)
|
330
|
+
sim_t2x = sim_t2x / self.temp
|
331
|
+
|
332
|
+
loss_itc = (
|
333
|
+
F.cross_entropy(sim_x2t, x_targets, label_smoothing=0.1) +
|
334
|
+
F.cross_entropy(sim_t2x, x_targets, label_smoothing=0.1)) / 2
|
335
|
+
|
336
|
+
return loss_itc
|
@@ -48,6 +48,36 @@ class SentenceTransformer(torch.nn.Module):
|
|
48
48
|
emb = F.normalize(emb, p=2, dim=1)
|
49
49
|
return emb
|
50
50
|
|
51
|
+
def get_input_ids(
|
52
|
+
self,
|
53
|
+
text: List[str],
|
54
|
+
batch_size: Optional[int] = None,
|
55
|
+
output_device: Optional[Union[torch.device, str]] = None,
|
56
|
+
) -> Tensor:
|
57
|
+
is_empty = len(text) == 0
|
58
|
+
text = ['dummy'] if is_empty else text
|
59
|
+
|
60
|
+
batch_size = len(text) if batch_size is None else batch_size
|
61
|
+
|
62
|
+
input_ids: List[Tensor] = []
|
63
|
+
attention_masks: List[Tensor] = []
|
64
|
+
for start in range(0, len(text), batch_size):
|
65
|
+
token = self.tokenizer(
|
66
|
+
text[start:start + batch_size],
|
67
|
+
padding=True,
|
68
|
+
truncation=True,
|
69
|
+
return_tensors='pt',
|
70
|
+
)
|
71
|
+
input_ids.append(token.input_ids.to(self.device))
|
72
|
+
attention_masks.append(token.attention_mask.to(self.device))
|
73
|
+
|
74
|
+
def _out(x: List[Tensor]) -> Tensor:
|
75
|
+
out = torch.cat(x, dim=0) if len(x) > 1 else x[0]
|
76
|
+
out = out[:0] if is_empty else out
|
77
|
+
return out.to(output_device)
|
78
|
+
|
79
|
+
return _out(input_ids), _out(attention_masks)
|
80
|
+
|
51
81
|
@property
|
52
82
|
def device(self) -> torch.device:
|
53
83
|
return next(iter(self.model.parameters())).device
|
@@ -0,0 +1,33 @@
|
|
1
|
+
from typing import Optional, Union
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch import Tensor
|
5
|
+
|
6
|
+
|
7
|
+
class VisionTransformer(torch.nn.Module):
|
8
|
+
def __init__(
|
9
|
+
self,
|
10
|
+
model_name: str,
|
11
|
+
) -> None:
|
12
|
+
super().__init__()
|
13
|
+
self.model_name = model_name
|
14
|
+
|
15
|
+
from transformers import SwinConfig, SwinModel
|
16
|
+
|
17
|
+
self.config = SwinConfig.from_pretrained(model_name)
|
18
|
+
self.model = SwinModel(self.config)
|
19
|
+
|
20
|
+
@torch.no_grad()
|
21
|
+
def forward(
|
22
|
+
self,
|
23
|
+
images: Tensor,
|
24
|
+
output_device: Optional[Union[torch.device, str]] = None,
|
25
|
+
) -> Tensor:
|
26
|
+
return self.model(images).last_hidden_state.to(output_device)
|
27
|
+
|
28
|
+
@property
|
29
|
+
def device(self) -> torch.device:
|
30
|
+
return next(iter(self.model.parameters())).device
|
31
|
+
|
32
|
+
def __repr__(self) -> str:
|
33
|
+
return f'{self.__class__.__name__}(model_name={self.model_name})'
|
@@ -20,6 +20,7 @@ from .utils import (
|
|
20
20
|
get_gpu_memory_from_nvidia_smi,
|
21
21
|
get_model_size,
|
22
22
|
)
|
23
|
+
from .nvtx import nvtxit
|
23
24
|
|
24
25
|
__all__ = [
|
25
26
|
'profileit',
|
@@ -38,6 +39,7 @@ __all__ = [
|
|
38
39
|
'get_gpu_memory_from_nvidia_smi',
|
39
40
|
'get_gpu_memory_from_ipex',
|
40
41
|
'benchmark',
|
42
|
+
'nvtxit',
|
41
43
|
]
|
42
44
|
|
43
45
|
classes = __all__
|
@@ -0,0 +1,66 @@
|
|
1
|
+
from functools import wraps
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import torch
|
5
|
+
|
6
|
+
CUDA_PROFILE_STARTED = False
|
7
|
+
|
8
|
+
|
9
|
+
def begin_cuda_profile():
|
10
|
+
global CUDA_PROFILE_STARTED
|
11
|
+
prev_state = CUDA_PROFILE_STARTED
|
12
|
+
if prev_state is False:
|
13
|
+
CUDA_PROFILE_STARTED = True
|
14
|
+
torch.cuda.cudart().cudaProfilerStart()
|
15
|
+
return prev_state
|
16
|
+
|
17
|
+
|
18
|
+
def end_cuda_profile(prev_state: bool):
|
19
|
+
global CUDA_PROFILE_STARTED
|
20
|
+
CUDA_PROFILE_STARTED = prev_state
|
21
|
+
if prev_state is False:
|
22
|
+
torch.cuda.cudart().cudaProfilerStop()
|
23
|
+
|
24
|
+
|
25
|
+
def nvtxit(name: Optional[str] = None, n_warmups: int = 0,
|
26
|
+
n_iters: Optional[int] = None):
|
27
|
+
"""Enables NVTX profiling for a function.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
name (Optional[str], optional): Name to give the reference frame for
|
31
|
+
the function being wrapped. Defaults to the name of the
|
32
|
+
function in code.
|
33
|
+
n_warmups (int, optional): Number of iters to call that function
|
34
|
+
before starting. Defaults to 0.
|
35
|
+
n_iters (Optional[int], optional): Number of iters of that function to
|
36
|
+
record. Defaults to all of them.
|
37
|
+
"""
|
38
|
+
def nvtx(func):
|
39
|
+
|
40
|
+
nonlocal name
|
41
|
+
iters_so_far = 0
|
42
|
+
if name is None:
|
43
|
+
name = func.__name__
|
44
|
+
|
45
|
+
@wraps(func)
|
46
|
+
def wrapper(*args, **kwargs):
|
47
|
+
nonlocal iters_so_far
|
48
|
+
if not torch.cuda.is_available():
|
49
|
+
return func(*args, **kwargs)
|
50
|
+
elif iters_so_far < n_warmups:
|
51
|
+
iters_so_far += 1
|
52
|
+
return func(*args, **kwargs)
|
53
|
+
elif n_iters is None or iters_so_far < n_iters + n_warmups:
|
54
|
+
prev_state = begin_cuda_profile()
|
55
|
+
torch.cuda.nvtx.range_push(f"{name}_{iters_so_far}")
|
56
|
+
result = func(*args, **kwargs)
|
57
|
+
torch.cuda.nvtx.range_pop()
|
58
|
+
end_cuda_profile(prev_state)
|
59
|
+
iters_so_far += 1
|
60
|
+
return result
|
61
|
+
else:
|
62
|
+
return func(*args, **kwargs)
|
63
|
+
|
64
|
+
return wrapper
|
65
|
+
|
66
|
+
return nvtx
|
torch_geometric/sampler/base.py
CHANGED
@@ -425,6 +425,14 @@ class NumNeighbors:
|
|
425
425
|
else:
|
426
426
|
assert False
|
427
427
|
|
428
|
+
# Confirm that `values` only hold valid edge types:
|
429
|
+
if isinstance(self.values, dict):
|
430
|
+
edge_types_str = {EdgeTypeStr(key) for key in edge_types}
|
431
|
+
invalid_edge_types = set(self.values.keys()) - edge_types_str
|
432
|
+
if len(invalid_edge_types) > 0:
|
433
|
+
raise ValueError("Not all edge types specified in "
|
434
|
+
"'num_neighbors' exist in the graph")
|
435
|
+
|
428
436
|
out = {}
|
429
437
|
for edge_type in edge_types:
|
430
438
|
edge_type_str = EdgeTypeStr(edge_type)
|
File without changes
|