fusion-bench 0.2.24__py3-none-any.whl → 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. fusion_bench/__init__.py +152 -42
  2. fusion_bench/dataset/__init__.py +27 -4
  3. fusion_bench/dataset/clip_dataset.py +2 -2
  4. fusion_bench/method/__init__.py +10 -1
  5. fusion_bench/method/classification/__init__.py +27 -2
  6. fusion_bench/method/classification/image_classification_finetune.py +214 -0
  7. fusion_bench/method/opcm/opcm.py +1 -0
  8. fusion_bench/method/pwe_moe/module.py +0 -2
  9. fusion_bench/method/tall_mask/task_arithmetic.py +2 -2
  10. fusion_bench/mixins/__init__.py +2 -0
  11. fusion_bench/mixins/pyinstrument.py +174 -0
  12. fusion_bench/mixins/simple_profiler.py +106 -23
  13. fusion_bench/modelpool/__init__.py +2 -0
  14. fusion_bench/modelpool/base_pool.py +77 -14
  15. fusion_bench/modelpool/clip_vision/modelpool.py +56 -19
  16. fusion_bench/modelpool/resnet_for_image_classification.py +208 -0
  17. fusion_bench/models/__init__.py +35 -9
  18. fusion_bench/optim/__init__.py +40 -2
  19. fusion_bench/optim/lr_scheduler/__init__.py +27 -1
  20. fusion_bench/optim/muon.py +339 -0
  21. fusion_bench/programs/__init__.py +2 -0
  22. fusion_bench/programs/fabric_fusion_program.py +2 -2
  23. fusion_bench/programs/fusion_program.py +271 -0
  24. fusion_bench/tasks/clip_classification/__init__.py +15 -0
  25. fusion_bench/utils/__init__.py +167 -21
  26. fusion_bench/utils/lazy_imports.py +91 -12
  27. fusion_bench/utils/lazy_state_dict.py +55 -5
  28. fusion_bench/utils/misc.py +104 -13
  29. fusion_bench/utils/packages.py +4 -0
  30. fusion_bench/utils/path.py +7 -0
  31. fusion_bench/utils/pylogger.py +6 -0
  32. fusion_bench/utils/rich_utils.py +1 -0
  33. fusion_bench/utils/state_dict_arithmetic.py +935 -162
  34. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/METADATA +1 -1
  35. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/RECORD +48 -34
  36. fusion_bench_config/method/classification/image_classification_finetune.yaml +16 -0
  37. fusion_bench_config/method/classification/image_classification_finetune_test.yaml +6 -0
  38. fusion_bench_config/model_fusion.yaml +45 -0
  39. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar10.yaml +14 -0
  40. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar100.yaml +14 -0
  41. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar10.yaml +14 -0
  42. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar100.yaml +14 -0
  43. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar10.yaml +14 -0
  44. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar100.yaml +14 -0
  45. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/WHEEL +0 -0
  46. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/entry_points.txt +0 -0
  47. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/licenses/LICENSE +0 -0
  48. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,7 @@ from pathlib import Path
5
5
  from typing import Any, Callable, Dict, Iterable, List, Optional, Union # noqa: F401
6
6
 
7
7
  import lightning as L
8
- from lightning.fabric.utilities.rank_zero import rank_zero_only
8
+ from lightning_utilities.core.rank_zero import rank_zero_only
9
9
  from omegaconf import DictConfig, OmegaConf
10
10
  from torch import nn
11
11
  from tqdm.auto import tqdm
@@ -236,7 +236,7 @@ class FabricModelFusionProgram(
236
236
 
237
237
  # create symbol link to hydra output directory
238
238
  if (
239
- self.fabric.is_global_zero
239
+ rank_zero_only.rank == 0
240
240
  and self.log_dir is not None
241
241
  and os.path.abspath(self.log_dir) != os.path.abspath(get_hydra_output_dir())
242
242
  ):
@@ -0,0 +1,271 @@
1
+ import json
2
+ import os
3
+ from typing import Any, Dict, Iterable, List, Optional, Union
4
+
5
+ import lightning as L
6
+ from lightning_utilities.core.rank_zero import rank_zero_only
7
+ from omegaconf import DictConfig, OmegaConf
8
+ from torch import nn
9
+ from tqdm.auto import tqdm
10
+
11
+ from fusion_bench import (
12
+ BaseAlgorithm,
13
+ BaseHydraProgram,
14
+ BaseModelPool,
15
+ BaseTaskPool,
16
+ RuntimeConstants,
17
+ auto_register_config,
18
+ get_rankzero_logger,
19
+ import_object,
20
+ instantiate,
21
+ timeit_context,
22
+ )
23
+ from fusion_bench.utils.json import print_json
24
+ from fusion_bench.utils.rich_utils import print_bordered, print_config_tree
25
+
26
+ log = get_rankzero_logger(__name__)
27
+
28
+
29
+ @auto_register_config
30
+ class ModelFusionProgram(BaseHydraProgram):
31
+ method: BaseAlgorithm
32
+ modelpool: BaseModelPool
33
+ taskpool: Optional[BaseTaskPool] = None
34
+
35
+ _config_mapping = BaseHydraProgram._config_mapping | {
36
+ "_method": "method",
37
+ "_modelpool": "modelpool",
38
+ "_taskpool": "taskpool",
39
+ "fast_dev_run": "fast_dev_run",
40
+ "seed": "seed",
41
+ "path": "path",
42
+ }
43
+
44
+ def __init__(
45
+ self,
46
+ method: DictConfig,
47
+ modelpool: DictConfig,
48
+ taskpool: Optional[DictConfig] = None,
49
+ *,
50
+ print_config: bool = True,
51
+ dry_run: bool = False,
52
+ report_save_path: Optional[str] = None,
53
+ merged_model_save_path: Optional[str] = None,
54
+ merged_model_save_kwargs: Optional[DictConfig] = None,
55
+ fast_dev_run: bool = False,
56
+ seed: Optional[int] = None,
57
+ print_function_call: bool = True,
58
+ path: DictConfig = None,
59
+ **kwargs,
60
+ ):
61
+ super().__init__(**kwargs)
62
+ self._method = method
63
+ self._modelpool = modelpool
64
+ self._taskpool = taskpool
65
+ self.report_save_path = report_save_path
66
+ self.merged_model_save_path = merged_model_save_path
67
+ self.merged_model_save_kwargs = merged_model_save_kwargs
68
+ self.fast_dev_run = fast_dev_run
69
+ self.seed = seed
70
+ self.path = path
71
+ RuntimeConstants.debug = fast_dev_run
72
+ RuntimeConstants.print_function_call = print_function_call
73
+ if path is not None:
74
+ RuntimeConstants.cache_dir = path.get("cache_dir", None)
75
+
76
+ if print_config:
77
+ print_config_tree(
78
+ self.config,
79
+ print_order=["method", "modelpool", "taskpool"],
80
+ )
81
+ if dry_run:
82
+ log.info("The program is running in dry-run mode. Exiting.")
83
+ exit(0)
84
+
85
+ def _instantiate_and_setup(
86
+ self, config: DictConfig, compat_load_fn: Optional[str] = None
87
+ ):
88
+ R"""
89
+ Instantiates and sets up an object based on the provided configuration.
90
+
91
+ This method performs the following steps:
92
+ 1. Checks if the configuration dictionary contains the key "_target_".
93
+ 2. If "_target_" is not found (for v0.1.x), attempts to instantiate the object using a compatible load function if provided.
94
+ - Logs a warning if "_target_" is missing.
95
+ - If `compat_load_fn` is provided, imports the function and uses it to instantiate the object.
96
+ - If `compat_load_fn` is not provided, raises a ValueError.
97
+ 3. If "_target_" is found (for v.0.2.0 and above), attempts to import and instantiate the object using the `instantiate` function.
98
+ - Ensures the target can be imported.
99
+ - Uses the `instantiate` function with `_recursive_` set based on the configuration.
100
+ 4. Sets the `_program` attribute of the instantiated object to `self` if the object has this attribute.
101
+ 5. Returns the instantiated and set up object.
102
+ """
103
+ if "_target_" not in config:
104
+ log.warning(
105
+ "No '_target_' key found in config. Attempting to instantiate the object in a compatible way."
106
+ )
107
+ if compat_load_fn is not None:
108
+ compat_load_fn = import_object(compat_load_fn)
109
+ if rank_zero_only.rank == 0:
110
+ print_bordered(
111
+ OmegaConf.to_yaml(config),
112
+ title="instantiate compat object",
113
+ style="magenta",
114
+ code_style="yaml",
115
+ )
116
+ obj = compat_load_fn(config)
117
+ else:
118
+ raise ValueError(
119
+ "No load function provided. Please provide a load function to instantiate the object."
120
+ )
121
+ else:
122
+ # try to import the object from the target
123
+ # this checks if the target is valid and can be imported
124
+ import_object(config._target_)
125
+ obj = instantiate(
126
+ config,
127
+ _recursive_=config.get("_recursive_", False),
128
+ )
129
+ if hasattr(obj, "_program"):
130
+ obj._program = self
131
+ return obj
132
+
133
+ def save_merged_model(self, merged_model):
134
+ """
135
+ Saves the merged model to the specified path.
136
+ """
137
+ if self.merged_model_save_path is not None:
138
+ # path to save the merged model, use "{log_dir}" to refer to the logger directory
139
+ save_path: str = self.merged_model_save_path
140
+ if "{log_dir}" in save_path and self.log_dir is not None:
141
+ save_path = save_path.format(log_dir=self.log_dir)
142
+
143
+ if os.path.dirname(save_path):
144
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
145
+
146
+ # save the merged model
147
+ if self.merged_model_save_kwargs is not None:
148
+ merged_model_save_kwargs = self.merged_model_save_kwargs
149
+ else:
150
+ merged_model_save_kwargs = {}
151
+ with timeit_context(f"Saving the merged model to {save_path}"):
152
+ self.modelpool.save_model(
153
+ merged_model,
154
+ save_path,
155
+ **merged_model_save_kwargs,
156
+ )
157
+ else:
158
+ print("No save path specified for the merged model. Skipping saving.")
159
+
160
+ def evaluate_merged_model(
161
+ self,
162
+ taskpool: BaseTaskPool,
163
+ merged_model: Union[nn.Module, Dict, Iterable],
164
+ *args: Any,
165
+ **kwargs: Any,
166
+ ) -> Union[Dict, List, Any]:
167
+ """
168
+ Evaluates the merged model using the provided task pool.
169
+
170
+ Depending on the type of the merged model, this function handles the evaluation differently:
171
+ - If the merged model is an instance of `nn.Module`, it directly evaluates the model.
172
+ - If the merged model is a dictionary, it extracts the model from the dictionary and evaluates it.
173
+ The evaluation report is then updated with the remaining dictionary items.
174
+ - If the merged model is an iterable, it recursively evaluates each model in the iterable.
175
+ - Raises a `ValueError` if the merged model is of an invalid type.
176
+
177
+ Args:
178
+ taskpool: The task pool used for evaluating the merged model.
179
+ merged_model: The merged model to be evaluated. It can be an instance of `nn.Module`, a dictionary, or an iterable.
180
+ *args: Additional positional arguments to be passed to the `evaluate` method of the taskpool.
181
+ **kwargs: Additional keyword arguments to be passed to the `evaluate` method of the taskpool.
182
+
183
+ Returns:
184
+ The evaluation report. The type of the report depends on the type of the merged model:
185
+ - If the merged model is an instance of `nn.Module`, the report is a dictionary.
186
+ - If the merged model is a dictionary, the report is a dictionary updated with the remaining dictionary items.
187
+ - If the merged model is an iterable, the report is a list of evaluation reports.
188
+ """
189
+ if isinstance(merged_model, nn.Module):
190
+ report = taskpool.evaluate(merged_model, *args, **kwargs)
191
+ return report
192
+ elif isinstance(merged_model, Dict):
193
+ report = {}
194
+ for key, item in merged_model.items():
195
+ if isinstance(item, nn.Module):
196
+ report[key] = taskpool.evaluate(item, *args, **kwargs)
197
+ elif key == "models":
198
+ # for multi-model evaluation
199
+ report[key] = self.evaluate_merged_model(
200
+ taskpool, item, *args, **kwargs
201
+ )
202
+ else:
203
+ # metadata
204
+ report[key] = item
205
+ return report
206
+ elif isinstance(merged_model, Iterable):
207
+ return [
208
+ self.evaluate_merged_model(taskpool, m, *args, **kwargs)
209
+ for m in tqdm(merged_model, desc="Evaluating models")
210
+ ]
211
+ else:
212
+ raise ValueError(f"Invalid type for merged model: {type(merged_model)}")
213
+
214
+ def run(self):
215
+ """
216
+ Executes the model fusion program.
217
+ """
218
+ if self.seed is not None:
219
+ L.seed_everything(self.seed)
220
+
221
+ log.info("Running the model fusion program.")
222
+ # setup the modelpool, method, and taskpool
223
+ log.info("loading model pool")
224
+ self.modelpool = self._instantiate_and_setup(
225
+ self._modelpool,
226
+ compat_load_fn="fusion_bench.compat.modelpool.load_modelpool_from_config",
227
+ )
228
+ log.info("loading method")
229
+ self.method = self._instantiate_and_setup(
230
+ self._method,
231
+ compat_load_fn="fusion_bench.compat.method.load_algorithm_from_config",
232
+ )
233
+ if self._taskpool is not None:
234
+ log.info("loading task pool")
235
+ self.taskpool = self._instantiate_and_setup(
236
+ self._taskpool,
237
+ compat_load_fn="fusion_bench.compat.taskpool.load_taskpool_from_config",
238
+ )
239
+
240
+ self.method.on_run_start()
241
+ merged_model = self.method.run(self.modelpool)
242
+ self.method.on_run_end()
243
+
244
+ if merged_model is None:
245
+ log.info(
246
+ "No merged model returned by the method. Skipping saving and evaluation."
247
+ )
248
+ else:
249
+ self.save_merged_model(merged_model)
250
+ if self.taskpool is not None:
251
+ report = self.evaluate_merged_model(self.taskpool, merged_model)
252
+ try:
253
+ if rank_zero_only.rank == 0:
254
+ print_json(report, print_type=False)
255
+ except Exception as e:
256
+ log.warning(f"Failed to pretty print the report: {e}")
257
+ log.info(report)
258
+ if self.report_save_path is not None:
259
+ # save report (Dict) to a file
260
+ # if the directory of `save_report` does not exists, create it
261
+ if (
262
+ "{log_dir}" in self.report_save_path
263
+ and self.path.log_dir is not None
264
+ ):
265
+ self.report_save_path = self.report_save_path.format(
266
+ log_dir=self.path.log_dir
267
+ )
268
+ os.makedirs(os.path.dirname(self.report_save_path), exist_ok=True)
269
+ json.dump(report, open(self.report_save_path, "w"))
270
+ else:
271
+ log.info("No task pool specified. Skipping evaluation.")
@@ -183,3 +183,18 @@ class CLIPTemplateFactory:
183
183
 
184
184
  def get_classnames_and_templates(dataset_name: str) -> Tuple[List[str], List[Callable]]:
185
185
  return CLIPTemplateFactory.get_classnames_and_templates(dataset_name)
186
+
187
+
188
+ def get_num_classes(dataset_name: str) -> int:
189
+ classnames, _ = get_classnames_and_templates(dataset_name)
190
+ return len(classnames)
191
+
192
+
193
+ def get_classnames(dataset_name: str) -> List[str]:
194
+ classnames, _ = get_classnames_and_templates(dataset_name)
195
+ return classnames
196
+
197
+
198
+ def get_templates(dataset_name: str) -> List[Callable]:
199
+ _, templates = get_classnames_and_templates(dataset_name)
200
+ return templates
@@ -1,23 +1,169 @@
1
1
  # flake8: noqa: F401
2
- import importlib
3
- from typing import Iterable
2
+ import sys
3
+ from typing import TYPE_CHECKING
4
4
 
5
- from . import data, functools, path, pylogger
6
- from .cache_utils import *
7
- from .devices import *
8
- from .dtype import parse_dtype
9
- from .fabric import seed_everything_by_time
10
- from .instantiate_utils import (
11
- instantiate,
12
- is_instantiable,
13
- set_print_function_call,
14
- set_print_function_call_permeanent,
15
- )
16
- from .json import load_from_json, save_to_json
17
- from .lazy_state_dict import LazyStateDict
18
- from .misc import *
19
- from .packages import import_object
20
- from .parameters import *
21
- from .pylogger import get_rankzero_logger
22
- from .timer import timeit_context
23
- from .type import BoolStateDictType, StateDictType, TorchModelType
5
+ from . import functools
6
+ from .lazy_imports import LazyImporter
7
+
8
+ _extra_objects = {
9
+ "functools": functools,
10
+ }
11
+ _import_structure = {
12
+ "cache_utils": [
13
+ "cache_to_disk",
14
+ "cache_with_joblib",
15
+ "set_default_cache_dir",
16
+ ],
17
+ "data": [
18
+ "InfiniteDataLoader",
19
+ "load_tensor_from_file",
20
+ "train_validation_split",
21
+ "train_validation_test_split",
22
+ ],
23
+ "devices": [
24
+ "clear_cuda_cache",
25
+ "get_current_device",
26
+ "get_device",
27
+ "get_device_capabilities",
28
+ "get_device_memory_info",
29
+ "num_devices",
30
+ "to_device",
31
+ ],
32
+ "dtype": ["get_dtype", "parse_dtype"],
33
+ "fabric": ["seed_everything_by_time"],
34
+ "instantiate_utils": [
35
+ "instantiate",
36
+ "is_instantiable",
37
+ "set_print_function_call",
38
+ "set_print_function_call_permeanent",
39
+ ],
40
+ "json": ["load_from_json", "save_to_json", "print_json"],
41
+ "lazy_state_dict": ["LazyStateDict"],
42
+ "misc": [
43
+ "first",
44
+ "has_length",
45
+ "join_lists",
46
+ "validate_and_suggest_corrections",
47
+ ],
48
+ "packages": ["compare_versions", "import_object"],
49
+ "parameters": [
50
+ "check_parameters_all_equal",
51
+ "count_parameters",
52
+ "get_parameter_statistics",
53
+ "get_parameter_summary",
54
+ "human_readable",
55
+ "print_parameters",
56
+ "state_dict_to_vector",
57
+ "trainable_state_dict",
58
+ "vector_to_state_dict",
59
+ ],
60
+ "path": [
61
+ "create_symlink",
62
+ "listdir_fullpath",
63
+ "path_is_dir_and_not_empty",
64
+ ],
65
+ "pylogger": [
66
+ "RankedLogger",
67
+ "RankZeroLogger",
68
+ "get_rankzero_logger",
69
+ ],
70
+ "state_dict_arithmetic": [
71
+ "ArithmeticStateDict",
72
+ "state_dicts_check_keys",
73
+ "num_params_of_state_dict",
74
+ "state_dict_to_device",
75
+ "state_dict_flatten",
76
+ "state_dict_avg",
77
+ "state_dict_sub",
78
+ "state_dict_add",
79
+ "state_dict_add_scalar",
80
+ "state_dict_mul",
81
+ "state_dict_div",
82
+ "state_dict_power",
83
+ "state_dict_interpolation",
84
+ "state_dict_sum",
85
+ "state_dict_weighted_sum",
86
+ "state_dict_diff_abs",
87
+ "state_dict_binary_mask",
88
+ "state_dict_hadamard_product",
89
+ ],
90
+ "timer": ["timeit_context"],
91
+ "type": [
92
+ "BoolStateDictType",
93
+ "StateDictType",
94
+ "TorchModelType",
95
+ ],
96
+ }
97
+
98
+ if TYPE_CHECKING:
99
+ from .cache_utils import cache_to_disk, cache_with_joblib, set_default_cache_dir
100
+ from .data import (
101
+ InfiniteDataLoader,
102
+ load_tensor_from_file,
103
+ train_validation_split,
104
+ train_validation_test_split,
105
+ )
106
+ from .devices import (
107
+ clear_cuda_cache,
108
+ get_current_device,
109
+ get_device,
110
+ get_device_capabilities,
111
+ get_device_memory_info,
112
+ num_devices,
113
+ to_device,
114
+ )
115
+ from .dtype import get_dtype, parse_dtype
116
+ from .fabric import seed_everything_by_time
117
+ from .instantiate_utils import (
118
+ instantiate,
119
+ is_instantiable,
120
+ set_print_function_call,
121
+ set_print_function_call_permeanent,
122
+ )
123
+ from .json import load_from_json, print_json, save_to_json
124
+ from .lazy_state_dict import LazyStateDict
125
+ from .misc import first, has_length, join_lists, validate_and_suggest_corrections
126
+ from .packages import compare_versions, import_object
127
+ from .parameters import (
128
+ check_parameters_all_equal,
129
+ count_parameters,
130
+ get_parameter_statistics,
131
+ get_parameter_summary,
132
+ human_readable,
133
+ print_parameters,
134
+ state_dict_to_vector,
135
+ trainable_state_dict,
136
+ vector_to_state_dict,
137
+ )
138
+ from .path import create_symlink, listdir_fullpath, path_is_dir_and_not_empty
139
+ from .pylogger import RankedLogger, RankZeroLogger, get_rankzero_logger
140
+ from .state_dict_arithmetic import (
141
+ ArithmeticStateDict,
142
+ num_params_of_state_dict,
143
+ state_dict_add,
144
+ state_dict_add_scalar,
145
+ state_dict_avg,
146
+ state_dict_binary_mask,
147
+ state_dict_diff_abs,
148
+ state_dict_div,
149
+ state_dict_flatten,
150
+ state_dict_hadamard_product,
151
+ state_dict_interpolation,
152
+ state_dict_mul,
153
+ state_dict_power,
154
+ state_dict_sub,
155
+ state_dict_sum,
156
+ state_dict_to_device,
157
+ state_dict_weighted_sum,
158
+ state_dicts_check_keys,
159
+ )
160
+ from .timer import timeit_context
161
+ from .type import BoolStateDictType, StateDictType, TorchModelType
162
+
163
+ else:
164
+ sys.modules[__name__] = LazyImporter(
165
+ __name__,
166
+ globals()["__file__"],
167
+ _import_structure,
168
+ extra_objects=_extra_objects,
169
+ )
@@ -24,36 +24,78 @@ to publish it as a standalone package.
24
24
  import importlib
25
25
  import os
26
26
  from types import ModuleType
27
- from typing import Any
27
+ from typing import Any, Dict, List, Optional, Set, Union
28
28
 
29
29
 
30
30
  class LazyImporter(ModuleType):
31
- """Do lazy imports."""
31
+ """Lazy importer for modules and their components.
32
+
33
+ This class allows for lazy importing of modules, meaning modules are only
34
+ imported when they are actually accessed. This can help reduce startup
35
+ time and memory usage for large packages with many optional dependencies.
36
+
37
+ Attributes:
38
+ _modules: Set of module names available for import.
39
+ _class_to_module: Mapping from class/function names to their module names.
40
+ _objects: Dictionary of extra objects to include in the module.
41
+ _name: Name of the module.
42
+ _import_structure: Dictionary mapping module names to lists of their exports.
43
+ """
32
44
 
33
45
  # Very heavily inspired by optuna.integration._IntegrationModule
34
46
  # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
35
- def __init__(self, name, module_file, import_structure, extra_objects=None):
47
+ def __init__(
48
+ self,
49
+ name: str,
50
+ module_file: str,
51
+ import_structure: Dict[str, List[str]],
52
+ extra_objects: Optional[Dict[str, Any]] = None,
53
+ ) -> None:
54
+ """Initialize the LazyImporter.
55
+
56
+ Args:
57
+ name: The name of the module.
58
+ module_file: Path to the module file.
59
+ import_structure: Dictionary mapping module names to lists of their exports.
60
+ extra_objects: Optional dictionary of extra objects to include.
61
+ """
36
62
  super().__init__(name)
37
- self._modules = set(import_structure.keys())
38
- self._class_to_module = {}
63
+ self._modules: Set[str] = set(import_structure.keys())
64
+ self._class_to_module: Dict[str, str] = {}
39
65
  for key, values in import_structure.items():
40
66
  for value in values:
41
67
  self._class_to_module[value] = key
42
68
  # Needed for autocompletion in an IDE
43
- self.__all__ = list(import_structure.keys()) + sum(
69
+ self.__all__: List[str] = list(import_structure.keys()) + sum(
44
70
  import_structure.values(), []
45
71
  )
46
72
  self.__file__ = module_file
47
73
  self.__path__ = [os.path.dirname(module_file)]
48
- self._objects = {} if extra_objects is None else extra_objects
74
+ self._objects: Dict[str, Any] = {} if extra_objects is None else extra_objects
49
75
  self._name = name
50
76
  self._import_structure = import_structure
51
77
 
52
78
  # Needed for autocompletion in an IDE
53
- def __dir__(self):
79
+ def __dir__(self) -> List[str]:
80
+ """Return list of available attributes for autocompletion.
81
+
82
+ Returns:
83
+ List of all available attribute names.
84
+ """
54
85
  return super().__dir__() + self.__all__
55
86
 
56
87
  def __getattr__(self, name: str) -> Any:
88
+ """Get attribute lazily, importing the module if necessary.
89
+
90
+ Args:
91
+ name: The name of the attribute to retrieve.
92
+
93
+ Returns:
94
+ The requested attribute.
95
+
96
+ Raises:
97
+ AttributeError: If the attribute is not found in any module.
98
+ """
57
99
  if name in self._objects:
58
100
  return self._objects[name]
59
101
  if name in self._modules:
@@ -67,31 +109,68 @@ class LazyImporter(ModuleType):
67
109
  setattr(self, name, value)
68
110
  return value
69
111
 
70
- def _get_module(self, module_name: str):
112
+ def _get_module(self, module_name: str) -> ModuleType:
113
+ """Import and return the specified module.
114
+
115
+ Args:
116
+ module_name: Name of the module to import.
117
+
118
+ Returns:
119
+ The imported module.
120
+ """
71
121
  return importlib.import_module("." + module_name, self.__name__)
72
122
 
73
- def __reduce__(self):
123
+ def __reduce__(self) -> tuple:
124
+ """Support for pickling the LazyImporter.
125
+
126
+ Returns:
127
+ Tuple containing the class and arguments needed to reconstruct the object.
128
+ """
74
129
  return (self.__class__, (self._name, self.__file__, self._import_structure))
75
130
 
76
131
 
77
- class LazyModule(ModuleType):
132
+ class LazyPyModule(ModuleType):
78
133
  """Module wrapper for lazy import.
134
+
79
135
  Adapted from Optuna: https://github.com/optuna/optuna/blob/1f92d496b0c4656645384e31539e4ee74992ff55/optuna/__init__.py
80
136
 
81
137
  This class wraps specified module and lazily import it when they are actually accessed.
138
+ This can help reduce startup time and memory usage by deferring module imports
139
+ until they are needed.
82
140
 
83
141
  Args:
84
142
  name: Name of module to apply lazy import.
143
+
144
+ Attributes:
145
+ _name: The name of the module to be lazily imported.
85
146
  """
86
147
 
87
148
  def __init__(self, name: str) -> None:
149
+ """Initialize the LazyPyModule.
150
+
151
+ Args:
152
+ name: The name of the module to be lazily imported.
153
+ """
88
154
  super().__init__(name)
89
- self._name = name
155
+ self._name: str = name
90
156
 
91
157
  def _load(self) -> ModuleType:
158
+ """Load the actual module and update this object's dictionary.
159
+
160
+ Returns:
161
+ The loaded module.
162
+ """
92
163
  module = importlib.import_module(self._name)
93
164
  self.__dict__.update(module.__dict__)
94
165
  return module
95
166
 
96
167
  def __getattr__(self, item: str) -> Any:
168
+ """Get attribute from the lazily loaded module.
169
+
170
+ Args:
171
+ item: The name of the attribute to retrieve.
172
+
173
+ Returns:
174
+ The requested attribute from the loaded module.
175
+ """
97
176
  return getattr(self._load(), item)