fusion-bench 0.2.22__py3-none-any.whl → 0.2.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. fusion_bench/__init__.py +4 -0
  2. fusion_bench/compat/method/__init__.py +5 -2
  3. fusion_bench/compat/method/base_algorithm.py +3 -2
  4. fusion_bench/compat/modelpool/base_pool.py +3 -3
  5. fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
  6. fusion_bench/dataset/gpt2_glue.py +1 -1
  7. fusion_bench/method/__init__.py +4 -2
  8. fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
  9. fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
  10. fusion_bench/method/bitdelta/bitdelta.py +7 -23
  11. fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
  12. fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
  13. fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
  14. fusion_bench/method/model_stock/__init__.py +1 -0
  15. fusion_bench/method/model_stock/model_stock.py +309 -0
  16. fusion_bench/method/regmean/clip_regmean.py +3 -6
  17. fusion_bench/method/regmean/regmean.py +27 -56
  18. fusion_bench/method/regmean/utils.py +56 -0
  19. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
  20. fusion_bench/method/slerp/__init__.py +1 -1
  21. fusion_bench/method/slerp/slerp.py +110 -14
  22. fusion_bench/method/we_moe/flan_t5_we_moe.py +9 -20
  23. fusion_bench/mixins/clip_classification.py +26 -6
  24. fusion_bench/mixins/serialization.py +25 -15
  25. fusion_bench/modelpool/base_pool.py +1 -1
  26. fusion_bench/modelpool/causal_lm/causal_lm.py +262 -43
  27. fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
  28. fusion_bench/models/hf_utils.py +9 -4
  29. fusion_bench/models/linearized/vision_model.py +6 -6
  30. fusion_bench/models/modeling_smile_mistral/__init__.py +1 -0
  31. fusion_bench/models/we_moe.py +8 -8
  32. fusion_bench/taskpool/base_pool.py +99 -17
  33. fusion_bench/taskpool/clip_vision/taskpool.py +1 -1
  34. fusion_bench/taskpool/dummy.py +101 -13
  35. fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
  36. fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
  37. fusion_bench/utils/__init__.py +1 -0
  38. fusion_bench/utils/data.py +6 -4
  39. fusion_bench/utils/devices.py +7 -4
  40. fusion_bench/utils/dtype.py +3 -2
  41. fusion_bench/utils/lazy_state_dict.py +82 -19
  42. fusion_bench/utils/packages.py +3 -3
  43. fusion_bench/utils/parameters.py +0 -2
  44. fusion_bench/utils/timer.py +92 -10
  45. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/METADATA +1 -1
  46. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/RECORD +53 -47
  47. fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
  48. fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
  49. fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
  50. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/WHEEL +0 -0
  51. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/entry_points.txt +0 -0
  52. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/licenses/LICENSE +0 -0
  53. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  import functools
2
2
  import logging
3
- from typing import List
3
+ from typing import Generic, List
4
4
 
5
5
  import torch
6
6
  import torch.func
@@ -9,7 +9,7 @@ from torch.func import functional_call
9
9
  from torch.nn import functional as F
10
10
 
11
11
  from fusion_bench.models.utils import del_attr, get_attr, set_attr
12
- from fusion_bench.utils.type import StateDictType
12
+ from fusion_bench.utils.type import StateDictType, TorchModelType
13
13
 
14
14
  log = logging.getLogger(__name__)
15
15
 
@@ -76,15 +76,15 @@ def construct_weight_ensembling_gate(
76
76
  return gate
77
77
 
78
78
 
79
- class WeightEnsemblingMoE(nn.Module):
79
+ class WeightEnsemblingMoE(nn.Module, Generic[TorchModelType]):
80
80
  # variable to store the merged state dict temporarily
81
81
  _merged_state_dict: StateDictType = None
82
82
 
83
83
  def __init__(
84
84
  self,
85
85
  hidden_size: int,
86
- base_model: nn.Module,
87
- expert_models: List[nn.Module],
86
+ base_model: TorchModelType,
87
+ expert_models: List[TorchModelType],
88
88
  init_lambda: float = 0.2,
89
89
  batch_first: bool = False,
90
90
  router_hidden_layers: int = 2,
@@ -101,8 +101,8 @@ class WeightEnsemblingMoE(nn.Module):
101
101
  Args:
102
102
 
103
103
  hidden_size (int): The size of the hidden layer in the models.
104
- base_model (nn.Module): The base model that will be used as a reference for the expert models.
105
- expert_models (List[nn.Module]): A list of expert models that will be combined.
104
+ base_model (TorchModelType): The base model that will be used as a reference for the expert models.
105
+ expert_models (List[TorchModelType]): A list of expert models that will be combined.
106
106
  init_lambda (float, optional): The initial lambda value for the weight ensembling gate. Defaults to 0.2.
107
107
  batch_first (bool, optional): If True, the input tensors are expected to have the batch size as the first dimension. Defaults to False.
108
108
  router_hidden_layers (int, optional): The number of hidden layers in the router. Defaults to 2.
@@ -145,7 +145,7 @@ class WeightEnsemblingMoE(nn.Module):
145
145
  self._merged_state_dict,
146
146
  )
147
147
 
148
- def merge_weights(self, expert_weights):
148
+ def merge_weights(self, expert_weights) -> StateDictType:
149
149
  state_dict = self.base_model.state_dict(keep_vars=True)
150
150
  for weight, task_vector in zip(expert_weights, self.task_vectors):
151
151
  for name, param in task_vector.named_parameters():
@@ -5,33 +5,115 @@ from fusion_bench.mixins import BaseYAMLSerializable
5
5
 
6
6
 
7
7
  class BaseTaskPool(BaseYAMLSerializable):
8
+ """Abstract base class for task pools in the FusionBench framework.
9
+
10
+ A task pool represents a collection of evaluation tasks that can be used to
11
+ assess model performance across multiple benchmarks or datasets. This base
12
+ class defines the common interface that all task pool implementations must
13
+ follow, ensuring consistency across different task types and evaluation
14
+ scenarios.
15
+
16
+ Task pools are designed to be configurable through YAML files and can be
17
+ used in various model fusion and evaluation workflows. They provide a
18
+ standardized way to evaluate models on multiple tasks and aggregate results.
19
+
20
+ The class inherits from BaseYAMLSerializable to support configuration
21
+ management and serialization capabilities.
22
+
23
+ Attributes:
24
+ _program: Optional program reference for execution context.
25
+ _config_key: Configuration key used for YAML configuration ("taskpool").
26
+
27
+ Abstract Methods:
28
+ evaluate: Must be implemented by subclasses to define task-specific
29
+ evaluation logic.
30
+
31
+ Example:
32
+ Implementing a custom task pool:
33
+
34
+ ```python
35
+ class MyTaskPool(BaseTaskPool):
36
+
37
+
38
+ def evaluate(self, model, **kwargs):
39
+ results = {}
40
+ for task_name in self.tasks:
41
+ # Implement task-specific evaluation
42
+ results[task_name] = self._evaluate_task(model, task_name)
43
+ return results
44
+ ```
45
+ """
46
+
8
47
  _program = None
9
48
  _config_key = "taskpool"
10
49
 
11
50
  @abstractmethod
12
51
  def evaluate(self, model: Any, *args: Any, **kwargs: Any) -> Dict[str, Any]:
13
- """
14
- Evaluate the model on all tasks in the task pool, and return a report.
52
+ """Evaluate a model on all tasks in the task pool and return aggregated results.
15
53
 
16
- Take image classification as an example, the report will look like:
54
+ This abstract method defines the core evaluation interface that all task pool
55
+ implementations must provide. It should evaluate the given model on all tasks
56
+ managed by the pool and return a structured report of the results.
17
57
 
18
- ```python
19
- {
20
- "mnist": {
21
- "accuracy": 0.8,
22
- "loss": 0.2,
23
- },
24
- <task_name>: {
25
- <metric_name>: <metric_value>,
26
- ...
27
- },
28
- }
29
- ```
58
+ The evaluation process typically involves:
59
+ 1. Iterating through all tasks in the pool
60
+ 2. Running model inference on each task's dataset
61
+ 3. Computing task-specific metrics
62
+ 4. Aggregating results into a standardized report format
30
63
 
31
64
  Args:
32
- model: The model to evaluate.
65
+ model: The model to evaluate. Can be any model type (PyTorch model,
66
+ Hugging Face model, etc.) that is compatible with the specific
67
+ task pool implementation.
68
+ *args: Additional positional arguments that may be needed for
69
+ task-specific evaluation procedures.
70
+ **kwargs: Additional keyword arguments for evaluation configuration,
71
+ such as batch_size, device, evaluation metrics, etc.
33
72
 
34
73
  Returns:
35
- report (dict): A dictionary containing the results of the evaluation for each task.
74
+ Dict[str, Any]: A dictionary containing evaluation results for each task.
75
+ The structure follows the pattern:
76
+
77
+ ```python
78
+ {
79
+ "task_name_1": {
80
+ "metric_1": value,
81
+ "metric_2": value,
82
+ ...
83
+ },
84
+ "task_name_2": {
85
+ "metric_1": value,
86
+ "metric_2": value,
87
+ ...
88
+ },
89
+ ...
90
+ }
91
+ ```
92
+
93
+ Example:
94
+ For an image classification task pool:
95
+
96
+ ```python
97
+ results = task_pool.evaluate(model)
98
+ # Returns:
99
+ # {
100
+ # "mnist": {
101
+ # "accuracy": 0.95,
102
+ # "loss": 0.15,
103
+ # },
104
+ # "cifar10": {
105
+ # "accuracy": 0.87,
106
+ # "loss": 0.42,
107
+ # }
108
+ # }
109
+ ```
110
+
111
+ Raises:
112
+ NotImplementedError: This method must be implemented by subclasses.
113
+
114
+ Note:
115
+ Implementations should ensure that the returned dictionary structure
116
+ is consistent and that metric names are standardized across similar
117
+ task types to enable meaningful comparison and aggregation.
36
118
  """
37
119
  pass
@@ -309,7 +309,7 @@ class CLIPVisionModelTaskPool(
309
309
  self.setup()
310
310
 
311
311
  report = {}
312
- # CLIPVisionModel works the same with CLIPVisonTransformer, so we can use it directly
312
+ # CLIPVisionModel works the same with CLIPVisionTransformer, so we can use it directly
313
313
  if hasattr(model, "is_surgery_model") and model.is_surgery_model:
314
314
  log.info("running evaluation on a surgery model.")
315
315
  model: "SurgeryModelWrapper" = model
@@ -1,5 +1,10 @@
1
1
  """
2
- This is the dummy task pool that is used for debugging purposes.
2
+ Dummy task pool implementation for debugging and testing purposes.
3
+
4
+ This module provides a minimal task pool implementation that can be used for
5
+ debugging model fusion workflows, testing infrastructure, and validating model
6
+ architectures without running expensive evaluation procedures. It's particularly
7
+ useful during development and prototyping phases.
3
8
  """
4
9
 
5
10
  from typing import Optional
@@ -14,14 +19,41 @@ from fusion_bench.utils.parameters import count_parameters, print_parameters
14
19
 
15
20
 
16
21
  def get_model_summary(model: nn.Module) -> dict:
17
- """
18
- Generate a report for the given model.
22
+ """Generate a comprehensive summary report for a PyTorch model.
23
+
24
+ Analyzes the given model to extract key information about its architecture,
25
+ parameter count, and training characteristics. This function is useful for
26
+ model introspection and comparative analysis during model fusion workflows.
27
+
28
+ The summary includes both trainable and total parameter counts, which helps
29
+ in understanding model complexity and memory requirements. The trainable
30
+ percentage is particularly useful for identifying models with frozen layers
31
+ or parameter-efficient fine-tuning setups.
19
32
 
20
33
  Args:
21
- model: The model to generate the report for.
34
+ model: The PyTorch model to analyze. Can be any nn.Module instance
35
+ including complex models, fusion models, or pre-trained models.
22
36
 
23
37
  Returns:
24
- dict: The generated report.
38
+ dict: A structured report containing model information:
39
+ - model_info: Dictionary with parameter statistics
40
+ - trainable_params: Number of trainable parameters
41
+ - all_params: Total number of parameters (trainable + frozen)
42
+ - trainable_percentage: Ratio of trainable to total parameters
43
+
44
+ Example:
45
+ ```python
46
+ >>> model = MyModel()
47
+ >>> summary = get_model_summary(model)
48
+ >>> print(summary)
49
+ {
50
+ "model_info": {
51
+ "trainable_params": 1234567,
52
+ "all_params": 1234567,
53
+ "trainable_percentage": 1.0
54
+ }
55
+ }
56
+ ```
25
57
  """
26
58
  report = {}
27
59
  training_params, all_params = count_parameters(model)
@@ -34,21 +66,77 @@ def get_model_summary(model: nn.Module) -> dict:
34
66
 
35
67
 
36
68
  class DummyTaskPool(BaseTaskPool):
69
+ """A lightweight task pool implementation for debugging and development workflows.
70
+
71
+ This dummy task pool provides a minimal evaluation interface that focuses on
72
+ model introspection rather than task-specific performance evaluation. It's
73
+ designed for development scenarios where you need to test model fusion
74
+ pipelines, validate architectures, or debug workflows without the overhead
75
+ of running actual evaluation tasks.
76
+
77
+ The task pool is particularly useful when:
78
+ - You want to verify model fusion works correctly
79
+ - You need to check parameter counts after fusion
80
+ - You're developing new fusion algorithms
81
+ - You want to test infrastructure without expensive evaluations
82
+
83
+ Example:
84
+ ```python
85
+ >>> taskpool = DummyTaskPool(model_save_path="/tmp/fused_model")
86
+ >>> results = taskpool.evaluate(fused_model)
87
+ >>> print(f"Model has {results['model_info']['trainable_params']} parameters")
88
+ ```
37
89
  """
38
- This is a dummy task pool used for debugging purposes. It inherits from the base TaskPool class.
39
- """
40
90
 
41
- def __init__(self, model_save_path: Optional[str] = None):
42
- super().__init__()
91
+ def __init__(self, model_save_path: Optional[str] = None, **kwargs):
92
+ """Initialize the dummy task pool with optional model saving capability.
93
+
94
+ Args:
95
+ model_save_path: Optional path where the evaluated model should be saved.
96
+ If provided, the model will be serialized and saved to this location
97
+ after evaluation using the separate_save utility. If None, no model
98
+ saving will be performed.
99
+
100
+ Example:
101
+ ```python
102
+ >>> # Create taskpool without saving
103
+ >>> taskpool = DummyTaskPool()
104
+
105
+ >>> # Create taskpool with model saving
106
+ >>> taskpool = DummyTaskPool(model_save_path="/path/to/save/model.pth")
107
+ ```
108
+ """
109
+ super().__init__(**kwargs)
43
110
  self.model_save_path = model_save_path
44
111
 
45
112
  def evaluate(self, model):
46
- """
47
- Evaluate the given model.
48
- This method does nothing but print the parameters of the model in a human-readable format.
113
+ """Perform lightweight evaluation and analysis of the given model.
114
+
115
+ This method provides a minimal evaluation that focuses on model introspection
116
+ rather than task-specific performance metrics. It performs parameter analysis,
117
+ optionally saves the model, and returns a summary report.
118
+
119
+ The evaluation process includes:
120
+ 1. Printing human-readable parameter information (rank-zero only)
121
+ 2. Optionally saving the model if a save path was configured
122
+ 3. Generating and returning a model summary report
49
123
 
50
124
  Args:
51
- model: The model to evaluate.
125
+ model: The model to evaluate. Can be any PyTorch nn.Module including
126
+ fusion models, pre-trained models, or custom architectures.
127
+
128
+ Returns:
129
+ dict: A model summary report containing parameter statistics and
130
+ architecture information. See get_model_summary() for detailed
131
+ format specification.
132
+
133
+ Example:
134
+ ```python
135
+ >>> taskpool = DummyTaskPool(model_save_path="/tmp/model.pth")
136
+ >>> model = torch.nn.Linear(10, 5)
137
+ >>> results = taskpool.evaluate(model)
138
+ >>> print(f"Trainable params: {results['model_info']['trainable_params']}")
139
+ ```
52
140
  """
53
141
  if rank_zero_only.rank == 0:
54
142
  print_parameters(model, is_human_readable=True)
@@ -16,6 +16,47 @@ log = logging.getLogger(__name__)
16
16
 
17
17
 
18
18
  class LMEvalHarnessTaskPool(BaseTaskPool, LightningFabricMixin):
19
+ """A task pool implementation that interfaces with the LM Evaluation Harness framework.
20
+
21
+ This class provides a wrapper around the LM Evaluation Harness (lm-eval) library,
22
+ enabling evaluation of language models on various standardized benchmarks and tasks.
23
+ It inherits from BaseTaskPool and LightningFabricMixin to provide distributed
24
+ computing capabilities through PyTorch Lightning Fabric.
25
+
26
+ The task pool supports evaluation on multiple tasks simultaneously and provides
27
+ flexible configuration options for batch processing, output formatting, and
28
+ logging. It automatically handles model setup and wrapping for distributed
29
+ evaluation when using Lightning Fabric.
30
+
31
+ Args:
32
+ tasks: A single task name or list of task names to evaluate on.
33
+ Examples: "hellaswag", ["arc_easy", "arc_challenge", "hellaswag"]
34
+ apply_chat_template: Whether to apply chat template formatting to inputs.
35
+ Useful for instruction-tuned or chat models.
36
+ include_path: Path to additional task definitions or custom tasks.
37
+ batch_size: Number of samples to process in each batch. Larger values
38
+ may improve throughput but require more memory.
39
+ metadata: Additional metadata to include in evaluation results.
40
+ verbosity: Logging verbosity level for the evaluation process.
41
+ output_path: Custom path for saving evaluation results. If None,
42
+ results are saved to the default log directory.
43
+ log_samples: Whether to log individual sample predictions and targets.
44
+ Useful for debugging but increases output size significantly.
45
+ _usage_: Internal usage tracking string.
46
+ _version_: Internal version tracking string.
47
+ **kwargs: Additional arguments passed to the LM Evaluation Harness.
48
+
49
+ Example:
50
+ ```python
51
+ >>> taskpool = LMEvalHarnessTaskPool(
52
+ ... tasks=["arc_easy", "hellaswag"],
53
+ ... batch_size=8,
54
+ ... verbosity="INFO"
55
+ ... )
56
+ >>> results = taskpool.evaluate(model)
57
+ ```
58
+ """
59
+
19
60
  def __init__(
20
61
  self,
21
62
  tasks: Union[str, List[str]],
@@ -44,6 +85,45 @@ class LMEvalHarnessTaskPool(BaseTaskPool, LightningFabricMixin):
44
85
  self.log_samples = log_samples
45
86
 
46
87
  def evaluate(self, model, *command_line_args, **kwargs):
88
+ """Evaluate a language model on the configured tasks using LM Evaluation Harness.
89
+
90
+ This method wraps the model with the LM Evaluation Harness framework and
91
+ executes evaluation on all configured tasks. It automatically handles
92
+ command-line argument construction, model wrapping with Lightning Fabric
93
+ for distributed evaluation, and result logging.
94
+
95
+ The evaluation process includes:
96
+ 1. Building command-line arguments from instance configuration
97
+ 2. Setting up the LM Evaluation Harness argument parser
98
+ 3. Wrapping the model with Lightning Fabric if not already wrapped
99
+ 4. Creating an HFLM (Hugging Face Language Model) wrapper
100
+ 5. Executing the evaluation through the LM-Eval CLI interface
101
+
102
+ Args:
103
+ model: The language model to evaluate. Can be a Hugging Face model,
104
+ PyTorch model, or any model compatible with the LM Evaluation Harness.
105
+ The model will be automatically wrapped with Lightning Fabric for
106
+ distributed evaluation if not already wrapped.
107
+ *command_line_args: Additional positional command-line arguments
108
+ (currently unused but preserved for interface compatibility).
109
+ **kwargs: Additional keyword arguments that will be converted to
110
+ command-line flags and passed to the LM Evaluation Harness.
111
+ Keys will be prefixed with '--' and values converted to strings.
112
+
113
+ Returns:
114
+ None: Results are written to the configured output path and logged.
115
+
116
+ Example:
117
+ ```python
118
+ >>> taskpool = LMEvalHarnessTaskPool(tasks=["arc_easy"])
119
+ >>> taskpool.evaluate(model, limit=100, device="cuda")
120
+ ```
121
+
122
+ Note:
123
+ The method leverages the LM Evaluation Harness's command-line interface
124
+ internally, which provides standardized evaluation procedures and
125
+ ensures compatibility with the broader evaluation ecosystem.
126
+ """
47
127
  command_line_args = []
48
128
  if self.include_path is not None:
49
129
  command_line_args.extend(["--include_path", self.include_path])
@@ -15,9 +15,37 @@ log = logging.getLogger(__name__)
15
15
 
16
16
 
17
17
  class NYUv2TaskPool(TaskPool):
18
+ """Task pool for multi-task learning evaluation on the NYUv2 dataset.
19
+
20
+ This task pool provides evaluation capabilities for multi-task learning models
21
+ on the NYU Depth V2 (NYUv2) dataset, which is a popular benchmark for indoor
22
+ scene understanding. The dataset supports multiple computer vision tasks
23
+ including semantic segmentation, depth estimation, and surface normal prediction.
24
+
25
+ The task pool is designed to work with encoder-decoder architectures where
26
+ a shared encoder processes input images and task-specific decoders generate
27
+ predictions for different tasks. It integrates with PyTorch Lightning for
28
+ streamlined training and evaluation workflows.
29
+
30
+ Supported Tasks:
31
+ - Semantic segmentation
32
+ - Depth estimation
33
+ - Surface normal prediction
34
+ """
35
+
18
36
  _trainer: L.Trainer = None
19
37
 
20
38
  def __init__(self, taskpool_config: DictConfig):
39
+ """Initialize the NYUv2 task pool with configuration settings.
40
+
41
+ Args:
42
+ taskpool_config: Configuration object containing all necessary
43
+ parameters for the task pool, including:
44
+ - data_dir: Path to the directory containing NYUv2 dataset
45
+ - tasks: List of tasks to evaluate (e.g., ["semantic", "depth"])
46
+ - batch_size: Batch size for evaluation data loader
47
+ - num_workers: Number of worker processes for data loading
48
+ """
21
49
  self.config = taskpool_config
22
50
 
23
51
  def load_datasets(self):
@@ -20,3 +20,4 @@ from .packages import import_object
20
20
  from .parameters import *
21
21
  from .pylogger import get_rankzero_logger
22
22
  from .timer import timeit_context
23
+ from .type import BoolStateDictType, StateDictType, TorchModelType
@@ -1,6 +1,6 @@
1
1
  import pickle
2
2
  from pathlib import Path
3
- from typing import Literal, Optional, Union
3
+ from typing import Any, Literal, Optional, Tuple, Union
4
4
 
5
5
  import numpy as np
6
6
  import torch
@@ -37,7 +37,9 @@ class InfiniteDataLoader:
37
37
  return data
38
38
 
39
39
 
40
- def load_tensor_from_file(file_path: Union[str, Path], device=None) -> torch.Tensor:
40
+ def load_tensor_from_file(
41
+ file_path: Union[str, Path], device: Optional[Union[str, torch.device]] = None
42
+ ) -> torch.Tensor:
41
43
  """
42
44
  Loads a tensor from a file, which can be either a .pt, .pth or .np file.
43
45
  If the file is not one of these formats, it will try to load it as a pickle file.
@@ -72,7 +74,7 @@ def train_validation_split(
72
74
  validation_size: Optional[int] = None,
73
75
  random_seed: Optional[int] = None,
74
76
  return_split: Literal["all", "train", "val"] = "both",
75
- ):
77
+ ) -> Union[Tuple[Dataset, Dataset], Dataset]:
76
78
  """
77
79
  Split a dataset into a training and validation set.
78
80
 
@@ -134,7 +136,7 @@ def train_validation_test_split(
134
136
  test_fraction: float,
135
137
  random_seed: Optional[int] = None,
136
138
  return_spilt: Literal["all", "train", "val", "test"] = "all",
137
- ):
139
+ ) -> Union[Tuple[Dataset, Dataset, Dataset], Dataset]:
138
140
  """
139
141
  Split a dataset into a training, validation and test set.
140
142
 
@@ -1,7 +1,7 @@
1
1
  import gc
2
2
  import logging
3
3
  import os
4
- from typing import List, Optional, Union
4
+ from typing import Any, List, Optional, Union
5
5
 
6
6
  import torch
7
7
  from transformers.utils import (
@@ -12,6 +12,8 @@ from transformers.utils import (
12
12
  is_torch_xpu_available,
13
13
  )
14
14
 
15
+ from .type import T
16
+
15
17
  __all__ = [
16
18
  "clear_cuda_cache",
17
19
  "to_device",
@@ -37,7 +39,7 @@ def clear_cuda_cache():
37
39
  log.warning("CUDA is not available. No cache to clear.")
38
40
 
39
41
 
40
- def to_device(obj, device: Optional[torch.device], **kwargs):
42
+ def to_device(obj: T, device: Optional[torch.device], **kwargs: Any) -> T:
41
43
  """
42
44
  Move a given object to the specified device.
43
45
 
@@ -102,7 +104,7 @@ def num_devices(devices: Union[int, List[int], str]) -> int:
102
104
  )
103
105
 
104
106
 
105
- def get_device(obj) -> torch.device:
107
+ def get_device(obj: Any) -> torch.device:
106
108
  """
107
109
  Get the device of a given object.
108
110
 
@@ -151,6 +153,7 @@ def get_current_device() -> torch.device:
151
153
  If not set, it defaults to "0".
152
154
 
153
155
  Example:
156
+
154
157
  >>> device = get_current_device()
155
158
  >>> print(device)
156
159
  xpu:0 # or npu:0, mps:0, cuda:0, cpu depending on availability
@@ -241,7 +244,7 @@ def cleanup_cuda():
241
244
  torch.cuda.reset_peak_memory_stats()
242
245
 
243
246
 
244
- def print_memory_usage(print_fn=print):
247
+ def print_memory_usage(print_fn=print) -> str:
245
248
  """
246
249
  Print the current GPU memory usage.
247
250
 
@@ -1,5 +1,5 @@
1
1
  import contextlib
2
- from typing import Dict, Generator, Iterable, Optional, Tuple
2
+ from typing import Dict, Generator, Iterable, Optional, Tuple, Union
3
3
 
4
4
  import torch
5
5
  from transformers.utils import (
@@ -25,7 +25,7 @@ PRECISION_STR_TO_DTYPE: Dict[str, torch.dtype] = {
25
25
  }
26
26
 
27
27
 
28
- def parse_dtype(dtype: Optional[str]):
28
+ def parse_dtype(dtype: Optional[str]) -> Optional[torch.dtype]:
29
29
  """
30
30
  Parses a string representation of a data type and returns the corresponding torch.dtype.
31
31
 
@@ -92,6 +92,7 @@ def set_default_dtype(dtype: torch.dtype) -> Generator[None, None, None]:
92
92
  ContextManager: context manager for setting default dtype.
93
93
 
94
94
  Example:
95
+
95
96
  >>> with set_default_dtype(torch.bfloat16):
96
97
  >>> x = torch.tensor([1, 2, 3])
97
98
  >>> x.dtype