fusion-bench 0.2.20__py3-none-any.whl → 0.2.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +1 -0
- fusion_bench/_get_started/__init__.py +3 -0
- fusion_bench/_get_started/greeting_program.py +49 -0
- fusion_bench/compat/method/base_algorithm.py +14 -0
- fusion_bench/constants/__init__.py +5 -0
- fusion_bench/constants/clip_vision.py +26 -2
- fusion_bench/constants/paths.py +4 -0
- fusion_bench/dataset/clip_dataset.py +2 -1
- fusion_bench/dataset/gpt2_glue.py +9 -9
- fusion_bench/dataset/image_corruption/__init__.py +0 -0
- fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
- fusion_bench/dataset/image_dataset.py +1 -1
- fusion_bench/dataset/nyuv2.py +2 -2
- fusion_bench/method/__init__.py +16 -3
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
- fusion_bench/method/base_algorithm.py +195 -12
- fusion_bench/method/bitdelta/__init__.py +4 -0
- fusion_bench/method/bitdelta/bitdelta.py +156 -0
- fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
- fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
- fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
- fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
- fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
- fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
- fusion_bench/method/doge_ta/doge_ta.py +1 -1
- fusion_bench/method/ensemble.py +12 -12
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -2
- fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
- fusion_bench/method/fw_merging/fw_hard.py +1 -1
- fusion_bench/method/fw_merging/fw_soft.py +1 -1
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
- fusion_bench/method/linear/expo.py +2 -1
- fusion_bench/method/linear/linear_interpolation.py +6 -4
- fusion_bench/method/linear/simple_average_for_llama.py +2 -3
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
- fusion_bench/method/model_recombination.py +2 -5
- fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
- fusion_bench/method/moe_pruner/utils/data.py +2 -1
- fusion_bench/method/moe_pruner/utils/prune.py +6 -1
- fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
- fusion_bench/method/pruning/wanda_utils/data.py +1 -2
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
- fusion_bench/method/randes/modelsoup.py +1 -3
- fusion_bench/method/regmean/clip_regmean.py +2 -2
- fusion_bench/method/regmean/gpt2_regmean.py +3 -10
- fusion_bench/method/regmean/regmean.py +2 -11
- fusion_bench/method/regmean_plusplus/__init__.py +1 -1
- fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +24 -17
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +56 -38
- fusion_bench/method/simple_average.py +5 -9
- fusion_bench/method/slerp/slerp.py +5 -2
- fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +145 -0
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +39 -28
- fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
- fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
- fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
- fusion_bench/method/ties_merging/ties_merging.py +13 -26
- fusion_bench/method/we_moe/clip_we_moe.py +5 -4
- fusion_bench/method/we_moe/we_moe.py +6 -6
- fusion_bench/method/weighted_average/llama.py +4 -16
- fusion_bench/metrics/continual_learning/__init__.py +1 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
- fusion_bench/metrics/nyuv2/__init__.py +2 -2
- fusion_bench/metrics/nyuv2/segmentation.py +1 -1
- fusion_bench/mixins/__init__.py +10 -2
- fusion_bench/mixins/clip_classification.py +4 -3
- fusion_bench/mixins/hydra_config.py +105 -7
- fusion_bench/mixins/lightning_fabric.py +2 -0
- fusion_bench/mixins/serialization.py +265 -48
- fusion_bench/modelpool/__init__.py +2 -2
- fusion_bench/modelpool/base_pool.py +29 -9
- fusion_bench/modelpool/causal_lm/causal_lm.py +9 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +1 -3
- fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
- fusion_bench/models/__init__.py +2 -1
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
- fusion_bench/models/hf_utils.py +182 -0
- fusion_bench/models/linearized/linearized_model_utils.py +4 -4
- fusion_bench/models/linearized/vision_model.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
- fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
- fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
- fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
- fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +0 -0
- fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +705 -0
- fusion_bench/models/modeling_smile_llama/register.py +8 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
- fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +6 -7
- fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
- fusion_bench/models/parameter_dict.py +1 -1
- fusion_bench/models/sparse_we_moe.py +1 -53
- fusion_bench/models/utils.py +26 -0
- fusion_bench/models/we_moe.py +1 -53
- fusion_bench/models/wrappers/ensemble.py +6 -4
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
- fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
- fusion_bench/programs/base_program.py +81 -2
- fusion_bench/programs/fabric_fusion_program.py +24 -8
- fusion_bench/scripts/cli.py +5 -5
- fusion_bench/taskpool/base_pool.py +4 -3
- fusion_bench/taskpool/clip_vision/taskpool.py +34 -18
- fusion_bench/taskpool/dummy.py +1 -1
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
- fusion_bench/tasks/clip_classification/__init__.py +6 -4
- fusion_bench/utils/__init__.py +6 -1
- fusion_bench/utils/devices.py +14 -4
- fusion_bench/utils/instantiate_utils.py +3 -1
- fusion_bench/utils/modelscope.py +127 -8
- fusion_bench/utils/parameters.py +2 -2
- fusion_bench/utils/rich_utils.py +3 -0
- fusion_bench/utils/state_dict_arithmetic.py +25 -23
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/METADATA +24 -25
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/RECORD +165 -134
- fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
- fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
- fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
- fusion_bench_config/_get_started/greeting_program.yaml +4 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
- fusion_bench_config/fabric_model_fusion.yaml +45 -17
- fusion_bench_config/hydra/default.yaml +6 -2
- fusion_bench_config/llama_full_finetune.yaml +1 -0
- fusion_bench_config/method/adamerging/clip.yaml +1 -1
- fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
- fusion_bench_config/method/depth_upscaling.yaml +4 -1
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +1 -1
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +2 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
- fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
- fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +3 -1
- fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
- fusion_bench_config/path/default.yaml +28 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
- fusion_bench_config/method/adamerging.yaml +0 -23
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/top_level.txt +0 -0
- /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
fusion_bench/method/ensemble.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from typing import List, Mapping, Union # noqa: F401
|
|
2
|
+
from typing import List, Mapping, Optional, Union # noqa: F401
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import torch
|
|
6
6
|
from torch import nn
|
|
7
7
|
|
|
8
8
|
from fusion_bench.method import BaseAlgorithm
|
|
9
|
+
from fusion_bench.mixins import auto_register_config
|
|
9
10
|
from fusion_bench.modelpool import BaseModelPool
|
|
10
11
|
from fusion_bench.models.wrappers.ensemble import (
|
|
11
12
|
EnsembleModule,
|
|
@@ -18,7 +19,7 @@ log = logging.getLogger(__name__)
|
|
|
18
19
|
|
|
19
20
|
class SimpleEnsembleAlgorithm(BaseAlgorithm):
|
|
20
21
|
@torch.no_grad()
|
|
21
|
-
def run(self, modelpool: BaseModelPool | List[nn.Module]):
|
|
22
|
+
def run(self, modelpool: BaseModelPool | List[nn.Module]) -> EnsembleModule:
|
|
22
23
|
"""
|
|
23
24
|
Run the simple ensemble algorithm on the given model pool.
|
|
24
25
|
|
|
@@ -35,20 +36,19 @@ class SimpleEnsembleAlgorithm(BaseAlgorithm):
|
|
|
35
36
|
return ensemble
|
|
36
37
|
|
|
37
38
|
|
|
39
|
+
@auto_register_config
|
|
38
40
|
class WeightedEnsembleAlgorithm(BaseAlgorithm):
|
|
39
41
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
self.normalize = normalize
|
|
47
|
-
self.weights = weights
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
normalize: bool = True,
|
|
45
|
+
weights: Optional[List[float]] = None,
|
|
46
|
+
**kwargs,
|
|
47
|
+
):
|
|
48
48
|
super().__init__(**kwargs)
|
|
49
49
|
|
|
50
50
|
@torch.no_grad()
|
|
51
|
-
def run(self, modelpool: BaseModelPool | List[nn.Module]):
|
|
51
|
+
def run(self, modelpool: BaseModelPool | List[nn.Module]) -> WeightedEnsembleModule:
|
|
52
52
|
"""
|
|
53
53
|
Run the weighted ensemble algorithm on the given model pool.
|
|
54
54
|
|
|
@@ -78,7 +78,7 @@ class WeightedEnsembleAlgorithm(BaseAlgorithm):
|
|
|
78
78
|
|
|
79
79
|
class MaxModelPredictorAlgorithm(BaseAlgorithm):
|
|
80
80
|
@torch.no_grad()
|
|
81
|
-
def run(self, modelpool: BaseModelPool | List[nn.Module]):
|
|
81
|
+
def run(self, modelpool: BaseModelPool | List[nn.Module]) -> MaxModelPredictor:
|
|
82
82
|
"""
|
|
83
83
|
Run the max model predictor algorithm on the given model pool.
|
|
84
84
|
|
|
@@ -12,9 +12,9 @@ import os
|
|
|
12
12
|
import torch
|
|
13
13
|
import transformers
|
|
14
14
|
from datasets import load_dataset
|
|
15
|
+
from huggingface_hub import hf_hub_download
|
|
15
16
|
from transformers import PreTrainedTokenizer, default_data_collator
|
|
16
17
|
from transformers.testing_utils import CaptureLogger
|
|
17
|
-
from huggingface_hub import hf_hub_download
|
|
18
18
|
|
|
19
19
|
logger = logging.getLogger(__name__)
|
|
20
20
|
|
|
@@ -65,7 +65,7 @@ class FisherMergingForCLIPVisionModel(
|
|
|
65
65
|
minimal_fisher_weight=minimal_fisher_weight,
|
|
66
66
|
num_fisher_examples=num_fisher_examples,
|
|
67
67
|
)
|
|
68
|
-
self.
|
|
68
|
+
self.dataloader_kwargs = dataloader_kwargs
|
|
69
69
|
self.zeroshot_weights_cache_dir = zeroshot_weights_cache_dir
|
|
70
70
|
for key, value in kwargs.items():
|
|
71
71
|
log.warning(f"Unused argument: {key}={value}")
|
|
@@ -127,7 +127,7 @@ class FisherMergingForCLIPVisionModel(
|
|
|
127
127
|
"""
|
|
128
128
|
# setup dataloader
|
|
129
129
|
train_dataset = CLIPDataset(train_dataset, self.clip_processor)
|
|
130
|
-
train_dataloader = DataLoader(train_dataset, **self.
|
|
130
|
+
train_dataloader = DataLoader(train_dataset, **self.dataloader_kwargs)
|
|
131
131
|
if self.fabric is not None:
|
|
132
132
|
train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
|
|
133
133
|
model = self.fabric.setup(model)
|
|
@@ -5,14 +5,14 @@ This implementation is largely based on the implementation from https://github.
|
|
|
5
5
|
import logging
|
|
6
6
|
import re
|
|
7
7
|
from collections import defaultdict
|
|
8
|
-
from typing import Dict, List
|
|
8
|
+
from typing import Any, Dict, List
|
|
9
9
|
|
|
10
10
|
import torch
|
|
11
11
|
from torch import Tensor, nn
|
|
12
12
|
from tqdm.autonotebook import tqdm
|
|
13
13
|
|
|
14
14
|
from fusion_bench.method import BaseAlgorithm
|
|
15
|
-
from fusion_bench.mixins import SimpleProfilerMixin
|
|
15
|
+
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
16
16
|
from fusion_bench.modelpool import BaseModelPool
|
|
17
17
|
|
|
18
18
|
log = logging.getLogger(__name__)
|
|
@@ -353,6 +353,7 @@ def filter_state_dict(
|
|
|
353
353
|
return filtered_state_dict
|
|
354
354
|
|
|
355
355
|
|
|
356
|
+
@auto_register_config
|
|
356
357
|
class FisherMergingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
357
358
|
"""
|
|
358
359
|
Implements the Fisher Merging Algorithm.
|
|
@@ -365,13 +366,6 @@ class FisherMergingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
365
366
|
Executes the Fisher merging process on the model pool and returns the merged model.
|
|
366
367
|
"""
|
|
367
368
|
|
|
368
|
-
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
369
|
-
"exclude_param_names_regex": "exclude_param_names_regex",
|
|
370
|
-
"normalize_fisher_weight": "normalize_fisher_weight",
|
|
371
|
-
"minimal_fisher_weight": "minimal_fisher_weight",
|
|
372
|
-
"num_fisher_examples": "num_fisher_examples",
|
|
373
|
-
}
|
|
374
|
-
|
|
375
369
|
def __init__(
|
|
376
370
|
self,
|
|
377
371
|
*,
|
|
@@ -379,12 +373,9 @@ class FisherMergingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
379
373
|
normalize_fisher_weight: bool,
|
|
380
374
|
minimal_fisher_weight: float,
|
|
381
375
|
num_fisher_examples: int,
|
|
376
|
+
**kwargs,
|
|
382
377
|
):
|
|
383
|
-
super().__init__()
|
|
384
|
-
self.exclude_param_names_regex = exclude_param_names_regex
|
|
385
|
-
self.normalize_fisher_weight = normalize_fisher_weight
|
|
386
|
-
self.minimal_fisher_weight = minimal_fisher_weight
|
|
387
|
-
self.num_fisher_examples = num_fisher_examples
|
|
378
|
+
super().__init__(**kwargs)
|
|
388
379
|
|
|
389
380
|
def run(self, modelpool: BaseModelPool) -> nn.Module:
|
|
390
381
|
"""
|
|
@@ -469,7 +460,7 @@ class FisherMergingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
469
460
|
self,
|
|
470
461
|
model_name: str,
|
|
471
462
|
model: nn.Module,
|
|
472
|
-
train_dataset,
|
|
463
|
+
train_dataset: Any,
|
|
473
464
|
param_names_to_merge: List[str],
|
|
474
465
|
) -> Dict[str, Tensor]:
|
|
475
466
|
"""
|
|
@@ -18,13 +18,14 @@ from transformers.models.gpt2.modeling_gpt2 import Conv1D
|
|
|
18
18
|
from fusion_bench.mixins import LightningFabricMixin
|
|
19
19
|
from fusion_bench.modelpool import GPT2ForSequenceClassificationPool
|
|
20
20
|
from fusion_bench.utils import timeit_context
|
|
21
|
-
|
|
21
|
+
from fusion_bench.mixins import auto_register_config
|
|
22
22
|
from .fisher_merging import FisherMergingAlgorithm, get_param_squared_gradients
|
|
23
23
|
|
|
24
24
|
|
|
25
|
+
@auto_register_config
|
|
25
26
|
class FisherMergingAlgorithmForGPT2(
|
|
26
|
-
FisherMergingAlgorithm,
|
|
27
27
|
LightningFabricMixin,
|
|
28
|
+
FisherMergingAlgorithm,
|
|
28
29
|
):
|
|
29
30
|
"""
|
|
30
31
|
Implements the Fisher Merging Algorithm for GPT-2 models on text classification tasks.
|
|
@@ -42,11 +43,6 @@ class FisherMergingAlgorithmForGPT2(
|
|
|
42
43
|
|
|
43
44
|
classifiers = {}
|
|
44
45
|
modelpool: GPT2ForSequenceClassificationPool = None
|
|
45
|
-
_config_mapping = FisherMergingAlgorithm._config_mapping | {
|
|
46
|
-
"cache_dir": "cache_dir",
|
|
47
|
-
"batch_size": "batch_size",
|
|
48
|
-
"num_workers": "num_workers",
|
|
49
|
-
}
|
|
50
46
|
|
|
51
47
|
def __init__(
|
|
52
48
|
self,
|
|
@@ -64,9 +60,6 @@ class FisherMergingAlgorithmForGPT2(
|
|
|
64
60
|
num_workers (int): Number of workers for data loading.
|
|
65
61
|
**kwargs: Additional keyword arguments.
|
|
66
62
|
"""
|
|
67
|
-
self.cache_dir = cache_dir
|
|
68
|
-
self.batch_size = batch_size
|
|
69
|
-
self.num_workers = num_workers
|
|
70
63
|
super().__init__(**kwargs)
|
|
71
64
|
|
|
72
65
|
def on_fisher_merging_start(self):
|
|
@@ -223,7 +223,7 @@ class FrankWolfeHardAlgorithm(
|
|
|
223
223
|
def get_shuffled_loader_iter(self, task: str):
|
|
224
224
|
if self.loss_fn == "cross_entropy":
|
|
225
225
|
# get dataloader kwargs
|
|
226
|
-
dataloader_kwargs = self.
|
|
226
|
+
dataloader_kwargs = self.dataloader_kwargs.copy()
|
|
227
227
|
dataloader_kwargs["shuffle"] = True
|
|
228
228
|
dataloader_kwargs["batch_size"] = 1
|
|
229
229
|
|
|
@@ -193,7 +193,7 @@ class FrankWolfeSoftAlgorithm(
|
|
|
193
193
|
@functools.cache
|
|
194
194
|
def get_shuffled_train_loader_iter(self, task: str, batch_size: int = 1):
|
|
195
195
|
# get dataloader kwargs
|
|
196
|
-
dataloader_kwargs = self.
|
|
196
|
+
dataloader_kwargs = self.dataloader_kwargs.copy()
|
|
197
197
|
dataloader_kwargs["shuffle"] = True
|
|
198
198
|
dataloader_kwargs["batch_size"] = batch_size
|
|
199
199
|
|
|
@@ -3,13 +3,12 @@ Example Usage:
|
|
|
3
3
|
|
|
4
4
|
```bash
|
|
5
5
|
fusion_bench \
|
|
6
|
-
|
|
6
|
+
path.log_dir=outputs/ViT-B-32/gossip_layer_wise_adamerging_adam \
|
|
7
|
+
method=adamerging/clip \
|
|
7
8
|
method.name=clip_layer_wise_adamerging \
|
|
8
9
|
method.save_merging_weights=merging_weights.pt \
|
|
9
|
-
modelpool=clip-vit-base-patch32_TA8 \
|
|
10
|
-
taskpool=clip-vit-classification_TA8
|
|
11
|
-
fabric_logger.root_dir=outputs/logs/ViT-B-32 \
|
|
12
|
-
fabric_logger.name=clip_layer_wise_adamerging_adam
|
|
10
|
+
modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8 \
|
|
11
|
+
taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8
|
|
13
12
|
```
|
|
14
13
|
"""
|
|
15
14
|
|
|
@@ -7,6 +7,7 @@ Reference:
|
|
|
7
7
|
|
|
8
8
|
import logging
|
|
9
9
|
from copy import deepcopy
|
|
10
|
+
from typing import Union
|
|
10
11
|
|
|
11
12
|
import torch
|
|
12
13
|
from torch import nn
|
|
@@ -79,7 +80,7 @@ class ExPOAlgorithm(BaseAlgorithm):
|
|
|
79
80
|
self.extrapolation_factor = extrapolation_factor
|
|
80
81
|
super().__init__(**kwargs)
|
|
81
82
|
|
|
82
|
-
def run(self, modelpool: BaseModelPool):
|
|
83
|
+
def run(self, modelpool: Union[BaseModelPool, list]) -> nn.Module:
|
|
83
84
|
"""
|
|
84
85
|
Run the ExPO merge algorithm.
|
|
85
86
|
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
from typing import Any
|
|
2
3
|
|
|
3
4
|
import torch
|
|
5
|
+
from torch import nn
|
|
4
6
|
|
|
5
7
|
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
6
8
|
from fusion_bench.utils.state_dict_arithmetic import state_dict_weighted_sum
|
|
@@ -10,7 +12,7 @@ log = logging.getLogger(__name__)
|
|
|
10
12
|
|
|
11
13
|
class LinearInterpolationAlgorithm(BaseAlgorithm):
|
|
12
14
|
R"""
|
|
13
|
-
LinearInterpolationAlgorithm performs linear interpolation between two models.
|
|
15
|
+
`LinearInterpolationAlgorithm` performs linear interpolation between two models.
|
|
14
16
|
Returns a model with the state dict that is a linear interpolation of the state dicts of the two models.
|
|
15
17
|
$\theta = (1-t) \theta_1 + t \theta_2$
|
|
16
18
|
"""
|
|
@@ -19,9 +21,9 @@ class LinearInterpolationAlgorithm(BaseAlgorithm):
|
|
|
19
21
|
"t": "t",
|
|
20
22
|
}
|
|
21
23
|
|
|
22
|
-
def __init__(self, t: float, **kwargs):
|
|
24
|
+
def __init__(self, t: float, **kwargs: Any):
|
|
23
25
|
"""
|
|
24
|
-
Initialize the LinearInterpolationAlgorithm with the given interpolation parameter.
|
|
26
|
+
Initialize the `LinearInterpolationAlgorithm` with the given interpolation parameter.
|
|
25
27
|
|
|
26
28
|
Args:
|
|
27
29
|
t (float): The interpolation parameter, should be in the range [0, 1].
|
|
@@ -31,7 +33,7 @@ class LinearInterpolationAlgorithm(BaseAlgorithm):
|
|
|
31
33
|
self.t = t
|
|
32
34
|
super().__init__(**kwargs)
|
|
33
35
|
|
|
34
|
-
def run(self, modelpool: BaseModelPool):
|
|
36
|
+
def run(self, modelpool: BaseModelPool) -> nn.Module:
|
|
35
37
|
"""
|
|
36
38
|
Run the linear interpolation algorithm on the given model pool.
|
|
37
39
|
|
|
@@ -1,15 +1,15 @@
|
|
|
1
1
|
from copy import deepcopy
|
|
2
2
|
from typing import TYPE_CHECKING, Optional
|
|
3
3
|
|
|
4
|
+
from omegaconf import flag_override
|
|
4
5
|
from typing_extensions import override
|
|
5
6
|
|
|
6
7
|
from fusion_bench import timeit_context
|
|
7
8
|
from fusion_bench.method.base_algorithm import BaseAlgorithm
|
|
8
9
|
from fusion_bench.method.simple_average import SimpleAverageAlgorithm
|
|
9
10
|
from fusion_bench.modelpool import CausalLMBackbonePool, CausalLMPool
|
|
10
|
-
from fusion_bench.utils.pylogger import getRankZeroLogger
|
|
11
|
-
from omegaconf import flag_override
|
|
12
11
|
from fusion_bench.utils import instantiate
|
|
12
|
+
from fusion_bench.utils.pylogger import getRankZeroLogger
|
|
13
13
|
|
|
14
14
|
log = getRankZeroLogger(__name__)
|
|
15
15
|
|
|
@@ -19,7 +19,6 @@ class SimpleAverageForLlama(BaseAlgorithm):
|
|
|
19
19
|
A simple averaging algorithm for LLama models. If `merge_backbone` is set to `True`, the backbone of the model will be averaged and the rest of the model will be loaded from the pre-trained model.
|
|
20
20
|
|
|
21
21
|
Examples:
|
|
22
|
-
|
|
23
22
|
The following example demonstrates how to use the `SimpleAverageForLlama` algorithm to merge Mistral models.
|
|
24
23
|
|
|
25
24
|
```bash
|
|
@@ -31,7 +31,7 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
|
31
31
|
from fusion_bench.dataset.llama.collate import bradley_terry_rm_collate
|
|
32
32
|
from fusion_bench.method import BaseAlgorithm
|
|
33
33
|
from fusion_bench.mixins import FabricTrainingMixin
|
|
34
|
-
from fusion_bench.modelpool import
|
|
34
|
+
from fusion_bench.modelpool import SequenceClassificationModelPool
|
|
35
35
|
from fusion_bench.utils import instantiate
|
|
36
36
|
from fusion_bench.utils.dtype import get_dtype
|
|
37
37
|
|
|
@@ -121,7 +121,7 @@ class BradleyTerryRewardModeling(BaseAlgorithm, FabricTrainingMixin):
|
|
|
121
121
|
self.fix_token_embedding = fix_token_embedding
|
|
122
122
|
super().__init__(**kwargs)
|
|
123
123
|
|
|
124
|
-
def run(self, modelpool:
|
|
124
|
+
def run(self, modelpool: SequenceClassificationModelPool):
|
|
125
125
|
self.modelpool = modelpool
|
|
126
126
|
self.setup()
|
|
127
127
|
self.train(self.model, self.optimizer, self.lr_scheduler)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from typing import Optional
|
|
2
|
+
from typing import Any, Optional
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
from tqdm.autonotebook import tqdm
|
|
@@ -23,8 +23,7 @@ from transformers.models.mixtral.modeling_mixtral import (
|
|
|
23
23
|
)
|
|
24
24
|
from transformers.utils import ContextManagers
|
|
25
25
|
|
|
26
|
-
from fusion_bench
|
|
27
|
-
from fusion_bench.modelpool import BaseModelPool
|
|
26
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
|
|
28
27
|
|
|
29
28
|
log = logging.getLogger(__name__)
|
|
30
29
|
|
|
@@ -114,7 +113,7 @@ def _upscale_decoder_layer(
|
|
|
114
113
|
|
|
115
114
|
def upscale_to_mixtral_model(
|
|
116
115
|
input_model: LlamaModel | MistralModel, output_model: MixtralModel
|
|
117
|
-
):
|
|
116
|
+
) -> None:
|
|
118
117
|
"""
|
|
119
118
|
A helper function.
|
|
120
119
|
|
|
@@ -140,7 +139,7 @@ def upscale_to_mixtral_model(
|
|
|
140
139
|
|
|
141
140
|
def upscale_to_mixtral_for_causal_lm(
|
|
142
141
|
input_model: LlamaForCausalLM | MistralForCausalLM, output_model: MixtralForCausalLM
|
|
143
|
-
):
|
|
142
|
+
) -> None:
|
|
144
143
|
"""
|
|
145
144
|
A helper function.
|
|
146
145
|
|
|
@@ -157,24 +156,19 @@ def upscale_to_mixtral_for_causal_lm(
|
|
|
157
156
|
upscale_to_mixtral_model(input_model.model, output_model.model)
|
|
158
157
|
|
|
159
158
|
|
|
159
|
+
@auto_register_config
|
|
160
160
|
class MixtralUpscalingAlgorithm(BaseAlgorithm):
|
|
161
161
|
"""
|
|
162
162
|
This class is responsible for upscaling a model to a MixtralModel.
|
|
163
163
|
It inherits from the ModelFusionAlgorithm class.
|
|
164
164
|
"""
|
|
165
165
|
|
|
166
|
-
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
167
|
-
"num_experts": "num_experts",
|
|
168
|
-
"experts_per_token": "experts_per_token",
|
|
169
|
-
"save_checkpoint": "save_checkpoint",
|
|
170
|
-
}
|
|
171
|
-
|
|
172
166
|
def __init__(
|
|
173
167
|
self,
|
|
174
168
|
num_experts: int,
|
|
175
169
|
experts_per_token: int,
|
|
176
170
|
save_checkpoint: str,
|
|
177
|
-
**kwargs,
|
|
171
|
+
**kwargs: Any,
|
|
178
172
|
):
|
|
179
173
|
"""
|
|
180
174
|
Initialize the MixtralUpscalingAlgorithm.
|
|
@@ -185,9 +179,6 @@ class MixtralUpscalingAlgorithm(BaseAlgorithm):
|
|
|
185
179
|
save_checkpoint (str): The path to save the checkpoint.
|
|
186
180
|
**kwargs: Additional keyword arguments.
|
|
187
181
|
"""
|
|
188
|
-
self.num_experts = num_experts
|
|
189
|
-
self.experts_per_token = experts_per_token
|
|
190
|
-
self.save_checkpoint = save_checkpoint
|
|
191
182
|
super().__init__(**kwargs)
|
|
192
183
|
|
|
193
184
|
@torch.no_grad()
|
|
@@ -242,24 +233,19 @@ class MixtralUpscalingAlgorithm(BaseAlgorithm):
|
|
|
242
233
|
return mixtral_model
|
|
243
234
|
|
|
244
235
|
|
|
236
|
+
@auto_register_config
|
|
245
237
|
class MixtralForCausalLMUpscalingAlgorithm(BaseAlgorithm):
|
|
246
238
|
"""
|
|
247
239
|
This class is responsible for upscaling a model to a MixtralForCausalLM.
|
|
248
240
|
It inherits from the ModelFusionAlgorithm class.
|
|
249
241
|
"""
|
|
250
242
|
|
|
251
|
-
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
252
|
-
"num_experts": "num_experts",
|
|
253
|
-
"experts_per_token": "experts_per_token",
|
|
254
|
-
"save_checkpoint": "save_checkpoint",
|
|
255
|
-
}
|
|
256
|
-
|
|
257
243
|
def __init__(
|
|
258
244
|
self,
|
|
259
245
|
num_experts: int,
|
|
260
246
|
experts_per_token: int,
|
|
261
247
|
save_checkpoint: str,
|
|
262
|
-
**kwargs,
|
|
248
|
+
**kwargs: Any,
|
|
263
249
|
):
|
|
264
250
|
"""
|
|
265
251
|
Initialize the MixtralForCausalLMUpscalingAlgorithm.
|
|
@@ -270,9 +256,6 @@ class MixtralForCausalLMUpscalingAlgorithm(BaseAlgorithm):
|
|
|
270
256
|
save_checkpoint (str): The path to save the checkpoint.
|
|
271
257
|
**kwargs: Additional keyword arguments.
|
|
272
258
|
"""
|
|
273
|
-
self.num_experts = num_experts
|
|
274
|
-
self.experts_per_token = experts_per_token
|
|
275
|
-
self.save_checkpoint = save_checkpoint
|
|
276
259
|
super().__init__(**kwargs)
|
|
277
260
|
|
|
278
261
|
@torch.no_grad()
|
|
@@ -302,7 +285,7 @@ class MixtralForCausalLMUpscalingAlgorithm(BaseAlgorithm):
|
|
|
302
285
|
self.config.experts_per_token,
|
|
303
286
|
)
|
|
304
287
|
|
|
305
|
-
with ContextManagers([no_init_weights(
|
|
288
|
+
with ContextManagers([no_init_weights()]):
|
|
306
289
|
for _ in tqdm(range(1), desc="Initializing Mixtral model"):
|
|
307
290
|
mixtral_model = MixtralForCausalLM(mixtral_config)
|
|
308
291
|
upscale_to_mixtral_for_causal_lm(pretrained_model, mixtral_model)
|
|
@@ -5,6 +5,7 @@ from typing import List, Mapping, Union # noqa: F401
|
|
|
5
5
|
import torch
|
|
6
6
|
from torch import nn
|
|
7
7
|
|
|
8
|
+
from fusion_bench import auto_register_config
|
|
8
9
|
from fusion_bench.method import BaseAlgorithm
|
|
9
10
|
from fusion_bench.modelpool import BaseModelPool
|
|
10
11
|
|
|
@@ -52,17 +53,13 @@ def recombine_state_dict(models: List[nn.Module]):
|
|
|
52
53
|
return models
|
|
53
54
|
|
|
54
55
|
|
|
56
|
+
@auto_register_config
|
|
55
57
|
class ModelRecombinationAlgorithm(BaseAlgorithm):
|
|
56
58
|
"""
|
|
57
59
|
Model recombination recombinates the layers of the given models, to create a new set of models.
|
|
58
60
|
"""
|
|
59
61
|
|
|
60
|
-
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
61
|
-
"return_modelpool": "return_modelpool",
|
|
62
|
-
}
|
|
63
|
-
|
|
64
62
|
def __init__(self, return_modelpool: bool, **kwargs):
|
|
65
|
-
self.return_modelpool = return_modelpool
|
|
66
63
|
super().__init__(**kwargs)
|
|
67
64
|
|
|
68
65
|
@torch.no_grad()
|
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
# Code adapted from https://github.com/IST-DASLab/sparsegpt/blob/master/datautils.py
|
|
2
2
|
|
|
3
|
+
import os
|
|
3
4
|
import random
|
|
4
5
|
from typing import List, Optional, Tuple, cast # noqa: F401
|
|
5
|
-
|
|
6
|
+
|
|
6
7
|
from datasets import load_dataset
|
|
7
8
|
from torch import Tensor
|
|
8
9
|
from tqdm.auto import tqdm
|
|
@@ -107,7 +107,12 @@ def prepare_calibration_input(
|
|
|
107
107
|
device=device,
|
|
108
108
|
requires_grad=False,
|
|
109
109
|
)
|
|
110
|
-
cache = {
|
|
110
|
+
cache = {
|
|
111
|
+
"i": 0,
|
|
112
|
+
"attention_mask": None,
|
|
113
|
+
"position_ids": None,
|
|
114
|
+
"position_embeddings": None,
|
|
115
|
+
}
|
|
111
116
|
|
|
112
117
|
class Catcher(nn.Module):
|
|
113
118
|
def __init__(self, module):
|
|
@@ -167,7 +167,7 @@ class MagnitudePruningForLlama(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
167
167
|
super().__init__(**kwargs)
|
|
168
168
|
|
|
169
169
|
@torch.no_grad()
|
|
170
|
-
def run(self, modelpool: CausalLMPool):
|
|
170
|
+
def run(self, modelpool: CausalLMPool) -> LlamaForCausalLM:
|
|
171
171
|
"""
|
|
172
172
|
Execute the pruning process on the first model from the given model pool.
|
|
173
173
|
|
|
@@ -4,12 +4,11 @@ import os
|
|
|
4
4
|
import random
|
|
5
5
|
from typing import List, Optional, Tuple, cast # noqa: F401
|
|
6
6
|
|
|
7
|
+
from datasets import load_dataset
|
|
7
8
|
from torch import Tensor
|
|
8
9
|
from tqdm.auto import tqdm
|
|
9
10
|
from transformers import PreTrainedTokenizer
|
|
10
11
|
|
|
11
|
-
from datasets import load_dataset
|
|
12
|
-
|
|
13
12
|
|
|
14
13
|
# Wrapper for tokenized input IDs
|
|
15
14
|
class TokenizerWrapper:
|
|
@@ -16,14 +16,18 @@ from transformers import CLIPVisionModel
|
|
|
16
16
|
from transformers.models.clip.modeling_clip import CLIPEncoderLayer
|
|
17
17
|
from typing_extensions import override
|
|
18
18
|
|
|
19
|
-
from fusion_bench
|
|
19
|
+
from fusion_bench import (
|
|
20
|
+
BaseAlgorithm,
|
|
21
|
+
auto_register_config,
|
|
22
|
+
print_parameters,
|
|
23
|
+
timeit_context,
|
|
24
|
+
)
|
|
25
|
+
from fusion_bench.dataset import CLIPDataset
|
|
20
26
|
from fusion_bench.method.task_arithmetic import task_arithmetic_merge
|
|
21
27
|
from fusion_bench.mixins.clip_classification import CLIPClassificationMixin
|
|
22
28
|
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
23
29
|
from fusion_bench.modelpool import CLIPVisionModelPool
|
|
24
|
-
from fusion_bench.utils import timeit_context
|
|
25
30
|
from fusion_bench.utils.data import InfiniteDataLoader
|
|
26
|
-
from fusion_bench.utils.parameters import print_parameters
|
|
27
31
|
|
|
28
32
|
from .module import ParetoWeightEnsemblingModule
|
|
29
33
|
from .utils import generate_simplex_grid
|
|
@@ -31,27 +35,13 @@ from .utils import generate_simplex_grid
|
|
|
31
35
|
log = logging.getLogger(__name__)
|
|
32
36
|
|
|
33
37
|
|
|
38
|
+
@auto_register_config
|
|
34
39
|
class PWEMoEAlgorithmForCLIP(
|
|
35
40
|
BaseAlgorithm,
|
|
36
41
|
SimpleProfilerMixin,
|
|
37
42
|
CLIPClassificationMixin,
|
|
38
43
|
):
|
|
39
44
|
modelpool: CLIPVisionModelPool = None
|
|
40
|
-
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
41
|
-
"upscale_mlp": "upscale_mlp",
|
|
42
|
-
"upscale_attn": "upscale_attn",
|
|
43
|
-
"init_lambda": "init_lambda",
|
|
44
|
-
"router_hidden_layers": "router_hidden_layers",
|
|
45
|
-
"lr": "lr",
|
|
46
|
-
"num_steps": "num_steps",
|
|
47
|
-
"save_interval": "save_interval",
|
|
48
|
-
"alpha": "alpha",
|
|
49
|
-
"checkpoint_path": "checkpoint_path",
|
|
50
|
-
"eval_grid": "eval_grid",
|
|
51
|
-
"eval_grid_n": "eval_grid_n",
|
|
52
|
-
"eval_grid_m": "eval_grid_m",
|
|
53
|
-
"_dataloader_kwargs": "dataloader_kwargs",
|
|
54
|
-
}
|
|
55
45
|
|
|
56
46
|
def __init__(
|
|
57
47
|
self,
|
|
@@ -72,19 +62,6 @@ class PWEMoEAlgorithmForCLIP(
|
|
|
72
62
|
**kwargs,
|
|
73
63
|
):
|
|
74
64
|
super().__init__(**kwargs)
|
|
75
|
-
self.upscale_mlp = upscale_mlp
|
|
76
|
-
self.upscale_attn = upscale_attn
|
|
77
|
-
self.init_lambda = init_lambda
|
|
78
|
-
self.router_hidden_layers = router_hidden_layers
|
|
79
|
-
self.lr = lr
|
|
80
|
-
self.num_steps = num_steps
|
|
81
|
-
self.save_interval = save_interval
|
|
82
|
-
self.alpha = alpha
|
|
83
|
-
self.checkpoint_path = checkpoint_path
|
|
84
|
-
self.eval_grid = eval_grid
|
|
85
|
-
self.eval_grid_n = eval_grid_n
|
|
86
|
-
self.eval_grid_m = eval_grid_m
|
|
87
|
-
self._dataloader_kwargs = dataloader_kwargs
|
|
88
65
|
|
|
89
66
|
@override
|
|
90
67
|
def run(self, modelpool: CLIPVisionModelPool):
|
|
@@ -193,13 +170,14 @@ class PWEMoEAlgorithmForCLIP(
|
|
|
193
170
|
Loads the datasets specified in the configuration.
|
|
194
171
|
"""
|
|
195
172
|
train_datasets = {
|
|
196
|
-
dataset_name:
|
|
197
|
-
dataset_name,
|
|
173
|
+
dataset_name: CLIPDataset(
|
|
174
|
+
self.modelpool.load_train_dataset(dataset_name),
|
|
175
|
+
processor=self.clip_processor,
|
|
198
176
|
)
|
|
199
177
|
for dataset_name in self.modelpool.model_names
|
|
200
178
|
}
|
|
201
179
|
train_loaders = {
|
|
202
|
-
dataset_name: DataLoader(dataset, shuffle=True, **self.
|
|
180
|
+
dataset_name: DataLoader(dataset, shuffle=True, **self.dataloader_kwargs)
|
|
203
181
|
for dataset_name, dataset in train_datasets.items()
|
|
204
182
|
}
|
|
205
183
|
train_loaders = {
|
|
@@ -5,9 +5,7 @@ import torch
|
|
|
5
5
|
|
|
6
6
|
from fusion_bench.modelpool import BaseModelPool
|
|
7
7
|
from fusion_bench.utils.parameters import count_parameters
|
|
8
|
-
from fusion_bench.utils.state_dict_arithmetic import
|
|
9
|
-
state_dict_mul,
|
|
10
|
-
)
|
|
8
|
+
from fusion_bench.utils.state_dict_arithmetic import state_dict_mul
|
|
11
9
|
|
|
12
10
|
from .base_algorithm import SuperposedAlgorithmBase, compare_models
|
|
13
11
|
|
|
@@ -27,7 +27,7 @@ class RegMeanAlgorithmForCLIP(
|
|
|
27
27
|
|
|
28
28
|
def __init__(self, *, dataloader_kwargs: DictConfig, **kwargs):
|
|
29
29
|
super().__init__(**kwargs)
|
|
30
|
-
self.
|
|
30
|
+
self.dataloader_kwargs = dataloader_kwargs
|
|
31
31
|
|
|
32
32
|
def on_regmean_start(self):
|
|
33
33
|
self.setup_zero_shot_classification_head()
|
|
@@ -60,7 +60,7 @@ class RegMeanAlgorithmForCLIP(
|
|
|
60
60
|
# setup dataloader
|
|
61
61
|
train_dataset = CLIPDataset(train_dataset, self.clip_processor)
|
|
62
62
|
train_dataloader = DataLoader(
|
|
63
|
-
train_dataset, shuffle=True, **self.
|
|
63
|
+
train_dataset, shuffle=True, **self.dataloader_kwargs
|
|
64
64
|
)
|
|
65
65
|
train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
|
|
66
66
|
model = self.fabric.setup(model)
|