rc-foundry 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -30,32 +30,32 @@ class ChunkedPositionPairDistEmbedder(nn.Module):
30
30
 
31
31
  def compute_pairs_chunked(
32
32
  self,
33
- query_pos: torch.Tensor, # [D, 3]
34
- key_pos: torch.Tensor, # [D, k, 3]
35
- valid_mask: torch.Tensor, # [D, k, 1]
33
+ query_pos: torch.Tensor, # [B, 3]
34
+ key_pos: torch.Tensor, # [B, k, 3]
35
+ valid_mask: torch.Tensor, # [B, k, 1]
36
36
  ) -> torch.Tensor:
37
37
  """
38
38
  Compute pairwise embeddings for specific query-key pairs.
39
39
 
40
40
  Args:
41
- query_pos: Query positions [D, 3]
42
- key_pos: Key positions [D, k, 3]
43
- valid_mask: Valid pair mask [D, k, 1]
41
+ query_pos: Query positions [B, 3]
42
+ key_pos: Key positions [B, k, 3]
43
+ valid_mask: Valid pair mask [B, k, 1]
44
44
 
45
45
  Returns:
46
- P_sparse: Pairwise embeddings [D, k, c_atompair]
46
+ P_sparse: Pairwise embeddings [B, k, c_atompair]
47
47
  """
48
- D, k = key_pos.shape[:2]
48
+ B, k = key_pos.shape[:2]
49
49
 
50
- # Compute pairwise distances: [D, k, 3]
51
- D_pairs = query_pos.unsqueeze(1) - key_pos # [D, 1, 3] - [D, k, 3] = [D, k, 3]
50
+ # Compute pairwise distances: [B, k, 3]
51
+ D_pairs = query_pos.unsqueeze(1) - key_pos # [B, 1, 3] - [B, k, 3] = [B, k, 3]
52
52
 
53
53
  if self.embed_frame:
54
54
  # Embed pairwise distances
55
- P_pairs = self.process_d(D_pairs) * valid_mask # [D, k, c_atompair]
55
+ P_pairs = self.process_d(D_pairs) * valid_mask # [B, k, c_atompair]
56
56
 
57
57
  # Add inverse distance embedding
58
- norm_sq = torch.linalg.norm(D_pairs, dim=-1, keepdim=True) ** 2 # [D, k, 1]
58
+ norm_sq = torch.linalg.norm(D_pairs, dim=-1, keepdim=True) ** 2 # [B, k, 1]
59
59
  inv_dist = 1 / (1 + norm_sq)
60
60
  P_pairs = P_pairs + self.process_inverse_dist(inv_dist) * valid_mask
61
61
 
@@ -95,19 +95,19 @@ class ChunkedSinusoidalDistEmbed(nn.Module):
95
95
 
96
96
  def compute_pairs_chunked(
97
97
  self,
98
- query_pos: torch.Tensor, # [D, 3]
99
- key_pos: torch.Tensor, # [D, k, 3]
100
- valid_mask: torch.Tensor, # [D, k, 1]
98
+ query_pos: torch.Tensor, # [B, 3]
99
+ key_pos: torch.Tensor, # [B, k, 3]
100
+ valid_mask: torch.Tensor, # [B, k, 1]
101
101
  ) -> torch.Tensor:
102
102
  """
103
103
  Compute sinusoidal distance embeddings for specific query-key pairs.
104
104
  """
105
- D, k = key_pos.shape[:2]
105
+ B, k = key_pos.shape[:2]
106
106
  device = query_pos.device
107
107
 
108
108
  # Compute pairwise distances
109
- D_pairs = query_pos.unsqueeze(1) - key_pos # [D, k, 3]
110
- dist_matrix = torch.linalg.norm(D_pairs, dim=-1) # [D, k]
109
+ D_pairs = query_pos.unsqueeze(1) - key_pos # [B, k, 3]
110
+ dist_matrix = torch.linalg.norm(D_pairs, dim=-1) # [B, k]
111
111
 
112
112
  # Sinusoidal embedding
113
113
  half_dim = self.n_freqs
@@ -117,13 +117,13 @@ class ChunkedSinusoidalDistEmbed(nn.Module):
117
117
  / half_dim
118
118
  ) # [n_freqs]
119
119
 
120
- angles = dist_matrix.unsqueeze(-1) * freq # [D, k, n_freqs]
120
+ angles = dist_matrix.unsqueeze(-1) * freq # [B, k, n_freqs]
121
121
  sin_embed = torch.sin(angles)
122
122
  cos_embed = torch.cos(angles)
123
- sincos_embed = torch.cat([sin_embed, cos_embed], dim=-1) # [D, k, 2*n_freqs]
123
+ sincos_embed = torch.cat([sin_embed, cos_embed], dim=-1) # [B, k, 2*n_freqs]
124
124
 
125
125
  # Linear projection
126
- P_pairs = self.output_proj(sincos_embed) # [D, k, c_atompair]
126
+ P_pairs = self.output_proj(sincos_embed) # [B, k, c_atompair]
127
127
  P_pairs = P_pairs * valid_mask
128
128
 
129
129
  # Add linear embedding of valid mask
@@ -191,8 +191,8 @@ class ChunkedPairwiseEmbedder(nn.Module):
191
191
  def forward_chunked(
192
192
  self,
193
193
  f: dict,
194
- indices: torch.Tensor, # [D, L, k] - sparse attention indices
195
- C_L: torch.Tensor, # [D, L, c_token] - atom features
194
+ indices: torch.Tensor, # [B, L, k] - sparse attention indices
195
+ C_L: torch.Tensor, # [B, L, c_token] - atom features
196
196
  Z_init_II: torch.Tensor, # [I, I, c_z] - token pair features
197
197
  tok_idx: torch.Tensor, # [L] - atom to token mapping
198
198
  ) -> torch.Tensor:
@@ -208,20 +208,20 @@ class ChunkedPairwiseEmbedder(nn.Module):
208
208
 
209
209
  Args:
210
210
  f: Feature dictionary
211
- indices: Sparse attention indices [D, L, k]
212
- C_L: Atom-level features [D, L, c_token]
211
+ indices: Sparse attention indices [B, L, k]
212
+ C_L: Atom-level features [B, L, c_token]
213
213
  Z_init_II: Token-level pair features [I, I, c_z]
214
214
  tok_idx: Atom to token mapping [L]
215
215
 
216
216
  Returns:
217
- P_LL_sparse: Sparse pairwise features [D, L, k, c_atompair]
217
+ P_LL_sparse: Sparse pairwise features [B, L, k, c_atompair]
218
218
  """
219
- D, L, k = indices.shape
219
+ B, L, k = indices.shape
220
220
  device = indices.device
221
221
 
222
222
  # Initialize sparse P_LL
223
223
  P_LL_sparse = torch.zeros(
224
- D, L, k, self.c_atompair, device=device, dtype=C_L.dtype
224
+ B, L, k, self.c_atompair, device=device, dtype=C_L.dtype
225
225
  )
226
226
 
227
227
  # Handle both batched and non-batched C_L
@@ -237,71 +237,72 @@ class ChunkedPairwiseEmbedder(nn.Module):
237
237
  if valid_indices.dim() == 2: # [L, k] - add batch dimension
238
238
  valid_indices = valid_indices.unsqueeze(0).expand(
239
239
  C_L.shape[0], -1, -1
240
- ) # [D, L, k]
240
+ ) # [B, L, k]
241
241
 
242
242
  # 1. Motif position embedding (if exists)
243
243
  if self.motif_pos_embedder is not None and "motif_pos" in f:
244
244
  motif_pos = f["motif_pos"] # [L, 3]
245
245
  is_motif = f["is_motif_atom_with_fixed_coord"] # [L]
246
-
246
+ is_motif_idx = torch.where(is_motif)[0]
247
247
  # For each query position
248
- for l in range(L):
249
- if is_motif[l]: # Only compute if query is motif
250
- key_indices = valid_indices[:, l, :] # [D, k] - use clamped indices
251
- key_pos = motif_pos[key_indices] # [D, k, 3]
252
- query_pos = motif_pos[l].unsqueeze(0).expand(D, -1) # [D, 3]
253
-
254
- # Valid mask: both query and keys must be motif
255
- key_is_motif = is_motif[key_indices] # [D, k]
256
- valid_mask = key_is_motif.unsqueeze(-1).float() # [D, k, 1]
257
-
258
- if valid_mask.sum() > 0:
259
- motif_pairs = self.motif_pos_embedder.compute_pairs_chunked(
260
- query_pos, key_pos, valid_mask
261
- )
262
- P_LL_sparse[:, l, :, :] += motif_pairs
248
+ for l in is_motif_idx:
249
+ key_indices = valid_indices[:, l, :] # [B, k] - use clamped indices
250
+ key_pos = motif_pos[key_indices] # [B, k, 3]
251
+ query_pos = motif_pos[l].unsqueeze(0).expand(B, -1) # [B, 3]
252
+
253
+ # Valid mask: both query and keys must be motif
254
+ key_is_motif = is_motif[key_indices] # [B, k]
255
+ valid_mask = key_is_motif.unsqueeze(-1).float() # [B, k, 1]
256
+
257
+ if valid_mask.sum() > 0:
258
+ motif_pairs = self.motif_pos_embedder.compute_pairs_chunked(
259
+ query_pos, key_pos, valid_mask
260
+ )
261
+ P_LL_sparse[:, l, :, :] += motif_pairs
263
262
 
264
263
  # 2. Reference position embedding (if exists)
265
264
  if self.ref_pos_embedder is not None and "ref_pos" in f:
266
265
  ref_pos = f["ref_pos"] # [L, 3]
267
266
  ref_space_uid = f["ref_space_uid"] # [L]
268
267
  is_motif_seq = f["is_motif_atom_with_fixed_seq"] # [L]
269
-
270
- for l in range(L):
271
- if is_motif_seq[l]: # Only compute if query has sequence
272
- key_indices = valid_indices[:, l, :] # [D, k] - use clamped indices
273
- key_pos = ref_pos[key_indices] # [D, k, 3]
274
- query_pos = ref_pos[l].unsqueeze(0).expand(D, -1) # [D, 3]
275
-
276
- # Valid mask: same token and both have sequence
277
- key_space_uid = ref_space_uid[key_indices] # [D, k]
278
- key_is_motif_seq = is_motif_seq[key_indices] # [D, k]
279
-
280
- same_token = key_space_uid == ref_space_uid[l] # [D, k]
281
- valid_mask = (
282
- (same_token & key_is_motif_seq).unsqueeze(-1).float()
283
- ) # [D, k, 1]
284
-
285
- if valid_mask.sum() > 0:
286
- ref_pairs = self.ref_pos_embedder.compute_pairs_chunked(
287
- query_pos, key_pos, valid_mask
288
- )
289
- P_LL_sparse[:, l, :, :] += ref_pairs
268
+ is_motif_seq_idx = torch.where(is_motif_seq)[0]
269
+ for l in is_motif_seq_idx:
270
+ key_indices = valid_indices[:, l, :] # [B, k] - use clamped indices
271
+ key_pos = ref_pos[key_indices] # [B, k, 3]
272
+ query_pos = ref_pos[l].unsqueeze(0).expand(B, -1) # [B, 3]
273
+
274
+ # Valid mask: same token and both have sequence
275
+ key_space_uid = ref_space_uid[key_indices] # [B, k]
276
+ key_is_motif_seq = is_motif_seq[key_indices] # [B, k]
277
+
278
+ same_token = key_space_uid == ref_space_uid[l] # [B, k]
279
+ valid_mask = (
280
+ (same_token & key_is_motif_seq).unsqueeze(-1).float()
281
+ ) # [B, k, 1]
282
+
283
+ if valid_mask.sum() > 0:
284
+ ref_pairs = self.ref_pos_embedder.compute_pairs_chunked(
285
+ query_pos, key_pos, valid_mask
286
+ )
287
+ P_LL_sparse[:, l, :, :] += ref_pairs
290
288
 
291
289
  # 3. Single embedding terms (broadcasted)
290
+ # Expand C_L to match valid_indices batch dimension
291
+ if C_L.shape[0] != B:
292
+ C_L = C_L.expand(B, -1, -1) # [B, L, c_token]
292
293
  # Gather key features for each query
294
+ C_L_queries = C_L.unsqueeze(2).expand(-1, -1, k, -1) # [B, L, k, c_token]
293
295
  C_L_keys = torch.gather(
294
- C_L.unsqueeze(2).expand(-1, -1, k, -1),
296
+ C_L_queries,
295
297
  1,
296
298
  valid_indices.unsqueeze(-1).expand(-1, -1, -1, C_L.shape[-1]),
297
- ) # [D, L, k, c_token]
298
- C_L_queries = C_L.unsqueeze(2).expand(-1, -1, k, -1) # [D, L, k, c_token]
299
+ ) # [B, L, k, c_token]
299
300
 
300
301
  # Add single embeddings - match standard implementation structure
301
302
  # Standard does: self.process_single_l(C_L).unsqueeze(-2) + self.process_single_m(C_L).unsqueeze(-3)
302
- # We need to broadcast from [D, L, k, c_atompair] to match this
303
- single_l = self.process_single_l(C_L_queries) # [D, L, k, c_atompair]
304
- single_m = self.process_single_m(C_L_keys) # [D, L, k, c_atompair]
303
+ # We need to broadcast from [B, L, k, c_atompair] to match this
304
+ single_l = self.process_single_l(C_L_queries) # [B, L, k, c_atompair]
305
+ single_m = self.process_single_m(C_L_keys) # [B, L, k, c_atompair]
305
306
  P_LL_sparse += single_l + single_m
306
307
 
307
308
  # 4. Token pair features Z_init_II
@@ -312,15 +313,16 @@ class ChunkedPairwiseEmbedder(nn.Module):
312
313
  else:
313
314
  tok_idx_expanded = tok_idx
314
315
 
315
- tok_queries = tok_idx_expanded.unsqueeze(2).expand(-1, -1, k) # [D, L, k]
316
+ # Expand tok_idx_expanded to match valid_indices batch dimension
317
+ if tok_idx_expanded.shape[0] != B:
318
+ tok_idx_expanded = tok_idx_expanded.expand(B, -1) # [B, L]
319
+ tok_queries = tok_idx_expanded.unsqueeze(2).expand(-1, -1, k) # [B, L, k]
316
320
  # Use valid_indices for token mapping as well
317
- tok_keys = torch.gather(
318
- tok_idx_expanded.unsqueeze(2).expand(-1, -1, k), 1, valid_indices
319
- ) # [D, L, k]
321
+ tok_keys = torch.gather(tok_queries, 1, valid_indices) # [B, L, k]
320
322
 
321
323
  # Gather Z_init_II[tok_queries, tok_keys] with safe indexing
322
324
  # Z_init_II shape is [I, I, c_z] (3D), not 4D
323
- # tok_queries shape: [D, L, k] - each value is a token index
325
+ # tok_queries shape: [B, L, k] - each value is a token index
324
326
  # We want: Z_init_II[tok_queries[d,l,k], tok_keys[d,l,k], :] for all d,l,k
325
327
 
326
328
  I_z, I_z2, c_z = Z_init_II.shape
@@ -338,20 +340,20 @@ class ChunkedPairwiseEmbedder(nn.Module):
338
340
  # Then we need to gather the sparse version
339
341
 
340
342
  Z_pairs_processed = torch.zeros(
341
- D, L, k, self.c_atompair, device=device, dtype=Z_processed.dtype
343
+ B, L, k, self.c_atompair, device=device, dtype=Z_processed.dtype
342
344
  )
343
345
 
344
- for d in range(D):
346
+ for b in range(B):
345
347
  # For this batch, get the token queries and keys
346
- tq = tok_queries[d] # [L, k]
347
- tk = tok_keys[d] # [L, k]
348
+ tq = tok_queries[b] # [L, k]
349
+ tk = tok_keys[b] # [L, k]
348
350
 
349
351
  # Ensure indices are within bounds
350
352
  tq = torch.clamp(tq, 0, I_z - 1)
351
353
  tk = torch.clamp(tk, 0, I_z2 - 1)
352
354
 
353
355
  # Apply the double token indexing like standard implementation
354
- Z_pairs_processed[d] = Z_processed[tq, tk] # [L, k, c_atompair]
356
+ Z_pairs_processed[b] = Z_processed[tq, tk] # [L, k, c_atompair]
355
357
 
356
358
  P_LL_sparse += Z_pairs_processed
357
359
 
@@ -60,13 +60,22 @@ class AddSymmetryFeats(Transform):
60
60
  )
61
61
  TIDs = torch.from_numpy(atom_array.get_annotation("sym_transform_id"))
62
62
 
63
- Oris = torch.unique_consecutive(Oris, dim=0)
64
- Xs = torch.unique_consecutive(Xs, dim=0)
65
- Ys = torch.unique_consecutive(Ys, dim=0)
66
- TIDs = torch.unique_consecutive(TIDs, dim=0)
67
- # the case in which there is only rotation (no translation), Ori = [0,0,0]
68
- if len(Oris) == 1 and (Oris == 0).all():
69
- Oris = Oris.repeat(len(Xs), 1)
63
+ # Get unique transforms by TID (more robust than unique_consecutive on each array)
64
+ unique_TIDs, inverse_indices = torch.unique(TIDs, return_inverse=True)
65
+
66
+ # Get the first occurrence of each unique TID
67
+ first_occurrence = torch.zeros(len(unique_TIDs), dtype=torch.long)
68
+ for i in range(len(TIDs)):
69
+ tid_idx = inverse_indices[i]
70
+ if first_occurrence[tid_idx] == 0 or i < first_occurrence[tid_idx]:
71
+ first_occurrence[tid_idx] = i
72
+
73
+ # Extract Ori, X, Y for each unique transform
74
+ Oris = Oris[first_occurrence]
75
+ Xs = Xs[first_occurrence]
76
+ Ys = Ys[first_occurrence]
77
+ TIDs = unique_TIDs
78
+
70
79
  Rs, Ts = framecoords_to_RTs(Oris, Xs, Ys)
71
80
 
72
81
  for R, T, transform_id in zip(Rs, Ts, TIDs):
rfd3/utils/inference.py CHANGED
@@ -3,7 +3,6 @@ Utilities for inference input preparation
3
3
  """
4
4
 
5
5
  import logging
6
- import os
7
6
  from os import PathLike
8
7
  from typing import Dict
9
8
 
@@ -365,32 +364,6 @@ def inference_load_(
365
364
  return data
366
365
 
367
366
 
368
- def ensure_input_is_abspath(args: dict, path: PathLike | None):
369
- """
370
- Ensures the input source is an absolute path if exists, if not it will convert
371
-
372
- args:
373
- spec: Inference specification for atom array
374
- path: None or file to which the input is relative to.
375
- """
376
- if isinstance(args, str):
377
- raise ValueError(
378
- "Expected args to be a dictionary, got a string: {}. If you are using an input JSON ensure it contains dictionaries of arguments".format(
379
- args
380
- )
381
- )
382
- if "input" not in args or not exists(args["input"]):
383
- return args
384
- input = args["input"]
385
- if not os.path.isabs(input):
386
- input = os.path.abspath(os.path.join(os.path.dirname(path), input))
387
- ranked_logger.info(
388
- f"Input source path is relative, converted to absolute path: {input}"
389
- )
390
- args["input"] = input
391
- return args
392
-
393
-
394
367
  def ensure_inference_sampler_matches_design_spec(
395
368
  design_spec: dict, inference_sampler: dict | None = None
396
369
  ):
@@ -401,7 +374,10 @@ def ensure_inference_sampler_matches_design_spec(
401
374
  inference_sampler: Inference sampler dictionary
402
375
  """
403
376
  has_symmetry_specification = [
404
- True if "symmetry" in item.keys() else False for item in design_spec.values()
377
+ True
378
+ if "symmetry" in item.keys() and item.get("symmetry") is not None
379
+ else False
380
+ for item in design_spec.values()
405
381
  ]
406
382
  if any(has_symmetry_specification):
407
383
  if (