rc-foundry 0.1.6__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/inference_engines/checkpoint_registry.py +58 -11
- foundry/utils/alignment.py +10 -2
- foundry/utils/ddp.py +1 -1
- foundry/utils/logging.py +1 -1
- foundry/version.py +2 -2
- foundry_cli/download_checkpoints.py +66 -66
- {rc_foundry-0.1.6.dist-info → rc_foundry-0.1.9.dist-info}/METADATA +30 -21
- {rc_foundry-0.1.6.dist-info → rc_foundry-0.1.9.dist-info}/RECORD +31 -31
- rf3/cli.py +13 -4
- rf3/inference.py +3 -1
- rfd3/configs/datasets/train/pdb/af3_train_interface.yaml +1 -1
- rfd3/configs/inference_engine/rfdiffusion3.yaml +2 -2
- rfd3/configs/model/samplers/symmetry.yaml +1 -1
- rfd3/engine.py +28 -12
- rfd3/inference/datasets.py +1 -1
- rfd3/inference/input_parsing.py +32 -1
- rfd3/inference/legacy_input_parsing.py +17 -1
- rfd3/inference/parsing.py +1 -0
- rfd3/inference/symmetry/atom_array.py +78 -13
- rfd3/inference/symmetry/checks.py +62 -29
- rfd3/inference/symmetry/frames.py +256 -5
- rfd3/inference/symmetry/symmetry_utils.py +39 -61
- rfd3/model/inference_sampler.py +11 -1
- rfd3/model/layers/block_utils.py +33 -33
- rfd3/model/layers/chunked_pairwise.py +84 -82
- rfd3/run_inference.py +3 -1
- rfd3/transforms/symmetry.py +16 -7
- rfd3/utils/inference.py +21 -22
- {rc_foundry-0.1.6.dist-info → rc_foundry-0.1.9.dist-info}/WHEEL +0 -0
- {rc_foundry-0.1.6.dist-info → rc_foundry-0.1.9.dist-info}/entry_points.txt +0 -0
- {rc_foundry-0.1.6.dist-info → rc_foundry-0.1.9.dist-info}/licenses/LICENSE.md +0 -0
rfd3/model/layers/block_utils.py
CHANGED
|
@@ -118,14 +118,14 @@ def scatter_add_pair_features(P_LK_tgt, P_LK_indices, P_LA_src, P_LA_indices):
|
|
|
118
118
|
|
|
119
119
|
Parameters
|
|
120
120
|
----------
|
|
121
|
-
P_LK_indices : (
|
|
121
|
+
P_LK_indices : (B, L, k) LongTensor
|
|
122
122
|
Key indices | P_LK_indices[d, i, k] = global atom index for which atom i attends to.
|
|
123
|
-
P_LK : (
|
|
123
|
+
P_LK : (B, L, k, c) FloatTensor
|
|
124
124
|
Key features to scatter add into
|
|
125
125
|
|
|
126
|
-
P_LA_indices : (
|
|
126
|
+
P_LA_indices : (B, L, a) LongTensor
|
|
127
127
|
Additional feature indices to scatter into P_LK.
|
|
128
|
-
P_LA : (
|
|
128
|
+
P_LA : (B, L, a, c) FloatTensor
|
|
129
129
|
Features corresponding to P_LA.
|
|
130
130
|
|
|
131
131
|
Both index tensors contain indices representing D batch dim,
|
|
@@ -135,42 +135,42 @@ def scatter_add_pair_features(P_LK_tgt, P_LK_indices, P_LA_src, P_LA_indices):
|
|
|
135
135
|
|
|
136
136
|
"""
|
|
137
137
|
# Handle case when indices and P_LA don't have batch dimensions
|
|
138
|
-
|
|
138
|
+
B, L, k = P_LK_indices.shape
|
|
139
139
|
if P_LA_indices.ndim == 2:
|
|
140
|
-
P_LA_indices = P_LA_indices.unsqueeze(0).expand(
|
|
140
|
+
P_LA_indices = P_LA_indices.unsqueeze(0).expand(B, -1, -1)
|
|
141
141
|
if P_LA_src.ndim == 3:
|
|
142
|
-
P_LA_src = P_LA_src.unsqueeze(0).expand(
|
|
142
|
+
P_LA_src = P_LA_src.unsqueeze(0).expand(B, -1, -1)
|
|
143
143
|
assert (
|
|
144
144
|
P_LA_src.shape[-1] == P_LK_tgt.shape[-1]
|
|
145
145
|
), "Channel dims do not match, got: {} vs {}".format(
|
|
146
146
|
P_LA_src.shape[-1], P_LK_tgt.shape[-1]
|
|
147
147
|
)
|
|
148
148
|
|
|
149
|
-
matches = P_LA_indices.unsqueeze(-1) == P_LK_indices.unsqueeze(-2) # (
|
|
149
|
+
matches = P_LA_indices.unsqueeze(-1) == P_LK_indices.unsqueeze(-2) # (B, L, a, k)
|
|
150
150
|
if not torch.all(matches.sum(dim=(-1, -2)) >= 1):
|
|
151
151
|
raise ValueError("Found multiple scatter indices for some atoms")
|
|
152
152
|
elif not torch.all(matches.sum(dim=-1) <= 1):
|
|
153
153
|
raise ValueError("Did not find a scatter index for every atom")
|
|
154
|
-
k_indices = matches.long().argmax(dim=-1) # (
|
|
154
|
+
k_indices = matches.long().argmax(dim=-1) # (B, L, a)
|
|
155
155
|
scatter_indices = k_indices.unsqueeze(-1).expand(
|
|
156
156
|
-1, -1, -1, P_LK_tgt.shape[-1]
|
|
157
|
-
) # (
|
|
157
|
+
) # (B, L, a, c)
|
|
158
158
|
P_LK_tgt = P_LK_tgt.scatter_add(dim=2, index=scatter_indices, src=P_LA_src)
|
|
159
159
|
return P_LK_tgt
|
|
160
160
|
|
|
161
161
|
|
|
162
162
|
def _batched_gather(values: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
|
163
163
|
"""
|
|
164
|
-
values : (
|
|
165
|
-
idx : (
|
|
166
|
-
returns: (
|
|
164
|
+
values : (B, L, C)
|
|
165
|
+
idx : (B, L, k)
|
|
166
|
+
returns: (B, L, k, C)
|
|
167
167
|
"""
|
|
168
|
-
|
|
168
|
+
B, L, C = values.shape
|
|
169
169
|
k = idx.shape[-1]
|
|
170
170
|
|
|
171
|
-
# (
|
|
171
|
+
# (B, L, 1, C) → stride-0 along k → (B, L, k, C)
|
|
172
172
|
src = values.unsqueeze(2).expand(-1, -1, k, -1)
|
|
173
|
-
idx = idx.unsqueeze(-1).expand(-1, -1, -1, C) # (
|
|
173
|
+
idx = idx.unsqueeze(-1).expand(-1, -1, -1, C) # (B, L, k, C)
|
|
174
174
|
|
|
175
175
|
return torch.gather(src, 1, idx) # dim=1 is the L-axis
|
|
176
176
|
|
|
@@ -196,7 +196,7 @@ def create_attention_indices(
|
|
|
196
196
|
X_L = torch.randn(
|
|
197
197
|
(1, L, 3), device=device, dtype=torch.float
|
|
198
198
|
) # [L, 3] - random
|
|
199
|
-
D_LL = torch.cdist(X_L, X_L, p=2) # [
|
|
199
|
+
D_LL = torch.cdist(X_L, X_L, p=2) # [B, L, L] - pairwise atom distances
|
|
200
200
|
|
|
201
201
|
# Create attention indices using neighbour distances
|
|
202
202
|
base_mask = ~f["unindexing_pair_mask"][
|
|
@@ -231,7 +231,7 @@ def create_attention_indices(
|
|
|
231
231
|
k_max=k_actual,
|
|
232
232
|
chain_id=chain_ids,
|
|
233
233
|
base_mask=base_mask,
|
|
234
|
-
) # [
|
|
234
|
+
) # [B, L, k] | indices[b, i, j] = atom index for atom i to j-th attn query
|
|
235
235
|
|
|
236
236
|
return attn_indices
|
|
237
237
|
|
|
@@ -245,7 +245,7 @@ def get_sparse_attention_indices_with_inter_chain(
|
|
|
245
245
|
|
|
246
246
|
Args:
|
|
247
247
|
tok_idx: atom to token mapping
|
|
248
|
-
D_LL: pairwise distances [
|
|
248
|
+
D_LL: pairwise distances [B, L, L]
|
|
249
249
|
n_seq_neighbours: number of sequence neighbors
|
|
250
250
|
k_intra: number of intra-chain attention keys
|
|
251
251
|
k_inter: number of inter-chain attention keys
|
|
@@ -253,29 +253,29 @@ def get_sparse_attention_indices_with_inter_chain(
|
|
|
253
253
|
base_mask: base mask for valid pairs
|
|
254
254
|
|
|
255
255
|
Returns:
|
|
256
|
-
attn_indices: [
|
|
256
|
+
attn_indices: [B, L, k_total] where k_total = k_intra + k_inter
|
|
257
257
|
"""
|
|
258
|
-
|
|
258
|
+
B, L, _ = D_LL.shape
|
|
259
259
|
|
|
260
260
|
# Get regular intra-chain indices (limited to k_intra)
|
|
261
261
|
intra_indices = get_sparse_attention_indices(
|
|
262
262
|
tok_idx, D_LL, n_seq_neighbours, k_intra, chain_id, base_mask
|
|
263
|
-
) # [
|
|
263
|
+
) # [B, L, k_intra]
|
|
264
264
|
|
|
265
265
|
# Get inter-chain indices for clash avoidance
|
|
266
|
-
inter_indices = torch.zeros(
|
|
267
|
-
|
|
268
|
-
for
|
|
269
|
-
for
|
|
270
|
-
query_chain = chain_id[
|
|
266
|
+
inter_indices = torch.zeros(B, L, k_inter, dtype=torch.long, device=D_LL.device)
|
|
267
|
+
unique_chains = torch.unique(chain_id)
|
|
268
|
+
for b in range(B):
|
|
269
|
+
for c in unique_chains:
|
|
270
|
+
query_chain = chain_id[c]
|
|
271
271
|
|
|
272
272
|
# Find atoms from different chains
|
|
273
|
-
other_chain_mask = (chain_id != query_chain) & base_mask[
|
|
273
|
+
other_chain_mask = (chain_id != query_chain) & base_mask[c, :]
|
|
274
274
|
other_chain_atoms = torch.where(other_chain_mask)[0]
|
|
275
275
|
|
|
276
276
|
if len(other_chain_atoms) > 0:
|
|
277
277
|
# Get distances to other chains
|
|
278
|
-
distances_to_other = D_LL[
|
|
278
|
+
distances_to_other = D_LL[b, c, other_chain_atoms]
|
|
279
279
|
|
|
280
280
|
# Select k_inter closest atoms from other chains
|
|
281
281
|
n_select = min(k_inter, len(other_chain_atoms))
|
|
@@ -283,23 +283,23 @@ def get_sparse_attention_indices_with_inter_chain(
|
|
|
283
283
|
selected_atoms = other_chain_atoms[closest_idx]
|
|
284
284
|
|
|
285
285
|
# Fill inter-chain indices
|
|
286
|
-
inter_indices[
|
|
286
|
+
inter_indices[b, c, :n_select] = selected_atoms
|
|
287
287
|
# Pad with random atoms if needed
|
|
288
288
|
if n_select < k_inter:
|
|
289
289
|
padding = torch.randint(
|
|
290
290
|
0, L, (k_inter - n_select,), device=D_LL.device
|
|
291
291
|
)
|
|
292
|
-
inter_indices[
|
|
292
|
+
inter_indices[b, c, n_select:] = padding
|
|
293
293
|
else:
|
|
294
294
|
# No other chains found, fill with random indices
|
|
295
|
-
inter_indices[
|
|
295
|
+
inter_indices[b, c, :] = torch.randint(
|
|
296
296
|
0, L, (k_inter,), device=D_LL.device
|
|
297
297
|
)
|
|
298
298
|
|
|
299
299
|
# Combine intra and inter chain indices
|
|
300
300
|
combined_indices = torch.cat(
|
|
301
301
|
[intra_indices, inter_indices], dim=-1
|
|
302
|
-
) # [
|
|
302
|
+
) # [B, L, k_total]
|
|
303
303
|
|
|
304
304
|
return combined_indices
|
|
305
305
|
|
|
@@ -30,32 +30,32 @@ class ChunkedPositionPairDistEmbedder(nn.Module):
|
|
|
30
30
|
|
|
31
31
|
def compute_pairs_chunked(
|
|
32
32
|
self,
|
|
33
|
-
query_pos: torch.Tensor, # [
|
|
34
|
-
key_pos: torch.Tensor, # [
|
|
35
|
-
valid_mask: torch.Tensor, # [
|
|
33
|
+
query_pos: torch.Tensor, # [B, 3]
|
|
34
|
+
key_pos: torch.Tensor, # [B, k, 3]
|
|
35
|
+
valid_mask: torch.Tensor, # [B, k, 1]
|
|
36
36
|
) -> torch.Tensor:
|
|
37
37
|
"""
|
|
38
38
|
Compute pairwise embeddings for specific query-key pairs.
|
|
39
39
|
|
|
40
40
|
Args:
|
|
41
|
-
query_pos: Query positions [
|
|
42
|
-
key_pos: Key positions [
|
|
43
|
-
valid_mask: Valid pair mask [
|
|
41
|
+
query_pos: Query positions [B, 3]
|
|
42
|
+
key_pos: Key positions [B, k, 3]
|
|
43
|
+
valid_mask: Valid pair mask [B, k, 1]
|
|
44
44
|
|
|
45
45
|
Returns:
|
|
46
|
-
P_sparse: Pairwise embeddings [
|
|
46
|
+
P_sparse: Pairwise embeddings [B, k, c_atompair]
|
|
47
47
|
"""
|
|
48
|
-
|
|
48
|
+
B, k = key_pos.shape[:2]
|
|
49
49
|
|
|
50
|
-
# Compute pairwise distances: [
|
|
51
|
-
D_pairs = query_pos.unsqueeze(1) - key_pos # [
|
|
50
|
+
# Compute pairwise distances: [B, k, 3]
|
|
51
|
+
D_pairs = query_pos.unsqueeze(1) - key_pos # [B, 1, 3] - [B, k, 3] = [B, k, 3]
|
|
52
52
|
|
|
53
53
|
if self.embed_frame:
|
|
54
54
|
# Embed pairwise distances
|
|
55
|
-
P_pairs = self.process_d(D_pairs) * valid_mask # [
|
|
55
|
+
P_pairs = self.process_d(D_pairs) * valid_mask # [B, k, c_atompair]
|
|
56
56
|
|
|
57
57
|
# Add inverse distance embedding
|
|
58
|
-
norm_sq = torch.linalg.norm(D_pairs, dim=-1, keepdim=True) ** 2 # [
|
|
58
|
+
norm_sq = torch.linalg.norm(D_pairs, dim=-1, keepdim=True) ** 2 # [B, k, 1]
|
|
59
59
|
inv_dist = 1 / (1 + norm_sq)
|
|
60
60
|
P_pairs = P_pairs + self.process_inverse_dist(inv_dist) * valid_mask
|
|
61
61
|
|
|
@@ -95,19 +95,19 @@ class ChunkedSinusoidalDistEmbed(nn.Module):
|
|
|
95
95
|
|
|
96
96
|
def compute_pairs_chunked(
|
|
97
97
|
self,
|
|
98
|
-
query_pos: torch.Tensor, # [
|
|
99
|
-
key_pos: torch.Tensor, # [
|
|
100
|
-
valid_mask: torch.Tensor, # [
|
|
98
|
+
query_pos: torch.Tensor, # [B, 3]
|
|
99
|
+
key_pos: torch.Tensor, # [B, k, 3]
|
|
100
|
+
valid_mask: torch.Tensor, # [B, k, 1]
|
|
101
101
|
) -> torch.Tensor:
|
|
102
102
|
"""
|
|
103
103
|
Compute sinusoidal distance embeddings for specific query-key pairs.
|
|
104
104
|
"""
|
|
105
|
-
|
|
105
|
+
B, k = key_pos.shape[:2]
|
|
106
106
|
device = query_pos.device
|
|
107
107
|
|
|
108
108
|
# Compute pairwise distances
|
|
109
|
-
D_pairs = query_pos.unsqueeze(1) - key_pos # [
|
|
110
|
-
dist_matrix = torch.linalg.norm(D_pairs, dim=-1) # [
|
|
109
|
+
D_pairs = query_pos.unsqueeze(1) - key_pos # [B, k, 3]
|
|
110
|
+
dist_matrix = torch.linalg.norm(D_pairs, dim=-1) # [B, k]
|
|
111
111
|
|
|
112
112
|
# Sinusoidal embedding
|
|
113
113
|
half_dim = self.n_freqs
|
|
@@ -117,13 +117,13 @@ class ChunkedSinusoidalDistEmbed(nn.Module):
|
|
|
117
117
|
/ half_dim
|
|
118
118
|
) # [n_freqs]
|
|
119
119
|
|
|
120
|
-
angles = dist_matrix.unsqueeze(-1) * freq # [
|
|
120
|
+
angles = dist_matrix.unsqueeze(-1) * freq # [B, k, n_freqs]
|
|
121
121
|
sin_embed = torch.sin(angles)
|
|
122
122
|
cos_embed = torch.cos(angles)
|
|
123
|
-
sincos_embed = torch.cat([sin_embed, cos_embed], dim=-1) # [
|
|
123
|
+
sincos_embed = torch.cat([sin_embed, cos_embed], dim=-1) # [B, k, 2*n_freqs]
|
|
124
124
|
|
|
125
125
|
# Linear projection
|
|
126
|
-
P_pairs = self.output_proj(sincos_embed) # [
|
|
126
|
+
P_pairs = self.output_proj(sincos_embed) # [B, k, c_atompair]
|
|
127
127
|
P_pairs = P_pairs * valid_mask
|
|
128
128
|
|
|
129
129
|
# Add linear embedding of valid mask
|
|
@@ -191,8 +191,8 @@ class ChunkedPairwiseEmbedder(nn.Module):
|
|
|
191
191
|
def forward_chunked(
|
|
192
192
|
self,
|
|
193
193
|
f: dict,
|
|
194
|
-
indices: torch.Tensor, # [
|
|
195
|
-
C_L: torch.Tensor, # [
|
|
194
|
+
indices: torch.Tensor, # [B, L, k] - sparse attention indices
|
|
195
|
+
C_L: torch.Tensor, # [B, L, c_token] - atom features
|
|
196
196
|
Z_init_II: torch.Tensor, # [I, I, c_z] - token pair features
|
|
197
197
|
tok_idx: torch.Tensor, # [L] - atom to token mapping
|
|
198
198
|
) -> torch.Tensor:
|
|
@@ -208,20 +208,20 @@ class ChunkedPairwiseEmbedder(nn.Module):
|
|
|
208
208
|
|
|
209
209
|
Args:
|
|
210
210
|
f: Feature dictionary
|
|
211
|
-
indices: Sparse attention indices [
|
|
212
|
-
C_L: Atom-level features [
|
|
211
|
+
indices: Sparse attention indices [B, L, k]
|
|
212
|
+
C_L: Atom-level features [B, L, c_token]
|
|
213
213
|
Z_init_II: Token-level pair features [I, I, c_z]
|
|
214
214
|
tok_idx: Atom to token mapping [L]
|
|
215
215
|
|
|
216
216
|
Returns:
|
|
217
|
-
P_LL_sparse: Sparse pairwise features [
|
|
217
|
+
P_LL_sparse: Sparse pairwise features [B, L, k, c_atompair]
|
|
218
218
|
"""
|
|
219
|
-
|
|
219
|
+
B, L, k = indices.shape
|
|
220
220
|
device = indices.device
|
|
221
221
|
|
|
222
222
|
# Initialize sparse P_LL
|
|
223
223
|
P_LL_sparse = torch.zeros(
|
|
224
|
-
|
|
224
|
+
B, L, k, self.c_atompair, device=device, dtype=C_L.dtype
|
|
225
225
|
)
|
|
226
226
|
|
|
227
227
|
# Handle both batched and non-batched C_L
|
|
@@ -237,71 +237,72 @@ class ChunkedPairwiseEmbedder(nn.Module):
|
|
|
237
237
|
if valid_indices.dim() == 2: # [L, k] - add batch dimension
|
|
238
238
|
valid_indices = valid_indices.unsqueeze(0).expand(
|
|
239
239
|
C_L.shape[0], -1, -1
|
|
240
|
-
) # [
|
|
240
|
+
) # [B, L, k]
|
|
241
241
|
|
|
242
242
|
# 1. Motif position embedding (if exists)
|
|
243
243
|
if self.motif_pos_embedder is not None and "motif_pos" in f:
|
|
244
244
|
motif_pos = f["motif_pos"] # [L, 3]
|
|
245
245
|
is_motif = f["is_motif_atom_with_fixed_coord"] # [L]
|
|
246
|
-
|
|
246
|
+
is_motif_idx = torch.where(is_motif)[0]
|
|
247
247
|
# For each query position
|
|
248
|
-
for l in
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
P_LL_sparse[:, l, :, :] += motif_pairs
|
|
248
|
+
for l in is_motif_idx:
|
|
249
|
+
key_indices = valid_indices[:, l, :] # [B, k] - use clamped indices
|
|
250
|
+
key_pos = motif_pos[key_indices] # [B, k, 3]
|
|
251
|
+
query_pos = motif_pos[l].unsqueeze(0).expand(B, -1) # [B, 3]
|
|
252
|
+
|
|
253
|
+
# Valid mask: both query and keys must be motif
|
|
254
|
+
key_is_motif = is_motif[key_indices] # [B, k]
|
|
255
|
+
valid_mask = key_is_motif.unsqueeze(-1).float() # [B, k, 1]
|
|
256
|
+
|
|
257
|
+
if valid_mask.sum() > 0:
|
|
258
|
+
motif_pairs = self.motif_pos_embedder.compute_pairs_chunked(
|
|
259
|
+
query_pos, key_pos, valid_mask
|
|
260
|
+
)
|
|
261
|
+
P_LL_sparse[:, l, :, :] += motif_pairs
|
|
263
262
|
|
|
264
263
|
# 2. Reference position embedding (if exists)
|
|
265
264
|
if self.ref_pos_embedder is not None and "ref_pos" in f:
|
|
266
265
|
ref_pos = f["ref_pos"] # [L, 3]
|
|
267
266
|
ref_space_uid = f["ref_space_uid"] # [L]
|
|
268
267
|
is_motif_seq = f["is_motif_atom_with_fixed_seq"] # [L]
|
|
269
|
-
|
|
270
|
-
for l in
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
P_LL_sparse[:, l, :, :] += ref_pairs
|
|
268
|
+
is_motif_seq_idx = torch.where(is_motif_seq)[0]
|
|
269
|
+
for l in is_motif_seq_idx:
|
|
270
|
+
key_indices = valid_indices[:, l, :] # [B, k] - use clamped indices
|
|
271
|
+
key_pos = ref_pos[key_indices] # [B, k, 3]
|
|
272
|
+
query_pos = ref_pos[l].unsqueeze(0).expand(B, -1) # [B, 3]
|
|
273
|
+
|
|
274
|
+
# Valid mask: same token and both have sequence
|
|
275
|
+
key_space_uid = ref_space_uid[key_indices] # [B, k]
|
|
276
|
+
key_is_motif_seq = is_motif_seq[key_indices] # [B, k]
|
|
277
|
+
|
|
278
|
+
same_token = key_space_uid == ref_space_uid[l] # [B, k]
|
|
279
|
+
valid_mask = (
|
|
280
|
+
(same_token & key_is_motif_seq).unsqueeze(-1).float()
|
|
281
|
+
) # [B, k, 1]
|
|
282
|
+
|
|
283
|
+
if valid_mask.sum() > 0:
|
|
284
|
+
ref_pairs = self.ref_pos_embedder.compute_pairs_chunked(
|
|
285
|
+
query_pos, key_pos, valid_mask
|
|
286
|
+
)
|
|
287
|
+
P_LL_sparse[:, l, :, :] += ref_pairs
|
|
290
288
|
|
|
291
289
|
# 3. Single embedding terms (broadcasted)
|
|
290
|
+
# Expand C_L to match valid_indices batch dimension
|
|
291
|
+
if C_L.shape[0] != B:
|
|
292
|
+
C_L = C_L.expand(B, -1, -1) # [B, L, c_token]
|
|
292
293
|
# Gather key features for each query
|
|
294
|
+
C_L_queries = C_L.unsqueeze(2).expand(-1, -1, k, -1) # [B, L, k, c_token]
|
|
293
295
|
C_L_keys = torch.gather(
|
|
294
|
-
|
|
296
|
+
C_L_queries,
|
|
295
297
|
1,
|
|
296
298
|
valid_indices.unsqueeze(-1).expand(-1, -1, -1, C_L.shape[-1]),
|
|
297
|
-
) # [
|
|
298
|
-
C_L_queries = C_L.unsqueeze(2).expand(-1, -1, k, -1) # [D, L, k, c_token]
|
|
299
|
+
) # [B, L, k, c_token]
|
|
299
300
|
|
|
300
301
|
# Add single embeddings - match standard implementation structure
|
|
301
302
|
# Standard does: self.process_single_l(C_L).unsqueeze(-2) + self.process_single_m(C_L).unsqueeze(-3)
|
|
302
|
-
# We need to broadcast from [
|
|
303
|
-
single_l = self.process_single_l(C_L_queries) # [
|
|
304
|
-
single_m = self.process_single_m(C_L_keys) # [
|
|
303
|
+
# We need to broadcast from [B, L, k, c_atompair] to match this
|
|
304
|
+
single_l = self.process_single_l(C_L_queries) # [B, L, k, c_atompair]
|
|
305
|
+
single_m = self.process_single_m(C_L_keys) # [B, L, k, c_atompair]
|
|
305
306
|
P_LL_sparse += single_l + single_m
|
|
306
307
|
|
|
307
308
|
# 4. Token pair features Z_init_II
|
|
@@ -312,15 +313,16 @@ class ChunkedPairwiseEmbedder(nn.Module):
|
|
|
312
313
|
else:
|
|
313
314
|
tok_idx_expanded = tok_idx
|
|
314
315
|
|
|
315
|
-
|
|
316
|
+
# Expand tok_idx_expanded to match valid_indices batch dimension
|
|
317
|
+
if tok_idx_expanded.shape[0] != B:
|
|
318
|
+
tok_idx_expanded = tok_idx_expanded.expand(B, -1) # [B, L]
|
|
319
|
+
tok_queries = tok_idx_expanded.unsqueeze(2).expand(-1, -1, k) # [B, L, k]
|
|
316
320
|
# Use valid_indices for token mapping as well
|
|
317
|
-
tok_keys = torch.gather(
|
|
318
|
-
tok_idx_expanded.unsqueeze(2).expand(-1, -1, k), 1, valid_indices
|
|
319
|
-
) # [D, L, k]
|
|
321
|
+
tok_keys = torch.gather(tok_queries, 1, valid_indices) # [B, L, k]
|
|
320
322
|
|
|
321
323
|
# Gather Z_init_II[tok_queries, tok_keys] with safe indexing
|
|
322
324
|
# Z_init_II shape is [I, I, c_z] (3D), not 4D
|
|
323
|
-
# tok_queries shape: [
|
|
325
|
+
# tok_queries shape: [B, L, k] - each value is a token index
|
|
324
326
|
# We want: Z_init_II[tok_queries[d,l,k], tok_keys[d,l,k], :] for all d,l,k
|
|
325
327
|
|
|
326
328
|
I_z, I_z2, c_z = Z_init_II.shape
|
|
@@ -338,20 +340,20 @@ class ChunkedPairwiseEmbedder(nn.Module):
|
|
|
338
340
|
# Then we need to gather the sparse version
|
|
339
341
|
|
|
340
342
|
Z_pairs_processed = torch.zeros(
|
|
341
|
-
|
|
343
|
+
B, L, k, self.c_atompair, device=device, dtype=Z_processed.dtype
|
|
342
344
|
)
|
|
343
345
|
|
|
344
|
-
for
|
|
346
|
+
for b in range(B):
|
|
345
347
|
# For this batch, get the token queries and keys
|
|
346
|
-
tq = tok_queries[
|
|
347
|
-
tk = tok_keys[
|
|
348
|
+
tq = tok_queries[b] # [L, k]
|
|
349
|
+
tk = tok_keys[b] # [L, k]
|
|
348
350
|
|
|
349
351
|
# Ensure indices are within bounds
|
|
350
352
|
tq = torch.clamp(tq, 0, I_z - 1)
|
|
351
353
|
tk = torch.clamp(tk, 0, I_z2 - 1)
|
|
352
354
|
|
|
353
355
|
# Apply the double token indexing like standard implementation
|
|
354
|
-
Z_pairs_processed[
|
|
356
|
+
Z_pairs_processed[b] = Z_processed[tq, tk] # [L, k, c_atompair]
|
|
355
357
|
|
|
356
358
|
P_LL_sparse += Z_pairs_processed
|
|
357
359
|
|
rfd3/run_inference.py
CHANGED
|
@@ -12,7 +12,9 @@ load_dotenv(override=True)
|
|
|
12
12
|
|
|
13
13
|
# For pip-installed package, configs should be relative to this file
|
|
14
14
|
# Adjust this path based on where configs are bundled in the package
|
|
15
|
-
_config_path = os.path.join(
|
|
15
|
+
_config_path = os.path.join(
|
|
16
|
+
os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "configs"
|
|
17
|
+
)
|
|
16
18
|
|
|
17
19
|
|
|
18
20
|
@hydra.main(
|
rfd3/transforms/symmetry.py
CHANGED
|
@@ -60,13 +60,22 @@ class AddSymmetryFeats(Transform):
|
|
|
60
60
|
)
|
|
61
61
|
TIDs = torch.from_numpy(atom_array.get_annotation("sym_transform_id"))
|
|
62
62
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
63
|
+
# Get unique transforms by TID (more robust than unique_consecutive on each array)
|
|
64
|
+
unique_TIDs, inverse_indices = torch.unique(TIDs, return_inverse=True)
|
|
65
|
+
|
|
66
|
+
# Get the first occurrence of each unique TID
|
|
67
|
+
first_occurrence = torch.zeros(len(unique_TIDs), dtype=torch.long)
|
|
68
|
+
for i in range(len(TIDs)):
|
|
69
|
+
tid_idx = inverse_indices[i]
|
|
70
|
+
if first_occurrence[tid_idx] == 0 or i < first_occurrence[tid_idx]:
|
|
71
|
+
first_occurrence[tid_idx] = i
|
|
72
|
+
|
|
73
|
+
# Extract Ori, X, Y for each unique transform
|
|
74
|
+
Oris = Oris[first_occurrence]
|
|
75
|
+
Xs = Xs[first_occurrence]
|
|
76
|
+
Ys = Ys[first_occurrence]
|
|
77
|
+
TIDs = unique_TIDs
|
|
78
|
+
|
|
70
79
|
Rs, Ts = framecoords_to_RTs(Oris, Xs, Ys)
|
|
71
80
|
|
|
72
81
|
for R, T, transform_id in zip(Rs, Ts, TIDs):
|
rfd3/utils/inference.py
CHANGED
|
@@ -3,7 +3,6 @@ Utilities for inference input preparation
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
import logging
|
|
6
|
-
import os
|
|
7
6
|
from os import PathLike
|
|
8
7
|
from typing import Dict
|
|
9
8
|
|
|
@@ -365,30 +364,30 @@ def inference_load_(
|
|
|
365
364
|
return data
|
|
366
365
|
|
|
367
366
|
|
|
368
|
-
def
|
|
367
|
+
def ensure_inference_sampler_matches_design_spec(
|
|
368
|
+
design_spec: dict, inference_sampler: dict | None = None
|
|
369
|
+
):
|
|
369
370
|
"""
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
path: None or file to which the input is relative to.
|
|
371
|
+
Ensure the inference sampler is set to the correct sampler for the design specification.
|
|
372
|
+
Args:
|
|
373
|
+
design_spec: Design specification dictionary
|
|
374
|
+
inference_sampler: Inference sampler dictionary
|
|
375
375
|
"""
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
376
|
+
has_symmetry_specification = [
|
|
377
|
+
True
|
|
378
|
+
if "symmetry" in item.keys() and item.get("symmetry") is not None
|
|
379
|
+
else False
|
|
380
|
+
for item in design_spec.values()
|
|
381
|
+
]
|
|
382
|
+
if any(has_symmetry_specification):
|
|
383
|
+
if (
|
|
384
|
+
inference_sampler is None
|
|
385
|
+
or inference_sampler.get("kind", "default") != "symmetry"
|
|
386
|
+
):
|
|
387
|
+
raise ValueError(
|
|
388
|
+
"You requested for symmetric designs, but inference sampler is not set to symmetry. "
|
|
389
|
+
"Please add inference_sampler.kind='symmetry' to your command."
|
|
380
390
|
)
|
|
381
|
-
)
|
|
382
|
-
if "input" not in args or not exists(args["input"]):
|
|
383
|
-
return args
|
|
384
|
-
input = args["input"]
|
|
385
|
-
if not os.path.isabs(input):
|
|
386
|
-
input = os.path.abspath(os.path.join(os.path.dirname(path), input))
|
|
387
|
-
ranked_logger.info(
|
|
388
|
-
f"Input source path is relative, converted to absolute path: {input}"
|
|
389
|
-
)
|
|
390
|
-
args["input"] = input
|
|
391
|
-
return args
|
|
392
391
|
|
|
393
392
|
|
|
394
393
|
#################################################################################
|
|
File without changes
|
|
File without changes
|
|
File without changes
|