rc-foundry 0.1.6__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -118,14 +118,14 @@ def scatter_add_pair_features(P_LK_tgt, P_LK_indices, P_LA_src, P_LA_indices):
118
118
 
119
119
  Parameters
120
120
  ----------
121
- P_LK_indices : (D, L, k) LongTensor
121
+ P_LK_indices : (B, L, k) LongTensor
122
122
  Key indices | P_LK_indices[d, i, k] = global atom index for which atom i attends to.
123
- P_LK : (D, L, k, c) FloatTensor
123
+ P_LK : (B, L, k, c) FloatTensor
124
124
  Key features to scatter add into
125
125
 
126
- P_LA_indices : (D, L, a) LongTensor
126
+ P_LA_indices : (B, L, a) LongTensor
127
127
  Additional feature indices to scatter into P_LK.
128
- P_LA : (D, L, a, c) FloatTensor
128
+ P_LA : (B, L, a, c) FloatTensor
129
129
  Features corresponding to P_LA.
130
130
 
131
131
  Both index tensors contain indices representing D batch dim,
@@ -135,42 +135,42 @@ def scatter_add_pair_features(P_LK_tgt, P_LK_indices, P_LA_src, P_LA_indices):
135
135
 
136
136
  """
137
137
  # Handle case when indices and P_LA don't have batch dimensions
138
- D, L, k = P_LK_indices.shape
138
+ B, L, k = P_LK_indices.shape
139
139
  if P_LA_indices.ndim == 2:
140
- P_LA_indices = P_LA_indices.unsqueeze(0).expand(D, -1, -1)
140
+ P_LA_indices = P_LA_indices.unsqueeze(0).expand(B, -1, -1)
141
141
  if P_LA_src.ndim == 3:
142
- P_LA_src = P_LA_src.unsqueeze(0).expand(D, -1, -1)
142
+ P_LA_src = P_LA_src.unsqueeze(0).expand(B, -1, -1)
143
143
  assert (
144
144
  P_LA_src.shape[-1] == P_LK_tgt.shape[-1]
145
145
  ), "Channel dims do not match, got: {} vs {}".format(
146
146
  P_LA_src.shape[-1], P_LK_tgt.shape[-1]
147
147
  )
148
148
 
149
- matches = P_LA_indices.unsqueeze(-1) == P_LK_indices.unsqueeze(-2) # (D, L, a, k)
149
+ matches = P_LA_indices.unsqueeze(-1) == P_LK_indices.unsqueeze(-2) # (B, L, a, k)
150
150
  if not torch.all(matches.sum(dim=(-1, -2)) >= 1):
151
151
  raise ValueError("Found multiple scatter indices for some atoms")
152
152
  elif not torch.all(matches.sum(dim=-1) <= 1):
153
153
  raise ValueError("Did not find a scatter index for every atom")
154
- k_indices = matches.long().argmax(dim=-1) # (D, L, a)
154
+ k_indices = matches.long().argmax(dim=-1) # (B, L, a)
155
155
  scatter_indices = k_indices.unsqueeze(-1).expand(
156
156
  -1, -1, -1, P_LK_tgt.shape[-1]
157
- ) # (D, L, a, c)
157
+ ) # (B, L, a, c)
158
158
  P_LK_tgt = P_LK_tgt.scatter_add(dim=2, index=scatter_indices, src=P_LA_src)
159
159
  return P_LK_tgt
160
160
 
161
161
 
162
162
  def _batched_gather(values: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
163
163
  """
164
- values : (D, L, C)
165
- idx : (D, L, k)
166
- returns: (D, L, k, C)
164
+ values : (B, L, C)
165
+ idx : (B, L, k)
166
+ returns: (B, L, k, C)
167
167
  """
168
- D, L, C = values.shape
168
+ B, L, C = values.shape
169
169
  k = idx.shape[-1]
170
170
 
171
- # (D, L, 1, C) → stride-0 along k → (D, L, k, C)
171
+ # (B, L, 1, C) → stride-0 along k → (B, L, k, C)
172
172
  src = values.unsqueeze(2).expand(-1, -1, k, -1)
173
- idx = idx.unsqueeze(-1).expand(-1, -1, -1, C) # (D, L, k, C)
173
+ idx = idx.unsqueeze(-1).expand(-1, -1, -1, C) # (B, L, k, C)
174
174
 
175
175
  return torch.gather(src, 1, idx) # dim=1 is the L-axis
176
176
 
@@ -196,7 +196,7 @@ def create_attention_indices(
196
196
  X_L = torch.randn(
197
197
  (1, L, 3), device=device, dtype=torch.float
198
198
  ) # [L, 3] - random
199
- D_LL = torch.cdist(X_L, X_L, p=2) # [D, L, L] - pairwise atom distances
199
+ D_LL = torch.cdist(X_L, X_L, p=2) # [B, L, L] - pairwise atom distances
200
200
 
201
201
  # Create attention indices using neighbour distances
202
202
  base_mask = ~f["unindexing_pair_mask"][
@@ -231,7 +231,7 @@ def create_attention_indices(
231
231
  k_max=k_actual,
232
232
  chain_id=chain_ids,
233
233
  base_mask=base_mask,
234
- ) # [D, L, k] | indices[b, i, j] = atom index for atom i to j-th attn query
234
+ ) # [B, L, k] | indices[b, i, j] = atom index for atom i to j-th attn query
235
235
 
236
236
  return attn_indices
237
237
 
@@ -245,7 +245,7 @@ def get_sparse_attention_indices_with_inter_chain(
245
245
 
246
246
  Args:
247
247
  tok_idx: atom to token mapping
248
- D_LL: pairwise distances [D, L, L]
248
+ D_LL: pairwise distances [B, L, L]
249
249
  n_seq_neighbours: number of sequence neighbors
250
250
  k_intra: number of intra-chain attention keys
251
251
  k_inter: number of inter-chain attention keys
@@ -253,29 +253,29 @@ def get_sparse_attention_indices_with_inter_chain(
253
253
  base_mask: base mask for valid pairs
254
254
 
255
255
  Returns:
256
- attn_indices: [D, L, k_total] where k_total = k_intra + k_inter
256
+ attn_indices: [B, L, k_total] where k_total = k_intra + k_inter
257
257
  """
258
- D, L, _ = D_LL.shape
258
+ B, L, _ = D_LL.shape
259
259
 
260
260
  # Get regular intra-chain indices (limited to k_intra)
261
261
  intra_indices = get_sparse_attention_indices(
262
262
  tok_idx, D_LL, n_seq_neighbours, k_intra, chain_id, base_mask
263
- ) # [D, L, k_intra]
263
+ ) # [B, L, k_intra]
264
264
 
265
265
  # Get inter-chain indices for clash avoidance
266
- inter_indices = torch.zeros(D, L, k_inter, dtype=torch.long, device=D_LL.device)
267
-
268
- for d in range(D):
269
- for l in range(L):
270
- query_chain = chain_id[l]
266
+ inter_indices = torch.zeros(B, L, k_inter, dtype=torch.long, device=D_LL.device)
267
+ unique_chains = torch.unique(chain_id)
268
+ for b in range(B):
269
+ for c in unique_chains:
270
+ query_chain = chain_id[c]
271
271
 
272
272
  # Find atoms from different chains
273
- other_chain_mask = (chain_id != query_chain) & base_mask[l, :]
273
+ other_chain_mask = (chain_id != query_chain) & base_mask[c, :]
274
274
  other_chain_atoms = torch.where(other_chain_mask)[0]
275
275
 
276
276
  if len(other_chain_atoms) > 0:
277
277
  # Get distances to other chains
278
- distances_to_other = D_LL[d, l, other_chain_atoms]
278
+ distances_to_other = D_LL[b, c, other_chain_atoms]
279
279
 
280
280
  # Select k_inter closest atoms from other chains
281
281
  n_select = min(k_inter, len(other_chain_atoms))
@@ -283,23 +283,23 @@ def get_sparse_attention_indices_with_inter_chain(
283
283
  selected_atoms = other_chain_atoms[closest_idx]
284
284
 
285
285
  # Fill inter-chain indices
286
- inter_indices[d, l, :n_select] = selected_atoms
286
+ inter_indices[b, c, :n_select] = selected_atoms
287
287
  # Pad with random atoms if needed
288
288
  if n_select < k_inter:
289
289
  padding = torch.randint(
290
290
  0, L, (k_inter - n_select,), device=D_LL.device
291
291
  )
292
- inter_indices[d, l, n_select:] = padding
292
+ inter_indices[b, c, n_select:] = padding
293
293
  else:
294
294
  # No other chains found, fill with random indices
295
- inter_indices[d, l, :] = torch.randint(
295
+ inter_indices[b, c, :] = torch.randint(
296
296
  0, L, (k_inter,), device=D_LL.device
297
297
  )
298
298
 
299
299
  # Combine intra and inter chain indices
300
300
  combined_indices = torch.cat(
301
301
  [intra_indices, inter_indices], dim=-1
302
- ) # [D, L, k_total]
302
+ ) # [B, L, k_total]
303
303
 
304
304
  return combined_indices
305
305
 
@@ -30,32 +30,32 @@ class ChunkedPositionPairDistEmbedder(nn.Module):
30
30
 
31
31
  def compute_pairs_chunked(
32
32
  self,
33
- query_pos: torch.Tensor, # [D, 3]
34
- key_pos: torch.Tensor, # [D, k, 3]
35
- valid_mask: torch.Tensor, # [D, k, 1]
33
+ query_pos: torch.Tensor, # [B, 3]
34
+ key_pos: torch.Tensor, # [B, k, 3]
35
+ valid_mask: torch.Tensor, # [B, k, 1]
36
36
  ) -> torch.Tensor:
37
37
  """
38
38
  Compute pairwise embeddings for specific query-key pairs.
39
39
 
40
40
  Args:
41
- query_pos: Query positions [D, 3]
42
- key_pos: Key positions [D, k, 3]
43
- valid_mask: Valid pair mask [D, k, 1]
41
+ query_pos: Query positions [B, 3]
42
+ key_pos: Key positions [B, k, 3]
43
+ valid_mask: Valid pair mask [B, k, 1]
44
44
 
45
45
  Returns:
46
- P_sparse: Pairwise embeddings [D, k, c_atompair]
46
+ P_sparse: Pairwise embeddings [B, k, c_atompair]
47
47
  """
48
- D, k = key_pos.shape[:2]
48
+ B, k = key_pos.shape[:2]
49
49
 
50
- # Compute pairwise distances: [D, k, 3]
51
- D_pairs = query_pos.unsqueeze(1) - key_pos # [D, 1, 3] - [D, k, 3] = [D, k, 3]
50
+ # Compute pairwise distances: [B, k, 3]
51
+ D_pairs = query_pos.unsqueeze(1) - key_pos # [B, 1, 3] - [B, k, 3] = [B, k, 3]
52
52
 
53
53
  if self.embed_frame:
54
54
  # Embed pairwise distances
55
- P_pairs = self.process_d(D_pairs) * valid_mask # [D, k, c_atompair]
55
+ P_pairs = self.process_d(D_pairs) * valid_mask # [B, k, c_atompair]
56
56
 
57
57
  # Add inverse distance embedding
58
- norm_sq = torch.linalg.norm(D_pairs, dim=-1, keepdim=True) ** 2 # [D, k, 1]
58
+ norm_sq = torch.linalg.norm(D_pairs, dim=-1, keepdim=True) ** 2 # [B, k, 1]
59
59
  inv_dist = 1 / (1 + norm_sq)
60
60
  P_pairs = P_pairs + self.process_inverse_dist(inv_dist) * valid_mask
61
61
 
@@ -95,19 +95,19 @@ class ChunkedSinusoidalDistEmbed(nn.Module):
95
95
 
96
96
  def compute_pairs_chunked(
97
97
  self,
98
- query_pos: torch.Tensor, # [D, 3]
99
- key_pos: torch.Tensor, # [D, k, 3]
100
- valid_mask: torch.Tensor, # [D, k, 1]
98
+ query_pos: torch.Tensor, # [B, 3]
99
+ key_pos: torch.Tensor, # [B, k, 3]
100
+ valid_mask: torch.Tensor, # [B, k, 1]
101
101
  ) -> torch.Tensor:
102
102
  """
103
103
  Compute sinusoidal distance embeddings for specific query-key pairs.
104
104
  """
105
- D, k = key_pos.shape[:2]
105
+ B, k = key_pos.shape[:2]
106
106
  device = query_pos.device
107
107
 
108
108
  # Compute pairwise distances
109
- D_pairs = query_pos.unsqueeze(1) - key_pos # [D, k, 3]
110
- dist_matrix = torch.linalg.norm(D_pairs, dim=-1) # [D, k]
109
+ D_pairs = query_pos.unsqueeze(1) - key_pos # [B, k, 3]
110
+ dist_matrix = torch.linalg.norm(D_pairs, dim=-1) # [B, k]
111
111
 
112
112
  # Sinusoidal embedding
113
113
  half_dim = self.n_freqs
@@ -117,13 +117,13 @@ class ChunkedSinusoidalDistEmbed(nn.Module):
117
117
  / half_dim
118
118
  ) # [n_freqs]
119
119
 
120
- angles = dist_matrix.unsqueeze(-1) * freq # [D, k, n_freqs]
120
+ angles = dist_matrix.unsqueeze(-1) * freq # [B, k, n_freqs]
121
121
  sin_embed = torch.sin(angles)
122
122
  cos_embed = torch.cos(angles)
123
- sincos_embed = torch.cat([sin_embed, cos_embed], dim=-1) # [D, k, 2*n_freqs]
123
+ sincos_embed = torch.cat([sin_embed, cos_embed], dim=-1) # [B, k, 2*n_freqs]
124
124
 
125
125
  # Linear projection
126
- P_pairs = self.output_proj(sincos_embed) # [D, k, c_atompair]
126
+ P_pairs = self.output_proj(sincos_embed) # [B, k, c_atompair]
127
127
  P_pairs = P_pairs * valid_mask
128
128
 
129
129
  # Add linear embedding of valid mask
@@ -191,8 +191,8 @@ class ChunkedPairwiseEmbedder(nn.Module):
191
191
  def forward_chunked(
192
192
  self,
193
193
  f: dict,
194
- indices: torch.Tensor, # [D, L, k] - sparse attention indices
195
- C_L: torch.Tensor, # [D, L, c_token] - atom features
194
+ indices: torch.Tensor, # [B, L, k] - sparse attention indices
195
+ C_L: torch.Tensor, # [B, L, c_token] - atom features
196
196
  Z_init_II: torch.Tensor, # [I, I, c_z] - token pair features
197
197
  tok_idx: torch.Tensor, # [L] - atom to token mapping
198
198
  ) -> torch.Tensor:
@@ -208,20 +208,20 @@ class ChunkedPairwiseEmbedder(nn.Module):
208
208
 
209
209
  Args:
210
210
  f: Feature dictionary
211
- indices: Sparse attention indices [D, L, k]
212
- C_L: Atom-level features [D, L, c_token]
211
+ indices: Sparse attention indices [B, L, k]
212
+ C_L: Atom-level features [B, L, c_token]
213
213
  Z_init_II: Token-level pair features [I, I, c_z]
214
214
  tok_idx: Atom to token mapping [L]
215
215
 
216
216
  Returns:
217
- P_LL_sparse: Sparse pairwise features [D, L, k, c_atompair]
217
+ P_LL_sparse: Sparse pairwise features [B, L, k, c_atompair]
218
218
  """
219
- D, L, k = indices.shape
219
+ B, L, k = indices.shape
220
220
  device = indices.device
221
221
 
222
222
  # Initialize sparse P_LL
223
223
  P_LL_sparse = torch.zeros(
224
- D, L, k, self.c_atompair, device=device, dtype=C_L.dtype
224
+ B, L, k, self.c_atompair, device=device, dtype=C_L.dtype
225
225
  )
226
226
 
227
227
  # Handle both batched and non-batched C_L
@@ -237,71 +237,72 @@ class ChunkedPairwiseEmbedder(nn.Module):
237
237
  if valid_indices.dim() == 2: # [L, k] - add batch dimension
238
238
  valid_indices = valid_indices.unsqueeze(0).expand(
239
239
  C_L.shape[0], -1, -1
240
- ) # [D, L, k]
240
+ ) # [B, L, k]
241
241
 
242
242
  # 1. Motif position embedding (if exists)
243
243
  if self.motif_pos_embedder is not None and "motif_pos" in f:
244
244
  motif_pos = f["motif_pos"] # [L, 3]
245
245
  is_motif = f["is_motif_atom_with_fixed_coord"] # [L]
246
-
246
+ is_motif_idx = torch.where(is_motif)[0]
247
247
  # For each query position
248
- for l in range(L):
249
- if is_motif[l]: # Only compute if query is motif
250
- key_indices = valid_indices[:, l, :] # [D, k] - use clamped indices
251
- key_pos = motif_pos[key_indices] # [D, k, 3]
252
- query_pos = motif_pos[l].unsqueeze(0).expand(D, -1) # [D, 3]
253
-
254
- # Valid mask: both query and keys must be motif
255
- key_is_motif = is_motif[key_indices] # [D, k]
256
- valid_mask = key_is_motif.unsqueeze(-1).float() # [D, k, 1]
257
-
258
- if valid_mask.sum() > 0:
259
- motif_pairs = self.motif_pos_embedder.compute_pairs_chunked(
260
- query_pos, key_pos, valid_mask
261
- )
262
- P_LL_sparse[:, l, :, :] += motif_pairs
248
+ for l in is_motif_idx:
249
+ key_indices = valid_indices[:, l, :] # [B, k] - use clamped indices
250
+ key_pos = motif_pos[key_indices] # [B, k, 3]
251
+ query_pos = motif_pos[l].unsqueeze(0).expand(B, -1) # [B, 3]
252
+
253
+ # Valid mask: both query and keys must be motif
254
+ key_is_motif = is_motif[key_indices] # [B, k]
255
+ valid_mask = key_is_motif.unsqueeze(-1).float() # [B, k, 1]
256
+
257
+ if valid_mask.sum() > 0:
258
+ motif_pairs = self.motif_pos_embedder.compute_pairs_chunked(
259
+ query_pos, key_pos, valid_mask
260
+ )
261
+ P_LL_sparse[:, l, :, :] += motif_pairs
263
262
 
264
263
  # 2. Reference position embedding (if exists)
265
264
  if self.ref_pos_embedder is not None and "ref_pos" in f:
266
265
  ref_pos = f["ref_pos"] # [L, 3]
267
266
  ref_space_uid = f["ref_space_uid"] # [L]
268
267
  is_motif_seq = f["is_motif_atom_with_fixed_seq"] # [L]
269
-
270
- for l in range(L):
271
- if is_motif_seq[l]: # Only compute if query has sequence
272
- key_indices = valid_indices[:, l, :] # [D, k] - use clamped indices
273
- key_pos = ref_pos[key_indices] # [D, k, 3]
274
- query_pos = ref_pos[l].unsqueeze(0).expand(D, -1) # [D, 3]
275
-
276
- # Valid mask: same token and both have sequence
277
- key_space_uid = ref_space_uid[key_indices] # [D, k]
278
- key_is_motif_seq = is_motif_seq[key_indices] # [D, k]
279
-
280
- same_token = key_space_uid == ref_space_uid[l] # [D, k]
281
- valid_mask = (
282
- (same_token & key_is_motif_seq).unsqueeze(-1).float()
283
- ) # [D, k, 1]
284
-
285
- if valid_mask.sum() > 0:
286
- ref_pairs = self.ref_pos_embedder.compute_pairs_chunked(
287
- query_pos, key_pos, valid_mask
288
- )
289
- P_LL_sparse[:, l, :, :] += ref_pairs
268
+ is_motif_seq_idx = torch.where(is_motif_seq)[0]
269
+ for l in is_motif_seq_idx:
270
+ key_indices = valid_indices[:, l, :] # [B, k] - use clamped indices
271
+ key_pos = ref_pos[key_indices] # [B, k, 3]
272
+ query_pos = ref_pos[l].unsqueeze(0).expand(B, -1) # [B, 3]
273
+
274
+ # Valid mask: same token and both have sequence
275
+ key_space_uid = ref_space_uid[key_indices] # [B, k]
276
+ key_is_motif_seq = is_motif_seq[key_indices] # [B, k]
277
+
278
+ same_token = key_space_uid == ref_space_uid[l] # [B, k]
279
+ valid_mask = (
280
+ (same_token & key_is_motif_seq).unsqueeze(-1).float()
281
+ ) # [B, k, 1]
282
+
283
+ if valid_mask.sum() > 0:
284
+ ref_pairs = self.ref_pos_embedder.compute_pairs_chunked(
285
+ query_pos, key_pos, valid_mask
286
+ )
287
+ P_LL_sparse[:, l, :, :] += ref_pairs
290
288
 
291
289
  # 3. Single embedding terms (broadcasted)
290
+ # Expand C_L to match valid_indices batch dimension
291
+ if C_L.shape[0] != B:
292
+ C_L = C_L.expand(B, -1, -1) # [B, L, c_token]
292
293
  # Gather key features for each query
294
+ C_L_queries = C_L.unsqueeze(2).expand(-1, -1, k, -1) # [B, L, k, c_token]
293
295
  C_L_keys = torch.gather(
294
- C_L.unsqueeze(2).expand(-1, -1, k, -1),
296
+ C_L_queries,
295
297
  1,
296
298
  valid_indices.unsqueeze(-1).expand(-1, -1, -1, C_L.shape[-1]),
297
- ) # [D, L, k, c_token]
298
- C_L_queries = C_L.unsqueeze(2).expand(-1, -1, k, -1) # [D, L, k, c_token]
299
+ ) # [B, L, k, c_token]
299
300
 
300
301
  # Add single embeddings - match standard implementation structure
301
302
  # Standard does: self.process_single_l(C_L).unsqueeze(-2) + self.process_single_m(C_L).unsqueeze(-3)
302
- # We need to broadcast from [D, L, k, c_atompair] to match this
303
- single_l = self.process_single_l(C_L_queries) # [D, L, k, c_atompair]
304
- single_m = self.process_single_m(C_L_keys) # [D, L, k, c_atompair]
303
+ # We need to broadcast from [B, L, k, c_atompair] to match this
304
+ single_l = self.process_single_l(C_L_queries) # [B, L, k, c_atompair]
305
+ single_m = self.process_single_m(C_L_keys) # [B, L, k, c_atompair]
305
306
  P_LL_sparse += single_l + single_m
306
307
 
307
308
  # 4. Token pair features Z_init_II
@@ -312,15 +313,16 @@ class ChunkedPairwiseEmbedder(nn.Module):
312
313
  else:
313
314
  tok_idx_expanded = tok_idx
314
315
 
315
- tok_queries = tok_idx_expanded.unsqueeze(2).expand(-1, -1, k) # [D, L, k]
316
+ # Expand tok_idx_expanded to match valid_indices batch dimension
317
+ if tok_idx_expanded.shape[0] != B:
318
+ tok_idx_expanded = tok_idx_expanded.expand(B, -1) # [B, L]
319
+ tok_queries = tok_idx_expanded.unsqueeze(2).expand(-1, -1, k) # [B, L, k]
316
320
  # Use valid_indices for token mapping as well
317
- tok_keys = torch.gather(
318
- tok_idx_expanded.unsqueeze(2).expand(-1, -1, k), 1, valid_indices
319
- ) # [D, L, k]
321
+ tok_keys = torch.gather(tok_queries, 1, valid_indices) # [B, L, k]
320
322
 
321
323
  # Gather Z_init_II[tok_queries, tok_keys] with safe indexing
322
324
  # Z_init_II shape is [I, I, c_z] (3D), not 4D
323
- # tok_queries shape: [D, L, k] - each value is a token index
325
+ # tok_queries shape: [B, L, k] - each value is a token index
324
326
  # We want: Z_init_II[tok_queries[d,l,k], tok_keys[d,l,k], :] for all d,l,k
325
327
 
326
328
  I_z, I_z2, c_z = Z_init_II.shape
@@ -338,20 +340,20 @@ class ChunkedPairwiseEmbedder(nn.Module):
338
340
  # Then we need to gather the sparse version
339
341
 
340
342
  Z_pairs_processed = torch.zeros(
341
- D, L, k, self.c_atompair, device=device, dtype=Z_processed.dtype
343
+ B, L, k, self.c_atompair, device=device, dtype=Z_processed.dtype
342
344
  )
343
345
 
344
- for d in range(D):
346
+ for b in range(B):
345
347
  # For this batch, get the token queries and keys
346
- tq = tok_queries[d] # [L, k]
347
- tk = tok_keys[d] # [L, k]
348
+ tq = tok_queries[b] # [L, k]
349
+ tk = tok_keys[b] # [L, k]
348
350
 
349
351
  # Ensure indices are within bounds
350
352
  tq = torch.clamp(tq, 0, I_z - 1)
351
353
  tk = torch.clamp(tk, 0, I_z2 - 1)
352
354
 
353
355
  # Apply the double token indexing like standard implementation
354
- Z_pairs_processed[d] = Z_processed[tq, tk] # [L, k, c_atompair]
356
+ Z_pairs_processed[b] = Z_processed[tq, tk] # [L, k, c_atompair]
355
357
 
356
358
  P_LL_sparse += Z_pairs_processed
357
359
 
rfd3/run_inference.py CHANGED
@@ -12,7 +12,9 @@ load_dotenv(override=True)
12
12
 
13
13
  # For pip-installed package, configs should be relative to this file
14
14
  # Adjust this path based on where configs are bundled in the package
15
- _config_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "configs")
15
+ _config_path = os.path.join(
16
+ os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "configs"
17
+ )
16
18
 
17
19
 
18
20
  @hydra.main(
@@ -60,13 +60,22 @@ class AddSymmetryFeats(Transform):
60
60
  )
61
61
  TIDs = torch.from_numpy(atom_array.get_annotation("sym_transform_id"))
62
62
 
63
- Oris = torch.unique_consecutive(Oris, dim=0)
64
- Xs = torch.unique_consecutive(Xs, dim=0)
65
- Ys = torch.unique_consecutive(Ys, dim=0)
66
- TIDs = torch.unique_consecutive(TIDs, dim=0)
67
- # the case in which there is only rotation (no translation), Ori = [0,0,0]
68
- if len(Oris) == 1 and (Oris == 0).all():
69
- Oris = Oris.repeat(len(Xs), 1)
63
+ # Get unique transforms by TID (more robust than unique_consecutive on each array)
64
+ unique_TIDs, inverse_indices = torch.unique(TIDs, return_inverse=True)
65
+
66
+ # Get the first occurrence of each unique TID
67
+ first_occurrence = torch.zeros(len(unique_TIDs), dtype=torch.long)
68
+ for i in range(len(TIDs)):
69
+ tid_idx = inverse_indices[i]
70
+ if first_occurrence[tid_idx] == 0 or i < first_occurrence[tid_idx]:
71
+ first_occurrence[tid_idx] = i
72
+
73
+ # Extract Ori, X, Y for each unique transform
74
+ Oris = Oris[first_occurrence]
75
+ Xs = Xs[first_occurrence]
76
+ Ys = Ys[first_occurrence]
77
+ TIDs = unique_TIDs
78
+
70
79
  Rs, Ts = framecoords_to_RTs(Oris, Xs, Ys)
71
80
 
72
81
  for R, T, transform_id in zip(Rs, Ts, TIDs):
rfd3/utils/inference.py CHANGED
@@ -3,7 +3,6 @@ Utilities for inference input preparation
3
3
  """
4
4
 
5
5
  import logging
6
- import os
7
6
  from os import PathLike
8
7
  from typing import Dict
9
8
 
@@ -365,30 +364,30 @@ def inference_load_(
365
364
  return data
366
365
 
367
366
 
368
- def ensure_input_is_abspath(args: dict, path: PathLike | None):
367
+ def ensure_inference_sampler_matches_design_spec(
368
+ design_spec: dict, inference_sampler: dict | None = None
369
+ ):
369
370
  """
370
- Ensures the input source is an absolute path if exists, if not it will convert
371
-
372
- args:
373
- spec: Inference specification for atom array
374
- path: None or file to which the input is relative to.
371
+ Ensure the inference sampler is set to the correct sampler for the design specification.
372
+ Args:
373
+ design_spec: Design specification dictionary
374
+ inference_sampler: Inference sampler dictionary
375
375
  """
376
- if isinstance(args, str):
377
- raise ValueError(
378
- "Expected args to be a dictionary, got a string: {}. If you are using an input JSON ensure it contains dictionaries of arguments".format(
379
- args
376
+ has_symmetry_specification = [
377
+ True
378
+ if "symmetry" in item.keys() and item.get("symmetry") is not None
379
+ else False
380
+ for item in design_spec.values()
381
+ ]
382
+ if any(has_symmetry_specification):
383
+ if (
384
+ inference_sampler is None
385
+ or inference_sampler.get("kind", "default") != "symmetry"
386
+ ):
387
+ raise ValueError(
388
+ "You requested for symmetric designs, but inference sampler is not set to symmetry. "
389
+ "Please add inference_sampler.kind='symmetry' to your command."
380
390
  )
381
- )
382
- if "input" not in args or not exists(args["input"]):
383
- return args
384
- input = args["input"]
385
- if not os.path.isabs(input):
386
- input = os.path.abspath(os.path.join(os.path.dirname(path), input))
387
- ranked_logger.info(
388
- f"Input source path is relative, converted to absolute path: {input}"
389
- )
390
- args["input"] = input
391
- return args
392
391
 
393
392
 
394
393
  #################################################################################