pyg-nightly 2.7.0.dev20241124__py3-none-any.whl → 2.7.0.dev20241126__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,336 @@
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import Tensor
6
+ from torch.nn import BatchNorm1d, LayerNorm, Linear, ReLU, Sequential
7
+
8
+ from torch_geometric.nn import GINEConv
9
+ from torch_geometric.nn.nlp import SentenceTransformer, VisionTransformer
10
+ from torch_geometric.utils import add_self_loops, to_dense_batch
11
+
12
+
13
+ class GraphEncoder(torch.nn.Module):
14
+ def __init__(
15
+ self,
16
+ num_layers: int,
17
+ in_channels: int,
18
+ dropout: float = 0.,
19
+ num_atom_type: int = 120,
20
+ num_chirality_tag: int = 3,
21
+ num_bond_type: int = 6,
22
+ num_bond_direction: int = 3,
23
+ ) -> None:
24
+ super().__init__()
25
+
26
+ self.num_layers = num_layers
27
+ self.dropout = dropout
28
+
29
+ self.x_embed1 = torch.nn.Embedding(num_atom_type, in_channels)
30
+ self.x_embed2 = torch.nn.Embedding(num_chirality_tag, in_channels)
31
+ self.edge_embed1 = torch.nn.Embedding(num_bond_type, in_channels)
32
+ self.edge_embed2 = torch.nn.Embedding(num_bond_direction, in_channels)
33
+
34
+ self.gnns = torch.nn.ModuleList()
35
+ self.batch_norms = torch.nn.ModuleList()
36
+ for _ in range(num_layers):
37
+ self.gnns.append(
38
+ GINEConv(
39
+ nn=Sequential(
40
+ Linear(in_channels, in_channels * 2),
41
+ ReLU(),
42
+ Linear(in_channels * 2, in_channels),
43
+ ),
44
+ train_eps=True,
45
+ edge_dim=in_channels,
46
+ ))
47
+ self.batch_norms.append(BatchNorm1d(in_channels))
48
+ self.reset_parameters()
49
+
50
+ def reset_parameters(self):
51
+ torch.nn.init.xavier_uniform_(self.x_embed1.weight.data)
52
+ torch.nn.init.xavier_uniform_(self.x_embed2.weight.data)
53
+ torch.nn.init.xavier_uniform_(self.edge_embed1.weight.data)
54
+ torch.nn.init.xavier_uniform_(self.edge_embed2.weight.data)
55
+
56
+ def forward(
57
+ self,
58
+ x: Tensor,
59
+ edge_index: Tensor,
60
+ batch: Tensor,
61
+ edge_attr: Tensor,
62
+ ) -> Tensor:
63
+ x = self.x_embed1(x[:, 0].long()) + self.x_embed2(x[:, 1].long())
64
+ edge_index, edge_attr = add_self_loops(
65
+ edge_index,
66
+ edge_attr,
67
+ fill_value=0,
68
+ num_nodes=x.size(0),
69
+ )
70
+ edge_attr = self.edge_embed1(edge_attr[:, 0]) + self.edge_embed2(
71
+ edge_attr[:, 1])
72
+ for i, (gnn, bn) in enumerate(zip(self.gnns, self.batch_norms)):
73
+ x = gnn(x, edge_index, edge_attr)
74
+ x = bn(x)
75
+ if i < self.num_layers - 1:
76
+ x = F.relu(x)
77
+ x = F.dropout(x, self.dropout, training=self.training)
78
+
79
+ x, mask = to_dense_batch(x, batch)
80
+ return x, mask
81
+
82
+
83
+ class GITFormer(torch.nn.Module):
84
+ def __init__(
85
+ self,
86
+ num_query_token: int,
87
+ vision_graph_width: int,
88
+ cross_attention_freq: int = 2,
89
+ ):
90
+ super().__init__()
91
+ from transformers import AutoConfig, AutoModel
92
+
93
+ config = AutoConfig.from_pretrained("allenai/scibert_scivocab_uncased")
94
+ config.encoder_width = vision_graph_width
95
+ # insert cross-attention layer every other block
96
+ config.add_cross_attention = True
97
+ config.is_decoder = True
98
+ config.cross_attention_freq = cross_attention_freq
99
+ config.query_length = num_query_token
100
+ self.Qformer = AutoModel.from_pretrained(
101
+ "allenai/scibert_scivocab_uncased", config=config)
102
+ self.query_tokens = torch.nn.Parameter(
103
+ torch.zeros(1, num_query_token, config.hidden_size))
104
+ self.query_tokens.data.normal_(mean=0.0, std=config.initializer_range)
105
+
106
+
107
+ class GITMol(torch.nn.Module):
108
+ r"""The GITMol model from the `"GIT-Mol: A Multi-modal Large Language
109
+ Model for Molecular Science with Graph, Image, and Text"
110
+ <https://arxiv.org/pdf/2308.06911>`_ paper.
111
+
112
+ .. note::
113
+ For an example of using :class:`GITMol`, see
114
+ `examples/llm/git_mol.py <https://github.com/pyg-team/
115
+ pytorch_geometric/blob/master/examples/llm/git_mol.py>`_.
116
+ """
117
+ def __init__(self) -> None:
118
+ super().__init__()
119
+ # graph
120
+ self.graph_encoder = GraphEncoder(num_layers=2, in_channels=16)
121
+ self.graph_proj = Linear(16, 768)
122
+ self.ln_graph = LayerNorm(768)
123
+ # text
124
+ self.text_encoder = SentenceTransformer(
125
+ model_name='allenai/scibert_scivocab_uncased',
126
+ pooling_strategy='last_hidden_state',
127
+ )
128
+ self.text_proj = Linear(768, 768)
129
+ self.ln_text = LayerNorm(768)
130
+ # vision
131
+ self.vision_encoder = VisionTransformer(
132
+ model_name='microsoft/swin-base-patch4-window7-224', )
133
+ self.vision_proj = Linear(1024, 768)
134
+ self.ln_vision = LayerNorm(768)
135
+ # cross-attention
136
+ self.gitformer = GITFormer(384, 768)
137
+
138
+ self.xtm_head = torch.nn.ModuleDict({
139
+ 'image':
140
+ Linear(self.gitformer.Qformer.config.hidden_size, 2),
141
+ 'graph':
142
+ Linear(self.gitformer.Qformer.config.hidden_size, 2),
143
+ 'cs_text':
144
+ Linear(self.gitformer.Qformer.config.hidden_size, 2),
145
+ })
146
+
147
+ self.xtc_proj = torch.nn.ModuleDict({
148
+ 'image':
149
+ Linear(self.gitformer.Qformer.config.hidden_size, 768),
150
+ 'graph':
151
+ Linear(self.gitformer.Qformer.config.hidden_size, 768),
152
+ 'cs_text':
153
+ Linear(self.gitformer.Qformer.config.hidden_size, 768),
154
+ })
155
+ self.temp = torch.nn.Parameter(0.07 * torch.ones([]))
156
+ self.model_freeze()
157
+
158
+ def model_freeze(self) -> None:
159
+ for param in self.graph_encoder.parameters():
160
+ param.requires_grad = False
161
+
162
+ for param in self.vision_encoder.parameters():
163
+ param.requires_grad = False
164
+
165
+ def forward(
166
+ self,
167
+ x: Tensor,
168
+ edge_index: Tensor,
169
+ batch: Tensor,
170
+ edge_attr: Optional[Tensor],
171
+ smiles: List[str],
172
+ images: Tensor,
173
+ captions: List[str],
174
+ ) -> Tensor:
175
+ batch_size = len(smiles)
176
+
177
+ x_vision = self.vision_encoder(images)
178
+ x_vision = self.vision_proj(x_vision)
179
+ x_vision = self.ln_vision(x_vision) # [bs, patch_len, d]
180
+ vision_atts = torch.ones(x_vision.size()[:-1],
181
+ dtype=torch.long).to(x_vision.device)
182
+ vision_targets = torch.arange(batch_size).to(x_vision.device)
183
+
184
+ x_graph, graph_atts = self.graph_encoder(x, edge_index, batch,
185
+ edge_attr)
186
+ x_graph = self.graph_proj(x_graph)
187
+ x_graph = self.ln_graph(x_graph) # [bs, node_len, d]
188
+ graph_targets = torch.arange(batch_size).to(x_graph.device)
189
+
190
+ x_smiles = self.text_encoder.encode(smiles) # [bs, seq_len, d]
191
+ smiles_atts = torch.ones(x_smiles.size()[:-1],
192
+ dtype=torch.long).to(x_smiles.device)
193
+ smiles_targets = torch.arange(batch_size).to(x_smiles.device)
194
+
195
+ caption_input_ids, caption_attention_masks = self.text_encoder.get_input_ids( # noqa: E501
196
+ captions)
197
+
198
+ text_output = self.gitformer.Qformer(
199
+ caption_input_ids,
200
+ attention_mask=caption_attention_masks,
201
+ return_dict=True,
202
+ )
203
+ text_feat = F.normalize(
204
+ self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1)
205
+
206
+ loss = 0
207
+ for x_embed, x_atts, x_targets, modal in zip(
208
+ [x_graph, x_smiles, x_vision],
209
+ [graph_atts, smiles_atts, vision_atts],
210
+ [graph_targets, smiles_targets, vision_targets],
211
+ ['graph', 'cs_text', 'image'],
212
+ ):
213
+ loss += self._calc_xtc_loss(x_embed, x_atts, x_targets, text_feat,
214
+ modal)
215
+ loss += self._calc_xtm_loss(x_embed, caption_input_ids,
216
+ caption_attention_masks, modal)
217
+
218
+ return loss / 6
219
+
220
+ def _calc_xtm_loss(
221
+ self,
222
+ x_embeds: Tensor,
223
+ input_ids: Tensor,
224
+ attention_mask: Tensor,
225
+ modal: str,
226
+ ) -> Tensor:
227
+ # Initializing lists to hold the original and negative samples
228
+ x_embeds_list = []
229
+ text_input_ids_list = []
230
+ text_attention_mask_list = []
231
+
232
+ batch_size = x_embeds.size(0)
233
+ for i in range(batch_size):
234
+ # Original samples
235
+ x_embeds_list.append(x_embeds[i])
236
+ text_input_ids_list.append(input_ids[i, :])
237
+ text_attention_mask_list.append(attention_mask[i, :])
238
+
239
+ if batch_size > 1:
240
+ # Negative samples (neg_text_input_ids corresponds to x_embeds)
241
+ neg_text_input_ids = input_ids[i - 1 if i == batch_size -
242
+ 1 else i + 1, :]
243
+ neg_text_attention_mask = attention_mask[i -
244
+ 1 if i == batch_size -
245
+ 1 else i + 1, :]
246
+ text_input_ids_list.append(neg_text_input_ids)
247
+ text_attention_mask_list.append(neg_text_attention_mask)
248
+ x_embeds_list.append(x_embeds[i, :])
249
+
250
+ # Negative samples (text_input_ids corresponds to neg_x_embeds)
251
+ neg_x_embeds = x_embeds[i - 1 if i == batch_size - 1 else i +
252
+ 1, :]
253
+ x_embeds_list.append(neg_x_embeds)
254
+ text_input_ids_list.append(input_ids[i, :])
255
+ text_attention_mask_list.append(attention_mask[i, :])
256
+
257
+ # Stack all samples into two large tensors
258
+ x_embeds_all = torch.stack(x_embeds_list, dim=1) \
259
+ .reshape(-1, x_embeds.size(1), x_embeds.size(2))
260
+ text_input_ids_all = torch.stack(text_input_ids_list, dim=1) \
261
+ .reshape(-1, input_ids.size(1))
262
+ # Create image attention masks for the concatenated tensor
263
+ image_attns_all = torch.ones(x_embeds_all.size()[:-1],
264
+ dtype=torch.long).to(x_embeds_all.device)
265
+ query_tokens_xtm = self.gitformer.query_tokens.expand(
266
+ text_input_ids_all.shape[0], -1, -1)
267
+ query_attns_xtm = torch.ones(query_tokens_xtm.size()[:-1],
268
+ dtype=torch.long).to(x_embeds_all.device)
269
+
270
+ output_xtm = self.gitformer.Qformer(
271
+ inputs_embeds=query_tokens_xtm,
272
+ attention_mask=query_attns_xtm,
273
+ encoder_hidden_states=x_embeds_all,
274
+ encoder_attention_mask=image_attns_all,
275
+ return_dict=True,
276
+ ).last_hidden_state
277
+
278
+ xtm_embeddings = output_xtm[:, :query_tokens_xtm.size(1), :]
279
+
280
+ xtm_logit = self.xtm_head[modal](xtm_embeddings).mean(dim=1)
281
+ # Create labels: 1 for the original samples, 0 for the negative samples
282
+ if batch_size > 1:
283
+ labels = torch.cat(
284
+ [torch.ones(batch_size),
285
+ torch.zeros(batch_size * 2)], dim=0)
286
+ else:
287
+ labels = torch.ones(batch_size)
288
+ labels = labels.long().to(xtm_logit.device)
289
+
290
+ # Calculate cross entropy loss
291
+ return F.cross_entropy(xtm_logit, labels)
292
+
293
+ def _calc_xtc_loss(
294
+ self,
295
+ x_embeds: Tensor,
296
+ x_atts: Tensor,
297
+ x_targets: Tensor,
298
+ text_feat: Tensor,
299
+ modal: str,
300
+ ) -> Tensor:
301
+ query_tokens = self.gitformer.query_tokens.expand(
302
+ x_embeds.shape[0], -1, -1)
303
+
304
+ query_output = self.gitformer.Qformer(
305
+ inputs_embeds=query_tokens,
306
+ encoder_hidden_states=x_embeds,
307
+ encoder_attention_mask=x_atts,
308
+ return_dict=True,
309
+ ).last_hidden_state
310
+
311
+ x_feats = F.normalize(self.xtc_proj[modal](query_output), dim=-1)
312
+
313
+ sim_q2t = torch.matmul(
314
+ x_feats.unsqueeze(1),
315
+ text_feat.unsqueeze(-1),
316
+ ).squeeze(-1)
317
+
318
+ # modal-text similarity: aggregate across all query tokens
319
+ sim_x2t, _ = sim_q2t.max(-1)
320
+ sim_x2t = sim_x2t / self.temp
321
+
322
+ # text-query similarity
323
+ sim_t2q = torch.matmul(
324
+ text_feat.unsqueeze(1).unsqueeze(1),
325
+ x_feats.permute(0, 2, 1),
326
+ ).squeeze(-2)
327
+
328
+ # text-modal similarity: aggregate across all query tokens
329
+ sim_t2x, _ = sim_t2q.max(-1)
330
+ sim_t2x = sim_t2x / self.temp
331
+
332
+ loss_itc = (
333
+ F.cross_entropy(sim_x2t, x_targets, label_smoothing=0.1) +
334
+ F.cross_entropy(sim_t2x, x_targets, label_smoothing=0.1)) / 2
335
+
336
+ return loss_itc
@@ -1,7 +1,9 @@
1
1
  from .sentence_transformer import SentenceTransformer
2
+ from .vision_transformer import VisionTransformer
2
3
  from .llm import LLM
3
4
 
4
5
  __all__ = classes = [
5
6
  'SentenceTransformer',
7
+ 'VisionTransformer',
6
8
  'LLM',
7
9
  ]
@@ -48,6 +48,36 @@ class SentenceTransformer(torch.nn.Module):
48
48
  emb = F.normalize(emb, p=2, dim=1)
49
49
  return emb
50
50
 
51
+ def get_input_ids(
52
+ self,
53
+ text: List[str],
54
+ batch_size: Optional[int] = None,
55
+ output_device: Optional[Union[torch.device, str]] = None,
56
+ ) -> Tensor:
57
+ is_empty = len(text) == 0
58
+ text = ['dummy'] if is_empty else text
59
+
60
+ batch_size = len(text) if batch_size is None else batch_size
61
+
62
+ input_ids: List[Tensor] = []
63
+ attention_masks: List[Tensor] = []
64
+ for start in range(0, len(text), batch_size):
65
+ token = self.tokenizer(
66
+ text[start:start + batch_size],
67
+ padding=True,
68
+ truncation=True,
69
+ return_tensors='pt',
70
+ )
71
+ input_ids.append(token.input_ids.to(self.device))
72
+ attention_masks.append(token.attention_mask.to(self.device))
73
+
74
+ def _out(x: List[Tensor]) -> Tensor:
75
+ out = torch.cat(x, dim=0) if len(x) > 1 else x[0]
76
+ out = out[:0] if is_empty else out
77
+ return out.to(output_device)
78
+
79
+ return _out(input_ids), _out(attention_masks)
80
+
51
81
  @property
52
82
  def device(self) -> torch.device:
53
83
  return next(iter(self.model.parameters())).device
@@ -0,0 +1,33 @@
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+
7
+ class VisionTransformer(torch.nn.Module):
8
+ def __init__(
9
+ self,
10
+ model_name: str,
11
+ ) -> None:
12
+ super().__init__()
13
+ self.model_name = model_name
14
+
15
+ from transformers import SwinConfig, SwinModel
16
+
17
+ self.config = SwinConfig.from_pretrained(model_name)
18
+ self.model = SwinModel(self.config)
19
+
20
+ @torch.no_grad()
21
+ def forward(
22
+ self,
23
+ images: Tensor,
24
+ output_device: Optional[Union[torch.device, str]] = None,
25
+ ) -> Tensor:
26
+ return self.model(images).last_hidden_state.to(output_device)
27
+
28
+ @property
29
+ def device(self) -> torch.device:
30
+ return next(iter(self.model.parameters())).device
31
+
32
+ def __repr__(self) -> str:
33
+ return f'{self.__class__.__name__}(model_name={self.model_name})'
@@ -20,6 +20,7 @@ from .utils import (
20
20
  get_gpu_memory_from_nvidia_smi,
21
21
  get_model_size,
22
22
  )
23
+ from .nvtx import nvtxit
23
24
 
24
25
  __all__ = [
25
26
  'profileit',
@@ -38,6 +39,7 @@ __all__ = [
38
39
  'get_gpu_memory_from_nvidia_smi',
39
40
  'get_gpu_memory_from_ipex',
40
41
  'benchmark',
42
+ 'nvtxit',
41
43
  ]
42
44
 
43
45
  classes = __all__
@@ -0,0 +1,66 @@
1
+ from functools import wraps
2
+ from typing import Optional
3
+
4
+ import torch
5
+
6
+ CUDA_PROFILE_STARTED = False
7
+
8
+
9
+ def begin_cuda_profile():
10
+ global CUDA_PROFILE_STARTED
11
+ prev_state = CUDA_PROFILE_STARTED
12
+ if prev_state is False:
13
+ CUDA_PROFILE_STARTED = True
14
+ torch.cuda.cudart().cudaProfilerStart()
15
+ return prev_state
16
+
17
+
18
+ def end_cuda_profile(prev_state: bool):
19
+ global CUDA_PROFILE_STARTED
20
+ CUDA_PROFILE_STARTED = prev_state
21
+ if prev_state is False:
22
+ torch.cuda.cudart().cudaProfilerStop()
23
+
24
+
25
+ def nvtxit(name: Optional[str] = None, n_warmups: int = 0,
26
+ n_iters: Optional[int] = None):
27
+ """Enables NVTX profiling for a function.
28
+
29
+ Args:
30
+ name (Optional[str], optional): Name to give the reference frame for
31
+ the function being wrapped. Defaults to the name of the
32
+ function in code.
33
+ n_warmups (int, optional): Number of iters to call that function
34
+ before starting. Defaults to 0.
35
+ n_iters (Optional[int], optional): Number of iters of that function to
36
+ record. Defaults to all of them.
37
+ """
38
+ def nvtx(func):
39
+
40
+ nonlocal name
41
+ iters_so_far = 0
42
+ if name is None:
43
+ name = func.__name__
44
+
45
+ @wraps(func)
46
+ def wrapper(*args, **kwargs):
47
+ nonlocal iters_so_far
48
+ if not torch.cuda.is_available():
49
+ return func(*args, **kwargs)
50
+ elif iters_so_far < n_warmups:
51
+ iters_so_far += 1
52
+ return func(*args, **kwargs)
53
+ elif n_iters is None or iters_so_far < n_iters + n_warmups:
54
+ prev_state = begin_cuda_profile()
55
+ torch.cuda.nvtx.range_push(f"{name}_{iters_so_far}")
56
+ result = func(*args, **kwargs)
57
+ torch.cuda.nvtx.range_pop()
58
+ end_cuda_profile(prev_state)
59
+ iters_so_far += 1
60
+ return result
61
+ else:
62
+ return func(*args, **kwargs)
63
+
64
+ return wrapper
65
+
66
+ return nvtx