fusion-bench 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. fusion_bench/compat/method/__init__.py +2 -0
  2. fusion_bench/compat/method/base_algorithm.py +7 -2
  3. fusion_bench/compat/modelpool/__init__.py +3 -2
  4. fusion_bench/compat/taskpool/__init__.py +1 -1
  5. fusion_bench/dataset/arc_agi/__init__.py +6 -1
  6. fusion_bench/dataset/arc_agi/arc.py +26 -7
  7. fusion_bench/dataset/arc_agi/arc_agi.py +156 -25
  8. fusion_bench/dataset/arc_agi/np_cache.py +0 -1
  9. fusion_bench/dataset/arc_agi/preprocess.py +51 -9
  10. fusion_bench/dataset/llama/__init__.py +1 -0
  11. fusion_bench/dataset/llama/alpaca.py +93 -3
  12. fusion_bench/dataset/llama/collate.py +72 -5
  13. fusion_bench/dataset/llama/metamathqa.py +50 -0
  14. fusion_bench/dataset/llama/preference_700k.py +70 -0
  15. fusion_bench/dataset/llama/stanford_shp.py +90 -0
  16. fusion_bench/dataset/llama/ultrachat.py +58 -0
  17. fusion_bench/dataset/llama/utils/__init__.py +0 -0
  18. fusion_bench/method/__init__.py +4 -1
  19. fusion_bench/method/adamerging/__init__.py +1 -1
  20. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
  21. fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
  22. fusion_bench/method/linear/expo.py +39 -0
  23. fusion_bench/method/lm_finetune/__init__.py +1 -0
  24. fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
  25. fusion_bench/method/lm_finetune/fullfinetune_sft.py +122 -150
  26. fusion_bench/method/lm_finetune/peftfinetune_sft.py +102 -157
  27. fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
  28. fusion_bench/method/pruning/llama_random_prune.py +2 -2
  29. fusion_bench/method/pruning/magnitude_diff_pruning.py +2 -1
  30. fusion_bench/method/rankone_moe/__init__.py +3 -0
  31. fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
  32. fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
  33. fusion_bench/method/simple_average.py +1 -1
  34. fusion_bench/method/surgery/__init__.py +3 -0
  35. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
  36. fusion_bench/mixins/__init__.py +2 -0
  37. fusion_bench/mixins/clip_classification.py +60 -12
  38. fusion_bench/mixins/fabric_training.py +320 -0
  39. fusion_bench/mixins/lightning_fabric.py +11 -2
  40. fusion_bench/modelpool/__init__.py +2 -0
  41. fusion_bench/modelpool/causal_lm/__init__.py +1 -1
  42. fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
  43. fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
  44. fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
  45. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
  46. fusion_bench/models/chat_templates/__init__.py +1 -0
  47. fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
  48. fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
  49. fusion_bench/models/hf_clip.py +50 -9
  50. fusion_bench/models/rankone_moe.py +410 -0
  51. fusion_bench/models/surgery/surgerymodelwrapper.py +157 -0
  52. fusion_bench/models/utils.py +8 -0
  53. fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
  54. fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
  55. fusion_bench/optim/__init__.py +2 -0
  56. fusion_bench/optim/exception.py +47 -0
  57. fusion_bench/optim/lr_scheduler/__init__.py +1 -0
  58. fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
  59. fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
  60. fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
  61. fusion_bench/optim/mezo.py +0 -2
  62. fusion_bench/programs/fabric_fusion_program.py +5 -1
  63. fusion_bench/taskpool/__init__.py +10 -2
  64. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  65. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
  66. fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
  67. fusion_bench/taskpool/llama/reward_model.py +157 -0
  68. fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
  69. fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +2 -1
  70. fusion_bench/utils/hydra_utils.py +22 -0
  71. fusion_bench/utils/plot/__init__.py +0 -0
  72. fusion_bench/utils/plot/token.py +52 -0
  73. fusion_bench/utils/plot/token_notebook.py +127 -0
  74. fusion_bench/utils/type.py +5 -3
  75. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/METADATA +1 -1
  76. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/RECORD +104 -57
  77. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  78. fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
  79. fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
  80. fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
  81. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  82. fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
  83. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
  84. fusion_bench_config/fabric_model_fusion.yaml +1 -1
  85. fusion_bench_config/llama_full_finetune.yaml +19 -0
  86. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
  87. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +13 -6
  88. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +17 -9
  89. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
  90. fusion_bench_config/method/regmean/clip_regmean.yaml +1 -0
  91. fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
  92. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
  93. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
  94. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
  95. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
  96. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
  97. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
  98. fusion_bench_config/nyuv2_config.yaml +5 -1
  99. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
  100. fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
  101. fusion_bench_config/llama_weighted_average.yaml +0 -26
  102. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/LICENSE +0 -0
  103. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/WHEEL +0 -0
  104. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/entry_points.txt +0 -0
  105. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/top_level.txt +0 -0
@@ -21,6 +21,7 @@ class AlgorithmFactory:
21
21
  "clip_task_wise_adamerging": ".adamerging.clip_task_wise_adamerging.CLIPTaskWiseAdaMergingAlgorithm",
22
22
  "clip_layer_wise_adamerging": ".adamerging.clip_layer_wise_adamerging.CLIPLayerWiseAdaMergingAlgorithm",
23
23
  "singular_projection_merging": "fusion_bench.method.smile_upscaling.singular_projection_merging.SingularProjectionMergingAlgorithm",
24
+ "clip_layer_wise_adamerging_surgery": ".surgery.clip_layer_wise_adamerging_surgery.CLIPLayerWiseAdaMergingSurgeryAlgorithm",
24
25
  # plug-and-play model merging methods
25
26
  "clip_concrete_task_arithmetic": ".concrete_subspace.clip_concrete_task_arithmetic.ConcreteTaskArithmeticAlgorithmForCLIP",
26
27
  "clip_concrete_task_wise_adamerging": ".concrete_subspace.clip_concrete_adamerging.ConcreteTaskWiseAdaMergingForCLIP",
@@ -29,6 +30,7 @@ class AlgorithmFactory:
29
30
  "clip_weight_ensembling_moe": ".we_moe.clip_we_moe.CLIPWeightEnsemblingMoEAlgorithm",
30
31
  "sparse_clip_weight_ensembling_moe": "fusion_bench.method.SparseCLIPWeightEnsemblingMoEAlgorithm",
31
32
  "smile_mistral_upscaling": ".smile_upscaling.smile_mistral_upscaling.SmileMistralUpscalingAlgorithm",
33
+ "rankone_moe": ".rankone_moe.clip_rankone_moe.CLIPRankOneMoEAlgorithm",
32
34
  }
33
35
 
34
36
  @staticmethod
@@ -1,8 +1,11 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Optional
2
+ from typing import Optional, TYPE_CHECKING
3
3
 
4
4
  from omegaconf import DictConfig
5
5
 
6
+ if TYPE_CHECKING:
7
+ from fusion_bench.programs.base_program import BaseHydraProgram
8
+
6
9
  __all__ = ["ModelFusionAlgorithm"]
7
10
 
8
11
 
@@ -18,6 +21,9 @@ class ModelFusionAlgorithm(ABC):
18
21
  config (DictConfig): Configuration for the algorithm.
19
22
  """
20
23
 
24
+ _program: "BaseHydraProgram" = None
25
+ """A reference to the program that is running the algorithm."""
26
+
21
27
  def __init__(self, algorithm_config: Optional[DictConfig] = None):
22
28
  """
23
29
  Initialize the model fusion algorithm with the given configuration.
@@ -26,7 +32,6 @@ class ModelFusionAlgorithm(ABC):
26
32
  algorithm_config (Optional[DictConfig]): Configuration for the algorithm. Defaults to an empty configuration if not provided.
27
33
  Get access to the configuration using `self.config`.
28
34
  """
29
- super().__init__()
30
35
  if algorithm_config is None:
31
36
  algorithm_config = DictConfig({})
32
37
  self.config = algorithm_config
@@ -1,4 +1,6 @@
1
1
  # flake8: noqa F401
2
+ import warnings
3
+
2
4
  from omegaconf import DictConfig
3
5
 
4
6
  from fusion_bench.modelpool.huggingface_gpt2_classification import (
@@ -9,7 +11,6 @@ from fusion_bench.modelpool.PeftModelForSeq2SeqLM import PeftModelForSeq2SeqLMPo
9
11
  from .AutoModelForSeq2SeqLM import AutoModelForSeq2SeqLMPool
10
12
  from .base_pool import DictModelPool, ListModelPool, ModelPool, to_modelpool
11
13
  from .huggingface_clip_vision import HuggingFaceClipVisionPool
12
- import warnings
13
14
 
14
15
 
15
16
  class ModelPoolFactory:
@@ -21,7 +22,7 @@ class ModelPoolFactory:
21
22
  """
22
23
 
23
24
  _modelpool = {
24
- "NYUv2ModelPool": ".nyuv2_modelpool.NYUv2ModelPool",
25
+ "NYUv2ModelPool": "fusion_bench.modelpool.nyuv2_modelpool.NYUv2ModelPool",
25
26
  "huggingface_clip_vision": HuggingFaceClipVisionPool,
26
27
  "HF_GPT2ForSequenceClassification": GPT2ForSequenceClassificationPool,
27
28
  "AutoModelPool": ".huggingface_automodel.AutoModelPool",
@@ -20,7 +20,7 @@ class TaskPoolFactory:
20
20
  "dummy": DummyTaskPool,
21
21
  "clip_vit_classification": ".clip_image_classification.CLIPImageClassificationTaskPool",
22
22
  "FlanT5GLUETextGenerationTaskPool": ".flan_t5_glue_text_generation.FlanT5GLUETextGenerationTaskPool",
23
- "NYUv2TaskPool": ".nyuv2_taskpool.NYUv2TaskPool",
23
+ "NYUv2TaskPool": "fusion_bench.taskpool.nyuv2_taskpool.NYUv2TaskPool",
24
24
  }
25
25
 
26
26
  @staticmethod
@@ -1 +1,6 @@
1
- from .arc_agi import load_tokenized_arc_agi_dataset
1
+ from .arc_agi import (
2
+ load_tokenized_arc_agi_dataset,
3
+ load_tokenized_arc_agi_dataset_for_ttt,
4
+ process_task,
5
+ process_task_for_ttt,
6
+ )
@@ -7,6 +7,7 @@ Task: a class to represent a task (task.test_example and task.train_examples are
7
7
  read_from_single_file: a function to read challenge problems and solutions from a single file
8
8
  make_submission: a function to create a submission file
9
9
  """
10
+
10
11
  import dataclasses
11
12
  import glob
12
13
  import json
@@ -15,7 +16,6 @@ from typing import List, Optional
15
16
 
16
17
  import numpy as np
17
18
 
18
-
19
19
  Grid = np.ndarray
20
20
 
21
21
 
@@ -66,7 +66,9 @@ class Example:
66
66
  def __eq__(self, other: object) -> bool:
67
67
  if not isinstance(other, Example):
68
68
  return NotImplemented
69
- return np.array_equal(self.input, other.input) and np.array_equal(self.output, other.output)
69
+ return np.array_equal(self.input, other.input) and np.array_equal(
70
+ self.output, other.output
71
+ )
70
72
 
71
73
  @classmethod
72
74
  def deserialize(cls, data: dict, test: bool = False) -> "Example":
@@ -150,7 +152,16 @@ class Task:
150
152
  tasks = []
151
153
  for test_data in data["test"]:
152
154
  task = cls.deserialize(
153
- {"train": data["train"], "test": [test_data], "name": data.get("name", "")},
155
+ {
156
+ "train": data["train"],
157
+ "test": [test_data],
158
+ "name": data.get("name", ""),
159
+ },
160
+ {
161
+ "train": data["train"],
162
+ "test": [test_data],
163
+ "name": data.get("name", ""),
164
+ },
154
165
  test=test,
155
166
  )
156
167
  tasks.append(task)
@@ -245,7 +256,9 @@ def make_submission(
245
256
  """
246
257
  Make a submission
247
258
  """
248
- assert len(tasks) == len(predictions), "Number of tasks and predictions should be the same"
259
+ assert len(tasks) == len(
260
+ predictions
261
+ ), "Number of tasks and predictions should be the same"
249
262
 
250
263
  # sort by task_name alphabetically to ensure order of subtasks
251
264
  indices = np.argsort([task.name for task in tasks])
@@ -259,8 +272,12 @@ def make_submission(
259
272
  if task_name not in submissions:
260
273
  submissions[task_name] = []
261
274
 
262
- assert len(prediction) == number_of_attempts, "Number of attempts should be the same"
263
- attempts = {f"attempt_{j+1}": to_list(pred) for j, pred in enumerate(prediction)}
275
+ assert (
276
+ len(prediction) == number_of_attempts
277
+ ), "Number of attempts should be the same"
278
+ attempts = {
279
+ f"attempt_{j+1}": to_list(pred) for j, pred in enumerate(prediction)
280
+ }
264
281
  while len(submissions[task_name]) <= task_no:
265
282
  submissions[task_name].append({"attempt_1": [[0]], "attempt_2": [[0]]})
266
283
 
@@ -277,7 +294,9 @@ if __name__ == "__main__":
277
294
  arc_path = "/kaggle/input/arc-prize-2024/"
278
295
  tasks = read_tasks_from_single_file(arc_path + "arc-agi_training_challenges.json")
279
296
  print(tasks[0])
280
- tasks = read_tasks_from_single_file(arc_path + "arc-agi_evaluation_challenges.json", test=True)
297
+ tasks = read_tasks_from_single_file(
298
+ arc_path + "arc-agi_evaluation_challenges.json", test=True
299
+ )
281
300
  print(tasks[0])
282
301
 
283
302
  tasks = read_tasks_from_single_file(
@@ -5,15 +5,16 @@ import sys
5
5
  from multiprocessing import Pool
6
6
  from typing import Any, Dict, List, Literal, Optional
7
7
 
8
- import fusion_bench
9
8
  import numpy as np
10
9
  from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
11
10
  from lightning.fabric.utilities import rank_zero_only
12
11
  from tqdm.auto import tqdm
13
12
  from typing_extensions import TYPE_CHECKING
14
13
 
14
+ import fusion_bench
15
+
15
16
  from .arc import Example, Task
16
- from .preprocess import get_augmenters, process_task
17
+ from .preprocess import get_augmenters, process_task, process_task_for_ttt
17
18
 
18
19
  if TYPE_CHECKING:
19
20
  from transformers import PreTrainedTokenizer
@@ -65,7 +66,7 @@ def _join_list(lists: List[List[Any]]) -> List[Any]:
65
66
  return ans
66
67
 
67
68
 
68
- def _to_task(
69
+ def _to_tasks(
69
70
  train_data: List[Dict[str, Any]],
70
71
  test_data: List[Dict[str, Any]],
71
72
  name: str,
@@ -87,7 +88,7 @@ def _to_task(
87
88
  return tasks
88
89
 
89
90
 
90
- def _tokenizer_tasks(
91
+ def tokenizer_tasks_for_ttt(
91
92
  tasks: List[Task],
92
93
  tokenizer: "PreTrainedTokenizer",
93
94
  use_data_augmentation: bool = True,
@@ -106,7 +107,7 @@ def _tokenizer_tasks(
106
107
 
107
108
  formatter = _get_formatter("new")
108
109
  processor = functools.partial(
109
- process_task,
110
+ process_task_for_ttt,
110
111
  augmenters=augmenters_to_apply,
111
112
  formatter=formatter,
112
113
  tokenizer=tokenizer,
@@ -133,7 +134,34 @@ def _tokenizer_tasks(
133
134
  return dataset
134
135
 
135
136
 
136
- def load_tokenized_arc_agi_dataset(
137
+ def tokenizer_tasks(
138
+ tasks: List[Task],
139
+ tokenizer: "PreTrainedTokenizer",
140
+ ):
141
+ formatter = _get_formatter("new")
142
+ processor = functools.partial(
143
+ process_task, formatter=formatter, tokenizer=tokenizer
144
+ )
145
+
146
+ # with Pool(multiprocessing.cpu_count()) as p:
147
+ # data = p.map(processor, tasks)
148
+ data = _join_list(
149
+ [
150
+ processor(task)
151
+ for task in tqdm(
152
+ tasks,
153
+ desc="Processing tasks",
154
+ dynamic_ncols=True,
155
+ leave=False,
156
+ disable=not rank_zero_only.rank == 0,
157
+ )
158
+ ]
159
+ )
160
+ dataset = Dataset.from_list(data)
161
+ return dataset
162
+
163
+
164
+ def load_tokenized_arc_agi_dataset_for_ttt(
137
165
  tokenizer: Optional["PreTrainedTokenizer"],
138
166
  path: str = "dataartist/arc-agi",
139
167
  split: Optional[str] = None,
@@ -144,47 +172,47 @@ def load_tokenized_arc_agi_dataset(
144
172
  max_num_tasks: Optional[int] = None,
145
173
  ):
146
174
  # regularize split
147
- if split.lower() == "train":
148
- split = "training"
149
- if split.lower() == "test":
150
- split = "evaluation"
175
+ split = split.lower() if split is not None else split
151
176
 
152
177
  # load cached dataset if available
153
178
  if cache_path is not None and fusion_bench.utils.path.path_is_dir_and_not_empty(
154
179
  cache_path
155
180
  ):
156
181
  datasets = load_from_disk(cache_path)
157
- if split is None:
158
- return datasets
159
- else:
182
+ if split is None and split in datasets.column_names:
160
183
  return datasets[split]
184
+ else:
185
+ return datasets
161
186
  else:
162
187
  assert (
163
188
  tokenizer is not None
164
189
  ), "Cached dataset not found. Need tokenizer to process the raw data."
165
190
 
166
191
  # load raw dataset
167
- datasets = load_dataset(path, split=split)
192
+ datasets = load_dataset(path)
193
+ datasets = DatasetDict(
194
+ {"train": datasets["training"], "test": datasets["evaluation"]}
195
+ )
168
196
  if split is None:
169
- converted_datasets = {
197
+ converted_datasets: Dict[str, List[Task]] = {
170
198
  "train": _join_list(
171
199
  [
172
- _to_task(
200
+ _to_tasks(
173
201
  task["train"],
174
202
  task["test"],
175
203
  task["id"],
176
204
  )
177
- for task in datasets["training"]
205
+ for task in datasets["train"]
178
206
  ]
179
207
  ),
180
208
  "test": _join_list(
181
209
  [
182
- _to_task(
210
+ _to_tasks(
183
211
  task["train"],
184
212
  task["test"],
185
213
  task["id"],
186
214
  )
187
- for task in datasets["evaluation"]
215
+ for task in datasets["test"]
188
216
  ]
189
217
  ),
190
218
  }
@@ -195,7 +223,7 @@ def load_tokenized_arc_agi_dataset(
195
223
  for split in converted_datasets
196
224
  }
197
225
  converted_datasets = {
198
- split: _tokenizer_tasks(
226
+ split: tokenizer_tasks_for_ttt(
199
227
  converted_datasets[split],
200
228
  tokenizer,
201
229
  use_data_augmentation,
@@ -210,25 +238,128 @@ def load_tokenized_arc_agi_dataset(
210
238
  )
211
239
  }
212
240
  converted_datasets = DatasetDict(converted_datasets)
213
- else:
241
+ else: # split is not None
214
242
  converted_datasets = _join_list(
215
243
  [
216
- _to_task(
244
+ _to_tasks(
217
245
  task["train"],
218
246
  task["test"],
219
247
  task["id"],
220
248
  )
221
- for task in datasets
249
+ for task in datasets[split]
222
250
  ]
223
251
  )
224
252
  if max_num_tasks is not None:
225
253
  # limit the number of tasks, useful for debugging
226
254
  converted_datasets = converted_datasets[:max_num_tasks]
227
- converted_datasets = _tokenizer_tasks(
255
+ converted_datasets = tokenizer_tasks_for_ttt(
228
256
  converted_datasets, tokenizer, use_data_augmentation, permute_n, seed
229
257
  )
230
258
 
231
- if cache_path is not None:
259
+ if cache_path is not None and rank_zero_only.rank == 0:
260
+ os.makedirs(cache_path, exist_ok=True)
261
+ converted_datasets.save_to_disk(cache_path)
262
+ return converted_datasets
263
+
264
+
265
+ def load_tokenized_arc_agi_dataset(
266
+ tokenizer: Optional["PreTrainedTokenizer"],
267
+ path: str = "dataartist/arc-agi",
268
+ split: Optional[str] = None,
269
+ cache_path: Optional[str] = None,
270
+ max_num_tasks: Optional[int] = None,
271
+ ):
272
+ """
273
+ Loads and tokenizes the ARC-AGI dataset.
274
+
275
+ Args:
276
+ tokenizer (Optional[PreTrainedTokenizer]): The tokenizer to use for tokenizing the dataset.
277
+ path (str, optional): The path to the dataset. Defaults to "dataartist/arc-agi".
278
+ split (Optional[str], optional): The dataset split to load (e.g., "train", "test"). Defaults to None.
279
+ cache_path (Optional[str], optional): The path to cache the processed dataset. Defaults to None.
280
+ max_num_tasks (Optional[int], optional): The maximum number of tasks to load. Useful for debugging. Defaults to None.
281
+
282
+ Returns:
283
+ DatasetDict or Dataset: The tokenized dataset, either as a DatasetDict if split is None, or as a Dataset if a specific split is specified.
284
+ """
285
+ # regularize split
286
+ split = split.lower() if split is not None else split
287
+
288
+ # load cached dataset if available
289
+ if cache_path is not None and fusion_bench.utils.path.path_is_dir_and_not_empty(
290
+ cache_path
291
+ ):
292
+ datasets = load_from_disk(cache_path)
293
+ if split is None and split in datasets.column_names:
294
+ return datasets[split]
295
+ else:
296
+ return datasets
297
+ else:
298
+ assert (
299
+ tokenizer is not None
300
+ ), "Cached dataset not found. Need tokenizer to process the raw data."
301
+
302
+ # load raw dataset
303
+ datasets = load_dataset(path)
304
+ datasets = DatasetDict(
305
+ {"train": datasets["training"], "test": datasets["evaluation"]}
306
+ )
307
+ if split is None:
308
+ converted_datasets: Dict[str, List[Task]] = {
309
+ "train": _join_list(
310
+ [
311
+ _to_tasks(
312
+ task["train"],
313
+ task["test"],
314
+ task["id"],
315
+ )
316
+ for task in datasets["train"]
317
+ ]
318
+ ),
319
+ "test": _join_list(
320
+ [
321
+ _to_tasks(
322
+ task["train"],
323
+ task["test"],
324
+ task["id"],
325
+ )
326
+ for task in datasets["test"]
327
+ ]
328
+ ),
329
+ }
330
+ if max_num_tasks is not None:
331
+ # limit the number of tasks, useful for debugging
332
+ converted_datasets = {
333
+ split: converted_datasets[split][:max_num_tasks]
334
+ for split in converted_datasets
335
+ }
336
+ converted_datasets = {
337
+ split: tokenizer_tasks(converted_datasets[split], tokenizer)
338
+ for split in tqdm(
339
+ converted_datasets,
340
+ desc="Processing splits",
341
+ dynamic_ncols=True,
342
+ disable=not rank_zero_only.rank == 0,
343
+ )
344
+ }
345
+ converted_datasets = DatasetDict(converted_datasets)
346
+ else: # split is not None
347
+ converted_datasets = _join_list(
348
+ [
349
+ _to_tasks(
350
+ task["train"],
351
+ task["test"],
352
+ task["id"],
353
+ )
354
+ for task in datasets[split]
355
+ ]
356
+ )
357
+ if max_num_tasks is not None:
358
+ # limit the number of tasks, useful for debugging
359
+ converted_datasets = converted_datasets[:max_num_tasks]
360
+ converted_datasets = tokenizer_tasks(converted_datasets, tokenizer)
361
+
362
+ if cache_path is not None and rank_zero_only.rank == 0:
232
363
  os.makedirs(cache_path, exist_ok=True)
233
364
  converted_datasets.save_to_disk(cache_path)
234
365
  return converted_datasets
@@ -6,7 +6,6 @@ from typing import Callable, Optional, TypeVar, cast
6
6
  import numpy as np
7
7
  from xxhash import xxh3_64_hexdigest
8
8
 
9
-
10
9
  __all__ = ["np_lru_cache"]
11
10
 
12
11
  TCallable = TypeVar("TCallable", bound=Callable)
@@ -112,19 +112,41 @@ def get_augmenters(
112
112
  def format_and_filter(
113
113
  formatter: MessageRepresenter,
114
114
  tokenizer: "PreTrainedTokenizer",
115
- task,
115
+ task: Task,
116
116
  ):
117
+ """
118
+ Formats and filters a task for model input.
119
+
120
+ Args:
121
+ formatter (MessageRepresenter): The formatter to encode the task.
122
+ tokenizer (PreTrainedTokenizer): The tokenizer to tokenize the conversation.
123
+ task: The task to be formatted and filtered.
124
+
125
+ Returns:
126
+ Dict[str, Any]: A dictionary containing the formatted data with keys:
127
+ - "input_ids": The tokenized input IDs.
128
+ - "attention_mask": The attention mask for the input IDs.
129
+ - "labels": The labels for the input IDs.
130
+ - "task_id": The task ID.
131
+ - "num_prompt_tokens": The number of prompt tokens.
132
+ - "num_output_tokens": The number of output tokens.
133
+ """
134
+ task_id = task.name
117
135
  task = formatter.encode(task)
118
136
  conversation = task[0] + [task[1]]
119
137
  assert conversation[-1]["role"] == "assistant", "Last message should be assistant"
120
138
  prompt_tokens = tokenizer.apply_chat_template(
121
139
  conversation[:-1], tokenize=True, add_generation_prompt=True
122
140
  )
123
- output_tokens = tokenizer.encode(conversation[-1]["content"] + tokenizer.eos_token)
141
+ generation_tokens = tokenizer.apply_chat_template(conversation, tokenize=True)
142
+ output_tokens = generation_tokens[len(prompt_tokens) :]
124
143
  data = {
125
144
  "input_ids": prompt_tokens + output_tokens,
126
145
  "attention_mask": [1] * len(prompt_tokens) + [1] * len(output_tokens),
127
- "labels": [-100] * len(prompt_tokens) + output_tokens,
146
+ "labels": prompt_tokens + output_tokens,
147
+ "task_id": task_id,
148
+ "num_prompt_tokens": len(prompt_tokens),
149
+ "num_output_tokens": len(output_tokens),
128
150
  }
129
151
  return data
130
152
 
@@ -136,6 +158,19 @@ def get_test_time_train_data(
136
158
  permute_n: int = 1,
137
159
  seed: int = 0,
138
160
  ) -> List[Task]:
161
+ """
162
+ Generates augmented training data for test-time training.
163
+
164
+ Args:
165
+ original_task (Task): The original task containing training examples.
166
+ augmenters (List[Augmenter]): A list of augmenters to apply to the tasks.
167
+ n (int, optional): The number of examples to leave out for testing. Defaults to 1.
168
+ permute_n (int, optional): The number of times to permute the augmented tasks. Defaults to 1.
169
+ seed (int, optional): The random seed for reproducibility. Defaults to 0.
170
+
171
+ Returns:
172
+ List[Task]: A list of augmented tasks.
173
+ """
139
174
  rng = np.random.RandomState(seed)
140
175
  train_examples = original_task.train_examples.copy()
141
176
  initial_tasks = []
@@ -150,7 +185,7 @@ def get_test_time_train_data(
150
185
  for comb in combs:
151
186
  initial_tasks.append(
152
187
  Task(
153
- name="",
188
+ name=original_task.name,
154
189
  train_examples=[examples[j] for j in comb],
155
190
  test_example=examples[i],
156
191
  )
@@ -183,7 +218,6 @@ def get_test_time_train_data(
183
218
  color_and_permute_augmented_tasks.append(new_task)
184
219
 
185
220
  augmented_tasks = color_and_permute_augmented_tasks + augmented_tasks
186
-
187
221
  augmented_tasks = list(set(augmented_tasks))
188
222
 
189
223
  return augmented_tasks
@@ -193,13 +227,12 @@ def get_formatted_data(
193
227
  task: Task,
194
228
  augmenters: List[Augmenter],
195
229
  formatter: MessageRepresenter,
196
- tokenizer,
230
+ tokenizer: "PreTrainedTokenizer",
197
231
  leave_n: int = 1,
198
232
  permute_n: int = 1,
199
233
  seed: int = 0,
200
234
  max_tokens: int = 8192,
201
235
  ):
202
-
203
236
  train_data = get_test_time_train_data(
204
237
  task, augmenters, n=leave_n, permute_n=permute_n, seed=seed
205
238
  )
@@ -213,11 +246,11 @@ def get_formatted_data(
213
246
  return formatted_data
214
247
 
215
248
 
216
- def process_task(
249
+ def process_task_for_ttt(
217
250
  task: Task,
218
251
  augmenters: List[Augmenter],
219
252
  formatter: MessageRepresenter,
220
- tokenizer,
253
+ tokenizer: "PreTrainedTokenizer",
221
254
  permute_n: int = 1,
222
255
  Nmax: int = 250,
223
256
  seed: int = 0,
@@ -254,3 +287,12 @@ def process_task(
254
287
  train = train[:Nmax]
255
288
 
256
289
  return train
290
+
291
+
292
+ def process_task(
293
+ task: Task,
294
+ formatter: MessageRepresenter,
295
+ tokenizer: "PreTrainedTokenizer",
296
+ ):
297
+ formatted = format_and_filter(formatter, tokenizer, task)
298
+ return [formatted]
@@ -0,0 +1 @@
1
+ from . import collate