fusion-bench 0.2.31__py3-none-any.whl → 0.2.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. fusion_bench/__init__.py +6 -0
  2. fusion_bench/__main__.py +2 -2
  3. fusion_bench/dataset/__init__.py +2 -0
  4. fusion_bench/dataset/clip_dataset.py +4 -72
  5. fusion_bench/dataset/image_dataset.py +44 -18
  6. fusion_bench/method/base_algorithm.py +4 -0
  7. fusion_bench/method/dop/dop.py +0 -22
  8. fusion_bench/method/dop/dop_general.py +489 -0
  9. fusion_bench/method/dop/utils.py +24 -4
  10. fusion_bench/method/emr_merging/__init__.py +1 -0
  11. fusion_bench/method/emr_merging/emr_merging.py +53 -0
  12. fusion_bench/method/emr_merging/utils.py +162 -0
  13. fusion_bench/method/opcm/opcm.py +6 -2
  14. fusion_bench/method/opcm/opcm_general.py +356 -0
  15. fusion_bench/method/opcm/utils.py +1 -4
  16. fusion_bench/method/simple_average.py +52 -18
  17. fusion_bench/method/task_arithmetic/task_arithmetic.py +1 -1
  18. fusion_bench/mixins/lightning_fabric.py +108 -3
  19. fusion_bench/mixins/serialization.py +1 -1
  20. fusion_bench/modelpool/base_pool.py +37 -1
  21. fusion_bench/modelpool/convnext_for_image_classification.py +5 -2
  22. fusion_bench/models/hf_clip.py +20 -0
  23. fusion_bench/models/modulator/__init__.py +1 -0
  24. fusion_bench/models/modulator/base.py +123 -0
  25. fusion_bench/models/parameter_dict.py +119 -29
  26. fusion_bench/models/utils.py +190 -2
  27. fusion_bench/models/wrappers/switch.py +90 -0
  28. fusion_bench/programs/base_program.py +6 -0
  29. fusion_bench/programs/fabric_fusion_program.py +4 -0
  30. fusion_bench/scripts/cli.py +19 -8
  31. fusion_bench/taskpool/image_classification.py +270 -0
  32. fusion_bench/utils/__init__.py +18 -1
  33. fusion_bench/utils/data.py +1 -1
  34. fusion_bench/utils/dict.py +19 -0
  35. fusion_bench/utils/dtype.py +19 -0
  36. fusion_bench/utils/misc.py +1 -0
  37. fusion_bench/utils/packages.py +4 -0
  38. fusion_bench/utils/state_dict_arithmetic.py +183 -1
  39. fusion_bench/utils/tensorboard.py +21 -3
  40. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/METADATA +3 -1
  41. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/RECORD +51 -37
  42. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/WHEEL +1 -1
  43. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/entry_points.txt +1 -1
  44. fusion_bench_config/fabric/loggers/mlflow_logger.yaml +4 -0
  45. fusion_bench_config/method/dop/dop_general.yaml +33 -0
  46. fusion_bench_config/method/emr_merging/emr_merging.yaml +1 -0
  47. fusion_bench_config/method/opcm/opcm_general.yaml +18 -0
  48. fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224_8-tasks.yaml +15 -0
  49. fusion_bench_config/taskpool/ImageClassificationTaskPool/convnext-base-224_8-tasks.yaml +17 -0
  50. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/licenses/LICENSE +0 -0
  51. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/top_level.txt +0 -0
@@ -3,11 +3,16 @@ from copy import deepcopy
3
3
  from typing import Dict, List, Mapping, Optional, Union
4
4
 
5
5
  import torch
6
- from torch import nn
6
+ from torch import Tensor, nn
7
7
 
8
8
  from fusion_bench.method.base_algorithm import BaseAlgorithm
9
9
  from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
10
10
  from fusion_bench.modelpool import BaseModelPool
11
+ from fusion_bench.models.utils import (
12
+ get_target_state_dict,
13
+ load_state_dict_into_target_modules,
14
+ validate_target_modules_equal,
15
+ )
11
16
  from fusion_bench.utils import LazyStateDict
12
17
  from fusion_bench.utils.state_dict_arithmetic import (
13
18
  state_dict_add,
@@ -21,21 +26,22 @@ log = logging.getLogger(__name__)
21
26
 
22
27
 
23
28
  def simple_average(
24
- modules: List[Union[nn.Module, StateDictType]],
25
- base_module: Optional[nn.Module] = None,
29
+ modules: List[Union[nn.Module, StateDictType, Tensor]],
30
+ base_module: Optional[Union[nn.Module, StateDictType, Tensor]] = None,
26
31
  ):
27
32
  R"""
28
33
  Averages the parameters of a list of PyTorch modules or state dictionaries.
29
34
 
30
35
  This function takes a list of PyTorch modules or state dictionaries and returns a new module with the averaged parameters, or a new state dictionary with the averaged parameters.
31
36
 
37
+ If `_fusion_bench_target_modules` attribute is set on the modules, only the parameters of the specified target submodules will be averaged.
38
+
32
39
  Args:
33
- modules (List[Union[nn.Module, StateDictType]]): A list of PyTorch modules or state dictionaries.
34
- base_module (Optional[nn.Module]): A base module to use for the new module. If provided, the averaged parameters will be loaded into this module. If not provided, a new module will be created by copying the first module in the list.
40
+ modules (List[Union[nn.Module, StateDictType, Tensor]]): A list of PyTorch modules or state dictionaries.
41
+ base_module (Optional[Union[nn.Module, StateDictType, Tensor]]): A base module to use for the new module. If provided, the averaged parameters will be loaded into this module. If not provided, a new module will be created by copying the first module in the list.
35
42
 
36
43
  Returns:
37
- module_or_state_dict (Union[nn.Module, StateDictType]): A new PyTorch module with the averaged parameters, or a new state dictionary with the averaged parameters.
38
-
44
+ module_or_state_dict (Union[nn.Module, StateDictType, Tensor]): A new PyTorch module with the averaged parameters, or a new state dictionary with the averaged parameters.
39
45
  Examples:
40
46
  >>> import torch.nn as nn
41
47
  >>> model1 = nn.Linear(10, 10)
@@ -47,23 +53,42 @@ def simple_average(
47
53
  >>> averaged_state_dict = simple_average([state_dict1, state_dict2])
48
54
  """
49
55
  assert len(modules) > 0, "modules must be a non-empty list"
56
+ validate_target_modules_equal(modules)
57
+
50
58
  if isinstance(modules[0], nn.Module):
51
59
  if base_module is None:
52
60
  new_module = deepcopy(modules[0])
53
61
  else:
54
62
  new_module = base_module
55
- state_dict = state_dict_avg([module.state_dict() for module in modules])
56
- new_module.load_state_dict(state_dict)
63
+ state_dict = state_dict_avg(
64
+ [get_target_state_dict(module) for module in modules]
65
+ )
66
+ load_state_dict_into_target_modules(new_module, state_dict)
57
67
  return new_module
58
68
  elif isinstance(modules[0], Mapping):
59
- return state_dict_avg(modules)
69
+ # if the modules are state dicts
70
+ # compute the average state dict
71
+ avg_state_dict = state_dict_avg(modules)
72
+ # load into base_module if provided
73
+ if base_module is not None:
74
+ for k in avg_state_dict:
75
+ base_module[k] = avg_state_dict[k]
76
+ return base_module
77
+ else:
78
+ return avg_state_dict
79
+ elif isinstance(modules[0], Tensor):
80
+ mean_tensor = torch.stack(modules, dim=0).mean(dim=0)
81
+ if base_module is not None:
82
+ base_module.data = mean_tensor
83
+ return base_module
84
+ else:
85
+ return mean_tensor
86
+ else:
87
+ raise ValueError(f"Unsupported type: {type(modules[0])}")
60
88
 
61
89
 
62
90
  @auto_register_config
63
- class SimpleAverageAlgorithm(
64
- SimpleProfilerMixin,
65
- BaseAlgorithm,
66
- ):
91
+ class SimpleAverageAlgorithm(SimpleProfilerMixin, BaseAlgorithm):
67
92
  def __init__(self, show_pbar: bool = False, inplace: bool = True, **kwargs):
68
93
  """
69
94
  Args:
@@ -87,13 +112,20 @@ class SimpleAverageAlgorithm(
87
112
  Returns:
88
113
  The fused model obtained by simple averaging.
89
114
  """
90
- if isinstance(modelpool, dict):
115
+ if not isinstance(modelpool, BaseModelPool):
91
116
  modelpool = BaseModelPool(modelpool)
92
117
 
93
118
  log.info(
94
119
  f"Fusing models using simple average on {len(modelpool.model_names)} models. "
95
120
  f"models: {modelpool.model_names}"
96
121
  )
122
+ if modelpool.has_instance_models and self.inplace:
123
+ log.warning(
124
+ "The model pool contains instance models, and inplace is set to True. "
125
+ "Therefore, the weights of the first model will be overwritten. "
126
+ "If this is desired behavior, this warning can be ignored."
127
+ )
128
+
97
129
  sd: Optional[StateDictType] = None
98
130
  forward_model = None
99
131
  merged_model_names = []
@@ -106,12 +138,12 @@ class SimpleAverageAlgorithm(
106
138
  with self.profile("merge weights"):
107
139
  if sd is None:
108
140
  # Initialize the state dictionary with the first model's state dictionary
109
- sd = model.state_dict()
141
+ sd = get_target_state_dict(model)
110
142
  forward_model = model if self.inplace else deepcopy(model)
111
143
  else:
112
144
  # Add the current model's state dictionary to the accumulated state dictionary
113
145
  sd = state_dict_add(
114
- sd, model.state_dict(), show_pbar=self.show_pbar
146
+ sd, get_target_state_dict(model), show_pbar=self.show_pbar
115
147
  )
116
148
  with self.profile("merge weights"):
117
149
  # Divide the accumulated state dictionary by the number of models to get the average
@@ -124,11 +156,13 @@ class SimpleAverageAlgorithm(
124
156
  forward_model = deepcopy(forward_model.meta_module).to_empty(
125
157
  device=forward_model._device
126
158
  )
127
- result = forward_model.load_state_dict(sd, strict=False)
159
+
160
+ result = load_state_dict_into_target_modules(forward_model, sd, strict=False)
128
161
  if result.unexpected_keys:
129
162
  raise ValueError(f"Unexpected keys in state dict: {result.unexpected_keys}")
130
163
  if result.missing_keys:
131
164
  log.warning(f"Missing keys in state dict: {result.missing_keys}")
165
+
132
166
  # print profile report and log the merged models
133
167
  self.print_profile_summary()
134
168
  log.info(f"merged {len(merged_model_names)} models:")
@@ -50,7 +50,7 @@ def task_arithmetic_merge(
50
50
  finetuned_models (List[nn.Module]): A list of fine-tuned models from which task vectors will be calculated.
51
51
  scaling_factor (float): A factor by which the task vectors will be scaled before merging.
52
52
  inplace (bool, optional): If True, the pre-trained model will be modified in place.
53
- If False, a copy of the pre-trained model will be modified. Defaults to True.
53
+ If False, a copy of the pre-trained model will be modified. Defaults to True.
54
54
 
55
55
  Returns:
56
56
  nn.Module: The pre-trained model with the merged task vectors.
@@ -1,6 +1,7 @@
1
1
  import functools
2
2
  import logging
3
3
  import os
4
+ import sys
4
5
  from typing import TYPE_CHECKING, Any, List, Mapping, Optional, TypeVar
5
6
 
6
7
  import lightning as L
@@ -12,17 +13,32 @@ from omegaconf import DictConfig, OmegaConf
12
13
 
13
14
  from fusion_bench.constants import RuntimeConstants
14
15
  from fusion_bench.utils import import_object
16
+ from fusion_bench.utils.hydra_utils import get_hydra_output_dir
15
17
  from fusion_bench.utils.instantiate_utils import instantiate
16
18
 
17
19
  if TYPE_CHECKING:
18
20
  import lightning.fabric.loggers.tensorboard
19
21
  from lightning.fabric.strategies import FSDPStrategy
22
+ from lightning.pytorch.loggers import MLFlowLogger
23
+ from mlflow.tracking.client import MlflowClient
20
24
 
21
25
  log = logging.getLogger(__name__)
22
26
 
23
27
  TensorOrModule = TypeVar("TensorOrModule", torch.Tensor, torch.nn.Module, Any)
24
28
 
25
29
 
30
+ def _fabric_has_logger(fabric: L.Fabric) -> bool:
31
+ """
32
+ Check if the fabric has a logger.
33
+
34
+ Args:
35
+ fabric (L.Fabric): The Lightning Fabric instance.
36
+ Returns:
37
+ bool: True if the fabric has a logger, False otherwise.
38
+ """
39
+ return fabric._loggers is not None and len(fabric._loggers) > 0
40
+
41
+
26
42
  def get_policy(*args: str) -> set:
27
43
  """
28
44
  Get the policy from the provided list of policy names.
@@ -43,6 +59,21 @@ def get_size_based_auto_wrap_policy(*args, **kwargs):
43
59
  return policy
44
60
 
45
61
 
62
+ def _is_mlflow_logger(fabric: L.Fabric) -> bool:
63
+ """
64
+ Check if the fabric's logger is an instance of MLFlowLogger.
65
+
66
+ Args:
67
+ fabric (L.Fabric): The Lightning Fabric instance.
68
+
69
+ Returns:
70
+ bool: True if the logger is an instance of MLFlowLogger, False otherwise.
71
+ """
72
+ if not _fabric_has_logger(fabric):
73
+ return False
74
+ return fabric.logger.__class__.__name__ == "MLFlowLogger"
75
+
76
+
46
77
  class LightningFabricMixin:
47
78
  """
48
79
  A mixin class for integrating Lightning Fabric into a project.
@@ -79,8 +110,8 @@ class LightningFabricMixin:
79
110
  """
80
111
  if self._fabric_instance is None:
81
112
  if config.get("fabric", None) is None:
82
- log.warning("No fabric configuration found. use default settings.")
83
- self._fabric_instance = L.Fabric()
113
+ log.warning("No fabric configuration found. use default settings. By default, use 1 device.")
114
+ self._fabric_instance = L.Fabric(devices=1)
84
115
  else:
85
116
  self._fabric_instance = instantiate(config.fabric)
86
117
  if not _is_using_cli(): # if not using cli, launch the fabric
@@ -123,7 +154,10 @@ class LightningFabricMixin:
123
154
  Retrieves the log directory from the fabric's logger.
124
155
  """
125
156
  if self.fabric is not None and len(self.fabric._loggers) > 0:
126
- log_dir = self.fabric.logger.log_dir
157
+ if hasattr(self.fabric.logger, "log_dir"):
158
+ log_dir = self.fabric.logger.log_dir
159
+ else:
160
+ log_dir = None
127
161
 
128
162
  # Special handling for SwanLabLogger to get the correct log directory
129
163
  if (
@@ -132,6 +166,20 @@ class LightningFabricMixin:
132
166
  ):
133
167
  log_dir = self.fabric.logger.save_dir or self.fabric.logger._logdir
134
168
 
169
+ if (
170
+ log_dir is None
171
+ and self.fabric.logger.__class__.__name__ == "MLFlowLogger"
172
+ ):
173
+ log_dir = self.fabric.logger.save_dir
174
+ if log_dir is None:
175
+ try:
176
+ log_dir = self._program.config.path.log_dir
177
+ except Exception:
178
+ log.error(
179
+ "Failed to get log_dir from program config for MLFlowLogger."
180
+ )
181
+ log_dir = "outputs"
182
+
135
183
  assert log_dir is not None, "log_dir should not be None"
136
184
  if self.fabric.is_global_zero and not os.path.exists(log_dir):
137
185
  os.makedirs(log_dir, exist_ok=True)
@@ -246,3 +294,60 @@ class LightningFabricMixin:
246
294
  """
247
295
  for i, param_group in enumerate(optimizer.param_groups):
248
296
  self.fabric.log(name_template.format(i), param_group["lr"], step=step)
297
+
298
+ def log_artifact(self, local_path: str, artifact_path: str | None = None):
299
+ """
300
+ Logs a file as an artifact to the fabric's logger.
301
+
302
+ Args:
303
+ local_dir: The path to the directory to log as an artifact.
304
+ artifact_path: The directory within the logger's artifact storage to save the file.
305
+ """
306
+ if _is_mlflow_logger(self.fabric):
307
+ logger: "MLFlowLogger" = self.fabric.logger
308
+ experiment: "MlflowClient" = logger.experiment
309
+ experiment.log_artifact(
310
+ logger.run_id,
311
+ local_path=local_path,
312
+ artifact_path=(artifact_path),
313
+ )
314
+
315
+ def log_artifacts(self, local_dir: str, artifact_path: str | None = None):
316
+ """
317
+ Logs a directory as artifacts to the fabric's logger.
318
+
319
+ Args:
320
+ local_dir: The path to the directory to log as artifacts.
321
+ artifact_path: The directory within the logger's artifact storage to save the files.
322
+ """
323
+ if _is_mlflow_logger(self.fabric):
324
+ logger: "MLFlowLogger" = self.fabric.logger
325
+ experiment: "MlflowClient" = logger.experiment
326
+ experiment.log_artifacts(
327
+ logger.run_id,
328
+ local_dir=local_dir,
329
+ artifact_path=artifact_path,
330
+ )
331
+
332
+ def finalize(self):
333
+ """
334
+ Destructor to ensure proper cleanup of the Lightning Fabric instance.
335
+ """
336
+ if self._fabric_instance is None:
337
+ return
338
+
339
+ if _fabric_has_logger(self.fabric) and _is_mlflow_logger(self.fabric):
340
+ if sys.exc_info()[0] is None:
341
+ status = "success"
342
+ else:
343
+ status = "failed"
344
+ self.fabric.logger.finalize(status)
345
+
346
+ del self._fabric_instance
347
+ self._fabric_instance = None
348
+
349
+ def __del__(self):
350
+ """
351
+ Destructor to ensure proper cleanup of the Lightning Fabric instance.
352
+ """
353
+ self.finalize()
@@ -68,7 +68,7 @@ def auto_register_config(cls):
68
68
 
69
69
  Behavior:
70
70
  - **Parameter Registration**: All non-variadic parameters (excluding ``*args``, ``**kwargs``)
71
- from the __init__ method are automatically added to _config_mapping
71
+ from the __init__ method are automatically added to _config_mapping
72
72
  - **Positional Arguments**: Handled in order and mapped to corresponding parameter names
73
73
  - **Keyword Arguments**: Processed after positional arguments, overriding any conflicts
74
74
  - **Default Values**: Applied when parameters are not provided via arguments
@@ -7,11 +7,12 @@ from omegaconf import DictConfig, OmegaConf, UnsupportedValueType
7
7
  from torch import nn
8
8
  from torch.utils.data import Dataset
9
9
 
10
- from fusion_bench import TorchModelType
10
+ from fusion_bench import StateDictType, TorchModelType
11
11
  from fusion_bench.mixins import BaseYAMLSerializable, HydraConfigMixin
12
12
  from fusion_bench.utils import (
13
13
  ValidationError,
14
14
  instantiate,
15
+ state_dict_sub,
15
16
  timeit_context,
16
17
  validate_model_name,
17
18
  )
@@ -57,6 +58,10 @@ class BaseModelPool(
57
58
  **kwargs,
58
59
  ):
59
60
  if isinstance(models, List):
61
+ log.debug(
62
+ "Initializing BaseModelPool with a list of models. "
63
+ "Converting to a dictionary with integer string keys."
64
+ )
60
65
  models = {str(model_idx): model for model_idx, model in enumerate(models)}
61
66
 
62
67
  if isinstance(models, dict):
@@ -81,6 +86,22 @@ class BaseModelPool(
81
86
  self._test_datasets = test_datasets
82
87
  super().__init__(**kwargs)
83
88
 
89
+ @property
90
+ def has_instance_models(self) -> bool:
91
+ """
92
+ Check if the model pool contains any pre-instantiated models.
93
+
94
+ Attention:
95
+ Some algorithms may modify the models in-place if they are pre-instantiated.
96
+
97
+ Returns:
98
+ bool: True if there are pre-instantiated models, False otherwise.
99
+ """
100
+ for model_cfg in self._models.values():
101
+ if isinstance(model_cfg, nn.Module):
102
+ return True
103
+ return False
104
+
84
105
  @property
85
106
  def has_pretrained(self) -> bool:
86
107
  """
@@ -329,6 +350,21 @@ class BaseModelPool(
329
350
  for model_name in self.model_names:
330
351
  yield model_name, self.load_model(model_name)
331
352
 
353
+ def load_pretrained_model_and_task_vectors(
354
+ self,
355
+ ) -> Tuple[TorchModelType, List[StateDictType]]:
356
+ pretrained_model = self.load_pretrained_model()
357
+
358
+ task_vectors = []
359
+ for model_name in self.model_names:
360
+ finetuned_model = self.load_model(model_name)
361
+ task_vector = state_dict_sub(
362
+ finetuned_model.state_dict(), pretrained_model.state_dict()
363
+ )
364
+ task_vectors.append(task_vector)
365
+
366
+ return pretrained_model, task_vectors
367
+
332
368
  @property
333
369
  def has_train_dataset(self) -> bool:
334
370
  """
@@ -98,7 +98,7 @@ class ConvNextForImageClassificationPool(BaseModelPool):
98
98
  - Load ConvNeXt models either from a pretrained checkpoint or from config.
99
99
  - Optionally adapt the classifier head to match dataset classnames.
100
100
  - Override `forward` to return logits for consistent interfaces within
101
- FusionBench.
101
+ FusionBench.
102
102
 
103
103
  See `fusion_bench.modelpool.resnet_for_image_classification` for a closely
104
104
  related ResNet-based pool with analogous behavior.
@@ -161,6 +161,9 @@ class ConvNextForImageClassificationPool(BaseModelPool):
161
161
  ).logits
162
162
  model.original_forward = original_forward
163
163
 
164
+ # Mark ConvNeXt layers for FusionBench fusion
165
+ model._fusion_bench_target_modules = ["convnext"]
166
+
164
167
  return model
165
168
 
166
169
  @override
@@ -180,7 +183,7 @@ class ConvNextForImageClassificationPool(BaseModelPool):
180
183
  - The ConvNeXt model via `model.save_pretrained`.
181
184
  - The paired image processor via `AutoImageProcessor.save_pretrained`.
182
185
  - If `algorithm_config` is provided and on rank-zero, a README model card
183
- documenting the FusionBench configuration.
186
+ documenting the FusionBench configuration.
184
187
  """
185
188
  model.save_pretrained(path)
186
189
  self.load_processor().save_pretrained(path)
@@ -62,16 +62,36 @@ class HFCLIPClassifier(nn.Module):
62
62
  persistent=False,
63
63
  )
64
64
 
65
+ # NOTE:
66
+ # The property setters seems not to work properly with `nn.Module` attributes.
67
+ # So avoid using them in practice.
68
+ # To set the text or vision model, directly access the attributes.
69
+ # For example:
70
+ # classifier.clip_model.text_model = new_text_model
71
+ # or
72
+ # classifier.clip_model.vision_model = new_vision_model
73
+ # reference: https://github.com/pytorch/pytorch/issues/52664
74
+
65
75
  @property
66
76
  def text_model(self):
67
77
  """Get the text model component of CLIP."""
68
78
  return self.clip_model.text_model
69
79
 
80
+ @text_model.setter
81
+ def text_model(self, model: nn.Module):
82
+ """Set the text model component of CLIP."""
83
+ self.clip_model.text_model = model
84
+
70
85
  @property
71
86
  def vision_model(self):
72
87
  """Get the vision model component of CLIP."""
73
88
  return self.clip_model.vision_model
74
89
 
90
+ @vision_model.setter
91
+ def vision_model(self, model: nn.Module):
92
+ """Set the vision model component of CLIP."""
93
+ self.clip_model.vision_model = model
94
+
75
95
  def set_classification_task(
76
96
  self,
77
97
  classnames: List[str],
@@ -0,0 +1 @@
1
+ from .base import ModulatedModel, TaskModulator
@@ -0,0 +1,123 @@
1
+ import logging
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any, Dict, Generic, Optional
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from fusion_bench import TorchModelType
9
+
10
+ log = logging.getLogger(__name__)
11
+
12
+
13
+ class ModulatedModel(nn.Module, Generic[TorchModelType]):
14
+ """
15
+ A model wrapper that uses task-specific modulators to adapt a shared backbone
16
+ for different tasks.
17
+
18
+ The model maintains a shared backbone and task-specific modulators. During forward pass,
19
+ the appropriate modulator is applied based on the current task.
20
+ """
21
+
22
+ _current_task: Optional[str] = None
23
+
24
+ def __init__(
25
+ self,
26
+ backbone: TorchModelType,
27
+ modulators: Dict[str, "TaskModulator[TorchModelType]"],
28
+ ):
29
+ super().__init__()
30
+ self.backbone = backbone
31
+ self.modulators = nn.ModuleDict(modulators)
32
+
33
+ def add_modulator(self, task_name: str, modulator: "TaskModulator[TorchModelType]"):
34
+ """Add a new task-specific modulator."""
35
+ if task_name in self.modulators:
36
+ raise ValueError(f"Modulator for task '{task_name}' already exists.")
37
+ self.modulators[task_name] = modulator
38
+
39
+ def remove_modulator(self, task_name: str):
40
+ """Remove an existing task-specific modulator."""
41
+ if task_name not in self.modulators:
42
+ raise ValueError(f"Modulator for task '{task_name}' does not exist.")
43
+ if self._current_task == task_name:
44
+ log.warning(
45
+ f"Removing modulator for current task '{task_name}'. "
46
+ "This will make unset the current task unpredictable."
47
+ )
48
+ del self.modulators[task_name]
49
+
50
+ def set_task(self, task_name: str):
51
+ """Set the current task for inference."""
52
+ if task_name not in self.modulators:
53
+ raise ValueError(
54
+ f"Task '{task_name}' not found in modulators. Available tasks: {list(self.modulators.keys())}"
55
+ )
56
+ if self._current_task == task_name:
57
+ return
58
+
59
+ # unset previous task
60
+ if self._current_task is not None:
61
+ self.modulators[self._current_task].remove(self)
62
+ assert (
63
+ self._current_task is None
64
+ ), "Current task should be None after removal."
65
+
66
+ # set new task
67
+ self.modulators[task_name].apply(self)
68
+ self._current_task = task_name
69
+
70
+ @property
71
+ def current_task(self) -> Optional[str]:
72
+ """Get the current task name."""
73
+ return self._current_task
74
+
75
+ def forward(self, *args, **kwargs) -> Any:
76
+ """
77
+ Forward pass with task-specific modulation.
78
+
79
+ Args:
80
+ *args: Positional arguments for the backbone model
81
+ **kwargs: Keyword arguments for the backbone model
82
+
83
+ Returns:
84
+ Model output after applying task-specific modulation
85
+ """
86
+ if self._current_task is None:
87
+ raise ValueError(
88
+ "No task specified. Set current_task or provide 'task' argument."
89
+ )
90
+
91
+ return self.backbone(*args, **kwargs)
92
+
93
+
94
+ class TaskModulator(nn.Module, Generic[TorchModelType], ABC):
95
+ """
96
+ Lightweight, task-specific parameterization that modulates
97
+ a shared representation.
98
+
99
+ This is the base class for all task modulators. Subclasses should implement
100
+ the `apply` method to define how the modulator adapts the backbone model
101
+ for a specific task.
102
+ """
103
+
104
+ @abstractmethod
105
+ def apply(self, modulated_model: "ModulatedModel[TorchModelType]"):
106
+ """
107
+ Apply task-specific modulation to the backbone model.
108
+
109
+ Args:
110
+ modulated_model: The modulated model
111
+ """
112
+ raise NotImplementedError("Subclasses must implement the apply method.")
113
+
114
+ @abstractmethod
115
+ def remove(self, modulated_model: "ModulatedModel[TorchModelType]"):
116
+ """
117
+ Remove task-specific modulation from the backbone model.
118
+ This is called when switching tasks.
119
+
120
+ Args:
121
+ modulated_model: The modulated model
122
+ """
123
+ raise NotImplementedError("Subclasses must implement the remove method.")