cache-dit 0.2.30__py3-none-any.whl → 0.2.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.30'
32
- __version_tuple__ = version_tuple = (0, 2, 30)
31
+ __version__ = version = '0.2.31'
32
+ __version_tuple__ = version_tuple = (0, 2, 31)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -501,7 +501,7 @@ def shape_adapter(pipe, **kwargs) -> BlockAdapter:
501
501
  )
502
502
 
503
503
 
504
- @BlockAdapterRegistry.register("HiDream", supported=True)
504
+ @BlockAdapterRegistry.register("HiDream")
505
505
  def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
506
506
  # NOTE: Need to patch Transformer forward to fully support
507
507
  # double_stream_blocks and single_stream_blocks, namely, need
@@ -509,29 +509,32 @@ def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
509
509
  # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L893
510
510
  # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L927
511
511
  from diffusers import HiDreamImageTransformer2DModel
512
+ from cache_dit.cache_factory.patch_functors import HiDreamPatchFunctor
512
513
 
513
514
  assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
514
515
  return BlockAdapter(
515
516
  pipe=pipe,
516
517
  transformer=pipe.transformer,
517
518
  blocks=[
518
- # pipe.transformer.double_stream_blocks,
519
+ pipe.transformer.double_stream_blocks,
519
520
  pipe.transformer.single_stream_blocks,
520
521
  ],
521
522
  forward_pattern=[
522
- # ForwardPattern.Pattern_4,
523
+ ForwardPattern.Pattern_0,
523
524
  ForwardPattern.Pattern_3,
524
525
  ],
525
- # The type hint in diffusers is wrong
526
- check_num_outputs=False,
526
+ patch_functor=HiDreamPatchFunctor(),
527
+ # NOTE: The type hint in diffusers is wrong
528
+ check_forward_pattern=True,
529
+ check_num_outputs=True,
527
530
  **kwargs,
528
531
  )
529
532
 
530
533
 
531
- @BlockAdapterRegistry.register("HunyuanDiT", supported=False)
534
+ @BlockAdapterRegistry.register("HunyuanDiT")
532
535
  def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
533
- # TODO: Patch Transformer forward
534
536
  from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
537
+ from cache_dit.cache_factory.patch_functors import HunyuanDiTPatchFunctor
535
538
 
536
539
  assert isinstance(
537
540
  pipe.transformer,
@@ -542,14 +545,15 @@ def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
542
545
  transformer=pipe.transformer,
543
546
  blocks=pipe.transformer.blocks,
544
547
  forward_pattern=ForwardPattern.Pattern_3,
548
+ patch_functor=HunyuanDiTPatchFunctor(),
545
549
  **kwargs,
546
550
  )
547
551
 
548
552
 
549
- @BlockAdapterRegistry.register("HunyuanDiTPAG", supported=False)
553
+ @BlockAdapterRegistry.register("HunyuanDiTPAG")
550
554
  def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
551
- # TODO: Patch Transformer forward
552
555
  from diffusers import HunyuanDiT2DModel
556
+ from cache_dit.cache_factory.patch_functors import HunyuanDiTPatchFunctor
553
557
 
554
558
  assert isinstance(pipe.transformer, HunyuanDiT2DModel)
555
559
  return BlockAdapter(
@@ -557,5 +561,6 @@ def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
557
561
  transformer=pipe.transformer,
558
562
  blocks=pipe.transformer.blocks,
559
563
  forward_pattern=ForwardPattern.Pattern_3,
564
+ patch_functor=HunyuanDiTPatchFunctor(),
560
565
  **kwargs,
561
566
  )
@@ -75,6 +75,7 @@ class BlockAdapter:
75
75
  List[List[ParamsModifier]],
76
76
  ] = None
77
77
 
78
+ check_forward_pattern: bool = True
78
79
  check_num_outputs: bool = False
79
80
 
80
81
  # Pipeline Level Flags
@@ -391,11 +392,20 @@ class BlockAdapter:
391
392
  forward_pattern: ForwardPattern,
392
393
  **kwargs,
393
394
  ) -> bool:
395
+
396
+ if not kwargs.get("check_forward_pattern", True):
397
+ return True
398
+
394
399
  assert (
395
400
  forward_pattern.Supported
396
401
  and forward_pattern in ForwardPattern.supported_patterns()
397
402
  ), f"Pattern {forward_pattern} is not support now!"
398
403
 
404
+ # NOTE: Special case for HiDreamBlock
405
+ if hasattr(block, "block"):
406
+ if isinstance(block.block, torch.nn.Module):
407
+ block = block.block
408
+
399
409
  forward_parameters = set(
400
410
  inspect.signature(block.forward).parameters.keys()
401
411
  )
@@ -425,6 +435,14 @@ class BlockAdapter:
425
435
  logging: bool = True,
426
436
  **kwargs,
427
437
  ) -> bool:
438
+
439
+ if not kwargs.get("check_forward_pattern", True):
440
+ if logging:
441
+ logger.warning(
442
+ f"Skipped Forward Pattern Check: {forward_pattern}"
443
+ )
444
+ return True
445
+
428
446
  assert (
429
447
  forward_pattern.Supported
430
448
  and forward_pattern in ForwardPattern.supported_patterns()
@@ -531,6 +549,7 @@ class BlockAdapter:
531
549
  blocks,
532
550
  forward_pattern=forward_pattern,
533
551
  check_num_outputs=adapter.check_num_outputs,
552
+ check_forward_pattern=adapter.check_forward_pattern,
534
553
  ), (
535
554
  "No block forward pattern matched, "
536
555
  f"supported lists: {ForwardPattern.supported_patterns()}"
@@ -345,6 +345,7 @@ class CachedAdapter:
345
345
  block_adapter.blocks[i][j],
346
346
  transformer=block_adapter.transformer[i],
347
347
  forward_pattern=block_adapter.forward_pattern[i][j],
348
+ check_forward_pattern=block_adapter.check_forward_pattern,
348
349
  check_num_outputs=block_adapter.check_num_outputs,
349
350
  # 1. Cache context configuration
350
351
  cache_prefix=block_adapter.blocks_name[i][j],
@@ -25,6 +25,7 @@ class CachedBlocks:
25
25
  transformer_blocks: torch.nn.ModuleList,
26
26
  transformer: torch.nn.Module = None,
27
27
  forward_pattern: ForwardPattern = None,
28
+ check_forward_pattern: bool = True,
28
29
  check_num_outputs: bool = True,
29
30
  # 1. Cache context configuration
30
31
  # 'transformer_blocks', 'blocks', 'single_transformer_blocks',
@@ -45,6 +46,7 @@ class CachedBlocks:
45
46
  transformer_blocks,
46
47
  transformer=transformer,
47
48
  forward_pattern=forward_pattern,
49
+ check_forward_pattern=check_forward_pattern,
48
50
  check_num_outputs=check_num_outputs,
49
51
  # 1. Cache context configuration
50
52
  cache_prefix=cache_prefix,
@@ -58,6 +60,7 @@ class CachedBlocks:
58
60
  transformer_blocks,
59
61
  transformer=transformer,
60
62
  forward_pattern=forward_pattern,
63
+ check_forward_pattern=check_forward_pattern,
61
64
  check_num_outputs=check_num_outputs,
62
65
  # 1. Cache context configuration
63
66
  cache_prefix=cache_prefix,
@@ -25,6 +25,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
25
25
  transformer_blocks: torch.nn.ModuleList,
26
26
  transformer: torch.nn.Module = None,
27
27
  forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
28
+ check_forward_pattern: bool = True,
28
29
  check_num_outputs: bool = True,
29
30
  # 1. Cache context configuration
30
31
  cache_prefix: str = None, # maybe un-need.
@@ -38,6 +39,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
38
39
  self.transformer = transformer
39
40
  self.transformer_blocks = transformer_blocks
40
41
  self.forward_pattern = forward_pattern
42
+ self.check_forward_pattern = check_forward_pattern
41
43
  self.check_num_outputs = check_num_outputs
42
44
  # 1. Cache context configuration
43
45
  self.cache_prefix = cache_prefix
@@ -52,6 +54,12 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
52
54
  )
53
55
 
54
56
  def _check_forward_pattern(self):
57
+ if not self.check_forward_pattern:
58
+ logger.warning(
59
+ f"Skipped Forward Pattern Check: {self.forward_pattern}"
60
+ )
61
+ return
62
+
55
63
  assert (
56
64
  self.forward_pattern.Supported
57
65
  and self.forward_pattern in self._supported_patterns
@@ -59,6 +67,11 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
59
67
 
60
68
  if self.transformer_blocks is not None:
61
69
  for block in self.transformer_blocks:
70
+ # Special case for HiDreamBlock
71
+ if hasattr(block, "block"):
72
+ if isinstance(block.block, torch.nn.Module):
73
+ block = block.block
74
+
62
75
  forward_parameters = set(
63
76
  inspect.signature(block.forward).parameters.keys()
64
77
  )
@@ -3,3 +3,9 @@ from cache_dit.cache_factory.patch_functors.functor_flux import FluxPatchFunctor
3
3
  from cache_dit.cache_factory.patch_functors.functor_chroma import (
4
4
  ChromaPatchFunctor,
5
5
  )
6
+ from cache_dit.cache_factory.patch_functors.functor_hidream import (
7
+ HiDreamPatchFunctor,
8
+ )
9
+ from cache_dit.cache_factory.patch_functors.functor_hunyuan_dit import (
10
+ HunyuanDiTPatchFunctor,
11
+ )
@@ -46,8 +46,10 @@ class ChromaPatchFunctor(PatchFunctor):
46
46
  block.forward = __patch_single_forward__.__get__(block)
47
47
  is_patched = True
48
48
 
49
+ cls_name = transformer.__class__.__name__
50
+
49
51
  if is_patched:
50
- logger.warning("Patched Chroma for cache-dit.")
52
+ logger.warning(f"Patched {cls_name} for cache-dit.")
51
53
  assert not getattr(transformer, "_is_parallelized", False), (
52
54
  "Please call `cache_dit.enable_cache` before Parallelize, "
53
55
  "the __patch_transformer_forward__ will overwrite the "
@@ -59,7 +61,6 @@ class ChromaPatchFunctor(PatchFunctor):
59
61
 
60
62
  transformer._is_patched = is_patched # True or False
61
63
 
62
- cls_name = transformer.__class__.__name__
63
64
  logger.info(
64
65
  f"Applied {self.__class__.__name__} for {cls_name}, "
65
66
  f"Patch: {is_patched}."
@@ -47,8 +47,10 @@ class FluxPatchFunctor(PatchFunctor):
47
47
  block.forward = __patch_single_forward__.__get__(block)
48
48
  is_patched = True
49
49
 
50
+ cls_name = transformer.__class__.__name__
51
+
50
52
  if is_patched:
51
- logger.warning("Patched Flux for cache-dit.")
53
+ logger.warning(f"Patched {cls_name} for cache-dit.")
52
54
  assert not getattr(transformer, "_is_parallelized", False), (
53
55
  "Please call `cache_dit.enable_cache` before Parallelize, "
54
56
  "the __patch_transformer_forward__ will overwrite the "
@@ -60,7 +62,6 @@ class FluxPatchFunctor(PatchFunctor):
60
62
 
61
63
  transformer._is_patched = is_patched # True or False
62
64
 
63
- cls_name = transformer.__class__.__name__
64
65
  logger.info(
65
66
  f"Applied {self.__class__.__name__} for {cls_name}, "
66
67
  f"Patch: {is_patched}."
@@ -0,0 +1,412 @@
1
+ import torch
2
+ from typing import Tuple, Optional, Dict, Any, Union, List
3
+ from diffusers import HiDreamImageTransformer2DModel
4
+ from diffusers.models.transformers.transformer_hidream_image import (
5
+ HiDreamBlock,
6
+ HiDreamImageTransformerBlock,
7
+ HiDreamImageSingleTransformerBlock,
8
+ Transformer2DModelOutput,
9
+ )
10
+ from diffusers.utils import (
11
+ deprecate,
12
+ USE_PEFT_BACKEND,
13
+ scale_lora_layers,
14
+ unscale_lora_layers,
15
+ )
16
+ from cache_dit.cache_factory.patch_functors.functor_base import (
17
+ PatchFunctor,
18
+ )
19
+ from cache_dit.logger import init_logger
20
+
21
+ logger = init_logger(__name__)
22
+
23
+
24
+ class HiDreamPatchFunctor(PatchFunctor):
25
+
26
+ def apply(
27
+ self,
28
+ transformer: HiDreamImageTransformer2DModel,
29
+ **kwargs,
30
+ ) -> HiDreamImageTransformer2DModel:
31
+ if hasattr(transformer, "_is_patched"):
32
+ return transformer
33
+
34
+ is_patched = False
35
+
36
+ _block_id = 0
37
+ for block in transformer.double_stream_blocks:
38
+ assert isinstance(block, HiDreamBlock)
39
+ block.forward = __patch_block_forward__.__get__(block)
40
+ # NOTE: Patch Inner block and block_id
41
+ _block = block.block
42
+ assert isinstance(_block, HiDreamImageTransformerBlock)
43
+ _block._block_id = _block_id
44
+ _block.forward = __patch_double_forward__.__get__(_block)
45
+ _block_id += 1
46
+
47
+ for block in transformer.single_stream_blocks:
48
+ assert isinstance(block, HiDreamBlock)
49
+ block.forward = __patch_block_forward__.__get__(block)
50
+ # NOTE: Patch Inner block and block_id
51
+ _block = block.block
52
+ assert isinstance(_block, HiDreamImageSingleTransformerBlock)
53
+ _block._block_id = _block_id
54
+ _block.forward = __patch_single_forward__.__get__(_block)
55
+ _block_id += 1
56
+
57
+ is_patched = True
58
+ cls_name = transformer.__class__.__name__
59
+
60
+ if is_patched:
61
+ logger.warning(f"Patched {cls_name} for cache-dit.")
62
+ assert not getattr(transformer, "_is_parallelized", False), (
63
+ "Please call `cache_dit.enable_cache` before Parallelize, "
64
+ "the __patch_transformer_forward__ will overwrite the "
65
+ "parallized forward and cause a downgrade of performance."
66
+ )
67
+ transformer.forward = __patch_transformer_forward__.__get__(
68
+ transformer
69
+ )
70
+
71
+ transformer._is_patched = is_patched # True or False
72
+
73
+ logger.info(
74
+ f"Applied {self.__class__.__name__} for {cls_name}, "
75
+ f"Patch: {is_patched}."
76
+ )
77
+
78
+ return transformer
79
+
80
+
81
+ # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py
82
+ def __patch_double_forward__(
83
+ self: HiDreamImageTransformerBlock,
84
+ hidden_states: torch.Tensor,
85
+ encoder_hidden_states: torch.Tensor, # initial_encoder_hidden_states
86
+ hidden_states_masks: Optional[torch.Tensor] = None,
87
+ temb: Optional[torch.Tensor] = None,
88
+ image_rotary_emb: torch.Tensor = None,
89
+ llama31_encoder_hidden_states: List[torch.Tensor] = None,
90
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
91
+ # Assume block_id was patched in transformer forward:
92
+ # for i, block in enumerate(blocks): block._block_id = i;
93
+ block_id = self._block_id
94
+ initial_encoder_hidden_states_seq_len = encoder_hidden_states.shape[1]
95
+ cur_llama31_encoder_hidden_states = llama31_encoder_hidden_states[block_id]
96
+ cur_encoder_hidden_states = torch.cat(
97
+ [encoder_hidden_states, cur_llama31_encoder_hidden_states],
98
+ dim=1,
99
+ )
100
+ encoder_hidden_states = cur_encoder_hidden_states
101
+
102
+ wtype = hidden_states.dtype
103
+ (
104
+ shift_msa_i,
105
+ scale_msa_i,
106
+ gate_msa_i,
107
+ shift_mlp_i,
108
+ scale_mlp_i,
109
+ gate_mlp_i,
110
+ shift_msa_t,
111
+ scale_msa_t,
112
+ gate_msa_t,
113
+ shift_mlp_t,
114
+ scale_mlp_t,
115
+ gate_mlp_t,
116
+ ) = self.adaLN_modulation(temb)[:, None].chunk(12, dim=-1)
117
+
118
+ # 1. MM-Attention
119
+ norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype)
120
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa_i) + shift_msa_i
121
+ norm_encoder_hidden_states = self.norm1_t(encoder_hidden_states).to(
122
+ dtype=wtype
123
+ )
124
+ norm_encoder_hidden_states = (
125
+ norm_encoder_hidden_states * (1 + scale_msa_t) + shift_msa_t
126
+ )
127
+
128
+ attn_output_i, attn_output_t = self.attn1(
129
+ norm_hidden_states,
130
+ hidden_states_masks,
131
+ norm_encoder_hidden_states,
132
+ image_rotary_emb=image_rotary_emb,
133
+ )
134
+
135
+ hidden_states = gate_msa_i * attn_output_i + hidden_states
136
+ encoder_hidden_states = gate_msa_t * attn_output_t + encoder_hidden_states
137
+
138
+ # 2. Feed-forward
139
+ norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype)
140
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i
141
+ norm_encoder_hidden_states = self.norm3_t(encoder_hidden_states).to(
142
+ dtype=wtype
143
+ )
144
+ norm_encoder_hidden_states = (
145
+ norm_encoder_hidden_states * (1 + scale_mlp_t) + shift_mlp_t
146
+ )
147
+
148
+ ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states)
149
+ ff_output_t = gate_mlp_t * self.ff_t(norm_encoder_hidden_states)
150
+ hidden_states = ff_output_i + hidden_states
151
+ encoder_hidden_states = ff_output_t + encoder_hidden_states
152
+
153
+ initial_encoder_hidden_states = encoder_hidden_states[
154
+ :, :initial_encoder_hidden_states_seq_len
155
+ ]
156
+
157
+ return hidden_states, initial_encoder_hidden_states
158
+
159
+
160
+ # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py
161
+ def __patch_single_forward__(
162
+ self: HiDreamImageSingleTransformerBlock,
163
+ hidden_states: torch.Tensor,
164
+ hidden_states_masks: Optional[torch.Tensor] = None,
165
+ temb: Optional[torch.Tensor] = None,
166
+ image_rotary_emb: torch.Tensor = None,
167
+ llama31_encoder_hidden_states: List[torch.Tensor] = None,
168
+ ) -> torch.Tensor:
169
+ # Assume block_id was patched in transformer forward:
170
+ # for i, block in enumerate(blocks): block._block_id = i;
171
+ block_id = self._block_id
172
+ hidden_states_seq_len = hidden_states.shape[1]
173
+ cur_llama31_encoder_hidden_states = llama31_encoder_hidden_states[block_id]
174
+ hidden_states = torch.cat(
175
+ [hidden_states, cur_llama31_encoder_hidden_states], dim=1
176
+ )
177
+
178
+ wtype = hidden_states.dtype
179
+ (
180
+ shift_msa_i,
181
+ scale_msa_i,
182
+ gate_msa_i,
183
+ shift_mlp_i,
184
+ scale_mlp_i,
185
+ gate_mlp_i,
186
+ ) = self.adaLN_modulation(temb)[:, None].chunk(6, dim=-1)
187
+
188
+ # 1. MM-Attention
189
+ norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype)
190
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa_i) + shift_msa_i
191
+ attn_output_i = self.attn1(
192
+ norm_hidden_states,
193
+ hidden_states_masks,
194
+ image_rotary_emb=image_rotary_emb,
195
+ )
196
+ hidden_states = gate_msa_i * attn_output_i + hidden_states
197
+
198
+ # 2. Feed-forward
199
+ norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype)
200
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i
201
+ ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states.to(dtype=wtype))
202
+ hidden_states = ff_output_i + hidden_states
203
+
204
+ hidden_states = hidden_states[:, :hidden_states_seq_len]
205
+
206
+ return hidden_states
207
+
208
+
209
+ # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py
210
+ def __patch_block_forward__(
211
+ self: HiDreamBlock,
212
+ hidden_states: torch.Tensor,
213
+ *args,
214
+ **kwargs,
215
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
216
+ return self.block(hidden_states, *args, **kwargs)
217
+
218
+
219
+ # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py
220
+ def __patch_transformer_forward__(
221
+ self: HiDreamImageTransformer2DModel,
222
+ hidden_states: torch.Tensor,
223
+ timesteps: torch.LongTensor = None,
224
+ encoder_hidden_states_t5: torch.Tensor = None,
225
+ encoder_hidden_states_llama3: torch.Tensor = None,
226
+ pooled_embeds: torch.Tensor = None,
227
+ img_ids: Optional[torch.Tensor] = None,
228
+ img_sizes: Optional[List[Tuple[int, int]]] = None,
229
+ hidden_states_masks: Optional[torch.Tensor] = None,
230
+ attention_kwargs: Optional[Dict[str, Any]] = None,
231
+ return_dict: bool = True,
232
+ **kwargs,
233
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
234
+ encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
235
+
236
+ if encoder_hidden_states is not None:
237
+ deprecation_message = "The `encoder_hidden_states` argument is deprecated. Please use `encoder_hidden_states_t5` and `encoder_hidden_states_llama3` instead."
238
+ deprecate("encoder_hidden_states", "0.35.0", deprecation_message)
239
+ encoder_hidden_states_t5 = encoder_hidden_states[0]
240
+ encoder_hidden_states_llama3 = encoder_hidden_states[1]
241
+
242
+ if (
243
+ img_ids is not None
244
+ and img_sizes is not None
245
+ and hidden_states_masks is None
246
+ ):
247
+ deprecation_message = "Passing `img_ids` and `img_sizes` with unpachified `hidden_states` is deprecated and will be ignored."
248
+ deprecate("img_ids", "0.35.0", deprecation_message)
249
+
250
+ if hidden_states_masks is not None and (
251
+ img_ids is None or img_sizes is None
252
+ ):
253
+ raise ValueError(
254
+ "if `hidden_states_masks` is passed, `img_ids` and `img_sizes` must also be passed."
255
+ )
256
+ elif hidden_states_masks is not None and hidden_states.ndim != 3:
257
+ raise ValueError(
258
+ "if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)"
259
+ )
260
+
261
+ if attention_kwargs is not None:
262
+ attention_kwargs = attention_kwargs.copy()
263
+ lora_scale = attention_kwargs.pop("scale", 1.0)
264
+ else:
265
+ lora_scale = 1.0
266
+
267
+ if USE_PEFT_BACKEND:
268
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
269
+ scale_lora_layers(self, lora_scale)
270
+ else:
271
+ if (
272
+ attention_kwargs is not None
273
+ and attention_kwargs.get("scale", None) is not None
274
+ ):
275
+ logger.warning(
276
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
277
+ )
278
+
279
+ # spatial forward
280
+ batch_size = hidden_states.shape[0]
281
+ hidden_states_type = hidden_states.dtype
282
+
283
+ # Patchify the input
284
+ if hidden_states_masks is None:
285
+ hidden_states, hidden_states_masks, img_sizes, img_ids = self.patchify(
286
+ hidden_states
287
+ )
288
+
289
+ # Embed the hidden states
290
+ hidden_states = self.x_embedder(hidden_states)
291
+
292
+ # 0. time
293
+ timesteps = self.t_embedder(timesteps, hidden_states_type)
294
+ p_embedder = self.p_embedder(pooled_embeds)
295
+ temb = timesteps + p_embedder
296
+
297
+ encoder_hidden_states = [
298
+ encoder_hidden_states_llama3[k] for k in self.config.llama_layers
299
+ ]
300
+
301
+ if self.caption_projection is not None:
302
+ new_encoder_hidden_states = []
303
+ for i, enc_hidden_state in enumerate(encoder_hidden_states):
304
+ enc_hidden_state = self.caption_projection[i](enc_hidden_state)
305
+ enc_hidden_state = enc_hidden_state.view(
306
+ batch_size, -1, hidden_states.shape[-1]
307
+ )
308
+ new_encoder_hidden_states.append(enc_hidden_state)
309
+ encoder_hidden_states = new_encoder_hidden_states
310
+ encoder_hidden_states_t5 = self.caption_projection[-1](
311
+ encoder_hidden_states_t5
312
+ )
313
+ encoder_hidden_states_t5 = encoder_hidden_states_t5.view(
314
+ batch_size, -1, hidden_states.shape[-1]
315
+ )
316
+ encoder_hidden_states.append(encoder_hidden_states_t5)
317
+
318
+ txt_ids = torch.zeros(
319
+ batch_size,
320
+ encoder_hidden_states[-1].shape[1]
321
+ + encoder_hidden_states[-2].shape[1]
322
+ + encoder_hidden_states[0].shape[1],
323
+ 3,
324
+ device=img_ids.device,
325
+ dtype=img_ids.dtype,
326
+ )
327
+ ids = torch.cat((img_ids, txt_ids), dim=1)
328
+ image_rotary_emb = self.pe_embedder(ids)
329
+
330
+ # 2. Blocks
331
+ # NOTE: block_id is no-need anymore.
332
+ initial_encoder_hidden_states = torch.cat(
333
+ [encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1
334
+ )
335
+ llama31_encoder_hidden_states = encoder_hidden_states
336
+ for bid, block in enumerate(self.double_stream_blocks):
337
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
338
+ hidden_states, initial_encoder_hidden_states = (
339
+ self._gradient_checkpointing_func(
340
+ block,
341
+ hidden_states,
342
+ initial_encoder_hidden_states,
343
+ hidden_states_masks,
344
+ temb,
345
+ image_rotary_emb,
346
+ llama31_encoder_hidden_states,
347
+ )
348
+ )
349
+ else:
350
+ hidden_states, initial_encoder_hidden_states = block(
351
+ hidden_states,
352
+ initial_encoder_hidden_states, # encoder_hidden_states
353
+ hidden_states_masks=hidden_states_masks,
354
+ temb=temb,
355
+ image_rotary_emb=image_rotary_emb,
356
+ llama31_encoder_hidden_states=llama31_encoder_hidden_states,
357
+ )
358
+
359
+ image_tokens_seq_len = hidden_states.shape[1]
360
+ hidden_states = torch.cat(
361
+ [hidden_states, initial_encoder_hidden_states], dim=1
362
+ )
363
+ if hidden_states_masks is not None:
364
+ # NOTE: Patched
365
+ cur_llama31_encoder_hidden_states = llama31_encoder_hidden_states[
366
+ self.double_stream_blocks[-1].block._block_id
367
+ ]
368
+ encoder_attention_mask_ones = torch.ones(
369
+ (
370
+ batch_size,
371
+ initial_encoder_hidden_states.shape[1]
372
+ + cur_llama31_encoder_hidden_states.shape[1],
373
+ ),
374
+ device=hidden_states_masks.device,
375
+ dtype=hidden_states_masks.dtype,
376
+ )
377
+ hidden_states_masks = torch.cat(
378
+ [hidden_states_masks, encoder_attention_mask_ones], dim=1
379
+ )
380
+
381
+ for bid, block in enumerate(self.single_stream_blocks):
382
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
383
+ hidden_states = self._gradient_checkpointing_func(
384
+ block,
385
+ hidden_states,
386
+ hidden_states_masks,
387
+ temb,
388
+ image_rotary_emb,
389
+ llama31_encoder_hidden_states,
390
+ )
391
+ else:
392
+ hidden_states = block(
393
+ hidden_states,
394
+ hidden_states_masks=hidden_states_masks,
395
+ temb=temb,
396
+ image_rotary_emb=image_rotary_emb,
397
+ llama31_encoder_hidden_states=llama31_encoder_hidden_states,
398
+ )
399
+
400
+ hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
401
+ output = self.final_layer(hidden_states, temb)
402
+ output = self.unpatchify(output, img_sizes, self.training)
403
+ if hidden_states_masks is not None:
404
+ hidden_states_masks = hidden_states_masks[:, :image_tokens_seq_len]
405
+
406
+ if USE_PEFT_BACKEND:
407
+ # remove `lora_scale` from each PEFT layer
408
+ unscale_lora_layers(self, lora_scale)
409
+
410
+ if not return_dict:
411
+ return (output,)
412
+ return Transformer2DModelOutput(sample=output)
@@ -0,0 +1,213 @@
1
+ import torch
2
+ from typing import Optional, Union, List
3
+ from diffusers import HunyuanDiT2DModel
4
+ from diffusers.models.transformers.hunyuan_transformer_2d import (
5
+ HunyuanDiTBlock,
6
+ Transformer2DModelOutput,
7
+ )
8
+ from cache_dit.cache_factory.patch_functors.functor_base import (
9
+ PatchFunctor,
10
+ )
11
+ from cache_dit.logger import init_logger
12
+
13
+ logger = init_logger(__name__)
14
+
15
+
16
+ class HunyuanDiTPatchFunctor(PatchFunctor):
17
+
18
+ def apply(
19
+ self,
20
+ transformer: HunyuanDiT2DModel,
21
+ **kwargs,
22
+ ) -> HunyuanDiT2DModel:
23
+ if hasattr(transformer, "_is_patched"):
24
+ return transformer
25
+
26
+ is_patched = False
27
+
28
+ num_layers = transformer.config.num_layers
29
+ layer_id = 0
30
+ for block in transformer.blocks:
31
+ assert isinstance(block, HunyuanDiTBlock)
32
+ block._num_layers = num_layers
33
+ block._layer_id = layer_id
34
+ block.forward = __patch_block_forward__.__get__(block)
35
+ layer_id += 1
36
+
37
+ is_patched = True
38
+
39
+ cls_name = transformer.__class__.__name__
40
+
41
+ if is_patched:
42
+ logger.warning(f"Patched {cls_name} for cache-dit.")
43
+ assert not getattr(transformer, "_is_parallelized", False), (
44
+ "Please call `cache_dit.enable_cache` before Parallelize, "
45
+ "the __patch_transformer_forward__ will overwrite the "
46
+ "parallized forward and cause a downgrade of performance."
47
+ )
48
+ transformer.forward = __patch_transformer_forward__.__get__(
49
+ transformer
50
+ )
51
+
52
+ transformer._is_patched = is_patched # True or False
53
+
54
+ logger.info(
55
+ f"Applied {self.__class__.__name__} for {cls_name}, "
56
+ f"Patch: {is_patched}."
57
+ )
58
+
59
+ return transformer
60
+
61
+
62
+ def __patch_block_forward__(
63
+ self: HunyuanDiTBlock,
64
+ hidden_states: torch.Tensor,
65
+ encoder_hidden_states: Optional[torch.Tensor] = None,
66
+ temb: Optional[torch.Tensor] = None,
67
+ image_rotary_emb: torch.Tensor = None,
68
+ controlnet_block_samples: torch.Tensor = None,
69
+ skips: List[torch.Tensor] = [],
70
+ ) -> torch.Tensor:
71
+ # Notice that normalization is always applied before the real computation in the following blocks.
72
+ # 0. Long Skip Connection
73
+ num_layers = self._num_layers
74
+ layer_id = self._layer_id
75
+
76
+ if layer_id > num_layers // 2:
77
+ if controlnet_block_samples is not None:
78
+ skip = skips.pop() + controlnet_block_samples.pop()
79
+ else:
80
+ skip = skips.pop()
81
+ else:
82
+ skip = None
83
+
84
+ if self.skip_linear is not None:
85
+ cat = torch.cat([hidden_states, skip], dim=-1)
86
+ cat = self.skip_norm(cat)
87
+ hidden_states = self.skip_linear(cat)
88
+
89
+ # 1. Self-Attention
90
+ norm_hidden_states = self.norm1(
91
+ hidden_states, temb
92
+ ) # checked: self.norm1 is correct
93
+ attn_output = self.attn1(
94
+ norm_hidden_states,
95
+ image_rotary_emb=image_rotary_emb,
96
+ )
97
+ hidden_states = hidden_states + attn_output
98
+
99
+ # 2. Cross-Attention
100
+ hidden_states = hidden_states + self.attn2(
101
+ self.norm2(hidden_states),
102
+ encoder_hidden_states=encoder_hidden_states,
103
+ image_rotary_emb=image_rotary_emb,
104
+ )
105
+
106
+ # FFN Layer
107
+ mlp_inputs = self.norm3(hidden_states)
108
+ hidden_states = hidden_states + self.ff(mlp_inputs)
109
+
110
+ if layer_id < (num_layers // 2 - 1):
111
+ skips.append(hidden_states)
112
+
113
+ return hidden_states
114
+
115
+
116
+ def __patch_transformer_forward__(
117
+ self: HunyuanDiT2DModel,
118
+ hidden_states,
119
+ timestep,
120
+ encoder_hidden_states=None,
121
+ text_embedding_mask=None,
122
+ encoder_hidden_states_t5=None,
123
+ text_embedding_mask_t5=None,
124
+ image_meta_size=None,
125
+ style=None,
126
+ image_rotary_emb=None,
127
+ controlnet_block_samples=None,
128
+ return_dict=True,
129
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
130
+ height, width = hidden_states.shape[-2:]
131
+
132
+ hidden_states = self.pos_embed(hidden_states)
133
+
134
+ temb = self.time_extra_emb(
135
+ timestep,
136
+ encoder_hidden_states_t5,
137
+ image_meta_size,
138
+ style,
139
+ hidden_dtype=timestep.dtype,
140
+ ) # [B, D]
141
+
142
+ # text projection
143
+ batch_size, sequence_length, _ = encoder_hidden_states_t5.shape
144
+ encoder_hidden_states_t5 = self.text_embedder(
145
+ encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1])
146
+ )
147
+ encoder_hidden_states_t5 = encoder_hidden_states_t5.view(
148
+ batch_size, sequence_length, -1
149
+ )
150
+
151
+ encoder_hidden_states = torch.cat(
152
+ [encoder_hidden_states, encoder_hidden_states_t5], dim=1
153
+ )
154
+ text_embedding_mask = torch.cat(
155
+ [text_embedding_mask, text_embedding_mask_t5], dim=-1
156
+ )
157
+ text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()
158
+
159
+ encoder_hidden_states = torch.where(
160
+ text_embedding_mask, encoder_hidden_states, self.text_embedding_padding
161
+ )
162
+
163
+ skips = []
164
+ for layer, block in enumerate(self.blocks):
165
+ hidden_states = block(
166
+ hidden_states,
167
+ temb=temb,
168
+ encoder_hidden_states=encoder_hidden_states,
169
+ image_rotary_emb=image_rotary_emb,
170
+ controlnet_block_samples=controlnet_block_samples,
171
+ skips=skips,
172
+ ) # (N, L, D)
173
+
174
+ if (
175
+ controlnet_block_samples is not None
176
+ and len(controlnet_block_samples) != 0
177
+ ):
178
+ raise ValueError(
179
+ "The number of controls is not equal to the number of skip connections."
180
+ )
181
+
182
+ # final layer
183
+ hidden_states = self.norm_out(hidden_states, temb.to(torch.float32))
184
+ hidden_states = self.proj_out(hidden_states)
185
+ # (N, L, patch_size ** 2 * out_channels)
186
+
187
+ # unpatchify: (N, out_channels, H, W)
188
+ patch_size = self.pos_embed.patch_size
189
+ height = height // patch_size
190
+ width = width // patch_size
191
+
192
+ hidden_states = hidden_states.reshape(
193
+ shape=(
194
+ hidden_states.shape[0],
195
+ height,
196
+ width,
197
+ patch_size,
198
+ patch_size,
199
+ self.out_channels,
200
+ )
201
+ )
202
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
203
+ output = hidden_states.reshape(
204
+ shape=(
205
+ hidden_states.shape[0],
206
+ self.out_channels,
207
+ height * patch_size,
208
+ width * patch_size,
209
+ )
210
+ )
211
+ if not return_dict:
212
+ return (output,)
213
+ return Transformer2DModelOutput(sample=output)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.2.30
3
+ Version: 0.2.31
4
4
  Summary: 🤗 A Unified and Training-free Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -59,18 +59,21 @@ Dynamic: requires-python
59
59
  🔥<b><a href="#unified">Unified Cache APIs</a> | <a href="#dbcache">DBCache</a> | <a href="#taylorseer">Hybrid TaylorSeer</a> | <a href="#cfg">Hybrid Cache CFG</a></b>🔥
60
60
  </p>
61
61
  <p align="center">
62
- 🎉Now, <b>cache-dit</b> covers <b>mainstream</b> Diffusers' <b>DiT-based</b> Pipelines🎉<br>
63
- 🔥<b><a href="#supported">Qwen-Image</a> | <a href="#supported">FLUX.1</a> | <a href="#supported">Wan 2.1/2.2</a> | <a href="#supported"> ... </a> | <a href="#supported">CogVideoX</a></b>🔥
62
+ 🎉Now, <b>cache-dit</b> covers <b>most</b> mainstream Diffusers' <b>DiT</b> Pipelines🎉<br>
63
+ 🔥<a href="#supported">Qwen-Image</a> | <a href="#supported">FLUX.1</a> | <a href="#supported">Wan 2.1</a> | <a href="#supported"> Wan 2.2 </a> | <a href="#supported">HunyuanVideo</a>🔥<br>
64
+ 🔥<a href="#supported">HunyuanDiT</a> | <a href="#supported">HiDream</a> | <a href="#supported">Mochi</a> | <a href="#supported"> CogVideoX </a> | <a href="#supported">CogVideoX1.5</a>🔥<br>
65
+ 🔥<a href="#supported">CogView3Plus</a> | <a href="#supported">CogView4</a> | <a href="#supported">Chroma</a> | <a href="#supported"> LTXVideo </a> | <a href="#supported">PixArt</a>🔥<br>
66
+ 🔥<a href="#supported">Cosmos</a> | <a href="#supported">SkyReelsV2</a> | <a href="#supported">VisualCloze</a> | <a href="#supported"> ... </a> | <a href="#supported">Lumina2</a>🔥
64
67
  </p>
65
68
  </div>
66
69
  <div align='center'>
67
- <img src=./assets/gifs/wan2.2.C0_Q0_NONE.gif width=160px>
68
- <img src=./assets/gifs/wan2.2.C1_Q0_DBCACHE_F1B0_W2M8MC2_T1O2_R0.08.gif width=160px>
69
- <img src=./assets/gifs/wan2.2.C1_Q1_fp8_w8a8_dq_DBCACHE_F1B0_W2M8MC2_T1O2_R0.08.gif width=160px>
70
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/wan2.2.C0_Q0_NONE.gif width=160px>
71
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/wan2.2.C1_Q0_DBCACHE_F1B0_W2M8MC2_T1O2_R0.08.gif width=160px>
72
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/wan2.2.C1_Q1_fp8_w8a8_dq_DBCACHE_F1B0_W2M8MC2_T1O2_R0.08.gif width=160px>
70
73
  <p><b>🔥Wan2.2 MoE</b> Baseline | <b><a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:~2.0x↑🎉</b> | +FP8 DQ:<b>~2.4x↑🎉</b></p>
71
- <img src=./assets/qwen-image.C0_Q0_NONE.png width=160px>
72
- <img src=./assets/qwen-image.C1_Q0_DBCACHE_F8B0_W8M0MC0_T1O4_R0.12_S23.png width=160px>
73
- <img src=./assets/qwen-image.C1_Q1_fp8_w8a8_dq_DBCACHE_F8B0_W8M0MC0_T1O4_R0.12_S18.png width=160px>
74
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image.C0_Q0_NONE.png width=160px>
75
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image.C1_Q0_DBCACHE_F8B0_W8M0MC0_T1O4_R0.12_S23.png width=160px>
76
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image.C1_Q1_fp8_w8a8_dq_DBCACHE_F8B0_W8M0MC0_T1O4_R0.12_S18.png width=160px>
74
77
  <p><b>🔥Qwen-Image</b> Baseline | <b><a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:~1.8x↑🎉</b> | +FP8 DQ:<b>~2.2x↑🎉</b><br>♥️ Please consider to leave a <b>⭐️ Star</b> to support us ~ ♥️</p>
75
78
  </p>
76
79
  </div>
@@ -141,7 +144,10 @@ Currently, **cache-dit** library supports almost **Any** Diffusion Transformers
141
144
  - [🚀Wan2.2-T2V](https://github.com/vipshop/cache-dit/raw/main/examples)
142
145
  - [🚀Wan2.1-T2V](https://github.com/vipshop/cache-dit/raw/main/examples)
143
146
  - [🚀Wan2.1-FLF2V](https://github.com/vipshop/cache-dit/raw/main/examples)
147
+ - [🚀mochi-1-preview](https://github.com/vipshop/cache-dit/raw/main/examples)
144
148
  - [🚀HunyuanVideo](https://github.com/vipshop/cache-dit/raw/main/examples)
149
+ - [🚀HunyuanDiT](https://github.com/vipshop/cache-dit/raw/main/examples)
150
+ - [🚀HiDream](https://github.com/vipshop/cache-dit/raw/main/examples)
145
151
 
146
152
  </details>
147
153
 
@@ -1,30 +1,32 @@
1
1
  cache_dit/__init__.py,sha256=kX9V-FegZG4c8LMwI4PTmMqH794MEW0pzDArdhC0cJw,1241
2
- cache_dit/_version.py,sha256=6uKAYeE03adIcUS0SDwp52AaQx0KO8z_-07D_lPHrz8,706
2
+ cache_dit/_version.py,sha256=cMx3p02rk8iaGjj6X7bw0aOcGW7d-iY_EBO9S_9o-b4,706
3
3
  cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
4
  cache_dit/utils.py,sha256=WK7eqgH6gCYNHXNLmWyxBDU0XSHTPg7CfOcyXlGXBqE,10510
5
5
  cache_dit/cache_factory/.gitignore,sha256=5Cb-qT9wsTUoMJ7vACDF7ZcLpAXhi5v-xdcWSRit988,23
6
6
  cache_dit/cache_factory/__init__.py,sha256=Iw6-iJLFbdzCsIDZXXOw371L-HPmoeZO_P9a3sDjP5s,1103
7
- cache_dit/cache_factory/cache_adapters.py,sha256=TA_0mEHMdSQDrt4rYASeX4-BD8pJOznSJfMV3hkrGuk,17851
7
+ cache_dit/cache_factory/cache_adapters.py,sha256=6YbBSfKEGdWi9oY1ceuxi-MpHcaDYoQ-t6NTaLZITR4,17938
8
8
  cache_dit/cache_factory/cache_interface.py,sha256=y1nY6R3MucRmAnG2UJRI_tIKrRk27FktGWLbfckf3zE,8543
9
9
  cache_dit/cache_factory/cache_types.py,sha256=ooukxQRG55uTLmaZ0SKw6gIeY6SQHhMxkbv55uj2Sqk,991
10
10
  cache_dit/cache_factory/forward_pattern.py,sha256=FumlCuZ-TSmSYH0hGBHctSJ-oGLCftdZjLygqhsmdR4,2258
11
11
  cache_dit/cache_factory/utils.py,sha256=XkVM9AXcB9zYq8-S8QKAsGz80r3tA6U3lBNGDGeHOe4,1871
12
- cache_dit/cache_factory/block_adapters/__init__.py,sha256=EA-4mEVy-JJ5vRDo6C3nJIOXu0ZDNc6FQ-ZLAKHDtB0,17251
13
- cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=OrKhuNdcGBCgSsPchdf4h32Ad-bQVUXNigMhPJ4cCvk,21069
12
+ cache_dit/cache_factory/block_adapters/__init__.py,sha256=x2ivShzOy2z3p1WUArzoChR4jaLHhNXkXMSk-RPzR3g,17534
13
+ cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=EQBiJYyoInKU1ND69wTm7M0n5Ja4I8QW01SgRpBjSn8,21671
14
14
  cache_dit/cache_factory/block_adapters/block_registers.py,sha256=ZeN2wGPmuf2u3puSsBx8x-rl3wRo8-cWcuWNcrssVfA,2553
15
- cache_dit/cache_factory/cache_blocks/__init__.py,sha256=OWjnpJxA8EJVoRzuyb5miuiRphUFj831-bbtWsTDjnM,2750
15
+ cache_dit/cache_factory/cache_blocks/__init__.py,sha256=08Ox7kD05lkRKCOsVTdEZeKAWBheqpxfrAT1Nz7eclI,2916
16
16
  cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py,sha256=ElMps6_7uI74tSF9GDR_dEI0bZEhdzcepM29xFWnYo8,428
17
17
  cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py,sha256=nf2f5wdxp6tfq9AhFyMyBeKiZfxh63WG1g8q-c2BBSg,10182
18
- cache_dit/cache_factory/cache_blocks/pattern_base.py,sha256=XSDy3hsaKbAZPGZY92YgGA0qLgjQyIX8irQkb2R5T2c,20331
18
+ cache_dit/cache_factory/cache_blocks/pattern_base.py,sha256=_sajtb-Cz8yrCRBRSiJREzFG7h6265K9pXeAz5i1meY,20814
19
19
  cache_dit/cache_factory/cache_blocks/utils.py,sha256=dGOC1tMMOvcbvEgx44eTESKn_jsv-0RZ3tRHPa3wmQ4,1315
20
20
  cache_dit/cache_factory/cache_contexts/__init__.py,sha256=rqnJ5__zqnpVHK5A1OqWILpNh5Ss-0ZDTGgtxZMKGGo,250
21
21
  cache_dit/cache_factory/cache_contexts/cache_context.py,sha256=N88WLdd4KE9DuMWmpX8URcF55E2zWNwcKMxgVYkxMJY,13691
22
22
  cache_dit/cache_factory/cache_contexts/cache_manager.py,sha256=_NUXcMYYEIVfDHpc4HJr9RUjU5RUEkZmAgFGE8bh5Wc,34883
23
23
  cache_dit/cache_factory/cache_contexts/taylorseer.py,sha256=etSUIZzDvqW3ScKCbccTPcFaSmxV1T-xAXdk-p3e3wk,3802
24
- cache_dit/cache_factory/patch_functors/__init__.py,sha256=yK05iONMGILsTZ83ynrUUJtiJKJ_FDjxmVIzRLy416s,252
24
+ cache_dit/cache_factory/patch_functors/__init__.py,sha256=06zdddrjvSCgBzJ0a8niRHd3ucF2qsbzlbL00d4aCvk,451
25
25
  cache_dit/cache_factory/patch_functors/functor_base.py,sha256=Ahk0fTfrHgNdEl-9JSkACvfyyv9G-Ei5OSz7XBIlX5o,357
26
- cache_dit/cache_factory/patch_functors/functor_chroma.py,sha256=N3UzszCM55g3GHeVdyXkid2_n72VJrfqBM2gdtD52gw,10042
27
- cache_dit/cache_factory/patch_functors/functor_flux.py,sha256=rJsbGEIxWPGnZyGI4ZwLLBdg8u6ZItsOeh0UoD_bVwk,9551
26
+ cache_dit/cache_factory/patch_functors/functor_chroma.py,sha256=2iLxlsc-1dDHRveqCXaC07E9CeMNOuBNkvpJ1atpK7E,10048
27
+ cache_dit/cache_factory/patch_functors/functor_flux.py,sha256=UMkyuEYjO7UO_zmXi9Djd-nD-XMgCUgE-qkYA3plWSM,9559
28
+ cache_dit/cache_factory/patch_functors/functor_hidream.py,sha256=pi_vvpDy1lsgQHxu3eK3v93rdJL7oNwkt3WakRP8pbw,15375
29
+ cache_dit/cache_factory/patch_functors/functor_hunyuan_dit.py,sha256=iSo5dD5uKnjQQeysDUIkKt0wdnK5bzXTc_F_lfHG70w,6401
28
30
  cache_dit/compile/__init__.py,sha256=FcTVzCeyypl-mxlc59_ehHL3lBNiDAFsXuRoJ-5Cfi0,56
29
31
  cache_dit/compile/utils.py,sha256=nN2OIrSdwRR5zGxJinKDqb07pXpvTNTF3g_OgLkeeBU,3858
30
32
  cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -39,9 +41,9 @@ cache_dit/quantize/__init__.py,sha256=kWYoMAyZgBXu9BJlZjTQ0dRffW9GqeeY9_iTkXrb70
39
41
  cache_dit/quantize/quantize_ao.py,sha256=mGspqYgQtenl3QnKPtsSYsSD7LbVX93f1M940bhXKLU,6066
40
42
  cache_dit/quantize/quantize_interface.py,sha256=2s_R7xPSKuJeFpEGeLwRxnq_CqJcBG3a3lzyW5wh-UM,1241
41
43
  cache_dit/quantize/quantize_svdq.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
42
- cache_dit-0.2.30.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
43
- cache_dit-0.2.30.dist-info/METADATA,sha256=8Ln_X5fw14U3greCM7cSukrei1SRiMDpksFalg5ZBAU,22130
44
- cache_dit-0.2.30.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
45
- cache_dit-0.2.30.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
46
- cache_dit-0.2.30.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
47
- cache_dit-0.2.30.dist-info/RECORD,,
44
+ cache_dit-0.2.31.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
45
+ cache_dit-0.2.31.dist-info/METADATA,sha256=MrRvt7HL8pNm0ZsBxKO25pBcCJhHPG7HddwjT_euy_I,23198
46
+ cache_dit-0.2.31.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
47
+ cache_dit-0.2.31.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
48
+ cache_dit-0.2.31.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
49
+ cache_dit-0.2.31.dist-info/RECORD,,