omnigenome 0.3.0a0__py3-none-any.whl → 0.3.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. omnigenome/__init__.py +29 -44
  2. omnigenome/auto/auto_bench/__init__.py +0 -1
  3. omnigenome/auto/auto_bench/auto_bench.py +24 -14
  4. omnigenome/auto/auto_train/__init__.py +0 -1
  5. omnigenome/auto/auto_train/auto_train.py +11 -12
  6. omnigenome/auto/bench_hub/__init__.py +0 -1
  7. omnigenome/auto/bench_hub/bench_hub.py +1 -1
  8. omnigenome/cli/__init__.py +0 -1
  9. omnigenome/cli/commands/__init__.py +0 -1
  10. omnigenome/cli/commands/base.py +10 -10
  11. omnigenome/cli/commands/bench/__init__.py +0 -1
  12. omnigenome/cli/commands/bench/bench_cli.py +10 -10
  13. omnigenome/cli/commands/rna/__init__.py +0 -1
  14. omnigenome/cli/commands/rna/rna_design.py +10 -11
  15. omnigenome/src/__init__.py +0 -1
  16. omnigenome/src/abc/__init__.py +0 -1
  17. omnigenome/src/abc/abstract_dataset.py +38 -19
  18. omnigenome/src/abc/abstract_metric.py +7 -7
  19. omnigenome/src/abc/abstract_model.py +15 -14
  20. omnigenome/src/abc/abstract_tokenizer.py +9 -7
  21. omnigenome/src/dataset/omni_dataset.py +16 -14
  22. omnigenome/src/lora/__init__.py +0 -1
  23. omnigenome/src/lora/lora_model.py +47 -41
  24. omnigenome/src/metric/classification_metric.py +11 -11
  25. omnigenome/src/metric/metric.py +19 -19
  26. omnigenome/src/metric/ranking_metric.py +15 -15
  27. omnigenome/src/metric/regression_metric.py +18 -18
  28. omnigenome/src/misc/utils.py +214 -150
  29. omnigenome/src/model/augmentation/__init__.py +0 -1
  30. omnigenome/src/model/augmentation/model.py +17 -17
  31. omnigenome/src/model/classification/__init__.py +0 -1
  32. omnigenome/src/model/classification/model.py +28 -32
  33. omnigenome/src/model/embedding/__init__.py +0 -1
  34. omnigenome/src/model/embedding/model.py +35 -35
  35. omnigenome/src/model/mlm/__init__.py +0 -1
  36. omnigenome/src/model/mlm/model.py +13 -13
  37. omnigenome/src/model/module_utils.py +17 -17
  38. omnigenome/src/model/regression/__init__.py +0 -1
  39. omnigenome/src/model/regression/model.py +72 -77
  40. omnigenome/src/model/regression/resnet.py +32 -32
  41. omnigenome/src/model/rna_design/__init__.py +0 -1
  42. omnigenome/src/model/rna_design/model.py +168 -118
  43. omnigenome/src/model/seq2seq/__init__.py +0 -1
  44. omnigenome/src/model/seq2seq/model.py +4 -4
  45. omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
  46. omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
  47. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
  48. omnigenome/src/trainer/accelerate_trainer.py +40 -32
  49. omnigenome/src/trainer/hf_trainer.py +8 -8
  50. omnigenome/src/trainer/trainer.py +37 -25
  51. omnigenome/utility/dataset_hub/__init__.py +0 -1
  52. omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
  53. omnigenome/utility/ensemble.py +26 -26
  54. omnigenome/utility/hub_utils.py +8 -8
  55. omnigenome/utility/model_hub/__init__.py +0 -1
  56. omnigenome/utility/model_hub/model_hub.py +26 -25
  57. omnigenome/utility/pipeline_hub/__init__.py +0 -1
  58. omnigenome/utility/pipeline_hub/pipeline.py +49 -49
  59. omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
  60. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +3 -3
  61. omnigenome-0.3.1a0.dist-info/RECORD +78 -0
  62. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -1
  63. omnigenome-0.3.0a0.dist-info/RECORD +0 -85
  64. tests/__init__.py +0 -9
  65. tests/conftest.py +0 -160
  66. tests/test_dataset_patterns.py +0 -291
  67. tests/test_examples_syntax.py +0 -83
  68. tests/test_model_loading.py +0 -183
  69. tests/test_rna_functions.py +0 -255
  70. tests/test_training_patterns.py +0 -302
  71. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
  72. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
  73. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
@@ -20,20 +20,20 @@ from ..abc.abstract_metric import OmniMetric
20
20
  def mcrmse(y_true, y_pred):
21
21
  """
22
22
  Compute Mean Column Root Mean Square Error (MCRMSE).
23
-
23
+
24
24
  MCRMSE is a multi-target regression metric that computes the RMSE for each target
25
25
  column and then takes the mean across all targets.
26
-
26
+
27
27
  Args:
28
28
  y_true (np.ndarray): Ground truth values with shape (n_samples, n_targets)
29
29
  y_pred (np.ndarray): Predicted values with shape (n_samples, n_targets)
30
-
30
+
31
31
  Returns:
32
32
  float: Mean Column Root Mean Square Error
33
-
33
+
34
34
  Raises:
35
35
  ValueError: If y_true and y_pred have different shapes
36
-
36
+
37
37
  Example:
38
38
  >>> y_true = np.array([[1, 2], [3, 4], [5, 6]])
39
39
  >>> y_pred = np.array([[1.1, 2.1], [2.9, 4.1], [5.2, 5.8]])
@@ -56,18 +56,18 @@ setattr(metrics, "mcrmse", mcrmse)
56
56
  class RegressionMetric(OmniMetric):
57
57
  """
58
58
  A specialized metric class for regression tasks and evaluation.
59
-
59
+
60
60
  This class provides access to regression-specific metrics from scikit-learn
61
61
  and handles different input formats including HuggingFace trainer outputs.
62
62
  It dynamically wraps scikit-learn metrics and provides a unified interface
63
63
  for computing various regression evaluation metrics.
64
-
64
+
65
65
  Attributes:
66
66
  metric_func: Custom metric function if provided
67
67
  ignore_y: Value to ignore in predictions and true values
68
68
  kwargs: Additional keyword arguments for metric computation
69
69
  metrics: Dictionary of available metrics including custom ones
70
-
70
+
71
71
  Example:
72
72
  >>> from omnigenome.src.metric import RegressionMetric
73
73
  >>> metric = RegressionMetric(ignore_y=-100)
@@ -81,7 +81,7 @@ class RegressionMetric(OmniMetric):
81
81
  def __init__(self, metric_func=None, ignore_y=-100, *args, **kwargs):
82
82
  """
83
83
  Initialize the RegressionMetric class.
84
-
84
+
85
85
  Args:
86
86
  metric_func (callable, optional): Custom metric function to use
87
87
  ignore_y (int, optional): Value to ignore in predictions and true values. Defaults to -100
@@ -97,14 +97,14 @@ class RegressionMetric(OmniMetric):
97
97
  def __getattribute__(self, name):
98
98
  """
99
99
  Dynamically create regression metric computation methods.
100
-
100
+
101
101
  This method intercepts attribute access and creates wrapper functions
102
102
  for scikit-learn regression metrics, handling different input formats and
103
103
  preprocessing the data appropriately.
104
-
104
+
105
105
  Args:
106
106
  name (str): Name of the regression metric to access
107
-
107
+
108
108
  Returns:
109
109
  callable: Wrapper function for the requested regression metric
110
110
  """
@@ -118,17 +118,17 @@ class RegressionMetric(OmniMetric):
118
118
  def wrapper(y_true=None, y_score=None, *args, **kwargs):
119
119
  """
120
120
  Compute the regression metric, based on the true and predicted values.
121
-
121
+
122
122
  This wrapper handles different input formats including HuggingFace
123
123
  trainer outputs and performs necessary preprocessing for regression tasks.
124
-
124
+
125
125
  Args:
126
126
  y_true: The true values or HuggingFace EvalPrediction object
127
127
  y_score: The predicted values
128
128
  ignore_y: The value to ignore in the predictions and true values in corresponding positions
129
129
  *args: Additional positional arguments for the metric
130
130
  **kwargs: Additional keyword arguments for the metric
131
-
131
+
132
132
  Returns:
133
133
  dict: Dictionary containing the metric name and computed value
134
134
  """
@@ -168,16 +168,16 @@ class RegressionMetric(OmniMetric):
168
168
  def compute(self, y_true, y_score, *args, **kwargs):
169
169
  """
170
170
  Compute the regression metric, based on the true and predicted values.
171
-
171
+
172
172
  Args:
173
173
  y_true: The true values
174
174
  y_score: The predicted values
175
175
  *args: Additional positional arguments for the metric
176
176
  **kwargs: Additional keyword arguments for the metric
177
-
177
+
178
178
  Returns:
179
179
  The computed regression metric value
180
-
180
+
181
181
  Raises:
182
182
  NotImplementedError: If no metric function is provided and compute is not implemented
183
183
  """
@@ -12,6 +12,7 @@ import pickle
12
12
  import sys
13
13
  import tempfile
14
14
  import time
15
+ import warnings
15
16
 
16
17
  import ViennaRNA as RNA
17
18
  import findfile
@@ -24,13 +25,13 @@ default_omnigenome_repo = (
24
25
  def seed_everything(seed=42):
25
26
  """
26
27
  Sets random seeds for reproducibility across all random number generators.
27
-
28
+
28
29
  This function sets seeds for Python's random module, NumPy, PyTorch (CPU and CUDA),
29
30
  and sets the PYTHONHASHSEED environment variable to ensure reproducible results
30
31
  across different runs.
31
-
32
+
32
33
  Args:
33
- seed (int): The seed value to use for all random number generators.
34
+ seed (int): The seed value to use for all random number generators.
34
35
  Defaults to 42.
35
36
 
36
37
  Example:
@@ -48,58 +49,50 @@ def seed_everything(seed=42):
48
49
  torch.manual_seed(seed)
49
50
  torch.cuda.manual_seed(seed)
50
51
  torch.backends.cudnn.deterministic = True
52
+ torch.backends.cudnn.benchmark = False
51
53
 
52
54
 
53
55
  class RNA2StructureCache(dict):
54
56
  """
55
- A cache for RNA sequence to structure predictions using ViennaRNA.
56
-
57
- This class provides a dictionary-like interface for caching RNA secondary
58
- structure predictions. It uses ViennaRNA for structure prediction and
59
- supports both single sequences and batches of sequences.
60
-
61
- The cache can be persisted to disk and loaded back, making it useful for
62
- avoiding redundant structure predictions across multiple runs.
63
-
57
+ A cache for RNA secondary structure predictions using ViennaRNA.
58
+
59
+ This class provides a caching mechanism for RNA secondary structure predictions
60
+ to avoid redundant computations. It supports both single sequence and batch
61
+ processing with optional multiprocessing for improved performance.
62
+
64
63
  Attributes:
65
- cache_file (str): Path to the cache file on disk.
66
- cache (dict): The in-memory cache dictionary.
67
- queue_num (int): Counter for tracking cache updates.
64
+ cache (dict): Dictionary storing sequence-structure mappings
65
+ cache_file (str): Path to the cache file on disk
66
+ queue_num (int): Counter for tracking cache updates
68
67
  """
69
68
 
70
69
  def __init__(self, cache_file=None, *args, **kwargs):
71
70
  """
72
- Initializes the RNA structure cache.
71
+ Initialize the RNA structure cache.
73
72
 
74
73
  Args:
75
74
  cache_file (str, optional): Path to the cache file. If None, uses
76
- a default path in `__OMNIGENOME_DATA__`.
77
- *args: Additional arguments passed to dict constructor.
78
- **kwargs: Additional keyword arguments passed to dict constructor.
79
-
80
- Example:
81
- >>> # Initialize with default cache file
82
- >>> cache = RNA2StructureCache()
83
-
84
- >>> # Initialize with custom cache file
85
- >>> cache = RNA2StructureCache("my_cache.pkl")
75
+ a default temporary file.
76
+ *args: Additional positional arguments for dict initialization
77
+ **kwargs: Additional keyword arguments for dict initialization
86
78
  """
87
79
  super().__init__(*args, **kwargs)
88
-
89
- if not cache_file:
90
- self.cache_file = "__OMNIGENOME_DATA__/rna2structure.cache.pkl"
91
- else:
92
- self.cache_file = cache_file
93
-
94
- if self.cache_file is None or not os.path.exists(self.cache_file):
95
- self.cache = {}
96
- else:
97
- fprint(f"Initialize sequence to structure cache from {self.cache_file}...")
98
- with open(self.cache_file, "rb") as f:
99
- self.cache = pickle.load(f)
100
-
80
+ self.cache = dict(*args, **kwargs)
81
+ self.cache_file = (
82
+ cache_file
83
+ if cache_file is not None
84
+ else os.path.join(tempfile.gettempdir(), "rna_structure_cache.pkl")
85
+ )
101
86
  self.queue_num = 0
102
87
 
88
+ # Load existing cache if available
89
+ if os.path.exists(self.cache_file):
90
+ try:
91
+ with open(self.cache_file, "rb") as f:
92
+ self.cache.update(pickle.load(f))
93
+ except Exception as e:
94
+ warnings.warn(f"Failed to load cache file: {e}")
95
+
103
96
  def __getitem__(self, key):
104
97
  """Gets a cached structure prediction."""
105
98
  return self.cache[key]
@@ -116,15 +109,31 @@ class RNA2StructureCache(dict):
116
109
  """String representation of the cache."""
117
110
  return str(self.cache)
118
111
 
112
+ def _fold_single_sequence(self, sequence):
113
+ """
114
+ Predict structure for a single sequence (worker function for multiprocessing).
115
+
116
+ Args:
117
+ sequence (str): RNA sequence to fold
118
+
119
+ Returns:
120
+ tuple: (structure, mfe) tuple
121
+ """
122
+ try:
123
+ return RNA.fold(sequence)
124
+ except Exception as e:
125
+ warnings.warn(f"Failed to fold sequence {sequence}: {e}")
126
+ return ("." * len(sequence), 0.0)
127
+
119
128
  def fold(self, sequence, return_mfe=False, num_workers=1):
120
129
  """
121
130
  Predicts RNA secondary structure for given sequences.
122
-
131
+
123
132
  This method predicts RNA secondary structures using ViennaRNA. It supports
124
133
  both single sequences and batches of sequences. The method uses caching
125
134
  to avoid redundant predictions and supports multiprocessing for batch
126
135
  processing on non-Windows systems.
127
-
136
+
128
137
  Args:
129
138
  sequence (str or list): A single RNA sequence or a list of sequences.
130
139
  return_mfe (bool): Whether to return minimum free energy along with
@@ -141,7 +150,7 @@ class RNA2StructureCache(dict):
141
150
  >>> # Predict structure for a single sequence
142
151
  >>> structure = cache.fold("GGGAAAUCC")
143
152
  >>> print(structure) # "(((...)))"
144
-
153
+
145
154
  >>> # Predict structures for multiple sequences
146
155
  >>> structures = cache.fold(["GGGAAAUCC", "AUUGCUAA"])
147
156
  >>> print(structures) # ["(((...)))", "........"]
@@ -151,39 +160,62 @@ class RNA2StructureCache(dict):
151
160
  else:
152
161
  sequences = sequence
153
162
 
154
- if (
155
- os.name != "nt" and len(sequences) > 1
156
- ): # multiprocessing is not working on Windows in my case
157
- num_workers = min(os.cpu_count(), len(sequences))
163
+ # Determine if we should use multiprocessing
164
+ use_multiprocessing = (
165
+ os.name != "nt" # Not Windows
166
+ and len(sequences) > 1 # Multiple sequences
167
+ and num_workers > 1 # Multiple workers requested
168
+ )
158
169
 
159
- structures = []
170
+ # Find sequences that need prediction
171
+ sequences_to_predict = [seq for seq in sequences if seq not in self.cache]
160
172
 
161
- if not all([seq in self.cache for seq in sequences]):
162
- if num_workers == 1:
163
- for seq in sequences:
164
- if seq not in self.cache:
165
- self.queue_num += 1
166
- self.cache[seq] = RNA.fold(seq)
167
- else:
173
+ if sequences_to_predict:
174
+ if use_multiprocessing:
175
+ # Use multiprocessing for batch prediction
168
176
  if num_workers is None:
169
- num_workers = min(os.cpu_count(), len(sequences))
170
-
171
- with multiprocessing.Pool(num_workers) as pool:
172
- for seq in sequences:
173
- if seq not in self.cache:
177
+ num_workers = min(os.cpu_count(), len(sequences_to_predict))
178
+
179
+ try:
180
+ # Set multiprocessing start method to 'spawn' for better compatibility
181
+ if multiprocessing.get_start_method(allow_none=True) != "spawn":
182
+ multiprocessing.set_start_method("spawn", force=True)
183
+
184
+ with multiprocessing.Pool(num_workers) as pool:
185
+ # Use map instead of apply_async for better error handling
186
+ results = pool.map(
187
+ self._fold_single_sequence, sequences_to_predict
188
+ )
189
+
190
+ # Update cache with results
191
+ for seq, result in zip(sequences_to_predict, results):
192
+ self.cache[seq] = result
174
193
  self.queue_num += 1
175
- async_result = pool.apply_async(RNA.fold, args=(seq,))
176
- structures.append((seq, async_result))
177
194
 
178
- for seq, result in structures:
179
- self.cache[seq] = result.get() # result is a tuple
195
+ except Exception as e:
196
+ warnings.warn(
197
+ f"Multiprocessing failed, falling back to sequential: {e}"
198
+ )
199
+ # Fallback to sequential processing
200
+ for seq in sequences_to_predict:
201
+ self.cache[seq] = self._fold_single_sequence(seq)
202
+ self.queue_num += 1
203
+ else:
204
+ # Sequential processing
205
+ for seq in sequences_to_predict:
206
+ self.cache[seq] = self._fold_single_sequence(seq)
207
+ self.queue_num += 1
180
208
 
209
+ # Prepare output
181
210
  if return_mfe:
182
211
  structures = [self.cache[seq] for seq in sequences]
183
212
  else:
184
213
  structures = [self.cache[seq][0] for seq in sequences]
214
+
215
+ # Update cache file periodically
185
216
  self.update_cache_file(self.cache_file)
186
217
 
218
+ # Return single result or list
187
219
  if len(structures) == 1:
188
220
  return structures[0]
189
221
  else:
@@ -192,10 +224,10 @@ class RNA2StructureCache(dict):
192
224
  def update_cache_file(self, cache_file=None):
193
225
  """
194
226
  Updates the cache file on disk.
195
-
227
+
196
228
  This method saves the in-memory cache to disk. It only saves when
197
229
  the queue_num reaches 100 to avoid excessive disk I/O.
198
-
230
+
199
231
  Args:
200
232
  cache_file (str, optional): Path to the cache file. If None, uses
201
233
  the instance's cache_file.
@@ -209,24 +241,26 @@ class RNA2StructureCache(dict):
209
241
  if cache_file is None:
210
242
  cache_file = self.cache_file
211
243
 
212
- if not os.path.exists(os.path.dirname(cache_file)):
213
- os.makedirs(os.path.dirname(cache_file))
244
+ try:
245
+ if not os.path.exists(os.path.dirname(cache_file)):
246
+ os.makedirs(os.path.dirname(cache_file))
214
247
 
215
- # print(f"Updating cache file {cache_file}...")
216
- with open(cache_file, "wb") as f:
217
- pickle.dump(self.cache, f)
248
+ with open(cache_file, "wb") as f:
249
+ pickle.dump(self.cache, f)
218
250
 
219
- self.queue_num = 0
251
+ self.queue_num = 0
252
+ except Exception as e:
253
+ warnings.warn(f"Failed to update cache file: {e}")
220
254
 
221
255
 
222
256
  def env_meta_info():
223
257
  """
224
258
  Collects metadata about the current environment and library versions.
225
-
259
+
226
260
  This function gathers information about the current Python environment,
227
261
  including versions of key libraries like PyTorch and Transformers,
228
262
  as well as OmniGenome version information.
229
-
263
+
230
264
  Returns:
231
265
  dict: A dictionary containing environment metadata including:
232
266
  - library_name: Name of the OmniGenome library
@@ -256,7 +290,7 @@ def env_meta_info():
256
290
  def naive_secondary_structure_repair(sequence, structure):
257
291
  """
258
292
  Repair the secondary structure of a sequence.
259
-
293
+
260
294
  This function attempts to repair malformed RNA secondary structure
261
295
  representations by ensuring proper bracket matching. It handles
262
296
  common issues like unmatched brackets by converting them to dots.
@@ -294,7 +328,7 @@ def naive_secondary_structure_repair(sequence, structure):
294
328
  def save_args(config, save_path):
295
329
  """
296
330
  Save arguments to a file.
297
-
331
+
298
332
  This function saves the arguments from a configuration object to a text file.
299
333
  It's useful for logging experiment parameters and configurations.
300
334
 
@@ -317,7 +351,7 @@ def save_args(config, save_path):
317
351
  def print_args(config, logger=None):
318
352
  """
319
353
  Print the arguments to the console.
320
-
354
+
321
355
  This function prints the arguments from a configuration object to the console
322
356
  or a logger. It's useful for debugging and logging experiment parameters.
323
357
 
@@ -330,110 +364,140 @@ def print_args(config, logger=None):
330
364
  >>> config = Namespace(learning_rate=0.001, batch_size=32)
331
365
  >>> print_args(config)
332
366
  """
333
- args = [key for key in sorted(config.args.keys())]
334
- if logger:
335
- logger.info(args)
367
+ if logger is None:
368
+ for arg in config.args:
369
+ if config.args_call_count[arg]:
370
+ print("{}: {}".format(arg, config.args[arg]))
336
371
  else:
337
- fprint(args)
372
+ for arg in config.args:
373
+ if config.args_call_count[arg]:
374
+ logger.info("{}: {}".format(arg, config.args[arg]))
338
375
 
339
376
 
340
377
  def fprint(*objects, sep=" ", end="\n", file=sys.stdout, flush=False):
341
378
  """
342
- Custom print function that adds a timestamp and the pyabsa version before the printed message.
379
+ Enhanced print function with automatic flushing.
380
+
381
+ This function provides a print-like interface with automatic flushing
382
+ to ensure output is displayed immediately. It's useful for real-time
383
+ logging and progress tracking.
343
384
 
344
385
  Args:
345
- *objects: Any number of objects to be printed
346
- sep (str, optional): Separator between objects. Defaults to " ".
347
- end (str, optional): Ending character after all objects are printed. Defaults to "\n".
348
- file (io.TextIOWrapper, optional): Text file to write printed output to. Defaults to sys.stdout.
349
- flush (bool, optional): Whether to flush output buffer after printing. Defaults to False.
386
+ *objects: Objects to print
387
+ sep (str): Separator between objects (default: " ")
388
+ end (str): String appended after the last value (default: "\n")
389
+ file: File-like object to write to (default: sys.stdout)
390
+ flush (bool): Whether to flush the stream (default: False)
391
+
392
+ Example:
393
+ >>> fprint("Training started...", flush=True)
394
+ >>> fprint("Epoch 1/10", "Loss: 0.5", sep=" | ")
350
395
  """
351
- from omnigenome import __version__
352
- from omnigenome import __name__
353
-
354
- print(
355
- time.strftime(
356
- "[%Y-%m-%d %H:%M:%S] [{} {}] ".format(__name__, __version__),
357
- time.localtime(time.time()),
358
- ),
359
- *objects,
360
- sep=sep,
361
- end=end,
362
- file=file,
363
- flush=flush,
364
- )
396
+ print(*objects, sep=sep, end=end, file=file, flush=True)
365
397
 
366
398
 
367
399
  def clean_temp_checkpoint(days_threshold=7):
368
400
  """
369
- 删除超过指定时间的 checkpoint 文件。
401
+ Clean up temporary checkpoint files older than specified days.
402
+
403
+ This function removes temporary checkpoint files that are older than
404
+ the specified threshold to free up disk space.
405
+
406
+ Args:
407
+ days_threshold (int): Number of days after which files are considered old.
408
+ Defaults to 7.
370
409
 
371
- 参数:
372
- - directory (str): 文件所在的目录路径。
373
- - file_extension (str): checkpoint 文件的扩展名,默认是 ".ckpt"。
374
- - days_threshold (int): 超过多少天的文件将被删除,默认是 7 天。
410
+ Example:
411
+ >>> clean_temp_checkpoint(3) # Remove files older than 3 days
375
412
  """
376
- # 获取当前时间
377
- import os
378
- from datetime import datetime, timedelta
379
-
380
- current_time = datetime.now()
381
- ckpt_files = findfile.find_cwd_files(["tmp_ckpt", ".pt"])
382
- # 遍历目录中的所有文件
383
- for file_path in ckpt_files:
384
- # 获取文件的最后修改时间
385
- file_mod_time = datetime.fromtimestamp(os.path.getmtime(file_path))
386
-
387
- # 计算文件是否超过指定的时间阈值
388
- if current_time - file_mod_time > timedelta(days=days_threshold):
413
+ import glob
414
+ import time
415
+
416
+ temp_patterns = [
417
+ "temp_checkpoint_*",
418
+ "checkpoint_*",
419
+ "*.tmp",
420
+ "*.temp",
421
+ ]
422
+
423
+ current_time = time.time()
424
+ threshold_time = current_time - (days_threshold * 24 * 60 * 60)
425
+
426
+ for pattern in temp_patterns:
427
+ for file_path in glob.glob(pattern):
389
428
  try:
390
- # 删除文件
391
- os.remove(file_path)
392
- print(f"Deleted: {file_path}")
393
- except Exception as e:
394
- print(f"Error deleting {file_path}: {e}")
429
+ if os.path.getmtime(file_path) < threshold_time:
430
+ os.remove(file_path)
431
+ except Exception:
432
+ pass
395
433
 
396
434
 
397
435
  def load_module_from_path(module_name, file_path):
398
- import importlib
436
+ """
437
+ Load a Python module from a file path.
438
+
439
+ This function dynamically loads a Python module from a file path,
440
+ useful for loading configuration files or custom modules.
441
+
442
+ Args:
443
+ module_name (str): Name to assign to the loaded module
444
+ file_path (str): Path to the Python file to load
445
+
446
+ Returns:
447
+ module: The loaded module object
448
+
449
+ Example:
450
+ >>> config = load_module_from_path("config", "config.py")
451
+ >>> print(config.some_variable)
452
+ """
453
+ import importlib.util
399
454
 
400
455
  spec = importlib.util.spec_from_file_location(module_name, file_path)
401
456
  module = importlib.util.module_from_spec(spec)
402
- try:
403
- spec.loader.exec_module(module)
404
- except FileNotFoundError:
405
- raise ImportError(f"Cannot find the module {module_name} from {file_path}.")
457
+ spec.loader.exec_module(module)
406
458
  return module
407
459
 
408
460
 
409
461
  def check_bench_version(bench_version, omnigenome_version):
410
- assert (
411
- bench_version is not None
412
- ), "Benchmark metadata does not contain a valid __omnigenome__ version."
462
+ """
463
+ Check if benchmark version is compatible with OmniGenome version.
413
464
 
414
- if not isinstance(bench_version, (int, float, str)):
415
- raise TypeError(
416
- f"Invalid type for benchmark version. Expected int, float, or str but got {type(bench_version).__name__}."
417
- )
465
+ This function compares the benchmark version with the OmniGenome version
466
+ to ensure compatibility and warns if there are potential issues.
418
467
 
419
- assert (
420
- omnigenome_version is not None
421
- ), "AutoBench is missing a valid omnigenome version."
468
+ Args:
469
+ bench_version (str): Version of the benchmark
470
+ omnigenome_version (str): Version of OmniGenome
422
471
 
423
- if bench_version > omnigenome_version:
424
- raise ValueError(
425
- f"AutoBench version {omnigenome_version} is not compatible with the benchmark version "
426
- f"{bench_version}. Please update the benchmark or AutoBench."
472
+ Example:
473
+ >>> check_bench_version("0.2.0", "0.3.0")
474
+ """
475
+ if bench_version != omnigenome_version:
476
+ warnings.warn(
477
+ f"Benchmark version ({bench_version}) differs from "
478
+ f"OmniGenome version ({omnigenome_version}). "
479
+ f"This may cause compatibility issues."
427
480
  )
428
481
 
429
482
 
430
483
  def clean_temp_dir_pt_files():
431
- tmp_dir = tempfile.gettempdir()
432
- for f in os.listdir(tmp_dir):
433
- if f.endswith(".pt") and f.startswith("tmp_ckpt"):
434
- path = os.path.join(tmp_dir, f)
484
+ """
485
+ Clean up temporary PyTorch files in the current directory.
486
+
487
+ This function removes temporary PyTorch files (like .pt, .pth files)
488
+ that may be left over from previous runs.
489
+
490
+ Example:
491
+ >>> clean_temp_dir_pt_files()
492
+ """
493
+ import glob
494
+
495
+ temp_patterns = ["*.pt", "*.pth", "temp_*", "checkpoint_*"]
496
+
497
+ for pattern in temp_patterns:
498
+ for file_path in glob.glob(pattern):
435
499
  try:
436
- os.remove(path)
437
- print(f"Removed: {path}")
438
- except Exception as e:
439
- print(f"Failed to remove {path}: {e}")
500
+ if os.path.isfile(file_path):
501
+ os.remove(file_path)
502
+ except Exception:
503
+ pass
@@ -9,4 +9,3 @@
9
9
  """
10
10
  This package contains modules for data augmentation.
11
11
  """
12
-