fusion-bench 0.2.23__py3-none-any.whl → 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. fusion_bench/__init__.py +152 -42
  2. fusion_bench/dataset/__init__.py +27 -4
  3. fusion_bench/dataset/clip_dataset.py +2 -2
  4. fusion_bench/method/__init__.py +18 -1
  5. fusion_bench/method/classification/__init__.py +27 -2
  6. fusion_bench/method/classification/image_classification_finetune.py +214 -0
  7. fusion_bench/method/ensemble.py +17 -2
  8. fusion_bench/method/linear/__init__.py +6 -2
  9. fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
  10. fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
  11. fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
  12. fusion_bench/method/opcm/opcm.py +1 -0
  13. fusion_bench/method/pwe_moe/module.py +0 -2
  14. fusion_bench/method/simple_average.py +2 -2
  15. fusion_bench/method/tall_mask/task_arithmetic.py +2 -2
  16. fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
  17. fusion_bench/method/ties_merging/ties_merging.py +22 -6
  18. fusion_bench/method/wudi/__init__.py +1 -0
  19. fusion_bench/method/wudi/wudi.py +105 -0
  20. fusion_bench/mixins/__init__.py +2 -0
  21. fusion_bench/mixins/lightning_fabric.py +4 -0
  22. fusion_bench/mixins/pyinstrument.py +174 -0
  23. fusion_bench/mixins/serialization.py +25 -78
  24. fusion_bench/mixins/simple_profiler.py +106 -23
  25. fusion_bench/modelpool/__init__.py +2 -0
  26. fusion_bench/modelpool/base_pool.py +77 -14
  27. fusion_bench/modelpool/causal_lm/causal_lm.py +32 -10
  28. fusion_bench/modelpool/clip_vision/modelpool.py +56 -19
  29. fusion_bench/modelpool/resnet_for_image_classification.py +208 -0
  30. fusion_bench/models/__init__.py +35 -9
  31. fusion_bench/models/hf_clip.py +4 -0
  32. fusion_bench/models/hf_utils.py +2 -1
  33. fusion_bench/models/model_card_templates/default.md +8 -1
  34. fusion_bench/models/wrappers/ensemble.py +136 -7
  35. fusion_bench/optim/__init__.py +40 -2
  36. fusion_bench/optim/lr_scheduler/__init__.py +27 -1
  37. fusion_bench/optim/muon.py +339 -0
  38. fusion_bench/programs/__init__.py +2 -0
  39. fusion_bench/programs/fabric_fusion_program.py +2 -2
  40. fusion_bench/programs/fusion_program.py +271 -0
  41. fusion_bench/scripts/cli.py +2 -2
  42. fusion_bench/taskpool/clip_vision/taskpool.py +11 -4
  43. fusion_bench/tasks/clip_classification/__init__.py +15 -0
  44. fusion_bench/utils/__init__.py +167 -21
  45. fusion_bench/utils/devices.py +30 -8
  46. fusion_bench/utils/lazy_imports.py +91 -12
  47. fusion_bench/utils/lazy_state_dict.py +58 -5
  48. fusion_bench/utils/misc.py +104 -13
  49. fusion_bench/utils/packages.py +4 -0
  50. fusion_bench/utils/path.py +7 -0
  51. fusion_bench/utils/pylogger.py +6 -0
  52. fusion_bench/utils/rich_utils.py +8 -3
  53. fusion_bench/utils/state_dict_arithmetic.py +935 -162
  54. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/METADATA +10 -3
  55. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/RECORD +76 -55
  56. fusion_bench_config/method/classification/image_classification_finetune.yaml +16 -0
  57. fusion_bench_config/method/classification/image_classification_finetune_test.yaml +6 -0
  58. fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
  59. fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
  60. fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
  61. fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
  62. fusion_bench_config/method/wudi/wudi.yaml +4 -0
  63. fusion_bench_config/model_fusion.yaml +45 -0
  64. fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
  65. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
  66. fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
  67. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar10.yaml +14 -0
  68. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar100.yaml +14 -0
  69. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar10.yaml +14 -0
  70. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar100.yaml +14 -0
  71. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar10.yaml +14 -0
  72. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar100.yaml +14 -0
  73. fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
  74. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/WHEEL +0 -0
  75. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/entry_points.txt +0 -0
  76. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/licenses/LICENSE +0 -0
  77. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/top_level.txt +0 -0
@@ -1,22 +1,27 @@
1
1
  import logging
2
+ import os
2
3
  from typing import Dict, List, Mapping, Optional, TypeVar, Union # noqa: F401
3
4
 
4
5
  from typing_extensions import override
5
6
 
6
- from fusion_bench import timeit_context
7
+ from fusion_bench import auto_register_config, timeit_context
7
8
  from fusion_bench.method import TaskArithmeticAlgorithm
8
9
  from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
9
10
  from fusion_bench.modelpool import CausalLMBackbonePool, CausalLMPool
11
+ from fusion_bench.models.hf_utils import create_default_model_card
10
12
 
11
13
  log = logging.getLogger(__name__)
12
14
 
13
15
 
14
- class TaskArithmeticForLlama(TaskArithmeticAlgorithm, SimpleProfilerMixin):
16
+ @auto_register_config
17
+ class TaskArithmeticForCausalLM(
18
+ TaskArithmeticAlgorithm,
19
+ ):
15
20
  R"""
16
21
  Examples:
17
22
 
18
23
  fusion_bench \
19
- method=linear/task_arithmetic_for_llama \
24
+ method=linear/task_arithmetic_for_causallm \
20
25
  method.scaling_factor=0.3 \
21
26
  method.model_save_path=outputs/simle_mixtral_exp_v4/task_arithmetic_0.3 \
22
27
  modelpool=CausalLMPool/simle_mixtral_exp_v4.yaml
@@ -29,18 +34,14 @@ class TaskArithmeticForLlama(TaskArithmeticAlgorithm, SimpleProfilerMixin):
29
34
  def __init__(
30
35
  self,
31
36
  scaling_factor: float,
32
- merge_backbone: bool,
37
+ merge_backbone: bool = False,
33
38
  model_save_path: Optional[str] = None,
39
+ **kwargs,
34
40
  ):
35
- self.merge_backbone = merge_backbone
36
- self.model_save_path = model_save_path
37
- super().__init__(scaling_factor=scaling_factor)
41
+ super().__init__(scaling_factor=scaling_factor, **kwargs)
38
42
 
39
43
  @override
40
44
  def run(self, modelpool: CausalLMPool):
41
- if self.model_save_path:
42
- tokenizer = modelpool.load_tokenizer()
43
-
44
45
  if self.merge_backbone:
45
46
  assert modelpool.has_pretrained
46
47
  backbone_modelpool = CausalLMBackbonePool(**modelpool.config)
@@ -52,6 +53,15 @@ class TaskArithmeticForLlama(TaskArithmeticAlgorithm, SimpleProfilerMixin):
52
53
 
53
54
  if self.model_save_path is not None:
54
55
  with timeit_context(f"Saving the model to {self.model_save_path}"):
55
- tokenizer.save_pretrained(self.model_save_path)
56
- model.save_pretrained(self.model_save_path)
56
+ description = f"Merged model using task arithmetic with scaling factor {self.scaling_factor}."
57
+ modelpool.save_model(
58
+ model=model,
59
+ path=self.model_save_path,
60
+ save_tokenizer=True,
61
+ algorithm_config=self.config,
62
+ description=description,
63
+ )
57
64
  return model
65
+
66
+
67
+ TaskArithmeticForLlama = TaskArithmeticForCausalLM
@@ -0,0 +1,70 @@
1
+ import logging
2
+ import os
3
+ from typing import Dict, List, Mapping, Optional, TypeVar, Union # noqa: F401
4
+
5
+ from typing_extensions import override
6
+
7
+ from fusion_bench import auto_register_config, timeit_context
8
+ from fusion_bench.method import TiesMergingAlgorithm
9
+ from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
10
+ from fusion_bench.modelpool import CausalLMBackbonePool, CausalLMPool
11
+ from fusion_bench.models.hf_utils import create_default_model_card
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ @auto_register_config
17
+ class TiesMergingForCausalLM(
18
+ TiesMergingAlgorithm,
19
+ ):
20
+ R"""
21
+ TIES merging algorithm for CausalLM models.
22
+
23
+ This class extends the TiesMergingAlgorithm to work specifically with CausalLM models,
24
+ providing model saving capabilities and backbone merging support.
25
+ """
26
+
27
+ _config_mapping = TiesMergingAlgorithm._config_mapping | {
28
+ "merge_backbone": "merge_backbone",
29
+ }
30
+
31
+ def __init__(
32
+ self,
33
+ scaling_factor: float,
34
+ threshold: float,
35
+ remove_keys: List[str] = None,
36
+ merge_func: str = "sum",
37
+ merge_backbone: bool = False,
38
+ model_save_path: Optional[str] = None,
39
+ **kwargs,
40
+ ):
41
+ super().__init__(
42
+ scaling_factor=scaling_factor,
43
+ threshold=threshold,
44
+ remove_keys=remove_keys,
45
+ merge_func=merge_func,
46
+ **kwargs,
47
+ )
48
+
49
+ @override
50
+ def run(self, modelpool: CausalLMPool):
51
+ if self.merge_backbone:
52
+ assert modelpool.has_pretrained
53
+ backbone_modelpool = CausalLMBackbonePool(**modelpool.config)
54
+ model = modelpool.load_model("_pretrained_")
55
+ backbone_model = super().run(backbone_modelpool)
56
+ model.model.layers = backbone_model
57
+ else:
58
+ model = super().run(modelpool)
59
+
60
+ if self.model_save_path is not None:
61
+ with timeit_context(f"Saving the model to {self.model_save_path}"):
62
+ description = f"Merged model using TIES merging with scaling factor {self.scaling_factor} and threshold {self.threshold}."
63
+ modelpool.save_model(
64
+ model=model,
65
+ path=self.model_save_path,
66
+ save_tokenizer=True,
67
+ algorithm_config=self.config,
68
+ description=description,
69
+ )
70
+ return model
@@ -87,6 +87,7 @@ class OPCMForCLIP(
87
87
  # get the average model
88
88
  with self.profile("loading model"):
89
89
  merged_model = modelpool.load_model(model_names[0])
90
+ assert merged_model is not None, "Failed to load the first model"
90
91
 
91
92
  if self.evaluate_on_every_step:
92
93
  with self.profile("evaluating model"):
@@ -13,8 +13,6 @@ import torch.func
13
13
  from torch import Tensor, nn
14
14
  from torch.nn import functional as F
15
15
 
16
- from fusion_bench.utils import join_list
17
-
18
16
  log = logging.getLogger(__name__)
19
17
 
20
18
 
@@ -89,7 +89,7 @@ class SimpleAverageAlgorithm(
89
89
  modelpool = BaseModelPool(modelpool)
90
90
 
91
91
  log.info(
92
- f"Fusing models using simple average on {len(modelpool.model_names)} models."
92
+ f"Fusing models using simple average on {len(modelpool.model_names)} models. "
93
93
  f"models: {modelpool.model_names}"
94
94
  )
95
95
  sd: Optional[StateDictType] = None
@@ -119,7 +119,7 @@ class SimpleAverageAlgorithm(
119
119
 
120
120
  if isinstance(forward_model, LazyStateDict):
121
121
  # if the model is a LazyStateDict, convert it to an empty module
122
- forward_model = forward_model.meta_module.to_empty(
122
+ forward_model = deepcopy(forward_model.meta_module).to_empty(
123
123
  device=forward_model._device
124
124
  )
125
125
  result = forward_model.load_state_dict(sd, strict=False)
@@ -15,7 +15,7 @@ from fusion_bench.utils.state_dict_arithmetic import (
15
15
  state_dict_add,
16
16
  state_dict_binary_mask,
17
17
  state_dict_diff_abs,
18
- state_dict_hadmard_product,
18
+ state_dict_hadamard_product,
19
19
  state_dict_mul,
20
20
  state_dict_sub,
21
21
  state_dict_sum,
@@ -111,7 +111,7 @@ class TallMaskTaskArithmeticAlgorithm(
111
111
 
112
112
  with self.profile("compress and retrieve"):
113
113
  for model_name in modelpool.model_names:
114
- retrieved_task_vector = state_dict_hadmard_product(
114
+ retrieved_task_vector = state_dict_hadamard_product(
115
115
  tall_masks[model_name], multi_task_vector
116
116
  )
117
117
  retrieved_state_dict = state_dict_add(
@@ -6,11 +6,20 @@ http://arxiv.org/abs/2212.04089
6
6
 
7
7
  import logging
8
8
  from copy import deepcopy
9
- from typing import Dict, List, Mapping, Optional, TypeVar, Union # noqa: F401
9
+ from typing import ( # noqa: F401
10
+ TYPE_CHECKING,
11
+ Dict,
12
+ List,
13
+ Mapping,
14
+ Optional,
15
+ TypeVar,
16
+ Union,
17
+ )
10
18
 
11
19
  import torch
12
20
  from torch import nn
13
21
 
22
+ from fusion_bench import LazyStateDict
14
23
  from fusion_bench.method.base_algorithm import BaseAlgorithm
15
24
  from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
16
25
  from fusion_bench.modelpool import BaseModelPool
@@ -21,6 +30,8 @@ from fusion_bench.utils.state_dict_arithmetic import (
21
30
  )
22
31
  from fusion_bench.utils.type import StateDictType, TorchModelType
23
32
 
33
+ if TYPE_CHECKING:
34
+ from transformers import PreTrainedModel
24
35
  log = logging.getLogger(__name__)
25
36
 
26
37
 
@@ -125,25 +136,39 @@ class TaskArithmeticAlgorithm(
125
136
  with self.profile("merge weights"):
126
137
  if task_vector is None:
127
138
  task_vector = state_dict_sub(
128
- model.state_dict(keep_vars=True),
129
- pretrained_model.state_dict(keep_vars=True),
139
+ model.state_dict(),
140
+ pretrained_model.state_dict(),
130
141
  )
131
142
  else:
132
143
  task_vector = state_dict_add(
133
144
  task_vector,
134
145
  state_dict_sub(
135
- model.state_dict(keep_vars=True),
136
- pretrained_model.state_dict(keep_vars=True),
146
+ model.state_dict(),
147
+ pretrained_model.state_dict(),
137
148
  ),
138
149
  )
139
150
  with self.profile("merge weights"):
140
151
  # scale the task vector
141
152
  task_vector = state_dict_mul(task_vector, self.config.scaling_factor)
142
153
  # add the task vector to the pretrained model
143
- state_dict = state_dict_add(
144
- pretrained_model.state_dict(keep_vars=True), task_vector
145
- )
154
+ state_dict = state_dict_add(pretrained_model.state_dict(), task_vector)
146
155
 
147
156
  self.print_profile_summary()
148
- pretrained_model.load_state_dict(state_dict)
149
- return pretrained_model
157
+
158
+ # apply state dict to model
159
+ if isinstance(pretrained_model, nn.Module):
160
+ model = pretrained_model
161
+ model.load_state_dict(state_dict)
162
+ elif isinstance(pretrained_model, LazyStateDict):
163
+ model = deepcopy(pretrained_model.meta_module)
164
+ model = model.to_empty(device=pretrained_model._device)
165
+ result = model.load_state_dict(state_dict, strict=False)
166
+ if result.unexpected_keys:
167
+ raise ValueError(
168
+ f"Unexpected keys in state dict: {result.unexpected_keys}"
169
+ )
170
+ if result.missing_keys:
171
+ log.warning(f"Missing keys in state dict: {result.missing_keys}")
172
+ else:
173
+ raise TypeError(f"Unsupported model type: {type(pretrained_model)}")
174
+ return model
@@ -9,11 +9,14 @@ Overview of Ties-Merging:
9
9
  """
10
10
 
11
11
  import logging
12
+ from copy import deepcopy
12
13
  from typing import Any, Dict, List, Literal, Mapping, Union # noqa: F401
13
14
 
14
15
  import torch
15
16
  from torch import Tensor, nn
17
+ from transformers import PreTrainedModel
16
18
 
19
+ from fusion_bench import LazyStateDict
17
20
  from fusion_bench.compat.modelpool import to_modelpool
18
21
  from fusion_bench.method import BaseAlgorithm
19
22
  from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
@@ -98,12 +101,25 @@ class TiesMergingAlgorithm(
98
101
  merge_func=merge_func,
99
102
  )
100
103
  merged_check = flat_ptm + scaling_factor * merged_tv
101
- merged_state_dict = vector_to_state_dict(
104
+ state_dict = vector_to_state_dict(
102
105
  merged_check, ptm_check, remove_keys=remove_keys
103
106
  )
104
-
105
- # Load the merged state dict into the pretrained model
106
- pretrained_model.load_state_dict(merged_state_dict)
107
-
108
107
  self.print_profile_summary()
109
- return pretrained_model
108
+
109
+ # apply state dict to model
110
+ if isinstance(pretrained_model, nn.Module):
111
+ model = pretrained_model
112
+ model.load_state_dict(state_dict)
113
+ elif isinstance(pretrained_model, LazyStateDict):
114
+ model = deepcopy(pretrained_model.meta_module)
115
+ model = model.to_empty(device=pretrained_model._device)
116
+ result = model.load_state_dict(state_dict, strict=False)
117
+ if result.unexpected_keys:
118
+ raise ValueError(
119
+ f"Unexpected keys in state dict: {result.unexpected_keys}"
120
+ )
121
+ if result.missing_keys:
122
+ log.warning(f"Missing keys in state dict: {result.missing_keys}")
123
+ else:
124
+ raise TypeError(f"Unsupported model type: {type(pretrained_model)}")
125
+ return model
@@ -0,0 +1 @@
1
+ from .wudi import WUDIMerging, wudi_merging
@@ -0,0 +1,105 @@
1
+ """
2
+ Whoever Started the Interference Should End It: Guiding Data-Free Model Merging via Task Vectors
3
+ Arxiv: http://arxiv.org/abs/2503.08099
4
+ """
5
+
6
+ from typing import List
7
+
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
12
+ from fusion_bench.mixins import LightningFabricMixin
13
+ from fusion_bench.utils import timeit_context
14
+ from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_sub
15
+
16
+
17
+ def wudi_merging(
18
+ task_vectors: List[torch.Tensor],
19
+ accelerator="cuda",
20
+ iter_num: int = 300,
21
+ exclude_keys: List[str] = None,
22
+ ):
23
+ exclude_keys = [] if exclude_keys is None else exclude_keys
24
+
25
+ with timeit_context("WUDI Merging"):
26
+ new_vector = {}
27
+ for key in tqdm(task_vectors[0], desc="WUDI Merging", leave=False):
28
+ tqdm.write(f"key: {key}")
29
+ original_device = task_vectors[0][key].device
30
+ tvs = torch.stack(
31
+ [
32
+ task_vector[key].to(device=accelerator, non_blocking=True)
33
+ for task_vector in task_vectors
34
+ ]
35
+ )
36
+ num_tvs = len(tvs)
37
+ new_vector[key] = torch.nn.Parameter(torch.sum(tvs, dim=0))
38
+
39
+ if len(task_vectors[0][key].shape) == 2 and key not in exclude_keys:
40
+ optimizer = torch.optim.Adam([new_vector[key]], lr=1e-5, weight_decay=0)
41
+ l2_norms = torch.square(
42
+ torch.norm(tvs.reshape(tvs.shape[0], -1), p=2, dim=-1)
43
+ )
44
+ for i in tqdm(
45
+ range(iter_num),
46
+ ):
47
+ disturbing_vectors = new_vector[key].unsqueeze(0) - tvs
48
+ product = torch.matmul(disturbing_vectors, tvs.transpose(1, 2))
49
+ loss = torch.sum(
50
+ torch.square(product) / l2_norms.unsqueeze(-1).unsqueeze(-1)
51
+ )
52
+ optimizer.zero_grad()
53
+ loss.backward()
54
+ optimizer.step()
55
+ else:
56
+ new_vector[key] = new_vector[key] / num_tvs
57
+ new_vector[key] = new_vector[key].to(
58
+ device=original_device, non_blocking=True
59
+ )
60
+ return new_vector
61
+
62
+
63
+ @auto_register_config
64
+ class WUDIMerging(
65
+ LightningFabricMixin,
66
+ BaseAlgorithm,
67
+ ):
68
+ """
69
+ Whoever Started the Interference Should End It: Guiding Data-Free Model Merging via Task Vectors
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ iter_num: int,
75
+ exclude_keys: List[str] = None,
76
+ **kwargs,
77
+ ):
78
+ super().__init__(**kwargs)
79
+
80
+ def run(self, modelpool: BaseModelPool):
81
+ # load the pretrained model and the task vectors of all the finetuned models
82
+ with torch.no_grad():
83
+ pretrained_model = modelpool.load_pretrained_model()
84
+ task_vectors = []
85
+ for model_name in modelpool.model_names:
86
+ finetuned_model = modelpool.load_model(model_name)
87
+ task_vectors.append(
88
+ state_dict_sub(
89
+ finetuned_model.state_dict(), pretrained_model.state_dict()
90
+ )
91
+ )
92
+ del finetuned_model # free memory
93
+
94
+ merged_tv = wudi_merging(
95
+ task_vectors,
96
+ accelerator=self.fabric.device,
97
+ iter_num=self.iter_num,
98
+ exclude_keys=self.exclude_keys,
99
+ )
100
+
101
+ pretrained_model.load_state_dict(
102
+ state_dict_add(pretrained_model.state_dict(), merged_tv)
103
+ )
104
+
105
+ return pretrained_model
@@ -11,6 +11,7 @@ _import_structure = {
11
11
  "hydra_config": ["HydraConfigMixin"],
12
12
  "lightning_fabric": ["LightningFabricMixin"],
13
13
  "openclip_classification": ["OpenCLIPClassificationMixin"],
14
+ "pyinstrument": ["PyinstrumentProfilerMixin"],
14
15
  "serialization": [
15
16
  "BaseYAMLSerializable",
16
17
  "YAMLSerializationMixin",
@@ -25,6 +26,7 @@ if TYPE_CHECKING:
25
26
  from .hydra_config import HydraConfigMixin
26
27
  from .lightning_fabric import LightningFabricMixin
27
28
  from .openclip_classification import OpenCLIPClassificationMixin
29
+ from .pyinstrument import PyinstrumentProfilerMixin
28
30
  from .serialization import (
29
31
  BaseYAMLSerializable,
30
32
  YAMLSerializationMixin,
@@ -100,6 +100,10 @@ class LightningFabricMixin:
100
100
  self.setup_lightning_fabric(getattr(self, "config", DictConfig({})))
101
101
  return self._fabric_instance
102
102
 
103
+ @fabric.setter
104
+ def fabric(self, instance: L.Fabric):
105
+ self._fabric_instance = instance
106
+
103
107
  @property
104
108
  def log_dir(self):
105
109
  """
@@ -0,0 +1,174 @@
1
+ from contextlib import contextmanager
2
+ from pathlib import Path
3
+ from typing import Generator, Optional, Union
4
+
5
+ from lightning.fabric.utilities.rank_zero import rank_zero_only
6
+ from pyinstrument import Profiler
7
+
8
+ __all__ = ["PyinstrumentProfilerMixin"]
9
+
10
+
11
+ class PyinstrumentProfilerMixin:
12
+ """
13
+ A mixin class that provides statistical profiling capabilities using pyinstrument.
14
+
15
+ This mixin allows for easy profiling of code blocks using a context manager.
16
+ It provides methods to start and stop profiling actions, save profiling results
17
+ to files, and print profiling summaries.
18
+
19
+ Note:
20
+ This mixin requires the `pyinstrument` package to be installed.
21
+ If not available, an ImportError will be raised when importing this module.
22
+
23
+ Examples:
24
+
25
+ ```python
26
+ class MyClass(PyinstrumentProfilerMixin):
27
+ def do_something(self):
28
+ with self.profile("work"):
29
+ # do some work here
30
+ ...
31
+
32
+ # save the profiling results
33
+ self.save_profile_report("profile_report.html")
34
+
35
+ # or print the summary
36
+ self.print_profile_summary()
37
+ ```
38
+
39
+ Attributes:
40
+ _profiler (Profiler): An instance of the pyinstrument Profiler class.
41
+ """
42
+
43
+ _profiler: Optional[Profiler] = None
44
+ _is_profiling: bool = False
45
+
46
+ @property
47
+ def profiler(self) -> Optional[Profiler]:
48
+ """Get the profiler instance, creating it if necessary."""
49
+ if self._profiler is None:
50
+ self._profiler = Profiler()
51
+ return self._profiler
52
+
53
+ @contextmanager
54
+ def profile(self, action_name: Optional[str] = None) -> Generator:
55
+ """
56
+ Context manager for profiling a code block.
57
+
58
+ Args:
59
+ action_name: Optional name for the profiling action (for logging purposes).
60
+
61
+ Example:
62
+
63
+ ```python
64
+ with self.profile("expensive_operation"):
65
+ # do some expensive work here
66
+ expensive_function()
67
+ ```
68
+ """
69
+ try:
70
+ self.start_profile(action_name)
71
+ yield action_name
72
+ finally:
73
+ self.stop_profile(action_name)
74
+
75
+ def start_profile(self, action_name: Optional[str] = None):
76
+ """
77
+ Start profiling.
78
+
79
+ Args:
80
+ action_name: Optional name for the profiling action.
81
+ """
82
+ if self._is_profiling:
83
+ return
84
+
85
+ self.profiler.start()
86
+ self._is_profiling = True
87
+ if action_name:
88
+ print(f"Started profiling: {action_name}")
89
+
90
+ def stop_profile(self, action_name: Optional[str] = None):
91
+ """
92
+ Stop profiling.
93
+
94
+ Args:
95
+ action_name: Optional name for the profiling action.
96
+ """
97
+ if not self._is_profiling:
98
+ return
99
+
100
+ self.profiler.stop()
101
+ self._is_profiling = False
102
+ if action_name:
103
+ print(f"Stopped profiling: {action_name}")
104
+
105
+ @rank_zero_only
106
+ def print_profile_summary(
107
+ self, title: Optional[str] = None, unicode: bool = True, color: bool = True
108
+ ):
109
+ """
110
+ Print a summary of the profiling results.
111
+
112
+ Args:
113
+ title: Optional title to print before the summary.
114
+ unicode: Whether to use unicode characters in the output.
115
+ color: Whether to use color in the output.
116
+ """
117
+ if self.profiler is None:
118
+ print("No profiling data available.")
119
+ return
120
+
121
+ if title is not None:
122
+ print(title)
123
+
124
+ print(self.profiler.output_text(unicode=unicode, color=color))
125
+
126
+ @rank_zero_only
127
+ def save_profile_report(
128
+ self,
129
+ output_path: Union[str, Path] = "profile_report.html",
130
+ format: str = "html",
131
+ title: Optional[str] = None,
132
+ ):
133
+ """
134
+ Save the profiling results to a file.
135
+
136
+ Args:
137
+ output_path: Path where to save the profiling report.
138
+ format: Output format ('html', or 'text').
139
+ title: Optional title for the report.
140
+ """
141
+ if self.profiler is None:
142
+ print("No profiling data available.")
143
+ return
144
+
145
+ output_path = Path(output_path)
146
+ output_path.parent.mkdir(parents=True, exist_ok=True)
147
+
148
+ if format.lower() == "html":
149
+ content = self.profiler.output_html()
150
+ elif format.lower() == "text":
151
+ content = self.profiler.output_text(unicode=True, color=False)
152
+ else:
153
+ raise ValueError(f"Unsupported format: {format}. Use 'html', or 'text'.")
154
+
155
+ with open(output_path, "w", encoding="utf-8") as f:
156
+ f.write(content)
157
+
158
+ print(f"Profile report saved to: {output_path}")
159
+
160
+ def reset_profile(self):
161
+ """Reset the profiler to start fresh."""
162
+ if self._is_profiling:
163
+ self.stop_profile()
164
+
165
+ self._profiler = None
166
+
167
+ def __del__(self):
168
+ """Cleanup when the object is destroyed."""
169
+ if self._is_profiling:
170
+ self.stop_profile()
171
+
172
+ if self._profiler is not None:
173
+ del self._profiler
174
+ self._profiler = None