pyg-nightly 2.6.0.dev20240511__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (205) hide show
  1. {pyg_nightly-2.6.0.dev20240511.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +30 -31
  2. {pyg_nightly-2.6.0.dev20240511.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +205 -181
  3. {pyg_nightly-2.6.0.dev20240511.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
  4. torch_geometric/__init__.py +26 -1
  5. torch_geometric/_compile.py +8 -1
  6. torch_geometric/_onnx.py +14 -0
  7. torch_geometric/config_mixin.py +113 -0
  8. torch_geometric/config_store.py +16 -14
  9. torch_geometric/data/__init__.py +24 -1
  10. torch_geometric/data/batch.py +2 -2
  11. torch_geometric/data/data.py +13 -8
  12. torch_geometric/data/database.py +15 -7
  13. torch_geometric/data/dataset.py +14 -6
  14. torch_geometric/data/feature_store.py +13 -22
  15. torch_geometric/data/graph_store.py +0 -4
  16. torch_geometric/data/hetero_data.py +4 -4
  17. torch_geometric/data/in_memory_dataset.py +2 -4
  18. torch_geometric/data/large_graph_indexer.py +677 -0
  19. torch_geometric/data/lightning/datamodule.py +4 -4
  20. torch_geometric/data/storage.py +15 -5
  21. torch_geometric/data/summary.py +14 -4
  22. torch_geometric/data/temporal.py +1 -2
  23. torch_geometric/datasets/__init__.py +11 -1
  24. torch_geometric/datasets/actor.py +9 -11
  25. torch_geometric/datasets/airfrans.py +15 -18
  26. torch_geometric/datasets/airports.py +10 -12
  27. torch_geometric/datasets/amazon.py +8 -11
  28. torch_geometric/datasets/amazon_book.py +9 -10
  29. torch_geometric/datasets/amazon_products.py +9 -10
  30. torch_geometric/datasets/aminer.py +8 -9
  31. torch_geometric/datasets/aqsol.py +10 -13
  32. torch_geometric/datasets/attributed_graph_dataset.py +10 -12
  33. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  34. torch_geometric/datasets/ba_shapes.py +5 -6
  35. torch_geometric/datasets/bitcoin_otc.py +1 -1
  36. torch_geometric/datasets/brca_tgca.py +1 -1
  37. torch_geometric/datasets/dblp.py +2 -1
  38. torch_geometric/datasets/dbp15k.py +2 -2
  39. torch_geometric/datasets/fake.py +1 -3
  40. torch_geometric/datasets/flickr.py +2 -1
  41. torch_geometric/datasets/freebase.py +1 -1
  42. torch_geometric/datasets/gdelt_lite.py +3 -2
  43. torch_geometric/datasets/ged_dataset.py +3 -2
  44. torch_geometric/datasets/git_mol_dataset.py +263 -0
  45. torch_geometric/datasets/gnn_benchmark_dataset.py +6 -5
  46. torch_geometric/datasets/hgb_dataset.py +8 -8
  47. torch_geometric/datasets/imdb.py +2 -1
  48. torch_geometric/datasets/last_fm.py +2 -1
  49. torch_geometric/datasets/linkx_dataset.py +4 -3
  50. torch_geometric/datasets/lrgb.py +3 -5
  51. torch_geometric/datasets/malnet_tiny.py +4 -3
  52. torch_geometric/datasets/mnist_superpixels.py +2 -3
  53. torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
  54. torch_geometric/datasets/molecule_net.py +7 -1
  55. torch_geometric/datasets/motif_generator/base.py +0 -1
  56. torch_geometric/datasets/neurograph.py +1 -3
  57. torch_geometric/datasets/ogb_mag.py +1 -1
  58. torch_geometric/datasets/opf.py +239 -0
  59. torch_geometric/datasets/ose_gvcs.py +1 -1
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcpnet_dataset.py +1 -1
  62. torch_geometric/datasets/pcqm4m.py +2 -1
  63. torch_geometric/datasets/ppi.py +1 -1
  64. torch_geometric/datasets/qm9.py +4 -3
  65. torch_geometric/datasets/reddit.py +2 -1
  66. torch_geometric/datasets/reddit2.py +2 -1
  67. torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
  68. torch_geometric/datasets/s3dis.py +2 -2
  69. torch_geometric/datasets/shapenet.py +3 -3
  70. torch_geometric/datasets/shrec2016.py +2 -2
  71. torch_geometric/datasets/tag_dataset.py +350 -0
  72. torch_geometric/datasets/upfd.py +2 -1
  73. torch_geometric/datasets/web_qsp_dataset.py +246 -0
  74. torch_geometric/datasets/webkb.py +2 -2
  75. torch_geometric/datasets/wikics.py +1 -1
  76. torch_geometric/datasets/wikidata.py +3 -2
  77. torch_geometric/datasets/wikipedia_network.py +2 -2
  78. torch_geometric/datasets/word_net.py +2 -2
  79. torch_geometric/datasets/yelp.py +2 -1
  80. torch_geometric/datasets/zinc.py +1 -1
  81. torch_geometric/device.py +42 -0
  82. torch_geometric/distributed/local_feature_store.py +3 -2
  83. torch_geometric/distributed/local_graph_store.py +2 -1
  84. torch_geometric/distributed/partition.py +9 -8
  85. torch_geometric/edge_index.py +17 -8
  86. torch_geometric/explain/algorithm/base.py +0 -1
  87. torch_geometric/explain/algorithm/pg_explainer.py +1 -1
  88. torch_geometric/explain/explanation.py +2 -2
  89. torch_geometric/graphgym/checkpoint.py +2 -1
  90. torch_geometric/graphgym/logger.py +4 -4
  91. torch_geometric/graphgym/loss.py +1 -1
  92. torch_geometric/graphgym/utils/agg_runs.py +6 -6
  93. torch_geometric/index.py +20 -7
  94. torch_geometric/inspector.py +6 -2
  95. torch_geometric/io/fs.py +28 -2
  96. torch_geometric/io/npz.py +2 -1
  97. torch_geometric/io/off.py +2 -2
  98. torch_geometric/io/sdf.py +2 -2
  99. torch_geometric/io/tu.py +2 -3
  100. torch_geometric/loader/__init__.py +4 -0
  101. torch_geometric/loader/cluster.py +9 -3
  102. torch_geometric/loader/graph_saint.py +2 -1
  103. torch_geometric/loader/ibmb_loader.py +12 -4
  104. torch_geometric/loader/mixin.py +1 -1
  105. torch_geometric/loader/neighbor_loader.py +1 -1
  106. torch_geometric/loader/neighbor_sampler.py +2 -2
  107. torch_geometric/loader/prefetch.py +1 -1
  108. torch_geometric/loader/rag_loader.py +107 -0
  109. torch_geometric/loader/zip_loader.py +10 -0
  110. torch_geometric/metrics/__init__.py +11 -2
  111. torch_geometric/metrics/link_pred.py +159 -34
  112. torch_geometric/nn/aggr/__init__.py +2 -0
  113. torch_geometric/nn/aggr/attention.py +0 -2
  114. torch_geometric/nn/aggr/base.py +2 -4
  115. torch_geometric/nn/aggr/patch_transformer.py +143 -0
  116. torch_geometric/nn/aggr/set_transformer.py +1 -1
  117. torch_geometric/nn/attention/__init__.py +5 -1
  118. torch_geometric/nn/attention/qformer.py +71 -0
  119. torch_geometric/nn/conv/collect.jinja +6 -3
  120. torch_geometric/nn/conv/cugraph/base.py +0 -1
  121. torch_geometric/nn/conv/edge_conv.py +3 -2
  122. torch_geometric/nn/conv/gat_conv.py +35 -7
  123. torch_geometric/nn/conv/gatv2_conv.py +36 -6
  124. torch_geometric/nn/conv/general_conv.py +1 -1
  125. torch_geometric/nn/conv/gravnet_conv.py +3 -2
  126. torch_geometric/nn/conv/hetero_conv.py +3 -3
  127. torch_geometric/nn/conv/hgt_conv.py +1 -1
  128. torch_geometric/nn/conv/message_passing.py +100 -82
  129. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  130. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  131. torch_geometric/nn/conv/spline_conv.py +4 -4
  132. torch_geometric/nn/conv/x_conv.py +3 -2
  133. torch_geometric/nn/dense/linear.py +5 -4
  134. torch_geometric/nn/fx.py +3 -3
  135. torch_geometric/nn/model_hub.py +3 -1
  136. torch_geometric/nn/models/__init__.py +10 -2
  137. torch_geometric/nn/models/deep_graph_infomax.py +1 -2
  138. torch_geometric/nn/models/dimenet_utils.py +5 -7
  139. torch_geometric/nn/models/g_retriever.py +230 -0
  140. torch_geometric/nn/models/git_mol.py +336 -0
  141. torch_geometric/nn/models/glem.py +385 -0
  142. torch_geometric/nn/models/gnnff.py +0 -1
  143. torch_geometric/nn/models/graph_unet.py +12 -3
  144. torch_geometric/nn/models/jumping_knowledge.py +63 -4
  145. torch_geometric/nn/models/lightgcn.py +1 -1
  146. torch_geometric/nn/models/metapath2vec.py +3 -4
  147. torch_geometric/nn/models/molecule_gpt.py +222 -0
  148. torch_geometric/nn/models/node2vec.py +1 -2
  149. torch_geometric/nn/models/schnet.py +2 -1
  150. torch_geometric/nn/models/signed_gcn.py +3 -3
  151. torch_geometric/nn/module_dict.py +2 -2
  152. torch_geometric/nn/nlp/__init__.py +9 -0
  153. torch_geometric/nn/nlp/llm.py +322 -0
  154. torch_geometric/nn/nlp/sentence_transformer.py +134 -0
  155. torch_geometric/nn/nlp/vision_transformer.py +33 -0
  156. torch_geometric/nn/norm/batch_norm.py +1 -1
  157. torch_geometric/nn/parameter_dict.py +2 -2
  158. torch_geometric/nn/pool/__init__.py +7 -5
  159. torch_geometric/nn/pool/cluster_pool.py +145 -0
  160. torch_geometric/nn/pool/connect/base.py +0 -1
  161. torch_geometric/nn/pool/edge_pool.py +1 -1
  162. torch_geometric/nn/pool/graclus.py +4 -2
  163. torch_geometric/nn/pool/select/base.py +0 -1
  164. torch_geometric/nn/pool/voxel_grid.py +3 -2
  165. torch_geometric/nn/resolver.py +1 -1
  166. torch_geometric/nn/sequential.jinja +10 -23
  167. torch_geometric/nn/sequential.py +203 -77
  168. torch_geometric/nn/summary.py +1 -1
  169. torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
  170. torch_geometric/profile/__init__.py +2 -0
  171. torch_geometric/profile/nvtx.py +66 -0
  172. torch_geometric/profile/profiler.py +24 -15
  173. torch_geometric/resolver.py +1 -1
  174. torch_geometric/sampler/base.py +34 -13
  175. torch_geometric/sampler/neighbor_sampler.py +11 -10
  176. torch_geometric/testing/decorators.py +17 -22
  177. torch_geometric/transforms/__init__.py +2 -0
  178. torch_geometric/transforms/add_metapaths.py +4 -4
  179. torch_geometric/transforms/add_positional_encoding.py +1 -1
  180. torch_geometric/transforms/delaunay.py +65 -14
  181. torch_geometric/transforms/face_to_edge.py +32 -3
  182. torch_geometric/transforms/gdc.py +7 -6
  183. torch_geometric/transforms/laplacian_lambda_max.py +2 -2
  184. torch_geometric/transforms/mask.py +5 -1
  185. torch_geometric/transforms/node_property_split.py +1 -2
  186. torch_geometric/transforms/pad.py +7 -6
  187. torch_geometric/transforms/random_link_split.py +1 -1
  188. torch_geometric/transforms/remove_self_loops.py +36 -0
  189. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  190. torch_geometric/transforms/virtual_node.py +2 -1
  191. torch_geometric/typing.py +31 -5
  192. torch_geometric/utils/__init__.py +5 -1
  193. torch_geometric/utils/_negative_sampling.py +1 -1
  194. torch_geometric/utils/_normalize_edge_index.py +46 -0
  195. torch_geometric/utils/_scatter.py +37 -12
  196. torch_geometric/utils/_subgraph.py +4 -0
  197. torch_geometric/utils/_tree_decomposition.py +2 -2
  198. torch_geometric/utils/augmentation.py +1 -1
  199. torch_geometric/utils/convert.py +5 -5
  200. torch_geometric/utils/geodesic.py +24 -22
  201. torch_geometric/utils/hetero.py +1 -1
  202. torch_geometric/utils/map.py +1 -1
  203. torch_geometric/utils/smiles.py +66 -28
  204. torch_geometric/utils/sparse.py +25 -10
  205. torch_geometric/visualization/graph.py +3 -4
@@ -0,0 +1,230 @@
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from torch_geometric.nn.nlp.llm import BOS, LLM, MAX_NEW_TOKENS
7
+ from torch_geometric.utils import scatter
8
+
9
+
10
+ class GRetriever(torch.nn.Module):
11
+ r"""The G-Retriever model from the `"G-Retriever: Retrieval-Augmented
12
+ Generation for Textual Graph Understanding and Question Answering"
13
+ <https://arxiv.org/abs/2402.07630>`_ paper.
14
+
15
+ Args:
16
+ llm (LLM): The LLM to use.
17
+ gnn (torch.nn.Module): The GNN to use.
18
+ use_lora (bool, optional): If set to :obj:`True`, will use LORA from
19
+ :obj:`peft` for training the LLM, see
20
+ `here <https://huggingface.co/docs/peft/en/index>`_ for details.
21
+ (default: :obj:`False`)
22
+ mlp_out_channels (int, optional): The size of each graph embedding
23
+ after projection. (default: :obj:`4096`)
24
+ mlp_out_tokens (int, optional): Number of LLM prefix tokens to
25
+ reserve for GNN output. (default: :obj:`1`)
26
+
27
+ .. warning::
28
+ This module has been tested with the following HuggingFace models
29
+
30
+ * :obj:`llm_to_use="meta-llama/Llama-2-7b-chat-hf"`
31
+ * :obj:`llm_to_use="google/gemma-7b"`
32
+
33
+ and may not work with other models. See other models at `HuggingFace
34
+ Models <https://huggingface.co/models>`_ and let us know if you
35
+ encounter any issues.
36
+
37
+ .. note::
38
+ For an example of using :class:`GRetriever`, see
39
+ `examples/llm/g_retriever.py <https://github.com/pyg-team/
40
+ pytorch_geometric/blob/master/examples/llm/g_retriever.py>`_.
41
+ """
42
+ def __init__(
43
+ self,
44
+ llm: LLM,
45
+ gnn: torch.nn.Module,
46
+ use_lora: bool = False,
47
+ mlp_out_channels: int = 4096,
48
+ mlp_out_tokens: int = 1,
49
+ ) -> None:
50
+ super().__init__()
51
+
52
+ self.llm = llm
53
+ self.gnn = gnn.to(self.llm.device)
54
+
55
+ self.word_embedding = self.llm.word_embedding
56
+ self.llm_generator = self.llm.llm
57
+ if use_lora:
58
+ from peft import (
59
+ LoraConfig,
60
+ get_peft_model,
61
+ prepare_model_for_kbit_training,
62
+ )
63
+ self.llm_generator = prepare_model_for_kbit_training(
64
+ self.llm_generator)
65
+ lora_r: int = 8
66
+ lora_alpha: int = 16
67
+ lora_dropout: float = 0.05
68
+ lora_target_modules = ['q_proj', 'v_proj']
69
+ config = LoraConfig(
70
+ r=lora_r,
71
+ lora_alpha=lora_alpha,
72
+ target_modules=lora_target_modules,
73
+ lora_dropout=lora_dropout,
74
+ bias='none',
75
+ task_type='CAUSAL_LM',
76
+ )
77
+ self.llm_generator = get_peft_model(self.llm_generator, config)
78
+
79
+ mlp_hidden_channels = self.gnn.out_channels
80
+ self.projector = torch.nn.Sequential(
81
+ torch.nn.Linear(mlp_hidden_channels, mlp_hidden_channels),
82
+ torch.nn.Sigmoid(),
83
+ torch.nn.Linear(mlp_hidden_channels,
84
+ mlp_out_channels * mlp_out_tokens),
85
+ torch.nn.Unflatten(-1, (mlp_out_tokens, mlp_out_channels)),
86
+ ).to(self.llm.device)
87
+
88
+ def encode(
89
+ self,
90
+ x: Tensor,
91
+ edge_index: Tensor,
92
+ batch: Tensor,
93
+ edge_attr: Optional[Tensor],
94
+ ) -> Tensor:
95
+ x = x.to(self.llm.device)
96
+ edge_index = edge_index.to(self.llm.device)
97
+ if edge_attr is not None:
98
+ edge_attr = edge_attr.to(self.llm.device)
99
+ batch = batch.to(self.llm.device)
100
+
101
+ out = self.gnn(x, edge_index, edge_attr=edge_attr)
102
+ return scatter(out, batch, dim=0, reduce='mean')
103
+
104
+ def forward(
105
+ self,
106
+ question: List[str],
107
+ x: Tensor,
108
+ edge_index: Tensor,
109
+ batch: Tensor,
110
+ label: List[str],
111
+ edge_attr: Optional[Tensor] = None,
112
+ additional_text_context: Optional[List[str]] = None,
113
+ ):
114
+ r"""The forward pass.
115
+
116
+ Args:
117
+ question (List[str]): The questions/prompts.
118
+ x (torch.Tensor): The input node features.
119
+ edge_index (torch.Tensor): The edge indices.
120
+ batch (torch.Tensor): The batch vector
121
+ :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
122
+ each element to a specific example.
123
+ label (List[str]): The answers/labels.
124
+ edge_attr (torch.Tensor, optional): The edge features (if supported
125
+ by the GNN). (default: :obj:`None`)
126
+ additional_text_context (List[str], optional): Additional context
127
+ to give to the LLM, such as textified knowledge graphs.
128
+ (default: :obj:`None`)
129
+ """
130
+ x = self.encode(x, edge_index, batch, edge_attr)
131
+ x = self.projector(x)
132
+ xs = x.split(1, dim=0)
133
+
134
+ # Handle case where theres more than one embedding for each sample
135
+ xs = [x.squeeze(0) for x in xs]
136
+
137
+ # Handle questions without node features:
138
+ batch_unique = batch.unique()
139
+ batch_size = len(question)
140
+ if len(batch_unique) < batch_size:
141
+ xs = [
142
+ xs[i] if i in batch_unique else None for i in range(batch_size)
143
+ ]
144
+
145
+ (
146
+ inputs_embeds,
147
+ attention_mask,
148
+ label_input_ids,
149
+ ) = self.llm._get_embeds(question, additional_text_context, xs, label)
150
+
151
+ with self.llm.autocast_context:
152
+ outputs = self.llm_generator(
153
+ inputs_embeds=inputs_embeds,
154
+ attention_mask=attention_mask,
155
+ return_dict=True,
156
+ labels=label_input_ids,
157
+ )
158
+
159
+ return outputs.loss
160
+
161
+ @torch.no_grad()
162
+ def inference(
163
+ self,
164
+ question: List[str],
165
+ x: Tensor,
166
+ edge_index: Tensor,
167
+ batch: Tensor,
168
+ edge_attr: Optional[Tensor] = None,
169
+ additional_text_context: Optional[List[str]] = None,
170
+ max_out_tokens: Optional[int] = MAX_NEW_TOKENS,
171
+ ):
172
+ r"""The inference pass.
173
+
174
+ Args:
175
+ question (List[str]): The questions/prompts.
176
+ x (torch.Tensor): The input node features.
177
+ edge_index (torch.Tensor): The edge indices.
178
+ batch (torch.Tensor): The batch vector
179
+ :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
180
+ each element to a specific example.
181
+ edge_attr (torch.Tensor, optional): The edge features (if supported
182
+ by the GNN). (default: :obj:`None`)
183
+ additional_text_context (List[str], optional): Additional context
184
+ to give to the LLM, such as textified knowledge graphs.
185
+ (default: :obj:`None`)
186
+ max_out_tokens (int, optional): How many tokens for the LLM to
187
+ generate. (default: :obj:`32`)
188
+ """
189
+ x = self.encode(x, edge_index, batch, edge_attr)
190
+ x = self.projector(x)
191
+ xs = x.split(1, dim=0)
192
+
193
+ # Handle case where theres more than one embedding for each sample
194
+ xs = [x.squeeze(0) for x in xs]
195
+
196
+ # Handle questions without node features:
197
+ batch_unique = batch.unique()
198
+ batch_size = len(question)
199
+ if len(batch_unique) < batch_size:
200
+ xs = [
201
+ xs[i] if i in batch_unique else None for i in range(batch_size)
202
+ ]
203
+
204
+ inputs_embeds, attention_mask, _ = self.llm._get_embeds(
205
+ question, additional_text_context, xs)
206
+
207
+ bos_token = self.llm.tokenizer(
208
+ BOS,
209
+ add_special_tokens=False,
210
+ ).input_ids[0]
211
+
212
+ with self.llm.autocast_context:
213
+ outputs = self.llm_generator.generate(
214
+ inputs_embeds=inputs_embeds,
215
+ max_new_tokens=max_out_tokens,
216
+ attention_mask=attention_mask,
217
+ bos_token_id=bos_token,
218
+ use_cache=True # Important to set!
219
+ )
220
+
221
+ return self.llm.tokenizer.batch_decode(
222
+ outputs,
223
+ skip_special_tokens=True,
224
+ )
225
+
226
+ def __repr__(self) -> str:
227
+ return (f'{self.__class__.__name__}(\n'
228
+ f' llm={self.llm},\n'
229
+ f' gnn={self.gnn},\n'
230
+ f')')
@@ -0,0 +1,336 @@
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import Tensor
6
+ from torch.nn import BatchNorm1d, LayerNorm, Linear, ReLU, Sequential
7
+
8
+ from torch_geometric.nn import GINEConv
9
+ from torch_geometric.nn.nlp import SentenceTransformer, VisionTransformer
10
+ from torch_geometric.utils import add_self_loops, to_dense_batch
11
+
12
+
13
+ class GraphEncoder(torch.nn.Module):
14
+ def __init__(
15
+ self,
16
+ num_layers: int,
17
+ in_channels: int,
18
+ dropout: float = 0.,
19
+ num_atom_type: int = 120,
20
+ num_chirality_tag: int = 3,
21
+ num_bond_type: int = 6,
22
+ num_bond_direction: int = 3,
23
+ ) -> None:
24
+ super().__init__()
25
+
26
+ self.num_layers = num_layers
27
+ self.dropout = dropout
28
+
29
+ self.x_embed1 = torch.nn.Embedding(num_atom_type, in_channels)
30
+ self.x_embed2 = torch.nn.Embedding(num_chirality_tag, in_channels)
31
+ self.edge_embed1 = torch.nn.Embedding(num_bond_type, in_channels)
32
+ self.edge_embed2 = torch.nn.Embedding(num_bond_direction, in_channels)
33
+
34
+ self.gnns = torch.nn.ModuleList()
35
+ self.batch_norms = torch.nn.ModuleList()
36
+ for _ in range(num_layers):
37
+ self.gnns.append(
38
+ GINEConv(
39
+ nn=Sequential(
40
+ Linear(in_channels, in_channels * 2),
41
+ ReLU(),
42
+ Linear(in_channels * 2, in_channels),
43
+ ),
44
+ train_eps=True,
45
+ edge_dim=in_channels,
46
+ ))
47
+ self.batch_norms.append(BatchNorm1d(in_channels))
48
+ self.reset_parameters()
49
+
50
+ def reset_parameters(self):
51
+ torch.nn.init.xavier_uniform_(self.x_embed1.weight.data)
52
+ torch.nn.init.xavier_uniform_(self.x_embed2.weight.data)
53
+ torch.nn.init.xavier_uniform_(self.edge_embed1.weight.data)
54
+ torch.nn.init.xavier_uniform_(self.edge_embed2.weight.data)
55
+
56
+ def forward(
57
+ self,
58
+ x: Tensor,
59
+ edge_index: Tensor,
60
+ batch: Tensor,
61
+ edge_attr: Tensor,
62
+ ) -> Tensor:
63
+ x = self.x_embed1(x[:, 0].long()) + self.x_embed2(x[:, 1].long())
64
+ edge_index, edge_attr = add_self_loops(
65
+ edge_index,
66
+ edge_attr,
67
+ fill_value=0,
68
+ num_nodes=x.size(0),
69
+ )
70
+ edge_attr = self.edge_embed1(edge_attr[:, 0]) + self.edge_embed2(
71
+ edge_attr[:, 1])
72
+ for i, (gnn, bn) in enumerate(zip(self.gnns, self.batch_norms)):
73
+ x = gnn(x, edge_index, edge_attr)
74
+ x = bn(x)
75
+ if i < self.num_layers - 1:
76
+ x = F.relu(x)
77
+ x = F.dropout(x, self.dropout, training=self.training)
78
+
79
+ x, mask = to_dense_batch(x, batch)
80
+ return x, mask
81
+
82
+
83
+ class GITFormer(torch.nn.Module):
84
+ def __init__(
85
+ self,
86
+ num_query_token: int,
87
+ vision_graph_width: int,
88
+ cross_attention_freq: int = 2,
89
+ ):
90
+ super().__init__()
91
+ from transformers import AutoConfig, AutoModel
92
+
93
+ config = AutoConfig.from_pretrained("allenai/scibert_scivocab_uncased")
94
+ config.encoder_width = vision_graph_width
95
+ # insert cross-attention layer every other block
96
+ config.add_cross_attention = True
97
+ config.is_decoder = True
98
+ config.cross_attention_freq = cross_attention_freq
99
+ config.query_length = num_query_token
100
+ self.Qformer = AutoModel.from_pretrained(
101
+ "allenai/scibert_scivocab_uncased", config=config)
102
+ self.query_tokens = torch.nn.Parameter(
103
+ torch.zeros(1, num_query_token, config.hidden_size))
104
+ self.query_tokens.data.normal_(mean=0.0, std=config.initializer_range)
105
+
106
+
107
+ class GITMol(torch.nn.Module):
108
+ r"""The GITMol model from the `"GIT-Mol: A Multi-modal Large Language
109
+ Model for Molecular Science with Graph, Image, and Text"
110
+ <https://arxiv.org/pdf/2308.06911>`_ paper.
111
+
112
+ .. note::
113
+ For an example of using :class:`GITMol`, see
114
+ `examples/llm/git_mol.py <https://github.com/pyg-team/
115
+ pytorch_geometric/blob/master/examples/llm/git_mol.py>`_.
116
+ """
117
+ def __init__(self) -> None:
118
+ super().__init__()
119
+ # graph
120
+ self.graph_encoder = GraphEncoder(num_layers=2, in_channels=16)
121
+ self.graph_proj = Linear(16, 768)
122
+ self.ln_graph = LayerNorm(768)
123
+ # text
124
+ self.text_encoder = SentenceTransformer(
125
+ model_name='allenai/scibert_scivocab_uncased',
126
+ pooling_strategy='last_hidden_state',
127
+ )
128
+ self.text_proj = Linear(768, 768)
129
+ self.ln_text = LayerNorm(768)
130
+ # vision
131
+ self.vision_encoder = VisionTransformer(
132
+ model_name='microsoft/swin-base-patch4-window7-224', )
133
+ self.vision_proj = Linear(1024, 768)
134
+ self.ln_vision = LayerNorm(768)
135
+ # cross-attention
136
+ self.gitformer = GITFormer(384, 768)
137
+
138
+ self.xtm_head = torch.nn.ModuleDict({
139
+ 'image':
140
+ Linear(self.gitformer.Qformer.config.hidden_size, 2),
141
+ 'graph':
142
+ Linear(self.gitformer.Qformer.config.hidden_size, 2),
143
+ 'cs_text':
144
+ Linear(self.gitformer.Qformer.config.hidden_size, 2),
145
+ })
146
+
147
+ self.xtc_proj = torch.nn.ModuleDict({
148
+ 'image':
149
+ Linear(self.gitformer.Qformer.config.hidden_size, 768),
150
+ 'graph':
151
+ Linear(self.gitformer.Qformer.config.hidden_size, 768),
152
+ 'cs_text':
153
+ Linear(self.gitformer.Qformer.config.hidden_size, 768),
154
+ })
155
+ self.temp = torch.nn.Parameter(0.07 * torch.ones([]))
156
+ self.model_freeze()
157
+
158
+ def model_freeze(self) -> None:
159
+ for param in self.graph_encoder.parameters():
160
+ param.requires_grad = False
161
+
162
+ for param in self.vision_encoder.parameters():
163
+ param.requires_grad = False
164
+
165
+ def forward(
166
+ self,
167
+ x: Tensor,
168
+ edge_index: Tensor,
169
+ batch: Tensor,
170
+ edge_attr: Optional[Tensor],
171
+ smiles: List[str],
172
+ images: Tensor,
173
+ captions: List[str],
174
+ ) -> Tensor:
175
+ batch_size = len(smiles)
176
+
177
+ x_vision = self.vision_encoder(images)
178
+ x_vision = self.vision_proj(x_vision)
179
+ x_vision = self.ln_vision(x_vision) # [bs, patch_len, d]
180
+ vision_atts = torch.ones(x_vision.size()[:-1],
181
+ dtype=torch.long).to(x_vision.device)
182
+ vision_targets = torch.arange(batch_size).to(x_vision.device)
183
+
184
+ x_graph, graph_atts = self.graph_encoder(x, edge_index, batch,
185
+ edge_attr)
186
+ x_graph = self.graph_proj(x_graph)
187
+ x_graph = self.ln_graph(x_graph) # [bs, node_len, d]
188
+ graph_targets = torch.arange(batch_size).to(x_graph.device)
189
+
190
+ x_smiles = self.text_encoder.encode(smiles) # [bs, seq_len, d]
191
+ smiles_atts = torch.ones(x_smiles.size()[:-1],
192
+ dtype=torch.long).to(x_smiles.device)
193
+ smiles_targets = torch.arange(batch_size).to(x_smiles.device)
194
+
195
+ caption_input_ids, caption_attention_masks = self.text_encoder.get_input_ids( # noqa: E501
196
+ captions)
197
+
198
+ text_output = self.gitformer.Qformer(
199
+ caption_input_ids,
200
+ attention_mask=caption_attention_masks,
201
+ return_dict=True,
202
+ )
203
+ text_feat = F.normalize(
204
+ self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1)
205
+
206
+ loss = 0
207
+ for x_embed, x_atts, x_targets, modal in zip(
208
+ [x_graph, x_smiles, x_vision],
209
+ [graph_atts, smiles_atts, vision_atts],
210
+ [graph_targets, smiles_targets, vision_targets],
211
+ ['graph', 'cs_text', 'image'],
212
+ ):
213
+ loss += self._calc_xtc_loss(x_embed, x_atts, x_targets, text_feat,
214
+ modal)
215
+ loss += self._calc_xtm_loss(x_embed, caption_input_ids,
216
+ caption_attention_masks, modal)
217
+
218
+ return loss / 6
219
+
220
+ def _calc_xtm_loss(
221
+ self,
222
+ x_embeds: Tensor,
223
+ input_ids: Tensor,
224
+ attention_mask: Tensor,
225
+ modal: str,
226
+ ) -> Tensor:
227
+ # Initializing lists to hold the original and negative samples
228
+ x_embeds_list = []
229
+ text_input_ids_list = []
230
+ text_attention_mask_list = []
231
+
232
+ batch_size = x_embeds.size(0)
233
+ for i in range(batch_size):
234
+ # Original samples
235
+ x_embeds_list.append(x_embeds[i])
236
+ text_input_ids_list.append(input_ids[i, :])
237
+ text_attention_mask_list.append(attention_mask[i, :])
238
+
239
+ if batch_size > 1:
240
+ # Negative samples (neg_text_input_ids corresponds to x_embeds)
241
+ neg_text_input_ids = input_ids[i - 1 if i == batch_size -
242
+ 1 else i + 1, :]
243
+ neg_text_attention_mask = attention_mask[i -
244
+ 1 if i == batch_size -
245
+ 1 else i + 1, :]
246
+ text_input_ids_list.append(neg_text_input_ids)
247
+ text_attention_mask_list.append(neg_text_attention_mask)
248
+ x_embeds_list.append(x_embeds[i, :])
249
+
250
+ # Negative samples (text_input_ids corresponds to neg_x_embeds)
251
+ neg_x_embeds = x_embeds[i - 1 if i == batch_size - 1 else i +
252
+ 1, :]
253
+ x_embeds_list.append(neg_x_embeds)
254
+ text_input_ids_list.append(input_ids[i, :])
255
+ text_attention_mask_list.append(attention_mask[i, :])
256
+
257
+ # Stack all samples into two large tensors
258
+ x_embeds_all = torch.stack(x_embeds_list, dim=1) \
259
+ .reshape(-1, x_embeds.size(1), x_embeds.size(2))
260
+ text_input_ids_all = torch.stack(text_input_ids_list, dim=1) \
261
+ .reshape(-1, input_ids.size(1))
262
+ # Create image attention masks for the concatenated tensor
263
+ image_attns_all = torch.ones(x_embeds_all.size()[:-1],
264
+ dtype=torch.long).to(x_embeds_all.device)
265
+ query_tokens_xtm = self.gitformer.query_tokens.expand(
266
+ text_input_ids_all.shape[0], -1, -1)
267
+ query_attns_xtm = torch.ones(query_tokens_xtm.size()[:-1],
268
+ dtype=torch.long).to(x_embeds_all.device)
269
+
270
+ output_xtm = self.gitformer.Qformer(
271
+ inputs_embeds=query_tokens_xtm,
272
+ attention_mask=query_attns_xtm,
273
+ encoder_hidden_states=x_embeds_all,
274
+ encoder_attention_mask=image_attns_all,
275
+ return_dict=True,
276
+ ).last_hidden_state
277
+
278
+ xtm_embeddings = output_xtm[:, :query_tokens_xtm.size(1), :]
279
+
280
+ xtm_logit = self.xtm_head[modal](xtm_embeddings).mean(dim=1)
281
+ # Create labels: 1 for the original samples, 0 for the negative samples
282
+ if batch_size > 1:
283
+ labels = torch.cat(
284
+ [torch.ones(batch_size),
285
+ torch.zeros(batch_size * 2)], dim=0)
286
+ else:
287
+ labels = torch.ones(batch_size)
288
+ labels = labels.long().to(xtm_logit.device)
289
+
290
+ # Calculate cross entropy loss
291
+ return F.cross_entropy(xtm_logit, labels)
292
+
293
+ def _calc_xtc_loss(
294
+ self,
295
+ x_embeds: Tensor,
296
+ x_atts: Tensor,
297
+ x_targets: Tensor,
298
+ text_feat: Tensor,
299
+ modal: str,
300
+ ) -> Tensor:
301
+ query_tokens = self.gitformer.query_tokens.expand(
302
+ x_embeds.shape[0], -1, -1)
303
+
304
+ query_output = self.gitformer.Qformer(
305
+ inputs_embeds=query_tokens,
306
+ encoder_hidden_states=x_embeds,
307
+ encoder_attention_mask=x_atts,
308
+ return_dict=True,
309
+ ).last_hidden_state
310
+
311
+ x_feats = F.normalize(self.xtc_proj[modal](query_output), dim=-1)
312
+
313
+ sim_q2t = torch.matmul(
314
+ x_feats.unsqueeze(1),
315
+ text_feat.unsqueeze(-1),
316
+ ).squeeze(-1)
317
+
318
+ # modal-text similarity: aggregate across all query tokens
319
+ sim_x2t, _ = sim_q2t.max(-1)
320
+ sim_x2t = sim_x2t / self.temp
321
+
322
+ # text-query similarity
323
+ sim_t2q = torch.matmul(
324
+ text_feat.unsqueeze(1).unsqueeze(1),
325
+ x_feats.permute(0, 2, 1),
326
+ ).squeeze(-2)
327
+
328
+ # text-modal similarity: aggregate across all query tokens
329
+ sim_t2x, _ = sim_t2q.max(-1)
330
+ sim_t2x = sim_t2x / self.temp
331
+
332
+ loss_itc = (
333
+ F.cross_entropy(sim_x2t, x_targets, label_smoothing=0.1) +
334
+ F.cross_entropy(sim_t2x, x_targets, label_smoothing=0.1)) / 2
335
+
336
+ return loss_itc