fusion-bench 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +2 -0
- fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
- fusion_bench/constants/clip_vision.py +22 -0
- fusion_bench/dataset/clip_dataset.py +10 -2
- fusion_bench/dataset/fer2013.py +1 -0
- fusion_bench/dataset/gsm8k.py +2 -2
- fusion_bench/method/__init__.py +10 -0
- fusion_bench/method/ada_svd/clip_vision.py +4 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
- fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
- fusion_bench/method/gossip/__init__.py +3 -0
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
- fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
- fusion_bench/method/gossip/entropy_loss.py +25 -0
- fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
- fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
- fusion_bench/method/gossip/min_norm_solvers.py +227 -0
- fusion_bench/method/gossip/task_wise_gossip.py +265 -0
- fusion_bench/method/gossip/utils.py +74 -0
- fusion_bench/method/isotropic_merging/__init__.py +1 -1
- fusion_bench/method/opcm/opcm.py +16 -7
- fusion_bench/method/pwe_moe/module.py +1 -1
- fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
- fusion_bench/method/regmean/regmean.py +25 -17
- fusion_bench/method/smile_upscaling/__init__.py +1 -1
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +46 -145
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +229 -0
- fusion_bench/method/smile_upscaling/smile_upscaling.py +19 -346
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
- fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
- fusion_bench/method/ties_merging/ties_merging.py +36 -31
- fusion_bench/method/we_moe/we_moe.py +14 -15
- fusion_bench/mixins/__init__.py +6 -3
- fusion_bench/mixins/hydra_config.py +49 -0
- fusion_bench/mixins/openclip_classification.py +11 -0
- fusion_bench/mixins/simple_profiler.py +4 -2
- fusion_bench/modelpool/__init__.py +3 -1
- fusion_bench/modelpool/base_pool.py +2 -2
- fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
- fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
- fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +2 -203
- fusion_bench/models/modeling_smile_qwen2/__init__.py +8 -0
- fusion_bench/models/modeling_smile_qwen2/configuration_smile_qwen2.py +21 -0
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +922 -0
- fusion_bench/models/modeling_smile_qwen2/register.py +11 -0
- fusion_bench/models/open_clip/__init__.py +6 -0
- fusion_bench/models/open_clip/modeling.py +176 -0
- fusion_bench/models/open_clip/utils.py +311 -0
- fusion_bench/models/open_clip/variables_and_paths.py +56 -0
- fusion_bench/models/parameter_dict.py +54 -13
- fusion_bench/models/rankone_moe.py +2 -88
- fusion_bench/models/smile_moe/linear_from_hf_config.py +373 -0
- fusion_bench/models/smile_moe/{linear.py → linear_from_module.py} +103 -33
- fusion_bench/models/smile_moe/utils/__init__.py +24 -0
- fusion_bench/models/smile_moe/utils/svd_utils.py +46 -0
- fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
- fusion_bench/taskpool/__init__.py +7 -3
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
- fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
- fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
- fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
- fusion_bench/taskpool/gpt2_text_classification.py +30 -1
- fusion_bench/taskpool/lm_eval_harness/__init__.py +3 -0
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +87 -0
- fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
- fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
- fusion_bench/utils/data.py +12 -0
- fusion_bench/utils/devices.py +14 -0
- fusion_bench/utils/instantiate.py +12 -0
- fusion_bench/utils/misc.py +9 -2
- fusion_bench/utils/packages.py +14 -0
- fusion_bench/utils/parameters.py +1 -1
- fusion_bench/utils/tensorboard.py +1 -1
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/METADATA +22 -2
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/RECORD +209 -157
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/WHEEL +1 -1
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
- fusion_bench_config/fabric/auto.yaml +0 -1
- fusion_bench_config/fabric/llama_ddp.yaml +0 -1
- fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric_model_fusion.yaml +0 -1
- fusion_bench_config/llama_full_finetune.yaml +0 -2
- fusion_bench_config/llama_model_fusion.yaml +0 -2
- fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
- fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
- fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
- fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
- fusion_bench_config/method/adamerging.yaml +2 -2
- fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
- fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
- fusion_bench_config/method/dare/simple_average.yaml +0 -1
- fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
- fusion_bench_config/method/dare/ties_merging.yaml +0 -2
- fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
- fusion_bench_config/method/doge_ta/doge_ta.yaml +1 -1
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
- fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
- fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
- fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
- fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
- fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
- fusion_bench_config/method/linear/llama_expo.yaml +0 -3
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
- fusion_bench_config/method/linear/weighted_average.yaml +0 -1
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
- fusion_bench_config/method/model_recombination.yaml +0 -1
- fusion_bench_config/method/opcm/opcm.yaml +0 -1
- fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
- fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
- fusion_bench_config/method/opcm/weight_average.yaml +0 -1
- fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
- fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
- fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
- fusion_bench_config/method/slerp/slerp.yaml +0 -2
- fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +5 -2
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +13 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
- fusion_bench_config/method/task_arithmetic.yaml +1 -1
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
- fusion_bench_config/method/ties_merging.yaml +1 -1
- fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
- fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +17 -0
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +1 -3
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
- fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
- fusion_bench_config/nyuv2_config.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml +12 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
- fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
|
2
|
+
|
|
3
|
+
from .configuration_smile_qwen2 import SmileQwen2Config
|
|
4
|
+
from .modeling_smile_qwen2 import (
|
|
5
|
+
SmileQwen2ForCausalLM,
|
|
6
|
+
SmileQwen2Model,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
AutoConfig.register("smile_qwen2", SmileQwen2Config)
|
|
10
|
+
AutoModel.register(SmileQwen2Config, SmileQwen2Model)
|
|
11
|
+
AutoModelForCausalLM.register(SmileQwen2Config, SmileQwen2ForCausalLM)
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
from typing import Callable, List
|
|
2
|
+
|
|
3
|
+
import open_clip
|
|
4
|
+
import torch
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
from . import utils
|
|
8
|
+
from .variables_and_paths import CACHEDIR, MODELS, OPENCLIP_CACHEDIR
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ImageEncoder(torch.nn.Module):
|
|
12
|
+
R"""
|
|
13
|
+
Examples:
|
|
14
|
+
|
|
15
|
+
load the image encoder for a given model name
|
|
16
|
+
|
|
17
|
+
>>> from fusion_bench.models.open_clip import ImageEncoder
|
|
18
|
+
>>> image_encoder = ImageEncoder(model_name="ViT-B-32")
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, model_name: str, keep_lang=False):
|
|
22
|
+
super().__init__()
|
|
23
|
+
assert (
|
|
24
|
+
model_name in MODELS
|
|
25
|
+
), f"Invalid model name: {model_name}. Valid models are: {MODELS}"
|
|
26
|
+
|
|
27
|
+
if "__pretrained__" in model_name:
|
|
28
|
+
name, pretrained = model_name.split("__pretrained__")
|
|
29
|
+
elif "__init__" in model_name:
|
|
30
|
+
print("Using random initialization.")
|
|
31
|
+
name, pretrained = model_name.split("__init__")[0], None
|
|
32
|
+
else:
|
|
33
|
+
name = model_name
|
|
34
|
+
pretrained = "openai"
|
|
35
|
+
(
|
|
36
|
+
self.model,
|
|
37
|
+
self.train_preprocess,
|
|
38
|
+
self.val_preprocess,
|
|
39
|
+
) = open_clip.create_model_and_transforms(
|
|
40
|
+
name, pretrained=pretrained, cache_dir=OPENCLIP_CACHEDIR
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
self.cache_dir = CACHEDIR
|
|
44
|
+
|
|
45
|
+
if not keep_lang and hasattr(self.model, "transformer"):
|
|
46
|
+
delattr(self.model, "transformer")
|
|
47
|
+
|
|
48
|
+
def forward(self, images):
|
|
49
|
+
assert self.model is not None
|
|
50
|
+
return self.model.encode_image(images)
|
|
51
|
+
|
|
52
|
+
def __call__(self, inputs):
|
|
53
|
+
return self.forward(inputs)
|
|
54
|
+
|
|
55
|
+
def save(self, filename):
|
|
56
|
+
print(f"Saving image encoder to {filename}")
|
|
57
|
+
utils.torch_save(self, filename)
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def load(cls, model_name, filename):
|
|
61
|
+
print(f"Loading image encoder from {filename}")
|
|
62
|
+
|
|
63
|
+
state_dict = torch.load(filename, map_location="cpu")
|
|
64
|
+
|
|
65
|
+
model = cls(model_name)
|
|
66
|
+
model.load_state_dict(state_dict)
|
|
67
|
+
return model
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class ClassificationHead(torch.nn.Linear):
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
normalize: bool,
|
|
74
|
+
weights: Tensor,
|
|
75
|
+
biases: Tensor = None,
|
|
76
|
+
):
|
|
77
|
+
output_size, input_size = weights.shape
|
|
78
|
+
super().__init__(input_size, output_size)
|
|
79
|
+
self.normalize = normalize
|
|
80
|
+
if weights is not None:
|
|
81
|
+
self.weight = torch.nn.Parameter(weights.clone())
|
|
82
|
+
if biases is not None:
|
|
83
|
+
self.bias = torch.nn.Parameter(biases.clone())
|
|
84
|
+
else:
|
|
85
|
+
self.bias = torch.nn.Parameter(torch.zeros_like(self.bias))
|
|
86
|
+
|
|
87
|
+
def forward(self, inputs: Tensor):
|
|
88
|
+
if self.normalize:
|
|
89
|
+
inputs = inputs / inputs.norm(dim=-1, keepdim=True)
|
|
90
|
+
return super().forward(inputs)
|
|
91
|
+
|
|
92
|
+
def __call__(self, inputs: Tensor):
|
|
93
|
+
return self.forward(inputs)
|
|
94
|
+
|
|
95
|
+
def save(self, filename):
|
|
96
|
+
print(f"Saving classification head to {filename}")
|
|
97
|
+
utils.torch_save(self, filename, save_state_dict=False)
|
|
98
|
+
|
|
99
|
+
@classmethod
|
|
100
|
+
def load(cls, filename):
|
|
101
|
+
# print(f"Loading classification head from {filename}")
|
|
102
|
+
return utils.torch_load(filename)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class ImageClassifier(torch.nn.Module):
|
|
106
|
+
train_preprocess: Callable
|
|
107
|
+
val_preprocess: Callable
|
|
108
|
+
|
|
109
|
+
def __init__(
|
|
110
|
+
self,
|
|
111
|
+
image_encoder: ImageEncoder,
|
|
112
|
+
classification_head: ClassificationHead,
|
|
113
|
+
):
|
|
114
|
+
super().__init__()
|
|
115
|
+
self.image_encoder = image_encoder
|
|
116
|
+
self.classification_head = classification_head
|
|
117
|
+
if self.image_encoder is not None:
|
|
118
|
+
self.train_preprocess = self.image_encoder.train_preprocess
|
|
119
|
+
self.val_preprocess = self.image_encoder.val_preprocess
|
|
120
|
+
|
|
121
|
+
def freeze_head(self):
|
|
122
|
+
self.classification_head.weight.requires_grad_(False)
|
|
123
|
+
self.classification_head.bias.requires_grad_(False)
|
|
124
|
+
|
|
125
|
+
def forward(self, inputs: Tensor):
|
|
126
|
+
features = self.image_encoder(inputs)
|
|
127
|
+
outputs = self.classification_head(features)
|
|
128
|
+
return outputs
|
|
129
|
+
|
|
130
|
+
def __call__(self, inputs):
|
|
131
|
+
return self.forward(inputs)
|
|
132
|
+
|
|
133
|
+
def save(self, filename):
|
|
134
|
+
print(f"Saving image classifier to {filename}")
|
|
135
|
+
utils.torch_save(self, filename)
|
|
136
|
+
|
|
137
|
+
@classmethod
|
|
138
|
+
def load(cls, filename):
|
|
139
|
+
print(f"Loading image classifier from {filename}")
|
|
140
|
+
return utils.torch_load(filename)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class MultiHeadImageClassifier(torch.nn.Module):
|
|
144
|
+
def __init__(
|
|
145
|
+
self,
|
|
146
|
+
image_encoder: ImageEncoder,
|
|
147
|
+
classification_heads: List[ClassificationHead],
|
|
148
|
+
):
|
|
149
|
+
super().__init__()
|
|
150
|
+
self.image_encoder = image_encoder
|
|
151
|
+
self.classification_heads = torch.nn.ModuleList(classification_heads)
|
|
152
|
+
if self.image_encoder is not None:
|
|
153
|
+
self.train_preprocess = self.image_encoder.train_preprocess
|
|
154
|
+
self.val_preprocess = self.image_encoder.val_preprocess
|
|
155
|
+
|
|
156
|
+
def freeze_head(self):
|
|
157
|
+
for idx in range(len(self.classification_heads)):
|
|
158
|
+
self.classification_heads[idx].weight.requires_grad_(False)
|
|
159
|
+
self.classification_heads[idx].bias.requires_grad_(False)
|
|
160
|
+
|
|
161
|
+
def forward(self, inputs, head_idx):
|
|
162
|
+
features = self.image_encoder(inputs)
|
|
163
|
+
outputs = self.classification_heads[head_idx](features)
|
|
164
|
+
return outputs
|
|
165
|
+
|
|
166
|
+
def __call__(self, inputs, head_idx):
|
|
167
|
+
return self.forward(inputs, head_idx)
|
|
168
|
+
|
|
169
|
+
def save(self, filename):
|
|
170
|
+
print(f"Saving image classifier to {filename}")
|
|
171
|
+
utils.torch_save(self, filename)
|
|
172
|
+
|
|
173
|
+
@classmethod
|
|
174
|
+
def load(cls, filename):
|
|
175
|
+
print(f"Loading image classifier from {filename}")
|
|
176
|
+
return utils.torch_load(filename)
|
|
@@ -0,0 +1,311 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import os
|
|
3
|
+
import pickle
|
|
4
|
+
from collections import OrderedDict
|
|
5
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def compute_l1_norm(
|
|
13
|
+
model1: nn.Module, model2: nn.Module
|
|
14
|
+
) -> Tuple[torch.Tensor, Dict[str, float]]:
|
|
15
|
+
"""
|
|
16
|
+
Computes the L1 norm between the parameters of two models.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
model1 (nn.Module): The first model.
|
|
20
|
+
model2 (nn.Module): The second model.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
Tuple[torch.Tensor, Dict[str, float]]: A tuple containing the total L1 norm and a dictionary
|
|
24
|
+
with the L1 norm for each layer.
|
|
25
|
+
|
|
26
|
+
"""
|
|
27
|
+
norms = dict()
|
|
28
|
+
l1_norm = 0.0
|
|
29
|
+
for (n, p1), p2 in zip(model1.named_parameters(), model2.parameters()):
|
|
30
|
+
layer_l1_norm = torch.norm(p1 - p2, 1)
|
|
31
|
+
l1_norm += layer_l1_norm
|
|
32
|
+
norms[n] = layer_l1_norm.item()
|
|
33
|
+
|
|
34
|
+
return l1_norm, norms
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def assign_learning_rate(param_group, new_lr):
|
|
38
|
+
param_group["lr"] = new_lr
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _warmup_lr(base_lr, warmup_length, step):
|
|
42
|
+
return base_lr * (step + 1) / warmup_length
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def cosine_lr(optimizer, base_lrs, warmup_length, steps):
|
|
46
|
+
if not isinstance(base_lrs, list):
|
|
47
|
+
base_lrs = [base_lrs for _ in optimizer.param_groups]
|
|
48
|
+
assert len(base_lrs) == len(optimizer.param_groups)
|
|
49
|
+
|
|
50
|
+
def _lr_adjuster(step):
|
|
51
|
+
for param_group, base_lr in zip(optimizer.param_groups, base_lrs):
|
|
52
|
+
if step < warmup_length:
|
|
53
|
+
lr = _warmup_lr(base_lr, warmup_length, step)
|
|
54
|
+
else:
|
|
55
|
+
e = step - warmup_length
|
|
56
|
+
es = steps - warmup_length
|
|
57
|
+
lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
|
|
58
|
+
assign_learning_rate(param_group, lr)
|
|
59
|
+
|
|
60
|
+
return _lr_adjuster
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def accuracy(output: torch.Tensor, target: torch.Tensor, topk: List[int] = (1,)):
|
|
64
|
+
pred = output.topk(max(topk), 1, True, True)[1].t()
|
|
65
|
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
|
66
|
+
return [
|
|
67
|
+
float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
|
|
68
|
+
for k in topk
|
|
69
|
+
]
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def torch_load_old(save_path: str, device=None):
|
|
73
|
+
with open(save_path, "rb") as f:
|
|
74
|
+
classifier = pickle.load(f)
|
|
75
|
+
if device is not None:
|
|
76
|
+
classifier = classifier.to(device)
|
|
77
|
+
return classifier
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def torch_save(model, save_path, save_state_dict=True):
|
|
81
|
+
# TODO: hacky way to save state dict
|
|
82
|
+
if save_state_dict and isinstance(model, torch.nn.Module):
|
|
83
|
+
model = model.state_dict()
|
|
84
|
+
if os.path.dirname(save_path) != "":
|
|
85
|
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
86
|
+
torch.save(model, save_path)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def torch_load(save_path, device=None):
|
|
90
|
+
model = torch.load(save_path, map_location="cpu")
|
|
91
|
+
if device is not None:
|
|
92
|
+
model = model.to(device)
|
|
93
|
+
return model
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def get_logits(inputs, classifier):
|
|
97
|
+
assert callable(classifier)
|
|
98
|
+
if hasattr(classifier, "to"):
|
|
99
|
+
classifier = classifier.to(inputs.device)
|
|
100
|
+
return classifier(inputs)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def get_probs(inputs, classifier):
|
|
104
|
+
if hasattr(classifier, "predict_proba"):
|
|
105
|
+
probs = classifier.predict_proba(inputs.detach().cpu().numpy())
|
|
106
|
+
return torch.from_numpy(probs)
|
|
107
|
+
logits = get_logits(inputs, classifier)
|
|
108
|
+
return logits.softmax(dim=1)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class LabelSmoothing(torch.nn.Module):
|
|
112
|
+
def __init__(self, smoothing=0.0):
|
|
113
|
+
super(LabelSmoothing, self).__init__()
|
|
114
|
+
self.confidence = 1.0 - smoothing
|
|
115
|
+
self.smoothing = smoothing
|
|
116
|
+
|
|
117
|
+
def forward(self, x, target):
|
|
118
|
+
logprobs = torch.nn.functional.log_softmax(x, dim=-1)
|
|
119
|
+
|
|
120
|
+
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
|
|
121
|
+
nll_loss = nll_loss.squeeze(1)
|
|
122
|
+
smooth_loss = -logprobs.mean(dim=-1)
|
|
123
|
+
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
|
|
124
|
+
return loss.mean()
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class DotDict(dict):
|
|
128
|
+
"""dot.notation access to dictionary attributes"""
|
|
129
|
+
|
|
130
|
+
__getattr__ = dict.get
|
|
131
|
+
__setattr__ = dict.__setitem__
|
|
132
|
+
__delattr__ = dict.__delitem__
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def find_optimal_coef(
|
|
136
|
+
results: Dict[str, Any],
|
|
137
|
+
metric: str = "avg_normalized_top1",
|
|
138
|
+
minimize: bool = False,
|
|
139
|
+
control_metric: Optional[str] = None,
|
|
140
|
+
control_metric_threshold: float = 0.0,
|
|
141
|
+
) -> float:
|
|
142
|
+
"""
|
|
143
|
+
Finds the optimal coefficient based on the given results and metric.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
results (Dict[str, Any]): A dictionary containing the results for different scaling coefficients.
|
|
147
|
+
metric (str, optional): The metric to optimize. Defaults to "avg_normalized_top1".
|
|
148
|
+
minimize (bool, optional): Whether to minimize the metric. Defaults to False.
|
|
149
|
+
control_metric (str, optional): The control metric to check against. Defaults to None.
|
|
150
|
+
control_metric_threshold (float, optional): The threshold value for the control metric. Defaults to 0.0.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
The optimal coefficient based on the given results and metric.
|
|
154
|
+
"""
|
|
155
|
+
best_coef = None
|
|
156
|
+
if minimize:
|
|
157
|
+
best_metric = 1
|
|
158
|
+
else:
|
|
159
|
+
best_metric = 0
|
|
160
|
+
for scaling_coef in results.keys():
|
|
161
|
+
if control_metric is not None:
|
|
162
|
+
if results[scaling_coef][control_metric] < control_metric_threshold:
|
|
163
|
+
print(f"Control metric fell below {control_metric_threshold} threshold")
|
|
164
|
+
continue
|
|
165
|
+
if minimize:
|
|
166
|
+
if results[scaling_coef][metric] < best_metric:
|
|
167
|
+
best_metric = results[scaling_coef][metric]
|
|
168
|
+
best_coef = scaling_coef
|
|
169
|
+
else:
|
|
170
|
+
if results[scaling_coef][metric] > best_metric:
|
|
171
|
+
best_metric = results[scaling_coef][metric]
|
|
172
|
+
best_coef = scaling_coef
|
|
173
|
+
return best_coef
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def nonlinear_advantage(nonlinear_acc, linear_acc, num_classes):
|
|
177
|
+
"""Computes the normalized non-linear advantage of a finetuned model.
|
|
178
|
+
|
|
179
|
+
The nonlinear_advantage is defined as:
|
|
180
|
+
error_rate(linear_model) - error_rate(nonlinear_model) / (1 - 1 / num_classes)
|
|
181
|
+
and takes values between [-1, 1]. A value of 0 indicates that the nonlinear
|
|
182
|
+
model is no better than the linear one. Meanwhile, a value of 1 indicates
|
|
183
|
+
that the nonlinear model is perfect and the linear trivial, and a value of
|
|
184
|
+
-1 indicates the opposite.
|
|
185
|
+
"""
|
|
186
|
+
return (nonlinear_acc - linear_acc) / (1.0 - 1.0 / num_classes)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def to_cuda(input_dict):
|
|
190
|
+
cuda_dict = {}
|
|
191
|
+
for key, value in input_dict.items():
|
|
192
|
+
cuda_dict[key] = value.to("cuda")
|
|
193
|
+
return cuda_dict
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def state_dict_to_vector(state_dict, remove_keys=[]):
|
|
197
|
+
shared_state_dict = copy.deepcopy(state_dict)
|
|
198
|
+
for key in remove_keys:
|
|
199
|
+
if key in shared_state_dict:
|
|
200
|
+
del shared_state_dict[key]
|
|
201
|
+
sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))
|
|
202
|
+
return torch.nn.utils.parameters_to_vector(
|
|
203
|
+
[value.reshape(-1) for key, value in sorted_shared_state_dict.items()]
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def vector_to_state_dict(vector, state_dict, remove_keys=[]):
|
|
208
|
+
# create a reference dict to define the order of the vector
|
|
209
|
+
reference_dict = copy.deepcopy(state_dict)
|
|
210
|
+
for key in remove_keys:
|
|
211
|
+
if key in reference_dict:
|
|
212
|
+
del reference_dict[key]
|
|
213
|
+
sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))
|
|
214
|
+
|
|
215
|
+
# create a shared state dict using the reference dict
|
|
216
|
+
torch.nn.utils.vector_to_parameters(vector, sorted_reference_dict.values())
|
|
217
|
+
|
|
218
|
+
# add back the encoder and decoder embedding weights.
|
|
219
|
+
if "transformer.shared.weight" in sorted_reference_dict:
|
|
220
|
+
for key in remove_keys:
|
|
221
|
+
sorted_reference_dict[key] = sorted_reference_dict[
|
|
222
|
+
"transformer.shared.weight"
|
|
223
|
+
]
|
|
224
|
+
return sorted_reference_dict
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def add_ptm_to_tv(tv_dict, ptm_dict):
|
|
228
|
+
assert set(tv_dict.keys()) == set(
|
|
229
|
+
ptm_dict.keys()
|
|
230
|
+
), "Differing parameter names in models."
|
|
231
|
+
final_dict = copy.deepcopy(tv_dict)
|
|
232
|
+
for k, v in ptm_dict.items():
|
|
233
|
+
final_dict[k] = tv_dict[k] + v
|
|
234
|
+
return final_dict
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def check_parameterNamesMatch(checkpoints):
|
|
238
|
+
parameter_names = set(checkpoints[0].keys())
|
|
239
|
+
|
|
240
|
+
if len(checkpoints) >= 2:
|
|
241
|
+
# raise ValueError("Number of models is less than 2.")
|
|
242
|
+
for checkpoint in checkpoints[1:]:
|
|
243
|
+
current_parameterNames = set(checkpoint.keys())
|
|
244
|
+
if current_parameterNames != parameter_names:
|
|
245
|
+
raise ValueError(
|
|
246
|
+
"Differing parameter names in models. "
|
|
247
|
+
f"The different parameters are {parameter_names.symmetric_difference(current_parameterNames)}"
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def check_state_dicts_equal(state_dict1, state_dict2):
|
|
252
|
+
if set(state_dict1.keys()) != set(state_dict2.keys()):
|
|
253
|
+
return False
|
|
254
|
+
|
|
255
|
+
for key in state_dict1.keys():
|
|
256
|
+
if not torch.equal(state_dict1[key], state_dict2[key]):
|
|
257
|
+
return False
|
|
258
|
+
|
|
259
|
+
return True
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def topk_values_mask(M, K=0.7, return_mask=False, reshape_mask=False):
|
|
263
|
+
if K == 100:
|
|
264
|
+
# print("Not applying mask")
|
|
265
|
+
if return_mask:
|
|
266
|
+
return M, torch.ones_like(M), None
|
|
267
|
+
else:
|
|
268
|
+
return M, torch.ones_like(M)
|
|
269
|
+
|
|
270
|
+
if K >= 1:
|
|
271
|
+
K /= 100
|
|
272
|
+
|
|
273
|
+
original_shape = M.shape
|
|
274
|
+
if M.dim() == 1:
|
|
275
|
+
M = M.unsqueeze(0)
|
|
276
|
+
|
|
277
|
+
n, d = M.shape
|
|
278
|
+
k = int(d * K)
|
|
279
|
+
k = d - k # Keep top k elements instead of bottom k elements
|
|
280
|
+
|
|
281
|
+
# Find the k-th smallest element by magnitude for each row
|
|
282
|
+
kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True)
|
|
283
|
+
# Create a mask tensor with True for the top k elements in each row
|
|
284
|
+
mask = M.abs() >= kth_values
|
|
285
|
+
final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask
|
|
286
|
+
|
|
287
|
+
if reshape_mask:
|
|
288
|
+
final_mask = final_mask.reshape(M.shape)
|
|
289
|
+
|
|
290
|
+
if return_mask:
|
|
291
|
+
return M * final_mask, final_mask.float().mean(dim=1), final_mask
|
|
292
|
+
else:
|
|
293
|
+
return M * final_mask, final_mask.float().mean(dim=1)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def cleanup_linear(state_dict):
|
|
297
|
+
# The linear model also has keys for the reference point $\theta_0$ in the state dict with the prefix `params0`.
|
|
298
|
+
state_dict = {k: v for k, v in state_dict.items() if "params." in k}
|
|
299
|
+
return state_dict
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def get_ptm_linear(state_dict: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
303
|
+
# rename keys so that they match afterwards
|
|
304
|
+
state_dict_new = {
|
|
305
|
+
k.replace("params0", "params"): v
|
|
306
|
+
for k, v in state_dict.items()
|
|
307
|
+
if "params0." in k
|
|
308
|
+
}
|
|
309
|
+
state_dict_remaining = {k: v for k, v in state_dict.items() if "params." not in k}
|
|
310
|
+
|
|
311
|
+
return state_dict_new, state_dict_remaining
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}{bar:-10b}"
|
|
5
|
+
MODELS = ["ViT-B-32", "ViT-B-16", "ViT-L-14"]
|
|
6
|
+
OPENCLIP_CACHEDIR = Path(Path.home(), "openclip-cachedir", "open_clip").as_posix()
|
|
7
|
+
CACHEDIR = None
|
|
8
|
+
|
|
9
|
+
ALL_DATASETS = [
|
|
10
|
+
"Cars",
|
|
11
|
+
"DTD",
|
|
12
|
+
"EuroSAT",
|
|
13
|
+
"GTSRB",
|
|
14
|
+
"MNIST",
|
|
15
|
+
"RESISC45",
|
|
16
|
+
"SVHN",
|
|
17
|
+
"SUN397",
|
|
18
|
+
"STL10",
|
|
19
|
+
"OxfordIIITPet",
|
|
20
|
+
"Flowers102",
|
|
21
|
+
"CIFAR100",
|
|
22
|
+
"PCAM",
|
|
23
|
+
"FER2013",
|
|
24
|
+
"CIFAR10",
|
|
25
|
+
"Food101",
|
|
26
|
+
"FashionMNIST",
|
|
27
|
+
"RenderedSST2",
|
|
28
|
+
"EMNIST",
|
|
29
|
+
"KMNIST",
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
DATASETS_8 = ALL_DATASETS[:8]
|
|
33
|
+
DATASETS_14 = ALL_DATASETS[:14]
|
|
34
|
+
DATASETS_20 = ALL_DATASETS[:20]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def cleanup_dataset_name(dataset_name: str):
|
|
38
|
+
return dataset_name.replace("Val", "") + "Val"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def get_zeroshot_path(root, dataset, model):
|
|
42
|
+
return Path(
|
|
43
|
+
root, model, cleanup_dataset_name(dataset), f"nonlinear_zeroshot.pt"
|
|
44
|
+
).as_posix()
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def get_finetuned_path(root, dataset, model):
|
|
48
|
+
return Path(
|
|
49
|
+
root, model, cleanup_dataset_name(dataset), f"nonlinear_finetuned.pt"
|
|
50
|
+
).as_posix()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_single_task_accuracies_path(model):
|
|
54
|
+
return Path(
|
|
55
|
+
"results/single_task", model, f"nonlinear_ft_accuracies.json"
|
|
56
|
+
).as_posix()
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import List, Mapping
|
|
1
|
+
from typing import List, Mapping, Optional, Tuple
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
from torch import nn
|
|
@@ -6,7 +6,13 @@ from torch import nn
|
|
|
6
6
|
__all__ = "ParamterDictModel"
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
def
|
|
9
|
+
def _set_attr(
|
|
10
|
+
obj,
|
|
11
|
+
names: List[str],
|
|
12
|
+
val,
|
|
13
|
+
check_parent: bool = False,
|
|
14
|
+
parent_builder=nn.Module,
|
|
15
|
+
):
|
|
10
16
|
"""
|
|
11
17
|
Sets an attribute of an object recursively.
|
|
12
18
|
|
|
@@ -20,8 +26,14 @@ def set_attr(obj, names: List[str], val, check_parent: bool = False):
|
|
|
20
26
|
setattr(obj, names[0], val)
|
|
21
27
|
else:
|
|
22
28
|
if check_parent and not hasattr(obj, names[0]):
|
|
23
|
-
setattr(obj, names[0],
|
|
24
|
-
|
|
29
|
+
setattr(obj, names[0], parent_builder())
|
|
30
|
+
_set_attr(
|
|
31
|
+
getattr(obj, names[0]),
|
|
32
|
+
names[1:],
|
|
33
|
+
val,
|
|
34
|
+
check_parent=check_parent,
|
|
35
|
+
parent_builder=parent_builder,
|
|
36
|
+
)
|
|
25
37
|
|
|
26
38
|
|
|
27
39
|
def has_attr(obj, names: List[str]):
|
|
@@ -49,17 +61,19 @@ class ParameterDictModel(nn.Module):
|
|
|
49
61
|
|
|
50
62
|
def __init__(
|
|
51
63
|
self,
|
|
52
|
-
parameters: Mapping[str, nn.Parameter],
|
|
64
|
+
parameters: Optional[Mapping[str, nn.Parameter]] = None,
|
|
53
65
|
):
|
|
54
66
|
super().__init__()
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
67
|
+
if parameters is not None:
|
|
68
|
+
for name, param in parameters.items():
|
|
69
|
+
assert isinstance(param, nn.Parameter), f"{name} is not a nn.Parameter"
|
|
70
|
+
_set_attr(
|
|
71
|
+
self,
|
|
72
|
+
name.split("."),
|
|
73
|
+
param,
|
|
74
|
+
check_parent=True,
|
|
75
|
+
parent_builder=self.__class__,
|
|
76
|
+
)
|
|
63
77
|
|
|
64
78
|
def __repr__(self):
|
|
65
79
|
"""
|
|
@@ -73,3 +87,30 @@ class ParameterDictModel(nn.Module):
|
|
|
73
87
|
param_repr = f"{name}: {param.size()}"
|
|
74
88
|
param_reprs.append(param_repr)
|
|
75
89
|
return f"{self.__class__.__name__}({', '.join(param_reprs)})"
|
|
90
|
+
|
|
91
|
+
def __getitem__(self, key: str):
|
|
92
|
+
if not has_attr(self, key.split(".")):
|
|
93
|
+
raise KeyError(f"Key {key} not found in {self}")
|
|
94
|
+
key = key.split(".")
|
|
95
|
+
obj = self
|
|
96
|
+
for k in key:
|
|
97
|
+
obj = getattr(obj, k)
|
|
98
|
+
return obj
|
|
99
|
+
|
|
100
|
+
def __setitem__(self, key: str, value: nn.Parameter):
|
|
101
|
+
if not has_attr(self, key.split(".")):
|
|
102
|
+
_set_attr(self, key.split("."), value, check_parent=True)
|
|
103
|
+
else:
|
|
104
|
+
_set_attr(self, key.split("."), value, check_parent=False)
|
|
105
|
+
|
|
106
|
+
def __contains__(self, key: str):
|
|
107
|
+
return has_attr(self, key.split("."))
|
|
108
|
+
|
|
109
|
+
def keys(self):
|
|
110
|
+
return [name for name, _ in self.named_parameters()]
|
|
111
|
+
|
|
112
|
+
def items(self) -> List[Tuple[str, nn.Parameter]]:
|
|
113
|
+
return [(name, self[name]) for name in self.keys()]
|
|
114
|
+
|
|
115
|
+
def values(self) -> List[nn.Parameter]:
|
|
116
|
+
return [self[name] for name in self.keys()]
|