fusion-bench 0.2.28__py3-none-any.whl → 0.2.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. fusion_bench/constants/__init__.py +5 -1
  2. fusion_bench/constants/runtime.py +111 -7
  3. fusion_bench/dataset/gsm8k.py +6 -2
  4. fusion_bench/dataset/image_corruption/make_corruption.py +168 -0
  5. fusion_bench/method/__init__.py +10 -2
  6. fusion_bench/method/base_algorithm.py +29 -19
  7. fusion_bench/method/classification/image_classification_finetune.py +1 -2
  8. fusion_bench/method/gossip/clip_task_wise_gossip.py +1 -29
  9. fusion_bench/metrics/model_kinship/__init__.py +2 -0
  10. fusion_bench/metrics/model_kinship/calculate.py +77 -0
  11. fusion_bench/metrics/model_kinship/calculate_split.py +171 -0
  12. fusion_bench/metrics/model_kinship/utility.py +184 -0
  13. fusion_bench/metrics/nyuv2/__init__.py +31 -0
  14. fusion_bench/metrics/nyuv2/depth.py +30 -0
  15. fusion_bench/metrics/nyuv2/loss.py +40 -0
  16. fusion_bench/metrics/nyuv2/noise.py +24 -0
  17. fusion_bench/metrics/nyuv2/normal.py +34 -1
  18. fusion_bench/metrics/nyuv2/segmentation.py +35 -1
  19. fusion_bench/mixins/clip_classification.py +30 -2
  20. fusion_bench/mixins/lightning_fabric.py +46 -5
  21. fusion_bench/mixins/rich_live.py +76 -0
  22. fusion_bench/modelpool/base_pool.py +86 -5
  23. fusion_bench/models/masks/mask_model.py +8 -2
  24. fusion_bench/models/open_clip/modeling.py +7 -0
  25. fusion_bench/models/wrappers/layer_wise_fusion.py +41 -3
  26. fusion_bench/models/wrappers/task_wise_fusion.py +14 -3
  27. fusion_bench/scripts/cli.py +14 -0
  28. fusion_bench/scripts/webui.py +250 -17
  29. fusion_bench/utils/__init__.py +14 -0
  30. fusion_bench/utils/data.py +100 -9
  31. fusion_bench/utils/devices.py +3 -1
  32. fusion_bench/utils/fabric.py +185 -4
  33. fusion_bench/utils/instantiate_utils.py +29 -18
  34. fusion_bench/utils/json.py +6 -0
  35. fusion_bench/utils/misc.py +16 -0
  36. fusion_bench/utils/rich_utils.py +123 -6
  37. fusion_bench/utils/validation.py +197 -0
  38. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/METADATA +72 -13
  39. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/RECORD +49 -45
  40. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +6 -19
  41. fusion_bench_config/llama_full_finetune.yaml +4 -16
  42. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  43. fusion_bench_config/nyuv2_config.yaml +4 -13
  44. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  45. fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  46. fusion_bench/utils/auto.py +0 -31
  47. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/WHEEL +0 -0
  48. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/entry_points.txt +0 -0
  49. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/licenses/LICENSE +0 -0
  50. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/top_level.txt +0 -0
@@ -1,24 +1,205 @@
1
1
  import time
2
- from typing import Optional
2
+ from contextlib import contextmanager
3
+ from functools import wraps
4
+ from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
3
5
 
4
6
  import lightning as L
7
+ import torch
5
8
 
6
9
  from fusion_bench.utils.pylogger import get_rankzero_logger
7
10
 
8
11
  log = get_rankzero_logger(__name__)
9
12
 
13
+ T = TypeVar("T")
10
14
 
11
- def seed_everything_by_time(fabric: Optional[L.Fabric] = None):
15
+
16
+ def seed_everything_by_time(fabric: Optional[L.Fabric] = None) -> int:
12
17
  """
13
- Set seed for all processes by time.
18
+ Set seed for all processes based on current timestamp.
19
+
20
+ This function generates a time-based seed on the global zero process and broadcasts
21
+ it to all other processes in a distributed setting to ensure reproducibility across
22
+ all workers. When no fabric instance is provided, it generates a seed locally without
23
+ synchronization.
24
+
25
+ Args:
26
+ fabric: Optional Lightning Fabric instance for distributed synchronization.
27
+ If None, seed is generated locally without broadcasting.
28
+
29
+ Returns:
30
+ The seed value used for random number generation.
31
+
32
+ Example:
33
+ ```python
34
+ import lightning as L
35
+ from fusion_bench.utils.fabric import seed_everything_by_time
36
+
37
+ # With fabric (distributed)
38
+ fabric = L.Fabric(accelerator="auto", devices=2)
39
+ fabric.launch()
40
+ seed = seed_everything_by_time(fabric)
41
+ print(f"All processes using seed: {seed}")
42
+
43
+ # Without fabric (single process)
44
+ seed = seed_everything_by_time()
45
+ print(f"Using seed: {seed}")
46
+ ```
47
+
48
+ Note:
49
+ - In distributed settings, only the global zero process generates the seed
50
+ - All other processes receive the broadcasted seed for consistency
51
+ - The seed is based on `time.time()`, so it will differ across runs
14
52
  """
15
- # set seed for all processes
53
+ # Generate seed on global zero process, None on others
16
54
  if fabric is None or fabric.is_global_zero:
17
55
  seed = int(time.time())
56
+ log.info(f"Generated time-based seed: {seed}")
18
57
  else:
19
58
  seed = None
59
+
60
+ # Broadcast seed to all processes in distributed setting
20
61
  if fabric is not None:
21
62
  log.debug(f"Broadcasting seed `{seed}` to all processes")
22
63
  fabric.barrier()
23
64
  seed = fabric.broadcast(seed, src=0)
65
+
66
+ # Apply seed to all random number generators
24
67
  L.seed_everything(seed)
68
+ return seed
69
+
70
+
71
+ def is_distributed(fabric: Optional[L.Fabric] = None) -> bool:
72
+ """
73
+ Check if running in distributed mode (multi-process).
74
+
75
+ Args:
76
+ fabric: Optional Lightning Fabric instance. If None, returns False.
77
+
78
+ Returns:
79
+ True if running with multiple processes, False otherwise.
80
+
81
+ Example:
82
+ ```python
83
+ fabric = L.Fabric(accelerator="auto", devices=2)
84
+ fabric.launch()
85
+ if is_distributed(fabric):
86
+ print("Running in distributed mode")
87
+ ```
88
+ """
89
+ return fabric is not None and fabric.world_size > 1
90
+
91
+
92
+ def get_world_info(fabric: Optional[L.Fabric] = None) -> Dict[str, Any]:
93
+ """
94
+ Get comprehensive information about the distributed setup.
95
+
96
+ Args:
97
+ fabric: Optional Lightning Fabric instance.
98
+
99
+ Returns:
100
+ Dictionary containing:
101
+ - world_size: Total number of processes
102
+ - global_rank: Global rank of current process
103
+ - local_rank: Local rank on current node
104
+ - is_global_zero: Whether this is the main process
105
+ - is_distributed: Whether running in distributed mode
106
+
107
+ Example:
108
+ ```python
109
+ fabric = L.Fabric(accelerator="auto", devices=2)
110
+ fabric.launch()
111
+ info = get_world_info(fabric)
112
+ print(f"Process {info['global_rank']}/{info['world_size']}")
113
+ ```
114
+ """
115
+ if fabric is None:
116
+ return {
117
+ "world_size": 1,
118
+ "global_rank": 0,
119
+ "local_rank": 0,
120
+ "is_global_zero": True,
121
+ "is_distributed": False,
122
+ }
123
+
124
+ return {
125
+ "world_size": fabric.world_size,
126
+ "global_rank": fabric.global_rank,
127
+ "local_rank": fabric.local_rank,
128
+ "is_global_zero": fabric.is_global_zero,
129
+ "is_distributed": fabric.world_size > 1,
130
+ }
131
+
132
+
133
+ def wait_for_everyone(
134
+ fabric: Optional[L.Fabric] = None, message: Optional[str] = None
135
+ ) -> None:
136
+ """
137
+ Synchronize all processes with optional logging.
138
+
139
+ This is a wrapper around fabric.barrier() with optional message logging.
140
+
141
+ Args:
142
+ fabric: Optional Lightning Fabric instance. If None, does nothing.
143
+ message: Optional message to log before synchronization.
144
+
145
+ Example:
146
+ ```python
147
+ fabric = L.Fabric(accelerator="auto", devices=2)
148
+ fabric.launch()
149
+
150
+ # Do some work...
151
+ wait_for_everyone(fabric, "Waiting after model loading")
152
+ # All processes synchronized
153
+ ```
154
+ """
155
+ if fabric is not None:
156
+ if message and fabric.is_global_zero:
157
+ log.info(message)
158
+ fabric.barrier()
159
+
160
+
161
+ @contextmanager
162
+ def rank_zero_only_context(fabric: Optional[L.Fabric] = None):
163
+ """
164
+ Context manager to execute code block only on global rank 0.
165
+
166
+ Args:
167
+ fabric: Optional Lightning Fabric instance.
168
+
169
+ Example:
170
+ ```python
171
+ fabric = L.Fabric(accelerator="auto", devices=2)
172
+ fabric.launch()
173
+
174
+ with rank_zero_only_context(fabric):
175
+ print("This prints only on rank 0")
176
+ save_checkpoint(model, "checkpoint.pt")
177
+ ```
178
+ """
179
+ should_execute = fabric is None or fabric.is_global_zero
180
+ try:
181
+ yield should_execute
182
+ finally:
183
+ pass
184
+
185
+
186
+ def print_on_rank_zero(*args, fabric: Optional[L.Fabric] = None, **kwargs) -> None:
187
+ """
188
+ Print message only on global rank 0.
189
+
190
+ Args:
191
+ *args: Arguments to pass to print().
192
+ fabric: Optional Lightning Fabric instance.
193
+ **kwargs: Keyword arguments to pass to print().
194
+
195
+ Example:
196
+ ```python
197
+ fabric = L.Fabric(accelerator="auto", devices=2)
198
+ fabric.launch()
199
+
200
+ print_on_rank_zero("Starting training", fabric=fabric)
201
+ # Prints only on rank 0
202
+ ```
203
+ """
204
+ if fabric is None or fabric.is_global_zero:
205
+ print(*args, **kwargs)
@@ -14,8 +14,8 @@ from lightning_utilities.core.rank_zero import rank_zero_only
14
14
  from omegaconf import DictConfig, OmegaConf, SCMode
15
15
  from omegaconf._utils import is_structured_config
16
16
  from rich import print
17
- from rich.panel import Panel
18
- from rich.syntax import Syntax
17
+
18
+ from fusion_bench.utils.rich_utils import print_bordered
19
19
 
20
20
  PRINT_FUNCTION_CALL = True
21
21
  """
@@ -67,12 +67,22 @@ def _resolve_callable_name(f: Callable[..., Any]) -> str:
67
67
  return full_name
68
68
 
69
69
 
70
- def _format_args_kwargs(args, kwargs):
70
+ def _get_obj_str(obj: Any) -> str:
71
+ if isinstance(obj, (str, int, float, bool, type(None))):
72
+ return repr(obj)
73
+ else:
74
+ return f"'<{type(obj).__name__} object>'"
75
+
76
+
77
+ def _format_args_kwargs(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> str:
71
78
  result_strings = []
72
79
  if len(args) > 0:
73
- result_strings.append(", ".join(repr(arg) for arg in args))
80
+ result_strings.append(", ".join(_get_obj_str(arg) for arg in args))
81
+
74
82
  if len(kwargs) > 0:
75
- result_strings.append(", ".join(f"{k}={repr(v)}" for k, v in kwargs.items()))
83
+ result_strings.append(
84
+ ", ".join(f"{k}={_get_obj_str(v)}" for k, v in kwargs.items())
85
+ )
76
86
 
77
87
  if len(result_strings) == 0:
78
88
  return ""
@@ -145,14 +155,14 @@ def _call_target(
145
155
  if _partial_:
146
156
  if PRINT_FUNCTION_CALL and getattr(rank_zero_only, "rank", 0) == 0:
147
157
  call_str = f"functools.partial({_resolve_callable_name(_target_)}, {_format_args_kwargs(args, kwargs)})"
148
- PRINT_FUNCTION_CALL_FUNC(
149
- Panel(
150
- Syntax(call_str, "python", theme="monokai", word_wrap=True),
151
- title="Instantiate by calling partial",
152
- border_style="cyan",
153
- )
158
+ print_bordered(
159
+ call_str,
160
+ code_style="python",
161
+ title=f"Instantiate by calling {'function' if not isinstance(_target_, type) else 'class'}",
162
+ style="cyan",
163
+ expand=False,
164
+ print_fn=PRINT_FUNCTION_CALL_FUNC,
154
165
  )
155
-
156
166
  if CATCH_EXCEPTION:
157
167
  try:
158
168
  return functools.partial(_target_, *args, **kwargs)
@@ -169,12 +179,13 @@ def _call_target(
169
179
  else:
170
180
  if PRINT_FUNCTION_CALL and getattr(rank_zero_only, "rank", 0) == 0:
171
181
  call_str = f"{_resolve_callable_name(_target_)}({_format_args_kwargs(args, kwargs)})"
172
- PRINT_FUNCTION_CALL_FUNC(
173
- Panel(
174
- Syntax(call_str, "python", theme="monokai", word_wrap=True),
175
- title="Instantiate by calling function",
176
- border_style="green",
177
- )
182
+ print_bordered(
183
+ call_str,
184
+ code_style="python",
185
+ title=f"Instantiate by calling {'function' if not isinstance(_target_, type) else 'class'}",
186
+ style="green",
187
+ expand=False,
188
+ print_fn=PRINT_FUNCTION_CALL_FUNC,
178
189
  )
179
190
  if CATCH_EXCEPTION:
180
191
  try:
@@ -2,6 +2,8 @@ import json
2
2
  from pathlib import Path
3
3
  from typing import TYPE_CHECKING, Any, Union
4
4
 
5
+ from fusion_bench.utils.validation import validate_file_exists
6
+
5
7
  if TYPE_CHECKING:
6
8
  from pyarrow.fs import FileSystem
7
9
 
@@ -49,6 +51,9 @@ def load_from_json(
49
51
 
50
52
  Returns:
51
53
  Union[dict, list]: the loaded object
54
+
55
+ Raises:
56
+ ValidationError: If the file doesn't exist (when using local filesystem)
52
57
  """
53
58
  if filesystem is not None:
54
59
  # Check if it's an fsspec-based filesystem (like s3fs)
@@ -65,6 +70,7 @@ def load_from_json(
65
70
  return json.loads(json_data)
66
71
  else:
67
72
  # Use standard Python file operations
73
+ validate_file_exists(path)
68
74
  with open(path, "r") as f:
69
75
  return json.load(f)
70
76
 
@@ -178,3 +178,19 @@ def validate_and_suggest_corrections(
178
178
  if matches:
179
179
  msg += f". Did you mean {', '.join(repr(m) for m in matches)}?"
180
180
  raise ValueError(msg)
181
+
182
+
183
+ class DeprecationWarningMeta(type):
184
+ """
185
+ Metaclass that issues a deprecation warning whenever a class using it is instantiated.
186
+ """
187
+
188
+ def __call__(cls, *args, **kwargs):
189
+ import warnings
190
+
191
+ warnings.warn(
192
+ f"{cls.__name__} is deprecated and will be removed in a future version. ",
193
+ DeprecationWarning,
194
+ stacklevel=2,
195
+ )
196
+ return super(DeprecationWarningMeta, cls).__call__(*args, **kwargs)
@@ -1,6 +1,6 @@
1
1
  import logging
2
2
  from pathlib import Path
3
- from typing import Sequence
3
+ from typing import Optional, Sequence
4
4
 
5
5
  import rich
6
6
  import rich.syntax
@@ -19,6 +19,9 @@ from rich.text import Text
19
19
  from rich.traceback import install as install_rich_traceback
20
20
 
21
21
  from fusion_bench.utils import pylogger
22
+ from fusion_bench.utils.packages import _is_package_available
23
+
24
+ install_rich_traceback()
22
25
 
23
26
  log = pylogger.RankedLogger(__name__, rank_zero_only=True)
24
27
 
@@ -61,7 +64,31 @@ def display_available_styles():
61
64
  console.print(Columns(style_samples, equal=True, expand=False))
62
65
 
63
66
 
64
- def print_bordered(message, title=None, style="blue", code_style=None):
67
+ def format_code_str(message: str, code_style="python"):
68
+ if code_style.lower() == "python" and _is_package_available("black"):
69
+ # Use black formatting for python code if black is available
70
+ import black
71
+
72
+ try:
73
+ message = black.format_str(message, mode=black.Mode())
74
+ except black.InvalidInput:
75
+ pass # If black fails, use the original message
76
+
77
+ return message.strip()
78
+
79
+
80
+ def print_bordered(
81
+ message,
82
+ title=None,
83
+ style="blue",
84
+ code_style=None,
85
+ *,
86
+ expand: bool = True,
87
+ theme: str = "monokai",
88
+ background_color: Optional[str] = "default",
89
+ print_fn=print,
90
+ format_code: bool = True,
91
+ ):
65
92
  """
66
93
  Print a message with a colored border.
67
94
 
@@ -73,12 +100,63 @@ def print_bordered(message, title=None, style="blue", code_style=None):
73
100
  Set to None for plain text. Defaults to "python".
74
101
  """
75
102
  if code_style:
76
- content = Syntax(message, code_style, theme="monokai", word_wrap=True)
103
+ if format_code:
104
+ message = format_code_str(message, code_style)
105
+ content = Syntax(
106
+ message,
107
+ code_style,
108
+ word_wrap=True,
109
+ theme=theme,
110
+ background_color=background_color,
111
+ )
77
112
  else:
78
113
  content = Text(message)
79
114
 
80
- panel = Panel(content, title=title, border_style=style)
81
- print(panel)
115
+ panel = Panel(content, title=title, border_style=style, expand=expand)
116
+ print_fn(panel)
117
+
118
+
119
+ def print_code(
120
+ message,
121
+ title=None,
122
+ code_style=None,
123
+ *,
124
+ expand: bool = True,
125
+ theme: str = "monokai",
126
+ background_color: Optional[str] = "default",
127
+ print_fn=print,
128
+ ):
129
+ """
130
+ Print code or plain text with optional syntax highlighting.
131
+
132
+ Args:
133
+ message (str): The message or code to print.
134
+ title (str, optional): Optional title associated with this output. Currently
135
+ not used by this function, but kept for API compatibility. Defaults to None.
136
+ code_style (str, optional): The language/lexer name for syntax highlighting
137
+ (for example, ``"python"``). If ``None``, the message is rendered as plain
138
+ text without syntax highlighting. Defaults to ``None``.
139
+ expand (bool, optional): Placeholder flag for API symmetry with other printing
140
+ helpers. It is not used in the current implementation. Defaults to True.
141
+ theme (str, optional): Name of the Rich syntax highlighting theme to use when
142
+ ``code_style`` is provided. Defaults to ``"monokai"``.
143
+ background_color (str, optional): Background color style to apply to the code
144
+ block when using syntax highlighting. Defaults to ``"default"``.
145
+ print_fn (Callable, optional): Function used to render the resulting Rich
146
+ object. Defaults to :func:`rich.print`.
147
+ """
148
+ if code_style:
149
+ content = Syntax(
150
+ message,
151
+ code_style,
152
+ word_wrap=True,
153
+ theme=theme,
154
+ background_color=background_color,
155
+ )
156
+ else:
157
+ content = Text(message)
158
+
159
+ print_fn(content)
82
160
 
83
161
 
84
162
  @rank_zero_only
@@ -95,6 +173,9 @@ def print_config_tree(
95
173
  ),
96
174
  resolve: bool = False,
97
175
  save_to_file: bool = False,
176
+ *,
177
+ theme: str = "monokai",
178
+ background_color: Optional[str] = "default",
98
179
  ) -> None:
99
180
  """Prints the contents of a DictConfig as a tree structure using the Rich library.
100
181
 
@@ -134,7 +215,14 @@ def print_config_tree(
134
215
  else:
135
216
  branch_content = str(config_group)
136
217
 
137
- branch.add(rich.syntax.Syntax(branch_content, "yaml"))
218
+ branch.add(
219
+ rich.syntax.Syntax(
220
+ branch_content,
221
+ "yaml",
222
+ theme=theme,
223
+ background_color=background_color,
224
+ )
225
+ )
138
226
 
139
227
  # print config tree
140
228
  rich.print(tree)
@@ -145,6 +233,35 @@ def print_config_tree(
145
233
  rich.print(tree, file=file)
146
234
 
147
235
 
236
+ @rank_zero_only
237
+ def print_config_yaml(
238
+ cfg: DictConfig,
239
+ resolve: bool = False,
240
+ output_path: Optional[str] = False,
241
+ *,
242
+ theme: str = "monokai",
243
+ background_color: Optional[str] = "default",
244
+ ) -> None:
245
+ """
246
+ Prints the contents of a DictConfig as a YAML string using the Rich library.
247
+
248
+ Args:
249
+ cfg: A DictConfig composed by Hydra.
250
+ resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.
251
+ output_path: Optional path to export the config YAML to. If provided, the file is written to this path.
252
+ """
253
+ config_yaml = OmegaConf.to_yaml(cfg, resolve=resolve)
254
+ syntax = rich.syntax.Syntax(
255
+ config_yaml, "yaml", theme=theme, background_color=background_color
256
+ )
257
+ rich.print(syntax)
258
+
259
+ if output_path:
260
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
261
+ with open(Path(output_path), "w") as file:
262
+ rich.print(syntax, file=file)
263
+
264
+
148
265
  @rank_zero_only
149
266
  def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
150
267
  """Prompts user to input tags from command line if no tags are provided in config.