pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251228__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (229) hide show
  1. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/METADATA +77 -53
  2. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/RECORD +227 -190
  3. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251228.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +14 -2
  6. torch_geometric/_compile.py +9 -3
  7. torch_geometric/_onnx.py +214 -0
  8. torch_geometric/config_mixin.py +5 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/batch.py +2 -2
  13. torch_geometric/data/collate.py +1 -3
  14. torch_geometric/data/data.py +109 -5
  15. torch_geometric/data/database.py +4 -0
  16. torch_geometric/data/dataset.py +14 -11
  17. torch_geometric/data/extract.py +1 -1
  18. torch_geometric/data/feature_store.py +17 -22
  19. torch_geometric/data/graph_store.py +3 -2
  20. torch_geometric/data/hetero_data.py +139 -7
  21. torch_geometric/data/hypergraph_data.py +2 -2
  22. torch_geometric/data/in_memory_dataset.py +2 -2
  23. torch_geometric/data/lightning/datamodule.py +42 -28
  24. torch_geometric/data/storage.py +9 -1
  25. torch_geometric/datasets/__init__.py +18 -1
  26. torch_geometric/datasets/actor.py +7 -9
  27. torch_geometric/datasets/airfrans.py +15 -17
  28. torch_geometric/datasets/airports.py +8 -10
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +8 -9
  31. torch_geometric/datasets/amazon_products.py +7 -9
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/city.py +157 -0
  38. torch_geometric/datasets/dbp15k.py +1 -1
  39. torch_geometric/datasets/git_mol_dataset.py +263 -0
  40. torch_geometric/datasets/hgb_dataset.py +2 -2
  41. torch_geometric/datasets/hm.py +1 -1
  42. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  43. torch_geometric/datasets/md17.py +3 -3
  44. torch_geometric/datasets/medshapenet.py +145 -0
  45. torch_geometric/datasets/modelnet.py +1 -1
  46. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  47. torch_geometric/datasets/molecule_net.py +3 -2
  48. torch_geometric/datasets/ppi.py +2 -1
  49. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  50. torch_geometric/datasets/qm7.py +1 -1
  51. torch_geometric/datasets/qm9.py +1 -1
  52. torch_geometric/datasets/snap_dataset.py +8 -4
  53. torch_geometric/datasets/tag_dataset.py +462 -0
  54. torch_geometric/datasets/teeth3ds.py +269 -0
  55. torch_geometric/datasets/web_qsp_dataset.py +310 -209
  56. torch_geometric/datasets/wikics.py +2 -1
  57. torch_geometric/deprecation.py +1 -1
  58. torch_geometric/distributed/__init__.py +13 -0
  59. torch_geometric/distributed/dist_loader.py +2 -2
  60. torch_geometric/distributed/partition.py +2 -2
  61. torch_geometric/distributed/rpc.py +3 -3
  62. torch_geometric/edge_index.py +18 -14
  63. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  64. torch_geometric/explain/algorithm/base.py +2 -2
  65. torch_geometric/explain/algorithm/captum.py +1 -1
  66. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  67. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  68. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  69. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  70. torch_geometric/explain/explainer.py +2 -2
  71. torch_geometric/explain/explanation.py +87 -3
  72. torch_geometric/explain/metric/faithfulness.py +1 -1
  73. torch_geometric/graphgym/config.py +3 -2
  74. torch_geometric/graphgym/imports.py +15 -4
  75. torch_geometric/graphgym/logger.py +1 -1
  76. torch_geometric/graphgym/loss.py +1 -1
  77. torch_geometric/graphgym/models/encoder.py +2 -2
  78. torch_geometric/graphgym/models/layer.py +1 -1
  79. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  80. torch_geometric/hash_tensor.py +798 -0
  81. torch_geometric/index.py +14 -5
  82. torch_geometric/inspector.py +4 -0
  83. torch_geometric/io/fs.py +5 -4
  84. torch_geometric/llm/__init__.py +9 -0
  85. torch_geometric/llm/large_graph_indexer.py +741 -0
  86. torch_geometric/llm/models/__init__.py +23 -0
  87. torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
  88. torch_geometric/llm/models/git_mol.py +336 -0
  89. torch_geometric/llm/models/glem.py +397 -0
  90. torch_geometric/{nn/nlp → llm/models}/llm.py +180 -32
  91. torch_geometric/llm/models/llm_judge.py +158 -0
  92. torch_geometric/llm/models/molecule_gpt.py +222 -0
  93. torch_geometric/llm/models/protein_mpnn.py +333 -0
  94. torch_geometric/llm/models/sentence_transformer.py +188 -0
  95. torch_geometric/llm/models/txt2kg.py +353 -0
  96. torch_geometric/llm/models/vision_transformer.py +38 -0
  97. torch_geometric/llm/rag_loader.py +154 -0
  98. torch_geometric/llm/utils/__init__.py +10 -0
  99. torch_geometric/llm/utils/backend_utils.py +443 -0
  100. torch_geometric/llm/utils/feature_store.py +169 -0
  101. torch_geometric/llm/utils/graph_store.py +199 -0
  102. torch_geometric/llm/utils/vectorrag.py +125 -0
  103. torch_geometric/loader/cluster.py +4 -4
  104. torch_geometric/loader/ibmb_loader.py +4 -4
  105. torch_geometric/loader/link_loader.py +1 -1
  106. torch_geometric/loader/link_neighbor_loader.py +2 -1
  107. torch_geometric/loader/mixin.py +6 -5
  108. torch_geometric/loader/neighbor_loader.py +1 -1
  109. torch_geometric/loader/neighbor_sampler.py +2 -2
  110. torch_geometric/loader/prefetch.py +3 -2
  111. torch_geometric/loader/temporal_dataloader.py +2 -2
  112. torch_geometric/loader/utils.py +10 -10
  113. torch_geometric/metrics/__init__.py +14 -0
  114. torch_geometric/metrics/link_pred.py +745 -92
  115. torch_geometric/nn/__init__.py +1 -0
  116. torch_geometric/nn/aggr/base.py +1 -1
  117. torch_geometric/nn/aggr/equilibrium.py +1 -1
  118. torch_geometric/nn/aggr/fused.py +1 -1
  119. torch_geometric/nn/aggr/patch_transformer.py +8 -2
  120. torch_geometric/nn/aggr/set_transformer.py +1 -1
  121. torch_geometric/nn/aggr/utils.py +9 -4
  122. torch_geometric/nn/attention/__init__.py +9 -1
  123. torch_geometric/nn/attention/polynormer.py +107 -0
  124. torch_geometric/nn/attention/qformer.py +71 -0
  125. torch_geometric/nn/attention/sgformer.py +99 -0
  126. torch_geometric/nn/conv/__init__.py +2 -0
  127. torch_geometric/nn/conv/appnp.py +1 -1
  128. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  129. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  130. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  131. torch_geometric/nn/conv/dna_conv.py +1 -1
  132. torch_geometric/nn/conv/eg_conv.py +7 -7
  133. torch_geometric/nn/conv/gen_conv.py +1 -1
  134. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  135. torch_geometric/nn/conv/hetero_conv.py +2 -1
  136. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  137. torch_geometric/nn/conv/message_passing.py +5 -4
  138. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  139. torch_geometric/nn/conv/sg_conv.py +1 -1
  140. torch_geometric/nn/conv/spline_conv.py +2 -1
  141. torch_geometric/nn/conv/ssg_conv.py +1 -1
  142. torch_geometric/nn/conv/transformer_conv.py +5 -3
  143. torch_geometric/nn/data_parallel.py +5 -4
  144. torch_geometric/nn/dense/linear.py +0 -20
  145. torch_geometric/nn/encoding.py +17 -3
  146. torch_geometric/nn/fx.py +14 -12
  147. torch_geometric/nn/model_hub.py +2 -15
  148. torch_geometric/nn/models/__init__.py +11 -2
  149. torch_geometric/nn/models/attentive_fp.py +1 -1
  150. torch_geometric/nn/models/attract_repel.py +148 -0
  151. torch_geometric/nn/models/basic_gnn.py +2 -1
  152. torch_geometric/nn/models/captum.py +1 -1
  153. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  154. torch_geometric/nn/models/dimenet.py +2 -2
  155. torch_geometric/nn/models/dimenet_utils.py +4 -2
  156. torch_geometric/nn/models/gpse.py +1083 -0
  157. torch_geometric/nn/models/graph_unet.py +13 -4
  158. torch_geometric/nn/models/lpformer.py +783 -0
  159. torch_geometric/nn/models/metapath2vec.py +1 -1
  160. torch_geometric/nn/models/mlp.py +4 -2
  161. torch_geometric/nn/models/node2vec.py +1 -1
  162. torch_geometric/nn/models/polynormer.py +206 -0
  163. torch_geometric/nn/models/rev_gnn.py +3 -3
  164. torch_geometric/nn/models/sgformer.py +219 -0
  165. torch_geometric/nn/models/signed_gcn.py +1 -1
  166. torch_geometric/nn/models/visnet.py +2 -2
  167. torch_geometric/nn/norm/batch_norm.py +17 -7
  168. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  169. torch_geometric/nn/norm/graph_norm.py +9 -4
  170. torch_geometric/nn/norm/instance_norm.py +5 -1
  171. torch_geometric/nn/norm/layer_norm.py +15 -7
  172. torch_geometric/nn/norm/msg_norm.py +8 -2
  173. torch_geometric/nn/pool/__init__.py +8 -4
  174. torch_geometric/nn/pool/cluster_pool.py +3 -4
  175. torch_geometric/nn/pool/connect/base.py +1 -3
  176. torch_geometric/nn/pool/knn.py +13 -10
  177. torch_geometric/nn/pool/select/base.py +1 -4
  178. torch_geometric/nn/to_hetero_module.py +4 -3
  179. torch_geometric/nn/to_hetero_transformer.py +3 -3
  180. torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
  181. torch_geometric/profile/__init__.py +2 -0
  182. torch_geometric/profile/nvtx.py +66 -0
  183. torch_geometric/profile/utils.py +20 -5
  184. torch_geometric/sampler/__init__.py +2 -1
  185. torch_geometric/sampler/base.py +336 -7
  186. torch_geometric/sampler/hgt_sampler.py +11 -1
  187. torch_geometric/sampler/neighbor_sampler.py +296 -23
  188. torch_geometric/sampler/utils.py +93 -5
  189. torch_geometric/testing/__init__.py +4 -0
  190. torch_geometric/testing/decorators.py +35 -5
  191. torch_geometric/testing/distributed.py +1 -1
  192. torch_geometric/transforms/__init__.py +2 -0
  193. torch_geometric/transforms/add_gpse.py +49 -0
  194. torch_geometric/transforms/add_metapaths.py +8 -6
  195. torch_geometric/transforms/add_positional_encoding.py +2 -2
  196. torch_geometric/transforms/base_transform.py +2 -1
  197. torch_geometric/transforms/delaunay.py +65 -15
  198. torch_geometric/transforms/face_to_edge.py +32 -3
  199. torch_geometric/transforms/gdc.py +7 -8
  200. torch_geometric/transforms/largest_connected_components.py +1 -1
  201. torch_geometric/transforms/mask.py +5 -1
  202. torch_geometric/transforms/normalize_features.py +3 -3
  203. torch_geometric/transforms/random_link_split.py +1 -1
  204. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  205. torch_geometric/transforms/rooted_subgraph.py +1 -1
  206. torch_geometric/typing.py +70 -17
  207. torch_geometric/utils/__init__.py +4 -1
  208. torch_geometric/utils/_lexsort.py +0 -9
  209. torch_geometric/utils/_negative_sampling.py +27 -12
  210. torch_geometric/utils/_scatter.py +132 -195
  211. torch_geometric/utils/_sort_edge_index.py +0 -2
  212. torch_geometric/utils/_spmm.py +16 -14
  213. torch_geometric/utils/_subgraph.py +4 -0
  214. torch_geometric/utils/_to_dense_batch.py +2 -2
  215. torch_geometric/utils/_trim_to_layer.py +2 -2
  216. torch_geometric/utils/convert.py +17 -10
  217. torch_geometric/utils/cross_entropy.py +34 -13
  218. torch_geometric/utils/embedding.py +91 -2
  219. torch_geometric/utils/geodesic.py +4 -3
  220. torch_geometric/utils/influence.py +279 -0
  221. torch_geometric/utils/map.py +13 -9
  222. torch_geometric/utils/nested.py +1 -1
  223. torch_geometric/utils/smiles.py +3 -3
  224. torch_geometric/utils/sparse.py +7 -14
  225. torch_geometric/visualization/__init__.py +2 -1
  226. torch_geometric/visualization/graph.py +250 -5
  227. torch_geometric/warnings.py +11 -2
  228. torch_geometric/nn/nlp/__init__.py +0 -7
  229. torch_geometric/nn/nlp/sentence_transformer.py +0 -101
@@ -0,0 +1,23 @@
1
+ from .sentence_transformer import SentenceTransformer
2
+ from .vision_transformer import VisionTransformer
3
+ from .llm import LLM
4
+ from .txt2kg import TXT2KG
5
+ from .llm_judge import LLMJudge
6
+ from .g_retriever import GRetriever
7
+ from .molecule_gpt import MoleculeGPT
8
+ from .glem import GLEM
9
+ from .protein_mpnn import ProteinMPNN
10
+ from .git_mol import GITMol
11
+
12
+ __all__ = classes = [
13
+ 'SentenceTransformer',
14
+ 'VisionTransformer',
15
+ 'LLM',
16
+ 'LLMJudge',
17
+ 'TXT2KG',
18
+ 'GRetriever',
19
+ 'MoleculeGPT',
20
+ 'GLEM',
21
+ 'ProteinMPNN',
22
+ 'GITMol',
23
+ ]
@@ -3,7 +3,7 @@ from typing import List, Optional
3
3
  import torch
4
4
  from torch import Tensor
5
5
 
6
- from torch_geometric.nn.nlp.llm import BOS, LLM, MAX_NEW_TOKENS
6
+ from torch_geometric.llm.models.llm import LLM, MAX_NEW_TOKENS
7
7
  from torch_geometric.utils import scatter
8
8
 
9
9
 
@@ -19,17 +19,19 @@ class GRetriever(torch.nn.Module):
19
19
  :obj:`peft` for training the LLM, see
20
20
  `here <https://huggingface.co/docs/peft/en/index>`_ for details.
21
21
  (default: :obj:`False`)
22
- mlp_out_channels (int, optional): The size of each graph embedding
23
- after projection. (default: :obj:`4096`)
22
+ mlp_out_tokens (int, optional): Number of LLM prefix tokens to
23
+ reserve for GNN output. (default: :obj:`1`)
24
24
 
25
25
  .. warning::
26
26
  This module has been tested with the following HuggingFace models
27
+ * :obj:`llm_to_use="meta-llama/Meta-Llama-3.1-8B-Instruct"`
28
+ * :obj:`llm_to_use="Qwen/Qwen3-0.6B"`
27
29
 
28
- * :obj:`llm_to_use="meta-llama/Llama-2-7b-chat-hf"`
29
- * :obj:`llm_to_use="google/gemma-7b"`
30
30
 
31
- and may not work with other models. See other models at `HuggingFace
32
- Models <https://huggingface.co/models>`_ and let us know if you
31
+ This module should work with any HuggingFace model.
32
+ See other models at `HuggingFace
33
+ Models <https://huggingface.co/models>`_
34
+ and let us know if you
33
35
  encounter any issues.
34
36
 
35
37
  .. note::
@@ -40,14 +42,14 @@ class GRetriever(torch.nn.Module):
40
42
  def __init__(
41
43
  self,
42
44
  llm: LLM,
43
- gnn: torch.nn.Module,
45
+ gnn: torch.nn.Module = None,
44
46
  use_lora: bool = False,
45
- mlp_out_channels: int = 4096,
47
+ mlp_out_tokens: int = 1,
46
48
  ) -> None:
47
49
  super().__init__()
48
50
 
49
51
  self.llm = llm
50
- self.gnn = gnn.to(self.llm.device)
52
+ self.gnn = gnn.to(self.llm.device) if gnn is not None else None
51
53
 
52
54
  self.word_embedding = self.llm.word_embedding
53
55
  self.llm_generator = self.llm.llm
@@ -73,12 +75,18 @@ class GRetriever(torch.nn.Module):
73
75
  )
74
76
  self.llm_generator = get_peft_model(self.llm_generator, config)
75
77
 
76
- mlp_hidden_channels = self.gnn.out_channels
77
- self.projector = torch.nn.Sequential(
78
- torch.nn.Linear(mlp_hidden_channels, mlp_hidden_channels),
79
- torch.nn.Sigmoid(),
80
- torch.nn.Linear(mlp_hidden_channels, mlp_out_channels),
81
- ).to(self.llm.device)
78
+ if self.gnn is not None:
79
+ mlp_out_channels = llm.word_embedding.embedding_dim
80
+ mlp_hidden_channels = self.gnn.out_channels
81
+ self.projector = torch.nn.Sequential(
82
+ torch.nn.Linear(mlp_hidden_channels, mlp_hidden_channels),
83
+ torch.nn.Sigmoid(),
84
+ torch.nn.Linear(mlp_hidden_channels,
85
+ mlp_out_channels * mlp_out_tokens),
86
+ torch.nn.Unflatten(-1, (mlp_out_tokens, mlp_out_channels)),
87
+ ).to(self.llm.device)
88
+
89
+ self.seq_length_stats = []
82
90
 
83
91
  def encode(
84
92
  self,
@@ -93,7 +101,16 @@ class GRetriever(torch.nn.Module):
93
101
  edge_attr = edge_attr.to(self.llm.device)
94
102
  batch = batch.to(self.llm.device)
95
103
 
96
- out = self.gnn(x, edge_index, edge_attr=edge_attr)
104
+ model_specific_kwargs = {}
105
+
106
+ # duck typing for SGFormer to get around circular import
107
+ if (hasattr(self.gnn, 'trans_conv')
108
+ and hasattr(self.gnn, 'graph_conv')):
109
+ model_specific_kwargs['batch'] = batch
110
+ else:
111
+ model_specific_kwargs['edge_attr'] = edge_attr
112
+
113
+ out = self.gnn(x, edge_index, **model_specific_kwargs)
97
114
  return scatter(out, batch, dim=0, reduce='mean')
98
115
 
99
116
  def forward(
@@ -122,24 +139,32 @@ class GRetriever(torch.nn.Module):
122
139
  to give to the LLM, such as textified knowledge graphs.
123
140
  (default: :obj:`None`)
124
141
  """
125
- x = self.encode(x, edge_index, batch, edge_attr)
126
- x = self.projector(x)
127
- xs = x.split(1, dim=0)
128
-
129
- # Handle questions without node features:
130
- batch_unique = batch.unique()
131
- batch_size = len(question)
132
- if len(batch_unique) < batch_size:
133
- xs = [
134
- xs[i] if i in batch_unique else None for i in range(batch_size)
135
- ]
136
-
142
+ xs = None
143
+ if self.gnn is not None:
144
+ x = self.encode(x, edge_index, batch, edge_attr)
145
+ x = self.projector(x)
146
+ xs = x.split(1, dim=0)
147
+
148
+ # Handle case where theres more than one embedding for each sample
149
+ xs = [x.squeeze(0) for x in xs]
150
+
151
+ # Handle questions without node features:
152
+ batch_unique = batch.unique()
153
+ batch_size = len(question)
154
+ if len(batch_unique) < batch_size:
155
+ xs = [
156
+ xs[i] if i in batch_unique else None
157
+ for i in range(batch_size)
158
+ ]
137
159
  (
138
160
  inputs_embeds,
139
161
  attention_mask,
140
162
  label_input_ids,
141
163
  ) = self.llm._get_embeds(question, additional_text_context, xs, label)
142
164
 
165
+ max_seq_len = inputs_embeds.size(1)
166
+ self.seq_length_stats.append(max_seq_len)
167
+
143
168
  with self.llm.autocast_context:
144
169
  outputs = self.llm_generator(
145
170
  inputs_embeds=inputs_embeds,
@@ -178,32 +203,39 @@ class GRetriever(torch.nn.Module):
178
203
  max_out_tokens (int, optional): How many tokens for the LLM to
179
204
  generate. (default: :obj:`32`)
180
205
  """
181
- x = self.encode(x, edge_index, batch, edge_attr)
182
- x = self.projector(x)
183
- xs = x.split(1, dim=0)
184
-
185
- # Handle questions without node features:
186
- batch_unique = batch.unique()
187
- batch_size = len(question)
188
- if len(batch_unique) < batch_size:
189
- xs = [
190
- xs[i] if i in batch_unique else None for i in range(batch_size)
191
- ]
206
+ xs = None
207
+ if self.gnn is not None:
208
+ x = self.encode(x, edge_index, batch, edge_attr)
209
+ x = self.projector(x)
210
+ xs = x.split(1, dim=0)
211
+
212
+ # Handle case where theres more than one embedding for each sample
213
+ xs = [x.squeeze(0) for x in xs]
214
+
215
+ # Handle questions without node features:
216
+ batch_unique = batch.unique()
217
+ batch_size = len(question)
218
+ if len(batch_unique) < batch_size:
219
+ xs = [
220
+ xs[i] if i in batch_unique else None
221
+ for i in range(batch_size)
222
+ ]
192
223
 
193
224
  inputs_embeds, attention_mask, _ = self.llm._get_embeds(
194
225
  question, additional_text_context, xs)
195
226
 
196
- bos_token = self.llm.tokenizer(
197
- BOS,
198
- add_special_tokens=False,
199
- ).input_ids[0]
227
+ # bos_token = self.llm.tokenizer(
228
+ # self.llm.tokenizer.bos_token_id,
229
+ # add_special_tokens=False,
230
+ # ).input_ids[0]
200
231
 
201
232
  with self.llm.autocast_context:
202
233
  outputs = self.llm_generator.generate(
203
234
  inputs_embeds=inputs_embeds,
204
235
  max_new_tokens=max_out_tokens,
205
236
  attention_mask=attention_mask,
206
- bos_token_id=bos_token,
237
+ bos_token_id=self.llm.tokenizer.bos_token_id,
238
+ pad_token_id=self.llm.tokenizer.eos_token_id,
207
239
  use_cache=True # Important to set!
208
240
  )
209
241
 
@@ -0,0 +1,336 @@
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import Tensor
6
+ from torch.nn import BatchNorm1d, LayerNorm, Linear, ReLU, Sequential
7
+
8
+ from torch_geometric.llm.models import SentenceTransformer, VisionTransformer
9
+ from torch_geometric.nn import GINEConv
10
+ from torch_geometric.utils import add_self_loops, to_dense_batch
11
+
12
+
13
+ class GraphEncoder(torch.nn.Module):
14
+ def __init__(
15
+ self,
16
+ num_layers: int,
17
+ in_channels: int,
18
+ dropout: float = 0.,
19
+ num_atom_type: int = 120,
20
+ num_chirality_tag: int = 3,
21
+ num_bond_type: int = 6,
22
+ num_bond_direction: int = 3,
23
+ ) -> None:
24
+ super().__init__()
25
+
26
+ self.num_layers = num_layers
27
+ self.dropout = dropout
28
+
29
+ self.x_embed1 = torch.nn.Embedding(num_atom_type, in_channels)
30
+ self.x_embed2 = torch.nn.Embedding(num_chirality_tag, in_channels)
31
+ self.edge_embed1 = torch.nn.Embedding(num_bond_type, in_channels)
32
+ self.edge_embed2 = torch.nn.Embedding(num_bond_direction, in_channels)
33
+
34
+ self.gnns = torch.nn.ModuleList()
35
+ self.batch_norms = torch.nn.ModuleList()
36
+ for _ in range(num_layers):
37
+ self.gnns.append(
38
+ GINEConv(
39
+ nn=Sequential(
40
+ Linear(in_channels, in_channels * 2),
41
+ ReLU(),
42
+ Linear(in_channels * 2, in_channels),
43
+ ),
44
+ train_eps=True,
45
+ edge_dim=in_channels,
46
+ ))
47
+ self.batch_norms.append(BatchNorm1d(in_channels))
48
+ self.reset_parameters()
49
+
50
+ def reset_parameters(self):
51
+ torch.nn.init.xavier_uniform_(self.x_embed1.weight.data)
52
+ torch.nn.init.xavier_uniform_(self.x_embed2.weight.data)
53
+ torch.nn.init.xavier_uniform_(self.edge_embed1.weight.data)
54
+ torch.nn.init.xavier_uniform_(self.edge_embed2.weight.data)
55
+
56
+ def forward(
57
+ self,
58
+ x: Tensor,
59
+ edge_index: Tensor,
60
+ batch: Tensor,
61
+ edge_attr: Tensor,
62
+ ) -> Tensor:
63
+ x = self.x_embed1(x[:, 0].long()) + self.x_embed2(x[:, 1].long())
64
+ edge_index, edge_attr = add_self_loops(
65
+ edge_index,
66
+ edge_attr,
67
+ fill_value=0,
68
+ num_nodes=x.size(0),
69
+ )
70
+ edge_attr = self.edge_embed1(edge_attr[:, 0]) + self.edge_embed2(
71
+ edge_attr[:, 1])
72
+ for i, (gnn, bn) in enumerate(zip(self.gnns, self.batch_norms)):
73
+ x = gnn(x, edge_index, edge_attr)
74
+ x = bn(x)
75
+ if i < self.num_layers - 1:
76
+ x = F.relu(x)
77
+ x = F.dropout(x, self.dropout, training=self.training)
78
+
79
+ x, mask = to_dense_batch(x, batch)
80
+ return x, mask
81
+
82
+
83
+ class GITFormer(torch.nn.Module):
84
+ def __init__(
85
+ self,
86
+ num_query_token: int,
87
+ vision_graph_width: int,
88
+ cross_attention_freq: int = 2,
89
+ ):
90
+ super().__init__()
91
+ from transformers import AutoConfig, AutoModel
92
+
93
+ config = AutoConfig.from_pretrained("allenai/scibert_scivocab_uncased")
94
+ config.encoder_width = vision_graph_width
95
+ # insert cross-attention layer every other block
96
+ config.add_cross_attention = True
97
+ config.is_decoder = True
98
+ config.cross_attention_freq = cross_attention_freq
99
+ config.query_length = num_query_token
100
+ self.Qformer = AutoModel.from_pretrained(
101
+ "allenai/scibert_scivocab_uncased", config=config)
102
+ self.query_tokens = torch.nn.Parameter(
103
+ torch.zeros(1, num_query_token, config.hidden_size))
104
+ self.query_tokens.data.normal_(mean=0.0, std=config.initializer_range)
105
+
106
+
107
+ class GITMol(torch.nn.Module):
108
+ r"""The GITMol model from the `"GIT-Mol: A Multi-modal Large Language
109
+ Model for Molecular Science with Graph, Image, and Text"
110
+ <https://arxiv.org/pdf/2308.06911>`_ paper.
111
+
112
+ .. note::
113
+ For an example of using :class:`GITMol`, see
114
+ `examples/llm/git_mol.py <https://github.com/pyg-team/
115
+ pytorch_geometric/blob/master/examples/llm/git_mol.py>`_.
116
+ """
117
+ def __init__(self) -> None:
118
+ super().__init__()
119
+ # graph
120
+ self.graph_encoder = GraphEncoder(num_layers=2, in_channels=16)
121
+ self.graph_proj = Linear(16, 768)
122
+ self.ln_graph = LayerNorm(768)
123
+ # text
124
+ self.text_encoder = SentenceTransformer(
125
+ model_name='allenai/scibert_scivocab_uncased',
126
+ pooling_strategy='last_hidden_state',
127
+ )
128
+ self.text_proj = Linear(768, 768)
129
+ self.ln_text = LayerNorm(768)
130
+ # vision
131
+ self.vision_encoder = VisionTransformer(
132
+ model_name='microsoft/swin-base-patch4-window7-224', )
133
+ self.vision_proj = Linear(1024, 768)
134
+ self.ln_vision = LayerNorm(768)
135
+ # cross-attention
136
+ self.gitformer = GITFormer(384, 768)
137
+
138
+ self.xtm_head = torch.nn.ModuleDict({
139
+ 'image':
140
+ Linear(self.gitformer.Qformer.config.hidden_size, 2),
141
+ 'graph':
142
+ Linear(self.gitformer.Qformer.config.hidden_size, 2),
143
+ 'cs_text':
144
+ Linear(self.gitformer.Qformer.config.hidden_size, 2),
145
+ })
146
+
147
+ self.xtc_proj = torch.nn.ModuleDict({
148
+ 'image':
149
+ Linear(self.gitformer.Qformer.config.hidden_size, 768),
150
+ 'graph':
151
+ Linear(self.gitformer.Qformer.config.hidden_size, 768),
152
+ 'cs_text':
153
+ Linear(self.gitformer.Qformer.config.hidden_size, 768),
154
+ })
155
+ self.temp = torch.nn.Parameter(0.07 * torch.ones([]))
156
+ self.model_freeze()
157
+
158
+ def model_freeze(self) -> None:
159
+ for param in self.graph_encoder.parameters():
160
+ param.requires_grad = False
161
+
162
+ for param in self.vision_encoder.parameters():
163
+ param.requires_grad = False
164
+
165
+ def forward(
166
+ self,
167
+ x: Tensor,
168
+ edge_index: Tensor,
169
+ batch: Tensor,
170
+ edge_attr: Optional[Tensor],
171
+ smiles: List[str],
172
+ images: Tensor,
173
+ captions: List[str],
174
+ ) -> Tensor:
175
+ batch_size = len(smiles)
176
+
177
+ x_vision = self.vision_encoder(images)
178
+ x_vision = self.vision_proj(x_vision)
179
+ x_vision = self.ln_vision(x_vision) # [bs, patch_len, d]
180
+ vision_atts = torch.ones(x_vision.size()[:-1],
181
+ dtype=torch.long).to(x_vision.device)
182
+ vision_targets = torch.arange(batch_size).to(x_vision.device)
183
+
184
+ x_graph, graph_atts = self.graph_encoder(x, edge_index, batch,
185
+ edge_attr)
186
+ x_graph = self.graph_proj(x_graph)
187
+ x_graph = self.ln_graph(x_graph) # [bs, node_len, d]
188
+ graph_targets = torch.arange(batch_size).to(x_graph.device)
189
+
190
+ x_smiles = self.text_encoder.encode(smiles) # [bs, seq_len, d]
191
+ smiles_atts = torch.ones(x_smiles.size()[:-1],
192
+ dtype=torch.long).to(x_smiles.device)
193
+ smiles_targets = torch.arange(batch_size).to(x_smiles.device)
194
+
195
+ caption_input_ids, caption_attention_masks = self.text_encoder.get_input_ids( # noqa: E501
196
+ captions)
197
+
198
+ text_output = self.gitformer.Qformer(
199
+ caption_input_ids,
200
+ attention_mask=caption_attention_masks,
201
+ return_dict=True,
202
+ )
203
+ text_feat = F.normalize(
204
+ self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1)
205
+
206
+ loss = 0
207
+ for x_embed, x_atts, x_targets, modal in zip(
208
+ [x_graph, x_smiles, x_vision],
209
+ [graph_atts, smiles_atts, vision_atts],
210
+ [graph_targets, smiles_targets, vision_targets],
211
+ ['graph', 'cs_text', 'image'],
212
+ ):
213
+ loss += self._calc_xtc_loss(x_embed, x_atts, x_targets, text_feat,
214
+ modal)
215
+ loss += self._calc_xtm_loss(x_embed, caption_input_ids,
216
+ caption_attention_masks, modal)
217
+
218
+ return loss / 6
219
+
220
+ def _calc_xtm_loss(
221
+ self,
222
+ x_embeds: Tensor,
223
+ input_ids: Tensor,
224
+ attention_mask: Tensor,
225
+ modal: str,
226
+ ) -> Tensor:
227
+ # Initializing lists to hold the original and negative samples
228
+ x_embeds_list = []
229
+ text_input_ids_list = []
230
+ text_attention_mask_list = []
231
+
232
+ batch_size = x_embeds.size(0)
233
+ for i in range(batch_size):
234
+ # Original samples
235
+ x_embeds_list.append(x_embeds[i])
236
+ text_input_ids_list.append(input_ids[i, :])
237
+ text_attention_mask_list.append(attention_mask[i, :])
238
+
239
+ if batch_size > 1:
240
+ # Negative samples (neg_text_input_ids corresponds to x_embeds)
241
+ neg_text_input_ids = input_ids[i - 1 if i == batch_size -
242
+ 1 else i + 1, :]
243
+ neg_text_attention_mask = attention_mask[i -
244
+ 1 if i == batch_size -
245
+ 1 else i + 1, :]
246
+ text_input_ids_list.append(neg_text_input_ids)
247
+ text_attention_mask_list.append(neg_text_attention_mask)
248
+ x_embeds_list.append(x_embeds[i, :])
249
+
250
+ # Negative samples (text_input_ids corresponds to neg_x_embeds)
251
+ neg_x_embeds = x_embeds[i - 1 if i == batch_size - 1 else i +
252
+ 1, :]
253
+ x_embeds_list.append(neg_x_embeds)
254
+ text_input_ids_list.append(input_ids[i, :])
255
+ text_attention_mask_list.append(attention_mask[i, :])
256
+
257
+ # Stack all samples into two large tensors
258
+ x_embeds_all = torch.stack(x_embeds_list, dim=1) \
259
+ .reshape(-1, x_embeds.size(1), x_embeds.size(2))
260
+ text_input_ids_all = torch.stack(text_input_ids_list, dim=1) \
261
+ .reshape(-1, input_ids.size(1))
262
+ # Create image attention masks for the concatenated tensor
263
+ image_attns_all = torch.ones(x_embeds_all.size()[:-1],
264
+ dtype=torch.long).to(x_embeds_all.device)
265
+ query_tokens_xtm = self.gitformer.query_tokens.expand(
266
+ text_input_ids_all.shape[0], -1, -1)
267
+ query_attns_xtm = torch.ones(query_tokens_xtm.size()[:-1],
268
+ dtype=torch.long).to(x_embeds_all.device)
269
+
270
+ output_xtm = self.gitformer.Qformer(
271
+ inputs_embeds=query_tokens_xtm,
272
+ attention_mask=query_attns_xtm,
273
+ encoder_hidden_states=x_embeds_all,
274
+ encoder_attention_mask=image_attns_all,
275
+ return_dict=True,
276
+ ).last_hidden_state
277
+
278
+ xtm_embeddings = output_xtm[:, :query_tokens_xtm.size(1), :]
279
+
280
+ xtm_logit = self.xtm_head[modal](xtm_embeddings).mean(dim=1)
281
+ # Create labels: 1 for the original samples, 0 for the negative samples
282
+ if batch_size > 1:
283
+ labels = torch.cat(
284
+ [torch.ones(batch_size),
285
+ torch.zeros(batch_size * 2)], dim=0)
286
+ else:
287
+ labels = torch.ones(batch_size)
288
+ labels = labels.long().to(xtm_logit.device)
289
+
290
+ # Calculate cross entropy loss
291
+ return F.cross_entropy(xtm_logit, labels)
292
+
293
+ def _calc_xtc_loss(
294
+ self,
295
+ x_embeds: Tensor,
296
+ x_atts: Tensor,
297
+ x_targets: Tensor,
298
+ text_feat: Tensor,
299
+ modal: str,
300
+ ) -> Tensor:
301
+ query_tokens = self.gitformer.query_tokens.expand(
302
+ x_embeds.shape[0], -1, -1)
303
+
304
+ query_output = self.gitformer.Qformer(
305
+ inputs_embeds=query_tokens,
306
+ encoder_hidden_states=x_embeds,
307
+ encoder_attention_mask=x_atts,
308
+ return_dict=True,
309
+ ).last_hidden_state
310
+
311
+ x_feats = F.normalize(self.xtc_proj[modal](query_output), dim=-1)
312
+
313
+ sim_q2t = torch.matmul(
314
+ x_feats.unsqueeze(1),
315
+ text_feat.unsqueeze(-1),
316
+ ).squeeze(-1)
317
+
318
+ # modal-text similarity: aggregate across all query tokens
319
+ sim_x2t, _ = sim_q2t.max(-1)
320
+ sim_x2t = sim_x2t / self.temp
321
+
322
+ # text-query similarity
323
+ sim_t2q = torch.matmul(
324
+ text_feat.unsqueeze(1).unsqueeze(1),
325
+ x_feats.permute(0, 2, 1),
326
+ ).squeeze(-2)
327
+
328
+ # text-modal similarity: aggregate across all query tokens
329
+ sim_t2x, _ = sim_t2q.max(-1)
330
+ sim_t2x = sim_t2x / self.temp
331
+
332
+ loss_itc = (
333
+ F.cross_entropy(sim_x2t, x_targets, label_smoothing=0.1) +
334
+ F.cross_entropy(sim_t2x, x_targets, label_smoothing=0.1)) / 2
335
+
336
+ return loss_itc