boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,322 @@
1
+ # started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
2
+
3
+ from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
4
+ from torch import nn, sigmoid
5
+ from torch.nn import (
6
+ LayerNorm,
7
+ Linear,
8
+ Module,
9
+ ModuleList,
10
+ Sequential,
11
+ )
12
+
13
+ from boltz.model.layers.attention import AttentionPairBias
14
+ from boltz.model.modules.utils import LinearNoBias, SwiGLU, default
15
+
16
+
17
+ class AdaLN(Module):
18
+ """Adaptive Layer Normalization"""
19
+
20
+ def __init__(self, dim, dim_single_cond):
21
+ """Initialize the adaptive layer normalization.
22
+
23
+ Parameters
24
+ ----------
25
+ dim : int
26
+ The input dimension.
27
+ dim_single_cond : int
28
+ The single condition dimension.
29
+
30
+ """
31
+ super().__init__()
32
+ self.a_norm = LayerNorm(dim, elementwise_affine=False, bias=False)
33
+ self.s_norm = LayerNorm(dim_single_cond, bias=False)
34
+ self.s_scale = Linear(dim_single_cond, dim)
35
+ self.s_bias = LinearNoBias(dim_single_cond, dim)
36
+
37
+ def forward(self, a, s):
38
+ a = self.a_norm(a)
39
+ s = self.s_norm(s)
40
+ a = sigmoid(self.s_scale(s)) * a + self.s_bias(s)
41
+ return a
42
+
43
+
44
+ class ConditionedTransitionBlock(Module):
45
+ """Conditioned Transition Block"""
46
+
47
+ def __init__(self, dim_single, dim_single_cond, expansion_factor=2):
48
+ """Initialize the conditioned transition block.
49
+
50
+ Parameters
51
+ ----------
52
+ dim_single : int
53
+ The single dimension.
54
+ dim_single_cond : int
55
+ The single condition dimension.
56
+ expansion_factor : int, optional
57
+ The expansion factor, by default 2
58
+
59
+ """
60
+ super().__init__()
61
+
62
+ self.adaln = AdaLN(dim_single, dim_single_cond)
63
+
64
+ dim_inner = int(dim_single * expansion_factor)
65
+ self.swish_gate = Sequential(
66
+ LinearNoBias(dim_single, dim_inner * 2),
67
+ SwiGLU(),
68
+ )
69
+ self.a_to_b = LinearNoBias(dim_single, dim_inner)
70
+ self.b_to_a = LinearNoBias(dim_inner, dim_single)
71
+
72
+ output_projection_linear = Linear(dim_single_cond, dim_single)
73
+ nn.init.zeros_(output_projection_linear.weight)
74
+ nn.init.constant_(output_projection_linear.bias, -2.0)
75
+
76
+ self.output_projection = nn.Sequential(output_projection_linear, nn.Sigmoid())
77
+
78
+ def forward(
79
+ self,
80
+ a,
81
+ s,
82
+ ):
83
+ a = self.adaln(a, s)
84
+ b = self.swish_gate(a) * self.a_to_b(a)
85
+ a = self.output_projection(s) * self.b_to_a(b)
86
+
87
+ return a
88
+
89
+
90
+ class DiffusionTransformer(Module):
91
+ """Diffusion Transformer"""
92
+
93
+ def __init__(
94
+ self,
95
+ depth,
96
+ heads,
97
+ dim=384,
98
+ dim_single_cond=None,
99
+ dim_pairwise=128,
100
+ activation_checkpointing=False,
101
+ offload_to_cpu=False,
102
+ ):
103
+ """Initialize the diffusion transformer.
104
+
105
+ Parameters
106
+ ----------
107
+ depth : int
108
+ The depth.
109
+ heads : int
110
+ The number of heads.
111
+ dim : int, optional
112
+ The dimension, by default 384
113
+ dim_single_cond : int, optional
114
+ The single condition dimension, by default None
115
+ dim_pairwise : int, optional
116
+ The pairwise dimension, by default 128
117
+ activation_checkpointing : bool, optional
118
+ Whether to use activation checkpointing, by default False
119
+ offload_to_cpu : bool, optional
120
+ Whether to offload to CPU, by default False
121
+
122
+ """
123
+ super().__init__()
124
+ self.activation_checkpointing = activation_checkpointing
125
+ dim_single_cond = default(dim_single_cond, dim)
126
+
127
+ self.layers = ModuleList()
128
+ for _ in range(depth):
129
+ if activation_checkpointing:
130
+ self.layers.append(
131
+ checkpoint_wrapper(
132
+ DiffusionTransformerLayer(
133
+ heads,
134
+ dim,
135
+ dim_single_cond,
136
+ dim_pairwise,
137
+ ),
138
+ offload_to_cpu=offload_to_cpu,
139
+ )
140
+ )
141
+ else:
142
+ self.layers.append(
143
+ DiffusionTransformerLayer(
144
+ heads,
145
+ dim,
146
+ dim_single_cond,
147
+ dim_pairwise,
148
+ )
149
+ )
150
+
151
+ def forward(
152
+ self,
153
+ a,
154
+ s,
155
+ z,
156
+ mask=None,
157
+ to_keys=None,
158
+ multiplicity=1,
159
+ model_cache=None,
160
+ ):
161
+ for i, layer in enumerate(self.layers):
162
+ layer_cache = None
163
+ if model_cache is not None:
164
+ prefix_cache = "layer_" + str(i)
165
+ if prefix_cache not in model_cache:
166
+ model_cache[prefix_cache] = {}
167
+ layer_cache = model_cache[prefix_cache]
168
+ a = layer(
169
+ a,
170
+ s,
171
+ z,
172
+ mask=mask,
173
+ to_keys=to_keys,
174
+ multiplicity=multiplicity,
175
+ layer_cache=layer_cache,
176
+ )
177
+ return a
178
+
179
+
180
+ class DiffusionTransformerLayer(Module):
181
+ """Diffusion Transformer Layer"""
182
+
183
+ def __init__(
184
+ self,
185
+ heads,
186
+ dim=384,
187
+ dim_single_cond=None,
188
+ dim_pairwise=128,
189
+ ):
190
+ """Initialize the diffusion transformer layer.
191
+
192
+ Parameters
193
+ ----------
194
+ heads : int
195
+ The number of heads.
196
+ dim : int, optional
197
+ The dimension, by default 384
198
+ dim_single_cond : int, optional
199
+ The single condition dimension, by default None
200
+ dim_pairwise : int, optional
201
+ The pairwise dimension, by default 128
202
+
203
+ """
204
+ super().__init__()
205
+
206
+ dim_single_cond = default(dim_single_cond, dim)
207
+
208
+ self.adaln = AdaLN(dim, dim_single_cond)
209
+
210
+ self.pair_bias_attn = AttentionPairBias(
211
+ c_s=dim, c_z=dim_pairwise, num_heads=heads, initial_norm=False
212
+ )
213
+
214
+ self.output_projection_linear = Linear(dim_single_cond, dim)
215
+ nn.init.zeros_(self.output_projection_linear.weight)
216
+ nn.init.constant_(self.output_projection_linear.bias, -2.0)
217
+
218
+ self.output_projection = nn.Sequential(
219
+ self.output_projection_linear, nn.Sigmoid()
220
+ )
221
+ self.transition = ConditionedTransitionBlock(
222
+ dim_single=dim, dim_single_cond=dim_single_cond
223
+ )
224
+
225
+ def forward(
226
+ self,
227
+ a,
228
+ s,
229
+ z,
230
+ mask=None,
231
+ to_keys=None,
232
+ multiplicity=1,
233
+ layer_cache=None,
234
+ ):
235
+ b = self.adaln(a, s)
236
+ b = self.pair_bias_attn(
237
+ s=b,
238
+ z=z,
239
+ mask=mask,
240
+ multiplicity=multiplicity,
241
+ to_keys=to_keys,
242
+ model_cache=layer_cache,
243
+ )
244
+ b = self.output_projection(s) * b
245
+
246
+ # NOTE: Added residual connection!
247
+ a = a + b
248
+ a = a + self.transition(a, s)
249
+ return a
250
+
251
+
252
+ class AtomTransformer(Module):
253
+ """Atom Transformer"""
254
+
255
+ def __init__(
256
+ self,
257
+ attn_window_queries=None,
258
+ attn_window_keys=None,
259
+ **diffusion_transformer_kwargs,
260
+ ):
261
+ """Initialize the atom transformer.
262
+
263
+ Parameters
264
+ ----------
265
+ attn_window_queries : int, optional
266
+ The attention window queries, by default None
267
+ attn_window_keys : int, optional
268
+ The attention window keys, by default None
269
+ diffusion_transformer_kwargs : dict
270
+ The diffusion transformer keyword arguments
271
+
272
+ """
273
+ super().__init__()
274
+ self.attn_window_queries = attn_window_queries
275
+ self.attn_window_keys = attn_window_keys
276
+ self.diffusion_transformer = DiffusionTransformer(
277
+ **diffusion_transformer_kwargs
278
+ )
279
+
280
+ def forward(
281
+ self,
282
+ q,
283
+ c,
284
+ p,
285
+ to_keys=None,
286
+ mask=None,
287
+ multiplicity=1,
288
+ model_cache=None,
289
+ ):
290
+ W = self.attn_window_queries
291
+ H = self.attn_window_keys
292
+
293
+ if W is not None:
294
+ B, N, D = q.shape
295
+ NW = N // W
296
+
297
+ # reshape tokens
298
+ q = q.view((B * NW, W, -1))
299
+ c = c.view((B * NW, W, -1))
300
+ if mask is not None:
301
+ mask = mask.view(B * NW, W)
302
+ p = p.view((p.shape[0] * NW, W, H, -1))
303
+
304
+ to_keys_new = lambda x: to_keys(x.view(B, NW * W, -1)).view(B * NW, H, -1)
305
+ else:
306
+ to_keys_new = None
307
+
308
+ # main transformer
309
+ q = self.diffusion_transformer(
310
+ a=q,
311
+ s=c,
312
+ z=p,
313
+ mask=mask.float(),
314
+ multiplicity=multiplicity,
315
+ to_keys=to_keys_new,
316
+ model_cache=model_cache,
317
+ )
318
+
319
+ if W is not None:
320
+ q = q.view((B, NW * W, D))
321
+
322
+ return q
@@ -0,0 +1,261 @@
1
+ # started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
2
+
3
+ import torch
4
+ from torch import nn, sigmoid
5
+ from torch.nn import (
6
+ LayerNorm,
7
+ Linear,
8
+ Module,
9
+ ModuleList,
10
+ Sequential,
11
+ )
12
+
13
+ from boltz.model.layers.attentionv2 import AttentionPairBias
14
+ from boltz.model.modules.utils import LinearNoBias, SwiGLU, default
15
+
16
+
17
+ class AdaLN(Module):
18
+ """Algorithm 26"""
19
+
20
+ def __init__(self, dim, dim_single_cond):
21
+ super().__init__()
22
+ self.a_norm = LayerNorm(dim, elementwise_affine=False, bias=False)
23
+ self.s_norm = LayerNorm(dim_single_cond, bias=False)
24
+ self.s_scale = Linear(dim_single_cond, dim)
25
+ self.s_bias = LinearNoBias(dim_single_cond, dim)
26
+
27
+ def forward(self, a, s):
28
+ a = self.a_norm(a)
29
+ s = self.s_norm(s)
30
+ a = sigmoid(self.s_scale(s)) * a + self.s_bias(s)
31
+ return a
32
+
33
+
34
+ class ConditionedTransitionBlock(Module):
35
+ """Algorithm 25"""
36
+
37
+ def __init__(self, dim_single, dim_single_cond, expansion_factor=2):
38
+ super().__init__()
39
+
40
+ self.adaln = AdaLN(dim_single, dim_single_cond)
41
+
42
+ dim_inner = int(dim_single * expansion_factor)
43
+ self.swish_gate = Sequential(
44
+ LinearNoBias(dim_single, dim_inner * 2),
45
+ SwiGLU(),
46
+ )
47
+ self.a_to_b = LinearNoBias(dim_single, dim_inner)
48
+ self.b_to_a = LinearNoBias(dim_inner, dim_single)
49
+
50
+ output_projection_linear = Linear(dim_single_cond, dim_single)
51
+ nn.init.zeros_(output_projection_linear.weight)
52
+ nn.init.constant_(output_projection_linear.bias, -2.0)
53
+
54
+ self.output_projection = nn.Sequential(output_projection_linear, nn.Sigmoid())
55
+
56
+ def forward(
57
+ self,
58
+ a, # Float['... d']
59
+ s,
60
+ ): # -> Float['... d']:
61
+ a = self.adaln(a, s)
62
+ b = self.swish_gate(a) * self.a_to_b(a)
63
+ a = self.output_projection(s) * self.b_to_a(b)
64
+
65
+ return a
66
+
67
+
68
+ class DiffusionTransformer(Module):
69
+ """Algorithm 23"""
70
+
71
+ def __init__(
72
+ self,
73
+ depth,
74
+ heads,
75
+ dim=384,
76
+ dim_single_cond=None,
77
+ pair_bias_attn=True,
78
+ activation_checkpointing=False,
79
+ post_layer_norm=False,
80
+ ):
81
+ super().__init__()
82
+ self.activation_checkpointing = activation_checkpointing
83
+ dim_single_cond = default(dim_single_cond, dim)
84
+ self.pair_bias_attn = pair_bias_attn
85
+
86
+ self.layers = ModuleList()
87
+ for _ in range(depth):
88
+ self.layers.append(
89
+ DiffusionTransformerLayer(
90
+ heads,
91
+ dim,
92
+ dim_single_cond,
93
+ post_layer_norm,
94
+ )
95
+ )
96
+
97
+ def forward(
98
+ self,
99
+ a, # Float['bm n d'],
100
+ s, # Float['bm n ds'],
101
+ bias=None, # Float['b n n dp']
102
+ mask=None, # Bool['b n'] | None = None
103
+ to_keys=None,
104
+ multiplicity=1,
105
+ ):
106
+ if self.pair_bias_attn:
107
+ B, N, M, D = bias.shape
108
+ L = len(self.layers)
109
+ bias = bias.view(B, N, M, L, D // L)
110
+
111
+ for i, layer in enumerate(self.layers):
112
+ if self.pair_bias_attn:
113
+ bias_l = bias[:, :, :, i]
114
+ else:
115
+ bias_l = None
116
+
117
+ if self.activation_checkpointing and self.training:
118
+ a = torch.utils.checkpoint.checkpoint(
119
+ layer,
120
+ a,
121
+ s,
122
+ bias_l,
123
+ mask,
124
+ to_keys,
125
+ multiplicity,
126
+ )
127
+
128
+ else:
129
+ a = layer(
130
+ a, # Float['bm n d'],
131
+ s, # Float['bm n ds'],
132
+ bias_l, # Float['b n n dp']
133
+ mask, # Bool['b n'] | None = None
134
+ to_keys,
135
+ multiplicity,
136
+ )
137
+ return a
138
+
139
+
140
+ class DiffusionTransformerLayer(Module):
141
+ """Algorithm 23"""
142
+
143
+ def __init__(
144
+ self,
145
+ heads,
146
+ dim=384,
147
+ dim_single_cond=None,
148
+ post_layer_norm=False,
149
+ ):
150
+ super().__init__()
151
+
152
+ dim_single_cond = default(dim_single_cond, dim)
153
+
154
+ self.adaln = AdaLN(dim, dim_single_cond)
155
+ self.pair_bias_attn = AttentionPairBias(
156
+ c_s=dim, num_heads=heads, compute_pair_bias=False
157
+ )
158
+
159
+ self.output_projection_linear = Linear(dim_single_cond, dim)
160
+ nn.init.zeros_(self.output_projection_linear.weight)
161
+ nn.init.constant_(self.output_projection_linear.bias, -2.0)
162
+
163
+ self.output_projection = nn.Sequential(
164
+ self.output_projection_linear, nn.Sigmoid()
165
+ )
166
+ self.transition = ConditionedTransitionBlock(
167
+ dim_single=dim, dim_single_cond=dim_single_cond
168
+ )
169
+
170
+ if post_layer_norm:
171
+ self.post_lnorm = nn.LayerNorm(dim)
172
+ else:
173
+ self.post_lnorm = nn.Identity()
174
+
175
+ def forward(
176
+ self,
177
+ a, # Float['bm n d'],
178
+ s, # Float['bm n ds'],
179
+ bias=None, # Float['b n n dp']
180
+ mask=None, # Bool['b n'] | None = None
181
+ to_keys=None,
182
+ multiplicity=1,
183
+ ):
184
+ b = self.adaln(a, s)
185
+
186
+ k_in = b
187
+ if to_keys is not None:
188
+ k_in = to_keys(b)
189
+ mask = to_keys(mask.unsqueeze(-1)).squeeze(-1)
190
+
191
+ if self.pair_bias_attn:
192
+ b = self.pair_bias_attn(
193
+ s=b,
194
+ z=bias,
195
+ mask=mask,
196
+ multiplicity=multiplicity,
197
+ k_in=k_in,
198
+ )
199
+ else:
200
+ b = self.no_pair_bias_attn(s=b, mask=mask, k_in=k_in)
201
+
202
+ b = self.output_projection(s) * b
203
+
204
+ a = a + b
205
+ a = a + self.transition(a, s)
206
+
207
+ a = self.post_lnorm(a)
208
+ return a
209
+
210
+
211
+ class AtomTransformer(Module):
212
+ """Algorithm 7"""
213
+
214
+ def __init__(
215
+ self,
216
+ attn_window_queries,
217
+ attn_window_keys,
218
+ **diffusion_transformer_kwargs,
219
+ ):
220
+ super().__init__()
221
+ self.attn_window_queries = attn_window_queries
222
+ self.attn_window_keys = attn_window_keys
223
+ self.diffusion_transformer = DiffusionTransformer(
224
+ **diffusion_transformer_kwargs
225
+ )
226
+
227
+ def forward(
228
+ self,
229
+ q, # Float['b m d'],
230
+ c, # Float['b m ds'],
231
+ bias, # Float['b m m dp']
232
+ to_keys,
233
+ mask, # Bool['b m'] | None = None
234
+ multiplicity=1,
235
+ ):
236
+ W = self.attn_window_queries
237
+ H = self.attn_window_keys
238
+
239
+ B, N, D = q.shape
240
+ NW = N // W
241
+
242
+ # reshape tokens
243
+ q = q.view((B * NW, W, -1))
244
+ c = c.view((B * NW, W, -1))
245
+ mask = mask.view(B * NW, W)
246
+ bias = bias.view((bias.shape[0] * NW, W, H, -1))
247
+
248
+ to_keys_new = lambda x: to_keys(x.view(B, NW * W, -1)).view(B * NW, H, -1)
249
+
250
+ # main transformer
251
+ q = self.diffusion_transformer(
252
+ a=q,
253
+ s=c,
254
+ bias=bias,
255
+ mask=mask.float(),
256
+ multiplicity=multiplicity,
257
+ to_keys=to_keys_new,
258
+ )
259
+
260
+ q = q.view((B, NW * W, D))
261
+ return q