cache-dit 0.3.2__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +149 -18
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +91 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +262 -123
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/offload_utils.py +115 -0
  11. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  12. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  13. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  14. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  15. cache_dit/caching/cache_contexts/__init__.py +28 -0
  16. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +29 -90
  18. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +138 -10
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  21. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  22. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  23. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  24. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  25. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  26. cache_dit/caching/cache_interface.py +358 -0
  27. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  28. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  29. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  30. cache_dit/caching/patch_functors/__init__.py +15 -0
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +2 -4
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  36. cache_dit/caching/patch_functors/functor_qwen_image_controlnet.py +263 -0
  37. cache_dit/caching/utils.py +68 -0
  38. cache_dit/metrics/__init__.py +11 -0
  39. cache_dit/metrics/metrics.py +3 -0
  40. cache_dit/parallelism/__init__.py +3 -0
  41. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  57. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  58. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  59. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  60. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  61. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  62. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  68. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  69. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  70. cache_dit/parallelism/parallel_backend.py +26 -0
  71. cache_dit/parallelism/parallel_config.py +88 -0
  72. cache_dit/parallelism/parallel_interface.py +77 -0
  73. cache_dit/quantize/__init__.py +7 -0
  74. cache_dit/quantize/backends/__init__.py +1 -0
  75. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  76. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  77. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +44 -30
  78. cache_dit/quantize/quantize_backend.py +0 -0
  79. cache_dit/quantize/quantize_config.py +0 -0
  80. cache_dit/quantize/quantize_interface.py +3 -16
  81. cache_dit/summary.py +593 -0
  82. cache_dit/utils.py +46 -290
  83. cache_dit-1.0.14.dist-info/METADATA +301 -0
  84. cache_dit-1.0.14.dist-info/RECORD +102 -0
  85. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  86. cache_dit/cache_factory/__init__.py +0 -28
  87. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  88. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  89. cache_dit/cache_factory/cache_blocks/__init__.py +0 -72
  90. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  91. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -238
  92. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -404
  93. cache_dit/cache_factory/cache_blocks/utils.py +0 -41
  94. cache_dit/cache_factory/cache_contexts/__init__.py +0 -14
  95. cache_dit/cache_factory/cache_interface.py +0 -217
  96. cache_dit/cache_factory/patch_functors/__init__.py +0 -12
  97. cache_dit/cache_factory/utils.py +0 -57
  98. cache_dit-0.3.2.dist-info/METADATA +0 -753
  99. cache_dit-0.3.2.dist-info/RECORD +0 -56
  100. cache_dit-0.3.2.dist-info/licenses/LICENSE +0 -53
  101. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  102. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  103. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  104. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  105. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  106. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  107. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  108. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,263 @@
1
+ import torch
2
+ import numpy as np
3
+ from typing import Tuple, Optional, Dict, Any, Union, List
4
+ from diffusers import QwenImageTransformer2DModel
5
+ from diffusers.models.transformers.transformer_qwenimage import (
6
+ QwenImageTransformerBlock,
7
+ Transformer2DModelOutput,
8
+ )
9
+ from diffusers.utils import (
10
+ USE_PEFT_BACKEND,
11
+ scale_lora_layers,
12
+ unscale_lora_layers,
13
+ )
14
+ from cache_dit.caching.patch_functors.functor_base import (
15
+ PatchFunctor,
16
+ )
17
+ from cache_dit.logger import init_logger
18
+
19
+ logger = init_logger(__name__)
20
+
21
+
22
+ class QwenImageControlNetPatchFunctor(PatchFunctor):
23
+
24
+ def apply(
25
+ self,
26
+ transformer: QwenImageTransformer2DModel,
27
+ **kwargs,
28
+ ) -> QwenImageTransformer2DModel:
29
+ if hasattr(transformer, "_is_patched"):
30
+ return transformer
31
+
32
+ is_patched = False
33
+
34
+ _index_block = 0
35
+ _num_blocks = len(transformer.transformer_blocks)
36
+ for block in transformer.transformer_blocks:
37
+ assert isinstance(block, QwenImageTransformerBlock)
38
+ block._index_block = _index_block
39
+ block._num_blocks = _num_blocks
40
+ block.forward = __patch_block_forward__.__get__(block)
41
+ _index_block += 1
42
+
43
+ is_patched = True
44
+ cls_name = transformer.__class__.__name__
45
+
46
+ if is_patched:
47
+ logger.warning(f"Patched {cls_name} for cache-dit.")
48
+ assert not getattr(transformer, "_is_parallelized", False), (
49
+ "Please call `cache_dit.enable_cache` before Parallelize, "
50
+ "the __patch_transformer_forward__ will overwrite the "
51
+ "parallized forward and cause a downgrade of performance."
52
+ )
53
+ transformer.forward = __patch_transformer_forward__.__get__(
54
+ transformer
55
+ )
56
+
57
+ transformer._is_patched = is_patched # True or False
58
+
59
+ logger.info(
60
+ f"Applied {self.__class__.__name__} for {cls_name}, "
61
+ f"Patch: {is_patched}."
62
+ )
63
+
64
+ return transformer
65
+
66
+
67
+ def __patch_block_forward__(
68
+ self: QwenImageTransformerBlock,
69
+ hidden_states: torch.Tensor,
70
+ encoder_hidden_states: torch.Tensor,
71
+ encoder_hidden_states_mask: torch.Tensor,
72
+ temb: torch.Tensor,
73
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
74
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
75
+ controlnet_block_samples: Optional[List[torch.Tensor]] = None,
76
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
77
+ # Get modulation parameters for both streams
78
+ img_mod_params = self.img_mod(temb) # [B, 6*dim]
79
+ txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
80
+
81
+ # Split modulation parameters for norm1 and norm2
82
+ img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
83
+ txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
84
+
85
+ # Process image stream - norm1 + modulation
86
+ img_normed = self.img_norm1(hidden_states)
87
+ img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
88
+
89
+ # Process text stream - norm1 + modulation
90
+ txt_normed = self.txt_norm1(encoder_hidden_states)
91
+ txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
92
+
93
+ # Use QwenAttnProcessor2_0 for joint attention computation
94
+ # This directly implements the DoubleStreamLayerMegatron logic:
95
+ # 1. Computes QKV for both streams
96
+ # 2. Applies QK normalization and RoPE
97
+ # 3. Concatenates and runs joint attention
98
+ # 4. Splits results back to separate streams
99
+ joint_attention_kwargs = joint_attention_kwargs or {}
100
+ attn_output = self.attn(
101
+ hidden_states=img_modulated, # Image stream (will be processed as "sample")
102
+ encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
103
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
104
+ image_rotary_emb=image_rotary_emb,
105
+ **joint_attention_kwargs,
106
+ )
107
+
108
+ # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
109
+ img_attn_output, txt_attn_output = attn_output
110
+
111
+ # Apply attention gates and add residual (like in Megatron)
112
+ hidden_states = hidden_states + img_gate1 * img_attn_output
113
+ encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
114
+
115
+ # Process image stream - norm2 + MLP
116
+ img_normed2 = self.img_norm2(hidden_states)
117
+ img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
118
+ img_mlp_output = self.img_mlp(img_modulated2)
119
+ hidden_states = hidden_states + img_gate2 * img_mlp_output
120
+
121
+ # Process text stream - norm2 + MLP
122
+ txt_normed2 = self.txt_norm2(encoder_hidden_states)
123
+ txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
124
+ txt_mlp_output = self.txt_mlp(txt_modulated2)
125
+ encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
126
+
127
+ # Clip to prevent overflow for fp16
128
+ if encoder_hidden_states.dtype == torch.float16:
129
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
130
+ if hidden_states.dtype == torch.float16:
131
+ hidden_states = hidden_states.clip(-65504, 65504)
132
+
133
+ if controlnet_block_samples is not None:
134
+ # Add ControlNet conditioning
135
+ num_blocks = self._num_blocks
136
+ index_block = self._index_block
137
+ interval_control = num_blocks / len(controlnet_block_samples)
138
+ interval_control = int(np.ceil(interval_control))
139
+ hidden_states = (
140
+ hidden_states
141
+ + controlnet_block_samples[index_block // interval_control]
142
+ )
143
+
144
+ return encoder_hidden_states, hidden_states
145
+
146
+
147
+ def __patch_transformer_forward__(
148
+ self: QwenImageTransformer2DModel,
149
+ hidden_states: torch.Tensor,
150
+ encoder_hidden_states: torch.Tensor = None,
151
+ encoder_hidden_states_mask: torch.Tensor = None,
152
+ timestep: torch.LongTensor = None,
153
+ img_shapes: Optional[List[Tuple[int, int, int]]] = None,
154
+ txt_seq_lens: Optional[List[int]] = None,
155
+ guidance: torch.Tensor = None, # TODO: this should probably be removed
156
+ attention_kwargs: Optional[Dict[str, Any]] = None,
157
+ controlnet_block_samples=None,
158
+ return_dict: bool = True,
159
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
160
+ """
161
+ The [`QwenTransformer2DModel`] forward method.
162
+
163
+ Args:
164
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
165
+ Input `hidden_states`.
166
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
167
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
168
+ encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):
169
+ Mask of the input conditions.
170
+ timestep ( `torch.LongTensor`):
171
+ Used to indicate denoising step.
172
+ attention_kwargs (`dict`, *optional*):
173
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
174
+ `self.processor` in
175
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
176
+ return_dict (`bool`, *optional*, defaults to `True`):
177
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
178
+ tuple.
179
+
180
+ Returns:
181
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
182
+ `tuple` where the first element is the sample tensor.
183
+ """
184
+ if attention_kwargs is not None:
185
+ attention_kwargs = attention_kwargs.copy()
186
+ lora_scale = attention_kwargs.pop("scale", 1.0)
187
+ else:
188
+ lora_scale = 1.0
189
+
190
+ if USE_PEFT_BACKEND:
191
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
192
+ scale_lora_layers(self, lora_scale)
193
+ else:
194
+ if (
195
+ attention_kwargs is not None
196
+ and attention_kwargs.get("scale", None) is not None
197
+ ):
198
+ logger.warning(
199
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
200
+ )
201
+
202
+ hidden_states = self.img_in(hidden_states)
203
+
204
+ timestep = timestep.to(hidden_states.dtype)
205
+ encoder_hidden_states = self.txt_norm(encoder_hidden_states)
206
+ encoder_hidden_states = self.txt_in(encoder_hidden_states)
207
+
208
+ if guidance is not None:
209
+ guidance = guidance.to(hidden_states.dtype) * 1000
210
+
211
+ temb = (
212
+ self.time_text_embed(timestep, hidden_states)
213
+ if guidance is None
214
+ else self.time_text_embed(timestep, guidance, hidden_states)
215
+ )
216
+
217
+ image_rotary_emb = self.pos_embed(
218
+ img_shapes, txt_seq_lens, device=hidden_states.device
219
+ )
220
+
221
+ for index_block, block in enumerate(self.transformer_blocks):
222
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
223
+ encoder_hidden_states, hidden_states = (
224
+ self._gradient_checkpointing_func(
225
+ block,
226
+ hidden_states,
227
+ encoder_hidden_states,
228
+ encoder_hidden_states_mask,
229
+ temb,
230
+ image_rotary_emb,
231
+ controlnet_block_samples,
232
+ )
233
+ )
234
+
235
+ else:
236
+ encoder_hidden_states, hidden_states = block(
237
+ hidden_states=hidden_states,
238
+ encoder_hidden_states=encoder_hidden_states,
239
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
240
+ temb=temb,
241
+ image_rotary_emb=image_rotary_emb,
242
+ controlnet_block_samples=controlnet_block_samples,
243
+ joint_attention_kwargs=attention_kwargs,
244
+ )
245
+
246
+ # # controlnet residual
247
+ # if controlnet_block_samples is not None:
248
+ # interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
249
+ # interval_control = int(np.ceil(interval_control))
250
+ # hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
251
+
252
+ # Use only the image part (hidden_states) from the dual-stream blocks
253
+ hidden_states = self.norm_out(hidden_states, temb)
254
+ output = self.proj_out(hidden_states)
255
+
256
+ if USE_PEFT_BACKEND:
257
+ # remove `lora_scale` from each PEFT layer
258
+ unscale_lora_layers(self, lora_scale)
259
+
260
+ if not return_dict:
261
+ return (output,)
262
+
263
+ return Transformer2DModelOutput(sample=output)
@@ -0,0 +1,68 @@
1
+ import yaml
2
+
3
+
4
+ def load_cache_options_from_yaml(yaml_file_path):
5
+ try:
6
+ with open(yaml_file_path, "r") as f:
7
+ kwargs: dict = yaml.safe_load(f)
8
+
9
+ required_keys = [
10
+ "residual_diff_threshold",
11
+ ]
12
+ for key in required_keys:
13
+ if key not in kwargs:
14
+ raise ValueError(
15
+ f"Configuration file missing required item: {key}"
16
+ )
17
+
18
+ cache_context_kwargs = {}
19
+ if kwargs.get("enable_taylorseer", False):
20
+ from cache_dit.caching.cache_contexts.calibrators import (
21
+ TaylorSeerCalibratorConfig,
22
+ )
23
+
24
+ cache_context_kwargs["calibrator_config"] = (
25
+ TaylorSeerCalibratorConfig(
26
+ enable_calibrator=kwargs.pop("enable_taylorseer"),
27
+ enable_encoder_calibrator=kwargs.pop(
28
+ "enable_encoder_taylorseer", False
29
+ ),
30
+ calibrator_cache_type=kwargs.pop(
31
+ "taylorseer_cache_type", "residual"
32
+ ),
33
+ taylorseer_order=kwargs.pop("taylorseer_order", 1),
34
+ )
35
+ )
36
+
37
+ if "cache_type" not in kwargs:
38
+ from cache_dit.caching.cache_contexts import BasicCacheConfig
39
+
40
+ cache_context_kwargs["cache_config"] = BasicCacheConfig()
41
+ cache_context_kwargs["cache_config"].update(**kwargs)
42
+ else:
43
+ cache_type = kwargs.pop("cache_type")
44
+ if cache_type == "DBCache":
45
+ from cache_dit.caching.cache_contexts import DBCacheConfig
46
+
47
+ cache_context_kwargs["cache_config"] = DBCacheConfig()
48
+ cache_context_kwargs["cache_config"].update(**kwargs)
49
+ elif cache_type == "DBPrune":
50
+ from cache_dit.caching.cache_contexts import DBPruneConfig
51
+
52
+ cache_context_kwargs["cache_config"] = DBPruneConfig()
53
+ cache_context_kwargs["cache_config"].update(**kwargs)
54
+ else:
55
+ raise ValueError(f"Unsupported cache_type: {cache_type}.")
56
+
57
+ return cache_context_kwargs
58
+
59
+ except FileNotFoundError:
60
+ raise FileNotFoundError(
61
+ f"Configuration file not found: {yaml_file_path}"
62
+ )
63
+ except yaml.YAMLError as e:
64
+ raise yaml.YAMLError(f"YAML file parsing error: {str(e)}")
65
+
66
+
67
+ def load_options(path: str):
68
+ return load_cache_options_from_yaml(path)
@@ -1,3 +1,14 @@
1
+ try:
2
+ import ImageReward
3
+ import lpips
4
+ import skimage
5
+ import scipy
6
+ except ImportError:
7
+ raise ImportError(
8
+ "Metrics functionality requires the 'metrics' extra dependencies. "
9
+ "Install with:\npip install cache-dit[metrics]"
10
+ )
11
+
1
12
  from cache_dit.metrics.metrics import compute_psnr
2
13
  from cache_dit.metrics.metrics import compute_ssim
3
14
  from cache_dit.metrics.metrics import compute_mse
@@ -646,6 +646,7 @@ def entrypoint():
646
646
  not os.path.exists(img_test),
647
647
  )
648
648
  ):
649
+ logger.error(f"Not exist: {img_true} or {img_test}, skip.")
649
650
  return
650
651
  # img_true and img_test can be files or dirs
651
652
  img_true_info = os.path.basename(img_true)
@@ -684,6 +685,7 @@ def entrypoint():
684
685
  not os.path.exists(img_test), # dir
685
686
  )
686
687
  ):
688
+ logger.error(f"Not exist: {prompt_true} or {img_test}, skip.")
687
689
  return
688
690
 
689
691
  # img_true and img_test can be files or dirs
@@ -714,6 +716,7 @@ def entrypoint():
714
716
  not os.path.exists(video_test),
715
717
  )
716
718
  ):
719
+ logger.error(f"Not exist: {video_true} or {video_test}, skip.")
717
720
  return
718
721
 
719
722
  # video_true and video_test can be files or dirs
@@ -0,0 +1,3 @@
1
+ from cache_dit.parallelism.parallel_backend import ParallelismBackend
2
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
3
+ from cache_dit.parallelism.parallel_interface import enable_parallelism
@@ -0,0 +1,6 @@
1
+ from cache_dit.parallelism.backends.native_diffusers.context_parallelism import (
2
+ ContextParallelismPlannerRegister,
3
+ )
4
+ from cache_dit.parallelism.backends.native_diffusers.parallel_difffusers import (
5
+ maybe_enable_parallelism,
6
+ )
@@ -0,0 +1,164 @@
1
+ import torch
2
+ from typing import Optional
3
+
4
+ from diffusers.models.modeling_utils import ModelMixin
5
+ from cache_dit.parallelism.parallel_backend import ParallelismBackend
6
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
7
+ from cache_dit.logger import init_logger
8
+ from ..utils import (
9
+ native_diffusers_parallelism_available,
10
+ ContextParallelConfig,
11
+ )
12
+ from .attention import maybe_resigter_native_attention_backend
13
+ from .cp_planners import *
14
+
15
+ try:
16
+ maybe_resigter_native_attention_backend()
17
+ except ImportError as e:
18
+ raise ImportError(e)
19
+
20
+ logger = init_logger(__name__)
21
+
22
+
23
+ def maybe_enable_context_parallelism(
24
+ transformer: torch.nn.Module,
25
+ parallelism_config: Optional[ParallelismConfig],
26
+ ) -> torch.nn.Module:
27
+ assert isinstance(transformer, ModelMixin), (
28
+ "transformer must be an instance of diffusers' ModelMixin, "
29
+ f"but got {type(transformer)}"
30
+ )
31
+ if parallelism_config is None:
32
+ return transformer
33
+
34
+ assert isinstance(parallelism_config, ParallelismConfig), (
35
+ "parallelism_config must be an instance of ParallelismConfig"
36
+ f" but got {type(parallelism_config)}"
37
+ )
38
+
39
+ if (
40
+ parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER
41
+ and native_diffusers_parallelism_available()
42
+ ):
43
+ cp_config = None
44
+ if (
45
+ parallelism_config.ulysses_size is not None
46
+ or parallelism_config.ring_size is not None
47
+ ):
48
+ cp_config = ContextParallelConfig(
49
+ ulysses_degree=parallelism_config.ulysses_size,
50
+ ring_degree=parallelism_config.ring_size,
51
+ )
52
+ if cp_config is not None:
53
+ attention_backend = parallelism_config.parallel_kwargs.get(
54
+ "attention_backend", None
55
+ )
56
+ if hasattr(transformer, "enable_parallelism"):
57
+ if hasattr(transformer, "set_attention_backend"):
58
+ # native, _native_cudnn, flash, etc.
59
+ if attention_backend is None:
60
+ # Now only _native_cudnn is supported for parallelism
61
+ # issue: https://github.com/huggingface/diffusers/pull/12443
62
+ transformer.set_attention_backend("_native_cudnn")
63
+ logger.warning(
64
+ "attention_backend is None, set default attention backend "
65
+ "to _native_cudnn for parallelism because of the issue: "
66
+ "https://github.com/huggingface/diffusers/pull/12443"
67
+ )
68
+ else:
69
+ transformer.set_attention_backend(attention_backend)
70
+ logger.info(
71
+ "Found attention_backend from config, set attention "
72
+ f"backend to: {attention_backend}"
73
+ )
74
+ # Prefer custom cp_plan if provided
75
+ cp_plan = parallelism_config.parallel_kwargs.get(
76
+ "cp_plan", None
77
+ )
78
+ if cp_plan is not None:
79
+ logger.info(
80
+ f"Using custom context parallelism plan: {cp_plan}"
81
+ )
82
+ else:
83
+ # Try get context parallelism plan from register if not provided
84
+ extra_parallel_kwargs = {}
85
+ if parallelism_config.parallel_kwargs is not None:
86
+ extra_parallel_kwargs = (
87
+ parallelism_config.parallel_kwargs
88
+ )
89
+ cp_plan = ContextParallelismPlannerRegister.get_planner(
90
+ transformer
91
+ )().apply(transformer=transformer, **extra_parallel_kwargs)
92
+
93
+ transformer.enable_parallelism(
94
+ config=cp_config, cp_plan=cp_plan
95
+ )
96
+ _maybe_patch_native_parallel_config(transformer)
97
+ else:
98
+ raise ValueError(
99
+ f"{transformer.__class__.__name__} does not support context parallelism."
100
+ )
101
+
102
+ return transformer
103
+
104
+
105
+ def _maybe_patch_native_parallel_config(
106
+ transformer: torch.nn.Module,
107
+ ) -> torch.nn.Module:
108
+
109
+ cls_name = transformer.__class__.__name__
110
+ if not cls_name.startswith("Nunchaku"):
111
+ return transformer
112
+
113
+ from diffusers import FluxTransformer2DModel, QwenImageTransformer2DModel
114
+
115
+ try:
116
+ from nunchaku.models.transformers.transformer_flux_v2 import (
117
+ NunchakuFluxTransformer2DModelV2,
118
+ NunchakuFluxAttention,
119
+ NunchakuFluxFA2Processor,
120
+ )
121
+ from nunchaku.models.transformers.transformer_qwenimage import (
122
+ NunchakuQwenAttention,
123
+ NunchakuQwenImageNaiveFA2Processor,
124
+ NunchakuQwenImageTransformer2DModel,
125
+ )
126
+ except ImportError:
127
+ raise ImportError(
128
+ "NunchakuFluxTransformer2DModelV2 or NunchakuQwenImageTransformer2DModel "
129
+ "requires the 'nunchaku' package. Please install nunchaku before using "
130
+ "the context parallelism for nunchaku 4-bits models."
131
+ )
132
+ assert isinstance(
133
+ transformer,
134
+ (
135
+ NunchakuFluxTransformer2DModelV2,
136
+ FluxTransformer2DModel,
137
+ ),
138
+ ) or isinstance(
139
+ transformer,
140
+ (
141
+ NunchakuQwenImageTransformer2DModel,
142
+ QwenImageTransformer2DModel,
143
+ ),
144
+ ), (
145
+ "transformer must be an instance of NunchakuFluxTransformer2DModelV2 "
146
+ f"or NunchakuQwenImageTransformer2DModel, but got {type(transformer)}"
147
+ )
148
+ config = transformer._parallel_config
149
+
150
+ attention_classes = (
151
+ NunchakuFluxAttention,
152
+ NunchakuFluxFA2Processor,
153
+ NunchakuQwenAttention,
154
+ NunchakuQwenImageNaiveFA2Processor,
155
+ )
156
+ for module in transformer.modules():
157
+ if not isinstance(module, attention_classes):
158
+ continue
159
+ processor = getattr(module, "processor", None)
160
+ if processor is None or not hasattr(processor, "_parallel_config"):
161
+ continue
162
+ processor._parallel_config = config
163
+
164
+ return transformer
@@ -0,0 +1,4 @@
1
+ def maybe_resigter_native_attention_backend():
2
+ """Maybe re-register native attention backend to enable context parallelism."""
3
+ # Import custom attention backend ensuring registration
4
+ from ._attention_dispatch import _native_attention