fusion-bench 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. fusion_bench/compat/method/__init__.py +5 -0
  2. fusion_bench/dataset/fer2013.py +0 -1
  3. fusion_bench/method/__init__.py +10 -0
  4. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  5. fusion_bench/method/concrete_subspace/__init__.py +8 -0
  6. fusion_bench/method/concrete_subspace/clip_post_defense.py +744 -0
  7. fusion_bench/method/concrete_subspace/clip_safe_concrete_adamerging.py +832 -0
  8. fusion_bench/method/doge_ta/__init__.py +2 -0
  9. fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +46 -0
  10. fusion_bench/method/doge_ta/doge_ta.py +364 -0
  11. fusion_bench/method/doge_ta/layer_wise_adamerging.py +250 -0
  12. fusion_bench/method/isotropic_merging/__init__.py +1 -1
  13. fusion_bench/method/isotropic_merging/iso.py +2 -2
  14. fusion_bench/method/opcm/opcm.py +93 -84
  15. fusion_bench/method/opcm/task_arithmetic.py +35 -21
  16. fusion_bench/method/opcm/ties_merging.py +71 -52
  17. fusion_bench/method/task_singular_vector/TSVM.py +3 -3
  18. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -46
  19. fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py +416 -0
  20. {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/METADATA +15 -2
  21. {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/RECORD +32 -19
  22. {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/WHEEL +1 -1
  23. fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +38 -0
  24. fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +41 -0
  25. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +39 -0
  26. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +40 -0
  27. fusion_bench_config/method/doge_ta/doge_ta.yaml +4 -0
  28. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +8 -8
  29. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +68 -0
  30. {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/entry_points.txt +0 -0
  31. {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info/licenses}/LICENSE +0 -0
  32. {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,7 @@ from tqdm.auto import tqdm
15
15
  from transformers import CLIPVisionModel
16
16
 
17
17
  from fusion_bench import BaseAlgorithm, BaseModelPool
18
- from fusion_bench.mixins import LightningFabricMixin
18
+ from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
19
19
  from fusion_bench.taskpool import CLIPVisionModelTaskPool
20
20
  from fusion_bench.utils import instantiate
21
21
  from fusion_bench.utils.json import load_from_json, save_to_json
@@ -31,6 +31,7 @@ if TYPE_CHECKING:
31
31
  class OPCMForCLIP(
32
32
  BaseAlgorithm,
33
33
  LightningFabricMixin,
34
+ SimpleProfilerMixin,
34
35
  ):
35
36
  def __init__(
36
37
  self,
@@ -64,7 +65,8 @@ class OPCMForCLIP(
64
65
  L.seed_everything(self.seed)
65
66
  accelerator = self.fabric.device
66
67
 
67
- pretrained_model = modelpool.load_pretrained_model()
68
+ with self.profile("loading model"):
69
+ pretrained_model = modelpool.load_pretrained_model()
68
70
 
69
71
  model_names = modelpool.model_names
70
72
  if self.shuffle_order:
@@ -83,15 +85,17 @@ class OPCMForCLIP(
83
85
  )
84
86
 
85
87
  # get the average model
86
- merged_model = modelpool.load_model(model_names[0])
88
+ with self.profile("loading model"):
89
+ merged_model = modelpool.load_model(model_names[0])
87
90
 
88
91
  if self.evaluate_on_every_step:
89
- self.taskpool._is_setup = False
90
- self.taskpool._test_datasets = DictConfig(
91
- {model_names[0]: self._test_datasets[model_names[0]]}
92
- )
93
- report = self.taskpool.evaluate(deepcopy(merged_model))
94
- save_to_json(report, Path(self.log_dir) / "report_0.json")
92
+ with self.profile("evaluating model"):
93
+ self.taskpool._is_setup = False
94
+ self.taskpool._test_datasets = DictConfig(
95
+ {model_names[0]: self._test_datasets[model_names[0]]}
96
+ )
97
+ report = self.taskpool.evaluate(deepcopy(merged_model))
98
+ save_to_json(report, Path(self.log_dir) / "report_0.json")
95
99
 
96
100
  self.avg_task_vector_norm = get_task_vector_norm(merged_model, pretrained_model)
97
101
  self.all_task_vector_norm = [self.avg_task_vector_norm]
@@ -113,90 +117,95 @@ class OPCMForCLIP(
113
117
  enumerate(model_names[1:]), desc="Processing models"
114
118
  ):
115
119
  model_idx += 1
116
- task_model = modelpool.load_model(model_name)
120
+ with self.profile("loading model"):
121
+ task_model = modelpool.load_model(model_name)
117
122
 
118
- self.all_task_vector_norm.append(
119
- get_task_vector_norm(task_model, pretrained_model)
120
- )
121
- self.avg_task_vector_norm = np.mean(self.all_task_vector_norm)
122
- self.fabric.log(
123
- "model/task_vector_norm", self.all_task_vector_norm[-1], step=model_idx
124
- )
125
- self.fabric.log(
126
- "model/avg_task_vector_norm", self.avg_task_vector_norm, step=model_idx
127
- )
123
+ with self.profile("merging model"):
124
+ self.all_task_vector_norm.append(
125
+ get_task_vector_norm(task_model, pretrained_model)
126
+ )
127
+ self.avg_task_vector_norm = np.mean(self.all_task_vector_norm)
128
+ self.fabric.log(
129
+ "model/task_vector_norm", self.all_task_vector_norm[-1], step=model_idx
130
+ )
131
+ self.fabric.log(
132
+ "model/avg_task_vector_norm", self.avg_task_vector_norm, step=model_idx
133
+ )
128
134
 
129
- self.lambda_t = 1 # temporary value
130
-
131
- for module_name, module in tqdm(
132
- list(merged_model.named_modules()),
133
- desc=f"Processing {model_name}",
134
- leave=False,
135
- ):
136
- if not is_leaf_module(module):
137
- continue
138
-
139
- if isinstance(module, nn.Linear):
140
- module.weight.data = self.merge_linear_weights(
141
- module.weight,
142
- pretrained_model.get_submodule(module_name).weight,
143
- task_model.get_submodule(module_name).weight,
144
- param_name=".".join([module_name, "weight"]),
145
- alpha=self.alpha,
146
- accelerator=accelerator,
147
- )
148
- if module.bias is not None:
149
- module.bias.data = self.merge_other_parameters(
150
- module.bias,
151
- pretrained_model.get_submodule(module_name).bias,
152
- task_model.get_submodule(module_name).bias,
153
- param_name=".".join([module_name, "bias"]),
135
+ self.lambda_t = 1 # temporary value
136
+
137
+ for module_name, module in tqdm(
138
+ list(merged_model.named_modules()),
139
+ desc=f"Processing {model_name}",
140
+ leave=False,
141
+ ):
142
+ if not is_leaf_module(module):
143
+ continue
144
+
145
+ if isinstance(module, nn.Linear):
146
+ module.weight.data = self.merge_linear_weights(
147
+ module.weight,
148
+ pretrained_model.get_submodule(module_name).weight,
149
+ task_model.get_submodule(module_name).weight,
150
+ param_name=".".join([module_name, "weight"]),
151
+ alpha=self.alpha,
154
152
  accelerator=accelerator,
155
153
  )
156
- else:
157
- for param_name, param in module.named_parameters():
158
- param.data = self.merge_other_parameters(
159
- merged_W=param,
160
- pretrained_W=pretrained_model.get_submodule(
161
- module_name
162
- ).get_parameter(param_name),
163
- task_W=task_model.get_submodule(module_name).get_parameter(
164
- param_name
165
- ),
166
- param_name=".".join([module_name, param_name]),
167
- accelerator=accelerator,
168
- )
169
-
170
- task_vector_norm = get_task_vector_norm(merged_model, pretrained_model)
171
- self.lambda_t *= task_vector_norm / self.avg_task_vector_norm
172
- for param_name, param in merged_model.named_parameters():
173
- param.data = pretrained_model.get_parameter(param_name) + (
174
- param - pretrained_model.get_parameter(param_name)
175
- ) * (self.avg_task_vector_norm / task_vector_norm)
176
- self.fabric.log("model/lambda_t", self.lambda_t, step=model_idx)
177
- self.fabric.log(
178
- "empirical/lambda_t", np.sqrt(model_idx + 1), step=model_idx
179
- )
180
- self.previous_lambda_t = self.lambda_t
181
- self.lambda_t = None
154
+ if module.bias is not None:
155
+ module.bias.data = self.merge_other_parameters(
156
+ module.bias,
157
+ pretrained_model.get_submodule(module_name).bias,
158
+ task_model.get_submodule(module_name).bias,
159
+ param_name=".".join([module_name, "bias"]),
160
+ accelerator=accelerator,
161
+ )
162
+ else:
163
+ for param_name, param in module.named_parameters():
164
+ param.data = self.merge_other_parameters(
165
+ merged_W=param,
166
+ pretrained_W=pretrained_model.get_submodule(
167
+ module_name
168
+ ).get_parameter(param_name),
169
+ task_W=task_model.get_submodule(module_name).get_parameter(
170
+ param_name
171
+ ),
172
+ param_name=".".join([module_name, param_name]),
173
+ accelerator=accelerator,
174
+ )
175
+
176
+ task_vector_norm = get_task_vector_norm(merged_model, pretrained_model)
177
+ self.lambda_t *= task_vector_norm / self.avg_task_vector_norm
178
+ for param_name, param in merged_model.named_parameters():
179
+ param.data = pretrained_model.get_parameter(param_name) + (
180
+ param - pretrained_model.get_parameter(param_name)
181
+ ) * (self.avg_task_vector_norm / task_vector_norm)
182
+ self.fabric.log("model/lambda_t", self.lambda_t, step=model_idx)
183
+ self.fabric.log(
184
+ "empirical/lambda_t", np.sqrt(model_idx + 1), step=model_idx
185
+ )
186
+ self.previous_lambda_t = self.lambda_t
187
+ self.lambda_t = None
182
188
 
183
- self.fabric.log(
184
- "model/merged_task_vector_norm",
185
- get_task_vector_norm(merged_model, pretrained_model),
186
- step=model_idx,
187
- )
189
+ self.fabric.log(
190
+ "model/merged_task_vector_norm",
191
+ get_task_vector_norm(merged_model, pretrained_model),
192
+ step=model_idx,
193
+ )
188
194
 
189
195
  if self.save_on_every_step:
190
- self.save_merged_model(merged_model, model_idx)
196
+ with self.profile("saving model"):
197
+ self.save_merged_model(merged_model, model_idx)
191
198
 
192
199
  if self.evaluate_on_every_step:
193
- self.taskpool._is_setup = False
194
- self.taskpool._test_datasets = DictConfig(
195
- {n: self._test_datasets[n] for n in model_names[: model_idx + 1]}
196
- )
197
- report = self.taskpool.evaluate(deepcopy(merged_model))
198
- save_to_json(report, Path(self.log_dir) / f"report_{model_idx}.json")
200
+ with self.profile("evaluating model"):
201
+ self.taskpool._is_setup = False
202
+ self.taskpool._test_datasets = DictConfig(
203
+ {n: self._test_datasets[n] for n in model_names[: model_idx + 1]}
204
+ )
205
+ report = self.taskpool.evaluate(deepcopy(merged_model))
206
+ save_to_json(report, Path(self.log_dir) / f"report_{model_idx}.json")
199
207
 
208
+ self.print_profile_summary()
200
209
  return merged_model
201
210
 
202
211
  def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
@@ -227,7 +236,7 @@ class OPCMForCLIP(
227
236
  split_rank = (s.cumsum(dim=0) / s.sum() > alpha).float().argmax().item()
228
237
 
229
238
  projected_task_tv = u.T @ task_tv @ v
230
- projected_task_tv.diag().fill_(0)
239
+ projected_task_tv.diagonal().fill_(0)
231
240
 
232
241
  projected_task_tv[:split_rank, :split_rank] = 0
233
242
 
@@ -15,7 +15,7 @@ from tqdm.auto import tqdm
15
15
  from transformers import CLIPVisionModel
16
16
 
17
17
  from fusion_bench import BaseAlgorithm, BaseModelPool
18
- from fusion_bench.mixins import LightningFabricMixin
18
+ from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
19
19
  from fusion_bench.taskpool import CLIPVisionModelTaskPool
20
20
  from fusion_bench.utils.json import load_from_json, save_to_json
21
21
  from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_sub
@@ -24,7 +24,11 @@ if TYPE_CHECKING:
24
24
  from torch.utils.tensorboard import SummaryWriter
25
25
 
26
26
 
27
- class ContinualTaskArithmeticForCLIP(BaseAlgorithm, LightningFabricMixin):
27
+ class ContinualTaskArithmeticForCLIP(
28
+ BaseAlgorithm,
29
+ LightningFabricMixin,
30
+ SimpleProfilerMixin,
31
+ ):
28
32
  def __init__(
29
33
  self,
30
34
  scaling_factor: float,
@@ -79,32 +83,42 @@ class ContinualTaskArithmeticForCLIP(BaseAlgorithm, LightningFabricMixin):
79
83
  for model_idx, model_name in tqdm(
80
84
  enumerate(model_names), desc="Processing models"
81
85
  ):
82
- task_model = modelpool.load_model(model_name)
86
+ with self.profile("loading model"):
87
+ task_model = modelpool.load_model(model_name)
83
88
 
84
- for param_name, param in task_model.named_parameters():
85
- if not param.requires_grad:
86
- continue
89
+ with self.profile("merging model"):
90
+ for param_name, param in task_model.named_parameters():
91
+ if not param.requires_grad:
92
+ continue
87
93
 
88
- task_param = param
89
- merged_param = merged_model.get_parameter(param_name)
90
- pretrained_param = pretrained_model.get_parameter(param_name)
94
+ task_param = param
95
+ merged_param = merged_model.get_parameter(param_name)
96
+ pretrained_param = pretrained_model.get_parameter(param_name)
91
97
 
92
- new_param = merged_param + self.scaling_factor * (
93
- task_param - pretrained_param
94
- )
95
- merged_model.get_parameter(param_name).data = new_param
98
+ new_param = merged_param + self.scaling_factor * (
99
+ task_param - pretrained_param
100
+ )
101
+ merged_model.get_parameter(param_name).data = new_param
96
102
 
97
103
  if self.save_on_every_step:
98
- self.save_merged_model(merged_model, model_idx)
104
+ with self.profile("saving model"):
105
+ self.save_merged_model(merged_model, model_idx)
99
106
 
100
107
  if self.evaluate_on_every_step:
101
- self.taskpool._is_setup = False
102
- self.taskpool._test_datasets = DictConfig(
103
- {n: self._test_datasets[n] for n in model_names[: model_idx + 1]}
104
- )
105
- report = self.taskpool.evaluate(deepcopy(merged_model))
106
- save_to_json(report, Path(self.log_dir) / f"report_{model_idx}.json")
107
-
108
+ with self.profile("evaluating model"):
109
+ self.taskpool._is_setup = False
110
+ self.taskpool._test_datasets = DictConfig(
111
+ {
112
+ n: self._test_datasets[n]
113
+ for n in model_names[: model_idx + 1]
114
+ }
115
+ )
116
+ report = self.taskpool.evaluate(deepcopy(merged_model))
117
+ save_to_json(
118
+ report, Path(self.log_dir) / f"report_{model_idx}.json"
119
+ )
120
+
121
+ self.print_profile_summary()
108
122
  return merged_model
109
123
 
110
124
  def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
@@ -20,7 +20,7 @@ from fusion_bench.method.ties_merging.ties_merging_utils import (
20
20
  ties_merging,
21
21
  vector_to_state_dict,
22
22
  )
23
- from fusion_bench.mixins import LightningFabricMixin
23
+ from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
24
24
  from fusion_bench.taskpool import CLIPVisionModelTaskPool
25
25
  from fusion_bench.utils.json import load_from_json, save_to_json
26
26
  from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_sub
@@ -29,7 +29,11 @@ if TYPE_CHECKING:
29
29
  from torch.utils.tensorboard import SummaryWriter
30
30
 
31
31
 
32
- class ContinualTiesMergingForCLIP(BaseAlgorithm, LightningFabricMixin):
32
+ class ContinualTiesMergingForCLIP(
33
+ BaseAlgorithm,
34
+ LightningFabricMixin,
35
+ SimpleProfilerMixin,
36
+ ):
33
37
  def __init__(
34
38
  self,
35
39
  scaling_factor: float,
@@ -84,68 +88,83 @@ class ContinualTiesMergingForCLIP(BaseAlgorithm, LightningFabricMixin):
84
88
  )
85
89
 
86
90
  # get the average model
87
- pretrained_model = modelpool.load_pretrained_model()
91
+ with self.profile("loading model"):
92
+ pretrained_model = modelpool.load_pretrained_model()
88
93
  merged_model = deepcopy(pretrained_model)
89
94
 
90
95
  for model_idx, model_name in tqdm(
91
96
  enumerate(model_names), desc="Processing models"
92
97
  ):
93
- task_model = modelpool.load_model(model_name)
98
+ with self.profile("loading model"):
99
+ task_model = modelpool.load_model(model_name)
94
100
 
95
- task_vector = state_dict_sub(
96
- task_model.state_dict(),
97
- pretrained_model.state_dict(),
98
- )
99
- if model_idx == 0:
100
- # if is the first model, the merged task vector is equal to the task vector
101
- ties_merging_state_dict = task_vector
102
- else:
103
- # if is not the first model, we need to merge the task vector with the previous merged task vector
104
- merged_tv = state_dict_sub(
105
- merged_model.state_dict(),
101
+ with self.profile("merging model"):
102
+ task_vector = state_dict_sub(
103
+ task_model.state_dict(),
106
104
  pretrained_model.state_dict(),
107
105
  )
108
- tv_flat_checks = torch.vstack(
109
- [
110
- state_dict_to_vector(merged_tv, remove_keys=self.remove_keys),
111
- state_dict_to_vector(task_vector, remove_keys=self.remove_keys),
112
- ]
113
- )
114
- # perform the TIES merging
115
- ties_merging_tv = ties_merging(
116
- tv_flat_checks,
117
- reset_thresh=self.threshold,
118
- merge_func=self.merge_func,
119
- )
120
- # convert the merged task vector back to a state dict
121
- ties_merging_state_dict = vector_to_state_dict(
122
- ties_merging_tv,
123
- merged_model.state_dict(),
124
- remove_keys=self.remove_keys,
125
- )
126
-
127
- for param_name, param in task_model.named_parameters():
128
- if not param.requires_grad:
129
- continue
130
-
131
- merged_param = merged_model.get_parameter(param_name)
132
- new_param = (
133
- merged_param
134
- + self.scaling_factor * ties_merging_state_dict[param_name]
135
- )
136
- merged_model.get_parameter(param_name).data = new_param
106
+ if model_idx == 0:
107
+ # if is the first model, the merged task vector is equal to the task vector
108
+ ties_merging_state_dict = task_vector
109
+ else:
110
+ # if is not the first model, we need to merge the task vector with the previous merged task vector
111
+ merged_tv = state_dict_sub(
112
+ merged_model.state_dict(),
113
+ pretrained_model.state_dict(),
114
+ )
115
+ tv_flat_checks = torch.vstack(
116
+ [
117
+ state_dict_to_vector(
118
+ merged_tv, remove_keys=self.remove_keys
119
+ ),
120
+ state_dict_to_vector(
121
+ task_vector, remove_keys=self.remove_keys
122
+ ),
123
+ ]
124
+ )
125
+ # perform the TIES merging
126
+ ties_merging_tv = ties_merging(
127
+ tv_flat_checks,
128
+ reset_thresh=self.threshold,
129
+ merge_func=self.merge_func,
130
+ )
131
+ # convert the merged task vector back to a state dict
132
+ ties_merging_state_dict = vector_to_state_dict(
133
+ ties_merging_tv,
134
+ merged_model.state_dict(),
135
+ remove_keys=self.remove_keys,
136
+ )
137
+
138
+ for param_name, param in task_model.named_parameters():
139
+ if not param.requires_grad:
140
+ continue
141
+
142
+ merged_param = merged_model.get_parameter(param_name)
143
+ new_param = (
144
+ merged_param
145
+ + self.scaling_factor * ties_merging_state_dict[param_name]
146
+ )
147
+ merged_model.get_parameter(param_name).data = new_param
137
148
 
138
149
  if self.save_on_every_step:
139
- self.save_merged_model(merged_model, model_idx)
150
+ with self.profile("saving model"):
151
+ self.save_merged_model(merged_model, model_idx)
140
152
 
141
153
  if self.evaluate_on_every_step:
142
- self.taskpool._is_setup = False
143
- self.taskpool._test_datasets = DictConfig(
144
- {n: self._test_datasets[n] for n in model_names[: model_idx + 1]}
145
- )
146
- report = self.taskpool.evaluate(deepcopy(merged_model))
147
- save_to_json(report, Path(self.log_dir) / f"report_{model_idx}.json")
148
-
154
+ with self.profile("evaluating model"):
155
+ self.taskpool._is_setup = False
156
+ self.taskpool._test_datasets = DictConfig(
157
+ {
158
+ n: self._test_datasets[n]
159
+ for n in model_names[: model_idx + 1]
160
+ }
161
+ )
162
+ report = self.taskpool.evaluate(deepcopy(merged_model))
163
+ save_to_json(
164
+ report, Path(self.log_dir) / f"report_{model_idx}.json"
165
+ )
166
+
167
+ self.print_profile_summary()
149
168
  return merged_model
150
169
 
151
170
  def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
@@ -9,19 +9,19 @@ fusion_bench \
9
9
  ```
10
10
  """
11
11
 
12
- from typing import List, Optional, Union, Iterable
12
+ from typing import Iterable, List, Optional, Union
13
13
 
14
14
  import torch
15
- from torch import Tensor, nn
16
15
  from omegaconf import ListConfig
16
+ from torch import Tensor, nn
17
17
 
18
18
  from fusion_bench import BaseAlgorithm
19
19
  from fusion_bench.mixins import LightningFabricMixin
20
20
  from fusion_bench.utils import timeit_context
21
21
  from fusion_bench.utils.state_dict_arithmetic import (
22
22
  state_dict_add,
23
- state_dict_sub,
24
23
  state_dict_mul,
24
+ state_dict_sub,
25
25
  )
26
26
  from fusion_bench.utils.type import StateDictType
27
27
 
@@ -16,6 +16,7 @@ import torch
16
16
  from torch import Tensor, nn
17
17
  from torch.func import functional_call
18
18
 
19
+ from fusion_bench.models.utils import del_attr, get_attr, set_attr
19
20
  from fusion_bench.utils.type import StateDictType, TorchModelType
20
21
 
21
22
  __all__ = ["get_layer_wise_weights", "fuse_weights", "LayerWiseMergedModel"]
@@ -23,52 +24,6 @@ __all__ = ["get_layer_wise_weights", "fuse_weights", "LayerWiseMergedModel"]
23
24
  log = logging.getLogger(__name__)
24
25
 
25
26
 
26
- def del_attr(obj, names: List[str]):
27
- """
28
- Deletes an attribute from an object recursively.
29
-
30
- Args:
31
- obj (object): Object to delete attribute from.
32
- names (list): List of attribute names to delete recursively.
33
- """
34
- if len(names) == 1:
35
- delattr(obj, names[0])
36
- else:
37
- del_attr(getattr(obj, names[0]), names[1:])
38
-
39
-
40
- def set_attr(obj, names: List[str], val):
41
- """
42
- Sets an attribute of an object recursively.
43
-
44
- Args:
45
- obj (object): Object to set attribute of.
46
- names (list): List of attribute names to set recursively.
47
- val (object): Value to set the attribute to.
48
- """
49
- if len(names) == 1:
50
- setattr(obj, names[0], val)
51
- else:
52
- set_attr(getattr(obj, names[0]), names[1:], val)
53
-
54
-
55
- def get_attr(obj, names: List[str]):
56
- """
57
- Gets an attribute of an object recursively.
58
-
59
- Args:
60
- obj (object): Object to get attribute of.
61
- names (list): List of attribute names to get recursively.
62
-
63
- Returns:
64
- object: The attribute of the object.
65
- """
66
- if len(names) == 1:
67
- return getattr(obj, names[0])
68
- else:
69
- return get_attr(getattr(obj, names[0]), names[1:])
70
-
71
-
72
27
  def get_layer_wise_weights(
73
28
  num_models: int,
74
29
  num_layers: int,