fusion-bench 0.2.27__py3-none-any.whl → 0.2.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fusion_bench/__init__.py CHANGED
@@ -41,6 +41,8 @@ _import_structure = {
41
41
  "CausalLMBackbonePool",
42
42
  "CausalLMPool",
43
43
  "CLIPVisionModelPool",
44
+ "ConvNextForImageClassificationPool",
45
+ "Dinov2ForImageClassificationPool",
44
46
  "GPT2ForSequenceClassificationPool",
45
47
  "HuggingFaceGPT2ClassificationPool",
46
48
  "NYUv2ModelPool",
@@ -107,6 +109,8 @@ if TYPE_CHECKING:
107
109
  CausalLMBackbonePool,
108
110
  CausalLMPool,
109
111
  CLIPVisionModelPool,
112
+ ConvNextForImageClassificationPool,
113
+ Dinov2ForImageClassificationPool,
110
114
  GPT2ForSequenceClassificationPool,
111
115
  HuggingFaceGPT2ClassificationPool,
112
116
  NYUv2ModelPool,
@@ -34,6 +34,13 @@ from torch.utils.data import random_split
34
34
  log = get_rankzero_logger(__name__)
35
35
 
36
36
 
37
+ def _get_base_model_name(model) -> Optional[str]:
38
+ if hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
39
+ return model.config._name_or_path
40
+ else:
41
+ return None
42
+
43
+
37
44
  @auto_register_config
38
45
  class ImageClassificationFineTuning(BaseAlgorithm):
39
46
  """Fine-tuning algorithm for image classification models.
@@ -107,6 +114,8 @@ class ImageClassificationFineTuning(BaseAlgorithm):
107
114
  """
108
115
  # load model and dataset
109
116
  model = modelpool.load_pretrained_or_first_model()
117
+ base_model_name = _get_base_model_name(model)
118
+
110
119
  assert isinstance(model, nn.Module), "Loaded model is not a nn.Module."
111
120
 
112
121
  assert (
@@ -135,6 +144,8 @@ class ImageClassificationFineTuning(BaseAlgorithm):
135
144
  val_dataset, processor=modelpool.load_processor(stage="val")
136
145
  )
137
146
  val_loader = self.get_dataloader(val_dataset, stage="val")
147
+ else:
148
+ val_loader = None
138
149
 
139
150
  # configure optimizer
140
151
  optimizer = instantiate(self.optimizer, params=model.parameters())
@@ -210,6 +221,7 @@ class ImageClassificationFineTuning(BaseAlgorithm):
210
221
  ),
211
222
  algorithm_config=self.config,
212
223
  description=f"Fine-tuned ResNet model on dataset {dataset_name}.",
224
+ base_model=base_model_name,
213
225
  )
214
226
  return model
215
227
 
@@ -149,7 +149,10 @@ class TaskArithmeticAlgorithm(
149
149
  )
150
150
  with self.profile("merge weights"):
151
151
  # scale the task vector
152
- task_vector = state_dict_mul(task_vector, self.config.scaling_factor)
152
+ # here we keep the dtype when the elements of value are all zeros to avoid dtype mismatch
153
+ task_vector = state_dict_mul(
154
+ task_vector, self.config.scaling_factor, keep_dtype_when_zero=True
155
+ )
153
156
  # add the task vector to the pretrained model
154
157
  state_dict = state_dict_add(pretrained_model.state_dict(), task_vector)
155
158
 
@@ -8,6 +8,14 @@ _import_structure = {
8
8
  "base_pool": ["BaseModelPool"],
9
9
  "causal_lm": ["CausalLMPool", "CausalLMBackbonePool"],
10
10
  "clip_vision": ["CLIPVisionModelPool"],
11
+ "convnext_for_image_classification": [
12
+ "ConvNextForImageClassificationPool",
13
+ "load_transformers_convnext",
14
+ ],
15
+ "dinov2_for_image_classification": [
16
+ "Dinov2ForImageClassificationPool",
17
+ "load_transformers_dinov2",
18
+ ],
11
19
  "nyuv2_modelpool": ["NYUv2ModelPool"],
12
20
  "huggingface_automodel": ["AutoModelPool"],
13
21
  "seq2seq_lm": ["Seq2SeqLMPool"],
@@ -18,7 +26,10 @@ _import_structure = {
18
26
  "GPT2ForSequenceClassificationPool",
19
27
  ],
20
28
  "seq_classification_lm": ["SequenceClassificationModelPool"],
21
- "resnet_for_image_classification": ["ResNetForImageClassificationPool"],
29
+ "resnet_for_image_classification": [
30
+ "ResNetForImageClassificationPool",
31
+ "load_transformers_resnet",
32
+ ],
22
33
  }
23
34
 
24
35
 
@@ -26,6 +37,14 @@ if TYPE_CHECKING:
26
37
  from .base_pool import BaseModelPool
27
38
  from .causal_lm import CausalLMBackbonePool, CausalLMPool
28
39
  from .clip_vision import CLIPVisionModelPool
40
+ from .convnext_for_image_classification import (
41
+ ConvNextForImageClassificationPool,
42
+ load_transformers_convnext,
43
+ )
44
+ from .dinov2_for_image_classification import (
45
+ Dinov2ForImageClassificationPool,
46
+ load_transformers_dinov2,
47
+ )
29
48
  from .huggingface_automodel import AutoModelPool
30
49
  from .huggingface_gpt2_classification import (
31
50
  GPT2ForSequenceClassificationPool,
@@ -34,7 +53,10 @@ if TYPE_CHECKING:
34
53
  from .nyuv2_modelpool import NYUv2ModelPool
35
54
  from .openclip_vision import OpenCLIPVisionModelPool
36
55
  from .PeftModelForSeq2SeqLM import PeftModelForSeq2SeqLMPool
37
- from .resnet_for_image_classification import ResNetForImageClassificationPool
56
+ from .resnet_for_image_classification import (
57
+ ResNetForImageClassificationPool,
58
+ load_transformers_resnet,
59
+ )
38
60
  from .seq2seq_lm import Seq2SeqLMPool
39
61
  from .seq_classification_lm import SequenceClassificationModelPool
40
62
 
@@ -3,7 +3,7 @@ from copy import deepcopy
3
3
  from typing import Dict, Generator, List, Optional, Tuple, Union
4
4
 
5
5
  import torch
6
- from omegaconf import DictConfig
6
+ from omegaconf import DictConfig, OmegaConf, UnsupportedValueType
7
7
  from torch import nn
8
8
  from torch.utils.data import Dataset
9
9
 
@@ -52,6 +52,13 @@ class BaseModelPool(
52
52
  ):
53
53
  if isinstance(models, List):
54
54
  models = {str(model_idx): model for model_idx, model in enumerate(models)}
55
+
56
+ if isinstance(models, dict):
57
+ try: # try to convert to DictConfig
58
+ models = OmegaConf.create(models)
59
+ except UnsupportedValueType:
60
+ pass
61
+
55
62
  self._models = models
56
63
  self._train_datasets = train_datasets
57
64
  self._val_datasets = val_datasets
@@ -0,0 +1,198 @@
1
+ """
2
+ Hugging Face ConvNeXt image classification model pool.
3
+
4
+ This module provides a `BaseModelPool` implementation that loads and saves
5
+ ConvNeXt models for image classification via `transformers`. It optionally
6
+ reconfigures the classification head to match a dataset's class names and
7
+ overrides `forward` to return logits only for simpler downstream usage.
8
+
9
+ See also: `fusion_bench.modelpool.resnet_for_image_classification` for a
10
+ parallel implementation for ResNet-based classifiers.
11
+ """
12
+
13
+ import os
14
+ from typing import (
15
+ TYPE_CHECKING,
16
+ Any,
17
+ Callable,
18
+ Dict,
19
+ Literal,
20
+ Optional,
21
+ TypeVar,
22
+ Union,
23
+ override,
24
+ )
25
+
26
+ import torch
27
+ from lightning_utilities.core.rank_zero import rank_zero_only
28
+ from omegaconf import DictConfig
29
+ from torch import nn
30
+
31
+ from fusion_bench import BaseModelPool, auto_register_config, get_rankzero_logger
32
+ from fusion_bench.tasks.clip_classification import get_classnames, get_num_classes
33
+
34
+ log = get_rankzero_logger(__name__)
35
+
36
+
37
+ def load_transformers_convnext(
38
+ config_path: str, pretrained: bool, dataset_name: Optional[str]
39
+ ):
40
+ """Create a ConvNeXt image classification model from a config or checkpoint.
41
+
42
+ Args:
43
+ config_path: A model identifier or local path understood by
44
+ `transformers.AutoConfig/AutoModel` (e.g., "facebook/convnext-base-224").
45
+ pretrained: If True, load weights via `from_pretrained`; otherwise, build
46
+ the model from config only.
47
+ dataset_name: Optional dataset key used by FusionBench to derive class
48
+ names via `get_classnames`. When provided, the model's id/label maps
49
+ are updated and the classifier head is resized accordingly.
50
+
51
+ Returns:
52
+ ConvNextForImageClassification: A `transformers.ConvNextForImageClassification` instance. If
53
+ `dataset_name` is set, the classifier head is adapted to the number of
54
+ classes. The model's `config.id2label` and `config.label2id` are also
55
+ populated.
56
+
57
+ Notes:
58
+ The overall structure mirrors the ResNet implementation in
59
+ `fusion_bench.modelpool.resnet_for_image_classification`.
60
+ """
61
+ from transformers import AutoConfig, ConvNextForImageClassification
62
+
63
+ if pretrained:
64
+ model = ConvNextForImageClassification.from_pretrained(config_path)
65
+ else:
66
+ config = AutoConfig.from_pretrained(config_path)
67
+ model = ConvNextForImageClassification(config)
68
+
69
+ if dataset_name is None:
70
+ return model
71
+
72
+ classnames = get_classnames(dataset_name)
73
+ id2label = {i: c for i, c in enumerate(classnames)}
74
+ label2id = {c: i for i, c in enumerate(classnames)}
75
+ model.config.id2label = id2label
76
+ model.config.label2id = label2id
77
+ model.num_labels = model.config.num_labels
78
+
79
+ model.classifier = (
80
+ nn.Linear(
81
+ model.classifier.in_features,
82
+ len(classnames),
83
+ device=model.classifier.weight.device,
84
+ dtype=model.classifier.weight.dtype,
85
+ )
86
+ if model.config.num_labels > 0
87
+ else nn.Identity()
88
+ )
89
+ return model
90
+
91
+
92
+ @auto_register_config
93
+ class ConvNextForImageClassificationPool(BaseModelPool):
94
+ """Model pool for ConvNeXt image classification models (HF Transformers).
95
+
96
+ Responsibilities:
97
+ - Load an `AutoImageProcessor` compatible with the configured ConvNeXt model.
98
+ - Load ConvNeXt models either from a pretrained checkpoint or from config.
99
+ - Optionally adapt the classifier head to match dataset classnames.
100
+ - Override `forward` to return logits for consistent interfaces within
101
+ FusionBench.
102
+
103
+ See `fusion_bench.modelpool.resnet_for_image_classification` for a closely
104
+ related ResNet-based pool with analogous behavior.
105
+ """
106
+
107
+ def load_processor(self, *args, **kwargs):
108
+ from transformers import AutoImageProcessor
109
+
110
+ if self.has_pretrained:
111
+ config_path = self._models["_pretrained_"].config_path
112
+ else:
113
+ for model_cfg in self._models.values():
114
+ if isinstance(model_cfg, str):
115
+ config_path = model_cfg
116
+ break
117
+ if "config_path" in model_cfg:
118
+ config_path = model_cfg["config_path"]
119
+ break
120
+ return AutoImageProcessor.from_pretrained(config_path)
121
+
122
+ @override
123
+ def load_model(self, model_name_or_config: Union[str, DictConfig], *args, **kwargs):
124
+ """Load a ConvNeXt model described by a name, path, or DictConfig.
125
+
126
+ Accepts either a string (pretrained identifier or local path) or a
127
+ config mapping with keys: `config_path`, optional `pretrained` (bool),
128
+ and optional `dataset_name` to resize the classifier.
129
+
130
+ Returns:
131
+ A model whose `forward` is wrapped to return only logits to align
132
+ with FusionBench expectations.
133
+ """
134
+ log.debug(f"Loading model: {model_name_or_config}", stacklevel=2)
135
+ if (
136
+ isinstance(model_name_or_config, str)
137
+ and model_name_or_config in self._models
138
+ ):
139
+ model_name_or_config = self._models[model_name_or_config]
140
+
141
+ match model_name_or_config:
142
+ case str() as model_path:
143
+ from transformers import AutoModelForImageClassification
144
+
145
+ model = AutoModelForImageClassification.from_pretrained(model_path)
146
+ case dict() | DictConfig() as model_config:
147
+ model = load_transformers_convnext(
148
+ model_config["config_path"],
149
+ pretrained=model_config.get("pretrained", True),
150
+ dataset_name=model_config.get("dataset_name", None),
151
+ )
152
+ case _:
153
+ raise ValueError(
154
+ f"Unsupported model_name_or_config type: {type(model_name_or_config)}"
155
+ )
156
+
157
+ # override forward to return logits only
158
+ original_forward = model.forward
159
+ model.forward = lambda pixel_values, **kwargs: original_forward(
160
+ pixel_values=pixel_values, **kwargs
161
+ ).logits
162
+ model.original_forward = original_forward
163
+
164
+ return model
165
+
166
+ @override
167
+ def save_model(
168
+ self,
169
+ model,
170
+ path,
171
+ algorithm_config: Optional[DictConfig] = None,
172
+ description: Optional[str] = None,
173
+ base_model: Optional[str] = None,
174
+ *args,
175
+ **kwargs,
176
+ ):
177
+ """Save the model, processor, and an optional model card to disk.
178
+
179
+ Artifacts written to `path`:
180
+ - The ConvNeXt model via `model.save_pretrained`.
181
+ - The paired image processor via `AutoImageProcessor.save_pretrained`.
182
+ - If `algorithm_config` is provided and on rank-zero, a README model card
183
+ documenting the FusionBench configuration.
184
+ """
185
+ model.save_pretrained(path)
186
+ self.load_processor().save_pretrained(path)
187
+
188
+ if algorithm_config is not None and rank_zero_only.rank == 0:
189
+ from fusion_bench.models.hf_utils import create_default_model_card
190
+
191
+ model_card_str = create_default_model_card(
192
+ algorithm_config=algorithm_config,
193
+ description=description,
194
+ modelpool_config=self.config,
195
+ base_model=base_model,
196
+ )
197
+ with open(os.path.join(path, "README.md"), "w") as f:
198
+ f.write(model_card_str)
@@ -0,0 +1,197 @@
1
+ """
2
+ Hugging Face DINOv2 image classification model pool.
3
+
4
+ This module provides a `BaseModelPool` implementation that loads and saves
5
+ DINOv2 models for image classification via `transformers`. It optionally
6
+ reconfigures the classification head to match a dataset's class names and
7
+ overrides `forward` to return logits only for simpler downstream usage.
8
+
9
+ See also: `fusion_bench.modelpool.convnext_for_image_classification` for a
10
+ parallel implementation for ConvNeXt-based classifiers.
11
+ """
12
+
13
+ import os
14
+ from typing import (
15
+ TYPE_CHECKING,
16
+ Any,
17
+ Callable,
18
+ Dict,
19
+ Literal,
20
+ Optional,
21
+ TypeVar,
22
+ Union,
23
+ override,
24
+ )
25
+
26
+ import torch
27
+ from lightning_utilities.core.rank_zero import rank_zero_only
28
+ from omegaconf import DictConfig
29
+ from torch import nn
30
+
31
+ from fusion_bench import BaseModelPool, auto_register_config, get_rankzero_logger
32
+ from fusion_bench.tasks.clip_classification import get_classnames, get_num_classes
33
+
34
+ log = get_rankzero_logger(__name__)
35
+
36
+
37
+ def load_transformers_dinov2(
38
+ config_path: str, pretrained: bool, dataset_name: Optional[str]
39
+ ):
40
+ """Create a DINOv2 image classification model from a config or checkpoint.
41
+
42
+ Args:
43
+ config_path: A model identifier or local path understood by
44
+ `transformers.AutoConfig/AutoModel` (e.g., "facebook/dinov2-base").
45
+ pretrained: If True, load weights via `from_pretrained`; otherwise, build
46
+ the model from config only.
47
+ dataset_name: Optional dataset key used by FusionBench to derive class
48
+ names via `get_classnames`. When provided, the model's id/label maps
49
+ are updated and the classifier head is resized accordingly.
50
+
51
+ Returns:
52
+ Dinov2ForImageClassification: A `transformers.Dinov2ForImageClassification` instance. If
53
+ `dataset_name` is set, the classifier head is adapted to the number of
54
+ classes. The model's `config.id2label` and `config.label2id` are also
55
+ populated.
56
+
57
+ Notes:
58
+ The overall structure mirrors the ConvNeXt implementation in
59
+ `fusion_bench.modelpool.convnext_for_image_classification`.
60
+ """
61
+ from transformers import AutoConfig, Dinov2ForImageClassification
62
+
63
+ if pretrained:
64
+ model = Dinov2ForImageClassification.from_pretrained(config_path)
65
+ else:
66
+ config = AutoConfig.from_pretrained(config_path)
67
+ model = Dinov2ForImageClassification(config)
68
+
69
+ if dataset_name is None:
70
+ return model
71
+
72
+ classnames = get_classnames(dataset_name)
73
+ id2label = {i: c for i, c in enumerate(classnames)}
74
+ label2id = {c: i for i, c in enumerate(classnames)}
75
+ model.config.id2label = id2label
76
+ model.config.label2id = label2id
77
+ model.num_labels = model.config.num_labels
78
+
79
+ # If the model is configured with a positive number of labels, resize the
80
+ # classifier to match the dataset classes; otherwise leave it as identity.
81
+ model.classifier = (
82
+ nn.Linear(
83
+ model.classifier.in_features,
84
+ len(classnames),
85
+ device=model.classifier.weight.device,
86
+ dtype=model.classifier.weight.dtype,
87
+ )
88
+ if model.config.num_labels > 0
89
+ else nn.Identity()
90
+ )
91
+ return model
92
+
93
+
94
+ @auto_register_config
95
+ class Dinov2ForImageClassificationPool(BaseModelPool):
96
+ """Model pool for DINOv2 image classification models (HF Transformers)."""
97
+
98
+ def load_processor(self, *args, **kwargs):
99
+ """Load the paired image processor for this model pool.
100
+
101
+ Uses the configured model's identifier or config path to retrieve the
102
+ appropriate `transformers.AutoImageProcessor` instance. If a pretrained
103
+ model entry exists in the pool configuration, it is preferred to derive
104
+ the processor to ensure tokenization/normalization parity.
105
+ """
106
+ from transformers import AutoImageProcessor
107
+
108
+ if self.has_pretrained:
109
+ config_path = self._models["_pretrained_"].config_path
110
+ else:
111
+ for model_cfg in self._models.values():
112
+ if isinstance(model_cfg, str):
113
+ config_path = model_cfg
114
+ break
115
+ if "config_path" in model_cfg:
116
+ config_path = model_cfg["config_path"]
117
+ break
118
+ return AutoImageProcessor.from_pretrained(config_path)
119
+
120
+ @override
121
+ def load_model(self, model_name_or_config: Union[str, DictConfig], *args, **kwargs):
122
+ """Load a DINOv2 model described by a name, path, or DictConfig.
123
+
124
+ Accepts either a string (pretrained identifier or local path) or a
125
+ config mapping with keys: `config_path`, optional `pretrained` (bool),
126
+ and optional `dataset_name` to resize the classifier.
127
+
128
+ Returns:
129
+ A model whose `forward` is wrapped to return only logits to align
130
+ with FusionBench expectations.
131
+ """
132
+ log.debug(f"Loading model: {model_name_or_config}", stacklevel=2)
133
+ if (
134
+ isinstance(model_name_or_config, str)
135
+ and model_name_or_config in self._models
136
+ ):
137
+ model_name_or_config = self._models[model_name_or_config]
138
+
139
+ match model_name_or_config:
140
+ case str() as model_path:
141
+ from transformers import AutoModelForImageClassification
142
+
143
+ model = AutoModelForImageClassification.from_pretrained(model_path)
144
+ case dict() | DictConfig() as model_config:
145
+ model = load_transformers_dinov2(
146
+ model_config["config_path"],
147
+ pretrained=model_config.get("pretrained", True),
148
+ dataset_name=model_config.get("dataset_name", None),
149
+ )
150
+ case _:
151
+ raise ValueError(
152
+ f"Unsupported model_name_or_config type: {type(model_name_or_config)}"
153
+ )
154
+
155
+ # Override forward to return logits only, to unify the interface across
156
+ # FusionBench model pools and simplify downstream usage.
157
+ original_forward = model.forward
158
+ model.forward = lambda pixel_values, **kwargs: original_forward(
159
+ pixel_values=pixel_values, **kwargs
160
+ ).logits
161
+ model.original_forward = original_forward
162
+
163
+ return model
164
+
165
+ @override
166
+ def save_model(
167
+ self,
168
+ model,
169
+ path,
170
+ algorithm_config: Optional[DictConfig] = None,
171
+ description: Optional[str] = None,
172
+ base_model: Optional[str] = None,
173
+ *args,
174
+ **kwargs,
175
+ ):
176
+ """Save the model, processor, and an optional model card to disk.
177
+
178
+ Artifacts written to `path`:
179
+ - The DINOv2 model via `model.save_pretrained`.
180
+ - The paired image processor via `AutoImageProcessor.save_pretrained`.
181
+ - If `algorithm_config` is provided and on rank-zero, a README model card
182
+ documenting the FusionBench configuration.
183
+ """
184
+ model.save_pretrained(path)
185
+ self.load_processor().save_pretrained(path)
186
+
187
+ if algorithm_config is not None and rank_zero_only.rank == 0:
188
+ from fusion_bench.models.hf_utils import create_default_model_card
189
+
190
+ model_card_str = create_default_model_card(
191
+ algorithm_config=algorithm_config,
192
+ description=description,
193
+ modelpool_config=self.config,
194
+ base_model=base_model,
195
+ )
196
+ with open(os.path.join(path, "README.md"), "w") as f:
197
+ f.write(model_card_str)
@@ -138,6 +138,7 @@ def load_transformers_resnet(
138
138
  label2id = {c: i for i, c in enumerate(classnames)}
139
139
  model.config.id2label = id2label
140
140
  model.config.label2id = label2id
141
+ model.num_labels = model.config.num_labels
141
142
 
142
143
  model.classifier[1] = (
143
144
  nn.Linear(
@@ -407,7 +408,7 @@ class ResNetForImageClassificationPool(BaseModelPool):
407
408
 
408
409
  model = load_transformers_resnet(
409
410
  config_path=model_config["config_path"],
410
- pretrained=model_config.get("pretrained", False),
411
+ pretrained=model_config.get("pretrained", True),
411
412
  dataset_name=model_config.get("dataset_name", None),
412
413
  )
413
414
  case _:
@@ -432,6 +433,7 @@ class ResNetForImageClassificationPool(BaseModelPool):
432
433
  path,
433
434
  algorithm_config: Optional[DictConfig] = None,
434
435
  description: Optional[str] = None,
436
+ base_model: Optional[str] = None,
435
437
  *args,
436
438
  **kwargs,
437
439
  ):
@@ -479,6 +481,7 @@ class ResNetForImageClassificationPool(BaseModelPool):
479
481
  from fusion_bench.models.hf_utils import create_default_model_card
480
482
 
481
483
  model_card_str = create_default_model_card(
484
+ base_model=base_model,
482
485
  algorithm_config=algorithm_config,
483
486
  description=description,
484
487
  modelpool_config=self.config,
@@ -1,6 +1,6 @@
1
1
  ---
2
- base_model:
3
2
  {%- if base_model is not none %}
3
+ base_model:
4
4
  - {{ base_model }}
5
5
  {%- endif %}
6
6
  {%- for model in models %}
@@ -1,31 +1,72 @@
1
1
  import json
2
2
  from pathlib import Path
3
- from typing import Any, Union
3
+ from typing import TYPE_CHECKING, Any, Union
4
4
 
5
+ if TYPE_CHECKING:
6
+ from pyarrow.fs import FileSystem
5
7
 
6
- def save_to_json(obj, path: Union[str, Path]):
8
+
9
+ def save_to_json(obj, path: Union[str, Path], filesystem: "FileSystem" = None):
7
10
  """
8
11
  save an object to a json file
9
12
 
10
13
  Args:
11
14
  obj (Any): the object to save
12
15
  path (Union[str, Path]): the path to save the object
16
+ filesystem (FileSystem, optional): PyArrow FileSystem to use for writing.
17
+ If None, uses local filesystem via standard Python open().
18
+ Can also be an s3fs.S3FileSystem or fsspec filesystem.
13
19
  """
14
- with open(path, "w") as f:
15
- json.dump(obj, f)
20
+ if filesystem is not None:
21
+ json_str = json.dumps(obj)
22
+ # Check if it's an fsspec-based filesystem (like s3fs)
23
+ if hasattr(filesystem, "open"):
24
+ # Direct fsspec/s3fs usage - more reliable for some endpoints
25
+ path_str = str(path)
26
+ with filesystem.open(path_str, "w") as f:
27
+ f.write(json_str)
28
+ else:
29
+ # Use PyArrow filesystem
30
+ path_str = str(path)
31
+ with filesystem.open_output_stream(path_str) as f:
32
+ f.write(json_str.encode("utf-8"))
33
+ else:
34
+ # Use standard Python file operations
35
+ with open(path, "w") as f:
36
+ json.dump(obj, f)
16
37
 
17
38
 
18
- def load_from_json(path: Union[str, Path]) -> Union[dict, list]:
39
+ def load_from_json(
40
+ path: Union[str, Path], filesystem: "FileSystem" = None
41
+ ) -> Union[dict, list]:
19
42
  """load an object from a json file
20
43
 
21
44
  Args:
22
45
  path (Union[str, Path]): the path to load the object
46
+ filesystem (FileSystem, optional): PyArrow FileSystem to use for reading.
47
+ If None, uses local filesystem via standard Python open().
48
+ Can also be an s3fs.S3FileSystem or fsspec filesystem.
23
49
 
24
50
  Returns:
25
- dict: the loaded object
51
+ Union[dict, list]: the loaded object
26
52
  """
27
- with open(path, "r") as f:
28
- return json.load(f)
53
+ if filesystem is not None:
54
+ # Check if it's an fsspec-based filesystem (like s3fs)
55
+ if hasattr(filesystem, "open"):
56
+ # Direct fsspec/s3fs usage
57
+ path_str = str(path)
58
+ with filesystem.open(path_str, "r") as f:
59
+ return json.load(f)
60
+ else:
61
+ # Use PyArrow filesystem
62
+ path_str = str(path)
63
+ with filesystem.open_input_stream(path_str) as f:
64
+ json_data = f.read().decode("utf-8")
65
+ return json.loads(json_data)
66
+ else:
67
+ # Use standard Python file operations
68
+ with open(path, "r") as f:
69
+ return json.load(f)
29
70
 
30
71
 
31
72
  def _is_list_of_dict(obj) -> bool:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fusion-bench
3
- Version: 0.2.27
3
+ Version: 0.2.28
4
4
  Summary: A Comprehensive Benchmark of Deep Model Fusion
5
5
  Author-email: Anke Tang <tang.anke@foxmail.com>
6
6
  Project-URL: Repository, https://github.com/tanganke/fusion_bench
@@ -1,4 +1,4 @@
1
- fusion_bench/__init__.py,sha256=Rw9sT2ZegKMxZAG7FBDgqVOqBGlJ-43C5p_EarRHd1M,5816
1
+ fusion_bench/__init__.py,sha256=C-0-HgZFdRjscXqpfNsz7iGUijUeSoP4GFRnFxuxQ7M,5992
2
2
  fusion_bench/__main__.py,sha256=weUjxpP3ULnDgUxCehdbmoCM9cqfkhDhGB85tAF5qoE,81
3
3
  fusion_bench/_get_started/__init__.py,sha256=Ht6OK6Luei2kdY9jRZzRQfzBlm3Yfm64BkXxpzeRg9Q,40
4
4
  fusion_bench/_get_started/greeting_program.py,sha256=wvVsPa7Djwx5Z5spAI6F9Kvv9KwfNkjIgJVH8oXR3Bo,1233
@@ -80,7 +80,7 @@ fusion_bench/method/bitdelta/bitdelta_utils/diff.py,sha256=o3ib5sgGDYLgnL8YTfX0Y
80
80
  fusion_bench/method/classification/__init__.py,sha256=byVJ574JQ_DUvsDv8S6ZM6BKAv4ZZ964Ej4btm0aC7k,867
81
81
  fusion_bench/method/classification/clip_finetune.py,sha256=5q5Sr3eVVh8DfYdeSoGjwaKDksC8F2dY2r8Dl-wRaDg,15844
82
82
  fusion_bench/method/classification/continual_clip_finetune.py,sha256=OLhZKS-6aCnafevZkZYcNMKTWDDj3DATB27eZl_i8EY,11530
83
- fusion_bench/method/classification/image_classification_finetune.py,sha256=ExUwsBsDHX6Kq1G9arapgf3xQZJLBcNoRfCIXqIsbD0,14967
83
+ fusion_bench/method/classification/image_classification_finetune.py,sha256=xWSspEuiyM9mz7nTFCLMbJMvkuD-k3B7mx-KMvq7nEU,15310
84
84
  fusion_bench/method/concrete_subspace/__init__.py,sha256=jJoFcjnQe-jvccsm9DuCXna378m9XBT9vV1fEZbdfR0,464
85
85
  fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py,sha256=UkLOkaa_Dzlb4Q5ES69Y9GV1bodTnD7DzZFreykt65s,24706
86
86
  fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py,sha256=Nx-3AiAeIt5zmcC21Ta2_-4cAQg9hOWvThurXNZzA-w,10580
@@ -229,7 +229,7 @@ fusion_bench/method/tall_mask/__init__.py,sha256=XINPP8PqGQ01he9p2RyHaKGyrcYoJuY
229
229
  fusion_bench/method/tall_mask/task_arithmetic.py,sha256=RX_JgEPwG52EPYGXWYGuq0LBeyJHMbVZn7Qy_4QmSsQ,4373
230
230
  fusion_bench/method/tall_mask/utils.py,sha256=Wlp8WcPwR_lCaBIZ9rgG6ewLfSzz3G7kPk9yj13pvls,8817
231
231
  fusion_bench/method/task_arithmetic/__init__.py,sha256=pSx_NV5Ra_6UXpyYWCi6ANQoAnEtymZt_X1dDN9wT4Y,96
232
- fusion_bench/method/task_arithmetic/task_arithmetic.py,sha256=KsSBshf04MUwIjoc0HAAmY6cWMqjZwZOYXbUuU4EaL0,6320
232
+ fusion_bench/method/task_arithmetic/task_arithmetic.py,sha256=yGMWk2--VlXTcQjDjnPdiug1q_rpjzu5SFvgCYDfTQ0,6479
233
233
  fusion_bench/method/task_singular_vector/TSVC.py,sha256=yn4SrZNvtA6PoGYJmbmtNeDyDbGnRCgfZ7ZCg914AZU,410
234
234
  fusion_bench/method/task_singular_vector/TSVM.py,sha256=Sdgoi8xT0Hl19pmGdIuUS3D1DsVqSVD-Hipp-Sj_HoA,13652
235
235
  fusion_bench/method/task_singular_vector/__init__.py,sha256=WMucyl9pu_Ev2kcdrfT4moqMMbzD7hHQVFME5Su5jMA,298
@@ -280,13 +280,15 @@ fusion_bench/mixins/simple_profiler.py,sha256=QA4fZhD-uL06fZaoqBQowI0c_qrAUhWszF
280
280
  fusion_bench/mixins/optim/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
281
281
  fusion_bench/mixins/optim/adamw_with_warmup.py,sha256=qTnRl8GVVIfaplOFBHnJFuZUbxPZRWRGHGNzm_EDhDE,1421
282
282
  fusion_bench/modelpool/PeftModelForSeq2SeqLM.py,sha256=rxPKTTWno3KAcTTEfydPpXx1b0EJa8PLbqrberweFF8,2108
283
- fusion_bench/modelpool/__init__.py,sha256=wKAkEgit_1ZtDAOKOntzrUKdCjOFIxnPMYN02B970Wg,1671
284
- fusion_bench/modelpool/base_pool.py,sha256=5snzTmqn1Xs_dy0Ws5QWxs9uCAXMwIuclrwfikKPh9o,12298
283
+ fusion_bench/modelpool/__init__.py,sha256=qDlBPrWFW-Z-LByzmfqP1ozYhWx2lYAEjhqjKF4EAbY,2307
284
+ fusion_bench/modelpool/base_pool.py,sha256=7v01al93RjJ5CynUM-HnM6mCgCX9asUmaqPNmxioNoA,12531
285
+ fusion_bench/modelpool/convnext_for_image_classification.py,sha256=m9MxFgfzNjGnHOU6gufaTPgkk67lifNNwW03nHUxXKo,7377
286
+ fusion_bench/modelpool/dinov2_for_image_classification.py,sha256=Wd60J5Ji4KwXUYTPcYYXuYWrcpDlh7pjGZ-zjjRqYio,7496
285
287
  fusion_bench/modelpool/huggingface_automodel.py,sha256=OJ6EyYyjNv1_Bhjn-zli-e__BJ0xVa4Fx9lhXVb-DJo,552
286
288
  fusion_bench/modelpool/huggingface_gpt2_classification.py,sha256=j8nicVwtoLXY4RPE2dcepeEB3agBKkkH-xA3yMj1czw,2014
287
289
  fusion_bench/modelpool/lazy_state_dict_pool.py,sha256=HtEA85rqSCHfsIddI5sKDcZf5kSuHNwrb8fF1TUSTr0,652
288
290
  fusion_bench/modelpool/nyuv2_modelpool.py,sha256=btuXmYxwfjI6MnGakhoOf53Iyb9fxYH20CavGTrTcnA,1375
289
- fusion_bench/modelpool/resnet_for_image_classification.py,sha256=1Q79oj3FIBQBOr13zCvIcscBKLA0PHbPmTarwVlhIww,19873
291
+ fusion_bench/modelpool/resnet_for_image_classification.py,sha256=drSQt6xMZnag2drrjepCu8jpORF_ui8MJj_CipqoRCU,20004
290
292
  fusion_bench/modelpool/causal_lm/__init__.py,sha256=F432-aDIgAbUITj4GNZS9dgUKKhaDMCbTeHB-9MecaQ,99
291
293
  fusion_bench/modelpool/causal_lm/causal_lm.py,sha256=FbatPI6aAJbaT5qa4Get2I0i8fxmbq0N6xwajolXpdg,19993
292
294
  fusion_bench/modelpool/clip_vision/__init__.py,sha256=3b9gN2bWUsoA1EmpitnIMnIlX7nklxbkn4WJ0QJtS2c,43
@@ -329,7 +331,7 @@ fusion_bench/models/llama/model_utils/mod.py,sha256=xzNOgTRfOK9q8kml4Q2nmSOl23f3
329
331
  fusion_bench/models/llama/model_utils/visual.py,sha256=wpqWqEASyA7WhJLCfC26h0Cdn5CXnwC1qPJUlSXggo4,8310
330
332
  fusion_bench/models/masks/__init__.py,sha256=vXG6jrBkDbPsnrX6nMEYAW1rQuGEWDgdjID7cKzXvrs,69
331
333
  fusion_bench/models/masks/mask_model.py,sha256=YXNZ_CGp6VPshZH__Znh6Z07BqOK53G-Ltc1LVy1E3I,5502
332
- fusion_bench/models/model_card_templates/default.md,sha256=DJXwDODCsqIOhkgP57-iCShxLYK_jnsDsJYH1GfbBY8,1028
334
+ fusion_bench/models/model_card_templates/default.md,sha256=OoU83l1hip1gKsoA08hoKx-nCrOYbKaVTVCjK0pt9WY,1028
333
335
  fusion_bench/models/modeling_deepseek_v2/__init__.py,sha256=trXrhtKb_gIxXVo7wSZ-il5sLJtDTiNZezRrEt3M8zM,505
334
336
  fusion_bench/models/modeling_deepseek_v2/configuration_deepseek.py,sha256=TblFOCfNwaXUnXnD-sxFhSn5Df-_yy2LMcrth-sBPFI,10301
335
337
  fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py,sha256=PtfkfPrfmQVoLiVhgqlp5toJAnCinPWfeZYeJJtWWBs,78676
@@ -464,7 +466,7 @@ fusion_bench/utils/fabric.py,sha256=NxquO_rVJyE2w4V3raMElNMr1-wT01QZWPuIfL2rgdQ,
464
466
  fusion_bench/utils/functools.py,sha256=7_tYJ2WD88_2DDuOOj5aZz3cYuslYH5tsVyIgCeLtmk,1318
465
467
  fusion_bench/utils/hydra_utils.py,sha256=TklUDKDEZlg4keI-TEZiqh4gFjr9-61Rt1RMlqkoSGk,1174
466
468
  fusion_bench/utils/instantiate_utils.py,sha256=OXkfhq_o3Sgy5n3Psf-HI-dIfbK9oD2GBdfcx3gT63Q,17526
467
- fusion_bench/utils/json.py,sha256=sVCqbm9mmyHybiui-O57KFt_ULrjLtN2wipSo6VDvqE,2533
469
+ fusion_bench/utils/json.py,sha256=LXmlqdUxgBepaFjf2JoLrOHQ7CdFAcKLzHL8LaSkPog,4359
468
470
  fusion_bench/utils/lazy_imports.py,sha256=s-1ABhPyyHs7gW4aodCzu3NySzILzTL7kVNZ0DZRXJA,6156
469
471
  fusion_bench/utils/lazy_state_dict.py,sha256=mJaiAtKB1vlNUAoQILnnCmU80FGJ8MSwmdPpmdhOyDE,22206
470
472
  fusion_bench/utils/misc.py,sha256=_7BaS9dNKyySGU0qmTmE0Tk8WK82TEm7IBJxVRkuEAw,5315
@@ -486,7 +488,7 @@ fusion_bench/utils/plot/token_notebook.py,sha256=bsntXf46Zz_RavTxNiB9c3-KvHw7LFw
486
488
  fusion_bench/utils/strenum/__init__.py,sha256=id9ORi1uXrDxhbmVxitJ1KDwLS4H3AAwFpaK5h1cQzw,8531
487
489
  fusion_bench/utils/strenum/_name_mangler.py,sha256=o11M5-bURW2RBvRTYXFQIPNeqLzburdoWLIqk8X3ydw,3397
488
490
  fusion_bench/utils/strenum/_version.py,sha256=6JQRo9LcvODbCOeVFYQb9HNJ_J9XiG_Zbn8ws2A3BV8,18466
489
- fusion_bench-0.2.27.dist-info/licenses/LICENSE,sha256=nhnOJlw4CPuPVE0qvkGmxfFgHmKi-6nzXvTu8t0NUdg,1066
491
+ fusion_bench-0.2.28.dist-info/licenses/LICENSE,sha256=nhnOJlw4CPuPVE0qvkGmxfFgHmKi-6nzXvTu8t0NUdg,1066
490
492
  fusion_bench_config/README.md,sha256=Lc8YSBJ5oxf9KV5kKDivJ9LRyGuraGQPmBbgbdVA-j4,703
491
493
  fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml,sha256=7IxLQoLRz-sRWyV8Vqc5kQcmYE_9YQz2_77pmvAkum8,1207
492
494
  fusion_bench_config/fabric_model_fusion.yaml,sha256=kSQbhBsKypVFA3rmkdhY9BITnZWDXJof-I35t473_U0,2646
@@ -887,6 +889,8 @@ fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-2b-it.yaml,sha256=
887
889
  fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-2b.yaml,sha256=SODG0kcnAP6yC0_J_SpSVMRV-v5qGV22gcWdiBaZo1I,368
888
890
  fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-9b-it.yaml,sha256=zwInWJS8yrhch4vOL1ypRKNWWpJKlhQsyY0Ln14CC-M,389
889
891
  fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-9b.yaml,sha256=ufmu4b3lyxn2XLDMVYxP-bKwYaGTjB5-JoYXLG8v8tY,368
892
+ fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224.yaml,sha256=gcXV5WIYe9Ep-54fjgT9HqbCBY7UiqbqkHvoNCQx62Y,259
893
+ fusion_bench_config/modelpool/Dinov2ForImageClassification/dinov2-base-imagenet1k-1-layer.yaml,sha256=jxe6rvV37FBGsV-Pdnyxe-G-Vw-HzOXuT2NMHKBSBOU,270
890
894
  fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md,sha256=DC0HF-isCHshipHTC0Rof6GvjTUa0i2DVQZKrklQQlU,2416
891
895
  fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml,sha256=jbJqqciORJQknpSzh2zKiFm6VKDOsmaSk9XfPCVmHGg,1220
892
896
  fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml,sha256=q2_E2R1wIOdxd-AF-wjXkPO64gJgD27YXsZ8FFLWUIo,1607
@@ -1011,8 +1015,8 @@ fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml,sha256=3q-KMuFaM
1011
1015
  fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml,sha256=GjpiiRownrBCpl-TNwWRW2PYePbF-Cl99jlLNPrK5T4,1017
1012
1016
  fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml,sha256=WwiYMQKehtJixDPnu5o3vcWe4yJksXTWRqOzm3uVWXQ,1017
1013
1017
  fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml,sha256=xGRt0J9joXTzWUew6DvoYprAWlPXhaVFw5AX4im5VQw,1017
1014
- fusion_bench-0.2.27.dist-info/METADATA,sha256=TnLxGqALTnvyF-GXwk-iGvl-eNvBjNvZzkDODdkVLVo,24307
1015
- fusion_bench-0.2.27.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
1016
- fusion_bench-0.2.27.dist-info/entry_points.txt,sha256=iUQ8MCJvda7HP4vYh2n1Teoapb4G9PBVYZkAfcc5SHU,116
1017
- fusion_bench-0.2.27.dist-info/top_level.txt,sha256=BuO4TL6iHL_2yPBUX9-LlIrHRczA_BNMIFwweK0PQEI,13
1018
- fusion_bench-0.2.27.dist-info/RECORD,,
1018
+ fusion_bench-0.2.28.dist-info/METADATA,sha256=2m3tF3J5gbcupGjZt_0Md77Tb7h3oDxwwp_Q_sZsdIM,24307
1019
+ fusion_bench-0.2.28.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
1020
+ fusion_bench-0.2.28.dist-info/entry_points.txt,sha256=iUQ8MCJvda7HP4vYh2n1Teoapb4G9PBVYZkAfcc5SHU,116
1021
+ fusion_bench-0.2.28.dist-info/top_level.txt,sha256=BuO4TL6iHL_2yPBUX9-LlIrHRczA_BNMIFwweK0PQEI,13
1022
+ fusion_bench-0.2.28.dist-info/RECORD,,
@@ -0,0 +1,10 @@
1
+ _target_: fusion_bench.modelpool.ConvNextForImageClassificationPool
2
+ _recursive_: False
3
+ models:
4
+ _pretrained_:
5
+ config_path: facebook/convnext-base-224
6
+ pretrained: true
7
+ dataset_name: null
8
+ train_datasets: null
9
+ val_datasets: null
10
+ test_datasets: null
@@ -0,0 +1,10 @@
1
+ _target_: fusion_bench.modelpool.Dinov2ForImageClassificationPool
2
+ _recursive_: False
3
+ models:
4
+ _pretrained_:
5
+ config_path: facebook/dinov2-base-imagenet1k-1-layer
6
+ pretrained: true
7
+ dataset_name: null
8
+ train_datasets: null
9
+ val_datasets: null
10
+ test_datasets: null