fusion-bench 0.2.28__py3-none-any.whl → 0.2.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/constants/__init__.py +5 -1
- fusion_bench/constants/runtime.py +111 -7
- fusion_bench/dataset/gsm8k.py +6 -2
- fusion_bench/dataset/image_corruption/make_corruption.py +168 -0
- fusion_bench/method/__init__.py +10 -2
- fusion_bench/method/base_algorithm.py +29 -19
- fusion_bench/method/classification/image_classification_finetune.py +1 -2
- fusion_bench/method/gossip/clip_task_wise_gossip.py +1 -29
- fusion_bench/metrics/model_kinship/__init__.py +2 -0
- fusion_bench/metrics/model_kinship/calculate.py +77 -0
- fusion_bench/metrics/model_kinship/calculate_split.py +171 -0
- fusion_bench/metrics/model_kinship/utility.py +184 -0
- fusion_bench/metrics/nyuv2/__init__.py +31 -0
- fusion_bench/metrics/nyuv2/depth.py +30 -0
- fusion_bench/metrics/nyuv2/loss.py +40 -0
- fusion_bench/metrics/nyuv2/noise.py +24 -0
- fusion_bench/metrics/nyuv2/normal.py +34 -1
- fusion_bench/metrics/nyuv2/segmentation.py +35 -1
- fusion_bench/mixins/clip_classification.py +30 -2
- fusion_bench/mixins/lightning_fabric.py +46 -5
- fusion_bench/mixins/rich_live.py +76 -0
- fusion_bench/modelpool/base_pool.py +86 -5
- fusion_bench/models/masks/mask_model.py +8 -2
- fusion_bench/models/open_clip/modeling.py +7 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +41 -3
- fusion_bench/models/wrappers/task_wise_fusion.py +14 -3
- fusion_bench/scripts/cli.py +14 -0
- fusion_bench/scripts/webui.py +250 -17
- fusion_bench/utils/__init__.py +14 -0
- fusion_bench/utils/data.py +100 -9
- fusion_bench/utils/devices.py +3 -1
- fusion_bench/utils/fabric.py +185 -4
- fusion_bench/utils/instantiate_utils.py +29 -18
- fusion_bench/utils/json.py +6 -0
- fusion_bench/utils/misc.py +16 -0
- fusion_bench/utils/rich_utils.py +123 -6
- fusion_bench/utils/validation.py +197 -0
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/METADATA +72 -13
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/RECORD +49 -45
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +6 -19
- fusion_bench_config/llama_full_finetune.yaml +4 -16
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +4 -13
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench/utils/auto.py +0 -31
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/top_level.txt +0 -0
fusion_bench/utils/fabric.py
CHANGED
|
@@ -1,24 +1,205 @@
|
|
|
1
1
|
import time
|
|
2
|
-
from
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
from functools import wraps
|
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
|
|
3
5
|
|
|
4
6
|
import lightning as L
|
|
7
|
+
import torch
|
|
5
8
|
|
|
6
9
|
from fusion_bench.utils.pylogger import get_rankzero_logger
|
|
7
10
|
|
|
8
11
|
log = get_rankzero_logger(__name__)
|
|
9
12
|
|
|
13
|
+
T = TypeVar("T")
|
|
10
14
|
|
|
11
|
-
|
|
15
|
+
|
|
16
|
+
def seed_everything_by_time(fabric: Optional[L.Fabric] = None) -> int:
|
|
12
17
|
"""
|
|
13
|
-
Set seed for all processes
|
|
18
|
+
Set seed for all processes based on current timestamp.
|
|
19
|
+
|
|
20
|
+
This function generates a time-based seed on the global zero process and broadcasts
|
|
21
|
+
it to all other processes in a distributed setting to ensure reproducibility across
|
|
22
|
+
all workers. When no fabric instance is provided, it generates a seed locally without
|
|
23
|
+
synchronization.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
fabric: Optional Lightning Fabric instance for distributed synchronization.
|
|
27
|
+
If None, seed is generated locally without broadcasting.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
The seed value used for random number generation.
|
|
31
|
+
|
|
32
|
+
Example:
|
|
33
|
+
```python
|
|
34
|
+
import lightning as L
|
|
35
|
+
from fusion_bench.utils.fabric import seed_everything_by_time
|
|
36
|
+
|
|
37
|
+
# With fabric (distributed)
|
|
38
|
+
fabric = L.Fabric(accelerator="auto", devices=2)
|
|
39
|
+
fabric.launch()
|
|
40
|
+
seed = seed_everything_by_time(fabric)
|
|
41
|
+
print(f"All processes using seed: {seed}")
|
|
42
|
+
|
|
43
|
+
# Without fabric (single process)
|
|
44
|
+
seed = seed_everything_by_time()
|
|
45
|
+
print(f"Using seed: {seed}")
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
Note:
|
|
49
|
+
- In distributed settings, only the global zero process generates the seed
|
|
50
|
+
- All other processes receive the broadcasted seed for consistency
|
|
51
|
+
- The seed is based on `time.time()`, so it will differ across runs
|
|
14
52
|
"""
|
|
15
|
-
#
|
|
53
|
+
# Generate seed on global zero process, None on others
|
|
16
54
|
if fabric is None or fabric.is_global_zero:
|
|
17
55
|
seed = int(time.time())
|
|
56
|
+
log.info(f"Generated time-based seed: {seed}")
|
|
18
57
|
else:
|
|
19
58
|
seed = None
|
|
59
|
+
|
|
60
|
+
# Broadcast seed to all processes in distributed setting
|
|
20
61
|
if fabric is not None:
|
|
21
62
|
log.debug(f"Broadcasting seed `{seed}` to all processes")
|
|
22
63
|
fabric.barrier()
|
|
23
64
|
seed = fabric.broadcast(seed, src=0)
|
|
65
|
+
|
|
66
|
+
# Apply seed to all random number generators
|
|
24
67
|
L.seed_everything(seed)
|
|
68
|
+
return seed
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def is_distributed(fabric: Optional[L.Fabric] = None) -> bool:
|
|
72
|
+
"""
|
|
73
|
+
Check if running in distributed mode (multi-process).
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
fabric: Optional Lightning Fabric instance. If None, returns False.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
True if running with multiple processes, False otherwise.
|
|
80
|
+
|
|
81
|
+
Example:
|
|
82
|
+
```python
|
|
83
|
+
fabric = L.Fabric(accelerator="auto", devices=2)
|
|
84
|
+
fabric.launch()
|
|
85
|
+
if is_distributed(fabric):
|
|
86
|
+
print("Running in distributed mode")
|
|
87
|
+
```
|
|
88
|
+
"""
|
|
89
|
+
return fabric is not None and fabric.world_size > 1
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def get_world_info(fabric: Optional[L.Fabric] = None) -> Dict[str, Any]:
|
|
93
|
+
"""
|
|
94
|
+
Get comprehensive information about the distributed setup.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
fabric: Optional Lightning Fabric instance.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Dictionary containing:
|
|
101
|
+
- world_size: Total number of processes
|
|
102
|
+
- global_rank: Global rank of current process
|
|
103
|
+
- local_rank: Local rank on current node
|
|
104
|
+
- is_global_zero: Whether this is the main process
|
|
105
|
+
- is_distributed: Whether running in distributed mode
|
|
106
|
+
|
|
107
|
+
Example:
|
|
108
|
+
```python
|
|
109
|
+
fabric = L.Fabric(accelerator="auto", devices=2)
|
|
110
|
+
fabric.launch()
|
|
111
|
+
info = get_world_info(fabric)
|
|
112
|
+
print(f"Process {info['global_rank']}/{info['world_size']}")
|
|
113
|
+
```
|
|
114
|
+
"""
|
|
115
|
+
if fabric is None:
|
|
116
|
+
return {
|
|
117
|
+
"world_size": 1,
|
|
118
|
+
"global_rank": 0,
|
|
119
|
+
"local_rank": 0,
|
|
120
|
+
"is_global_zero": True,
|
|
121
|
+
"is_distributed": False,
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
return {
|
|
125
|
+
"world_size": fabric.world_size,
|
|
126
|
+
"global_rank": fabric.global_rank,
|
|
127
|
+
"local_rank": fabric.local_rank,
|
|
128
|
+
"is_global_zero": fabric.is_global_zero,
|
|
129
|
+
"is_distributed": fabric.world_size > 1,
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def wait_for_everyone(
|
|
134
|
+
fabric: Optional[L.Fabric] = None, message: Optional[str] = None
|
|
135
|
+
) -> None:
|
|
136
|
+
"""
|
|
137
|
+
Synchronize all processes with optional logging.
|
|
138
|
+
|
|
139
|
+
This is a wrapper around fabric.barrier() with optional message logging.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
fabric: Optional Lightning Fabric instance. If None, does nothing.
|
|
143
|
+
message: Optional message to log before synchronization.
|
|
144
|
+
|
|
145
|
+
Example:
|
|
146
|
+
```python
|
|
147
|
+
fabric = L.Fabric(accelerator="auto", devices=2)
|
|
148
|
+
fabric.launch()
|
|
149
|
+
|
|
150
|
+
# Do some work...
|
|
151
|
+
wait_for_everyone(fabric, "Waiting after model loading")
|
|
152
|
+
# All processes synchronized
|
|
153
|
+
```
|
|
154
|
+
"""
|
|
155
|
+
if fabric is not None:
|
|
156
|
+
if message and fabric.is_global_zero:
|
|
157
|
+
log.info(message)
|
|
158
|
+
fabric.barrier()
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
@contextmanager
|
|
162
|
+
def rank_zero_only_context(fabric: Optional[L.Fabric] = None):
|
|
163
|
+
"""
|
|
164
|
+
Context manager to execute code block only on global rank 0.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
fabric: Optional Lightning Fabric instance.
|
|
168
|
+
|
|
169
|
+
Example:
|
|
170
|
+
```python
|
|
171
|
+
fabric = L.Fabric(accelerator="auto", devices=2)
|
|
172
|
+
fabric.launch()
|
|
173
|
+
|
|
174
|
+
with rank_zero_only_context(fabric):
|
|
175
|
+
print("This prints only on rank 0")
|
|
176
|
+
save_checkpoint(model, "checkpoint.pt")
|
|
177
|
+
```
|
|
178
|
+
"""
|
|
179
|
+
should_execute = fabric is None or fabric.is_global_zero
|
|
180
|
+
try:
|
|
181
|
+
yield should_execute
|
|
182
|
+
finally:
|
|
183
|
+
pass
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def print_on_rank_zero(*args, fabric: Optional[L.Fabric] = None, **kwargs) -> None:
|
|
187
|
+
"""
|
|
188
|
+
Print message only on global rank 0.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
*args: Arguments to pass to print().
|
|
192
|
+
fabric: Optional Lightning Fabric instance.
|
|
193
|
+
**kwargs: Keyword arguments to pass to print().
|
|
194
|
+
|
|
195
|
+
Example:
|
|
196
|
+
```python
|
|
197
|
+
fabric = L.Fabric(accelerator="auto", devices=2)
|
|
198
|
+
fabric.launch()
|
|
199
|
+
|
|
200
|
+
print_on_rank_zero("Starting training", fabric=fabric)
|
|
201
|
+
# Prints only on rank 0
|
|
202
|
+
```
|
|
203
|
+
"""
|
|
204
|
+
if fabric is None or fabric.is_global_zero:
|
|
205
|
+
print(*args, **kwargs)
|
|
@@ -14,8 +14,8 @@ from lightning_utilities.core.rank_zero import rank_zero_only
|
|
|
14
14
|
from omegaconf import DictConfig, OmegaConf, SCMode
|
|
15
15
|
from omegaconf._utils import is_structured_config
|
|
16
16
|
from rich import print
|
|
17
|
-
|
|
18
|
-
from
|
|
17
|
+
|
|
18
|
+
from fusion_bench.utils.rich_utils import print_bordered
|
|
19
19
|
|
|
20
20
|
PRINT_FUNCTION_CALL = True
|
|
21
21
|
"""
|
|
@@ -67,12 +67,22 @@ def _resolve_callable_name(f: Callable[..., Any]) -> str:
|
|
|
67
67
|
return full_name
|
|
68
68
|
|
|
69
69
|
|
|
70
|
-
def
|
|
70
|
+
def _get_obj_str(obj: Any) -> str:
|
|
71
|
+
if isinstance(obj, (str, int, float, bool, type(None))):
|
|
72
|
+
return repr(obj)
|
|
73
|
+
else:
|
|
74
|
+
return f"'<{type(obj).__name__} object>'"
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _format_args_kwargs(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> str:
|
|
71
78
|
result_strings = []
|
|
72
79
|
if len(args) > 0:
|
|
73
|
-
result_strings.append(", ".join(
|
|
80
|
+
result_strings.append(", ".join(_get_obj_str(arg) for arg in args))
|
|
81
|
+
|
|
74
82
|
if len(kwargs) > 0:
|
|
75
|
-
result_strings.append(
|
|
83
|
+
result_strings.append(
|
|
84
|
+
", ".join(f"{k}={_get_obj_str(v)}" for k, v in kwargs.items())
|
|
85
|
+
)
|
|
76
86
|
|
|
77
87
|
if len(result_strings) == 0:
|
|
78
88
|
return ""
|
|
@@ -145,14 +155,14 @@ def _call_target(
|
|
|
145
155
|
if _partial_:
|
|
146
156
|
if PRINT_FUNCTION_CALL and getattr(rank_zero_only, "rank", 0) == 0:
|
|
147
157
|
call_str = f"functools.partial({_resolve_callable_name(_target_)}, {_format_args_kwargs(args, kwargs)})"
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
158
|
+
print_bordered(
|
|
159
|
+
call_str,
|
|
160
|
+
code_style="python",
|
|
161
|
+
title=f"Instantiate by calling {'function' if not isinstance(_target_, type) else 'class'}",
|
|
162
|
+
style="cyan",
|
|
163
|
+
expand=False,
|
|
164
|
+
print_fn=PRINT_FUNCTION_CALL_FUNC,
|
|
154
165
|
)
|
|
155
|
-
|
|
156
166
|
if CATCH_EXCEPTION:
|
|
157
167
|
try:
|
|
158
168
|
return functools.partial(_target_, *args, **kwargs)
|
|
@@ -169,12 +179,13 @@ def _call_target(
|
|
|
169
179
|
else:
|
|
170
180
|
if PRINT_FUNCTION_CALL and getattr(rank_zero_only, "rank", 0) == 0:
|
|
171
181
|
call_str = f"{_resolve_callable_name(_target_)}({_format_args_kwargs(args, kwargs)})"
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
182
|
+
print_bordered(
|
|
183
|
+
call_str,
|
|
184
|
+
code_style="python",
|
|
185
|
+
title=f"Instantiate by calling {'function' if not isinstance(_target_, type) else 'class'}",
|
|
186
|
+
style="green",
|
|
187
|
+
expand=False,
|
|
188
|
+
print_fn=PRINT_FUNCTION_CALL_FUNC,
|
|
178
189
|
)
|
|
179
190
|
if CATCH_EXCEPTION:
|
|
180
191
|
try:
|
fusion_bench/utils/json.py
CHANGED
|
@@ -2,6 +2,8 @@ import json
|
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
from typing import TYPE_CHECKING, Any, Union
|
|
4
4
|
|
|
5
|
+
from fusion_bench.utils.validation import validate_file_exists
|
|
6
|
+
|
|
5
7
|
if TYPE_CHECKING:
|
|
6
8
|
from pyarrow.fs import FileSystem
|
|
7
9
|
|
|
@@ -49,6 +51,9 @@ def load_from_json(
|
|
|
49
51
|
|
|
50
52
|
Returns:
|
|
51
53
|
Union[dict, list]: the loaded object
|
|
54
|
+
|
|
55
|
+
Raises:
|
|
56
|
+
ValidationError: If the file doesn't exist (when using local filesystem)
|
|
52
57
|
"""
|
|
53
58
|
if filesystem is not None:
|
|
54
59
|
# Check if it's an fsspec-based filesystem (like s3fs)
|
|
@@ -65,6 +70,7 @@ def load_from_json(
|
|
|
65
70
|
return json.loads(json_data)
|
|
66
71
|
else:
|
|
67
72
|
# Use standard Python file operations
|
|
73
|
+
validate_file_exists(path)
|
|
68
74
|
with open(path, "r") as f:
|
|
69
75
|
return json.load(f)
|
|
70
76
|
|
fusion_bench/utils/misc.py
CHANGED
|
@@ -178,3 +178,19 @@ def validate_and_suggest_corrections(
|
|
|
178
178
|
if matches:
|
|
179
179
|
msg += f". Did you mean {', '.join(repr(m) for m in matches)}?"
|
|
180
180
|
raise ValueError(msg)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class DeprecationWarningMeta(type):
|
|
184
|
+
"""
|
|
185
|
+
Metaclass that issues a deprecation warning whenever a class using it is instantiated.
|
|
186
|
+
"""
|
|
187
|
+
|
|
188
|
+
def __call__(cls, *args, **kwargs):
|
|
189
|
+
import warnings
|
|
190
|
+
|
|
191
|
+
warnings.warn(
|
|
192
|
+
f"{cls.__name__} is deprecated and will be removed in a future version. ",
|
|
193
|
+
DeprecationWarning,
|
|
194
|
+
stacklevel=2,
|
|
195
|
+
)
|
|
196
|
+
return super(DeprecationWarningMeta, cls).__call__(*args, **kwargs)
|
fusion_bench/utils/rich_utils.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
from typing import Sequence
|
|
3
|
+
from typing import Optional, Sequence
|
|
4
4
|
|
|
5
5
|
import rich
|
|
6
6
|
import rich.syntax
|
|
@@ -19,6 +19,9 @@ from rich.text import Text
|
|
|
19
19
|
from rich.traceback import install as install_rich_traceback
|
|
20
20
|
|
|
21
21
|
from fusion_bench.utils import pylogger
|
|
22
|
+
from fusion_bench.utils.packages import _is_package_available
|
|
23
|
+
|
|
24
|
+
install_rich_traceback()
|
|
22
25
|
|
|
23
26
|
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
|
24
27
|
|
|
@@ -61,7 +64,31 @@ def display_available_styles():
|
|
|
61
64
|
console.print(Columns(style_samples, equal=True, expand=False))
|
|
62
65
|
|
|
63
66
|
|
|
64
|
-
def
|
|
67
|
+
def format_code_str(message: str, code_style="python"):
|
|
68
|
+
if code_style.lower() == "python" and _is_package_available("black"):
|
|
69
|
+
# Use black formatting for python code if black is available
|
|
70
|
+
import black
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
message = black.format_str(message, mode=black.Mode())
|
|
74
|
+
except black.InvalidInput:
|
|
75
|
+
pass # If black fails, use the original message
|
|
76
|
+
|
|
77
|
+
return message.strip()
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def print_bordered(
|
|
81
|
+
message,
|
|
82
|
+
title=None,
|
|
83
|
+
style="blue",
|
|
84
|
+
code_style=None,
|
|
85
|
+
*,
|
|
86
|
+
expand: bool = True,
|
|
87
|
+
theme: str = "monokai",
|
|
88
|
+
background_color: Optional[str] = "default",
|
|
89
|
+
print_fn=print,
|
|
90
|
+
format_code: bool = True,
|
|
91
|
+
):
|
|
65
92
|
"""
|
|
66
93
|
Print a message with a colored border.
|
|
67
94
|
|
|
@@ -73,12 +100,63 @@ def print_bordered(message, title=None, style="blue", code_style=None):
|
|
|
73
100
|
Set to None for plain text. Defaults to "python".
|
|
74
101
|
"""
|
|
75
102
|
if code_style:
|
|
76
|
-
|
|
103
|
+
if format_code:
|
|
104
|
+
message = format_code_str(message, code_style)
|
|
105
|
+
content = Syntax(
|
|
106
|
+
message,
|
|
107
|
+
code_style,
|
|
108
|
+
word_wrap=True,
|
|
109
|
+
theme=theme,
|
|
110
|
+
background_color=background_color,
|
|
111
|
+
)
|
|
77
112
|
else:
|
|
78
113
|
content = Text(message)
|
|
79
114
|
|
|
80
|
-
panel = Panel(content, title=title, border_style=style)
|
|
81
|
-
|
|
115
|
+
panel = Panel(content, title=title, border_style=style, expand=expand)
|
|
116
|
+
print_fn(panel)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def print_code(
|
|
120
|
+
message,
|
|
121
|
+
title=None,
|
|
122
|
+
code_style=None,
|
|
123
|
+
*,
|
|
124
|
+
expand: bool = True,
|
|
125
|
+
theme: str = "monokai",
|
|
126
|
+
background_color: Optional[str] = "default",
|
|
127
|
+
print_fn=print,
|
|
128
|
+
):
|
|
129
|
+
"""
|
|
130
|
+
Print code or plain text with optional syntax highlighting.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
message (str): The message or code to print.
|
|
134
|
+
title (str, optional): Optional title associated with this output. Currently
|
|
135
|
+
not used by this function, but kept for API compatibility. Defaults to None.
|
|
136
|
+
code_style (str, optional): The language/lexer name for syntax highlighting
|
|
137
|
+
(for example, ``"python"``). If ``None``, the message is rendered as plain
|
|
138
|
+
text without syntax highlighting. Defaults to ``None``.
|
|
139
|
+
expand (bool, optional): Placeholder flag for API symmetry with other printing
|
|
140
|
+
helpers. It is not used in the current implementation. Defaults to True.
|
|
141
|
+
theme (str, optional): Name of the Rich syntax highlighting theme to use when
|
|
142
|
+
``code_style`` is provided. Defaults to ``"monokai"``.
|
|
143
|
+
background_color (str, optional): Background color style to apply to the code
|
|
144
|
+
block when using syntax highlighting. Defaults to ``"default"``.
|
|
145
|
+
print_fn (Callable, optional): Function used to render the resulting Rich
|
|
146
|
+
object. Defaults to :func:`rich.print`.
|
|
147
|
+
"""
|
|
148
|
+
if code_style:
|
|
149
|
+
content = Syntax(
|
|
150
|
+
message,
|
|
151
|
+
code_style,
|
|
152
|
+
word_wrap=True,
|
|
153
|
+
theme=theme,
|
|
154
|
+
background_color=background_color,
|
|
155
|
+
)
|
|
156
|
+
else:
|
|
157
|
+
content = Text(message)
|
|
158
|
+
|
|
159
|
+
print_fn(content)
|
|
82
160
|
|
|
83
161
|
|
|
84
162
|
@rank_zero_only
|
|
@@ -95,6 +173,9 @@ def print_config_tree(
|
|
|
95
173
|
),
|
|
96
174
|
resolve: bool = False,
|
|
97
175
|
save_to_file: bool = False,
|
|
176
|
+
*,
|
|
177
|
+
theme: str = "monokai",
|
|
178
|
+
background_color: Optional[str] = "default",
|
|
98
179
|
) -> None:
|
|
99
180
|
"""Prints the contents of a DictConfig as a tree structure using the Rich library.
|
|
100
181
|
|
|
@@ -134,7 +215,14 @@ def print_config_tree(
|
|
|
134
215
|
else:
|
|
135
216
|
branch_content = str(config_group)
|
|
136
217
|
|
|
137
|
-
branch.add(
|
|
218
|
+
branch.add(
|
|
219
|
+
rich.syntax.Syntax(
|
|
220
|
+
branch_content,
|
|
221
|
+
"yaml",
|
|
222
|
+
theme=theme,
|
|
223
|
+
background_color=background_color,
|
|
224
|
+
)
|
|
225
|
+
)
|
|
138
226
|
|
|
139
227
|
# print config tree
|
|
140
228
|
rich.print(tree)
|
|
@@ -145,6 +233,35 @@ def print_config_tree(
|
|
|
145
233
|
rich.print(tree, file=file)
|
|
146
234
|
|
|
147
235
|
|
|
236
|
+
@rank_zero_only
|
|
237
|
+
def print_config_yaml(
|
|
238
|
+
cfg: DictConfig,
|
|
239
|
+
resolve: bool = False,
|
|
240
|
+
output_path: Optional[str] = False,
|
|
241
|
+
*,
|
|
242
|
+
theme: str = "monokai",
|
|
243
|
+
background_color: Optional[str] = "default",
|
|
244
|
+
) -> None:
|
|
245
|
+
"""
|
|
246
|
+
Prints the contents of a DictConfig as a YAML string using the Rich library.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
cfg: A DictConfig composed by Hydra.
|
|
250
|
+
resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.
|
|
251
|
+
output_path: Optional path to export the config YAML to. If provided, the file is written to this path.
|
|
252
|
+
"""
|
|
253
|
+
config_yaml = OmegaConf.to_yaml(cfg, resolve=resolve)
|
|
254
|
+
syntax = rich.syntax.Syntax(
|
|
255
|
+
config_yaml, "yaml", theme=theme, background_color=background_color
|
|
256
|
+
)
|
|
257
|
+
rich.print(syntax)
|
|
258
|
+
|
|
259
|
+
if output_path:
|
|
260
|
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
|
261
|
+
with open(Path(output_path), "w") as file:
|
|
262
|
+
rich.print(syntax, file=file)
|
|
263
|
+
|
|
264
|
+
|
|
148
265
|
@rank_zero_only
|
|
149
266
|
def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
|
|
150
267
|
"""Prompts user to input tags from command line if no tags are provided in config.
|