rc-foundry 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/utils/ddp.py +1 -1
- foundry/utils/logging.py +1 -1
- foundry/version.py +2 -2
- {rc_foundry-0.1.7.dist-info → rc_foundry-0.1.9.dist-info}/METADATA +6 -2
- {rc_foundry-0.1.7.dist-info → rc_foundry-0.1.9.dist-info}/RECORD +22 -22
- rf3/cli.py +13 -4
- rf3/inference.py +3 -1
- rfd3/engine.py +11 -3
- rfd3/inference/datasets.py +1 -1
- rfd3/inference/input_parsing.py +31 -0
- rfd3/inference/symmetry/atom_array.py +78 -9
- rfd3/inference/symmetry/checks.py +12 -4
- rfd3/inference/symmetry/frames.py +248 -0
- rfd3/inference/symmetry/symmetry_utils.py +5 -5
- rfd3/model/inference_sampler.py +11 -1
- rfd3/model/layers/block_utils.py +33 -33
- rfd3/model/layers/chunked_pairwise.py +84 -82
- rfd3/transforms/symmetry.py +16 -7
- rfd3/utils/inference.py +4 -28
- {rc_foundry-0.1.7.dist-info → rc_foundry-0.1.9.dist-info}/WHEEL +0 -0
- {rc_foundry-0.1.7.dist-info → rc_foundry-0.1.9.dist-info}/entry_points.txt +0 -0
- {rc_foundry-0.1.7.dist-info → rc_foundry-0.1.9.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -30,32 +30,32 @@ class ChunkedPositionPairDistEmbedder(nn.Module):
|
|
|
30
30
|
|
|
31
31
|
def compute_pairs_chunked(
|
|
32
32
|
self,
|
|
33
|
-
query_pos: torch.Tensor, # [
|
|
34
|
-
key_pos: torch.Tensor, # [
|
|
35
|
-
valid_mask: torch.Tensor, # [
|
|
33
|
+
query_pos: torch.Tensor, # [B, 3]
|
|
34
|
+
key_pos: torch.Tensor, # [B, k, 3]
|
|
35
|
+
valid_mask: torch.Tensor, # [B, k, 1]
|
|
36
36
|
) -> torch.Tensor:
|
|
37
37
|
"""
|
|
38
38
|
Compute pairwise embeddings for specific query-key pairs.
|
|
39
39
|
|
|
40
40
|
Args:
|
|
41
|
-
query_pos: Query positions [
|
|
42
|
-
key_pos: Key positions [
|
|
43
|
-
valid_mask: Valid pair mask [
|
|
41
|
+
query_pos: Query positions [B, 3]
|
|
42
|
+
key_pos: Key positions [B, k, 3]
|
|
43
|
+
valid_mask: Valid pair mask [B, k, 1]
|
|
44
44
|
|
|
45
45
|
Returns:
|
|
46
|
-
P_sparse: Pairwise embeddings [
|
|
46
|
+
P_sparse: Pairwise embeddings [B, k, c_atompair]
|
|
47
47
|
"""
|
|
48
|
-
|
|
48
|
+
B, k = key_pos.shape[:2]
|
|
49
49
|
|
|
50
|
-
# Compute pairwise distances: [
|
|
51
|
-
D_pairs = query_pos.unsqueeze(1) - key_pos # [
|
|
50
|
+
# Compute pairwise distances: [B, k, 3]
|
|
51
|
+
D_pairs = query_pos.unsqueeze(1) - key_pos # [B, 1, 3] - [B, k, 3] = [B, k, 3]
|
|
52
52
|
|
|
53
53
|
if self.embed_frame:
|
|
54
54
|
# Embed pairwise distances
|
|
55
|
-
P_pairs = self.process_d(D_pairs) * valid_mask # [
|
|
55
|
+
P_pairs = self.process_d(D_pairs) * valid_mask # [B, k, c_atompair]
|
|
56
56
|
|
|
57
57
|
# Add inverse distance embedding
|
|
58
|
-
norm_sq = torch.linalg.norm(D_pairs, dim=-1, keepdim=True) ** 2 # [
|
|
58
|
+
norm_sq = torch.linalg.norm(D_pairs, dim=-1, keepdim=True) ** 2 # [B, k, 1]
|
|
59
59
|
inv_dist = 1 / (1 + norm_sq)
|
|
60
60
|
P_pairs = P_pairs + self.process_inverse_dist(inv_dist) * valid_mask
|
|
61
61
|
|
|
@@ -95,19 +95,19 @@ class ChunkedSinusoidalDistEmbed(nn.Module):
|
|
|
95
95
|
|
|
96
96
|
def compute_pairs_chunked(
|
|
97
97
|
self,
|
|
98
|
-
query_pos: torch.Tensor, # [
|
|
99
|
-
key_pos: torch.Tensor, # [
|
|
100
|
-
valid_mask: torch.Tensor, # [
|
|
98
|
+
query_pos: torch.Tensor, # [B, 3]
|
|
99
|
+
key_pos: torch.Tensor, # [B, k, 3]
|
|
100
|
+
valid_mask: torch.Tensor, # [B, k, 1]
|
|
101
101
|
) -> torch.Tensor:
|
|
102
102
|
"""
|
|
103
103
|
Compute sinusoidal distance embeddings for specific query-key pairs.
|
|
104
104
|
"""
|
|
105
|
-
|
|
105
|
+
B, k = key_pos.shape[:2]
|
|
106
106
|
device = query_pos.device
|
|
107
107
|
|
|
108
108
|
# Compute pairwise distances
|
|
109
|
-
D_pairs = query_pos.unsqueeze(1) - key_pos # [
|
|
110
|
-
dist_matrix = torch.linalg.norm(D_pairs, dim=-1) # [
|
|
109
|
+
D_pairs = query_pos.unsqueeze(1) - key_pos # [B, k, 3]
|
|
110
|
+
dist_matrix = torch.linalg.norm(D_pairs, dim=-1) # [B, k]
|
|
111
111
|
|
|
112
112
|
# Sinusoidal embedding
|
|
113
113
|
half_dim = self.n_freqs
|
|
@@ -117,13 +117,13 @@ class ChunkedSinusoidalDistEmbed(nn.Module):
|
|
|
117
117
|
/ half_dim
|
|
118
118
|
) # [n_freqs]
|
|
119
119
|
|
|
120
|
-
angles = dist_matrix.unsqueeze(-1) * freq # [
|
|
120
|
+
angles = dist_matrix.unsqueeze(-1) * freq # [B, k, n_freqs]
|
|
121
121
|
sin_embed = torch.sin(angles)
|
|
122
122
|
cos_embed = torch.cos(angles)
|
|
123
|
-
sincos_embed = torch.cat([sin_embed, cos_embed], dim=-1) # [
|
|
123
|
+
sincos_embed = torch.cat([sin_embed, cos_embed], dim=-1) # [B, k, 2*n_freqs]
|
|
124
124
|
|
|
125
125
|
# Linear projection
|
|
126
|
-
P_pairs = self.output_proj(sincos_embed) # [
|
|
126
|
+
P_pairs = self.output_proj(sincos_embed) # [B, k, c_atompair]
|
|
127
127
|
P_pairs = P_pairs * valid_mask
|
|
128
128
|
|
|
129
129
|
# Add linear embedding of valid mask
|
|
@@ -191,8 +191,8 @@ class ChunkedPairwiseEmbedder(nn.Module):
|
|
|
191
191
|
def forward_chunked(
|
|
192
192
|
self,
|
|
193
193
|
f: dict,
|
|
194
|
-
indices: torch.Tensor, # [
|
|
195
|
-
C_L: torch.Tensor, # [
|
|
194
|
+
indices: torch.Tensor, # [B, L, k] - sparse attention indices
|
|
195
|
+
C_L: torch.Tensor, # [B, L, c_token] - atom features
|
|
196
196
|
Z_init_II: torch.Tensor, # [I, I, c_z] - token pair features
|
|
197
197
|
tok_idx: torch.Tensor, # [L] - atom to token mapping
|
|
198
198
|
) -> torch.Tensor:
|
|
@@ -208,20 +208,20 @@ class ChunkedPairwiseEmbedder(nn.Module):
|
|
|
208
208
|
|
|
209
209
|
Args:
|
|
210
210
|
f: Feature dictionary
|
|
211
|
-
indices: Sparse attention indices [
|
|
212
|
-
C_L: Atom-level features [
|
|
211
|
+
indices: Sparse attention indices [B, L, k]
|
|
212
|
+
C_L: Atom-level features [B, L, c_token]
|
|
213
213
|
Z_init_II: Token-level pair features [I, I, c_z]
|
|
214
214
|
tok_idx: Atom to token mapping [L]
|
|
215
215
|
|
|
216
216
|
Returns:
|
|
217
|
-
P_LL_sparse: Sparse pairwise features [
|
|
217
|
+
P_LL_sparse: Sparse pairwise features [B, L, k, c_atompair]
|
|
218
218
|
"""
|
|
219
|
-
|
|
219
|
+
B, L, k = indices.shape
|
|
220
220
|
device = indices.device
|
|
221
221
|
|
|
222
222
|
# Initialize sparse P_LL
|
|
223
223
|
P_LL_sparse = torch.zeros(
|
|
224
|
-
|
|
224
|
+
B, L, k, self.c_atompair, device=device, dtype=C_L.dtype
|
|
225
225
|
)
|
|
226
226
|
|
|
227
227
|
# Handle both batched and non-batched C_L
|
|
@@ -237,71 +237,72 @@ class ChunkedPairwiseEmbedder(nn.Module):
|
|
|
237
237
|
if valid_indices.dim() == 2: # [L, k] - add batch dimension
|
|
238
238
|
valid_indices = valid_indices.unsqueeze(0).expand(
|
|
239
239
|
C_L.shape[0], -1, -1
|
|
240
|
-
) # [
|
|
240
|
+
) # [B, L, k]
|
|
241
241
|
|
|
242
242
|
# 1. Motif position embedding (if exists)
|
|
243
243
|
if self.motif_pos_embedder is not None and "motif_pos" in f:
|
|
244
244
|
motif_pos = f["motif_pos"] # [L, 3]
|
|
245
245
|
is_motif = f["is_motif_atom_with_fixed_coord"] # [L]
|
|
246
|
-
|
|
246
|
+
is_motif_idx = torch.where(is_motif)[0]
|
|
247
247
|
# For each query position
|
|
248
|
-
for l in
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
P_LL_sparse[:, l, :, :] += motif_pairs
|
|
248
|
+
for l in is_motif_idx:
|
|
249
|
+
key_indices = valid_indices[:, l, :] # [B, k] - use clamped indices
|
|
250
|
+
key_pos = motif_pos[key_indices] # [B, k, 3]
|
|
251
|
+
query_pos = motif_pos[l].unsqueeze(0).expand(B, -1) # [B, 3]
|
|
252
|
+
|
|
253
|
+
# Valid mask: both query and keys must be motif
|
|
254
|
+
key_is_motif = is_motif[key_indices] # [B, k]
|
|
255
|
+
valid_mask = key_is_motif.unsqueeze(-1).float() # [B, k, 1]
|
|
256
|
+
|
|
257
|
+
if valid_mask.sum() > 0:
|
|
258
|
+
motif_pairs = self.motif_pos_embedder.compute_pairs_chunked(
|
|
259
|
+
query_pos, key_pos, valid_mask
|
|
260
|
+
)
|
|
261
|
+
P_LL_sparse[:, l, :, :] += motif_pairs
|
|
263
262
|
|
|
264
263
|
# 2. Reference position embedding (if exists)
|
|
265
264
|
if self.ref_pos_embedder is not None and "ref_pos" in f:
|
|
266
265
|
ref_pos = f["ref_pos"] # [L, 3]
|
|
267
266
|
ref_space_uid = f["ref_space_uid"] # [L]
|
|
268
267
|
is_motif_seq = f["is_motif_atom_with_fixed_seq"] # [L]
|
|
269
|
-
|
|
270
|
-
for l in
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
P_LL_sparse[:, l, :, :] += ref_pairs
|
|
268
|
+
is_motif_seq_idx = torch.where(is_motif_seq)[0]
|
|
269
|
+
for l in is_motif_seq_idx:
|
|
270
|
+
key_indices = valid_indices[:, l, :] # [B, k] - use clamped indices
|
|
271
|
+
key_pos = ref_pos[key_indices] # [B, k, 3]
|
|
272
|
+
query_pos = ref_pos[l].unsqueeze(0).expand(B, -1) # [B, 3]
|
|
273
|
+
|
|
274
|
+
# Valid mask: same token and both have sequence
|
|
275
|
+
key_space_uid = ref_space_uid[key_indices] # [B, k]
|
|
276
|
+
key_is_motif_seq = is_motif_seq[key_indices] # [B, k]
|
|
277
|
+
|
|
278
|
+
same_token = key_space_uid == ref_space_uid[l] # [B, k]
|
|
279
|
+
valid_mask = (
|
|
280
|
+
(same_token & key_is_motif_seq).unsqueeze(-1).float()
|
|
281
|
+
) # [B, k, 1]
|
|
282
|
+
|
|
283
|
+
if valid_mask.sum() > 0:
|
|
284
|
+
ref_pairs = self.ref_pos_embedder.compute_pairs_chunked(
|
|
285
|
+
query_pos, key_pos, valid_mask
|
|
286
|
+
)
|
|
287
|
+
P_LL_sparse[:, l, :, :] += ref_pairs
|
|
290
288
|
|
|
291
289
|
# 3. Single embedding terms (broadcasted)
|
|
290
|
+
# Expand C_L to match valid_indices batch dimension
|
|
291
|
+
if C_L.shape[0] != B:
|
|
292
|
+
C_L = C_L.expand(B, -1, -1) # [B, L, c_token]
|
|
292
293
|
# Gather key features for each query
|
|
294
|
+
C_L_queries = C_L.unsqueeze(2).expand(-1, -1, k, -1) # [B, L, k, c_token]
|
|
293
295
|
C_L_keys = torch.gather(
|
|
294
|
-
|
|
296
|
+
C_L_queries,
|
|
295
297
|
1,
|
|
296
298
|
valid_indices.unsqueeze(-1).expand(-1, -1, -1, C_L.shape[-1]),
|
|
297
|
-
) # [
|
|
298
|
-
C_L_queries = C_L.unsqueeze(2).expand(-1, -1, k, -1) # [D, L, k, c_token]
|
|
299
|
+
) # [B, L, k, c_token]
|
|
299
300
|
|
|
300
301
|
# Add single embeddings - match standard implementation structure
|
|
301
302
|
# Standard does: self.process_single_l(C_L).unsqueeze(-2) + self.process_single_m(C_L).unsqueeze(-3)
|
|
302
|
-
# We need to broadcast from [
|
|
303
|
-
single_l = self.process_single_l(C_L_queries) # [
|
|
304
|
-
single_m = self.process_single_m(C_L_keys) # [
|
|
303
|
+
# We need to broadcast from [B, L, k, c_atompair] to match this
|
|
304
|
+
single_l = self.process_single_l(C_L_queries) # [B, L, k, c_atompair]
|
|
305
|
+
single_m = self.process_single_m(C_L_keys) # [B, L, k, c_atompair]
|
|
305
306
|
P_LL_sparse += single_l + single_m
|
|
306
307
|
|
|
307
308
|
# 4. Token pair features Z_init_II
|
|
@@ -312,15 +313,16 @@ class ChunkedPairwiseEmbedder(nn.Module):
|
|
|
312
313
|
else:
|
|
313
314
|
tok_idx_expanded = tok_idx
|
|
314
315
|
|
|
315
|
-
|
|
316
|
+
# Expand tok_idx_expanded to match valid_indices batch dimension
|
|
317
|
+
if tok_idx_expanded.shape[0] != B:
|
|
318
|
+
tok_idx_expanded = tok_idx_expanded.expand(B, -1) # [B, L]
|
|
319
|
+
tok_queries = tok_idx_expanded.unsqueeze(2).expand(-1, -1, k) # [B, L, k]
|
|
316
320
|
# Use valid_indices for token mapping as well
|
|
317
|
-
tok_keys = torch.gather(
|
|
318
|
-
tok_idx_expanded.unsqueeze(2).expand(-1, -1, k), 1, valid_indices
|
|
319
|
-
) # [D, L, k]
|
|
321
|
+
tok_keys = torch.gather(tok_queries, 1, valid_indices) # [B, L, k]
|
|
320
322
|
|
|
321
323
|
# Gather Z_init_II[tok_queries, tok_keys] with safe indexing
|
|
322
324
|
# Z_init_II shape is [I, I, c_z] (3D), not 4D
|
|
323
|
-
# tok_queries shape: [
|
|
325
|
+
# tok_queries shape: [B, L, k] - each value is a token index
|
|
324
326
|
# We want: Z_init_II[tok_queries[d,l,k], tok_keys[d,l,k], :] for all d,l,k
|
|
325
327
|
|
|
326
328
|
I_z, I_z2, c_z = Z_init_II.shape
|
|
@@ -338,20 +340,20 @@ class ChunkedPairwiseEmbedder(nn.Module):
|
|
|
338
340
|
# Then we need to gather the sparse version
|
|
339
341
|
|
|
340
342
|
Z_pairs_processed = torch.zeros(
|
|
341
|
-
|
|
343
|
+
B, L, k, self.c_atompair, device=device, dtype=Z_processed.dtype
|
|
342
344
|
)
|
|
343
345
|
|
|
344
|
-
for
|
|
346
|
+
for b in range(B):
|
|
345
347
|
# For this batch, get the token queries and keys
|
|
346
|
-
tq = tok_queries[
|
|
347
|
-
tk = tok_keys[
|
|
348
|
+
tq = tok_queries[b] # [L, k]
|
|
349
|
+
tk = tok_keys[b] # [L, k]
|
|
348
350
|
|
|
349
351
|
# Ensure indices are within bounds
|
|
350
352
|
tq = torch.clamp(tq, 0, I_z - 1)
|
|
351
353
|
tk = torch.clamp(tk, 0, I_z2 - 1)
|
|
352
354
|
|
|
353
355
|
# Apply the double token indexing like standard implementation
|
|
354
|
-
Z_pairs_processed[
|
|
356
|
+
Z_pairs_processed[b] = Z_processed[tq, tk] # [L, k, c_atompair]
|
|
355
357
|
|
|
356
358
|
P_LL_sparse += Z_pairs_processed
|
|
357
359
|
|
rfd3/transforms/symmetry.py
CHANGED
|
@@ -60,13 +60,22 @@ class AddSymmetryFeats(Transform):
|
|
|
60
60
|
)
|
|
61
61
|
TIDs = torch.from_numpy(atom_array.get_annotation("sym_transform_id"))
|
|
62
62
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
63
|
+
# Get unique transforms by TID (more robust than unique_consecutive on each array)
|
|
64
|
+
unique_TIDs, inverse_indices = torch.unique(TIDs, return_inverse=True)
|
|
65
|
+
|
|
66
|
+
# Get the first occurrence of each unique TID
|
|
67
|
+
first_occurrence = torch.zeros(len(unique_TIDs), dtype=torch.long)
|
|
68
|
+
for i in range(len(TIDs)):
|
|
69
|
+
tid_idx = inverse_indices[i]
|
|
70
|
+
if first_occurrence[tid_idx] == 0 or i < first_occurrence[tid_idx]:
|
|
71
|
+
first_occurrence[tid_idx] = i
|
|
72
|
+
|
|
73
|
+
# Extract Ori, X, Y for each unique transform
|
|
74
|
+
Oris = Oris[first_occurrence]
|
|
75
|
+
Xs = Xs[first_occurrence]
|
|
76
|
+
Ys = Ys[first_occurrence]
|
|
77
|
+
TIDs = unique_TIDs
|
|
78
|
+
|
|
70
79
|
Rs, Ts = framecoords_to_RTs(Oris, Xs, Ys)
|
|
71
80
|
|
|
72
81
|
for R, T, transform_id in zip(Rs, Ts, TIDs):
|
rfd3/utils/inference.py
CHANGED
|
@@ -3,7 +3,6 @@ Utilities for inference input preparation
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
import logging
|
|
6
|
-
import os
|
|
7
6
|
from os import PathLike
|
|
8
7
|
from typing import Dict
|
|
9
8
|
|
|
@@ -365,32 +364,6 @@ def inference_load_(
|
|
|
365
364
|
return data
|
|
366
365
|
|
|
367
366
|
|
|
368
|
-
def ensure_input_is_abspath(args: dict, path: PathLike | None):
|
|
369
|
-
"""
|
|
370
|
-
Ensures the input source is an absolute path if exists, if not it will convert
|
|
371
|
-
|
|
372
|
-
args:
|
|
373
|
-
spec: Inference specification for atom array
|
|
374
|
-
path: None or file to which the input is relative to.
|
|
375
|
-
"""
|
|
376
|
-
if isinstance(args, str):
|
|
377
|
-
raise ValueError(
|
|
378
|
-
"Expected args to be a dictionary, got a string: {}. If you are using an input JSON ensure it contains dictionaries of arguments".format(
|
|
379
|
-
args
|
|
380
|
-
)
|
|
381
|
-
)
|
|
382
|
-
if "input" not in args or not exists(args["input"]):
|
|
383
|
-
return args
|
|
384
|
-
input = args["input"]
|
|
385
|
-
if not os.path.isabs(input):
|
|
386
|
-
input = os.path.abspath(os.path.join(os.path.dirname(path), input))
|
|
387
|
-
ranked_logger.info(
|
|
388
|
-
f"Input source path is relative, converted to absolute path: {input}"
|
|
389
|
-
)
|
|
390
|
-
args["input"] = input
|
|
391
|
-
return args
|
|
392
|
-
|
|
393
|
-
|
|
394
367
|
def ensure_inference_sampler_matches_design_spec(
|
|
395
368
|
design_spec: dict, inference_sampler: dict | None = None
|
|
396
369
|
):
|
|
@@ -401,7 +374,10 @@ def ensure_inference_sampler_matches_design_spec(
|
|
|
401
374
|
inference_sampler: Inference sampler dictionary
|
|
402
375
|
"""
|
|
403
376
|
has_symmetry_specification = [
|
|
404
|
-
True
|
|
377
|
+
True
|
|
378
|
+
if "symmetry" in item.keys() and item.get("symmetry") is not None
|
|
379
|
+
else False
|
|
380
|
+
for item in design_spec.values()
|
|
405
381
|
]
|
|
406
382
|
if any(has_symmetry_specification):
|
|
407
383
|
if (
|
|
File without changes
|
|
File without changes
|
|
File without changes
|