fusion-bench 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +1 -0
- fusion_bench/compat/method/base_algorithm.py +0 -1
- fusion_bench/compat/modelpool/__init__.py +2 -1
- fusion_bench/dataset/arc_agi/__init__.py +6 -1
- fusion_bench/dataset/arc_agi/arc.py +21 -7
- fusion_bench/dataset/arc_agi/arc_agi.py +156 -25
- fusion_bench/dataset/arc_agi/np_cache.py +0 -1
- fusion_bench/dataset/arc_agi/preprocess.py +50 -8
- fusion_bench/dataset/llama/collate.py +10 -3
- fusion_bench/method/__init__.py +3 -0
- fusion_bench/method/adamerging/__init__.py +1 -1
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +47 -5
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +58 -23
- fusion_bench/method/pruning/magnitude_diff_pruning.py +2 -1
- fusion_bench/method/rankone_moe/__init__.py +3 -0
- fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
- fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
- fusion_bench/method/simple_average.py +1 -1
- fusion_bench/mixins/clip_classification.py +2 -7
- fusion_bench/mixins/lightning_fabric.py +2 -2
- fusion_bench/models/rankone_moe.py +410 -0
- fusion_bench/taskpool/__init__.py +10 -2
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +2 -1
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/RECORD +36 -29
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +4 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +13 -7
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +1 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/top_level.txt +0 -0
|
@@ -13,6 +13,7 @@ from omegaconf import DictConfig
|
|
|
13
13
|
from torch import nn
|
|
14
14
|
from torch.utils.data import ConcatDataset, DataLoader
|
|
15
15
|
from tqdm.auto import tqdm
|
|
16
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
16
17
|
from typing_extensions import TYPE_CHECKING, override
|
|
17
18
|
|
|
18
19
|
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
@@ -117,6 +118,9 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
117
118
|
self.model.gradient_checkpointing_enable(
|
|
118
119
|
gradient_checkpointing_kwargs={"use_reentrant": True}
|
|
119
120
|
)
|
|
121
|
+
self.use_cache = False
|
|
122
|
+
else:
|
|
123
|
+
self.use_cache = True
|
|
120
124
|
self.model_dtype = get_dtype(self.model)
|
|
121
125
|
|
|
122
126
|
def configure_optimizer(self):
|
|
@@ -215,7 +219,12 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
215
219
|
# disable gradient synchronization if accumulating gradients across steps for improved performance
|
|
216
220
|
with fabric.no_backward_sync(self.model, enabled=is_accumulating):
|
|
217
221
|
# use_cache=True is not compatible with gradient checkpointing, so we disable it here
|
|
218
|
-
output = self.model(
|
|
222
|
+
output = self.model(
|
|
223
|
+
input_ids=batch["input_ids"],
|
|
224
|
+
attention_mask=batch["attention_mask"],
|
|
225
|
+
labels=batch["labels"],
|
|
226
|
+
use_cache=self.use_cache,
|
|
227
|
+
)
|
|
219
228
|
loss = output["loss"]
|
|
220
229
|
|
|
221
230
|
fabric.backward(loss)
|
|
@@ -252,7 +261,7 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
252
261
|
):
|
|
253
262
|
break
|
|
254
263
|
# break if max_steps is set, and exit training
|
|
255
|
-
if self.max_steps > 0 and self.global_step_idx >= self.max_steps:
|
|
264
|
+
if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
|
|
256
265
|
self.is_training = False
|
|
257
266
|
break
|
|
258
267
|
|
|
@@ -328,14 +337,15 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
328
337
|
"checkpoints",
|
|
329
338
|
"latest_model.ckpt",
|
|
330
339
|
),
|
|
331
|
-
os.path.join(
|
|
340
|
+
dst := os.path.join(
|
|
332
341
|
self.log_dir,
|
|
333
342
|
"checkpoints",
|
|
334
343
|
f"epoch={self.epoch_idx}_step={self.global_step_idx}.ckpt",
|
|
335
344
|
),
|
|
345
|
+
target_is_directory=os.path.isdir(dst),
|
|
336
346
|
)
|
|
337
347
|
except Exception as e:
|
|
338
|
-
|
|
348
|
+
log.error(f"Failed to create symlink: {e}")
|
|
339
349
|
else:
|
|
340
350
|
raise ValueError(
|
|
341
351
|
f"Unknown stage: {stage}. Available options: 'end_of_step', 'end_of_epoch', 'end_of_training'"
|
|
@@ -364,8 +374,15 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
364
374
|
}
|
|
365
375
|
)
|
|
366
376
|
|
|
377
|
+
trainable_param_names = set(
|
|
378
|
+
name
|
|
379
|
+
for name, param in self.model.state_dict(keep_vars=True).items()
|
|
380
|
+
if param.requires_grad
|
|
381
|
+
)
|
|
367
382
|
filter = (
|
|
368
|
-
None
|
|
383
|
+
None
|
|
384
|
+
if self.save_full_model
|
|
385
|
+
else {"model": lambda k, p: k in trainable_param_names}
|
|
369
386
|
)
|
|
370
387
|
|
|
371
388
|
fabric.save(path, state=state, filter=filter)
|
|
@@ -401,3 +418,28 @@ def load_checkpoint(
|
|
|
401
418
|
state = {"model": model}
|
|
402
419
|
state.update(state_components)
|
|
403
420
|
fabric.load(ckpt_path, state=state, strict=strict)
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
if __name__ == "__main__":
|
|
424
|
+
# convert a checkpoint to hf format
|
|
425
|
+
import argparse
|
|
426
|
+
|
|
427
|
+
parser = argparse.ArgumentParser()
|
|
428
|
+
parser.add_argument("--base_model_path", type=str)
|
|
429
|
+
parser.add_argument("--ckpt_path", type=str)
|
|
430
|
+
parser.add_argument("--output_path", type=str)
|
|
431
|
+
|
|
432
|
+
args = parser.parse_args()
|
|
433
|
+
|
|
434
|
+
fabric = L.Fabric(devices=1, strategy="fsdp")
|
|
435
|
+
fabric.launch()
|
|
436
|
+
|
|
437
|
+
tokenizer = AutoTokenizer.from_pretrained(args.base_model_path)
|
|
438
|
+
tokenizer.save_pretrained(args.output_path)
|
|
439
|
+
|
|
440
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
441
|
+
args.base_model_path, torch_dtype=torch.bfloat16
|
|
442
|
+
)
|
|
443
|
+
model = fabric.setup_module(model)
|
|
444
|
+
load_checkpoint(fabric, args.ckpt_path, model=model, strict=True)
|
|
445
|
+
model.save_pretrained(args.output_path)
|
|
@@ -10,10 +10,10 @@ import peft
|
|
|
10
10
|
import torch
|
|
11
11
|
from lightning.fabric.strategies.fsdp import FSDPStrategy
|
|
12
12
|
from lightning.fabric.utilities import rank_zero_only
|
|
13
|
-
from omegaconf import DictConfig
|
|
13
|
+
from omegaconf import DictConfig, OmegaConf
|
|
14
14
|
from peft import PeftModel, get_peft_config, get_peft_model
|
|
15
15
|
from torch import nn
|
|
16
|
-
from torch.utils.data import
|
|
16
|
+
from torch.utils.data import ConcatDataset, DataLoader
|
|
17
17
|
from tqdm.auto import tqdm
|
|
18
18
|
from typing_extensions import TYPE_CHECKING, override
|
|
19
19
|
|
|
@@ -65,7 +65,9 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
65
65
|
gradient_clip_algorithm: Literal["value", "norm"] = "norm",
|
|
66
66
|
save_optimizer_state: bool = False,
|
|
67
67
|
save_full_model: bool = False,
|
|
68
|
+
save_ckpt_type: Literal["lightning", "peft"] = "peft",
|
|
68
69
|
ckpt_path: Optional[str] = None,
|
|
70
|
+
max_length: int = 6150,
|
|
69
71
|
**kwargs,
|
|
70
72
|
):
|
|
71
73
|
"""
|
|
@@ -90,6 +92,7 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
90
92
|
gradient_clip_algorithm(str): Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.
|
|
91
93
|
save_optimizer_state(bool): Whether to save the optimizer and lr_scheduler state along with the model checkpoint.
|
|
92
94
|
save_full_model(bool): Whether to save the full model or only the trainable parameters in the model checkpoint.
|
|
95
|
+
save_ckpt_type(str): Type of checkpoint to save. Available options: 'lightning', 'peft'. If set to 'lightning', the model will be saved using the Lightning checkpointing mechanism. If set to 'peft', the model will be saved using the PEFT checkpointing mechanism.
|
|
93
96
|
ckpt_path(str): Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.
|
|
94
97
|
"""
|
|
95
98
|
self._optimizer = optimizer
|
|
@@ -110,7 +113,9 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
110
113
|
self.gradient_clip_algorithm = gradient_clip_algorithm
|
|
111
114
|
self.save_optimizer_state = save_optimizer_state
|
|
112
115
|
self.save_full_model = save_full_model
|
|
116
|
+
self.save_ckpt_type = save_ckpt_type
|
|
113
117
|
self.ckpt_path = ckpt_path
|
|
118
|
+
self.max_length = max_length
|
|
114
119
|
super().__init__(**kwargs)
|
|
115
120
|
|
|
116
121
|
def run(self, modelpool: CausalLMPool):
|
|
@@ -126,7 +131,7 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
126
131
|
model = self.modelpool.load_pretrained_model()
|
|
127
132
|
|
|
128
133
|
# get the PEFT model
|
|
129
|
-
peft_config =
|
|
134
|
+
peft_config = instantiate(self._peft_config, _convert_="all")
|
|
130
135
|
peft_model = get_peft_model(model, peft_config, self.adapter_name)
|
|
131
136
|
peft_model.print_trainable_parameters()
|
|
132
137
|
|
|
@@ -139,6 +144,10 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
139
144
|
self.model.gradient_checkpointing_enable(
|
|
140
145
|
gradient_checkpointing_kwargs={"use_reentrant": True}
|
|
141
146
|
)
|
|
147
|
+
self.use_cache = False
|
|
148
|
+
else:
|
|
149
|
+
self.use_cache = True
|
|
150
|
+
|
|
142
151
|
self.model_dtype = get_dtype(self.model)
|
|
143
152
|
|
|
144
153
|
def configure_optimizer(self):
|
|
@@ -234,10 +243,22 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
234
243
|
):
|
|
235
244
|
is_accumulating = (step_idx + 1) % self.accumulate_grad_batches != 0
|
|
236
245
|
|
|
246
|
+
if self.max_length > 0 and batch["input_ids"].shape[1] > self.max_length:
|
|
247
|
+
log.warning(
|
|
248
|
+
f"Input length exceeds max_length: {batch['input_ids'].shape[1]} > {self.max_length}. Truncating input."
|
|
249
|
+
)
|
|
250
|
+
batch["input_ids"] = batch["input_ids"][:, : self.max_length]
|
|
251
|
+
batch["attention_mask"] = batch["attention_mask"][:, : self.max_length]
|
|
252
|
+
batch["labels"] = batch["labels"][:, : self.max_length]
|
|
237
253
|
# disable gradient synchronization if accumulating gradients across steps for improved performance
|
|
238
254
|
with fabric.no_backward_sync(self.model, enabled=is_accumulating):
|
|
239
255
|
# use_cache=True is not compatible with gradient checkpointing, so we disable it here
|
|
240
|
-
output = self.model(
|
|
256
|
+
output = self.model(
|
|
257
|
+
input_ids=batch["input_ids"],
|
|
258
|
+
attention_mask=batch["attention_mask"],
|
|
259
|
+
labels=batch["labels"],
|
|
260
|
+
use_cache=self.use_cache,
|
|
261
|
+
)
|
|
241
262
|
loss = output["loss"]
|
|
242
263
|
|
|
243
264
|
fabric.backward(loss)
|
|
@@ -274,7 +295,7 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
274
295
|
):
|
|
275
296
|
break
|
|
276
297
|
# break if max_steps is set, and exit training
|
|
277
|
-
if self.max_steps > 0 and self.global_step_idx >= self.max_steps:
|
|
298
|
+
if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
|
|
278
299
|
self.is_training = False
|
|
279
300
|
break
|
|
280
301
|
|
|
@@ -350,14 +371,15 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
350
371
|
"checkpoints",
|
|
351
372
|
"latest_model.ckpt",
|
|
352
373
|
),
|
|
353
|
-
os.path.join(
|
|
374
|
+
dst := os.path.join(
|
|
354
375
|
self.log_dir,
|
|
355
376
|
"checkpoints",
|
|
356
377
|
f"epoch={self.epoch_idx}_step={self.global_step_idx}.ckpt",
|
|
357
378
|
),
|
|
379
|
+
target_is_directory=os.path.isdir(dst),
|
|
358
380
|
)
|
|
359
381
|
except Exception as e:
|
|
360
|
-
|
|
382
|
+
log.error(f"Failed to create symlink: {e}")
|
|
361
383
|
else:
|
|
362
384
|
raise ValueError(
|
|
363
385
|
f"Unknown stage: {stage}. Available options: 'end_of_step', 'end_of_epoch', 'end_of_training'"
|
|
@@ -373,24 +395,37 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
373
395
|
return log.warning(f"Checkpoint already exists at {path}. Skipping save.")
|
|
374
396
|
|
|
375
397
|
fabric = self.fabric
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
398
|
+
if self.save_ckpt_type == "lightning":
|
|
399
|
+
state = {"model": self.model}
|
|
400
|
+
|
|
401
|
+
# save the optimizer and lr_scheduler state if needed
|
|
402
|
+
if self.save_optimizer_state and save_optimizer_state is not False:
|
|
403
|
+
state.update(
|
|
404
|
+
{
|
|
405
|
+
"optimizer": self.optimizer,
|
|
406
|
+
"lr_scheduler": self.lr_scheduler,
|
|
407
|
+
"global_step_idx": self.global_step_idx,
|
|
408
|
+
"epoch_idx": self.epoch_idx,
|
|
409
|
+
}
|
|
410
|
+
)
|
|
411
|
+
trainable_param_names = set(
|
|
412
|
+
name
|
|
413
|
+
for name, param in self.model.state_dict(keep_vars=True).items()
|
|
414
|
+
if param.requires_grad
|
|
415
|
+
)
|
|
416
|
+
filter = (
|
|
417
|
+
None
|
|
418
|
+
if self.save_full_model
|
|
419
|
+
else {"model": lambda k, p: k in trainable_param_names}
|
|
387
420
|
)
|
|
388
421
|
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
422
|
+
fabric.save(path, state=state, filter=filter)
|
|
423
|
+
elif self.save_ckpt_type == "peft":
|
|
424
|
+
self.model.save_pretrained(path, is_main_process=fabric.is_global_zero)
|
|
425
|
+
else:
|
|
426
|
+
raise ValueError(
|
|
427
|
+
f"Unknown save_ckpt_type: {self.save_ckpt_type}. Available options: 'lightning', 'peft'"
|
|
428
|
+
)
|
|
394
429
|
self._latest_saved_checkpoint_global_step = self.global_step_idx
|
|
395
430
|
|
|
396
431
|
def load_checkpoint(self, path: Union[str, Path]):
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import functools
|
|
1
2
|
import logging
|
|
2
3
|
import re
|
|
3
4
|
from copy import deepcopy
|
|
@@ -10,7 +11,7 @@ from tqdm.auto import tqdm
|
|
|
10
11
|
from fusion_bench.method import BaseAlgorithm
|
|
11
12
|
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
12
13
|
from fusion_bench.modelpool import BaseModelPool
|
|
13
|
-
|
|
14
|
+
|
|
14
15
|
from .prune_utils import unstructured_magnitude_prune_
|
|
15
16
|
|
|
16
17
|
log = logging.getLogger(__name__)
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
from torch.utils.data import DataLoader
|
|
9
|
+
from transformers.models.clip.modeling_clip import CLIPEncoder
|
|
10
|
+
|
|
11
|
+
from fusion_bench.dataset import CLIPDataset
|
|
12
|
+
from fusion_bench.method.task_arithmetic.task_arithmetic import task_arithmetic_merge
|
|
13
|
+
from fusion_bench.mixins import CLIPClassificationMixin
|
|
14
|
+
from fusion_bench.modelpool import CLIPVisionModelPool
|
|
15
|
+
from fusion_bench.models.rankone_moe import RankOneMoE
|
|
16
|
+
from fusion_bench.utils.data import InfiniteDataLoader
|
|
17
|
+
|
|
18
|
+
from .rankone_moe import RankOneMoEAlgorithm
|
|
19
|
+
|
|
20
|
+
log = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class CLIPRankOneMoEAlgorithm(
|
|
24
|
+
RankOneMoEAlgorithm,
|
|
25
|
+
CLIPClassificationMixin,
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
CLIPRankOneMoEAlgorithm is a class that implements the RankOneMoEAlgorithm (https://github.com/EnnengYang/RankOne-MoE)
|
|
29
|
+
for CLIP models. It extends the RankOneMoEAlgorithm and CLIPClassificationMixin classes.
|
|
30
|
+
|
|
31
|
+
Attributes:
|
|
32
|
+
modelpool (CLIPVisionModelPool): The model pool containing the CLIP models.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
modelpool: CLIPVisionModelPool = None
|
|
36
|
+
|
|
37
|
+
def load_checkpoint(self, model, checkpoint):
|
|
38
|
+
"""
|
|
39
|
+
Load the checkpoint file.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
model: The model to load the checkpoint into.
|
|
43
|
+
checkpoint: The path to the checkpoint file.
|
|
44
|
+
"""
|
|
45
|
+
state = {"model": model}
|
|
46
|
+
self._fabric.load(checkpoint, state)
|
|
47
|
+
|
|
48
|
+
def save_checkpoint(self, model, checkpoint):
|
|
49
|
+
"""
|
|
50
|
+
Save the checkpoint file.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
model: The model to save the checkpoint from.
|
|
54
|
+
checkpoint: The path to the checkpoint file.
|
|
55
|
+
"""
|
|
56
|
+
self._fabric.save(checkpoint, {"model": model})
|
|
57
|
+
|
|
58
|
+
def construct_moe_model(self) -> RankOneMoE:
|
|
59
|
+
"""
|
|
60
|
+
Construct the RankOne-MoE model using the models in the model pool.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
RankOne-MoE: The constructed MoE model.
|
|
64
|
+
"""
|
|
65
|
+
base_model = self.modelpool.load_model("_pretrained_")
|
|
66
|
+
expert_models = [
|
|
67
|
+
self.modelpool.load_model(m) for m in self.modelpool.model_names
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
# Merge the models using task arithmetic
|
|
71
|
+
moe_model = task_arithmetic_merge(
|
|
72
|
+
# This function modifies the model in place, so we need to pass a deepcopy
|
|
73
|
+
deepcopy(base_model),
|
|
74
|
+
expert_models,
|
|
75
|
+
scaling_factor=self.config.init_lambda,
|
|
76
|
+
).requires_grad_(False)
|
|
77
|
+
|
|
78
|
+
# Up-scale MLP modules
|
|
79
|
+
base_encoder: CLIPEncoder = base_model.vision_model.encoder
|
|
80
|
+
moe_encoder: CLIPEncoder = moe_model.vision_model.encoder
|
|
81
|
+
expert_encoders = [m.vision_model.encoder for m in expert_models]
|
|
82
|
+
|
|
83
|
+
num_layers = len(base_encoder.layers)
|
|
84
|
+
for layer_idx in range(num_layers):
|
|
85
|
+
base_mlp = base_encoder.layers[layer_idx].mlp
|
|
86
|
+
expert_mlps = [e.layers[layer_idx].mlp for e in expert_encoders]
|
|
87
|
+
|
|
88
|
+
moe_encoder.layers[layer_idx].mlp = RankOneMoE(
|
|
89
|
+
hidden_size=base_encoder.config.hidden_size,
|
|
90
|
+
base_model=base_mlp,
|
|
91
|
+
expert_models=expert_mlps,
|
|
92
|
+
init_lambda=self.config.init_lambda,
|
|
93
|
+
batch_first=True, # For open_clip models this is False
|
|
94
|
+
router_hidden_layers=self.config.router_hidden_layers,
|
|
95
|
+
batch_reduce=self.config.batch_reduce,
|
|
96
|
+
svd_accelerator=self.config.svd_accelerator,
|
|
97
|
+
rank_k=self.config.rank_k,
|
|
98
|
+
select_k=self.config.select_k,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
return moe_model
|
|
102
|
+
|
|
103
|
+
@functools.cache
|
|
104
|
+
def get_shuffled_test_loader_iter(self, tta_dataset: str):
|
|
105
|
+
"""
|
|
106
|
+
Get an iterator for the shuffled test data loader.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
tta_dataset (str): The name of the test-time adaptation dataset.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
Iterator: An iterator for the shuffled test data loader.
|
|
113
|
+
"""
|
|
114
|
+
dataset = self.modelpool.load_test_dataset(tta_dataset)
|
|
115
|
+
dataset = CLIPDataset(dataset, processor=self.clip_processor)
|
|
116
|
+
log.info("get_shuffled_test_loader_iter")
|
|
117
|
+
loader = DataLoader(
|
|
118
|
+
dataset,
|
|
119
|
+
batch_size=self.config.batch_size,
|
|
120
|
+
shuffle=True,
|
|
121
|
+
num_workers=self.config.num_workers,
|
|
122
|
+
pin_memory=True,
|
|
123
|
+
)
|
|
124
|
+
loader = self.fabric.setup_dataloaders(loader)
|
|
125
|
+
return iter(InfiniteDataLoader(loader))
|
|
126
|
+
|
|
127
|
+
def on_test_time_adaptation_start(self):
|
|
128
|
+
"""
|
|
129
|
+
Load the CLIP processor and construct the zero-shot classification head for each task.
|
|
130
|
+
"""
|
|
131
|
+
self.setup_zero_shot_classification_head()
|
|
132
|
+
|
|
133
|
+
def compute_logits(self, module, batch, task) -> Tensor:
|
|
134
|
+
"""
|
|
135
|
+
Compute the logits for the given batch and task.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
module: The model module.
|
|
139
|
+
batch: The input batch.
|
|
140
|
+
task: The task name.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
Tensor: The computed logits.
|
|
144
|
+
"""
|
|
145
|
+
images, _ = batch
|
|
146
|
+
text_embeds = self.zeroshot_weights[task]
|
|
147
|
+
|
|
148
|
+
image_embeds = module(images)[1]
|
|
149
|
+
image_embeds = self.visual_projection(image_embeds)
|
|
150
|
+
|
|
151
|
+
# Normalize embeddings
|
|
152
|
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
|
153
|
+
|
|
154
|
+
# Cosine similarity
|
|
155
|
+
logits_per_text = (
|
|
156
|
+
torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
|
|
157
|
+
)
|
|
158
|
+
logits_per_image = logits_per_text.t()
|
|
159
|
+
|
|
160
|
+
return logits_per_image
|