fusion-bench 0.2.28__py3-none-any.whl → 0.2.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. fusion_bench/constants/__init__.py +5 -1
  2. fusion_bench/constants/runtime.py +111 -7
  3. fusion_bench/dataset/gsm8k.py +6 -2
  4. fusion_bench/dataset/image_corruption/make_corruption.py +168 -0
  5. fusion_bench/method/__init__.py +10 -2
  6. fusion_bench/method/base_algorithm.py +29 -19
  7. fusion_bench/method/classification/image_classification_finetune.py +1 -2
  8. fusion_bench/method/gossip/clip_task_wise_gossip.py +1 -29
  9. fusion_bench/metrics/model_kinship/__init__.py +2 -0
  10. fusion_bench/metrics/model_kinship/calculate.py +77 -0
  11. fusion_bench/metrics/model_kinship/calculate_split.py +171 -0
  12. fusion_bench/metrics/model_kinship/utility.py +184 -0
  13. fusion_bench/metrics/nyuv2/__init__.py +31 -0
  14. fusion_bench/metrics/nyuv2/depth.py +30 -0
  15. fusion_bench/metrics/nyuv2/loss.py +40 -0
  16. fusion_bench/metrics/nyuv2/noise.py +24 -0
  17. fusion_bench/metrics/nyuv2/normal.py +34 -1
  18. fusion_bench/metrics/nyuv2/segmentation.py +35 -1
  19. fusion_bench/mixins/clip_classification.py +30 -2
  20. fusion_bench/mixins/lightning_fabric.py +46 -5
  21. fusion_bench/mixins/rich_live.py +76 -0
  22. fusion_bench/modelpool/base_pool.py +86 -5
  23. fusion_bench/models/masks/mask_model.py +8 -2
  24. fusion_bench/models/open_clip/modeling.py +7 -0
  25. fusion_bench/models/wrappers/layer_wise_fusion.py +41 -3
  26. fusion_bench/models/wrappers/task_wise_fusion.py +14 -3
  27. fusion_bench/scripts/cli.py +14 -0
  28. fusion_bench/scripts/webui.py +250 -17
  29. fusion_bench/utils/__init__.py +14 -0
  30. fusion_bench/utils/data.py +100 -9
  31. fusion_bench/utils/devices.py +3 -1
  32. fusion_bench/utils/fabric.py +185 -4
  33. fusion_bench/utils/instantiate_utils.py +29 -18
  34. fusion_bench/utils/json.py +6 -0
  35. fusion_bench/utils/misc.py +16 -0
  36. fusion_bench/utils/rich_utils.py +123 -6
  37. fusion_bench/utils/validation.py +197 -0
  38. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/METADATA +72 -13
  39. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/RECORD +49 -45
  40. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +6 -19
  41. fusion_bench_config/llama_full_finetune.yaml +4 -16
  42. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  43. fusion_bench_config/nyuv2_config.yaml +4 -13
  44. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  45. fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  46. fusion_bench/utils/auto.py +0 -31
  47. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/WHEEL +0 -0
  48. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/entry_points.txt +0 -0
  49. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/licenses/LICENSE +0 -0
  50. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/top_level.txt +0 -0
@@ -59,6 +59,15 @@ class CLIPClassificationMixin(LightningFabricMixin):
59
59
 
60
60
  @property
61
61
  def clip_processor(self):
62
+ """
63
+ Get the CLIP processor, loading it from the model pool if necessary.
64
+
65
+ Returns:
66
+ CLIPProcessor: The CLIP processor for image and text preprocessing.
67
+
68
+ Raises:
69
+ AssertionError: If the model pool is not set.
70
+ """
62
71
  if self._clip_processor is None:
63
72
  assert self.modelpool is not None, "Model pool is not set"
64
73
  self._clip_processor = self.modelpool.load_processor()
@@ -125,6 +134,11 @@ class CLIPClassificationMixin(LightningFabricMixin):
125
134
  clip_model (Optional[CLIPModel]): The CLIP model to use. If not provided, a pretrained model is loaded from the model pool.
126
135
  task_names (Optional[List[str]]): A list of task names to set up the classification head for. If not provided, all models in the model pool will be used.
127
136
  """
137
+ # make sure the task names are equal across all processes
138
+ _task_names = self.fabric.broadcast(task_names, src=0)
139
+ if not self.fabric.is_global_zero and task_names != _task_names:
140
+ raise ValueError("The `task_names` must be the same across all processes.")
141
+
128
142
  self.whether_setup_zero_shot_classification_head = True
129
143
  # load clip model if not provided
130
144
  if clip_model is None:
@@ -147,7 +161,10 @@ class CLIPClassificationMixin(LightningFabricMixin):
147
161
  self.logit_scale_exp = self.fabric.to_device(self.logit_scale_exp)
148
162
 
149
163
  @cache_with_joblib()
150
- def construct_classification_head(task: str):
164
+ def construct_classification_head(task: str, model_name: str):
165
+ log.info(
166
+ f"Constructing zero-shot classification head for task: {task} using model: {model_name}"
167
+ )
151
168
  nonlocal clip_classifier
152
169
 
153
170
  classnames, templates = get_classnames_and_templates(task)
@@ -163,7 +180,18 @@ class CLIPClassificationMixin(LightningFabricMixin):
163
180
  ):
164
181
  zeroshot_weights = None
165
182
  if self.fabric.is_global_zero:
166
- zeroshot_weights = construct_classification_head(task)
183
+ if hasattr(clip_model, "config") and hasattr(
184
+ clip_model.config, "_name_or_path"
185
+ ):
186
+ model_name = clip_model.config._name_or_path
187
+ else:
188
+ model_name = "unknown_model"
189
+ log.warning(
190
+ "CLIP model config does not have `_name_or_path` attribute. Using 'unknown_model' as model name."
191
+ )
192
+ zeroshot_weights = construct_classification_head(
193
+ task, model_name=model_name
194
+ )
167
195
 
168
196
  self.fabric.barrier()
169
197
  self.zeroshot_weights[task] = self.fabric.broadcast(zeroshot_weights, src=0)
@@ -1,7 +1,7 @@
1
1
  import functools
2
2
  import logging
3
3
  import os
4
- from typing import TYPE_CHECKING, Any, List, Optional, TypeVar
4
+ from typing import TYPE_CHECKING, Any, List, Mapping, Optional, TypeVar
5
5
 
6
6
  import lightning as L
7
7
  import torch
@@ -96,12 +96,24 @@ class LightningFabricMixin:
96
96
 
97
97
  @property
98
98
  def fabric(self):
99
+ """
100
+ Get the Lightning Fabric instance, initializing it if necessary.
101
+
102
+ Returns:
103
+ L.Fabric: The Lightning Fabric instance for distributed computing.
104
+ """
99
105
  if self._fabric_instance is None:
100
106
  self.setup_lightning_fabric(getattr(self, "config", DictConfig({})))
101
107
  return self._fabric_instance
102
108
 
103
109
  @fabric.setter
104
110
  def fabric(self, instance: L.Fabric):
111
+ """
112
+ Set the Lightning Fabric instance.
113
+
114
+ Args:
115
+ instance: The Lightning Fabric instance to use.
116
+ """
105
117
  self._fabric_instance = instance
106
118
 
107
119
  @property
@@ -172,6 +184,15 @@ class LightningFabricMixin:
172
184
  def tensorboard_summarywriter(
173
185
  self,
174
186
  ) -> "lightning.fabric.loggers.tensorboard.SummaryWriter":
187
+ """
188
+ Get the TensorBoard SummaryWriter for detailed logging.
189
+
190
+ Returns:
191
+ SummaryWriter: The TensorBoard SummaryWriter instance.
192
+
193
+ Raises:
194
+ AttributeError: If the logger is not a TensorBoardLogger.
195
+ """
175
196
  if isinstance(self.fabric.logger, TensorBoardLogger):
176
197
  return self.fabric.logger.experiment
177
198
  else:
@@ -179,6 +200,12 @@ class LightningFabricMixin:
179
200
 
180
201
  @property
181
202
  def is_debug_mode(self):
203
+ """
204
+ Check if the program is running in debug mode (fast_dev_run).
205
+
206
+ Returns:
207
+ bool: True if fast_dev_run is enabled, False otherwise.
208
+ """
182
209
  if hasattr(self, "config") and self.config.get("fast_dev_run", False):
183
210
  return True
184
211
  elif hasattr(self, "_program") and self._program.config.get(
@@ -190,13 +217,22 @@ class LightningFabricMixin:
190
217
 
191
218
  def log(self, name: str, value: Any, step: Optional[int] = None):
192
219
  """
193
- Logs the metric to the fabric's logger.
220
+ Logs a single metric to the fabric's logger.
221
+
222
+ Args:
223
+ name: The name of the metric to log.
224
+ value: The value of the metric.
225
+ step: Optional step number for the metric.
194
226
  """
195
227
  self.fabric.log(name, value, step=step)
196
228
 
197
- def log_dict(self, metrics: dict, step: Optional[int] = None):
229
+ def log_dict(self, metrics: Mapping[str, Any], step: Optional[int] = None):
198
230
  """
199
- Logs the metrics to the fabric's logger.
231
+ Logs multiple metrics to the fabric's logger.
232
+
233
+ Args:
234
+ metrics: Dictionary of metric names and values.
235
+ step: Optional step number for the metrics.
200
236
  """
201
237
  self.fabric.log_dict(metrics, step=step)
202
238
 
@@ -207,7 +243,12 @@ class LightningFabricMixin:
207
243
  name_template: str = "train/lr_group_{0}",
208
244
  ):
209
245
  """
210
- Logs the learning rate of the optimizer to the fabric's logger.
246
+ Logs the learning rate of each parameter group in the optimizer.
247
+
248
+ Args:
249
+ optimizer: The optimizer whose learning rates should be logged.
250
+ step: Optional step number for the log entry.
251
+ name_template: Template string for the log name. Use {0} as placeholder for group index.
211
252
  """
212
253
  for i, param_group in enumerate(optimizer.param_groups):
213
254
  self.fabric.log(name_template.format(i), param_group["lr"], step=step)
@@ -2,20 +2,96 @@ from rich.live import Live
2
2
 
3
3
 
4
4
  class RichLiveMixin:
5
+ """
6
+ A mixin class that provides Rich Live display capabilities.
7
+
8
+ This mixin integrates Rich's Live display functionality, allowing for
9
+ dynamic, auto-refreshing console output. It's particularly useful for
10
+ displaying real-time updates, progress information, or continuously
11
+ changing data without cluttering the terminal.
12
+
13
+ Attributes:
14
+ _rich_live (Live): The internal Rich Live instance for live display updates.
15
+
16
+ Example:
17
+ ```python
18
+ class MyTask(RichLiveMixin):
19
+ def run(self):
20
+ self.start_rich_live()
21
+ for i in range(100):
22
+ self.rich_live_print(f"Processing item {i}")
23
+ time.sleep(0.1)
24
+ self.stop_rich_live()
25
+ ```
26
+ """
27
+
5
28
  _rich_live: Live = None
6
29
 
7
30
  @property
8
31
  def rich_live(self) -> Live:
32
+ """
33
+ Get the Rich Live instance, creating it if necessary.
34
+
35
+ Returns:
36
+ Live: The Rich Live instance for dynamic console output.
37
+ """
9
38
  if self._rich_live is None:
10
39
  self._rich_live = Live()
11
40
  return self._rich_live
12
41
 
13
42
  def start_rich_live(self):
43
+ """
44
+ Start the Rich Live display context.
45
+
46
+ This method enters the Rich Live context, enabling dynamic console output.
47
+ Must be paired with stop_rich_live() to properly clean up resources.
48
+
49
+ Returns:
50
+ The Rich Live instance in its started state.
51
+
52
+ Example:
53
+ ```python
54
+ self.start_rich_live()
55
+ # Display dynamic content
56
+ self.rich_live_print("Dynamic output")
57
+ self.stop_rich_live()
58
+ ```
59
+ """
14
60
  return self.rich_live.__enter__()
15
61
 
16
62
  def stop_rich_live(self):
63
+ """
64
+ Stop the Rich Live display context and clean up resources.
65
+
66
+ This method exits the Rich Live context and resets the internal Live instance.
67
+ Should be called after start_rich_live() when dynamic display is complete.
68
+
69
+ Example:
70
+ ```python
71
+ self.start_rich_live()
72
+ # ... display content ...
73
+ self.stop_rich_live()
74
+ ```
75
+ """
17
76
  self.rich_live.__exit__(None, None, None)
18
77
  self._rich_live = None
19
78
 
20
79
  def rich_live_print(self, msg):
80
+ """
81
+ Print a message to the Rich Live console.
82
+
83
+ This method displays the given message through the Rich Live console,
84
+ allowing for formatted, dynamic output that updates in place.
85
+
86
+ Args:
87
+ msg: The message to display. Can be a string or any Rich renderable object.
88
+
89
+ Example:
90
+ ```python
91
+ self.start_rich_live()
92
+ self.rich_live_print("[bold green]Success![/bold green]")
93
+ self.rich_live_print(Panel("Status: Running"))
94
+ self.stop_rich_live()
95
+ ```
96
+ """
21
97
  self.rich_live.console.print(msg)
@@ -1,6 +1,6 @@
1
1
  import logging
2
2
  from copy import deepcopy
3
- from typing import Dict, Generator, List, Optional, Tuple, Union
3
+ from typing import Any, Dict, Generator, List, Optional, Tuple, Union
4
4
 
5
5
  import torch
6
6
  from omegaconf import DictConfig, OmegaConf, UnsupportedValueType
@@ -8,7 +8,12 @@ from torch import nn
8
8
  from torch.utils.data import Dataset
9
9
 
10
10
  from fusion_bench.mixins import BaseYAMLSerializable, HydraConfigMixin
11
- from fusion_bench.utils import instantiate, timeit_context
11
+ from fusion_bench.utils import (
12
+ ValidationError,
13
+ instantiate,
14
+ timeit_context,
15
+ validate_model_name,
16
+ )
12
17
 
13
18
  __all__ = ["BaseModelPool"]
14
19
 
@@ -59,6 +64,16 @@ class BaseModelPool(
59
64
  except UnsupportedValueType:
60
65
  pass
61
66
 
67
+ if not models:
68
+ log.warning("Initialized BaseModelPool with empty models dictionary.")
69
+ else:
70
+ # Validate model names
71
+ for model_name in models.keys():
72
+ try:
73
+ validate_model_name(model_name, allow_special=True)
74
+ except ValidationError as e:
75
+ log.warning(f"Invalid model name '{model_name}': {e}")
76
+
62
77
  self._models = models
63
78
  self._train_datasets = train_datasets
64
79
  self._val_datasets = val_datasets
@@ -147,7 +162,9 @@ class BaseModelPool(
147
162
  """
148
163
  return model_name.startswith("_") and model_name.endswith("_")
149
164
 
150
- def get_model_config(self, model_name: str, return_copy: bool = True) -> DictConfig:
165
+ def get_model_config(
166
+ self, model_name: str, return_copy: bool = True
167
+ ) -> Union[DictConfig, str, Any]:
151
168
  """
152
169
  Get the configuration for the specified model.
153
170
 
@@ -155,10 +172,36 @@ class BaseModelPool(
155
172
  model_name (str): The name of the model.
156
173
 
157
174
  Returns:
158
- DictConfig: The configuration for the specified model.
175
+ Union[DictConfig, str, Any]: The configuration for the specified model, which may be a DictConfig, string path, or other type.
176
+
177
+ Raises:
178
+ ValidationError: If model_name is invalid.
179
+ KeyError: If model_name is not found in the pool.
159
180
  """
181
+ # Validate model name
182
+ validate_model_name(model_name, allow_special=True)
183
+
184
+ # raise friendly error if model not found in the pool
185
+ if model_name not in self._models:
186
+ available_models = list(self._models.keys())
187
+ raise KeyError(
188
+ f"Model '{model_name}' not found in model pool. "
189
+ f"Available models: {available_models}"
190
+ )
191
+
160
192
  model_config = self._models[model_name]
193
+ if isinstance(model_config, nn.Module):
194
+ log.warning(
195
+ f"Model configuration for '{model_name}' is a pre-instantiated model. "
196
+ "Returning the model instance instead of configuration."
197
+ )
198
+
161
199
  if return_copy:
200
+ if isinstance(model_config, nn.Module):
201
+ # raise performance warning
202
+ log.warning(
203
+ f"Furthermore, returning a copy of the pre-instantiated model '{model_name}' may be inefficient."
204
+ )
162
205
  model_config = deepcopy(model_config)
163
206
  return model_config
164
207
 
@@ -171,12 +214,28 @@ class BaseModelPool(
171
214
 
172
215
  Returns:
173
216
  str: The path for the specified model.
217
+
218
+ Raises:
219
+ ValidationError: If model_name is invalid.
220
+ KeyError: If model_name is not found in the pool.
221
+ ValueError: If model configuration is not a string path.
174
222
  """
223
+ # Validate model name
224
+ validate_model_name(model_name, allow_special=True)
225
+
226
+ if model_name not in self._models:
227
+ available_models = list(self._models.keys())
228
+ raise KeyError(
229
+ f"Model '{model_name}' not found in model pool. "
230
+ f"Available models: {available_models}"
231
+ )
232
+
175
233
  if isinstance(self._models[model_name], str):
176
234
  return self._models[model_name]
177
235
  else:
178
236
  raise ValueError(
179
- "Model path is not a string. Try to override this method in derived modelpool class."
237
+ f"Model configuration for '{model_name}' is not a string path. "
238
+ "Try to override this method in derived modelpool class."
180
239
  )
181
240
 
182
241
  def load_model(
@@ -357,3 +416,25 @@ class BaseModelPool(
357
416
  """
358
417
  with timeit_context(f"Saving the state dict of model to {path}"):
359
418
  torch.save(model.state_dict(), path)
419
+
420
+ def __contains__(self, model_name: str) -> bool:
421
+ """
422
+ Check if a model with the given name exists in the model pool.
423
+
424
+ Examples:
425
+ >>> modelpool = BaseModelPool(models={"modelA": ..., "modelB": ...})
426
+ >>> "modelA" in modelpool
427
+ True
428
+ >>> "modelC" in modelpool
429
+ False
430
+
431
+ Args:
432
+ model_name (str): The name of the model to check.
433
+
434
+ Returns:
435
+ bool: True if the model exists, False otherwise.
436
+ """
437
+ if self._models is None:
438
+ raise RuntimeError("Model pool is not initialized")
439
+ validate_model_name(model_name, allow_special=True)
440
+ return model_name in self._models
@@ -113,21 +113,27 @@ class MaskModel(ParameterDictModel):
113
113
  def get_distribution(
114
114
  self,
115
115
  mask_type: Literal["discrete", "continuous"],
116
+ temperature: float = 0.5,
116
117
  **kwargs,
117
118
  ):
118
119
  return {
119
- name: self._param_to_distribution(param, mask_type=mask_type, **kwargs)
120
+ name: self._param_to_distribution(
121
+ param, mask_type=mask_type, temperature=temperature, **kwargs
122
+ )
120
123
  for name, param in self.named_parameters()
121
124
  }
122
125
 
123
126
  def sample_mask(
124
127
  self,
125
128
  mask_type: Literal["discrete", "continuous"] = "discrete",
129
+ temperature: float = 0.5,
126
130
  **kwargs,
127
131
  ):
128
132
  mask = {}
129
133
  for name, param in self.named_parameters():
130
- dist = self._param_to_distribution(param, mask_type, **kwargs)
134
+ dist = self._param_to_distribution(
135
+ param, mask_type, temperature=temperature, **kwargs
136
+ )
131
137
  if mask_type == "discrete":
132
138
  mask[name] = dist.sample()
133
139
  elif mask_type == "continuous":
@@ -1,3 +1,10 @@
1
+ from fusion_bench.utils.packages import is_open_clip_available
2
+
3
+ if not is_open_clip_available():
4
+ raise ImportError(
5
+ "open_clip is not installed. Please install it with `pip install open_clip_torch`."
6
+ )
7
+
1
8
  from typing import Callable, List
2
9
 
3
10
  import open_clip
@@ -173,6 +173,24 @@ class LayerWiseMergedModel(nn.Module, Generic[TorchModelType]):
173
173
 
174
174
  @property
175
175
  def forward_model(self):
176
+ """
177
+ Get a functional model with merged parameters.
178
+
179
+ Returns a partial function that applies the pretrained model with the current
180
+ merged state dictionary. This allows for efficient forward passes without
181
+ modifying the original model's parameters.
182
+
183
+ Returns:
184
+ Callable: A partial function that can be called with (args, kwargs) to
185
+ perform forward pass with merged parameters.
186
+
187
+ Example:
188
+ ```python
189
+ # Internal usage during forward pass
190
+ forward_fn = merged_model.forward_model
191
+ output = forward_fn(args=(x,), kwargs={})
192
+ ```
193
+ """
176
194
  return functools.partial(
177
195
  functional_call,
178
196
  self.pretrained_model,
@@ -181,10 +199,30 @@ class LayerWiseMergedModel(nn.Module, Generic[TorchModelType]):
181
199
  strict=self.strict,
182
200
  )
183
201
 
184
- def merge_and_unload(self, task_vector_mask: Optional[Dict[str, Tensor]] = None):
202
+ def merge_and_unload(
203
+ self,
204
+ task_vector_mask: Optional[Dict[str, Tensor]] = None,
205
+ copy: bool = False,
206
+ ) -> TorchModelType:
207
+ """
208
+ Merge models and return the final merged model.
209
+
210
+ Args:
211
+ task_vector_mask (Optional[Dict[str, Tensor]], optional): Optional masks
212
+ for selective parameter merging. Defaults to None.
213
+ copy (bool, optional): Whether to return a deep copy of the pretrained model.
214
+ Defaults to False. If True, the original pretrained model remains unchanged.
215
+
216
+ Returns:
217
+ TorchModelType: The pretrained model with merged parameters loaded.
218
+ """
185
219
  self.merge_weights(task_vector_mask=task_vector_mask)
186
- self.pretrained_model.load_state_dict(self._merged_state_dict)
187
- return self.pretrained_model
220
+ if copy:
221
+ model = deepcopy(self.pretrained_model)
222
+ else:
223
+ model = self.pretrained_model
224
+ model.load_state_dict(self._merged_state_dict)
225
+ return model
188
226
 
189
227
  def merge_weights(self, task_vector_mask: Optional[Dict[str, Tensor]] = None):
190
228
  """
@@ -16,6 +16,7 @@ outputs = merged_model(inputs)
16
16
 
17
17
  import functools
18
18
  import logging
19
+ from copy import deepcopy
19
20
  from typing import Any, Callable, Dict, Generic, Iterator, List, Optional # noqa: F401
20
21
 
21
22
  import torch
@@ -327,7 +328,11 @@ class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
327
328
  self._merged_state_dict = state_dict
328
329
  return state_dict
329
330
 
330
- def merge_and_unload(self, task_vector_mask: Optional[Dict[str, Tensor]] = None):
331
+ def merge_and_unload(
332
+ self,
333
+ task_vector_mask: Optional[Dict[str, Tensor]] = None,
334
+ copy: bool = False,
335
+ ) -> TorchModelType:
331
336
  """
332
337
  Merge models and return the final merged model.
333
338
 
@@ -338,6 +343,8 @@ class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
338
343
  Args:
339
344
  task_vector_mask (Optional[Dict[str, Tensor]], optional): Optional masks
340
345
  for selective parameter merging. Defaults to None.
346
+ copy (bool, optional): Whether to return a deep copy of the pretrained model.
347
+ Defaults to False. If True, the original pretrained model remains unchanged.
341
348
 
342
349
  Returns:
343
350
  TorchModelType: The pretrained model with merged parameters loaded.
@@ -363,8 +370,12 @@ class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
363
370
  The original pretrained model parameters will be lost.
364
371
  """
365
372
  self.merge_weights(task_vector_mask=task_vector_mask)
366
- self.pretrained_model.load_state_dict(self._merged_state_dict)
367
- return self.pretrained_model
373
+ if copy:
374
+ model = deepcopy(self.pretrained_model)
375
+ else:
376
+ model = self.pretrained_model
377
+ model.load_state_dict(self._merged_state_dict)
378
+ return model
368
379
 
369
380
  def forward(self, *args, **kwargs):
370
381
  """
@@ -69,6 +69,20 @@ def main(cfg: DictConfig) -> None:
69
69
  """
70
70
  OmegaConf.resolve(cfg)
71
71
  program: BaseHydraProgram = instantiate(cfg)
72
+
73
+ # Validate that instantiation succeeded and returned an object with 'run' method
74
+ if not hasattr(program, "run") or not callable(getattr(program, "run")):
75
+ err_msg = (
76
+ f"Expected an object with a callable 'run' method, but got {type(program).__name__}. "
77
+ "Ensure that the configuration specifies a concrete program class with '_target_'."
78
+ )
79
+ if "_target_" not in cfg:
80
+ err_msg += "\nThe '_target_' field is missing from the root configuration."
81
+ else:
82
+ err_msg += f"\nFound '_target_': {cfg._target_}"
83
+ err_msg += f"\n\nConfiguration content:\n{cfg}"
84
+ raise TypeError(err_msg)
85
+
72
86
  program.run()
73
87
 
74
88