cache-dit 0.2.24__py3-none-any.whl → 0.2.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (29) hide show
  1. cache_dit/__init__.py +5 -0
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +2 -0
  4. cache_dit/cache_factory/cache_adapters.py +375 -26
  5. cache_dit/cache_factory/cache_blocks/__init__.py +20 -0
  6. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +16 -0
  7. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +270 -0
  8. cache_dit/cache_factory/{cache_blocks.py → cache_blocks/pattern_base.py} +17 -18
  9. cache_dit/cache_factory/cache_blocks/utils.py +19 -0
  10. cache_dit/cache_factory/cache_context.py +32 -25
  11. cache_dit/cache_factory/cache_interface.py +8 -3
  12. cache_dit/cache_factory/forward_pattern.py +45 -24
  13. cache_dit/cache_factory/patch_functors/__init__.py +5 -0
  14. cache_dit/cache_factory/patch_functors/functor_base.py +18 -0
  15. cache_dit/cache_factory/patch_functors/functor_chroma.py +273 -0
  16. cache_dit/cache_factory/{patch/flux.py → patch_functors/functor_flux.py} +45 -31
  17. cache_dit/compile/utils.py +1 -1
  18. cache_dit/quantize/__init__.py +1 -0
  19. cache_dit/quantize/quantize_ao.py +196 -0
  20. cache_dit/quantize/quantize_interface.py +46 -0
  21. cache_dit/utils.py +49 -17
  22. {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/METADATA +43 -18
  23. cache_dit-0.2.26.dist-info/RECORD +42 -0
  24. cache_dit-0.2.24.dist-info/RECORD +0 -32
  25. /cache_dit/{cache_factory/patch/__init__.py → quantize/quantize_svdq.py} +0 -0
  26. {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/WHEEL +0 -0
  27. {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/entry_points.txt +0 -0
  28. {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/licenses/LICENSE +0 -0
  29. {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,270 @@
1
+ import torch
2
+
3
+ from cache_dit.cache_factory import cache_context
4
+ from cache_dit.cache_factory import ForwardPattern
5
+ from cache_dit.cache_factory.cache_blocks.utils import (
6
+ patch_cached_stats,
7
+ )
8
+ from cache_dit.cache_factory.cache_blocks.pattern_base import (
9
+ DBCachedBlocks_Pattern_Base,
10
+ )
11
+ from cache_dit.logger import init_logger
12
+
13
+ logger = init_logger(__name__)
14
+
15
+
16
+ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
17
+ _supported_patterns = [
18
+ ForwardPattern.Pattern_3,
19
+ ForwardPattern.Pattern_4,
20
+ ForwardPattern.Pattern_5,
21
+ ]
22
+
23
+ def forward(
24
+ self,
25
+ hidden_states: torch.Tensor,
26
+ *args,
27
+ **kwargs,
28
+ ):
29
+ original_hidden_states = hidden_states
30
+ # Call first `n` blocks to process the hidden states for
31
+ # more stable diff calculation.
32
+ # encoder_hidden_states: None Pattern 3, else 4, 5
33
+ hidden_states, encoder_hidden_states = self.call_Fn_blocks(
34
+ hidden_states,
35
+ *args,
36
+ **kwargs,
37
+ )
38
+
39
+ Fn_hidden_states_residual = hidden_states - original_hidden_states
40
+ del original_hidden_states
41
+
42
+ cache_context.mark_step_begin()
43
+ # Residual L1 diff or Hidden States L1 diff
44
+ can_use_cache = cache_context.get_can_use_cache(
45
+ (
46
+ Fn_hidden_states_residual
47
+ if not cache_context.is_l1_diff_enabled()
48
+ else hidden_states
49
+ ),
50
+ parallelized=self._is_parallelized(),
51
+ prefix=(
52
+ "Fn_residual"
53
+ if not cache_context.is_l1_diff_enabled()
54
+ else "Fn_hidden_states"
55
+ ),
56
+ )
57
+
58
+ torch._dynamo.graph_break()
59
+ if can_use_cache:
60
+ cache_context.add_cached_step()
61
+ del Fn_hidden_states_residual
62
+ hidden_states, encoder_hidden_states = (
63
+ cache_context.apply_hidden_states_residual(
64
+ hidden_states,
65
+ # None Pattern 3, else 4, 5
66
+ encoder_hidden_states,
67
+ prefix=(
68
+ "Bn_residual"
69
+ if cache_context.is_cache_residual()
70
+ else "Bn_hidden_states"
71
+ ),
72
+ encoder_prefix=(
73
+ "Bn_residual"
74
+ if cache_context.is_encoder_cache_residual()
75
+ else "Bn_hidden_states"
76
+ ),
77
+ )
78
+ )
79
+ torch._dynamo.graph_break()
80
+ # Call last `n` blocks to further process the hidden states
81
+ # for higher precision.
82
+ hidden_states, encoder_hidden_states = self.call_Bn_blocks(
83
+ hidden_states,
84
+ encoder_hidden_states,
85
+ *args,
86
+ **kwargs,
87
+ )
88
+ else:
89
+ cache_context.set_Fn_buffer(
90
+ Fn_hidden_states_residual, prefix="Fn_residual"
91
+ )
92
+ if cache_context.is_l1_diff_enabled():
93
+ # for hidden states L1 diff
94
+ cache_context.set_Fn_buffer(hidden_states, "Fn_hidden_states")
95
+ del Fn_hidden_states_residual
96
+ torch._dynamo.graph_break()
97
+ (
98
+ hidden_states,
99
+ encoder_hidden_states,
100
+ hidden_states_residual,
101
+ # None Pattern 3, else 4, 5
102
+ encoder_hidden_states_residual,
103
+ ) = self.call_Mn_blocks( # middle
104
+ hidden_states,
105
+ # None Pattern 3, else 4, 5
106
+ encoder_hidden_states,
107
+ *args,
108
+ **kwargs,
109
+ )
110
+ torch._dynamo.graph_break()
111
+ if cache_context.is_cache_residual():
112
+ cache_context.set_Bn_buffer(
113
+ hidden_states_residual,
114
+ prefix="Bn_residual",
115
+ )
116
+ else:
117
+ # TaylorSeer
118
+ cache_context.set_Bn_buffer(
119
+ hidden_states,
120
+ prefix="Bn_hidden_states",
121
+ )
122
+ if cache_context.is_encoder_cache_residual():
123
+ cache_context.set_Bn_encoder_buffer(
124
+ # None Pattern 3, else 4, 5
125
+ encoder_hidden_states_residual,
126
+ prefix="Bn_residual",
127
+ )
128
+ else:
129
+ # TaylorSeer
130
+ cache_context.set_Bn_encoder_buffer(
131
+ # None Pattern 3, else 4, 5
132
+ encoder_hidden_states,
133
+ prefix="Bn_hidden_states",
134
+ )
135
+ torch._dynamo.graph_break()
136
+ # Call last `n` blocks to further process the hidden states
137
+ # for higher precision.
138
+ hidden_states, encoder_hidden_states = self.call_Bn_blocks(
139
+ hidden_states,
140
+ # None Pattern 3, else 4, 5
141
+ encoder_hidden_states,
142
+ *args,
143
+ **kwargs,
144
+ )
145
+
146
+ patch_cached_stats(self.transformer)
147
+ torch._dynamo.graph_break()
148
+
149
+ return (
150
+ hidden_states
151
+ if self.forward_pattern.Return_H_Only
152
+ else (
153
+ (hidden_states, encoder_hidden_states)
154
+ if self.forward_pattern.Return_H_First
155
+ else (encoder_hidden_states, hidden_states)
156
+ )
157
+ )
158
+
159
+ def call_Fn_blocks(
160
+ self,
161
+ hidden_states: torch.Tensor,
162
+ *args,
163
+ **kwargs,
164
+ ):
165
+ assert cache_context.Fn_compute_blocks() <= len(
166
+ self.transformer_blocks
167
+ ), (
168
+ f"Fn_compute_blocks {cache_context.Fn_compute_blocks()} must be less than "
169
+ f"the number of transformer blocks {len(self.transformer_blocks)}"
170
+ )
171
+ encoder_hidden_states = None # Pattern 3
172
+ for block in self._Fn_blocks():
173
+ hidden_states = block(
174
+ hidden_states,
175
+ *args,
176
+ **kwargs,
177
+ )
178
+ if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
179
+ hidden_states, encoder_hidden_states = hidden_states
180
+ if not self.forward_pattern.Return_H_First:
181
+ hidden_states, encoder_hidden_states = (
182
+ encoder_hidden_states,
183
+ hidden_states,
184
+ )
185
+
186
+ return hidden_states, encoder_hidden_states
187
+
188
+ def call_Mn_blocks(
189
+ self,
190
+ hidden_states: torch.Tensor,
191
+ # None Pattern 3, else 4, 5
192
+ encoder_hidden_states: torch.Tensor | None,
193
+ *args,
194
+ **kwargs,
195
+ ):
196
+ original_hidden_states = hidden_states
197
+ original_encoder_hidden_states = encoder_hidden_states
198
+ for block in self._Mn_blocks():
199
+ hidden_states = block(
200
+ hidden_states,
201
+ *args,
202
+ **kwargs,
203
+ )
204
+ if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
205
+ hidden_states, encoder_hidden_states = hidden_states
206
+ if not self.forward_pattern.Return_H_First:
207
+ hidden_states, encoder_hidden_states = (
208
+ encoder_hidden_states,
209
+ hidden_states,
210
+ )
211
+
212
+ # compute hidden_states residual
213
+ hidden_states = hidden_states.contiguous()
214
+ hidden_states_residual = hidden_states - original_hidden_states
215
+ if (
216
+ original_encoder_hidden_states is not None
217
+ and encoder_hidden_states is not None
218
+ ): # Pattern 4, 5
219
+ encoder_hidden_states_residual = (
220
+ encoder_hidden_states - original_encoder_hidden_states
221
+ )
222
+ else:
223
+ encoder_hidden_states_residual = None # Pattern 3
224
+
225
+ return (
226
+ hidden_states,
227
+ encoder_hidden_states,
228
+ hidden_states_residual,
229
+ encoder_hidden_states_residual,
230
+ )
231
+
232
+ def call_Bn_blocks(
233
+ self,
234
+ hidden_states: torch.Tensor,
235
+ # None Pattern 3, else 4, 5
236
+ encoder_hidden_states: torch.Tensor | None,
237
+ *args,
238
+ **kwargs,
239
+ ):
240
+ if cache_context.Bn_compute_blocks() == 0:
241
+ return hidden_states, encoder_hidden_states
242
+
243
+ assert cache_context.Bn_compute_blocks() <= len(
244
+ self.transformer_blocks
245
+ ), (
246
+ f"Bn_compute_blocks {cache_context.Bn_compute_blocks()} must be less than "
247
+ f"the number of transformer blocks {len(self.transformer_blocks)}"
248
+ )
249
+ if len(cache_context.Bn_compute_blocks_ids()) > 0:
250
+ raise ValueError(
251
+ f"Bn_compute_blocks_ids is not support for "
252
+ f"patterns: {self._supported_patterns}."
253
+ )
254
+ else:
255
+ # Compute all Bn blocks if no specific Bn compute blocks ids are set.
256
+ for block in self._Bn_blocks():
257
+ hidden_states = block(
258
+ hidden_states,
259
+ *args,
260
+ **kwargs,
261
+ )
262
+ if not isinstance(hidden_states, torch.Tensor): # Pattern 4,5
263
+ hidden_states, encoder_hidden_states = hidden_states
264
+ if not self.forward_pattern.Return_H_First:
265
+ hidden_states, encoder_hidden_states = (
266
+ encoder_hidden_states,
267
+ hidden_states,
268
+ )
269
+
270
+ return hidden_states, encoder_hidden_states
@@ -4,12 +4,15 @@ import torch.distributed as dist
4
4
 
5
5
  from cache_dit.cache_factory import cache_context
6
6
  from cache_dit.cache_factory import ForwardPattern
7
+ from cache_dit.cache_factory.cache_blocks.utils import (
8
+ patch_cached_stats,
9
+ )
7
10
  from cache_dit.logger import init_logger
8
11
 
9
12
  logger = init_logger(__name__)
10
13
 
11
14
 
12
- class DBCachedTransformerBlocks(torch.nn.Module):
15
+ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
13
16
  _supported_patterns = [
14
17
  ForwardPattern.Pattern_0,
15
18
  ForwardPattern.Pattern_1,
@@ -29,18 +32,30 @@ class DBCachedTransformerBlocks(torch.nn.Module):
29
32
  self.transformer_blocks = transformer_blocks
30
33
  self.forward_pattern = forward_pattern
31
34
  self._check_forward_pattern()
35
+ logger.info(f"Match Cached Blocks: {self.__class__.__name__}")
32
36
 
33
37
  def _check_forward_pattern(self):
34
38
  assert (
35
39
  self.forward_pattern.Supported
36
40
  and self.forward_pattern in self._supported_patterns
37
- ), f"Pattern {self.forward_pattern} is not support for DBCache now!"
41
+ ), f"Pattern {self.forward_pattern} is not supported now!"
38
42
 
39
43
  if self.transformer_blocks is not None:
40
44
  for block in self.transformer_blocks:
41
45
  forward_parameters = set(
42
46
  inspect.signature(block.forward).parameters.keys()
43
47
  )
48
+ num_outputs = str(
49
+ inspect.signature(block.forward).return_annotation
50
+ ).count("torch.Tensor")
51
+
52
+ if num_outputs > 0:
53
+ assert len(self.forward_pattern.Out) == num_outputs, (
54
+ f"The number of block's outputs is {num_outputs} don't not "
55
+ f"match the number of the pattern: {self.forward_pattern}, "
56
+ f"Out: {len(self.forward_pattern.Out)}."
57
+ )
58
+
44
59
  for required_param in self.forward_pattern.In:
45
60
  assert (
46
61
  required_param in forward_parameters
@@ -479,19 +494,3 @@ class DBCachedTransformerBlocks(torch.nn.Module):
479
494
  )
480
495
 
481
496
  return hidden_states, encoder_hidden_states
482
-
483
-
484
- @torch.compiler.disable
485
- def patch_cached_stats(
486
- transformer,
487
- ):
488
- # Patch the cached stats to the transformer, the cached stats
489
- # will be reset for each calling of pipe.__call__(**kwargs).
490
- if transformer is None:
491
- return
492
-
493
- # TODO: Patch more cached stats to the transformer
494
- transformer._cached_steps = cache_context.get_cached_steps()
495
- transformer._residual_diffs = cache_context.get_residual_diffs()
496
- transformer._cfg_cached_steps = cache_context.get_cfg_cached_steps()
497
- transformer._cfg_residual_diffs = cache_context.get_cfg_residual_diffs()
@@ -0,0 +1,19 @@
1
+ import torch
2
+
3
+ from cache_dit.cache_factory import cache_context
4
+
5
+
6
+ @torch.compiler.disable
7
+ def patch_cached_stats(
8
+ transformer,
9
+ ):
10
+ # Patch the cached stats to the transformer, the cached stats
11
+ # will be reset for each calling of pipe.__call__(**kwargs).
12
+ if transformer is None:
13
+ return
14
+
15
+ # TODO: Patch more cached stats to the transformer
16
+ transformer._cached_steps = cache_context.get_cached_steps()
17
+ transformer._residual_diffs = cache_context.get_residual_diffs()
18
+ transformer._cfg_cached_steps = cache_context.get_cfg_cached_steps()
19
+ transformer._cfg_residual_diffs = cache_context.get_cfg_residual_diffs()
@@ -328,6 +328,33 @@ class DBCacheContext:
328
328
  return self.get_current_step() < self.max_warmup_steps
329
329
 
330
330
 
331
+ # TODO: Support context manager for different cache_context
332
+
333
+
334
+ def create_cache_context(*args, **kwargs):
335
+ return DBCacheContext(*args, **kwargs)
336
+
337
+
338
+ def get_current_cache_context():
339
+ return _current_cache_context
340
+
341
+
342
+ def set_current_cache_context(cache_context=None):
343
+ global _current_cache_context
344
+ _current_cache_context = cache_context
345
+
346
+
347
+ @contextlib.contextmanager
348
+ def cache_context(cache_context):
349
+ global _current_cache_context
350
+ old_cache_context = _current_cache_context
351
+ _current_cache_context = cache_context
352
+ try:
353
+ yield
354
+ finally:
355
+ _current_cache_context = old_cache_context
356
+
357
+
331
358
  @torch.compiler.disable
332
359
  def get_residual_diff_threshold():
333
360
  cache_context = get_current_cache_context()
@@ -657,19 +684,6 @@ def cfg_diff_compute_separate():
657
684
  _current_cache_context: DBCacheContext = None
658
685
 
659
686
 
660
- def create_cache_context(*args, **kwargs):
661
- return DBCacheContext(*args, **kwargs)
662
-
663
-
664
- def get_current_cache_context():
665
- return _current_cache_context
666
-
667
-
668
- def set_current_cache_context(cache_context=None):
669
- global _current_cache_context
670
- _current_cache_context = cache_context
671
-
672
-
673
687
  def collect_cache_kwargs(default_attrs: dict, **kwargs):
674
688
  # NOTE: This API will split kwargs into cache_kwargs and other_kwargs
675
689
  # default_attrs: specific settings for different pipelines
@@ -716,17 +730,6 @@ def collect_cache_kwargs(default_attrs: dict, **kwargs):
716
730
  return cache_kwargs, kwargs
717
731
 
718
732
 
719
- @contextlib.contextmanager
720
- def cache_context(cache_context):
721
- global _current_cache_context
722
- old_cache_context = _current_cache_context
723
- _current_cache_context = cache_context
724
- try:
725
- yield
726
- finally:
727
- _current_cache_context = old_cache_context
728
-
729
-
730
733
  @torch.compiler.disable
731
734
  def are_two_tensors_similar(
732
735
  t1: torch.Tensor, # prev residual R(t-1,n) = H(t-1,n) - H(t-1,0)
@@ -938,7 +941,11 @@ def get_Bn_buffer(prefix: str = "Bn"):
938
941
 
939
942
 
940
943
  @torch.compiler.disable
941
- def set_Bn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
944
+ def set_Bn_encoder_buffer(buffer: torch.Tensor | None, prefix: str = "Bn"):
945
+ # DON'T set None Buffer
946
+ if buffer is None:
947
+ return
948
+
942
949
  # This buffer is use for encoder hidden states approximation.
943
950
  if is_encoder_taylorseer_enabled():
944
951
  # taylorseer, encoder_taylorseer
@@ -1,3 +1,4 @@
1
+ from typing import Any, Tuple, List
1
2
  from diffusers import DiffusionPipeline
2
3
  from cache_dit.cache_factory.forward_pattern import ForwardPattern
3
4
  from cache_dit.cache_factory.cache_types import CacheType
@@ -9,9 +10,13 @@ from cache_dit.logger import init_logger
9
10
  logger = init_logger(__name__)
10
11
 
11
12
 
13
+ def supported_pipelines() -> Tuple[int, List[str]]:
14
+ return UnifiedCacheAdapter.supported_pipelines()
15
+
16
+
12
17
  def enable_cache(
13
18
  # BlockAdapter & forward pattern
14
- pipe_or_adapter: DiffusionPipeline | BlockAdapter,
19
+ pipe_or_adapter: DiffusionPipeline | BlockAdapter | Any,
15
20
  forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
16
21
  # Cache context kwargs
17
22
  Fn_compute_blocks: int = 8,
@@ -23,14 +28,14 @@ def enable_cache(
23
28
  # Cache CFG or not
24
29
  do_separate_cfg: bool = False,
25
30
  cfg_compute_first: bool = False,
26
- cfg_diff_compute_separate: bool = False,
31
+ cfg_diff_compute_separate: bool = True,
27
32
  # Hybird TaylorSeer
28
33
  enable_taylorseer: bool = False,
29
34
  enable_encoder_taylorseer: bool = False,
30
35
  taylorseer_cache_type: str = "residual",
31
36
  taylorseer_order: int = 2,
32
37
  **other_cache_kwargs,
33
- ) -> DiffusionPipeline:
38
+ ) -> DiffusionPipeline | Any:
34
39
  r"""
35
40
  Unified Cache API for almost Any Diffusion Transformers (with Transformer Blocks
36
41
  that match the specific Input and Output patterns).
@@ -19,39 +19,57 @@ class ForwardPattern(Enum):
19
19
  self.Supported = Supported
20
20
 
21
21
  Pattern_0 = (
22
- True,
23
- False,
24
- False,
25
- ("hidden_states", "encoder_hidden_states"),
26
- ("hidden_states", "encoder_hidden_states"),
27
- True,
22
+ True, # Return_H_First
23
+ False, # Return_H_Only
24
+ False, # Forward_H_only
25
+ ("hidden_states", "encoder_hidden_states"), # In
26
+ ("hidden_states", "encoder_hidden_states"), # Out
27
+ True, # Supported
28
28
  )
29
29
 
30
30
  Pattern_1 = (
31
- False,
32
- False,
33
- False,
34
- ("hidden_states", "encoder_hidden_states"),
35
- ("encoder_hidden_states", "hidden_states"),
36
- True,
31
+ False, # Return_H_First
32
+ False, # Return_H_Only
33
+ False, # Forward_H_only
34
+ ("hidden_states", "encoder_hidden_states"), # In
35
+ ("encoder_hidden_states", "hidden_states"), # Out
36
+ True, # Supported
37
37
  )
38
38
 
39
39
  Pattern_2 = (
40
- False,
41
- True,
42
- False,
43
- ("hidden_states", "encoder_hidden_states"),
44
- ("hidden_states",),
45
- True,
40
+ False, # Return_H_First
41
+ True, # Return_H_Only
42
+ False, # Forward_H_only
43
+ ("hidden_states", "encoder_hidden_states"), # In
44
+ ("hidden_states",), # Out
45
+ True, # Supported
46
46
  )
47
47
 
48
48
  Pattern_3 = (
49
- False,
50
- True,
51
- False,
52
- ("hidden_states",),
53
- ("hidden_states",),
54
- False,
49
+ False, # Return_H_First
50
+ True, # Return_H_Only
51
+ True, # Forward_H_only
52
+ ("hidden_states",), # In
53
+ ("hidden_states",), # Out
54
+ True, # Supported
55
+ )
56
+
57
+ Pattern_4 = (
58
+ True, # Return_H_First
59
+ False, # Return_H_Only
60
+ True, # Forward_H_only
61
+ ("hidden_states",), # In
62
+ ("hidden_states", "encoder_hidden_states"), # Out
63
+ True, # Supported
64
+ )
65
+
66
+ Pattern_5 = (
67
+ False, # Return_H_First
68
+ False, # Return_H_Only
69
+ True, # Forward_H_only
70
+ ("hidden_states",), # In
71
+ ("encoder_hidden_states", "hidden_states"), # Out
72
+ True, # Supported
55
73
  )
56
74
 
57
75
  @staticmethod
@@ -60,4 +78,7 @@ class ForwardPattern(Enum):
60
78
  ForwardPattern.Pattern_0,
61
79
  ForwardPattern.Pattern_1,
62
80
  ForwardPattern.Pattern_2,
81
+ ForwardPattern.Pattern_3,
82
+ ForwardPattern.Pattern_4,
83
+ ForwardPattern.Pattern_5,
63
84
  ]
@@ -0,0 +1,5 @@
1
+ from cache_dit.cache_factory.patch_functors.functor_base import PatchFunctor
2
+ from cache_dit.cache_factory.patch_functors.functor_flux import FluxPatchFunctor
3
+ from cache_dit.cache_factory.patch_functors.functor_chroma import (
4
+ ChromaPatchFunctor,
5
+ )
@@ -0,0 +1,18 @@
1
+ import torch
2
+ from abc import abstractmethod
3
+
4
+ from cache_dit.logger import init_logger
5
+
6
+ logger = init_logger(__name__)
7
+
8
+
9
+ class PatchFunctor:
10
+
11
+ @abstractmethod
12
+ def apply(
13
+ self,
14
+ transformer: torch.nn.Module,
15
+ *args,
16
+ **kwargs,
17
+ ) -> torch.nn.Module:
18
+ raise NotImplementedError("apply method is not implemented.")