omnigenome 0.3.0a0__py3-none-any.whl → 0.3.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. omnigenome/__init__.py +29 -44
  2. omnigenome/auto/auto_bench/__init__.py +0 -1
  3. omnigenome/auto/auto_bench/auto_bench.py +24 -14
  4. omnigenome/auto/auto_train/__init__.py +0 -1
  5. omnigenome/auto/auto_train/auto_train.py +11 -12
  6. omnigenome/auto/bench_hub/__init__.py +0 -1
  7. omnigenome/auto/bench_hub/bench_hub.py +1 -1
  8. omnigenome/cli/__init__.py +0 -1
  9. omnigenome/cli/commands/__init__.py +0 -1
  10. omnigenome/cli/commands/base.py +10 -10
  11. omnigenome/cli/commands/bench/__init__.py +0 -1
  12. omnigenome/cli/commands/bench/bench_cli.py +10 -10
  13. omnigenome/cli/commands/rna/__init__.py +0 -1
  14. omnigenome/cli/commands/rna/rna_design.py +10 -11
  15. omnigenome/src/__init__.py +0 -1
  16. omnigenome/src/abc/__init__.py +0 -1
  17. omnigenome/src/abc/abstract_dataset.py +38 -19
  18. omnigenome/src/abc/abstract_metric.py +7 -7
  19. omnigenome/src/abc/abstract_model.py +15 -14
  20. omnigenome/src/abc/abstract_tokenizer.py +9 -7
  21. omnigenome/src/dataset/omni_dataset.py +16 -14
  22. omnigenome/src/lora/__init__.py +0 -1
  23. omnigenome/src/lora/lora_model.py +47 -41
  24. omnigenome/src/metric/classification_metric.py +11 -11
  25. omnigenome/src/metric/metric.py +19 -19
  26. omnigenome/src/metric/ranking_metric.py +15 -15
  27. omnigenome/src/metric/regression_metric.py +18 -18
  28. omnigenome/src/misc/utils.py +214 -150
  29. omnigenome/src/model/augmentation/__init__.py +0 -1
  30. omnigenome/src/model/augmentation/model.py +17 -17
  31. omnigenome/src/model/classification/__init__.py +0 -1
  32. omnigenome/src/model/classification/model.py +28 -32
  33. omnigenome/src/model/embedding/__init__.py +0 -1
  34. omnigenome/src/model/embedding/model.py +35 -35
  35. omnigenome/src/model/mlm/__init__.py +0 -1
  36. omnigenome/src/model/mlm/model.py +13 -13
  37. omnigenome/src/model/module_utils.py +17 -17
  38. omnigenome/src/model/regression/__init__.py +0 -1
  39. omnigenome/src/model/regression/model.py +72 -77
  40. omnigenome/src/model/regression/resnet.py +32 -32
  41. omnigenome/src/model/rna_design/__init__.py +0 -1
  42. omnigenome/src/model/rna_design/model.py +168 -118
  43. omnigenome/src/model/seq2seq/__init__.py +0 -1
  44. omnigenome/src/model/seq2seq/model.py +4 -4
  45. omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
  46. omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
  47. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
  48. omnigenome/src/trainer/accelerate_trainer.py +40 -32
  49. omnigenome/src/trainer/hf_trainer.py +8 -8
  50. omnigenome/src/trainer/trainer.py +37 -25
  51. omnigenome/utility/dataset_hub/__init__.py +0 -1
  52. omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
  53. omnigenome/utility/ensemble.py +26 -26
  54. omnigenome/utility/hub_utils.py +8 -8
  55. omnigenome/utility/model_hub/__init__.py +0 -1
  56. omnigenome/utility/model_hub/model_hub.py +26 -25
  57. omnigenome/utility/pipeline_hub/__init__.py +0 -1
  58. omnigenome/utility/pipeline_hub/pipeline.py +49 -49
  59. omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
  60. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +3 -3
  61. omnigenome-0.3.1a0.dist-info/RECORD +78 -0
  62. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -1
  63. omnigenome-0.3.0a0.dist-info/RECORD +0 -85
  64. tests/__init__.py +0 -9
  65. tests/conftest.py +0 -160
  66. tests/test_dataset_patterns.py +0 -291
  67. tests/test_examples_syntax.py +0 -83
  68. tests/test_model_loading.py +0 -183
  69. tests/test_rna_functions.py +0 -255
  70. tests/test_training_patterns.py +0 -302
  71. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
  72. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
  73. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
@@ -18,22 +18,23 @@ import torch
18
18
  from torch import nn
19
19
  from omnigenome.src.misc.utils import fprint
20
20
 
21
+
21
22
  def find_linear_target_modules(model, keyword_filter=None, use_full_path=True):
22
23
  """
23
24
  Find linear modules in a model that can be targeted for LoRA adaptation.
24
-
25
+
25
26
  This function searches through a model's modules to identify linear layers
26
27
  that can be adapted using LoRA. It supports filtering by keyword patterns
27
28
  to target specific types of layers.
28
-
29
+
29
30
  Args:
30
31
  model: The model to search for linear modules
31
32
  keyword_filter (str, list, tuple, optional): Keywords to filter modules by name
32
33
  use_full_path (bool): Whether to return full module paths or just names (default: True)
33
-
34
+
34
35
  Returns:
35
36
  list: Sorted list of linear module names that can be targeted for LoRA
36
-
37
+
37
38
  Raises:
38
39
  TypeError: If keyword_filter is not None, str, or a list/tuple of str
39
40
  """
@@ -46,31 +47,32 @@ def find_linear_target_modules(model, keyword_filter=None, use_full_path=True):
46
47
  elif not isinstance(keyword_filter, (list, tuple)):
47
48
  raise TypeError("keyword_filter must be None, str, or a list/tuple of str")
48
49
 
49
- pattern = '|'.join(map(re.escape, keyword_filter))
50
+ pattern = "|".join(map(re.escape, keyword_filter))
50
51
 
51
52
  linear_modules = set()
52
53
  for name, module in model.named_modules():
53
54
  if isinstance(module, nn.Linear):
54
55
  if keyword_filter is None or re.search(pattern, name, re.IGNORECASE):
55
- linear_modules.add(name if use_full_path else name.split('.')[-1])
56
+ linear_modules.add(name if use_full_path else name.split(".")[-1])
56
57
 
57
58
  return sorted(linear_modules)
58
59
 
60
+
59
61
  def auto_lora_model(model, **kwargs):
60
62
  """
61
63
  Automatically create a LoRA-adapted model.
62
-
64
+
63
65
  This function automatically identifies suitable target modules and creates
64
66
  a LoRA-adapted version of the input model. It handles configuration
65
67
  setup and parameter freezing for efficient fine-tuning.
66
-
68
+
67
69
  Args:
68
70
  model: The base model to adapt with LoRA
69
71
  **kwargs: Additional LoRA configuration parameters
70
-
72
+
71
73
  Returns:
72
74
  The LoRA-adapted model
73
-
75
+
74
76
  Raises:
75
77
  AssertionError: If no target modules are found for LoRA injection
76
78
  """
@@ -79,8 +81,8 @@ def auto_lora_model(model, **kwargs):
79
81
 
80
82
  # A bad case for the EVO-1 model, which has a custom config class
81
83
  ######################
82
- if hasattr(model, 'config') and not isinstance(model.config, PretrainedConfig):
83
- delattr(model.config, 'Loader')
84
+ if hasattr(model, "config") and not isinstance(model.config, PretrainedConfig):
85
+ delattr(model.config, "Loader")
84
86
  model.config = PretrainedConfig.from_dict(dict(model.config))
85
87
  #######################
86
88
 
@@ -92,7 +94,9 @@ def auto_lora_model(model, **kwargs):
92
94
  lora_dropout = kwargs.pop("lora_dropout", 0.1)
93
95
 
94
96
  if target_modules is None:
95
- target_modules = find_linear_target_modules(model, keyword_filter=kwargs.get("keyword_filter", None))
97
+ target_modules = find_linear_target_modules(
98
+ model, keyword_filter=kwargs.get("keyword_filter", None)
99
+ )
96
100
  assert target_modules is not None, "No target modules found for LoRA injection."
97
101
  config = LoraConfig(
98
102
  target_modules=target_modules,
@@ -115,29 +119,30 @@ def auto_lora_model(model, **kwargs):
115
119
  )
116
120
  return lora_model
117
121
 
122
+
118
123
  class OmniLoraModel(nn.Module):
119
124
  """
120
125
  LoRA-adapted model for OmniGenome.
121
-
126
+
122
127
  This class provides a wrapper around LoRA-adapted models, enabling
123
128
  efficient fine-tuning of large genomic language models while maintaining
124
129
  compatibility with the OmniGenome framework.
125
-
130
+
126
131
  Attributes:
127
132
  lora_model: The underlying LoRA-adapted model
128
133
  config: Model configuration
129
134
  device: Device the model is running on
130
135
  dtype: Data type of the model parameters
131
136
  """
132
-
137
+
133
138
  def __init__(self, model, **kwargs):
134
139
  """
135
140
  Initialize the LoRA-adapted model.
136
-
141
+
137
142
  Args:
138
143
  model: The base model to adapt with LoRA
139
144
  **kwargs: LoRA configuration parameters
140
-
145
+
141
146
  Raises:
142
147
  ValueError: If no target modules are specified for LoRA injection
143
148
  """
@@ -147,7 +152,8 @@ class OmniLoraModel(nn.Module):
147
152
  raise ValueError(
148
153
  "No target modules found for LoRA injection. To perform LoRA adaptation fine-tuning, "
149
154
  "please specify the target modules using the 'target_modules' argument. "
150
- "The target modules depend on the model architecture, such as 'query', 'value', etc. ")
155
+ "The target modules depend on the model architecture, such as 'query', 'value', etc. "
156
+ )
151
157
 
152
158
  self.lora_model = auto_lora_model(model, **kwargs)
153
159
 
@@ -159,23 +165,23 @@ class OmniLoraModel(nn.Module):
159
165
  )
160
166
 
161
167
  self.config = model.config
162
- self.to('cpu') # Move the model to CPU initially
168
+ self.to("cpu") # Move the model to CPU initially
163
169
  fprint(
164
170
  "LoRA model initialized with the following configuration:\n",
165
- self.lora_model
171
+ self.lora_model,
166
172
  )
167
173
 
168
174
  def to(self, *args, **kwargs):
169
175
  """
170
176
  Move the model to a specific device and data type.
171
-
177
+
172
178
  This method overrides the default to() method to ensure the LoRA model
173
179
  and its components are properly moved to the target device and dtype.
174
-
180
+
175
181
  Args:
176
182
  *args: Device specification (e.g., 'cuda', 'cpu')
177
183
  **kwargs: Additional arguments including dtype
178
-
184
+
179
185
  Returns:
180
186
  self: The model instance
181
187
  """
@@ -188,20 +194,20 @@ class OmniLoraModel(nn.Module):
188
194
  break
189
195
  for module in self.lora_model.modules():
190
196
  module.device = self.device
191
- if hasattr(module, 'dtype'):
197
+ if hasattr(module, "dtype"):
192
198
  module.dtype = self.dtype
193
199
  except Exception as e:
194
- pass # Ignore errors if parameters are not available
200
+ pass # Ignore errors if parameters are not available
195
201
  return self
196
202
 
197
203
  def forward(self, *args, **kwargs):
198
204
  """
199
205
  Forward pass through the LoRA model.
200
-
206
+
201
207
  Args:
202
208
  *args: Positional arguments for the forward pass
203
209
  **kwargs: Keyword arguments for the forward pass
204
-
210
+
205
211
  Returns:
206
212
  The output from the LoRA model
207
213
  """
@@ -210,11 +216,11 @@ class OmniLoraModel(nn.Module):
210
216
  def predict(self, *args, **kwargs):
211
217
  """
212
218
  Generate predictions using the LoRA model.
213
-
219
+
214
220
  Args:
215
221
  *args: Positional arguments for prediction
216
222
  **kwargs: Keyword arguments for prediction
217
-
223
+
218
224
  Returns:
219
225
  Model predictions
220
226
  """
@@ -223,11 +229,11 @@ class OmniLoraModel(nn.Module):
223
229
  def save(self, *args, **kwargs):
224
230
  """
225
231
  Save the LoRA model.
226
-
232
+
227
233
  Args:
228
234
  *args: Positional arguments for saving
229
235
  **kwargs: Keyword arguments for saving
230
-
236
+
231
237
  Returns:
232
238
  Result of the save operation
233
239
  """
@@ -236,7 +242,7 @@ class OmniLoraModel(nn.Module):
236
242
  def model_info(self):
237
243
  """
238
244
  Get information about the LoRA model.
239
-
245
+
240
246
  Returns:
241
247
  Model information from the base model
242
248
  """
@@ -245,10 +251,10 @@ class OmniLoraModel(nn.Module):
245
251
  def set_loss_fn(self, fn):
246
252
  """
247
253
  Set the loss function for the LoRA model.
248
-
254
+
249
255
  Args:
250
256
  fn: Loss function to set
251
-
257
+
252
258
  Returns:
253
259
  Result of setting the loss function
254
260
  """
@@ -257,10 +263,10 @@ class OmniLoraModel(nn.Module):
257
263
  def last_hidden_state_forward(self, **kwargs):
258
264
  """
259
265
  Forward pass to get the last hidden state.
260
-
266
+
261
267
  Args:
262
268
  **kwargs: Keyword arguments for the forward pass
263
-
269
+
264
270
  Returns:
265
271
  Last hidden state from the base model
266
272
  """
@@ -269,7 +275,7 @@ class OmniLoraModel(nn.Module):
269
275
  def tokenizer(self):
270
276
  """
271
277
  Get the tokenizer from the base model.
272
-
278
+
273
279
  Returns:
274
280
  The tokenizer from the base model
275
281
  """
@@ -278,7 +284,7 @@ class OmniLoraModel(nn.Module):
278
284
  def config(self):
279
285
  """
280
286
  Get the configuration from the base model.
281
-
287
+
282
288
  Returns:
283
289
  The configuration from the base model
284
290
  """
@@ -287,8 +293,8 @@ class OmniLoraModel(nn.Module):
287
293
  def model(self):
288
294
  """
289
295
  Get the base model.
290
-
296
+
291
297
  Returns:
292
298
  The base model
293
299
  """
294
- return self.lora_model.base_model.model
300
+ return self.lora_model.base_model.model
@@ -19,17 +19,17 @@ from ..abc.abstract_metric import OmniMetric
19
19
  class ClassificationMetric(OmniMetric):
20
20
  """
21
21
  Classification metric class for evaluating classification models.
22
-
22
+
23
23
  This class provides a comprehensive interface for classification metrics
24
24
  in the OmniGenome framework. It integrates with scikit-learn's classification
25
25
  metrics and provides additional functionality for handling genomic classification
26
26
  tasks.
27
-
27
+
28
28
  The class automatically exposes all scikit-learn classification metrics as
29
29
  callable attributes, making them easily accessible for evaluation. It also
30
30
  handles special cases like Hugging Face's EvalPrediction objects and
31
31
  provides proper handling of ignored labels.
32
-
32
+
33
33
  Attributes:
34
34
  metric_func (callable): A callable metric function from sklearn.metrics.
35
35
  ignore_y (any): A value in the ground truth labels to be ignored during
@@ -42,10 +42,10 @@ class ClassificationMetric(OmniMetric):
42
42
  Initializes the classification metric.
43
43
 
44
44
  Args:
45
- metric_func (callable, optional): A callable metric function from
45
+ metric_func (callable, optional): A callable metric function from
46
46
  sklearn.metrics. If None, subclasses
47
47
  should implement their own compute method.
48
- ignore_y (any, optional): A value in the ground truth labels to be
48
+ ignore_y (any, optional): A value in the ground truth labels to be
49
49
  ignored during metric computation. Defaults to -100.
50
50
  *args: Additional positional arguments.
51
51
  **kwargs: Additional keyword arguments.
@@ -53,7 +53,7 @@ class ClassificationMetric(OmniMetric):
53
53
  Example:
54
54
  >>> # Initialize with a specific metric function
55
55
  >>> metric = ClassificationMetric(metrics.accuracy_score)
56
-
56
+
57
57
  >>> # Initialize with ignore value
58
58
  >>> metric = ClassificationMetric(ignore_y=-100)
59
59
  """
@@ -64,7 +64,7 @@ class ClassificationMetric(OmniMetric):
64
64
  def __getattribute__(self, name):
65
65
  """
66
66
  Custom attribute getter that provides dynamic access to scikit-learn metrics.
67
-
67
+
68
68
  This method provides transparent access to all scikit-learn classification
69
69
  metrics. When a metric function is accessed, it returns a callable wrapper
70
70
  that handles the metric computation with proper preprocessing.
@@ -91,7 +91,7 @@ class ClassificationMetric(OmniMetric):
91
91
  def wrapper(y_true=None, y_pred=None, *args, **kwargs):
92
92
  """
93
93
  Compute the metric, based on the true and predicted values.
94
-
94
+
95
95
  This wrapper function handles various input formats including
96
96
  Hugging Face's EvalPrediction objects and provides proper
97
97
  preprocessing for metric computation.
@@ -99,7 +99,7 @@ class ClassificationMetric(OmniMetric):
99
99
  Args:
100
100
  y_true: The true values (ground truth labels).
101
101
  y_pred: The predicted values (model predictions).
102
- ignore_y: The value to ignore in the predictions and true
102
+ ignore_y: The value to ignore in the predictions and true
103
103
  values in corresponding positions.
104
104
  *args: Additional positional arguments for the metric function.
105
105
  **kwargs: Additional keyword arguments for the metric function.
@@ -111,7 +111,7 @@ class ClassificationMetric(OmniMetric):
111
111
  >>> # Standard usage
112
112
  >>> result = accuracy_fn(y_true, y_pred)
113
113
  >>> print(result) # {'accuracy_score': 0.85}
114
-
114
+
115
115
  >>> # With Hugging Face EvalPrediction
116
116
  >>> result = accuracy_fn(eval_prediction)
117
117
  >>> print(result) # {'accuracy_score': 0.85}
@@ -152,7 +152,7 @@ class ClassificationMetric(OmniMetric):
152
152
  def compute(self, y_true, y_pred, *args, **kwargs):
153
153
  """
154
154
  Compute the metric, based on the true and predicted values.
155
-
155
+
156
156
  This method computes the classification metric using the provided
157
157
  metric function. It handles preprocessing and applies any additional
158
158
  keyword arguments.
@@ -20,20 +20,20 @@ from ..abc.abstract_metric import OmniMetric
20
20
  def mcrmse(y_true, y_pred):
21
21
  """
22
22
  Compute Mean Column Root Mean Square Error (MCRMSE).
23
-
23
+
24
24
  MCRMSE is a multi-target regression metric that computes the RMSE for each target
25
25
  column and then takes the mean across all targets.
26
-
26
+
27
27
  Args:
28
28
  y_true (np.ndarray): Ground truth values with shape (n_samples, n_targets)
29
29
  y_pred (np.ndarray): Predicted values with shape (n_samples, n_targets)
30
-
30
+
31
31
  Returns:
32
32
  float: Mean Column Root Mean Square Error
33
-
33
+
34
34
  Raises:
35
35
  ValueError: If y_true and y_pred have different shapes
36
-
36
+
37
37
  Example:
38
38
  >>> y_true = np.array([[1, 2], [3, 4], [5, 6]])
39
39
  >>> y_pred = np.array([[1.1, 2.1], [2.9, 4.1], [5.2, 5.8]])
@@ -57,18 +57,18 @@ class Metric(OmniMetric):
57
57
  """
58
58
  A flexible metric class that provides access to all scikit-learn metrics
59
59
  and custom metrics for evaluation.
60
-
60
+
61
61
  This class dynamically wraps scikit-learn metrics and provides a unified
62
62
  interface for computing various evaluation metrics. It handles different
63
63
  input formats including HuggingFace trainer outputs and supports
64
64
  custom metric functions.
65
-
65
+
66
66
  Attributes:
67
67
  metric_func: Custom metric function if provided
68
68
  ignore_y: Value to ignore in predictions and true values
69
69
  kwargs: Additional keyword arguments for metric computation
70
70
  metrics: Dictionary of available metrics including custom ones
71
-
71
+
72
72
  Example:
73
73
  >>> from omnigenome.src.metric import Metric
74
74
  >>> metric = Metric(ignore_y=-100)
@@ -82,7 +82,7 @@ class Metric(OmniMetric):
82
82
  def __init__(self, metric_func=None, ignore_y=-100, *args, **kwargs):
83
83
  """
84
84
  Initialize the Metric class.
85
-
85
+
86
86
  Args:
87
87
  metric_func (callable, optional): Custom metric function to use
88
88
  ignore_y (int, optional): Value to ignore in predictions and true values. Defaults to -100
@@ -98,14 +98,14 @@ class Metric(OmniMetric):
98
98
  def __getattribute__(self, name):
99
99
  """
100
100
  Dynamically create metric computation methods.
101
-
101
+
102
102
  This method intercepts attribute access and creates wrapper functions
103
103
  for scikit-learn metrics, handling different input formats and
104
104
  preprocessing the data appropriately.
105
-
105
+
106
106
  Args:
107
107
  name (str): Name of the metric to access
108
-
108
+
109
109
  Returns:
110
110
  callable: Wrapper function for the requested metric
111
111
  """
@@ -119,20 +119,20 @@ class Metric(OmniMetric):
119
119
  def wrapper(y_true=None, y_score=None, *args, **kwargs):
120
120
  """
121
121
  Compute the metric, based on the true and predicted values.
122
-
122
+
123
123
  This wrapper handles different input formats including HuggingFace
124
124
  trainer outputs and performs necessary preprocessing.
125
-
125
+
126
126
  Args:
127
127
  y_true: The true values or HuggingFace EvalPrediction object
128
128
  y_score: The predicted values
129
129
  ignore_y: The value to ignore in the predictions and true values in corresponding positions
130
130
  *args: Additional positional arguments for the metric
131
131
  **kwargs: Additional keyword arguments for the metric
132
-
132
+
133
133
  Returns:
134
134
  dict: Dictionary containing the metric name and computed value
135
-
135
+
136
136
  Raises:
137
137
  ValueError: If neither y_true nor y_score is provided
138
138
  """
@@ -176,16 +176,16 @@ class Metric(OmniMetric):
176
176
  def compute(self, y_true, y_score, *args, **kwargs):
177
177
  """
178
178
  Compute the metric, based on the true and predicted values.
179
-
179
+
180
180
  Args:
181
181
  y_true: The true values
182
182
  y_score: The predicted values
183
183
  *args: Additional positional arguments for the metric
184
184
  **kwargs: Additional keyword arguments for the metric
185
-
185
+
186
186
  Returns:
187
187
  The computed metric value
188
-
188
+
189
189
  Raises:
190
190
  NotImplementedError: If no metric function is provided and compute is not implemented
191
191
  """
@@ -20,16 +20,16 @@ from ..abc.abstract_metric import OmniMetric
20
20
  class RankingMetric(OmniMetric):
21
21
  """
22
22
  A specialized metric class for ranking tasks and evaluation.
23
-
23
+
24
24
  This class provides access to ranking-specific metrics from scikit-learn
25
25
  and handles different input formats including HuggingFace trainer outputs.
26
26
  It dynamically wraps scikit-learn metrics and provides a unified interface
27
27
  for computing various ranking evaluation metrics.
28
-
28
+
29
29
  Attributes:
30
30
  metric_func: Custom metric function if provided
31
31
  ignore_y: Value to ignore in predictions and true values
32
-
32
+
33
33
  Example:
34
34
  >>> from omnigenome.src.metric import RankingMetric
35
35
  >>> metric = RankingMetric(ignore_y=-100)
@@ -43,7 +43,7 @@ class RankingMetric(OmniMetric):
43
43
  def __init__(self, *args, **kwargs):
44
44
  """
45
45
  Initialize the RankingMetric class.
46
-
46
+
47
47
  Args:
48
48
  *args: Additional positional arguments passed to parent class
49
49
  **kwargs: Additional keyword arguments passed to parent class
@@ -53,17 +53,17 @@ class RankingMetric(OmniMetric):
53
53
  def __getattr__(self, name):
54
54
  """
55
55
  Dynamically create ranking metric computation methods.
56
-
56
+
57
57
  This method intercepts attribute access and creates wrapper functions
58
58
  for scikit-learn ranking metrics, handling different input formats and
59
59
  preprocessing the data appropriately.
60
-
60
+
61
61
  Args:
62
62
  name (str): Name of the ranking metric to access
63
-
63
+
64
64
  Returns:
65
65
  callable: Wrapper function for the requested ranking metric
66
-
66
+
67
67
  Raises:
68
68
  AttributeError: If the requested metric is not found
69
69
  """
@@ -74,17 +74,17 @@ class RankingMetric(OmniMetric):
74
74
  def wrapper(y_true=None, y_score=None, *args, **kwargs):
75
75
  """
76
76
  Compute the ranking metric, based on the true and predicted values.
77
-
77
+
78
78
  This wrapper handles different input formats including HuggingFace
79
79
  trainer outputs and performs necessary preprocessing for ranking tasks.
80
-
80
+
81
81
  Args:
82
82
  y_true: The true values or HuggingFace EvalPrediction object
83
83
  y_score: The predicted values (scores for ranking)
84
84
  ignore_y: The value to ignore in the predictions and true values in corresponding positions
85
85
  *args: Additional positional arguments for the metric
86
86
  **kwargs: Additional keyword arguments for the metric
87
-
87
+
88
88
  Returns:
89
89
  dict: Dictionary containing the metric name and computed value
90
90
  """
@@ -121,19 +121,19 @@ class RankingMetric(OmniMetric):
121
121
  def compute(self, y_true, y_score, *args, **kwargs):
122
122
  """
123
123
  Compute the ranking metric, based on the true and predicted values.
124
-
124
+
125
125
  This method should be implemented by subclasses to provide specific
126
126
  ranking metric computation logic.
127
-
127
+
128
128
  Args:
129
129
  y_true: The true values
130
130
  y_score: The predicted values (scores for ranking)
131
131
  *args: Additional positional arguments for the metric
132
132
  **kwargs: Additional keyword arguments for the metric
133
-
133
+
134
134
  Returns:
135
135
  The computed ranking metric value
136
-
136
+
137
137
  Raises:
138
138
  NotImplementedError: If compute method is not implemented in the child class
139
139
  """