boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,380 @@
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from functools import partial
16
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
17
+
18
+ import torch
19
+
20
+
21
+ def add(m1, m2, inplace):
22
+ # The first operation in a checkpoint can't be in-place, but it's
23
+ # nice to have in-place addition during inference. Thus...
24
+ if not inplace:
25
+ m1 = m1 + m2
26
+ else:
27
+ m1 += m2
28
+
29
+ return m1
30
+
31
+
32
+ def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
33
+ zero_index = -1 * len(inds)
34
+ first_inds = list(range(len(tensor.shape[:zero_index])))
35
+ return tensor.permute(first_inds + [zero_index + i for i in inds])
36
+
37
+
38
+ def is_fp16_enabled():
39
+ # Autocast world
40
+ fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16
41
+ fp16_enabled = fp16_enabled and torch.is_autocast_enabled()
42
+
43
+ return fp16_enabled
44
+
45
+
46
+ # With tree_map, a poor man's JAX tree_map
47
+ def dict_map(fn, dic, leaf_type):
48
+ new_dict = {}
49
+ for k, v in dic.items():
50
+ if type(v) is dict:
51
+ new_dict[k] = dict_map(fn, v, leaf_type)
52
+ else:
53
+ new_dict[k] = tree_map(fn, v, leaf_type)
54
+
55
+ return new_dict
56
+
57
+
58
+ def tree_map(fn, tree, leaf_type):
59
+ if isinstance(tree, dict):
60
+ return dict_map(fn, tree, leaf_type)
61
+ elif isinstance(tree, list):
62
+ return [tree_map(fn, x, leaf_type) for x in tree]
63
+ elif isinstance(tree, tuple):
64
+ return tuple([tree_map(fn, x, leaf_type) for x in tree])
65
+ elif isinstance(tree, leaf_type):
66
+ return fn(tree)
67
+ else:
68
+ raise ValueError(f"Tree of type {type(tree)} not supported")
69
+
70
+
71
+ tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
72
+
73
+
74
+ def flatten_final_dims(t: torch.Tensor, no_dims: int):
75
+ return t.reshape(t.shape[:-no_dims] + (-1,))
76
+
77
+
78
+ def _fetch_dims(tree):
79
+ shapes = []
80
+ tree_type = type(tree)
81
+ if tree_type is dict:
82
+ for v in tree.values():
83
+ shapes.extend(_fetch_dims(v))
84
+ elif tree_type is list or tree_type is tuple:
85
+ for t in tree:
86
+ shapes.extend(_fetch_dims(t))
87
+ elif tree_type is torch.Tensor:
88
+ shapes.append(tree.shape)
89
+ else:
90
+ raise ValueError("Not supported")
91
+
92
+ return shapes
93
+
94
+
95
+ @torch.jit.ignore
96
+ def _flat_idx_to_idx(
97
+ flat_idx: int,
98
+ dims: Tuple[int],
99
+ ) -> Tuple[int]:
100
+ idx = []
101
+ for d in reversed(dims):
102
+ idx.append(flat_idx % d)
103
+ flat_idx = flat_idx // d
104
+
105
+ return tuple(reversed(idx))
106
+
107
+
108
+ @torch.jit.ignore
109
+ def _get_minimal_slice_set(
110
+ start: Sequence[int],
111
+ end: Sequence[int],
112
+ dims: int,
113
+ start_edges: Optional[Sequence[bool]] = None,
114
+ end_edges: Optional[Sequence[bool]] = None,
115
+ ) -> Sequence[Tuple[int]]:
116
+ """
117
+ Produces an ordered sequence of tensor slices that, when used in
118
+ sequence on a tensor with shape dims, yields tensors that contain every
119
+ leaf in the contiguous range [start, end]. Care is taken to yield a
120
+ short sequence of slices, and perhaps even the shortest possible (I'm
121
+ pretty sure it's the latter).
122
+
123
+ end is INCLUSIVE.
124
+ """
125
+
126
+ # start_edges and end_edges both indicate whether, starting from any given
127
+ # dimension, the start/end index is at the top/bottom edge of the
128
+ # corresponding tensor, modeled as a tree
129
+ def reduce_edge_list(l):
130
+ tally = 1
131
+ for i in range(len(l)):
132
+ reversed_idx = -1 * (i + 1)
133
+ l[reversed_idx] *= tally
134
+ tally = l[reversed_idx]
135
+
136
+ if start_edges is None:
137
+ start_edges = [s == 0 for s in start]
138
+ reduce_edge_list(start_edges)
139
+ if end_edges is None:
140
+ end_edges = [e == (d - 1) for e, d in zip(end, dims)]
141
+ reduce_edge_list(end_edges)
142
+
143
+ # Base cases. Either start/end are empty and we're done, or the final,
144
+ # one-dimensional tensor can be simply sliced
145
+ if len(start) == 0:
146
+ return [tuple()]
147
+ elif len(start) == 1:
148
+ return [(slice(start[0], end[0] + 1),)]
149
+
150
+ slices = []
151
+ path = []
152
+
153
+ # Dimensions common to start and end can be selected directly
154
+ for s, e in zip(start, end):
155
+ if s == e:
156
+ path.append(slice(s, s + 1))
157
+ else:
158
+ break
159
+
160
+ path = tuple(path)
161
+ divergence_idx = len(path)
162
+
163
+ # start == end, and we're done
164
+ if divergence_idx == len(dims):
165
+ return [tuple(path)]
166
+
167
+ def upper():
168
+ sdi = start[divergence_idx]
169
+ return [
170
+ path + (slice(sdi, sdi + 1),) + s
171
+ for s in _get_minimal_slice_set(
172
+ start[divergence_idx + 1 :],
173
+ [d - 1 for d in dims[divergence_idx + 1 :]],
174
+ dims[divergence_idx + 1 :],
175
+ start_edges=start_edges[divergence_idx + 1 :],
176
+ end_edges=[1 for _ in end_edges[divergence_idx + 1 :]],
177
+ )
178
+ ]
179
+
180
+ def lower():
181
+ edi = end[divergence_idx]
182
+ return [
183
+ path + (slice(edi, edi + 1),) + s
184
+ for s in _get_minimal_slice_set(
185
+ [0 for _ in start[divergence_idx + 1 :]],
186
+ end[divergence_idx + 1 :],
187
+ dims[divergence_idx + 1 :],
188
+ start_edges=[1 for _ in start_edges[divergence_idx + 1 :]],
189
+ end_edges=end_edges[divergence_idx + 1 :],
190
+ )
191
+ ]
192
+
193
+ # If both start and end are at the edges of the subtree rooted at
194
+ # divergence_idx, we can just select the whole subtree at once
195
+ if start_edges[divergence_idx] and end_edges[divergence_idx]:
196
+ slices.append(path + (slice(start[divergence_idx], end[divergence_idx] + 1),))
197
+ # If just start is at the edge, we can grab almost all of the subtree,
198
+ # treating only the ragged bottom edge as an edge case
199
+ elif start_edges[divergence_idx]:
200
+ slices.append(path + (slice(start[divergence_idx], end[divergence_idx]),))
201
+ slices.extend(lower())
202
+ # Analogous to the previous case, but the top is ragged this time
203
+ elif end_edges[divergence_idx]:
204
+ slices.extend(upper())
205
+ slices.append(
206
+ path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),)
207
+ )
208
+ # If both sides of the range are ragged, we need to handle both sides
209
+ # separately. If there's contiguous meat in between them, we can index it
210
+ # in one big chunk
211
+ else:
212
+ slices.extend(upper())
213
+ middle_ground = end[divergence_idx] - start[divergence_idx]
214
+ if middle_ground > 1:
215
+ slices.append(
216
+ path + (slice(start[divergence_idx] + 1, end[divergence_idx]),)
217
+ )
218
+ slices.extend(lower())
219
+
220
+ return [tuple(s) for s in slices]
221
+
222
+
223
+ @torch.jit.ignore
224
+ def _chunk_slice(
225
+ t: torch.Tensor,
226
+ flat_start: int,
227
+ flat_end: int,
228
+ no_batch_dims: int,
229
+ ) -> torch.Tensor:
230
+ """
231
+ Equivalent to
232
+
233
+ t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
234
+
235
+ but without the need for the initial reshape call, which can be
236
+ memory-intensive in certain situations. The only reshape operations
237
+ in this function are performed on sub-tensors that scale with
238
+ (flat_end - flat_start), the chunk size.
239
+ """
240
+
241
+ batch_dims = t.shape[:no_batch_dims]
242
+ start_idx = list(_flat_idx_to_idx(flat_start, batch_dims))
243
+ # _get_minimal_slice_set is inclusive
244
+ end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims))
245
+
246
+ # Get an ordered list of slices to perform
247
+ slices = _get_minimal_slice_set(
248
+ start_idx,
249
+ end_idx,
250
+ batch_dims,
251
+ )
252
+
253
+ sliced_tensors = [t[s] for s in slices]
254
+
255
+ return torch.cat([s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors])
256
+
257
+
258
+ def chunk_layer(
259
+ layer: Callable,
260
+ inputs: Dict[str, Any],
261
+ chunk_size: int,
262
+ no_batch_dims: int,
263
+ low_mem: bool = False,
264
+ _out: Any = None,
265
+ _add_into_out: bool = False,
266
+ ) -> Any:
267
+ """
268
+ Implements the "chunking" procedure described in section 1.11.8.
269
+
270
+ Layer outputs and inputs are assumed to be simple "pytrees,"
271
+ consisting only of (arbitrarily nested) lists, tuples, and dicts with
272
+ torch.Tensor leaves.
273
+
274
+ Args:
275
+ layer:
276
+ The layer to be applied chunk-wise
277
+ inputs:
278
+ A (non-nested) dictionary of keyworded inputs. All leaves must
279
+ be tensors and must share the same batch dimensions.
280
+ chunk_size:
281
+ The number of sub-batches per chunk. If multiple batch
282
+ dimensions are specified, a "sub-batch" is defined as a single
283
+ indexing of all batch dimensions simultaneously (s.t. the
284
+ number of sub-batches is the product of the batch dimensions).
285
+ no_batch_dims:
286
+ How many of the initial dimensions of each input tensor can
287
+ be considered batch dimensions.
288
+ low_mem:
289
+ Avoids flattening potentially large input tensors. Unnecessary
290
+ in most cases, and is ever so slightly slower than the default
291
+ setting.
292
+ Returns:
293
+ The reassembled output of the layer on the inputs.
294
+ """
295
+ if not (len(inputs) > 0):
296
+ raise ValueError("Must provide at least one input")
297
+
298
+ initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)]
299
+ orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
300
+
301
+ def _prep_inputs(t):
302
+ if not low_mem:
303
+ if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
304
+ t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
305
+ t = t.reshape(-1, *t.shape[no_batch_dims:])
306
+ else:
307
+ t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
308
+ return t
309
+
310
+ prepped_inputs = tensor_tree_map(_prep_inputs, inputs)
311
+ prepped_outputs = None
312
+ if _out is not None:
313
+ reshape_fn = lambda t: t.view([-1] + list(t.shape[no_batch_dims:]))
314
+ prepped_outputs = tensor_tree_map(reshape_fn, _out)
315
+
316
+ flat_batch_dim = 1
317
+ for d in orig_batch_dims:
318
+ flat_batch_dim *= d
319
+
320
+ no_chunks = flat_batch_dim // chunk_size + (flat_batch_dim % chunk_size != 0)
321
+
322
+ i = 0
323
+ out = prepped_outputs
324
+ for _ in range(no_chunks):
325
+ # Chunk the input
326
+ if not low_mem:
327
+ select_chunk = lambda t: t[i : i + chunk_size] if t.shape[0] != 1 else t
328
+ else:
329
+ select_chunk = partial(
330
+ _chunk_slice,
331
+ flat_start=i,
332
+ flat_end=min(flat_batch_dim, i + chunk_size),
333
+ no_batch_dims=len(orig_batch_dims),
334
+ )
335
+
336
+ chunks = tensor_tree_map(select_chunk, prepped_inputs)
337
+
338
+ # Run the layer on the chunk
339
+ output_chunk = layer(**chunks)
340
+
341
+ # Allocate space for the output
342
+ if out is None:
343
+ allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:])
344
+ out = tensor_tree_map(allocate, output_chunk)
345
+
346
+ # Put the chunk in its pre-allocated space
347
+ out_type = type(output_chunk)
348
+ if out_type is dict:
349
+
350
+ def assign(d1, d2):
351
+ for k, v in d1.items():
352
+ if type(v) is dict:
353
+ assign(v, d2[k])
354
+ else:
355
+ if _add_into_out:
356
+ v[i : i + chunk_size] += d2[k]
357
+ else:
358
+ v[i : i + chunk_size] = d2[k]
359
+
360
+ assign(out, output_chunk)
361
+ elif out_type is tuple:
362
+ for x1, x2 in zip(out, output_chunk):
363
+ if _add_into_out:
364
+ x1[i : i + chunk_size] += x2
365
+ else:
366
+ x1[i : i + chunk_size] = x2
367
+ elif out_type is torch.Tensor:
368
+ if _add_into_out:
369
+ out[i : i + chunk_size] += output_chunk
370
+ else:
371
+ out[i : i + chunk_size] = output_chunk
372
+ else:
373
+ raise ValueError("Not supported")
374
+
375
+ i += chunk_size
376
+
377
+ reshape = lambda t: t.view(orig_batch_dims + t.shape[1:])
378
+ out = tensor_tree_map(reshape, out)
379
+
380
+ return out
@@ -0,0 +1,212 @@
1
+ import torch
2
+ from cuequivariance_torch.primitives.triangle import triangle_multiplicative_update
3
+ from torch import Tensor, nn
4
+
5
+ from boltz.model.layers import initialize as init
6
+
7
+
8
+ @torch.compiler.disable
9
+ def kernel_triangular_mult(
10
+ x,
11
+ direction,
12
+ mask,
13
+ norm_in_weight,
14
+ norm_in_bias,
15
+ p_in_weight,
16
+ g_in_weight,
17
+ norm_out_weight,
18
+ norm_out_bias,
19
+ p_out_weight,
20
+ g_out_weight,
21
+ eps,
22
+ ):
23
+ return triangle_multiplicative_update(
24
+ x,
25
+ direction=direction,
26
+ mask=mask,
27
+ norm_in_weight=norm_in_weight,
28
+ norm_in_bias=norm_in_bias,
29
+ p_in_weight=p_in_weight,
30
+ g_in_weight=g_in_weight,
31
+ norm_out_weight=norm_out_weight,
32
+ norm_out_bias=norm_out_bias,
33
+ p_out_weight=p_out_weight,
34
+ g_out_weight=g_out_weight,
35
+ eps=eps,
36
+ )
37
+
38
+
39
+ class TriangleMultiplicationOutgoing(nn.Module):
40
+ """TriangleMultiplicationOutgoing."""
41
+
42
+ def __init__(self, dim: int = 128) -> None:
43
+ """Initialize the TriangularUpdate module.
44
+
45
+ Parameters
46
+ ----------
47
+ dim: int
48
+ The dimension of the input, default 128
49
+
50
+ """
51
+ super().__init__()
52
+
53
+ self.norm_in = nn.LayerNorm(dim, eps=1e-5)
54
+ self.p_in = nn.Linear(dim, 2 * dim, bias=False)
55
+ self.g_in = nn.Linear(dim, 2 * dim, bias=False)
56
+
57
+ self.norm_out = nn.LayerNorm(dim)
58
+ self.p_out = nn.Linear(dim, dim, bias=False)
59
+ self.g_out = nn.Linear(dim, dim, bias=False)
60
+
61
+ init.bias_init_one_(self.norm_in.weight)
62
+ init.bias_init_zero_(self.norm_in.bias)
63
+
64
+ init.lecun_normal_init_(self.p_in.weight)
65
+ init.gating_init_(self.g_in.weight)
66
+
67
+ init.bias_init_one_(self.norm_out.weight)
68
+ init.bias_init_zero_(self.norm_out.bias)
69
+
70
+ init.final_init_(self.p_out.weight)
71
+ init.gating_init_(self.g_out.weight)
72
+
73
+ def forward(self, x: Tensor, mask: Tensor, use_kernels: bool = False) -> Tensor:
74
+ """Perform a forward pass.
75
+
76
+ Parameters
77
+ ----------
78
+ x: torch.Tensor
79
+ The input data of shape (B, N, N, D)
80
+ mask: torch.Tensor
81
+ The input mask of shape (B, N, N)
82
+ use_kernels: bool
83
+ Whether to use the kernel
84
+
85
+ Returns
86
+ -------
87
+ x: torch.Tensor
88
+ The output data of shape (B, N, N, D)
89
+
90
+ """
91
+ if use_kernels:
92
+ return kernel_triangular_mult(
93
+ x,
94
+ direction="outgoing",
95
+ mask=mask,
96
+ norm_in_weight=self.norm_in.weight,
97
+ norm_in_bias=self.norm_in.bias,
98
+ p_in_weight=self.p_in.weight,
99
+ g_in_weight=self.g_in.weight,
100
+ norm_out_weight=self.norm_out.weight,
101
+ norm_out_bias=self.norm_out.bias,
102
+ p_out_weight=self.p_out.weight,
103
+ g_out_weight=self.g_out.weight,
104
+ eps=1e-5,
105
+ )
106
+
107
+ # Input gating: D -> D
108
+ x = self.norm_in(x)
109
+ x_in = x
110
+ x = self.p_in(x) * self.g_in(x).sigmoid()
111
+
112
+ # Apply mask
113
+ x = x * mask.unsqueeze(-1)
114
+
115
+ # Split input and cast to float
116
+ a, b = torch.chunk(x.float(), 2, dim=-1)
117
+
118
+ # Triangular projection
119
+ x = torch.einsum("bikd,bjkd->bijd", a, b)
120
+
121
+ # Output gating
122
+ x = self.p_out(self.norm_out(x)) * self.g_out(x_in).sigmoid()
123
+
124
+ return x
125
+
126
+
127
+ class TriangleMultiplicationIncoming(nn.Module):
128
+ """TriangleMultiplicationIncoming."""
129
+
130
+ def __init__(self, dim: int = 128) -> None:
131
+ """Initialize the TriangularUpdate module.
132
+
133
+ Parameters
134
+ ----------
135
+ dim: int
136
+ The dimension of the input, default 128
137
+
138
+ """
139
+ super().__init__()
140
+
141
+ self.norm_in = nn.LayerNorm(dim, eps=1e-5)
142
+ self.p_in = nn.Linear(dim, 2 * dim, bias=False)
143
+ self.g_in = nn.Linear(dim, 2 * dim, bias=False)
144
+
145
+ self.norm_out = nn.LayerNorm(dim)
146
+ self.p_out = nn.Linear(dim, dim, bias=False)
147
+ self.g_out = nn.Linear(dim, dim, bias=False)
148
+
149
+ init.bias_init_one_(self.norm_in.weight)
150
+ init.bias_init_zero_(self.norm_in.bias)
151
+
152
+ init.lecun_normal_init_(self.p_in.weight)
153
+ init.gating_init_(self.g_in.weight)
154
+
155
+ init.bias_init_one_(self.norm_out.weight)
156
+ init.bias_init_zero_(self.norm_out.bias)
157
+
158
+ init.final_init_(self.p_out.weight)
159
+ init.gating_init_(self.g_out.weight)
160
+
161
+ def forward(self, x: Tensor, mask: Tensor, use_kernels: bool = False) -> Tensor:
162
+ """Perform a forward pass.
163
+
164
+ Parameters
165
+ ----------
166
+ x: torch.Tensor
167
+ The input data of shape (B, N, N, D)
168
+ mask: torch.Tensor
169
+ The input mask of shape (B, N, N)
170
+ use_kernels: bool
171
+ Whether to use the kernel
172
+
173
+ Returns
174
+ -------
175
+ x: torch.Tensor
176
+ The output data of shape (B, N, N, D)
177
+
178
+ """
179
+ if use_kernels:
180
+ return kernel_triangular_mult(
181
+ x,
182
+ direction="incoming",
183
+ mask=mask,
184
+ norm_in_weight=self.norm_in.weight,
185
+ norm_in_bias=self.norm_in.bias,
186
+ p_in_weight=self.p_in.weight,
187
+ g_in_weight=self.g_in.weight,
188
+ norm_out_weight=self.norm_out.weight,
189
+ norm_out_bias=self.norm_out.bias,
190
+ p_out_weight=self.p_out.weight,
191
+ g_out_weight=self.g_out.weight,
192
+ eps=1e-5,
193
+ )
194
+
195
+ # Input gating: D -> D
196
+ x = self.norm_in(x)
197
+ x_in = x
198
+ x = self.p_in(x) * self.g_in(x).sigmoid()
199
+
200
+ # Apply mask
201
+ x = x * mask.unsqueeze(-1)
202
+
203
+ # Split input and cast to float
204
+ a, b = torch.chunk(x.float(), 2, dim=-1)
205
+
206
+ # Triangular projection
207
+ x = torch.einsum("bkid,bkjd->bijd", a, b)
208
+
209
+ # Output gating
210
+ x = self.p_out(self.norm_out(x)) * self.g_out(x_in).sigmoid()
211
+
212
+ return x
File without changes
@@ -0,0 +1,49 @@
1
+ import torch
2
+ from torch import Tensor
3
+
4
+
5
+ def bfactor_loss_fn(
6
+ output: dict[str, Tensor],
7
+ feats: dict[str, Tensor],
8
+ ) -> Tensor:
9
+ """Compute the bfactor loss.
10
+
11
+ Parameters
12
+ ----------
13
+ output : dict[str, Tensor]
14
+ Output of the model
15
+ feats : dict[str, Tensor]
16
+ Input features
17
+
18
+ Returns
19
+ -------
20
+ Tensor
21
+ The globally averaged loss.
22
+
23
+ """
24
+ with torch.autocast("cuda", enabled=False):
25
+ # Get predicted distograms
26
+ pred = output["pbfactor"].float() # (B, L, bins)
27
+ bins = pred.shape[2] # num_bins
28
+ token_to_rep_atom = feats["token_to_rep_atom"]
29
+
30
+ # Compute target histogram
31
+ bfactor_atom = feats["bfactor"].unsqueeze(-1) # (B, L)
32
+ bfactor_token = torch.bmm(token_to_rep_atom.float(), bfactor_atom)
33
+
34
+ boundaries = torch.linspace(0, 100, bins - 1, device=bfactor_token.device)
35
+ bfactor_token_bin = (bfactor_token > boundaries).sum(dim=-1).long()
36
+ bfactor_target = torch.nn.functional.one_hot(
37
+ bfactor_token_bin, num_classes=bins
38
+ )
39
+
40
+ # Combine target mask and padding mask
41
+ token_mask = (bfactor_token > 1e-5).squeeze(-1).float()
42
+
43
+ # Compute the bfactor loss
44
+ errors = -1 * torch.sum(
45
+ bfactor_target * torch.nn.functional.log_softmax(pred, dim=-1),
46
+ dim=-1,
47
+ )
48
+ loss = torch.sum(errors * token_mask) / (torch.sum(token_mask) + 1e-5)
49
+ return loss