fusion-bench 0.2.24__py3-none-any.whl → 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +152 -42
- fusion_bench/dataset/__init__.py +27 -4
- fusion_bench/dataset/clip_dataset.py +2 -2
- fusion_bench/method/__init__.py +10 -1
- fusion_bench/method/classification/__init__.py +27 -2
- fusion_bench/method/classification/image_classification_finetune.py +214 -0
- fusion_bench/method/opcm/opcm.py +1 -0
- fusion_bench/method/pwe_moe/module.py +0 -2
- fusion_bench/method/tall_mask/task_arithmetic.py +2 -2
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/pyinstrument.py +174 -0
- fusion_bench/mixins/simple_profiler.py +106 -23
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/base_pool.py +77 -14
- fusion_bench/modelpool/clip_vision/modelpool.py +56 -19
- fusion_bench/modelpool/resnet_for_image_classification.py +208 -0
- fusion_bench/models/__init__.py +35 -9
- fusion_bench/optim/__init__.py +40 -2
- fusion_bench/optim/lr_scheduler/__init__.py +27 -1
- fusion_bench/optim/muon.py +339 -0
- fusion_bench/programs/__init__.py +2 -0
- fusion_bench/programs/fabric_fusion_program.py +2 -2
- fusion_bench/programs/fusion_program.py +271 -0
- fusion_bench/tasks/clip_classification/__init__.py +15 -0
- fusion_bench/utils/__init__.py +167 -21
- fusion_bench/utils/lazy_imports.py +91 -12
- fusion_bench/utils/lazy_state_dict.py +55 -5
- fusion_bench/utils/misc.py +104 -13
- fusion_bench/utils/packages.py +4 -0
- fusion_bench/utils/path.py +7 -0
- fusion_bench/utils/pylogger.py +6 -0
- fusion_bench/utils/rich_utils.py +1 -0
- fusion_bench/utils/state_dict_arithmetic.py +935 -162
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/RECORD +48 -34
- fusion_bench_config/method/classification/image_classification_finetune.yaml +16 -0
- fusion_bench_config/method/classification/image_classification_finetune_test.yaml +6 -0
- fusion_bench_config/model_fusion.yaml +45 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar100.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar100.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar100.yaml +14 -0
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/top_level.txt +0 -0
|
@@ -5,7 +5,7 @@ from pathlib import Path
|
|
|
5
5
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Union # noqa: F401
|
|
6
6
|
|
|
7
7
|
import lightning as L
|
|
8
|
-
from
|
|
8
|
+
from lightning_utilities.core.rank_zero import rank_zero_only
|
|
9
9
|
from omegaconf import DictConfig, OmegaConf
|
|
10
10
|
from torch import nn
|
|
11
11
|
from tqdm.auto import tqdm
|
|
@@ -236,7 +236,7 @@ class FabricModelFusionProgram(
|
|
|
236
236
|
|
|
237
237
|
# create symbol link to hydra output directory
|
|
238
238
|
if (
|
|
239
|
-
|
|
239
|
+
rank_zero_only.rank == 0
|
|
240
240
|
and self.log_dir is not None
|
|
241
241
|
and os.path.abspath(self.log_dir) != os.path.abspath(get_hydra_output_dir())
|
|
242
242
|
):
|
|
@@ -0,0 +1,271 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
from typing import Any, Dict, Iterable, List, Optional, Union
|
|
4
|
+
|
|
5
|
+
import lightning as L
|
|
6
|
+
from lightning_utilities.core.rank_zero import rank_zero_only
|
|
7
|
+
from omegaconf import DictConfig, OmegaConf
|
|
8
|
+
from torch import nn
|
|
9
|
+
from tqdm.auto import tqdm
|
|
10
|
+
|
|
11
|
+
from fusion_bench import (
|
|
12
|
+
BaseAlgorithm,
|
|
13
|
+
BaseHydraProgram,
|
|
14
|
+
BaseModelPool,
|
|
15
|
+
BaseTaskPool,
|
|
16
|
+
RuntimeConstants,
|
|
17
|
+
auto_register_config,
|
|
18
|
+
get_rankzero_logger,
|
|
19
|
+
import_object,
|
|
20
|
+
instantiate,
|
|
21
|
+
timeit_context,
|
|
22
|
+
)
|
|
23
|
+
from fusion_bench.utils.json import print_json
|
|
24
|
+
from fusion_bench.utils.rich_utils import print_bordered, print_config_tree
|
|
25
|
+
|
|
26
|
+
log = get_rankzero_logger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@auto_register_config
|
|
30
|
+
class ModelFusionProgram(BaseHydraProgram):
|
|
31
|
+
method: BaseAlgorithm
|
|
32
|
+
modelpool: BaseModelPool
|
|
33
|
+
taskpool: Optional[BaseTaskPool] = None
|
|
34
|
+
|
|
35
|
+
_config_mapping = BaseHydraProgram._config_mapping | {
|
|
36
|
+
"_method": "method",
|
|
37
|
+
"_modelpool": "modelpool",
|
|
38
|
+
"_taskpool": "taskpool",
|
|
39
|
+
"fast_dev_run": "fast_dev_run",
|
|
40
|
+
"seed": "seed",
|
|
41
|
+
"path": "path",
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
method: DictConfig,
|
|
47
|
+
modelpool: DictConfig,
|
|
48
|
+
taskpool: Optional[DictConfig] = None,
|
|
49
|
+
*,
|
|
50
|
+
print_config: bool = True,
|
|
51
|
+
dry_run: bool = False,
|
|
52
|
+
report_save_path: Optional[str] = None,
|
|
53
|
+
merged_model_save_path: Optional[str] = None,
|
|
54
|
+
merged_model_save_kwargs: Optional[DictConfig] = None,
|
|
55
|
+
fast_dev_run: bool = False,
|
|
56
|
+
seed: Optional[int] = None,
|
|
57
|
+
print_function_call: bool = True,
|
|
58
|
+
path: DictConfig = None,
|
|
59
|
+
**kwargs,
|
|
60
|
+
):
|
|
61
|
+
super().__init__(**kwargs)
|
|
62
|
+
self._method = method
|
|
63
|
+
self._modelpool = modelpool
|
|
64
|
+
self._taskpool = taskpool
|
|
65
|
+
self.report_save_path = report_save_path
|
|
66
|
+
self.merged_model_save_path = merged_model_save_path
|
|
67
|
+
self.merged_model_save_kwargs = merged_model_save_kwargs
|
|
68
|
+
self.fast_dev_run = fast_dev_run
|
|
69
|
+
self.seed = seed
|
|
70
|
+
self.path = path
|
|
71
|
+
RuntimeConstants.debug = fast_dev_run
|
|
72
|
+
RuntimeConstants.print_function_call = print_function_call
|
|
73
|
+
if path is not None:
|
|
74
|
+
RuntimeConstants.cache_dir = path.get("cache_dir", None)
|
|
75
|
+
|
|
76
|
+
if print_config:
|
|
77
|
+
print_config_tree(
|
|
78
|
+
self.config,
|
|
79
|
+
print_order=["method", "modelpool", "taskpool"],
|
|
80
|
+
)
|
|
81
|
+
if dry_run:
|
|
82
|
+
log.info("The program is running in dry-run mode. Exiting.")
|
|
83
|
+
exit(0)
|
|
84
|
+
|
|
85
|
+
def _instantiate_and_setup(
|
|
86
|
+
self, config: DictConfig, compat_load_fn: Optional[str] = None
|
|
87
|
+
):
|
|
88
|
+
R"""
|
|
89
|
+
Instantiates and sets up an object based on the provided configuration.
|
|
90
|
+
|
|
91
|
+
This method performs the following steps:
|
|
92
|
+
1. Checks if the configuration dictionary contains the key "_target_".
|
|
93
|
+
2. If "_target_" is not found (for v0.1.x), attempts to instantiate the object using a compatible load function if provided.
|
|
94
|
+
- Logs a warning if "_target_" is missing.
|
|
95
|
+
- If `compat_load_fn` is provided, imports the function and uses it to instantiate the object.
|
|
96
|
+
- If `compat_load_fn` is not provided, raises a ValueError.
|
|
97
|
+
3. If "_target_" is found (for v.0.2.0 and above), attempts to import and instantiate the object using the `instantiate` function.
|
|
98
|
+
- Ensures the target can be imported.
|
|
99
|
+
- Uses the `instantiate` function with `_recursive_` set based on the configuration.
|
|
100
|
+
4. Sets the `_program` attribute of the instantiated object to `self` if the object has this attribute.
|
|
101
|
+
5. Returns the instantiated and set up object.
|
|
102
|
+
"""
|
|
103
|
+
if "_target_" not in config:
|
|
104
|
+
log.warning(
|
|
105
|
+
"No '_target_' key found in config. Attempting to instantiate the object in a compatible way."
|
|
106
|
+
)
|
|
107
|
+
if compat_load_fn is not None:
|
|
108
|
+
compat_load_fn = import_object(compat_load_fn)
|
|
109
|
+
if rank_zero_only.rank == 0:
|
|
110
|
+
print_bordered(
|
|
111
|
+
OmegaConf.to_yaml(config),
|
|
112
|
+
title="instantiate compat object",
|
|
113
|
+
style="magenta",
|
|
114
|
+
code_style="yaml",
|
|
115
|
+
)
|
|
116
|
+
obj = compat_load_fn(config)
|
|
117
|
+
else:
|
|
118
|
+
raise ValueError(
|
|
119
|
+
"No load function provided. Please provide a load function to instantiate the object."
|
|
120
|
+
)
|
|
121
|
+
else:
|
|
122
|
+
# try to import the object from the target
|
|
123
|
+
# this checks if the target is valid and can be imported
|
|
124
|
+
import_object(config._target_)
|
|
125
|
+
obj = instantiate(
|
|
126
|
+
config,
|
|
127
|
+
_recursive_=config.get("_recursive_", False),
|
|
128
|
+
)
|
|
129
|
+
if hasattr(obj, "_program"):
|
|
130
|
+
obj._program = self
|
|
131
|
+
return obj
|
|
132
|
+
|
|
133
|
+
def save_merged_model(self, merged_model):
|
|
134
|
+
"""
|
|
135
|
+
Saves the merged model to the specified path.
|
|
136
|
+
"""
|
|
137
|
+
if self.merged_model_save_path is not None:
|
|
138
|
+
# path to save the merged model, use "{log_dir}" to refer to the logger directory
|
|
139
|
+
save_path: str = self.merged_model_save_path
|
|
140
|
+
if "{log_dir}" in save_path and self.log_dir is not None:
|
|
141
|
+
save_path = save_path.format(log_dir=self.log_dir)
|
|
142
|
+
|
|
143
|
+
if os.path.dirname(save_path):
|
|
144
|
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
145
|
+
|
|
146
|
+
# save the merged model
|
|
147
|
+
if self.merged_model_save_kwargs is not None:
|
|
148
|
+
merged_model_save_kwargs = self.merged_model_save_kwargs
|
|
149
|
+
else:
|
|
150
|
+
merged_model_save_kwargs = {}
|
|
151
|
+
with timeit_context(f"Saving the merged model to {save_path}"):
|
|
152
|
+
self.modelpool.save_model(
|
|
153
|
+
merged_model,
|
|
154
|
+
save_path,
|
|
155
|
+
**merged_model_save_kwargs,
|
|
156
|
+
)
|
|
157
|
+
else:
|
|
158
|
+
print("No save path specified for the merged model. Skipping saving.")
|
|
159
|
+
|
|
160
|
+
def evaluate_merged_model(
|
|
161
|
+
self,
|
|
162
|
+
taskpool: BaseTaskPool,
|
|
163
|
+
merged_model: Union[nn.Module, Dict, Iterable],
|
|
164
|
+
*args: Any,
|
|
165
|
+
**kwargs: Any,
|
|
166
|
+
) -> Union[Dict, List, Any]:
|
|
167
|
+
"""
|
|
168
|
+
Evaluates the merged model using the provided task pool.
|
|
169
|
+
|
|
170
|
+
Depending on the type of the merged model, this function handles the evaluation differently:
|
|
171
|
+
- If the merged model is an instance of `nn.Module`, it directly evaluates the model.
|
|
172
|
+
- If the merged model is a dictionary, it extracts the model from the dictionary and evaluates it.
|
|
173
|
+
The evaluation report is then updated with the remaining dictionary items.
|
|
174
|
+
- If the merged model is an iterable, it recursively evaluates each model in the iterable.
|
|
175
|
+
- Raises a `ValueError` if the merged model is of an invalid type.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
taskpool: The task pool used for evaluating the merged model.
|
|
179
|
+
merged_model: The merged model to be evaluated. It can be an instance of `nn.Module`, a dictionary, or an iterable.
|
|
180
|
+
*args: Additional positional arguments to be passed to the `evaluate` method of the taskpool.
|
|
181
|
+
**kwargs: Additional keyword arguments to be passed to the `evaluate` method of the taskpool.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
The evaluation report. The type of the report depends on the type of the merged model:
|
|
185
|
+
- If the merged model is an instance of `nn.Module`, the report is a dictionary.
|
|
186
|
+
- If the merged model is a dictionary, the report is a dictionary updated with the remaining dictionary items.
|
|
187
|
+
- If the merged model is an iterable, the report is a list of evaluation reports.
|
|
188
|
+
"""
|
|
189
|
+
if isinstance(merged_model, nn.Module):
|
|
190
|
+
report = taskpool.evaluate(merged_model, *args, **kwargs)
|
|
191
|
+
return report
|
|
192
|
+
elif isinstance(merged_model, Dict):
|
|
193
|
+
report = {}
|
|
194
|
+
for key, item in merged_model.items():
|
|
195
|
+
if isinstance(item, nn.Module):
|
|
196
|
+
report[key] = taskpool.evaluate(item, *args, **kwargs)
|
|
197
|
+
elif key == "models":
|
|
198
|
+
# for multi-model evaluation
|
|
199
|
+
report[key] = self.evaluate_merged_model(
|
|
200
|
+
taskpool, item, *args, **kwargs
|
|
201
|
+
)
|
|
202
|
+
else:
|
|
203
|
+
# metadata
|
|
204
|
+
report[key] = item
|
|
205
|
+
return report
|
|
206
|
+
elif isinstance(merged_model, Iterable):
|
|
207
|
+
return [
|
|
208
|
+
self.evaluate_merged_model(taskpool, m, *args, **kwargs)
|
|
209
|
+
for m in tqdm(merged_model, desc="Evaluating models")
|
|
210
|
+
]
|
|
211
|
+
else:
|
|
212
|
+
raise ValueError(f"Invalid type for merged model: {type(merged_model)}")
|
|
213
|
+
|
|
214
|
+
def run(self):
|
|
215
|
+
"""
|
|
216
|
+
Executes the model fusion program.
|
|
217
|
+
"""
|
|
218
|
+
if self.seed is not None:
|
|
219
|
+
L.seed_everything(self.seed)
|
|
220
|
+
|
|
221
|
+
log.info("Running the model fusion program.")
|
|
222
|
+
# setup the modelpool, method, and taskpool
|
|
223
|
+
log.info("loading model pool")
|
|
224
|
+
self.modelpool = self._instantiate_and_setup(
|
|
225
|
+
self._modelpool,
|
|
226
|
+
compat_load_fn="fusion_bench.compat.modelpool.load_modelpool_from_config",
|
|
227
|
+
)
|
|
228
|
+
log.info("loading method")
|
|
229
|
+
self.method = self._instantiate_and_setup(
|
|
230
|
+
self._method,
|
|
231
|
+
compat_load_fn="fusion_bench.compat.method.load_algorithm_from_config",
|
|
232
|
+
)
|
|
233
|
+
if self._taskpool is not None:
|
|
234
|
+
log.info("loading task pool")
|
|
235
|
+
self.taskpool = self._instantiate_and_setup(
|
|
236
|
+
self._taskpool,
|
|
237
|
+
compat_load_fn="fusion_bench.compat.taskpool.load_taskpool_from_config",
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
self.method.on_run_start()
|
|
241
|
+
merged_model = self.method.run(self.modelpool)
|
|
242
|
+
self.method.on_run_end()
|
|
243
|
+
|
|
244
|
+
if merged_model is None:
|
|
245
|
+
log.info(
|
|
246
|
+
"No merged model returned by the method. Skipping saving and evaluation."
|
|
247
|
+
)
|
|
248
|
+
else:
|
|
249
|
+
self.save_merged_model(merged_model)
|
|
250
|
+
if self.taskpool is not None:
|
|
251
|
+
report = self.evaluate_merged_model(self.taskpool, merged_model)
|
|
252
|
+
try:
|
|
253
|
+
if rank_zero_only.rank == 0:
|
|
254
|
+
print_json(report, print_type=False)
|
|
255
|
+
except Exception as e:
|
|
256
|
+
log.warning(f"Failed to pretty print the report: {e}")
|
|
257
|
+
log.info(report)
|
|
258
|
+
if self.report_save_path is not None:
|
|
259
|
+
# save report (Dict) to a file
|
|
260
|
+
# if the directory of `save_report` does not exists, create it
|
|
261
|
+
if (
|
|
262
|
+
"{log_dir}" in self.report_save_path
|
|
263
|
+
and self.path.log_dir is not None
|
|
264
|
+
):
|
|
265
|
+
self.report_save_path = self.report_save_path.format(
|
|
266
|
+
log_dir=self.path.log_dir
|
|
267
|
+
)
|
|
268
|
+
os.makedirs(os.path.dirname(self.report_save_path), exist_ok=True)
|
|
269
|
+
json.dump(report, open(self.report_save_path, "w"))
|
|
270
|
+
else:
|
|
271
|
+
log.info("No task pool specified. Skipping evaluation.")
|
|
@@ -183,3 +183,18 @@ class CLIPTemplateFactory:
|
|
|
183
183
|
|
|
184
184
|
def get_classnames_and_templates(dataset_name: str) -> Tuple[List[str], List[Callable]]:
|
|
185
185
|
return CLIPTemplateFactory.get_classnames_and_templates(dataset_name)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def get_num_classes(dataset_name: str) -> int:
|
|
189
|
+
classnames, _ = get_classnames_and_templates(dataset_name)
|
|
190
|
+
return len(classnames)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def get_classnames(dataset_name: str) -> List[str]:
|
|
194
|
+
classnames, _ = get_classnames_and_templates(dataset_name)
|
|
195
|
+
return classnames
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def get_templates(dataset_name: str) -> List[Callable]:
|
|
199
|
+
_, templates = get_classnames_and_templates(dataset_name)
|
|
200
|
+
return templates
|
fusion_bench/utils/__init__.py
CHANGED
|
@@ -1,23 +1,169 @@
|
|
|
1
1
|
# flake8: noqa: F401
|
|
2
|
-
import
|
|
3
|
-
from typing import
|
|
2
|
+
import sys
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
4
|
|
|
5
|
-
from . import
|
|
6
|
-
from .
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
5
|
+
from . import functools
|
|
6
|
+
from .lazy_imports import LazyImporter
|
|
7
|
+
|
|
8
|
+
_extra_objects = {
|
|
9
|
+
"functools": functools,
|
|
10
|
+
}
|
|
11
|
+
_import_structure = {
|
|
12
|
+
"cache_utils": [
|
|
13
|
+
"cache_to_disk",
|
|
14
|
+
"cache_with_joblib",
|
|
15
|
+
"set_default_cache_dir",
|
|
16
|
+
],
|
|
17
|
+
"data": [
|
|
18
|
+
"InfiniteDataLoader",
|
|
19
|
+
"load_tensor_from_file",
|
|
20
|
+
"train_validation_split",
|
|
21
|
+
"train_validation_test_split",
|
|
22
|
+
],
|
|
23
|
+
"devices": [
|
|
24
|
+
"clear_cuda_cache",
|
|
25
|
+
"get_current_device",
|
|
26
|
+
"get_device",
|
|
27
|
+
"get_device_capabilities",
|
|
28
|
+
"get_device_memory_info",
|
|
29
|
+
"num_devices",
|
|
30
|
+
"to_device",
|
|
31
|
+
],
|
|
32
|
+
"dtype": ["get_dtype", "parse_dtype"],
|
|
33
|
+
"fabric": ["seed_everything_by_time"],
|
|
34
|
+
"instantiate_utils": [
|
|
35
|
+
"instantiate",
|
|
36
|
+
"is_instantiable",
|
|
37
|
+
"set_print_function_call",
|
|
38
|
+
"set_print_function_call_permeanent",
|
|
39
|
+
],
|
|
40
|
+
"json": ["load_from_json", "save_to_json", "print_json"],
|
|
41
|
+
"lazy_state_dict": ["LazyStateDict"],
|
|
42
|
+
"misc": [
|
|
43
|
+
"first",
|
|
44
|
+
"has_length",
|
|
45
|
+
"join_lists",
|
|
46
|
+
"validate_and_suggest_corrections",
|
|
47
|
+
],
|
|
48
|
+
"packages": ["compare_versions", "import_object"],
|
|
49
|
+
"parameters": [
|
|
50
|
+
"check_parameters_all_equal",
|
|
51
|
+
"count_parameters",
|
|
52
|
+
"get_parameter_statistics",
|
|
53
|
+
"get_parameter_summary",
|
|
54
|
+
"human_readable",
|
|
55
|
+
"print_parameters",
|
|
56
|
+
"state_dict_to_vector",
|
|
57
|
+
"trainable_state_dict",
|
|
58
|
+
"vector_to_state_dict",
|
|
59
|
+
],
|
|
60
|
+
"path": [
|
|
61
|
+
"create_symlink",
|
|
62
|
+
"listdir_fullpath",
|
|
63
|
+
"path_is_dir_and_not_empty",
|
|
64
|
+
],
|
|
65
|
+
"pylogger": [
|
|
66
|
+
"RankedLogger",
|
|
67
|
+
"RankZeroLogger",
|
|
68
|
+
"get_rankzero_logger",
|
|
69
|
+
],
|
|
70
|
+
"state_dict_arithmetic": [
|
|
71
|
+
"ArithmeticStateDict",
|
|
72
|
+
"state_dicts_check_keys",
|
|
73
|
+
"num_params_of_state_dict",
|
|
74
|
+
"state_dict_to_device",
|
|
75
|
+
"state_dict_flatten",
|
|
76
|
+
"state_dict_avg",
|
|
77
|
+
"state_dict_sub",
|
|
78
|
+
"state_dict_add",
|
|
79
|
+
"state_dict_add_scalar",
|
|
80
|
+
"state_dict_mul",
|
|
81
|
+
"state_dict_div",
|
|
82
|
+
"state_dict_power",
|
|
83
|
+
"state_dict_interpolation",
|
|
84
|
+
"state_dict_sum",
|
|
85
|
+
"state_dict_weighted_sum",
|
|
86
|
+
"state_dict_diff_abs",
|
|
87
|
+
"state_dict_binary_mask",
|
|
88
|
+
"state_dict_hadamard_product",
|
|
89
|
+
],
|
|
90
|
+
"timer": ["timeit_context"],
|
|
91
|
+
"type": [
|
|
92
|
+
"BoolStateDictType",
|
|
93
|
+
"StateDictType",
|
|
94
|
+
"TorchModelType",
|
|
95
|
+
],
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
if TYPE_CHECKING:
|
|
99
|
+
from .cache_utils import cache_to_disk, cache_with_joblib, set_default_cache_dir
|
|
100
|
+
from .data import (
|
|
101
|
+
InfiniteDataLoader,
|
|
102
|
+
load_tensor_from_file,
|
|
103
|
+
train_validation_split,
|
|
104
|
+
train_validation_test_split,
|
|
105
|
+
)
|
|
106
|
+
from .devices import (
|
|
107
|
+
clear_cuda_cache,
|
|
108
|
+
get_current_device,
|
|
109
|
+
get_device,
|
|
110
|
+
get_device_capabilities,
|
|
111
|
+
get_device_memory_info,
|
|
112
|
+
num_devices,
|
|
113
|
+
to_device,
|
|
114
|
+
)
|
|
115
|
+
from .dtype import get_dtype, parse_dtype
|
|
116
|
+
from .fabric import seed_everything_by_time
|
|
117
|
+
from .instantiate_utils import (
|
|
118
|
+
instantiate,
|
|
119
|
+
is_instantiable,
|
|
120
|
+
set_print_function_call,
|
|
121
|
+
set_print_function_call_permeanent,
|
|
122
|
+
)
|
|
123
|
+
from .json import load_from_json, print_json, save_to_json
|
|
124
|
+
from .lazy_state_dict import LazyStateDict
|
|
125
|
+
from .misc import first, has_length, join_lists, validate_and_suggest_corrections
|
|
126
|
+
from .packages import compare_versions, import_object
|
|
127
|
+
from .parameters import (
|
|
128
|
+
check_parameters_all_equal,
|
|
129
|
+
count_parameters,
|
|
130
|
+
get_parameter_statistics,
|
|
131
|
+
get_parameter_summary,
|
|
132
|
+
human_readable,
|
|
133
|
+
print_parameters,
|
|
134
|
+
state_dict_to_vector,
|
|
135
|
+
trainable_state_dict,
|
|
136
|
+
vector_to_state_dict,
|
|
137
|
+
)
|
|
138
|
+
from .path import create_symlink, listdir_fullpath, path_is_dir_and_not_empty
|
|
139
|
+
from .pylogger import RankedLogger, RankZeroLogger, get_rankzero_logger
|
|
140
|
+
from .state_dict_arithmetic import (
|
|
141
|
+
ArithmeticStateDict,
|
|
142
|
+
num_params_of_state_dict,
|
|
143
|
+
state_dict_add,
|
|
144
|
+
state_dict_add_scalar,
|
|
145
|
+
state_dict_avg,
|
|
146
|
+
state_dict_binary_mask,
|
|
147
|
+
state_dict_diff_abs,
|
|
148
|
+
state_dict_div,
|
|
149
|
+
state_dict_flatten,
|
|
150
|
+
state_dict_hadamard_product,
|
|
151
|
+
state_dict_interpolation,
|
|
152
|
+
state_dict_mul,
|
|
153
|
+
state_dict_power,
|
|
154
|
+
state_dict_sub,
|
|
155
|
+
state_dict_sum,
|
|
156
|
+
state_dict_to_device,
|
|
157
|
+
state_dict_weighted_sum,
|
|
158
|
+
state_dicts_check_keys,
|
|
159
|
+
)
|
|
160
|
+
from .timer import timeit_context
|
|
161
|
+
from .type import BoolStateDictType, StateDictType, TorchModelType
|
|
162
|
+
|
|
163
|
+
else:
|
|
164
|
+
sys.modules[__name__] = LazyImporter(
|
|
165
|
+
__name__,
|
|
166
|
+
globals()["__file__"],
|
|
167
|
+
_import_structure,
|
|
168
|
+
extra_objects=_extra_objects,
|
|
169
|
+
)
|
|
@@ -24,36 +24,78 @@ to publish it as a standalone package.
|
|
|
24
24
|
import importlib
|
|
25
25
|
import os
|
|
26
26
|
from types import ModuleType
|
|
27
|
-
from typing import Any
|
|
27
|
+
from typing import Any, Dict, List, Optional, Set, Union
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
class LazyImporter(ModuleType):
|
|
31
|
-
"""
|
|
31
|
+
"""Lazy importer for modules and their components.
|
|
32
|
+
|
|
33
|
+
This class allows for lazy importing of modules, meaning modules are only
|
|
34
|
+
imported when they are actually accessed. This can help reduce startup
|
|
35
|
+
time and memory usage for large packages with many optional dependencies.
|
|
36
|
+
|
|
37
|
+
Attributes:
|
|
38
|
+
_modules: Set of module names available for import.
|
|
39
|
+
_class_to_module: Mapping from class/function names to their module names.
|
|
40
|
+
_objects: Dictionary of extra objects to include in the module.
|
|
41
|
+
_name: Name of the module.
|
|
42
|
+
_import_structure: Dictionary mapping module names to lists of their exports.
|
|
43
|
+
"""
|
|
32
44
|
|
|
33
45
|
# Very heavily inspired by optuna.integration._IntegrationModule
|
|
34
46
|
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
|
|
35
|
-
def __init__(
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
name: str,
|
|
50
|
+
module_file: str,
|
|
51
|
+
import_structure: Dict[str, List[str]],
|
|
52
|
+
extra_objects: Optional[Dict[str, Any]] = None,
|
|
53
|
+
) -> None:
|
|
54
|
+
"""Initialize the LazyImporter.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
name: The name of the module.
|
|
58
|
+
module_file: Path to the module file.
|
|
59
|
+
import_structure: Dictionary mapping module names to lists of their exports.
|
|
60
|
+
extra_objects: Optional dictionary of extra objects to include.
|
|
61
|
+
"""
|
|
36
62
|
super().__init__(name)
|
|
37
|
-
self._modules = set(import_structure.keys())
|
|
38
|
-
self._class_to_module = {}
|
|
63
|
+
self._modules: Set[str] = set(import_structure.keys())
|
|
64
|
+
self._class_to_module: Dict[str, str] = {}
|
|
39
65
|
for key, values in import_structure.items():
|
|
40
66
|
for value in values:
|
|
41
67
|
self._class_to_module[value] = key
|
|
42
68
|
# Needed for autocompletion in an IDE
|
|
43
|
-
self.__all__ = list(import_structure.keys()) + sum(
|
|
69
|
+
self.__all__: List[str] = list(import_structure.keys()) + sum(
|
|
44
70
|
import_structure.values(), []
|
|
45
71
|
)
|
|
46
72
|
self.__file__ = module_file
|
|
47
73
|
self.__path__ = [os.path.dirname(module_file)]
|
|
48
|
-
self._objects = {} if extra_objects is None else extra_objects
|
|
74
|
+
self._objects: Dict[str, Any] = {} if extra_objects is None else extra_objects
|
|
49
75
|
self._name = name
|
|
50
76
|
self._import_structure = import_structure
|
|
51
77
|
|
|
52
78
|
# Needed for autocompletion in an IDE
|
|
53
|
-
def __dir__(self):
|
|
79
|
+
def __dir__(self) -> List[str]:
|
|
80
|
+
"""Return list of available attributes for autocompletion.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
List of all available attribute names.
|
|
84
|
+
"""
|
|
54
85
|
return super().__dir__() + self.__all__
|
|
55
86
|
|
|
56
87
|
def __getattr__(self, name: str) -> Any:
|
|
88
|
+
"""Get attribute lazily, importing the module if necessary.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
name: The name of the attribute to retrieve.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
The requested attribute.
|
|
95
|
+
|
|
96
|
+
Raises:
|
|
97
|
+
AttributeError: If the attribute is not found in any module.
|
|
98
|
+
"""
|
|
57
99
|
if name in self._objects:
|
|
58
100
|
return self._objects[name]
|
|
59
101
|
if name in self._modules:
|
|
@@ -67,31 +109,68 @@ class LazyImporter(ModuleType):
|
|
|
67
109
|
setattr(self, name, value)
|
|
68
110
|
return value
|
|
69
111
|
|
|
70
|
-
def _get_module(self, module_name: str):
|
|
112
|
+
def _get_module(self, module_name: str) -> ModuleType:
|
|
113
|
+
"""Import and return the specified module.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
module_name: Name of the module to import.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
The imported module.
|
|
120
|
+
"""
|
|
71
121
|
return importlib.import_module("." + module_name, self.__name__)
|
|
72
122
|
|
|
73
|
-
def __reduce__(self):
|
|
123
|
+
def __reduce__(self) -> tuple:
|
|
124
|
+
"""Support for pickling the LazyImporter.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
Tuple containing the class and arguments needed to reconstruct the object.
|
|
128
|
+
"""
|
|
74
129
|
return (self.__class__, (self._name, self.__file__, self._import_structure))
|
|
75
130
|
|
|
76
131
|
|
|
77
|
-
class
|
|
132
|
+
class LazyPyModule(ModuleType):
|
|
78
133
|
"""Module wrapper for lazy import.
|
|
134
|
+
|
|
79
135
|
Adapted from Optuna: https://github.com/optuna/optuna/blob/1f92d496b0c4656645384e31539e4ee74992ff55/optuna/__init__.py
|
|
80
136
|
|
|
81
137
|
This class wraps specified module and lazily import it when they are actually accessed.
|
|
138
|
+
This can help reduce startup time and memory usage by deferring module imports
|
|
139
|
+
until they are needed.
|
|
82
140
|
|
|
83
141
|
Args:
|
|
84
142
|
name: Name of module to apply lazy import.
|
|
143
|
+
|
|
144
|
+
Attributes:
|
|
145
|
+
_name: The name of the module to be lazily imported.
|
|
85
146
|
"""
|
|
86
147
|
|
|
87
148
|
def __init__(self, name: str) -> None:
|
|
149
|
+
"""Initialize the LazyPyModule.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
name: The name of the module to be lazily imported.
|
|
153
|
+
"""
|
|
88
154
|
super().__init__(name)
|
|
89
|
-
self._name = name
|
|
155
|
+
self._name: str = name
|
|
90
156
|
|
|
91
157
|
def _load(self) -> ModuleType:
|
|
158
|
+
"""Load the actual module and update this object's dictionary.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
The loaded module.
|
|
162
|
+
"""
|
|
92
163
|
module = importlib.import_module(self._name)
|
|
93
164
|
self.__dict__.update(module.__dict__)
|
|
94
165
|
return module
|
|
95
166
|
|
|
96
167
|
def __getattr__(self, item: str) -> Any:
|
|
168
|
+
"""Get attribute from the lazily loaded module.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
item: The name of the attribute to retrieve.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
The requested attribute from the loaded module.
|
|
175
|
+
"""
|
|
97
176
|
return getattr(self._load(), item)
|