omnigenome 0.3.1a0__py3-none-any.whl → 1.0.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. omnigenome/__init__.py +26 -266
  2. {omnigenome-0.3.1a0.dist-info → omnigenome-1.0.0b0.dist-info}/METADATA +8 -9
  3. omnigenome-1.0.0b0.dist-info/RECORD +6 -0
  4. omnigenome/auto/__init__.py +0 -3
  5. omnigenome/auto/auto_bench/__init__.py +0 -11
  6. omnigenome/auto/auto_bench/auto_bench.py +0 -494
  7. omnigenome/auto/auto_bench/auto_bench_cli.py +0 -230
  8. omnigenome/auto/auto_bench/auto_bench_config.py +0 -216
  9. omnigenome/auto/auto_bench/config_check.py +0 -34
  10. omnigenome/auto/auto_train/__init__.py +0 -12
  11. omnigenome/auto/auto_train/auto_train.py +0 -429
  12. omnigenome/auto/auto_train/auto_train_cli.py +0 -222
  13. omnigenome/auto/bench_hub/__init__.py +0 -11
  14. omnigenome/auto/bench_hub/bench_hub.py +0 -25
  15. omnigenome/cli/__init__.py +0 -12
  16. omnigenome/cli/commands/__init__.py +0 -12
  17. omnigenome/cli/commands/base.py +0 -83
  18. omnigenome/cli/commands/bench/__init__.py +0 -12
  19. omnigenome/cli/commands/bench/bench_cli.py +0 -202
  20. omnigenome/cli/commands/rna/__init__.py +0 -12
  21. omnigenome/cli/commands/rna/rna_design.py +0 -177
  22. omnigenome/cli/omnigenome_cli.py +0 -128
  23. omnigenome/src/__init__.py +0 -11
  24. omnigenome/src/abc/__init__.py +0 -11
  25. omnigenome/src/abc/abstract_dataset.py +0 -641
  26. omnigenome/src/abc/abstract_metric.py +0 -114
  27. omnigenome/src/abc/abstract_model.py +0 -690
  28. omnigenome/src/abc/abstract_tokenizer.py +0 -269
  29. omnigenome/src/dataset/__init__.py +0 -16
  30. omnigenome/src/dataset/omni_dataset.py +0 -437
  31. omnigenome/src/lora/__init__.py +0 -12
  32. omnigenome/src/lora/lora_model.py +0 -300
  33. omnigenome/src/metric/__init__.py +0 -15
  34. omnigenome/src/metric/classification_metric.py +0 -184
  35. omnigenome/src/metric/metric.py +0 -199
  36. omnigenome/src/metric/ranking_metric.py +0 -142
  37. omnigenome/src/metric/regression_metric.py +0 -191
  38. omnigenome/src/misc/__init__.py +0 -3
  39. omnigenome/src/misc/utils.py +0 -503
  40. omnigenome/src/model/__init__.py +0 -19
  41. omnigenome/src/model/augmentation/__init__.py +0 -11
  42. omnigenome/src/model/augmentation/model.py +0 -219
  43. omnigenome/src/model/classification/__init__.py +0 -11
  44. omnigenome/src/model/classification/model.py +0 -638
  45. omnigenome/src/model/embedding/__init__.py +0 -11
  46. omnigenome/src/model/embedding/model.py +0 -263
  47. omnigenome/src/model/mlm/__init__.py +0 -11
  48. omnigenome/src/model/mlm/model.py +0 -177
  49. omnigenome/src/model/module_utils.py +0 -232
  50. omnigenome/src/model/regression/__init__.py +0 -11
  51. omnigenome/src/model/regression/model.py +0 -781
  52. omnigenome/src/model/regression/resnet.py +0 -483
  53. omnigenome/src/model/rna_design/__init__.py +0 -11
  54. omnigenome/src/model/rna_design/model.py +0 -476
  55. omnigenome/src/model/seq2seq/__init__.py +0 -11
  56. omnigenome/src/model/seq2seq/model.py +0 -44
  57. omnigenome/src/tokenizer/__init__.py +0 -16
  58. omnigenome/src/tokenizer/bpe_tokenizer.py +0 -226
  59. omnigenome/src/tokenizer/kmers_tokenizer.py +0 -247
  60. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +0 -249
  61. omnigenome/src/trainer/__init__.py +0 -14
  62. omnigenome/src/trainer/accelerate_trainer.py +0 -747
  63. omnigenome/src/trainer/hf_trainer.py +0 -75
  64. omnigenome/src/trainer/trainer.py +0 -591
  65. omnigenome/utility/__init__.py +0 -3
  66. omnigenome/utility/dataset_hub/__init__.py +0 -12
  67. omnigenome/utility/dataset_hub/dataset_hub.py +0 -178
  68. omnigenome/utility/ensemble.py +0 -324
  69. omnigenome/utility/hub_utils.py +0 -517
  70. omnigenome/utility/model_hub/__init__.py +0 -11
  71. omnigenome/utility/model_hub/model_hub.py +0 -232
  72. omnigenome/utility/pipeline_hub/__init__.py +0 -11
  73. omnigenome/utility/pipeline_hub/pipeline.py +0 -483
  74. omnigenome/utility/pipeline_hub/pipeline_hub.py +0 -129
  75. omnigenome-0.3.1a0.dist-info/RECORD +0 -78
  76. omnigenome-0.3.1a0.dist-info/entry_points.txt +0 -3
  77. {omnigenome-0.3.1a0.dist-info → omnigenome-1.0.0b0.dist-info}/WHEEL +0 -0
  78. {omnigenome-0.3.1a0.dist-info → omnigenome-1.0.0b0.dist-info}/licenses/LICENSE +0 -0
  79. {omnigenome-0.3.1a0.dist-info → omnigenome-1.0.0b0.dist-info}/top_level.txt +0 -0
@@ -1,690 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # file: omnigenome_model.py
3
- # time: 18:36 06/04/2024
4
- # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
- # github: https://github.com/yangheng95
6
- # huggingface: https://huggingface.co/yangheng
7
- # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
- # Copyright (C) 2019-2024. All Rights Reserved.
9
- import json
10
- import os
11
- import shutil
12
- import warnings
13
- import inspect
14
- from importlib import import_module
15
-
16
- import dill
17
- import findfile
18
- import torch
19
- from transformers import AutoModel, AutoConfig, AutoTokenizer, BatchEncoding
20
-
21
- from ..misc.utils import fprint, env_meta_info
22
-
23
- warnings.filterwarnings("once")
24
-
25
-
26
- def count_parameters(model):
27
- """
28
- Counts the number of trainable parameters in a model.
29
-
30
- This function iterates through all parameters of a PyTorch model and counts
31
- only those that require gradients (i.e., trainable parameters).
32
-
33
- Args:
34
- model (torch.nn.Module): A PyTorch model.
35
-
36
- Returns:
37
- int: The total number of trainable parameters.
38
-
39
- Example:
40
- >>> model = OmniModelForSequenceClassification(config, tokenizer)
41
- >>> num_params = count_parameters(model)
42
- >>> print(f"Model has {num_params} trainable parameters")
43
- """
44
- return sum(p.numel() for p in model.parameters() if p.requires_grad)
45
-
46
-
47
- class OmniModel(torch.nn.Module):
48
- """
49
- Abstract base class for all models in OmniGenome.
50
-
51
- This class provides a unified interface for all genomic models in the OmniGenome
52
- framework. It handles model initialization, forward passes, loss computation,
53
- prediction, inference, and model persistence.
54
-
55
- The class is designed to work with various types of genomic data and tasks,
56
- including sequence classification, token classification, regression, and more.
57
-
58
- Attributes:
59
- model (torch.nn.Module): The underlying PyTorch model.
60
- config: The model configuration.
61
- tokenizer: The tokenizer associated with the model.
62
- metadata (dict): Metadata about the model including version info.
63
- loss_fn: The loss function for training.
64
- dropout (torch.nn.Dropout): Dropout layer for regularization.
65
- activation (torch.nn.Tanh): Activation function.
66
- pad_token_id (int): ID of the padding token.
67
- """
68
-
69
- def __init__(self, config_or_model, tokenizer, *args, **kwargs):
70
- """
71
- Initializes the model.
72
-
73
- This method handles different types of model initialization:
74
- - From a pre-trained model path (string)
75
- - From a PyTorch model instance
76
- - From a configuration object
77
-
78
- Args:
79
- config_or_model: A model configuration, a pre-trained model path (str),
80
- or a `torch.nn.Module` instance.
81
- tokenizer: The tokenizer associated with the model.
82
- *args: Additional positional arguments.
83
- **kwargs: Additional keyword arguments.
84
- - label2id (dict): Mapping from class labels to IDs.
85
- - num_labels (int): The number of labels.
86
- - trust_remote_code (bool): Whether to trust remote code when loading
87
- from Hugging Face Hub. Defaults to True.
88
- - ignore_mismatched_sizes (bool): Whether to ignore size mismatches
89
- when loading pre-trained weights. Defaults to False.
90
- - dropout (float): Dropout rate. Defaults to 0.0.
91
-
92
- Raises:
93
- ValueError: If config_or_model is not a valid type or if required
94
- configuration is missing.
95
- RuntimeError: If the hidden size cannot be determined from the config.
96
-
97
- Example:
98
- >>> # Initialize from a pre-trained model
99
- >>> model = OmniModelForSequenceClassification("model_path", tokenizer)
100
-
101
- >>> # Initialize from a configuration
102
- >>> config = AutoConfig.from_pretrained("model_path")
103
- >>> model = OmniModelForSequenceClassification(config, tokenizer)
104
- """
105
- self.loss_fn = None
106
-
107
- label2id = kwargs.pop("label2id", None)
108
- trust_remote_code = kwargs.pop("trust_remote_code", True)
109
- num_labels = kwargs.pop("num_labels", None)
110
- ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
111
-
112
- if label2id is not None and num_labels is None:
113
- num_labels = len(label2id)
114
- elif num_labels is not None and label2id is None:
115
- label2id = {str(i): i for i in range(num_labels)}
116
-
117
- # do not change the order of the following lines
118
- super().__init__(*args, **kwargs)
119
-
120
- if isinstance(config_or_model, str):
121
- config = AutoConfig.from_pretrained(
122
- config_or_model,
123
- num_labels=num_labels,
124
- label2id=label2id,
125
- trust_remote_code=trust_remote_code,
126
- )
127
- # Load the model from either `architectures` or `auto_map`
128
- if hasattr(config, "auto_map") and config.auto_map:
129
- architectures = list(set(config.auto_map.keys()) - set(["AutoConfig"]))
130
- if architectures:
131
- model_cls_name = (
132
- "AutoModel"
133
- if "AutoModel" in architectures
134
- else architectures[-1]
135
- )
136
- model_cls = getattr(import_module(f"transformers"), model_cls_name)
137
-
138
- model = model_cls.from_pretrained(
139
- config_or_model,
140
- config=config,
141
- trust_remote_code=trust_remote_code,
142
- ignore_mismatched_sizes=ignore_mismatched_sizes,
143
- ).base_model
144
- else:
145
- raise ValueError(
146
- f"The model cannot be instantiated from {config_or_model}. "
147
- f"Please check the model configuration contains the architectures or auto_map."
148
- )
149
- elif hasattr(config, "architectures") and config.architectures:
150
- model_cls_name = (
151
- AutoModel
152
- if "AutoModel" in config.architectures
153
- else config.architectures[-1]
154
- )
155
- model_cls = getattr(import_module(f"transformers"), model_cls_name)
156
- model = model_cls.from_pretrained(
157
- config_or_model,
158
- config=config,
159
- trust_remote_code=trust_remote_code,
160
- ignore_mismatched_sizes=ignore_mismatched_sizes,
161
- ).base_model
162
- else:
163
- raise ValueError(
164
- "Neither `architectures` nor `auto_map` is defined in the config."
165
- )
166
- self.model = model
167
- self.model.config = config
168
- del model_cls
169
- elif isinstance(config_or_model, torch.nn.Module):
170
- self.model = config_or_model
171
- self.model.config.num_labels = (
172
- num_labels if len(label2id) == num_labels else len(label2id)
173
- )
174
- self.model.config.label2id = label2id
175
- elif isinstance(config_or_model, AutoConfig):
176
- config = config_or_model
177
- config.num_labels = (
178
- num_labels if len(label2id) == num_labels else len(label2id)
179
- )
180
- config.label2id = label2id
181
- self.model = AutoModel.from_config(config)
182
- self.model.config = config
183
- else:
184
- raise ValueError(
185
- "The config_or_model should be either a string, a torch.nn.Module or a AutoConfig object."
186
- )
187
-
188
- # Update the config
189
- self.config = self.model.config
190
- if isinstance(label2id, dict):
191
- self.config.label2id = label2id
192
- self.config.id2label = {v: k for k, v in label2id.items()}
193
- if (
194
- not hasattr(self.config, "num_labels")
195
- or len(self.config.id2label) != self.config.num_labels
196
- ):
197
- fprint(
198
- "Warning: The number of labels in the config is not equal to the number of labels in the label2id dictionary. "
199
- )
200
- fprint(
201
- "Please check the label2id dictionary and the num_labels parameter in the config."
202
- )
203
- self.config.num_labels = len(self.config.id2label)
204
-
205
- assert (
206
- len(self.config.label2id) == num_labels
207
- ), f"Expected {num_labels} labels, but got {len(self.config.label2id)} in label2id dictionary."
208
-
209
- # The metadata of the model
210
- self.metadata = env_meta_info()
211
- self.metadata["model_cls"] = self.__class__.__name__
212
-
213
- # The config of the model
214
- if hasattr(self.config, "n_embd") and self.config.n_embd:
215
- self.config.hidden_size = self.config.n_embd
216
- elif hasattr(self.config, "d_model") and self.config.d_model:
217
- self.config.hidden_size = self.config.d_model
218
- elif hasattr(self.config, "hidden_size") and self.config.hidden_size:
219
- self.config.hidden_size = self.config.hidden_size
220
- else:
221
- raise RuntimeError(
222
- "The hidden size of the model is not found in the config."
223
- )
224
-
225
- # The tokenizer of the model
226
- self.tokenizer = tokenizer
227
- self.metadata["tokenizer_cls"] = self.tokenizer.__class__.__name__
228
- if hasattr(self.tokenizer, "base_tokenizer"):
229
- self.pad_token_id = self.tokenizer.base_tokenizer.pad_token_id
230
- else:
231
- self.pad_token_id = self.tokenizer.pad_token_id
232
-
233
- self.dropout = torch.nn.Dropout(kwargs.get("dropout", 0.0))
234
- self.activation = torch.nn.Tanh()
235
-
236
- def last_hidden_state_forward(self, **inputs):
237
- """
238
- Performs a forward pass to get the last hidden state from the base model.
239
-
240
- This method handles the forward pass through the underlying model and
241
- returns the last hidden state. It also handles compatibility with different
242
- model architectures by mapping input parameters appropriately.
243
-
244
- Args:
245
- **inputs: The inputs to the model, compatible with the base model's
246
- forward method. Typically includes 'input_ids', 'attention_mask',
247
- and other model-specific parameters.
248
-
249
- Returns:
250
- torch.Tensor: The last hidden state tensor.
251
-
252
- Example:
253
- >>> inputs = {
254
- ... 'input_ids': torch.tensor([[1, 2, 3, 4]]),
255
- ... 'attention_mask': torch.tensor([[1, 1, 1, 1]])
256
- ... }
257
- >>> hidden_states = model.last_hidden_state_forward(**inputs)
258
- """
259
- model = self.model
260
- input_mapping = {}
261
- inputs["output_hidden_states"] = True
262
-
263
- if "strippedhyena" in model.__class__.__name__.lower():
264
- inputs["x"] = inputs["input_ids"] # For compatibility with Evo models
265
- if isinstance(inputs, BatchEncoding) or isinstance(inputs, dict):
266
- # Determine the input parameter names of the model's forward method
267
- forward_params = inspect.signature(model.forward).parameters
268
- # Map the inputs to the forward method parameters
269
- for param in forward_params:
270
- if param in inputs:
271
- input_mapping[param] = inputs[param]
272
- # 对于未在模型签名中声明的关键参数,可以给出警告或日志
273
- ignored_keys = set(inputs.keys()) - set(input_mapping.keys())
274
- if ignored_keys:
275
- warnings.warn(f"Warning: Ignored keys in inputs: {ignored_keys}")
276
-
277
- inputs = input_mapping
278
- elif isinstance(inputs, tuple):
279
- input_ids = inputs[0]
280
- attention_mask = inputs[1] if len(inputs) > 1 else None
281
- inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
282
- elif isinstance(inputs, torch.Tensor):
283
- shape = inputs.shape
284
- try:
285
- if len(shape) == 3:
286
- if shape[1] == 2:
287
- input_ids = inputs[:, 0]
288
- attention_mask = inputs[:, 1]
289
- else:
290
- input_ids = inputs[0]
291
- attention_mask = inputs[1] if len(inputs) > 1 else None
292
- elif len(shape) == 2:
293
- input_ids = inputs
294
- attention_mask = None
295
- else:
296
- raise ValueError(
297
- f"Failed to get the input_ids and attention_mask from the inputs, got shape {shape}."
298
- )
299
- except:
300
- raise ValueError(
301
- f"Failed to get the input_ids and attention_mask from the inputs, got shape {shape}."
302
- )
303
- inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
304
- else:
305
- raise ValueError(
306
- f"The inputs should be a tuple, BatchEncoding or a dictionary-like object, got {type(inputs)}."
307
- )
308
-
309
- # 执行模型
310
- outputs = model(**inputs)
311
-
312
- if not hasattr(outputs, "last_hidden_state"):
313
- warnings.warn(
314
- f"last_hidden_state not found in the outputs from the {model.__class__.__name__} model."
315
- )
316
-
317
- if hasattr(outputs, "last_hidden_state"):
318
- last_hidden_state = outputs.last_hidden_state
319
- elif isinstance(outputs, dict) and "last_hidden_state" in outputs:
320
- last_hidden_state = outputs["last_hidden_state"]
321
- elif hasattr(outputs, "hidden_states"):
322
- last_hidden_state = outputs.hidden_states[-1]
323
- elif isinstance(outputs, (list, tuple, torch.Tensor)):
324
- if len(outputs) <= 2:
325
- # For Evo models that return a tuple of (last_hidden_state, logits)
326
- last_hidden_state = outputs[0]
327
- elif len(outputs) >= 3:
328
- last_hidden_state = outputs[-1]
329
- else:
330
- raise ValueError(
331
- f"Cannot find the last hidden state in the outputs from the {model.__class__.__name__} model, "
332
- f"please check the model architecture."
333
- )
334
-
335
- return last_hidden_state
336
-
337
- def loss_function(self, logits, labels):
338
- """
339
- Calculates the loss. Must be implemented by subclasses.
340
-
341
- This method should be implemented by concrete model classes to define
342
- how the loss is calculated for their specific task (classification,
343
- regression, etc.).
344
-
345
- Args:
346
- logits (torch.Tensor): The model's output logits.
347
- labels (torch.Tensor): The ground truth labels.
348
-
349
- Returns:
350
- torch.Tensor: The calculated loss.
351
-
352
- Raises:
353
- NotImplementedError: If the method is not implemented by the subclass.
354
-
355
- Example:
356
- >>> # In a classification model
357
- >>> loss = model.loss_function(logits, labels)
358
- """
359
- raise NotImplementedError(
360
- "The loss_function() function should be implemented for your model."
361
- )
362
-
363
- def set_loss_fn(self, loss_function):
364
- """
365
- Sets a custom loss function for the model.
366
-
367
- This method allows setting a custom loss function that will be used
368
- during training. The loss function should be compatible with the
369
- model's output format.
370
-
371
- Args:
372
- loss_function (callable): A callable loss function that takes
373
- logits and labels as arguments.
374
-
375
- Example:
376
- >>> import torch.nn as nn
377
- >>> model.set_loss_fn(nn.CrossEntropyLoss())
378
- """
379
- self.loss_fn = loss_function
380
-
381
- def predict(self, sequence_or_inputs, **kwargs):
382
- """
383
- Performs prediction on raw inputs. Returns raw model outputs.
384
-
385
- This method takes raw sequences or tokenized inputs and returns
386
- the raw model outputs (logits, hidden states, etc.) without
387
- post-processing. It's useful for getting the model's direct
388
- predictions for further processing.
389
-
390
- Args:
391
- sequence_or_inputs: A sequence (str), list of sequences, or
392
- tokenized inputs (dict/tuple).
393
- **kwargs: Additional arguments for tokenization and inference.
394
-
395
- Returns:
396
- dict: A dictionary containing the raw model outputs, typically
397
- including 'logits', 'last_hidden_state', and other
398
- model-specific outputs.
399
-
400
- Example:
401
- >>> # Predict on a single sequence
402
- >>> outputs = model.predict("ATCGATCG")
403
-
404
- >>> # Predict on multiple sequences
405
- >>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"])
406
- """
407
- # Please implement the predict() function for your model
408
- raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
409
- return raw_outputs
410
-
411
- def inference(self, sequence_or_inputs, **kwargs):
412
- """
413
- Performs inference on raw inputs. Returns processed, human-readable predictions.
414
-
415
- This method takes raw sequences or tokenized inputs and returns
416
- processed predictions that are ready for human consumption. It
417
- typically includes post-processing steps like converting logits
418
- to class labels or probabilities.
419
-
420
- Args:
421
- sequence_or_inputs: A sequence (str), list of sequences, or
422
- tokenized inputs (dict/tuple).
423
- **kwargs: Additional arguments for tokenization and inference.
424
-
425
- Returns:
426
- dict: A dictionary containing the processed predictions, typically
427
- including 'predictions', 'confidence', and other
428
- human-readable outputs.
429
-
430
- Example:
431
- >>> # Inference on a single sequence
432
- >>> results = model.inference("ATCGATCG")
433
- >>> print(results['predictions']) # Class labels
434
-
435
- >>> # Inference on multiple sequences
436
- >>> results = model.inference(["ATCGATCG", "GCTAGCTA"])
437
- """
438
- # Please implement the predict() function for your model
439
- raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
440
- return raw_outputs
441
-
442
- def __call__(self, **inputs):
443
- """
444
- The main forward pass of the model, suitable for training loops.
445
-
446
- This method is the primary interface for model forward passes during
447
- training. It handles both tokenized inputs and raw sequences,
448
- calculates loss if labels are provided, and returns a comprehensive
449
- output dictionary.
450
-
451
- Args:
452
- **inputs: A dictionary of tokenized inputs, potentially including
453
- labels. Can also handle raw sequences that will be
454
- tokenized automatically.
455
-
456
- Returns:
457
- dict: A dictionary containing logits, last_hidden_state, labels,
458
- and loss (if labels were provided).
459
-
460
- Example:
461
- >>> # Training forward pass
462
- >>> outputs = model(
463
- ... input_ids=torch.tensor([[1, 2, 3, 4]]),
464
- ... attention_mask=torch.tensor([[1, 1, 1, 1]]),
465
- ... labels=torch.tensor([0])
466
- ... )
467
- >>> loss = outputs['loss']
468
- """
469
- # For transformer trainer integration, we need to pop the "inputs" to be a tokenized inputs object.
470
- # For native trainer, the inputs are already tokenized inputs object
471
- labels = inputs.pop("labels", None)
472
- inputs = inputs.pop("inputs", inputs)
473
- inputs["labels"] = labels
474
- if isinstance(inputs, dict):
475
-
476
- labels = inputs.get("labels", None)
477
- label = inputs.get("label", None)
478
- labels = labels if labels is not None else label
479
- # if labels is None:
480
- # warnings.warn(
481
- # "No labels are provided in the inputs, the model will not calculate the loss."
482
- # )
483
- elif isinstance(inputs, tuple):
484
- labels = inputs[1]
485
- inputs = inputs[0]
486
- elif labels is not None:
487
- labels = labels
488
- outputs = self.forward(**inputs)
489
-
490
- if labels is not None:
491
- outputs["loss"] = self._calculate_loss(outputs, labels)
492
- else:
493
- outputs["loss"] = None
494
- return outputs
495
-
496
- def _calculate_loss(self, outputs, labels):
497
- """
498
- Internal method to calculate loss if not already present in outputs.
499
-
500
- :param outputs: The dictionary of model outputs.
501
- :param labels: The ground truth labels.
502
- :return: The calculated loss.
503
- """
504
- loss = outputs.get("loss", None)
505
- if loss is not None:
506
- return outputs
507
-
508
- logits = outputs["logits"]
509
- if logits is not None or labels is not None:
510
- loss = self.loss_function(logits, labels)
511
- return loss
512
- else:
513
- raise RuntimeError(
514
- "The output of the forward() function should be a dictionary-like objective"
515
- " and have either 'loss', or 'logits' and 'labels' attribute."
516
- )
517
-
518
- def save(self, path, overwrite=False, dtype=torch.float16, **kwargs):
519
- """
520
- Saves the model, tokenizer, and metadata to a directory.
521
-
522
- :param path: The directory to save the model to.
523
- :param overwrite: Whether to overwrite the directory if it exists.
524
- :param dtype: The data type to save the model weights in.
525
- :param kwargs: Additional arguments.
526
- """
527
- self.eval()
528
-
529
- if os.path.exists(path) and not overwrite:
530
- raise FileExistsError(
531
- f"The path {path} already exists, please set overwrite=True to overwrite it."
532
- )
533
-
534
- if not os.path.exists(path):
535
- os.makedirs(path)
536
-
537
- for file in findfile.find_files(
538
- self.config.name_or_path,
539
- or_key=["bin", "json", "txt", "py"],
540
- exclude_key=["pytorch_model.bin", "model.safetensors"],
541
- ):
542
- shutil.copyfile(file, f"{path}/{os.path.basename(file)}")
543
-
544
- _device = self.model.device
545
- _dtype = self.model.dtype
546
- self.model.to(dtype).to("cpu")
547
- self.tokenizer.save_pretrained(path)
548
-
549
- # Save metadata including information about the loss function
550
- metadata = self.metadata.copy()
551
- if self.loss_fn is not None:
552
- metadata["loss_fn_class"] = self.loss_fn.__class__.__name__
553
- metadata["loss_fn_module"] = self.loss_fn.__class__.__module__
554
-
555
- with open(f"{path}/metadata.json", "w", encoding="utf8") as f:
556
- json.dump(metadata, f)
557
- with open(f"{path}/tokenizer.bin", "wb") as f:
558
- dill.dump(self.tokenizer, f)
559
- self.model.save_pretrained(
560
- f"{path}", safe_serialization=False
561
- ) # do not remove this line, used to save customized model scripts
562
-
563
- # Save complete state dict including all components
564
- with open(f"{path}/pytorch_model.bin", "wb") as f:
565
- torch.save(self.state_dict(), f)
566
-
567
- self.model.to(_dtype).to(_device)
568
- fprint(f"The model is saved to {path}.")
569
-
570
- def load(self, path, **kwargs):
571
- """
572
- Loads the model, tokenizer, and metadata from a directory.
573
-
574
- :param path: The directory to load the model from.
575
- :param kwargs: Additional arguments.
576
- :return: The loaded model instance.
577
- """
578
- with open(f"{path}/metadata.json", "r", encoding="utf8") as f:
579
- metadata = json.load(f)
580
-
581
- if metadata["model_cls"] != self.__class__.__name__: # Check the model class
582
- raise ValueError(
583
- f"The model class in the loaded model is {metadata['model_cls']}, "
584
- f"but the current model class is {self.__class__.__name__}."
585
- )
586
- config = AutoConfig.from_pretrained(path, trust_remote_code=True, **kwargs)
587
-
588
- for key, value in config.__dict__.items():
589
- if key not in self.config.__dict__ or self.config.__dict__[key] != value:
590
- fprint(
591
- f"Warning: The value of the key {key} in the loaded model is {value}, "
592
- f"but the current value is {self.config.__dict__.get(key, None)}."
593
- )
594
-
595
- # Attempt to restore any saved loss function
596
- if "loss_fn_class" in metadata and "loss_fn_module" in metadata:
597
- try:
598
- loss_module = import_module(metadata["loss_fn_module"])
599
- loss_class = getattr(loss_module, metadata["loss_fn_class"])
600
- # Initialize loss function if possible (parameters will be loaded with state dict)
601
- self.loss_fn = loss_class()
602
- fprint(
603
- f"Restored loss function: {metadata['loss_fn_class']} from {metadata['loss_fn_module']}"
604
- )
605
- except (ImportError, AttributeError) as e:
606
- warnings.warn(f"Could not restore loss function: {e}")
607
-
608
- with open(f"{path}/pytorch_model.bin", "rb") as f:
609
- loaded_state_dict = torch.load(f, map_location=kwargs.get("device", "cpu"))
610
-
611
- # Check if keys match between current and loaded state dict
612
- current_keys = set(self.state_dict().keys())
613
- loaded_keys = set(loaded_state_dict.keys())
614
- missing_keys = current_keys - loaded_keys
615
- unexpected_keys = loaded_keys - current_keys
616
-
617
- if missing_keys:
618
- warnings.warn(f"Missing keys in loaded weights: {missing_keys}")
619
- if unexpected_keys:
620
- warnings.warn(f"Unexpected keys in loaded weights: {unexpected_keys}")
621
-
622
- self.load_state_dict(loaded_state_dict, strict=False)
623
- # Load the tokenizer
624
- if os.path.exists(f"{path}/tokenizer.bin"):
625
- with open(f"{path}/tokenizer.bin", "rb") as f:
626
- self.tokenizer = dill.load(f)
627
-
628
- return self
629
-
630
- def _forward_from_raw_input(self, sequence_or_inputs, **kwargs):
631
- """
632
- Tokenizes raw input and performs a forward pass in no_grad mode.
633
-
634
- :param sequence_or_inputs: A sequence, list of sequences, or tokenized inputs.
635
- :param kwargs: Additional arguments for tokenization.
636
- :return: A dictionary containing the raw model outputs and the tokenized inputs.
637
- """
638
- if not isinstance(sequence_or_inputs, BatchEncoding) and not isinstance(
639
- sequence_or_inputs, dict
640
- ):
641
- inputs = self.tokenizer(
642
- sequence_or_inputs,
643
- padding=kwargs.pop("padding", True),
644
- max_length=kwargs.pop("max_length", 1024),
645
- truncation=kwargs.pop("truncation", True),
646
- return_tensors=kwargs.pop("return_tensors", "pt"),
647
- **kwargs,
648
- )
649
- else:
650
- inputs = sequence_or_inputs
651
- inputs = inputs.to(self.model.device)
652
- with torch.no_grad():
653
- raw_outputs = self(**inputs)
654
- raw_outputs["inputs"] = inputs
655
- return raw_outputs
656
-
657
- @staticmethod
658
- def from_pretrained(model_name_or_path, tokenizer, *args, **kwargs):
659
- """
660
- Loads a pre-trained model and tokenizer.
661
-
662
- :param model_name_or_path: The name or path of the pre-trained model.
663
- :param tokenizer: The tokenizer to use.
664
- :param args: Additional positional arguments.
665
- :param kwargs: Additional keyword arguments.
666
- :return: An instance of `OmniModel`.
667
- """
668
- config = kwargs.pop("config", None)
669
- if config is None:
670
- config = AutoConfig.from_pretrained(model_name_or_path, **kwargs)
671
- base_model = AutoModel.from_pretrained(model_name_or_path, **kwargs)
672
- if tokenizer is None:
673
- tokenizer = AutoTokenizer.from_pretrained(base_model, **kwargs)
674
- return OmniModel(config, base_model, tokenizer, *args, **kwargs)
675
-
676
- def model_info(self):
677
- """
678
- Prints and returns detailed information about the model.
679
-
680
- :return: A string containing the model information.
681
- """
682
- info = f"Model Name: {self.__class__.__name__}\n"
683
- info += f"Model Metadata: {self.metadata}\n"
684
- info += f"Base Model Name: {self.config.name_or_path}\n"
685
- info += f"Model Type: {self.config.model_type}\n"
686
- info += f"Model Architecture: {self.config.architectures}\n"
687
- info += f"Model Parameters: {count_parameters(self.model) / 1e6} M\n"
688
- info += f"Model Config: {self.config}\n"
689
- fprint(info)
690
- return info