fusion-bench 0.2.23__py3-none-any.whl → 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. fusion_bench/__init__.py +152 -42
  2. fusion_bench/dataset/__init__.py +27 -4
  3. fusion_bench/dataset/clip_dataset.py +2 -2
  4. fusion_bench/method/__init__.py +18 -1
  5. fusion_bench/method/classification/__init__.py +27 -2
  6. fusion_bench/method/classification/image_classification_finetune.py +214 -0
  7. fusion_bench/method/ensemble.py +17 -2
  8. fusion_bench/method/linear/__init__.py +6 -2
  9. fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
  10. fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
  11. fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
  12. fusion_bench/method/opcm/opcm.py +1 -0
  13. fusion_bench/method/pwe_moe/module.py +0 -2
  14. fusion_bench/method/simple_average.py +2 -2
  15. fusion_bench/method/tall_mask/task_arithmetic.py +2 -2
  16. fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
  17. fusion_bench/method/ties_merging/ties_merging.py +22 -6
  18. fusion_bench/method/wudi/__init__.py +1 -0
  19. fusion_bench/method/wudi/wudi.py +105 -0
  20. fusion_bench/mixins/__init__.py +2 -0
  21. fusion_bench/mixins/lightning_fabric.py +4 -0
  22. fusion_bench/mixins/pyinstrument.py +174 -0
  23. fusion_bench/mixins/serialization.py +25 -78
  24. fusion_bench/mixins/simple_profiler.py +106 -23
  25. fusion_bench/modelpool/__init__.py +2 -0
  26. fusion_bench/modelpool/base_pool.py +77 -14
  27. fusion_bench/modelpool/causal_lm/causal_lm.py +32 -10
  28. fusion_bench/modelpool/clip_vision/modelpool.py +56 -19
  29. fusion_bench/modelpool/resnet_for_image_classification.py +208 -0
  30. fusion_bench/models/__init__.py +35 -9
  31. fusion_bench/models/hf_clip.py +4 -0
  32. fusion_bench/models/hf_utils.py +2 -1
  33. fusion_bench/models/model_card_templates/default.md +8 -1
  34. fusion_bench/models/wrappers/ensemble.py +136 -7
  35. fusion_bench/optim/__init__.py +40 -2
  36. fusion_bench/optim/lr_scheduler/__init__.py +27 -1
  37. fusion_bench/optim/muon.py +339 -0
  38. fusion_bench/programs/__init__.py +2 -0
  39. fusion_bench/programs/fabric_fusion_program.py +2 -2
  40. fusion_bench/programs/fusion_program.py +271 -0
  41. fusion_bench/scripts/cli.py +2 -2
  42. fusion_bench/taskpool/clip_vision/taskpool.py +11 -4
  43. fusion_bench/tasks/clip_classification/__init__.py +15 -0
  44. fusion_bench/utils/__init__.py +167 -21
  45. fusion_bench/utils/devices.py +30 -8
  46. fusion_bench/utils/lazy_imports.py +91 -12
  47. fusion_bench/utils/lazy_state_dict.py +58 -5
  48. fusion_bench/utils/misc.py +104 -13
  49. fusion_bench/utils/packages.py +4 -0
  50. fusion_bench/utils/path.py +7 -0
  51. fusion_bench/utils/pylogger.py +6 -0
  52. fusion_bench/utils/rich_utils.py +8 -3
  53. fusion_bench/utils/state_dict_arithmetic.py +935 -162
  54. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/METADATA +10 -3
  55. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/RECORD +76 -55
  56. fusion_bench_config/method/classification/image_classification_finetune.yaml +16 -0
  57. fusion_bench_config/method/classification/image_classification_finetune_test.yaml +6 -0
  58. fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
  59. fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
  60. fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
  61. fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
  62. fusion_bench_config/method/wudi/wudi.yaml +4 -0
  63. fusion_bench_config/model_fusion.yaml +45 -0
  64. fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
  65. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
  66. fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
  67. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar10.yaml +14 -0
  68. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar100.yaml +14 -0
  69. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar10.yaml +14 -0
  70. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar100.yaml +14 -0
  71. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar10.yaml +14 -0
  72. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar100.yaml +14 -0
  73. fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
  74. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/WHEEL +0 -0
  75. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/entry_points.txt +0 -0
  76. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/licenses/LICENSE +0 -0
  77. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,208 @@
1
+ from typing import (
2
+ TYPE_CHECKING,
3
+ Any,
4
+ Callable,
5
+ Dict,
6
+ Literal,
7
+ Optional,
8
+ TypeVar,
9
+ Union,
10
+ override,
11
+ )
12
+
13
+ import torch
14
+ from omegaconf import DictConfig
15
+ from torch import nn
16
+
17
+ from fusion_bench import BaseModelPool, auto_register_config, get_rankzero_logger
18
+ from fusion_bench.tasks.clip_classification import get_classnames, get_num_classes
19
+
20
+ if TYPE_CHECKING:
21
+ from torchvision.models import ResNet as TorchVisionResNet
22
+
23
+ log = get_rankzero_logger(__name__)
24
+
25
+
26
+ def load_torchvision_resnet(
27
+ model_name: str, weights: Optional[str], num_classes: Optional[int]
28
+ ) -> "TorchVisionResNet":
29
+ import torchvision.models
30
+
31
+ model_fn = getattr(torchvision.models, model_name)
32
+ model: "TorchVisionResNet" = model_fn(weights=weights)
33
+
34
+ if num_classes is not None:
35
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
36
+
37
+ return model
38
+
39
+
40
+ def load_transformers_resnet(
41
+ config_path: str, pretrained: bool, dataset_name: Optional[str]
42
+ ):
43
+ from transformers import AutoConfig, ResNetForImageClassification
44
+
45
+ if pretrained:
46
+ model = ResNetForImageClassification.from_pretrained(config_path)
47
+ else:
48
+ config = AutoConfig.from_pretrained(config_path)
49
+ model = ResNetForImageClassification(config)
50
+
51
+ if dataset_name is None:
52
+ return model
53
+
54
+ classnames = get_classnames(dataset_name)
55
+ id2label = {i: c for i, c in enumerate(classnames)}
56
+ label2id = {c: i for i, c in enumerate(classnames)}
57
+ model.config.id2label = id2label
58
+ model.config.label2id = label2id
59
+
60
+ model.classifier[1] = (
61
+ nn.Linear(
62
+ model.classifier[1].in_features,
63
+ len(classnames),
64
+ )
65
+ if model.config.num_labels > 0
66
+ else nn.Identity()
67
+ )
68
+ return model
69
+
70
+
71
+ @auto_register_config
72
+ class ResNetForImageClassificationPool(BaseModelPool):
73
+ def __init__(self, type: str, **kwargs):
74
+ super().__init__(**kwargs)
75
+ assert type in ["torchvision", "transformers"]
76
+
77
+ def load_processor(
78
+ self, stage: Literal["train", "val", "test"] = "test", *args, **kwargs
79
+ ):
80
+ if self.type == "torchvision":
81
+ from torchvision import transforms
82
+
83
+ to_tensor = transforms.ToTensor()
84
+ normalize = transforms.Normalize(
85
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
86
+ )
87
+ if stage == "train":
88
+ train_transform = transforms.Compose(
89
+ [
90
+ transforms.RandomResizedCrop(224),
91
+ transforms.RandomHorizontalFlip(),
92
+ to_tensor,
93
+ normalize,
94
+ ]
95
+ )
96
+ return train_transform
97
+ else:
98
+ val_transform = transforms.Compose(
99
+ [
100
+ transforms.Resize(256),
101
+ transforms.CenterCrop(224),
102
+ to_tensor,
103
+ normalize,
104
+ ]
105
+ )
106
+ return val_transform
107
+
108
+ elif self.type == "transformers":
109
+ from transformers import AutoImageProcessor
110
+
111
+ if self.has_pretrained:
112
+ config_path = self._models["_pretrained_"].config_path
113
+ else:
114
+ for model_cfg in self._models.values():
115
+ if isinstance(model_cfg, str):
116
+ config_path = model_cfg
117
+ break
118
+ if "config_path" in model_cfg:
119
+ config_path = model_cfg["config_path"]
120
+ break
121
+ return AutoImageProcessor.from_pretrained(config_path)
122
+
123
+ @override
124
+ def load_model(self, model_name_or_config: Union[str, DictConfig], *args, **kwargs):
125
+ log.debug(f"Loading model: {model_name_or_config}", stacklevel=2)
126
+ if (
127
+ isinstance(model_name_or_config, str)
128
+ and model_name_or_config in self._models
129
+ ):
130
+ model_name_or_config = self._models[model_name_or_config]
131
+
132
+ if self.type == "torchvision":
133
+ from torchvision.models import (
134
+ resnet18,
135
+ resnet34,
136
+ resnet50,
137
+ resnet101,
138
+ resnet152,
139
+ )
140
+
141
+ match model_name_or_config:
142
+ case "resnet18":
143
+ model = resnet18()
144
+ case "resnet34":
145
+ model = resnet34()
146
+ case "resnet50":
147
+ model = resnet50()
148
+ case "resnet101":
149
+ model = resnet101()
150
+ case "resnet152":
151
+ model = resnet152()
152
+ case dict() | DictConfig() as model_config:
153
+ if "dataset_name" in model_config:
154
+ num_classes = get_num_classes(model_config["dataset_name"])
155
+ if "num_classes" in model_config:
156
+ assert (
157
+ num_classes == model_config["num_classes"]
158
+ ), f"num_classes mismatch: {num_classes} vs {model_config['num_classes']}"
159
+ elif "num_classes" in model_config:
160
+ num_classes = model_config["num_classes"]
161
+ else:
162
+ num_classes = None
163
+ model = load_torchvision_resnet(
164
+ model_name=model_config["model_name"],
165
+ weights=model_config.get("weights", None),
166
+ num_classes=num_classes,
167
+ )
168
+ case _:
169
+ raise ValueError(
170
+ f"Invalid model_name_or_config type: {type(model_name_or_config)}"
171
+ )
172
+ elif self.type == "transformers":
173
+ match model_name_or_config:
174
+ case str() as model_path:
175
+ from transformers import AutoModelForImageClassification
176
+
177
+ model = AutoModelForImageClassification.from_pretrained(model_path)
178
+ case dict() | DictConfig() as model_config:
179
+
180
+ model = load_transformers_resnet(
181
+ config_path=model_config["config_path"],
182
+ pretrained=model_config.get("pretrained", False),
183
+ dataset_name=model_config.get("dataset_name", None),
184
+ )
185
+ case _:
186
+ raise ValueError(
187
+ f"Invalid model_name_or_config type: {type(model_name_or_config)}"
188
+ )
189
+
190
+ # override forward to return logits only
191
+ original_forward = model.forward
192
+ model.forward = lambda pixel_values, **kwargs: original_forward(
193
+ pixel_values=pixel_values, **kwargs
194
+ ).logits
195
+ model.original_forward = original_forward
196
+ else:
197
+ raise ValueError(f"Unknown model type: {self.type}")
198
+ return model
199
+
200
+ @override
201
+ def save_model(self, model, path, *args, **kwargs):
202
+ if self.type == "torchvision":
203
+ torch.save(model.state_dict(), path)
204
+ elif self.type == "transformers":
205
+ model.save_pretrained(path)
206
+ self.load_processor().save_pretrained(path)
207
+ else:
208
+ raise ValueError(f"Unknown model type: {self.type}")
@@ -1,10 +1,36 @@
1
1
  # flake8: noqa F401
2
- from fusion_bench.utils import LazyStateDict
3
-
4
- from . import separate_io, utils
5
- from .hf_utils import (
6
- create_default_model_card,
7
- load_model_card_template,
8
- save_pretrained_with_remote_code,
9
- )
10
- from .parameter_dict import ParameterDictModel
2
+ import sys
3
+ from typing import TYPE_CHECKING
4
+
5
+ from fusion_bench.utils.lazy_imports import LazyImporter
6
+
7
+ from . import utils
8
+
9
+ _extra_objects = {
10
+ "utils": utils,
11
+ }
12
+ _import_structure = {
13
+ "hf_utils": [
14
+ "create_default_model_card",
15
+ "load_model_card_template",
16
+ "save_pretrained_with_remote_code",
17
+ ],
18
+ "parameter_dict": ["ParameterDictModel"],
19
+ "separate_io": ["separate_load", "separate_save"],
20
+ }
21
+
22
+ if TYPE_CHECKING:
23
+ from .hf_utils import (
24
+ create_default_model_card,
25
+ load_model_card_template,
26
+ save_pretrained_with_remote_code,
27
+ )
28
+ from .parameter_dict import ParameterDictModel
29
+ from .separate_io import separate_load, separate_save
30
+ else:
31
+ sys.modules[__name__] = LazyImporter(
32
+ __name__,
33
+ globals()["__file__"],
34
+ _import_structure,
35
+ extra_objects=_extra_objects,
36
+ )
@@ -195,5 +195,9 @@ class HFCLIPClassifier(nn.Module):
195
195
  pass
196
196
  elif isinstance(image_embeds, BaseModelOutputWithPooling):
197
197
  image_embeds = image_embeds[1]
198
+ elif isinstance(image_embeds, dict) and "pooler_output" in image_embeds:
199
+ image_embeds = image_embeds["pooler_output"]
200
+ else:
201
+ raise ValueError("Unsupported output type from vision model outputs")
198
202
  image_embeds = self.clip_model.visual_projection(image_embeds)
199
203
  return image_embeds
@@ -143,7 +143,7 @@ def save_pretrained_with_remote_code(
143
143
 
144
144
  def create_default_model_card(
145
145
  models: list[str],
146
- *,
146
+ base_model: Optional[str] = None,
147
147
  title: str = "Deep Model Fusion",
148
148
  tags: list[str] = ["fusion-bench", "merge"],
149
149
  description=None,
@@ -154,6 +154,7 @@ def create_default_model_card(
154
154
 
155
155
  template: Template = Template(load_model_card_template("default.md"))
156
156
  card = template.render(
157
+ base_model=base_model,
157
158
  models=models,
158
159
  library_name="transformers",
159
160
  title=title,
@@ -1,5 +1,8 @@
1
1
  ---
2
2
  base_model:
3
+ {%- if base_model is not none %}
4
+ - {{ base_model }}
5
+ {%- endif %}
3
6
  {%- for model in models %}
4
7
  - {{ model }}
5
8
  {%- endfor %}
@@ -18,7 +21,11 @@ tags:
18
21
  This is a merged model created using [fusion-bench](https://github.com/tanganke/fusion_bench).
19
22
 
20
23
  The following models were included in the merge:
21
- {% for model in models %}
24
+
25
+ {% if base_model is not none %}
26
+ - base model: {{ base_model }}
27
+ {%- endif %}
28
+ {%- for model in models %}
22
29
  - {{ model }}
23
30
  {%- endfor %}
24
31
 
@@ -1,10 +1,17 @@
1
- from typing import Any, Callable, Dict, List, Union, cast
1
+ import logging
2
+ from typing import Any, Callable, Dict, Generic, List, Union, cast
2
3
 
3
4
  import numpy as np
4
5
  import torch
6
+ import torch.futures
5
7
  from omegaconf import ListConfig
6
8
  from torch import Tensor, nn
7
9
 
10
+ from fusion_bench.utils.devices import to_device
11
+ from fusion_bench.utils.type import TorchModelType
12
+
13
+ log = logging.getLogger(__name__)
14
+
8
15
 
9
16
  def aggregate_tensors(
10
17
  outputs: List[Any], aggregate_fn: Callable
@@ -58,12 +65,16 @@ def aggregate_tensors(
58
65
  raise ValueError("Unsupported type for outputs")
59
66
 
60
67
 
61
- class EnsembleModule(nn.Module):
68
+ class EnsembleModule(nn.Module, Generic[TorchModelType]):
62
69
  """
63
70
  Ensemble module that averages the outputs of multiple models.
64
71
  """
65
72
 
66
- def __init__(self, models: List[nn.Module]):
73
+ def __init__(
74
+ self,
75
+ models: List[TorchModelType],
76
+ device_map: Dict[int, Union[int, str]] | None = None,
77
+ ):
67
78
  """
68
79
  Initializes the EnsembleModule with a list of models.
69
80
 
@@ -73,6 +84,16 @@ class EnsembleModule(nn.Module):
73
84
  super().__init__()
74
85
  # TODO: distribute models to devices
75
86
  self.model_list = nn.ModuleList(models)
87
+ self.device_map = device_map
88
+ if self.device_map is not None:
89
+ self._move_models_to_devices()
90
+
91
+ def _move_models_to_devices(self):
92
+ for model_idx, device_id in self.device_map.items():
93
+ log.info(f"Moving model {model_idx} to device {device_id}")
94
+ self.model_list[model_idx] = self.model_list[model_idx].to(
95
+ device_id, non_blocking=True
96
+ )
76
97
 
77
98
  def _aggregate_tensors(self, outputs: List[Tensor]) -> Tensor:
78
99
  """
@@ -86,6 +107,49 @@ class EnsembleModule(nn.Module):
86
107
  """
87
108
  return torch.stack(outputs).mean(dim=0)
88
109
 
110
+ def _parallel_forward_with_device_map(self, *args: Any, **kwargs: Any) -> List[Any]:
111
+ """
112
+ Performs parallel forward pass using device mapping with futures.
113
+
114
+ Args:
115
+ *args: Variable length argument list.
116
+ **kwargs: Arbitrary keyword arguments.
117
+
118
+ Returns:
119
+ List[Any]: List of outputs from all models, all moved to the same device.
120
+ """
121
+ futures = []
122
+
123
+ device_data_cache = {}
124
+ for i, model in enumerate(self.model_list):
125
+ device_id = self.device_map.get(i, "cpu")
126
+
127
+ if device_id not in device_data_cache:
128
+ # Move inputs to the same device as the model
129
+ device_args = to_device(
130
+ args, device_id, copy_on_move=True, non_blocking=True
131
+ )
132
+ device_kwargs = to_device(
133
+ kwargs, device_id, copy_on_move=True, non_blocking=True
134
+ )
135
+ device_data_cache[device_id] = (device_args, device_kwargs)
136
+ else:
137
+ device_args, device_kwargs = device_data_cache[device_id]
138
+
139
+ # Create a future for asynchronous execution
140
+ future = torch.jit.fork(model, *device_args, **device_kwargs)
141
+ futures.append(future)
142
+
143
+ # Wait for all futures to complete and collect results
144
+ outputs = [torch.jit.wait(future) for future in futures]
145
+
146
+ # Move all outputs to the same device (use the device of the first model or cpu as fallback)
147
+ target_device = self.device_map.get(0, "cpu") if self.device_map else "cpu"
148
+ outputs = [
149
+ to_device(output, target_device, non_blocking=True) for output in outputs
150
+ ]
151
+ return outputs
152
+
89
153
  def forward(self, *args: Any, **kwargs: Any) -> Any:
90
154
  """
91
155
  Performs a forward pass by averaging the outputs of the models.
@@ -97,20 +161,25 @@ class EnsembleModule(nn.Module):
97
161
  Returns:
98
162
  Aggregated output from the ensemble of models.
99
163
  """
100
- outputs = [model(*args, **kwargs) for model in self.model_list]
164
+ if self.device_map is None:
165
+ outputs = [model(*args, **kwargs) for model in self.model_list]
166
+ else:
167
+ # Parallel execution with device mapping
168
+ outputs = self._parallel_forward_with_device_map(*args, **kwargs)
101
169
  return aggregate_tensors(outputs, self._aggregate_tensors)
102
170
 
103
171
 
104
- class WeightedEnsembleModule(nn.Module):
172
+ class WeightedEnsembleModule(nn.Module, Generic[TorchModelType]):
105
173
  """
106
174
  Ensemble module that computes a weighted average of the outputs from multiple models.
107
175
  """
108
176
 
109
177
  def __init__(
110
178
  self,
111
- models: List[nn.Module],
179
+ models: List[TorchModelType],
112
180
  weights: List[float] | Tensor | np.ndarray,
113
181
  normalize: bool = True,
182
+ device_map: Dict[int, Union[int, str]] | None = None,
114
183
  ):
115
184
  """
116
185
  Initializes the WeightedEnsembleModule with models and their corresponding weights.
@@ -119,9 +188,12 @@ class WeightedEnsembleModule(nn.Module):
119
188
  models (List[nn.Module]): List of models to ensemble.
120
189
  weights (List[float] | Tensor | np.ndarray): Weights for each model.
121
190
  normalize (bool, optional): If True, normalizes the weights. Defaults to True.
191
+ device_map (Dict[int, Union[int, str]] | None, optional): Device mapping for parallel execution. Defaults to None.
122
192
  """
123
193
  super().__init__()
124
194
  self.model_list = nn.ModuleList(models)
195
+ self.device_map = device_map
196
+
125
197
  if isinstance(weights, (list, tuple, ListConfig)):
126
198
  weights = torch.tensor(weights)
127
199
  elif isinstance(weights, Tensor):
@@ -139,6 +211,17 @@ class WeightedEnsembleModule(nn.Module):
139
211
  weights = weights / weights.sum()
140
212
  self.register_buffer("weights", weights)
141
213
 
214
+ if self.device_map is not None:
215
+ self._move_models_to_devices()
216
+
217
+ def _move_models_to_devices(self):
218
+ """Move models to their assigned devices according to device_map."""
219
+ for model_idx, device_id in self.device_map.items():
220
+ log.info(f"Moving model {model_idx} to device {device_id}")
221
+ self.model_list[model_idx] = self.model_list[model_idx].to(
222
+ device_id, non_blocking=True
223
+ )
224
+
142
225
  def _aggregate_tensors(self, outputs: List[Tensor]) -> Tensor:
143
226
  """
144
227
  Aggregates a list of tensors using the provided weights.
@@ -152,6 +235,48 @@ class WeightedEnsembleModule(nn.Module):
152
235
  weights = cast(Tensor, self.weights).view(-1, *([1] * outputs[0].dim()))
153
236
  return (torch.stack(outputs) * weights).sum(dim=0)
154
237
 
238
+ def _parallel_forward_with_device_map(self, *args: Any, **kwargs: Any) -> List[Any]:
239
+ """
240
+ Performs parallel forward pass using device mapping with futures.
241
+
242
+ Args:
243
+ *args: Variable length argument list.
244
+ **kwargs: Arbitrary keyword arguments.
245
+
246
+ Returns:
247
+ List[Any]: List of outputs from all models, all moved to the same device.
248
+ """
249
+ futures = []
250
+
251
+ device_data_cache = {}
252
+ for i, model in enumerate(self.model_list):
253
+ device_id = self.device_map.get(i, "cpu")
254
+
255
+ if device_id not in device_data_cache:
256
+ # Move inputs to the same device as the model
257
+ device_args = to_device(
258
+ args, device_id, copy_on_move=True, non_blocking=True
259
+ )
260
+ device_kwargs = to_device(
261
+ kwargs, device_id, copy_on_move=True, non_blocking=True
262
+ )
263
+ device_data_cache[device_id] = (device_args, device_kwargs)
264
+ else:
265
+ device_args, device_kwargs = device_data_cache[device_id]
266
+
267
+ # Create a future for asynchronous execution
268
+ future = torch.jit.fork(model, *device_args, **device_kwargs)
269
+ futures.append(future)
270
+
271
+ # Wait for all futures to complete and collect results
272
+ outputs = [torch.jit.wait(future) for future in futures]
273
+
274
+ # Move all outputs to the same device (use the device of the first model or cpu as fallback)
275
+ target_device = self.device_map.get(0, "cpu") if self.device_map else "cpu"
276
+ outputs = [to_device(output, target_device) for output in outputs]
277
+
278
+ return outputs
279
+
155
280
  def forward(self, *args: Any, **kwargs: Any) -> Any:
156
281
  """
157
282
  Performs a forward pass by computing the weighted average of the models' outputs.
@@ -163,7 +288,11 @@ class WeightedEnsembleModule(nn.Module):
163
288
  Returns:
164
289
  Weighted aggregated output from the ensemble of models.
165
290
  """
166
- outputs = [model(*args, **kwargs) for model in self.model_list]
291
+ if self.device_map is None:
292
+ outputs = [model(*args, **kwargs) for model in self.model_list]
293
+ else:
294
+ # Parallel execution with device mapping
295
+ outputs = self._parallel_forward_with_device_map(*args, **kwargs)
167
296
  return aggregate_tensors(outputs, self._aggregate_tensors)
168
297
 
169
298
 
@@ -1,2 +1,40 @@
1
- from . import exception, lr_scheduler
2
- from .mezo import MeZO
1
+ import sys
2
+ from typing import TYPE_CHECKING
3
+
4
+ from fusion_bench.utils.lazy_imports import LazyImporter
5
+
6
+ from . import lr_scheduler
7
+
8
+ _extra_objects = {
9
+ "lr_scheduler": lr_scheduler,
10
+ }
11
+ _import_structure = {
12
+ "exception": [
13
+ "NoClosureError",
14
+ "NoSparseGradientError",
15
+ "NegativeLRError",
16
+ "NegativeStepError",
17
+ "ZeroParameterSizeError",
18
+ ],
19
+ "mezo": ["MeZO"],
20
+ "muon": ["Muon"],
21
+ }
22
+
23
+ if TYPE_CHECKING:
24
+ from .exception import (
25
+ NegativeLRError,
26
+ NegativeStepError,
27
+ NoClosureError,
28
+ NoSparseGradientError,
29
+ ZeroParameterSizeError,
30
+ )
31
+ from .mezo import MeZO
32
+ from .muon import Muon
33
+
34
+ else:
35
+ sys.modules[__name__] = LazyImporter(
36
+ __name__,
37
+ globals()["__file__"],
38
+ _import_structure,
39
+ extra_objects=_extra_objects,
40
+ )
@@ -1 +1,27 @@
1
- from .linear_warmup import *
1
+ import sys
2
+ from typing import TYPE_CHECKING
3
+
4
+ from fusion_bench.utils.lazy_imports import LazyImporter
5
+
6
+ _import_structure = {
7
+ "linear_warmup": [
8
+ "BaseLinearWarmupScheduler",
9
+ "LinearWarmupScheduler",
10
+ "CosineDecayWithWarmup",
11
+ "PolySchedulerWithWarmup",
12
+ ],
13
+ }
14
+
15
+ if TYPE_CHECKING:
16
+ from .linear_warmup import (
17
+ BaseLinearWarmupScheduler,
18
+ CosineDecayWithWarmup,
19
+ LinearWarmupScheduler,
20
+ PolySchedulerWithWarmup,
21
+ )
22
+ else:
23
+ sys.modules[__name__] = LazyImporter(
24
+ __name__,
25
+ globals()["__file__"],
26
+ _import_structure,
27
+ )