fusion-bench 0.2.23__py3-none-any.whl → 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. fusion_bench/__init__.py +152 -42
  2. fusion_bench/dataset/__init__.py +27 -4
  3. fusion_bench/dataset/clip_dataset.py +2 -2
  4. fusion_bench/method/__init__.py +18 -1
  5. fusion_bench/method/classification/__init__.py +27 -2
  6. fusion_bench/method/classification/image_classification_finetune.py +214 -0
  7. fusion_bench/method/ensemble.py +17 -2
  8. fusion_bench/method/linear/__init__.py +6 -2
  9. fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
  10. fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
  11. fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
  12. fusion_bench/method/opcm/opcm.py +1 -0
  13. fusion_bench/method/pwe_moe/module.py +0 -2
  14. fusion_bench/method/simple_average.py +2 -2
  15. fusion_bench/method/tall_mask/task_arithmetic.py +2 -2
  16. fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
  17. fusion_bench/method/ties_merging/ties_merging.py +22 -6
  18. fusion_bench/method/wudi/__init__.py +1 -0
  19. fusion_bench/method/wudi/wudi.py +105 -0
  20. fusion_bench/mixins/__init__.py +2 -0
  21. fusion_bench/mixins/lightning_fabric.py +4 -0
  22. fusion_bench/mixins/pyinstrument.py +174 -0
  23. fusion_bench/mixins/serialization.py +25 -78
  24. fusion_bench/mixins/simple_profiler.py +106 -23
  25. fusion_bench/modelpool/__init__.py +2 -0
  26. fusion_bench/modelpool/base_pool.py +77 -14
  27. fusion_bench/modelpool/causal_lm/causal_lm.py +32 -10
  28. fusion_bench/modelpool/clip_vision/modelpool.py +56 -19
  29. fusion_bench/modelpool/resnet_for_image_classification.py +208 -0
  30. fusion_bench/models/__init__.py +35 -9
  31. fusion_bench/models/hf_clip.py +4 -0
  32. fusion_bench/models/hf_utils.py +2 -1
  33. fusion_bench/models/model_card_templates/default.md +8 -1
  34. fusion_bench/models/wrappers/ensemble.py +136 -7
  35. fusion_bench/optim/__init__.py +40 -2
  36. fusion_bench/optim/lr_scheduler/__init__.py +27 -1
  37. fusion_bench/optim/muon.py +339 -0
  38. fusion_bench/programs/__init__.py +2 -0
  39. fusion_bench/programs/fabric_fusion_program.py +2 -2
  40. fusion_bench/programs/fusion_program.py +271 -0
  41. fusion_bench/scripts/cli.py +2 -2
  42. fusion_bench/taskpool/clip_vision/taskpool.py +11 -4
  43. fusion_bench/tasks/clip_classification/__init__.py +15 -0
  44. fusion_bench/utils/__init__.py +167 -21
  45. fusion_bench/utils/devices.py +30 -8
  46. fusion_bench/utils/lazy_imports.py +91 -12
  47. fusion_bench/utils/lazy_state_dict.py +58 -5
  48. fusion_bench/utils/misc.py +104 -13
  49. fusion_bench/utils/packages.py +4 -0
  50. fusion_bench/utils/path.py +7 -0
  51. fusion_bench/utils/pylogger.py +6 -0
  52. fusion_bench/utils/rich_utils.py +8 -3
  53. fusion_bench/utils/state_dict_arithmetic.py +935 -162
  54. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/METADATA +10 -3
  55. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/RECORD +76 -55
  56. fusion_bench_config/method/classification/image_classification_finetune.yaml +16 -0
  57. fusion_bench_config/method/classification/image_classification_finetune_test.yaml +6 -0
  58. fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
  59. fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
  60. fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
  61. fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
  62. fusion_bench_config/method/wudi/wudi.yaml +4 -0
  63. fusion_bench_config/model_fusion.yaml +45 -0
  64. fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
  65. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
  66. fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
  67. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar10.yaml +14 -0
  68. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar100.yaml +14 -0
  69. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar10.yaml +14 -0
  70. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar100.yaml +14 -0
  71. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar10.yaml +14 -0
  72. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar100.yaml +14 -0
  73. fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
  74. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/WHEEL +0 -0
  75. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/entry_points.txt +0 -0
  76. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/licenses/LICENSE +0 -0
  77. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/top_level.txt +0 -0
@@ -27,7 +27,7 @@ from tqdm.autonotebook import tqdm
27
27
  from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
28
28
  from transformers.models.clip.modeling_clip import CLIPVisionTransformer
29
29
 
30
- from fusion_bench import RuntimeConstants
30
+ from fusion_bench import RuntimeConstants, auto_register_config
31
31
  from fusion_bench.dataset import CLIPDataset
32
32
  from fusion_bench.mixins import HydraConfigMixin, LightningFabricMixin
33
33
  from fusion_bench.models.hf_clip import HFCLIPClassifier
@@ -86,6 +86,7 @@ class LayerWiseFeatureSaver:
86
86
  torch.save(features, self.save_path)
87
87
 
88
88
 
89
+ @auto_register_config
89
90
  class CLIPVisionModelTaskPool(
90
91
  HydraConfigMixin,
91
92
  LightningFabricMixin,
@@ -134,11 +135,13 @@ class CLIPVisionModelTaskPool(
134
135
  layer_wise_feature_first_token_only: bool = True,
135
136
  layer_wise_feature_max_num: Optional[int] = None,
136
137
  fast_dev_run: Optional[bool] = None,
138
+ move_to_device: bool = True,
137
139
  **kwargs,
138
140
  ):
139
141
  """
140
142
  Initialize the CLIPVisionModelTaskPool.
141
143
  """
144
+ super().__init__(**kwargs)
142
145
  self._test_datasets = test_datasets
143
146
  self._processor = processor
144
147
  self._data_processor = data_processor
@@ -159,7 +162,6 @@ class CLIPVisionModelTaskPool(
159
162
  self.fast_dev_run = RuntimeConstants().debug
160
163
  else:
161
164
  self.fast_dev_run = fast_dev_run
162
- super().__init__(**kwargs)
163
165
 
164
166
  def setup(self):
165
167
  """
@@ -220,7 +222,9 @@ class CLIPVisionModelTaskPool(
220
222
  for name, dataset in self.test_datasets.items()
221
223
  }
222
224
  self.test_dataloaders = {
223
- name: self.fabric.setup_dataloaders(dataloader)
225
+ name: self.fabric.setup_dataloaders(
226
+ dataloader, move_to_device=self.move_to_device
227
+ )
224
228
  for name, dataloader in self.test_dataloaders.items()
225
229
  }
226
230
 
@@ -273,6 +277,8 @@ class CLIPVisionModelTaskPool(
273
277
  task_name=task_name,
274
278
  )
275
279
  logits: Tensor = outputs["logits"]
280
+ if logits.device != targets.device:
281
+ targets = targets.to(logits.device)
276
282
 
277
283
  loss = F.cross_entropy(logits, targets)
278
284
  loss_metric.update(loss.detach().cpu())
@@ -321,7 +327,8 @@ class CLIPVisionModelTaskPool(
321
327
  self.clip_model,
322
328
  processor=self.processor,
323
329
  )
324
- classifier = cast(HFCLIPClassifier, self.fabric.to_device(classifier))
330
+ if self.move_to_device:
331
+ classifier = cast(HFCLIPClassifier, self.fabric.to_device(classifier))
325
332
  # collect basic model information
326
333
  training_params, all_params = count_parameters(model)
327
334
  report["model_info"] = {
@@ -183,3 +183,18 @@ class CLIPTemplateFactory:
183
183
 
184
184
  def get_classnames_and_templates(dataset_name: str) -> Tuple[List[str], List[Callable]]:
185
185
  return CLIPTemplateFactory.get_classnames_and_templates(dataset_name)
186
+
187
+
188
+ def get_num_classes(dataset_name: str) -> int:
189
+ classnames, _ = get_classnames_and_templates(dataset_name)
190
+ return len(classnames)
191
+
192
+
193
+ def get_classnames(dataset_name: str) -> List[str]:
194
+ classnames, _ = get_classnames_and_templates(dataset_name)
195
+ return classnames
196
+
197
+
198
+ def get_templates(dataset_name: str) -> List[Callable]:
199
+ _, templates = get_classnames_and_templates(dataset_name)
200
+ return templates
@@ -1,23 +1,169 @@
1
1
  # flake8: noqa: F401
2
- import importlib
3
- from typing import Iterable
2
+ import sys
3
+ from typing import TYPE_CHECKING
4
4
 
5
- from . import data, functools, path, pylogger
6
- from .cache_utils import *
7
- from .devices import *
8
- from .dtype import parse_dtype
9
- from .fabric import seed_everything_by_time
10
- from .instantiate_utils import (
11
- instantiate,
12
- is_instantiable,
13
- set_print_function_call,
14
- set_print_function_call_permeanent,
15
- )
16
- from .json import load_from_json, save_to_json
17
- from .lazy_state_dict import LazyStateDict
18
- from .misc import *
19
- from .packages import import_object
20
- from .parameters import *
21
- from .pylogger import get_rankzero_logger
22
- from .timer import timeit_context
23
- from .type import BoolStateDictType, StateDictType, TorchModelType
5
+ from . import functools
6
+ from .lazy_imports import LazyImporter
7
+
8
+ _extra_objects = {
9
+ "functools": functools,
10
+ }
11
+ _import_structure = {
12
+ "cache_utils": [
13
+ "cache_to_disk",
14
+ "cache_with_joblib",
15
+ "set_default_cache_dir",
16
+ ],
17
+ "data": [
18
+ "InfiniteDataLoader",
19
+ "load_tensor_from_file",
20
+ "train_validation_split",
21
+ "train_validation_test_split",
22
+ ],
23
+ "devices": [
24
+ "clear_cuda_cache",
25
+ "get_current_device",
26
+ "get_device",
27
+ "get_device_capabilities",
28
+ "get_device_memory_info",
29
+ "num_devices",
30
+ "to_device",
31
+ ],
32
+ "dtype": ["get_dtype", "parse_dtype"],
33
+ "fabric": ["seed_everything_by_time"],
34
+ "instantiate_utils": [
35
+ "instantiate",
36
+ "is_instantiable",
37
+ "set_print_function_call",
38
+ "set_print_function_call_permeanent",
39
+ ],
40
+ "json": ["load_from_json", "save_to_json", "print_json"],
41
+ "lazy_state_dict": ["LazyStateDict"],
42
+ "misc": [
43
+ "first",
44
+ "has_length",
45
+ "join_lists",
46
+ "validate_and_suggest_corrections",
47
+ ],
48
+ "packages": ["compare_versions", "import_object"],
49
+ "parameters": [
50
+ "check_parameters_all_equal",
51
+ "count_parameters",
52
+ "get_parameter_statistics",
53
+ "get_parameter_summary",
54
+ "human_readable",
55
+ "print_parameters",
56
+ "state_dict_to_vector",
57
+ "trainable_state_dict",
58
+ "vector_to_state_dict",
59
+ ],
60
+ "path": [
61
+ "create_symlink",
62
+ "listdir_fullpath",
63
+ "path_is_dir_and_not_empty",
64
+ ],
65
+ "pylogger": [
66
+ "RankedLogger",
67
+ "RankZeroLogger",
68
+ "get_rankzero_logger",
69
+ ],
70
+ "state_dict_arithmetic": [
71
+ "ArithmeticStateDict",
72
+ "state_dicts_check_keys",
73
+ "num_params_of_state_dict",
74
+ "state_dict_to_device",
75
+ "state_dict_flatten",
76
+ "state_dict_avg",
77
+ "state_dict_sub",
78
+ "state_dict_add",
79
+ "state_dict_add_scalar",
80
+ "state_dict_mul",
81
+ "state_dict_div",
82
+ "state_dict_power",
83
+ "state_dict_interpolation",
84
+ "state_dict_sum",
85
+ "state_dict_weighted_sum",
86
+ "state_dict_diff_abs",
87
+ "state_dict_binary_mask",
88
+ "state_dict_hadamard_product",
89
+ ],
90
+ "timer": ["timeit_context"],
91
+ "type": [
92
+ "BoolStateDictType",
93
+ "StateDictType",
94
+ "TorchModelType",
95
+ ],
96
+ }
97
+
98
+ if TYPE_CHECKING:
99
+ from .cache_utils import cache_to_disk, cache_with_joblib, set_default_cache_dir
100
+ from .data import (
101
+ InfiniteDataLoader,
102
+ load_tensor_from_file,
103
+ train_validation_split,
104
+ train_validation_test_split,
105
+ )
106
+ from .devices import (
107
+ clear_cuda_cache,
108
+ get_current_device,
109
+ get_device,
110
+ get_device_capabilities,
111
+ get_device_memory_info,
112
+ num_devices,
113
+ to_device,
114
+ )
115
+ from .dtype import get_dtype, parse_dtype
116
+ from .fabric import seed_everything_by_time
117
+ from .instantiate_utils import (
118
+ instantiate,
119
+ is_instantiable,
120
+ set_print_function_call,
121
+ set_print_function_call_permeanent,
122
+ )
123
+ from .json import load_from_json, print_json, save_to_json
124
+ from .lazy_state_dict import LazyStateDict
125
+ from .misc import first, has_length, join_lists, validate_and_suggest_corrections
126
+ from .packages import compare_versions, import_object
127
+ from .parameters import (
128
+ check_parameters_all_equal,
129
+ count_parameters,
130
+ get_parameter_statistics,
131
+ get_parameter_summary,
132
+ human_readable,
133
+ print_parameters,
134
+ state_dict_to_vector,
135
+ trainable_state_dict,
136
+ vector_to_state_dict,
137
+ )
138
+ from .path import create_symlink, listdir_fullpath, path_is_dir_and_not_empty
139
+ from .pylogger import RankedLogger, RankZeroLogger, get_rankzero_logger
140
+ from .state_dict_arithmetic import (
141
+ ArithmeticStateDict,
142
+ num_params_of_state_dict,
143
+ state_dict_add,
144
+ state_dict_add_scalar,
145
+ state_dict_avg,
146
+ state_dict_binary_mask,
147
+ state_dict_diff_abs,
148
+ state_dict_div,
149
+ state_dict_flatten,
150
+ state_dict_hadamard_product,
151
+ state_dict_interpolation,
152
+ state_dict_mul,
153
+ state_dict_power,
154
+ state_dict_sub,
155
+ state_dict_sum,
156
+ state_dict_to_device,
157
+ state_dict_weighted_sum,
158
+ state_dicts_check_keys,
159
+ )
160
+ from .timer import timeit_context
161
+ from .type import BoolStateDictType, StateDictType, TorchModelType
162
+
163
+ else:
164
+ sys.modules[__name__] = LazyImporter(
165
+ __name__,
166
+ globals()["__file__"],
167
+ _import_structure,
168
+ extra_objects=_extra_objects,
169
+ )
@@ -39,7 +39,12 @@ def clear_cuda_cache():
39
39
  log.warning("CUDA is not available. No cache to clear.")
40
40
 
41
41
 
42
- def to_device(obj: T, device: Optional[torch.device], **kwargs: Any) -> T:
42
+ def to_device(
43
+ obj: T,
44
+ device: Optional[torch.device],
45
+ copy_on_move: bool = False,
46
+ **kwargs: Any,
47
+ ) -> T:
43
48
  """
44
49
  Move a given object to the specified device.
45
50
 
@@ -49,12 +54,20 @@ def to_device(obj: T, device: Optional[torch.device], **kwargs: Any) -> T:
49
54
  Args:
50
55
  obj: The object to be moved to the device. This can be a torch.Tensor, torch.nn.Module, list, tuple, or dict.
51
56
  device (torch.device): The target device to move the object to. This can be `None`.
52
- **kwargs: Additional keyword arguments to be passed to the `to` method of torch.Tensor or torch.nn.Module. For example, `non_blocking=True`, `dtype=torch.float16`.
57
+ copy_on_move (bool, optional): Whether to force a copy operation when moving tensors to a different device.
58
+ If True, tensors will be copied when moved to a different device (copy=True is passed to tensor.to()).
59
+ If False (default), tensors are moved without forcing a copy operation, allowing PyTorch to optimize
60
+ the operation. This parameter only affects torch.Tensor objects; modules and other types are unaffected.
61
+ Defaults to False.
62
+ **kwargs: Additional keyword arguments to be passed to the `to` method of torch.Tensor or torch.nn.Module.
63
+ For example, `non_blocking=True`, `dtype=torch.float16`. Note that if `copy_on_move=True`, the `copy`
64
+ keyword argument will be automatically set and should not be provided manually.
53
65
 
54
66
  Returns:
55
67
  The object moved to the specified device. The type of the returned object matches the type of the input object.
56
68
 
57
69
  Examples:
70
+ ```python
58
71
  >>> tensor = torch.tensor([1, 2, 3])
59
72
  >>> to_device(tensor, torch.device('cuda'))
60
73
  tensor([1, 2, 3], device='cuda:0')
@@ -66,17 +79,26 @@ def to_device(obj: T, device: Optional[torch.device], **kwargs: Any) -> T:
66
79
  >>> data = [torch.tensor([1, 2]), torch.tensor([3, 4])]
67
80
  >>> to_device(data, torch.device('cuda'))
68
81
  [tensor([1, 2], device='cuda:0'), tensor([3, 4], device='cuda:0')]
82
+
83
+ >>> # Force copy when moving to different device
84
+ >>> tensor = torch.tensor([1, 2, 3], device='cpu')
85
+ >>> copied_tensor = to_device(tensor, torch.device('cuda'), copy_on_move=True)
86
+ >>> # tensor and copied_tensor will have different memory locations
87
+ ```
69
88
  """
70
- if isinstance(obj, (torch.Tensor, torch.nn.Module)):
89
+ if isinstance(obj, torch.Tensor):
90
+ if copy_on_move:
91
+ if obj.device != torch.device(device):
92
+ kwargs["copy"] = True
93
+ return obj.to(device, **kwargs)
94
+ elif isinstance(obj, torch.nn.Module):
71
95
  return obj.to(device, **kwargs)
72
96
  elif isinstance(obj, list):
73
- return [to_device(o, device) for o in obj]
97
+ return [to_device(o, device, **kwargs) for o in obj]
74
98
  elif isinstance(obj, tuple):
75
- return tuple(to_device(o, device) for o in obj)
99
+ return tuple(to_device(o, device, **kwargs) for o in obj)
76
100
  elif isinstance(obj, dict):
77
- for key in obj:
78
- obj[key] = to_device(obj[key], device)
79
- return obj
101
+ return {key: to_device(value, device, **kwargs) for key, value in obj.items()}
80
102
  else:
81
103
  # the default behavior is to return the object as is
82
104
  return obj
@@ -24,36 +24,78 @@ to publish it as a standalone package.
24
24
  import importlib
25
25
  import os
26
26
  from types import ModuleType
27
- from typing import Any
27
+ from typing import Any, Dict, List, Optional, Set, Union
28
28
 
29
29
 
30
30
  class LazyImporter(ModuleType):
31
- """Do lazy imports."""
31
+ """Lazy importer for modules and their components.
32
+
33
+ This class allows for lazy importing of modules, meaning modules are only
34
+ imported when they are actually accessed. This can help reduce startup
35
+ time and memory usage for large packages with many optional dependencies.
36
+
37
+ Attributes:
38
+ _modules: Set of module names available for import.
39
+ _class_to_module: Mapping from class/function names to their module names.
40
+ _objects: Dictionary of extra objects to include in the module.
41
+ _name: Name of the module.
42
+ _import_structure: Dictionary mapping module names to lists of their exports.
43
+ """
32
44
 
33
45
  # Very heavily inspired by optuna.integration._IntegrationModule
34
46
  # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
35
- def __init__(self, name, module_file, import_structure, extra_objects=None):
47
+ def __init__(
48
+ self,
49
+ name: str,
50
+ module_file: str,
51
+ import_structure: Dict[str, List[str]],
52
+ extra_objects: Optional[Dict[str, Any]] = None,
53
+ ) -> None:
54
+ """Initialize the LazyImporter.
55
+
56
+ Args:
57
+ name: The name of the module.
58
+ module_file: Path to the module file.
59
+ import_structure: Dictionary mapping module names to lists of their exports.
60
+ extra_objects: Optional dictionary of extra objects to include.
61
+ """
36
62
  super().__init__(name)
37
- self._modules = set(import_structure.keys())
38
- self._class_to_module = {}
63
+ self._modules: Set[str] = set(import_structure.keys())
64
+ self._class_to_module: Dict[str, str] = {}
39
65
  for key, values in import_structure.items():
40
66
  for value in values:
41
67
  self._class_to_module[value] = key
42
68
  # Needed for autocompletion in an IDE
43
- self.__all__ = list(import_structure.keys()) + sum(
69
+ self.__all__: List[str] = list(import_structure.keys()) + sum(
44
70
  import_structure.values(), []
45
71
  )
46
72
  self.__file__ = module_file
47
73
  self.__path__ = [os.path.dirname(module_file)]
48
- self._objects = {} if extra_objects is None else extra_objects
74
+ self._objects: Dict[str, Any] = {} if extra_objects is None else extra_objects
49
75
  self._name = name
50
76
  self._import_structure = import_structure
51
77
 
52
78
  # Needed for autocompletion in an IDE
53
- def __dir__(self):
79
+ def __dir__(self) -> List[str]:
80
+ """Return list of available attributes for autocompletion.
81
+
82
+ Returns:
83
+ List of all available attribute names.
84
+ """
54
85
  return super().__dir__() + self.__all__
55
86
 
56
87
  def __getattr__(self, name: str) -> Any:
88
+ """Get attribute lazily, importing the module if necessary.
89
+
90
+ Args:
91
+ name: The name of the attribute to retrieve.
92
+
93
+ Returns:
94
+ The requested attribute.
95
+
96
+ Raises:
97
+ AttributeError: If the attribute is not found in any module.
98
+ """
57
99
  if name in self._objects:
58
100
  return self._objects[name]
59
101
  if name in self._modules:
@@ -67,31 +109,68 @@ class LazyImporter(ModuleType):
67
109
  setattr(self, name, value)
68
110
  return value
69
111
 
70
- def _get_module(self, module_name: str):
112
+ def _get_module(self, module_name: str) -> ModuleType:
113
+ """Import and return the specified module.
114
+
115
+ Args:
116
+ module_name: Name of the module to import.
117
+
118
+ Returns:
119
+ The imported module.
120
+ """
71
121
  return importlib.import_module("." + module_name, self.__name__)
72
122
 
73
- def __reduce__(self):
123
+ def __reduce__(self) -> tuple:
124
+ """Support for pickling the LazyImporter.
125
+
126
+ Returns:
127
+ Tuple containing the class and arguments needed to reconstruct the object.
128
+ """
74
129
  return (self.__class__, (self._name, self.__file__, self._import_structure))
75
130
 
76
131
 
77
- class LazyModule(ModuleType):
132
+ class LazyPyModule(ModuleType):
78
133
  """Module wrapper for lazy import.
134
+
79
135
  Adapted from Optuna: https://github.com/optuna/optuna/blob/1f92d496b0c4656645384e31539e4ee74992ff55/optuna/__init__.py
80
136
 
81
137
  This class wraps specified module and lazily import it when they are actually accessed.
138
+ This can help reduce startup time and memory usage by deferring module imports
139
+ until they are needed.
82
140
 
83
141
  Args:
84
142
  name: Name of module to apply lazy import.
143
+
144
+ Attributes:
145
+ _name: The name of the module to be lazily imported.
85
146
  """
86
147
 
87
148
  def __init__(self, name: str) -> None:
149
+ """Initialize the LazyPyModule.
150
+
151
+ Args:
152
+ name: The name of the module to be lazily imported.
153
+ """
88
154
  super().__init__(name)
89
- self._name = name
155
+ self._name: str = name
90
156
 
91
157
  def _load(self) -> ModuleType:
158
+ """Load the actual module and update this object's dictionary.
159
+
160
+ Returns:
161
+ The loaded module.
162
+ """
92
163
  module = importlib.import_module(self._name)
93
164
  self.__dict__.update(module.__dict__)
94
165
  return module
95
166
 
96
167
  def __getattr__(self, item: str) -> Any:
168
+ """Get attribute from the lazily loaded module.
169
+
170
+ Args:
171
+ item: The name of the attribute to retrieve.
172
+
173
+ Returns:
174
+ The requested attribute from the loaded module.
175
+ """
97
176
  return getattr(self._load(), item)