fusion-bench 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (727) hide show
  1. fusion_bench/__init__.py +20 -0
  2. fusion_bench/__main__.py +4 -0
  3. fusion_bench/compat/__init__.py +0 -0
  4. fusion_bench/compat/method/__init__.py +109 -0
  5. fusion_bench/compat/method/base_algorithm.py +58 -0
  6. fusion_bench/compat/modelpool/AutoModelForSeq2SeqLM.py +34 -0
  7. fusion_bench/compat/modelpool/__init__.py +116 -0
  8. fusion_bench/compat/modelpool/base_pool.py +328 -0
  9. fusion_bench/compat/modelpool/huggingface_clip_vision.py +178 -0
  10. fusion_bench/compat/taskpool/__init__.py +95 -0
  11. fusion_bench/compat/taskpool/base_pool.py +111 -0
  12. fusion_bench/compat/taskpool/clip_image_classification.py +210 -0
  13. fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +175 -0
  14. fusion_bench/constants/__init__.py +2 -0
  15. fusion_bench/constants/paths.py +18 -0
  16. fusion_bench/dataset/__init__.py +29 -0
  17. fusion_bench/dataset/arc_agi/__init__.py +6 -0
  18. fusion_bench/dataset/arc_agi/arc.py +308 -0
  19. fusion_bench/dataset/arc_agi/arc_agi.py +365 -0
  20. fusion_bench/dataset/arc_agi/augmenters.py +1036 -0
  21. fusion_bench/dataset/arc_agi/messagers.py +1355 -0
  22. fusion_bench/dataset/arc_agi/np_cache.py +168 -0
  23. fusion_bench/dataset/arc_agi/preprocess.py +298 -0
  24. fusion_bench/dataset/arc_agi/representers.py +1019 -0
  25. fusion_bench/dataset/clip_dataset.py +71 -0
  26. fusion_bench/dataset/fer2013.py +12 -0
  27. fusion_bench/dataset/gpt2_glue.py +300 -0
  28. fusion_bench/dataset/gsm8k.py +60 -0
  29. fusion_bench/dataset/image_dataset.py +55 -0
  30. fusion_bench/dataset/imdb.py +11 -0
  31. fusion_bench/dataset/llama/__init__.py +1 -0
  32. fusion_bench/dataset/llama/alpaca.py +232 -0
  33. fusion_bench/dataset/llama/collate.py +120 -0
  34. fusion_bench/dataset/llama/metamathqa.py +50 -0
  35. fusion_bench/dataset/llama/openai.py +160 -0
  36. fusion_bench/dataset/llama/preference_700k.py +70 -0
  37. fusion_bench/dataset/llama/sharegpt.py +141 -0
  38. fusion_bench/dataset/llama/squad.py +125 -0
  39. fusion_bench/dataset/llama/stanford_shp.py +90 -0
  40. fusion_bench/dataset/llama/ultrachat.py +58 -0
  41. fusion_bench/dataset/llama/utils/__init__.py +0 -0
  42. fusion_bench/dataset/llama/wikitext.py +89 -0
  43. fusion_bench/dataset/nyuv2.py +119 -0
  44. fusion_bench/method/__init__.py +177 -0
  45. fusion_bench/method/ada_svd/__init__.py +2 -0
  46. fusion_bench/method/ada_svd/clip_vision.py +319 -0
  47. fusion_bench/method/adamerging/__init__.py +6 -0
  48. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +46 -0
  49. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +187 -0
  50. fusion_bench/method/adamerging/entropy_loss.py +25 -0
  51. fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +332 -0
  52. fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +351 -0
  53. fusion_bench/method/adamerging/layer_wise_adamerging.py +252 -0
  54. fusion_bench/method/adamerging/llama_adamerging.py +335 -0
  55. fusion_bench/method/adamerging/min_norm_solvers.py +227 -0
  56. fusion_bench/method/adamerging/task_wise_adamerging.py +174 -0
  57. fusion_bench/method/adamerging/utils.py +15 -0
  58. fusion_bench/method/analysis/__init__.py +2 -0
  59. fusion_bench/method/analysis/task_vector_cos_similarity.py +172 -0
  60. fusion_bench/method/analysis/task_vector_violin_plot.py +205 -0
  61. fusion_bench/method/base_algorithm.py +44 -0
  62. fusion_bench/method/classification/__init__.py +3 -0
  63. fusion_bench/method/classification/clip_finetune.py +444 -0
  64. fusion_bench/method/classification/continual_clip_finetune.py +297 -0
  65. fusion_bench/method/concrete_subspace/__init__.py +6 -0
  66. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +595 -0
  67. fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py +263 -0
  68. fusion_bench/method/dare/__init__.py +4 -0
  69. fusion_bench/method/dare/simple_average.py +31 -0
  70. fusion_bench/method/dare/task_arithmetic.py +82 -0
  71. fusion_bench/method/dare/ties_merging.py +100 -0
  72. fusion_bench/method/dare/utils.py +87 -0
  73. fusion_bench/method/dawe/__init__.py +2 -0
  74. fusion_bench/method/dawe/dawe_for_clip.py +274 -0
  75. fusion_bench/method/dawe/warppers/__init__.py +13 -0
  76. fusion_bench/method/dawe/warppers/dawe_model.py +256 -0
  77. fusion_bench/method/depth_upscaling/__init__.py +3 -0
  78. fusion_bench/method/depth_upscaling/depth_upscaling.py +89 -0
  79. fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +57 -0
  80. fusion_bench/method/dummy.py +35 -0
  81. fusion_bench/method/ensemble.py +98 -0
  82. fusion_bench/method/fisher_merging/__init__.py +4 -0
  83. fusion_bench/method/fisher_merging/clip_fisher_merging.py +191 -0
  84. fusion_bench/method/fisher_merging/fisher_merging.py +484 -0
  85. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +193 -0
  86. fusion_bench/method/linear/__init__.py +6 -0
  87. fusion_bench/method/linear/expo.py +118 -0
  88. fusion_bench/method/linear/linear_interpolation.py +60 -0
  89. fusion_bench/method/linear/llama_expo.py +229 -0
  90. fusion_bench/method/linear/simple_average_for_llama.py +54 -0
  91. fusion_bench/method/linear/task_arithmetic_for_llama.py +57 -0
  92. fusion_bench/method/lm_finetune/__init__.py +3 -0
  93. fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
  94. fusion_bench/method/lm_finetune/causal_lm_pretrain.py +7 -0
  95. fusion_bench/method/lm_finetune/fullfinetune_sft.py +375 -0
  96. fusion_bench/method/lm_finetune/peftfinetune_sft.py +370 -0
  97. fusion_bench/method/mixture_of_experts/__init__.py +7 -0
  98. fusion_bench/method/mixture_of_experts/mixtral_merging.py +112 -0
  99. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +329 -0
  100. fusion_bench/method/model_recombination.py +121 -0
  101. fusion_bench/method/opcm/__init__.py +4 -0
  102. fusion_bench/method/opcm/opcm.py +277 -0
  103. fusion_bench/method/opcm/task_arithmetic.py +115 -0
  104. fusion_bench/method/opcm/ties_merging.py +156 -0
  105. fusion_bench/method/opcm/utils.py +73 -0
  106. fusion_bench/method/opcm/weight_average.py +120 -0
  107. fusion_bench/method/pruning/__init__.py +5 -0
  108. fusion_bench/method/pruning/llama_magnitude_prune.py +202 -0
  109. fusion_bench/method/pruning/llama_random_prune.py +143 -0
  110. fusion_bench/method/pruning/llama_wanda_prune.py +359 -0
  111. fusion_bench/method/pruning/magnitude_diff_pruning.py +180 -0
  112. fusion_bench/method/pruning/prune_utils.py +165 -0
  113. fusion_bench/method/pruning/wanda_utils/__init__.py +7 -0
  114. fusion_bench/method/pruning/wanda_utils/ablate.py +188 -0
  115. fusion_bench/method/pruning/wanda_utils/data.py +135 -0
  116. fusion_bench/method/pruning/wanda_utils/eval.py +245 -0
  117. fusion_bench/method/pruning/wanda_utils/layerwrapper.py +61 -0
  118. fusion_bench/method/pruning/wanda_utils/prune.py +581 -0
  119. fusion_bench/method/pruning/wanda_utils/prune_opt.py +539 -0
  120. fusion_bench/method/pruning/wanda_utils/sparsegpt.py +165 -0
  121. fusion_bench/method/pwe_moe/__init__.py +5 -0
  122. fusion_bench/method/pwe_moe/clip_pwe_moe.py +315 -0
  123. fusion_bench/method/pwe_moe/module.py +316 -0
  124. fusion_bench/method/pwe_moe/phn/__init__.py +2 -0
  125. fusion_bench/method/pwe_moe/phn/solvers.py +195 -0
  126. fusion_bench/method/pwe_moe/utils.py +43 -0
  127. fusion_bench/method/rankone_moe/__init__.py +3 -0
  128. fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
  129. fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
  130. fusion_bench/method/regmean/__init__.py +4 -0
  131. fusion_bench/method/regmean/clip_regmean.py +131 -0
  132. fusion_bench/method/regmean/gpt2_regmean.py +147 -0
  133. fusion_bench/method/regmean/regmean.py +375 -0
  134. fusion_bench/method/simple_average.py +112 -0
  135. fusion_bench/method/slerp/__init__.py +2 -0
  136. fusion_bench/method/slerp/slerp.py +101 -0
  137. fusion_bench/method/slerp/slerp_utils.py +107 -0
  138. fusion_bench/method/smile_upscaling/__init__.py +3 -0
  139. fusion_bench/method/smile_upscaling/singular_projection_merging.py +198 -0
  140. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +331 -0
  141. fusion_bench/method/smile_upscaling/smile_upscaling.py +573 -0
  142. fusion_bench/method/sparse_we_moe/__init__.py +2 -0
  143. fusion_bench/method/sparse_we_moe/sparse_clip_we_moe.py +248 -0
  144. fusion_bench/method/sparse_we_moe/sparse_we_moe.py +301 -0
  145. fusion_bench/method/sparselo/__init__.py +2 -0
  146. fusion_bench/method/sparselo/sparselo.py +955 -0
  147. fusion_bench/method/surgery/__init__.py +1 -0
  148. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
  149. fusion_bench/method/tall_mask/__init__.py +0 -0
  150. fusion_bench/method/tall_mask/utils.py +234 -0
  151. fusion_bench/method/task_arithmetic/__init__.py +2 -0
  152. fusion_bench/method/task_arithmetic/task_arithmetic.py +151 -0
  153. fusion_bench/method/task_singular_vector/TSVC.py +16 -0
  154. fusion_bench/method/task_singular_vector/TSVM.py +63 -0
  155. fusion_bench/method/task_singular_vector/__init__.py +9 -0
  156. fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
  157. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +640 -0
  158. fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
  159. fusion_bench/method/ties_merging/__init__.py +2 -0
  160. fusion_bench/method/ties_merging/ties_merging.py +117 -0
  161. fusion_bench/method/ties_merging/ties_merging_utils.py +331 -0
  162. fusion_bench/method/trust_region/__init__.py +2 -0
  163. fusion_bench/method/trust_region/clip_task_arithmetic.py +205 -0
  164. fusion_bench/method/trust_region/utils.py +58 -0
  165. fusion_bench/method/we_moe/__init__.py +2 -0
  166. fusion_bench/method/we_moe/clip_we_moe.py +161 -0
  167. fusion_bench/method/we_moe/we_moe.py +247 -0
  168. fusion_bench/method/weighted_average/__init__.py +3 -0
  169. fusion_bench/method/weighted_average/llama.py +113 -0
  170. fusion_bench/method/weighted_average/weighted_average.py +102 -0
  171. fusion_bench/metrics/__init__.py +0 -0
  172. fusion_bench/metrics/continual_learning/backward_transfer.py +22 -0
  173. fusion_bench/metrics/nyuv2/__init__.py +11 -0
  174. fusion_bench/metrics/nyuv2/depth.py +45 -0
  175. fusion_bench/metrics/nyuv2/loss.py +31 -0
  176. fusion_bench/metrics/nyuv2/noise.py +16 -0
  177. fusion_bench/metrics/nyuv2/normal.py +48 -0
  178. fusion_bench/metrics/nyuv2/segmentation.py +43 -0
  179. fusion_bench/metrics/text_to_image_generation/__init__.py +9 -0
  180. fusion_bench/metrics/text_to_image_generation/aesthetic_scorer.py +123 -0
  181. fusion_bench/metrics/text_to_image_generation/compressibility.py +49 -0
  182. fusion_bench/metrics/text_to_image_generation/pickscore_scorer.py +95 -0
  183. fusion_bench/mixins/__init__.py +28 -0
  184. fusion_bench/mixins/clip_classification.py +252 -0
  185. fusion_bench/mixins/fabric_training.py +320 -0
  186. fusion_bench/mixins/lightning_fabric.py +174 -0
  187. fusion_bench/mixins/optim/__init__.py +0 -0
  188. fusion_bench/mixins/optim/adamw_with_warmup.py +42 -0
  189. fusion_bench/mixins/rich_live.py +21 -0
  190. fusion_bench/mixins/serialization.py +132 -0
  191. fusion_bench/mixins/simple_profiler.py +79 -0
  192. fusion_bench/modelpool/PeftModelForSeq2SeqLM.py +49 -0
  193. fusion_bench/modelpool/__init__.py +42 -0
  194. fusion_bench/modelpool/base_pool.py +268 -0
  195. fusion_bench/modelpool/causal_lm/__init__.py +2 -0
  196. fusion_bench/modelpool/causal_lm/causal_lm.py +139 -0
  197. fusion_bench/modelpool/clip_vision/__init__.py +1 -0
  198. fusion_bench/modelpool/clip_vision/modelpool.py +145 -0
  199. fusion_bench/modelpool/huggingface_automodel.py +20 -0
  200. fusion_bench/modelpool/huggingface_gpt2_classification.py +63 -0
  201. fusion_bench/modelpool/nyuv2_modelpool.py +40 -0
  202. fusion_bench/modelpool/seq2seq_lm/__init__.py +2 -0
  203. fusion_bench/modelpool/seq2seq_lm/modelpool.py +65 -0
  204. fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
  205. fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
  206. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
  207. fusion_bench/models/__init__.py +3 -0
  208. fusion_bench/models/chat_templates/__init__.py +1 -0
  209. fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
  210. fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
  211. fusion_bench/models/hf_clip.py +199 -0
  212. fusion_bench/models/linearized/__init__.py +0 -0
  213. fusion_bench/models/linearized/linearized_model_utils.py +91 -0
  214. fusion_bench/models/linearized/vision_model.py +122 -0
  215. fusion_bench/models/llama/__init__.py +16 -0
  216. fusion_bench/models/llama/model_utils/__init__.py +0 -0
  217. fusion_bench/models/llama/model_utils/embedding.py +87 -0
  218. fusion_bench/models/llama/model_utils/liger_kernel.py +86 -0
  219. fusion_bench/models/llama/model_utils/misc.py +112 -0
  220. fusion_bench/models/llama/model_utils/mod.py +52 -0
  221. fusion_bench/models/llama/model_utils/visual.py +241 -0
  222. fusion_bench/models/llama/patcher.py +78 -0
  223. fusion_bench/models/llama/tokenizer_loader.py +153 -0
  224. fusion_bench/models/masks/__init__.py +2 -0
  225. fusion_bench/models/masks/mask_model.py +160 -0
  226. fusion_bench/models/modeling_losparse_llama/__init__.py +4 -0
  227. fusion_bench/models/modeling_losparse_llama/configuration_losparse_llama.py +205 -0
  228. fusion_bench/models/modeling_losparse_llama/losparse_linear.py +67 -0
  229. fusion_bench/models/modeling_losparse_llama/modeling_losparse_llama.py +1825 -0
  230. fusion_bench/models/modeling_losparse_llama/register.py +8 -0
  231. fusion_bench/models/modeling_losparse_llama/utils.py +60 -0
  232. fusion_bench/models/modeling_smile_mistral/__init__.py +48 -0
  233. fusion_bench/models/modeling_smile_mistral/configuration_smile_mistral.py +21 -0
  234. fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +1034 -0
  235. fusion_bench/models/modeling_smile_mistral/register.py +8 -0
  236. fusion_bench/models/nyuv2/__init__.py +0 -0
  237. fusion_bench/models/nyuv2/aspp.py +82 -0
  238. fusion_bench/models/nyuv2/lightning_module.py +176 -0
  239. fusion_bench/models/nyuv2/resnet.py +405 -0
  240. fusion_bench/models/nyuv2/resnet_dilated.py +99 -0
  241. fusion_bench/models/parameter_dict.py +75 -0
  242. fusion_bench/models/rankone_moe.py +410 -0
  243. fusion_bench/models/separate_io.py +105 -0
  244. fusion_bench/models/smile_moe/__init__.py +0 -0
  245. fusion_bench/models/smile_moe/linear.py +256 -0
  246. fusion_bench/models/sparse_we_moe.py +459 -0
  247. fusion_bench/models/surgery/__init__.py +1 -0
  248. fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
  249. fusion_bench/models/utils.py +80 -0
  250. fusion_bench/models/we_moe.py +247 -0
  251. fusion_bench/models/wrappers/__init__.py +0 -0
  252. fusion_bench/models/wrappers/ensemble.py +183 -0
  253. fusion_bench/models/wrappers/layer_wise_fusion.py +336 -0
  254. fusion_bench/models/wrappers/task_wise_fusion.py +249 -0
  255. fusion_bench/optim/__init__.py +2 -0
  256. fusion_bench/optim/exception.py +47 -0
  257. fusion_bench/optim/lr_scheduler/__init__.py +1 -0
  258. fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
  259. fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
  260. fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
  261. fusion_bench/optim/mezo.py +118 -0
  262. fusion_bench/programs/__init__.py +20 -0
  263. fusion_bench/programs/base_program.py +9 -0
  264. fusion_bench/programs/fabric_fusion_program.py +299 -0
  265. fusion_bench/scripts/__init__.py +0 -0
  266. fusion_bench/scripts/cli.py +43 -0
  267. fusion_bench/scripts/clip/__init__.py +0 -0
  268. fusion_bench/scripts/clip/convert_checkpoint.py +39 -0
  269. fusion_bench/scripts/imgui.py +218 -0
  270. fusion_bench/scripts/nyuv2_mtl_train.py +137 -0
  271. fusion_bench/scripts/webui.py +405 -0
  272. fusion_bench/taskpool/__init__.py +39 -0
  273. fusion_bench/taskpool/base_pool.py +35 -0
  274. fusion_bench/taskpool/clip_vision/__init__.py +4 -0
  275. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
  276. fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +120 -0
  277. fusion_bench/taskpool/clip_vision/taskpool.py +392 -0
  278. fusion_bench/taskpool/dummy.py +58 -0
  279. fusion_bench/taskpool/gpt2_text_classification.py +149 -0
  280. fusion_bench/taskpool/llama/__init__.py +1 -0
  281. fusion_bench/taskpool/llama/reward_model.py +157 -0
  282. fusion_bench/taskpool/llama/test_generation.py +185 -0
  283. fusion_bench/taskpool/nyuv2_taskpool.py +65 -0
  284. fusion_bench/tasks/__init__.py +2 -0
  285. fusion_bench/tasks/base_task.py +18 -0
  286. fusion_bench/tasks/classification.py +75 -0
  287. fusion_bench/tasks/clip_classification/__init__.py +183 -0
  288. fusion_bench/tasks/clip_classification/cifar10.py +33 -0
  289. fusion_bench/tasks/clip_classification/cifar100.py +146 -0
  290. fusion_bench/tasks/clip_classification/clip_dataset.py +1 -0
  291. fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
  292. fusion_bench/tasks/clip_classification/dtd.py +60 -0
  293. fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
  294. fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
  295. fusion_bench/tasks/clip_classification/eurosat.py +18 -0
  296. fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
  297. fusion_bench/tasks/clip_classification/fer2013.py +18 -0
  298. fusion_bench/tasks/clip_classification/flower102.py +106 -0
  299. fusion_bench/tasks/clip_classification/food101.py +105 -0
  300. fusion_bench/tasks/clip_classification/gtsrb.py +51 -0
  301. fusion_bench/tasks/clip_classification/imagenet.py +2103 -0
  302. fusion_bench/tasks/clip_classification/kmnist.py +17 -0
  303. fusion_bench/tasks/clip_classification/mnist.py +5 -0
  304. fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
  305. fusion_bench/tasks/clip_classification/oxford_iiit_pet.py +41 -0
  306. fusion_bench/tasks/clip_classification/pcam.py +5 -0
  307. fusion_bench/tasks/clip_classification/rendered_sst2.py +3 -0
  308. fusion_bench/tasks/clip_classification/resisc45.py +68 -0
  309. fusion_bench/tasks/clip_classification/stanford_cars.py +209 -0
  310. fusion_bench/tasks/clip_classification/stl10.py +17 -0
  311. fusion_bench/tasks/clip_classification/sun397.py +404 -0
  312. fusion_bench/tasks/clip_classification/svhn.py +5 -0
  313. fusion_bench/tasks/clip_classification/tiny_imagenet.py +208 -0
  314. fusion_bench/tasks/flan_t5_text_generation/__init__.py +0 -0
  315. fusion_bench/tasks/flan_t5_text_generation/datasets_preprocess.py +71 -0
  316. fusion_bench/tasks/flan_t5_text_generation/glue_evaluation.py +132 -0
  317. fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +64 -0
  318. fusion_bench/tasks/flan_t5_text_generation/glue_preprocessors.py +379 -0
  319. fusion_bench/tasks/flan_t5_text_generation/glue_prompt_templates.py +52 -0
  320. fusion_bench/utils/__init__.py +14 -0
  321. fusion_bench/utils/auto.py +31 -0
  322. fusion_bench/utils/cache_utils.py +58 -0
  323. fusion_bench/utils/data.py +165 -0
  324. fusion_bench/utils/devices.py +231 -0
  325. fusion_bench/utils/dict.py +43 -0
  326. fusion_bench/utils/dtype.py +146 -0
  327. fusion_bench/utils/expr.py +90 -0
  328. fusion_bench/utils/fabric.py +17 -0
  329. fusion_bench/utils/functools.py +37 -0
  330. fusion_bench/utils/hydra_utils.py +28 -0
  331. fusion_bench/utils/instantiate.py +450 -0
  332. fusion_bench/utils/json.py +93 -0
  333. fusion_bench/utils/lazy_imports.py +74 -0
  334. fusion_bench/utils/misc.py +18 -0
  335. fusion_bench/utils/packages.py +84 -0
  336. fusion_bench/utils/parameters.py +323 -0
  337. fusion_bench/utils/path.py +22 -0
  338. fusion_bench/utils/plot/__init__.py +0 -0
  339. fusion_bench/utils/plot/color_data.py +1726 -0
  340. fusion_bench/utils/plot/token.py +52 -0
  341. fusion_bench/utils/plot/token_notebook.py +127 -0
  342. fusion_bench/utils/pylogger.py +55 -0
  343. fusion_bench/utils/rich_utils.py +201 -0
  344. fusion_bench/utils/set.py +8 -0
  345. fusion_bench/utils/state_dict_arithmetic.py +297 -0
  346. fusion_bench/utils/strenum/__init__.py +326 -0
  347. fusion_bench/utils/strenum/_name_mangler.py +127 -0
  348. fusion_bench/utils/strenum/_version.py +556 -0
  349. fusion_bench/utils/tensorboard.py +51 -0
  350. fusion_bench/utils/timer.py +49 -0
  351. fusion_bench/utils/type.py +34 -0
  352. fusion_bench-0.2.9.dist-info/LICENSE +21 -0
  353. fusion_bench-0.2.9.dist-info/METADATA +258 -0
  354. fusion_bench-0.2.9.dist-info/RECORD +727 -0
  355. fusion_bench-0.2.9.dist-info/WHEEL +5 -0
  356. fusion_bench-0.2.9.dist-info/entry_points.txt +3 -0
  357. fusion_bench-0.2.9.dist-info/top_level.txt +1 -0
  358. fusion_bench_config/README.md +12 -0
  359. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +23 -0
  360. fusion_bench_config/dataset/image_classification/README.md +6 -0
  361. fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
  362. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
  363. fusion_bench_config/dataset/image_classification/test/cifar10.yaml +4 -0
  364. fusion_bench_config/dataset/image_classification/test/cifar100.yaml +4 -0
  365. fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
  366. fusion_bench_config/dataset/image_classification/test/dtd.yaml +4 -0
  367. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
  368. fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
  369. fusion_bench_config/dataset/image_classification/test/eurosat.yaml +4 -0
  370. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
  371. fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
  372. fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
  373. fusion_bench_config/dataset/image_classification/test/gtsrb.yaml +4 -0
  374. fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
  375. fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
  376. fusion_bench_config/dataset/image_classification/test/mnist.yaml +4 -0
  377. fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
  378. fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
  379. fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
  380. fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
  381. fusion_bench_config/dataset/image_classification/test/resisc45.yaml +4 -0
  382. fusion_bench_config/dataset/image_classification/test/stanford-cars.yaml +4 -0
  383. fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
  384. fusion_bench_config/dataset/image_classification/test/sun397.yaml +4 -0
  385. fusion_bench_config/dataset/image_classification/test/svhn.yaml +6 -0
  386. fusion_bench_config/dataset/image_classification/test/the_eight_tasks.yaml +9 -0
  387. fusion_bench_config/dataset/image_classification/test/tiny-imagenet.yaml +4 -0
  388. fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
  389. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
  390. fusion_bench_config/dataset/image_classification/train/cifar10.yaml +4 -0
  391. fusion_bench_config/dataset/image_classification/train/cifar100.yaml +4 -0
  392. fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
  393. fusion_bench_config/dataset/image_classification/train/dtd.yaml +4 -0
  394. fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
  395. fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
  396. fusion_bench_config/dataset/image_classification/train/eurosat.yaml +4 -0
  397. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
  398. fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
  399. fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
  400. fusion_bench_config/dataset/image_classification/train/gtsrb.yaml +4 -0
  401. fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
  402. fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
  403. fusion_bench_config/dataset/image_classification/train/mnist.yaml +4 -0
  404. fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
  405. fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
  406. fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
  407. fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
  408. fusion_bench_config/dataset/image_classification/train/resisc45.yaml +4 -0
  409. fusion_bench_config/dataset/image_classification/train/stanford-cars.yaml +4 -0
  410. fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
  411. fusion_bench_config/dataset/image_classification/train/sun397.yaml +4 -0
  412. fusion_bench_config/dataset/image_classification/train/svhn.yaml +6 -0
  413. fusion_bench_config/dataset/image_classification/train/the_eight_tasks.yaml +9 -0
  414. fusion_bench_config/dataset/image_classification/train/tiny-imagenet.yaml +4 -0
  415. fusion_bench_config/dataset/image_classification/val/dtd.yaml +10 -0
  416. fusion_bench_config/dataset/image_classification/val/eurosat.yaml +10 -0
  417. fusion_bench_config/dataset/image_classification/val/gtsrb.yaml +10 -0
  418. fusion_bench_config/dataset/image_classification/val/mnist.yaml +10 -0
  419. fusion_bench_config/dataset/image_classification/val/resisc45.yaml +10 -0
  420. fusion_bench_config/dataset/image_classification/val/stanford-cars.yaml +10 -0
  421. fusion_bench_config/dataset/image_classification/val/sun397.yaml +10 -0
  422. fusion_bench_config/dataset/image_classification/val/svhn.yaml +12 -0
  423. fusion_bench_config/dataset/image_classification/val/the_eight_tasks.yaml +9 -0
  424. fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
  425. fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
  426. fusion_bench_config/dataset/question_answering/search_qa.yaml +6 -0
  427. fusion_bench_config/dataset/question_answering/test/search_qa.yaml +7 -0
  428. fusion_bench_config/dataset/question_answering/train/MetaMathQA.yaml +4 -0
  429. fusion_bench_config/dataset/question_answering/train/search_qa.yaml +7 -0
  430. fusion_bench_config/dataset/question_answering/val/search_qa.yaml +7 -0
  431. fusion_bench_config/dataset/summarization/test/xsum.yaml +4 -0
  432. fusion_bench_config/dataset/summarization/train/xsum.yaml +4 -0
  433. fusion_bench_config/dataset/summarization/val/xsum.yaml +4 -0
  434. fusion_bench_config/dataset/summarization/xsum.yaml +3 -0
  435. fusion_bench_config/dataset/text_generation/test/gsm-hard.yaml +4 -0
  436. fusion_bench_config/dataset/text_generation/test/gsm8k.yaml +5 -0
  437. fusion_bench_config/dataset/text_generation/test/gsm8k_question_label.yaml +3 -0
  438. fusion_bench_config/dataset/text_generation/train/CodeAlpaca-20k.yaml +4 -0
  439. fusion_bench_config/dataset/text_generation/train/gsm8k.yaml +5 -0
  440. fusion_bench_config/dataset/text_generation/train/gsm8k_question_label.yaml +3 -0
  441. fusion_bench_config/fabric/auto.yaml +16 -0
  442. fusion_bench_config/fabric/llama_ddp.yaml +18 -0
  443. fusion_bench_config/fabric/llama_fsdp.yaml +16 -0
  444. fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
  445. fusion_bench_config/fabric/loggers/csv_logger.yaml +11 -0
  446. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +11 -0
  447. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  448. fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
  449. fusion_bench_config/fabric/strategy/llama_fsdp.yaml +8 -0
  450. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
  451. fusion_bench_config/fabric_model_fusion.yaml +20 -0
  452. fusion_bench_config/hydra/default.yaml +8 -0
  453. fusion_bench_config/hydra/help/fusion_bench_help.yaml +47 -0
  454. fusion_bench_config/hydra/job_logging/rich_logging.yaml +20 -0
  455. fusion_bench_config/llama_full_finetune.yaml +19 -0
  456. fusion_bench_config/llama_magnitude_pruning.yaml +16 -0
  457. fusion_bench_config/llama_model_fusion.yaml +17 -0
  458. fusion_bench_config/method/ada_svd/clip_vision.yaml +9 -0
  459. fusion_bench_config/method/adamerging/clip.yaml +23 -0
  460. fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +23 -0
  461. fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +23 -0
  462. fusion_bench_config/method/adamerging/llama_sft.yaml +33 -0
  463. fusion_bench_config/method/adamerging.yaml +23 -0
  464. fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +6 -0
  465. fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +6 -0
  466. fusion_bench_config/method/classification/clip_continual_finetune.yaml +28 -0
  467. fusion_bench_config/method/classification/clip_finetune.yaml +26 -0
  468. fusion_bench_config/method/clip_finetune.yaml +26 -0
  469. fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +27 -0
  470. fusion_bench_config/method/concrete_subspace/clip_concrete_task_arithmetic.yaml +25 -0
  471. fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +27 -0
  472. fusion_bench_config/method/dare/simple_average.yaml +5 -0
  473. fusion_bench_config/method/dare/task_arithmetic.yaml +6 -0
  474. fusion_bench_config/method/dare/ties_merging.yaml +15 -0
  475. fusion_bench_config/method/dawe/dawe_for_clip.yaml +32 -0
  476. fusion_bench_config/method/depth_upscaling.yaml +5 -0
  477. fusion_bench_config/method/dummy.yaml +1 -0
  478. fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -0
  479. fusion_bench_config/method/ensemble/simple_ensemble.yaml +2 -0
  480. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +6 -0
  481. fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +13 -0
  482. fusion_bench_config/method/fisher_merging/fisher_merging.yaml +9 -0
  483. fusion_bench_config/method/fisher_merging/gpt2_fisher_merging.yaml +12 -0
  484. fusion_bench_config/method/linear/expo.yaml +8 -0
  485. fusion_bench_config/method/linear/linear_interpolation.yaml +3 -0
  486. fusion_bench_config/method/linear/llama_expo.yaml +19 -0
  487. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +19 -0
  488. fusion_bench_config/method/linear/simple_average_for_llama.yaml +5 -0
  489. fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +4 -0
  490. fusion_bench_config/method/linear/weighted_average.yaml +6 -0
  491. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +12 -0
  492. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
  493. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +47 -0
  494. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +63 -0
  495. fusion_bench_config/method/mixtral_moe_merging.yaml +4 -0
  496. fusion_bench_config/method/mixtral_moe_upscaling.yaml +7 -0
  497. fusion_bench_config/method/model_recombination.yaml +4 -0
  498. fusion_bench_config/method/opcm/opcm.yaml +12 -0
  499. fusion_bench_config/method/opcm/task_arithmetic.yaml +12 -0
  500. fusion_bench_config/method/opcm/ties_merging.yaml +18 -0
  501. fusion_bench_config/method/opcm/weight_average.yaml +10 -0
  502. fusion_bench_config/method/pruning/llama_magnitude_pruning.yaml +14 -0
  503. fusion_bench_config/method/pruning/llama_random_pruning.yaml +9 -0
  504. fusion_bench_config/method/pruning/llama_wanda_pruning.yaml +16 -0
  505. fusion_bench_config/method/pruning/magnitude_diff_pruning.yaml +5 -0
  506. fusion_bench_config/method/pwe_moe_ls_for_clip.yaml +22 -0
  507. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
  508. fusion_bench_config/method/regmean/clip_regmean.yaml +11 -0
  509. fusion_bench_config/method/regmean/gpt2_regmean.yaml +12 -0
  510. fusion_bench_config/method/regmean/regmean.yaml +4 -0
  511. fusion_bench_config/method/simple_average.yaml +1 -0
  512. fusion_bench_config/method/slerp/slerp.yaml +6 -0
  513. fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +8 -0
  514. fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +10 -0
  515. fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +14 -0
  516. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +20 -0
  517. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +20 -0
  518. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +19 -0
  519. fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
  520. fusion_bench_config/method/task_arithmetic.yaml +2 -0
  521. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
  522. fusion_bench_config/method/ties_merging.yaml +8 -0
  523. fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +7 -0
  524. fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +39 -0
  525. fusion_bench_config/method/wemoe/weight_ensembling_moe.yaml +20 -0
  526. fusion_bench_config/model/clip-vit/README.md +38 -0
  527. fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -0
  528. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
  529. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
  530. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
  531. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
  532. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -0
  533. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eight_tasks.yaml +10 -0
  534. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
  535. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -0
  536. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
  537. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
  538. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
  539. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -0
  540. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
  541. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -0
  542. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
  543. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
  544. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
  545. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
  546. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -0
  547. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -0
  548. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
  549. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -0
  550. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -0
  551. fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -0
  552. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
  553. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
  554. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
  555. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
  556. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -0
  557. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +11 -0
  558. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
  559. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -0
  560. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
  561. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
  562. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
  563. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -0
  564. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
  565. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -0
  566. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
  567. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
  568. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
  569. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
  570. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -0
  571. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -0
  572. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
  573. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -0
  574. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -0
  575. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -0
  576. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
  577. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
  578. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
  579. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
  580. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -0
  581. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eight_tasks.yaml +10 -0
  582. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
  583. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -0
  584. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
  585. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
  586. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
  587. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -0
  588. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
  589. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -0
  590. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
  591. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
  592. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
  593. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
  594. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -0
  595. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -0
  596. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
  597. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -0
  598. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -0
  599. fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
  600. fusion_bench_config/model/clip-vit/generate_vit_model_config.sh +23 -0
  601. fusion_bench_config/model/flan-t5/flan-t5-base.yaml +3 -0
  602. fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola.yaml +3 -0
  603. fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola_lora-16.yaml +4 -0
  604. fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli.yaml +3 -0
  605. fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli_lora-16.yaml +4 -0
  606. fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc.yaml +3 -0
  607. fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc_lora-16.yaml +4 -0
  608. fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli.yaml +3 -0
  609. fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli_lora-16.yaml +4 -0
  610. fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp.yaml +3 -0
  611. fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp_lora-16.yaml +4 -0
  612. fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte.yaml +3 -0
  613. fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte_lora-16.yaml +4 -0
  614. fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2.yaml +3 -0
  615. fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2_lora-16.yaml +4 -0
  616. fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb.yaml +3 -0
  617. fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb_lora-16.yaml +4 -0
  618. fusion_bench_config/model/flan-t5/flan-t5-large.yaml +3 -0
  619. fusion_bench_config/model/flan-t5/flan-t5-large_glue-cola_lora-16.yaml +4 -0
  620. fusion_bench_config/model/flan-t5/flan-t5-large_glue-mnli_lora-16.yaml +4 -0
  621. fusion_bench_config/model/flan-t5/flan-t5-large_glue-mrpc_lora-16.yaml +4 -0
  622. fusion_bench_config/model/flan-t5/flan-t5-large_glue-qnli_lora-16.yaml +4 -0
  623. fusion_bench_config/model/flan-t5/flan-t5-large_glue-qqp_lora-16.yaml +4 -0
  624. fusion_bench_config/model/flan-t5/flan-t5-large_glue-rte_lora-16.yaml +4 -0
  625. fusion_bench_config/model/flan-t5/flan-t5-large_glue-sst2_lora-16.yaml +4 -0
  626. fusion_bench_config/model/flan-t5/flan-t5-large_glue-stsb_lora-16.yaml +4 -0
  627. fusion_bench_config/model/flan-t5/generate_flan-t5.sh +38 -0
  628. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +12 -0
  629. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8.yaml +8 -0
  630. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +53 -0
  631. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
  632. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
  633. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
  634. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
  635. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
  636. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +19 -0
  637. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +14 -0
  638. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +5 -0
  639. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +24 -0
  640. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +3 -0
  641. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
  642. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
  643. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
  644. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
  645. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp1.yaml +24 -0
  646. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp2.yaml +24 -0
  647. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +13 -0
  648. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_mtl.yaml +5 -0
  649. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_clean.yaml +18 -0
  650. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +29 -0
  651. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +5 -0
  652. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
  653. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +6 -0
  654. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +18 -0
  655. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +8 -0
  656. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +6 -0
  657. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
  658. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
  659. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
  660. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
  661. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +19 -0
  662. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
  663. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
  664. fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +20 -0
  665. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
  666. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
  667. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +21 -0
  668. fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +17 -0
  669. fusion_bench_config/modelpool/Seq2SeqLMPool/_template.yaml +8 -0
  670. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +13 -0
  671. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +41 -0
  672. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +68 -0
  673. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +7 -0
  674. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +45 -0
  675. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
  676. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
  677. fusion_bench_config/modelpool/automodelpool.yaml +12 -0
  678. fusion_bench_config/modelpool/gpt-2_glue.yaml +64 -0
  679. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +14 -0
  680. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +6 -0
  681. fusion_bench_config/modelpool/nyuv2_modelpool.yaml +26 -0
  682. fusion_bench_config/modelpool/smile_mistral_exp_v1.yaml +9 -0
  683. fusion_bench_config/modelpool/smile_mistral_exp_v2.yaml +9 -0
  684. fusion_bench_config/modelpool/smile_mistral_exp_v3.yaml +9 -0
  685. fusion_bench_config/modelpool/smile_mistral_exp_v4.yaml +13 -0
  686. fusion_bench_config/nyuv2_config.yaml +17 -0
  687. fusion_bench_config/nyuv2_mtl_train.yaml +32 -0
  688. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +31 -0
  689. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
  690. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8.yaml +11 -0
  691. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +31 -0
  692. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_L14.yaml +12 -0
  693. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_val.yaml +12 -0
  694. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_with_control_task.yaml +12 -0
  695. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
  696. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
  697. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
  698. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
  699. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
  700. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
  701. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
  702. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
  703. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
  704. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
  705. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
  706. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
  707. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
  708. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
  709. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
  710. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
  711. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
  712. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
  713. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
  714. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
  715. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
  716. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
  717. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
  718. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
  719. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +18 -0
  720. fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_clean.yaml +24 -0
  721. fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
  722. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +22 -0
  723. fusion_bench_config/taskpool/dummy.yaml +2 -0
  724. fusion_bench_config/taskpool/flan-t5_glue_text_generation.yaml +44 -0
  725. fusion_bench_config/taskpool/gpt-2_glue.yaml +39 -0
  726. fusion_bench_config/taskpool/nyuv2_taskpool.yaml +9 -0
  727. fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
@@ -0,0 +1,1355 @@
1
+ """
2
+ This module contains classes for representing tasks and examples as messages for chat-based interfaces.
3
+ """
4
+
5
+ from abc import ABC, abstractmethod
6
+ from html import escape
7
+ from typing import Dict, List, Optional, Tuple, Union
8
+
9
+ import numpy as np
10
+
11
+ from .arc import Example, Task
12
+ from .representers import (
13
+ CompositeRepresenter,
14
+ ConnectedComponentRepresenter,
15
+ DelimitedGridRepresenter,
16
+ DiffExampleRepresenter,
17
+ GridRepresenter,
18
+ ImageTaskRepresenter,
19
+ PythonListGridRepresenter,
20
+ TaskRepresenter,
21
+ TextExampleRepresenter,
22
+ TextTaskRepresenter,
23
+ WordGridRepresenter,
24
+ )
25
+
26
+ MESSAGE = Dict[str, Union[str, Dict]]
27
+ MESSAGES = List[MESSAGE]
28
+
29
+
30
+ def display_messages(messages: MESSAGES):
31
+ html_output = """<!DOCTYPE html>
32
+ <html>
33
+ <head>
34
+ <meta charset="UTF-8">
35
+ <title>Chat View</title>
36
+ <style>
37
+ /* CSS styling for chat interface */
38
+ body {
39
+ font-family: Arial, sans-serif;
40
+ background-color: #f5f5f5;
41
+ }
42
+ .chat-container {
43
+ width: 80%;
44
+ max-width: 800px;
45
+ margin: 0 auto;
46
+ margin-top: 50px;
47
+ }
48
+ .message {
49
+ display: block;
50
+ clear: both;
51
+ margin-bottom: 15px;
52
+ }
53
+ .message.user {
54
+ text-align: right;
55
+ }
56
+ .message.assistant {
57
+ text-align: left;
58
+ }
59
+ .message.system {
60
+ text-align: left;
61
+ }
62
+ .message .bubble {
63
+ display: inline-block;
64
+ padding: 10px 15px;
65
+ border-radius: 15px;
66
+ max-width: 70%;
67
+ position: relative;
68
+ }
69
+ .message.user .bubble {
70
+ background-color: #0084ff;
71
+ color: white;
72
+ }
73
+ .message.assistant .bubble {
74
+ background-color: #e5e5ea;
75
+ color: black;
76
+ }
77
+ .message.system .bubble {
78
+ background-color: #e5e5ea;
79
+ color: black;
80
+ }
81
+ .message .bubble img {
82
+ max-width: 100%;
83
+ border-radius: 10px;
84
+ }
85
+ .message .role {
86
+ font-size: 0.8em;
87
+ color: black;
88
+ margin-bottom: 5px;
89
+ }
90
+ </style>
91
+ </head>
92
+ <body>
93
+ <div class="chat-container">
94
+ """
95
+
96
+ # Loop through messages
97
+ for message in messages:
98
+ role = message.get("role", "user")
99
+ content_list = message.get("content", [])
100
+ if not content_list:
101
+ continue # Skip if no content
102
+ if isinstance(content_list, str):
103
+ content_list = [{"type": "text", "text": content_list}]
104
+
105
+ # Start message div
106
+ html_output += f'<div class="message {role}">\n'
107
+ # Start bubble div
108
+ html_output += '<div class="bubble">\n'
109
+ # Add role label inside the bubble
110
+ html_output += f'<div class="role">{role.capitalize()}</div>\n'
111
+
112
+ # Process content items
113
+ for content in content_list:
114
+ content_type = content.get("type")
115
+ if content_type == "text":
116
+ text = content.get("text", "")
117
+ # Escape HTML entities in text
118
+ safe_text = escape(text)
119
+ # Replace newlines with <br>
120
+ safe_text = safe_text.replace("\n", "<br>")
121
+ html_output += f"<p>{safe_text}</p>\n"
122
+ elif content_type == "image_url":
123
+ image_url = content["image_url"].get("url", {})
124
+ if image_url:
125
+ html_output += f'<img src="{image_url}" alt="Image">\n'
126
+ else:
127
+ # Handle other content types if necessary
128
+ pass
129
+
130
+ # Close bubble and message divs
131
+ html_output += "</div>\n</div>\n"
132
+
133
+ # Close chat-container and body tags
134
+ html_output += """
135
+ </div>
136
+ </body>
137
+ </html>"""
138
+
139
+ return html_output
140
+
141
+
142
+ class MessageRepresenter(ABC):
143
+ task_representer: TaskRepresenter
144
+
145
+ @abstractmethod
146
+ def encode(self, task: Task, **kwargs) -> Tuple[MESSAGES, MESSAGE]:
147
+ pass
148
+
149
+ def display(self, messages: MESSAGES):
150
+ return display_messages(messages)
151
+
152
+
153
+ # =============== MESSAGE REPRESENTATION ===============
154
+
155
+
156
+ class GPTTextMessagerepresenter(MessageRepresenter):
157
+ def __init__(
158
+ self,
159
+ prompt: Optional[
160
+ str
161
+ ] = "Figure out the pattern in the following examples and apply it to the test case. {description}Your answer must follow the format of the examples. \n",
162
+ task_representer: TaskRepresenter = TextTaskRepresenter(),
163
+ ):
164
+ self.prompt = prompt
165
+ self.task_representer = task_representer
166
+
167
+ def encode(self, task: Task, **kwargs) -> Tuple[MESSAGES, MESSAGE]:
168
+ input_data = []
169
+
170
+ if hasattr(task, "description"):
171
+ desciption = "Here is a description of the task: \n\n{description}\n"
172
+ description = desciption.format(description=task.description)
173
+ prompt = self.prompt.format(description=description)
174
+ else:
175
+ prompt = self.prompt.format(description="")
176
+
177
+ input_data.append({"role": "system", "content": prompt})
178
+
179
+ for example in task.train_examples:
180
+ query, output = self.task_representer.example_representer.encode(
181
+ example, **kwargs
182
+ )
183
+ input_data.append({"role": "system", "content": query + output})
184
+
185
+ query, output = self.task_representer.example_representer.encode(
186
+ task.test_example, **kwargs
187
+ )
188
+
189
+ input_data.append({"role": "user", "content": query})
190
+
191
+ output_data = {"role": "assistant", "content": output}
192
+
193
+ return input_data, output_data
194
+
195
+ def decode(self, input_data: MESSAGES, output_data: MESSAGE, **kwargs) -> Task:
196
+ raise NotImplementedError(
197
+ "Decoding for GPTTextMessagerepresenter is not implemented."
198
+ )
199
+
200
+
201
+ class GPTTextMessageRepresenterV2(MessageRepresenter):
202
+ def __init__(
203
+ self,
204
+ prompt: Optional[
205
+ str
206
+ ] = "Figure out the underlying transformation in the following examples and apply it to the test case. {description}Here are some examples from this transformation, your answer must follow the format.\n",
207
+ task_representer: TaskRepresenter = TextTaskRepresenter(),
208
+ ):
209
+ self.prompt = prompt
210
+ self.task_representer = task_representer
211
+ # if example_representer is not None:
212
+ # self.task_representer.example_representer = example_representer(
213
+ # io_sep=" -> ",
214
+ # input_header="",
215
+ # output_header="",
216
+ # grid_representer=PythonListGridRepresenter
217
+ # )
218
+
219
+ def encode(self, task: Task, **kwargs) -> Tuple[MESSAGES, MESSAGE]:
220
+ input_data = []
221
+
222
+ if hasattr(task, "description"):
223
+ description = task.description
224
+ description = f"\n\n A possible description of the transformation: \n\n{description}\n"
225
+ prompt = self.prompt.format(description=description)
226
+ else:
227
+ prompt = self.prompt.format(description="")
228
+
229
+ if isinstance(
230
+ self.task_representer.example_representer, DiffExampleRepresenter
231
+ ):
232
+ if self.task_representer.example_representer.use_output:
233
+ prompt += "The input-diff-output grids are provided as python arrays where the diff is simply the output minus input:\n"
234
+ else:
235
+ prompt += "The input-diff grids are provided as python arrays:\n"
236
+ elif isinstance(
237
+ self.task_representer.example_representer.grid_representer,
238
+ ConnectedComponentRepresenter,
239
+ ):
240
+ connected_component = kwargs.get(
241
+ "connected_component",
242
+ self.task_representer.example_representer.grid_representer.connected_component,
243
+ )
244
+ connected_component = (
245
+ "including diagonals"
246
+ if connected_component == 8
247
+ else "excluding diagonals"
248
+ )
249
+ prompt += f"The input-output grids are provided with indices of connected shapes ({connected_component}) of the same color:\n"
250
+ elif isinstance(
251
+ self.task_representer.example_representer.grid_representer,
252
+ PythonListGridRepresenter,
253
+ ):
254
+ prompt += "The input-output grids are provided as python arrays:\n"
255
+ elif isinstance(
256
+ self.task_representer.example_representer.grid_representer,
257
+ CompositeRepresenter,
258
+ ):
259
+ connected_component = kwargs.get(
260
+ "connected_component",
261
+ self.task_representer.example_representer.grid_representer.connected_component,
262
+ )
263
+ connected_component = (
264
+ "including diagonals"
265
+ if connected_component == 8
266
+ else "excluding diagonals"
267
+ )
268
+ prompt += f"The input-output grids are provided as both python arrays and indices of connected shapes ({connected_component}) of the same color:\n"
269
+
270
+ for example in task.train_examples:
271
+ query, output = self.task_representer.example_representer.encode(
272
+ example, **kwargs
273
+ )
274
+ if query is None or output is None:
275
+ return None, None
276
+ prompt += query + output + "\n"
277
+
278
+ input_data.append({"role": "system", "content": prompt})
279
+
280
+ query, output = self.task_representer.example_representer.encode(
281
+ task.test_example, **kwargs
282
+ )
283
+ if query is None or output is None:
284
+ return None, None
285
+
286
+ input_data.append({"role": "user", "content": query})
287
+
288
+ output_data = {"role": "assistant", "content": output}
289
+
290
+ return input_data, output_data
291
+
292
+ def decode(self, input_data: MESSAGES, output_data: MESSAGE, **kwargs) -> Task:
293
+ raise NotImplementedError(
294
+ "Decoding for GPTTextMessageRepresenterV2 is not implemented."
295
+ )
296
+
297
+ def __repr__(self) -> str:
298
+ return f"GPTTextMessageRepresenterV2(prompt={self.prompt!r}, task_representer={repr(self.task_representer)})"
299
+
300
+
301
+ class GPTTextMessageRepresenterV2CoT(MessageRepresenter):
302
+ def __init__(
303
+ self,
304
+ prompt: Optional[str] = None,
305
+ task_representer: TaskRepresenter = TextTaskRepresenter(),
306
+ ):
307
+ if prompt:
308
+ self.prompt = prompt
309
+ else:
310
+ self.prompt = "Figure out the underlying transformation in the following examples and apply it to the test case. {description}Here are some examples from this transformation, your answer must follow the format.\n"
311
+
312
+ self.task_representer = task_representer
313
+
314
+ def encode(self, task: Task) -> Tuple[MESSAGES, MESSAGE]:
315
+ input_data = []
316
+
317
+ if hasattr(task, "description"):
318
+ description = task.description
319
+ description = f"\n\n A possible description of the transformation: \n\n{description}\n"
320
+ prompt = self.prompt.format(description=description)
321
+ else:
322
+ prompt = self.prompt.format(description="")
323
+
324
+ if isinstance(
325
+ self.task_representer.example_representer.grid_representer,
326
+ ConnectedComponentRepresenter,
327
+ ):
328
+ prompt += "The input-output grids are provided with indices of connected shapes of the same color:\n"
329
+ elif isinstance(
330
+ self.task_representer.example_representer.grid_representer,
331
+ PythonListGridRepresenter,
332
+ ):
333
+ prompt += "The input-output grids are provided as python arrays:\n"
334
+ elif isinstance(
335
+ self.task_representer.example_representer.grid_representer,
336
+ CompositeRepresenter,
337
+ ):
338
+ prompt += "The input-output grids are provided as both python arrays and indices of connected shapes of the same color:\n"
339
+
340
+ for example in task.train_examples:
341
+ query, output = self.task_representer.example_representer.encode(example)
342
+ prompt += query + output + "\n"
343
+
344
+ input_data.append({"role": "system", "content": prompt})
345
+
346
+ query, output = self.task_representer.example_representer.encode(
347
+ task.test_example
348
+ )
349
+
350
+ input_data.append(
351
+ {"role": "user", "content": query + ". Let's think step by step:"}
352
+ )
353
+
354
+ cot_strs = ""
355
+ for i, cot in enumerate(task.test_example.cot[:-1]):
356
+ if -1 in cot:
357
+ cot = np.where(cot == -1, 0, cot)
358
+ cot_str = self.task_representer.example_representer.grid_representer.encode(
359
+ cot
360
+ )
361
+ cot_str = "Step-" + str(i + 1) + ":\n" + cot_str
362
+ cot_strs += cot_str + "\n"
363
+
364
+ cot_strs += (
365
+ "Final Step:\n"
366
+ + self.task_representer.example_representer.grid_representer.encode(
367
+ task.test_example.cot[-1]
368
+ )
369
+ )
370
+
371
+ output_data = {"role": "assistant", "content": cot_strs}
372
+
373
+ return input_data, output_data
374
+
375
+ def decode(self, input_data: MESSAGES, output_data: MESSAGE, **kwargs) -> Task:
376
+ raise NotImplementedError(
377
+ "Decoding for GPTTextMessageRepresenterV2CoT is not implemented."
378
+ )
379
+
380
+
381
+ class DataToCodeTextrepresenter(MessageRepresenter):
382
+ def __init__(
383
+ self,
384
+ task_representer: TaskRepresenter = TextTaskRepresenter(),
385
+ prompt: Optional[
386
+ str
387
+ ] = "Figure out the underlying code that produces the following input-output grids:\n",
388
+ ):
389
+ self.prompt = prompt
390
+ self.task_representer = task_representer
391
+
392
+ def encode(self, task: Task, code: str) -> Tuple[MESSAGES, MESSAGE]:
393
+ input_data = []
394
+
395
+ prompt = self.prompt
396
+
397
+ input_data.append({"role": "system", "content": prompt})
398
+
399
+ data_points = ""
400
+
401
+ for example in task.train_examples:
402
+ query, output = self.task_representer.example_representer.encode(example)
403
+ data_points += query + output + "\n"
404
+
405
+ input_data.append({"role": "user", "content": data_points})
406
+
407
+ output_data = {"role": "assistant", "content": code}
408
+
409
+ return input_data, output_data
410
+
411
+ def decode(self, input_data: MESSAGES, output_data: MESSAGE, **kwargs) -> Task:
412
+ # Decoding logic for DataToCodeTextrepresenter is complex and depends on the specific encoding format.
413
+ # This is a placeholder for the actual implementation.
414
+ raise NotImplementedError(
415
+ "Decoding for DataToCodeTextrepresenter is not implemented."
416
+ )
417
+
418
+
419
+ class GPTTextMessageRepresenterFewShot(MessageRepresenter):
420
+ def __init__(
421
+ self,
422
+ task_representer: TaskRepresenter = TextTaskRepresenter(),
423
+ prompt: Optional[
424
+ str
425
+ ] = "Figure out the underlying transformations in each task and complete the examples. You must follow the format.\n\n",
426
+ ):
427
+ self.prompt = prompt
428
+ self.task_representer = task_representer
429
+
430
+ def encode(
431
+ self, task: Task, examples: List[Task], num_demonstrations: List[int]
432
+ ) -> Tuple[MESSAGES, MESSAGE]:
433
+ input_data = []
434
+
435
+ prompts = []
436
+ for i, demo_task in enumerate(examples):
437
+ k = num_demonstrations[i]
438
+ if k >= 3:
439
+ prompt = "== START OF TASK ==\n\n"
440
+ demonstrations = demo_task.train_examples + [demo_task.test_example]
441
+ for j in range(k):
442
+ example = demonstrations[j]
443
+ query, output = self.task_representer.example_representer.encode(
444
+ example
445
+ )
446
+ prompt += query + output + "\n\n"
447
+ prompt += "== END OF TASK ==\n\n"
448
+ prompts.append(prompt)
449
+
450
+ prompts = "".join(prompts)
451
+
452
+ input_data.append({"role": "system", "content": self.prompt + prompts})
453
+
454
+ prompt = ""
455
+ for example in task.train_examples:
456
+ query, output = self.task_representer.example_representer.encode(example)
457
+ prompt += query + output + "\n\n"
458
+
459
+ query, output = self.task_representer.example_representer.encode(
460
+ task.test_example
461
+ )
462
+
463
+ input_data.append({"role": "user", "content": prompt + query})
464
+
465
+ output_data = {"role": "assistant", "content": output}
466
+
467
+ return input_data, output_data
468
+
469
+ def decode(self, input_data: MESSAGES, output_data: MESSAGE, **kwargs) -> Task:
470
+ # Decoding logic for GPTTextMessagerepresenterFewShot is complex and depends on the specific encoding format.
471
+ # This is a placeholder for the actual implementation.
472
+ raise NotImplementedError(
473
+ "Decoding for GPTTextMessagerepresenterFewShot is not implemented."
474
+ )
475
+
476
+
477
+ class GPTTextImageMessagerepresenter(MessageRepresenter):
478
+ def __init__(
479
+ self,
480
+ text_representer: TextTaskRepresenter = TextTaskRepresenter(),
481
+ image_representer: ImageTaskRepresenter = ImageTaskRepresenter(),
482
+ prompt: Optional[
483
+ str
484
+ ] = "Figure out the underlying transformation in the following examples and apply it to the test case. {description}Here are some examples from this transformation, your answer must follow the format.\n",
485
+ ):
486
+ self.prompt = prompt
487
+ self.text_representer = text_representer
488
+ self.image_representer = image_representer
489
+
490
+ def encode(self, task: Task, **kwargs) -> Tuple[MESSAGES, MESSAGE]:
491
+ input_data = []
492
+
493
+ if hasattr(task, "description"):
494
+ description = task.description
495
+ description = f"\n\n A possible description of the transformation: \n\n{description}\n"
496
+ prompt = self.prompt.format(description=description)
497
+ else:
498
+ prompt = self.prompt.format(description="")
499
+
500
+ if isinstance(
501
+ self.text_representer.example_representer.grid_representer,
502
+ ConnectedComponentRepresenter,
503
+ ):
504
+ connected_component = kwargs.get(
505
+ "connected_component",
506
+ self.text_representer.example_representer.grid_representer.connected_component,
507
+ )
508
+ connected_component = (
509
+ "including diagonals"
510
+ if connected_component == 8
511
+ else "excluding diagonals"
512
+ )
513
+ prompt += f"The input-output grids are provided with both as image and as indices of connected shapes ({connected_component}) of the same color."
514
+ elif isinstance(
515
+ self.text_representer.example_representer.grid_representer,
516
+ PythonListGridRepresenter,
517
+ ):
518
+ prompt += "The input-output grids are provided both as image and as python arrays:\n"
519
+ elif isinstance(
520
+ self.text_representer.example_representer.grid_representer,
521
+ CompositeRepresenter,
522
+ ):
523
+ connected_component = kwargs.get(
524
+ "connected_component",
525
+ self.text_representer.example_representer.grid_representer.connected_component,
526
+ )
527
+ connected_component = (
528
+ "including diagonals"
529
+ if connected_component == 8
530
+ else "excluding diagonals"
531
+ )
532
+ prompt += f"The input-output grids are provided as both python arrays and as indices of connected shapes ({connected_component}) of the same color."
533
+
534
+ input_data.append({"role": "system", "content": prompt})
535
+
536
+ for j, example in enumerate(task.train_examples + [task.test_example]):
537
+ content = []
538
+ query, output = self.text_representer.example_representer.encode(
539
+ example, **kwargs
540
+ )
541
+
542
+ content.append(
543
+ {
544
+ "type": "text",
545
+ "text": query.replace("\nOUTPUT:\n", ""),
546
+ }
547
+ )
548
+
549
+ input_image = (
550
+ self.image_representer.example_representer.grid_representer.encode(
551
+ example.input, **kwargs
552
+ )
553
+ )
554
+ content.append(
555
+ {
556
+ "type": "image_url",
557
+ "image_url": {"url": f"data:image/jpeg;base64,{input_image}"},
558
+ }
559
+ )
560
+ if j != len(task.train_examples):
561
+ output_image = (
562
+ self.image_representer.example_representer.grid_representer.encode(
563
+ example.output, **kwargs
564
+ )
565
+ )
566
+ content.append({"type": "text", "text": "\nOUTPUT:\n" + output})
567
+
568
+ content.append(
569
+ {
570
+ "type": "image_url",
571
+ "image_url": {"url": f"data:image/jpeg;base64,{output_image}"},
572
+ }
573
+ )
574
+ else:
575
+ test_content = []
576
+ output_image = (
577
+ self.image_representer.example_representer.grid_representer.encode(
578
+ example.output, **kwargs
579
+ )
580
+ )
581
+ test_content.append({"type": "text", "text": "\nOUTPUT:\n" + output})
582
+
583
+ input_data.append({"role": "user", "content": content})
584
+
585
+ output_data = {
586
+ "role": "assistant",
587
+ "content": test_content,
588
+ }
589
+
590
+ return input_data, output_data
591
+
592
+ def decode(self, input_data: MESSAGES, output_data: MESSAGE, **kwargs) -> Task:
593
+ raise NotImplementedError(
594
+ "Decoding for GPTTextMessageRepresenterV2 is not implemented."
595
+ )
596
+
597
+
598
+ class GPTTextImageMessageRepresenterFewShot(MessageRepresenter):
599
+ def __init__(
600
+ self,
601
+ text_representer: TextTaskRepresenter = TextTaskRepresenter(),
602
+ image_representer: ImageTaskRepresenter = ImageTaskRepresenter(),
603
+ diff_representer: Optional[GridRepresenter] = DelimitedGridRepresenter(),
604
+ prompt: Optional[
605
+ str
606
+ ] = "Figure out the underlying transformations in each task and complete the examples. You must follow the format.\n\n",
607
+ disable_image: Optional[bool] = False,
608
+ disable_text: Optional[bool] = False,
609
+ ):
610
+ self.prompt = prompt
611
+ self.disable_image = disable_image
612
+ self.disable_text = disable_text
613
+ self.text_representer = text_representer
614
+ self.image_representer = image_representer
615
+ self.diff_representer = diff_representer
616
+
617
+ def encode(
618
+ self, task: Task, examples: List[Tuple[Task, str]]
619
+ ) -> Tuple[MESSAGES, MESSAGE]:
620
+ input_data = []
621
+
622
+ # if hasattr(task, "description"):
623
+ # description = task.description
624
+ # description = f"\n\n A possible description of the transformation: \n\n{description}\n"
625
+ # prompt = self.prompt.format(description=description)
626
+ # else:
627
+ # prompt = self.prompt.format(description="")
628
+
629
+ # if isinstance(self.text_representer.example_representer.grid_representer, ConnectedComponentRepresenter):
630
+ # connected_component = kwargs.get("connected_component", self.text_representer.example_representer.grid_representer.connected_component)
631
+ # connected_component = "including diagonals" if connected_component == 8 else "excluding diagonals"
632
+ # prompt += f"The input-output grids are provided with both as image and as indices of connected shapes ({connected_component}) of the same color."
633
+ # elif isinstance(self.text_representer.example_representer.grid_representer, PythonListGridRepresenter):
634
+ # prompt += "The input-output grids are provided both as image and as python arrays:\n"
635
+ # elif isinstance(self.text_representer.example_representer.grid_representer, CompositeRepresenter):
636
+ # connected_component = kwargs.get("connected_component", self.text_representer.example_representer.grid_representer.connected_component)
637
+ # connected_component = "including diagonals" if connected_component == 8 else "excluding diagonals"
638
+ # prompt += f"The input-output grids are provided as both python arrays and as indices of connected shapes ({connected_component}) of the same color."
639
+ prompt = self.prompt
640
+ input_data.append({"role": "system", "content": prompt})
641
+ # Iterate over the examples provided for few-shot learning
642
+ for example_task, example_output in examples:
643
+ content = []
644
+ for j, example in enumerate(
645
+ example_task.train_examples + [example_task.test_example]
646
+ ):
647
+ query, output = self.text_representer.example_representer.encode(
648
+ example
649
+ )
650
+ if not self.disable_text:
651
+ content.append(
652
+ {
653
+ "type": "text",
654
+ "text": query.replace("\nOUTPUT:\n", ""),
655
+ }
656
+ )
657
+ else:
658
+ content.append(
659
+ {
660
+ "type": "text",
661
+ "text": "\nINPUT:\n",
662
+ }
663
+ )
664
+
665
+ if not self.disable_image:
666
+ input_image = self.image_representer.example_representer.grid_representer.encode(
667
+ example.input
668
+ )
669
+ content.append(
670
+ {
671
+ "type": "image_url",
672
+ "image_url": {
673
+ "url": f"data:image/jpeg;base64,{input_image}"
674
+ },
675
+ }
676
+ )
677
+ if j != len(example_task.train_examples):
678
+ if not self.disable_text:
679
+ content.append({"type": "text", "text": "\nOUTPUT:\n" + output})
680
+ else:
681
+ content.append({"type": "text", "text": "\nOUTPUT:\n"})
682
+ if not self.disable_image:
683
+ output_image = self.image_representer.example_representer.grid_representer.encode(
684
+ example.output
685
+ )
686
+ content.append(
687
+ {
688
+ "type": "image_url",
689
+ "image_url": {
690
+ "url": f"data:image/jpeg;base64,{output_image}"
691
+ },
692
+ }
693
+ )
694
+
695
+ if np.shape(example.input) == np.shape(example.output):
696
+ diff = example.output - example.input
697
+ diff = np.where(diff != 0, example.output, diff)
698
+ encoded_diff = self.diff_representer.encode(diff)
699
+ if not self.disable_text:
700
+ content.append(
701
+ {"type": "text", "text": "\nDIFF:\n" + encoded_diff}
702
+ )
703
+ else:
704
+ content.append({"type": "text", "text": "\nDIFF:\n"})
705
+
706
+ if not self.disable_image:
707
+ diff_image = self.image_representer.example_representer.grid_representer.encode(
708
+ diff
709
+ )
710
+ content.append(
711
+ {
712
+ "type": "image_url",
713
+ "image_url": {
714
+ "url": f"data:image/jpeg;base64,{diff_image}"
715
+ },
716
+ }
717
+ )
718
+
719
+ input_data.append({"role": "user", "content": content})
720
+ # reasoning
721
+ input_data.append({"role": "assistant", "content": example_output})
722
+
723
+ content = []
724
+ for j, example in enumerate(task.train_examples + [task.test_example]):
725
+ query, output = self.text_representer.example_representer.encode(example)
726
+ if not self.disable_text:
727
+ content.append(
728
+ {
729
+ "type": "text",
730
+ "text": query.replace("\nOUTPUT:\n", ""),
731
+ }
732
+ )
733
+ else:
734
+ content.append(
735
+ {
736
+ "type": "text",
737
+ "text": "\nINPUT:\n",
738
+ }
739
+ )
740
+ if not self.disable_image:
741
+ input_image = (
742
+ self.image_representer.example_representer.grid_representer.encode(
743
+ example.input
744
+ )
745
+ )
746
+ content.append(
747
+ {
748
+ "type": "image_url",
749
+ "image_url": {"url": f"data:image/jpeg;base64,{input_image}"},
750
+ }
751
+ )
752
+
753
+ if j != len(task.train_examples):
754
+ if not self.disable_text:
755
+ content.append({"type": "text", "text": "\nOUTPUT:\n" + output})
756
+ else:
757
+ content.append({"type": "text", "text": "\nOUTPUT:\n"})
758
+
759
+ if not self.disable_image:
760
+ output_image = self.image_representer.example_representer.grid_representer.encode(
761
+ example.output
762
+ )
763
+
764
+ content.append(
765
+ {
766
+ "type": "image_url",
767
+ "image_url": {
768
+ "url": f"data:image/jpeg;base64,{output_image}"
769
+ },
770
+ }
771
+ )
772
+
773
+ if np.shape(example.input) == np.shape(example.output):
774
+ diff = example.output - example.input
775
+ diff = np.where(diff != 0, example.output, diff)
776
+ encoded_diff = self.diff_representer.encode(diff)
777
+ if not self.disable_text:
778
+ content.append(
779
+ {"type": "text", "text": "\nDIFF:\n" + encoded_diff}
780
+ )
781
+ else:
782
+ content.append({"type": "text", "text": "\nDIFF:\n"})
783
+
784
+ if not self.disable_image:
785
+ diff_image = self.image_representer.example_representer.grid_representer.encode(
786
+ diff
787
+ )
788
+ content.append(
789
+ {
790
+ "type": "image_url",
791
+ "image_url": {
792
+ "url": f"data:image/jpeg;base64,{diff_image}"
793
+ },
794
+ }
795
+ )
796
+
797
+ input_data.append({"role": "user", "content": content})
798
+
799
+ output_data = [{}]
800
+
801
+ return input_data, output_data
802
+
803
+
804
+ class TextMessageRepresenterFewShot(MessageRepresenter):
805
+ def __init__(
806
+ self,
807
+ text_representer: TextTaskRepresenter = TextTaskRepresenter(),
808
+ image_representer: ImageTaskRepresenter = ImageTaskRepresenter(),
809
+ prompt: Optional[
810
+ str
811
+ ] = "Figure out the underlying transformations in each task and complete the examples. You must follow the format.\n\n",
812
+ ):
813
+ self.prompt = prompt
814
+ self.text_representer = text_representer
815
+ self.image_representer = image_representer
816
+
817
+ def encode(
818
+ self, task: Task, examples: List[Tuple[Task, str]], **kwargs
819
+ ) -> Tuple[MESSAGES, MESSAGE]:
820
+ input_data = []
821
+
822
+ if hasattr(task, "description"):
823
+ description = task.description
824
+ description = f"\n\n A possible description of the transformation: \n\n{description}\n"
825
+ prompt = self.prompt.format(description=description)
826
+ else:
827
+ prompt = self.prompt.format(description="")
828
+
829
+ if isinstance(
830
+ self.text_representer.example_representer.grid_representer,
831
+ ConnectedComponentRepresenter,
832
+ ):
833
+ connected_component = kwargs.get(
834
+ "connected_component",
835
+ self.text_representer.example_representer.grid_representer.connected_component,
836
+ )
837
+ connected_component = (
838
+ "including diagonals"
839
+ if connected_component == 8
840
+ else "excluding diagonals"
841
+ )
842
+ prompt += f"The input-output grids are provided with both as image and as indices of connected shapes ({connected_component}) of the same color."
843
+ elif isinstance(
844
+ self.text_representer.example_representer.grid_representer,
845
+ PythonListGridRepresenter,
846
+ ):
847
+ prompt += "The input-output grids are provided both as image and as python arrays:\n"
848
+ elif isinstance(
849
+ self.text_representer.example_representer.grid_representer,
850
+ CompositeRepresenter,
851
+ ):
852
+ connected_component = kwargs.get(
853
+ "connected_component",
854
+ self.text_representer.example_representer.grid_representer.connected_component,
855
+ )
856
+ connected_component = (
857
+ "including diagonals"
858
+ if connected_component == 8
859
+ else "excluding diagonals"
860
+ )
861
+ prompt += f"The input-output grids are provided as both python arrays and as indices of connected shapes ({connected_component}) of the same color."
862
+
863
+ input_data.append({"role": "system", "content": prompt})
864
+ # Iterate over the examples provided for few-shot learning
865
+ for example_task, example_output in examples:
866
+ content = []
867
+ for j, example in enumerate(
868
+ example_task.train_examples + [example_task.test_example]
869
+ ):
870
+ query, output = self.text_representer.example_representer.encode(
871
+ example
872
+ )
873
+
874
+ content.append(
875
+ {
876
+ "type": "text",
877
+ "text": query.replace("\nOUTPUT:\n", ""),
878
+ }
879
+ )
880
+
881
+ if j != len(example_task.train_examples):
882
+ content.append({"type": "text", "text": "\nOUTPUT:\n" + output})
883
+
884
+ input_data.append({"role": "user", "content": content})
885
+ # reasoning
886
+ input_data.append({"role": "assistant", "content": example_output})
887
+
888
+ content = []
889
+ for j, example in enumerate(task.train_examples + [task.test_example]):
890
+ query, output = self.text_representer.example_representer.encode(example)
891
+
892
+ content.append(
893
+ {
894
+ "type": "text",
895
+ "text": query.replace("\nOUTPUT:\n", ""),
896
+ }
897
+ )
898
+
899
+ if j != len(task.train_examples):
900
+ content.append({"type": "text", "text": "\nOUTPUT:\n" + output})
901
+
902
+ input_data.append({"role": "user", "content": content})
903
+
904
+ output_data = [{}]
905
+
906
+ return input_data, output_data
907
+
908
+
909
+ class GPTImageMessageRepresenterFewShot(MessageRepresenter):
910
+
911
+ def __init__(
912
+ self,
913
+ text_representer: TextTaskRepresenter = TextTaskRepresenter(),
914
+ image_representer: ImageTaskRepresenter = ImageTaskRepresenter(),
915
+ prompt: Optional[
916
+ str
917
+ ] = "Figure out the underlying transformations in each task and complete the examples. You must follow the format.\n\n",
918
+ ):
919
+ self.prompt = prompt
920
+ self.text_representer = text_representer
921
+ self.image_representer = image_representer
922
+
923
+ def encode(
924
+ self, task: Task, examples: List[Tuple[Task, str]], **kwargs
925
+ ) -> Tuple[MESSAGES, MESSAGE]:
926
+ input_data = []
927
+
928
+ if hasattr(task, "description"):
929
+ description = task.description
930
+ description = f"\n\n A possible description of the transformation: \n\n{description}\n"
931
+ prompt = self.prompt.format(description=description)
932
+ else:
933
+ prompt = self.prompt.format(description="")
934
+
935
+ if isinstance(
936
+ self.text_representer.example_representer.grid_representer,
937
+ ConnectedComponentRepresenter,
938
+ ):
939
+ connected_component = kwargs.get(
940
+ "connected_component",
941
+ self.text_representer.example_representer.grid_representer.connected_component,
942
+ )
943
+ connected_component = (
944
+ "including diagonals"
945
+ if connected_component == 8
946
+ else "excluding diagonals"
947
+ )
948
+ prompt += f"The input-output grids are provided with both as image and as indices of connected shapes ({connected_component}) of the same color."
949
+ elif isinstance(
950
+ self.text_representer.example_representer.grid_representer,
951
+ PythonListGridRepresenter,
952
+ ):
953
+ prompt += "The input-output grids are provided both as image and as python arrays:\n"
954
+ elif isinstance(
955
+ self.text_representer.example_representer.grid_representer,
956
+ CompositeRepresenter,
957
+ ):
958
+ connected_component = kwargs.get(
959
+ "connected_component",
960
+ self.text_representer.example_representer.grid_representer.connected_component,
961
+ )
962
+ connected_component = (
963
+ "including diagonals"
964
+ if connected_component == 8
965
+ else "excluding diagonals"
966
+ )
967
+ prompt += f"The input-output grids are provided as both python arrays and as indices of connected shapes ({connected_component}) of the same color."
968
+
969
+ input_data.append({"role": "system", "content": prompt})
970
+ # Iterate over the examples provided for few-shot learning
971
+ for example_task, example_output in examples:
972
+ content = []
973
+ for j, example in enumerate(
974
+ example_task.train_examples + [example_task.test_example]
975
+ ):
976
+
977
+ query, output = self.text_representer.example_representer.encode(
978
+ example
979
+ )
980
+
981
+ input_image = (
982
+ self.image_representer.example_representer.grid_representer.encode(
983
+ example.input
984
+ )
985
+ )
986
+ content.append(
987
+ {
988
+ "type": "image_url",
989
+ "image_url": {"url": f"data:image/jpeg;base64,{input_image}"},
990
+ }
991
+ )
992
+ if j != len(example_task.train_examples):
993
+ output_image = self.image_representer.example_representer.grid_representer.encode(
994
+ example.output
995
+ )
996
+ content.append(
997
+ {
998
+ "type": "image_url",
999
+ "image_url": {
1000
+ "url": f"data:image/jpeg;base64,{output_image}"
1001
+ },
1002
+ }
1003
+ )
1004
+
1005
+ input_data.append({"role": "user", "content": content})
1006
+ # reasoning
1007
+ input_data.append({"role": "assistant", "content": example_output})
1008
+
1009
+ content = []
1010
+ for j, example in enumerate(task.train_examples + [task.test_example]):
1011
+ query, output = self.text_representer.example_representer.encode(example)
1012
+
1013
+ input_image = (
1014
+ self.image_representer.example_representer.grid_representer.encode(
1015
+ example.input
1016
+ )
1017
+ )
1018
+ content.append(
1019
+ {
1020
+ "type": "image_url",
1021
+ "image_url": {"url": f"data:image/jpeg;base64,{input_image}"},
1022
+ }
1023
+ )
1024
+
1025
+ if j != len(task.train_examples):
1026
+ output_image = (
1027
+ self.image_representer.example_representer.grid_representer.encode(
1028
+ example.output
1029
+ )
1030
+ )
1031
+
1032
+ content.append(
1033
+ {
1034
+ "type": "image_url",
1035
+ "image_url": {"url": f"data:image/jpeg;base64,{output_image}"},
1036
+ }
1037
+ )
1038
+
1039
+ input_data.append({"role": "user", "content": content})
1040
+
1041
+ output_data = [{}]
1042
+
1043
+ return input_data, output_data
1044
+
1045
+
1046
+ class GPTTextImageCodeMessageRepresenterFewShot(MessageRepresenter):
1047
+ def __init__(
1048
+ self,
1049
+ text_representer: TextTaskRepresenter = TextTaskRepresenter(),
1050
+ image_representer: ImageTaskRepresenter = ImageTaskRepresenter(),
1051
+ prompt: Optional[
1052
+ str
1053
+ ] = "Figure out the underlying transformations in each task and complete the examples. You must follow the format.\n\n",
1054
+ disable_image: Optional[bool] = False,
1055
+ ):
1056
+ self.prompt = prompt
1057
+ self.disable_image = disable_image
1058
+ print("disable_image", self.disable_image)
1059
+ self.text_representer = text_representer
1060
+ self.image_representer = image_representer
1061
+
1062
+ def encode(
1063
+ self, task: Task, task_reasoning: str, examples: List[Tuple[Task, str, str]]
1064
+ ) -> Tuple[MESSAGES, MESSAGE]:
1065
+ input_data = []
1066
+
1067
+ # if hasattr(task, "description"):
1068
+ # description = task.description
1069
+ # description = f"\n\n A possible description of the transformation: \n\n{description}\n"
1070
+ # prompt = self.prompt.format(description=description)
1071
+ # else:
1072
+ # prompt = self.prompt.format(description="")
1073
+
1074
+ # if isinstance(self.text_representer.example_representer.grid_representer, ConnectedComponentRepresenter):
1075
+ # connected_component = kwargs.get("connected_component", self.text_representer.example_representer.grid_representer.connected_component)
1076
+ # connected_component = "including diagonals" if connected_component == 8 else "excluding diagonals"
1077
+ # prompt += f"The input-output grids are provided with both as image and as indices of connected shapes ({connected_component}) of the same color."
1078
+ # elif isinstance(self.text_representer.example_representer.grid_representer, PythonListGridRepresenter):
1079
+ # prompt += "The input-output grids are provided both as image and as python arrays:\n"
1080
+ # elif isinstance(self.text_representer.example_representer.grid_representer, CompositeRepresenter):
1081
+ # connected_component = kwargs.get("connected_component", self.text_representer.example_representer.grid_representer.connected_component)
1082
+ # connected_component = "including diagonals" if connected_component == 8 else "excluding diagonals"
1083
+ # prompt += f"The input-output grids are provided as both python arrays and as indices of connected shapes ({connected_component}) of the same color."
1084
+
1085
+ prompt = self.prompt
1086
+
1087
+ input_data.append({"role": "system", "content": prompt})
1088
+ # Iterate over the examples provided for few-shot learning
1089
+ for example_task, reasoning, example_output in examples:
1090
+ content = []
1091
+ for j, example in enumerate(
1092
+ example_task.train_examples + [example_task.test_example]
1093
+ ):
1094
+ query, output = self.text_representer.example_representer.encode(
1095
+ example
1096
+ )
1097
+
1098
+ content.append(
1099
+ {
1100
+ "type": "text",
1101
+ "text": query.replace("\nOUTPUT:\n", ""),
1102
+ }
1103
+ )
1104
+ if not self.disable_image:
1105
+ input_image = self.image_representer.example_representer.grid_representer.encode(
1106
+ example.input
1107
+ )
1108
+ content.append(
1109
+ {
1110
+ "type": "image_url",
1111
+ "image_url": {
1112
+ "url": f"data:image/jpeg;base64,{input_image}"
1113
+ },
1114
+ }
1115
+ )
1116
+ if j != len(example_task.train_examples):
1117
+
1118
+ content.append({"type": "text", "text": "\nOUTPUT:\n" + output})
1119
+
1120
+ if not self.disable_image:
1121
+ output_image = self.image_representer.example_representer.grid_representer.encode(
1122
+ example.output
1123
+ )
1124
+
1125
+ content.append(
1126
+ {
1127
+ "type": "image_url",
1128
+ "image_url": {
1129
+ "url": f"data:image/jpeg;base64,{output_image}"
1130
+ },
1131
+ }
1132
+ )
1133
+
1134
+ input_data.append(
1135
+ {
1136
+ "role": "user",
1137
+ "content": content
1138
+ + [
1139
+ {
1140
+ "type": "text",
1141
+ "text": "\n\n====REASONING FOR CODE=====\n\n" + reasoning,
1142
+ }
1143
+ ],
1144
+ }
1145
+ )
1146
+ # reasoning
1147
+ input_data.append({"role": "assistant", "content": example_output})
1148
+
1149
+ content = []
1150
+ for j, example in enumerate(task.train_examples + [task.test_example]):
1151
+ query, output = self.text_representer.example_representer.encode(example)
1152
+
1153
+ content.append(
1154
+ {
1155
+ "type": "text",
1156
+ "text": query.replace("\nOUTPUT:\n", ""),
1157
+ }
1158
+ )
1159
+ if not self.disable_image:
1160
+ input_image = (
1161
+ self.image_representer.example_representer.grid_representer.encode(
1162
+ example.input
1163
+ )
1164
+ )
1165
+ content.append(
1166
+ {
1167
+ "type": "image_url",
1168
+ "image_url": {"url": f"data:image/jpeg;base64,{input_image}"},
1169
+ }
1170
+ )
1171
+
1172
+ if j != len(task.train_examples):
1173
+ content.append({"type": "text", "text": "\nOUTPUT:\n" + output})
1174
+
1175
+ if not self.disable_image:
1176
+ output_image = self.image_representer.example_representer.grid_representer.encode(
1177
+ example.output
1178
+ )
1179
+ content.append(
1180
+ {
1181
+ "type": "image_url",
1182
+ "image_url": {
1183
+ "url": f"data:image/jpeg;base64,{output_image}"
1184
+ },
1185
+ }
1186
+ )
1187
+
1188
+ input_data.append(
1189
+ {
1190
+ "role": "user",
1191
+ "content": content
1192
+ + [
1193
+ {
1194
+ "type": "text",
1195
+ "text": "\n\n====REASONING FOR CODE=====\n\n" + task_reasoning,
1196
+ }
1197
+ ],
1198
+ }
1199
+ )
1200
+
1201
+ output_data = [{}]
1202
+
1203
+ return input_data, output_data
1204
+
1205
+
1206
+ class GPTCodeDebuggerMessager(MessageRepresenter):
1207
+
1208
+ def __init__(
1209
+ self,
1210
+ text_representer: TextTaskRepresenter = TextTaskRepresenter(),
1211
+ prompt: str = "You are a debugging assistant. Please debug the code provided below.",
1212
+ ):
1213
+ self.prompt = prompt
1214
+ self.text_representer = text_representer
1215
+
1216
+ def encode(self, task: Task, reasoning: str, code: str, error_message: str):
1217
+ # Prepare the input message for the model
1218
+ # system
1219
+ input_messages = [{"role": "system", "content": self.prompt}]
1220
+
1221
+ demonstrations, query, output = self.text_representer.encode(task)
1222
+ query = demonstrations + "\n" + query.replace("\nOUTPUT:\n", "")
1223
+
1224
+ content = []
1225
+ content.append(
1226
+ {
1227
+ "type": "text",
1228
+ "text": query
1229
+ + "\n\n====REASONING FOR CODE=====\n\n"
1230
+ + reasoning
1231
+ + "\n\n"
1232
+ + "```python\n"
1233
+ + code
1234
+ + "\n```",
1235
+ }
1236
+ )
1237
+
1238
+ content.append(
1239
+ {
1240
+ "type": "text",
1241
+ "text": "\n\n Here is the error message:\n\n"
1242
+ + error_message
1243
+ + "\n\n Can you now give the debugged version of the code? Remember, the implementation must contain the ExampleRepresenter() class as that is used for testing. Do not make up your own Class names for the representer.",
1244
+ }
1245
+ )
1246
+ input_messages.append({"role": "user", "content": content})
1247
+
1248
+ return input_messages, None
1249
+
1250
+
1251
+ class GPTTextMessageRepresenterForBarc(MessageRepresenter):
1252
+ def __init__(
1253
+ self,
1254
+ prompt: Optional[
1255
+ str
1256
+ ] = "You are a world-class puzzle solver with exceptional pattern recognition skills. Your task is to analyze puzzles, spot patterns, and provide direct solutions.",
1257
+ task_representer: TaskRepresenter = TextTaskRepresenter(),
1258
+ ):
1259
+ self.prompt = prompt
1260
+ self.task_representer = task_representer
1261
+ # if example_representer is not None:
1262
+ # self.task_representer.example_representer = example_representer(
1263
+ # io_sep=" -> ",
1264
+ # input_header="",
1265
+ # output_header="",
1266
+ # grid_representer=PythonListGridRepresenter
1267
+ # )
1268
+
1269
+ def encode(self, task: Task, **kwargs) -> Tuple[MESSAGES, MESSAGE]:
1270
+ input_data = []
1271
+
1272
+ input_data.append({"role": "system", "content": self.prompt})
1273
+
1274
+ content = "Given input-output grid pairs as reference examples, carefully observe the patterns to predict the output grid for new test input. Each pair follows the same transformation rule. Grids are 2D arrays represented as strings, with cells (colors) separated by spaces and rows by newlines.\nHere are the input and output grids for the reference examples:\n"
1275
+
1276
+ for i, example in enumerate(task.train_examples):
1277
+ content += f"Example {i + 1}:\n"
1278
+ query, output = self.task_representer.example_representer.encode(
1279
+ example, **kwargs
1280
+ )
1281
+ if query is None or output is None:
1282
+ return None, None
1283
+ content += query + output + "\n"
1284
+
1285
+ content += "\n\nHere is the input grid for the test example:\n"
1286
+
1287
+ query, output = self.task_representer.example_representer.encode(
1288
+ task.test_example, **kwargs
1289
+ )
1290
+
1291
+ query = query.replace("Output:", "")
1292
+
1293
+ content += (
1294
+ query
1295
+ + "Directly provide the output grids corresponding to the given test input grids, based on the patterns observed in the reference examples."
1296
+ )
1297
+
1298
+ input_data.append({"role": "user", "content": content})
1299
+
1300
+ output = f"The output grid for the test input grid is:\n\n```\n{output}\n```"
1301
+
1302
+ output_data = {"role": "assistant", "content": output}
1303
+
1304
+ return input_data, output_data
1305
+
1306
+ def decode(self, input_data: MESSAGES, output_data: MESSAGE, **kwargs) -> Task:
1307
+ raise NotImplementedError(
1308
+ "Decoding for GPTTextMessageRepresenterV2 is not implemented."
1309
+ )
1310
+
1311
+ def __repr__(self) -> str:
1312
+ return f"GPTTextMessageRepresenterForBarc(prompt={self.prompt!r}, task_representer={repr(self.task_representer)})"
1313
+
1314
+ def __str__(self) -> str:
1315
+ return repr(self)
1316
+
1317
+
1318
+ if __name__ == "__main__":
1319
+ print("Running tests")
1320
+ grid = np.array([[1, 1, 1], [0, 0, 0], [1, 1, 1]])
1321
+ example = Example(input=grid, output=grid)
1322
+ task = Task(test_example=example, train_examples=[example])
1323
+
1324
+ representer = GPTTextMessageRepresenterForBarc(
1325
+ task_representer=TextTaskRepresenter(
1326
+ example_representer=TextExampleRepresenter(
1327
+ grid_representer=WordGridRepresenter(),
1328
+ input_header="Input:\n",
1329
+ output_header="\nOutput:\n",
1330
+ io_sep="\n",
1331
+ )
1332
+ )
1333
+ )
1334
+
1335
+ input, output = representer.encode(task)
1336
+ breakpoint()
1337
+
1338
+ representer = GPTTextMessagerepresenter()
1339
+ representer = GPTTextMessageRepresenterV2()
1340
+ breakpoint()
1341
+ input, output = representer.encode(task)
1342
+ print(input)
1343
+ html_output = representer.display(input)
1344
+ # Write to an HTML file
1345
+ with open("chat_view.html", "w", encoding="utf-8") as file:
1346
+ file.write(html_output)
1347
+
1348
+ # representer = GPTTextMessageRepresenterV2(task_representer=TextTaskRepresenter(example_representer=TextExampleRepresenter(grid_representer=ConnectedComponentRepresenter())))
1349
+
1350
+ representer = GPTTextImageMessagerepresenter()
1351
+ input, output = representer.encode(task)
1352
+ html_output = representer.display(input + [output])
1353
+ # Write to an HTML file
1354
+ with open("chat_view_w_image.html", "w", encoding="utf-8") as file:
1355
+ file.write(html_output)