omnigenome 0.3.0a0__py3-none-any.whl → 0.3.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. omnigenome/__init__.py +29 -44
  2. omnigenome/auto/auto_bench/__init__.py +0 -1
  3. omnigenome/auto/auto_bench/auto_bench.py +24 -14
  4. omnigenome/auto/auto_train/__init__.py +0 -1
  5. omnigenome/auto/auto_train/auto_train.py +11 -12
  6. omnigenome/auto/bench_hub/__init__.py +0 -1
  7. omnigenome/auto/bench_hub/bench_hub.py +1 -1
  8. omnigenome/cli/__init__.py +0 -1
  9. omnigenome/cli/commands/__init__.py +0 -1
  10. omnigenome/cli/commands/base.py +10 -10
  11. omnigenome/cli/commands/bench/__init__.py +0 -1
  12. omnigenome/cli/commands/bench/bench_cli.py +10 -10
  13. omnigenome/cli/commands/rna/__init__.py +0 -1
  14. omnigenome/cli/commands/rna/rna_design.py +10 -11
  15. omnigenome/src/__init__.py +0 -1
  16. omnigenome/src/abc/__init__.py +0 -1
  17. omnigenome/src/abc/abstract_dataset.py +38 -19
  18. omnigenome/src/abc/abstract_metric.py +7 -7
  19. omnigenome/src/abc/abstract_model.py +15 -14
  20. omnigenome/src/abc/abstract_tokenizer.py +9 -7
  21. omnigenome/src/dataset/omni_dataset.py +16 -14
  22. omnigenome/src/lora/__init__.py +0 -1
  23. omnigenome/src/lora/lora_model.py +47 -41
  24. omnigenome/src/metric/classification_metric.py +11 -11
  25. omnigenome/src/metric/metric.py +19 -19
  26. omnigenome/src/metric/ranking_metric.py +15 -15
  27. omnigenome/src/metric/regression_metric.py +18 -18
  28. omnigenome/src/misc/utils.py +214 -150
  29. omnigenome/src/model/augmentation/__init__.py +0 -1
  30. omnigenome/src/model/augmentation/model.py +17 -17
  31. omnigenome/src/model/classification/__init__.py +0 -1
  32. omnigenome/src/model/classification/model.py +28 -32
  33. omnigenome/src/model/embedding/__init__.py +0 -1
  34. omnigenome/src/model/embedding/model.py +35 -35
  35. omnigenome/src/model/mlm/__init__.py +0 -1
  36. omnigenome/src/model/mlm/model.py +13 -13
  37. omnigenome/src/model/module_utils.py +17 -17
  38. omnigenome/src/model/regression/__init__.py +0 -1
  39. omnigenome/src/model/regression/model.py +72 -77
  40. omnigenome/src/model/regression/resnet.py +32 -32
  41. omnigenome/src/model/rna_design/__init__.py +0 -1
  42. omnigenome/src/model/rna_design/model.py +168 -118
  43. omnigenome/src/model/seq2seq/__init__.py +0 -1
  44. omnigenome/src/model/seq2seq/model.py +4 -4
  45. omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
  46. omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
  47. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
  48. omnigenome/src/trainer/accelerate_trainer.py +40 -32
  49. omnigenome/src/trainer/hf_trainer.py +8 -8
  50. omnigenome/src/trainer/trainer.py +37 -25
  51. omnigenome/utility/dataset_hub/__init__.py +0 -1
  52. omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
  53. omnigenome/utility/ensemble.py +26 -26
  54. omnigenome/utility/hub_utils.py +8 -8
  55. omnigenome/utility/model_hub/__init__.py +0 -1
  56. omnigenome/utility/model_hub/model_hub.py +26 -25
  57. omnigenome/utility/pipeline_hub/__init__.py +0 -1
  58. omnigenome/utility/pipeline_hub/pipeline.py +49 -49
  59. omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
  60. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +3 -3
  61. omnigenome-0.3.1a0.dist-info/RECORD +78 -0
  62. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -1
  63. omnigenome-0.3.0a0.dist-info/RECORD +0 -85
  64. tests/__init__.py +0 -9
  65. tests/conftest.py +0 -160
  66. tests/test_dataset_patterns.py +0 -291
  67. tests/test_examples_syntax.py +0 -83
  68. tests/test_model_loading.py +0 -183
  69. tests/test_rna_functions.py +0 -255
  70. tests/test_training_patterns.py +0 -302
  71. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
  72. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
  73. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
@@ -56,7 +56,7 @@ def covert_input_to_tensor(data):
56
56
  class OmniGenomeDict(dict):
57
57
  """
58
58
  A dictionary subclass that allows moving all tensor values to a specified device.
59
-
59
+
60
60
  This class extends the standard Python dictionary to provide a convenient
61
61
  method for moving all tensor values to a specific device (CPU/GPU).
62
62
  """
@@ -87,14 +87,14 @@ class OmniGenomeDict(dict):
87
87
  class OmniDataset(torch.utils.data.Dataset):
88
88
  """
89
89
  Abstract base class for all datasets in OmniGenome.
90
-
90
+
91
91
  This class provides a unified interface for genomic datasets in the OmniGenome
92
92
  framework. It handles data loading, preprocessing, tokenization, and provides
93
93
  a PyTorch-compatible dataset interface.
94
-
94
+
95
95
  The class supports various data formats and can handle different types of
96
96
  genomic tasks including classification, regression, and token-level tasks.
97
-
97
+
98
98
  Attributes:
99
99
  tokenizer: The tokenizer to use for processing sequences.
100
100
  max_length (int): The maximum sequence length for tokenization.
@@ -118,17 +118,17 @@ class OmniDataset(torch.utils.data.Dataset):
118
118
  **kwargs: Additional keyword arguments.
119
119
  - label2id (dict): A mapping from labels to integer IDs.
120
120
  - shuffle (bool): Whether to shuffle the data. Defaults to True.
121
- - structure_in (bool): Whether to include secondary structure
121
+ - structure_in (bool): Whether to include secondary structure
122
122
  information. Defaults to False.
123
- - drop_long_seq (bool): Whether to drop sequences longer than
123
+ - drop_long_seq (bool): Whether to drop sequences longer than
124
124
  max_length. Defaults to False.
125
125
 
126
126
  Example:
127
127
  >>> # Initialize with a single data file
128
128
  >>> dataset = OmniDataset("data.json", tokenizer, max_length=512)
129
-
129
+
130
130
  >>> # Initialize with label mapping
131
- >>> dataset = OmniDataset("data.json", tokenizer,
131
+ >>> dataset = OmniDataset("data.json", tokenizer,
132
132
  ... label2id={"A": 0, "B": 1})
133
133
  """
134
134
  super(OmniDataset, self).__init__()
@@ -158,9 +158,7 @@ class OmniDataset(torch.utils.data.Dataset):
158
158
  )
159
159
  self.max_length = self.tokenizer.max_length
160
160
  else:
161
- fprint(
162
- f"No max_length detected, using default max_length=512."
163
- )
161
+ fprint(f"No max_length detected, using default max_length=512.")
164
162
  self.max_length = 512
165
163
 
166
164
  self.tokenizer.max_length = self.max_length
@@ -417,23 +415,44 @@ class OmniDataset(torch.utils.data.Dataset):
417
415
  lines = f.readlines()
418
416
  for line in lines:
419
417
  examples.append({"text": line.strip()})
420
- elif data_source.endswith(('.fasta', '.fa', '.fna', '.ffn', '.faa', '.frn')):
418
+ elif data_source.endswith(
419
+ (".fasta", ".fa", ".fna", ".ffn", ".faa", ".frn")
420
+ ):
421
421
  try:
422
422
  from Bio import SeqIO
423
423
  except ImportError:
424
- raise ImportError("Biopython is required for FASTA parsing. Please install with 'pip install biopython'.")
424
+ raise ImportError(
425
+ "Biopython is required for FASTA parsing. Please install with 'pip install biopython'."
426
+ )
425
427
  for record in SeqIO.parse(data_source, "fasta"):
426
- examples.append({"id": record.id, "sequence": str(record.seq), "description": record.description})
427
- elif data_source.endswith(('.fastq', '.fq')):
428
+ examples.append(
429
+ {
430
+ "id": record.id,
431
+ "sequence": str(record.seq),
432
+ "description": record.description,
433
+ }
434
+ )
435
+ elif data_source.endswith((".fastq", ".fq")):
428
436
  try:
429
437
  from Bio import SeqIO
430
438
  except ImportError:
431
- raise ImportError("Biopython is required for FASTQ parsing. Please install with 'pip install biopython'.")
439
+ raise ImportError(
440
+ "Biopython is required for FASTQ parsing. Please install with 'pip install biopython'."
441
+ )
432
442
  for record in SeqIO.parse(data_source, "fastq"):
433
- examples.append({"id": record.id, "sequence": str(record.seq), "quality": record.letter_annotations.get("phred_quality", [])})
434
- elif data_source.endswith('.bed'):
443
+ examples.append(
444
+ {
445
+ "id": record.id,
446
+ "sequence": str(record.seq),
447
+ "quality": record.letter_annotations.get(
448
+ "phred_quality", []
449
+ ),
450
+ }
451
+ )
452
+ elif data_source.endswith(".bed"):
435
453
  import pandas as pd
436
- df = pd.read_csv(data_source, sep='\t', comment='#')
454
+
455
+ df = pd.read_csv(data_source, sep="\t", comment="#")
437
456
  # Assign column names for standard BED fields
438
457
  for _, row in df.iterrows():
439
458
  examples.append(row.to_dict())
@@ -15,17 +15,17 @@ from ..misc.utils import env_meta_info
15
15
  class OmniMetric:
16
16
  """
17
17
  Abstract base class for all metrics in OmniGenome, based on scikit-learn.
18
-
18
+
19
19
  This class provides a unified interface for evaluation metrics in the OmniGenome
20
20
  framework. It integrates with scikit-learn's metric functions and provides
21
21
  additional functionality for handling genomic data evaluation.
22
-
22
+
23
23
  The class automatically exposes all scikit-learn metrics as attributes,
24
24
  making them easily accessible for evaluation tasks.
25
-
25
+
26
26
  Attributes:
27
27
  metric_func (callable): A callable metric function from `sklearn.metrics`.
28
- ignore_y (any): A value in the ground truth labels to be ignored during
28
+ ignore_y (any): A value in the ground truth labels to be ignored during
29
29
  metric computation.
30
30
  metadata (dict): Metadata about the metric including version info.
31
31
  """
@@ -35,10 +35,10 @@ class OmniMetric:
35
35
  Initializes the metric.
36
36
 
37
37
  Args:
38
- metric_func (callable, optional): A callable metric function from
38
+ metric_func (callable, optional): A callable metric function from
39
39
  `sklearn.metrics`. If None, subclasses
40
40
  should implement their own compute method.
41
- ignore_y (any, optional): A value in the ground truth labels to be
41
+ ignore_y (any, optional): A value in the ground truth labels to be
42
42
  ignored during metric computation.
43
43
  *args: Additional positional arguments.
44
44
  **kwargs: Additional keyword arguments.
@@ -46,7 +46,7 @@ class OmniMetric:
46
46
  Example:
47
47
  >>> # Initialize with a specific metric function
48
48
  >>> metric = OmniMetric(metrics.accuracy_score)
49
-
49
+
50
50
  >>> # Initialize with ignore value
51
51
  >>> metric = OmniMetric(ignore_y=-100)
52
52
  """
@@ -47,14 +47,14 @@ def count_parameters(model):
47
47
  class OmniModel(torch.nn.Module):
48
48
  """
49
49
  Abstract base class for all models in OmniGenome.
50
-
50
+
51
51
  This class provides a unified interface for all genomic models in the OmniGenome
52
52
  framework. It handles model initialization, forward passes, loss computation,
53
53
  prediction, inference, and model persistence.
54
-
54
+
55
55
  The class is designed to work with various types of genomic data and tasks,
56
56
  including sequence classification, token classification, regression, and more.
57
-
57
+
58
58
  Attributes:
59
59
  model (torch.nn.Module): The underlying PyTorch model.
60
60
  config: The model configuration.
@@ -76,16 +76,16 @@ class OmniModel(torch.nn.Module):
76
76
  - From a configuration object
77
77
 
78
78
  Args:
79
- config_or_model: A model configuration, a pre-trained model path (str),
79
+ config_or_model: A model configuration, a pre-trained model path (str),
80
80
  or a `torch.nn.Module` instance.
81
81
  tokenizer: The tokenizer associated with the model.
82
82
  *args: Additional positional arguments.
83
83
  **kwargs: Additional keyword arguments.
84
84
  - label2id (dict): Mapping from class labels to IDs.
85
85
  - num_labels (int): The number of labels.
86
- - trust_remote_code (bool): Whether to trust remote code when loading
86
+ - trust_remote_code (bool): Whether to trust remote code when loading
87
87
  from Hugging Face Hub. Defaults to True.
88
- - ignore_mismatched_sizes (bool): Whether to ignore size mismatches
88
+ - ignore_mismatched_sizes (bool): Whether to ignore size mismatches
89
89
  when loading pre-trained weights. Defaults to False.
90
90
  - dropout (float): Dropout rate. Defaults to 0.0.
91
91
 
@@ -97,7 +97,7 @@ class OmniModel(torch.nn.Module):
97
97
  Example:
98
98
  >>> # Initialize from a pre-trained model
99
99
  >>> model = OmniModelForSequenceClassification("model_path", tokenizer)
100
-
100
+
101
101
  >>> # Initialize from a configuration
102
102
  >>> config = AutoConfig.from_pretrained("model_path")
103
103
  >>> model = OmniModelForSequenceClassification(config, tokenizer)
@@ -202,7 +202,9 @@ class OmniModel(torch.nn.Module):
202
202
  )
203
203
  self.config.num_labels = len(self.config.id2label)
204
204
 
205
- assert len(self.config.label2id) == num_labels, f"Expected {num_labels} labels, but got {len(self.config.label2id)} in label2id dictionary."
205
+ assert (
206
+ len(self.config.label2id) == num_labels
207
+ ), f"Expected {num_labels} labels, but got {len(self.config.label2id)} in label2id dictionary."
206
208
 
207
209
  # The metadata of the model
208
210
  self.metadata = env_meta_info()
@@ -240,7 +242,7 @@ class OmniModel(torch.nn.Module):
240
242
  model architectures by mapping input parameters appropriately.
241
243
 
242
244
  Args:
243
- **inputs: The inputs to the model, compatible with the base model's
245
+ **inputs: The inputs to the model, compatible with the base model's
244
246
  forward method. Typically includes 'input_ids', 'attention_mask',
245
247
  and other model-specific parameters.
246
248
 
@@ -386,7 +388,7 @@ class OmniModel(torch.nn.Module):
386
388
  predictions for further processing.
387
389
 
388
390
  Args:
389
- sequence_or_inputs: A sequence (str), list of sequences, or
391
+ sequence_or_inputs: A sequence (str), list of sequences, or
390
392
  tokenized inputs (dict/tuple).
391
393
  **kwargs: Additional arguments for tokenization and inference.
392
394
 
@@ -398,7 +400,7 @@ class OmniModel(torch.nn.Module):
398
400
  Example:
399
401
  >>> # Predict on a single sequence
400
402
  >>> outputs = model.predict("ATCGATCG")
401
-
403
+
402
404
  >>> # Predict on multiple sequences
403
405
  >>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"])
404
406
  """
@@ -416,7 +418,7 @@ class OmniModel(torch.nn.Module):
416
418
  to class labels or probabilities.
417
419
 
418
420
  Args:
419
- sequence_or_inputs: A sequence (str), list of sequences, or
421
+ sequence_or_inputs: A sequence (str), list of sequences, or
420
422
  tokenized inputs (dict/tuple).
421
423
  **kwargs: Additional arguments for tokenization and inference.
422
424
 
@@ -429,7 +431,7 @@ class OmniModel(torch.nn.Module):
429
431
  >>> # Inference on a single sequence
430
432
  >>> results = model.inference("ATCGATCG")
431
433
  >>> print(results['predictions']) # Class labels
432
-
434
+
433
435
  >>> # Inference on multiple sequences
434
436
  >>> results = model.inference(["ATCGATCG", "GCTAGCTA"])
435
437
  """
@@ -686,4 +688,3 @@ class OmniModel(torch.nn.Module):
686
688
  info += f"Model Config: {self.config}\n"
687
689
  fprint(info)
688
690
  return info
689
-
@@ -16,15 +16,15 @@ from ..misc.utils import env_meta_info, load_module_from_path
16
16
  class OmniTokenizer:
17
17
  """
18
18
  A wrapper class for tokenizers to provide a consistent interface within OmniGenome.
19
-
19
+
20
20
  This class provides a unified interface for tokenizers in the OmniGenome framework.
21
21
  It wraps underlying tokenizers (typically from Hugging Face) and provides
22
22
  additional functionality for genomic sequence processing.
23
-
23
+
24
24
  The class handles various tokenization strategies and provides compatibility
25
25
  with different model architectures. It also supports custom tokenizer wrappers
26
26
  for specialized genomic tasks.
27
-
27
+
28
28
  Attributes:
29
29
  base_tokenizer: The underlying tokenizer instance (e.g., from Hugging Face).
30
30
  max_length (int): The default maximum sequence length.
@@ -52,7 +52,7 @@ class OmniTokenizer:
52
52
  >>> from transformers import AutoTokenizer
53
53
  >>> base_tokenizer = AutoTokenizer.from_pretrained("model_name")
54
54
  >>> tokenizer = OmniTokenizer(base_tokenizer, max_length=512)
55
-
55
+
56
56
  >>> # Initialize with sequence conversion
57
57
  >>> tokenizer = OmniTokenizer(base_tokenizer, u2t=True)
58
58
  """
@@ -87,9 +87,9 @@ class OmniTokenizer:
87
87
  Example:
88
88
  >>> # Load from a pre-trained model
89
89
  >>> tokenizer = OmniTokenizer.from_pretrained("model_name")
90
-
90
+
91
91
  >>> # Load with custom parameters
92
- >>> tokenizer = OmniTokenizer.from_pretrained("model_name",
92
+ >>> tokenizer = OmniTokenizer.from_pretrained("model_name",
93
93
  ... trust_remote_code=True)
94
94
  """
95
95
  wrapper_path = f"{model_name_or_path.rstrip('/')}/omnigenome_wrapper.py"
@@ -104,7 +104,9 @@ class OmniTokenizer:
104
104
  warnings.warn(
105
105
  f"No tokenizer wrapper found in {wrapper_path} -> Exception: {e}"
106
106
  )
107
- kwargs.pop("num_labels", None) # Remove num_labels if it exists, as it may not be applicable
107
+ kwargs.pop(
108
+ "num_labels", None
109
+ ) # Remove num_labels if it exists, as it may not be applicable
108
110
 
109
111
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, **kwargs)
110
112
 
@@ -27,11 +27,11 @@ from ... import __name__, __version__
27
27
  class OmniDatasetForTokenClassification(OmniDataset):
28
28
  """
29
29
  Dataset class specifically designed for token classification tasks in genomics.
30
-
30
+
31
31
  This class extends `OmniDataset` to provide functionalities for preparing input sequences
32
32
  and their corresponding token-level labels. It's designed for tasks where each token
33
33
  in a sequence needs to be classified independently.
34
-
34
+
35
35
  Attributes:
36
36
  metadata: Dictionary containing dataset metadata including library information
37
37
  label2id: Mapping from label strings to integer IDs
@@ -68,7 +68,7 @@ class OmniDatasetForTokenClassification(OmniDataset):
68
68
  def prepare_input(self, instance, **kwargs):
69
69
  """
70
70
  Prepare a single data instance for token classification.
71
-
71
+
72
72
  This method handles both string sequences and dictionary instances
73
73
  containing sequence and label information. It tokenizes the input
74
74
  sequence and prepares token-level labels for classification.
@@ -138,11 +138,11 @@ class OmniDatasetForTokenClassification(OmniDataset):
138
138
  class OmniDatasetForSequenceClassification(OmniDataset):
139
139
  """
140
140
  Dataset class for sequence classification tasks in genomics.
141
-
141
+
142
142
  This class extends `OmniDataset` to prepare input sequences and their corresponding
143
143
  sequence-level labels. It's designed for tasks where the entire sequence needs
144
144
  to be classified into one of several categories.
145
-
145
+
146
146
  Attributes:
147
147
  metadata: Dictionary containing dataset metadata including library information
148
148
  label2id: Mapping from label strings to integer IDs
@@ -179,7 +179,7 @@ class OmniDatasetForSequenceClassification(OmniDataset):
179
179
  def prepare_input(self, instance, **kwargs):
180
180
  """
181
181
  Prepare a single data instance for sequence classification.
182
-
182
+
183
183
  This method handles both string sequences and dictionary instances
184
184
  containing sequence and label information. It tokenizes the input
185
185
  sequence and prepares sequence-level labels for classification.
@@ -238,11 +238,11 @@ class OmniDatasetForSequenceClassification(OmniDataset):
238
238
  class OmniDatasetForTokenRegression(OmniDataset):
239
239
  """
240
240
  Dataset class for token regression tasks in genomics.
241
-
241
+
242
242
  This class extends `OmniDataset` to prepare input sequences and their corresponding
243
243
  token-level regression targets. It's designed for tasks where each token in a
244
244
  sequence needs to be assigned a continuous value.
245
-
245
+
246
246
  Attributes:
247
247
  metadata: Dictionary containing dataset metadata including library information
248
248
  """
@@ -278,7 +278,7 @@ class OmniDatasetForTokenRegression(OmniDataset):
278
278
  def prepare_input(self, instance, **kwargs):
279
279
  """
280
280
  Prepare a single data instance for token regression.
281
-
281
+
282
282
  This method handles both string sequences and dictionary instances
283
283
  containing sequence and regression target information. It tokenizes
284
284
  the input sequence and prepares token-level regression targets.
@@ -330,7 +330,9 @@ class OmniDatasetForTokenRegression(OmniDataset):
330
330
  # Handle token-level regression labels
331
331
  if isinstance(labels, (list, tuple)):
332
332
  # Ensure labels match sequence length
333
- labels = list(labels)[:self.max_length - 2] # Account for special tokens
333
+ labels = list(labels)[
334
+ : self.max_length - 2
335
+ ] # Account for special tokens
334
336
  labels = [-100] + labels + [-100] # Add padding for special tokens
335
337
  else:
336
338
  # Single value for the entire sequence
@@ -343,11 +345,11 @@ class OmniDatasetForTokenRegression(OmniDataset):
343
345
  class OmniDatasetForSequenceRegression(OmniDataset):
344
346
  """
345
347
  Dataset class for sequence regression tasks in genomics.
346
-
348
+
347
349
  This class extends `OmniDataset` to prepare input sequences and their corresponding
348
350
  sequence-level regression targets. It's designed for tasks where the entire
349
351
  sequence needs to be assigned a continuous value.
350
-
352
+
351
353
  Attributes:
352
354
  metadata: Dictionary containing dataset metadata including library information
353
355
  """
@@ -383,7 +385,7 @@ class OmniDatasetForSequenceRegression(OmniDataset):
383
385
  def prepare_input(self, instance, **kwargs):
384
386
  """
385
387
  Prepare a single data instance for sequence regression.
386
-
388
+
387
389
  This method handles both string sequences and dictionary instances
388
390
  containing sequence and regression target information. It tokenizes
389
391
  the input sequence and prepares sequence-level regression targets.
@@ -432,4 +434,4 @@ class OmniDatasetForSequenceRegression(OmniDataset):
432
434
  labels = float(labels)
433
435
 
434
436
  tokenized_inputs["labels"] = torch.tensor(labels, dtype=torch.float32)
435
- return tokenized_inputs
437
+ return tokenized_inputs
@@ -10,4 +10,3 @@
10
10
  """
11
11
  This package contains modules for LoRA (Low-Rank Adaptation) fine-tuning.
12
12
  """
13
-