omnigenome 0.3.0a1__py3-none-any.whl → 0.3.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. omnigenome/__init__.py +16 -8
  2. omnigenome/auto/auto_bench/__init__.py +0 -1
  3. omnigenome/auto/auto_bench/auto_bench.py +24 -14
  4. omnigenome/auto/auto_train/__init__.py +0 -1
  5. omnigenome/auto/auto_train/auto_train.py +11 -12
  6. omnigenome/auto/bench_hub/__init__.py +0 -1
  7. omnigenome/auto/bench_hub/bench_hub.py +1 -1
  8. omnigenome/cli/__init__.py +0 -1
  9. omnigenome/cli/commands/__init__.py +0 -1
  10. omnigenome/cli/commands/base.py +10 -10
  11. omnigenome/cli/commands/bench/__init__.py +0 -1
  12. omnigenome/cli/commands/bench/bench_cli.py +10 -10
  13. omnigenome/cli/commands/rna/__init__.py +0 -1
  14. omnigenome/cli/commands/rna/rna_design.py +10 -11
  15. omnigenome/src/__init__.py +0 -1
  16. omnigenome/src/abc/__init__.py +0 -1
  17. omnigenome/src/abc/abstract_dataset.py +38 -19
  18. omnigenome/src/abc/abstract_metric.py +7 -7
  19. omnigenome/src/abc/abstract_model.py +15 -14
  20. omnigenome/src/abc/abstract_tokenizer.py +9 -7
  21. omnigenome/src/dataset/omni_dataset.py +16 -14
  22. omnigenome/src/lora/__init__.py +0 -1
  23. omnigenome/src/lora/lora_model.py +47 -41
  24. omnigenome/src/metric/classification_metric.py +11 -11
  25. omnigenome/src/metric/metric.py +19 -19
  26. omnigenome/src/metric/ranking_metric.py +15 -15
  27. omnigenome/src/metric/regression_metric.py +18 -18
  28. omnigenome/src/misc/utils.py +40 -36
  29. omnigenome/src/model/augmentation/__init__.py +0 -1
  30. omnigenome/src/model/augmentation/model.py +17 -17
  31. omnigenome/src/model/classification/__init__.py +0 -1
  32. omnigenome/src/model/classification/model.py +28 -32
  33. omnigenome/src/model/embedding/__init__.py +0 -1
  34. omnigenome/src/model/embedding/model.py +35 -35
  35. omnigenome/src/model/mlm/__init__.py +0 -1
  36. omnigenome/src/model/mlm/model.py +13 -13
  37. omnigenome/src/model/module_utils.py +17 -17
  38. omnigenome/src/model/regression/__init__.py +0 -1
  39. omnigenome/src/model/regression/model.py +72 -77
  40. omnigenome/src/model/regression/resnet.py +32 -32
  41. omnigenome/src/model/rna_design/__init__.py +0 -1
  42. omnigenome/src/model/rna_design/model.py +65 -58
  43. omnigenome/src/model/seq2seq/__init__.py +0 -1
  44. omnigenome/src/model/seq2seq/model.py +4 -4
  45. omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
  46. omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
  47. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
  48. omnigenome/src/trainer/accelerate_trainer.py +40 -32
  49. omnigenome/src/trainer/hf_trainer.py +8 -8
  50. omnigenome/src/trainer/trainer.py +37 -25
  51. omnigenome/utility/dataset_hub/__init__.py +0 -1
  52. omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
  53. omnigenome/utility/ensemble.py +26 -26
  54. omnigenome/utility/hub_utils.py +8 -8
  55. omnigenome/utility/model_hub/__init__.py +0 -1
  56. omnigenome/utility/model_hub/model_hub.py +26 -25
  57. omnigenome/utility/pipeline_hub/__init__.py +0 -1
  58. omnigenome/utility/pipeline_hub/pipeline.py +49 -49
  59. omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
  60. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +2 -2
  61. omnigenome-0.3.1a0.dist-info/RECORD +78 -0
  62. omnigenome-0.3.0a1.dist-info/RECORD +0 -78
  63. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
  64. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
  65. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
  66. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -0
@@ -47,14 +47,14 @@ def count_parameters(model):
47
47
  class OmniModel(torch.nn.Module):
48
48
  """
49
49
  Abstract base class for all models in OmniGenome.
50
-
50
+
51
51
  This class provides a unified interface for all genomic models in the OmniGenome
52
52
  framework. It handles model initialization, forward passes, loss computation,
53
53
  prediction, inference, and model persistence.
54
-
54
+
55
55
  The class is designed to work with various types of genomic data and tasks,
56
56
  including sequence classification, token classification, regression, and more.
57
-
57
+
58
58
  Attributes:
59
59
  model (torch.nn.Module): The underlying PyTorch model.
60
60
  config: The model configuration.
@@ -76,16 +76,16 @@ class OmniModel(torch.nn.Module):
76
76
  - From a configuration object
77
77
 
78
78
  Args:
79
- config_or_model: A model configuration, a pre-trained model path (str),
79
+ config_or_model: A model configuration, a pre-trained model path (str),
80
80
  or a `torch.nn.Module` instance.
81
81
  tokenizer: The tokenizer associated with the model.
82
82
  *args: Additional positional arguments.
83
83
  **kwargs: Additional keyword arguments.
84
84
  - label2id (dict): Mapping from class labels to IDs.
85
85
  - num_labels (int): The number of labels.
86
- - trust_remote_code (bool): Whether to trust remote code when loading
86
+ - trust_remote_code (bool): Whether to trust remote code when loading
87
87
  from Hugging Face Hub. Defaults to True.
88
- - ignore_mismatched_sizes (bool): Whether to ignore size mismatches
88
+ - ignore_mismatched_sizes (bool): Whether to ignore size mismatches
89
89
  when loading pre-trained weights. Defaults to False.
90
90
  - dropout (float): Dropout rate. Defaults to 0.0.
91
91
 
@@ -97,7 +97,7 @@ class OmniModel(torch.nn.Module):
97
97
  Example:
98
98
  >>> # Initialize from a pre-trained model
99
99
  >>> model = OmniModelForSequenceClassification("model_path", tokenizer)
100
-
100
+
101
101
  >>> # Initialize from a configuration
102
102
  >>> config = AutoConfig.from_pretrained("model_path")
103
103
  >>> model = OmniModelForSequenceClassification(config, tokenizer)
@@ -202,7 +202,9 @@ class OmniModel(torch.nn.Module):
202
202
  )
203
203
  self.config.num_labels = len(self.config.id2label)
204
204
 
205
- assert len(self.config.label2id) == num_labels, f"Expected {num_labels} labels, but got {len(self.config.label2id)} in label2id dictionary."
205
+ assert (
206
+ len(self.config.label2id) == num_labels
207
+ ), f"Expected {num_labels} labels, but got {len(self.config.label2id)} in label2id dictionary."
206
208
 
207
209
  # The metadata of the model
208
210
  self.metadata = env_meta_info()
@@ -240,7 +242,7 @@ class OmniModel(torch.nn.Module):
240
242
  model architectures by mapping input parameters appropriately.
241
243
 
242
244
  Args:
243
- **inputs: The inputs to the model, compatible with the base model's
245
+ **inputs: The inputs to the model, compatible with the base model's
244
246
  forward method. Typically includes 'input_ids', 'attention_mask',
245
247
  and other model-specific parameters.
246
248
 
@@ -386,7 +388,7 @@ class OmniModel(torch.nn.Module):
386
388
  predictions for further processing.
387
389
 
388
390
  Args:
389
- sequence_or_inputs: A sequence (str), list of sequences, or
391
+ sequence_or_inputs: A sequence (str), list of sequences, or
390
392
  tokenized inputs (dict/tuple).
391
393
  **kwargs: Additional arguments for tokenization and inference.
392
394
 
@@ -398,7 +400,7 @@ class OmniModel(torch.nn.Module):
398
400
  Example:
399
401
  >>> # Predict on a single sequence
400
402
  >>> outputs = model.predict("ATCGATCG")
401
-
403
+
402
404
  >>> # Predict on multiple sequences
403
405
  >>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"])
404
406
  """
@@ -416,7 +418,7 @@ class OmniModel(torch.nn.Module):
416
418
  to class labels or probabilities.
417
419
 
418
420
  Args:
419
- sequence_or_inputs: A sequence (str), list of sequences, or
421
+ sequence_or_inputs: A sequence (str), list of sequences, or
420
422
  tokenized inputs (dict/tuple).
421
423
  **kwargs: Additional arguments for tokenization and inference.
422
424
 
@@ -429,7 +431,7 @@ class OmniModel(torch.nn.Module):
429
431
  >>> # Inference on a single sequence
430
432
  >>> results = model.inference("ATCGATCG")
431
433
  >>> print(results['predictions']) # Class labels
432
-
434
+
433
435
  >>> # Inference on multiple sequences
434
436
  >>> results = model.inference(["ATCGATCG", "GCTAGCTA"])
435
437
  """
@@ -686,4 +688,3 @@ class OmniModel(torch.nn.Module):
686
688
  info += f"Model Config: {self.config}\n"
687
689
  fprint(info)
688
690
  return info
689
-
@@ -16,15 +16,15 @@ from ..misc.utils import env_meta_info, load_module_from_path
16
16
  class OmniTokenizer:
17
17
  """
18
18
  A wrapper class for tokenizers to provide a consistent interface within OmniGenome.
19
-
19
+
20
20
  This class provides a unified interface for tokenizers in the OmniGenome framework.
21
21
  It wraps underlying tokenizers (typically from Hugging Face) and provides
22
22
  additional functionality for genomic sequence processing.
23
-
23
+
24
24
  The class handles various tokenization strategies and provides compatibility
25
25
  with different model architectures. It also supports custom tokenizer wrappers
26
26
  for specialized genomic tasks.
27
-
27
+
28
28
  Attributes:
29
29
  base_tokenizer: The underlying tokenizer instance (e.g., from Hugging Face).
30
30
  max_length (int): The default maximum sequence length.
@@ -52,7 +52,7 @@ class OmniTokenizer:
52
52
  >>> from transformers import AutoTokenizer
53
53
  >>> base_tokenizer = AutoTokenizer.from_pretrained("model_name")
54
54
  >>> tokenizer = OmniTokenizer(base_tokenizer, max_length=512)
55
-
55
+
56
56
  >>> # Initialize with sequence conversion
57
57
  >>> tokenizer = OmniTokenizer(base_tokenizer, u2t=True)
58
58
  """
@@ -87,9 +87,9 @@ class OmniTokenizer:
87
87
  Example:
88
88
  >>> # Load from a pre-trained model
89
89
  >>> tokenizer = OmniTokenizer.from_pretrained("model_name")
90
-
90
+
91
91
  >>> # Load with custom parameters
92
- >>> tokenizer = OmniTokenizer.from_pretrained("model_name",
92
+ >>> tokenizer = OmniTokenizer.from_pretrained("model_name",
93
93
  ... trust_remote_code=True)
94
94
  """
95
95
  wrapper_path = f"{model_name_or_path.rstrip('/')}/omnigenome_wrapper.py"
@@ -104,7 +104,9 @@ class OmniTokenizer:
104
104
  warnings.warn(
105
105
  f"No tokenizer wrapper found in {wrapper_path} -> Exception: {e}"
106
106
  )
107
- kwargs.pop("num_labels", None) # Remove num_labels if it exists, as it may not be applicable
107
+ kwargs.pop(
108
+ "num_labels", None
109
+ ) # Remove num_labels if it exists, as it may not be applicable
108
110
 
109
111
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, **kwargs)
110
112
 
@@ -27,11 +27,11 @@ from ... import __name__, __version__
27
27
  class OmniDatasetForTokenClassification(OmniDataset):
28
28
  """
29
29
  Dataset class specifically designed for token classification tasks in genomics.
30
-
30
+
31
31
  This class extends `OmniDataset` to provide functionalities for preparing input sequences
32
32
  and their corresponding token-level labels. It's designed for tasks where each token
33
33
  in a sequence needs to be classified independently.
34
-
34
+
35
35
  Attributes:
36
36
  metadata: Dictionary containing dataset metadata including library information
37
37
  label2id: Mapping from label strings to integer IDs
@@ -68,7 +68,7 @@ class OmniDatasetForTokenClassification(OmniDataset):
68
68
  def prepare_input(self, instance, **kwargs):
69
69
  """
70
70
  Prepare a single data instance for token classification.
71
-
71
+
72
72
  This method handles both string sequences and dictionary instances
73
73
  containing sequence and label information. It tokenizes the input
74
74
  sequence and prepares token-level labels for classification.
@@ -138,11 +138,11 @@ class OmniDatasetForTokenClassification(OmniDataset):
138
138
  class OmniDatasetForSequenceClassification(OmniDataset):
139
139
  """
140
140
  Dataset class for sequence classification tasks in genomics.
141
-
141
+
142
142
  This class extends `OmniDataset` to prepare input sequences and their corresponding
143
143
  sequence-level labels. It's designed for tasks where the entire sequence needs
144
144
  to be classified into one of several categories.
145
-
145
+
146
146
  Attributes:
147
147
  metadata: Dictionary containing dataset metadata including library information
148
148
  label2id: Mapping from label strings to integer IDs
@@ -179,7 +179,7 @@ class OmniDatasetForSequenceClassification(OmniDataset):
179
179
  def prepare_input(self, instance, **kwargs):
180
180
  """
181
181
  Prepare a single data instance for sequence classification.
182
-
182
+
183
183
  This method handles both string sequences and dictionary instances
184
184
  containing sequence and label information. It tokenizes the input
185
185
  sequence and prepares sequence-level labels for classification.
@@ -238,11 +238,11 @@ class OmniDatasetForSequenceClassification(OmniDataset):
238
238
  class OmniDatasetForTokenRegression(OmniDataset):
239
239
  """
240
240
  Dataset class for token regression tasks in genomics.
241
-
241
+
242
242
  This class extends `OmniDataset` to prepare input sequences and their corresponding
243
243
  token-level regression targets. It's designed for tasks where each token in a
244
244
  sequence needs to be assigned a continuous value.
245
-
245
+
246
246
  Attributes:
247
247
  metadata: Dictionary containing dataset metadata including library information
248
248
  """
@@ -278,7 +278,7 @@ class OmniDatasetForTokenRegression(OmniDataset):
278
278
  def prepare_input(self, instance, **kwargs):
279
279
  """
280
280
  Prepare a single data instance for token regression.
281
-
281
+
282
282
  This method handles both string sequences and dictionary instances
283
283
  containing sequence and regression target information. It tokenizes
284
284
  the input sequence and prepares token-level regression targets.
@@ -330,7 +330,9 @@ class OmniDatasetForTokenRegression(OmniDataset):
330
330
  # Handle token-level regression labels
331
331
  if isinstance(labels, (list, tuple)):
332
332
  # Ensure labels match sequence length
333
- labels = list(labels)[:self.max_length - 2] # Account for special tokens
333
+ labels = list(labels)[
334
+ : self.max_length - 2
335
+ ] # Account for special tokens
334
336
  labels = [-100] + labels + [-100] # Add padding for special tokens
335
337
  else:
336
338
  # Single value for the entire sequence
@@ -343,11 +345,11 @@ class OmniDatasetForTokenRegression(OmniDataset):
343
345
  class OmniDatasetForSequenceRegression(OmniDataset):
344
346
  """
345
347
  Dataset class for sequence regression tasks in genomics.
346
-
348
+
347
349
  This class extends `OmniDataset` to prepare input sequences and their corresponding
348
350
  sequence-level regression targets. It's designed for tasks where the entire
349
351
  sequence needs to be assigned a continuous value.
350
-
352
+
351
353
  Attributes:
352
354
  metadata: Dictionary containing dataset metadata including library information
353
355
  """
@@ -383,7 +385,7 @@ class OmniDatasetForSequenceRegression(OmniDataset):
383
385
  def prepare_input(self, instance, **kwargs):
384
386
  """
385
387
  Prepare a single data instance for sequence regression.
386
-
388
+
387
389
  This method handles both string sequences and dictionary instances
388
390
  containing sequence and regression target information. It tokenizes
389
391
  the input sequence and prepares sequence-level regression targets.
@@ -432,4 +434,4 @@ class OmniDatasetForSequenceRegression(OmniDataset):
432
434
  labels = float(labels)
433
435
 
434
436
  tokenized_inputs["labels"] = torch.tensor(labels, dtype=torch.float32)
435
- return tokenized_inputs
437
+ return tokenized_inputs
@@ -10,4 +10,3 @@
10
10
  """
11
11
  This package contains modules for LoRA (Low-Rank Adaptation) fine-tuning.
12
12
  """
13
-
@@ -18,22 +18,23 @@ import torch
18
18
  from torch import nn
19
19
  from omnigenome.src.misc.utils import fprint
20
20
 
21
+
21
22
  def find_linear_target_modules(model, keyword_filter=None, use_full_path=True):
22
23
  """
23
24
  Find linear modules in a model that can be targeted for LoRA adaptation.
24
-
25
+
25
26
  This function searches through a model's modules to identify linear layers
26
27
  that can be adapted using LoRA. It supports filtering by keyword patterns
27
28
  to target specific types of layers.
28
-
29
+
29
30
  Args:
30
31
  model: The model to search for linear modules
31
32
  keyword_filter (str, list, tuple, optional): Keywords to filter modules by name
32
33
  use_full_path (bool): Whether to return full module paths or just names (default: True)
33
-
34
+
34
35
  Returns:
35
36
  list: Sorted list of linear module names that can be targeted for LoRA
36
-
37
+
37
38
  Raises:
38
39
  TypeError: If keyword_filter is not None, str, or a list/tuple of str
39
40
  """
@@ -46,31 +47,32 @@ def find_linear_target_modules(model, keyword_filter=None, use_full_path=True):
46
47
  elif not isinstance(keyword_filter, (list, tuple)):
47
48
  raise TypeError("keyword_filter must be None, str, or a list/tuple of str")
48
49
 
49
- pattern = '|'.join(map(re.escape, keyword_filter))
50
+ pattern = "|".join(map(re.escape, keyword_filter))
50
51
 
51
52
  linear_modules = set()
52
53
  for name, module in model.named_modules():
53
54
  if isinstance(module, nn.Linear):
54
55
  if keyword_filter is None or re.search(pattern, name, re.IGNORECASE):
55
- linear_modules.add(name if use_full_path else name.split('.')[-1])
56
+ linear_modules.add(name if use_full_path else name.split(".")[-1])
56
57
 
57
58
  return sorted(linear_modules)
58
59
 
60
+
59
61
  def auto_lora_model(model, **kwargs):
60
62
  """
61
63
  Automatically create a LoRA-adapted model.
62
-
64
+
63
65
  This function automatically identifies suitable target modules and creates
64
66
  a LoRA-adapted version of the input model. It handles configuration
65
67
  setup and parameter freezing for efficient fine-tuning.
66
-
68
+
67
69
  Args:
68
70
  model: The base model to adapt with LoRA
69
71
  **kwargs: Additional LoRA configuration parameters
70
-
72
+
71
73
  Returns:
72
74
  The LoRA-adapted model
73
-
75
+
74
76
  Raises:
75
77
  AssertionError: If no target modules are found for LoRA injection
76
78
  """
@@ -79,8 +81,8 @@ def auto_lora_model(model, **kwargs):
79
81
 
80
82
  # A bad case for the EVO-1 model, which has a custom config class
81
83
  ######################
82
- if hasattr(model, 'config') and not isinstance(model.config, PretrainedConfig):
83
- delattr(model.config, 'Loader')
84
+ if hasattr(model, "config") and not isinstance(model.config, PretrainedConfig):
85
+ delattr(model.config, "Loader")
84
86
  model.config = PretrainedConfig.from_dict(dict(model.config))
85
87
  #######################
86
88
 
@@ -92,7 +94,9 @@ def auto_lora_model(model, **kwargs):
92
94
  lora_dropout = kwargs.pop("lora_dropout", 0.1)
93
95
 
94
96
  if target_modules is None:
95
- target_modules = find_linear_target_modules(model, keyword_filter=kwargs.get("keyword_filter", None))
97
+ target_modules = find_linear_target_modules(
98
+ model, keyword_filter=kwargs.get("keyword_filter", None)
99
+ )
96
100
  assert target_modules is not None, "No target modules found for LoRA injection."
97
101
  config = LoraConfig(
98
102
  target_modules=target_modules,
@@ -115,29 +119,30 @@ def auto_lora_model(model, **kwargs):
115
119
  )
116
120
  return lora_model
117
121
 
122
+
118
123
  class OmniLoraModel(nn.Module):
119
124
  """
120
125
  LoRA-adapted model for OmniGenome.
121
-
126
+
122
127
  This class provides a wrapper around LoRA-adapted models, enabling
123
128
  efficient fine-tuning of large genomic language models while maintaining
124
129
  compatibility with the OmniGenome framework.
125
-
130
+
126
131
  Attributes:
127
132
  lora_model: The underlying LoRA-adapted model
128
133
  config: Model configuration
129
134
  device: Device the model is running on
130
135
  dtype: Data type of the model parameters
131
136
  """
132
-
137
+
133
138
  def __init__(self, model, **kwargs):
134
139
  """
135
140
  Initialize the LoRA-adapted model.
136
-
141
+
137
142
  Args:
138
143
  model: The base model to adapt with LoRA
139
144
  **kwargs: LoRA configuration parameters
140
-
145
+
141
146
  Raises:
142
147
  ValueError: If no target modules are specified for LoRA injection
143
148
  """
@@ -147,7 +152,8 @@ class OmniLoraModel(nn.Module):
147
152
  raise ValueError(
148
153
  "No target modules found for LoRA injection. To perform LoRA adaptation fine-tuning, "
149
154
  "please specify the target modules using the 'target_modules' argument. "
150
- "The target modules depend on the model architecture, such as 'query', 'value', etc. ")
155
+ "The target modules depend on the model architecture, such as 'query', 'value', etc. "
156
+ )
151
157
 
152
158
  self.lora_model = auto_lora_model(model, **kwargs)
153
159
 
@@ -159,23 +165,23 @@ class OmniLoraModel(nn.Module):
159
165
  )
160
166
 
161
167
  self.config = model.config
162
- self.to('cpu') # Move the model to CPU initially
168
+ self.to("cpu") # Move the model to CPU initially
163
169
  fprint(
164
170
  "LoRA model initialized with the following configuration:\n",
165
- self.lora_model
171
+ self.lora_model,
166
172
  )
167
173
 
168
174
  def to(self, *args, **kwargs):
169
175
  """
170
176
  Move the model to a specific device and data type.
171
-
177
+
172
178
  This method overrides the default to() method to ensure the LoRA model
173
179
  and its components are properly moved to the target device and dtype.
174
-
180
+
175
181
  Args:
176
182
  *args: Device specification (e.g., 'cuda', 'cpu')
177
183
  **kwargs: Additional arguments including dtype
178
-
184
+
179
185
  Returns:
180
186
  self: The model instance
181
187
  """
@@ -188,20 +194,20 @@ class OmniLoraModel(nn.Module):
188
194
  break
189
195
  for module in self.lora_model.modules():
190
196
  module.device = self.device
191
- if hasattr(module, 'dtype'):
197
+ if hasattr(module, "dtype"):
192
198
  module.dtype = self.dtype
193
199
  except Exception as e:
194
- pass # Ignore errors if parameters are not available
200
+ pass # Ignore errors if parameters are not available
195
201
  return self
196
202
 
197
203
  def forward(self, *args, **kwargs):
198
204
  """
199
205
  Forward pass through the LoRA model.
200
-
206
+
201
207
  Args:
202
208
  *args: Positional arguments for the forward pass
203
209
  **kwargs: Keyword arguments for the forward pass
204
-
210
+
205
211
  Returns:
206
212
  The output from the LoRA model
207
213
  """
@@ -210,11 +216,11 @@ class OmniLoraModel(nn.Module):
210
216
  def predict(self, *args, **kwargs):
211
217
  """
212
218
  Generate predictions using the LoRA model.
213
-
219
+
214
220
  Args:
215
221
  *args: Positional arguments for prediction
216
222
  **kwargs: Keyword arguments for prediction
217
-
223
+
218
224
  Returns:
219
225
  Model predictions
220
226
  """
@@ -223,11 +229,11 @@ class OmniLoraModel(nn.Module):
223
229
  def save(self, *args, **kwargs):
224
230
  """
225
231
  Save the LoRA model.
226
-
232
+
227
233
  Args:
228
234
  *args: Positional arguments for saving
229
235
  **kwargs: Keyword arguments for saving
230
-
236
+
231
237
  Returns:
232
238
  Result of the save operation
233
239
  """
@@ -236,7 +242,7 @@ class OmniLoraModel(nn.Module):
236
242
  def model_info(self):
237
243
  """
238
244
  Get information about the LoRA model.
239
-
245
+
240
246
  Returns:
241
247
  Model information from the base model
242
248
  """
@@ -245,10 +251,10 @@ class OmniLoraModel(nn.Module):
245
251
  def set_loss_fn(self, fn):
246
252
  """
247
253
  Set the loss function for the LoRA model.
248
-
254
+
249
255
  Args:
250
256
  fn: Loss function to set
251
-
257
+
252
258
  Returns:
253
259
  Result of setting the loss function
254
260
  """
@@ -257,10 +263,10 @@ class OmniLoraModel(nn.Module):
257
263
  def last_hidden_state_forward(self, **kwargs):
258
264
  """
259
265
  Forward pass to get the last hidden state.
260
-
266
+
261
267
  Args:
262
268
  **kwargs: Keyword arguments for the forward pass
263
-
269
+
264
270
  Returns:
265
271
  Last hidden state from the base model
266
272
  """
@@ -269,7 +275,7 @@ class OmniLoraModel(nn.Module):
269
275
  def tokenizer(self):
270
276
  """
271
277
  Get the tokenizer from the base model.
272
-
278
+
273
279
  Returns:
274
280
  The tokenizer from the base model
275
281
  """
@@ -278,7 +284,7 @@ class OmniLoraModel(nn.Module):
278
284
  def config(self):
279
285
  """
280
286
  Get the configuration from the base model.
281
-
287
+
282
288
  Returns:
283
289
  The configuration from the base model
284
290
  """
@@ -287,8 +293,8 @@ class OmniLoraModel(nn.Module):
287
293
  def model(self):
288
294
  """
289
295
  Get the base model.
290
-
296
+
291
297
  Returns:
292
298
  The base model
293
299
  """
294
- return self.lora_model.base_model.model
300
+ return self.lora_model.base_model.model
@@ -19,17 +19,17 @@ from ..abc.abstract_metric import OmniMetric
19
19
  class ClassificationMetric(OmniMetric):
20
20
  """
21
21
  Classification metric class for evaluating classification models.
22
-
22
+
23
23
  This class provides a comprehensive interface for classification metrics
24
24
  in the OmniGenome framework. It integrates with scikit-learn's classification
25
25
  metrics and provides additional functionality for handling genomic classification
26
26
  tasks.
27
-
27
+
28
28
  The class automatically exposes all scikit-learn classification metrics as
29
29
  callable attributes, making them easily accessible for evaluation. It also
30
30
  handles special cases like Hugging Face's EvalPrediction objects and
31
31
  provides proper handling of ignored labels.
32
-
32
+
33
33
  Attributes:
34
34
  metric_func (callable): A callable metric function from sklearn.metrics.
35
35
  ignore_y (any): A value in the ground truth labels to be ignored during
@@ -42,10 +42,10 @@ class ClassificationMetric(OmniMetric):
42
42
  Initializes the classification metric.
43
43
 
44
44
  Args:
45
- metric_func (callable, optional): A callable metric function from
45
+ metric_func (callable, optional): A callable metric function from
46
46
  sklearn.metrics. If None, subclasses
47
47
  should implement their own compute method.
48
- ignore_y (any, optional): A value in the ground truth labels to be
48
+ ignore_y (any, optional): A value in the ground truth labels to be
49
49
  ignored during metric computation. Defaults to -100.
50
50
  *args: Additional positional arguments.
51
51
  **kwargs: Additional keyword arguments.
@@ -53,7 +53,7 @@ class ClassificationMetric(OmniMetric):
53
53
  Example:
54
54
  >>> # Initialize with a specific metric function
55
55
  >>> metric = ClassificationMetric(metrics.accuracy_score)
56
-
56
+
57
57
  >>> # Initialize with ignore value
58
58
  >>> metric = ClassificationMetric(ignore_y=-100)
59
59
  """
@@ -64,7 +64,7 @@ class ClassificationMetric(OmniMetric):
64
64
  def __getattribute__(self, name):
65
65
  """
66
66
  Custom attribute getter that provides dynamic access to scikit-learn metrics.
67
-
67
+
68
68
  This method provides transparent access to all scikit-learn classification
69
69
  metrics. When a metric function is accessed, it returns a callable wrapper
70
70
  that handles the metric computation with proper preprocessing.
@@ -91,7 +91,7 @@ class ClassificationMetric(OmniMetric):
91
91
  def wrapper(y_true=None, y_pred=None, *args, **kwargs):
92
92
  """
93
93
  Compute the metric, based on the true and predicted values.
94
-
94
+
95
95
  This wrapper function handles various input formats including
96
96
  Hugging Face's EvalPrediction objects and provides proper
97
97
  preprocessing for metric computation.
@@ -99,7 +99,7 @@ class ClassificationMetric(OmniMetric):
99
99
  Args:
100
100
  y_true: The true values (ground truth labels).
101
101
  y_pred: The predicted values (model predictions).
102
- ignore_y: The value to ignore in the predictions and true
102
+ ignore_y: The value to ignore in the predictions and true
103
103
  values in corresponding positions.
104
104
  *args: Additional positional arguments for the metric function.
105
105
  **kwargs: Additional keyword arguments for the metric function.
@@ -111,7 +111,7 @@ class ClassificationMetric(OmniMetric):
111
111
  >>> # Standard usage
112
112
  >>> result = accuracy_fn(y_true, y_pred)
113
113
  >>> print(result) # {'accuracy_score': 0.85}
114
-
114
+
115
115
  >>> # With Hugging Face EvalPrediction
116
116
  >>> result = accuracy_fn(eval_prediction)
117
117
  >>> print(result) # {'accuracy_score': 0.85}
@@ -152,7 +152,7 @@ class ClassificationMetric(OmniMetric):
152
152
  def compute(self, y_true, y_pred, *args, **kwargs):
153
153
  """
154
154
  Compute the metric, based on the true and predicted values.
155
-
155
+
156
156
  This method computes the classification metric using the provided
157
157
  metric function. It handles preprocessing and applies any additional
158
158
  keyword arguments.