cache-dit 0.2.29__py3-none-any.whl → 0.2.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

@@ -1,5 +1,6 @@
1
1
  import torch
2
2
 
3
+ from typing import Dict, Any
3
4
  from cache_dit.cache_factory import ForwardPattern
4
5
  from cache_dit.cache_factory.cache_blocks.pattern_base import (
5
6
  CachedBlocks_Pattern_Base,
@@ -31,7 +32,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
31
32
  # Call first `n` blocks to process the hidden states for
32
33
  # more stable diff calculation.
33
34
  # encoder_hidden_states: None Pattern 3, else 4, 5
34
- hidden_states, encoder_hidden_states = self.call_Fn_blocks(
35
+ hidden_states, new_encoder_hidden_states = self.call_Fn_blocks(
35
36
  hidden_states,
36
37
  *args,
37
38
  **kwargs,
@@ -60,11 +61,10 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
60
61
  if can_use_cache:
61
62
  self.cache_manager.add_cached_step()
62
63
  del Fn_hidden_states_residual
63
- hidden_states, encoder_hidden_states = (
64
+ hidden_states, new_encoder_hidden_states = (
64
65
  self.cache_manager.apply_cache(
65
66
  hidden_states,
66
- # None Pattern 3, else 4, 5
67
- encoder_hidden_states,
67
+ new_encoder_hidden_states, # encoder_hidden_states not use cache
68
68
  prefix=(
69
69
  f"{self.cache_prefix}_Bn_residual"
70
70
  if self.cache_manager.is_cache_residual()
@@ -80,12 +80,12 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
80
80
  torch._dynamo.graph_break()
81
81
  # Call last `n` blocks to further process the hidden states
82
82
  # for higher precision.
83
- hidden_states, encoder_hidden_states = self.call_Bn_blocks(
84
- hidden_states,
85
- encoder_hidden_states,
86
- *args,
87
- **kwargs,
88
- )
83
+ if self.cache_manager.Bn_compute_blocks() > 0:
84
+ hidden_states, new_encoder_hidden_states = self.call_Bn_blocks(
85
+ hidden_states,
86
+ *args,
87
+ **kwargs,
88
+ )
89
89
  else:
90
90
  self.cache_manager.set_Fn_buffer(
91
91
  Fn_hidden_states_residual,
@@ -99,19 +99,20 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
99
99
  )
100
100
  del Fn_hidden_states_residual
101
101
  torch._dynamo.graph_break()
102
+ old_encoder_hidden_states = new_encoder_hidden_states
102
103
  (
103
104
  hidden_states,
104
- encoder_hidden_states,
105
+ new_encoder_hidden_states,
105
106
  hidden_states_residual,
106
- # None Pattern 3, else 4, 5
107
- encoder_hidden_states_residual,
108
107
  ) = self.call_Mn_blocks( # middle
109
108
  hidden_states,
110
- # None Pattern 3, else 4, 5
111
- encoder_hidden_states,
112
109
  *args,
113
110
  **kwargs,
114
111
  )
112
+ if new_encoder_hidden_states is not None:
113
+ new_encoder_hidden_states_residual = (
114
+ new_encoder_hidden_states - old_encoder_hidden_states
115
+ )
115
116
  torch._dynamo.graph_break()
116
117
  if self.cache_manager.is_cache_residual():
117
118
  self.cache_manager.set_Bn_buffer(
@@ -119,34 +120,32 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
119
120
  prefix=f"{self.cache_prefix}_Bn_residual",
120
121
  )
121
122
  else:
122
- # TaylorSeer
123
123
  self.cache_manager.set_Bn_buffer(
124
124
  hidden_states,
125
125
  prefix=f"{self.cache_prefix}_Bn_hidden_states",
126
126
  )
127
+
127
128
  if self.cache_manager.is_encoder_cache_residual():
128
- self.cache_manager.set_Bn_encoder_buffer(
129
- # None Pattern 3, else 4, 5
130
- encoder_hidden_states_residual,
131
- prefix=f"{self.cache_prefix}_Bn_residual",
132
- )
129
+ if new_encoder_hidden_states is not None:
130
+ self.cache_manager.set_Bn_encoder_buffer(
131
+ new_encoder_hidden_states_residual,
132
+ prefix=f"{self.cache_prefix}_Bn_residual",
133
+ )
133
134
  else:
134
- # TaylorSeer
135
- self.cache_manager.set_Bn_encoder_buffer(
136
- # None Pattern 3, else 4, 5
137
- encoder_hidden_states,
138
- prefix=f"{self.cache_prefix}_Bn_hidden_states",
139
- )
135
+ if new_encoder_hidden_states is not None:
136
+ self.cache_manager.set_Bn_encoder_buffer(
137
+ new_encoder_hidden_states_residual,
138
+ prefix=f"{self.cache_prefix}_Bn_hidden_states",
139
+ )
140
140
  torch._dynamo.graph_break()
141
141
  # Call last `n` blocks to further process the hidden states
142
142
  # for higher precision.
143
- hidden_states, encoder_hidden_states = self.call_Bn_blocks(
144
- hidden_states,
145
- # None Pattern 3, else 4, 5
146
- encoder_hidden_states,
147
- *args,
148
- **kwargs,
149
- )
143
+ if self.cache_manager.Bn_compute_blocks() > 0:
144
+ hidden_states, new_encoder_hidden_states = self.call_Bn_blocks(
145
+ hidden_states,
146
+ *args,
147
+ **kwargs,
148
+ )
150
149
 
151
150
  torch._dynamo.graph_break()
152
151
 
@@ -154,12 +153,21 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
154
153
  hidden_states
155
154
  if self.forward_pattern.Return_H_Only
156
155
  else (
157
- (hidden_states, encoder_hidden_states)
156
+ (hidden_states, new_encoder_hidden_states)
158
157
  if self.forward_pattern.Return_H_First
159
- else (encoder_hidden_states, hidden_states)
158
+ else (new_encoder_hidden_states, hidden_states)
160
159
  )
161
160
  )
162
161
 
162
+ @torch.compiler.disable
163
+ def maybe_update_kwargs(
164
+ self, encoder_hidden_states, kwargs: Dict[str, Any]
165
+ ) -> Dict[str, Any]:
166
+ # if "encoder_hidden_states" in kwargs:
167
+ # kwargs["encoder_hidden_states"] = encoder_hidden_states
168
+ # return kwargs
169
+ return kwargs
170
+
163
171
  def call_Fn_blocks(
164
172
  self,
165
173
  hidden_states: torch.Tensor,
@@ -172,7 +180,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
172
180
  f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
173
181
  f"the number of transformer blocks {len(self.transformer_blocks)}"
174
182
  )
175
- encoder_hidden_states = None # Pattern 3
183
+ new_encoder_hidden_states = None
176
184
  for block in self._Fn_blocks():
177
185
  hidden_states = block(
178
186
  hidden_states,
@@ -180,25 +188,27 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
180
188
  **kwargs,
181
189
  )
182
190
  if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
183
- hidden_states, encoder_hidden_states = hidden_states
191
+ hidden_states, new_encoder_hidden_states = hidden_states
184
192
  if not self.forward_pattern.Return_H_First:
185
- hidden_states, encoder_hidden_states = (
186
- encoder_hidden_states,
193
+ hidden_states, new_encoder_hidden_states = (
194
+ new_encoder_hidden_states,
187
195
  hidden_states,
188
196
  )
197
+ kwargs = self.maybe_update_kwargs(
198
+ new_encoder_hidden_states,
199
+ kwargs,
200
+ )
189
201
 
190
- return hidden_states, encoder_hidden_states
202
+ return hidden_states, new_encoder_hidden_states
191
203
 
192
204
  def call_Mn_blocks(
193
205
  self,
194
206
  hidden_states: torch.Tensor,
195
- # None Pattern 3, else 4, 5
196
- encoder_hidden_states: torch.Tensor | None,
197
207
  *args,
198
208
  **kwargs,
199
209
  ):
200
210
  original_hidden_states = hidden_states
201
- original_encoder_hidden_states = encoder_hidden_states
211
+ new_encoder_hidden_states = None
202
212
  for block in self._Mn_blocks():
203
213
  hidden_states = block(
204
214
  hidden_states,
@@ -206,44 +216,33 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
206
216
  **kwargs,
207
217
  )
208
218
  if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
209
- hidden_states, encoder_hidden_states = hidden_states
219
+ hidden_states, new_encoder_hidden_states = hidden_states
210
220
  if not self.forward_pattern.Return_H_First:
211
- hidden_states, encoder_hidden_states = (
212
- encoder_hidden_states,
221
+ hidden_states, new_encoder_hidden_states = (
222
+ new_encoder_hidden_states,
213
223
  hidden_states,
214
224
  )
225
+ kwargs = self.maybe_update_kwargs(
226
+ new_encoder_hidden_states,
227
+ kwargs,
228
+ )
215
229
 
216
230
  # compute hidden_states residual
217
231
  hidden_states = hidden_states.contiguous()
218
232
  hidden_states_residual = hidden_states - original_hidden_states
219
- if (
220
- original_encoder_hidden_states is not None
221
- and encoder_hidden_states is not None
222
- ): # Pattern 4, 5
223
- encoder_hidden_states_residual = (
224
- encoder_hidden_states - original_encoder_hidden_states
225
- )
226
- else:
227
- encoder_hidden_states_residual = None # Pattern 3
228
233
 
229
234
  return (
230
235
  hidden_states,
231
- encoder_hidden_states,
236
+ new_encoder_hidden_states,
232
237
  hidden_states_residual,
233
- encoder_hidden_states_residual,
234
238
  )
235
239
 
236
240
  def call_Bn_blocks(
237
241
  self,
238
242
  hidden_states: torch.Tensor,
239
- # None Pattern 3, else 4, 5
240
- encoder_hidden_states: torch.Tensor | None,
241
243
  *args,
242
244
  **kwargs,
243
245
  ):
244
- if self.cache_manager.Bn_compute_blocks() == 0:
245
- return hidden_states, encoder_hidden_states
246
-
247
246
  assert self.cache_manager.Bn_compute_blocks() <= len(
248
247
  self.transformer_blocks
249
248
  ), (
@@ -264,11 +263,15 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
264
263
  **kwargs,
265
264
  )
266
265
  if not isinstance(hidden_states, torch.Tensor): # Pattern 4,5
267
- hidden_states, encoder_hidden_states = hidden_states
266
+ hidden_states, new_encoder_hidden_states = hidden_states
268
267
  if not self.forward_pattern.Return_H_First:
269
- hidden_states, encoder_hidden_states = (
270
- encoder_hidden_states,
268
+ hidden_states, new_encoder_hidden_states = (
269
+ new_encoder_hidden_states,
271
270
  hidden_states,
272
271
  )
272
+ kwargs = self.maybe_update_kwargs(
273
+ new_encoder_hidden_states,
274
+ kwargs,
275
+ )
273
276
 
274
- return hidden_states, encoder_hidden_states
277
+ return hidden_states, new_encoder_hidden_states
@@ -25,6 +25,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
25
25
  transformer_blocks: torch.nn.ModuleList,
26
26
  transformer: torch.nn.Module = None,
27
27
  forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
28
+ check_forward_pattern: bool = True,
28
29
  check_num_outputs: bool = True,
29
30
  # 1. Cache context configuration
30
31
  cache_prefix: str = None, # maybe un-need.
@@ -38,6 +39,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
38
39
  self.transformer = transformer
39
40
  self.transformer_blocks = transformer_blocks
40
41
  self.forward_pattern = forward_pattern
42
+ self.check_forward_pattern = check_forward_pattern
41
43
  self.check_num_outputs = check_num_outputs
42
44
  # 1. Cache context configuration
43
45
  self.cache_prefix = cache_prefix
@@ -52,6 +54,12 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
52
54
  )
53
55
 
54
56
  def _check_forward_pattern(self):
57
+ if not self.check_forward_pattern:
58
+ logger.warning(
59
+ f"Skipped Forward Pattern Check: {self.forward_pattern}"
60
+ )
61
+ return
62
+
55
63
  assert (
56
64
  self.forward_pattern.Supported
57
65
  and self.forward_pattern in self._supported_patterns
@@ -59,6 +67,11 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
59
67
 
60
68
  if self.transformer_blocks is not None:
61
69
  for block in self.transformer_blocks:
70
+ # Special case for HiDreamBlock
71
+ if hasattr(block, "block"):
72
+ if isinstance(block.block, torch.nn.Module):
73
+ block = block.block
74
+
62
75
  forward_parameters = set(
63
76
  inspect.signature(block.forward).parameters.keys()
64
77
  )
@@ -733,17 +733,15 @@ class CachedContextManager:
733
733
  encoder_prefix
734
734
  )
735
735
 
736
- assert (
737
- encoder_hidden_states_prev is not None
738
- ), f"{prefix}_encoder_buffer must be set before"
736
+ if encoder_hidden_states_prev is not None:
739
737
 
740
- if self.is_encoder_cache_residual():
741
- encoder_hidden_states = (
742
- encoder_hidden_states_prev + encoder_hidden_states
743
- )
744
- else:
745
- # If encoder cache is not residual, we use the encoder hidden states directly
746
- encoder_hidden_states = encoder_hidden_states_prev
738
+ if self.is_encoder_cache_residual():
739
+ encoder_hidden_states = (
740
+ encoder_hidden_states_prev + encoder_hidden_states
741
+ )
742
+ else:
743
+ # If encoder cache is not residual, we use the encoder hidden states directly
744
+ encoder_hidden_states = encoder_hidden_states_prev
747
745
 
748
746
  encoder_hidden_states = encoder_hidden_states.contiguous()
749
747
 
@@ -1,11 +1,9 @@
1
- import torch
2
- from typing import Any, Tuple, List
1
+ from typing import Any, Tuple, List, Union
3
2
  from diffusers import DiffusionPipeline
4
3
  from cache_dit.cache_factory.cache_types import CacheType
5
4
  from cache_dit.cache_factory.block_adapters import BlockAdapter
6
5
  from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
7
6
  from cache_dit.cache_factory.cache_adapters import CachedAdapter
8
- from cache_dit.cache_factory.cache_contexts import CachedContextManager
9
7
 
10
8
  from cache_dit.logger import init_logger
11
9
 
@@ -14,7 +12,10 @@ logger = init_logger(__name__)
14
12
 
15
13
  def enable_cache(
16
14
  # DiffusionPipeline or BlockAdapter
17
- pipe_or_adapter: DiffusionPipeline | BlockAdapter,
15
+ pipe_or_adapter: Union[
16
+ DiffusionPipeline,
17
+ BlockAdapter,
18
+ ],
18
19
  # Cache context kwargs
19
20
  Fn_compute_blocks: int = 8,
20
21
  Bn_compute_blocks: int = 0,
@@ -32,7 +33,10 @@ def enable_cache(
32
33
  taylorseer_cache_type: str = "residual",
33
34
  taylorseer_order: int = 2,
34
35
  **other_cache_context_kwargs,
35
- ) -> BlockAdapter:
36
+ ) -> Union[
37
+ DiffusionPipeline,
38
+ BlockAdapter,
39
+ ]:
36
40
  r"""
37
41
  Unified Cache API for almost Any Diffusion Transformers (with Transformer Blocks
38
42
  that match the specific Input and Output patterns).
@@ -102,11 +106,11 @@ def enable_cache(
102
106
  >>> cache_dit.enable_cache(pipe) # One-line code with default cache options.
103
107
  >>> output = pipe(...) # Just call the pipe as normal.
104
108
  >>> stats = cache_dit.summary(pipe) # Then, get the summary of cache acceleration stats.
109
+ >>> cache_dit.disable_cache(pipe) # Disable cache and run original pipe.
105
110
  """
106
-
107
111
  # Collect cache context kwargs
108
112
  cache_context_kwargs = other_cache_context_kwargs.copy()
109
- if cache_type := cache_context_kwargs.get("cache_type", None):
113
+ if (cache_type := cache_context_kwargs.get("cache_type", None)) is not None:
110
114
  if cache_type == CacheType.NONE:
111
115
  return pipe_or_adapter
112
116
 
@@ -145,79 +149,17 @@ def enable_cache(
145
149
 
146
150
 
147
151
  def disable_cache(
148
- # DiffusionPipeline or BlockAdapter
149
- pipe_or_adapter: DiffusionPipeline | BlockAdapter,
152
+ pipe_or_adapter: Union[
153
+ DiffusionPipeline,
154
+ BlockAdapter,
155
+ ],
150
156
  ):
151
- from cache_dit.cache_factory.cache_blocks.utils import (
152
- remove_cached_stats,
157
+ CachedAdapter.maybe_release_hooks(pipe_or_adapter)
158
+ logger.warning(
159
+ f"Cache Acceleration is disabled for: "
160
+ f"{pipe_or_adapter.__class__.__name__}."
153
161
  )
154
162
 
155
- def _disable_blocks(blocks: torch.nn.ModuleList):
156
- if blocks is None:
157
- return
158
- if hasattr(blocks, "_forward_pattern"):
159
- del blocks._forward_pattern
160
- if hasattr(blocks, "_cache_context_kwargs"):
161
- del blocks._cache_context_kwargs
162
- remove_cached_stats(blocks)
163
-
164
- def _disable_transformer(transformer: torch.nn.Module):
165
- if transformer is None or not BlockAdapter.is_cached(transformer):
166
- return
167
- if original_forward := getattr(transformer, "_original_forward"):
168
- transformer.forward = original_forward.__get__(transformer)
169
- del transformer._original_forward
170
- if hasattr(transformer, "_is_cached"):
171
- del transformer._is_cached
172
- if hasattr(transformer, "_forward_pattern"):
173
- del transformer._forward_pattern
174
- if hasattr(transformer, "_has_separate_cfg"):
175
- del transformer._has_separate_cfg
176
- if hasattr(transformer, "_cache_context_kwargs"):
177
- del transformer._cache_context_kwargs
178
- remove_cached_stats(transformer)
179
- for blocks in BlockAdapter.find_blocks(transformer):
180
- _disable_blocks(blocks)
181
-
182
- def _disable_pipe(pipe: DiffusionPipeline):
183
- if pipe is None or not BlockAdapter.is_cached(pipe):
184
- return
185
- if original_call := getattr(pipe, "_original_call"):
186
- pipe.__class__.__call__ = original_call
187
- del pipe.__class__._original_call
188
- if cache_manager := getattr(pipe, "_cache_manager"):
189
- assert isinstance(cache_manager, CachedContextManager)
190
- cache_manager.clear_contexts()
191
- del pipe._cache_manager
192
- if hasattr(pipe, "_is_cached"):
193
- del pipe.__class__._is_cached
194
- if hasattr(pipe, "_cache_context_kwargs"):
195
- del pipe._cache_context_kwargs
196
- remove_cached_stats(pipe)
197
-
198
- if isinstance(pipe_or_adapter, DiffusionPipeline):
199
- pipe = pipe_or_adapter
200
- _disable_pipe(pipe)
201
- if hasattr(pipe, "transformer"):
202
- _disable_transformer(pipe.transformer)
203
- if hasattr(pipe, "transformer_2"): # Wan 2.2
204
- _disable_transformer(pipe.transformer_2)
205
- pipe_cls_name = pipe.__class__.__name__
206
- logger.warning(f"Cache Acceleration is disabled for: {pipe_cls_name}")
207
- elif isinstance(pipe_or_adapter, BlockAdapter):
208
- # BlockAdapter
209
- adapter = pipe_or_adapter
210
- BlockAdapter.assert_normalized(adapter)
211
- _disable_pipe(adapter.pipe)
212
- for transformer in BlockAdapter.flatten(adapter.transformer):
213
- _disable_transformer(transformer)
214
- for blocks in BlockAdapter.flatten(adapter.blocks):
215
- _disable_blocks(blocks)
216
- pipe_cls_name = adapter.pipe.__class__.__name__
217
- logger.warning(f"Cache Acceleration is disabled for: {pipe_cls_name}")
218
- else:
219
- pass # do nothing
220
-
221
163
 
222
164
  def supported_pipelines(
223
165
  **kwargs,
@@ -22,11 +22,11 @@ def cache_type(type_hint: "CacheType | str") -> "CacheType":
22
22
  if isinstance(type_hint, CacheType):
23
23
  return type_hint
24
24
 
25
- elif type_hint.lower() in (
26
- "dual_block_cache",
27
- "db_cache",
28
- "dbcache",
29
- "db",
25
+ elif type_hint.upper() in (
26
+ "DUAL_BLOCK_CACHE",
27
+ "DB_CACHE",
28
+ "DBCACHE",
29
+ "DB",
30
30
  ):
31
31
  return CacheType.DBCache
32
32
  return CacheType.NONE
@@ -3,3 +3,9 @@ from cache_dit.cache_factory.patch_functors.functor_flux import FluxPatchFunctor
3
3
  from cache_dit.cache_factory.patch_functors.functor_chroma import (
4
4
  ChromaPatchFunctor,
5
5
  )
6
+ from cache_dit.cache_factory.patch_functors.functor_hidream import (
7
+ HiDreamPatchFunctor,
8
+ )
9
+ from cache_dit.cache_factory.patch_functors.functor_hunyuan_dit import (
10
+ HunyuanDiTPatchFunctor,
11
+ )
@@ -46,8 +46,10 @@ class ChromaPatchFunctor(PatchFunctor):
46
46
  block.forward = __patch_single_forward__.__get__(block)
47
47
  is_patched = True
48
48
 
49
+ cls_name = transformer.__class__.__name__
50
+
49
51
  if is_patched:
50
- logger.warning("Patched Chroma for cache-dit.")
52
+ logger.warning(f"Patched {cls_name} for cache-dit.")
51
53
  assert not getattr(transformer, "_is_parallelized", False), (
52
54
  "Please call `cache_dit.enable_cache` before Parallelize, "
53
55
  "the __patch_transformer_forward__ will overwrite the "
@@ -56,9 +58,9 @@ class ChromaPatchFunctor(PatchFunctor):
56
58
  transformer.forward = __patch_transformer_forward__.__get__(
57
59
  transformer
58
60
  )
59
- transformer._is_patched = True
60
61
 
61
- cls_name = transformer.__class__.__name__
62
+ transformer._is_patched = is_patched # True or False
63
+
62
64
  logger.info(
63
65
  f"Applied {self.__class__.__name__} for {cls_name}, "
64
66
  f"Patch: {is_patched}."
@@ -47,8 +47,10 @@ class FluxPatchFunctor(PatchFunctor):
47
47
  block.forward = __patch_single_forward__.__get__(block)
48
48
  is_patched = True
49
49
 
50
+ cls_name = transformer.__class__.__name__
51
+
50
52
  if is_patched:
51
- logger.warning("Patched Flux for cache-dit.")
53
+ logger.warning(f"Patched {cls_name} for cache-dit.")
52
54
  assert not getattr(transformer, "_is_parallelized", False), (
53
55
  "Please call `cache_dit.enable_cache` before Parallelize, "
54
56
  "the __patch_transformer_forward__ will overwrite the "
@@ -57,9 +59,9 @@ class FluxPatchFunctor(PatchFunctor):
57
59
  transformer.forward = __patch_transformer_forward__.__get__(
58
60
  transformer
59
61
  )
60
- transformer._is_patched = True
61
62
 
62
- cls_name = transformer.__class__.__name__
63
+ transformer._is_patched = is_patched # True or False
64
+
63
65
  logger.info(
64
66
  f"Applied {self.__class__.__name__} for {cls_name}, "
65
67
  f"Patch: {is_patched}."