fusion-bench 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +1 -0
- fusion_bench/compat/method/base_algorithm.py +0 -1
- fusion_bench/compat/modelpool/__init__.py +2 -1
- fusion_bench/dataset/arc_agi/__init__.py +6 -1
- fusion_bench/dataset/arc_agi/arc.py +21 -7
- fusion_bench/dataset/arc_agi/arc_agi.py +156 -25
- fusion_bench/dataset/arc_agi/np_cache.py +0 -1
- fusion_bench/dataset/arc_agi/preprocess.py +50 -8
- fusion_bench/dataset/llama/collate.py +10 -3
- fusion_bench/method/__init__.py +3 -0
- fusion_bench/method/adamerging/__init__.py +1 -1
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +47 -5
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +58 -23
- fusion_bench/method/pruning/magnitude_diff_pruning.py +2 -1
- fusion_bench/method/rankone_moe/__init__.py +3 -0
- fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
- fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
- fusion_bench/method/simple_average.py +1 -1
- fusion_bench/mixins/clip_classification.py +2 -7
- fusion_bench/mixins/lightning_fabric.py +2 -2
- fusion_bench/models/rankone_moe.py +410 -0
- fusion_bench/taskpool/__init__.py +10 -2
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +2 -1
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/RECORD +36 -29
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +4 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +13 -7
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +1 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/top_level.txt +0 -0
|
@@ -29,6 +29,7 @@ class AlgorithmFactory:
|
|
|
29
29
|
"clip_weight_ensembling_moe": ".we_moe.clip_we_moe.CLIPWeightEnsemblingMoEAlgorithm",
|
|
30
30
|
"sparse_clip_weight_ensembling_moe": "fusion_bench.method.SparseCLIPWeightEnsemblingMoEAlgorithm",
|
|
31
31
|
"smile_mistral_upscaling": ".smile_upscaling.smile_mistral_upscaling.SmileMistralUpscalingAlgorithm",
|
|
32
|
+
"rankone_moe": ".rankone_moe.clip_rankone_moe.CLIPRankOneMoEAlgorithm",
|
|
32
33
|
}
|
|
33
34
|
|
|
34
35
|
@staticmethod
|
|
@@ -26,7 +26,6 @@ class ModelFusionAlgorithm(ABC):
|
|
|
26
26
|
algorithm_config (Optional[DictConfig]): Configuration for the algorithm. Defaults to an empty configuration if not provided.
|
|
27
27
|
Get access to the configuration using `self.config`.
|
|
28
28
|
"""
|
|
29
|
-
super().__init__()
|
|
30
29
|
if algorithm_config is None:
|
|
31
30
|
algorithm_config = DictConfig({})
|
|
32
31
|
self.config = algorithm_config
|
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
# flake8: noqa F401
|
|
2
|
+
import warnings
|
|
3
|
+
|
|
2
4
|
from omegaconf import DictConfig
|
|
3
5
|
|
|
4
6
|
from fusion_bench.modelpool.huggingface_gpt2_classification import (
|
|
@@ -9,7 +11,6 @@ from fusion_bench.modelpool.PeftModelForSeq2SeqLM import PeftModelForSeq2SeqLMPo
|
|
|
9
11
|
from .AutoModelForSeq2SeqLM import AutoModelForSeq2SeqLMPool
|
|
10
12
|
from .base_pool import DictModelPool, ListModelPool, ModelPool, to_modelpool
|
|
11
13
|
from .huggingface_clip_vision import HuggingFaceClipVisionPool
|
|
12
|
-
import warnings
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
class ModelPoolFactory:
|
|
@@ -7,6 +7,7 @@ Task: a class to represent a task (task.test_example and task.train_examples are
|
|
|
7
7
|
read_from_single_file: a function to read challenge problems and solutions from a single file
|
|
8
8
|
make_submission: a function to create a submission file
|
|
9
9
|
"""
|
|
10
|
+
|
|
10
11
|
import dataclasses
|
|
11
12
|
import glob
|
|
12
13
|
import json
|
|
@@ -15,7 +16,6 @@ from typing import List, Optional
|
|
|
15
16
|
|
|
16
17
|
import numpy as np
|
|
17
18
|
|
|
18
|
-
|
|
19
19
|
Grid = np.ndarray
|
|
20
20
|
|
|
21
21
|
|
|
@@ -66,7 +66,9 @@ class Example:
|
|
|
66
66
|
def __eq__(self, other: object) -> bool:
|
|
67
67
|
if not isinstance(other, Example):
|
|
68
68
|
return NotImplemented
|
|
69
|
-
return np.array_equal(self.input, other.input) and np.array_equal(
|
|
69
|
+
return np.array_equal(self.input, other.input) and np.array_equal(
|
|
70
|
+
self.output, other.output
|
|
71
|
+
)
|
|
70
72
|
|
|
71
73
|
@classmethod
|
|
72
74
|
def deserialize(cls, data: dict, test: bool = False) -> "Example":
|
|
@@ -150,7 +152,11 @@ class Task:
|
|
|
150
152
|
tasks = []
|
|
151
153
|
for test_data in data["test"]:
|
|
152
154
|
task = cls.deserialize(
|
|
153
|
-
{
|
|
155
|
+
{
|
|
156
|
+
"train": data["train"],
|
|
157
|
+
"test": [test_data],
|
|
158
|
+
"name": data.get("name", ""),
|
|
159
|
+
},
|
|
154
160
|
test=test,
|
|
155
161
|
)
|
|
156
162
|
tasks.append(task)
|
|
@@ -245,7 +251,9 @@ def make_submission(
|
|
|
245
251
|
"""
|
|
246
252
|
Make a submission
|
|
247
253
|
"""
|
|
248
|
-
assert len(tasks) == len(
|
|
254
|
+
assert len(tasks) == len(
|
|
255
|
+
predictions
|
|
256
|
+
), "Number of tasks and predictions should be the same"
|
|
249
257
|
|
|
250
258
|
# sort by task_name alphabetically to ensure order of subtasks
|
|
251
259
|
indices = np.argsort([task.name for task in tasks])
|
|
@@ -259,8 +267,12 @@ def make_submission(
|
|
|
259
267
|
if task_name not in submissions:
|
|
260
268
|
submissions[task_name] = []
|
|
261
269
|
|
|
262
|
-
assert
|
|
263
|
-
|
|
270
|
+
assert (
|
|
271
|
+
len(prediction) == number_of_attempts
|
|
272
|
+
), "Number of attempts should be the same"
|
|
273
|
+
attempts = {
|
|
274
|
+
f"attempt_{j+1}": to_list(pred) for j, pred in enumerate(prediction)
|
|
275
|
+
}
|
|
264
276
|
while len(submissions[task_name]) <= task_no:
|
|
265
277
|
submissions[task_name].append({"attempt_1": [[0]], "attempt_2": [[0]]})
|
|
266
278
|
|
|
@@ -277,7 +289,9 @@ if __name__ == "__main__":
|
|
|
277
289
|
arc_path = "/kaggle/input/arc-prize-2024/"
|
|
278
290
|
tasks = read_tasks_from_single_file(arc_path + "arc-agi_training_challenges.json")
|
|
279
291
|
print(tasks[0])
|
|
280
|
-
tasks = read_tasks_from_single_file(
|
|
292
|
+
tasks = read_tasks_from_single_file(
|
|
293
|
+
arc_path + "arc-agi_evaluation_challenges.json", test=True
|
|
294
|
+
)
|
|
281
295
|
print(tasks[0])
|
|
282
296
|
|
|
283
297
|
tasks = read_tasks_from_single_file(
|
|
@@ -5,15 +5,16 @@ import sys
|
|
|
5
5
|
from multiprocessing import Pool
|
|
6
6
|
from typing import Any, Dict, List, Literal, Optional
|
|
7
7
|
|
|
8
|
-
import fusion_bench
|
|
9
8
|
import numpy as np
|
|
10
9
|
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
|
11
10
|
from lightning.fabric.utilities import rank_zero_only
|
|
12
11
|
from tqdm.auto import tqdm
|
|
13
12
|
from typing_extensions import TYPE_CHECKING
|
|
14
13
|
|
|
14
|
+
import fusion_bench
|
|
15
|
+
|
|
15
16
|
from .arc import Example, Task
|
|
16
|
-
from .preprocess import get_augmenters, process_task
|
|
17
|
+
from .preprocess import get_augmenters, process_task, process_task_for_ttt
|
|
17
18
|
|
|
18
19
|
if TYPE_CHECKING:
|
|
19
20
|
from transformers import PreTrainedTokenizer
|
|
@@ -65,7 +66,7 @@ def _join_list(lists: List[List[Any]]) -> List[Any]:
|
|
|
65
66
|
return ans
|
|
66
67
|
|
|
67
68
|
|
|
68
|
-
def
|
|
69
|
+
def _to_tasks(
|
|
69
70
|
train_data: List[Dict[str, Any]],
|
|
70
71
|
test_data: List[Dict[str, Any]],
|
|
71
72
|
name: str,
|
|
@@ -87,7 +88,7 @@ def _to_task(
|
|
|
87
88
|
return tasks
|
|
88
89
|
|
|
89
90
|
|
|
90
|
-
def
|
|
91
|
+
def tokenizer_tasks_for_ttt(
|
|
91
92
|
tasks: List[Task],
|
|
92
93
|
tokenizer: "PreTrainedTokenizer",
|
|
93
94
|
use_data_augmentation: bool = True,
|
|
@@ -106,7 +107,7 @@ def _tokenizer_tasks(
|
|
|
106
107
|
|
|
107
108
|
formatter = _get_formatter("new")
|
|
108
109
|
processor = functools.partial(
|
|
109
|
-
|
|
110
|
+
process_task_for_ttt,
|
|
110
111
|
augmenters=augmenters_to_apply,
|
|
111
112
|
formatter=formatter,
|
|
112
113
|
tokenizer=tokenizer,
|
|
@@ -133,7 +134,34 @@ def _tokenizer_tasks(
|
|
|
133
134
|
return dataset
|
|
134
135
|
|
|
135
136
|
|
|
136
|
-
def
|
|
137
|
+
def tokenizer_tasks(
|
|
138
|
+
tasks: List[Task],
|
|
139
|
+
tokenizer: "PreTrainedTokenizer",
|
|
140
|
+
):
|
|
141
|
+
formatter = _get_formatter("new")
|
|
142
|
+
processor = functools.partial(
|
|
143
|
+
process_task, formatter=formatter, tokenizer=tokenizer
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# with Pool(multiprocessing.cpu_count()) as p:
|
|
147
|
+
# data = p.map(processor, tasks)
|
|
148
|
+
data = _join_list(
|
|
149
|
+
[
|
|
150
|
+
processor(task)
|
|
151
|
+
for task in tqdm(
|
|
152
|
+
tasks,
|
|
153
|
+
desc="Processing tasks",
|
|
154
|
+
dynamic_ncols=True,
|
|
155
|
+
leave=False,
|
|
156
|
+
disable=not rank_zero_only.rank == 0,
|
|
157
|
+
)
|
|
158
|
+
]
|
|
159
|
+
)
|
|
160
|
+
dataset = Dataset.from_list(data)
|
|
161
|
+
return dataset
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def load_tokenized_arc_agi_dataset_for_ttt(
|
|
137
165
|
tokenizer: Optional["PreTrainedTokenizer"],
|
|
138
166
|
path: str = "dataartist/arc-agi",
|
|
139
167
|
split: Optional[str] = None,
|
|
@@ -144,47 +172,47 @@ def load_tokenized_arc_agi_dataset(
|
|
|
144
172
|
max_num_tasks: Optional[int] = None,
|
|
145
173
|
):
|
|
146
174
|
# regularize split
|
|
147
|
-
|
|
148
|
-
split = "training"
|
|
149
|
-
if split.lower() == "test":
|
|
150
|
-
split = "evaluation"
|
|
175
|
+
split = split.lower() if split is not None else split
|
|
151
176
|
|
|
152
177
|
# load cached dataset if available
|
|
153
178
|
if cache_path is not None and fusion_bench.utils.path.path_is_dir_and_not_empty(
|
|
154
179
|
cache_path
|
|
155
180
|
):
|
|
156
181
|
datasets = load_from_disk(cache_path)
|
|
157
|
-
if split is None:
|
|
158
|
-
return datasets
|
|
159
|
-
else:
|
|
182
|
+
if split is None and split in datasets.column_names:
|
|
160
183
|
return datasets[split]
|
|
184
|
+
else:
|
|
185
|
+
return datasets
|
|
161
186
|
else:
|
|
162
187
|
assert (
|
|
163
188
|
tokenizer is not None
|
|
164
189
|
), "Cached dataset not found. Need tokenizer to process the raw data."
|
|
165
190
|
|
|
166
191
|
# load raw dataset
|
|
167
|
-
datasets = load_dataset(path
|
|
192
|
+
datasets = load_dataset(path)
|
|
193
|
+
datasets = DatasetDict(
|
|
194
|
+
{"train": datasets["training"], "test": datasets["evaluation"]}
|
|
195
|
+
)
|
|
168
196
|
if split is None:
|
|
169
|
-
converted_datasets = {
|
|
197
|
+
converted_datasets: Dict[str, List[Task]] = {
|
|
170
198
|
"train": _join_list(
|
|
171
199
|
[
|
|
172
|
-
|
|
200
|
+
_to_tasks(
|
|
173
201
|
task["train"],
|
|
174
202
|
task["test"],
|
|
175
203
|
task["id"],
|
|
176
204
|
)
|
|
177
|
-
for task in datasets["
|
|
205
|
+
for task in datasets["train"]
|
|
178
206
|
]
|
|
179
207
|
),
|
|
180
208
|
"test": _join_list(
|
|
181
209
|
[
|
|
182
|
-
|
|
210
|
+
_to_tasks(
|
|
183
211
|
task["train"],
|
|
184
212
|
task["test"],
|
|
185
213
|
task["id"],
|
|
186
214
|
)
|
|
187
|
-
for task in datasets["
|
|
215
|
+
for task in datasets["test"]
|
|
188
216
|
]
|
|
189
217
|
),
|
|
190
218
|
}
|
|
@@ -195,7 +223,7 @@ def load_tokenized_arc_agi_dataset(
|
|
|
195
223
|
for split in converted_datasets
|
|
196
224
|
}
|
|
197
225
|
converted_datasets = {
|
|
198
|
-
split:
|
|
226
|
+
split: tokenizer_tasks_for_ttt(
|
|
199
227
|
converted_datasets[split],
|
|
200
228
|
tokenizer,
|
|
201
229
|
use_data_augmentation,
|
|
@@ -210,25 +238,128 @@ def load_tokenized_arc_agi_dataset(
|
|
|
210
238
|
)
|
|
211
239
|
}
|
|
212
240
|
converted_datasets = DatasetDict(converted_datasets)
|
|
213
|
-
else:
|
|
241
|
+
else: # split is not None
|
|
214
242
|
converted_datasets = _join_list(
|
|
215
243
|
[
|
|
216
|
-
|
|
244
|
+
_to_tasks(
|
|
217
245
|
task["train"],
|
|
218
246
|
task["test"],
|
|
219
247
|
task["id"],
|
|
220
248
|
)
|
|
221
|
-
for task in datasets
|
|
249
|
+
for task in datasets[split]
|
|
222
250
|
]
|
|
223
251
|
)
|
|
224
252
|
if max_num_tasks is not None:
|
|
225
253
|
# limit the number of tasks, useful for debugging
|
|
226
254
|
converted_datasets = converted_datasets[:max_num_tasks]
|
|
227
|
-
converted_datasets =
|
|
255
|
+
converted_datasets = tokenizer_tasks_for_ttt(
|
|
228
256
|
converted_datasets, tokenizer, use_data_augmentation, permute_n, seed
|
|
229
257
|
)
|
|
230
258
|
|
|
231
|
-
if cache_path is not None:
|
|
259
|
+
if cache_path is not None and rank_zero_only.rank == 0:
|
|
260
|
+
os.makedirs(cache_path, exist_ok=True)
|
|
261
|
+
converted_datasets.save_to_disk(cache_path)
|
|
262
|
+
return converted_datasets
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def load_tokenized_arc_agi_dataset(
|
|
266
|
+
tokenizer: Optional["PreTrainedTokenizer"],
|
|
267
|
+
path: str = "dataartist/arc-agi",
|
|
268
|
+
split: Optional[str] = None,
|
|
269
|
+
cache_path: Optional[str] = None,
|
|
270
|
+
max_num_tasks: Optional[int] = None,
|
|
271
|
+
):
|
|
272
|
+
"""
|
|
273
|
+
Loads and tokenizes the ARC-AGI dataset.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
tokenizer (Optional[PreTrainedTokenizer]): The tokenizer to use for tokenizing the dataset.
|
|
277
|
+
path (str, optional): The path to the dataset. Defaults to "dataartist/arc-agi".
|
|
278
|
+
split (Optional[str], optional): The dataset split to load (e.g., "train", "test"). Defaults to None.
|
|
279
|
+
cache_path (Optional[str], optional): The path to cache the processed dataset. Defaults to None.
|
|
280
|
+
max_num_tasks (Optional[int], optional): The maximum number of tasks to load. Useful for debugging. Defaults to None.
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
DatasetDict or Dataset: The tokenized dataset, either as a DatasetDict if split is None, or as a Dataset if a specific split is specified.
|
|
284
|
+
"""
|
|
285
|
+
# regularize split
|
|
286
|
+
split = split.lower() if split is not None else split
|
|
287
|
+
|
|
288
|
+
# load cached dataset if available
|
|
289
|
+
if cache_path is not None and fusion_bench.utils.path.path_is_dir_and_not_empty(
|
|
290
|
+
cache_path
|
|
291
|
+
):
|
|
292
|
+
datasets = load_from_disk(cache_path)
|
|
293
|
+
if split is None and split in datasets.column_names:
|
|
294
|
+
return datasets[split]
|
|
295
|
+
else:
|
|
296
|
+
return datasets
|
|
297
|
+
else:
|
|
298
|
+
assert (
|
|
299
|
+
tokenizer is not None
|
|
300
|
+
), "Cached dataset not found. Need tokenizer to process the raw data."
|
|
301
|
+
|
|
302
|
+
# load raw dataset
|
|
303
|
+
datasets = load_dataset(path)
|
|
304
|
+
datasets = DatasetDict(
|
|
305
|
+
{"train": datasets["training"], "test": datasets["evaluation"]}
|
|
306
|
+
)
|
|
307
|
+
if split is None:
|
|
308
|
+
converted_datasets: Dict[str, List[Task]] = {
|
|
309
|
+
"train": _join_list(
|
|
310
|
+
[
|
|
311
|
+
_to_tasks(
|
|
312
|
+
task["train"],
|
|
313
|
+
task["test"],
|
|
314
|
+
task["id"],
|
|
315
|
+
)
|
|
316
|
+
for task in datasets["train"]
|
|
317
|
+
]
|
|
318
|
+
),
|
|
319
|
+
"test": _join_list(
|
|
320
|
+
[
|
|
321
|
+
_to_tasks(
|
|
322
|
+
task["train"],
|
|
323
|
+
task["test"],
|
|
324
|
+
task["id"],
|
|
325
|
+
)
|
|
326
|
+
for task in datasets["test"]
|
|
327
|
+
]
|
|
328
|
+
),
|
|
329
|
+
}
|
|
330
|
+
if max_num_tasks is not None:
|
|
331
|
+
# limit the number of tasks, useful for debugging
|
|
332
|
+
converted_datasets = {
|
|
333
|
+
split: converted_datasets[split][:max_num_tasks]
|
|
334
|
+
for split in converted_datasets
|
|
335
|
+
}
|
|
336
|
+
converted_datasets = {
|
|
337
|
+
split: tokenizer_tasks(converted_datasets[split], tokenizer)
|
|
338
|
+
for split in tqdm(
|
|
339
|
+
converted_datasets,
|
|
340
|
+
desc="Processing splits",
|
|
341
|
+
dynamic_ncols=True,
|
|
342
|
+
disable=not rank_zero_only.rank == 0,
|
|
343
|
+
)
|
|
344
|
+
}
|
|
345
|
+
converted_datasets = DatasetDict(converted_datasets)
|
|
346
|
+
else: # split is not None
|
|
347
|
+
converted_datasets = _join_list(
|
|
348
|
+
[
|
|
349
|
+
_to_tasks(
|
|
350
|
+
task["train"],
|
|
351
|
+
task["test"],
|
|
352
|
+
task["id"],
|
|
353
|
+
)
|
|
354
|
+
for task in datasets[split]
|
|
355
|
+
]
|
|
356
|
+
)
|
|
357
|
+
if max_num_tasks is not None:
|
|
358
|
+
# limit the number of tasks, useful for debugging
|
|
359
|
+
converted_datasets = converted_datasets[:max_num_tasks]
|
|
360
|
+
converted_datasets = tokenizer_tasks(converted_datasets, tokenizer)
|
|
361
|
+
|
|
362
|
+
if cache_path is not None and rank_zero_only.rank == 0:
|
|
232
363
|
os.makedirs(cache_path, exist_ok=True)
|
|
233
364
|
converted_datasets.save_to_disk(cache_path)
|
|
234
365
|
return converted_datasets
|
|
@@ -112,19 +112,41 @@ def get_augmenters(
|
|
|
112
112
|
def format_and_filter(
|
|
113
113
|
formatter: MessageRepresenter,
|
|
114
114
|
tokenizer: "PreTrainedTokenizer",
|
|
115
|
-
task,
|
|
115
|
+
task: Task,
|
|
116
116
|
):
|
|
117
|
+
"""
|
|
118
|
+
Formats and filters a task for model input.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
formatter (MessageRepresenter): The formatter to encode the task.
|
|
122
|
+
tokenizer (PreTrainedTokenizer): The tokenizer to tokenize the conversation.
|
|
123
|
+
task: The task to be formatted and filtered.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
Dict[str, Any]: A dictionary containing the formatted data with keys:
|
|
127
|
+
- "input_ids": The tokenized input IDs.
|
|
128
|
+
- "attention_mask": The attention mask for the input IDs.
|
|
129
|
+
- "labels": The labels for the input IDs.
|
|
130
|
+
- "task_id": The task ID.
|
|
131
|
+
- "num_prompt_tokens": The number of prompt tokens.
|
|
132
|
+
- "num_output_tokens": The number of output tokens.
|
|
133
|
+
"""
|
|
134
|
+
task_id = task.name
|
|
117
135
|
task = formatter.encode(task)
|
|
118
136
|
conversation = task[0] + [task[1]]
|
|
119
137
|
assert conversation[-1]["role"] == "assistant", "Last message should be assistant"
|
|
120
138
|
prompt_tokens = tokenizer.apply_chat_template(
|
|
121
139
|
conversation[:-1], tokenize=True, add_generation_prompt=True
|
|
122
140
|
)
|
|
123
|
-
|
|
141
|
+
generation_tokens = tokenizer.apply_chat_template(conversation, tokenize=True)
|
|
142
|
+
output_tokens = generation_tokens[len(prompt_tokens) :]
|
|
124
143
|
data = {
|
|
125
144
|
"input_ids": prompt_tokens + output_tokens,
|
|
126
145
|
"attention_mask": [1] * len(prompt_tokens) + [1] * len(output_tokens),
|
|
127
146
|
"labels": [-100] * len(prompt_tokens) + output_tokens,
|
|
147
|
+
"task_id": task_id,
|
|
148
|
+
"num_prompt_tokens": len(prompt_tokens),
|
|
149
|
+
"num_output_tokens": len(output_tokens),
|
|
128
150
|
}
|
|
129
151
|
return data
|
|
130
152
|
|
|
@@ -136,6 +158,19 @@ def get_test_time_train_data(
|
|
|
136
158
|
permute_n: int = 1,
|
|
137
159
|
seed: int = 0,
|
|
138
160
|
) -> List[Task]:
|
|
161
|
+
"""
|
|
162
|
+
Generates augmented training data for test-time training.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
original_task (Task): The original task containing training examples.
|
|
166
|
+
augmenters (List[Augmenter]): A list of augmenters to apply to the tasks.
|
|
167
|
+
n (int, optional): The number of examples to leave out for testing. Defaults to 1.
|
|
168
|
+
permute_n (int, optional): The number of times to permute the augmented tasks. Defaults to 1.
|
|
169
|
+
seed (int, optional): The random seed for reproducibility. Defaults to 0.
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
List[Task]: A list of augmented tasks.
|
|
173
|
+
"""
|
|
139
174
|
rng = np.random.RandomState(seed)
|
|
140
175
|
train_examples = original_task.train_examples.copy()
|
|
141
176
|
initial_tasks = []
|
|
@@ -150,7 +185,7 @@ def get_test_time_train_data(
|
|
|
150
185
|
for comb in combs:
|
|
151
186
|
initial_tasks.append(
|
|
152
187
|
Task(
|
|
153
|
-
name=
|
|
188
|
+
name=original_task.name,
|
|
154
189
|
train_examples=[examples[j] for j in comb],
|
|
155
190
|
test_example=examples[i],
|
|
156
191
|
)
|
|
@@ -183,7 +218,6 @@ def get_test_time_train_data(
|
|
|
183
218
|
color_and_permute_augmented_tasks.append(new_task)
|
|
184
219
|
|
|
185
220
|
augmented_tasks = color_and_permute_augmented_tasks + augmented_tasks
|
|
186
|
-
|
|
187
221
|
augmented_tasks = list(set(augmented_tasks))
|
|
188
222
|
|
|
189
223
|
return augmented_tasks
|
|
@@ -193,13 +227,12 @@ def get_formatted_data(
|
|
|
193
227
|
task: Task,
|
|
194
228
|
augmenters: List[Augmenter],
|
|
195
229
|
formatter: MessageRepresenter,
|
|
196
|
-
tokenizer,
|
|
230
|
+
tokenizer: "PreTrainedTokenizer",
|
|
197
231
|
leave_n: int = 1,
|
|
198
232
|
permute_n: int = 1,
|
|
199
233
|
seed: int = 0,
|
|
200
234
|
max_tokens: int = 8192,
|
|
201
235
|
):
|
|
202
|
-
|
|
203
236
|
train_data = get_test_time_train_data(
|
|
204
237
|
task, augmenters, n=leave_n, permute_n=permute_n, seed=seed
|
|
205
238
|
)
|
|
@@ -213,11 +246,11 @@ def get_formatted_data(
|
|
|
213
246
|
return formatted_data
|
|
214
247
|
|
|
215
248
|
|
|
216
|
-
def
|
|
249
|
+
def process_task_for_ttt(
|
|
217
250
|
task: Task,
|
|
218
251
|
augmenters: List[Augmenter],
|
|
219
252
|
formatter: MessageRepresenter,
|
|
220
|
-
tokenizer,
|
|
253
|
+
tokenizer: "PreTrainedTokenizer",
|
|
221
254
|
permute_n: int = 1,
|
|
222
255
|
Nmax: int = 250,
|
|
223
256
|
seed: int = 0,
|
|
@@ -254,3 +287,12 @@ def process_task(
|
|
|
254
287
|
train = train[:Nmax]
|
|
255
288
|
|
|
256
289
|
return train
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def process_task(
|
|
293
|
+
task: Task,
|
|
294
|
+
formatter: MessageRepresenter,
|
|
295
|
+
tokenizer: "PreTrainedTokenizer",
|
|
296
|
+
):
|
|
297
|
+
formatted = format_and_filter(formatter, tokenizer, task)
|
|
298
|
+
return [formatted]
|
|
@@ -13,7 +13,8 @@ def padded_collate_sft(
|
|
|
13
13
|
labels_key: Optional[str] = "labels",
|
|
14
14
|
ignore_idx: int = -100,
|
|
15
15
|
) -> Dict[str, torch.Tensor]:
|
|
16
|
-
"""
|
|
16
|
+
"""
|
|
17
|
+
Pad (right) a batch of sequences to the longest sequence length in the batch, and
|
|
17
18
|
convert integer lists to tensors.
|
|
18
19
|
|
|
19
20
|
Args:
|
|
@@ -44,10 +45,16 @@ def padded_collate_sft(
|
|
|
44
45
|
)
|
|
45
46
|
|
|
46
47
|
if attention_mask is not None:
|
|
47
|
-
|
|
48
|
+
collated_batch = {
|
|
48
49
|
input_ids_key: input_ids,
|
|
49
50
|
attention_mask_key: attention_mask,
|
|
50
51
|
labels_key: labels,
|
|
51
52
|
}
|
|
52
53
|
else:
|
|
53
|
-
|
|
54
|
+
collated_batch = {input_ids_key: input_ids, labels_key: labels}
|
|
55
|
+
|
|
56
|
+
for key in batch[0]:
|
|
57
|
+
if key not in [input_ids_key, attention_mask_key, labels_key]:
|
|
58
|
+
collated_batch[key] = [x[key] for x in batch]
|
|
59
|
+
|
|
60
|
+
return collated_batch
|
fusion_bench/method/__init__.py
CHANGED
|
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
|
|
|
5
5
|
from fusion_bench.utils.lazy_imports import LazyImporter
|
|
6
6
|
|
|
7
7
|
_import_structure = {
|
|
8
|
+
# --------------
|
|
8
9
|
"base_algorithm": ["BaseModelFusionAlgorithm", "BaseAlgorithm"],
|
|
9
10
|
"dummy": ["DummyAlgorithm"],
|
|
10
11
|
# single task learning (fine-tuning)
|
|
@@ -64,6 +65,7 @@ _import_structure = {
|
|
|
64
65
|
],
|
|
65
66
|
"dawe": ["DataAdaptiveWeightEnsemblingForCLIP"],
|
|
66
67
|
"we_moe": ["CLIPWeightEnsemblingMoEAlgorithm"],
|
|
68
|
+
"rankone_moe": ["CLIPRankOneMoEAlgorithm", "RankOneMoEAlgorithm"],
|
|
67
69
|
"sparse_we_moe": [
|
|
68
70
|
"SparseWeightEnsemblingMoEAlgorithm",
|
|
69
71
|
"SparseCLIPWeightEnsemblingMoEAlgorithm",
|
|
@@ -134,6 +136,7 @@ if TYPE_CHECKING:
|
|
|
134
136
|
PWEMoELinearScalarizationForCLIP,
|
|
135
137
|
PWEMoExactParetoOptimalForCLIP,
|
|
136
138
|
)
|
|
139
|
+
from .rankone_moe import CLIPRankOneMoEAlgorithm, RankOneMoEAlgorithm
|
|
137
140
|
from .regmean import RegMeanAlgorithmForCLIP, RegMeanAlgorithmForGPT2
|
|
138
141
|
from .simple_average import SimpleAverageAlgorithm
|
|
139
142
|
from .smile_upscaling import (
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# flake8: noqa F401
|
|
2
2
|
from .clip_layer_wise_adamerging import CLIPLayerWiseAdaMergingAlgorithm
|
|
3
3
|
from .clip_task_wise_adamerging import CLIPTaskWiseAdaMergingAlgorithm
|
|
4
|
+
from .flan_t5_layer_wise_adamerging import FlanT5LayerWiseAdaMergingAlgorithm
|
|
4
5
|
from .gpt2_layer_wise_adamerging import GPT2LayerWiseAdaMergingAlgorithm
|
|
5
6
|
from .llama_adamerging import LayerWiseAdaMergingForLlamaSFT
|
|
6
|
-
from .flan_t5_layer_wise_adamerging import FlanT5LayerWiseAdaMergingAlgorithm
|