pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251228__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (229) hide show
  1. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/METADATA +77 -53
  2. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/RECORD +227 -190
  3. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251228.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +14 -2
  6. torch_geometric/_compile.py +9 -3
  7. torch_geometric/_onnx.py +214 -0
  8. torch_geometric/config_mixin.py +5 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/batch.py +2 -2
  13. torch_geometric/data/collate.py +1 -3
  14. torch_geometric/data/data.py +109 -5
  15. torch_geometric/data/database.py +4 -0
  16. torch_geometric/data/dataset.py +14 -11
  17. torch_geometric/data/extract.py +1 -1
  18. torch_geometric/data/feature_store.py +17 -22
  19. torch_geometric/data/graph_store.py +3 -2
  20. torch_geometric/data/hetero_data.py +139 -7
  21. torch_geometric/data/hypergraph_data.py +2 -2
  22. torch_geometric/data/in_memory_dataset.py +2 -2
  23. torch_geometric/data/lightning/datamodule.py +42 -28
  24. torch_geometric/data/storage.py +9 -1
  25. torch_geometric/datasets/__init__.py +18 -1
  26. torch_geometric/datasets/actor.py +7 -9
  27. torch_geometric/datasets/airfrans.py +15 -17
  28. torch_geometric/datasets/airports.py +8 -10
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +8 -9
  31. torch_geometric/datasets/amazon_products.py +7 -9
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/city.py +157 -0
  38. torch_geometric/datasets/dbp15k.py +1 -1
  39. torch_geometric/datasets/git_mol_dataset.py +263 -0
  40. torch_geometric/datasets/hgb_dataset.py +2 -2
  41. torch_geometric/datasets/hm.py +1 -1
  42. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  43. torch_geometric/datasets/md17.py +3 -3
  44. torch_geometric/datasets/medshapenet.py +145 -0
  45. torch_geometric/datasets/modelnet.py +1 -1
  46. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  47. torch_geometric/datasets/molecule_net.py +3 -2
  48. torch_geometric/datasets/ppi.py +2 -1
  49. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  50. torch_geometric/datasets/qm7.py +1 -1
  51. torch_geometric/datasets/qm9.py +1 -1
  52. torch_geometric/datasets/snap_dataset.py +8 -4
  53. torch_geometric/datasets/tag_dataset.py +462 -0
  54. torch_geometric/datasets/teeth3ds.py +269 -0
  55. torch_geometric/datasets/web_qsp_dataset.py +310 -209
  56. torch_geometric/datasets/wikics.py +2 -1
  57. torch_geometric/deprecation.py +1 -1
  58. torch_geometric/distributed/__init__.py +13 -0
  59. torch_geometric/distributed/dist_loader.py +2 -2
  60. torch_geometric/distributed/partition.py +2 -2
  61. torch_geometric/distributed/rpc.py +3 -3
  62. torch_geometric/edge_index.py +18 -14
  63. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  64. torch_geometric/explain/algorithm/base.py +2 -2
  65. torch_geometric/explain/algorithm/captum.py +1 -1
  66. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  67. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  68. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  69. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  70. torch_geometric/explain/explainer.py +2 -2
  71. torch_geometric/explain/explanation.py +87 -3
  72. torch_geometric/explain/metric/faithfulness.py +1 -1
  73. torch_geometric/graphgym/config.py +3 -2
  74. torch_geometric/graphgym/imports.py +15 -4
  75. torch_geometric/graphgym/logger.py +1 -1
  76. torch_geometric/graphgym/loss.py +1 -1
  77. torch_geometric/graphgym/models/encoder.py +2 -2
  78. torch_geometric/graphgym/models/layer.py +1 -1
  79. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  80. torch_geometric/hash_tensor.py +798 -0
  81. torch_geometric/index.py +14 -5
  82. torch_geometric/inspector.py +4 -0
  83. torch_geometric/io/fs.py +5 -4
  84. torch_geometric/llm/__init__.py +9 -0
  85. torch_geometric/llm/large_graph_indexer.py +741 -0
  86. torch_geometric/llm/models/__init__.py +23 -0
  87. torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
  88. torch_geometric/llm/models/git_mol.py +336 -0
  89. torch_geometric/llm/models/glem.py +397 -0
  90. torch_geometric/{nn/nlp → llm/models}/llm.py +180 -32
  91. torch_geometric/llm/models/llm_judge.py +158 -0
  92. torch_geometric/llm/models/molecule_gpt.py +222 -0
  93. torch_geometric/llm/models/protein_mpnn.py +333 -0
  94. torch_geometric/llm/models/sentence_transformer.py +188 -0
  95. torch_geometric/llm/models/txt2kg.py +353 -0
  96. torch_geometric/llm/models/vision_transformer.py +38 -0
  97. torch_geometric/llm/rag_loader.py +154 -0
  98. torch_geometric/llm/utils/__init__.py +10 -0
  99. torch_geometric/llm/utils/backend_utils.py +443 -0
  100. torch_geometric/llm/utils/feature_store.py +169 -0
  101. torch_geometric/llm/utils/graph_store.py +199 -0
  102. torch_geometric/llm/utils/vectorrag.py +125 -0
  103. torch_geometric/loader/cluster.py +4 -4
  104. torch_geometric/loader/ibmb_loader.py +4 -4
  105. torch_geometric/loader/link_loader.py +1 -1
  106. torch_geometric/loader/link_neighbor_loader.py +2 -1
  107. torch_geometric/loader/mixin.py +6 -5
  108. torch_geometric/loader/neighbor_loader.py +1 -1
  109. torch_geometric/loader/neighbor_sampler.py +2 -2
  110. torch_geometric/loader/prefetch.py +3 -2
  111. torch_geometric/loader/temporal_dataloader.py +2 -2
  112. torch_geometric/loader/utils.py +10 -10
  113. torch_geometric/metrics/__init__.py +14 -0
  114. torch_geometric/metrics/link_pred.py +745 -92
  115. torch_geometric/nn/__init__.py +1 -0
  116. torch_geometric/nn/aggr/base.py +1 -1
  117. torch_geometric/nn/aggr/equilibrium.py +1 -1
  118. torch_geometric/nn/aggr/fused.py +1 -1
  119. torch_geometric/nn/aggr/patch_transformer.py +8 -2
  120. torch_geometric/nn/aggr/set_transformer.py +1 -1
  121. torch_geometric/nn/aggr/utils.py +9 -4
  122. torch_geometric/nn/attention/__init__.py +9 -1
  123. torch_geometric/nn/attention/polynormer.py +107 -0
  124. torch_geometric/nn/attention/qformer.py +71 -0
  125. torch_geometric/nn/attention/sgformer.py +99 -0
  126. torch_geometric/nn/conv/__init__.py +2 -0
  127. torch_geometric/nn/conv/appnp.py +1 -1
  128. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  129. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  130. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  131. torch_geometric/nn/conv/dna_conv.py +1 -1
  132. torch_geometric/nn/conv/eg_conv.py +7 -7
  133. torch_geometric/nn/conv/gen_conv.py +1 -1
  134. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  135. torch_geometric/nn/conv/hetero_conv.py +2 -1
  136. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  137. torch_geometric/nn/conv/message_passing.py +5 -4
  138. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  139. torch_geometric/nn/conv/sg_conv.py +1 -1
  140. torch_geometric/nn/conv/spline_conv.py +2 -1
  141. torch_geometric/nn/conv/ssg_conv.py +1 -1
  142. torch_geometric/nn/conv/transformer_conv.py +5 -3
  143. torch_geometric/nn/data_parallel.py +5 -4
  144. torch_geometric/nn/dense/linear.py +0 -20
  145. torch_geometric/nn/encoding.py +17 -3
  146. torch_geometric/nn/fx.py +14 -12
  147. torch_geometric/nn/model_hub.py +2 -15
  148. torch_geometric/nn/models/__init__.py +11 -2
  149. torch_geometric/nn/models/attentive_fp.py +1 -1
  150. torch_geometric/nn/models/attract_repel.py +148 -0
  151. torch_geometric/nn/models/basic_gnn.py +2 -1
  152. torch_geometric/nn/models/captum.py +1 -1
  153. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  154. torch_geometric/nn/models/dimenet.py +2 -2
  155. torch_geometric/nn/models/dimenet_utils.py +4 -2
  156. torch_geometric/nn/models/gpse.py +1083 -0
  157. torch_geometric/nn/models/graph_unet.py +13 -4
  158. torch_geometric/nn/models/lpformer.py +783 -0
  159. torch_geometric/nn/models/metapath2vec.py +1 -1
  160. torch_geometric/nn/models/mlp.py +4 -2
  161. torch_geometric/nn/models/node2vec.py +1 -1
  162. torch_geometric/nn/models/polynormer.py +206 -0
  163. torch_geometric/nn/models/rev_gnn.py +3 -3
  164. torch_geometric/nn/models/sgformer.py +219 -0
  165. torch_geometric/nn/models/signed_gcn.py +1 -1
  166. torch_geometric/nn/models/visnet.py +2 -2
  167. torch_geometric/nn/norm/batch_norm.py +17 -7
  168. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  169. torch_geometric/nn/norm/graph_norm.py +9 -4
  170. torch_geometric/nn/norm/instance_norm.py +5 -1
  171. torch_geometric/nn/norm/layer_norm.py +15 -7
  172. torch_geometric/nn/norm/msg_norm.py +8 -2
  173. torch_geometric/nn/pool/__init__.py +8 -4
  174. torch_geometric/nn/pool/cluster_pool.py +3 -4
  175. torch_geometric/nn/pool/connect/base.py +1 -3
  176. torch_geometric/nn/pool/knn.py +13 -10
  177. torch_geometric/nn/pool/select/base.py +1 -4
  178. torch_geometric/nn/to_hetero_module.py +4 -3
  179. torch_geometric/nn/to_hetero_transformer.py +3 -3
  180. torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
  181. torch_geometric/profile/__init__.py +2 -0
  182. torch_geometric/profile/nvtx.py +66 -0
  183. torch_geometric/profile/utils.py +20 -5
  184. torch_geometric/sampler/__init__.py +2 -1
  185. torch_geometric/sampler/base.py +336 -7
  186. torch_geometric/sampler/hgt_sampler.py +11 -1
  187. torch_geometric/sampler/neighbor_sampler.py +296 -23
  188. torch_geometric/sampler/utils.py +93 -5
  189. torch_geometric/testing/__init__.py +4 -0
  190. torch_geometric/testing/decorators.py +35 -5
  191. torch_geometric/testing/distributed.py +1 -1
  192. torch_geometric/transforms/__init__.py +2 -0
  193. torch_geometric/transforms/add_gpse.py +49 -0
  194. torch_geometric/transforms/add_metapaths.py +8 -6
  195. torch_geometric/transforms/add_positional_encoding.py +2 -2
  196. torch_geometric/transforms/base_transform.py +2 -1
  197. torch_geometric/transforms/delaunay.py +65 -15
  198. torch_geometric/transforms/face_to_edge.py +32 -3
  199. torch_geometric/transforms/gdc.py +7 -8
  200. torch_geometric/transforms/largest_connected_components.py +1 -1
  201. torch_geometric/transforms/mask.py +5 -1
  202. torch_geometric/transforms/normalize_features.py +3 -3
  203. torch_geometric/transforms/random_link_split.py +1 -1
  204. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  205. torch_geometric/transforms/rooted_subgraph.py +1 -1
  206. torch_geometric/typing.py +70 -17
  207. torch_geometric/utils/__init__.py +4 -1
  208. torch_geometric/utils/_lexsort.py +0 -9
  209. torch_geometric/utils/_negative_sampling.py +27 -12
  210. torch_geometric/utils/_scatter.py +132 -195
  211. torch_geometric/utils/_sort_edge_index.py +0 -2
  212. torch_geometric/utils/_spmm.py +16 -14
  213. torch_geometric/utils/_subgraph.py +4 -0
  214. torch_geometric/utils/_to_dense_batch.py +2 -2
  215. torch_geometric/utils/_trim_to_layer.py +2 -2
  216. torch_geometric/utils/convert.py +17 -10
  217. torch_geometric/utils/cross_entropy.py +34 -13
  218. torch_geometric/utils/embedding.py +91 -2
  219. torch_geometric/utils/geodesic.py +4 -3
  220. torch_geometric/utils/influence.py +279 -0
  221. torch_geometric/utils/map.py +13 -9
  222. torch_geometric/utils/nested.py +1 -1
  223. torch_geometric/utils/smiles.py +3 -3
  224. torch_geometric/utils/sparse.py +7 -14
  225. torch_geometric/visualization/__init__.py +2 -1
  226. torch_geometric/visualization/graph.py +250 -5
  227. torch_geometric/warnings.py +11 -2
  228. torch_geometric/nn/nlp/__init__.py +0 -7
  229. torch_geometric/nn/nlp/sentence_transformer.py +0 -101
@@ -0,0 +1,333 @@
1
+ from itertools import product
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from torch_geometric.nn import knn_graph
8
+ from torch_geometric.nn.conv import MessagePassing
9
+ from torch_geometric.utils import to_dense_adj, to_dense_batch
10
+
11
+
12
+ class PositionWiseFeedForward(torch.nn.Module):
13
+ def __init__(self, in_channels: int, hidden_channels: int) -> None:
14
+ super().__init__()
15
+ self.out = torch.nn.Sequential(
16
+ torch.nn.Linear(in_channels, hidden_channels),
17
+ torch.nn.GELU(),
18
+ torch.nn.Linear(hidden_channels, in_channels),
19
+ )
20
+
21
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
22
+ return self.out(x)
23
+
24
+
25
+ class PositionalEncoding(torch.nn.Module):
26
+ def __init__(self, hidden_channels: int,
27
+ max_relative_feature: int = 32) -> None:
28
+ super().__init__()
29
+ self.max_relative_feature = max_relative_feature
30
+ self.emb = torch.nn.Embedding(2 * max_relative_feature + 2,
31
+ hidden_channels)
32
+
33
+ def forward(self, offset, mask) -> torch.Tensor:
34
+ d = torch.clip(offset + self.max_relative_feature, 0,
35
+ 2 * self.max_relative_feature) * mask + (1 - mask) * (
36
+ 2 * self.max_relative_feature + 1) # noqa: E501
37
+ return self.emb(d.long())
38
+
39
+
40
+ class Encoder(MessagePassing):
41
+ def __init__(
42
+ self,
43
+ in_channels: int,
44
+ hidden_channels: int,
45
+ dropout: float = 0.1,
46
+ scale: float = 30,
47
+ ) -> None:
48
+ super().__init__()
49
+ self.out_v = torch.nn.Sequential(
50
+ torch.nn.Linear(in_channels, hidden_channels),
51
+ torch.nn.GELU(),
52
+ torch.nn.Linear(hidden_channels, hidden_channels),
53
+ torch.nn.GELU(),
54
+ torch.nn.Linear(hidden_channels, hidden_channels),
55
+ )
56
+ self.out_e = torch.nn.Sequential(
57
+ torch.nn.Linear(in_channels, hidden_channels),
58
+ torch.nn.GELU(),
59
+ torch.nn.Linear(hidden_channels, hidden_channels),
60
+ torch.nn.GELU(),
61
+ torch.nn.Linear(hidden_channels, hidden_channels),
62
+ )
63
+ self.dropout1 = torch.nn.Dropout(dropout)
64
+ self.dropout2 = torch.nn.Dropout(dropout)
65
+ self.dropout3 = torch.nn.Dropout(dropout)
66
+ self.norm1 = torch.nn.LayerNorm(hidden_channels)
67
+ self.norm2 = torch.nn.LayerNorm(hidden_channels)
68
+ self.norm3 = torch.nn.LayerNorm(hidden_channels)
69
+ self.scale = scale
70
+ self.dense = PositionWiseFeedForward(hidden_channels,
71
+ hidden_channels * 4)
72
+
73
+ def forward(
74
+ self,
75
+ x: torch.Tensor,
76
+ edge_index: torch.Tensor,
77
+ edge_attr: torch.Tensor,
78
+ ) -> torch.Tensor:
79
+ # x: [N, d_v]
80
+ # edge_index: [2, E]
81
+ # edge_attr: [E, d_e]
82
+ # update node features
83
+ h_message = self.propagate(x=x, edge_index=edge_index,
84
+ edge_attr=edge_attr)
85
+ dh = h_message / self.scale
86
+ x = self.norm1(x + self.dropout1(dh))
87
+ dh = self.dense(x)
88
+ x = self.norm2(x + self.dropout2(dh))
89
+ # update edge features
90
+ row, col = edge_index
91
+ x_i, x_j = x[row], x[col]
92
+ h_e = torch.cat([x_i, x_j, edge_attr], dim=-1)
93
+ h_e = self.out_e(h_e)
94
+ edge_attr = self.norm3(edge_attr + self.dropout3(h_e))
95
+ return x, edge_attr
96
+
97
+ def message(self, x_i: torch.Tensor, x_j: torch.Tensor,
98
+ edge_attr: torch.Tensor) -> torch.Tensor:
99
+ h = torch.cat([x_i, x_j, edge_attr], dim=-1) # [E, 2*d_v + d_e]
100
+ h = self.out_e(h) # [E, d_e]
101
+ return h
102
+
103
+
104
+ class Decoder(MessagePassing):
105
+ def __init__(
106
+ self,
107
+ in_channels: int,
108
+ hidden_channels: int,
109
+ dropout: float = 0.1,
110
+ scale: float = 30,
111
+ ) -> None:
112
+ super().__init__()
113
+ self.out_v = torch.nn.Sequential(
114
+ torch.nn.Linear(in_channels, hidden_channels),
115
+ torch.nn.GELU(),
116
+ torch.nn.Linear(hidden_channels, hidden_channels),
117
+ torch.nn.GELU(),
118
+ torch.nn.Linear(hidden_channels, hidden_channels),
119
+ )
120
+ self.dropout1 = torch.nn.Dropout(dropout)
121
+ self.dropout2 = torch.nn.Dropout(dropout)
122
+ self.norm1 = torch.nn.LayerNorm(hidden_channels)
123
+ self.norm2 = torch.nn.LayerNorm(hidden_channels)
124
+ self.scale = scale
125
+ self.dense = PositionWiseFeedForward(hidden_channels,
126
+ hidden_channels * 4)
127
+
128
+ def forward(
129
+ self,
130
+ x: torch.Tensor,
131
+ edge_index: torch.Tensor,
132
+ edge_attr: torch.Tensor,
133
+ x_label: torch.Tensor,
134
+ mask: torch.Tensor,
135
+ ) -> torch.Tensor:
136
+ # x: [N, d_v]
137
+ # edge_index: [2, E]
138
+ # edge_attr: [E, d_e]
139
+ h_message = self.propagate(x=x, x_label=x_label, edge_index=edge_index,
140
+ edge_attr=edge_attr, mask=mask)
141
+ dh = h_message / self.scale
142
+ x = self.norm1(x + self.dropout1(dh))
143
+ dh = self.dense(x)
144
+ x = self.norm2(x + self.dropout2(dh))
145
+ return x
146
+
147
+ def message(self, x_i: torch.Tensor, x_j: torch.Tensor,
148
+ x_label_j: torch.Tensor, edge_attr: torch.Tensor,
149
+ mask: torch.Tensor) -> torch.Tensor:
150
+ h_1 = torch.cat([x_j, edge_attr, x_label_j], dim=-1)
151
+ h_0 = torch.cat([x_j, edge_attr, torch.zeros_like(x_label_j)], dim=-1)
152
+ h = h_1 * mask + h_0 * (1 - mask)
153
+ h = torch.concat([x_i, h], dim=-1)
154
+ h = self.out_v(h)
155
+ return h
156
+
157
+
158
+ class ProteinMPNN(torch.nn.Module):
159
+ r"""The ProteinMPNN model from the `"Robust deep learning--based
160
+ protein sequence design using ProteinMPNN"
161
+ <https://www.biorxiv.org/content/10.1101/2022.06.03.494563v1>`_ paper.
162
+
163
+ Args:
164
+ hidden_dim (int): Hidden channels.
165
+ (default: :obj:`128`)
166
+ num_encoder_layers (int): Number of encode layers.
167
+ (default: :obj:`3`)
168
+ num_decoder_layers (int): Number of decode layers.
169
+ (default: :obj:`3`)
170
+ num_neighbors (int): Number of neighbors for each atom.
171
+ (default: :obj:`30`)
172
+ num_rbf (int): Number of radial basis functions.
173
+ (default: :obj:`16`)
174
+ dropout (float): Dropout rate.
175
+ (default: :obj:`0.1`)
176
+ augment_eps (float): Augmentation epsilon for input coordinates.
177
+ (default: :obj:`0.2`)
178
+ num_positional_embedding (int): Number of positional embeddings.
179
+ (default: :obj:`16`)
180
+ vocab_size (int): Number of vocabulary.
181
+ (default: :obj:`21`)
182
+
183
+ .. note::
184
+ For an example of using :class:`ProteinMPNN`, see
185
+ `examples/llm/protein_mpnn.py <https://github.com/pyg-team/
186
+ pytorch_geometric/blob/master/examples/llm/protein_mpnn.py>`_.
187
+ """
188
+ def __init__(
189
+ self,
190
+ hidden_dim: int = 128,
191
+ num_encoder_layers: int = 3,
192
+ num_decoder_layers: int = 3,
193
+ num_neighbors: int = 30,
194
+ num_rbf: int = 16,
195
+ dropout: float = 0.1,
196
+ augment_eps: float = 0.2,
197
+ num_positional_embedding: int = 16,
198
+ vocab_size: int = 21,
199
+ ) -> None:
200
+ super().__init__()
201
+ self.augment_eps = augment_eps
202
+ self.hidden_dim = hidden_dim
203
+ self.num_neighbors = num_neighbors
204
+ self.num_rbf = num_rbf
205
+ self.embedding = PositionalEncoding(num_positional_embedding)
206
+ self.edge_mlp = torch.nn.Sequential(
207
+ torch.nn.Linear(num_positional_embedding + 400, hidden_dim),
208
+ torch.nn.LayerNorm(hidden_dim),
209
+ torch.nn.Linear(hidden_dim, hidden_dim),
210
+ )
211
+ self.label_embedding = torch.nn.Embedding(vocab_size, hidden_dim)
212
+ self.encoder_layers = torch.nn.ModuleList([
213
+ Encoder(hidden_dim * 3, hidden_dim, dropout)
214
+ for _ in range(num_encoder_layers)
215
+ ])
216
+
217
+ self.decoder_layers = torch.nn.ModuleList([
218
+ Decoder(hidden_dim * 4, hidden_dim, dropout)
219
+ for _ in range(num_decoder_layers)
220
+ ])
221
+ self.output = torch.nn.Linear(hidden_dim, vocab_size)
222
+
223
+ self.reset_parameters()
224
+
225
+ def reset_parameters(self):
226
+ for p in self.parameters():
227
+ if p.dim() > 1:
228
+ torch.nn.init.xavier_uniform_(p)
229
+
230
+ def _featurize(
231
+ self,
232
+ x: torch.Tensor,
233
+ mask: torch.Tensor,
234
+ batch: torch.Tensor,
235
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
236
+ N, Ca, C, O = (x[:, i, :] for i in range(4)) # noqa: E741
237
+ b = Ca - N
238
+ c = C - Ca
239
+ a = torch.cross(b, c, dim=-1)
240
+ Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + Ca
241
+
242
+ valid_mask = mask.bool()
243
+ valid_Ca = Ca[valid_mask]
244
+ valid_batch = batch[valid_mask]
245
+
246
+ edge_index = knn_graph(valid_Ca, k=self.num_neighbors,
247
+ batch=valid_batch, loop=True)
248
+
249
+ row, col = edge_index
250
+ original_indices = torch.arange(Ca.size(0),
251
+ device=x.device)[valid_mask]
252
+ edge_index_original = torch.stack(
253
+ [original_indices[row], original_indices[col]], dim=0)
254
+ row, col = edge_index_original
255
+
256
+ rbf_all = []
257
+ for A, B in list(product([N, Ca, C, O, Cb], repeat=2)):
258
+ distances = torch.sqrt(torch.sum((A[row] - B[col])**2, 1) + 1e-6)
259
+ rbf = self._rbf(distances)
260
+ rbf_all.append(rbf)
261
+
262
+ return edge_index_original, torch.cat(rbf_all, dim=-1)
263
+
264
+ def _rbf(self, D: torch.Tensor) -> torch.Tensor:
265
+ D_min, D_max, D_count = 2., 22., self.num_rbf
266
+ D_mu = torch.linspace(D_min, D_max, D_count, device=D.device)
267
+ D_mu = D_mu.view([1, -1])
268
+ D_sigma = (D_max - D_min) / D_count
269
+ D_expand = torch.unsqueeze(D, -1)
270
+ RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)
271
+ return RBF
272
+
273
+ def forward(
274
+ self,
275
+ x: torch.Tensor,
276
+ chain_seq_label: torch.Tensor,
277
+ mask: torch.Tensor,
278
+ chain_mask_all: torch.Tensor,
279
+ residue_idx: torch.Tensor,
280
+ chain_encoding_all: torch.Tensor,
281
+ batch: torch.Tensor,
282
+ ) -> torch.Tensor:
283
+ device = x.device
284
+ if self.training and self.augment_eps > 0:
285
+ x = x + self.augment_eps * torch.randn_like(x)
286
+
287
+ edge_index, edge_attr = self._featurize(x, mask, batch)
288
+
289
+ row, col = edge_index
290
+ offset = residue_idx[row] - residue_idx[col]
291
+ # find self vs non-self interaction
292
+ e_chains = ((chain_encoding_all[row] -
293
+ chain_encoding_all[col]) == 0).long()
294
+ e_pos = self.embedding(offset, e_chains)
295
+ h_e = self.edge_mlp(torch.cat([edge_attr, e_pos], dim=-1))
296
+ h_v = torch.zeros(x.size(0), self.hidden_dim, device=x.device)
297
+
298
+ # encoder
299
+ for encoder in self.encoder_layers:
300
+ h_v, h_e = encoder(h_v, edge_index, h_e)
301
+
302
+ # mask
303
+ h_label = self.label_embedding(chain_seq_label)
304
+ batch_chain_mask_all, _ = to_dense_batch(chain_mask_all * mask,
305
+ batch) # [B, N]
306
+ # 0 - visible - encoder, 1 - masked - decoder
307
+ decoding_order = torch.argsort(
308
+ (batch_chain_mask_all + 1e-4) * (torch.abs(
309
+ torch.randn(batch_chain_mask_all.shape, device=device))))
310
+ mask_size = batch_chain_mask_all.size(1)
311
+ permutation_matrix_reverse = F.one_hot(decoding_order,
312
+ num_classes=mask_size).float()
313
+ order_mask_backward = torch.einsum(
314
+ 'ij, biq, bjp->bqp',
315
+ 1 - torch.triu(torch.ones(mask_size, mask_size, device=device)),
316
+ permutation_matrix_reverse,
317
+ permutation_matrix_reverse,
318
+ )
319
+ adj = to_dense_adj(edge_index, batch)
320
+ mask_attend = order_mask_backward[adj.bool()].unsqueeze(-1)
321
+
322
+ # decoder
323
+ for decoder in self.decoder_layers:
324
+ h_v = decoder(
325
+ h_v,
326
+ edge_index,
327
+ h_e,
328
+ h_label,
329
+ mask_attend,
330
+ )
331
+
332
+ logits = self.output(h_v)
333
+ return F.log_softmax(logits, dim=-1)
@@ -0,0 +1,188 @@
1
+ from enum import Enum
2
+ from typing import List, Optional, Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import Tensor
7
+ from tqdm import tqdm
8
+
9
+
10
+ class PoolingStrategy(Enum):
11
+ MEAN = 'mean'
12
+ LAST = 'last'
13
+ CLS = 'cls'
14
+ LAST_HIDDEN_STATE = 'last_hidden_state'
15
+
16
+
17
+ class SentenceTransformer(torch.nn.Module):
18
+ r"""A wrapper around a Sentence-Transformer from HuggingFace.
19
+
20
+ Args:
21
+ model_name (str): The HuggingFace model name, *e.g.*, :obj:`"BERT"`.
22
+ pooling_strategy (str, optional): The pooling strategy to use
23
+ for generating node embeddings. (default: :obj:`"mean"`)
24
+ """
25
+ def __init__(
26
+ self,
27
+ model_name: str,
28
+ pooling_strategy: Union[PoolingStrategy, str] = 'mean',
29
+ ) -> None:
30
+ super().__init__()
31
+
32
+ self.model_name = model_name
33
+ self.pooling_strategy = PoolingStrategy(pooling_strategy)
34
+
35
+ from transformers import AutoModel, AutoTokenizer
36
+
37
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
38
+ self.model = AutoModel.from_pretrained(model_name)
39
+ if self.tokenizer.pad_token is None:
40
+ self.tokenizer.pad_token = self.tokenizer.eos_token
41
+
42
+ # Maximum sequence length from the model configuration (e.g. 8192 for
43
+ # models like ModernBERT)
44
+ self.max_seq_length = self.model.config.max_position_embeddings
45
+ """
46
+ Some models define a max sequence length in their configuration. Others
47
+ only in the tokenizer. This is a hacky heuristic to find the max
48
+ sequence length that works for the model.
49
+ """
50
+ probe_tokens = self.tokenizer("hacky heuristic", padding='max_length',
51
+ return_tensors='pt')
52
+ self.max_seq_length = min(self.max_seq_length,
53
+ probe_tokens.input_ids.shape[1])
54
+
55
+ def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
56
+ out = self.model(input_ids=input_ids, attention_mask=attention_mask)
57
+
58
+ emb = out[0] # First element contains all token embeddings.
59
+ if self.pooling_strategy == PoolingStrategy.MEAN:
60
+ emb = mean_pooling(emb, attention_mask)
61
+ elif self.pooling_strategy == PoolingStrategy.LAST:
62
+ emb = last_pooling(emb, attention_mask)
63
+ elif self.pooling_strategy == PoolingStrategy.LAST_HIDDEN_STATE:
64
+ emb = out.last_hidden_state
65
+ else:
66
+ assert self.pooling_strategy == PoolingStrategy.CLS
67
+ emb = emb[:, 0, :]
68
+
69
+ emb = F.normalize(emb, p=2, dim=1)
70
+ return emb
71
+
72
+ def get_input_ids(
73
+ self,
74
+ text: List[str],
75
+ batch_size: Optional[int] = None,
76
+ output_device: Optional[Union[torch.device, str]] = None,
77
+ ) -> Tensor:
78
+ is_empty = len(text) == 0
79
+ text = ['dummy'] if is_empty else text
80
+
81
+ batch_size = len(text) if batch_size is None else batch_size
82
+
83
+ input_ids: List[Tensor] = []
84
+ attention_masks: List[Tensor] = []
85
+ for start in range(0, len(text), batch_size):
86
+ token = self.tokenizer(
87
+ text[start:start + batch_size],
88
+ padding=True,
89
+ truncation=True,
90
+ return_tensors='pt',
91
+ max_length=self.max_seq_length,
92
+ )
93
+ input_ids.append(token.input_ids.to(self.device))
94
+ attention_masks.append(token.attention_mask.to(self.device))
95
+
96
+ def _out(x: List[Tensor]) -> Tensor:
97
+ out = torch.cat(x, dim=0) if len(x) > 1 else x[0]
98
+ out = out[:0] if is_empty else out
99
+ return out.to(output_device)
100
+
101
+ return _out(input_ids), _out(attention_masks)
102
+
103
+ @property
104
+ def device(self) -> torch.device:
105
+ return next(iter(self.model.parameters())).device
106
+
107
+ @torch.no_grad()
108
+ def encode(
109
+ self,
110
+ text: List[str],
111
+ batch_size: Optional[int] = None,
112
+ output_device: Optional[Union[torch.device, str]] = None,
113
+ verbose=False,
114
+ ) -> Tensor:
115
+ r"""Main function for users. Converts strings to embeddings.
116
+
117
+ Args:
118
+ text (List[str]): List of strings to embed.
119
+ batch_size (int, optional): How many strings to process.
120
+ Defaults to processing all at once, but this may lead to
121
+ OOM errors. (default: obj:`None`)
122
+ output_device (Union[torch.device, str], optional):
123
+ By default outputs cpu pytorch tensor, but can choose
124
+ to output to specific cuda devices. (default: obj:`None`)
125
+ verbose (bool, optional): Controls the verbosity of outputs.
126
+ (default: obj:`False`)
127
+ """
128
+ is_empty = len(text) == 0
129
+ text = ['dummy'] if is_empty else text
130
+
131
+ batch_size = len(text) if batch_size is None else batch_size
132
+
133
+ embs: List[Tensor] = []
134
+ loader = range(0, len(text), batch_size)
135
+ if verbose:
136
+ loader = tqdm(
137
+ loader, desc="Encoding " + str(len(text)) +
138
+ " strings w/ SentenceTransformer")
139
+ for start in loader:
140
+ token = self.tokenizer(
141
+ text[start:start + batch_size],
142
+ padding=True,
143
+ truncation=True,
144
+ return_tensors='pt',
145
+ max_length=self.max_seq_length,
146
+ )
147
+ try:
148
+ emb = self(
149
+ input_ids=token.input_ids.to(self.device),
150
+ attention_mask=token.attention_mask.to(self.device),
151
+ ).to(output_device)
152
+
153
+ embs.append(emb)
154
+ except: # noqa
155
+ # fallback to using CPU for huge strings that cause OOMs
156
+ print("Sentence Transformer failed on cuda, trying w/ cpu...")
157
+ previous_device = self.device
158
+ self.model = self.model.to("cpu")
159
+ emb = self(
160
+ input_ids=token.input_ids.to(self.device),
161
+ attention_mask=token.attention_mask.to(self.device),
162
+ ).to(output_device)
163
+
164
+ embs.append(emb)
165
+ self.model = self.model.to(previous_device)
166
+
167
+ out = torch.cat(embs, dim=0) if len(embs) > 1 else embs[0]
168
+ out = out[:0] if is_empty else out
169
+ return out
170
+
171
+ def __repr__(self) -> str:
172
+ return f'{self.__class__.__name__}(model_name={self.model_name})'
173
+
174
+
175
+ def mean_pooling(emb: Tensor, attention_mask: Tensor) -> Tensor:
176
+ mask = attention_mask.unsqueeze(-1).expand(emb.size()).to(emb.dtype)
177
+ return (emb * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
178
+
179
+
180
+ def last_pooling(emb: Tensor, attention_mask: Tensor) -> Tensor:
181
+ # Check whether language model uses left padding,
182
+ # which is always used for decoder LLMs
183
+ left_padding = attention_mask[:, -1].sum() == attention_mask.size(0)
184
+ if left_padding:
185
+ return emb[:, -1]
186
+
187
+ seq_indices = attention_mask.sum(dim=1) - 1
188
+ return emb[torch.arange(emb.size(0), device=emb.device), seq_indices]