cache-dit 0.2.30__py3-none-any.whl → 0.2.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/block_adapters/__init__.py +15 -10
- cache_dit/cache_factory/block_adapters/block_adapters.py +19 -0
- cache_dit/cache_factory/cache_adapters.py +8 -1
- cache_dit/cache_factory/cache_blocks/__init__.py +3 -0
- cache_dit/cache_factory/cache_blocks/pattern_base.py +34 -7
- cache_dit/cache_factory/cache_interface.py +2 -2
- cache_dit/cache_factory/patch_functors/__init__.py +6 -0
- cache_dit/cache_factory/patch_functors/functor_chroma.py +3 -2
- cache_dit/cache_factory/patch_functors/functor_flux.py +3 -2
- cache_dit/cache_factory/patch_functors/functor_hidream.py +412 -0
- cache_dit/cache_factory/patch_functors/functor_hunyuan_dit.py +213 -0
- {cache_dit-0.2.30.dist-info → cache_dit-0.2.32.dist-info}/METADATA +28 -13
- {cache_dit-0.2.30.dist-info → cache_dit-0.2.32.dist-info}/RECORD +18 -16
- {cache_dit-0.2.30.dist-info → cache_dit-0.2.32.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.30.dist-info → cache_dit-0.2.32.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.30.dist-info → cache_dit-0.2.32.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.30.dist-info → cache_dit-0.2.32.dist-info}/top_level.txt +0 -0
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.2.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 2,
|
|
31
|
+
__version__ = version = '0.2.32'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 2, 32)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -254,7 +254,7 @@ def skyreelsv2_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
254
254
|
)
|
|
255
255
|
|
|
256
256
|
|
|
257
|
-
@BlockAdapterRegistry.register("
|
|
257
|
+
@BlockAdapterRegistry.register("StableDiffusion3")
|
|
258
258
|
def sd3_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
259
259
|
from diffusers import SD3Transformer2DModel
|
|
260
260
|
|
|
@@ -501,7 +501,7 @@ def shape_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
501
501
|
)
|
|
502
502
|
|
|
503
503
|
|
|
504
|
-
@BlockAdapterRegistry.register("HiDream"
|
|
504
|
+
@BlockAdapterRegistry.register("HiDream")
|
|
505
505
|
def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
506
506
|
# NOTE: Need to patch Transformer forward to fully support
|
|
507
507
|
# double_stream_blocks and single_stream_blocks, namely, need
|
|
@@ -509,29 +509,32 @@ def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
509
509
|
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L893
|
|
510
510
|
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L927
|
|
511
511
|
from diffusers import HiDreamImageTransformer2DModel
|
|
512
|
+
from cache_dit.cache_factory.patch_functors import HiDreamPatchFunctor
|
|
512
513
|
|
|
513
514
|
assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
|
|
514
515
|
return BlockAdapter(
|
|
515
516
|
pipe=pipe,
|
|
516
517
|
transformer=pipe.transformer,
|
|
517
518
|
blocks=[
|
|
518
|
-
|
|
519
|
+
pipe.transformer.double_stream_blocks,
|
|
519
520
|
pipe.transformer.single_stream_blocks,
|
|
520
521
|
],
|
|
521
522
|
forward_pattern=[
|
|
522
|
-
|
|
523
|
+
ForwardPattern.Pattern_0,
|
|
523
524
|
ForwardPattern.Pattern_3,
|
|
524
525
|
],
|
|
525
|
-
|
|
526
|
-
|
|
526
|
+
patch_functor=HiDreamPatchFunctor(),
|
|
527
|
+
# NOTE: The type hint in diffusers is wrong
|
|
528
|
+
check_forward_pattern=True,
|
|
529
|
+
check_num_outputs=True,
|
|
527
530
|
**kwargs,
|
|
528
531
|
)
|
|
529
532
|
|
|
530
533
|
|
|
531
|
-
@BlockAdapterRegistry.register("HunyuanDiT"
|
|
534
|
+
@BlockAdapterRegistry.register("HunyuanDiT")
|
|
532
535
|
def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
533
|
-
# TODO: Patch Transformer forward
|
|
534
536
|
from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
|
|
537
|
+
from cache_dit.cache_factory.patch_functors import HunyuanDiTPatchFunctor
|
|
535
538
|
|
|
536
539
|
assert isinstance(
|
|
537
540
|
pipe.transformer,
|
|
@@ -542,14 +545,15 @@ def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
542
545
|
transformer=pipe.transformer,
|
|
543
546
|
blocks=pipe.transformer.blocks,
|
|
544
547
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
548
|
+
patch_functor=HunyuanDiTPatchFunctor(),
|
|
545
549
|
**kwargs,
|
|
546
550
|
)
|
|
547
551
|
|
|
548
552
|
|
|
549
|
-
@BlockAdapterRegistry.register("HunyuanDiTPAG"
|
|
553
|
+
@BlockAdapterRegistry.register("HunyuanDiTPAG")
|
|
550
554
|
def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
551
|
-
# TODO: Patch Transformer forward
|
|
552
555
|
from diffusers import HunyuanDiT2DModel
|
|
556
|
+
from cache_dit.cache_factory.patch_functors import HunyuanDiTPatchFunctor
|
|
553
557
|
|
|
554
558
|
assert isinstance(pipe.transformer, HunyuanDiT2DModel)
|
|
555
559
|
return BlockAdapter(
|
|
@@ -557,5 +561,6 @@ def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
557
561
|
transformer=pipe.transformer,
|
|
558
562
|
blocks=pipe.transformer.blocks,
|
|
559
563
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
564
|
+
patch_functor=HunyuanDiTPatchFunctor(),
|
|
560
565
|
**kwargs,
|
|
561
566
|
)
|
|
@@ -75,6 +75,7 @@ class BlockAdapter:
|
|
|
75
75
|
List[List[ParamsModifier]],
|
|
76
76
|
] = None
|
|
77
77
|
|
|
78
|
+
check_forward_pattern: bool = True
|
|
78
79
|
check_num_outputs: bool = False
|
|
79
80
|
|
|
80
81
|
# Pipeline Level Flags
|
|
@@ -391,11 +392,20 @@ class BlockAdapter:
|
|
|
391
392
|
forward_pattern: ForwardPattern,
|
|
392
393
|
**kwargs,
|
|
393
394
|
) -> bool:
|
|
395
|
+
|
|
396
|
+
if not kwargs.get("check_forward_pattern", True):
|
|
397
|
+
return True
|
|
398
|
+
|
|
394
399
|
assert (
|
|
395
400
|
forward_pattern.Supported
|
|
396
401
|
and forward_pattern in ForwardPattern.supported_patterns()
|
|
397
402
|
), f"Pattern {forward_pattern} is not support now!"
|
|
398
403
|
|
|
404
|
+
# NOTE: Special case for HiDreamBlock
|
|
405
|
+
if hasattr(block, "block"):
|
|
406
|
+
if isinstance(block.block, torch.nn.Module):
|
|
407
|
+
block = block.block
|
|
408
|
+
|
|
399
409
|
forward_parameters = set(
|
|
400
410
|
inspect.signature(block.forward).parameters.keys()
|
|
401
411
|
)
|
|
@@ -425,6 +435,14 @@ class BlockAdapter:
|
|
|
425
435
|
logging: bool = True,
|
|
426
436
|
**kwargs,
|
|
427
437
|
) -> bool:
|
|
438
|
+
|
|
439
|
+
if not kwargs.get("check_forward_pattern", True):
|
|
440
|
+
if logging:
|
|
441
|
+
logger.warning(
|
|
442
|
+
f"Skipped Forward Pattern Check: {forward_pattern}"
|
|
443
|
+
)
|
|
444
|
+
return True
|
|
445
|
+
|
|
428
446
|
assert (
|
|
429
447
|
forward_pattern.Supported
|
|
430
448
|
and forward_pattern in ForwardPattern.supported_patterns()
|
|
@@ -531,6 +549,7 @@ class BlockAdapter:
|
|
|
531
549
|
blocks,
|
|
532
550
|
forward_pattern=forward_pattern,
|
|
533
551
|
check_num_outputs=adapter.check_num_outputs,
|
|
552
|
+
check_forward_pattern=adapter.check_forward_pattern,
|
|
534
553
|
), (
|
|
535
554
|
"No block forward pattern matched, "
|
|
536
555
|
f"supported lists: {ForwardPattern.supported_patterns()}"
|
|
@@ -114,7 +114,7 @@ class CachedAdapter:
|
|
|
114
114
|
**cache_context_kwargs,
|
|
115
115
|
):
|
|
116
116
|
# Check cache_context_kwargs
|
|
117
|
-
if
|
|
117
|
+
if cache_context_kwargs["enable_spearate_cfg"] is None:
|
|
118
118
|
# Check cfg for some specific case if users don't set it as True
|
|
119
119
|
if BlockAdapterRegistry.has_separate_cfg(block_adapter):
|
|
120
120
|
cache_context_kwargs["enable_spearate_cfg"] = True
|
|
@@ -131,6 +131,12 @@ class CachedAdapter:
|
|
|
131
131
|
f"register: {cache_context_kwargs['enable_spearate_cfg']}, "
|
|
132
132
|
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
133
133
|
)
|
|
134
|
+
else:
|
|
135
|
+
logger.info(
|
|
136
|
+
f"Use custom 'enable_spearate_cfg' from cache context "
|
|
137
|
+
f"kwargs: {cache_context_kwargs['enable_spearate_cfg']}. "
|
|
138
|
+
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
139
|
+
)
|
|
134
140
|
|
|
135
141
|
if (
|
|
136
142
|
cache_type := cache_context_kwargs.pop("cache_type", None)
|
|
@@ -345,6 +351,7 @@ class CachedAdapter:
|
|
|
345
351
|
block_adapter.blocks[i][j],
|
|
346
352
|
transformer=block_adapter.transformer[i],
|
|
347
353
|
forward_pattern=block_adapter.forward_pattern[i][j],
|
|
354
|
+
check_forward_pattern=block_adapter.check_forward_pattern,
|
|
348
355
|
check_num_outputs=block_adapter.check_num_outputs,
|
|
349
356
|
# 1. Cache context configuration
|
|
350
357
|
cache_prefix=block_adapter.blocks_name[i][j],
|
|
@@ -25,6 +25,7 @@ class CachedBlocks:
|
|
|
25
25
|
transformer_blocks: torch.nn.ModuleList,
|
|
26
26
|
transformer: torch.nn.Module = None,
|
|
27
27
|
forward_pattern: ForwardPattern = None,
|
|
28
|
+
check_forward_pattern: bool = True,
|
|
28
29
|
check_num_outputs: bool = True,
|
|
29
30
|
# 1. Cache context configuration
|
|
30
31
|
# 'transformer_blocks', 'blocks', 'single_transformer_blocks',
|
|
@@ -45,6 +46,7 @@ class CachedBlocks:
|
|
|
45
46
|
transformer_blocks,
|
|
46
47
|
transformer=transformer,
|
|
47
48
|
forward_pattern=forward_pattern,
|
|
49
|
+
check_forward_pattern=check_forward_pattern,
|
|
48
50
|
check_num_outputs=check_num_outputs,
|
|
49
51
|
# 1. Cache context configuration
|
|
50
52
|
cache_prefix=cache_prefix,
|
|
@@ -58,6 +60,7 @@ class CachedBlocks:
|
|
|
58
60
|
transformer_blocks,
|
|
59
61
|
transformer=transformer,
|
|
60
62
|
forward_pattern=forward_pattern,
|
|
63
|
+
check_forward_pattern=check_forward_pattern,
|
|
61
64
|
check_num_outputs=check_num_outputs,
|
|
62
65
|
# 1. Cache context configuration
|
|
63
66
|
cache_prefix=cache_prefix,
|
|
@@ -25,6 +25,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
25
25
|
transformer_blocks: torch.nn.ModuleList,
|
|
26
26
|
transformer: torch.nn.Module = None,
|
|
27
27
|
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
28
|
+
check_forward_pattern: bool = True,
|
|
28
29
|
check_num_outputs: bool = True,
|
|
29
30
|
# 1. Cache context configuration
|
|
30
31
|
cache_prefix: str = None, # maybe un-need.
|
|
@@ -38,6 +39,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
38
39
|
self.transformer = transformer
|
|
39
40
|
self.transformer_blocks = transformer_blocks
|
|
40
41
|
self.forward_pattern = forward_pattern
|
|
42
|
+
self.check_forward_pattern = check_forward_pattern
|
|
41
43
|
self.check_num_outputs = check_num_outputs
|
|
42
44
|
# 1. Cache context configuration
|
|
43
45
|
self.cache_prefix = cache_prefix
|
|
@@ -52,6 +54,12 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
52
54
|
)
|
|
53
55
|
|
|
54
56
|
def _check_forward_pattern(self):
|
|
57
|
+
if not self.check_forward_pattern:
|
|
58
|
+
logger.warning(
|
|
59
|
+
f"Skipped Forward Pattern Check: {self.forward_pattern}"
|
|
60
|
+
)
|
|
61
|
+
return
|
|
62
|
+
|
|
55
63
|
assert (
|
|
56
64
|
self.forward_pattern.Supported
|
|
57
65
|
and self.forward_pattern in self._supported_patterns
|
|
@@ -59,6 +67,11 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
59
67
|
|
|
60
68
|
if self.transformer_blocks is not None:
|
|
61
69
|
for block in self.transformer_blocks:
|
|
70
|
+
# Special case for HiDreamBlock
|
|
71
|
+
if hasattr(block, "block"):
|
|
72
|
+
if isinstance(block.block, torch.nn.Module):
|
|
73
|
+
block = block.block
|
|
74
|
+
|
|
62
75
|
forward_parameters = set(
|
|
63
76
|
inspect.signature(block.forward).parameters.keys()
|
|
64
77
|
)
|
|
@@ -332,12 +345,19 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
332
345
|
|
|
333
346
|
# compute hidden_states residual
|
|
334
347
|
hidden_states = hidden_states.contiguous()
|
|
335
|
-
encoder_hidden_states = encoder_hidden_states.contiguous()
|
|
336
348
|
|
|
337
349
|
hidden_states_residual = hidden_states - original_hidden_states
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
350
|
+
|
|
351
|
+
if (
|
|
352
|
+
encoder_hidden_states is not None
|
|
353
|
+
and original_encoder_hidden_states is not None
|
|
354
|
+
):
|
|
355
|
+
encoder_hidden_states = encoder_hidden_states.contiguous()
|
|
356
|
+
encoder_hidden_states_residual = (
|
|
357
|
+
encoder_hidden_states - original_encoder_hidden_states
|
|
358
|
+
)
|
|
359
|
+
else:
|
|
360
|
+
encoder_hidden_states_residual = None
|
|
341
361
|
|
|
342
362
|
return (
|
|
343
363
|
hidden_states,
|
|
@@ -387,9 +407,16 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
387
407
|
Bn_i_hidden_states_residual = (
|
|
388
408
|
hidden_states - Bn_i_original_hidden_states
|
|
389
409
|
)
|
|
390
|
-
|
|
391
|
-
encoder_hidden_states
|
|
392
|
-
|
|
410
|
+
if (
|
|
411
|
+
encoder_hidden_states is not None
|
|
412
|
+
and Bn_i_original_encoder_hidden_states is not None
|
|
413
|
+
):
|
|
414
|
+
Bn_i_encoder_hidden_states_residual = (
|
|
415
|
+
encoder_hidden_states
|
|
416
|
+
- Bn_i_original_encoder_hidden_states
|
|
417
|
+
)
|
|
418
|
+
else:
|
|
419
|
+
Bn_i_encoder_hidden_states_residual = None
|
|
393
420
|
|
|
394
421
|
# Save original_hidden_states for diff calculation.
|
|
395
422
|
self.cache_manager.set_Bn_buffer(
|
|
@@ -24,7 +24,7 @@ def enable_cache(
|
|
|
24
24
|
max_continuous_cached_steps: int = -1,
|
|
25
25
|
residual_diff_threshold: float = 0.08,
|
|
26
26
|
# Cache CFG or not
|
|
27
|
-
enable_spearate_cfg: bool =
|
|
27
|
+
enable_spearate_cfg: bool | None = None,
|
|
28
28
|
cfg_compute_first: bool = False,
|
|
29
29
|
cfg_diff_compute_separate: bool = True,
|
|
30
30
|
# Hybird TaylorSeer
|
|
@@ -70,7 +70,7 @@ def enable_cache(
|
|
|
70
70
|
residual_diff_threshold (`float`, *required*, defaults to 0.08):
|
|
71
71
|
he value of residual diff threshold, a higher value leads to faster performance at the
|
|
72
72
|
cost of lower precision.
|
|
73
|
-
enable_spearate_cfg (`bool`, *required*, defaults to
|
|
73
|
+
enable_spearate_cfg (`bool`, *required*, defaults to None):
|
|
74
74
|
Whether to do separate cfg or not, such as Wan 2.1, Qwen-Image. For model that fused CFG
|
|
75
75
|
and non-CFG into single forward step, should set enable_spearate_cfg as False, for example:
|
|
76
76
|
CogVideoX, HunyuanVideo, Mochi, etc.
|
|
@@ -3,3 +3,9 @@ from cache_dit.cache_factory.patch_functors.functor_flux import FluxPatchFunctor
|
|
|
3
3
|
from cache_dit.cache_factory.patch_functors.functor_chroma import (
|
|
4
4
|
ChromaPatchFunctor,
|
|
5
5
|
)
|
|
6
|
+
from cache_dit.cache_factory.patch_functors.functor_hidream import (
|
|
7
|
+
HiDreamPatchFunctor,
|
|
8
|
+
)
|
|
9
|
+
from cache_dit.cache_factory.patch_functors.functor_hunyuan_dit import (
|
|
10
|
+
HunyuanDiTPatchFunctor,
|
|
11
|
+
)
|
|
@@ -46,8 +46,10 @@ class ChromaPatchFunctor(PatchFunctor):
|
|
|
46
46
|
block.forward = __patch_single_forward__.__get__(block)
|
|
47
47
|
is_patched = True
|
|
48
48
|
|
|
49
|
+
cls_name = transformer.__class__.__name__
|
|
50
|
+
|
|
49
51
|
if is_patched:
|
|
50
|
-
logger.warning("Patched
|
|
52
|
+
logger.warning(f"Patched {cls_name} for cache-dit.")
|
|
51
53
|
assert not getattr(transformer, "_is_parallelized", False), (
|
|
52
54
|
"Please call `cache_dit.enable_cache` before Parallelize, "
|
|
53
55
|
"the __patch_transformer_forward__ will overwrite the "
|
|
@@ -59,7 +61,6 @@ class ChromaPatchFunctor(PatchFunctor):
|
|
|
59
61
|
|
|
60
62
|
transformer._is_patched = is_patched # True or False
|
|
61
63
|
|
|
62
|
-
cls_name = transformer.__class__.__name__
|
|
63
64
|
logger.info(
|
|
64
65
|
f"Applied {self.__class__.__name__} for {cls_name}, "
|
|
65
66
|
f"Patch: {is_patched}."
|
|
@@ -47,8 +47,10 @@ class FluxPatchFunctor(PatchFunctor):
|
|
|
47
47
|
block.forward = __patch_single_forward__.__get__(block)
|
|
48
48
|
is_patched = True
|
|
49
49
|
|
|
50
|
+
cls_name = transformer.__class__.__name__
|
|
51
|
+
|
|
50
52
|
if is_patched:
|
|
51
|
-
logger.warning("Patched
|
|
53
|
+
logger.warning(f"Patched {cls_name} for cache-dit.")
|
|
52
54
|
assert not getattr(transformer, "_is_parallelized", False), (
|
|
53
55
|
"Please call `cache_dit.enable_cache` before Parallelize, "
|
|
54
56
|
"the __patch_transformer_forward__ will overwrite the "
|
|
@@ -60,7 +62,6 @@ class FluxPatchFunctor(PatchFunctor):
|
|
|
60
62
|
|
|
61
63
|
transformer._is_patched = is_patched # True or False
|
|
62
64
|
|
|
63
|
-
cls_name = transformer.__class__.__name__
|
|
64
65
|
logger.info(
|
|
65
66
|
f"Applied {self.__class__.__name__} for {cls_name}, "
|
|
66
67
|
f"Patch: {is_patched}."
|
|
@@ -0,0 +1,412 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Tuple, Optional, Dict, Any, Union, List
|
|
3
|
+
from diffusers import HiDreamImageTransformer2DModel
|
|
4
|
+
from diffusers.models.transformers.transformer_hidream_image import (
|
|
5
|
+
HiDreamBlock,
|
|
6
|
+
HiDreamImageTransformerBlock,
|
|
7
|
+
HiDreamImageSingleTransformerBlock,
|
|
8
|
+
Transformer2DModelOutput,
|
|
9
|
+
)
|
|
10
|
+
from diffusers.utils import (
|
|
11
|
+
deprecate,
|
|
12
|
+
USE_PEFT_BACKEND,
|
|
13
|
+
scale_lora_layers,
|
|
14
|
+
unscale_lora_layers,
|
|
15
|
+
)
|
|
16
|
+
from cache_dit.cache_factory.patch_functors.functor_base import (
|
|
17
|
+
PatchFunctor,
|
|
18
|
+
)
|
|
19
|
+
from cache_dit.logger import init_logger
|
|
20
|
+
|
|
21
|
+
logger = init_logger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class HiDreamPatchFunctor(PatchFunctor):
|
|
25
|
+
|
|
26
|
+
def apply(
|
|
27
|
+
self,
|
|
28
|
+
transformer: HiDreamImageTransformer2DModel,
|
|
29
|
+
**kwargs,
|
|
30
|
+
) -> HiDreamImageTransformer2DModel:
|
|
31
|
+
if hasattr(transformer, "_is_patched"):
|
|
32
|
+
return transformer
|
|
33
|
+
|
|
34
|
+
is_patched = False
|
|
35
|
+
|
|
36
|
+
_block_id = 0
|
|
37
|
+
for block in transformer.double_stream_blocks:
|
|
38
|
+
assert isinstance(block, HiDreamBlock)
|
|
39
|
+
block.forward = __patch_block_forward__.__get__(block)
|
|
40
|
+
# NOTE: Patch Inner block and block_id
|
|
41
|
+
_block = block.block
|
|
42
|
+
assert isinstance(_block, HiDreamImageTransformerBlock)
|
|
43
|
+
_block._block_id = _block_id
|
|
44
|
+
_block.forward = __patch_double_forward__.__get__(_block)
|
|
45
|
+
_block_id += 1
|
|
46
|
+
|
|
47
|
+
for block in transformer.single_stream_blocks:
|
|
48
|
+
assert isinstance(block, HiDreamBlock)
|
|
49
|
+
block.forward = __patch_block_forward__.__get__(block)
|
|
50
|
+
# NOTE: Patch Inner block and block_id
|
|
51
|
+
_block = block.block
|
|
52
|
+
assert isinstance(_block, HiDreamImageSingleTransformerBlock)
|
|
53
|
+
_block._block_id = _block_id
|
|
54
|
+
_block.forward = __patch_single_forward__.__get__(_block)
|
|
55
|
+
_block_id += 1
|
|
56
|
+
|
|
57
|
+
is_patched = True
|
|
58
|
+
cls_name = transformer.__class__.__name__
|
|
59
|
+
|
|
60
|
+
if is_patched:
|
|
61
|
+
logger.warning(f"Patched {cls_name} for cache-dit.")
|
|
62
|
+
assert not getattr(transformer, "_is_parallelized", False), (
|
|
63
|
+
"Please call `cache_dit.enable_cache` before Parallelize, "
|
|
64
|
+
"the __patch_transformer_forward__ will overwrite the "
|
|
65
|
+
"parallized forward and cause a downgrade of performance."
|
|
66
|
+
)
|
|
67
|
+
transformer.forward = __patch_transformer_forward__.__get__(
|
|
68
|
+
transformer
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
transformer._is_patched = is_patched # True or False
|
|
72
|
+
|
|
73
|
+
logger.info(
|
|
74
|
+
f"Applied {self.__class__.__name__} for {cls_name}, "
|
|
75
|
+
f"Patch: {is_patched}."
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
return transformer
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py
|
|
82
|
+
def __patch_double_forward__(
|
|
83
|
+
self: HiDreamImageTransformerBlock,
|
|
84
|
+
hidden_states: torch.Tensor,
|
|
85
|
+
encoder_hidden_states: torch.Tensor, # initial_encoder_hidden_states
|
|
86
|
+
hidden_states_masks: Optional[torch.Tensor] = None,
|
|
87
|
+
temb: Optional[torch.Tensor] = None,
|
|
88
|
+
image_rotary_emb: torch.Tensor = None,
|
|
89
|
+
llama31_encoder_hidden_states: List[torch.Tensor] = None,
|
|
90
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
91
|
+
# Assume block_id was patched in transformer forward:
|
|
92
|
+
# for i, block in enumerate(blocks): block._block_id = i;
|
|
93
|
+
block_id = self._block_id
|
|
94
|
+
initial_encoder_hidden_states_seq_len = encoder_hidden_states.shape[1]
|
|
95
|
+
cur_llama31_encoder_hidden_states = llama31_encoder_hidden_states[block_id]
|
|
96
|
+
cur_encoder_hidden_states = torch.cat(
|
|
97
|
+
[encoder_hidden_states, cur_llama31_encoder_hidden_states],
|
|
98
|
+
dim=1,
|
|
99
|
+
)
|
|
100
|
+
encoder_hidden_states = cur_encoder_hidden_states
|
|
101
|
+
|
|
102
|
+
wtype = hidden_states.dtype
|
|
103
|
+
(
|
|
104
|
+
shift_msa_i,
|
|
105
|
+
scale_msa_i,
|
|
106
|
+
gate_msa_i,
|
|
107
|
+
shift_mlp_i,
|
|
108
|
+
scale_mlp_i,
|
|
109
|
+
gate_mlp_i,
|
|
110
|
+
shift_msa_t,
|
|
111
|
+
scale_msa_t,
|
|
112
|
+
gate_msa_t,
|
|
113
|
+
shift_mlp_t,
|
|
114
|
+
scale_mlp_t,
|
|
115
|
+
gate_mlp_t,
|
|
116
|
+
) = self.adaLN_modulation(temb)[:, None].chunk(12, dim=-1)
|
|
117
|
+
|
|
118
|
+
# 1. MM-Attention
|
|
119
|
+
norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype)
|
|
120
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_msa_i) + shift_msa_i
|
|
121
|
+
norm_encoder_hidden_states = self.norm1_t(encoder_hidden_states).to(
|
|
122
|
+
dtype=wtype
|
|
123
|
+
)
|
|
124
|
+
norm_encoder_hidden_states = (
|
|
125
|
+
norm_encoder_hidden_states * (1 + scale_msa_t) + shift_msa_t
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
attn_output_i, attn_output_t = self.attn1(
|
|
129
|
+
norm_hidden_states,
|
|
130
|
+
hidden_states_masks,
|
|
131
|
+
norm_encoder_hidden_states,
|
|
132
|
+
image_rotary_emb=image_rotary_emb,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
hidden_states = gate_msa_i * attn_output_i + hidden_states
|
|
136
|
+
encoder_hidden_states = gate_msa_t * attn_output_t + encoder_hidden_states
|
|
137
|
+
|
|
138
|
+
# 2. Feed-forward
|
|
139
|
+
norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype)
|
|
140
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i
|
|
141
|
+
norm_encoder_hidden_states = self.norm3_t(encoder_hidden_states).to(
|
|
142
|
+
dtype=wtype
|
|
143
|
+
)
|
|
144
|
+
norm_encoder_hidden_states = (
|
|
145
|
+
norm_encoder_hidden_states * (1 + scale_mlp_t) + shift_mlp_t
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states)
|
|
149
|
+
ff_output_t = gate_mlp_t * self.ff_t(norm_encoder_hidden_states)
|
|
150
|
+
hidden_states = ff_output_i + hidden_states
|
|
151
|
+
encoder_hidden_states = ff_output_t + encoder_hidden_states
|
|
152
|
+
|
|
153
|
+
initial_encoder_hidden_states = encoder_hidden_states[
|
|
154
|
+
:, :initial_encoder_hidden_states_seq_len
|
|
155
|
+
]
|
|
156
|
+
|
|
157
|
+
return hidden_states, initial_encoder_hidden_states
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py
|
|
161
|
+
def __patch_single_forward__(
|
|
162
|
+
self: HiDreamImageSingleTransformerBlock,
|
|
163
|
+
hidden_states: torch.Tensor,
|
|
164
|
+
hidden_states_masks: Optional[torch.Tensor] = None,
|
|
165
|
+
temb: Optional[torch.Tensor] = None,
|
|
166
|
+
image_rotary_emb: torch.Tensor = None,
|
|
167
|
+
llama31_encoder_hidden_states: List[torch.Tensor] = None,
|
|
168
|
+
) -> torch.Tensor:
|
|
169
|
+
# Assume block_id was patched in transformer forward:
|
|
170
|
+
# for i, block in enumerate(blocks): block._block_id = i;
|
|
171
|
+
block_id = self._block_id
|
|
172
|
+
hidden_states_seq_len = hidden_states.shape[1]
|
|
173
|
+
cur_llama31_encoder_hidden_states = llama31_encoder_hidden_states[block_id]
|
|
174
|
+
hidden_states = torch.cat(
|
|
175
|
+
[hidden_states, cur_llama31_encoder_hidden_states], dim=1
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
wtype = hidden_states.dtype
|
|
179
|
+
(
|
|
180
|
+
shift_msa_i,
|
|
181
|
+
scale_msa_i,
|
|
182
|
+
gate_msa_i,
|
|
183
|
+
shift_mlp_i,
|
|
184
|
+
scale_mlp_i,
|
|
185
|
+
gate_mlp_i,
|
|
186
|
+
) = self.adaLN_modulation(temb)[:, None].chunk(6, dim=-1)
|
|
187
|
+
|
|
188
|
+
# 1. MM-Attention
|
|
189
|
+
norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype)
|
|
190
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_msa_i) + shift_msa_i
|
|
191
|
+
attn_output_i = self.attn1(
|
|
192
|
+
norm_hidden_states,
|
|
193
|
+
hidden_states_masks,
|
|
194
|
+
image_rotary_emb=image_rotary_emb,
|
|
195
|
+
)
|
|
196
|
+
hidden_states = gate_msa_i * attn_output_i + hidden_states
|
|
197
|
+
|
|
198
|
+
# 2. Feed-forward
|
|
199
|
+
norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype)
|
|
200
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i
|
|
201
|
+
ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states.to(dtype=wtype))
|
|
202
|
+
hidden_states = ff_output_i + hidden_states
|
|
203
|
+
|
|
204
|
+
hidden_states = hidden_states[:, :hidden_states_seq_len]
|
|
205
|
+
|
|
206
|
+
return hidden_states
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py
|
|
210
|
+
def __patch_block_forward__(
|
|
211
|
+
self: HiDreamBlock,
|
|
212
|
+
hidden_states: torch.Tensor,
|
|
213
|
+
*args,
|
|
214
|
+
**kwargs,
|
|
215
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
216
|
+
return self.block(hidden_states, *args, **kwargs)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py
|
|
220
|
+
def __patch_transformer_forward__(
|
|
221
|
+
self: HiDreamImageTransformer2DModel,
|
|
222
|
+
hidden_states: torch.Tensor,
|
|
223
|
+
timesteps: torch.LongTensor = None,
|
|
224
|
+
encoder_hidden_states_t5: torch.Tensor = None,
|
|
225
|
+
encoder_hidden_states_llama3: torch.Tensor = None,
|
|
226
|
+
pooled_embeds: torch.Tensor = None,
|
|
227
|
+
img_ids: Optional[torch.Tensor] = None,
|
|
228
|
+
img_sizes: Optional[List[Tuple[int, int]]] = None,
|
|
229
|
+
hidden_states_masks: Optional[torch.Tensor] = None,
|
|
230
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
231
|
+
return_dict: bool = True,
|
|
232
|
+
**kwargs,
|
|
233
|
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
|
234
|
+
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
|
235
|
+
|
|
236
|
+
if encoder_hidden_states is not None:
|
|
237
|
+
deprecation_message = "The `encoder_hidden_states` argument is deprecated. Please use `encoder_hidden_states_t5` and `encoder_hidden_states_llama3` instead."
|
|
238
|
+
deprecate("encoder_hidden_states", "0.35.0", deprecation_message)
|
|
239
|
+
encoder_hidden_states_t5 = encoder_hidden_states[0]
|
|
240
|
+
encoder_hidden_states_llama3 = encoder_hidden_states[1]
|
|
241
|
+
|
|
242
|
+
if (
|
|
243
|
+
img_ids is not None
|
|
244
|
+
and img_sizes is not None
|
|
245
|
+
and hidden_states_masks is None
|
|
246
|
+
):
|
|
247
|
+
deprecation_message = "Passing `img_ids` and `img_sizes` with unpachified `hidden_states` is deprecated and will be ignored."
|
|
248
|
+
deprecate("img_ids", "0.35.0", deprecation_message)
|
|
249
|
+
|
|
250
|
+
if hidden_states_masks is not None and (
|
|
251
|
+
img_ids is None or img_sizes is None
|
|
252
|
+
):
|
|
253
|
+
raise ValueError(
|
|
254
|
+
"if `hidden_states_masks` is passed, `img_ids` and `img_sizes` must also be passed."
|
|
255
|
+
)
|
|
256
|
+
elif hidden_states_masks is not None and hidden_states.ndim != 3:
|
|
257
|
+
raise ValueError(
|
|
258
|
+
"if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)"
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
if attention_kwargs is not None:
|
|
262
|
+
attention_kwargs = attention_kwargs.copy()
|
|
263
|
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
|
264
|
+
else:
|
|
265
|
+
lora_scale = 1.0
|
|
266
|
+
|
|
267
|
+
if USE_PEFT_BACKEND:
|
|
268
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
|
269
|
+
scale_lora_layers(self, lora_scale)
|
|
270
|
+
else:
|
|
271
|
+
if (
|
|
272
|
+
attention_kwargs is not None
|
|
273
|
+
and attention_kwargs.get("scale", None) is not None
|
|
274
|
+
):
|
|
275
|
+
logger.warning(
|
|
276
|
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
# spatial forward
|
|
280
|
+
batch_size = hidden_states.shape[0]
|
|
281
|
+
hidden_states_type = hidden_states.dtype
|
|
282
|
+
|
|
283
|
+
# Patchify the input
|
|
284
|
+
if hidden_states_masks is None:
|
|
285
|
+
hidden_states, hidden_states_masks, img_sizes, img_ids = self.patchify(
|
|
286
|
+
hidden_states
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
# Embed the hidden states
|
|
290
|
+
hidden_states = self.x_embedder(hidden_states)
|
|
291
|
+
|
|
292
|
+
# 0. time
|
|
293
|
+
timesteps = self.t_embedder(timesteps, hidden_states_type)
|
|
294
|
+
p_embedder = self.p_embedder(pooled_embeds)
|
|
295
|
+
temb = timesteps + p_embedder
|
|
296
|
+
|
|
297
|
+
encoder_hidden_states = [
|
|
298
|
+
encoder_hidden_states_llama3[k] for k in self.config.llama_layers
|
|
299
|
+
]
|
|
300
|
+
|
|
301
|
+
if self.caption_projection is not None:
|
|
302
|
+
new_encoder_hidden_states = []
|
|
303
|
+
for i, enc_hidden_state in enumerate(encoder_hidden_states):
|
|
304
|
+
enc_hidden_state = self.caption_projection[i](enc_hidden_state)
|
|
305
|
+
enc_hidden_state = enc_hidden_state.view(
|
|
306
|
+
batch_size, -1, hidden_states.shape[-1]
|
|
307
|
+
)
|
|
308
|
+
new_encoder_hidden_states.append(enc_hidden_state)
|
|
309
|
+
encoder_hidden_states = new_encoder_hidden_states
|
|
310
|
+
encoder_hidden_states_t5 = self.caption_projection[-1](
|
|
311
|
+
encoder_hidden_states_t5
|
|
312
|
+
)
|
|
313
|
+
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(
|
|
314
|
+
batch_size, -1, hidden_states.shape[-1]
|
|
315
|
+
)
|
|
316
|
+
encoder_hidden_states.append(encoder_hidden_states_t5)
|
|
317
|
+
|
|
318
|
+
txt_ids = torch.zeros(
|
|
319
|
+
batch_size,
|
|
320
|
+
encoder_hidden_states[-1].shape[1]
|
|
321
|
+
+ encoder_hidden_states[-2].shape[1]
|
|
322
|
+
+ encoder_hidden_states[0].shape[1],
|
|
323
|
+
3,
|
|
324
|
+
device=img_ids.device,
|
|
325
|
+
dtype=img_ids.dtype,
|
|
326
|
+
)
|
|
327
|
+
ids = torch.cat((img_ids, txt_ids), dim=1)
|
|
328
|
+
image_rotary_emb = self.pe_embedder(ids)
|
|
329
|
+
|
|
330
|
+
# 2. Blocks
|
|
331
|
+
# NOTE: block_id is no-need anymore.
|
|
332
|
+
initial_encoder_hidden_states = torch.cat(
|
|
333
|
+
[encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1
|
|
334
|
+
)
|
|
335
|
+
llama31_encoder_hidden_states = encoder_hidden_states
|
|
336
|
+
for bid, block in enumerate(self.double_stream_blocks):
|
|
337
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
338
|
+
hidden_states, initial_encoder_hidden_states = (
|
|
339
|
+
self._gradient_checkpointing_func(
|
|
340
|
+
block,
|
|
341
|
+
hidden_states,
|
|
342
|
+
initial_encoder_hidden_states,
|
|
343
|
+
hidden_states_masks,
|
|
344
|
+
temb,
|
|
345
|
+
image_rotary_emb,
|
|
346
|
+
llama31_encoder_hidden_states,
|
|
347
|
+
)
|
|
348
|
+
)
|
|
349
|
+
else:
|
|
350
|
+
hidden_states, initial_encoder_hidden_states = block(
|
|
351
|
+
hidden_states,
|
|
352
|
+
initial_encoder_hidden_states, # encoder_hidden_states
|
|
353
|
+
hidden_states_masks=hidden_states_masks,
|
|
354
|
+
temb=temb,
|
|
355
|
+
image_rotary_emb=image_rotary_emb,
|
|
356
|
+
llama31_encoder_hidden_states=llama31_encoder_hidden_states,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
image_tokens_seq_len = hidden_states.shape[1]
|
|
360
|
+
hidden_states = torch.cat(
|
|
361
|
+
[hidden_states, initial_encoder_hidden_states], dim=1
|
|
362
|
+
)
|
|
363
|
+
if hidden_states_masks is not None:
|
|
364
|
+
# NOTE: Patched
|
|
365
|
+
cur_llama31_encoder_hidden_states = llama31_encoder_hidden_states[
|
|
366
|
+
self.double_stream_blocks[-1].block._block_id
|
|
367
|
+
]
|
|
368
|
+
encoder_attention_mask_ones = torch.ones(
|
|
369
|
+
(
|
|
370
|
+
batch_size,
|
|
371
|
+
initial_encoder_hidden_states.shape[1]
|
|
372
|
+
+ cur_llama31_encoder_hidden_states.shape[1],
|
|
373
|
+
),
|
|
374
|
+
device=hidden_states_masks.device,
|
|
375
|
+
dtype=hidden_states_masks.dtype,
|
|
376
|
+
)
|
|
377
|
+
hidden_states_masks = torch.cat(
|
|
378
|
+
[hidden_states_masks, encoder_attention_mask_ones], dim=1
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
for bid, block in enumerate(self.single_stream_blocks):
|
|
382
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
383
|
+
hidden_states = self._gradient_checkpointing_func(
|
|
384
|
+
block,
|
|
385
|
+
hidden_states,
|
|
386
|
+
hidden_states_masks,
|
|
387
|
+
temb,
|
|
388
|
+
image_rotary_emb,
|
|
389
|
+
llama31_encoder_hidden_states,
|
|
390
|
+
)
|
|
391
|
+
else:
|
|
392
|
+
hidden_states = block(
|
|
393
|
+
hidden_states,
|
|
394
|
+
hidden_states_masks=hidden_states_masks,
|
|
395
|
+
temb=temb,
|
|
396
|
+
image_rotary_emb=image_rotary_emb,
|
|
397
|
+
llama31_encoder_hidden_states=llama31_encoder_hidden_states,
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
|
|
401
|
+
output = self.final_layer(hidden_states, temb)
|
|
402
|
+
output = self.unpatchify(output, img_sizes, self.training)
|
|
403
|
+
if hidden_states_masks is not None:
|
|
404
|
+
hidden_states_masks = hidden_states_masks[:, :image_tokens_seq_len]
|
|
405
|
+
|
|
406
|
+
if USE_PEFT_BACKEND:
|
|
407
|
+
# remove `lora_scale` from each PEFT layer
|
|
408
|
+
unscale_lora_layers(self, lora_scale)
|
|
409
|
+
|
|
410
|
+
if not return_dict:
|
|
411
|
+
return (output,)
|
|
412
|
+
return Transformer2DModelOutput(sample=output)
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Optional, Union, List
|
|
3
|
+
from diffusers import HunyuanDiT2DModel
|
|
4
|
+
from diffusers.models.transformers.hunyuan_transformer_2d import (
|
|
5
|
+
HunyuanDiTBlock,
|
|
6
|
+
Transformer2DModelOutput,
|
|
7
|
+
)
|
|
8
|
+
from cache_dit.cache_factory.patch_functors.functor_base import (
|
|
9
|
+
PatchFunctor,
|
|
10
|
+
)
|
|
11
|
+
from cache_dit.logger import init_logger
|
|
12
|
+
|
|
13
|
+
logger = init_logger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class HunyuanDiTPatchFunctor(PatchFunctor):
|
|
17
|
+
|
|
18
|
+
def apply(
|
|
19
|
+
self,
|
|
20
|
+
transformer: HunyuanDiT2DModel,
|
|
21
|
+
**kwargs,
|
|
22
|
+
) -> HunyuanDiT2DModel:
|
|
23
|
+
if hasattr(transformer, "_is_patched"):
|
|
24
|
+
return transformer
|
|
25
|
+
|
|
26
|
+
is_patched = False
|
|
27
|
+
|
|
28
|
+
num_layers = transformer.config.num_layers
|
|
29
|
+
layer_id = 0
|
|
30
|
+
for block in transformer.blocks:
|
|
31
|
+
assert isinstance(block, HunyuanDiTBlock)
|
|
32
|
+
block._num_layers = num_layers
|
|
33
|
+
block._layer_id = layer_id
|
|
34
|
+
block.forward = __patch_block_forward__.__get__(block)
|
|
35
|
+
layer_id += 1
|
|
36
|
+
|
|
37
|
+
is_patched = True
|
|
38
|
+
|
|
39
|
+
cls_name = transformer.__class__.__name__
|
|
40
|
+
|
|
41
|
+
if is_patched:
|
|
42
|
+
logger.warning(f"Patched {cls_name} for cache-dit.")
|
|
43
|
+
assert not getattr(transformer, "_is_parallelized", False), (
|
|
44
|
+
"Please call `cache_dit.enable_cache` before Parallelize, "
|
|
45
|
+
"the __patch_transformer_forward__ will overwrite the "
|
|
46
|
+
"parallized forward and cause a downgrade of performance."
|
|
47
|
+
)
|
|
48
|
+
transformer.forward = __patch_transformer_forward__.__get__(
|
|
49
|
+
transformer
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
transformer._is_patched = is_patched # True or False
|
|
53
|
+
|
|
54
|
+
logger.info(
|
|
55
|
+
f"Applied {self.__class__.__name__} for {cls_name}, "
|
|
56
|
+
f"Patch: {is_patched}."
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
return transformer
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def __patch_block_forward__(
|
|
63
|
+
self: HunyuanDiTBlock,
|
|
64
|
+
hidden_states: torch.Tensor,
|
|
65
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
66
|
+
temb: Optional[torch.Tensor] = None,
|
|
67
|
+
image_rotary_emb: torch.Tensor = None,
|
|
68
|
+
controlnet_block_samples: torch.Tensor = None,
|
|
69
|
+
skips: List[torch.Tensor] = [],
|
|
70
|
+
) -> torch.Tensor:
|
|
71
|
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
|
72
|
+
# 0. Long Skip Connection
|
|
73
|
+
num_layers = self._num_layers
|
|
74
|
+
layer_id = self._layer_id
|
|
75
|
+
|
|
76
|
+
if layer_id > num_layers // 2:
|
|
77
|
+
if controlnet_block_samples is not None:
|
|
78
|
+
skip = skips.pop() + controlnet_block_samples.pop()
|
|
79
|
+
else:
|
|
80
|
+
skip = skips.pop()
|
|
81
|
+
else:
|
|
82
|
+
skip = None
|
|
83
|
+
|
|
84
|
+
if self.skip_linear is not None:
|
|
85
|
+
cat = torch.cat([hidden_states, skip], dim=-1)
|
|
86
|
+
cat = self.skip_norm(cat)
|
|
87
|
+
hidden_states = self.skip_linear(cat)
|
|
88
|
+
|
|
89
|
+
# 1. Self-Attention
|
|
90
|
+
norm_hidden_states = self.norm1(
|
|
91
|
+
hidden_states, temb
|
|
92
|
+
) # checked: self.norm1 is correct
|
|
93
|
+
attn_output = self.attn1(
|
|
94
|
+
norm_hidden_states,
|
|
95
|
+
image_rotary_emb=image_rotary_emb,
|
|
96
|
+
)
|
|
97
|
+
hidden_states = hidden_states + attn_output
|
|
98
|
+
|
|
99
|
+
# 2. Cross-Attention
|
|
100
|
+
hidden_states = hidden_states + self.attn2(
|
|
101
|
+
self.norm2(hidden_states),
|
|
102
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
103
|
+
image_rotary_emb=image_rotary_emb,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# FFN Layer
|
|
107
|
+
mlp_inputs = self.norm3(hidden_states)
|
|
108
|
+
hidden_states = hidden_states + self.ff(mlp_inputs)
|
|
109
|
+
|
|
110
|
+
if layer_id < (num_layers // 2 - 1):
|
|
111
|
+
skips.append(hidden_states)
|
|
112
|
+
|
|
113
|
+
return hidden_states
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def __patch_transformer_forward__(
|
|
117
|
+
self: HunyuanDiT2DModel,
|
|
118
|
+
hidden_states,
|
|
119
|
+
timestep,
|
|
120
|
+
encoder_hidden_states=None,
|
|
121
|
+
text_embedding_mask=None,
|
|
122
|
+
encoder_hidden_states_t5=None,
|
|
123
|
+
text_embedding_mask_t5=None,
|
|
124
|
+
image_meta_size=None,
|
|
125
|
+
style=None,
|
|
126
|
+
image_rotary_emb=None,
|
|
127
|
+
controlnet_block_samples=None,
|
|
128
|
+
return_dict=True,
|
|
129
|
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
|
130
|
+
height, width = hidden_states.shape[-2:]
|
|
131
|
+
|
|
132
|
+
hidden_states = self.pos_embed(hidden_states)
|
|
133
|
+
|
|
134
|
+
temb = self.time_extra_emb(
|
|
135
|
+
timestep,
|
|
136
|
+
encoder_hidden_states_t5,
|
|
137
|
+
image_meta_size,
|
|
138
|
+
style,
|
|
139
|
+
hidden_dtype=timestep.dtype,
|
|
140
|
+
) # [B, D]
|
|
141
|
+
|
|
142
|
+
# text projection
|
|
143
|
+
batch_size, sequence_length, _ = encoder_hidden_states_t5.shape
|
|
144
|
+
encoder_hidden_states_t5 = self.text_embedder(
|
|
145
|
+
encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1])
|
|
146
|
+
)
|
|
147
|
+
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(
|
|
148
|
+
batch_size, sequence_length, -1
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
encoder_hidden_states = torch.cat(
|
|
152
|
+
[encoder_hidden_states, encoder_hidden_states_t5], dim=1
|
|
153
|
+
)
|
|
154
|
+
text_embedding_mask = torch.cat(
|
|
155
|
+
[text_embedding_mask, text_embedding_mask_t5], dim=-1
|
|
156
|
+
)
|
|
157
|
+
text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()
|
|
158
|
+
|
|
159
|
+
encoder_hidden_states = torch.where(
|
|
160
|
+
text_embedding_mask, encoder_hidden_states, self.text_embedding_padding
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
skips = []
|
|
164
|
+
for layer, block in enumerate(self.blocks):
|
|
165
|
+
hidden_states = block(
|
|
166
|
+
hidden_states,
|
|
167
|
+
temb=temb,
|
|
168
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
169
|
+
image_rotary_emb=image_rotary_emb,
|
|
170
|
+
controlnet_block_samples=controlnet_block_samples,
|
|
171
|
+
skips=skips,
|
|
172
|
+
) # (N, L, D)
|
|
173
|
+
|
|
174
|
+
if (
|
|
175
|
+
controlnet_block_samples is not None
|
|
176
|
+
and len(controlnet_block_samples) != 0
|
|
177
|
+
):
|
|
178
|
+
raise ValueError(
|
|
179
|
+
"The number of controls is not equal to the number of skip connections."
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# final layer
|
|
183
|
+
hidden_states = self.norm_out(hidden_states, temb.to(torch.float32))
|
|
184
|
+
hidden_states = self.proj_out(hidden_states)
|
|
185
|
+
# (N, L, patch_size ** 2 * out_channels)
|
|
186
|
+
|
|
187
|
+
# unpatchify: (N, out_channels, H, W)
|
|
188
|
+
patch_size = self.pos_embed.patch_size
|
|
189
|
+
height = height // patch_size
|
|
190
|
+
width = width // patch_size
|
|
191
|
+
|
|
192
|
+
hidden_states = hidden_states.reshape(
|
|
193
|
+
shape=(
|
|
194
|
+
hidden_states.shape[0],
|
|
195
|
+
height,
|
|
196
|
+
width,
|
|
197
|
+
patch_size,
|
|
198
|
+
patch_size,
|
|
199
|
+
self.out_channels,
|
|
200
|
+
)
|
|
201
|
+
)
|
|
202
|
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
|
203
|
+
output = hidden_states.reshape(
|
|
204
|
+
shape=(
|
|
205
|
+
hidden_states.shape[0],
|
|
206
|
+
self.out_channels,
|
|
207
|
+
height * patch_size,
|
|
208
|
+
width * patch_size,
|
|
209
|
+
)
|
|
210
|
+
)
|
|
211
|
+
if not return_dict:
|
|
212
|
+
return (output,)
|
|
213
|
+
return Transformer2DModelOutput(sample=output)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.32
|
|
4
4
|
Summary: 🤗 A Unified and Training-free Cache Acceleration Toolbox for Diffusion Transformers
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -59,29 +59,37 @@ Dynamic: requires-python
|
|
|
59
59
|
🔥<b><a href="#unified">Unified Cache APIs</a> | <a href="#dbcache">DBCache</a> | <a href="#taylorseer">Hybrid TaylorSeer</a> | <a href="#cfg">Hybrid Cache CFG</a></b>🔥
|
|
60
60
|
</p>
|
|
61
61
|
<p align="center">
|
|
62
|
-
|
|
63
|
-
|
|
62
|
+
🎉Now, <b>cache-dit</b> covers <b>most</b> mainstream Diffusers' <b>DiT</b> Pipelines🎉<br>
|
|
63
|
+
🔥<a href="#supported">Qwen-Image</a> | <a href="#supported">FLUX.1</a> | <a href="#supported">Qwen-Image-Lightning</a> | <a href="#supported"> Wan 2.1/2.2 </a>🔥<br>
|
|
64
|
+
🔥<a href="#supported">HunyuanVideo</a> | <a href="#supported">HunyuanDiT</a> | <a href="#supported">HiDream</a> | <a href="#supported">Mochi</a> | <a href="#supported">CogVideoX 1/1.5</a>🔥<br>
|
|
65
|
+
🔥<a href="#supported">CogView3Plus</a> | <a href="#supported">CogView4</a> | <a href="#supported">Chroma</a> | <a href="#supported"> LTXVideo </a> | <a href="#supported">PixArt</a>🔥<br>
|
|
66
|
+
🔥<a href="#supported">Cosmos</a> | <a href="#supported">SkyReelsV2</a> | <a href="#supported">VisualCloze</a> | <a href="#supported"> OmniGen </a> | <a href="#supported">Lumina 1/2</a>🔥<br>
|
|
67
|
+
🔥<a href="#supported">Allegro</a> | <a href="#supported">EasyAnimate</a> | <a href="#supported">SD 3/3.5</a> | <a href="#supported"> ... </a> | <a href="#supported">DiT-XL</a>🔥
|
|
64
68
|
</p>
|
|
65
69
|
</div>
|
|
66
70
|
<div align='center'>
|
|
67
|
-
<img src
|
|
68
|
-
<img src
|
|
69
|
-
<img src
|
|
70
|
-
<p><b>🔥Wan2.2 MoE</b>
|
|
71
|
-
<img src
|
|
72
|
-
<img src
|
|
73
|
-
<img src
|
|
74
|
-
<p><b>🔥Qwen-Image</b>
|
|
75
|
-
|
|
71
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/wan2.2.C0_Q0_NONE.gif width=160px>
|
|
72
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/wan2.2.C1_Q0_DBCACHE_F1B0_W2M8MC2_T1O2_R0.08.gif width=160px>
|
|
73
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/wan2.2.C1_Q1_fp8_w8a8_dq_DBCACHE_F1B0_W2M8MC2_T1O2_R0.08.gif width=160px>
|
|
74
|
+
<p><b>🔥Wan2.2 MoE</b> | <b><a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:~2.0x↑🎉</b> | +FP8 DQ:<b>~2.4x↑🎉</b></p>
|
|
75
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image.C0_Q0_NONE.png width=160px>
|
|
76
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image.C1_Q0_DBCACHE_F8B0_W8M0MC0_T1O4_R0.12_S23.png width=160px>
|
|
77
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image.C1_Q1_fp8_w8a8_dq_DBCACHE_F8B0_W8M0MC0_T1O4_R0.12_S18.png width=160px>
|
|
78
|
+
<p><b>🔥Qwen-Image</b> | <b><a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:~1.8x↑🎉</b> | +FP8 DQ:<b>~2.2x↑🎉</b></p>
|
|
79
|
+
<img src=./assets/qwen-image-lightning.4steps.C0_L1_Q0_NONE.png width=200px>
|
|
80
|
+
<img src=./assets/qwen-image-lightning.4steps.C0_L1_Q0_DBCACHE_F16B16_W2M1MC1_T0O2_R0.9_S1.png width=200px>
|
|
81
|
+
<p><b>🔥Qwen-Image-Lightning</b> 4 steps | <b><a href="https://github.com/vipshop/cache-dit">+cache-dit</a></b> 3.5 steps:<b>~1.14x↑🎉</b>
|
|
82
|
+
<br>♥️ Please consider to leave a <b>⭐️ Star</b> to support us ~ ♥️</p>
|
|
76
83
|
</div>
|
|
77
84
|
|
|
78
85
|
## 🔥News
|
|
79
86
|
|
|
87
|
+
- [2025-09-08] 🔥[**Qwen-Image-Lightning**](./examples/pipeline/run_qwen_image_lightning.py) **7.1/3.5 steps🎉** inference with **[DBCache: F16B16](https://github.com/vipshop/cache-dit)**.
|
|
80
88
|
- [2025-09-03] 🎉[**Wan2.2-MoE**](https://github.com/Wan-Video) **2.4x↑🎉** speedup! Please refer to [run_wan_2.2.py](./examples/pipeline/run_wan_2.2.py) as an example.
|
|
81
89
|
- [2025-08-19] 🔥[**Qwen-Image-Edit**](https://github.com/QwenLM/Qwen-Image) **2x↑🎉** speedup! Check the example: [run_qwen_image_edit.py](./examples/pipeline/run_qwen_image_edit.py).
|
|
82
90
|
- [2025-08-12] 🎉First caching mechanism in [QwenLM/Qwen-Image](https://github.com/QwenLM/Qwen-Image) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check this [PR](https://github.com/QwenLM/Qwen-Image/pull/61).
|
|
83
91
|
- [2025-08-11] 🔥[**Qwen-Image**](https://github.com/QwenLM/Qwen-Image) **1.8x↑🎉** speedup! Please refer to [run_qwen_image.py](./examples/pipeline/run_qwen_image.py) as an example.
|
|
84
|
-
- [2025-07-13] 🎉[**FLUX.1-
|
|
92
|
+
- [2025-07-13] 🎉[**FLUX.1-dev**](https://github.com/xlite-dev/flux-faster) **3.3x↑🎉** speedup! NVIDIA L20 with **[cache-dit](https://github.com/vipshop/cache-dit)** + **compile + FP8 DQ**.
|
|
85
93
|
|
|
86
94
|
<details>
|
|
87
95
|
<summary> Previous News </summary>
|
|
@@ -131,6 +139,7 @@ pip3 install git+https://github.com/vipshop/cache-dit.git
|
|
|
131
139
|
|
|
132
140
|
Currently, **cache-dit** library supports almost **Any** Diffusion Transformers (with **Transformer Blocks** that match the specific Input and Output **patterns**). Please check [🎉Unified Cache APIs](#unified) for more details. Here are just some of the tested models listed:
|
|
133
141
|
|
|
142
|
+
- [🚀Qwen-Image-Lightning](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
134
143
|
- [🚀Qwen-Image-Edit](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
135
144
|
- [🚀Qwen-Image](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
136
145
|
- [🚀FLUX.1-dev](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
@@ -141,7 +150,13 @@ Currently, **cache-dit** library supports almost **Any** Diffusion Transformers
|
|
|
141
150
|
- [🚀Wan2.2-T2V](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
142
151
|
- [🚀Wan2.1-T2V](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
143
152
|
- [🚀Wan2.1-FLF2V](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
153
|
+
- [🚀mochi-1-preview](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
144
154
|
- [🚀HunyuanVideo](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
155
|
+
- [🚀HunyuanDiT](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
156
|
+
- [🚀HiDream-I1-Full](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
157
|
+
- [🚀PixArt-Alpha](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
158
|
+
- [🚀PixArt-Sigma](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
159
|
+
- [🚀SD-3/3.5](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
145
160
|
|
|
146
161
|
</details>
|
|
147
162
|
|
|
@@ -1,30 +1,32 @@
|
|
|
1
1
|
cache_dit/__init__.py,sha256=kX9V-FegZG4c8LMwI4PTmMqH794MEW0pzDArdhC0cJw,1241
|
|
2
|
-
cache_dit/_version.py,sha256=
|
|
2
|
+
cache_dit/_version.py,sha256=J0YTFDgdG9rY1Xk5pUbWWGgbT2rbSasvUHcntxayVtA,706
|
|
3
3
|
cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
|
|
4
4
|
cache_dit/utils.py,sha256=WK7eqgH6gCYNHXNLmWyxBDU0XSHTPg7CfOcyXlGXBqE,10510
|
|
5
5
|
cache_dit/cache_factory/.gitignore,sha256=5Cb-qT9wsTUoMJ7vACDF7ZcLpAXhi5v-xdcWSRit988,23
|
|
6
6
|
cache_dit/cache_factory/__init__.py,sha256=Iw6-iJLFbdzCsIDZXXOw371L-HPmoeZO_P9a3sDjP5s,1103
|
|
7
|
-
cache_dit/cache_factory/cache_adapters.py,sha256=
|
|
8
|
-
cache_dit/cache_factory/cache_interface.py,sha256=
|
|
7
|
+
cache_dit/cache_factory/cache_adapters.py,sha256=dmNX68nBD52HtQvHnNAuSn1zjDWrQdycD0qXy-w-mwc,18212
|
|
8
|
+
cache_dit/cache_factory/cache_interface.py,sha256=LpyCy-tQ_GcTRAYLpMMf9hFVIktABHI6CObn5Ll8bMw,8548
|
|
9
9
|
cache_dit/cache_factory/cache_types.py,sha256=ooukxQRG55uTLmaZ0SKw6gIeY6SQHhMxkbv55uj2Sqk,991
|
|
10
10
|
cache_dit/cache_factory/forward_pattern.py,sha256=FumlCuZ-TSmSYH0hGBHctSJ-oGLCftdZjLygqhsmdR4,2258
|
|
11
11
|
cache_dit/cache_factory/utils.py,sha256=XkVM9AXcB9zYq8-S8QKAsGz80r3tA6U3lBNGDGeHOe4,1871
|
|
12
|
-
cache_dit/cache_factory/block_adapters/__init__.py,sha256=
|
|
13
|
-
cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=
|
|
12
|
+
cache_dit/cache_factory/block_adapters/__init__.py,sha256=OZM5vJwmQIkoIwVmMxKXiHqKvs31NyAva1Z91C_ko3w,17547
|
|
13
|
+
cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=EQBiJYyoInKU1ND69wTm7M0n5Ja4I8QW01SgRpBjSn8,21671
|
|
14
14
|
cache_dit/cache_factory/block_adapters/block_registers.py,sha256=ZeN2wGPmuf2u3puSsBx8x-rl3wRo8-cWcuWNcrssVfA,2553
|
|
15
|
-
cache_dit/cache_factory/cache_blocks/__init__.py,sha256=
|
|
15
|
+
cache_dit/cache_factory/cache_blocks/__init__.py,sha256=08Ox7kD05lkRKCOsVTdEZeKAWBheqpxfrAT1Nz7eclI,2916
|
|
16
16
|
cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py,sha256=ElMps6_7uI74tSF9GDR_dEI0bZEhdzcepM29xFWnYo8,428
|
|
17
17
|
cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py,sha256=nf2f5wdxp6tfq9AhFyMyBeKiZfxh63WG1g8q-c2BBSg,10182
|
|
18
|
-
cache_dit/cache_factory/cache_blocks/pattern_base.py,sha256=
|
|
18
|
+
cache_dit/cache_factory/cache_blocks/pattern_base.py,sha256=f1ojREQcDoBtDG3dzl8t1g_Vru8140LVDRPWlY-kAXw,21311
|
|
19
19
|
cache_dit/cache_factory/cache_blocks/utils.py,sha256=dGOC1tMMOvcbvEgx44eTESKn_jsv-0RZ3tRHPa3wmQ4,1315
|
|
20
20
|
cache_dit/cache_factory/cache_contexts/__init__.py,sha256=rqnJ5__zqnpVHK5A1OqWILpNh5Ss-0ZDTGgtxZMKGGo,250
|
|
21
21
|
cache_dit/cache_factory/cache_contexts/cache_context.py,sha256=N88WLdd4KE9DuMWmpX8URcF55E2zWNwcKMxgVYkxMJY,13691
|
|
22
22
|
cache_dit/cache_factory/cache_contexts/cache_manager.py,sha256=_NUXcMYYEIVfDHpc4HJr9RUjU5RUEkZmAgFGE8bh5Wc,34883
|
|
23
23
|
cache_dit/cache_factory/cache_contexts/taylorseer.py,sha256=etSUIZzDvqW3ScKCbccTPcFaSmxV1T-xAXdk-p3e3wk,3802
|
|
24
|
-
cache_dit/cache_factory/patch_functors/__init__.py,sha256=
|
|
24
|
+
cache_dit/cache_factory/patch_functors/__init__.py,sha256=06zdddrjvSCgBzJ0a8niRHd3ucF2qsbzlbL00d4aCvk,451
|
|
25
25
|
cache_dit/cache_factory/patch_functors/functor_base.py,sha256=Ahk0fTfrHgNdEl-9JSkACvfyyv9G-Ei5OSz7XBIlX5o,357
|
|
26
|
-
cache_dit/cache_factory/patch_functors/functor_chroma.py,sha256=
|
|
27
|
-
cache_dit/cache_factory/patch_functors/functor_flux.py,sha256=
|
|
26
|
+
cache_dit/cache_factory/patch_functors/functor_chroma.py,sha256=2iLxlsc-1dDHRveqCXaC07E9CeMNOuBNkvpJ1atpK7E,10048
|
|
27
|
+
cache_dit/cache_factory/patch_functors/functor_flux.py,sha256=UMkyuEYjO7UO_zmXi9Djd-nD-XMgCUgE-qkYA3plWSM,9559
|
|
28
|
+
cache_dit/cache_factory/patch_functors/functor_hidream.py,sha256=pi_vvpDy1lsgQHxu3eK3v93rdJL7oNwkt3WakRP8pbw,15375
|
|
29
|
+
cache_dit/cache_factory/patch_functors/functor_hunyuan_dit.py,sha256=iSo5dD5uKnjQQeysDUIkKt0wdnK5bzXTc_F_lfHG70w,6401
|
|
28
30
|
cache_dit/compile/__init__.py,sha256=FcTVzCeyypl-mxlc59_ehHL3lBNiDAFsXuRoJ-5Cfi0,56
|
|
29
31
|
cache_dit/compile/utils.py,sha256=nN2OIrSdwRR5zGxJinKDqb07pXpvTNTF3g_OgLkeeBU,3858
|
|
30
32
|
cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -39,9 +41,9 @@ cache_dit/quantize/__init__.py,sha256=kWYoMAyZgBXu9BJlZjTQ0dRffW9GqeeY9_iTkXrb70
|
|
|
39
41
|
cache_dit/quantize/quantize_ao.py,sha256=mGspqYgQtenl3QnKPtsSYsSD7LbVX93f1M940bhXKLU,6066
|
|
40
42
|
cache_dit/quantize/quantize_interface.py,sha256=2s_R7xPSKuJeFpEGeLwRxnq_CqJcBG3a3lzyW5wh-UM,1241
|
|
41
43
|
cache_dit/quantize/quantize_svdq.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
42
|
-
cache_dit-0.2.
|
|
43
|
-
cache_dit-0.2.
|
|
44
|
-
cache_dit-0.2.
|
|
45
|
-
cache_dit-0.2.
|
|
46
|
-
cache_dit-0.2.
|
|
47
|
-
cache_dit-0.2.
|
|
44
|
+
cache_dit-0.2.32.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
|
|
45
|
+
cache_dit-0.2.32.dist-info/METADATA,sha256=WQ9GP-Om05j3NBvtifkmbz5t20XBU_-KJQptrK7jQBs,24222
|
|
46
|
+
cache_dit-0.2.32.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
47
|
+
cache_dit-0.2.32.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
|
|
48
|
+
cache_dit-0.2.32.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
|
|
49
|
+
cache_dit-0.2.32.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|