fusion-bench 0.2.23__py3-none-any.whl → 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +152 -42
- fusion_bench/dataset/__init__.py +27 -4
- fusion_bench/dataset/clip_dataset.py +2 -2
- fusion_bench/method/__init__.py +18 -1
- fusion_bench/method/classification/__init__.py +27 -2
- fusion_bench/method/classification/image_classification_finetune.py +214 -0
- fusion_bench/method/ensemble.py +17 -2
- fusion_bench/method/linear/__init__.py +6 -2
- fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
- fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
- fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
- fusion_bench/method/opcm/opcm.py +1 -0
- fusion_bench/method/pwe_moe/module.py +0 -2
- fusion_bench/method/simple_average.py +2 -2
- fusion_bench/method/tall_mask/task_arithmetic.py +2 -2
- fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
- fusion_bench/method/ties_merging/ties_merging.py +22 -6
- fusion_bench/method/wudi/__init__.py +1 -0
- fusion_bench/method/wudi/wudi.py +105 -0
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/lightning_fabric.py +4 -0
- fusion_bench/mixins/pyinstrument.py +174 -0
- fusion_bench/mixins/serialization.py +25 -78
- fusion_bench/mixins/simple_profiler.py +106 -23
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/base_pool.py +77 -14
- fusion_bench/modelpool/causal_lm/causal_lm.py +32 -10
- fusion_bench/modelpool/clip_vision/modelpool.py +56 -19
- fusion_bench/modelpool/resnet_for_image_classification.py +208 -0
- fusion_bench/models/__init__.py +35 -9
- fusion_bench/models/hf_clip.py +4 -0
- fusion_bench/models/hf_utils.py +2 -1
- fusion_bench/models/model_card_templates/default.md +8 -1
- fusion_bench/models/wrappers/ensemble.py +136 -7
- fusion_bench/optim/__init__.py +40 -2
- fusion_bench/optim/lr_scheduler/__init__.py +27 -1
- fusion_bench/optim/muon.py +339 -0
- fusion_bench/programs/__init__.py +2 -0
- fusion_bench/programs/fabric_fusion_program.py +2 -2
- fusion_bench/programs/fusion_program.py +271 -0
- fusion_bench/scripts/cli.py +2 -2
- fusion_bench/taskpool/clip_vision/taskpool.py +11 -4
- fusion_bench/tasks/clip_classification/__init__.py +15 -0
- fusion_bench/utils/__init__.py +167 -21
- fusion_bench/utils/devices.py +30 -8
- fusion_bench/utils/lazy_imports.py +91 -12
- fusion_bench/utils/lazy_state_dict.py +58 -5
- fusion_bench/utils/misc.py +104 -13
- fusion_bench/utils/packages.py +4 -0
- fusion_bench/utils/path.py +7 -0
- fusion_bench/utils/pylogger.py +6 -0
- fusion_bench/utils/rich_utils.py +8 -3
- fusion_bench/utils/state_dict_arithmetic.py +935 -162
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/METADATA +10 -3
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/RECORD +76 -55
- fusion_bench_config/method/classification/image_classification_finetune.yaml +16 -0
- fusion_bench_config/method/classification/image_classification_finetune_test.yaml +6 -0
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
- fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
- fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
- fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
- fusion_bench_config/method/wudi/wudi.yaml +4 -0
- fusion_bench_config/model_fusion.yaml +45 -0
- fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar100.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar100.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar100.yaml +14 -0
- fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,11 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utilities for handling model checkpoints and state dictionaries.
|
|
3
|
+
|
|
4
|
+
This module provides classes and functions for lazily loading state dictionaries
|
|
5
|
+
from various checkpoint formats, including PyTorch .bin files, SafeTensors files,
|
|
6
|
+
and sharded checkpoints.
|
|
7
|
+
"""
|
|
8
|
+
|
|
1
9
|
import json
|
|
2
10
|
import logging
|
|
3
11
|
import os
|
|
@@ -43,6 +51,21 @@ def resolve_checkpoint_path(
|
|
|
43
51
|
hf_cache_dir: Optional[str] = None,
|
|
44
52
|
hf_proxies: Optional[Dict] = None,
|
|
45
53
|
):
|
|
54
|
+
"""
|
|
55
|
+
Resolve a checkpoint path, downloading from Hugging Face Hub if necessary.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
checkpoint: Path to local checkpoint or Hugging Face model ID.
|
|
59
|
+
hf_revision: Specific revision to download from HF Hub.
|
|
60
|
+
hf_cache_dir: Local cache directory for HF downloads.
|
|
61
|
+
hf_proxies: Proxy settings for HF downloads.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
Local path to the checkpoint.
|
|
65
|
+
|
|
66
|
+
Raises:
|
|
67
|
+
FileNotFoundError: If the checkpoint cannot be resolved.
|
|
68
|
+
"""
|
|
46
69
|
# If it's a local file or directory, return as is
|
|
47
70
|
if os.path.exists(checkpoint):
|
|
48
71
|
return checkpoint
|
|
@@ -64,11 +87,11 @@ def resolve_checkpoint_path(
|
|
|
64
87
|
|
|
65
88
|
class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
66
89
|
"""
|
|
67
|
-
|
|
90
|
+
A dictionary-like object that lazily loads tensors from model checkpoints.
|
|
68
91
|
"""
|
|
69
92
|
|
|
70
93
|
_local_path: str
|
|
71
|
-
"""
|
|
94
|
+
"""Local path to the checkpoint."""
|
|
72
95
|
_state_dict_cache: Optional[Dict]
|
|
73
96
|
"""Cache for the state dict, if enabled."""
|
|
74
97
|
_index_filename: Optional[str]
|
|
@@ -76,6 +99,9 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
|
76
99
|
_index: Optional[Dict[str, str]]
|
|
77
100
|
"""Mapping of parameter names to checkpoint files."""
|
|
78
101
|
|
|
102
|
+
meta_module: TorchModelType = None
|
|
103
|
+
meta_module_class: Optional[Type[TorchModelType]] = None
|
|
104
|
+
|
|
79
105
|
def __init__(
|
|
80
106
|
self,
|
|
81
107
|
checkpoint: str,
|
|
@@ -89,6 +115,8 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
|
89
115
|
hf_proxies: Optional[Dict] = None,
|
|
90
116
|
):
|
|
91
117
|
"""
|
|
118
|
+
Initialize LazyStateDict with a checkpoint path.
|
|
119
|
+
|
|
92
120
|
Args:
|
|
93
121
|
checkpoint (str): Path to the checkpoint file or directory.
|
|
94
122
|
meta_module_class (Type[nn.Module], optional): Class of the meta module to instantiate.
|
|
@@ -113,6 +141,7 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
|
113
141
|
self.meta_module_class = import_object(self.meta_module_class)
|
|
114
142
|
self.meta_module = meta_module
|
|
115
143
|
|
|
144
|
+
# Instantiate meta module if class provided
|
|
116
145
|
if self.meta_module_class is not None:
|
|
117
146
|
with init_empty_weights():
|
|
118
147
|
self.meta_module = self.meta_module_class.from_pretrained(
|
|
@@ -123,6 +152,7 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
|
123
152
|
proxies=hf_proxies,
|
|
124
153
|
)
|
|
125
154
|
|
|
155
|
+
# Store original checkpoint path and resolve to local path
|
|
126
156
|
self._checkpoint = checkpoint
|
|
127
157
|
self._local_path = resolve_checkpoint_path(
|
|
128
158
|
checkpoint,
|
|
@@ -131,10 +161,12 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
|
131
161
|
hf_proxies=hf_proxies,
|
|
132
162
|
)
|
|
133
163
|
|
|
164
|
+
# Detect checkpoint file type and set up indexing
|
|
134
165
|
self._index, self._index_filename, self._checkpoint_files = (
|
|
135
166
|
self._resolve_checkpoint_files(self._local_path)
|
|
136
167
|
)
|
|
137
168
|
|
|
169
|
+
# Set up based on checkpoint type
|
|
138
170
|
if self._index is not None:
|
|
139
171
|
# if meta_module is provided, remove the keys that are not in the meta_module
|
|
140
172
|
if self.meta_module is not None:
|
|
@@ -149,7 +181,7 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
|
149
181
|
elif len(self._checkpoint_files) == 1 and self._checkpoint_files[0].endswith(
|
|
150
182
|
SAFE_WEIGHTS_NAME
|
|
151
183
|
):
|
|
152
|
-
#
|
|
184
|
+
# SafeTensors file: create index mapping all keys to this file
|
|
153
185
|
with safe_open(
|
|
154
186
|
self._checkpoint_files[0], framework="pt", device=device
|
|
155
187
|
) as f:
|
|
@@ -161,6 +193,7 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
|
161
193
|
elif len(self._checkpoint_files) == 1 and self._checkpoint_files[0].endswith(
|
|
162
194
|
WEIGHTS_NAME
|
|
163
195
|
):
|
|
196
|
+
# PyTorch .bin file: load entire state dict immediately
|
|
164
197
|
log.info(f"Loading full state dict from {WEIGHTS_NAME}")
|
|
165
198
|
self._state_dict_cache = torch.load(self._checkpoint_files[0])
|
|
166
199
|
# if meta_module is provided, remove the keys that are not in the meta_module
|
|
@@ -170,6 +203,7 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
|
170
203
|
if key not in meta_module_state_dict:
|
|
171
204
|
self._state_dict_cache.pop(key)
|
|
172
205
|
else:
|
|
206
|
+
# Unsupported checkpoint format
|
|
173
207
|
raise ValueError(
|
|
174
208
|
f"Cannot determine the type of checkpoint, please provide a checkpoint path to a file containing a whole state dict with file name {WEIGHTS_NAME} or {SAFE_WEIGHTS_NAME}, or the index of a sharded checkpoint ending with `.index.json`."
|
|
175
209
|
)
|
|
@@ -206,10 +240,19 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
|
206
240
|
return deepcopy(self)
|
|
207
241
|
|
|
208
242
|
def _resolve_checkpoint_files(self, checkpoint: str):
|
|
209
|
-
|
|
243
|
+
"""
|
|
244
|
+
Detect and resolve checkpoint files based on the checkpoint path.
|
|
245
|
+
|
|
246
|
+
Handles single files, directories with state dict files, and sharded checkpoints.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
Tuple of (index_dict, index_filename, checkpoint_files)
|
|
250
|
+
"""
|
|
251
|
+
# Reference: https://huggingface.co/docs/accelerate/v0.17.1/en/usage_guides/big_modeling
|
|
210
252
|
checkpoint_files = None
|
|
211
253
|
index_filename = None
|
|
212
254
|
if os.path.isfile(checkpoint):
|
|
255
|
+
# Single file: check if it's an index or a state dict
|
|
213
256
|
if str(checkpoint).endswith(".json"):
|
|
214
257
|
index_filename = checkpoint
|
|
215
258
|
else:
|
|
@@ -229,7 +272,7 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
|
229
272
|
os.path.join(checkpoint, potential_state_safetensor[0])
|
|
230
273
|
]
|
|
231
274
|
else:
|
|
232
|
-
#
|
|
275
|
+
# Check for sharded checkpoints
|
|
233
276
|
potential_index = [
|
|
234
277
|
f for f in os.listdir(checkpoint) if f.endswith(".index.json")
|
|
235
278
|
]
|
|
@@ -244,18 +287,22 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
|
244
287
|
f"{checkpoint} containing more than one `.index.json` file, delete the irrelevant ones."
|
|
245
288
|
)
|
|
246
289
|
else:
|
|
290
|
+
# Invalid checkpoint path
|
|
247
291
|
raise ValueError(
|
|
248
292
|
"`checkpoint` should be the path to a file containing a whole state dict, or the index of a sharded "
|
|
249
293
|
f"checkpoint, or a folder containing a sharded checkpoint or the whole state dict, but got {checkpoint}."
|
|
250
294
|
)
|
|
251
295
|
|
|
296
|
+
# Load index file if present
|
|
252
297
|
if index_filename is not None:
|
|
253
298
|
checkpoint_folder = os.path.split(index_filename)[0]
|
|
254
299
|
with open(index_filename) as f:
|
|
255
300
|
index = json.loads(f.read())
|
|
256
301
|
|
|
302
|
+
# Extract weight_map if present (standard format)
|
|
257
303
|
if "weight_map" in index:
|
|
258
304
|
index = index["weight_map"]
|
|
305
|
+
# Get list of unique checkpoint files
|
|
259
306
|
checkpoint_files = sorted(list(set(index.values())))
|
|
260
307
|
checkpoint_files = [
|
|
261
308
|
os.path.join(checkpoint_folder, f) for f in checkpoint_files
|
|
@@ -267,6 +314,11 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
|
267
314
|
def _load_tensor_from_checkpoint_file(
|
|
268
315
|
self, checkpoint_file: str, key: str, update_cache: bool = True
|
|
269
316
|
) -> torch.Tensor:
|
|
317
|
+
"""
|
|
318
|
+
Load a tensor from the checkpoint file.
|
|
319
|
+
For safetensors, loads only the requested tensor.
|
|
320
|
+
For PyTorch files, loads the entire state dict on first access.
|
|
321
|
+
"""
|
|
270
322
|
if checkpoint_file.endswith(".safetensors"):
|
|
271
323
|
with safe_open(checkpoint_file, framework="pt", device=self._device) as f:
|
|
272
324
|
tensor = f.get_tensor(key)
|
|
@@ -276,6 +328,7 @@ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
|
276
328
|
self._state_dict_cache[key] = tensor
|
|
277
329
|
return tensor
|
|
278
330
|
else:
|
|
331
|
+
# PyTorch .bin file: load entire state dict
|
|
279
332
|
state_dict = torch.load(checkpoint_file, map_location=self._device)
|
|
280
333
|
if update_cache:
|
|
281
334
|
if self._state_dict_cache is not None:
|
fusion_bench/utils/misc.py
CHANGED
|
@@ -1,35 +1,126 @@
|
|
|
1
1
|
from difflib import get_close_matches
|
|
2
|
-
from typing import Any, Iterable, List, Optional
|
|
2
|
+
from typing import Any, Iterable, List, Optional, TypeVar, Union
|
|
3
|
+
|
|
4
|
+
T = TypeVar("T")
|
|
3
5
|
|
|
4
6
|
__all__ = [
|
|
5
7
|
"first",
|
|
6
8
|
"has_length",
|
|
7
|
-
"
|
|
9
|
+
"join_lists",
|
|
8
10
|
"attr_equal",
|
|
9
11
|
"validate_and_suggest_corrections",
|
|
10
12
|
]
|
|
11
13
|
|
|
12
14
|
|
|
13
|
-
def first(iterable: Iterable):
|
|
14
|
-
|
|
15
|
+
def first(iterable: Iterable[T], default: Optional[T] = None) -> Optional[T]:
|
|
16
|
+
"""
|
|
17
|
+
Return the first element of an iterable.
|
|
15
18
|
|
|
19
|
+
Args:
|
|
20
|
+
iterable: The iterable to get the first element from.
|
|
21
|
+
default: The value to return if the iterable is empty. If None and
|
|
22
|
+
the iterable is empty, raises StopIteration.
|
|
16
23
|
|
|
17
|
-
|
|
24
|
+
Returns:
|
|
25
|
+
The first element of the iterable, or the default value if empty.
|
|
26
|
+
|
|
27
|
+
Raises:
|
|
28
|
+
StopIteration: If the iterable is empty and no default is provided.
|
|
29
|
+
TypeError: If the object is not iterable.
|
|
18
30
|
"""
|
|
19
|
-
|
|
31
|
+
try:
|
|
32
|
+
iterator = iter(iterable)
|
|
33
|
+
return next(iterator)
|
|
34
|
+
except StopIteration:
|
|
35
|
+
if default is not None:
|
|
36
|
+
return default
|
|
37
|
+
raise
|
|
38
|
+
except TypeError as e:
|
|
39
|
+
raise TypeError(
|
|
40
|
+
f"Object of type {type(iterable).__name__} is not iterable"
|
|
41
|
+
) from e
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def has_length(obj: Any) -> bool:
|
|
20
45
|
"""
|
|
46
|
+
Check if an object has a length (implements __len__) and len() works correctly.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
obj: The object to check for length support.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
bool: True if the object supports len() and doesn't raise an error,
|
|
53
|
+
False otherwise.
|
|
54
|
+
"""
|
|
55
|
+
if obj is None:
|
|
56
|
+
return False
|
|
57
|
+
|
|
21
58
|
try:
|
|
22
|
-
|
|
23
|
-
|
|
59
|
+
# Check if __len__ method exists
|
|
60
|
+
if not hasattr(obj, "__len__"):
|
|
61
|
+
return False
|
|
62
|
+
|
|
63
|
+
# Try to get the length - this will raise TypeError for unsized objects
|
|
64
|
+
length = len(obj)
|
|
65
|
+
|
|
66
|
+
# Verify the length is a non-negative integer
|
|
67
|
+
return isinstance(length, int) and length >= 0
|
|
68
|
+
except (TypeError, AttributeError):
|
|
24
69
|
# TypeError: len() of unsized object
|
|
70
|
+
# AttributeError: if __len__ is not callable somehow
|
|
25
71
|
return False
|
|
72
|
+
except Exception:
|
|
73
|
+
# Any other unexpected error
|
|
74
|
+
return False
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def join_lists(list_of_lists: Iterable[Iterable[T]]) -> List[T]:
|
|
78
|
+
"""
|
|
79
|
+
Flatten a collection of iterables into a single list.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
list_of_lists: An iterable containing iterables to be flattened.
|
|
26
83
|
|
|
84
|
+
Returns:
|
|
85
|
+
List[T]: A new list containing all elements from the input iterables
|
|
86
|
+
in order.
|
|
27
87
|
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
88
|
+
Raises:
|
|
89
|
+
TypeError: If any item in list_of_lists is not iterable.
|
|
90
|
+
|
|
91
|
+
Examples:
|
|
92
|
+
>>> join_lists([[1, 2], [3, 4], [5]])
|
|
93
|
+
[1, 2, 3, 4, 5]
|
|
94
|
+
>>> join_lists([])
|
|
95
|
+
[]
|
|
96
|
+
>>> join_lists([[], [1], [], [2, 3]])
|
|
97
|
+
[1, 2, 3]
|
|
98
|
+
"""
|
|
99
|
+
if not list_of_lists:
|
|
100
|
+
return []
|
|
101
|
+
|
|
102
|
+
result = []
|
|
103
|
+
for i, item in enumerate(list_of_lists):
|
|
104
|
+
try:
|
|
105
|
+
# Check if item is iterable (but not string, which is iterable but
|
|
106
|
+
# usually not what we want to flatten character by character)
|
|
107
|
+
if isinstance(item, (str, bytes)):
|
|
108
|
+
raise TypeError(
|
|
109
|
+
f"Item at index {i} is a string/bytes, not a list-like iterable"
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Try to extend with the item
|
|
113
|
+
result.extend(item)
|
|
114
|
+
except TypeError as e:
|
|
115
|
+
if "not iterable" in str(e):
|
|
116
|
+
raise TypeError(
|
|
117
|
+
f"Item at index {i} (type: {type(item).__name__}) is not iterable"
|
|
118
|
+
) from e
|
|
119
|
+
else:
|
|
120
|
+
# Re-raise our custom error or other TypeError
|
|
121
|
+
raise
|
|
122
|
+
|
|
123
|
+
return result
|
|
33
124
|
|
|
34
125
|
|
|
35
126
|
def attr_equal(obj, attr: str, value):
|
fusion_bench/utils/packages.py
CHANGED
fusion_bench/utils/path.py
CHANGED
|
@@ -2,6 +2,8 @@ import logging
|
|
|
2
2
|
import os
|
|
3
3
|
from typing import List
|
|
4
4
|
|
|
5
|
+
from lightning_utilities.core.rank_zero import rank_zero_only
|
|
6
|
+
|
|
5
7
|
log = logging.getLogger(__name__)
|
|
6
8
|
|
|
7
9
|
|
|
@@ -25,6 +27,7 @@ def listdir_fullpath(dir: str) -> List[str]:
|
|
|
25
27
|
return [os.path.join(dir, name) for name in names]
|
|
26
28
|
|
|
27
29
|
|
|
30
|
+
@rank_zero_only
|
|
28
31
|
def create_symlink(src_dir: str, dst_dir: str, link_name: str = None):
|
|
29
32
|
"""
|
|
30
33
|
Creates a symbolic link from src_dir to dst_dir.
|
|
@@ -59,6 +62,10 @@ def create_symlink(src_dir: str, dst_dir: str, link_name: str = None):
|
|
|
59
62
|
link_name = os.path.basename(src_dir)
|
|
60
63
|
|
|
61
64
|
link_path = os.path.join(dst_dir, link_name)
|
|
65
|
+
# if the link already exists, skip
|
|
66
|
+
if os.path.exists(link_path):
|
|
67
|
+
log.warning(f"Symbolic link already exists, skipping: {link_path}")
|
|
68
|
+
return
|
|
62
69
|
|
|
63
70
|
try:
|
|
64
71
|
# if the system is windows, use the `mklink` command in "CMD" to create the symlink
|
fusion_bench/utils/pylogger.py
CHANGED
|
@@ -3,6 +3,12 @@ from typing import Mapping, Optional
|
|
|
3
3
|
|
|
4
4
|
from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
|
|
5
5
|
|
|
6
|
+
__all__ = [
|
|
7
|
+
"RankedLogger",
|
|
8
|
+
"RankZeroLogger",
|
|
9
|
+
"get_rankzero_logger",
|
|
10
|
+
]
|
|
11
|
+
|
|
6
12
|
|
|
7
13
|
class RankedLogger(logging.LoggerAdapter):
|
|
8
14
|
"""A multi-GPU-friendly python command line logger."""
|
fusion_bench/utils/rich_utils.py
CHANGED
|
@@ -16,6 +16,7 @@ from rich.panel import Panel
|
|
|
16
16
|
from rich.prompt import Prompt
|
|
17
17
|
from rich.syntax import Syntax
|
|
18
18
|
from rich.text import Text
|
|
19
|
+
from rich.traceback import install as install_rich_traceback
|
|
19
20
|
|
|
20
21
|
from fusion_bench.utils import pylogger
|
|
21
22
|
|
|
@@ -188,17 +189,21 @@ if __name__ == "__main__":
|
|
|
188
189
|
display_available_styles()
|
|
189
190
|
|
|
190
191
|
|
|
191
|
-
def setup_colorlogging(
|
|
192
|
+
def setup_colorlogging(
|
|
193
|
+
force=False,
|
|
194
|
+
level=logging.INFO,
|
|
195
|
+
**kwargs,
|
|
196
|
+
):
|
|
192
197
|
"""
|
|
193
198
|
Sets up color logging for the application.
|
|
194
199
|
"""
|
|
195
200
|
FORMAT = "%(message)s"
|
|
196
201
|
|
|
197
202
|
logging.basicConfig(
|
|
198
|
-
level=
|
|
203
|
+
level=level,
|
|
199
204
|
format=FORMAT,
|
|
200
205
|
datefmt="[%X]",
|
|
201
206
|
handlers=[RichHandler()],
|
|
202
207
|
force=force,
|
|
203
|
-
**
|
|
208
|
+
**kwargs,
|
|
204
209
|
)
|