fusion-bench 0.2.31__py3-none-any.whl → 0.2.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. fusion_bench/__init__.py +6 -0
  2. fusion_bench/__main__.py +2 -2
  3. fusion_bench/dataset/__init__.py +2 -0
  4. fusion_bench/dataset/clip_dataset.py +4 -72
  5. fusion_bench/dataset/image_dataset.py +44 -18
  6. fusion_bench/method/base_algorithm.py +4 -0
  7. fusion_bench/method/dop/dop.py +0 -22
  8. fusion_bench/method/dop/dop_general.py +489 -0
  9. fusion_bench/method/dop/utils.py +24 -4
  10. fusion_bench/method/emr_merging/__init__.py +1 -0
  11. fusion_bench/method/emr_merging/emr_merging.py +53 -0
  12. fusion_bench/method/emr_merging/utils.py +162 -0
  13. fusion_bench/method/opcm/opcm.py +6 -2
  14. fusion_bench/method/opcm/opcm_general.py +356 -0
  15. fusion_bench/method/opcm/utils.py +1 -4
  16. fusion_bench/method/simple_average.py +52 -18
  17. fusion_bench/method/task_arithmetic/task_arithmetic.py +1 -1
  18. fusion_bench/mixins/lightning_fabric.py +108 -3
  19. fusion_bench/mixins/serialization.py +1 -1
  20. fusion_bench/modelpool/base_pool.py +37 -1
  21. fusion_bench/modelpool/convnext_for_image_classification.py +5 -2
  22. fusion_bench/models/hf_clip.py +20 -0
  23. fusion_bench/models/modulator/__init__.py +1 -0
  24. fusion_bench/models/modulator/base.py +123 -0
  25. fusion_bench/models/parameter_dict.py +119 -29
  26. fusion_bench/models/utils.py +190 -2
  27. fusion_bench/models/wrappers/switch.py +90 -0
  28. fusion_bench/programs/base_program.py +6 -0
  29. fusion_bench/programs/fabric_fusion_program.py +4 -0
  30. fusion_bench/scripts/cli.py +19 -8
  31. fusion_bench/taskpool/image_classification.py +270 -0
  32. fusion_bench/utils/__init__.py +18 -1
  33. fusion_bench/utils/data.py +1 -1
  34. fusion_bench/utils/dict.py +19 -0
  35. fusion_bench/utils/dtype.py +19 -0
  36. fusion_bench/utils/misc.py +1 -0
  37. fusion_bench/utils/packages.py +4 -0
  38. fusion_bench/utils/state_dict_arithmetic.py +183 -1
  39. fusion_bench/utils/tensorboard.py +21 -3
  40. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/METADATA +3 -1
  41. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/RECORD +51 -37
  42. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/WHEEL +1 -1
  43. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/entry_points.txt +1 -1
  44. fusion_bench_config/fabric/loggers/mlflow_logger.yaml +4 -0
  45. fusion_bench_config/method/dop/dop_general.yaml +33 -0
  46. fusion_bench_config/method/emr_merging/emr_merging.yaml +1 -0
  47. fusion_bench_config/method/opcm/opcm_general.yaml +18 -0
  48. fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224_8-tasks.yaml +15 -0
  49. fusion_bench_config/taskpool/ImageClassificationTaskPool/convnext-base-224_8-tasks.yaml +17 -0
  50. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/licenses/LICENSE +0 -0
  51. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,270 @@
1
+ import itertools
2
+ import json
3
+ import os
4
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Union
5
+
6
+ import torch
7
+ from omegaconf import DictConfig, OmegaConf
8
+ from torch import Tensor, nn
9
+ from torch.nn import functional as F
10
+ from torch.utils.data import DataLoader, Dataset
11
+ from torchmetrics import Accuracy, MeanMetric
12
+ from tqdm.auto import tqdm
13
+
14
+ from fusion_bench import (
15
+ BaseTaskPool,
16
+ LightningFabricMixin,
17
+ RuntimeConstants,
18
+ auto_register_config,
19
+ get_rankzero_logger,
20
+ instantiate,
21
+ )
22
+ from fusion_bench.dataset import ImageClassificationDataset
23
+ from fusion_bench.models.wrappers.switch import set_active_option
24
+ from fusion_bench.tasks.clip_classification import get_classnames, get_num_classes
25
+ from fusion_bench.utils import count_parameters
26
+
27
+ if TYPE_CHECKING:
28
+ from transformers import AutoModelForImageClassification
29
+
30
+ log = get_rankzero_logger(__name__)
31
+
32
+
33
+ def _get_logits_from_model_output(outputs) -> Tensor:
34
+ """Extract logits from model outputs."""
35
+ match outputs:
36
+ case Tensor():
37
+ logits = outputs
38
+ case dict() | DictConfig() if "logits" in outputs:
39
+ logits = outputs["logits"]
40
+ assert isinstance(
41
+ logits, Tensor
42
+ ), "The 'logits' in the model output dictionary is not a Tensor."
43
+ case _:
44
+ if hasattr(outputs, "logits"):
45
+ logits = outputs.logits
46
+ assert isinstance(
47
+ logits, Tensor
48
+ ), "The 'logits' attribute of the model output is not a Tensor."
49
+ else:
50
+ raise ValueError(
51
+ "Model output is not a Tensor and does not have 'logits' attribute."
52
+ )
53
+ return logits
54
+
55
+
56
+ @auto_register_config
57
+ class ImageClassificationTaskPool(LightningFabricMixin, BaseTaskPool):
58
+ _config_mapping = BaseTaskPool._config_mapping | {
59
+ "_test_datasets": "test_datasets",
60
+ "_processor": "processor",
61
+ }
62
+
63
+ _processor_instance = None
64
+ _is_setup: bool = False
65
+
66
+ def __init__(
67
+ self,
68
+ test_datasets: DictConfig | Dict[str, Any],
69
+ *,
70
+ processor: DictConfig | Any,
71
+ dataloader_kwargs: DictConfig,
72
+ **kwargs,
73
+ ):
74
+ super().__init__(**kwargs)
75
+
76
+ # if the processor is given as a transformers processor instance, store it directly
77
+ if callable(processor):
78
+ self._processor_instance = processor
79
+
80
+ @property
81
+ def processor(self) -> Any:
82
+ if self._processor is None:
83
+ return None
84
+
85
+ if self._processor_instance is not None:
86
+ return self._processor_instance
87
+
88
+ match self._processor:
89
+ case dict() | DictConfig() if "_target_" in self._processor:
90
+ self._processor_instance = instantiate(self._processor)
91
+ return self._processor_instance
92
+ case str():
93
+ from transformers import AutoProcessor
94
+
95
+ self._processor_instance = AutoProcessor.from_pretrained(
96
+ self._processor
97
+ )
98
+ return self._processor_instance
99
+
100
+ raise ValueError("Processor is not properly configured.")
101
+
102
+ def setup(self):
103
+ # Load test datasets
104
+ test_datasets = {
105
+ ds_name: ImageClassificationDataset(
106
+ self.load_test_dataset(ds_name), processor=self.processor
107
+ )
108
+ for ds_name in self._test_datasets
109
+ }
110
+ self.test_datasets = test_datasets
111
+ self.test_dataloaders = {
112
+ ds_name: self.fabric.setup_dataloaders(
113
+ self.get_dataloader(ds, stage="test")
114
+ )
115
+ for ds_name, ds in test_datasets.items()
116
+ }
117
+
118
+ def load_test_dataset(self, dataset_name: str, *args, **kwargs) -> Dataset:
119
+ """
120
+ Load the testing dataset for the specified model.
121
+
122
+ Args:
123
+ dataset_name (str): The name of the model.
124
+
125
+ Returns:
126
+ Dataset: The instantiated testing dataset.
127
+ """
128
+ test_dataset = self._test_datasets[dataset_name]
129
+ if isinstance(test_dataset, (DictConfig, dict)):
130
+ return instantiate(test_dataset, *args, **kwargs)
131
+ else:
132
+ return test_dataset
133
+
134
+ def get_dataloader(self, dataset, stage: str):
135
+ """Create a DataLoader for the specified dataset and training stage.
136
+
137
+ Constructs a PyTorch DataLoader with stage-appropriate configurations:
138
+ - Training stage: shuffling enabled by default
139
+ - Validation/test stages: shuffling disabled by default
140
+
141
+ Args:
142
+ dataset: The dataset to wrap in a DataLoader.
143
+ stage (str): Training stage, must be one of "train", "val", or "test".
144
+ Determines default shuffling behavior.
145
+
146
+ Returns:
147
+ DataLoader: Configured DataLoader for the given dataset and stage.
148
+ """
149
+ assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
150
+ dataloader_kwargs = dict(self.dataloader_kwargs)
151
+ if "shuffle" not in dataloader_kwargs:
152
+ dataloader_kwargs["shuffle"] = stage == "train"
153
+ return DataLoader(dataset, **dataloader_kwargs)
154
+
155
+ @torch.no_grad()
156
+ def _evaluate(
157
+ self,
158
+ classifier,
159
+ test_loader,
160
+ num_classes: int,
161
+ task_name: str = None,
162
+ ):
163
+ classifier.eval()
164
+ accuracy = Accuracy(task="multiclass", num_classes=num_classes)
165
+ loss_metric = MeanMetric()
166
+ if RuntimeConstants.debug:
167
+ log.info("Running under fast_dev_run mode, evaluating on a single batch.")
168
+ test_loader = itertools.islice(test_loader, 1)
169
+ else:
170
+ test_loader = test_loader
171
+
172
+ pbar = tqdm(
173
+ test_loader,
174
+ desc=f"Evaluating {task_name}" if task_name is not None else "Evaluating",
175
+ leave=False,
176
+ dynamic_ncols=True,
177
+ )
178
+ for batch in pbar:
179
+ inputs, targets = batch
180
+ outputs = classifier(inputs)
181
+ logits = _get_logits_from_model_output(outputs)
182
+ if logits.device != targets.device:
183
+ targets = targets.to(logits.device)
184
+
185
+ loss = F.cross_entropy(logits, targets)
186
+ loss_metric.update(loss.detach().cpu())
187
+ acc = accuracy(logits.detach().cpu(), targets.detach().cpu())
188
+ pbar.set_postfix(
189
+ {
190
+ "accuracy": accuracy.compute().item(),
191
+ "loss": loss_metric.compute().item(),
192
+ }
193
+ )
194
+
195
+ acc = accuracy.compute().item()
196
+ loss = loss_metric.compute().item()
197
+ results = {"accuracy": acc, "loss": loss}
198
+ return results
199
+
200
+ def evaluate(
201
+ self,
202
+ model: Union["AutoModelForImageClassification", nn.Module],
203
+ name: str = None,
204
+ **kwargs,
205
+ ) -> Dict[str, Any]:
206
+ assert isinstance(
207
+ model, nn.Module
208
+ ), f"Expected model to be an instance of nn.Module, but got {type(model)}"
209
+
210
+ if not self._is_setup:
211
+ self.setup()
212
+
213
+ classifier = self.fabric.to_device(model)
214
+ classifier.eval()
215
+ report = {}
216
+ # collect basic model information
217
+ training_params, all_params = count_parameters(model)
218
+ report["model_info"] = {
219
+ "trainable_params": training_params,
220
+ "all_params": all_params,
221
+ "trainable_percentage": training_params / all_params,
222
+ }
223
+ if name is not None:
224
+ report["model_info"]["name"] = name
225
+
226
+ # evaluate on each task
227
+ pbar = tqdm(
228
+ self.test_dataloaders.items(),
229
+ desc="Evaluating tasks",
230
+ total=len(self.test_dataloaders),
231
+ )
232
+ for task_name, test_dataloader in pbar:
233
+ set_active_option(classifier, task_name)
234
+ num_classes = get_num_classes(task_name)
235
+ result = self._evaluate(
236
+ classifier,
237
+ test_dataloader,
238
+ num_classes=num_classes,
239
+ task_name=task_name,
240
+ )
241
+ report[task_name] = result
242
+
243
+ # calculate the average accuracy and loss
244
+ if "average" not in report:
245
+ report["average"] = {}
246
+ accuracies = [
247
+ value["accuracy"]
248
+ for key, value in report.items()
249
+ if "accuracy" in value
250
+ ]
251
+ if len(accuracies) > 0:
252
+ average_accuracy = sum(accuracies) / len(accuracies)
253
+ report["average"]["accuracy"] = average_accuracy
254
+ losses = [value["loss"] for key, value in report.items() if "loss" in value]
255
+ if len(losses) > 0:
256
+ average_loss = sum(losses) / len(losses)
257
+ report["average"]["loss"] = average_loss
258
+
259
+ log.info(f"Evaluation Result: {report}")
260
+ if self.fabric.is_global_zero and len(self.fabric._loggers) > 0:
261
+ save_path = os.path.join(self.log_dir, "report.json")
262
+ for version in itertools.count(1):
263
+ if not os.path.exists(save_path):
264
+ break
265
+ # if the file already exists, increment the version to avoid overwriting
266
+ save_path = os.path.join(self.log_dir, f"report_{version}.json")
267
+ with open(save_path, "w") as fp:
268
+ json.dump(report, fp)
269
+ log.info(f"Evaluation report saved to {save_path}")
270
+ return report
@@ -31,6 +31,11 @@ _import_structure = {
31
31
  ],
32
32
  "dtype": ["get_dtype", "parse_dtype"],
33
33
  "fabric": ["seed_everything_by_time"],
34
+ "hydra_utils": [
35
+ "initialize_hydra_config",
36
+ "get_default_config_path",
37
+ "get_hydra_output_dir",
38
+ ],
34
39
  "instantiate_utils": [
35
40
  "instantiate",
36
41
  "is_instantiable",
@@ -40,6 +45,7 @@ _import_structure = {
40
45
  "json": ["load_from_json", "save_to_json", "print_json"],
41
46
  "lazy_state_dict": ["LazyStateDict"],
42
47
  "misc": [
48
+ "DeprecationWarningMeta",
43
49
  "first",
44
50
  "has_length",
45
51
  "join_lists",
@@ -122,6 +128,11 @@ if TYPE_CHECKING:
122
128
  )
123
129
  from .dtype import get_dtype, parse_dtype
124
130
  from .fabric import seed_everything_by_time
131
+ from .hydra_utils import (
132
+ get_default_config_path,
133
+ get_hydra_output_dir,
134
+ initialize_hydra_config,
135
+ )
125
136
  from .instantiate_utils import (
126
137
  instantiate,
127
138
  is_instantiable,
@@ -130,7 +141,13 @@ if TYPE_CHECKING:
130
141
  )
131
142
  from .json import load_from_json, print_json, save_to_json
132
143
  from .lazy_state_dict import LazyStateDict
133
- from .misc import first, has_length, join_lists, validate_and_suggest_corrections
144
+ from .misc import (
145
+ DeprecationWarningMeta,
146
+ first,
147
+ has_length,
148
+ join_lists,
149
+ validate_and_suggest_corrections,
150
+ )
134
151
  from .packages import compare_versions, import_object
135
152
  from .parameters import (
136
153
  check_parameters_all_equal,
@@ -95,7 +95,7 @@ class InfiniteDataLoader:
95
95
  f"Failed to retrieve data from data loader after {self.max_retries} attempts. "
96
96
  f"Last error: [{type(last_exception).__name__}]{last_exception}. "
97
97
  + (
98
- f"The data loader appears to be empty."
98
+ f"The data loader may be empty."
99
99
  if isinstance(last_exception, StopIteration)
100
100
  else ""
101
101
  )
@@ -41,3 +41,22 @@ def dict_map(f, d: dict, *, max_level: int = -1, skip_levels=0, inplace=False):
41
41
 
42
42
  dict_map_impl(d, ans, 0)
43
43
  return ans
44
+
45
+
46
+ def dict_merge(dicts: Iterable[dict], disjoint: bool = True) -> dict:
47
+ """Merge multiple dictionaries into one.
48
+
49
+ Args:
50
+ dicts (Iterable[dict]): iterable of dictionaries to merge
51
+ disjoint (bool, optional): if True, raises an error on key conflicts. Defaults to True.
52
+
53
+ Returns:
54
+ dict: merged dictionary
55
+ """
56
+ merged_dict = type(dicts[0])()
57
+ for d in dicts:
58
+ for k, v in d.items():
59
+ if disjoint and k in merged_dict:
60
+ raise ValueError(f"Key conflict when merging dictionaries: {k}")
61
+ merged_dict[k] = v
62
+ return merged_dict
@@ -146,3 +146,22 @@ def validate_expected_param_dtype(
146
146
  raise ValueError(
147
147
  f"Parameter {name} has dtype {param.dtype}, but expected {dtype}"
148
148
  )
149
+
150
+
151
+ def dtype_support_svd(dtype: torch.dtype) -> bool:
152
+ """
153
+ Check if the given dtype is supported for SVD operation in PyTorch.
154
+
155
+ Args:
156
+ dtype (torch.dtype): The data type to check.
157
+
158
+ Returns:
159
+ bool: True if the dtype is supported for SVD, False otherwise.
160
+ """
161
+ supported_dtypes = {
162
+ torch.float32,
163
+ torch.float64,
164
+ torch.complex64,
165
+ torch.complex128,
166
+ }
167
+ return dtype in supported_dtypes
@@ -9,6 +9,7 @@ __all__ = [
9
9
  "join_lists",
10
10
  "attr_equal",
11
11
  "validate_and_suggest_corrections",
12
+ "DeprecationWarningMeta",
12
13
  ]
13
14
 
14
15
 
@@ -20,6 +20,10 @@ def _get_package_version(name: str) -> "Version":
20
20
  return version.parse("0.0.0")
21
21
 
22
22
 
23
+ def is_ray_available():
24
+ return _is_package_available("ray")
25
+
26
+
23
27
  def is_pyav_available():
24
28
  return _is_package_available("av")
25
29
 
@@ -1,6 +1,16 @@
1
1
  from collections import OrderedDict
2
2
  from numbers import Number
3
- from typing import Callable, Dict, List, Literal, Optional, Union, cast
3
+ from typing import (
4
+ Callable,
5
+ Dict,
6
+ Iterator,
7
+ List,
8
+ Literal,
9
+ Mapping,
10
+ Optional,
11
+ Union,
12
+ cast,
13
+ )
4
14
 
5
15
  import torch
6
16
  from torch import Tensor
@@ -462,6 +472,118 @@ class ArithmeticStateDict(OrderedDict):
462
472
  return cls(result_dict)
463
473
 
464
474
 
475
+ class LazyStateDictExpr(Mapping[str, torch.Tensor]):
476
+ """
477
+ A lazy, key-wise expression over state_dict-like objects.
478
+ """
479
+
480
+ # ---- core Mapping API ----
481
+ def __getitem__(self, key: str) -> torch.Tensor:
482
+ raise NotImplementedError
483
+
484
+ def __iter__(self) -> Iterator[str]:
485
+ raise NotImplementedError
486
+
487
+ def __len__(self) -> int:
488
+ raise NotImplementedError
489
+
490
+ # ---- arithmetic (build graph only) ----
491
+ def __add__(self, other):
492
+ return BinaryOp(torch.add, self, ensure_expr(other))
493
+
494
+ def __sub__(self, other):
495
+ return BinaryOp(torch.sub, self, ensure_expr(other))
496
+
497
+ def __mul__(self, scalar):
498
+ return UnaryOp(lambda x: x * scalar, self)
499
+
500
+ def __rmul__(self, scalar):
501
+ return self.__mul__(scalar)
502
+
503
+ def __truediv__(self, scalar):
504
+ return UnaryOp(lambda x: x / scalar, self)
505
+
506
+ # ---- eager escape hatch ----
507
+ def materialize(
508
+ self, device=None, dtype=None, non_blocking=False, copy=False
509
+ ) -> Dict[str, torch.Tensor]:
510
+ """
511
+ Eagerly evaluate into an OrderedDict.
512
+ """
513
+ out = {}
514
+ for k in self:
515
+ v = self[k]
516
+ out[k] = v.to(
517
+ device=device,
518
+ dtype=dtype,
519
+ non_blocking=non_blocking,
520
+ copy=copy,
521
+ )
522
+ return out
523
+
524
+ def __repr__(self):
525
+ return f"{self.__class__.__name__}(lazy)"
526
+
527
+
528
+ class StateDictLeaf(LazyStateDictExpr):
529
+ def __init__(self, state_dict: Mapping[str, torch.Tensor]):
530
+ self._sd = state_dict
531
+
532
+ def __getitem__(self, key: str) -> torch.Tensor:
533
+ return self._sd[key]
534
+
535
+ def __iter__(self):
536
+ return iter(self._sd)
537
+
538
+ def __len__(self):
539
+ return len(self._sd)
540
+
541
+
542
+ class UnaryOp(LazyStateDictExpr):
543
+ def __init__(self, op: Callable[[torch.Tensor], torch.Tensor], child):
544
+ self.op = op
545
+ self.child = child
546
+
547
+ def __getitem__(self, key: str):
548
+ return self.op(self.child[key])
549
+
550
+ def __iter__(self):
551
+ return iter(self.child)
552
+
553
+ def __len__(self):
554
+ return len(self.child)
555
+
556
+
557
+ class BinaryOp(LazyStateDictExpr):
558
+ def __init__(
559
+ self,
560
+ op: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
561
+ left,
562
+ right,
563
+ ):
564
+ self.op = op
565
+ self.left = left
566
+ self.right = right
567
+
568
+ def __getitem__(self, key: str):
569
+ return self.op(self.left[key], self.right[key])
570
+
571
+ def __iter__(self):
572
+ # assume key sets are aligned
573
+ return iter(self.left)
574
+
575
+ def __len__(self):
576
+ return len(self.left)
577
+
578
+
579
+ def ensure_expr(x):
580
+ if isinstance(x, LazyStateDictExpr):
581
+ return x
582
+ if isinstance(x, Mapping):
583
+ return StateDictLeaf(x)
584
+ raise TypeError(f"Unsupported operand type: {type(x)}")
585
+
586
+
465
587
  def _validate_state_dict_list_not_empty(state_dicts: List[StateDictType]) -> None:
466
588
  """
467
589
  Validate that the list of state dicts is not empty and contains valid state dicts.
@@ -1228,3 +1350,63 @@ def state_dict_hadamard_product(a: StateDictType, b: StateDictType) -> StateDict
1228
1350
  """
1229
1351
  _validate_state_dict_same_keys([a, b])
1230
1352
  return OrderedDict((key, a[key] * b[key]) for key in a)
1353
+
1354
+
1355
+ def state_dict_max(
1356
+ state_dicts: List[StateDictType],
1357
+ ) -> StateDictType:
1358
+ """
1359
+ Compute the element-wise maximum across multiple state dicts.
1360
+
1361
+ Args:
1362
+ state_dicts: List of state dicts to compute the maximum from.
1363
+
1364
+ Returns:
1365
+ A state dict containing the element-wise maximums.
1366
+ """
1367
+ _validate_state_dict_list_not_empty(state_dicts)
1368
+ _validate_state_dict_same_keys(state_dicts)
1369
+
1370
+ max_state_dict = OrderedDict()
1371
+
1372
+ for key in state_dicts[0]:
1373
+ # Initialize with the first tensor
1374
+ max_tensor = state_dicts[0][key].clone()
1375
+
1376
+ # Compute element-wise maximum
1377
+ for state_dict in state_dicts[1:]:
1378
+ max_tensor = torch.max(max_tensor, state_dict[key])
1379
+
1380
+ max_state_dict[key] = max_tensor
1381
+
1382
+ return max_state_dict
1383
+
1384
+
1385
+ def state_dict_max_abs(
1386
+ state_dicts: List[StateDictType],
1387
+ ) -> StateDictType:
1388
+ """
1389
+ Compute the element-wise maximum absolute value across multiple state dicts.
1390
+
1391
+ Args:
1392
+ state_dicts: List of state dicts to compute the maximum absolute values from.
1393
+
1394
+ Returns:
1395
+ A state dict containing the element-wise maximum absolute values.
1396
+ """
1397
+ _validate_state_dict_list_not_empty(state_dicts)
1398
+ _validate_state_dict_same_keys(state_dicts)
1399
+
1400
+ max_abs_state_dict = OrderedDict()
1401
+
1402
+ for key in state_dicts[0]:
1403
+ # Initialize with the absolute values of the first tensor
1404
+ max_abs_tensor = state_dicts[0][key].abs()
1405
+
1406
+ # Compute element-wise maximum absolute value
1407
+ for state_dict in state_dicts[1:]:
1408
+ max_abs_tensor = torch.max(max_abs_tensor, state_dict[key].abs())
1409
+
1410
+ max_abs_state_dict[key] = max_abs_tensor
1411
+
1412
+ return max_abs_state_dict
@@ -2,14 +2,18 @@
2
2
  functions deal with tensorboard logs.
3
3
  """
4
4
 
5
- from typing import Dict, Iterable, List
5
+ from pathlib import Path
6
+ from typing import Dict, Iterable, List, Union
6
7
 
7
8
  import numpy as np
8
9
  import pandas as pd
9
10
  from tensorboard.backend.event_processing import event_accumulator
10
11
 
11
12
 
12
- def parse_tensorboard_as_dict(path: str, scalars: Iterable[str]):
13
+ def parse_tensorboard_as_dict(
14
+ path: Union[str, Path],
15
+ scalars: Iterable[str],
16
+ ) -> Dict[str, pd.DataFrame]:
13
17
  """
14
18
  returns a dictionary of pandas dataframes for each requested scalar.
15
19
 
@@ -20,7 +24,19 @@ def parse_tensorboard_as_dict(path: str, scalars: Iterable[str]):
20
24
 
21
25
  Returns:
22
26
  Dict[str, pandas.DataFrame]: a dictionary of pandas dataframes for each requested scalar
27
+
28
+ Example:
29
+
30
+ >>> from fusion_bench.utils.tensorboard import parse_tensorboard_as_dict
31
+ >>> path = "path/to/tensorboard/logs"
32
+ >>> scalars = ["train/loss", "val/accuracy"]
33
+ >>> data = parse_tensorboard_as_dict(path, scalars)
34
+ >>> train_loss_df = data["train/loss"]
35
+ >>> val_accuracy_df = data["val/accuracy"]
23
36
  """
37
+ if isinstance(path, Path):
38
+ path = str(path)
39
+ assert isinstance(path, str), "path must be a string"
24
40
  ea = event_accumulator.EventAccumulator(
25
41
  path,
26
42
  size_guidance={event_accumulator.SCALARS: 0},
@@ -33,7 +49,9 @@ def parse_tensorboard_as_dict(path: str, scalars: Iterable[str]):
33
49
  return {k: pd.DataFrame(ea.Scalars(k)) for k in scalars}
34
50
 
35
51
 
36
- def parse_tensorboard_as_list(path: str, scalars: Iterable[str]):
52
+ def parse_tensorboard_as_list(
53
+ path: Union[str, Path], scalars: Iterable[str]
54
+ ) -> List[pd.DataFrame]:
37
55
  """
38
56
  returns a list of pandas dataframes for each requested scalar.
39
57
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fusion-bench
3
- Version: 0.2.31
3
+ Version: 0.2.32
4
4
  Summary: A Comprehensive Benchmark of Deep Model Fusion
5
5
  Author-email: Anke Tang <tang.anke@foxmail.com>
6
6
  Project-URL: Repository, https://github.com/tanganke/fusion_bench
@@ -61,6 +61,8 @@ Dynamic: license-file
61
61
 
62
62
  FusionBench is a benchmark suite designed to evaluate the performance of various deep model fusion techniques. It aims to provide a comprehensive comparison of different methods on a variety of datasets and tasks.
63
63
 
64
+ ## :newspaper: News and Related
65
+
64
66
  Projects based on FusionBench and news from the community (descending order of date. If you have any work based on FusionBench, please feel free to let us know, we are willing to add it to the list. :partying_face:):
65
67
 
66
68
  <details>