omnigenome 0.3.0a1__py3-none-any.whl → 0.3.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. omnigenome/__init__.py +16 -8
  2. omnigenome/auto/auto_bench/__init__.py +0 -1
  3. omnigenome/auto/auto_bench/auto_bench.py +24 -14
  4. omnigenome/auto/auto_train/__init__.py +0 -1
  5. omnigenome/auto/auto_train/auto_train.py +11 -12
  6. omnigenome/auto/bench_hub/__init__.py +0 -1
  7. omnigenome/auto/bench_hub/bench_hub.py +1 -1
  8. omnigenome/cli/__init__.py +0 -1
  9. omnigenome/cli/commands/__init__.py +0 -1
  10. omnigenome/cli/commands/base.py +10 -10
  11. omnigenome/cli/commands/bench/__init__.py +0 -1
  12. omnigenome/cli/commands/bench/bench_cli.py +10 -10
  13. omnigenome/cli/commands/rna/__init__.py +0 -1
  14. omnigenome/cli/commands/rna/rna_design.py +10 -11
  15. omnigenome/src/__init__.py +0 -1
  16. omnigenome/src/abc/__init__.py +0 -1
  17. omnigenome/src/abc/abstract_dataset.py +38 -19
  18. omnigenome/src/abc/abstract_metric.py +7 -7
  19. omnigenome/src/abc/abstract_model.py +15 -14
  20. omnigenome/src/abc/abstract_tokenizer.py +9 -7
  21. omnigenome/src/dataset/omni_dataset.py +16 -14
  22. omnigenome/src/lora/__init__.py +0 -1
  23. omnigenome/src/lora/lora_model.py +47 -41
  24. omnigenome/src/metric/classification_metric.py +11 -11
  25. omnigenome/src/metric/metric.py +19 -19
  26. omnigenome/src/metric/ranking_metric.py +15 -15
  27. omnigenome/src/metric/regression_metric.py +18 -18
  28. omnigenome/src/misc/utils.py +40 -36
  29. omnigenome/src/model/augmentation/__init__.py +0 -1
  30. omnigenome/src/model/augmentation/model.py +17 -17
  31. omnigenome/src/model/classification/__init__.py +0 -1
  32. omnigenome/src/model/classification/model.py +28 -32
  33. omnigenome/src/model/embedding/__init__.py +0 -1
  34. omnigenome/src/model/embedding/model.py +35 -35
  35. omnigenome/src/model/mlm/__init__.py +0 -1
  36. omnigenome/src/model/mlm/model.py +13 -13
  37. omnigenome/src/model/module_utils.py +17 -17
  38. omnigenome/src/model/regression/__init__.py +0 -1
  39. omnigenome/src/model/regression/model.py +72 -77
  40. omnigenome/src/model/regression/resnet.py +32 -32
  41. omnigenome/src/model/rna_design/__init__.py +0 -1
  42. omnigenome/src/model/rna_design/model.py +65 -58
  43. omnigenome/src/model/seq2seq/__init__.py +0 -1
  44. omnigenome/src/model/seq2seq/model.py +4 -4
  45. omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
  46. omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
  47. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
  48. omnigenome/src/trainer/accelerate_trainer.py +40 -32
  49. omnigenome/src/trainer/hf_trainer.py +8 -8
  50. omnigenome/src/trainer/trainer.py +37 -25
  51. omnigenome/utility/dataset_hub/__init__.py +0 -1
  52. omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
  53. omnigenome/utility/ensemble.py +26 -26
  54. omnigenome/utility/hub_utils.py +8 -8
  55. omnigenome/utility/model_hub/__init__.py +0 -1
  56. omnigenome/utility/model_hub/model_hub.py +26 -25
  57. omnigenome/utility/pipeline_hub/__init__.py +0 -1
  58. omnigenome/utility/pipeline_hub/pipeline.py +49 -49
  59. omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
  60. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +2 -2
  61. omnigenome-0.3.1a0.dist-info/RECORD +78 -0
  62. omnigenome-0.3.0a1.dist-info/RECORD +0 -78
  63. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
  64. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
  65. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
  66. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -0
@@ -23,14 +23,14 @@ from typing import Type, Callable, Union, List, Optional
23
23
  def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
24
24
  """
25
25
  3x3 convolution with padding.
26
-
26
+
27
27
  Args:
28
28
  in_planes (int): Number of input channels
29
29
  out_planes (int): Number of output channels
30
30
  stride (int): Stride for the convolution (default: 1)
31
31
  groups (int): Number of groups for grouped convolution (default: 1)
32
32
  dilation (int): Dilation factor for the convolution (default: 1)
33
-
33
+
34
34
  Returns:
35
35
  nn.Conv2d: 3x3 convolution layer
36
36
  """
@@ -49,12 +49,12 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
49
49
  def conv1x1(in_planes, out_planes, stride=1):
50
50
  """
51
51
  1x1 convolution.
52
-
52
+
53
53
  Args:
54
54
  in_planes (int): Number of input channels
55
55
  out_planes (int): Number of output channels
56
56
  stride (int): Stride for the convolution (default: 1)
57
-
57
+
58
58
  Returns:
59
59
  nn.Conv2d: 1x1 convolution layer
60
60
  """
@@ -64,14 +64,14 @@ def conv1x1(in_planes, out_planes, stride=1):
64
64
  def conv5x5(in_planes, out_planes, stride=1, groups=1, dilation=1):
65
65
  """
66
66
  5x5 convolution with padding.
67
-
67
+
68
68
  Args:
69
69
  in_planes (int): Number of input channels
70
70
  out_planes (int): Number of output channels
71
71
  stride (int): Stride for the convolution (default: 1)
72
72
  groups (int): Number of groups for grouped convolution (default: 1)
73
73
  dilation (int): Dilation factor for the convolution (default: 1)
74
-
74
+
75
75
  Returns:
76
76
  nn.Conv2d: 5x5 convolution layer
77
77
  """
@@ -90,10 +90,10 @@ def conv5x5(in_planes, out_planes, stride=1, groups=1, dilation=1):
90
90
  class BasicBlock(nn.Module):
91
91
  """
92
92
  Basic ResNet block for genomic sequence processing.
93
-
93
+
94
94
  This block implements a basic residual connection with two convolutions
95
95
  and is optimized for processing genomic sequence data with layer normalization.
96
-
96
+
97
97
  Attributes:
98
98
  expansion (int): Expansion factor for the block (default: 1)
99
99
  conv1: First 3x3 convolution layer
@@ -105,7 +105,7 @@ class BasicBlock(nn.Module):
105
105
  downsample: Downsampling layer for residual connection
106
106
  stride: Stride for the convolutions
107
107
  """
108
-
108
+
109
109
  expansion: int = 1
110
110
 
111
111
  def __init__(
@@ -121,7 +121,7 @@ class BasicBlock(nn.Module):
121
121
  ) -> None:
122
122
  """
123
123
  Initialize the BasicBlock.
124
-
124
+
125
125
  Args:
126
126
  inplanes (int): Number of input channels
127
127
  planes (int): Number of output channels
@@ -130,7 +130,7 @@ class BasicBlock(nn.Module):
130
130
  groups (int): Number of groups for grouped convolution (default: 1)
131
131
  dilation (int): Dilation factor for convolutions (default: 1)
132
132
  norm_layer: Normalization layer type (default: None, uses LayerNorm)
133
-
133
+
134
134
  Raises:
135
135
  NotImplementedError: If dilation > 1 is specified
136
136
  """
@@ -154,10 +154,10 @@ class BasicBlock(nn.Module):
154
154
  def forward(self, x: Tensor) -> Tensor:
155
155
  """
156
156
  Forward pass through the BasicBlock.
157
-
157
+
158
158
  Args:
159
159
  x (Tensor): Input tensor [batch_size, channels, height, width]
160
-
160
+
161
161
  Returns:
162
162
  Tensor: Output tensor with same shape as input
163
163
  """
@@ -188,11 +188,11 @@ class BasicBlock(nn.Module):
188
188
  class Bottleneck(nn.Module):
189
189
  """
190
190
  Bottleneck ResNet block for genomic sequence processing.
191
-
191
+
192
192
  This block implements a bottleneck residual connection with three convolutions
193
193
  (1x1, 3x3, 1x1) and is designed for deeper networks. It's adapted from
194
194
  the original ResNet V1.5 implementation.
195
-
195
+
196
196
  Attributes:
197
197
  expansion (int): Expansion factor for the block (default: 4)
198
198
  conv1: First 1x1 convolution layer
@@ -205,7 +205,7 @@ class Bottleneck(nn.Module):
205
205
  downsample: Downsampling layer for residual connection
206
206
  stride: Stride for the convolutions
207
207
  """
208
-
208
+
209
209
  # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
210
210
  # while original implementation places the stride at the first 1x1 convolution(self.conv1)
211
211
  # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
@@ -227,7 +227,7 @@ class Bottleneck(nn.Module):
227
227
  ) -> None:
228
228
  """
229
229
  Initialize the Bottleneck block.
230
-
230
+
231
231
  Args:
232
232
  inplanes (int): Number of input channels
233
233
  planes (int): Number of output channels
@@ -256,10 +256,10 @@ class Bottleneck(nn.Module):
256
256
  def forward(self, x: Tensor) -> Tensor:
257
257
  """
258
258
  Forward pass through the Bottleneck block.
259
-
259
+
260
260
  Args:
261
261
  x (Tensor): Input tensor [batch_size, channels, height, width]
262
-
262
+
263
263
  Returns:
264
264
  Tensor: Output tensor with same shape as input
265
265
  """
@@ -288,11 +288,11 @@ class Bottleneck(nn.Module):
288
288
  class ResNet(nn.Module):
289
289
  """
290
290
  ResNet architecture adapted for genomic sequence analysis.
291
-
291
+
292
292
  This ResNet implementation is specifically designed for processing genomic
293
293
  sequences and their structural representations. It uses layer normalization
294
294
  instead of batch normalization and is optimized for genomic data characteristics.
295
-
295
+
296
296
  Attributes:
297
297
  _norm_layer: Normalization layer type
298
298
  inplanes: Number of input channels for the first layer
@@ -319,7 +319,7 @@ class ResNet(nn.Module):
319
319
  ) -> None:
320
320
  """
321
321
  Initialize the ResNet architecture.
322
-
322
+
323
323
  Args:
324
324
  channels (int): Number of input channels
325
325
  block: Type of ResNet block (BasicBlock or Bottleneck)
@@ -329,7 +329,7 @@ class ResNet(nn.Module):
329
329
  width_per_group (int): Width per group for bottleneck blocks (default: 1)
330
330
  replace_stride_with_dilation: Whether to replace stride with dilation (default: None)
331
331
  norm_layer: Normalization layer type (default: None, uses LayerNorm)
332
-
332
+
333
333
  Raises:
334
334
  ValueError: If replace_stride_with_dilation is not None or a 3-element tuple
335
335
  """
@@ -379,14 +379,14 @@ class ResNet(nn.Module):
379
379
  ) -> nn.Sequential:
380
380
  """
381
381
  Create a layer of ResNet blocks.
382
-
382
+
383
383
  Args:
384
384
  block: Type of ResNet block to use
385
385
  planes (int): Number of output channels for the layer
386
386
  blocks (int): Number of blocks in the layer
387
387
  stride (int): Stride for the first block (default: 1)
388
388
  dilate (bool): Whether to use dilation (default: False)
389
-
389
+
390
390
  Returns:
391
391
  nn.Sequential: Sequential container of ResNet blocks
392
392
  """
@@ -433,10 +433,10 @@ class ResNet(nn.Module):
433
433
  def _forward_impl(self, x: Tensor) -> Tensor:
434
434
  """
435
435
  Forward pass implementation.
436
-
436
+
437
437
  Args:
438
438
  x (Tensor): Input tensor [batch_size, channels, height, width]
439
-
439
+
440
440
  Returns:
441
441
  Tensor: Output tensor after processing through ResNet
442
442
  """
@@ -456,10 +456,10 @@ class ResNet(nn.Module):
456
456
  def forward(self, x: Tensor) -> Tensor:
457
457
  """
458
458
  Forward pass through the ResNet.
459
-
459
+
460
460
  Args:
461
461
  x (Tensor): Input tensor [batch_size, channels, height, width]
462
-
462
+
463
463
  Returns:
464
464
  Tensor: Output tensor after processing through ResNet
465
465
  """
@@ -469,14 +469,14 @@ class ResNet(nn.Module):
469
469
  def resnet_b16(channels=128, bbn=16):
470
470
  """
471
471
  Create a ResNet-B16 model for genomic sequence analysis.
472
-
472
+
473
473
  This function creates a ResNet model with 16 basic blocks, optimized
474
474
  for processing genomic sequences and their structural representations.
475
-
475
+
476
476
  Args:
477
477
  channels (int): Number of input channels (default: 128)
478
478
  bbn (int): Number of basic blocks (default: 16)
479
-
479
+
480
480
  Returns:
481
481
  ResNet: Configured ResNet model
482
482
  """
@@ -9,4 +9,3 @@
9
9
  """
10
10
  This package contains modules for RNA design models.
11
11
  """
12
-
@@ -30,19 +30,19 @@ from omnigenome.src.misc.utils import fprint
30
30
  class OmniModelForRNADesign(torch.nn.Module):
31
31
  """
32
32
  RNA design model using masked language modeling and evolutionary algorithms.
33
-
33
+
34
34
  This model combines a pre-trained masked language model with evolutionary
35
35
  algorithms to design RNA sequences that fold into specific target structures.
36
36
  It uses a multi-objective optimization approach to balance structure similarity
37
37
  and thermodynamic stability.
38
-
38
+
39
39
  Attributes:
40
40
  device: Device to run the model on (CPU or GPU)
41
41
  parallel: Whether to use parallel processing for structure prediction
42
42
  tokenizer: Tokenizer for processing RNA sequences
43
43
  model: Pre-trained masked language model
44
44
  """
45
-
45
+
46
46
  def __init__(
47
47
  self,
48
48
  model="yangheng/OmniGenome-186M",
@@ -53,7 +53,7 @@ class OmniModelForRNADesign(torch.nn.Module):
53
53
  ):
54
54
  """
55
55
  Initialize the RNA design model.
56
-
56
+
57
57
  Args:
58
58
  model (str): Model name or path for the pre-trained MLM model
59
59
  device: Device to run the model on (default: None, auto-detect)
@@ -72,10 +72,10 @@ class OmniModelForRNADesign(torch.nn.Module):
72
72
  def _random_bp_span(bp_span=None):
73
73
  """
74
74
  Generate a random base pair span.
75
-
75
+
76
76
  Args:
77
77
  bp_span (int, optional): Fixed base pair span. If None, generates random.
78
-
78
+
79
79
  Returns:
80
80
  int: Base pair span value
81
81
  """
@@ -87,16 +87,16 @@ class OmniModelForRNADesign(torch.nn.Module):
87
87
  def _longest_bp_span(structure):
88
88
  """
89
89
  Find the longest base pair span in the structure.
90
-
90
+
91
91
  Args:
92
92
  structure (str): RNA structure in dot-bracket notation
93
-
93
+
94
94
  Returns:
95
95
  int: Length of the longest base pair span
96
96
  """
97
97
  max_span = 0
98
98
  current_span = 0
99
-
99
+
100
100
  for char in structure:
101
101
  if char == "(":
102
102
  current_span += 1
@@ -105,18 +105,18 @@ class OmniModelForRNADesign(torch.nn.Module):
105
105
  current_span = max(0, current_span - 1)
106
106
  else:
107
107
  current_span = 0
108
-
108
+
109
109
  return max_span
110
110
 
111
111
  @staticmethod
112
112
  def _predict_structure_single(sequence, bp_span=-1):
113
113
  """
114
114
  Predict structure for a single sequence (worker function for multiprocessing).
115
-
115
+
116
116
  Args:
117
117
  sequence (str): RNA sequence to fold
118
118
  bp_span (int): Base pair span parameter
119
-
119
+
120
120
  Returns:
121
121
  tuple: (structure, mfe) tuple
122
122
  """
@@ -129,30 +129,30 @@ class OmniModelForRNADesign(torch.nn.Module):
129
129
  def _predict_structure(self, sequences, bp_span=-1):
130
130
  """
131
131
  Predict structures for multiple sequences.
132
-
132
+
133
133
  Args:
134
134
  sequences (list): List of RNA sequences
135
135
  bp_span (int): Base pair span parameter
136
-
136
+
137
137
  Returns:
138
138
  list: List of (structure, mfe) tuples
139
139
  """
140
140
  if not self.parallel or len(sequences) <= 1:
141
141
  # Sequential processing
142
142
  return [self._predict_structure_single(seq, bp_span) for seq in sequences]
143
-
143
+
144
144
  # Parallel processing with improved error handling
145
145
  try:
146
146
  # Determine number of workers
147
147
  max_workers = min(os.cpu_count(), len(sequences), 8) # Limit to 8 workers
148
-
148
+
149
149
  with ProcessPoolExecutor(max_workers=max_workers) as executor:
150
150
  # Submit all tasks
151
151
  future_to_seq = {
152
- executor.submit(self._predict_structure_single, seq, bp_span): seq
152
+ executor.submit(self._predict_structure_single, seq, bp_span): seq
153
153
  for seq in sequences
154
154
  }
155
-
155
+
156
156
  # Collect results
157
157
  results = []
158
158
  for future in as_completed(future_to_seq):
@@ -164,112 +164,119 @@ class OmniModelForRNADesign(torch.nn.Module):
164
164
  warnings.warn(f"Failed to process sequence {seq}: {e}")
165
165
  # Fallback to dot structure
166
166
  results.append(("." * len(seq), 0.0))
167
-
167
+
168
168
  return results
169
-
169
+
170
170
  except Exception as e:
171
- warnings.warn(f"Parallel processing failed, falling back to sequential: {e}")
171
+ warnings.warn(
172
+ f"Parallel processing failed, falling back to sequential: {e}"
173
+ )
172
174
  # Fallback to sequential processing
173
175
  return [self._predict_structure_single(seq, bp_span) for seq in sequences]
174
176
 
175
177
  def _init_population(self, structure, num_population):
176
178
  """
177
179
  Initialize the population with random sequences.
178
-
180
+
179
181
  Args:
180
182
  structure (str): Target RNA structure
181
183
  num_population (int): Population size
182
-
184
+
183
185
  Returns:
184
186
  list: List of (sequence, bp_span) tuples
185
187
  """
186
188
  population = []
187
189
  bp_span = self._longest_bp_span(structure)
188
-
190
+
189
191
  for _ in range(num_population):
190
192
  # Generate random sequence
191
193
  sequence = "".join(random.choice("ACGU") for _ in range(len(structure)))
192
194
  population.append((sequence, bp_span))
193
-
195
+
194
196
  return population
195
197
 
196
198
  def _mlm_mutate(self, population, structure, mutation_ratio):
197
199
  """
198
200
  Mutate population using masked language modeling.
199
-
201
+
200
202
  Args:
201
203
  population (list): Current population
202
204
  structure (str): Target RNA structure
203
205
  mutation_ratio (float): Ratio of tokens to mutate
204
-
206
+
205
207
  Returns:
206
208
  list: Mutated population
207
209
  """
210
+
208
211
  def mutate(sequence, mutation_rate):
209
212
  # Create masked sequence
210
213
  masked_sequence = list(sequence)
211
214
  num_mutations = int(len(sequence) * mutation_rate)
212
215
  mutation_positions = random.sample(range(len(sequence)), num_mutations)
213
-
216
+
214
217
  for pos in mutation_positions:
215
218
  masked_sequence[pos] = self.tokenizer.mask_token
216
-
219
+
217
220
  return "".join(masked_sequence)
218
-
221
+
219
222
  # Prepare inputs for MLM
220
223
  mlm_inputs = []
221
224
  for sequence, bp_span in population:
222
225
  masked_seq = mutate(sequence, mutation_ratio)
223
226
  mlm_inputs.append(masked_seq)
224
-
227
+
225
228
  # Get predictions from MLM
226
229
  predicted_tokens = self._mlm_predict(mlm_inputs, structure)
227
-
230
+
228
231
  # Convert predictions back to sequences
229
232
  mutated_population = []
230
233
  for i, (sequence, bp_span) in enumerate(population):
231
234
  # Convert token IDs back to nucleotides
232
- new_sequence = self.tokenizer.decode(predicted_tokens[i], skip_special_tokens=True)
235
+ new_sequence = self.tokenizer.decode(
236
+ predicted_tokens[i], skip_special_tokens=True
237
+ )
233
238
  # Ensure the sequence has the correct length
234
239
  if len(new_sequence) != len(structure):
235
- new_sequence = new_sequence[:len(structure)].ljust(len(structure), "A")
240
+ new_sequence = new_sequence[: len(structure)].ljust(len(structure), "A")
236
241
  mutated_population.append((new_sequence, bp_span))
237
-
242
+
238
243
  return mutated_population
239
244
 
240
245
  def _crossover(self, population, num_points=3):
241
246
  """
242
247
  Perform crossover operation on the population.
243
-
248
+
244
249
  Args:
245
250
  population (list): Current population
246
251
  num_points (int): Number of crossover points
247
-
252
+
248
253
  Returns:
249
254
  list: Population after crossover
250
255
  """
251
256
  if len(population) < 2:
252
257
  return population
253
-
258
+
254
259
  # Create crossover masks
255
260
  num_sequences = len(population)
256
261
  masks = np.zeros((num_sequences, len(population[0][0])), dtype=bool)
257
-
262
+
258
263
  # Generate random crossover points
259
- crossover_points = np.random.randint(0, len(population[0][0]), (num_sequences, num_points))
260
-
264
+ crossover_points = np.random.randint(
265
+ 0, len(population[0][0]), (num_sequences, num_points)
266
+ )
267
+
261
268
  # Create parent indices
262
269
  parent_indices = np.random.randint(0, num_sequences, (num_sequences, 2))
263
-
270
+
264
271
  # Generate crossover masks
265
272
  for i in range(num_sequences):
266
273
  for j in range(num_points):
267
274
  if j == 0:
268
- masks[i, :crossover_points[i, j]] = True
275
+ masks[i, : crossover_points[i, j]] = True
269
276
  else:
270
- last_point = crossover_points[i, j-1]
271
- masks[i, last_point:crossover_points[i, j]] = j % 2 == 0
272
-
277
+ last_point = crossover_points[i, j - 1]
278
+ masks[i, last_point : crossover_points[i, j]] = j % 2 == 0
279
+
273
280
  # Handle the last segment
274
281
  last_point = crossover_points[i, -1]
275
282
  masks[i, last_point:] = num_points % 2 == 0
@@ -298,17 +305,17 @@ class OmniModelForRNADesign(torch.nn.Module):
298
305
  def _evaluate_structure_fitness(self, sequences, structure):
299
306
  """
300
307
  Evaluate the fitness of the RNA structure by comparing with the target structure.
301
-
308
+
302
309
  Args:
303
310
  sequences (list): List of (sequence, bp_span) tuples to evaluate
304
311
  structure (str): Target RNA structure
305
-
312
+
306
313
  Returns:
307
314
  list: Sorted population with fitness scores and MFE values
308
315
  """
309
316
  # Get sequences for structure prediction
310
317
  seq_list = [seq for seq, _ in sequences]
311
-
318
+
312
319
  # Predict structures (with improved multiprocessing)
313
320
  structures_mfe = self._predict_structure(seq_list)
314
321
 
@@ -326,11 +333,11 @@ class OmniModelForRNADesign(torch.nn.Module):
326
333
  def _non_dominated_sorting(scores, mfe_values):
327
334
  """
328
335
  Perform non-dominated sorting for multi-objective optimization.
329
-
336
+
330
337
  Args:
331
338
  scores (list): Structure similarity scores
332
339
  mfe_values (list): Minimum free energy values
333
-
340
+
334
341
  Returns:
335
342
  list: List of fronts (Pareto fronts)
336
343
  """
@@ -369,11 +376,11 @@ class OmniModelForRNADesign(torch.nn.Module):
369
376
  def _select_next_generation(next_generation, fronts):
370
377
  """
371
378
  Select the next generation based on Pareto fronts.
372
-
379
+
373
380
  Args:
374
381
  next_generation (list): Current population with fitness scores
375
382
  fronts (list): Pareto fronts
376
-
383
+
377
384
  Returns:
378
385
  list: Selected population for the next generation
379
386
  """
@@ -389,11 +396,11 @@ class OmniModelForRNADesign(torch.nn.Module):
389
396
  def _mlm_predict(self, mlm_inputs, structure):
390
397
  """
391
398
  Perform masked language model prediction.
392
-
399
+
393
400
  Args:
394
401
  mlm_inputs (list): List of masked input sequences
395
402
  structure (str): Target RNA structure
396
-
403
+
397
404
  Returns:
398
405
  list: Predicted token IDs for each input
399
406
  """
@@ -403,7 +410,7 @@ class OmniModelForRNADesign(torch.nn.Module):
403
410
  with torch.no_grad():
404
411
  for i in range(0, len(mlm_inputs), batch_size):
405
412
  inputs = self.tokenizer(
406
- mlm_inputs[i: i + batch_size],
413
+ mlm_inputs[i : i + batch_size],
407
414
  padding=False,
408
415
  max_length=1024,
409
416
  truncation=True,
@@ -422,13 +429,13 @@ class OmniModelForRNADesign(torch.nn.Module):
422
429
  ):
423
430
  """
424
431
  Design RNA sequences for a target structure using evolutionary algorithms.
425
-
432
+
426
433
  Args:
427
434
  structure (str): Target RNA structure in dot-bracket notation
428
435
  mutation_ratio (float): Ratio of tokens to mutate (default: 0.5)
429
436
  num_population (int): Population size (default: 100)
430
437
  num_generation (int): Number of generations (default: 100)
431
-
438
+
432
439
  Returns:
433
440
  list: List of designed RNA sequences with their fitness scores
434
441
  """
@@ -9,4 +9,3 @@
9
9
  """
10
10
  This package contains modules for sequence-to-sequence models.
11
11
  """
12
-
@@ -21,20 +21,20 @@ from ...abc.abstract_model import OmniModel
21
21
  class OmniModelForSeq2Seq(OmniModel):
22
22
  """
23
23
  Sequence-to-sequence model for genomic sequences.
24
-
24
+
25
25
  This model implements a sequence-to-sequence architecture for genomic
26
26
  sequences, where the input is one sequence and the output is another
27
27
  sequence. It's useful for tasks like sequence translation, structure
28
28
  prediction, or sequence transformation.
29
-
29
+
30
30
  The model can be extended to implement specific seq2seq tasks by
31
31
  overriding the forward, predict, and inference methods.
32
32
  """
33
-
33
+
34
34
  def __init__(self, config_or_model, tokenizer, *args, **kwargs):
35
35
  """
36
36
  Initialize the sequence-to-sequence model.
37
-
37
+
38
38
  Args:
39
39
  config_or_model: Model configuration or pre-trained model
40
40
  tokenizer: Tokenizer for processing input sequences