omnigenome 0.3.0a1__py3-none-any.whl → 0.3.1a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- omnigenome/__init__.py +16 -8
- omnigenome/auto/auto_bench/__init__.py +0 -1
- omnigenome/auto/auto_bench/auto_bench.py +24 -14
- omnigenome/auto/auto_train/__init__.py +0 -1
- omnigenome/auto/auto_train/auto_train.py +11 -12
- omnigenome/auto/bench_hub/__init__.py +0 -1
- omnigenome/auto/bench_hub/bench_hub.py +1 -1
- omnigenome/cli/__init__.py +0 -1
- omnigenome/cli/commands/__init__.py +0 -1
- omnigenome/cli/commands/base.py +10 -10
- omnigenome/cli/commands/bench/__init__.py +0 -1
- omnigenome/cli/commands/bench/bench_cli.py +10 -10
- omnigenome/cli/commands/rna/__init__.py +0 -1
- omnigenome/cli/commands/rna/rna_design.py +10 -11
- omnigenome/src/__init__.py +0 -1
- omnigenome/src/abc/__init__.py +0 -1
- omnigenome/src/abc/abstract_dataset.py +38 -19
- omnigenome/src/abc/abstract_metric.py +7 -7
- omnigenome/src/abc/abstract_model.py +15 -14
- omnigenome/src/abc/abstract_tokenizer.py +9 -7
- omnigenome/src/dataset/omni_dataset.py +16 -14
- omnigenome/src/lora/__init__.py +0 -1
- omnigenome/src/lora/lora_model.py +47 -41
- omnigenome/src/metric/classification_metric.py +11 -11
- omnigenome/src/metric/metric.py +19 -19
- omnigenome/src/metric/ranking_metric.py +15 -15
- omnigenome/src/metric/regression_metric.py +18 -18
- omnigenome/src/misc/utils.py +40 -36
- omnigenome/src/model/augmentation/__init__.py +0 -1
- omnigenome/src/model/augmentation/model.py +17 -17
- omnigenome/src/model/classification/__init__.py +0 -1
- omnigenome/src/model/classification/model.py +28 -32
- omnigenome/src/model/embedding/__init__.py +0 -1
- omnigenome/src/model/embedding/model.py +35 -35
- omnigenome/src/model/mlm/__init__.py +0 -1
- omnigenome/src/model/mlm/model.py +13 -13
- omnigenome/src/model/module_utils.py +17 -17
- omnigenome/src/model/regression/__init__.py +0 -1
- omnigenome/src/model/regression/model.py +72 -77
- omnigenome/src/model/regression/resnet.py +32 -32
- omnigenome/src/model/rna_design/__init__.py +0 -1
- omnigenome/src/model/rna_design/model.py +65 -58
- omnigenome/src/model/seq2seq/__init__.py +0 -1
- omnigenome/src/model/seq2seq/model.py +4 -4
- omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
- omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
- omnigenome/src/trainer/accelerate_trainer.py +40 -32
- omnigenome/src/trainer/hf_trainer.py +8 -8
- omnigenome/src/trainer/trainer.py +37 -25
- omnigenome/utility/dataset_hub/__init__.py +0 -1
- omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
- omnigenome/utility/ensemble.py +26 -26
- omnigenome/utility/hub_utils.py +8 -8
- omnigenome/utility/model_hub/__init__.py +0 -1
- omnigenome/utility/model_hub/model_hub.py +26 -25
- omnigenome/utility/pipeline_hub/__init__.py +0 -1
- omnigenome/utility/pipeline_hub/pipeline.py +49 -49
- omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +2 -2
- omnigenome-0.3.1a0.dist-info/RECORD +78 -0
- omnigenome-0.3.0a1.dist-info/RECORD +0 -78
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -0
|
@@ -47,14 +47,14 @@ def count_parameters(model):
|
|
|
47
47
|
class OmniModel(torch.nn.Module):
|
|
48
48
|
"""
|
|
49
49
|
Abstract base class for all models in OmniGenome.
|
|
50
|
-
|
|
50
|
+
|
|
51
51
|
This class provides a unified interface for all genomic models in the OmniGenome
|
|
52
52
|
framework. It handles model initialization, forward passes, loss computation,
|
|
53
53
|
prediction, inference, and model persistence.
|
|
54
|
-
|
|
54
|
+
|
|
55
55
|
The class is designed to work with various types of genomic data and tasks,
|
|
56
56
|
including sequence classification, token classification, regression, and more.
|
|
57
|
-
|
|
57
|
+
|
|
58
58
|
Attributes:
|
|
59
59
|
model (torch.nn.Module): The underlying PyTorch model.
|
|
60
60
|
config: The model configuration.
|
|
@@ -76,16 +76,16 @@ class OmniModel(torch.nn.Module):
|
|
|
76
76
|
- From a configuration object
|
|
77
77
|
|
|
78
78
|
Args:
|
|
79
|
-
config_or_model: A model configuration, a pre-trained model path (str),
|
|
79
|
+
config_or_model: A model configuration, a pre-trained model path (str),
|
|
80
80
|
or a `torch.nn.Module` instance.
|
|
81
81
|
tokenizer: The tokenizer associated with the model.
|
|
82
82
|
*args: Additional positional arguments.
|
|
83
83
|
**kwargs: Additional keyword arguments.
|
|
84
84
|
- label2id (dict): Mapping from class labels to IDs.
|
|
85
85
|
- num_labels (int): The number of labels.
|
|
86
|
-
- trust_remote_code (bool): Whether to trust remote code when loading
|
|
86
|
+
- trust_remote_code (bool): Whether to trust remote code when loading
|
|
87
87
|
from Hugging Face Hub. Defaults to True.
|
|
88
|
-
- ignore_mismatched_sizes (bool): Whether to ignore size mismatches
|
|
88
|
+
- ignore_mismatched_sizes (bool): Whether to ignore size mismatches
|
|
89
89
|
when loading pre-trained weights. Defaults to False.
|
|
90
90
|
- dropout (float): Dropout rate. Defaults to 0.0.
|
|
91
91
|
|
|
@@ -97,7 +97,7 @@ class OmniModel(torch.nn.Module):
|
|
|
97
97
|
Example:
|
|
98
98
|
>>> # Initialize from a pre-trained model
|
|
99
99
|
>>> model = OmniModelForSequenceClassification("model_path", tokenizer)
|
|
100
|
-
|
|
100
|
+
|
|
101
101
|
>>> # Initialize from a configuration
|
|
102
102
|
>>> config = AutoConfig.from_pretrained("model_path")
|
|
103
103
|
>>> model = OmniModelForSequenceClassification(config, tokenizer)
|
|
@@ -202,7 +202,9 @@ class OmniModel(torch.nn.Module):
|
|
|
202
202
|
)
|
|
203
203
|
self.config.num_labels = len(self.config.id2label)
|
|
204
204
|
|
|
205
|
-
assert
|
|
205
|
+
assert (
|
|
206
|
+
len(self.config.label2id) == num_labels
|
|
207
|
+
), f"Expected {num_labels} labels, but got {len(self.config.label2id)} in label2id dictionary."
|
|
206
208
|
|
|
207
209
|
# The metadata of the model
|
|
208
210
|
self.metadata = env_meta_info()
|
|
@@ -240,7 +242,7 @@ class OmniModel(torch.nn.Module):
|
|
|
240
242
|
model architectures by mapping input parameters appropriately.
|
|
241
243
|
|
|
242
244
|
Args:
|
|
243
|
-
**inputs: The inputs to the model, compatible with the base model's
|
|
245
|
+
**inputs: The inputs to the model, compatible with the base model's
|
|
244
246
|
forward method. Typically includes 'input_ids', 'attention_mask',
|
|
245
247
|
and other model-specific parameters.
|
|
246
248
|
|
|
@@ -386,7 +388,7 @@ class OmniModel(torch.nn.Module):
|
|
|
386
388
|
predictions for further processing.
|
|
387
389
|
|
|
388
390
|
Args:
|
|
389
|
-
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
391
|
+
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
390
392
|
tokenized inputs (dict/tuple).
|
|
391
393
|
**kwargs: Additional arguments for tokenization and inference.
|
|
392
394
|
|
|
@@ -398,7 +400,7 @@ class OmniModel(torch.nn.Module):
|
|
|
398
400
|
Example:
|
|
399
401
|
>>> # Predict on a single sequence
|
|
400
402
|
>>> outputs = model.predict("ATCGATCG")
|
|
401
|
-
|
|
403
|
+
|
|
402
404
|
>>> # Predict on multiple sequences
|
|
403
405
|
>>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"])
|
|
404
406
|
"""
|
|
@@ -416,7 +418,7 @@ class OmniModel(torch.nn.Module):
|
|
|
416
418
|
to class labels or probabilities.
|
|
417
419
|
|
|
418
420
|
Args:
|
|
419
|
-
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
421
|
+
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
420
422
|
tokenized inputs (dict/tuple).
|
|
421
423
|
**kwargs: Additional arguments for tokenization and inference.
|
|
422
424
|
|
|
@@ -429,7 +431,7 @@ class OmniModel(torch.nn.Module):
|
|
|
429
431
|
>>> # Inference on a single sequence
|
|
430
432
|
>>> results = model.inference("ATCGATCG")
|
|
431
433
|
>>> print(results['predictions']) # Class labels
|
|
432
|
-
|
|
434
|
+
|
|
433
435
|
>>> # Inference on multiple sequences
|
|
434
436
|
>>> results = model.inference(["ATCGATCG", "GCTAGCTA"])
|
|
435
437
|
"""
|
|
@@ -686,4 +688,3 @@ class OmniModel(torch.nn.Module):
|
|
|
686
688
|
info += f"Model Config: {self.config}\n"
|
|
687
689
|
fprint(info)
|
|
688
690
|
return info
|
|
689
|
-
|
|
@@ -16,15 +16,15 @@ from ..misc.utils import env_meta_info, load_module_from_path
|
|
|
16
16
|
class OmniTokenizer:
|
|
17
17
|
"""
|
|
18
18
|
A wrapper class for tokenizers to provide a consistent interface within OmniGenome.
|
|
19
|
-
|
|
19
|
+
|
|
20
20
|
This class provides a unified interface for tokenizers in the OmniGenome framework.
|
|
21
21
|
It wraps underlying tokenizers (typically from Hugging Face) and provides
|
|
22
22
|
additional functionality for genomic sequence processing.
|
|
23
|
-
|
|
23
|
+
|
|
24
24
|
The class handles various tokenization strategies and provides compatibility
|
|
25
25
|
with different model architectures. It also supports custom tokenizer wrappers
|
|
26
26
|
for specialized genomic tasks.
|
|
27
|
-
|
|
27
|
+
|
|
28
28
|
Attributes:
|
|
29
29
|
base_tokenizer: The underlying tokenizer instance (e.g., from Hugging Face).
|
|
30
30
|
max_length (int): The default maximum sequence length.
|
|
@@ -52,7 +52,7 @@ class OmniTokenizer:
|
|
|
52
52
|
>>> from transformers import AutoTokenizer
|
|
53
53
|
>>> base_tokenizer = AutoTokenizer.from_pretrained("model_name")
|
|
54
54
|
>>> tokenizer = OmniTokenizer(base_tokenizer, max_length=512)
|
|
55
|
-
|
|
55
|
+
|
|
56
56
|
>>> # Initialize with sequence conversion
|
|
57
57
|
>>> tokenizer = OmniTokenizer(base_tokenizer, u2t=True)
|
|
58
58
|
"""
|
|
@@ -87,9 +87,9 @@ class OmniTokenizer:
|
|
|
87
87
|
Example:
|
|
88
88
|
>>> # Load from a pre-trained model
|
|
89
89
|
>>> tokenizer = OmniTokenizer.from_pretrained("model_name")
|
|
90
|
-
|
|
90
|
+
|
|
91
91
|
>>> # Load with custom parameters
|
|
92
|
-
>>> tokenizer = OmniTokenizer.from_pretrained("model_name",
|
|
92
|
+
>>> tokenizer = OmniTokenizer.from_pretrained("model_name",
|
|
93
93
|
... trust_remote_code=True)
|
|
94
94
|
"""
|
|
95
95
|
wrapper_path = f"{model_name_or_path.rstrip('/')}/omnigenome_wrapper.py"
|
|
@@ -104,7 +104,9 @@ class OmniTokenizer:
|
|
|
104
104
|
warnings.warn(
|
|
105
105
|
f"No tokenizer wrapper found in {wrapper_path} -> Exception: {e}"
|
|
106
106
|
)
|
|
107
|
-
kwargs.pop(
|
|
107
|
+
kwargs.pop(
|
|
108
|
+
"num_labels", None
|
|
109
|
+
) # Remove num_labels if it exists, as it may not be applicable
|
|
108
110
|
|
|
109
111
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, **kwargs)
|
|
110
112
|
|
|
@@ -27,11 +27,11 @@ from ... import __name__, __version__
|
|
|
27
27
|
class OmniDatasetForTokenClassification(OmniDataset):
|
|
28
28
|
"""
|
|
29
29
|
Dataset class specifically designed for token classification tasks in genomics.
|
|
30
|
-
|
|
30
|
+
|
|
31
31
|
This class extends `OmniDataset` to provide functionalities for preparing input sequences
|
|
32
32
|
and their corresponding token-level labels. It's designed for tasks where each token
|
|
33
33
|
in a sequence needs to be classified independently.
|
|
34
|
-
|
|
34
|
+
|
|
35
35
|
Attributes:
|
|
36
36
|
metadata: Dictionary containing dataset metadata including library information
|
|
37
37
|
label2id: Mapping from label strings to integer IDs
|
|
@@ -68,7 +68,7 @@ class OmniDatasetForTokenClassification(OmniDataset):
|
|
|
68
68
|
def prepare_input(self, instance, **kwargs):
|
|
69
69
|
"""
|
|
70
70
|
Prepare a single data instance for token classification.
|
|
71
|
-
|
|
71
|
+
|
|
72
72
|
This method handles both string sequences and dictionary instances
|
|
73
73
|
containing sequence and label information. It tokenizes the input
|
|
74
74
|
sequence and prepares token-level labels for classification.
|
|
@@ -138,11 +138,11 @@ class OmniDatasetForTokenClassification(OmniDataset):
|
|
|
138
138
|
class OmniDatasetForSequenceClassification(OmniDataset):
|
|
139
139
|
"""
|
|
140
140
|
Dataset class for sequence classification tasks in genomics.
|
|
141
|
-
|
|
141
|
+
|
|
142
142
|
This class extends `OmniDataset` to prepare input sequences and their corresponding
|
|
143
143
|
sequence-level labels. It's designed for tasks where the entire sequence needs
|
|
144
144
|
to be classified into one of several categories.
|
|
145
|
-
|
|
145
|
+
|
|
146
146
|
Attributes:
|
|
147
147
|
metadata: Dictionary containing dataset metadata including library information
|
|
148
148
|
label2id: Mapping from label strings to integer IDs
|
|
@@ -179,7 +179,7 @@ class OmniDatasetForSequenceClassification(OmniDataset):
|
|
|
179
179
|
def prepare_input(self, instance, **kwargs):
|
|
180
180
|
"""
|
|
181
181
|
Prepare a single data instance for sequence classification.
|
|
182
|
-
|
|
182
|
+
|
|
183
183
|
This method handles both string sequences and dictionary instances
|
|
184
184
|
containing sequence and label information. It tokenizes the input
|
|
185
185
|
sequence and prepares sequence-level labels for classification.
|
|
@@ -238,11 +238,11 @@ class OmniDatasetForSequenceClassification(OmniDataset):
|
|
|
238
238
|
class OmniDatasetForTokenRegression(OmniDataset):
|
|
239
239
|
"""
|
|
240
240
|
Dataset class for token regression tasks in genomics.
|
|
241
|
-
|
|
241
|
+
|
|
242
242
|
This class extends `OmniDataset` to prepare input sequences and their corresponding
|
|
243
243
|
token-level regression targets. It's designed for tasks where each token in a
|
|
244
244
|
sequence needs to be assigned a continuous value.
|
|
245
|
-
|
|
245
|
+
|
|
246
246
|
Attributes:
|
|
247
247
|
metadata: Dictionary containing dataset metadata including library information
|
|
248
248
|
"""
|
|
@@ -278,7 +278,7 @@ class OmniDatasetForTokenRegression(OmniDataset):
|
|
|
278
278
|
def prepare_input(self, instance, **kwargs):
|
|
279
279
|
"""
|
|
280
280
|
Prepare a single data instance for token regression.
|
|
281
|
-
|
|
281
|
+
|
|
282
282
|
This method handles both string sequences and dictionary instances
|
|
283
283
|
containing sequence and regression target information. It tokenizes
|
|
284
284
|
the input sequence and prepares token-level regression targets.
|
|
@@ -330,7 +330,9 @@ class OmniDatasetForTokenRegression(OmniDataset):
|
|
|
330
330
|
# Handle token-level regression labels
|
|
331
331
|
if isinstance(labels, (list, tuple)):
|
|
332
332
|
# Ensure labels match sequence length
|
|
333
|
-
labels = list(labels)[
|
|
333
|
+
labels = list(labels)[
|
|
334
|
+
: self.max_length - 2
|
|
335
|
+
] # Account for special tokens
|
|
334
336
|
labels = [-100] + labels + [-100] # Add padding for special tokens
|
|
335
337
|
else:
|
|
336
338
|
# Single value for the entire sequence
|
|
@@ -343,11 +345,11 @@ class OmniDatasetForTokenRegression(OmniDataset):
|
|
|
343
345
|
class OmniDatasetForSequenceRegression(OmniDataset):
|
|
344
346
|
"""
|
|
345
347
|
Dataset class for sequence regression tasks in genomics.
|
|
346
|
-
|
|
348
|
+
|
|
347
349
|
This class extends `OmniDataset` to prepare input sequences and their corresponding
|
|
348
350
|
sequence-level regression targets. It's designed for tasks where the entire
|
|
349
351
|
sequence needs to be assigned a continuous value.
|
|
350
|
-
|
|
352
|
+
|
|
351
353
|
Attributes:
|
|
352
354
|
metadata: Dictionary containing dataset metadata including library information
|
|
353
355
|
"""
|
|
@@ -383,7 +385,7 @@ class OmniDatasetForSequenceRegression(OmniDataset):
|
|
|
383
385
|
def prepare_input(self, instance, **kwargs):
|
|
384
386
|
"""
|
|
385
387
|
Prepare a single data instance for sequence regression.
|
|
386
|
-
|
|
388
|
+
|
|
387
389
|
This method handles both string sequences and dictionary instances
|
|
388
390
|
containing sequence and regression target information. It tokenizes
|
|
389
391
|
the input sequence and prepares sequence-level regression targets.
|
|
@@ -432,4 +434,4 @@ class OmniDatasetForSequenceRegression(OmniDataset):
|
|
|
432
434
|
labels = float(labels)
|
|
433
435
|
|
|
434
436
|
tokenized_inputs["labels"] = torch.tensor(labels, dtype=torch.float32)
|
|
435
|
-
return tokenized_inputs
|
|
437
|
+
return tokenized_inputs
|
omnigenome/src/lora/__init__.py
CHANGED
|
@@ -18,22 +18,23 @@ import torch
|
|
|
18
18
|
from torch import nn
|
|
19
19
|
from omnigenome.src.misc.utils import fprint
|
|
20
20
|
|
|
21
|
+
|
|
21
22
|
def find_linear_target_modules(model, keyword_filter=None, use_full_path=True):
|
|
22
23
|
"""
|
|
23
24
|
Find linear modules in a model that can be targeted for LoRA adaptation.
|
|
24
|
-
|
|
25
|
+
|
|
25
26
|
This function searches through a model's modules to identify linear layers
|
|
26
27
|
that can be adapted using LoRA. It supports filtering by keyword patterns
|
|
27
28
|
to target specific types of layers.
|
|
28
|
-
|
|
29
|
+
|
|
29
30
|
Args:
|
|
30
31
|
model: The model to search for linear modules
|
|
31
32
|
keyword_filter (str, list, tuple, optional): Keywords to filter modules by name
|
|
32
33
|
use_full_path (bool): Whether to return full module paths or just names (default: True)
|
|
33
|
-
|
|
34
|
+
|
|
34
35
|
Returns:
|
|
35
36
|
list: Sorted list of linear module names that can be targeted for LoRA
|
|
36
|
-
|
|
37
|
+
|
|
37
38
|
Raises:
|
|
38
39
|
TypeError: If keyword_filter is not None, str, or a list/tuple of str
|
|
39
40
|
"""
|
|
@@ -46,31 +47,32 @@ def find_linear_target_modules(model, keyword_filter=None, use_full_path=True):
|
|
|
46
47
|
elif not isinstance(keyword_filter, (list, tuple)):
|
|
47
48
|
raise TypeError("keyword_filter must be None, str, or a list/tuple of str")
|
|
48
49
|
|
|
49
|
-
pattern =
|
|
50
|
+
pattern = "|".join(map(re.escape, keyword_filter))
|
|
50
51
|
|
|
51
52
|
linear_modules = set()
|
|
52
53
|
for name, module in model.named_modules():
|
|
53
54
|
if isinstance(module, nn.Linear):
|
|
54
55
|
if keyword_filter is None or re.search(pattern, name, re.IGNORECASE):
|
|
55
|
-
linear_modules.add(name if use_full_path else name.split(
|
|
56
|
+
linear_modules.add(name if use_full_path else name.split(".")[-1])
|
|
56
57
|
|
|
57
58
|
return sorted(linear_modules)
|
|
58
59
|
|
|
60
|
+
|
|
59
61
|
def auto_lora_model(model, **kwargs):
|
|
60
62
|
"""
|
|
61
63
|
Automatically create a LoRA-adapted model.
|
|
62
|
-
|
|
64
|
+
|
|
63
65
|
This function automatically identifies suitable target modules and creates
|
|
64
66
|
a LoRA-adapted version of the input model. It handles configuration
|
|
65
67
|
setup and parameter freezing for efficient fine-tuning.
|
|
66
|
-
|
|
68
|
+
|
|
67
69
|
Args:
|
|
68
70
|
model: The base model to adapt with LoRA
|
|
69
71
|
**kwargs: Additional LoRA configuration parameters
|
|
70
|
-
|
|
72
|
+
|
|
71
73
|
Returns:
|
|
72
74
|
The LoRA-adapted model
|
|
73
|
-
|
|
75
|
+
|
|
74
76
|
Raises:
|
|
75
77
|
AssertionError: If no target modules are found for LoRA injection
|
|
76
78
|
"""
|
|
@@ -79,8 +81,8 @@ def auto_lora_model(model, **kwargs):
|
|
|
79
81
|
|
|
80
82
|
# A bad case for the EVO-1 model, which has a custom config class
|
|
81
83
|
######################
|
|
82
|
-
if hasattr(model,
|
|
83
|
-
delattr(model.config,
|
|
84
|
+
if hasattr(model, "config") and not isinstance(model.config, PretrainedConfig):
|
|
85
|
+
delattr(model.config, "Loader")
|
|
84
86
|
model.config = PretrainedConfig.from_dict(dict(model.config))
|
|
85
87
|
#######################
|
|
86
88
|
|
|
@@ -92,7 +94,9 @@ def auto_lora_model(model, **kwargs):
|
|
|
92
94
|
lora_dropout = kwargs.pop("lora_dropout", 0.1)
|
|
93
95
|
|
|
94
96
|
if target_modules is None:
|
|
95
|
-
target_modules = find_linear_target_modules(
|
|
97
|
+
target_modules = find_linear_target_modules(
|
|
98
|
+
model, keyword_filter=kwargs.get("keyword_filter", None)
|
|
99
|
+
)
|
|
96
100
|
assert target_modules is not None, "No target modules found for LoRA injection."
|
|
97
101
|
config = LoraConfig(
|
|
98
102
|
target_modules=target_modules,
|
|
@@ -115,29 +119,30 @@ def auto_lora_model(model, **kwargs):
|
|
|
115
119
|
)
|
|
116
120
|
return lora_model
|
|
117
121
|
|
|
122
|
+
|
|
118
123
|
class OmniLoraModel(nn.Module):
|
|
119
124
|
"""
|
|
120
125
|
LoRA-adapted model for OmniGenome.
|
|
121
|
-
|
|
126
|
+
|
|
122
127
|
This class provides a wrapper around LoRA-adapted models, enabling
|
|
123
128
|
efficient fine-tuning of large genomic language models while maintaining
|
|
124
129
|
compatibility with the OmniGenome framework.
|
|
125
|
-
|
|
130
|
+
|
|
126
131
|
Attributes:
|
|
127
132
|
lora_model: The underlying LoRA-adapted model
|
|
128
133
|
config: Model configuration
|
|
129
134
|
device: Device the model is running on
|
|
130
135
|
dtype: Data type of the model parameters
|
|
131
136
|
"""
|
|
132
|
-
|
|
137
|
+
|
|
133
138
|
def __init__(self, model, **kwargs):
|
|
134
139
|
"""
|
|
135
140
|
Initialize the LoRA-adapted model.
|
|
136
|
-
|
|
141
|
+
|
|
137
142
|
Args:
|
|
138
143
|
model: The base model to adapt with LoRA
|
|
139
144
|
**kwargs: LoRA configuration parameters
|
|
140
|
-
|
|
145
|
+
|
|
141
146
|
Raises:
|
|
142
147
|
ValueError: If no target modules are specified for LoRA injection
|
|
143
148
|
"""
|
|
@@ -147,7 +152,8 @@ class OmniLoraModel(nn.Module):
|
|
|
147
152
|
raise ValueError(
|
|
148
153
|
"No target modules found for LoRA injection. To perform LoRA adaptation fine-tuning, "
|
|
149
154
|
"please specify the target modules using the 'target_modules' argument. "
|
|
150
|
-
"The target modules depend on the model architecture, such as 'query', 'value', etc. "
|
|
155
|
+
"The target modules depend on the model architecture, such as 'query', 'value', etc. "
|
|
156
|
+
)
|
|
151
157
|
|
|
152
158
|
self.lora_model = auto_lora_model(model, **kwargs)
|
|
153
159
|
|
|
@@ -159,23 +165,23 @@ class OmniLoraModel(nn.Module):
|
|
|
159
165
|
)
|
|
160
166
|
|
|
161
167
|
self.config = model.config
|
|
162
|
-
self.to(
|
|
168
|
+
self.to("cpu") # Move the model to CPU initially
|
|
163
169
|
fprint(
|
|
164
170
|
"LoRA model initialized with the following configuration:\n",
|
|
165
|
-
self.lora_model
|
|
171
|
+
self.lora_model,
|
|
166
172
|
)
|
|
167
173
|
|
|
168
174
|
def to(self, *args, **kwargs):
|
|
169
175
|
"""
|
|
170
176
|
Move the model to a specific device and data type.
|
|
171
|
-
|
|
177
|
+
|
|
172
178
|
This method overrides the default to() method to ensure the LoRA model
|
|
173
179
|
and its components are properly moved to the target device and dtype.
|
|
174
|
-
|
|
180
|
+
|
|
175
181
|
Args:
|
|
176
182
|
*args: Device specification (e.g., 'cuda', 'cpu')
|
|
177
183
|
**kwargs: Additional arguments including dtype
|
|
178
|
-
|
|
184
|
+
|
|
179
185
|
Returns:
|
|
180
186
|
self: The model instance
|
|
181
187
|
"""
|
|
@@ -188,20 +194,20 @@ class OmniLoraModel(nn.Module):
|
|
|
188
194
|
break
|
|
189
195
|
for module in self.lora_model.modules():
|
|
190
196
|
module.device = self.device
|
|
191
|
-
if hasattr(module,
|
|
197
|
+
if hasattr(module, "dtype"):
|
|
192
198
|
module.dtype = self.dtype
|
|
193
199
|
except Exception as e:
|
|
194
|
-
pass
|
|
200
|
+
pass # Ignore errors if parameters are not available
|
|
195
201
|
return self
|
|
196
202
|
|
|
197
203
|
def forward(self, *args, **kwargs):
|
|
198
204
|
"""
|
|
199
205
|
Forward pass through the LoRA model.
|
|
200
|
-
|
|
206
|
+
|
|
201
207
|
Args:
|
|
202
208
|
*args: Positional arguments for the forward pass
|
|
203
209
|
**kwargs: Keyword arguments for the forward pass
|
|
204
|
-
|
|
210
|
+
|
|
205
211
|
Returns:
|
|
206
212
|
The output from the LoRA model
|
|
207
213
|
"""
|
|
@@ -210,11 +216,11 @@ class OmniLoraModel(nn.Module):
|
|
|
210
216
|
def predict(self, *args, **kwargs):
|
|
211
217
|
"""
|
|
212
218
|
Generate predictions using the LoRA model.
|
|
213
|
-
|
|
219
|
+
|
|
214
220
|
Args:
|
|
215
221
|
*args: Positional arguments for prediction
|
|
216
222
|
**kwargs: Keyword arguments for prediction
|
|
217
|
-
|
|
223
|
+
|
|
218
224
|
Returns:
|
|
219
225
|
Model predictions
|
|
220
226
|
"""
|
|
@@ -223,11 +229,11 @@ class OmniLoraModel(nn.Module):
|
|
|
223
229
|
def save(self, *args, **kwargs):
|
|
224
230
|
"""
|
|
225
231
|
Save the LoRA model.
|
|
226
|
-
|
|
232
|
+
|
|
227
233
|
Args:
|
|
228
234
|
*args: Positional arguments for saving
|
|
229
235
|
**kwargs: Keyword arguments for saving
|
|
230
|
-
|
|
236
|
+
|
|
231
237
|
Returns:
|
|
232
238
|
Result of the save operation
|
|
233
239
|
"""
|
|
@@ -236,7 +242,7 @@ class OmniLoraModel(nn.Module):
|
|
|
236
242
|
def model_info(self):
|
|
237
243
|
"""
|
|
238
244
|
Get information about the LoRA model.
|
|
239
|
-
|
|
245
|
+
|
|
240
246
|
Returns:
|
|
241
247
|
Model information from the base model
|
|
242
248
|
"""
|
|
@@ -245,10 +251,10 @@ class OmniLoraModel(nn.Module):
|
|
|
245
251
|
def set_loss_fn(self, fn):
|
|
246
252
|
"""
|
|
247
253
|
Set the loss function for the LoRA model.
|
|
248
|
-
|
|
254
|
+
|
|
249
255
|
Args:
|
|
250
256
|
fn: Loss function to set
|
|
251
|
-
|
|
257
|
+
|
|
252
258
|
Returns:
|
|
253
259
|
Result of setting the loss function
|
|
254
260
|
"""
|
|
@@ -257,10 +263,10 @@ class OmniLoraModel(nn.Module):
|
|
|
257
263
|
def last_hidden_state_forward(self, **kwargs):
|
|
258
264
|
"""
|
|
259
265
|
Forward pass to get the last hidden state.
|
|
260
|
-
|
|
266
|
+
|
|
261
267
|
Args:
|
|
262
268
|
**kwargs: Keyword arguments for the forward pass
|
|
263
|
-
|
|
269
|
+
|
|
264
270
|
Returns:
|
|
265
271
|
Last hidden state from the base model
|
|
266
272
|
"""
|
|
@@ -269,7 +275,7 @@ class OmniLoraModel(nn.Module):
|
|
|
269
275
|
def tokenizer(self):
|
|
270
276
|
"""
|
|
271
277
|
Get the tokenizer from the base model.
|
|
272
|
-
|
|
278
|
+
|
|
273
279
|
Returns:
|
|
274
280
|
The tokenizer from the base model
|
|
275
281
|
"""
|
|
@@ -278,7 +284,7 @@ class OmniLoraModel(nn.Module):
|
|
|
278
284
|
def config(self):
|
|
279
285
|
"""
|
|
280
286
|
Get the configuration from the base model.
|
|
281
|
-
|
|
287
|
+
|
|
282
288
|
Returns:
|
|
283
289
|
The configuration from the base model
|
|
284
290
|
"""
|
|
@@ -287,8 +293,8 @@ class OmniLoraModel(nn.Module):
|
|
|
287
293
|
def model(self):
|
|
288
294
|
"""
|
|
289
295
|
Get the base model.
|
|
290
|
-
|
|
296
|
+
|
|
291
297
|
Returns:
|
|
292
298
|
The base model
|
|
293
299
|
"""
|
|
294
|
-
return self.lora_model.base_model.model
|
|
300
|
+
return self.lora_model.base_model.model
|
|
@@ -19,17 +19,17 @@ from ..abc.abstract_metric import OmniMetric
|
|
|
19
19
|
class ClassificationMetric(OmniMetric):
|
|
20
20
|
"""
|
|
21
21
|
Classification metric class for evaluating classification models.
|
|
22
|
-
|
|
22
|
+
|
|
23
23
|
This class provides a comprehensive interface for classification metrics
|
|
24
24
|
in the OmniGenome framework. It integrates with scikit-learn's classification
|
|
25
25
|
metrics and provides additional functionality for handling genomic classification
|
|
26
26
|
tasks.
|
|
27
|
-
|
|
27
|
+
|
|
28
28
|
The class automatically exposes all scikit-learn classification metrics as
|
|
29
29
|
callable attributes, making them easily accessible for evaluation. It also
|
|
30
30
|
handles special cases like Hugging Face's EvalPrediction objects and
|
|
31
31
|
provides proper handling of ignored labels.
|
|
32
|
-
|
|
32
|
+
|
|
33
33
|
Attributes:
|
|
34
34
|
metric_func (callable): A callable metric function from sklearn.metrics.
|
|
35
35
|
ignore_y (any): A value in the ground truth labels to be ignored during
|
|
@@ -42,10 +42,10 @@ class ClassificationMetric(OmniMetric):
|
|
|
42
42
|
Initializes the classification metric.
|
|
43
43
|
|
|
44
44
|
Args:
|
|
45
|
-
metric_func (callable, optional): A callable metric function from
|
|
45
|
+
metric_func (callable, optional): A callable metric function from
|
|
46
46
|
sklearn.metrics. If None, subclasses
|
|
47
47
|
should implement their own compute method.
|
|
48
|
-
ignore_y (any, optional): A value in the ground truth labels to be
|
|
48
|
+
ignore_y (any, optional): A value in the ground truth labels to be
|
|
49
49
|
ignored during metric computation. Defaults to -100.
|
|
50
50
|
*args: Additional positional arguments.
|
|
51
51
|
**kwargs: Additional keyword arguments.
|
|
@@ -53,7 +53,7 @@ class ClassificationMetric(OmniMetric):
|
|
|
53
53
|
Example:
|
|
54
54
|
>>> # Initialize with a specific metric function
|
|
55
55
|
>>> metric = ClassificationMetric(metrics.accuracy_score)
|
|
56
|
-
|
|
56
|
+
|
|
57
57
|
>>> # Initialize with ignore value
|
|
58
58
|
>>> metric = ClassificationMetric(ignore_y=-100)
|
|
59
59
|
"""
|
|
@@ -64,7 +64,7 @@ class ClassificationMetric(OmniMetric):
|
|
|
64
64
|
def __getattribute__(self, name):
|
|
65
65
|
"""
|
|
66
66
|
Custom attribute getter that provides dynamic access to scikit-learn metrics.
|
|
67
|
-
|
|
67
|
+
|
|
68
68
|
This method provides transparent access to all scikit-learn classification
|
|
69
69
|
metrics. When a metric function is accessed, it returns a callable wrapper
|
|
70
70
|
that handles the metric computation with proper preprocessing.
|
|
@@ -91,7 +91,7 @@ class ClassificationMetric(OmniMetric):
|
|
|
91
91
|
def wrapper(y_true=None, y_pred=None, *args, **kwargs):
|
|
92
92
|
"""
|
|
93
93
|
Compute the metric, based on the true and predicted values.
|
|
94
|
-
|
|
94
|
+
|
|
95
95
|
This wrapper function handles various input formats including
|
|
96
96
|
Hugging Face's EvalPrediction objects and provides proper
|
|
97
97
|
preprocessing for metric computation.
|
|
@@ -99,7 +99,7 @@ class ClassificationMetric(OmniMetric):
|
|
|
99
99
|
Args:
|
|
100
100
|
y_true: The true values (ground truth labels).
|
|
101
101
|
y_pred: The predicted values (model predictions).
|
|
102
|
-
ignore_y: The value to ignore in the predictions and true
|
|
102
|
+
ignore_y: The value to ignore in the predictions and true
|
|
103
103
|
values in corresponding positions.
|
|
104
104
|
*args: Additional positional arguments for the metric function.
|
|
105
105
|
**kwargs: Additional keyword arguments for the metric function.
|
|
@@ -111,7 +111,7 @@ class ClassificationMetric(OmniMetric):
|
|
|
111
111
|
>>> # Standard usage
|
|
112
112
|
>>> result = accuracy_fn(y_true, y_pred)
|
|
113
113
|
>>> print(result) # {'accuracy_score': 0.85}
|
|
114
|
-
|
|
114
|
+
|
|
115
115
|
>>> # With Hugging Face EvalPrediction
|
|
116
116
|
>>> result = accuracy_fn(eval_prediction)
|
|
117
117
|
>>> print(result) # {'accuracy_score': 0.85}
|
|
@@ -152,7 +152,7 @@ class ClassificationMetric(OmniMetric):
|
|
|
152
152
|
def compute(self, y_true, y_pred, *args, **kwargs):
|
|
153
153
|
"""
|
|
154
154
|
Compute the metric, based on the true and predicted values.
|
|
155
|
-
|
|
155
|
+
|
|
156
156
|
This method computes the classification metric using the provided
|
|
157
157
|
metric function. It handles preprocessing and applies any additional
|
|
158
158
|
keyword arguments.
|