cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  11. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  12. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  13. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  14. cache_dit/caching/cache_contexts/__init__.py +28 -0
  15. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  16. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
  18. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  21. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  22. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  23. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  24. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  25. cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
  26. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  27. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  28. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  29. cache_dit/caching/patch_functors/__init__.py +15 -0
  30. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
  36. cache_dit/{cache_factory → caching}/utils.py +19 -8
  37. cache_dit/metrics/__init__.py +11 -0
  38. cache_dit/parallelism/__init__.py +3 -0
  39. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  40. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  41. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  57. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  58. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  59. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  60. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  61. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  62. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  68. cache_dit/parallelism/parallel_backend.py +26 -0
  69. cache_dit/parallelism/parallel_config.py +88 -0
  70. cache_dit/parallelism/parallel_interface.py +77 -0
  71. cache_dit/quantize/__init__.py +7 -0
  72. cache_dit/quantize/backends/__init__.py +1 -0
  73. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  74. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  75. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
  76. cache_dit/quantize/quantize_backend.py +0 -0
  77. cache_dit/quantize/quantize_config.py +0 -0
  78. cache_dit/quantize/quantize_interface.py +3 -16
  79. cache_dit/summary.py +593 -0
  80. cache_dit/utils.py +46 -290
  81. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
  82. cache_dit-1.0.14.dist-info/RECORD +102 -0
  83. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  84. cache_dit/cache_factory/__init__.py +0 -28
  85. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  86. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  87. cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
  88. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  89. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
  90. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
  91. cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
  92. cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
  93. cache_dit/cache_factory/patch_functors/__init__.py +0 -15
  94. cache_dit-1.0.3.dist-info/RECORD +0 -58
  95. cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
  96. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  97. /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
  98. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  99. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  100. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  101. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  102. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  103. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  104. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,118 @@
1
+ import torch
2
+ from typing import Any, Tuple, List, Dict, Callable, Union
3
+
4
+ from diffusers import DiffusionPipeline
5
+ from cache_dit.caching.block_adapters.block_adapters import (
6
+ BlockAdapter,
7
+ FakeDiffusionPipeline,
8
+ )
9
+
10
+ from cache_dit.logger import init_logger
11
+
12
+ logger = init_logger(__name__)
13
+
14
+
15
+ class BlockAdapterRegistry:
16
+ _adapters: Dict[str, Callable[..., BlockAdapter]] = {}
17
+ _predefined_adapters_has_separate_cfg: List[str] = [
18
+ "QwenImage",
19
+ "Wan",
20
+ "CogView4",
21
+ "Cosmos",
22
+ "SkyReelsV2",
23
+ "Chroma",
24
+ "Lumina2",
25
+ "Kandinsky5",
26
+ ]
27
+
28
+ @classmethod
29
+ def register(cls, name: str, supported: bool = True):
30
+ def decorator(
31
+ func: Callable[..., BlockAdapter]
32
+ ) -> Callable[..., BlockAdapter]:
33
+ if supported:
34
+ cls._adapters[name] = func
35
+ return func
36
+
37
+ return decorator
38
+
39
+ @classmethod
40
+ def get_adapter(
41
+ cls,
42
+ pipe_or_module: DiffusionPipeline | torch.nn.Module | str | Any,
43
+ **kwargs,
44
+ ) -> BlockAdapter | None:
45
+ if not isinstance(pipe_or_module, str):
46
+ cls_name: str = pipe_or_module.__class__.__name__
47
+ else:
48
+ cls_name = pipe_or_module
49
+
50
+ for name in cls._adapters:
51
+ if cls_name.startswith(name):
52
+ if not isinstance(pipe_or_module, DiffusionPipeline):
53
+ assert isinstance(pipe_or_module, torch.nn.Module)
54
+ # NOTE: Make pre-registered adapters support Transformer-only case.
55
+ # WARN: This branch is not officially supported and only for testing
56
+ # purpose. We construct a fake diffusion pipeline that contains the
57
+ # given transformer module. Currently, only works for DiT models which
58
+ # only have one transformer module. Case like multiple transformers
59
+ # is not supported, e.g, Wan2.2. Please use BlockAdapter directly for
60
+ # such cases.
61
+ return cls._adapters[name](
62
+ FakeDiffusionPipeline(pipe_or_module), **kwargs
63
+ )
64
+ else:
65
+ return cls._adapters[name](pipe_or_module, **kwargs)
66
+
67
+ return None
68
+
69
+ @classmethod
70
+ def has_separate_cfg(
71
+ cls,
72
+ pipe_or_adapter: Union[
73
+ DiffusionPipeline,
74
+ FakeDiffusionPipeline,
75
+ BlockAdapter,
76
+ Any,
77
+ ],
78
+ ) -> bool:
79
+
80
+ # Prefer custom setting from block adapter.
81
+ if isinstance(pipe_or_adapter, BlockAdapter):
82
+ return pipe_or_adapter.has_separate_cfg
83
+
84
+ has_separate_cfg = False
85
+ if isinstance(pipe_or_adapter, FakeDiffusionPipeline):
86
+ return False
87
+
88
+ if isinstance(pipe_or_adapter, DiffusionPipeline):
89
+ adapter = cls.get_adapter(
90
+ pipe_or_adapter,
91
+ skip_post_init=True, # check cfg setting only
92
+ )
93
+ if adapter is not None:
94
+ has_separate_cfg = adapter.has_separate_cfg
95
+
96
+ if has_separate_cfg:
97
+ return True
98
+
99
+ pipe_cls_name = pipe_or_adapter.__class__.__name__
100
+ for name in cls._predefined_adapters_has_separate_cfg:
101
+ if pipe_cls_name.startswith(name):
102
+ return True
103
+
104
+ return False
105
+
106
+ @classmethod
107
+ def is_supported(cls, pipe_or_module) -> bool:
108
+ cls_name: str = pipe_or_module.__class__.__name__
109
+
110
+ for name in cls._adapters:
111
+ if cls_name.startswith(name):
112
+ return True
113
+ return False
114
+
115
+ @classmethod
116
+ def supported_pipelines(cls, **kwargs) -> Tuple[int, List[str]]:
117
+ val_pipelines = cls._adapters.keys()
118
+ return len(val_pipelines), [p for p in val_pipelines]
@@ -0,0 +1 @@
1
+ from cache_dit.caching.cache_adapters.cache_adapter import CachedAdapter