fusion-bench 0.2.22__py3-none-any.whl → 0.2.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +4 -0
- fusion_bench/compat/method/__init__.py +5 -2
- fusion_bench/compat/method/base_algorithm.py +3 -2
- fusion_bench/compat/modelpool/base_pool.py +3 -3
- fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
- fusion_bench/dataset/gpt2_glue.py +1 -1
- fusion_bench/method/__init__.py +12 -2
- fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
- fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
- fusion_bench/method/bitdelta/bitdelta.py +7 -23
- fusion_bench/method/ensemble.py +17 -2
- fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
- fusion_bench/method/linear/__init__.py +6 -2
- fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
- fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
- fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
- fusion_bench/method/model_stock/__init__.py +1 -0
- fusion_bench/method/model_stock/model_stock.py +309 -0
- fusion_bench/method/regmean/clip_regmean.py +3 -6
- fusion_bench/method/regmean/regmean.py +27 -56
- fusion_bench/method/regmean/utils.py +56 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
- fusion_bench/method/simple_average.py +2 -2
- fusion_bench/method/slerp/__init__.py +1 -1
- fusion_bench/method/slerp/slerp.py +110 -14
- fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
- fusion_bench/method/ties_merging/ties_merging.py +22 -6
- fusion_bench/method/we_moe/flan_t5_we_moe.py +9 -20
- fusion_bench/method/wudi/__init__.py +1 -0
- fusion_bench/method/wudi/wudi.py +105 -0
- fusion_bench/mixins/clip_classification.py +26 -6
- fusion_bench/mixins/lightning_fabric.py +4 -0
- fusion_bench/mixins/serialization.py +40 -83
- fusion_bench/modelpool/base_pool.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +285 -44
- fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
- fusion_bench/models/hf_clip.py +4 -0
- fusion_bench/models/hf_utils.py +10 -4
- fusion_bench/models/linearized/vision_model.py +6 -6
- fusion_bench/models/model_card_templates/default.md +8 -1
- fusion_bench/models/modeling_smile_mistral/__init__.py +1 -0
- fusion_bench/models/we_moe.py +8 -8
- fusion_bench/models/wrappers/ensemble.py +136 -7
- fusion_bench/scripts/cli.py +2 -2
- fusion_bench/taskpool/base_pool.py +99 -17
- fusion_bench/taskpool/clip_vision/taskpool.py +12 -5
- fusion_bench/taskpool/dummy.py +101 -13
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
- fusion_bench/utils/__init__.py +1 -0
- fusion_bench/utils/data.py +6 -4
- fusion_bench/utils/devices.py +36 -11
- fusion_bench/utils/dtype.py +3 -2
- fusion_bench/utils/lazy_state_dict.py +85 -19
- fusion_bench/utils/packages.py +3 -3
- fusion_bench/utils/parameters.py +0 -2
- fusion_bench/utils/rich_utils.py +7 -3
- fusion_bench/utils/timer.py +92 -10
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/METADATA +10 -3
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/RECORD +77 -64
- fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
- fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
- fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
- fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
- fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
- fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
- fusion_bench_config/method/wudi/wudi.yaml +4 -0
- fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
- fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/top_level.txt +0 -0
fusion_bench/taskpool/dummy.py
CHANGED
|
@@ -1,5 +1,10 @@
|
|
|
1
1
|
"""
|
|
2
|
-
|
|
2
|
+
Dummy task pool implementation for debugging and testing purposes.
|
|
3
|
+
|
|
4
|
+
This module provides a minimal task pool implementation that can be used for
|
|
5
|
+
debugging model fusion workflows, testing infrastructure, and validating model
|
|
6
|
+
architectures without running expensive evaluation procedures. It's particularly
|
|
7
|
+
useful during development and prototyping phases.
|
|
3
8
|
"""
|
|
4
9
|
|
|
5
10
|
from typing import Optional
|
|
@@ -14,14 +19,41 @@ from fusion_bench.utils.parameters import count_parameters, print_parameters
|
|
|
14
19
|
|
|
15
20
|
|
|
16
21
|
def get_model_summary(model: nn.Module) -> dict:
|
|
17
|
-
"""
|
|
18
|
-
|
|
22
|
+
"""Generate a comprehensive summary report for a PyTorch model.
|
|
23
|
+
|
|
24
|
+
Analyzes the given model to extract key information about its architecture,
|
|
25
|
+
parameter count, and training characteristics. This function is useful for
|
|
26
|
+
model introspection and comparative analysis during model fusion workflows.
|
|
27
|
+
|
|
28
|
+
The summary includes both trainable and total parameter counts, which helps
|
|
29
|
+
in understanding model complexity and memory requirements. The trainable
|
|
30
|
+
percentage is particularly useful for identifying models with frozen layers
|
|
31
|
+
or parameter-efficient fine-tuning setups.
|
|
19
32
|
|
|
20
33
|
Args:
|
|
21
|
-
model: The model to
|
|
34
|
+
model: The PyTorch model to analyze. Can be any nn.Module instance
|
|
35
|
+
including complex models, fusion models, or pre-trained models.
|
|
22
36
|
|
|
23
37
|
Returns:
|
|
24
|
-
dict:
|
|
38
|
+
dict: A structured report containing model information:
|
|
39
|
+
- model_info: Dictionary with parameter statistics
|
|
40
|
+
- trainable_params: Number of trainable parameters
|
|
41
|
+
- all_params: Total number of parameters (trainable + frozen)
|
|
42
|
+
- trainable_percentage: Ratio of trainable to total parameters
|
|
43
|
+
|
|
44
|
+
Example:
|
|
45
|
+
```python
|
|
46
|
+
>>> model = MyModel()
|
|
47
|
+
>>> summary = get_model_summary(model)
|
|
48
|
+
>>> print(summary)
|
|
49
|
+
{
|
|
50
|
+
"model_info": {
|
|
51
|
+
"trainable_params": 1234567,
|
|
52
|
+
"all_params": 1234567,
|
|
53
|
+
"trainable_percentage": 1.0
|
|
54
|
+
}
|
|
55
|
+
}
|
|
56
|
+
```
|
|
25
57
|
"""
|
|
26
58
|
report = {}
|
|
27
59
|
training_params, all_params = count_parameters(model)
|
|
@@ -34,21 +66,77 @@ def get_model_summary(model: nn.Module) -> dict:
|
|
|
34
66
|
|
|
35
67
|
|
|
36
68
|
class DummyTaskPool(BaseTaskPool):
|
|
69
|
+
"""A lightweight task pool implementation for debugging and development workflows.
|
|
70
|
+
|
|
71
|
+
This dummy task pool provides a minimal evaluation interface that focuses on
|
|
72
|
+
model introspection rather than task-specific performance evaluation. It's
|
|
73
|
+
designed for development scenarios where you need to test model fusion
|
|
74
|
+
pipelines, validate architectures, or debug workflows without the overhead
|
|
75
|
+
of running actual evaluation tasks.
|
|
76
|
+
|
|
77
|
+
The task pool is particularly useful when:
|
|
78
|
+
- You want to verify model fusion works correctly
|
|
79
|
+
- You need to check parameter counts after fusion
|
|
80
|
+
- You're developing new fusion algorithms
|
|
81
|
+
- You want to test infrastructure without expensive evaluations
|
|
82
|
+
|
|
83
|
+
Example:
|
|
84
|
+
```python
|
|
85
|
+
>>> taskpool = DummyTaskPool(model_save_path="/tmp/fused_model")
|
|
86
|
+
>>> results = taskpool.evaluate(fused_model)
|
|
87
|
+
>>> print(f"Model has {results['model_info']['trainable_params']} parameters")
|
|
88
|
+
```
|
|
37
89
|
"""
|
|
38
|
-
This is a dummy task pool used for debugging purposes. It inherits from the base TaskPool class.
|
|
39
|
-
"""
|
|
40
90
|
|
|
41
|
-
def __init__(self, model_save_path: Optional[str] = None):
|
|
42
|
-
|
|
91
|
+
def __init__(self, model_save_path: Optional[str] = None, **kwargs):
|
|
92
|
+
"""Initialize the dummy task pool with optional model saving capability.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
model_save_path: Optional path where the evaluated model should be saved.
|
|
96
|
+
If provided, the model will be serialized and saved to this location
|
|
97
|
+
after evaluation using the separate_save utility. If None, no model
|
|
98
|
+
saving will be performed.
|
|
99
|
+
|
|
100
|
+
Example:
|
|
101
|
+
```python
|
|
102
|
+
>>> # Create taskpool without saving
|
|
103
|
+
>>> taskpool = DummyTaskPool()
|
|
104
|
+
|
|
105
|
+
>>> # Create taskpool with model saving
|
|
106
|
+
>>> taskpool = DummyTaskPool(model_save_path="/path/to/save/model.pth")
|
|
107
|
+
```
|
|
108
|
+
"""
|
|
109
|
+
super().__init__(**kwargs)
|
|
43
110
|
self.model_save_path = model_save_path
|
|
44
111
|
|
|
45
112
|
def evaluate(self, model):
|
|
46
|
-
"""
|
|
47
|
-
|
|
48
|
-
This method
|
|
113
|
+
"""Perform lightweight evaluation and analysis of the given model.
|
|
114
|
+
|
|
115
|
+
This method provides a minimal evaluation that focuses on model introspection
|
|
116
|
+
rather than task-specific performance metrics. It performs parameter analysis,
|
|
117
|
+
optionally saves the model, and returns a summary report.
|
|
118
|
+
|
|
119
|
+
The evaluation process includes:
|
|
120
|
+
1. Printing human-readable parameter information (rank-zero only)
|
|
121
|
+
2. Optionally saving the model if a save path was configured
|
|
122
|
+
3. Generating and returning a model summary report
|
|
49
123
|
|
|
50
124
|
Args:
|
|
51
|
-
model: The model to evaluate.
|
|
125
|
+
model: The model to evaluate. Can be any PyTorch nn.Module including
|
|
126
|
+
fusion models, pre-trained models, or custom architectures.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
dict: A model summary report containing parameter statistics and
|
|
130
|
+
architecture information. See get_model_summary() for detailed
|
|
131
|
+
format specification.
|
|
132
|
+
|
|
133
|
+
Example:
|
|
134
|
+
```python
|
|
135
|
+
>>> taskpool = DummyTaskPool(model_save_path="/tmp/model.pth")
|
|
136
|
+
>>> model = torch.nn.Linear(10, 5)
|
|
137
|
+
>>> results = taskpool.evaluate(model)
|
|
138
|
+
>>> print(f"Trainable params: {results['model_info']['trainable_params']}")
|
|
139
|
+
```
|
|
52
140
|
"""
|
|
53
141
|
if rank_zero_only.rank == 0:
|
|
54
142
|
print_parameters(model, is_human_readable=True)
|
|
@@ -16,6 +16,47 @@ log = logging.getLogger(__name__)
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
class LMEvalHarnessTaskPool(BaseTaskPool, LightningFabricMixin):
|
|
19
|
+
"""A task pool implementation that interfaces with the LM Evaluation Harness framework.
|
|
20
|
+
|
|
21
|
+
This class provides a wrapper around the LM Evaluation Harness (lm-eval) library,
|
|
22
|
+
enabling evaluation of language models on various standardized benchmarks and tasks.
|
|
23
|
+
It inherits from BaseTaskPool and LightningFabricMixin to provide distributed
|
|
24
|
+
computing capabilities through PyTorch Lightning Fabric.
|
|
25
|
+
|
|
26
|
+
The task pool supports evaluation on multiple tasks simultaneously and provides
|
|
27
|
+
flexible configuration options for batch processing, output formatting, and
|
|
28
|
+
logging. It automatically handles model setup and wrapping for distributed
|
|
29
|
+
evaluation when using Lightning Fabric.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
tasks: A single task name or list of task names to evaluate on.
|
|
33
|
+
Examples: "hellaswag", ["arc_easy", "arc_challenge", "hellaswag"]
|
|
34
|
+
apply_chat_template: Whether to apply chat template formatting to inputs.
|
|
35
|
+
Useful for instruction-tuned or chat models.
|
|
36
|
+
include_path: Path to additional task definitions or custom tasks.
|
|
37
|
+
batch_size: Number of samples to process in each batch. Larger values
|
|
38
|
+
may improve throughput but require more memory.
|
|
39
|
+
metadata: Additional metadata to include in evaluation results.
|
|
40
|
+
verbosity: Logging verbosity level for the evaluation process.
|
|
41
|
+
output_path: Custom path for saving evaluation results. If None,
|
|
42
|
+
results are saved to the default log directory.
|
|
43
|
+
log_samples: Whether to log individual sample predictions and targets.
|
|
44
|
+
Useful for debugging but increases output size significantly.
|
|
45
|
+
_usage_: Internal usage tracking string.
|
|
46
|
+
_version_: Internal version tracking string.
|
|
47
|
+
**kwargs: Additional arguments passed to the LM Evaluation Harness.
|
|
48
|
+
|
|
49
|
+
Example:
|
|
50
|
+
```python
|
|
51
|
+
>>> taskpool = LMEvalHarnessTaskPool(
|
|
52
|
+
... tasks=["arc_easy", "hellaswag"],
|
|
53
|
+
... batch_size=8,
|
|
54
|
+
... verbosity="INFO"
|
|
55
|
+
... )
|
|
56
|
+
>>> results = taskpool.evaluate(model)
|
|
57
|
+
```
|
|
58
|
+
"""
|
|
59
|
+
|
|
19
60
|
def __init__(
|
|
20
61
|
self,
|
|
21
62
|
tasks: Union[str, List[str]],
|
|
@@ -44,6 +85,45 @@ class LMEvalHarnessTaskPool(BaseTaskPool, LightningFabricMixin):
|
|
|
44
85
|
self.log_samples = log_samples
|
|
45
86
|
|
|
46
87
|
def evaluate(self, model, *command_line_args, **kwargs):
|
|
88
|
+
"""Evaluate a language model on the configured tasks using LM Evaluation Harness.
|
|
89
|
+
|
|
90
|
+
This method wraps the model with the LM Evaluation Harness framework and
|
|
91
|
+
executes evaluation on all configured tasks. It automatically handles
|
|
92
|
+
command-line argument construction, model wrapping with Lightning Fabric
|
|
93
|
+
for distributed evaluation, and result logging.
|
|
94
|
+
|
|
95
|
+
The evaluation process includes:
|
|
96
|
+
1. Building command-line arguments from instance configuration
|
|
97
|
+
2. Setting up the LM Evaluation Harness argument parser
|
|
98
|
+
3. Wrapping the model with Lightning Fabric if not already wrapped
|
|
99
|
+
4. Creating an HFLM (Hugging Face Language Model) wrapper
|
|
100
|
+
5. Executing the evaluation through the LM-Eval CLI interface
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
model: The language model to evaluate. Can be a Hugging Face model,
|
|
104
|
+
PyTorch model, or any model compatible with the LM Evaluation Harness.
|
|
105
|
+
The model will be automatically wrapped with Lightning Fabric for
|
|
106
|
+
distributed evaluation if not already wrapped.
|
|
107
|
+
*command_line_args: Additional positional command-line arguments
|
|
108
|
+
(currently unused but preserved for interface compatibility).
|
|
109
|
+
**kwargs: Additional keyword arguments that will be converted to
|
|
110
|
+
command-line flags and passed to the LM Evaluation Harness.
|
|
111
|
+
Keys will be prefixed with '--' and values converted to strings.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
None: Results are written to the configured output path and logged.
|
|
115
|
+
|
|
116
|
+
Example:
|
|
117
|
+
```python
|
|
118
|
+
>>> taskpool = LMEvalHarnessTaskPool(tasks=["arc_easy"])
|
|
119
|
+
>>> taskpool.evaluate(model, limit=100, device="cuda")
|
|
120
|
+
```
|
|
121
|
+
|
|
122
|
+
Note:
|
|
123
|
+
The method leverages the LM Evaluation Harness's command-line interface
|
|
124
|
+
internally, which provides standardized evaluation procedures and
|
|
125
|
+
ensures compatibility with the broader evaluation ecosystem.
|
|
126
|
+
"""
|
|
47
127
|
command_line_args = []
|
|
48
128
|
if self.include_path is not None:
|
|
49
129
|
command_line_args.extend(["--include_path", self.include_path])
|
|
@@ -15,9 +15,37 @@ log = logging.getLogger(__name__)
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
class NYUv2TaskPool(TaskPool):
|
|
18
|
+
"""Task pool for multi-task learning evaluation on the NYUv2 dataset.
|
|
19
|
+
|
|
20
|
+
This task pool provides evaluation capabilities for multi-task learning models
|
|
21
|
+
on the NYU Depth V2 (NYUv2) dataset, which is a popular benchmark for indoor
|
|
22
|
+
scene understanding. The dataset supports multiple computer vision tasks
|
|
23
|
+
including semantic segmentation, depth estimation, and surface normal prediction.
|
|
24
|
+
|
|
25
|
+
The task pool is designed to work with encoder-decoder architectures where
|
|
26
|
+
a shared encoder processes input images and task-specific decoders generate
|
|
27
|
+
predictions for different tasks. It integrates with PyTorch Lightning for
|
|
28
|
+
streamlined training and evaluation workflows.
|
|
29
|
+
|
|
30
|
+
Supported Tasks:
|
|
31
|
+
- Semantic segmentation
|
|
32
|
+
- Depth estimation
|
|
33
|
+
- Surface normal prediction
|
|
34
|
+
"""
|
|
35
|
+
|
|
18
36
|
_trainer: L.Trainer = None
|
|
19
37
|
|
|
20
38
|
def __init__(self, taskpool_config: DictConfig):
|
|
39
|
+
"""Initialize the NYUv2 task pool with configuration settings.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
taskpool_config: Configuration object containing all necessary
|
|
43
|
+
parameters for the task pool, including:
|
|
44
|
+
- data_dir: Path to the directory containing NYUv2 dataset
|
|
45
|
+
- tasks: List of tasks to evaluate (e.g., ["semantic", "depth"])
|
|
46
|
+
- batch_size: Batch size for evaluation data loader
|
|
47
|
+
- num_workers: Number of worker processes for data loading
|
|
48
|
+
"""
|
|
21
49
|
self.config = taskpool_config
|
|
22
50
|
|
|
23
51
|
def load_datasets(self):
|
fusion_bench/utils/__init__.py
CHANGED
fusion_bench/utils/data.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import pickle
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
from typing import Literal, Optional, Union
|
|
3
|
+
from typing import Any, Literal, Optional, Tuple, Union
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import torch
|
|
@@ -37,7 +37,9 @@ class InfiniteDataLoader:
|
|
|
37
37
|
return data
|
|
38
38
|
|
|
39
39
|
|
|
40
|
-
def load_tensor_from_file(
|
|
40
|
+
def load_tensor_from_file(
|
|
41
|
+
file_path: Union[str, Path], device: Optional[Union[str, torch.device]] = None
|
|
42
|
+
) -> torch.Tensor:
|
|
41
43
|
"""
|
|
42
44
|
Loads a tensor from a file, which can be either a .pt, .pth or .np file.
|
|
43
45
|
If the file is not one of these formats, it will try to load it as a pickle file.
|
|
@@ -72,7 +74,7 @@ def train_validation_split(
|
|
|
72
74
|
validation_size: Optional[int] = None,
|
|
73
75
|
random_seed: Optional[int] = None,
|
|
74
76
|
return_split: Literal["all", "train", "val"] = "both",
|
|
75
|
-
):
|
|
77
|
+
) -> Union[Tuple[Dataset, Dataset], Dataset]:
|
|
76
78
|
"""
|
|
77
79
|
Split a dataset into a training and validation set.
|
|
78
80
|
|
|
@@ -134,7 +136,7 @@ def train_validation_test_split(
|
|
|
134
136
|
test_fraction: float,
|
|
135
137
|
random_seed: Optional[int] = None,
|
|
136
138
|
return_spilt: Literal["all", "train", "val", "test"] = "all",
|
|
137
|
-
):
|
|
139
|
+
) -> Union[Tuple[Dataset, Dataset, Dataset], Dataset]:
|
|
138
140
|
"""
|
|
139
141
|
Split a dataset into a training, validation and test set.
|
|
140
142
|
|
fusion_bench/utils/devices.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import gc
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
|
-
from typing import List, Optional, Union
|
|
4
|
+
from typing import Any, List, Optional, Union
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
from transformers.utils import (
|
|
@@ -12,6 +12,8 @@ from transformers.utils import (
|
|
|
12
12
|
is_torch_xpu_available,
|
|
13
13
|
)
|
|
14
14
|
|
|
15
|
+
from .type import T
|
|
16
|
+
|
|
15
17
|
__all__ = [
|
|
16
18
|
"clear_cuda_cache",
|
|
17
19
|
"to_device",
|
|
@@ -37,7 +39,12 @@ def clear_cuda_cache():
|
|
|
37
39
|
log.warning("CUDA is not available. No cache to clear.")
|
|
38
40
|
|
|
39
41
|
|
|
40
|
-
def to_device(
|
|
42
|
+
def to_device(
|
|
43
|
+
obj: T,
|
|
44
|
+
device: Optional[torch.device],
|
|
45
|
+
copy_on_move: bool = False,
|
|
46
|
+
**kwargs: Any,
|
|
47
|
+
) -> T:
|
|
41
48
|
"""
|
|
42
49
|
Move a given object to the specified device.
|
|
43
50
|
|
|
@@ -47,12 +54,20 @@ def to_device(obj, device: Optional[torch.device], **kwargs):
|
|
|
47
54
|
Args:
|
|
48
55
|
obj: The object to be moved to the device. This can be a torch.Tensor, torch.nn.Module, list, tuple, or dict.
|
|
49
56
|
device (torch.device): The target device to move the object to. This can be `None`.
|
|
50
|
-
|
|
57
|
+
copy_on_move (bool, optional): Whether to force a copy operation when moving tensors to a different device.
|
|
58
|
+
If True, tensors will be copied when moved to a different device (copy=True is passed to tensor.to()).
|
|
59
|
+
If False (default), tensors are moved without forcing a copy operation, allowing PyTorch to optimize
|
|
60
|
+
the operation. This parameter only affects torch.Tensor objects; modules and other types are unaffected.
|
|
61
|
+
Defaults to False.
|
|
62
|
+
**kwargs: Additional keyword arguments to be passed to the `to` method of torch.Tensor or torch.nn.Module.
|
|
63
|
+
For example, `non_blocking=True`, `dtype=torch.float16`. Note that if `copy_on_move=True`, the `copy`
|
|
64
|
+
keyword argument will be automatically set and should not be provided manually.
|
|
51
65
|
|
|
52
66
|
Returns:
|
|
53
67
|
The object moved to the specified device. The type of the returned object matches the type of the input object.
|
|
54
68
|
|
|
55
69
|
Examples:
|
|
70
|
+
```python
|
|
56
71
|
>>> tensor = torch.tensor([1, 2, 3])
|
|
57
72
|
>>> to_device(tensor, torch.device('cuda'))
|
|
58
73
|
tensor([1, 2, 3], device='cuda:0')
|
|
@@ -64,17 +79,26 @@ def to_device(obj, device: Optional[torch.device], **kwargs):
|
|
|
64
79
|
>>> data = [torch.tensor([1, 2]), torch.tensor([3, 4])]
|
|
65
80
|
>>> to_device(data, torch.device('cuda'))
|
|
66
81
|
[tensor([1, 2], device='cuda:0'), tensor([3, 4], device='cuda:0')]
|
|
82
|
+
|
|
83
|
+
>>> # Force copy when moving to different device
|
|
84
|
+
>>> tensor = torch.tensor([1, 2, 3], device='cpu')
|
|
85
|
+
>>> copied_tensor = to_device(tensor, torch.device('cuda'), copy_on_move=True)
|
|
86
|
+
>>> # tensor and copied_tensor will have different memory locations
|
|
87
|
+
```
|
|
67
88
|
"""
|
|
68
|
-
if isinstance(obj,
|
|
89
|
+
if isinstance(obj, torch.Tensor):
|
|
90
|
+
if copy_on_move:
|
|
91
|
+
if obj.device != torch.device(device):
|
|
92
|
+
kwargs["copy"] = True
|
|
93
|
+
return obj.to(device, **kwargs)
|
|
94
|
+
elif isinstance(obj, torch.nn.Module):
|
|
69
95
|
return obj.to(device, **kwargs)
|
|
70
96
|
elif isinstance(obj, list):
|
|
71
|
-
return [to_device(o, device) for o in obj]
|
|
97
|
+
return [to_device(o, device, **kwargs) for o in obj]
|
|
72
98
|
elif isinstance(obj, tuple):
|
|
73
|
-
return tuple(to_device(o, device) for o in obj)
|
|
99
|
+
return tuple(to_device(o, device, **kwargs) for o in obj)
|
|
74
100
|
elif isinstance(obj, dict):
|
|
75
|
-
for key in obj
|
|
76
|
-
obj[key] = to_device(obj[key], device)
|
|
77
|
-
return obj
|
|
101
|
+
return {key: to_device(value, device, **kwargs) for key, value in obj.items()}
|
|
78
102
|
else:
|
|
79
103
|
# the default behavior is to return the object as is
|
|
80
104
|
return obj
|
|
@@ -102,7 +126,7 @@ def num_devices(devices: Union[int, List[int], str]) -> int:
|
|
|
102
126
|
)
|
|
103
127
|
|
|
104
128
|
|
|
105
|
-
def get_device(obj) -> torch.device:
|
|
129
|
+
def get_device(obj: Any) -> torch.device:
|
|
106
130
|
"""
|
|
107
131
|
Get the device of a given object.
|
|
108
132
|
|
|
@@ -151,6 +175,7 @@ def get_current_device() -> torch.device:
|
|
|
151
175
|
If not set, it defaults to "0".
|
|
152
176
|
|
|
153
177
|
Example:
|
|
178
|
+
|
|
154
179
|
>>> device = get_current_device()
|
|
155
180
|
>>> print(device)
|
|
156
181
|
xpu:0 # or npu:0, mps:0, cuda:0, cpu depending on availability
|
|
@@ -241,7 +266,7 @@ def cleanup_cuda():
|
|
|
241
266
|
torch.cuda.reset_peak_memory_stats()
|
|
242
267
|
|
|
243
268
|
|
|
244
|
-
def print_memory_usage(print_fn=print):
|
|
269
|
+
def print_memory_usage(print_fn=print) -> str:
|
|
245
270
|
"""
|
|
246
271
|
Print the current GPU memory usage.
|
|
247
272
|
|
fusion_bench/utils/dtype.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import contextlib
|
|
2
|
-
from typing import Dict, Generator, Iterable, Optional, Tuple
|
|
2
|
+
from typing import Dict, Generator, Iterable, Optional, Tuple, Union
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
from transformers.utils import (
|
|
@@ -25,7 +25,7 @@ PRECISION_STR_TO_DTYPE: Dict[str, torch.dtype] = {
|
|
|
25
25
|
}
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
def parse_dtype(dtype: Optional[str]):
|
|
28
|
+
def parse_dtype(dtype: Optional[str]) -> Optional[torch.dtype]:
|
|
29
29
|
"""
|
|
30
30
|
Parses a string representation of a data type and returns the corresponding torch.dtype.
|
|
31
31
|
|
|
@@ -92,6 +92,7 @@ def set_default_dtype(dtype: torch.dtype) -> Generator[None, None, None]:
|
|
|
92
92
|
ContextManager: context manager for setting default dtype.
|
|
93
93
|
|
|
94
94
|
Example:
|
|
95
|
+
|
|
95
96
|
>>> with set_default_dtype(torch.bfloat16):
|
|
96
97
|
>>> x = torch.tensor([1, 2, 3])
|
|
97
98
|
>>> x.dtype
|
|
@@ -2,7 +2,18 @@ import json
|
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
4
|
from copy import deepcopy
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import (
|
|
6
|
+
TYPE_CHECKING,
|
|
7
|
+
Dict,
|
|
8
|
+
Generic,
|
|
9
|
+
Iterator,
|
|
10
|
+
List,
|
|
11
|
+
Mapping,
|
|
12
|
+
Optional,
|
|
13
|
+
Tuple,
|
|
14
|
+
Type,
|
|
15
|
+
Union,
|
|
16
|
+
)
|
|
6
17
|
|
|
7
18
|
import torch
|
|
8
19
|
from accelerate import init_empty_weights
|
|
@@ -11,10 +22,12 @@ from huggingface_hub import snapshot_download
|
|
|
11
22
|
from safetensors import safe_open
|
|
12
23
|
from safetensors.torch import load_file
|
|
13
24
|
from torch import nn
|
|
25
|
+
from torch.nn.modules.module import _IncompatibleKeys
|
|
14
26
|
from transformers import AutoConfig
|
|
15
27
|
|
|
16
28
|
from fusion_bench.utils.dtype import parse_dtype
|
|
17
29
|
from fusion_bench.utils.packages import import_object
|
|
30
|
+
from fusion_bench.utils.type import TorchModelType
|
|
18
31
|
|
|
19
32
|
if TYPE_CHECKING:
|
|
20
33
|
from transformers import PretrainedConfig
|
|
@@ -49,7 +62,7 @@ def resolve_checkpoint_path(
|
|
|
49
62
|
)
|
|
50
63
|
|
|
51
64
|
|
|
52
|
-
class LazyStateDict(Mapping[str, torch.Tensor]):
|
|
65
|
+
class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
53
66
|
"""
|
|
54
67
|
Dictionary-like object that lazily loads a state dict from a checkpoint path.
|
|
55
68
|
"""
|
|
@@ -63,11 +76,14 @@ class LazyStateDict(Mapping[str, torch.Tensor]):
|
|
|
63
76
|
_index: Optional[Dict[str, str]]
|
|
64
77
|
"""Mapping of parameter names to checkpoint files."""
|
|
65
78
|
|
|
79
|
+
meta_module: TorchModelType = None
|
|
80
|
+
meta_module_class: Optional[Type[TorchModelType]] = None
|
|
81
|
+
|
|
66
82
|
def __init__(
|
|
67
83
|
self,
|
|
68
84
|
checkpoint: str,
|
|
69
|
-
meta_module_class: Optional[Type[
|
|
70
|
-
meta_module: Optional[
|
|
85
|
+
meta_module_class: Optional[Type[TorchModelType]] = None,
|
|
86
|
+
meta_module: Optional[TorchModelType] = None,
|
|
71
87
|
cache_state_dict: bool = False,
|
|
72
88
|
torch_dtype: Optional[torch.dtype] = None,
|
|
73
89
|
device: str = "cpu",
|
|
@@ -88,15 +104,19 @@ class LazyStateDict(Mapping[str, torch.Tensor]):
|
|
|
88
104
|
hf_proxies (Dict, optional): Proxies to use for downloading from Hugging Face Hub.
|
|
89
105
|
"""
|
|
90
106
|
self.cache_state_dict = cache_state_dict
|
|
107
|
+
|
|
108
|
+
# Validate that both meta_module_class and meta_module are not provided
|
|
109
|
+
if meta_module_class is not None and meta_module is not None:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
"Cannot provide both meta_module_class and meta_module, please provide only one."
|
|
112
|
+
)
|
|
113
|
+
|
|
91
114
|
self.meta_module_class = meta_module_class
|
|
92
115
|
if isinstance(self.meta_module_class, str):
|
|
93
116
|
self.meta_module_class = import_object(self.meta_module_class)
|
|
94
117
|
self.meta_module = meta_module
|
|
118
|
+
|
|
95
119
|
if self.meta_module_class is not None:
|
|
96
|
-
if self.meta_module is not None:
|
|
97
|
-
raise ValueError(
|
|
98
|
-
"Cannot provide both meta_module_class and meta_module, please provide only one."
|
|
99
|
-
)
|
|
100
120
|
with init_empty_weights():
|
|
101
121
|
self.meta_module = self.meta_module_class.from_pretrained(
|
|
102
122
|
checkpoint,
|
|
@@ -173,9 +193,13 @@ class LazyStateDict(Mapping[str, torch.Tensor]):
|
|
|
173
193
|
"""
|
|
174
194
|
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
|
175
195
|
"""
|
|
196
|
+
if hasattr(self, "_cached_dtype"):
|
|
197
|
+
return self._cached_dtype
|
|
198
|
+
|
|
176
199
|
first_key = next(iter(self.keys()))
|
|
177
200
|
first_param = self[first_key]
|
|
178
|
-
|
|
201
|
+
self._cached_dtype = first_param.dtype
|
|
202
|
+
return self._cached_dtype
|
|
179
203
|
|
|
180
204
|
def state_dict(self, keep_vars: bool = False) -> "LazyStateDict":
|
|
181
205
|
"""
|
|
@@ -321,9 +345,7 @@ class LazyStateDict(Mapping[str, torch.Tensor]):
|
|
|
321
345
|
if self._state_dict_cache is not None:
|
|
322
346
|
self._state_dict_cache[key] = value
|
|
323
347
|
else:
|
|
324
|
-
log.warning(
|
|
325
|
-
"State dict cache is disabled, setting a tensor will not update the cache."
|
|
326
|
-
)
|
|
348
|
+
log.warning("State dict cache is disabled, initializing the cache.")
|
|
327
349
|
self._state_dict_cache = {key: value}
|
|
328
350
|
|
|
329
351
|
def __contains__(self, key: str) -> bool:
|
|
@@ -339,7 +361,7 @@ class LazyStateDict(Mapping[str, torch.Tensor]):
|
|
|
339
361
|
self._checkpoint_files[0], key, update_cache=False
|
|
340
362
|
)
|
|
341
363
|
return tensor is not None
|
|
342
|
-
except
|
|
364
|
+
except (KeyError, FileNotFoundError, RuntimeError, EOFError):
|
|
343
365
|
return False
|
|
344
366
|
return False
|
|
345
367
|
|
|
@@ -409,8 +431,8 @@ class LazyStateDict(Mapping[str, torch.Tensor]):
|
|
|
409
431
|
)
|
|
410
432
|
|
|
411
433
|
def load_state_dict(
|
|
412
|
-
self, state_dict:
|
|
413
|
-
) ->
|
|
434
|
+
self, state_dict: Mapping[str, torch.Tensor], strict: bool = True
|
|
435
|
+
) -> _IncompatibleKeys:
|
|
414
436
|
"""
|
|
415
437
|
Load a state dict into this LazyStateDict.
|
|
416
438
|
This method is only for compatibility with nn.Module and it overrides the cache of LazyStateDict.
|
|
@@ -419,16 +441,60 @@ class LazyStateDict(Mapping[str, torch.Tensor]):
|
|
|
419
441
|
state_dict (Dict[str, torch.Tensor]): The state dict to load.
|
|
420
442
|
strict (bool): Whether to enforce that all keys in the state dict are present in this LazyStateDict.
|
|
421
443
|
"""
|
|
444
|
+
if not isinstance(state_dict, Mapping):
|
|
445
|
+
raise TypeError(
|
|
446
|
+
f"Expected state_dict to be dict-like, got {type(state_dict)}."
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
missing_keys: list[str] = []
|
|
450
|
+
unexpected_keys: list[str] = []
|
|
451
|
+
error_msgs: list[str] = []
|
|
452
|
+
|
|
422
453
|
log.warning(
|
|
423
454
|
"Loading state dict into LazyStateDict is not recommended, as it may lead to unexpected behavior. "
|
|
424
455
|
"Use with caution."
|
|
425
456
|
)
|
|
457
|
+
|
|
458
|
+
# Check for unexpected keys in the provided state_dict
|
|
459
|
+
for key in state_dict:
|
|
460
|
+
if key not in self:
|
|
461
|
+
unexpected_keys.append(key)
|
|
462
|
+
|
|
463
|
+
# Check for missing keys that are expected in this LazyStateDict
|
|
464
|
+
for key in self.keys():
|
|
465
|
+
if key not in state_dict:
|
|
466
|
+
missing_keys.append(key)
|
|
467
|
+
|
|
468
|
+
# Handle strict mode
|
|
426
469
|
if strict:
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
470
|
+
if len(unexpected_keys) > 0:
|
|
471
|
+
error_msgs.insert(
|
|
472
|
+
0,
|
|
473
|
+
"Unexpected key(s) in state_dict: {}. ".format(
|
|
474
|
+
", ".join(f'"{k}"' for k in unexpected_keys)
|
|
475
|
+
),
|
|
476
|
+
)
|
|
477
|
+
if len(missing_keys) > 0:
|
|
478
|
+
error_msgs.insert(
|
|
479
|
+
0,
|
|
480
|
+
"Missing key(s) in state_dict: {}. ".format(
|
|
481
|
+
", ".join(f'"{k}"' for k in missing_keys)
|
|
482
|
+
),
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
if len(error_msgs) > 0:
|
|
486
|
+
raise RuntimeError(
|
|
487
|
+
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
|
488
|
+
self.__class__.__name__, "\n\t".join(error_msgs)
|
|
489
|
+
)
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
# Load the state dict values
|
|
430
493
|
for key, value in state_dict.items():
|
|
431
|
-
|
|
494
|
+
if key in self: # Only set keys that exist in this LazyStateDict
|
|
495
|
+
self[key] = value
|
|
496
|
+
|
|
497
|
+
return _IncompatibleKeys(missing_keys, unexpected_keys)
|
|
432
498
|
|
|
433
499
|
def __getattr__(self, name: str):
|
|
434
500
|
if "meta_module" in self.__dict__:
|
fusion_bench/utils/packages.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import importlib.metadata
|
|
2
2
|
import importlib.util
|
|
3
3
|
from functools import lru_cache
|
|
4
|
-
from typing import TYPE_CHECKING
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
5
|
|
|
6
6
|
from packaging import version
|
|
7
7
|
|
|
@@ -69,7 +69,7 @@ def is_vllm_available():
|
|
|
69
69
|
return _is_package_available("vllm")
|
|
70
70
|
|
|
71
71
|
|
|
72
|
-
def import_object(abs_obj_name: str):
|
|
72
|
+
def import_object(abs_obj_name: str) -> Any:
|
|
73
73
|
"""
|
|
74
74
|
Imports a class from a module given the absolute class name.
|
|
75
75
|
|
|
@@ -84,7 +84,7 @@ def import_object(abs_obj_name: str):
|
|
|
84
84
|
return getattr(module, obj_name)
|
|
85
85
|
|
|
86
86
|
|
|
87
|
-
def compare_versions(v1, v2):
|
|
87
|
+
def compare_versions(v1: str, v2: str) -> int:
|
|
88
88
|
"""Compare two version strings.
|
|
89
89
|
Returns -1 if v1 < v2, 0 if v1 == v2, 1 if v1 > v2"""
|
|
90
90
|
|