omnigenome 0.3.0a0__py3-none-any.whl → 0.3.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. omnigenome/__init__.py +29 -44
  2. omnigenome/auto/auto_bench/__init__.py +0 -1
  3. omnigenome/auto/auto_bench/auto_bench.py +24 -14
  4. omnigenome/auto/auto_train/__init__.py +0 -1
  5. omnigenome/auto/auto_train/auto_train.py +11 -12
  6. omnigenome/auto/bench_hub/__init__.py +0 -1
  7. omnigenome/auto/bench_hub/bench_hub.py +1 -1
  8. omnigenome/cli/__init__.py +0 -1
  9. omnigenome/cli/commands/__init__.py +0 -1
  10. omnigenome/cli/commands/base.py +10 -10
  11. omnigenome/cli/commands/bench/__init__.py +0 -1
  12. omnigenome/cli/commands/bench/bench_cli.py +10 -10
  13. omnigenome/cli/commands/rna/__init__.py +0 -1
  14. omnigenome/cli/commands/rna/rna_design.py +10 -11
  15. omnigenome/src/__init__.py +0 -1
  16. omnigenome/src/abc/__init__.py +0 -1
  17. omnigenome/src/abc/abstract_dataset.py +38 -19
  18. omnigenome/src/abc/abstract_metric.py +7 -7
  19. omnigenome/src/abc/abstract_model.py +15 -14
  20. omnigenome/src/abc/abstract_tokenizer.py +9 -7
  21. omnigenome/src/dataset/omni_dataset.py +16 -14
  22. omnigenome/src/lora/__init__.py +0 -1
  23. omnigenome/src/lora/lora_model.py +47 -41
  24. omnigenome/src/metric/classification_metric.py +11 -11
  25. omnigenome/src/metric/metric.py +19 -19
  26. omnigenome/src/metric/ranking_metric.py +15 -15
  27. omnigenome/src/metric/regression_metric.py +18 -18
  28. omnigenome/src/misc/utils.py +214 -150
  29. omnigenome/src/model/augmentation/__init__.py +0 -1
  30. omnigenome/src/model/augmentation/model.py +17 -17
  31. omnigenome/src/model/classification/__init__.py +0 -1
  32. omnigenome/src/model/classification/model.py +28 -32
  33. omnigenome/src/model/embedding/__init__.py +0 -1
  34. omnigenome/src/model/embedding/model.py +35 -35
  35. omnigenome/src/model/mlm/__init__.py +0 -1
  36. omnigenome/src/model/mlm/model.py +13 -13
  37. omnigenome/src/model/module_utils.py +17 -17
  38. omnigenome/src/model/regression/__init__.py +0 -1
  39. omnigenome/src/model/regression/model.py +72 -77
  40. omnigenome/src/model/regression/resnet.py +32 -32
  41. omnigenome/src/model/rna_design/__init__.py +0 -1
  42. omnigenome/src/model/rna_design/model.py +168 -118
  43. omnigenome/src/model/seq2seq/__init__.py +0 -1
  44. omnigenome/src/model/seq2seq/model.py +4 -4
  45. omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
  46. omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
  47. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
  48. omnigenome/src/trainer/accelerate_trainer.py +40 -32
  49. omnigenome/src/trainer/hf_trainer.py +8 -8
  50. omnigenome/src/trainer/trainer.py +37 -25
  51. omnigenome/utility/dataset_hub/__init__.py +0 -1
  52. omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
  53. omnigenome/utility/ensemble.py +26 -26
  54. omnigenome/utility/hub_utils.py +8 -8
  55. omnigenome/utility/model_hub/__init__.py +0 -1
  56. omnigenome/utility/model_hub/model_hub.py +26 -25
  57. omnigenome/utility/pipeline_hub/__init__.py +0 -1
  58. omnigenome/utility/pipeline_hub/pipeline.py +49 -49
  59. omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
  60. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +3 -3
  61. omnigenome-0.3.1a0.dist-info/RECORD +78 -0
  62. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -1
  63. omnigenome-0.3.0a0.dist-info/RECORD +0 -85
  64. tests/__init__.py +0 -9
  65. tests/conftest.py +0 -160
  66. tests/test_dataset_patterns.py +0 -291
  67. tests/test_examples_syntax.py +0 -83
  68. tests/test_model_loading.py +0 -183
  69. tests/test_rna_functions.py +0 -255
  70. tests/test_training_patterns.py +0 -302
  71. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
  72. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
  73. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
@@ -18,9 +18,11 @@ import numpy as np
18
18
  import torch
19
19
  import autocuda
20
20
  from transformers import AutoModelForMaskedLM, AutoTokenizer
21
- from concurrent.futures import ProcessPoolExecutor
21
+ from concurrent.futures import ProcessPoolExecutor, as_completed
22
22
  import ViennaRNA
23
23
  from scipy.spatial.distance import hamming
24
+ import warnings
25
+ import os
24
26
 
25
27
  from omnigenome.src.misc.utils import fprint
26
28
 
@@ -28,19 +30,19 @@ from omnigenome.src.misc.utils import fprint
28
30
  class OmniModelForRNADesign(torch.nn.Module):
29
31
  """
30
32
  RNA design model using masked language modeling and evolutionary algorithms.
31
-
33
+
32
34
  This model combines a pre-trained masked language model with evolutionary
33
35
  algorithms to design RNA sequences that fold into specific target structures.
34
36
  It uses a multi-objective optimization approach to balance structure similarity
35
37
  and thermodynamic stability.
36
-
38
+
37
39
  Attributes:
38
40
  device: Device to run the model on (CPU or GPU)
39
41
  parallel: Whether to use parallel processing for structure prediction
40
42
  tokenizer: Tokenizer for processing RNA sequences
41
43
  model: Pre-trained masked language model
42
44
  """
43
-
45
+
44
46
  def __init__(
45
47
  self,
46
48
  model="yangheng/OmniGenome-186M",
@@ -51,7 +53,7 @@ class OmniModelForRNADesign(torch.nn.Module):
51
53
  ):
52
54
  """
53
55
  Initialize the RNA design model.
54
-
56
+
55
57
  Args:
56
58
  model (str): Model name or path for the pre-trained MLM model
57
59
  device: Device to run the model on (default: None, auto-detect)
@@ -70,164 +72,216 @@ class OmniModelForRNADesign(torch.nn.Module):
70
72
  def _random_bp_span(bp_span=None):
71
73
  """
72
74
  Generate a random base pair span.
73
-
75
+
74
76
  Args:
75
- bp_span (int, optional): Base pair span to center around (default: None)
76
-
77
+ bp_span (int, optional): Fixed base pair span. If None, generates random.
78
+
77
79
  Returns:
78
- int: Random base pair span within ±50 of the input span
80
+ int: Base pair span value
79
81
  """
80
- return random.choice(range(max(0, bp_span - 50), min(bp_span + 50, 400)))
82
+ if bp_span is None:
83
+ return random.randint(1, 10)
84
+ return bp_span
81
85
 
82
86
  @staticmethod
83
87
  def _longest_bp_span(structure):
84
88
  """
85
- Compute the longest base-pair span from RNA structure.
86
-
89
+ Find the longest base pair span in the structure.
90
+
87
91
  Args:
88
92
  structure (str): RNA structure in dot-bracket notation
89
-
93
+
90
94
  Returns:
91
- int: Length of the longest base-pair span
95
+ int: Length of the longest base pair span
92
96
  """
93
- stack = []
94
97
  max_span = 0
95
- for i, char in enumerate(structure):
98
+ current_span = 0
99
+
100
+ for char in structure:
96
101
  if char == "(":
97
- stack.append(i)
98
- elif char == ")" and stack:
99
- left_index = stack.pop()
100
- max_span = max(max_span, i - left_index)
102
+ current_span += 1
103
+ max_span = max(max_span, current_span)
104
+ elif char == ")":
105
+ current_span = max(0, current_span - 1)
106
+ else:
107
+ current_span = 0
108
+
101
109
  return max_span
102
110
 
103
111
  @staticmethod
104
112
  def _predict_structure_single(sequence, bp_span=-1):
105
113
  """
106
- Predict the RNA structure and minimum free energy (MFE) for a single sequence.
107
-
114
+ Predict structure for a single sequence (worker function for multiprocessing).
115
+
108
116
  Args:
109
- sequence (str): RNA sequence
110
- bp_span (int): Maximum base pair span for folding (default: -1, no limit)
111
-
117
+ sequence (str): RNA sequence to fold
118
+ bp_span (int): Base pair span parameter
119
+
112
120
  Returns:
113
- tuple: (structure, mfe) where structure is in dot-bracket notation
121
+ tuple: (structure, mfe) tuple
114
122
  """
115
- md = ViennaRNA.md()
116
- md.max_bp_span = bp_span
117
- fc = ViennaRNA.fold_compound(sequence, md)
118
- return fc.mfe()
123
+ try:
124
+ return ViennaRNA.fold(sequence)
125
+ except Exception as e:
126
+ warnings.warn(f"Failed to fold sequence {sequence}: {e}")
127
+ return ("." * len(sequence), 0.0)
119
128
 
120
129
  def _predict_structure(self, sequences, bp_span=-1):
121
130
  """
122
- Predict RNA structures for multiple sequences.
123
-
131
+ Predict structures for multiple sequences.
132
+
124
133
  Args:
125
134
  sequences (list): List of RNA sequences
126
- bp_span (int): Maximum base pair span for folding (default: -1, no limit)
127
-
135
+ bp_span (int): Base pair span parameter
136
+
128
137
  Returns:
129
138
  list: List of (structure, mfe) tuples
130
139
  """
131
- return [self._predict_structure_single(seq, bp_span) for seq in sequences]
140
+ if not self.parallel or len(sequences) <= 1:
141
+ # Sequential processing
142
+ return [self._predict_structure_single(seq, bp_span) for seq in sequences]
143
+
144
+ # Parallel processing with improved error handling
145
+ try:
146
+ # Determine number of workers
147
+ max_workers = min(os.cpu_count(), len(sequences), 8) # Limit to 8 workers
148
+
149
+ with ProcessPoolExecutor(max_workers=max_workers) as executor:
150
+ # Submit all tasks
151
+ future_to_seq = {
152
+ executor.submit(self._predict_structure_single, seq, bp_span): seq
153
+ for seq in sequences
154
+ }
155
+
156
+ # Collect results
157
+ results = []
158
+ for future in as_completed(future_to_seq):
159
+ try:
160
+ result = future.result()
161
+ results.append(result)
162
+ except Exception as e:
163
+ seq = future_to_seq[future]
164
+ warnings.warn(f"Failed to process sequence {seq}: {e}")
165
+ # Fallback to dot structure
166
+ results.append(("." * len(seq), 0.0))
167
+
168
+ return results
169
+
170
+ except Exception as e:
171
+ warnings.warn(
172
+ f"Parallel processing failed, falling back to sequential: {e}"
173
+ )
174
+ # Fallback to sequential processing
175
+ return [self._predict_structure_single(seq, bp_span) for seq in sequences]
132
176
 
133
177
  def _init_population(self, structure, num_population):
134
178
  """
135
- Initialize the population with masked sequences.
136
-
179
+ Initialize the population with random sequences.
180
+
137
181
  Args:
138
- structure (str): Target RNA structure in dot-bracket notation
139
- num_population (int): Number of individuals in the population
140
-
182
+ structure (str): Target RNA structure
183
+ num_population (int): Population size
184
+
141
185
  Returns:
142
- list: List of (sequence, bp_span) tuples representing the initial population
186
+ list: List of (sequence, bp_span) tuples
143
187
  """
144
188
  population = []
145
- mlm_inputs = []
146
- for _ in range(num_population):
147
- masked_sequence = "".join(
148
- [random.choice(["G", "C", "<mask>"]) for _ in structure]
149
- )
150
- mlm_inputs.append(f"{masked_sequence}<eos>{structure}")
151
-
152
- outputs = self._mlm_predict(mlm_inputs, structure)
189
+ bp_span = self._longest_bp_span(structure)
153
190
 
154
- for i, output in enumerate(outputs):
155
- sequence = self.tokenizer.convert_ids_to_tokens(output.tolist())
156
- fixed_sequence = [
157
- x if x in "AGCT" else random.choice(["A", "T", "G", "C"])
158
- for x in sequence
159
- ]
160
- bp_span = self._random_bp_span(len(structure))
161
- population.append(("".join(fixed_sequence), bp_span))
191
+ for _ in range(num_population):
192
+ # Generate random sequence
193
+ sequence = "".join(random.choice("ACGU") for _ in range(len(structure)))
194
+ population.append((sequence, bp_span))
162
195
 
163
196
  return population
164
197
 
165
198
  def _mlm_mutate(self, population, structure, mutation_ratio):
166
199
  """
167
- Apply mutation to the population using the masked language model (MLM).
168
-
200
+ Mutate population using masked language modeling.
201
+
169
202
  Args:
170
- population (list): Current population of (sequence, bp_span) tuples
203
+ population (list): Current population
171
204
  structure (str): Target RNA structure
172
205
  mutation_ratio (float): Ratio of tokens to mutate
173
-
206
+
174
207
  Returns:
175
- list: Mutated population of (sequence, bp_span) tuples
208
+ list: Mutated population
176
209
  """
177
210
 
178
211
  def mutate(sequence, mutation_rate):
179
- sequence = np.array(list(sequence))
180
- masked_indices = np.random.rand(len(sequence)) < mutation_rate
181
- sequence[masked_indices] = "$"
182
- return "".join(sequence).replace("$", "<mask>")
212
+ # Create masked sequence
213
+ masked_sequence = list(sequence)
214
+ num_mutations = int(len(sequence) * mutation_rate)
215
+ mutation_positions = random.sample(range(len(sequence)), num_mutations)
183
216
 
184
- mlm_inputs = []
185
- for sequence, bp_span in population:
186
- masked_sequence = mutate(sequence, mutation_ratio)
187
- mlm_inputs.append(f"{masked_sequence}<eos>{structure}")
217
+ for pos in mutation_positions:
218
+ masked_sequence[pos] = self.tokenizer.mask_token
188
219
 
189
- outputs = self._mlm_predict(mlm_inputs, structure)
220
+ return "".join(masked_sequence)
190
221
 
191
- mut_population = []
192
- for i, (seq, bp_span) in enumerate(population):
193
- sequence = self.tokenizer.convert_ids_to_tokens(outputs[i].tolist())
194
- fixed_sequence = [
195
- x if x in "AGCT" else random.choice(["A", "T", "G", "C"])
196
- for x in sequence
197
- ]
198
- bp_span = self._random_bp_span(bp_span)
199
- mut_population.append(("".join(fixed_sequence), bp_span))
222
+ # Prepare inputs for MLM
223
+ mlm_inputs = []
224
+ for sequence, bp_span in population:
225
+ masked_seq = mutate(sequence, mutation_ratio)
226
+ mlm_inputs.append(masked_seq)
227
+
228
+ # Get predictions from MLM
229
+ predicted_tokens = self._mlm_predict(mlm_inputs, structure)
230
+
231
+ # Convert predictions back to sequences
232
+ mutated_population = []
233
+ for i, (sequence, bp_span) in enumerate(population):
234
+ # Convert token IDs back to nucleotides
235
+ new_sequence = self.tokenizer.decode(
236
+ predicted_tokens[i], skip_special_tokens=True
237
+ )
238
+ # Ensure the sequence has the correct length
239
+ if len(new_sequence) != len(structure):
240
+ new_sequence = new_sequence[: len(structure)].ljust(len(structure), "A")
241
+ mutated_population.append((new_sequence, bp_span))
200
242
 
201
- return mut_population
243
+ return mutated_population
202
244
 
203
245
  def _crossover(self, population, num_points=3):
204
246
  """
205
- Perform crossover operation to create offspring.
206
-
247
+ Perform crossover operation on the population.
248
+
207
249
  Args:
208
- population (list): Current population of (sequence, bp_span) tuples
209
- num_points (int): Number of crossover points (default: 3)
210
-
250
+ population (list): Current population
251
+ num_points (int): Number of crossover points
252
+
211
253
  Returns:
212
- list: Offspring population after crossover
254
+ list: Population after crossover
213
255
  """
214
- population_size = len(population)
215
- sequence_length = len(population[0][0])
256
+ if len(population) < 2:
257
+ return population
258
+
259
+ # Create crossover masks
260
+ num_sequences = len(population)
261
+ masks = np.zeros((num_sequences, len(population[0][0])), dtype=bool)
216
262
 
217
- parent_indices = np.random.choice(population_size // 10, (population_size, 2))
218
- crossover_points = np.sort(
219
- np.random.randint(1, sequence_length, size=(population_size, num_points)),
220
- axis=1,
263
+ # Generate random crossover points
264
+ crossover_points = np.random.randint(
265
+ 0, len(population[0][0]), (num_sequences, num_points)
221
266
  )
222
267
 
223
- masks = np.zeros((population_size, sequence_length), dtype=bool)
224
- for i in range(population_size):
225
- last_point = 0
268
+ # Create parent indices
269
+ parent_indices = np.random.randint(0, num_sequences, (num_sequences, 2))
270
+
271
+ # Generate crossover masks
272
+ for i in range(num_sequences):
226
273
  for j in range(num_points):
227
- masks[i, last_point : crossover_points[i, j]] = j % 2 == 0
228
- last_point = crossover_points[i, j]
274
+ if j == 0:
275
+ masks[i, : crossover_points[i, j]] = True
276
+ else:
277
+ last_point = crossover_points[i, j - 1]
278
+ masks[i, last_point : crossover_points[i, j]] = j % 2 == 0
279
+
280
+ # Handle the last segment
281
+ last_point = crossover_points[i, -1]
229
282
  masks[i, last_point:] = num_points % 2 == 0
230
283
 
284
+ # Perform crossover
231
285
  population_array = np.array([list(seq[0]) for seq in population])
232
286
  child1_array = np.where(
233
287
  masks,
@@ -251,23 +305,19 @@ class OmniModelForRNADesign(torch.nn.Module):
251
305
  def _evaluate_structure_fitness(self, sequences, structure):
252
306
  """
253
307
  Evaluate the fitness of the RNA structure by comparing with the target structure.
254
-
308
+
255
309
  Args:
256
310
  sequences (list): List of (sequence, bp_span) tuples to evaluate
257
311
  structure (str): Target RNA structure
258
-
312
+
259
313
  Returns:
260
314
  list: Sorted population with fitness scores and MFE values
261
315
  """
262
- if self.parallel:
263
- with ProcessPoolExecutor() as executor:
264
- structures_mfe = list(
265
- executor.map(
266
- self._predict_structure_single, [seq for seq, _ in sequences]
267
- )
268
- )
269
- else:
270
- structures_mfe = self._predict_structure([seq for seq, _ in sequences])
316
+ # Get sequences for structure prediction
317
+ seq_list = [seq for seq, _ in sequences]
318
+
319
+ # Predict structures (with improved multiprocessing)
320
+ structures_mfe = self._predict_structure(seq_list)
271
321
 
272
322
  sorted_population = []
273
323
  for (seq, bp_span), (ss, mfe) in zip(sequences, structures_mfe):
@@ -283,11 +333,11 @@ class OmniModelForRNADesign(torch.nn.Module):
283
333
  def _non_dominated_sorting(scores, mfe_values):
284
334
  """
285
335
  Perform non-dominated sorting for multi-objective optimization.
286
-
336
+
287
337
  Args:
288
338
  scores (list): Structure similarity scores
289
339
  mfe_values (list): Minimum free energy values
290
-
340
+
291
341
  Returns:
292
342
  list: List of fronts (Pareto fronts)
293
343
  """
@@ -326,11 +376,11 @@ class OmniModelForRNADesign(torch.nn.Module):
326
376
  def _select_next_generation(next_generation, fronts):
327
377
  """
328
378
  Select the next generation based on Pareto fronts.
329
-
379
+
330
380
  Args:
331
381
  next_generation (list): Current population with fitness scores
332
382
  fronts (list): Pareto fronts
333
-
383
+
334
384
  Returns:
335
385
  list: Selected population for the next generation
336
386
  """
@@ -346,11 +396,11 @@ class OmniModelForRNADesign(torch.nn.Module):
346
396
  def _mlm_predict(self, mlm_inputs, structure):
347
397
  """
348
398
  Perform masked language model prediction.
349
-
399
+
350
400
  Args:
351
401
  mlm_inputs (list): List of masked input sequences
352
402
  structure (str): Target RNA structure
353
-
403
+
354
404
  Returns:
355
405
  list: Predicted token IDs for each input
356
406
  """
@@ -360,7 +410,7 @@ class OmniModelForRNADesign(torch.nn.Module):
360
410
  with torch.no_grad():
361
411
  for i in range(0, len(mlm_inputs), batch_size):
362
412
  inputs = self.tokenizer(
363
- mlm_inputs[i: i + batch_size],
413
+ mlm_inputs[i : i + batch_size],
364
414
  padding=False,
365
415
  max_length=1024,
366
416
  truncation=True,
@@ -379,13 +429,13 @@ class OmniModelForRNADesign(torch.nn.Module):
379
429
  ):
380
430
  """
381
431
  Design RNA sequences for a target structure using evolutionary algorithms.
382
-
432
+
383
433
  Args:
384
434
  structure (str): Target RNA structure in dot-bracket notation
385
435
  mutation_ratio (float): Ratio of tokens to mutate (default: 0.5)
386
436
  num_population (int): Population size (default: 100)
387
437
  num_generation (int): Number of generations (default: 100)
388
-
438
+
389
439
  Returns:
390
440
  list: List of designed RNA sequences with their fitness scores
391
441
  """
@@ -9,4 +9,3 @@
9
9
  """
10
10
  This package contains modules for sequence-to-sequence models.
11
11
  """
12
-
@@ -21,20 +21,20 @@ from ...abc.abstract_model import OmniModel
21
21
  class OmniModelForSeq2Seq(OmniModel):
22
22
  """
23
23
  Sequence-to-sequence model for genomic sequences.
24
-
24
+
25
25
  This model implements a sequence-to-sequence architecture for genomic
26
26
  sequences, where the input is one sequence and the output is another
27
27
  sequence. It's useful for tasks like sequence translation, structure
28
28
  prediction, or sequence transformation.
29
-
29
+
30
30
  The model can be extended to implement specific seq2seq tasks by
31
31
  overriding the forward, predict, and inference methods.
32
32
  """
33
-
33
+
34
34
  def __init__(self, config_or_model, tokenizer, *args, **kwargs):
35
35
  """
36
36
  Initialize the sequence-to-sequence model.
37
-
37
+
38
38
  Args:
39
39
  config_or_model: Model configuration or pre-trained model
40
40
  tokenizer: Tokenizer for processing input sequences
@@ -17,17 +17,17 @@ warnings.filterwarnings("once")
17
17
  def is_bpe_tokenization(tokens, threshold=0.1):
18
18
  """
19
19
  Check if the tokenization is BPE-based by analyzing token characteristics.
20
-
20
+
21
21
  This function examines the tokens to determine if they follow BPE tokenization
22
22
  patterns by analyzing token length distributions and special token patterns.
23
-
23
+
24
24
  Args:
25
25
  tokens (list): List of tokens to analyze
26
26
  threshold (float, optional): Threshold for determining BPE tokenization. Defaults to 0.1
27
-
27
+
28
28
  Returns:
29
29
  bool: True if tokens appear to be BPE-based, False otherwise
30
-
30
+
31
31
  Example:
32
32
  >>> tokens = ["▁hello", "▁world", "▁how", "▁are", "▁you"]
33
33
  >>> is_bpe = is_bpe_tokenization(tokens)
@@ -52,15 +52,15 @@ def is_bpe_tokenization(tokens, threshold=0.1):
52
52
  class OmniBPETokenizer(OmniTokenizer):
53
53
  """
54
54
  A Byte Pair Encoding (BPE) tokenizer for genomic sequences.
55
-
55
+
56
56
  This tokenizer uses BPE tokenization for genomic sequences and provides
57
57
  validation to ensure the base tokenizer is BPE-based. It supports sequence
58
58
  preprocessing and handles various input formats.
59
-
59
+
60
60
  Attributes:
61
61
  base_tokenizer: The underlying BPE tokenizer
62
62
  metadata: Dictionary containing tokenizer metadata
63
-
63
+
64
64
  Example:
65
65
  >>> from omnigenome.src.tokenizer import OmniBPETokenizer
66
66
  >>> from transformers import AutoTokenizer
@@ -75,7 +75,7 @@ class OmniBPETokenizer(OmniTokenizer):
75
75
  def __init__(self, base_tokenizer=None, **kwargs):
76
76
  """
77
77
  Initialize the OmniBPETokenizer.
78
-
78
+
79
79
  Args:
80
80
  base_tokenizer: The base BPE tokenizer
81
81
  **kwargs: Additional keyword arguments passed to parent class
@@ -86,21 +86,21 @@ class OmniBPETokenizer(OmniTokenizer):
86
86
  def __call__(self, sequence, **kwargs):
87
87
  """
88
88
  Tokenize a sequence using BPE tokenization.
89
-
89
+
90
90
  This method processes the input sequence using BPE tokenization,
91
91
  handles sequence preprocessing (U/T conversion, whitespace addition),
92
92
  and validates that the tokenization is BPE-based.
93
-
93
+
94
94
  Args:
95
95
  sequence (str): Input sequence to tokenize
96
96
  **kwargs: Additional keyword arguments including max_length
97
-
97
+
98
98
  Returns:
99
99
  dict: Dictionary containing tokenized inputs with keys 'input_ids' and 'attention_mask'
100
-
100
+
101
101
  Raises:
102
102
  ValueError: If the tokenizer is not BPE-based
103
-
103
+
104
104
  Example:
105
105
  >>> sequence = "ACGUAGGUAUCGUAGA"
106
106
  >>> tokenized = tokenizer(sequence)
@@ -136,14 +136,14 @@ class OmniBPETokenizer(OmniTokenizer):
136
136
  def from_pretrained(model_name_or_path, **kwargs):
137
137
  """
138
138
  Create a BPE tokenizer from a pre-trained model.
139
-
139
+
140
140
  Args:
141
141
  model_name_or_path (str): Name or path of the pre-trained model
142
142
  **kwargs: Additional keyword arguments
143
-
143
+
144
144
  Returns:
145
145
  OmniBPETokenizer: Initialized BPE tokenizer
146
-
146
+
147
147
  Example:
148
148
  >>> tokenizer = OmniBPETokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
149
149
  >>> print(type(tokenizer))
@@ -159,14 +159,14 @@ class OmniBPETokenizer(OmniTokenizer):
159
159
  def tokenize(self, sequence, **kwargs):
160
160
  """
161
161
  Tokenize a sequence using the base BPE tokenizer.
162
-
162
+
163
163
  Args:
164
164
  sequence (str): Input sequence to tokenize
165
165
  **kwargs: Additional keyword arguments
166
-
166
+
167
167
  Returns:
168
168
  list: List of tokens
169
-
169
+
170
170
  Example:
171
171
  >>> sequence = "ACGUAGGUAUCGUAGA"
172
172
  >>> tokens = tokenizer.tokenize(sequence)
@@ -178,17 +178,17 @@ class OmniBPETokenizer(OmniTokenizer):
178
178
  def encode(self, sequence, **kwargs):
179
179
  """
180
180
  Encode a sequence using the base BPE tokenizer.
181
-
181
+
182
182
  Args:
183
183
  sequence (str): Input sequence to encode
184
184
  **kwargs: Additional keyword arguments
185
-
185
+
186
186
  Returns:
187
187
  list: List of token IDs
188
-
188
+
189
189
  Raises:
190
190
  AssertionError: If the base tokenizer is not BPE-based
191
-
191
+
192
192
  Example:
193
193
  >>> sequence = "ACGUAGGUAUCGUAGA"
194
194
  >>> token_ids = tokenizer.encode(sequence)
@@ -203,17 +203,17 @@ class OmniBPETokenizer(OmniTokenizer):
203
203
  def decode(self, sequence, **kwargs):
204
204
  """
205
205
  Decode a sequence using the base BPE tokenizer.
206
-
206
+
207
207
  Args:
208
208
  sequence: Input sequence to decode (can be token IDs or tokens)
209
209
  **kwargs: Additional keyword arguments
210
-
210
+
211
211
  Returns:
212
212
  str: Decoded sequence
213
-
213
+
214
214
  Raises:
215
215
  AssertionError: If the base tokenizer is not BPE-based
216
-
216
+
217
217
  Example:
218
218
  >>> token_ids = [1, 2, 3, 4, 5]
219
219
  >>> sequence = tokenizer.decode(token_ids)