cache-dit 0.2.24__py3-none-any.whl → 0.2.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (29) hide show
  1. cache_dit/__init__.py +5 -0
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +2 -0
  4. cache_dit/cache_factory/cache_adapters.py +375 -26
  5. cache_dit/cache_factory/cache_blocks/__init__.py +20 -0
  6. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +16 -0
  7. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +270 -0
  8. cache_dit/cache_factory/{cache_blocks.py → cache_blocks/pattern_base.py} +17 -18
  9. cache_dit/cache_factory/cache_blocks/utils.py +19 -0
  10. cache_dit/cache_factory/cache_context.py +32 -25
  11. cache_dit/cache_factory/cache_interface.py +8 -3
  12. cache_dit/cache_factory/forward_pattern.py +45 -24
  13. cache_dit/cache_factory/patch_functors/__init__.py +5 -0
  14. cache_dit/cache_factory/patch_functors/functor_base.py +18 -0
  15. cache_dit/cache_factory/patch_functors/functor_chroma.py +273 -0
  16. cache_dit/cache_factory/{patch/flux.py → patch_functors/functor_flux.py} +45 -31
  17. cache_dit/compile/utils.py +1 -1
  18. cache_dit/quantize/__init__.py +1 -0
  19. cache_dit/quantize/quantize_ao.py +196 -0
  20. cache_dit/quantize/quantize_interface.py +46 -0
  21. cache_dit/utils.py +49 -17
  22. {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/METADATA +43 -18
  23. cache_dit-0.2.26.dist-info/RECORD +42 -0
  24. cache_dit-0.2.24.dist-info/RECORD +0 -32
  25. /cache_dit/{cache_factory/patch/__init__.py → quantize/quantize_svdq.py} +0 -0
  26. {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/WHEEL +0 -0
  27. {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/entry_points.txt +0 -0
  28. {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/licenses/LICENSE +0 -0
  29. {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,273 @@
1
+ import inspect
2
+
3
+ import torch
4
+ import numpy as np
5
+ from typing import Tuple, Optional, Dict, Any, Union
6
+ from diffusers import ChromaTransformer2DModel
7
+ from diffusers.models.transformers.transformer_chroma import (
8
+ ChromaSingleTransformerBlock,
9
+ Transformer2DModelOutput,
10
+ )
11
+ from diffusers.utils import (
12
+ USE_PEFT_BACKEND,
13
+ scale_lora_layers,
14
+ unscale_lora_layers,
15
+ )
16
+
17
+ from cache_dit.cache_factory.patch_functors.functor_base import (
18
+ PatchFunctor,
19
+ )
20
+ from cache_dit.logger import init_logger
21
+
22
+ logger = init_logger(__name__)
23
+
24
+
25
+ class ChromaPatchFunctor(PatchFunctor):
26
+
27
+ def apply(
28
+ self,
29
+ transformer: ChromaTransformer2DModel,
30
+ blocks: torch.nn.ModuleList = None,
31
+ **kwargs,
32
+ ) -> ChromaTransformer2DModel:
33
+ if blocks is None:
34
+ blocks = transformer.single_transformer_blocks
35
+
36
+ is_patched = False
37
+ for block in blocks:
38
+ if isinstance(block, ChromaSingleTransformerBlock):
39
+ forward_parameters = inspect.signature(
40
+ block.forward
41
+ ).parameters.keys()
42
+ if "encoder_hidden_states" not in forward_parameters:
43
+ block.forward = __patch_single_forward__.__get__(block)
44
+ is_patched = True
45
+
46
+ if is_patched:
47
+ logger.warning("Patched Chroma for cache-dit.")
48
+ assert not getattr(transformer, "_is_parallelized", False), (
49
+ "Please call `cache_dit.enable_cache` before Parallelize, "
50
+ "the __patch_transformer_forward__ will overwrite the "
51
+ "parallized forward and cause a downgrade of performance."
52
+ )
53
+ transformer.forward = __patch_transformer_forward__.__get__(
54
+ transformer
55
+ )
56
+ transformer._is_patched = True
57
+
58
+ cls_name = transformer.__class__.__name__
59
+ logger.info(
60
+ f"Applied {self.__class__.__name__} for {cls_name}, "
61
+ f"Patch: {is_patched}."
62
+ )
63
+
64
+ return transformer
65
+
66
+
67
+ # adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_chroma.py
68
+ def __patch_single_forward__(
69
+ self: ChromaSingleTransformerBlock, # Almost same as FluxSingleTransformerBlock
70
+ hidden_states: torch.Tensor,
71
+ encoder_hidden_states: torch.Tensor,
72
+ temb: torch.Tensor,
73
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
74
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
75
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
76
+ text_seq_len = encoder_hidden_states.shape[1]
77
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
78
+
79
+ residual = hidden_states
80
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
81
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
82
+ joint_attention_kwargs = joint_attention_kwargs or {}
83
+ attn_output = self.attn(
84
+ hidden_states=norm_hidden_states,
85
+ image_rotary_emb=image_rotary_emb,
86
+ **joint_attention_kwargs,
87
+ )
88
+
89
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
90
+ gate = gate.unsqueeze(1)
91
+ hidden_states = gate * self.proj_out(hidden_states)
92
+ hidden_states = residual + hidden_states
93
+ if hidden_states.dtype == torch.float16:
94
+ hidden_states = hidden_states.clip(-65504, 65504)
95
+
96
+ encoder_hidden_states, hidden_states = (
97
+ hidden_states[:, :text_seq_len],
98
+ hidden_states[:, text_seq_len:],
99
+ )
100
+ return encoder_hidden_states, hidden_states
101
+
102
+
103
+ # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_chroma.py
104
+ def __patch_transformer_forward__(
105
+ self: ChromaTransformer2DModel,
106
+ hidden_states: torch.Tensor,
107
+ encoder_hidden_states: torch.Tensor = None,
108
+ timestep: torch.LongTensor = None,
109
+ img_ids: torch.Tensor = None,
110
+ txt_ids: torch.Tensor = None,
111
+ attention_mask: torch.Tensor = None,
112
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
113
+ controlnet_block_samples=None,
114
+ controlnet_single_block_samples=None,
115
+ return_dict: bool = True,
116
+ controlnet_blocks_repeat: bool = False,
117
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
118
+ if joint_attention_kwargs is not None:
119
+ joint_attention_kwargs = joint_attention_kwargs.copy()
120
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
121
+ else:
122
+ lora_scale = 1.0
123
+
124
+ if USE_PEFT_BACKEND:
125
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
126
+ scale_lora_layers(self, lora_scale)
127
+ else:
128
+ if (
129
+ joint_attention_kwargs is not None
130
+ and joint_attention_kwargs.get("scale", None) is not None
131
+ ):
132
+ logger.warning(
133
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
134
+ )
135
+
136
+ hidden_states = self.x_embedder(hidden_states)
137
+
138
+ timestep = timestep.to(hidden_states.dtype) * 1000
139
+
140
+ input_vec = self.time_text_embed(timestep)
141
+ pooled_temb = self.distilled_guidance_layer(input_vec)
142
+
143
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
144
+
145
+ if txt_ids.ndim == 3:
146
+ logger.warning(
147
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
148
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
149
+ )
150
+ txt_ids = txt_ids[0]
151
+ if img_ids.ndim == 3:
152
+ logger.warning(
153
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
154
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
155
+ )
156
+ img_ids = img_ids[0]
157
+
158
+ ids = torch.cat((txt_ids, img_ids), dim=0)
159
+ image_rotary_emb = self.pos_embed(ids)
160
+
161
+ if (
162
+ joint_attention_kwargs is not None
163
+ and "ip_adapter_image_embeds" in joint_attention_kwargs
164
+ ):
165
+ ip_adapter_image_embeds = joint_attention_kwargs.pop(
166
+ "ip_adapter_image_embeds"
167
+ )
168
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
169
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
170
+
171
+ for index_block, block in enumerate(self.transformer_blocks):
172
+ img_offset = 3 * len(self.single_transformer_blocks)
173
+ txt_offset = img_offset + 6 * len(self.transformer_blocks)
174
+ img_modulation = img_offset + 6 * index_block
175
+ text_modulation = txt_offset + 6 * index_block
176
+ temb = torch.cat(
177
+ (
178
+ pooled_temb[:, img_modulation : img_modulation + 6],
179
+ pooled_temb[:, text_modulation : text_modulation + 6],
180
+ ),
181
+ dim=1,
182
+ )
183
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
184
+ encoder_hidden_states, hidden_states = (
185
+ self._gradient_checkpointing_func(
186
+ block,
187
+ hidden_states,
188
+ encoder_hidden_states,
189
+ temb,
190
+ image_rotary_emb,
191
+ attention_mask,
192
+ )
193
+ )
194
+
195
+ else:
196
+ encoder_hidden_states, hidden_states = block(
197
+ hidden_states=hidden_states,
198
+ encoder_hidden_states=encoder_hidden_states,
199
+ temb=temb,
200
+ image_rotary_emb=image_rotary_emb,
201
+ attention_mask=attention_mask,
202
+ joint_attention_kwargs=joint_attention_kwargs,
203
+ )
204
+
205
+ # controlnet residual
206
+ if controlnet_block_samples is not None:
207
+ interval_control = len(self.transformer_blocks) / len(
208
+ controlnet_block_samples
209
+ )
210
+ interval_control = int(np.ceil(interval_control))
211
+ # For Xlabs ControlNet.
212
+ if controlnet_blocks_repeat:
213
+ hidden_states = (
214
+ hidden_states
215
+ + controlnet_block_samples[
216
+ index_block % len(controlnet_block_samples)
217
+ ]
218
+ )
219
+ else:
220
+ hidden_states = (
221
+ hidden_states
222
+ + controlnet_block_samples[index_block // interval_control]
223
+ )
224
+
225
+ for index_block, block in enumerate(self.single_transformer_blocks):
226
+ start_idx = 3 * index_block
227
+ temb = pooled_temb[:, start_idx : start_idx + 3]
228
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
229
+ encoder_hidden_states, hidden_states = (
230
+ self._gradient_checkpointing_func(
231
+ block,
232
+ hidden_states,
233
+ encoder_hidden_states,
234
+ temb,
235
+ image_rotary_emb,
236
+ )
237
+ )
238
+
239
+ else:
240
+ encoder_hidden_states, hidden_states = block(
241
+ hidden_states=hidden_states,
242
+ encoder_hidden_states=encoder_hidden_states,
243
+ temb=temb,
244
+ image_rotary_emb=image_rotary_emb,
245
+ attention_mask=attention_mask,
246
+ joint_attention_kwargs=joint_attention_kwargs,
247
+ )
248
+
249
+ # controlnet residual
250
+ if controlnet_single_block_samples is not None:
251
+ interval_control = len(self.single_transformer_blocks) / len(
252
+ controlnet_single_block_samples
253
+ )
254
+ interval_control = int(np.ceil(interval_control))
255
+ hidden_states = (
256
+ hidden_states
257
+ + controlnet_single_block_samples[
258
+ index_block // interval_control
259
+ ]
260
+ )
261
+
262
+ temb = pooled_temb[:, -2:]
263
+ hidden_states = self.norm_out(hidden_states, temb)
264
+ output = self.proj_out(hidden_states)
265
+
266
+ if USE_PEFT_BACKEND:
267
+ # remove `lora_scale` from each PEFT layer
268
+ unscale_lora_layers(self, lora_scale)
269
+
270
+ if not return_dict:
271
+ return (output,)
272
+
273
+ return Transformer2DModelOutput(sample=output)
@@ -14,12 +14,56 @@ from diffusers.utils import (
14
14
  unscale_lora_layers,
15
15
  )
16
16
 
17
-
17
+ from cache_dit.cache_factory.patch_functors.functor_base import (
18
+ PatchFunctor,
19
+ )
18
20
  from cache_dit.logger import init_logger
19
21
 
20
22
  logger = init_logger(__name__)
21
23
 
22
24
 
25
+ class FluxPatchFunctor(PatchFunctor):
26
+
27
+ def apply(
28
+ self,
29
+ transformer: FluxTransformer2DModel,
30
+ blocks: torch.nn.ModuleList = None,
31
+ **kwargs,
32
+ ) -> FluxTransformer2DModel:
33
+ if blocks is None:
34
+ blocks = transformer.single_transformer_blocks
35
+
36
+ is_patched = False
37
+ for block in blocks:
38
+ if isinstance(block, FluxSingleTransformerBlock):
39
+ forward_parameters = inspect.signature(
40
+ block.forward
41
+ ).parameters.keys()
42
+ if "encoder_hidden_states" not in forward_parameters:
43
+ block.forward = __patch_single_forward__.__get__(block)
44
+ is_patched = True
45
+
46
+ if is_patched:
47
+ logger.warning("Patched Flux for cache-dit.")
48
+ assert not getattr(transformer, "_is_parallelized", False), (
49
+ "Please call `cache_dit.enable_cache` before Parallelize, "
50
+ "the __patch_transformer_forward__ will overwrite the "
51
+ "parallized forward and cause a downgrade of performance."
52
+ )
53
+ transformer.forward = __patch_transformer_forward__.__get__(
54
+ transformer
55
+ )
56
+ transformer._is_patched = True
57
+
58
+ cls_name = transformer.__class__.__name__
59
+ logger.info(
60
+ f"Applied {self.__class__.__name__} for {cls_name}, "
61
+ f"Patch: {is_patched}."
62
+ )
63
+
64
+ return transformer
65
+
66
+
23
67
  # copy from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L380
24
68
  def __patch_single_forward__(
25
69
  self: FluxSingleTransformerBlock,
@@ -217,33 +261,3 @@ def __patch_transformer_forward__(
217
261
  return (output,)
218
262
 
219
263
  return Transformer2DModelOutput(sample=output)
220
-
221
-
222
- def maybe_patch_flux_transformer(
223
- transformer: FluxTransformer2DModel,
224
- blocks: torch.nn.ModuleList = None,
225
- ) -> FluxTransformer2DModel:
226
- if blocks is None:
227
- blocks = transformer.single_transformer_blocks
228
-
229
- is_patched = False
230
- for block in blocks:
231
- if isinstance(block, FluxSingleTransformerBlock):
232
- forward_parameters = inspect.signature(
233
- block.forward
234
- ).parameters.keys()
235
- if "encoder_hidden_states" not in forward_parameters:
236
- block.forward = __patch_single_forward__.__get__(block)
237
- is_patched = True
238
-
239
- if is_patched:
240
- logger.warning("Patched Flux for cache-dit.")
241
- assert not getattr(transformer, "_is_parallelized", False), (
242
- "Please call apply_cache_on_pipe before Parallelize, "
243
- "the __patch_transformer_forward__ will overwrite the "
244
- "parallized forward and cause a downgrade of performance."
245
- )
246
- transformer.forward = __patch_transformer_forward__.__get__(transformer)
247
- transformer._is_patched = True
248
-
249
- return transformer
@@ -24,7 +24,7 @@ def epilogue_prologue_fusion_enabled(**kwargs) -> bool:
24
24
 
25
25
 
26
26
  def set_compile_configs(
27
- descent_tuning: bool = True,
27
+ descent_tuning: bool = False,
28
28
  cuda_graphs: bool = False,
29
29
  force_disable_compile_caches: bool = False,
30
30
  use_fast_math: bool = False,
@@ -0,0 +1 @@
1
+ from cache_dit.quantize.quantize_interface import quantize
@@ -0,0 +1,196 @@
1
+ import gc
2
+ import time
3
+ import torch
4
+ from typing import Callable, Optional, List
5
+ from cache_dit.logger import init_logger
6
+
7
+ logger = init_logger(__name__)
8
+
9
+
10
+ def quantize_ao(
11
+ module: torch.nn.Module,
12
+ quant_type: str = "fp8_w8a8_dq",
13
+ exclude_layers: List[str] = [
14
+ "embedder",
15
+ "embed",
16
+ ],
17
+ filter_fn: Optional[Callable] = None,
18
+ # paramters for fp8 quantization
19
+ per_row: bool = True,
20
+ **kwargs,
21
+ ) -> torch.nn.Module:
22
+ # Apply FP8 DQ for module and skip any `embed` modules
23
+ # by default to avoid non-trivial precision downgrade. Please
24
+ # set `exclude_layers` as `[]` if you don't want this behavior.
25
+ assert isinstance(module, torch.nn.Module)
26
+
27
+ quant_type = quant_type.lower()
28
+ assert quant_type in (
29
+ "fp8_w8a8_dq",
30
+ "fp8_w8a16_wo",
31
+ "int8_w8a8_dq",
32
+ "int8_w8a16_wo",
33
+ "int4_w4a8_dq",
34
+ "int4_w4a4_dq",
35
+ "int4_w4a16_wo",
36
+ ), f"{quant_type} is not supported for torchao backend now!"
37
+
38
+ if "fp8" in quant_type:
39
+ assert torch.cuda.get_device_capability() >= (
40
+ 8,
41
+ 9,
42
+ ), "FP8 is not supported for current device."
43
+
44
+ num_quant_linear = 0
45
+ num_skip_linear = 0
46
+ num_linear_layers = 0
47
+ num_layers = 0
48
+
49
+ # Ensure bfloat16 for per_row
50
+ def _filter_fn(m: torch.nn.Module, name: str) -> bool:
51
+ nonlocal num_quant_linear, num_skip_linear, num_linear_layers, num_layers
52
+ num_layers += 1
53
+ if isinstance(m, torch.nn.Linear):
54
+ num_linear_layers += 1
55
+ for exclude_name in exclude_layers:
56
+ if exclude_name in name:
57
+ logger.info(
58
+ f"Skip Quantization: {name} -> "
59
+ f"pattern<{exclude_name}>"
60
+ )
61
+
62
+ num_skip_linear += 1
63
+ return False
64
+
65
+ if (
66
+ per_row
67
+ and m.weight.dtype != torch.bfloat16
68
+ and quant_type == "fp8_w8a8_dq"
69
+ ):
70
+ logger.info(
71
+ f"Skip Quantization: {name} -> "
72
+ f"pattern<dtype({m.weight.dtype})!=bfloat16>"
73
+ )
74
+
75
+ num_skip_linear += 1
76
+ return False
77
+
78
+ num_quant_linear += 1
79
+ return True
80
+
81
+ return False
82
+
83
+ def _quantization_fn():
84
+ try:
85
+ if quant_type == "fp8_w8a8_dq":
86
+ from torchao.quantization import (
87
+ float8_dynamic_activation_float8_weight,
88
+ PerTensor,
89
+ PerRow,
90
+ )
91
+
92
+ quantization_fn = float8_dynamic_activation_float8_weight(
93
+ weight_dtype=kwargs.get(
94
+ "weight_dtype",
95
+ torch.float8_e4m3fn,
96
+ ),
97
+ activation_dtype=kwargs.get(
98
+ "activation_dtype",
99
+ torch.float8_e4m3fn,
100
+ ),
101
+ granularity=(
102
+ ((PerRow(), PerRow()))
103
+ if per_row
104
+ else ((PerTensor(), PerTensor()))
105
+ ),
106
+ )
107
+
108
+ elif quant_type == "fp8_w8a16_wo":
109
+ from torchao.quantization import float8_weight_only
110
+
111
+ quantization_fn = float8_weight_only(
112
+ weight_dtype=kwargs.get(
113
+ "weight_dtype",
114
+ torch.float8_e4m3fn,
115
+ ),
116
+ )
117
+
118
+ elif quant_type == "int8_w8a8_dq":
119
+ from torchao.quantization import (
120
+ int8_dynamic_activation_int8_weight,
121
+ )
122
+
123
+ quantization_fn = int8_dynamic_activation_int8_weight()
124
+
125
+ elif quant_type == "int8_w8a16_wo":
126
+ from torchao.quantization import int8_weight_only
127
+
128
+ quantization_fn = int8_weight_only(
129
+ # group_size is None -> per_channel, else per group
130
+ group_size=kwargs.get("group_size", None),
131
+ )
132
+
133
+ elif quant_type == "int4_w4a8_dq":
134
+ from torchao.quantization import (
135
+ int8_dynamic_activation_int4_weight,
136
+ )
137
+
138
+ quantization_fn = int8_dynamic_activation_int4_weight(
139
+ group_size=kwargs.get("group_size", 32),
140
+ )
141
+
142
+ elif quant_type == "int4_w4a4_dq":
143
+ from torchao.quantization import (
144
+ int4_dynamic_activation_int4_weight,
145
+ )
146
+
147
+ quantization_fn = int4_dynamic_activation_int4_weight()
148
+
149
+ elif quant_type == "int4_w4a16_wo":
150
+ from torchao.quantization import int4_weight_only
151
+
152
+ quantization_fn = int4_weight_only(
153
+ group_size=kwargs.get("group_size", 32),
154
+ )
155
+
156
+ else:
157
+ raise ValueError(
158
+ f"quant_type: {quant_type} is not supported now!"
159
+ )
160
+
161
+ except ImportError as e:
162
+ e.msg += (
163
+ f"{quant_type} is not supported in torchao backend now! "
164
+ "Please upgrade the torchao library."
165
+ )
166
+ raise e
167
+
168
+ return quantization_fn
169
+
170
+ from torchao.quantization import quantize_
171
+
172
+ quantize_(
173
+ module,
174
+ _quantization_fn(),
175
+ filter_fn=_filter_fn if filter_fn is None else filter_fn,
176
+ device=kwargs.get("device", None),
177
+ )
178
+
179
+ force_empty_cache()
180
+
181
+ logger.info(
182
+ f"Quantized Linear Layers: {num_quant_linear:>5}\n"
183
+ f"Skipped Linear Layers: {num_skip_linear:>5}\n"
184
+ f"Total Linear Layers: {num_linear_layers:>5}\n"
185
+ f"Total (all) Layers: {num_layers:>5}"
186
+ )
187
+ return module
188
+
189
+
190
+ def force_empty_cache():
191
+ time.sleep(1)
192
+ gc.collect()
193
+ torch.cuda.empty_cache()
194
+ time.sleep(1)
195
+ gc.collect()
196
+ torch.cuda.empty_cache()
@@ -0,0 +1,46 @@
1
+ import torch
2
+ from typing import Callable, Optional, List
3
+ from cache_dit.logger import init_logger
4
+
5
+ logger = init_logger(__name__)
6
+
7
+
8
+ def quantize(
9
+ module: torch.nn.Module,
10
+ quant_type: str = "fp8_w8a8_dq",
11
+ backend: str = "ao",
12
+ exclude_layers: List[str] = [
13
+ "embedder",
14
+ "embed",
15
+ ],
16
+ filter_fn: Optional[Callable] = None,
17
+ # only for fp8_w8a8_dq
18
+ per_row: bool = True,
19
+ **kwargs,
20
+ ) -> torch.nn.Module:
21
+ assert isinstance(module, torch.nn.Module)
22
+
23
+ if backend.lower() in ("ao", "torchao"):
24
+ from cache_dit.quantize.quantize_ao import quantize_ao
25
+
26
+ quant_type = quant_type.lower()
27
+ assert quant_type in (
28
+ "fp8_w8a8_dq",
29
+ "fp8_w8a16_wo",
30
+ "int8_w8a8_dq",
31
+ "int8_w8a16_wo",
32
+ "int4_w4a8_dq",
33
+ "int4_w4a4_dq",
34
+ "int4_w4a16_wo",
35
+ ), f"{quant_type} is not supported for torchao backend now!"
36
+
37
+ return quantize_ao(
38
+ module,
39
+ quant_type=quant_type,
40
+ per_row=per_row,
41
+ exclude_layers=exclude_layers,
42
+ filter_fn=filter_fn,
43
+ **kwargs,
44
+ )
45
+ else:
46
+ raise ValueError(f"backend: {backend} is not supported now!")