fusion-bench 0.2.30__py3-none-any.whl → 0.2.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. fusion_bench/__init__.py +6 -0
  2. fusion_bench/__main__.py +2 -2
  3. fusion_bench/constants/runtime.py +4 -1
  4. fusion_bench/dataset/__init__.py +2 -0
  5. fusion_bench/dataset/clip_dataset.py +4 -72
  6. fusion_bench/dataset/image_dataset.py +44 -18
  7. fusion_bench/method/base_algorithm.py +4 -0
  8. fusion_bench/method/classification/image_classification_finetune.py +1 -0
  9. fusion_bench/method/concrete_subspace/clip_concrete_tsvm.py +285 -0
  10. fusion_bench/method/dop/dop.py +0 -22
  11. fusion_bench/method/dop/dop_general.py +489 -0
  12. fusion_bench/method/dop/utils.py +24 -4
  13. fusion_bench/method/emr_merging/__init__.py +1 -0
  14. fusion_bench/method/emr_merging/emr_merging.py +53 -0
  15. fusion_bench/method/emr_merging/utils.py +162 -0
  16. fusion_bench/method/opcm/opcm.py +6 -2
  17. fusion_bench/method/opcm/opcm_general.py +356 -0
  18. fusion_bench/method/opcm/utils.py +1 -4
  19. fusion_bench/method/simple_average.py +52 -18
  20. fusion_bench/method/task_arithmetic/task_arithmetic.py +1 -1
  21. fusion_bench/method/task_singular_vector/TSVM.py +7 -6
  22. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +0 -1
  23. fusion_bench/mixins/lightning_fabric.py +110 -11
  24. fusion_bench/mixins/openclip_classification.py +155 -1
  25. fusion_bench/mixins/serialization.py +1 -1
  26. fusion_bench/modelpool/base_pool.py +37 -0
  27. fusion_bench/modelpool/convnext_for_image_classification.py +5 -2
  28. fusion_bench/modelpool/openclip_vision/modelpool.py +12 -3
  29. fusion_bench/models/hf_clip.py +20 -0
  30. fusion_bench/models/modulator/__init__.py +1 -0
  31. fusion_bench/models/modulator/base.py +123 -0
  32. fusion_bench/models/open_clip/modeling.py +61 -5
  33. fusion_bench/models/open_clip/utils.py +13 -2
  34. fusion_bench/models/parameter_dict.py +119 -29
  35. fusion_bench/models/utils.py +190 -2
  36. fusion_bench/models/wrappers/switch.py +90 -0
  37. fusion_bench/programs/base_program.py +6 -0
  38. fusion_bench/programs/fabric_fusion_program.py +4 -0
  39. fusion_bench/py.typed +1 -0
  40. fusion_bench/scripts/cli.py +25 -23
  41. fusion_bench/scripts/imgui.py +2 -2
  42. fusion_bench/scripts/webui.py +2 -2
  43. fusion_bench/taskpool/image_classification.py +270 -0
  44. fusion_bench/utils/__init__.py +20 -1
  45. fusion_bench/utils/data.py +1 -1
  46. fusion_bench/utils/dict.py +19 -0
  47. fusion_bench/utils/dtype.py +19 -0
  48. fusion_bench/utils/hydra_utils.py +75 -0
  49. fusion_bench/utils/misc.py +1 -0
  50. fusion_bench/utils/packages.py +4 -0
  51. fusion_bench/utils/parameters.py +33 -0
  52. fusion_bench/utils/rich_utils.py +42 -19
  53. fusion_bench/utils/state_dict_arithmetic.py +183 -1
  54. fusion_bench/utils/tensorboard.py +21 -3
  55. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/METADATA +3 -1
  56. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/RECORD +70 -53
  57. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/WHEEL +1 -1
  58. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/entry_points.txt +1 -1
  59. fusion_bench_config/README.md +9 -0
  60. fusion_bench_config/fabric/auto.yaml +1 -0
  61. fusion_bench_config/fabric/loggers/mlflow_logger.yaml +4 -0
  62. fusion_bench_config/hydra/default.yaml +3 -1
  63. fusion_bench_config/method/concrete_subspace/clip_concrete_tsvm.yaml +38 -0
  64. fusion_bench_config/method/dop/dop_general.yaml +33 -0
  65. fusion_bench_config/method/emr_merging/emr_merging.yaml +1 -0
  66. fusion_bench_config/method/opcm/opcm_general.yaml +18 -0
  67. fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224_8-tasks.yaml +15 -0
  68. fusion_bench_config/taskpool/ImageClassificationTaskPool/convnext-base-224_8-tasks.yaml +17 -0
  69. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/licenses/LICENSE +0 -0
  70. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,270 @@
1
+ import itertools
2
+ import json
3
+ import os
4
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Union
5
+
6
+ import torch
7
+ from omegaconf import DictConfig, OmegaConf
8
+ from torch import Tensor, nn
9
+ from torch.nn import functional as F
10
+ from torch.utils.data import DataLoader, Dataset
11
+ from torchmetrics import Accuracy, MeanMetric
12
+ from tqdm.auto import tqdm
13
+
14
+ from fusion_bench import (
15
+ BaseTaskPool,
16
+ LightningFabricMixin,
17
+ RuntimeConstants,
18
+ auto_register_config,
19
+ get_rankzero_logger,
20
+ instantiate,
21
+ )
22
+ from fusion_bench.dataset import ImageClassificationDataset
23
+ from fusion_bench.models.wrappers.switch import set_active_option
24
+ from fusion_bench.tasks.clip_classification import get_classnames, get_num_classes
25
+ from fusion_bench.utils import count_parameters
26
+
27
+ if TYPE_CHECKING:
28
+ from transformers import AutoModelForImageClassification
29
+
30
+ log = get_rankzero_logger(__name__)
31
+
32
+
33
+ def _get_logits_from_model_output(outputs) -> Tensor:
34
+ """Extract logits from model outputs."""
35
+ match outputs:
36
+ case Tensor():
37
+ logits = outputs
38
+ case dict() | DictConfig() if "logits" in outputs:
39
+ logits = outputs["logits"]
40
+ assert isinstance(
41
+ logits, Tensor
42
+ ), "The 'logits' in the model output dictionary is not a Tensor."
43
+ case _:
44
+ if hasattr(outputs, "logits"):
45
+ logits = outputs.logits
46
+ assert isinstance(
47
+ logits, Tensor
48
+ ), "The 'logits' attribute of the model output is not a Tensor."
49
+ else:
50
+ raise ValueError(
51
+ "Model output is not a Tensor and does not have 'logits' attribute."
52
+ )
53
+ return logits
54
+
55
+
56
+ @auto_register_config
57
+ class ImageClassificationTaskPool(LightningFabricMixin, BaseTaskPool):
58
+ _config_mapping = BaseTaskPool._config_mapping | {
59
+ "_test_datasets": "test_datasets",
60
+ "_processor": "processor",
61
+ }
62
+
63
+ _processor_instance = None
64
+ _is_setup: bool = False
65
+
66
+ def __init__(
67
+ self,
68
+ test_datasets: DictConfig | Dict[str, Any],
69
+ *,
70
+ processor: DictConfig | Any,
71
+ dataloader_kwargs: DictConfig,
72
+ **kwargs,
73
+ ):
74
+ super().__init__(**kwargs)
75
+
76
+ # if the processor is given as a transformers processor instance, store it directly
77
+ if callable(processor):
78
+ self._processor_instance = processor
79
+
80
+ @property
81
+ def processor(self) -> Any:
82
+ if self._processor is None:
83
+ return None
84
+
85
+ if self._processor_instance is not None:
86
+ return self._processor_instance
87
+
88
+ match self._processor:
89
+ case dict() | DictConfig() if "_target_" in self._processor:
90
+ self._processor_instance = instantiate(self._processor)
91
+ return self._processor_instance
92
+ case str():
93
+ from transformers import AutoProcessor
94
+
95
+ self._processor_instance = AutoProcessor.from_pretrained(
96
+ self._processor
97
+ )
98
+ return self._processor_instance
99
+
100
+ raise ValueError("Processor is not properly configured.")
101
+
102
+ def setup(self):
103
+ # Load test datasets
104
+ test_datasets = {
105
+ ds_name: ImageClassificationDataset(
106
+ self.load_test_dataset(ds_name), processor=self.processor
107
+ )
108
+ for ds_name in self._test_datasets
109
+ }
110
+ self.test_datasets = test_datasets
111
+ self.test_dataloaders = {
112
+ ds_name: self.fabric.setup_dataloaders(
113
+ self.get_dataloader(ds, stage="test")
114
+ )
115
+ for ds_name, ds in test_datasets.items()
116
+ }
117
+
118
+ def load_test_dataset(self, dataset_name: str, *args, **kwargs) -> Dataset:
119
+ """
120
+ Load the testing dataset for the specified model.
121
+
122
+ Args:
123
+ dataset_name (str): The name of the model.
124
+
125
+ Returns:
126
+ Dataset: The instantiated testing dataset.
127
+ """
128
+ test_dataset = self._test_datasets[dataset_name]
129
+ if isinstance(test_dataset, (DictConfig, dict)):
130
+ return instantiate(test_dataset, *args, **kwargs)
131
+ else:
132
+ return test_dataset
133
+
134
+ def get_dataloader(self, dataset, stage: str):
135
+ """Create a DataLoader for the specified dataset and training stage.
136
+
137
+ Constructs a PyTorch DataLoader with stage-appropriate configurations:
138
+ - Training stage: shuffling enabled by default
139
+ - Validation/test stages: shuffling disabled by default
140
+
141
+ Args:
142
+ dataset: The dataset to wrap in a DataLoader.
143
+ stage (str): Training stage, must be one of "train", "val", or "test".
144
+ Determines default shuffling behavior.
145
+
146
+ Returns:
147
+ DataLoader: Configured DataLoader for the given dataset and stage.
148
+ """
149
+ assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
150
+ dataloader_kwargs = dict(self.dataloader_kwargs)
151
+ if "shuffle" not in dataloader_kwargs:
152
+ dataloader_kwargs["shuffle"] = stage == "train"
153
+ return DataLoader(dataset, **dataloader_kwargs)
154
+
155
+ @torch.no_grad()
156
+ def _evaluate(
157
+ self,
158
+ classifier,
159
+ test_loader,
160
+ num_classes: int,
161
+ task_name: str = None,
162
+ ):
163
+ classifier.eval()
164
+ accuracy = Accuracy(task="multiclass", num_classes=num_classes)
165
+ loss_metric = MeanMetric()
166
+ if RuntimeConstants.debug:
167
+ log.info("Running under fast_dev_run mode, evaluating on a single batch.")
168
+ test_loader = itertools.islice(test_loader, 1)
169
+ else:
170
+ test_loader = test_loader
171
+
172
+ pbar = tqdm(
173
+ test_loader,
174
+ desc=f"Evaluating {task_name}" if task_name is not None else "Evaluating",
175
+ leave=False,
176
+ dynamic_ncols=True,
177
+ )
178
+ for batch in pbar:
179
+ inputs, targets = batch
180
+ outputs = classifier(inputs)
181
+ logits = _get_logits_from_model_output(outputs)
182
+ if logits.device != targets.device:
183
+ targets = targets.to(logits.device)
184
+
185
+ loss = F.cross_entropy(logits, targets)
186
+ loss_metric.update(loss.detach().cpu())
187
+ acc = accuracy(logits.detach().cpu(), targets.detach().cpu())
188
+ pbar.set_postfix(
189
+ {
190
+ "accuracy": accuracy.compute().item(),
191
+ "loss": loss_metric.compute().item(),
192
+ }
193
+ )
194
+
195
+ acc = accuracy.compute().item()
196
+ loss = loss_metric.compute().item()
197
+ results = {"accuracy": acc, "loss": loss}
198
+ return results
199
+
200
+ def evaluate(
201
+ self,
202
+ model: Union["AutoModelForImageClassification", nn.Module],
203
+ name: str = None,
204
+ **kwargs,
205
+ ) -> Dict[str, Any]:
206
+ assert isinstance(
207
+ model, nn.Module
208
+ ), f"Expected model to be an instance of nn.Module, but got {type(model)}"
209
+
210
+ if not self._is_setup:
211
+ self.setup()
212
+
213
+ classifier = self.fabric.to_device(model)
214
+ classifier.eval()
215
+ report = {}
216
+ # collect basic model information
217
+ training_params, all_params = count_parameters(model)
218
+ report["model_info"] = {
219
+ "trainable_params": training_params,
220
+ "all_params": all_params,
221
+ "trainable_percentage": training_params / all_params,
222
+ }
223
+ if name is not None:
224
+ report["model_info"]["name"] = name
225
+
226
+ # evaluate on each task
227
+ pbar = tqdm(
228
+ self.test_dataloaders.items(),
229
+ desc="Evaluating tasks",
230
+ total=len(self.test_dataloaders),
231
+ )
232
+ for task_name, test_dataloader in pbar:
233
+ set_active_option(classifier, task_name)
234
+ num_classes = get_num_classes(task_name)
235
+ result = self._evaluate(
236
+ classifier,
237
+ test_dataloader,
238
+ num_classes=num_classes,
239
+ task_name=task_name,
240
+ )
241
+ report[task_name] = result
242
+
243
+ # calculate the average accuracy and loss
244
+ if "average" not in report:
245
+ report["average"] = {}
246
+ accuracies = [
247
+ value["accuracy"]
248
+ for key, value in report.items()
249
+ if "accuracy" in value
250
+ ]
251
+ if len(accuracies) > 0:
252
+ average_accuracy = sum(accuracies) / len(accuracies)
253
+ report["average"]["accuracy"] = average_accuracy
254
+ losses = [value["loss"] for key, value in report.items() if "loss" in value]
255
+ if len(losses) > 0:
256
+ average_loss = sum(losses) / len(losses)
257
+ report["average"]["loss"] = average_loss
258
+
259
+ log.info(f"Evaluation Result: {report}")
260
+ if self.fabric.is_global_zero and len(self.fabric._loggers) > 0:
261
+ save_path = os.path.join(self.log_dir, "report.json")
262
+ for version in itertools.count(1):
263
+ if not os.path.exists(save_path):
264
+ break
265
+ # if the file already exists, increment the version to avoid overwriting
266
+ save_path = os.path.join(self.log_dir, f"report_{version}.json")
267
+ with open(save_path, "w") as fp:
268
+ json.dump(report, fp)
269
+ log.info(f"Evaluation report saved to {save_path}")
270
+ return report
@@ -31,6 +31,11 @@ _import_structure = {
31
31
  ],
32
32
  "dtype": ["get_dtype", "parse_dtype"],
33
33
  "fabric": ["seed_everything_by_time"],
34
+ "hydra_utils": [
35
+ "initialize_hydra_config",
36
+ "get_default_config_path",
37
+ "get_hydra_output_dir",
38
+ ],
34
39
  "instantiate_utils": [
35
40
  "instantiate",
36
41
  "is_instantiable",
@@ -40,6 +45,7 @@ _import_structure = {
40
45
  "json": ["load_from_json", "save_to_json", "print_json"],
41
46
  "lazy_state_dict": ["LazyStateDict"],
42
47
  "misc": [
48
+ "DeprecationWarningMeta",
43
49
  "first",
44
50
  "has_length",
45
51
  "join_lists",
@@ -53,6 +59,7 @@ _import_structure = {
53
59
  "get_parameter_summary",
54
60
  "human_readable",
55
61
  "print_parameters",
62
+ "print_trainable_parameters",
56
63
  "state_dict_to_vector",
57
64
  "trainable_state_dict",
58
65
  "vector_to_state_dict",
@@ -121,6 +128,11 @@ if TYPE_CHECKING:
121
128
  )
122
129
  from .dtype import get_dtype, parse_dtype
123
130
  from .fabric import seed_everything_by_time
131
+ from .hydra_utils import (
132
+ get_default_config_path,
133
+ get_hydra_output_dir,
134
+ initialize_hydra_config,
135
+ )
124
136
  from .instantiate_utils import (
125
137
  instantiate,
126
138
  is_instantiable,
@@ -129,7 +141,13 @@ if TYPE_CHECKING:
129
141
  )
130
142
  from .json import load_from_json, print_json, save_to_json
131
143
  from .lazy_state_dict import LazyStateDict
132
- from .misc import first, has_length, join_lists, validate_and_suggest_corrections
144
+ from .misc import (
145
+ DeprecationWarningMeta,
146
+ first,
147
+ has_length,
148
+ join_lists,
149
+ validate_and_suggest_corrections,
150
+ )
133
151
  from .packages import compare_versions, import_object
134
152
  from .parameters import (
135
153
  check_parameters_all_equal,
@@ -138,6 +156,7 @@ if TYPE_CHECKING:
138
156
  get_parameter_summary,
139
157
  human_readable,
140
158
  print_parameters,
159
+ print_trainable_parameters,
141
160
  state_dict_to_vector,
142
161
  trainable_state_dict,
143
162
  vector_to_state_dict,
@@ -95,7 +95,7 @@ class InfiniteDataLoader:
95
95
  f"Failed to retrieve data from data loader after {self.max_retries} attempts. "
96
96
  f"Last error: [{type(last_exception).__name__}]{last_exception}. "
97
97
  + (
98
- f"The data loader appears to be empty."
98
+ f"The data loader may be empty."
99
99
  if isinstance(last_exception, StopIteration)
100
100
  else ""
101
101
  )
@@ -41,3 +41,22 @@ def dict_map(f, d: dict, *, max_level: int = -1, skip_levels=0, inplace=False):
41
41
 
42
42
  dict_map_impl(d, ans, 0)
43
43
  return ans
44
+
45
+
46
+ def dict_merge(dicts: Iterable[dict], disjoint: bool = True) -> dict:
47
+ """Merge multiple dictionaries into one.
48
+
49
+ Args:
50
+ dicts (Iterable[dict]): iterable of dictionaries to merge
51
+ disjoint (bool, optional): if True, raises an error on key conflicts. Defaults to True.
52
+
53
+ Returns:
54
+ dict: merged dictionary
55
+ """
56
+ merged_dict = type(dicts[0])()
57
+ for d in dicts:
58
+ for k, v in d.items():
59
+ if disjoint and k in merged_dict:
60
+ raise ValueError(f"Key conflict when merging dictionaries: {k}")
61
+ merged_dict[k] = v
62
+ return merged_dict
@@ -146,3 +146,22 @@ def validate_expected_param_dtype(
146
146
  raise ValueError(
147
147
  f"Parameter {name} has dtype {param.dtype}, but expected {dtype}"
148
148
  )
149
+
150
+
151
+ def dtype_support_svd(dtype: torch.dtype) -> bool:
152
+ """
153
+ Check if the given dtype is supported for SVD operation in PyTorch.
154
+
155
+ Args:
156
+ dtype (torch.dtype): The data type to check.
157
+
158
+ Returns:
159
+ bool: True if the dtype is supported for SVD, False otherwise.
160
+ """
161
+ supported_dtypes = {
162
+ torch.float32,
163
+ torch.float64,
164
+ torch.complex64,
165
+ torch.complex128,
166
+ }
167
+ return dtype in supported_dtypes
@@ -1,4 +1,79 @@
1
+ import logging
2
+ import os
3
+
1
4
  import hydra.core.hydra_config
5
+ from hydra import compose, initialize
6
+ from omegaconf import DictConfig
7
+
8
+ from fusion_bench.constants import PROJECT_ROOT_PATH
9
+
10
+ log = logging.getLogger(__name__)
11
+
12
+
13
+ def get_default_config_path():
14
+ """
15
+ Get the default configuration path by searching in common locations.
16
+ """
17
+ for config_path_root in [os.getcwd(), PROJECT_ROOT_PATH]:
18
+ for config_dir in ["config", "fusion_bench_config"]:
19
+ config_path = os.path.join(config_path_root, config_dir)
20
+ if os.path.exists(config_path) and os.path.isdir(config_path):
21
+ return os.path.abspath(config_path)
22
+ return None
23
+
24
+
25
+ def initialize_hydra_config(
26
+ config_name: str,
27
+ overrides: list[str] = None,
28
+ config_path: str = None,
29
+ return_hydra_config: bool = False,
30
+ ) -> DictConfig:
31
+ """
32
+ Load the Hydra configuration.
33
+
34
+ Args:
35
+ config_name (str): The name of the configuration file (without .yaml extension).
36
+ overrides (list[str]): A list of configuration overrides.
37
+ config_path (str): The path to the configuration directory. If None, it will be automatically detected.
38
+ return_hydra_config (bool): If True, return the Hydra configuration object.
39
+
40
+ Returns:
41
+ DictConfig: The loaded configuration.
42
+
43
+ Example:
44
+ >>> cfg = initialize_hydra_config(
45
+ ... config_name="fabric_model_fusion",
46
+ ... overrides=["method=dummy", "modelpool=dummy"],
47
+ ... )
48
+ >>> print(cfg.method)
49
+ """
50
+ if config_path is None:
51
+ config_path = get_default_config_path()
52
+
53
+ # check config_path validity
54
+ if config_path is None:
55
+ raise FileNotFoundError("Could not find configuration directory.")
56
+ if not os.path.isdir(config_path):
57
+ raise NotADirectoryError(
58
+ f"Configuration path {config_path} do not exists or is not a directory."
59
+ )
60
+
61
+ if overrides is None:
62
+ overrides = []
63
+
64
+ with initialize(
65
+ version_base=None,
66
+ config_path=os.path.relpath(
67
+ config_path,
68
+ start=os.path.dirname(__file__),
69
+ ),
70
+ ):
71
+ cfg = compose(
72
+ config_name=config_name,
73
+ overrides=overrides,
74
+ return_hydra_config=return_hydra_config,
75
+ )
76
+ return cfg
2
77
 
3
78
 
4
79
  def get_hydra_output_dir():
@@ -9,6 +9,7 @@ __all__ = [
9
9
  "join_lists",
10
10
  "attr_equal",
11
11
  "validate_and_suggest_corrections",
12
+ "DeprecationWarningMeta",
12
13
  ]
13
14
 
14
15
 
@@ -20,6 +20,10 @@ def _get_package_version(name: str) -> "Version":
20
20
  return version.parse("0.0.0")
21
21
 
22
22
 
23
+ def is_ray_available():
24
+ return _is_package_available("ray")
25
+
26
+
23
27
  def is_pyav_available():
24
28
  return _is_package_available("av")
25
29
 
@@ -10,6 +10,7 @@ from .type import StateDictType
10
10
  __all__ = [
11
11
  "count_parameters",
12
12
  "print_parameters",
13
+ "print_trainable_parameters",
13
14
  "check_parameters_all_equal",
14
15
  "get_parameter_statistics",
15
16
  "state_dict_to_vector",
@@ -282,6 +283,38 @@ def print_parameters(
282
283
  )
283
284
 
284
285
 
286
+ def print_trainable_parameters(
287
+ module: nn.Module,
288
+ is_human_readable: bool = True,
289
+ print_fn=print,
290
+ non_zero_only: bool = False,
291
+ ):
292
+ """
293
+ Print the names and number of trainable parameters in a PyTorch model.
294
+
295
+ Args:
296
+ module (nn.Module): The PyTorch model.
297
+ is_human_readable (bool, optional): Whether to print the number of parameters in a human-readable format. Defaults to True.
298
+ print_fn (callable, optional): The function to use for printing. Defaults to print.
299
+ non_zero_only (bool, optional): Whether to count only non-zero parameters. Defaults to False.
300
+
301
+ Prints:
302
+ The names and number of trainable parameters in the model.
303
+
304
+ ```python
305
+ print_trainable_parameters(model)
306
+ # weight: 1.50M parameters
307
+ # bias: 500.00K parameters
308
+ ```
309
+ """
310
+ for name, param in module.named_parameters():
311
+ if param.requires_grad:
312
+ num_params = _numel(param, non_zero_only=non_zero_only)
313
+ if is_human_readable:
314
+ num_params = human_readable(num_params)
315
+ print_fn(f"{name}: {num_params} parameters")
316
+
317
+
285
318
  def check_parameters_all_equal(
286
319
  list_of_param_names: List[Union[StateDictType, nn.Module, List[str]]],
287
320
  ) -> None:
@@ -93,11 +93,11 @@ def print_bordered(
93
93
  Print a message with a colored border.
94
94
 
95
95
  Args:
96
- message (str): The message to print.
97
- title (str, optional): The title of the panel. Defaults to None.
98
- style (str, optional): The color style for the border. Defaults to "cyan".
99
- code_style (str, optional): The syntax highlighting style if the message is code.
100
- Set to None for plain text. Defaults to "python".
96
+ message (str): The message to print.
97
+ title (str, optional): The title of the panel. Defaults to None.
98
+ style (str, optional): The color style for the border. Defaults to "cyan".
99
+ code_style (str, optional): The syntax highlighting style if the message is code.
100
+ Set to None for plain text. Defaults to "python".
101
101
  """
102
102
  if code_style:
103
103
  if format_code:
@@ -168,7 +168,7 @@ def print_config_tree(
168
168
  "callbacks",
169
169
  "logger",
170
170
  "trainer",
171
- "paths",
171
+ "path",
172
172
  "extras",
173
173
  ),
174
174
  resolve: bool = False,
@@ -179,11 +179,20 @@ def print_config_tree(
179
179
  ) -> None:
180
180
  """Prints the contents of a DictConfig as a tree structure using the Rich library.
181
181
 
182
- :param cfg: A DictConfig composed by Hydra.
183
- :param print_order: Determines in what order config components are printed. Default is ``("data", "model",
184
- "callbacks", "logger", "trainer", "paths", "extras")``.
185
- :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.
186
- :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``.
182
+ Args:
183
+ cfg (DictConfig): A DictConfig composed by Hydra.
184
+ print_order (Sequence[str], optional): Determines in what order config components are printed.
185
+ Defaults to ``("data", "model", "callbacks", "logger", "trainer", "paths", "extras")``.
186
+ resolve (bool, optional): Whether to resolve reference fields of DictConfig.
187
+ Defaults to ``False``.
188
+ save_to_file (bool, optional): Whether to export config to the hydra output folder.
189
+ Defaults to ``False``.
190
+ theme (str, optional): The theme to use for syntax highlighting. Defaults to "monokai".
191
+ background_color (str, optional): The background color to use for syntax highlighting.
192
+ Defaults to "default".
193
+
194
+ Returns:
195
+ None
187
196
  """
188
197
  style = "tree"
189
198
  tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
@@ -200,18 +209,13 @@ def print_config_tree(
200
209
  )
201
210
  )
202
211
 
203
- # add all the other fields to queue (not specified in `print_order`)
204
- for field in cfg:
205
- if field not in queue:
206
- queue.append(field)
207
-
208
212
  # generate config tree from queue
209
213
  for field in queue:
210
214
  branch = tree.add(field, style=style, guide_style=style)
211
215
 
212
216
  config_group = cfg[field]
213
217
  if isinstance(config_group, DictConfig):
214
- branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
218
+ branch_content = OmegaConf.to_yaml(config_group, resolve=resolve).strip()
215
219
  else:
216
220
  branch_content = str(config_group)
217
221
 
@@ -224,13 +228,32 @@ def print_config_tree(
224
228
  )
225
229
  )
226
230
 
231
+ # add all the other fields to queue (not specified in `print_order`)
232
+ other_fields = [field for field in cfg if field not in queue]
233
+ if other_fields:
234
+ others_branch = tree.add(Text("[others]"), style=style, guide_style=style)
235
+
236
+ other_cfg = OmegaConf.create({field: cfg[field] for field in other_fields})
237
+ branch_content = OmegaConf.to_yaml(other_cfg, resolve=resolve).strip()
238
+
239
+ others_branch.add(
240
+ rich.syntax.Syntax(
241
+ branch_content, "yaml", theme=theme, background_color=background_color
242
+ )
243
+ )
244
+
227
245
  # print config tree
228
246
  rich.print(tree)
229
247
 
230
248
  # save config tree to file
231
249
  if save_to_file:
232
- with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
233
- rich.print(tree, file=file)
250
+ if not cfg.get("paths") or not cfg.paths.get("output_dir"):
251
+ log.error(
252
+ "Cannot save config tree to file. 'paths.output_dir' is not specified in the config."
253
+ )
254
+ else:
255
+ with open(Path(cfg.path.output_dir, "config_tree.log"), "w") as file:
256
+ rich.print(tree, file=file)
234
257
 
235
258
 
236
259
  @rank_zero_only