fusion-bench 0.2.23__py3-none-any.whl → 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +152 -42
- fusion_bench/dataset/__init__.py +27 -4
- fusion_bench/dataset/clip_dataset.py +2 -2
- fusion_bench/method/__init__.py +18 -1
- fusion_bench/method/classification/__init__.py +27 -2
- fusion_bench/method/classification/image_classification_finetune.py +214 -0
- fusion_bench/method/ensemble.py +17 -2
- fusion_bench/method/linear/__init__.py +6 -2
- fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
- fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
- fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
- fusion_bench/method/opcm/opcm.py +1 -0
- fusion_bench/method/pwe_moe/module.py +0 -2
- fusion_bench/method/simple_average.py +2 -2
- fusion_bench/method/tall_mask/task_arithmetic.py +2 -2
- fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
- fusion_bench/method/ties_merging/ties_merging.py +22 -6
- fusion_bench/method/wudi/__init__.py +1 -0
- fusion_bench/method/wudi/wudi.py +105 -0
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/lightning_fabric.py +4 -0
- fusion_bench/mixins/pyinstrument.py +174 -0
- fusion_bench/mixins/serialization.py +25 -78
- fusion_bench/mixins/simple_profiler.py +106 -23
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/base_pool.py +77 -14
- fusion_bench/modelpool/causal_lm/causal_lm.py +32 -10
- fusion_bench/modelpool/clip_vision/modelpool.py +56 -19
- fusion_bench/modelpool/resnet_for_image_classification.py +208 -0
- fusion_bench/models/__init__.py +35 -9
- fusion_bench/models/hf_clip.py +4 -0
- fusion_bench/models/hf_utils.py +2 -1
- fusion_bench/models/model_card_templates/default.md +8 -1
- fusion_bench/models/wrappers/ensemble.py +136 -7
- fusion_bench/optim/__init__.py +40 -2
- fusion_bench/optim/lr_scheduler/__init__.py +27 -1
- fusion_bench/optim/muon.py +339 -0
- fusion_bench/programs/__init__.py +2 -0
- fusion_bench/programs/fabric_fusion_program.py +2 -2
- fusion_bench/programs/fusion_program.py +271 -0
- fusion_bench/scripts/cli.py +2 -2
- fusion_bench/taskpool/clip_vision/taskpool.py +11 -4
- fusion_bench/tasks/clip_classification/__init__.py +15 -0
- fusion_bench/utils/__init__.py +167 -21
- fusion_bench/utils/devices.py +30 -8
- fusion_bench/utils/lazy_imports.py +91 -12
- fusion_bench/utils/lazy_state_dict.py +58 -5
- fusion_bench/utils/misc.py +104 -13
- fusion_bench/utils/packages.py +4 -0
- fusion_bench/utils/path.py +7 -0
- fusion_bench/utils/pylogger.py +6 -0
- fusion_bench/utils/rich_utils.py +8 -3
- fusion_bench/utils/state_dict_arithmetic.py +935 -162
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/METADATA +10 -3
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/RECORD +76 -55
- fusion_bench_config/method/classification/image_classification_finetune.yaml +16 -0
- fusion_bench_config/method/classification/image_classification_finetune_test.yaml +6 -0
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
- fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
- fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
- fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
- fusion_bench_config/method/wudi/wudi.yaml +4 -0
- fusion_bench_config/model_fusion.yaml +45 -0
- fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar100.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar100.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar100.yaml +14 -0
- fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/top_level.txt +0 -0
|
@@ -27,7 +27,7 @@ from tqdm.autonotebook import tqdm
|
|
|
27
27
|
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
|
|
28
28
|
from transformers.models.clip.modeling_clip import CLIPVisionTransformer
|
|
29
29
|
|
|
30
|
-
from fusion_bench import RuntimeConstants
|
|
30
|
+
from fusion_bench import RuntimeConstants, auto_register_config
|
|
31
31
|
from fusion_bench.dataset import CLIPDataset
|
|
32
32
|
from fusion_bench.mixins import HydraConfigMixin, LightningFabricMixin
|
|
33
33
|
from fusion_bench.models.hf_clip import HFCLIPClassifier
|
|
@@ -86,6 +86,7 @@ class LayerWiseFeatureSaver:
|
|
|
86
86
|
torch.save(features, self.save_path)
|
|
87
87
|
|
|
88
88
|
|
|
89
|
+
@auto_register_config
|
|
89
90
|
class CLIPVisionModelTaskPool(
|
|
90
91
|
HydraConfigMixin,
|
|
91
92
|
LightningFabricMixin,
|
|
@@ -134,11 +135,13 @@ class CLIPVisionModelTaskPool(
|
|
|
134
135
|
layer_wise_feature_first_token_only: bool = True,
|
|
135
136
|
layer_wise_feature_max_num: Optional[int] = None,
|
|
136
137
|
fast_dev_run: Optional[bool] = None,
|
|
138
|
+
move_to_device: bool = True,
|
|
137
139
|
**kwargs,
|
|
138
140
|
):
|
|
139
141
|
"""
|
|
140
142
|
Initialize the CLIPVisionModelTaskPool.
|
|
141
143
|
"""
|
|
144
|
+
super().__init__(**kwargs)
|
|
142
145
|
self._test_datasets = test_datasets
|
|
143
146
|
self._processor = processor
|
|
144
147
|
self._data_processor = data_processor
|
|
@@ -159,7 +162,6 @@ class CLIPVisionModelTaskPool(
|
|
|
159
162
|
self.fast_dev_run = RuntimeConstants().debug
|
|
160
163
|
else:
|
|
161
164
|
self.fast_dev_run = fast_dev_run
|
|
162
|
-
super().__init__(**kwargs)
|
|
163
165
|
|
|
164
166
|
def setup(self):
|
|
165
167
|
"""
|
|
@@ -220,7 +222,9 @@ class CLIPVisionModelTaskPool(
|
|
|
220
222
|
for name, dataset in self.test_datasets.items()
|
|
221
223
|
}
|
|
222
224
|
self.test_dataloaders = {
|
|
223
|
-
name: self.fabric.setup_dataloaders(
|
|
225
|
+
name: self.fabric.setup_dataloaders(
|
|
226
|
+
dataloader, move_to_device=self.move_to_device
|
|
227
|
+
)
|
|
224
228
|
for name, dataloader in self.test_dataloaders.items()
|
|
225
229
|
}
|
|
226
230
|
|
|
@@ -273,6 +277,8 @@ class CLIPVisionModelTaskPool(
|
|
|
273
277
|
task_name=task_name,
|
|
274
278
|
)
|
|
275
279
|
logits: Tensor = outputs["logits"]
|
|
280
|
+
if logits.device != targets.device:
|
|
281
|
+
targets = targets.to(logits.device)
|
|
276
282
|
|
|
277
283
|
loss = F.cross_entropy(logits, targets)
|
|
278
284
|
loss_metric.update(loss.detach().cpu())
|
|
@@ -321,7 +327,8 @@ class CLIPVisionModelTaskPool(
|
|
|
321
327
|
self.clip_model,
|
|
322
328
|
processor=self.processor,
|
|
323
329
|
)
|
|
324
|
-
|
|
330
|
+
if self.move_to_device:
|
|
331
|
+
classifier = cast(HFCLIPClassifier, self.fabric.to_device(classifier))
|
|
325
332
|
# collect basic model information
|
|
326
333
|
training_params, all_params = count_parameters(model)
|
|
327
334
|
report["model_info"] = {
|
|
@@ -183,3 +183,18 @@ class CLIPTemplateFactory:
|
|
|
183
183
|
|
|
184
184
|
def get_classnames_and_templates(dataset_name: str) -> Tuple[List[str], List[Callable]]:
|
|
185
185
|
return CLIPTemplateFactory.get_classnames_and_templates(dataset_name)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def get_num_classes(dataset_name: str) -> int:
|
|
189
|
+
classnames, _ = get_classnames_and_templates(dataset_name)
|
|
190
|
+
return len(classnames)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def get_classnames(dataset_name: str) -> List[str]:
|
|
194
|
+
classnames, _ = get_classnames_and_templates(dataset_name)
|
|
195
|
+
return classnames
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def get_templates(dataset_name: str) -> List[Callable]:
|
|
199
|
+
_, templates = get_classnames_and_templates(dataset_name)
|
|
200
|
+
return templates
|
fusion_bench/utils/__init__.py
CHANGED
|
@@ -1,23 +1,169 @@
|
|
|
1
1
|
# flake8: noqa: F401
|
|
2
|
-
import
|
|
3
|
-
from typing import
|
|
2
|
+
import sys
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
4
|
|
|
5
|
-
from . import
|
|
6
|
-
from .
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
5
|
+
from . import functools
|
|
6
|
+
from .lazy_imports import LazyImporter
|
|
7
|
+
|
|
8
|
+
_extra_objects = {
|
|
9
|
+
"functools": functools,
|
|
10
|
+
}
|
|
11
|
+
_import_structure = {
|
|
12
|
+
"cache_utils": [
|
|
13
|
+
"cache_to_disk",
|
|
14
|
+
"cache_with_joblib",
|
|
15
|
+
"set_default_cache_dir",
|
|
16
|
+
],
|
|
17
|
+
"data": [
|
|
18
|
+
"InfiniteDataLoader",
|
|
19
|
+
"load_tensor_from_file",
|
|
20
|
+
"train_validation_split",
|
|
21
|
+
"train_validation_test_split",
|
|
22
|
+
],
|
|
23
|
+
"devices": [
|
|
24
|
+
"clear_cuda_cache",
|
|
25
|
+
"get_current_device",
|
|
26
|
+
"get_device",
|
|
27
|
+
"get_device_capabilities",
|
|
28
|
+
"get_device_memory_info",
|
|
29
|
+
"num_devices",
|
|
30
|
+
"to_device",
|
|
31
|
+
],
|
|
32
|
+
"dtype": ["get_dtype", "parse_dtype"],
|
|
33
|
+
"fabric": ["seed_everything_by_time"],
|
|
34
|
+
"instantiate_utils": [
|
|
35
|
+
"instantiate",
|
|
36
|
+
"is_instantiable",
|
|
37
|
+
"set_print_function_call",
|
|
38
|
+
"set_print_function_call_permeanent",
|
|
39
|
+
],
|
|
40
|
+
"json": ["load_from_json", "save_to_json", "print_json"],
|
|
41
|
+
"lazy_state_dict": ["LazyStateDict"],
|
|
42
|
+
"misc": [
|
|
43
|
+
"first",
|
|
44
|
+
"has_length",
|
|
45
|
+
"join_lists",
|
|
46
|
+
"validate_and_suggest_corrections",
|
|
47
|
+
],
|
|
48
|
+
"packages": ["compare_versions", "import_object"],
|
|
49
|
+
"parameters": [
|
|
50
|
+
"check_parameters_all_equal",
|
|
51
|
+
"count_parameters",
|
|
52
|
+
"get_parameter_statistics",
|
|
53
|
+
"get_parameter_summary",
|
|
54
|
+
"human_readable",
|
|
55
|
+
"print_parameters",
|
|
56
|
+
"state_dict_to_vector",
|
|
57
|
+
"trainable_state_dict",
|
|
58
|
+
"vector_to_state_dict",
|
|
59
|
+
],
|
|
60
|
+
"path": [
|
|
61
|
+
"create_symlink",
|
|
62
|
+
"listdir_fullpath",
|
|
63
|
+
"path_is_dir_and_not_empty",
|
|
64
|
+
],
|
|
65
|
+
"pylogger": [
|
|
66
|
+
"RankedLogger",
|
|
67
|
+
"RankZeroLogger",
|
|
68
|
+
"get_rankzero_logger",
|
|
69
|
+
],
|
|
70
|
+
"state_dict_arithmetic": [
|
|
71
|
+
"ArithmeticStateDict",
|
|
72
|
+
"state_dicts_check_keys",
|
|
73
|
+
"num_params_of_state_dict",
|
|
74
|
+
"state_dict_to_device",
|
|
75
|
+
"state_dict_flatten",
|
|
76
|
+
"state_dict_avg",
|
|
77
|
+
"state_dict_sub",
|
|
78
|
+
"state_dict_add",
|
|
79
|
+
"state_dict_add_scalar",
|
|
80
|
+
"state_dict_mul",
|
|
81
|
+
"state_dict_div",
|
|
82
|
+
"state_dict_power",
|
|
83
|
+
"state_dict_interpolation",
|
|
84
|
+
"state_dict_sum",
|
|
85
|
+
"state_dict_weighted_sum",
|
|
86
|
+
"state_dict_diff_abs",
|
|
87
|
+
"state_dict_binary_mask",
|
|
88
|
+
"state_dict_hadamard_product",
|
|
89
|
+
],
|
|
90
|
+
"timer": ["timeit_context"],
|
|
91
|
+
"type": [
|
|
92
|
+
"BoolStateDictType",
|
|
93
|
+
"StateDictType",
|
|
94
|
+
"TorchModelType",
|
|
95
|
+
],
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
if TYPE_CHECKING:
|
|
99
|
+
from .cache_utils import cache_to_disk, cache_with_joblib, set_default_cache_dir
|
|
100
|
+
from .data import (
|
|
101
|
+
InfiniteDataLoader,
|
|
102
|
+
load_tensor_from_file,
|
|
103
|
+
train_validation_split,
|
|
104
|
+
train_validation_test_split,
|
|
105
|
+
)
|
|
106
|
+
from .devices import (
|
|
107
|
+
clear_cuda_cache,
|
|
108
|
+
get_current_device,
|
|
109
|
+
get_device,
|
|
110
|
+
get_device_capabilities,
|
|
111
|
+
get_device_memory_info,
|
|
112
|
+
num_devices,
|
|
113
|
+
to_device,
|
|
114
|
+
)
|
|
115
|
+
from .dtype import get_dtype, parse_dtype
|
|
116
|
+
from .fabric import seed_everything_by_time
|
|
117
|
+
from .instantiate_utils import (
|
|
118
|
+
instantiate,
|
|
119
|
+
is_instantiable,
|
|
120
|
+
set_print_function_call,
|
|
121
|
+
set_print_function_call_permeanent,
|
|
122
|
+
)
|
|
123
|
+
from .json import load_from_json, print_json, save_to_json
|
|
124
|
+
from .lazy_state_dict import LazyStateDict
|
|
125
|
+
from .misc import first, has_length, join_lists, validate_and_suggest_corrections
|
|
126
|
+
from .packages import compare_versions, import_object
|
|
127
|
+
from .parameters import (
|
|
128
|
+
check_parameters_all_equal,
|
|
129
|
+
count_parameters,
|
|
130
|
+
get_parameter_statistics,
|
|
131
|
+
get_parameter_summary,
|
|
132
|
+
human_readable,
|
|
133
|
+
print_parameters,
|
|
134
|
+
state_dict_to_vector,
|
|
135
|
+
trainable_state_dict,
|
|
136
|
+
vector_to_state_dict,
|
|
137
|
+
)
|
|
138
|
+
from .path import create_symlink, listdir_fullpath, path_is_dir_and_not_empty
|
|
139
|
+
from .pylogger import RankedLogger, RankZeroLogger, get_rankzero_logger
|
|
140
|
+
from .state_dict_arithmetic import (
|
|
141
|
+
ArithmeticStateDict,
|
|
142
|
+
num_params_of_state_dict,
|
|
143
|
+
state_dict_add,
|
|
144
|
+
state_dict_add_scalar,
|
|
145
|
+
state_dict_avg,
|
|
146
|
+
state_dict_binary_mask,
|
|
147
|
+
state_dict_diff_abs,
|
|
148
|
+
state_dict_div,
|
|
149
|
+
state_dict_flatten,
|
|
150
|
+
state_dict_hadamard_product,
|
|
151
|
+
state_dict_interpolation,
|
|
152
|
+
state_dict_mul,
|
|
153
|
+
state_dict_power,
|
|
154
|
+
state_dict_sub,
|
|
155
|
+
state_dict_sum,
|
|
156
|
+
state_dict_to_device,
|
|
157
|
+
state_dict_weighted_sum,
|
|
158
|
+
state_dicts_check_keys,
|
|
159
|
+
)
|
|
160
|
+
from .timer import timeit_context
|
|
161
|
+
from .type import BoolStateDictType, StateDictType, TorchModelType
|
|
162
|
+
|
|
163
|
+
else:
|
|
164
|
+
sys.modules[__name__] = LazyImporter(
|
|
165
|
+
__name__,
|
|
166
|
+
globals()["__file__"],
|
|
167
|
+
_import_structure,
|
|
168
|
+
extra_objects=_extra_objects,
|
|
169
|
+
)
|
fusion_bench/utils/devices.py
CHANGED
|
@@ -39,7 +39,12 @@ def clear_cuda_cache():
|
|
|
39
39
|
log.warning("CUDA is not available. No cache to clear.")
|
|
40
40
|
|
|
41
41
|
|
|
42
|
-
def to_device(
|
|
42
|
+
def to_device(
|
|
43
|
+
obj: T,
|
|
44
|
+
device: Optional[torch.device],
|
|
45
|
+
copy_on_move: bool = False,
|
|
46
|
+
**kwargs: Any,
|
|
47
|
+
) -> T:
|
|
43
48
|
"""
|
|
44
49
|
Move a given object to the specified device.
|
|
45
50
|
|
|
@@ -49,12 +54,20 @@ def to_device(obj: T, device: Optional[torch.device], **kwargs: Any) -> T:
|
|
|
49
54
|
Args:
|
|
50
55
|
obj: The object to be moved to the device. This can be a torch.Tensor, torch.nn.Module, list, tuple, or dict.
|
|
51
56
|
device (torch.device): The target device to move the object to. This can be `None`.
|
|
52
|
-
|
|
57
|
+
copy_on_move (bool, optional): Whether to force a copy operation when moving tensors to a different device.
|
|
58
|
+
If True, tensors will be copied when moved to a different device (copy=True is passed to tensor.to()).
|
|
59
|
+
If False (default), tensors are moved without forcing a copy operation, allowing PyTorch to optimize
|
|
60
|
+
the operation. This parameter only affects torch.Tensor objects; modules and other types are unaffected.
|
|
61
|
+
Defaults to False.
|
|
62
|
+
**kwargs: Additional keyword arguments to be passed to the `to` method of torch.Tensor or torch.nn.Module.
|
|
63
|
+
For example, `non_blocking=True`, `dtype=torch.float16`. Note that if `copy_on_move=True`, the `copy`
|
|
64
|
+
keyword argument will be automatically set and should not be provided manually.
|
|
53
65
|
|
|
54
66
|
Returns:
|
|
55
67
|
The object moved to the specified device. The type of the returned object matches the type of the input object.
|
|
56
68
|
|
|
57
69
|
Examples:
|
|
70
|
+
```python
|
|
58
71
|
>>> tensor = torch.tensor([1, 2, 3])
|
|
59
72
|
>>> to_device(tensor, torch.device('cuda'))
|
|
60
73
|
tensor([1, 2, 3], device='cuda:0')
|
|
@@ -66,17 +79,26 @@ def to_device(obj: T, device: Optional[torch.device], **kwargs: Any) -> T:
|
|
|
66
79
|
>>> data = [torch.tensor([1, 2]), torch.tensor([3, 4])]
|
|
67
80
|
>>> to_device(data, torch.device('cuda'))
|
|
68
81
|
[tensor([1, 2], device='cuda:0'), tensor([3, 4], device='cuda:0')]
|
|
82
|
+
|
|
83
|
+
>>> # Force copy when moving to different device
|
|
84
|
+
>>> tensor = torch.tensor([1, 2, 3], device='cpu')
|
|
85
|
+
>>> copied_tensor = to_device(tensor, torch.device('cuda'), copy_on_move=True)
|
|
86
|
+
>>> # tensor and copied_tensor will have different memory locations
|
|
87
|
+
```
|
|
69
88
|
"""
|
|
70
|
-
if isinstance(obj,
|
|
89
|
+
if isinstance(obj, torch.Tensor):
|
|
90
|
+
if copy_on_move:
|
|
91
|
+
if obj.device != torch.device(device):
|
|
92
|
+
kwargs["copy"] = True
|
|
93
|
+
return obj.to(device, **kwargs)
|
|
94
|
+
elif isinstance(obj, torch.nn.Module):
|
|
71
95
|
return obj.to(device, **kwargs)
|
|
72
96
|
elif isinstance(obj, list):
|
|
73
|
-
return [to_device(o, device) for o in obj]
|
|
97
|
+
return [to_device(o, device, **kwargs) for o in obj]
|
|
74
98
|
elif isinstance(obj, tuple):
|
|
75
|
-
return tuple(to_device(o, device) for o in obj)
|
|
99
|
+
return tuple(to_device(o, device, **kwargs) for o in obj)
|
|
76
100
|
elif isinstance(obj, dict):
|
|
77
|
-
for key in obj
|
|
78
|
-
obj[key] = to_device(obj[key], device)
|
|
79
|
-
return obj
|
|
101
|
+
return {key: to_device(value, device, **kwargs) for key, value in obj.items()}
|
|
80
102
|
else:
|
|
81
103
|
# the default behavior is to return the object as is
|
|
82
104
|
return obj
|
|
@@ -24,36 +24,78 @@ to publish it as a standalone package.
|
|
|
24
24
|
import importlib
|
|
25
25
|
import os
|
|
26
26
|
from types import ModuleType
|
|
27
|
-
from typing import Any
|
|
27
|
+
from typing import Any, Dict, List, Optional, Set, Union
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
class LazyImporter(ModuleType):
|
|
31
|
-
"""
|
|
31
|
+
"""Lazy importer for modules and their components.
|
|
32
|
+
|
|
33
|
+
This class allows for lazy importing of modules, meaning modules are only
|
|
34
|
+
imported when they are actually accessed. This can help reduce startup
|
|
35
|
+
time and memory usage for large packages with many optional dependencies.
|
|
36
|
+
|
|
37
|
+
Attributes:
|
|
38
|
+
_modules: Set of module names available for import.
|
|
39
|
+
_class_to_module: Mapping from class/function names to their module names.
|
|
40
|
+
_objects: Dictionary of extra objects to include in the module.
|
|
41
|
+
_name: Name of the module.
|
|
42
|
+
_import_structure: Dictionary mapping module names to lists of their exports.
|
|
43
|
+
"""
|
|
32
44
|
|
|
33
45
|
# Very heavily inspired by optuna.integration._IntegrationModule
|
|
34
46
|
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
|
|
35
|
-
def __init__(
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
name: str,
|
|
50
|
+
module_file: str,
|
|
51
|
+
import_structure: Dict[str, List[str]],
|
|
52
|
+
extra_objects: Optional[Dict[str, Any]] = None,
|
|
53
|
+
) -> None:
|
|
54
|
+
"""Initialize the LazyImporter.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
name: The name of the module.
|
|
58
|
+
module_file: Path to the module file.
|
|
59
|
+
import_structure: Dictionary mapping module names to lists of their exports.
|
|
60
|
+
extra_objects: Optional dictionary of extra objects to include.
|
|
61
|
+
"""
|
|
36
62
|
super().__init__(name)
|
|
37
|
-
self._modules = set(import_structure.keys())
|
|
38
|
-
self._class_to_module = {}
|
|
63
|
+
self._modules: Set[str] = set(import_structure.keys())
|
|
64
|
+
self._class_to_module: Dict[str, str] = {}
|
|
39
65
|
for key, values in import_structure.items():
|
|
40
66
|
for value in values:
|
|
41
67
|
self._class_to_module[value] = key
|
|
42
68
|
# Needed for autocompletion in an IDE
|
|
43
|
-
self.__all__ = list(import_structure.keys()) + sum(
|
|
69
|
+
self.__all__: List[str] = list(import_structure.keys()) + sum(
|
|
44
70
|
import_structure.values(), []
|
|
45
71
|
)
|
|
46
72
|
self.__file__ = module_file
|
|
47
73
|
self.__path__ = [os.path.dirname(module_file)]
|
|
48
|
-
self._objects = {} if extra_objects is None else extra_objects
|
|
74
|
+
self._objects: Dict[str, Any] = {} if extra_objects is None else extra_objects
|
|
49
75
|
self._name = name
|
|
50
76
|
self._import_structure = import_structure
|
|
51
77
|
|
|
52
78
|
# Needed for autocompletion in an IDE
|
|
53
|
-
def __dir__(self):
|
|
79
|
+
def __dir__(self) -> List[str]:
|
|
80
|
+
"""Return list of available attributes for autocompletion.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
List of all available attribute names.
|
|
84
|
+
"""
|
|
54
85
|
return super().__dir__() + self.__all__
|
|
55
86
|
|
|
56
87
|
def __getattr__(self, name: str) -> Any:
|
|
88
|
+
"""Get attribute lazily, importing the module if necessary.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
name: The name of the attribute to retrieve.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
The requested attribute.
|
|
95
|
+
|
|
96
|
+
Raises:
|
|
97
|
+
AttributeError: If the attribute is not found in any module.
|
|
98
|
+
"""
|
|
57
99
|
if name in self._objects:
|
|
58
100
|
return self._objects[name]
|
|
59
101
|
if name in self._modules:
|
|
@@ -67,31 +109,68 @@ class LazyImporter(ModuleType):
|
|
|
67
109
|
setattr(self, name, value)
|
|
68
110
|
return value
|
|
69
111
|
|
|
70
|
-
def _get_module(self, module_name: str):
|
|
112
|
+
def _get_module(self, module_name: str) -> ModuleType:
|
|
113
|
+
"""Import and return the specified module.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
module_name: Name of the module to import.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
The imported module.
|
|
120
|
+
"""
|
|
71
121
|
return importlib.import_module("." + module_name, self.__name__)
|
|
72
122
|
|
|
73
|
-
def __reduce__(self):
|
|
123
|
+
def __reduce__(self) -> tuple:
|
|
124
|
+
"""Support for pickling the LazyImporter.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
Tuple containing the class and arguments needed to reconstruct the object.
|
|
128
|
+
"""
|
|
74
129
|
return (self.__class__, (self._name, self.__file__, self._import_structure))
|
|
75
130
|
|
|
76
131
|
|
|
77
|
-
class
|
|
132
|
+
class LazyPyModule(ModuleType):
|
|
78
133
|
"""Module wrapper for lazy import.
|
|
134
|
+
|
|
79
135
|
Adapted from Optuna: https://github.com/optuna/optuna/blob/1f92d496b0c4656645384e31539e4ee74992ff55/optuna/__init__.py
|
|
80
136
|
|
|
81
137
|
This class wraps specified module and lazily import it when they are actually accessed.
|
|
138
|
+
This can help reduce startup time and memory usage by deferring module imports
|
|
139
|
+
until they are needed.
|
|
82
140
|
|
|
83
141
|
Args:
|
|
84
142
|
name: Name of module to apply lazy import.
|
|
143
|
+
|
|
144
|
+
Attributes:
|
|
145
|
+
_name: The name of the module to be lazily imported.
|
|
85
146
|
"""
|
|
86
147
|
|
|
87
148
|
def __init__(self, name: str) -> None:
|
|
149
|
+
"""Initialize the LazyPyModule.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
name: The name of the module to be lazily imported.
|
|
153
|
+
"""
|
|
88
154
|
super().__init__(name)
|
|
89
|
-
self._name = name
|
|
155
|
+
self._name: str = name
|
|
90
156
|
|
|
91
157
|
def _load(self) -> ModuleType:
|
|
158
|
+
"""Load the actual module and update this object's dictionary.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
The loaded module.
|
|
162
|
+
"""
|
|
92
163
|
module = importlib.import_module(self._name)
|
|
93
164
|
self.__dict__.update(module.__dict__)
|
|
94
165
|
return module
|
|
95
166
|
|
|
96
167
|
def __getattr__(self, item: str) -> Any:
|
|
168
|
+
"""Get attribute from the lazily loaded module.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
item: The name of the attribute to retrieve.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
The requested attribute from the loaded module.
|
|
175
|
+
"""
|
|
97
176
|
return getattr(self._load(), item)
|