omnigenome 0.3.0a1__py3-none-any.whl → 0.3.3a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of omnigenome might be problematic. Click here for more details.

Files changed (79) hide show
  1. omnigenome/__init__.py +252 -258
  2. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.3a0.dist-info}/METADATA +10 -10
  3. omnigenome-0.3.3a0.dist-info/RECORD +7 -0
  4. omnigenome/auto/__init__.py +0 -3
  5. omnigenome/auto/auto_bench/__init__.py +0 -12
  6. omnigenome/auto/auto_bench/auto_bench.py +0 -484
  7. omnigenome/auto/auto_bench/auto_bench_cli.py +0 -230
  8. omnigenome/auto/auto_bench/auto_bench_config.py +0 -216
  9. omnigenome/auto/auto_bench/config_check.py +0 -34
  10. omnigenome/auto/auto_train/__init__.py +0 -13
  11. omnigenome/auto/auto_train/auto_train.py +0 -430
  12. omnigenome/auto/auto_train/auto_train_cli.py +0 -222
  13. omnigenome/auto/bench_hub/__init__.py +0 -12
  14. omnigenome/auto/bench_hub/bench_hub.py +0 -25
  15. omnigenome/cli/__init__.py +0 -13
  16. omnigenome/cli/commands/__init__.py +0 -13
  17. omnigenome/cli/commands/base.py +0 -83
  18. omnigenome/cli/commands/bench/__init__.py +0 -13
  19. omnigenome/cli/commands/bench/bench_cli.py +0 -202
  20. omnigenome/cli/commands/rna/__init__.py +0 -13
  21. omnigenome/cli/commands/rna/rna_design.py +0 -178
  22. omnigenome/cli/omnigenome_cli.py +0 -128
  23. omnigenome/src/__init__.py +0 -12
  24. omnigenome/src/abc/__init__.py +0 -12
  25. omnigenome/src/abc/abstract_dataset.py +0 -622
  26. omnigenome/src/abc/abstract_metric.py +0 -114
  27. omnigenome/src/abc/abstract_model.py +0 -689
  28. omnigenome/src/abc/abstract_tokenizer.py +0 -267
  29. omnigenome/src/dataset/__init__.py +0 -16
  30. omnigenome/src/dataset/omni_dataset.py +0 -435
  31. omnigenome/src/lora/__init__.py +0 -13
  32. omnigenome/src/lora/lora_model.py +0 -294
  33. omnigenome/src/metric/__init__.py +0 -15
  34. omnigenome/src/metric/classification_metric.py +0 -184
  35. omnigenome/src/metric/metric.py +0 -199
  36. omnigenome/src/metric/ranking_metric.py +0 -142
  37. omnigenome/src/metric/regression_metric.py +0 -191
  38. omnigenome/src/misc/__init__.py +0 -3
  39. omnigenome/src/misc/utils.py +0 -499
  40. omnigenome/src/model/__init__.py +0 -19
  41. omnigenome/src/model/augmentation/__init__.py +0 -12
  42. omnigenome/src/model/augmentation/model.py +0 -219
  43. omnigenome/src/model/classification/__init__.py +0 -12
  44. omnigenome/src/model/classification/model.py +0 -642
  45. omnigenome/src/model/embedding/__init__.py +0 -12
  46. omnigenome/src/model/embedding/model.py +0 -263
  47. omnigenome/src/model/mlm/__init__.py +0 -12
  48. omnigenome/src/model/mlm/model.py +0 -177
  49. omnigenome/src/model/module_utils.py +0 -232
  50. omnigenome/src/model/regression/__init__.py +0 -12
  51. omnigenome/src/model/regression/model.py +0 -786
  52. omnigenome/src/model/regression/resnet.py +0 -483
  53. omnigenome/src/model/rna_design/__init__.py +0 -12
  54. omnigenome/src/model/rna_design/model.py +0 -469
  55. omnigenome/src/model/seq2seq/__init__.py +0 -12
  56. omnigenome/src/model/seq2seq/model.py +0 -44
  57. omnigenome/src/tokenizer/__init__.py +0 -16
  58. omnigenome/src/tokenizer/bpe_tokenizer.py +0 -226
  59. omnigenome/src/tokenizer/kmers_tokenizer.py +0 -247
  60. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +0 -249
  61. omnigenome/src/trainer/__init__.py +0 -14
  62. omnigenome/src/trainer/accelerate_trainer.py +0 -739
  63. omnigenome/src/trainer/hf_trainer.py +0 -75
  64. omnigenome/src/trainer/trainer.py +0 -579
  65. omnigenome/utility/__init__.py +0 -3
  66. omnigenome/utility/dataset_hub/__init__.py +0 -13
  67. omnigenome/utility/dataset_hub/dataset_hub.py +0 -178
  68. omnigenome/utility/ensemble.py +0 -324
  69. omnigenome/utility/hub_utils.py +0 -517
  70. omnigenome/utility/model_hub/__init__.py +0 -12
  71. omnigenome/utility/model_hub/model_hub.py +0 -231
  72. omnigenome/utility/pipeline_hub/__init__.py +0 -12
  73. omnigenome/utility/pipeline_hub/pipeline.py +0 -483
  74. omnigenome/utility/pipeline_hub/pipeline_hub.py +0 -129
  75. omnigenome-0.3.0a1.dist-info/RECORD +0 -78
  76. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.3a0.dist-info}/WHEEL +0 -0
  77. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.3a0.dist-info}/entry_points.txt +0 -0
  78. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.3a0.dist-info}/licenses/LICENSE +0 -0
  79. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.3a0.dist-info}/top_level.txt +0 -0
@@ -1,689 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # file: omnigenome_model.py
3
- # time: 18:36 06/04/2024
4
- # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
- # github: https://github.com/yangheng95
6
- # huggingface: https://huggingface.co/yangheng
7
- # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
- # Copyright (C) 2019-2024. All Rights Reserved.
9
- import json
10
- import os
11
- import shutil
12
- import warnings
13
- import inspect
14
- from importlib import import_module
15
-
16
- import dill
17
- import findfile
18
- import torch
19
- from transformers import AutoModel, AutoConfig, AutoTokenizer, BatchEncoding
20
-
21
- from ..misc.utils import fprint, env_meta_info
22
-
23
- warnings.filterwarnings("once")
24
-
25
-
26
- def count_parameters(model):
27
- """
28
- Counts the number of trainable parameters in a model.
29
-
30
- This function iterates through all parameters of a PyTorch model and counts
31
- only those that require gradients (i.e., trainable parameters).
32
-
33
- Args:
34
- model (torch.nn.Module): A PyTorch model.
35
-
36
- Returns:
37
- int: The total number of trainable parameters.
38
-
39
- Example:
40
- >>> model = OmniModelForSequenceClassification(config, tokenizer)
41
- >>> num_params = count_parameters(model)
42
- >>> print(f"Model has {num_params} trainable parameters")
43
- """
44
- return sum(p.numel() for p in model.parameters() if p.requires_grad)
45
-
46
-
47
- class OmniModel(torch.nn.Module):
48
- """
49
- Abstract base class for all models in OmniGenome.
50
-
51
- This class provides a unified interface for all genomic models in the OmniGenome
52
- framework. It handles model initialization, forward passes, loss computation,
53
- prediction, inference, and model persistence.
54
-
55
- The class is designed to work with various types of genomic data and tasks,
56
- including sequence classification, token classification, regression, and more.
57
-
58
- Attributes:
59
- model (torch.nn.Module): The underlying PyTorch model.
60
- config: The model configuration.
61
- tokenizer: The tokenizer associated with the model.
62
- metadata (dict): Metadata about the model including version info.
63
- loss_fn: The loss function for training.
64
- dropout (torch.nn.Dropout): Dropout layer for regularization.
65
- activation (torch.nn.Tanh): Activation function.
66
- pad_token_id (int): ID of the padding token.
67
- """
68
-
69
- def __init__(self, config_or_model, tokenizer, *args, **kwargs):
70
- """
71
- Initializes the model.
72
-
73
- This method handles different types of model initialization:
74
- - From a pre-trained model path (string)
75
- - From a PyTorch model instance
76
- - From a configuration object
77
-
78
- Args:
79
- config_or_model: A model configuration, a pre-trained model path (str),
80
- or a `torch.nn.Module` instance.
81
- tokenizer: The tokenizer associated with the model.
82
- *args: Additional positional arguments.
83
- **kwargs: Additional keyword arguments.
84
- - label2id (dict): Mapping from class labels to IDs.
85
- - num_labels (int): The number of labels.
86
- - trust_remote_code (bool): Whether to trust remote code when loading
87
- from Hugging Face Hub. Defaults to True.
88
- - ignore_mismatched_sizes (bool): Whether to ignore size mismatches
89
- when loading pre-trained weights. Defaults to False.
90
- - dropout (float): Dropout rate. Defaults to 0.0.
91
-
92
- Raises:
93
- ValueError: If config_or_model is not a valid type or if required
94
- configuration is missing.
95
- RuntimeError: If the hidden size cannot be determined from the config.
96
-
97
- Example:
98
- >>> # Initialize from a pre-trained model
99
- >>> model = OmniModelForSequenceClassification("model_path", tokenizer)
100
-
101
- >>> # Initialize from a configuration
102
- >>> config = AutoConfig.from_pretrained("model_path")
103
- >>> model = OmniModelForSequenceClassification(config, tokenizer)
104
- """
105
- self.loss_fn = None
106
-
107
- label2id = kwargs.pop("label2id", None)
108
- trust_remote_code = kwargs.pop("trust_remote_code", True)
109
- num_labels = kwargs.pop("num_labels", None)
110
- ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
111
-
112
- if label2id is not None and num_labels is None:
113
- num_labels = len(label2id)
114
- elif num_labels is not None and label2id is None:
115
- label2id = {str(i): i for i in range(num_labels)}
116
-
117
- # do not change the order of the following lines
118
- super().__init__(*args, **kwargs)
119
-
120
- if isinstance(config_or_model, str):
121
- config = AutoConfig.from_pretrained(
122
- config_or_model,
123
- num_labels=num_labels,
124
- label2id=label2id,
125
- trust_remote_code=trust_remote_code,
126
- )
127
- # Load the model from either `architectures` or `auto_map`
128
- if hasattr(config, "auto_map") and config.auto_map:
129
- architectures = list(set(config.auto_map.keys()) - set(["AutoConfig"]))
130
- if architectures:
131
- model_cls_name = (
132
- "AutoModel"
133
- if "AutoModel" in architectures
134
- else architectures[-1]
135
- )
136
- model_cls = getattr(import_module(f"transformers"), model_cls_name)
137
-
138
- model = model_cls.from_pretrained(
139
- config_or_model,
140
- config=config,
141
- trust_remote_code=trust_remote_code,
142
- ignore_mismatched_sizes=ignore_mismatched_sizes,
143
- ).base_model
144
- else:
145
- raise ValueError(
146
- f"The model cannot be instantiated from {config_or_model}. "
147
- f"Please check the model configuration contains the architectures or auto_map."
148
- )
149
- elif hasattr(config, "architectures") and config.architectures:
150
- model_cls_name = (
151
- AutoModel
152
- if "AutoModel" in config.architectures
153
- else config.architectures[-1]
154
- )
155
- model_cls = getattr(import_module(f"transformers"), model_cls_name)
156
- model = model_cls.from_pretrained(
157
- config_or_model,
158
- config=config,
159
- trust_remote_code=trust_remote_code,
160
- ignore_mismatched_sizes=ignore_mismatched_sizes,
161
- ).base_model
162
- else:
163
- raise ValueError(
164
- "Neither `architectures` nor `auto_map` is defined in the config."
165
- )
166
- self.model = model
167
- self.model.config = config
168
- del model_cls
169
- elif isinstance(config_or_model, torch.nn.Module):
170
- self.model = config_or_model
171
- self.model.config.num_labels = (
172
- num_labels if len(label2id) == num_labels else len(label2id)
173
- )
174
- self.model.config.label2id = label2id
175
- elif isinstance(config_or_model, AutoConfig):
176
- config = config_or_model
177
- config.num_labels = (
178
- num_labels if len(label2id) == num_labels else len(label2id)
179
- )
180
- config.label2id = label2id
181
- self.model = AutoModel.from_config(config)
182
- self.model.config = config
183
- else:
184
- raise ValueError(
185
- "The config_or_model should be either a string, a torch.nn.Module or a AutoConfig object."
186
- )
187
-
188
- # Update the config
189
- self.config = self.model.config
190
- if isinstance(label2id, dict):
191
- self.config.label2id = label2id
192
- self.config.id2label = {v: k for k, v in label2id.items()}
193
- if (
194
- not hasattr(self.config, "num_labels")
195
- or len(self.config.id2label) != self.config.num_labels
196
- ):
197
- fprint(
198
- "Warning: The number of labels in the config is not equal to the number of labels in the label2id dictionary. "
199
- )
200
- fprint(
201
- "Please check the label2id dictionary and the num_labels parameter in the config."
202
- )
203
- self.config.num_labels = len(self.config.id2label)
204
-
205
- assert len(self.config.label2id) == num_labels, f"Expected {num_labels} labels, but got {len(self.config.label2id)} in label2id dictionary."
206
-
207
- # The metadata of the model
208
- self.metadata = env_meta_info()
209
- self.metadata["model_cls"] = self.__class__.__name__
210
-
211
- # The config of the model
212
- if hasattr(self.config, "n_embd") and self.config.n_embd:
213
- self.config.hidden_size = self.config.n_embd
214
- elif hasattr(self.config, "d_model") and self.config.d_model:
215
- self.config.hidden_size = self.config.d_model
216
- elif hasattr(self.config, "hidden_size") and self.config.hidden_size:
217
- self.config.hidden_size = self.config.hidden_size
218
- else:
219
- raise RuntimeError(
220
- "The hidden size of the model is not found in the config."
221
- )
222
-
223
- # The tokenizer of the model
224
- self.tokenizer = tokenizer
225
- self.metadata["tokenizer_cls"] = self.tokenizer.__class__.__name__
226
- if hasattr(self.tokenizer, "base_tokenizer"):
227
- self.pad_token_id = self.tokenizer.base_tokenizer.pad_token_id
228
- else:
229
- self.pad_token_id = self.tokenizer.pad_token_id
230
-
231
- self.dropout = torch.nn.Dropout(kwargs.get("dropout", 0.0))
232
- self.activation = torch.nn.Tanh()
233
-
234
- def last_hidden_state_forward(self, **inputs):
235
- """
236
- Performs a forward pass to get the last hidden state from the base model.
237
-
238
- This method handles the forward pass through the underlying model and
239
- returns the last hidden state. It also handles compatibility with different
240
- model architectures by mapping input parameters appropriately.
241
-
242
- Args:
243
- **inputs: The inputs to the model, compatible with the base model's
244
- forward method. Typically includes 'input_ids', 'attention_mask',
245
- and other model-specific parameters.
246
-
247
- Returns:
248
- torch.Tensor: The last hidden state tensor.
249
-
250
- Example:
251
- >>> inputs = {
252
- ... 'input_ids': torch.tensor([[1, 2, 3, 4]]),
253
- ... 'attention_mask': torch.tensor([[1, 1, 1, 1]])
254
- ... }
255
- >>> hidden_states = model.last_hidden_state_forward(**inputs)
256
- """
257
- model = self.model
258
- input_mapping = {}
259
- inputs["output_hidden_states"] = True
260
-
261
- if "strippedhyena" in model.__class__.__name__.lower():
262
- inputs["x"] = inputs["input_ids"] # For compatibility with Evo models
263
- if isinstance(inputs, BatchEncoding) or isinstance(inputs, dict):
264
- # Determine the input parameter names of the model's forward method
265
- forward_params = inspect.signature(model.forward).parameters
266
- # Map the inputs to the forward method parameters
267
- for param in forward_params:
268
- if param in inputs:
269
- input_mapping[param] = inputs[param]
270
- # 对于未在模型签名中声明的关键参数,可以给出警告或日志
271
- ignored_keys = set(inputs.keys()) - set(input_mapping.keys())
272
- if ignored_keys:
273
- warnings.warn(f"Warning: Ignored keys in inputs: {ignored_keys}")
274
-
275
- inputs = input_mapping
276
- elif isinstance(inputs, tuple):
277
- input_ids = inputs[0]
278
- attention_mask = inputs[1] if len(inputs) > 1 else None
279
- inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
280
- elif isinstance(inputs, torch.Tensor):
281
- shape = inputs.shape
282
- try:
283
- if len(shape) == 3:
284
- if shape[1] == 2:
285
- input_ids = inputs[:, 0]
286
- attention_mask = inputs[:, 1]
287
- else:
288
- input_ids = inputs[0]
289
- attention_mask = inputs[1] if len(inputs) > 1 else None
290
- elif len(shape) == 2:
291
- input_ids = inputs
292
- attention_mask = None
293
- else:
294
- raise ValueError(
295
- f"Failed to get the input_ids and attention_mask from the inputs, got shape {shape}."
296
- )
297
- except:
298
- raise ValueError(
299
- f"Failed to get the input_ids and attention_mask from the inputs, got shape {shape}."
300
- )
301
- inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
302
- else:
303
- raise ValueError(
304
- f"The inputs should be a tuple, BatchEncoding or a dictionary-like object, got {type(inputs)}."
305
- )
306
-
307
- # 执行模型
308
- outputs = model(**inputs)
309
-
310
- if not hasattr(outputs, "last_hidden_state"):
311
- warnings.warn(
312
- f"last_hidden_state not found in the outputs from the {model.__class__.__name__} model."
313
- )
314
-
315
- if hasattr(outputs, "last_hidden_state"):
316
- last_hidden_state = outputs.last_hidden_state
317
- elif isinstance(outputs, dict) and "last_hidden_state" in outputs:
318
- last_hidden_state = outputs["last_hidden_state"]
319
- elif hasattr(outputs, "hidden_states"):
320
- last_hidden_state = outputs.hidden_states[-1]
321
- elif isinstance(outputs, (list, tuple, torch.Tensor)):
322
- if len(outputs) <= 2:
323
- # For Evo models that return a tuple of (last_hidden_state, logits)
324
- last_hidden_state = outputs[0]
325
- elif len(outputs) >= 3:
326
- last_hidden_state = outputs[-1]
327
- else:
328
- raise ValueError(
329
- f"Cannot find the last hidden state in the outputs from the {model.__class__.__name__} model, "
330
- f"please check the model architecture."
331
- )
332
-
333
- return last_hidden_state
334
-
335
- def loss_function(self, logits, labels):
336
- """
337
- Calculates the loss. Must be implemented by subclasses.
338
-
339
- This method should be implemented by concrete model classes to define
340
- how the loss is calculated for their specific task (classification,
341
- regression, etc.).
342
-
343
- Args:
344
- logits (torch.Tensor): The model's output logits.
345
- labels (torch.Tensor): The ground truth labels.
346
-
347
- Returns:
348
- torch.Tensor: The calculated loss.
349
-
350
- Raises:
351
- NotImplementedError: If the method is not implemented by the subclass.
352
-
353
- Example:
354
- >>> # In a classification model
355
- >>> loss = model.loss_function(logits, labels)
356
- """
357
- raise NotImplementedError(
358
- "The loss_function() function should be implemented for your model."
359
- )
360
-
361
- def set_loss_fn(self, loss_function):
362
- """
363
- Sets a custom loss function for the model.
364
-
365
- This method allows setting a custom loss function that will be used
366
- during training. The loss function should be compatible with the
367
- model's output format.
368
-
369
- Args:
370
- loss_function (callable): A callable loss function that takes
371
- logits and labels as arguments.
372
-
373
- Example:
374
- >>> import torch.nn as nn
375
- >>> model.set_loss_fn(nn.CrossEntropyLoss())
376
- """
377
- self.loss_fn = loss_function
378
-
379
- def predict(self, sequence_or_inputs, **kwargs):
380
- """
381
- Performs prediction on raw inputs. Returns raw model outputs.
382
-
383
- This method takes raw sequences or tokenized inputs and returns
384
- the raw model outputs (logits, hidden states, etc.) without
385
- post-processing. It's useful for getting the model's direct
386
- predictions for further processing.
387
-
388
- Args:
389
- sequence_or_inputs: A sequence (str), list of sequences, or
390
- tokenized inputs (dict/tuple).
391
- **kwargs: Additional arguments for tokenization and inference.
392
-
393
- Returns:
394
- dict: A dictionary containing the raw model outputs, typically
395
- including 'logits', 'last_hidden_state', and other
396
- model-specific outputs.
397
-
398
- Example:
399
- >>> # Predict on a single sequence
400
- >>> outputs = model.predict("ATCGATCG")
401
-
402
- >>> # Predict on multiple sequences
403
- >>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"])
404
- """
405
- # Please implement the predict() function for your model
406
- raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
407
- return raw_outputs
408
-
409
- def inference(self, sequence_or_inputs, **kwargs):
410
- """
411
- Performs inference on raw inputs. Returns processed, human-readable predictions.
412
-
413
- This method takes raw sequences or tokenized inputs and returns
414
- processed predictions that are ready for human consumption. It
415
- typically includes post-processing steps like converting logits
416
- to class labels or probabilities.
417
-
418
- Args:
419
- sequence_or_inputs: A sequence (str), list of sequences, or
420
- tokenized inputs (dict/tuple).
421
- **kwargs: Additional arguments for tokenization and inference.
422
-
423
- Returns:
424
- dict: A dictionary containing the processed predictions, typically
425
- including 'predictions', 'confidence', and other
426
- human-readable outputs.
427
-
428
- Example:
429
- >>> # Inference on a single sequence
430
- >>> results = model.inference("ATCGATCG")
431
- >>> print(results['predictions']) # Class labels
432
-
433
- >>> # Inference on multiple sequences
434
- >>> results = model.inference(["ATCGATCG", "GCTAGCTA"])
435
- """
436
- # Please implement the predict() function for your model
437
- raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
438
- return raw_outputs
439
-
440
- def __call__(self, **inputs):
441
- """
442
- The main forward pass of the model, suitable for training loops.
443
-
444
- This method is the primary interface for model forward passes during
445
- training. It handles both tokenized inputs and raw sequences,
446
- calculates loss if labels are provided, and returns a comprehensive
447
- output dictionary.
448
-
449
- Args:
450
- **inputs: A dictionary of tokenized inputs, potentially including
451
- labels. Can also handle raw sequences that will be
452
- tokenized automatically.
453
-
454
- Returns:
455
- dict: A dictionary containing logits, last_hidden_state, labels,
456
- and loss (if labels were provided).
457
-
458
- Example:
459
- >>> # Training forward pass
460
- >>> outputs = model(
461
- ... input_ids=torch.tensor([[1, 2, 3, 4]]),
462
- ... attention_mask=torch.tensor([[1, 1, 1, 1]]),
463
- ... labels=torch.tensor([0])
464
- ... )
465
- >>> loss = outputs['loss']
466
- """
467
- # For transformer trainer integration, we need to pop the "inputs" to be a tokenized inputs object.
468
- # For native trainer, the inputs are already tokenized inputs object
469
- labels = inputs.pop("labels", None)
470
- inputs = inputs.pop("inputs", inputs)
471
- inputs["labels"] = labels
472
- if isinstance(inputs, dict):
473
-
474
- labels = inputs.get("labels", None)
475
- label = inputs.get("label", None)
476
- labels = labels if labels is not None else label
477
- # if labels is None:
478
- # warnings.warn(
479
- # "No labels are provided in the inputs, the model will not calculate the loss."
480
- # )
481
- elif isinstance(inputs, tuple):
482
- labels = inputs[1]
483
- inputs = inputs[0]
484
- elif labels is not None:
485
- labels = labels
486
- outputs = self.forward(**inputs)
487
-
488
- if labels is not None:
489
- outputs["loss"] = self._calculate_loss(outputs, labels)
490
- else:
491
- outputs["loss"] = None
492
- return outputs
493
-
494
- def _calculate_loss(self, outputs, labels):
495
- """
496
- Internal method to calculate loss if not already present in outputs.
497
-
498
- :param outputs: The dictionary of model outputs.
499
- :param labels: The ground truth labels.
500
- :return: The calculated loss.
501
- """
502
- loss = outputs.get("loss", None)
503
- if loss is not None:
504
- return outputs
505
-
506
- logits = outputs["logits"]
507
- if logits is not None or labels is not None:
508
- loss = self.loss_function(logits, labels)
509
- return loss
510
- else:
511
- raise RuntimeError(
512
- "The output of the forward() function should be a dictionary-like objective"
513
- " and have either 'loss', or 'logits' and 'labels' attribute."
514
- )
515
-
516
- def save(self, path, overwrite=False, dtype=torch.float16, **kwargs):
517
- """
518
- Saves the model, tokenizer, and metadata to a directory.
519
-
520
- :param path: The directory to save the model to.
521
- :param overwrite: Whether to overwrite the directory if it exists.
522
- :param dtype: The data type to save the model weights in.
523
- :param kwargs: Additional arguments.
524
- """
525
- self.eval()
526
-
527
- if os.path.exists(path) and not overwrite:
528
- raise FileExistsError(
529
- f"The path {path} already exists, please set overwrite=True to overwrite it."
530
- )
531
-
532
- if not os.path.exists(path):
533
- os.makedirs(path)
534
-
535
- for file in findfile.find_files(
536
- self.config.name_or_path,
537
- or_key=["bin", "json", "txt", "py"],
538
- exclude_key=["pytorch_model.bin", "model.safetensors"],
539
- ):
540
- shutil.copyfile(file, f"{path}/{os.path.basename(file)}")
541
-
542
- _device = self.model.device
543
- _dtype = self.model.dtype
544
- self.model.to(dtype).to("cpu")
545
- self.tokenizer.save_pretrained(path)
546
-
547
- # Save metadata including information about the loss function
548
- metadata = self.metadata.copy()
549
- if self.loss_fn is not None:
550
- metadata["loss_fn_class"] = self.loss_fn.__class__.__name__
551
- metadata["loss_fn_module"] = self.loss_fn.__class__.__module__
552
-
553
- with open(f"{path}/metadata.json", "w", encoding="utf8") as f:
554
- json.dump(metadata, f)
555
- with open(f"{path}/tokenizer.bin", "wb") as f:
556
- dill.dump(self.tokenizer, f)
557
- self.model.save_pretrained(
558
- f"{path}", safe_serialization=False
559
- ) # do not remove this line, used to save customized model scripts
560
-
561
- # Save complete state dict including all components
562
- with open(f"{path}/pytorch_model.bin", "wb") as f:
563
- torch.save(self.state_dict(), f)
564
-
565
- self.model.to(_dtype).to(_device)
566
- fprint(f"The model is saved to {path}.")
567
-
568
- def load(self, path, **kwargs):
569
- """
570
- Loads the model, tokenizer, and metadata from a directory.
571
-
572
- :param path: The directory to load the model from.
573
- :param kwargs: Additional arguments.
574
- :return: The loaded model instance.
575
- """
576
- with open(f"{path}/metadata.json", "r", encoding="utf8") as f:
577
- metadata = json.load(f)
578
-
579
- if metadata["model_cls"] != self.__class__.__name__: # Check the model class
580
- raise ValueError(
581
- f"The model class in the loaded model is {metadata['model_cls']}, "
582
- f"but the current model class is {self.__class__.__name__}."
583
- )
584
- config = AutoConfig.from_pretrained(path, trust_remote_code=True, **kwargs)
585
-
586
- for key, value in config.__dict__.items():
587
- if key not in self.config.__dict__ or self.config.__dict__[key] != value:
588
- fprint(
589
- f"Warning: The value of the key {key} in the loaded model is {value}, "
590
- f"but the current value is {self.config.__dict__.get(key, None)}."
591
- )
592
-
593
- # Attempt to restore any saved loss function
594
- if "loss_fn_class" in metadata and "loss_fn_module" in metadata:
595
- try:
596
- loss_module = import_module(metadata["loss_fn_module"])
597
- loss_class = getattr(loss_module, metadata["loss_fn_class"])
598
- # Initialize loss function if possible (parameters will be loaded with state dict)
599
- self.loss_fn = loss_class()
600
- fprint(
601
- f"Restored loss function: {metadata['loss_fn_class']} from {metadata['loss_fn_module']}"
602
- )
603
- except (ImportError, AttributeError) as e:
604
- warnings.warn(f"Could not restore loss function: {e}")
605
-
606
- with open(f"{path}/pytorch_model.bin", "rb") as f:
607
- loaded_state_dict = torch.load(f, map_location=kwargs.get("device", "cpu"))
608
-
609
- # Check if keys match between current and loaded state dict
610
- current_keys = set(self.state_dict().keys())
611
- loaded_keys = set(loaded_state_dict.keys())
612
- missing_keys = current_keys - loaded_keys
613
- unexpected_keys = loaded_keys - current_keys
614
-
615
- if missing_keys:
616
- warnings.warn(f"Missing keys in loaded weights: {missing_keys}")
617
- if unexpected_keys:
618
- warnings.warn(f"Unexpected keys in loaded weights: {unexpected_keys}")
619
-
620
- self.load_state_dict(loaded_state_dict, strict=False)
621
- # Load the tokenizer
622
- if os.path.exists(f"{path}/tokenizer.bin"):
623
- with open(f"{path}/tokenizer.bin", "rb") as f:
624
- self.tokenizer = dill.load(f)
625
-
626
- return self
627
-
628
- def _forward_from_raw_input(self, sequence_or_inputs, **kwargs):
629
- """
630
- Tokenizes raw input and performs a forward pass in no_grad mode.
631
-
632
- :param sequence_or_inputs: A sequence, list of sequences, or tokenized inputs.
633
- :param kwargs: Additional arguments for tokenization.
634
- :return: A dictionary containing the raw model outputs and the tokenized inputs.
635
- """
636
- if not isinstance(sequence_or_inputs, BatchEncoding) and not isinstance(
637
- sequence_or_inputs, dict
638
- ):
639
- inputs = self.tokenizer(
640
- sequence_or_inputs,
641
- padding=kwargs.pop("padding", True),
642
- max_length=kwargs.pop("max_length", 1024),
643
- truncation=kwargs.pop("truncation", True),
644
- return_tensors=kwargs.pop("return_tensors", "pt"),
645
- **kwargs,
646
- )
647
- else:
648
- inputs = sequence_or_inputs
649
- inputs = inputs.to(self.model.device)
650
- with torch.no_grad():
651
- raw_outputs = self(**inputs)
652
- raw_outputs["inputs"] = inputs
653
- return raw_outputs
654
-
655
- @staticmethod
656
- def from_pretrained(model_name_or_path, tokenizer, *args, **kwargs):
657
- """
658
- Loads a pre-trained model and tokenizer.
659
-
660
- :param model_name_or_path: The name or path of the pre-trained model.
661
- :param tokenizer: The tokenizer to use.
662
- :param args: Additional positional arguments.
663
- :param kwargs: Additional keyword arguments.
664
- :return: An instance of `OmniModel`.
665
- """
666
- config = kwargs.pop("config", None)
667
- if config is None:
668
- config = AutoConfig.from_pretrained(model_name_or_path, **kwargs)
669
- base_model = AutoModel.from_pretrained(model_name_or_path, **kwargs)
670
- if tokenizer is None:
671
- tokenizer = AutoTokenizer.from_pretrained(base_model, **kwargs)
672
- return OmniModel(config, base_model, tokenizer, *args, **kwargs)
673
-
674
- def model_info(self):
675
- """
676
- Prints and returns detailed information about the model.
677
-
678
- :return: A string containing the model information.
679
- """
680
- info = f"Model Name: {self.__class__.__name__}\n"
681
- info += f"Model Metadata: {self.metadata}\n"
682
- info += f"Base Model Name: {self.config.name_or_path}\n"
683
- info += f"Model Type: {self.config.model_type}\n"
684
- info += f"Model Architecture: {self.config.architectures}\n"
685
- info += f"Model Parameters: {count_parameters(self.model) / 1e6} M\n"
686
- info += f"Model Config: {self.config}\n"
687
- fprint(info)
688
- return info
689
-