fusion-bench 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (727) hide show
  1. fusion_bench/__init__.py +20 -0
  2. fusion_bench/__main__.py +4 -0
  3. fusion_bench/compat/__init__.py +0 -0
  4. fusion_bench/compat/method/__init__.py +109 -0
  5. fusion_bench/compat/method/base_algorithm.py +58 -0
  6. fusion_bench/compat/modelpool/AutoModelForSeq2SeqLM.py +34 -0
  7. fusion_bench/compat/modelpool/__init__.py +116 -0
  8. fusion_bench/compat/modelpool/base_pool.py +328 -0
  9. fusion_bench/compat/modelpool/huggingface_clip_vision.py +178 -0
  10. fusion_bench/compat/taskpool/__init__.py +95 -0
  11. fusion_bench/compat/taskpool/base_pool.py +111 -0
  12. fusion_bench/compat/taskpool/clip_image_classification.py +210 -0
  13. fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +175 -0
  14. fusion_bench/constants/__init__.py +2 -0
  15. fusion_bench/constants/paths.py +18 -0
  16. fusion_bench/dataset/__init__.py +29 -0
  17. fusion_bench/dataset/arc_agi/__init__.py +6 -0
  18. fusion_bench/dataset/arc_agi/arc.py +308 -0
  19. fusion_bench/dataset/arc_agi/arc_agi.py +365 -0
  20. fusion_bench/dataset/arc_agi/augmenters.py +1036 -0
  21. fusion_bench/dataset/arc_agi/messagers.py +1355 -0
  22. fusion_bench/dataset/arc_agi/np_cache.py +168 -0
  23. fusion_bench/dataset/arc_agi/preprocess.py +298 -0
  24. fusion_bench/dataset/arc_agi/representers.py +1019 -0
  25. fusion_bench/dataset/clip_dataset.py +71 -0
  26. fusion_bench/dataset/fer2013.py +12 -0
  27. fusion_bench/dataset/gpt2_glue.py +300 -0
  28. fusion_bench/dataset/gsm8k.py +60 -0
  29. fusion_bench/dataset/image_dataset.py +55 -0
  30. fusion_bench/dataset/imdb.py +11 -0
  31. fusion_bench/dataset/llama/__init__.py +1 -0
  32. fusion_bench/dataset/llama/alpaca.py +232 -0
  33. fusion_bench/dataset/llama/collate.py +120 -0
  34. fusion_bench/dataset/llama/metamathqa.py +50 -0
  35. fusion_bench/dataset/llama/openai.py +160 -0
  36. fusion_bench/dataset/llama/preference_700k.py +70 -0
  37. fusion_bench/dataset/llama/sharegpt.py +141 -0
  38. fusion_bench/dataset/llama/squad.py +125 -0
  39. fusion_bench/dataset/llama/stanford_shp.py +90 -0
  40. fusion_bench/dataset/llama/ultrachat.py +58 -0
  41. fusion_bench/dataset/llama/utils/__init__.py +0 -0
  42. fusion_bench/dataset/llama/wikitext.py +89 -0
  43. fusion_bench/dataset/nyuv2.py +119 -0
  44. fusion_bench/method/__init__.py +177 -0
  45. fusion_bench/method/ada_svd/__init__.py +2 -0
  46. fusion_bench/method/ada_svd/clip_vision.py +319 -0
  47. fusion_bench/method/adamerging/__init__.py +6 -0
  48. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +46 -0
  49. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +187 -0
  50. fusion_bench/method/adamerging/entropy_loss.py +25 -0
  51. fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +332 -0
  52. fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +351 -0
  53. fusion_bench/method/adamerging/layer_wise_adamerging.py +252 -0
  54. fusion_bench/method/adamerging/llama_adamerging.py +335 -0
  55. fusion_bench/method/adamerging/min_norm_solvers.py +227 -0
  56. fusion_bench/method/adamerging/task_wise_adamerging.py +174 -0
  57. fusion_bench/method/adamerging/utils.py +15 -0
  58. fusion_bench/method/analysis/__init__.py +2 -0
  59. fusion_bench/method/analysis/task_vector_cos_similarity.py +172 -0
  60. fusion_bench/method/analysis/task_vector_violin_plot.py +205 -0
  61. fusion_bench/method/base_algorithm.py +44 -0
  62. fusion_bench/method/classification/__init__.py +3 -0
  63. fusion_bench/method/classification/clip_finetune.py +444 -0
  64. fusion_bench/method/classification/continual_clip_finetune.py +297 -0
  65. fusion_bench/method/concrete_subspace/__init__.py +6 -0
  66. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +595 -0
  67. fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py +263 -0
  68. fusion_bench/method/dare/__init__.py +4 -0
  69. fusion_bench/method/dare/simple_average.py +31 -0
  70. fusion_bench/method/dare/task_arithmetic.py +82 -0
  71. fusion_bench/method/dare/ties_merging.py +100 -0
  72. fusion_bench/method/dare/utils.py +87 -0
  73. fusion_bench/method/dawe/__init__.py +2 -0
  74. fusion_bench/method/dawe/dawe_for_clip.py +274 -0
  75. fusion_bench/method/dawe/warppers/__init__.py +13 -0
  76. fusion_bench/method/dawe/warppers/dawe_model.py +256 -0
  77. fusion_bench/method/depth_upscaling/__init__.py +3 -0
  78. fusion_bench/method/depth_upscaling/depth_upscaling.py +89 -0
  79. fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +57 -0
  80. fusion_bench/method/dummy.py +35 -0
  81. fusion_bench/method/ensemble.py +98 -0
  82. fusion_bench/method/fisher_merging/__init__.py +4 -0
  83. fusion_bench/method/fisher_merging/clip_fisher_merging.py +191 -0
  84. fusion_bench/method/fisher_merging/fisher_merging.py +484 -0
  85. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +193 -0
  86. fusion_bench/method/linear/__init__.py +6 -0
  87. fusion_bench/method/linear/expo.py +118 -0
  88. fusion_bench/method/linear/linear_interpolation.py +60 -0
  89. fusion_bench/method/linear/llama_expo.py +229 -0
  90. fusion_bench/method/linear/simple_average_for_llama.py +54 -0
  91. fusion_bench/method/linear/task_arithmetic_for_llama.py +57 -0
  92. fusion_bench/method/lm_finetune/__init__.py +3 -0
  93. fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
  94. fusion_bench/method/lm_finetune/causal_lm_pretrain.py +7 -0
  95. fusion_bench/method/lm_finetune/fullfinetune_sft.py +375 -0
  96. fusion_bench/method/lm_finetune/peftfinetune_sft.py +370 -0
  97. fusion_bench/method/mixture_of_experts/__init__.py +7 -0
  98. fusion_bench/method/mixture_of_experts/mixtral_merging.py +112 -0
  99. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +329 -0
  100. fusion_bench/method/model_recombination.py +121 -0
  101. fusion_bench/method/opcm/__init__.py +4 -0
  102. fusion_bench/method/opcm/opcm.py +277 -0
  103. fusion_bench/method/opcm/task_arithmetic.py +115 -0
  104. fusion_bench/method/opcm/ties_merging.py +156 -0
  105. fusion_bench/method/opcm/utils.py +73 -0
  106. fusion_bench/method/opcm/weight_average.py +120 -0
  107. fusion_bench/method/pruning/__init__.py +5 -0
  108. fusion_bench/method/pruning/llama_magnitude_prune.py +202 -0
  109. fusion_bench/method/pruning/llama_random_prune.py +143 -0
  110. fusion_bench/method/pruning/llama_wanda_prune.py +359 -0
  111. fusion_bench/method/pruning/magnitude_diff_pruning.py +180 -0
  112. fusion_bench/method/pruning/prune_utils.py +165 -0
  113. fusion_bench/method/pruning/wanda_utils/__init__.py +7 -0
  114. fusion_bench/method/pruning/wanda_utils/ablate.py +188 -0
  115. fusion_bench/method/pruning/wanda_utils/data.py +135 -0
  116. fusion_bench/method/pruning/wanda_utils/eval.py +245 -0
  117. fusion_bench/method/pruning/wanda_utils/layerwrapper.py +61 -0
  118. fusion_bench/method/pruning/wanda_utils/prune.py +581 -0
  119. fusion_bench/method/pruning/wanda_utils/prune_opt.py +539 -0
  120. fusion_bench/method/pruning/wanda_utils/sparsegpt.py +165 -0
  121. fusion_bench/method/pwe_moe/__init__.py +5 -0
  122. fusion_bench/method/pwe_moe/clip_pwe_moe.py +315 -0
  123. fusion_bench/method/pwe_moe/module.py +316 -0
  124. fusion_bench/method/pwe_moe/phn/__init__.py +2 -0
  125. fusion_bench/method/pwe_moe/phn/solvers.py +195 -0
  126. fusion_bench/method/pwe_moe/utils.py +43 -0
  127. fusion_bench/method/rankone_moe/__init__.py +3 -0
  128. fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
  129. fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
  130. fusion_bench/method/regmean/__init__.py +4 -0
  131. fusion_bench/method/regmean/clip_regmean.py +131 -0
  132. fusion_bench/method/regmean/gpt2_regmean.py +147 -0
  133. fusion_bench/method/regmean/regmean.py +375 -0
  134. fusion_bench/method/simple_average.py +112 -0
  135. fusion_bench/method/slerp/__init__.py +2 -0
  136. fusion_bench/method/slerp/slerp.py +101 -0
  137. fusion_bench/method/slerp/slerp_utils.py +107 -0
  138. fusion_bench/method/smile_upscaling/__init__.py +3 -0
  139. fusion_bench/method/smile_upscaling/singular_projection_merging.py +198 -0
  140. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +331 -0
  141. fusion_bench/method/smile_upscaling/smile_upscaling.py +573 -0
  142. fusion_bench/method/sparse_we_moe/__init__.py +2 -0
  143. fusion_bench/method/sparse_we_moe/sparse_clip_we_moe.py +248 -0
  144. fusion_bench/method/sparse_we_moe/sparse_we_moe.py +301 -0
  145. fusion_bench/method/sparselo/__init__.py +2 -0
  146. fusion_bench/method/sparselo/sparselo.py +955 -0
  147. fusion_bench/method/surgery/__init__.py +1 -0
  148. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
  149. fusion_bench/method/tall_mask/__init__.py +0 -0
  150. fusion_bench/method/tall_mask/utils.py +234 -0
  151. fusion_bench/method/task_arithmetic/__init__.py +2 -0
  152. fusion_bench/method/task_arithmetic/task_arithmetic.py +151 -0
  153. fusion_bench/method/task_singular_vector/TSVC.py +16 -0
  154. fusion_bench/method/task_singular_vector/TSVM.py +63 -0
  155. fusion_bench/method/task_singular_vector/__init__.py +9 -0
  156. fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
  157. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +640 -0
  158. fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
  159. fusion_bench/method/ties_merging/__init__.py +2 -0
  160. fusion_bench/method/ties_merging/ties_merging.py +117 -0
  161. fusion_bench/method/ties_merging/ties_merging_utils.py +331 -0
  162. fusion_bench/method/trust_region/__init__.py +2 -0
  163. fusion_bench/method/trust_region/clip_task_arithmetic.py +205 -0
  164. fusion_bench/method/trust_region/utils.py +58 -0
  165. fusion_bench/method/we_moe/__init__.py +2 -0
  166. fusion_bench/method/we_moe/clip_we_moe.py +161 -0
  167. fusion_bench/method/we_moe/we_moe.py +247 -0
  168. fusion_bench/method/weighted_average/__init__.py +3 -0
  169. fusion_bench/method/weighted_average/llama.py +113 -0
  170. fusion_bench/method/weighted_average/weighted_average.py +102 -0
  171. fusion_bench/metrics/__init__.py +0 -0
  172. fusion_bench/metrics/continual_learning/backward_transfer.py +22 -0
  173. fusion_bench/metrics/nyuv2/__init__.py +11 -0
  174. fusion_bench/metrics/nyuv2/depth.py +45 -0
  175. fusion_bench/metrics/nyuv2/loss.py +31 -0
  176. fusion_bench/metrics/nyuv2/noise.py +16 -0
  177. fusion_bench/metrics/nyuv2/normal.py +48 -0
  178. fusion_bench/metrics/nyuv2/segmentation.py +43 -0
  179. fusion_bench/metrics/text_to_image_generation/__init__.py +9 -0
  180. fusion_bench/metrics/text_to_image_generation/aesthetic_scorer.py +123 -0
  181. fusion_bench/metrics/text_to_image_generation/compressibility.py +49 -0
  182. fusion_bench/metrics/text_to_image_generation/pickscore_scorer.py +95 -0
  183. fusion_bench/mixins/__init__.py +28 -0
  184. fusion_bench/mixins/clip_classification.py +252 -0
  185. fusion_bench/mixins/fabric_training.py +320 -0
  186. fusion_bench/mixins/lightning_fabric.py +174 -0
  187. fusion_bench/mixins/optim/__init__.py +0 -0
  188. fusion_bench/mixins/optim/adamw_with_warmup.py +42 -0
  189. fusion_bench/mixins/rich_live.py +21 -0
  190. fusion_bench/mixins/serialization.py +132 -0
  191. fusion_bench/mixins/simple_profiler.py +79 -0
  192. fusion_bench/modelpool/PeftModelForSeq2SeqLM.py +49 -0
  193. fusion_bench/modelpool/__init__.py +42 -0
  194. fusion_bench/modelpool/base_pool.py +268 -0
  195. fusion_bench/modelpool/causal_lm/__init__.py +2 -0
  196. fusion_bench/modelpool/causal_lm/causal_lm.py +139 -0
  197. fusion_bench/modelpool/clip_vision/__init__.py +1 -0
  198. fusion_bench/modelpool/clip_vision/modelpool.py +145 -0
  199. fusion_bench/modelpool/huggingface_automodel.py +20 -0
  200. fusion_bench/modelpool/huggingface_gpt2_classification.py +63 -0
  201. fusion_bench/modelpool/nyuv2_modelpool.py +40 -0
  202. fusion_bench/modelpool/seq2seq_lm/__init__.py +2 -0
  203. fusion_bench/modelpool/seq2seq_lm/modelpool.py +65 -0
  204. fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
  205. fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
  206. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
  207. fusion_bench/models/__init__.py +3 -0
  208. fusion_bench/models/chat_templates/__init__.py +1 -0
  209. fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
  210. fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
  211. fusion_bench/models/hf_clip.py +199 -0
  212. fusion_bench/models/linearized/__init__.py +0 -0
  213. fusion_bench/models/linearized/linearized_model_utils.py +91 -0
  214. fusion_bench/models/linearized/vision_model.py +122 -0
  215. fusion_bench/models/llama/__init__.py +16 -0
  216. fusion_bench/models/llama/model_utils/__init__.py +0 -0
  217. fusion_bench/models/llama/model_utils/embedding.py +87 -0
  218. fusion_bench/models/llama/model_utils/liger_kernel.py +86 -0
  219. fusion_bench/models/llama/model_utils/misc.py +112 -0
  220. fusion_bench/models/llama/model_utils/mod.py +52 -0
  221. fusion_bench/models/llama/model_utils/visual.py +241 -0
  222. fusion_bench/models/llama/patcher.py +78 -0
  223. fusion_bench/models/llama/tokenizer_loader.py +153 -0
  224. fusion_bench/models/masks/__init__.py +2 -0
  225. fusion_bench/models/masks/mask_model.py +160 -0
  226. fusion_bench/models/modeling_losparse_llama/__init__.py +4 -0
  227. fusion_bench/models/modeling_losparse_llama/configuration_losparse_llama.py +205 -0
  228. fusion_bench/models/modeling_losparse_llama/losparse_linear.py +67 -0
  229. fusion_bench/models/modeling_losparse_llama/modeling_losparse_llama.py +1825 -0
  230. fusion_bench/models/modeling_losparse_llama/register.py +8 -0
  231. fusion_bench/models/modeling_losparse_llama/utils.py +60 -0
  232. fusion_bench/models/modeling_smile_mistral/__init__.py +48 -0
  233. fusion_bench/models/modeling_smile_mistral/configuration_smile_mistral.py +21 -0
  234. fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +1034 -0
  235. fusion_bench/models/modeling_smile_mistral/register.py +8 -0
  236. fusion_bench/models/nyuv2/__init__.py +0 -0
  237. fusion_bench/models/nyuv2/aspp.py +82 -0
  238. fusion_bench/models/nyuv2/lightning_module.py +176 -0
  239. fusion_bench/models/nyuv2/resnet.py +405 -0
  240. fusion_bench/models/nyuv2/resnet_dilated.py +99 -0
  241. fusion_bench/models/parameter_dict.py +75 -0
  242. fusion_bench/models/rankone_moe.py +410 -0
  243. fusion_bench/models/separate_io.py +105 -0
  244. fusion_bench/models/smile_moe/__init__.py +0 -0
  245. fusion_bench/models/smile_moe/linear.py +256 -0
  246. fusion_bench/models/sparse_we_moe.py +459 -0
  247. fusion_bench/models/surgery/__init__.py +1 -0
  248. fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
  249. fusion_bench/models/utils.py +80 -0
  250. fusion_bench/models/we_moe.py +247 -0
  251. fusion_bench/models/wrappers/__init__.py +0 -0
  252. fusion_bench/models/wrappers/ensemble.py +183 -0
  253. fusion_bench/models/wrappers/layer_wise_fusion.py +336 -0
  254. fusion_bench/models/wrappers/task_wise_fusion.py +249 -0
  255. fusion_bench/optim/__init__.py +2 -0
  256. fusion_bench/optim/exception.py +47 -0
  257. fusion_bench/optim/lr_scheduler/__init__.py +1 -0
  258. fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
  259. fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
  260. fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
  261. fusion_bench/optim/mezo.py +118 -0
  262. fusion_bench/programs/__init__.py +20 -0
  263. fusion_bench/programs/base_program.py +9 -0
  264. fusion_bench/programs/fabric_fusion_program.py +299 -0
  265. fusion_bench/scripts/__init__.py +0 -0
  266. fusion_bench/scripts/cli.py +43 -0
  267. fusion_bench/scripts/clip/__init__.py +0 -0
  268. fusion_bench/scripts/clip/convert_checkpoint.py +39 -0
  269. fusion_bench/scripts/imgui.py +218 -0
  270. fusion_bench/scripts/nyuv2_mtl_train.py +137 -0
  271. fusion_bench/scripts/webui.py +405 -0
  272. fusion_bench/taskpool/__init__.py +39 -0
  273. fusion_bench/taskpool/base_pool.py +35 -0
  274. fusion_bench/taskpool/clip_vision/__init__.py +4 -0
  275. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
  276. fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +120 -0
  277. fusion_bench/taskpool/clip_vision/taskpool.py +392 -0
  278. fusion_bench/taskpool/dummy.py +58 -0
  279. fusion_bench/taskpool/gpt2_text_classification.py +149 -0
  280. fusion_bench/taskpool/llama/__init__.py +1 -0
  281. fusion_bench/taskpool/llama/reward_model.py +157 -0
  282. fusion_bench/taskpool/llama/test_generation.py +185 -0
  283. fusion_bench/taskpool/nyuv2_taskpool.py +65 -0
  284. fusion_bench/tasks/__init__.py +2 -0
  285. fusion_bench/tasks/base_task.py +18 -0
  286. fusion_bench/tasks/classification.py +75 -0
  287. fusion_bench/tasks/clip_classification/__init__.py +183 -0
  288. fusion_bench/tasks/clip_classification/cifar10.py +33 -0
  289. fusion_bench/tasks/clip_classification/cifar100.py +146 -0
  290. fusion_bench/tasks/clip_classification/clip_dataset.py +1 -0
  291. fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
  292. fusion_bench/tasks/clip_classification/dtd.py +60 -0
  293. fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
  294. fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
  295. fusion_bench/tasks/clip_classification/eurosat.py +18 -0
  296. fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
  297. fusion_bench/tasks/clip_classification/fer2013.py +18 -0
  298. fusion_bench/tasks/clip_classification/flower102.py +106 -0
  299. fusion_bench/tasks/clip_classification/food101.py +105 -0
  300. fusion_bench/tasks/clip_classification/gtsrb.py +51 -0
  301. fusion_bench/tasks/clip_classification/imagenet.py +2103 -0
  302. fusion_bench/tasks/clip_classification/kmnist.py +17 -0
  303. fusion_bench/tasks/clip_classification/mnist.py +5 -0
  304. fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
  305. fusion_bench/tasks/clip_classification/oxford_iiit_pet.py +41 -0
  306. fusion_bench/tasks/clip_classification/pcam.py +5 -0
  307. fusion_bench/tasks/clip_classification/rendered_sst2.py +3 -0
  308. fusion_bench/tasks/clip_classification/resisc45.py +68 -0
  309. fusion_bench/tasks/clip_classification/stanford_cars.py +209 -0
  310. fusion_bench/tasks/clip_classification/stl10.py +17 -0
  311. fusion_bench/tasks/clip_classification/sun397.py +404 -0
  312. fusion_bench/tasks/clip_classification/svhn.py +5 -0
  313. fusion_bench/tasks/clip_classification/tiny_imagenet.py +208 -0
  314. fusion_bench/tasks/flan_t5_text_generation/__init__.py +0 -0
  315. fusion_bench/tasks/flan_t5_text_generation/datasets_preprocess.py +71 -0
  316. fusion_bench/tasks/flan_t5_text_generation/glue_evaluation.py +132 -0
  317. fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +64 -0
  318. fusion_bench/tasks/flan_t5_text_generation/glue_preprocessors.py +379 -0
  319. fusion_bench/tasks/flan_t5_text_generation/glue_prompt_templates.py +52 -0
  320. fusion_bench/utils/__init__.py +14 -0
  321. fusion_bench/utils/auto.py +31 -0
  322. fusion_bench/utils/cache_utils.py +58 -0
  323. fusion_bench/utils/data.py +165 -0
  324. fusion_bench/utils/devices.py +231 -0
  325. fusion_bench/utils/dict.py +43 -0
  326. fusion_bench/utils/dtype.py +146 -0
  327. fusion_bench/utils/expr.py +90 -0
  328. fusion_bench/utils/fabric.py +17 -0
  329. fusion_bench/utils/functools.py +37 -0
  330. fusion_bench/utils/hydra_utils.py +28 -0
  331. fusion_bench/utils/instantiate.py +450 -0
  332. fusion_bench/utils/json.py +93 -0
  333. fusion_bench/utils/lazy_imports.py +74 -0
  334. fusion_bench/utils/misc.py +18 -0
  335. fusion_bench/utils/packages.py +84 -0
  336. fusion_bench/utils/parameters.py +323 -0
  337. fusion_bench/utils/path.py +22 -0
  338. fusion_bench/utils/plot/__init__.py +0 -0
  339. fusion_bench/utils/plot/color_data.py +1726 -0
  340. fusion_bench/utils/plot/token.py +52 -0
  341. fusion_bench/utils/plot/token_notebook.py +127 -0
  342. fusion_bench/utils/pylogger.py +55 -0
  343. fusion_bench/utils/rich_utils.py +201 -0
  344. fusion_bench/utils/set.py +8 -0
  345. fusion_bench/utils/state_dict_arithmetic.py +297 -0
  346. fusion_bench/utils/strenum/__init__.py +326 -0
  347. fusion_bench/utils/strenum/_name_mangler.py +127 -0
  348. fusion_bench/utils/strenum/_version.py +556 -0
  349. fusion_bench/utils/tensorboard.py +51 -0
  350. fusion_bench/utils/timer.py +49 -0
  351. fusion_bench/utils/type.py +34 -0
  352. fusion_bench-0.2.9.dist-info/LICENSE +21 -0
  353. fusion_bench-0.2.9.dist-info/METADATA +258 -0
  354. fusion_bench-0.2.9.dist-info/RECORD +727 -0
  355. fusion_bench-0.2.9.dist-info/WHEEL +5 -0
  356. fusion_bench-0.2.9.dist-info/entry_points.txt +3 -0
  357. fusion_bench-0.2.9.dist-info/top_level.txt +1 -0
  358. fusion_bench_config/README.md +12 -0
  359. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +23 -0
  360. fusion_bench_config/dataset/image_classification/README.md +6 -0
  361. fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
  362. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
  363. fusion_bench_config/dataset/image_classification/test/cifar10.yaml +4 -0
  364. fusion_bench_config/dataset/image_classification/test/cifar100.yaml +4 -0
  365. fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
  366. fusion_bench_config/dataset/image_classification/test/dtd.yaml +4 -0
  367. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
  368. fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
  369. fusion_bench_config/dataset/image_classification/test/eurosat.yaml +4 -0
  370. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
  371. fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
  372. fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
  373. fusion_bench_config/dataset/image_classification/test/gtsrb.yaml +4 -0
  374. fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
  375. fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
  376. fusion_bench_config/dataset/image_classification/test/mnist.yaml +4 -0
  377. fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
  378. fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
  379. fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
  380. fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
  381. fusion_bench_config/dataset/image_classification/test/resisc45.yaml +4 -0
  382. fusion_bench_config/dataset/image_classification/test/stanford-cars.yaml +4 -0
  383. fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
  384. fusion_bench_config/dataset/image_classification/test/sun397.yaml +4 -0
  385. fusion_bench_config/dataset/image_classification/test/svhn.yaml +6 -0
  386. fusion_bench_config/dataset/image_classification/test/the_eight_tasks.yaml +9 -0
  387. fusion_bench_config/dataset/image_classification/test/tiny-imagenet.yaml +4 -0
  388. fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
  389. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
  390. fusion_bench_config/dataset/image_classification/train/cifar10.yaml +4 -0
  391. fusion_bench_config/dataset/image_classification/train/cifar100.yaml +4 -0
  392. fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
  393. fusion_bench_config/dataset/image_classification/train/dtd.yaml +4 -0
  394. fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
  395. fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
  396. fusion_bench_config/dataset/image_classification/train/eurosat.yaml +4 -0
  397. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
  398. fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
  399. fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
  400. fusion_bench_config/dataset/image_classification/train/gtsrb.yaml +4 -0
  401. fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
  402. fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
  403. fusion_bench_config/dataset/image_classification/train/mnist.yaml +4 -0
  404. fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
  405. fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
  406. fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
  407. fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
  408. fusion_bench_config/dataset/image_classification/train/resisc45.yaml +4 -0
  409. fusion_bench_config/dataset/image_classification/train/stanford-cars.yaml +4 -0
  410. fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
  411. fusion_bench_config/dataset/image_classification/train/sun397.yaml +4 -0
  412. fusion_bench_config/dataset/image_classification/train/svhn.yaml +6 -0
  413. fusion_bench_config/dataset/image_classification/train/the_eight_tasks.yaml +9 -0
  414. fusion_bench_config/dataset/image_classification/train/tiny-imagenet.yaml +4 -0
  415. fusion_bench_config/dataset/image_classification/val/dtd.yaml +10 -0
  416. fusion_bench_config/dataset/image_classification/val/eurosat.yaml +10 -0
  417. fusion_bench_config/dataset/image_classification/val/gtsrb.yaml +10 -0
  418. fusion_bench_config/dataset/image_classification/val/mnist.yaml +10 -0
  419. fusion_bench_config/dataset/image_classification/val/resisc45.yaml +10 -0
  420. fusion_bench_config/dataset/image_classification/val/stanford-cars.yaml +10 -0
  421. fusion_bench_config/dataset/image_classification/val/sun397.yaml +10 -0
  422. fusion_bench_config/dataset/image_classification/val/svhn.yaml +12 -0
  423. fusion_bench_config/dataset/image_classification/val/the_eight_tasks.yaml +9 -0
  424. fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
  425. fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
  426. fusion_bench_config/dataset/question_answering/search_qa.yaml +6 -0
  427. fusion_bench_config/dataset/question_answering/test/search_qa.yaml +7 -0
  428. fusion_bench_config/dataset/question_answering/train/MetaMathQA.yaml +4 -0
  429. fusion_bench_config/dataset/question_answering/train/search_qa.yaml +7 -0
  430. fusion_bench_config/dataset/question_answering/val/search_qa.yaml +7 -0
  431. fusion_bench_config/dataset/summarization/test/xsum.yaml +4 -0
  432. fusion_bench_config/dataset/summarization/train/xsum.yaml +4 -0
  433. fusion_bench_config/dataset/summarization/val/xsum.yaml +4 -0
  434. fusion_bench_config/dataset/summarization/xsum.yaml +3 -0
  435. fusion_bench_config/dataset/text_generation/test/gsm-hard.yaml +4 -0
  436. fusion_bench_config/dataset/text_generation/test/gsm8k.yaml +5 -0
  437. fusion_bench_config/dataset/text_generation/test/gsm8k_question_label.yaml +3 -0
  438. fusion_bench_config/dataset/text_generation/train/CodeAlpaca-20k.yaml +4 -0
  439. fusion_bench_config/dataset/text_generation/train/gsm8k.yaml +5 -0
  440. fusion_bench_config/dataset/text_generation/train/gsm8k_question_label.yaml +3 -0
  441. fusion_bench_config/fabric/auto.yaml +16 -0
  442. fusion_bench_config/fabric/llama_ddp.yaml +18 -0
  443. fusion_bench_config/fabric/llama_fsdp.yaml +16 -0
  444. fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
  445. fusion_bench_config/fabric/loggers/csv_logger.yaml +11 -0
  446. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +11 -0
  447. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  448. fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
  449. fusion_bench_config/fabric/strategy/llama_fsdp.yaml +8 -0
  450. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
  451. fusion_bench_config/fabric_model_fusion.yaml +20 -0
  452. fusion_bench_config/hydra/default.yaml +8 -0
  453. fusion_bench_config/hydra/help/fusion_bench_help.yaml +47 -0
  454. fusion_bench_config/hydra/job_logging/rich_logging.yaml +20 -0
  455. fusion_bench_config/llama_full_finetune.yaml +19 -0
  456. fusion_bench_config/llama_magnitude_pruning.yaml +16 -0
  457. fusion_bench_config/llama_model_fusion.yaml +17 -0
  458. fusion_bench_config/method/ada_svd/clip_vision.yaml +9 -0
  459. fusion_bench_config/method/adamerging/clip.yaml +23 -0
  460. fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +23 -0
  461. fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +23 -0
  462. fusion_bench_config/method/adamerging/llama_sft.yaml +33 -0
  463. fusion_bench_config/method/adamerging.yaml +23 -0
  464. fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +6 -0
  465. fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +6 -0
  466. fusion_bench_config/method/classification/clip_continual_finetune.yaml +28 -0
  467. fusion_bench_config/method/classification/clip_finetune.yaml +26 -0
  468. fusion_bench_config/method/clip_finetune.yaml +26 -0
  469. fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +27 -0
  470. fusion_bench_config/method/concrete_subspace/clip_concrete_task_arithmetic.yaml +25 -0
  471. fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +27 -0
  472. fusion_bench_config/method/dare/simple_average.yaml +5 -0
  473. fusion_bench_config/method/dare/task_arithmetic.yaml +6 -0
  474. fusion_bench_config/method/dare/ties_merging.yaml +15 -0
  475. fusion_bench_config/method/dawe/dawe_for_clip.yaml +32 -0
  476. fusion_bench_config/method/depth_upscaling.yaml +5 -0
  477. fusion_bench_config/method/dummy.yaml +1 -0
  478. fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -0
  479. fusion_bench_config/method/ensemble/simple_ensemble.yaml +2 -0
  480. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +6 -0
  481. fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +13 -0
  482. fusion_bench_config/method/fisher_merging/fisher_merging.yaml +9 -0
  483. fusion_bench_config/method/fisher_merging/gpt2_fisher_merging.yaml +12 -0
  484. fusion_bench_config/method/linear/expo.yaml +8 -0
  485. fusion_bench_config/method/linear/linear_interpolation.yaml +3 -0
  486. fusion_bench_config/method/linear/llama_expo.yaml +19 -0
  487. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +19 -0
  488. fusion_bench_config/method/linear/simple_average_for_llama.yaml +5 -0
  489. fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +4 -0
  490. fusion_bench_config/method/linear/weighted_average.yaml +6 -0
  491. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +12 -0
  492. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
  493. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +47 -0
  494. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +63 -0
  495. fusion_bench_config/method/mixtral_moe_merging.yaml +4 -0
  496. fusion_bench_config/method/mixtral_moe_upscaling.yaml +7 -0
  497. fusion_bench_config/method/model_recombination.yaml +4 -0
  498. fusion_bench_config/method/opcm/opcm.yaml +12 -0
  499. fusion_bench_config/method/opcm/task_arithmetic.yaml +12 -0
  500. fusion_bench_config/method/opcm/ties_merging.yaml +18 -0
  501. fusion_bench_config/method/opcm/weight_average.yaml +10 -0
  502. fusion_bench_config/method/pruning/llama_magnitude_pruning.yaml +14 -0
  503. fusion_bench_config/method/pruning/llama_random_pruning.yaml +9 -0
  504. fusion_bench_config/method/pruning/llama_wanda_pruning.yaml +16 -0
  505. fusion_bench_config/method/pruning/magnitude_diff_pruning.yaml +5 -0
  506. fusion_bench_config/method/pwe_moe_ls_for_clip.yaml +22 -0
  507. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
  508. fusion_bench_config/method/regmean/clip_regmean.yaml +11 -0
  509. fusion_bench_config/method/regmean/gpt2_regmean.yaml +12 -0
  510. fusion_bench_config/method/regmean/regmean.yaml +4 -0
  511. fusion_bench_config/method/simple_average.yaml +1 -0
  512. fusion_bench_config/method/slerp/slerp.yaml +6 -0
  513. fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +8 -0
  514. fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +10 -0
  515. fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +14 -0
  516. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +20 -0
  517. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +20 -0
  518. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +19 -0
  519. fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
  520. fusion_bench_config/method/task_arithmetic.yaml +2 -0
  521. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
  522. fusion_bench_config/method/ties_merging.yaml +8 -0
  523. fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +7 -0
  524. fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +39 -0
  525. fusion_bench_config/method/wemoe/weight_ensembling_moe.yaml +20 -0
  526. fusion_bench_config/model/clip-vit/README.md +38 -0
  527. fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -0
  528. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
  529. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
  530. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
  531. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
  532. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -0
  533. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eight_tasks.yaml +10 -0
  534. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
  535. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -0
  536. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
  537. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
  538. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
  539. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -0
  540. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
  541. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -0
  542. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
  543. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
  544. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
  545. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
  546. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -0
  547. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -0
  548. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
  549. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -0
  550. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -0
  551. fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -0
  552. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
  553. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
  554. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
  555. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
  556. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -0
  557. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +11 -0
  558. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
  559. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -0
  560. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
  561. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
  562. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
  563. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -0
  564. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
  565. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -0
  566. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
  567. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
  568. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
  569. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
  570. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -0
  571. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -0
  572. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
  573. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -0
  574. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -0
  575. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -0
  576. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
  577. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
  578. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
  579. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
  580. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -0
  581. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eight_tasks.yaml +10 -0
  582. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
  583. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -0
  584. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
  585. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
  586. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
  587. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -0
  588. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
  589. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -0
  590. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
  591. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
  592. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
  593. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
  594. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -0
  595. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -0
  596. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
  597. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -0
  598. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -0
  599. fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
  600. fusion_bench_config/model/clip-vit/generate_vit_model_config.sh +23 -0
  601. fusion_bench_config/model/flan-t5/flan-t5-base.yaml +3 -0
  602. fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola.yaml +3 -0
  603. fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola_lora-16.yaml +4 -0
  604. fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli.yaml +3 -0
  605. fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli_lora-16.yaml +4 -0
  606. fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc.yaml +3 -0
  607. fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc_lora-16.yaml +4 -0
  608. fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli.yaml +3 -0
  609. fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli_lora-16.yaml +4 -0
  610. fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp.yaml +3 -0
  611. fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp_lora-16.yaml +4 -0
  612. fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte.yaml +3 -0
  613. fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte_lora-16.yaml +4 -0
  614. fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2.yaml +3 -0
  615. fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2_lora-16.yaml +4 -0
  616. fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb.yaml +3 -0
  617. fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb_lora-16.yaml +4 -0
  618. fusion_bench_config/model/flan-t5/flan-t5-large.yaml +3 -0
  619. fusion_bench_config/model/flan-t5/flan-t5-large_glue-cola_lora-16.yaml +4 -0
  620. fusion_bench_config/model/flan-t5/flan-t5-large_glue-mnli_lora-16.yaml +4 -0
  621. fusion_bench_config/model/flan-t5/flan-t5-large_glue-mrpc_lora-16.yaml +4 -0
  622. fusion_bench_config/model/flan-t5/flan-t5-large_glue-qnli_lora-16.yaml +4 -0
  623. fusion_bench_config/model/flan-t5/flan-t5-large_glue-qqp_lora-16.yaml +4 -0
  624. fusion_bench_config/model/flan-t5/flan-t5-large_glue-rte_lora-16.yaml +4 -0
  625. fusion_bench_config/model/flan-t5/flan-t5-large_glue-sst2_lora-16.yaml +4 -0
  626. fusion_bench_config/model/flan-t5/flan-t5-large_glue-stsb_lora-16.yaml +4 -0
  627. fusion_bench_config/model/flan-t5/generate_flan-t5.sh +38 -0
  628. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +12 -0
  629. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8.yaml +8 -0
  630. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +53 -0
  631. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
  632. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
  633. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
  634. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
  635. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
  636. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +19 -0
  637. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +14 -0
  638. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +5 -0
  639. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +24 -0
  640. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +3 -0
  641. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
  642. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
  643. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
  644. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
  645. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp1.yaml +24 -0
  646. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp2.yaml +24 -0
  647. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +13 -0
  648. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_mtl.yaml +5 -0
  649. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_clean.yaml +18 -0
  650. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +29 -0
  651. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +5 -0
  652. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
  653. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +6 -0
  654. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +18 -0
  655. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +8 -0
  656. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +6 -0
  657. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
  658. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
  659. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
  660. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
  661. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +19 -0
  662. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
  663. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
  664. fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +20 -0
  665. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
  666. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
  667. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +21 -0
  668. fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +17 -0
  669. fusion_bench_config/modelpool/Seq2SeqLMPool/_template.yaml +8 -0
  670. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +13 -0
  671. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +41 -0
  672. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +68 -0
  673. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +7 -0
  674. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +45 -0
  675. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
  676. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
  677. fusion_bench_config/modelpool/automodelpool.yaml +12 -0
  678. fusion_bench_config/modelpool/gpt-2_glue.yaml +64 -0
  679. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +14 -0
  680. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +6 -0
  681. fusion_bench_config/modelpool/nyuv2_modelpool.yaml +26 -0
  682. fusion_bench_config/modelpool/smile_mistral_exp_v1.yaml +9 -0
  683. fusion_bench_config/modelpool/smile_mistral_exp_v2.yaml +9 -0
  684. fusion_bench_config/modelpool/smile_mistral_exp_v3.yaml +9 -0
  685. fusion_bench_config/modelpool/smile_mistral_exp_v4.yaml +13 -0
  686. fusion_bench_config/nyuv2_config.yaml +17 -0
  687. fusion_bench_config/nyuv2_mtl_train.yaml +32 -0
  688. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +31 -0
  689. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
  690. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8.yaml +11 -0
  691. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +31 -0
  692. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_L14.yaml +12 -0
  693. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_val.yaml +12 -0
  694. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_with_control_task.yaml +12 -0
  695. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
  696. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
  697. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
  698. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
  699. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
  700. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
  701. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
  702. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
  703. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
  704. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
  705. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
  706. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
  707. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
  708. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
  709. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
  710. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
  711. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
  712. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
  713. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
  714. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
  715. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
  716. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
  717. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
  718. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
  719. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +18 -0
  720. fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_clean.yaml +24 -0
  721. fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
  722. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +22 -0
  723. fusion_bench_config/taskpool/dummy.yaml +2 -0
  724. fusion_bench_config/taskpool/flan-t5_glue_text_generation.yaml +44 -0
  725. fusion_bench_config/taskpool/gpt-2_glue.yaml +39 -0
  726. fusion_bench_config/taskpool/nyuv2_taskpool.yaml +9 -0
  727. fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
@@ -0,0 +1,1019 @@
1
+ """
2
+ This module contains classes for representing ARC tasks, examples, and grids in different formats.
3
+ """
4
+
5
+ import re
6
+ from abc import ABC, abstractmethod
7
+ from io import BytesIO
8
+ from types import WrapperDescriptorType
9
+ from typing import List, Optional, Text, Tuple, Union
10
+
11
+ import matplotlib.pyplot as plt
12
+ import numpy as np
13
+ from matplotlib.colors import ListedColormap, Normalize
14
+ from scipy.ndimage import generate_binary_structure, label
15
+
16
+ from .arc import Example, Grid, Task
17
+ from .np_cache import np_lru_cache
18
+
19
+ # =============== CONSTANTS ===============
20
+ COLUMN_SEP = " "
21
+ ROW_SEP = "\n"
22
+ IO_SEP = "\n\n"
23
+ EXAMPLE_SEP = "\n\n"
24
+ TRAIN_HEADER = "==TRAIN==\n"
25
+ TRAIN_TEST_SEP = "\n\n"
26
+ TEST_HEADER = "==TEST==\n"
27
+
28
+ # =============== UTILS ===============
29
+
30
+
31
+ def parse_numpy_from_str(array_str: str) -> np.ndarray:
32
+ """
33
+ Parses a string representation of a 2D array into a NumPy ndarray.
34
+
35
+ Parameters:
36
+ - array_str (str): A string representation of a 2D array, where rows are separated by newlines.
37
+
38
+ Returns:
39
+ - np.ndarray: A NumPy array of type int8 representing the parsed 2D array.
40
+ """
41
+ try:
42
+ # Remove the surrounding brackets from the string
43
+ clean_str = array_str.replace("[", "").replace("]", "")
44
+
45
+ # Split the cleaned string by whitespace to get individual elements and convert them to integers
46
+ elements = list(map(int, clean_str.split()))
47
+
48
+ # Determine the number of rows by counting the newline characters and adding one
49
+ rows = array_str.count("\n") + 1
50
+
51
+ # Calculate the number of columns by dividing the total number of elements by the number of rows
52
+ cols = len(elements) // rows
53
+
54
+ # Create the NumPy array with the determined shape and convert it to type int8
55
+ array = np.array(elements).reshape((rows, cols)).astype(np.int8)
56
+
57
+ return array
58
+ except Exception as e:
59
+ # Print the exception message and the original string for debugging purposes
60
+ print(e)
61
+ print(array_str)
62
+ # Return a default 1x1 array with a zero element in case of an error
63
+ # raise e
64
+ return None
65
+
66
+
67
+ # =============== INTERFACE ===============
68
+
69
+
70
+ class GridRepresenter(ABC):
71
+ @abstractmethod
72
+ def encode(self, grid: Grid) -> str:
73
+ pass
74
+
75
+ @abstractmethod
76
+ def decode(self, encoded_str: str, **kwargs) -> Grid:
77
+ pass
78
+
79
+ def display(self, encoded_str: str):
80
+ print(self.decode(encoded_str))
81
+
82
+
83
+ class ExampleRepresenter(ABC):
84
+ grid_representer: GridRepresenter
85
+
86
+ @abstractmethod
87
+ def encode(self, example: Example, **kwargs) -> Union[str, Tuple[str, str]]:
88
+ pass
89
+
90
+ @abstractmethod
91
+ def decode(self, encoded: Tuple[str, str], **kwargs) -> Example:
92
+ pass
93
+
94
+ def display(self, encoded: Union[str, Tuple[str, str]]):
95
+ if isinstance(encoded, str):
96
+ print(encoded)
97
+ else:
98
+ print("\n".join(encoded))
99
+
100
+
101
+ class TaskRepresenter(ABC):
102
+ example_representer: ExampleRepresenter
103
+
104
+ @abstractmethod
105
+ def encode(self, task: Task, **kwargs) -> Union[Tuple[str, str, str], str]:
106
+ pass
107
+
108
+ @abstractmethod
109
+ def decode(self, encoded: Tuple[str, str], **kwargs) -> Task:
110
+ pass
111
+
112
+ def display(self, encoded: Union[str, Tuple[str, str]]):
113
+ if isinstance(encoded, str):
114
+ print(encoded)
115
+ else:
116
+ print("\n".join(encoded))
117
+
118
+
119
+ # =============== GRID REPRESENTATION ===============
120
+ class DelimitedGridRepresenter(GridRepresenter):
121
+ def __init__(self, column_sep: str = " ", row_sep: str = "\n"):
122
+ self.column_sep: str = column_sep
123
+ self.row_sep: str = row_sep
124
+
125
+ def encode(self, grid: Grid) -> str:
126
+ output = ""
127
+ for i in range(grid.shape[0]):
128
+ for j in range(grid.shape[1]):
129
+ output += str(grid[i][j]) + self.column_sep
130
+ output = output[:-1] + self.row_sep
131
+ return output[: -len(self.row_sep)]
132
+
133
+ def decode(self, encoded_str: str) -> Grid:
134
+ rows = encoded_str.strip().split(self.row_sep)
135
+ grid = [list(map(int, row.split(self.column_sep))) for row in rows]
136
+ return np.array(grid)
137
+
138
+ def __repr__(self) -> str:
139
+ return f"DelimitedGridRepresenter(column_sep={self.column_sep!r}, row_sep={self.row_sep!r})"
140
+
141
+ def __str__(self) -> str:
142
+ return repr(self)
143
+
144
+
145
+ class PythonListGridRepresenter(GridRepresenter):
146
+ def encode(self, grid: Grid) -> str:
147
+ return str(grid)
148
+
149
+ def decode(self, encoded_str: str) -> Grid:
150
+ return parse_numpy_from_str(encoded_str)
151
+
152
+ def __repr__(self) -> str:
153
+ return "PythonListGridRepresenter()"
154
+
155
+ def __str__(self) -> str:
156
+ return repr(self)
157
+
158
+
159
+ # Used in BARC
160
+ class WordGridRepresenter(GridRepresenter):
161
+ color_map = {
162
+ 0: "Black",
163
+ 1: "Blue",
164
+ 2: "Red",
165
+ 3: "Green",
166
+ 4: "Yellow",
167
+ 5: "Gray",
168
+ 6: "Pink",
169
+ 7: "Orange",
170
+ 8: "Purple",
171
+ 9: "Brown",
172
+ }
173
+
174
+ def __init__(self):
175
+ self.inv_map = {v: k for k, v in self.color_map.items()}
176
+
177
+ def encode(self, grid: Grid) -> str:
178
+ output = ""
179
+ for i in range(grid.shape[0]):
180
+ for j in range(grid.shape[1]):
181
+ output += self.color_map[grid[i][j]] + " "
182
+ output = output[:-1] + "\n"
183
+ return output[:-1]
184
+
185
+ def decode(self, encoded_str: str) -> Grid:
186
+ rows = encoded_str.strip().split("\n")
187
+ grid = [[self.inv_map[color] for color in row.split()] for row in rows]
188
+ return np.array(grid)
189
+
190
+ def __str__(self) -> str:
191
+ return "WordGridRepresenter()"
192
+
193
+ def __repr__(self) -> str:
194
+ return "WordGridRepresenter()"
195
+
196
+
197
+ # This is adapted from Greenblat 2024
198
+ class ConnectedComponentRepresenter(GridRepresenter):
199
+ normalized: bool = True
200
+ max_token_per_color: Optional[int] = None
201
+ disable_absolute: bool = False
202
+ dotsafter: int = 4
203
+ connected_component: int = 4
204
+ sort_by_count: bool = False
205
+ spreadsheet_col_labels: List[str] = [
206
+ "A",
207
+ "B",
208
+ "C",
209
+ "D",
210
+ "E",
211
+ "F",
212
+ "G",
213
+ "H",
214
+ "I",
215
+ "J",
216
+ "K",
217
+ "L",
218
+ "M",
219
+ "N",
220
+ "O",
221
+ "P",
222
+ "Q",
223
+ "R",
224
+ "S",
225
+ "T",
226
+ "U",
227
+ "V",
228
+ "W",
229
+ "X",
230
+ "Y",
231
+ "Z",
232
+ "AA",
233
+ "AB",
234
+ "AC",
235
+ "AD",
236
+ "AE",
237
+ "AF",
238
+ ]
239
+
240
+ def __init__(
241
+ self,
242
+ normalized: bool = True,
243
+ max_token_per_color: Optional[int] = None,
244
+ disable_absolute: bool = False,
245
+ sort_by_count: bool = False,
246
+ ):
247
+ self.normalized = normalized
248
+ self.max_token_per_color = max_token_per_color
249
+ self.disable_absolute = disable_absolute
250
+ self.sort_by_count = sort_by_count
251
+
252
+ def to_spreadsheet(self, i: int, j: int) -> str:
253
+ try:
254
+ out = f"{self.spreadsheet_col_labels[j]}{i+1}"
255
+ except IndexError:
256
+ print(i, j)
257
+ raise
258
+ return out
259
+
260
+ def to_spreadsheet_with_dots(self, rows_cols: List[Tuple[int, int]]) -> str:
261
+ row_cols_v = np.array(sorted(rows_cols, key=lambda x: (x[0], x[1])))
262
+ running_str = ""
263
+ idx = 0
264
+ while idx < len(row_cols_v):
265
+ r, c = row_cols_v[idx]
266
+ count_in_a_row = 0
267
+ for checking_idx, (n_r, n_c) in enumerate(row_cols_v[idx:]):
268
+ if n_r == r and n_c == c + checking_idx:
269
+ count_in_a_row += 1
270
+ else:
271
+ break
272
+ if count_in_a_row > self.dotsafter:
273
+ start = self.to_spreadsheet(r, c)
274
+ c_end = c + count_in_a_row - 1
275
+ assert np.array_equal(
276
+ row_cols_v[idx + count_in_a_row - 1], (r, c_end)
277
+ ), (
278
+ row_cols_v[idx + count_in_a_row - 1],
279
+ (r, c_end),
280
+ )
281
+ end = self.to_spreadsheet(r, c_end)
282
+ running_str += f" {start} ... {end}"
283
+ idx += count_in_a_row
284
+ else:
285
+ running_str += " " + self.to_spreadsheet(r, c)
286
+ idx += 1
287
+ return running_str
288
+
289
+ def find_contiguous_shapes(self, grid: Grid, color: int) -> List[np.ndarray]:
290
+ labeled_array, num_features = label(grid == color)
291
+ shapes = []
292
+ for i in range(1, num_features + 1):
293
+ shapes.append(np.argwhere(labeled_array == i))
294
+ if self.sort_by_count:
295
+ shapes = sorted(shapes, key=lambda x: len(x), reverse=True)
296
+ return shapes
297
+
298
+ def find_contiguous_shapes_moore(self, grid: Grid, color: int) -> List[np.ndarray]:
299
+ s = generate_binary_structure(2, 2)
300
+ labeled_array, num_features = label(grid == color, structure=s)
301
+ shapes = []
302
+ for i in range(1, num_features + 1):
303
+ shapes.append(np.argwhere(labeled_array == i))
304
+ if self.sort_by_count:
305
+ shapes = sorted(shapes, key=lambda x: len(x), reverse=True)
306
+ return shapes
307
+
308
+ def encode(self, grid: Grid, connected_component: Optional[int] = None) -> str:
309
+ out = "[["
310
+ if connected_component is None:
311
+ connected_component = self.connected_component
312
+ if connected_component == 4:
313
+ out += "(4CC)\n"
314
+ elif connected_component == 8:
315
+ out += "(8CC)\n"
316
+
317
+ color_shapes = []
318
+
319
+ for color in range(10):
320
+ if connected_component == 4:
321
+ contiguous_shapes = self.find_contiguous_shapes(grid, color)
322
+ elif connected_component == 8:
323
+ contiguous_shapes = self.find_contiguous_shapes_moore(grid, color)
324
+ color_shapes.append(contiguous_shapes)
325
+
326
+ if self.sort_by_count:
327
+ sorted_index = np.argsort(
328
+ [
329
+ sum([len(shape) for shape in contiguous_shapes])
330
+ for contiguous_shapes in color_shapes
331
+ ]
332
+ )
333
+ else:
334
+ sorted_index = np.arange(10)
335
+
336
+ for color in sorted_index:
337
+ contiguous_shapes = color_shapes[color]
338
+ if len(contiguous_shapes) == 0:
339
+ continue
340
+ shape_strings = []
341
+ for shape in contiguous_shapes:
342
+ if self.normalized:
343
+ min_i = min(i for i, j in shape)
344
+ min_j = min(j for i, j in shape)
345
+ normalized = [
346
+ (i - min_i, j - min_j)
347
+ for i, j in sorted(shape, key=lambda x: (int(x[0]), int(x[1])))
348
+ ]
349
+ basic_shape_str = self.to_spreadsheet_with_dots(normalized)
350
+ if not self.disable_absolute:
351
+ shape_str = (
352
+ "[Abs. "
353
+ + self.to_spreadsheet(min_i, min_j)
354
+ + "]"
355
+ + basic_shape_str
356
+ )
357
+ else:
358
+ shape_str = basic_shape_str
359
+ else:
360
+ shape = [
361
+ (i, j)
362
+ for i, j in sorted(shape, key=lambda x: (int(x[0]), int(x[1])))
363
+ ]
364
+ shape_str = self.to_spreadsheet_with_dots(shape)
365
+ shape_strings.append(shape_str)
366
+
367
+ full_str = " | ".join(shape_strings)
368
+ if self.max_token_per_color and self.max_token_per_color < len(
369
+ full_str.split(" ")
370
+ ):
371
+ color_str = " [OMITTED DUE TO EXCESSIVE LENGTH]"
372
+ else:
373
+ color_str = full_str
374
+
375
+ out += f"{color}: {color_str}\n"
376
+
377
+ return out + "]]"
378
+
379
+ def parse_position(self, pos):
380
+ # find the letter part
381
+ letter = re.findall(r"[A-Z]+", pos)[0]
382
+ # find the number part
383
+ number = re.findall(r"[0-9]+", pos)[0]
384
+ row = int(number) - 1
385
+ column = self.spreadsheet_col_labels.index(letter)
386
+ return row, column
387
+
388
+ def decode(self, encoded_str: str) -> Grid:
389
+ encoded_str = encoded_str.replace("[[", "").replace("]]", "")
390
+ encoded_str = encoded_str.replace("(8CC)\n", "").replace("(4CC)\n", "")
391
+ max_row, max_col = (
392
+ 30,
393
+ 30,
394
+ ) # Adjusted for the given example, can be adjusted if needed
395
+ grid = np.full(
396
+ (max_row, max_col), -1
397
+ ) # Initialize with -1 to indicate empty cells
398
+
399
+ # Process each color and its components
400
+ for line in encoded_str.strip().split("\n"):
401
+ color, components = line.split(": ")
402
+ color = int(color)
403
+ components = components.split(" | ")
404
+ for component in components:
405
+ component = component.replace("[Abs. ", "").replace("]", "")
406
+ component = component.replace(" ... ", "...")
407
+ abs_pos, *rel_positions = component.split()
408
+ abs_pos = self.parse_position(abs_pos)
409
+ abs_row, abs_col = abs_pos
410
+ for rel_position in rel_positions:
411
+ if "..." in rel_position:
412
+ start_pos, end_pos = rel_position.split("...")
413
+ start_row, start_col = self.parse_position(start_pos)
414
+ end_row, end_col = self.parse_position(end_pos)
415
+ for row in range(start_row, end_row + 1):
416
+ grid[abs_row + row][
417
+ abs_col + start_col : abs_col + end_col + 1
418
+ ] = color
419
+ else:
420
+ rel_row, rel_col = self.parse_position(rel_position)
421
+ grid[abs_row + rel_row][abs_col + rel_col] = color
422
+
423
+ # crop the grid from -1s
424
+ rows = np.all(grid == -1, axis=1)
425
+ cols = np.all(grid == -1, axis=0)
426
+ grid = grid[~rows]
427
+ grid = grid[:, ~cols]
428
+ # replace remaining -1s with 0
429
+ grid = np.where(grid == -1, 0, grid)
430
+ return grid
431
+
432
+
433
+ class CompositeRepresenter(GridRepresenter):
434
+ connected_component: int = 4
435
+
436
+ def __init__(self, representers: List[GridRepresenter]):
437
+ self.representers = representers
438
+
439
+ def encode(
440
+ self,
441
+ grid: Grid,
442
+ actives: Optional[List[int]] = None,
443
+ connected_component: Optional[int] = None,
444
+ ) -> str:
445
+ if actives is not None:
446
+ representers = [self.representers[i] for i in actives]
447
+ else:
448
+ representers = self.representers
449
+
450
+ out = ""
451
+ for representer in representers:
452
+ if isinstance(representer, ConnectedComponentRepresenter):
453
+ kwargs = {"connected_component": connected_component}
454
+ else:
455
+ kwargs = {}
456
+ out += representer.encode(grid, **kwargs)
457
+ out += "\n"
458
+ return out.strip()
459
+
460
+ def decode(self, encoded_str: str, actives: Optional[List[int]] = None) -> Grid:
461
+ # Decoding logic for CompositeRepresenter is complex and depends on the specific encoding format.
462
+ # This is a placeholder for the actual implementation.
463
+ raise NotImplementedError(
464
+ "Decoding for CompositeRepresenter is not implemented."
465
+ )
466
+
467
+
468
+ class ImageGridRepresenter(GridRepresenter):
469
+ cmap: ListedColormap = ListedColormap(
470
+ [
471
+ "#000000",
472
+ "#0074D9",
473
+ "#FF4136",
474
+ "#2ECC40",
475
+ "#FFDC00",
476
+ "#AAAAAA",
477
+ "#F012BE",
478
+ "#FF851B",
479
+ "#7FDBFF",
480
+ "#870C25",
481
+ ]
482
+ )
483
+
484
+ cnames: List[str] = [
485
+ "black",
486
+ "blue",
487
+ "red",
488
+ "green",
489
+ "yellow",
490
+ "gray",
491
+ "magenta",
492
+ "orange",
493
+ "lightblue",
494
+ "brown",
495
+ ]
496
+
497
+ @np_lru_cache(maxsize=8096)
498
+ def encode(self, grid: Grid) -> str:
499
+ # make sure the actual pixels are based on the grid's size
500
+ fig, ax = plt.subplots(figsize=(len(grid[0]) / 2, len(grid) / 2))
501
+ norm = Normalize(vmin=0, vmax=9)
502
+ ax.imshow(grid, cmap=self.cmap, norm=norm)
503
+ ax.grid(True, which="both", color="lightgrey", linewidth=0.5)
504
+ ax.set_yticks([x - 0.5 for x in range(1 + len(grid))])
505
+ ax.set_xticks([x - 0.5 for x in range(1 + len(grid[0]))])
506
+ ax.set_xticklabels([])
507
+ ax.set_yticklabels([])
508
+ plt.tight_layout()
509
+ # Save the plot to a BytesIO object
510
+ buf = BytesIO()
511
+ plt.savefig(buf, format="png", bbox_inches="tight")
512
+ plt.close(fig)
513
+ buf.seek(0)
514
+
515
+ # Encode the image in base64
516
+ img_base64 = base64.b64encode(buf.read()).decode("utf-8")
517
+ return img_base64
518
+
519
+ def display(self, encoded_str: str):
520
+ imgdata = base64.b64decode(encoded_str)
521
+ plt.imshow(plt.imread(BytesIO(imgdata)))
522
+ plt.axis("off")
523
+ plt.show()
524
+
525
+ def decode(self, encoded_str: str, **kwargs) -> Grid:
526
+ raise NotImplementedError(
527
+ "Decoding for ImageGridRepresenter is not implemented."
528
+ )
529
+
530
+
531
+ class ConnectedComponentRepresenterV2(GridRepresenter):
532
+ def __init__(self, sort_by_count: bool = False, connected_component: int = 4):
533
+ self.sort_by_count = sort_by_count
534
+ self.connected_component = connected_component
535
+
536
+ def find_contiguous_shapes(
537
+ self, grid: Grid, color: int, include_diagonals=False
538
+ ) -> List[np.ndarray]:
539
+ if include_diagonals:
540
+ mask = generate_binary_structure(2, 2)
541
+ else:
542
+ mask = None
543
+ labeled_array, num_features = label(grid == color, structure=mask)
544
+ shapes = []
545
+ for i in range(1, num_features + 1):
546
+ shapes.append(np.argwhere(labeled_array == i))
547
+ if self.sort_by_count:
548
+ shapes = sorted(shapes, key=lambda x: len(x), reverse=True)
549
+ return shapes
550
+
551
+ def encode(self, grid: Grid, connected_component: Optional[int] = None) -> str:
552
+ if connected_component is None:
553
+ connected_component = self.connected_component
554
+
555
+ color_shapes = []
556
+
557
+ for color in range(10):
558
+ contiguous_shapes = self.find_contiguous_shapes(
559
+ grid, color, include_diagonals=connected_component == 8
560
+ )
561
+ color_shapes.append(contiguous_shapes)
562
+
563
+ if self.sort_by_count:
564
+ sorted_index = np.argsort(
565
+ [
566
+ sum([len(shape) for shape in contiguous_shapes])
567
+ for contiguous_shapes in color_shapes
568
+ ]
569
+ )
570
+ else:
571
+ sorted_index = np.arange(10)
572
+ # specify the shape
573
+ output = f"(height={grid.shape[0]}, width={grid.shape[1]})\n"
574
+ for k, color in enumerate(sorted_index):
575
+ # skip color 0
576
+ if color == 0:
577
+ continue
578
+ contiguous_shapes = color_shapes[color]
579
+ if len(contiguous_shapes) == 0:
580
+ continue
581
+ shape_strings = []
582
+ for shape in contiguous_shapes:
583
+ min_i, min_j = np.min(shape, axis=0)
584
+ max_i, max_j = np.max(shape, axis=0)
585
+
586
+ subshape = grid[min_i : max_i + 1, min_j : max_j + 1]
587
+ subshape_str = PythonListGridRepresenter().encode(subshape)
588
+
589
+ shape_str = (
590
+ f"Shape(color={color}, pos=({min_i},{min_j}), grid={subshape_str})"
591
+ )
592
+ shape_strings.append(shape_str)
593
+
594
+ output += "- " + "\n- ".join(shape_strings)
595
+ if k != len(sorted_index) - 1:
596
+ output += "\n"
597
+ return output
598
+
599
+ def decode(self, encoded_str: str) -> Grid:
600
+ return None
601
+
602
+ def __repr__(self) -> str:
603
+ return f"ConnectedComponentRepresenterV2(sort_by_count={self.sort_by_count}, connected_component={self.connected_component})"
604
+
605
+
606
+ # =============== Example REPRESENTATION ===============
607
+
608
+
609
+ class TextExampleRepresenter(ExampleRepresenter):
610
+ def __init__(
611
+ self,
612
+ io_sep: str = " -> ",
613
+ input_header: str = "",
614
+ output_header: str = "",
615
+ output_footer="",
616
+ grid_representer: GridRepresenter = PythonListGridRepresenter(),
617
+ ):
618
+ self.io_sep = io_sep
619
+ self.input_header = input_header
620
+ self.output_header = output_header
621
+ self.output_footer = output_footer
622
+ self.grid_representer = grid_representer
623
+
624
+ def encode(self, example: Example, **kwargs) -> Tuple[str, str]:
625
+ input_str = self.grid_representer.encode(example.input, **kwargs)
626
+ if self.input_header:
627
+ input_header = self.input_header
628
+ else:
629
+ input_header = ""
630
+
631
+ output_str = self.grid_representer.encode(example.output, **kwargs)
632
+ if self.output_header:
633
+ output_header = self.output_header
634
+ else:
635
+ output_header = ""
636
+
637
+ return (
638
+ f"{input_header}{input_str}{self.io_sep}{output_header}",
639
+ f"{output_str}{self.output_footer}",
640
+ )
641
+
642
+ def decode(self, encoded: Tuple[str, str], **kwargs) -> Example:
643
+ input_str, output_str = encoded
644
+ input_str = input_str.replace(self.input_header, "").replace(
645
+ self.output_header, ""
646
+ )
647
+ if self.io_sep != "\n":
648
+ input_str = input_str.replace(self.io_sep, "")
649
+ input_str.strip()
650
+
651
+ output_str = (
652
+ output_str.replace(self.input_header, "")
653
+ .replace(self.output_header, "")
654
+ .strip()
655
+ )
656
+
657
+ input_grid = self.grid_representer.decode(input_str, **kwargs)
658
+ output_grid = self.grid_representer.decode(output_str, **kwargs)
659
+ return Example(input=input_grid, output=output_grid)
660
+
661
+ def __repr__(self) -> str:
662
+ return f"TextExampleRepresenter(io_sep={self.io_sep!r}, input_header={self.input_header!r}, output_header={self.output_header!r}, output_footer={self.output_footer!r}, grid_representer={repr(self.grid_representer)})"
663
+
664
+ def __str__(self) -> str:
665
+ return repr(self)
666
+
667
+
668
+ class ImageExampleRepresenter(ExampleRepresenter):
669
+ def __init__(self, grid_representer: ImageGridRepresenter = ImageGridRepresenter()):
670
+ self.grid_representer = grid_representer
671
+
672
+ def encode(self, example: Example, **kwargs) -> str:
673
+ input_grid = example.input
674
+ output_grid = example.output
675
+
676
+ # Create a figure with two subplots side by side
677
+ # max height
678
+ height = max(len(input_grid), len(output_grid))
679
+ # max width
680
+ width = max(len(input_grid[0]), len(output_grid[0]))
681
+ fig, axes = plt.subplots(1, 2, figsize=(height, width))
682
+
683
+ # Plot input grid
684
+ grid = input_grid
685
+ norm = Normalize(vmin=0, vmax=9)
686
+ ax = axes[0]
687
+ ax.imshow(input_grid, cmap=self.grid_representer.cmap, norm=norm)
688
+ ax.set_title("Input")
689
+ ax.grid(True, which="both", color="lightgrey", linewidth=0.5)
690
+ ax.set_yticks([x - 0.5 for x in range(1 + len(grid))])
691
+ ax.set_xticks([x - 0.5 for x in range(1 + len(grid[0]))])
692
+ ax.set_xticklabels([])
693
+ ax.set_yticklabels([])
694
+
695
+ # Plot output grid
696
+ grid = output_grid
697
+ ax = axes[1]
698
+ ax.imshow(output_grid, cmap=self.grid_representer.cmap, norm=norm)
699
+ ax.set_title("Output")
700
+ ax.grid(True, which="both", color="lightgrey", linewidth=0.5)
701
+ ax.set_yticks([x - 0.5 for x in range(1 + len(grid))])
702
+ ax.set_xticks([x - 0.5 for x in range(1 + len(grid[0]))])
703
+ ax.set_xticklabels([])
704
+ ax.set_yticklabels([])
705
+
706
+ plt.tight_layout()
707
+
708
+ # Save the plot to a BytesIO object
709
+ buf = BytesIO()
710
+ plt.savefig(buf, format="png", bbox_inches="tight")
711
+ plt.close(fig)
712
+ buf.seek(0)
713
+
714
+ # Encode the image in base64
715
+ img_base64 = base64.b64encode(buf.read()).decode("utf-8")
716
+ return img_base64
717
+
718
+ def display(self, encoded_str: str):
719
+ imgdata = base64.b64decode(encoded_str)
720
+ plt.imshow(plt.imread(BytesIO(imgdata)))
721
+ plt.axis("off")
722
+ plt.show()
723
+
724
+ def decode(self, encoded_str: str, **kwargs) -> Example:
725
+ raise NotImplementedError(
726
+ "Decoding for ImageExampleRepresenter is not implemented."
727
+ )
728
+
729
+
730
+ class DiffExampleRepresenter(ExampleRepresenter):
731
+ def __init__(
732
+ self,
733
+ grid_representer: GridRepresenter = PythonListGridRepresenter(),
734
+ io_sep: str = "\n\n",
735
+ input_header: str = "INPUT:\n",
736
+ output_header: str = "OUTPUT:\n",
737
+ output_footer: str = "",
738
+ use_output: bool = True,
739
+ diff_sep: str = " + ",
740
+ diff_output_sep: str = ",",
741
+ ):
742
+ self.io_sep = io_sep
743
+ self.input_header = input_header
744
+ self.output_header = output_header
745
+ self.diff_output_sep = diff_output_sep
746
+ self.diff_sep = diff_sep
747
+ self.use_output = use_output
748
+ self.output_footer = output_footer
749
+ self.grid_representer = grid_representer
750
+
751
+ def encode(self, example: Example, **kwargs) -> Tuple[str, str]:
752
+
753
+ if np.shape(example.input) != np.shape(example.output):
754
+ return None, None
755
+ diff_str = self.grid_representer.encode(example.output - example.input)
756
+ diff_str = diff_str.replace(" ", " ").replace("[[ ", "[[").replace("[ ", "[")
757
+ input_str = self.grid_representer.encode(example.input, **kwargs)
758
+
759
+ if self.use_output:
760
+ output_str = self.grid_representer.encode(example.output, **kwargs)
761
+ return (
762
+ f"{self.input_header}{input_str}{self.io_sep}{self.output_header}",
763
+ f"{self.diff_sep}{diff_str}{self.diff_output_sep}{output_str}{self.output_footer}",
764
+ )
765
+ else:
766
+ return (
767
+ f"{self.input_header}{input_str}{self.io_sep}{self.output_header}",
768
+ f"{self.diff_sep}{diff_str}{self.output_footer}",
769
+ )
770
+
771
+ def decode(self, encoded: Tuple[str, str], **kwargs) -> Example:
772
+ input_str, output_str = encoded
773
+ input_str = (
774
+ input_str.replace(self.input_header, "")
775
+ .replace(self.output_header, "")
776
+ .replace(self.io_sep, "")
777
+ .strip()
778
+ )
779
+ output_str = (
780
+ output_str.replace(self.diff_sep.strip(), "")
781
+ .replace(self.output_footer, "")
782
+ .strip()
783
+ )
784
+ input_grid = self.grid_representer.decode(input_str, **kwargs)
785
+ if self.use_output:
786
+ diff_str, output_str = output_str.split(self.diff_output_sep)
787
+ output_grid = self.grid_representer.decode(output_str, **kwargs)
788
+ else:
789
+ diff_grid = self.grid_representer.decode(output_str, **kwargs)
790
+ output_grid = input_grid + diff_grid
791
+ return Example(input=input_grid, output=output_grid)
792
+
793
+ def __repr__(self) -> str:
794
+ return f"DiffExampleRepresenter(io_sep={self.io_sep!r}, input_header={self.input_header!r}, output_header={self.output_header!r}, output_footer={self.output_footer!r}, use_output={self.use_output}, diff_sep={self.diff_sep!r}, diff_output_sep={self.diff_output_sep!r}, grid_representer={repr(self.grid_representer)})"
795
+
796
+ def __str__(self) -> str:
797
+ return repr(self)
798
+
799
+
800
+ # =============== TASK REPRESENTATION ===============
801
+
802
+
803
+ class TextTaskRepresenter(TaskRepresenter):
804
+ def __init__(
805
+ self,
806
+ train_header: str = "==TRAIN==\n",
807
+ train_test_sep: str = "\n\n",
808
+ test_header: str = "==TEST==\n",
809
+ example_sep: str = "\n\n",
810
+ example_representer: ExampleRepresenter = TextExampleRepresenter(),
811
+ ):
812
+ self.train_header = train_header
813
+ self.train_test_sep = train_test_sep
814
+ self.test_header = test_header
815
+ self.example_sep = example_sep
816
+ self.example_representer = example_representer
817
+
818
+ def encode(self, task: Task, **kwargs) -> Tuple[str, str, str]:
819
+ trains = self.train_header
820
+ for train_example in task.train_examples:
821
+ query, output = self.example_representer.encode(train_example, **kwargs)
822
+ trains += query + output
823
+ trains += self.example_sep
824
+
825
+ trains = trains[: -len(self.example_sep)]
826
+
827
+ demonstrations = trains
828
+
829
+ test = self.test_header
830
+ query, output = self.example_representer.encode(task.test_example, **kwargs)
831
+ test += query
832
+
833
+ return demonstrations, test, output
834
+
835
+ def decode(self, encoded: Tuple[str, str, str], **kwargs) -> Task:
836
+ train_examples = []
837
+ demonstrations, test, encoded_output = encoded
838
+
839
+ train_str = demonstrations.replace(self.train_header, "").strip()
840
+ for example_str in train_str.split(
841
+ self.example_sep + self.example_representer.input_header
842
+ ):
843
+ input_str, output_str = example_str.split(
844
+ self.example_representer.io_sep + self.example_representer.output_header
845
+ )
846
+ train_example = self.example_representer.decode(
847
+ (input_str, output_str), **kwargs
848
+ )
849
+ train_examples.append(train_example)
850
+
851
+ test_input_str = test.replace(self.test_header, "")
852
+ test_example = self.example_representer.decode(
853
+ (test_input_str, encoded_output), **kwargs
854
+ )
855
+
856
+ return Task(train_examples=train_examples, test_example=test_example)
857
+
858
+ def __repr__(self) -> str:
859
+ return f"TextTaskRepresenter(train_header={self.train_header!r}, train_test_sep={self.train_test_sep!r}, test_header={self.test_header!r}, example_sep={self.example_sep!r}, example_representer={repr(self.example_representer)})"
860
+
861
+ def __str__(self) -> str:
862
+ return repr(self)
863
+
864
+
865
+ class ImageTaskRepresenter(TaskRepresenter):
866
+ example_representer: ImageExampleRepresenter = ImageExampleRepresenter()
867
+
868
+ def __init__(
869
+ self, example_representer: Optional[ImageExampleRepresenter] = None
870
+ ) -> None:
871
+ if example_representer is not None:
872
+ self.example_representer = example_representer
873
+
874
+ def encode(self, task: Task, show_test_output=False, **kwargs) -> str:
875
+ height = task.max_height()
876
+ width = task.max_width()
877
+ examples = task.train_examples + [task.test_example]
878
+ height = len(examples)
879
+ width = 2
880
+ fig, axes = plt.subplots(len(examples), 2, figsize=(3 * height, 3 * width))
881
+ norm = Normalize(vmin=0, vmax=9)
882
+ for k, example in enumerate(examples):
883
+ input_grid = example.input
884
+ output_grid = example.output
885
+ # Plot input grid
886
+ grid = input_grid
887
+ ax = axes[k, 0]
888
+ ax.imshow(
889
+ grid, cmap=self.example_representer.grid_representer.cmap, norm=norm
890
+ )
891
+ ax.set_title("Input")
892
+ ax.grid(True, which="both", color="lightgrey", linewidth=0.5)
893
+ ax.set_yticks([x - 0.5 for x in range(1 + len(grid))])
894
+ ax.set_xticks([x - 0.5 for x in range(1 + len(grid[0]))])
895
+ ax.set_xticklabels([])
896
+ ax.set_yticklabels([])
897
+
898
+ # Plot output grid
899
+ grid = output_grid
900
+ ax = axes[k, 1]
901
+ if not show_test_output and k == len(examples) - 1:
902
+ # display black image
903
+ grid = np.zeros((1, 1))
904
+ ax.imshow(
905
+ grid, cmap=self.example_representer.grid_representer.cmap, norm=norm
906
+ )
907
+ ax.set_title("(Hidden)")
908
+ ax.grid(True, which="both", color="lightgrey", linewidth=0.5)
909
+ ax.set_yticks([x - 0.5 for x in range(1 + len(grid))])
910
+ ax.set_xticks([x - 0.5 for x in range(1 + len(grid[0]))])
911
+ ax.set_xticklabels([])
912
+ ax.set_yticklabels([])
913
+ else:
914
+ ax = axes[k, 1]
915
+ ax.imshow(
916
+ grid, cmap=self.example_representer.grid_representer.cmap, norm=norm
917
+ )
918
+ ax.set_title("Output")
919
+ ax.grid(True, which="both", color="lightgrey", linewidth=0.5)
920
+ ax.set_yticks([x - 0.5 for x in range(1 + len(grid))])
921
+ ax.set_xticks([x - 0.5 for x in range(1 + len(grid[0]))])
922
+ ax.set_xticklabels([])
923
+ ax.set_yticklabels([])
924
+
925
+ plt.tight_layout()
926
+
927
+ # Save the plot to a BytesIO object
928
+ buf = BytesIO()
929
+ plt.savefig(buf, format="png", bbox_inches="tight")
930
+ plt.close(fig)
931
+ buf.seek(0)
932
+
933
+ # Encode the image in base64
934
+ img_base64 = base64.b64encode(buf.read()).decode("utf-8")
935
+ return img_base64
936
+
937
+ def display(self, encoded_str: str):
938
+ imgdata = base64.b64decode(encoded_str)
939
+ plt.imshow(plt.imread(BytesIO(imgdata)))
940
+ plt.axis("off")
941
+ plt.show()
942
+
943
+ def decode(self, encoded_str: str, **kwargs) -> Example:
944
+ raise NotImplementedError(
945
+ "Decoding for ImageExampleRepresenter is not implemented."
946
+ )
947
+
948
+
949
+ if __name__ == "__main__":
950
+
951
+ example_representer = TextExampleRepresenter(
952
+ grid_representer=WordGridRepresenter(),
953
+ input_header="Input:\n",
954
+ output_header="\nOutput:\n",
955
+ io_sep="\n",
956
+ )
957
+
958
+ grid = np.array([[1, 1, 1], [0, 0, 0], [1, 1, 1]])
959
+
960
+ example = Example(input=grid, output=grid)
961
+ print(example_representer.encode(example))
962
+ assert example == example_representer.decode(example_representer.encode(example))
963
+
964
+ grid = np.array([[1, 1, 1], [0, 0, 0], [1, 1, 1]])
965
+ representer = ConnectedComponentRepresenter()
966
+ print(representer.encode(grid))
967
+ assert np.array_equal(grid, representer.decode(representer.encode(grid)))
968
+ representer.connected_component = 8
969
+ print(representer.encode(grid))
970
+ assert np.array_equal(grid, representer.decode(representer.encode(grid)))
971
+ representer = DelimitedGridRepresenter()
972
+ print(representer.encode(grid))
973
+ assert np.array_equal(grid, representer.decode(representer.encode(grid)))
974
+ representer = PythonListGridRepresenter()
975
+ print(representer.encode(grid))
976
+ assert np.array_equal(grid, representer.decode(representer.encode(grid)))
977
+ representer = CompositeRepresenter(
978
+ [DelimitedGridRepresenter(), ConnectedComponentRepresenter()]
979
+ )
980
+ print(representer.encode(grid))
981
+ example = Example(input=grid, output=grid)
982
+ representer = TextExampleRepresenter()
983
+ print(representer.encode(example))
984
+ assert example == representer.decode(representer.encode(example))
985
+ task = Task(train_examples=[example, example], test_example=example)
986
+ representer = TextTaskRepresenter()
987
+ assert task == representer.decode(representer.encode(task))
988
+ print(representer.encode(task))
989
+ print(representer.display(representer.encode(task)))
990
+
991
+ # image
992
+ representer = ImageGridRepresenter()
993
+ print(representer.encode(grid))
994
+ # save as png
995
+ base64_img = representer.encode(grid)
996
+ with open("output.png", "wb") as f:
997
+ f.write(base64.b64decode(base64_img))
998
+
999
+ representer = ImageExampleRepresenter()
1000
+ print(representer.encode(example))
1001
+ base64_img = representer.encode(Example(input=grid, output=grid))
1002
+ with open("example.png", "wb") as f:
1003
+ f.write(base64.b64decode(base64_img))
1004
+
1005
+ representer = ImageTaskRepresenter()
1006
+ print(representer.encode(task))
1007
+ base64_img = representer.encode(task)
1008
+ with open("task.png", "wb") as f:
1009
+ f.write(base64.b64decode(base64_img))
1010
+
1011
+ grid = np.array([[1, 1, 1], [0, 0, 0], [1, 1, 1]])
1012
+ representer = ConnectedComponentRepresenterV2()
1013
+ print(representer.encode(grid))
1014
+
1015
+ representer = DiffExampleRepresenter(use_output=False)
1016
+ example = Example(input=grid, output=2 * grid)
1017
+ assert example == representer.decode(representer.encode(example))
1018
+ representer_str = repr(representer)
1019
+ print(representer_str)