cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  11. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  12. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  13. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  14. cache_dit/caching/cache_contexts/__init__.py +28 -0
  15. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  16. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
  18. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  21. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  22. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  23. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  24. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  25. cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
  26. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  27. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  28. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  29. cache_dit/caching/patch_functors/__init__.py +15 -0
  30. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
  36. cache_dit/{cache_factory → caching}/utils.py +19 -8
  37. cache_dit/metrics/__init__.py +11 -0
  38. cache_dit/parallelism/__init__.py +3 -0
  39. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  40. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  41. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  57. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  58. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  59. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  60. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  61. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  62. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  68. cache_dit/parallelism/parallel_backend.py +26 -0
  69. cache_dit/parallelism/parallel_config.py +88 -0
  70. cache_dit/parallelism/parallel_interface.py +77 -0
  71. cache_dit/quantize/__init__.py +7 -0
  72. cache_dit/quantize/backends/__init__.py +1 -0
  73. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  74. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  75. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
  76. cache_dit/quantize/quantize_backend.py +0 -0
  77. cache_dit/quantize/quantize_config.py +0 -0
  78. cache_dit/quantize/quantize_interface.py +3 -16
  79. cache_dit/summary.py +593 -0
  80. cache_dit/utils.py +46 -290
  81. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
  82. cache_dit-1.0.14.dist-info/RECORD +102 -0
  83. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  84. cache_dit/cache_factory/__init__.py +0 -28
  85. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  86. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  87. cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
  88. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  89. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
  90. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
  91. cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
  92. cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
  93. cache_dit/cache_factory/patch_functors/__init__.py +0 -15
  94. cache_dit-1.0.3.dist-info/RECORD +0 -58
  95. cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
  96. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  97. /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
  98. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  99. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  100. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  101. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  102. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  103. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  104. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,153 @@
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.distributed import DeviceMesh, init_device_mesh
6
+ from torch.distributed.tensor.parallel import (
7
+ ColwiseParallel,
8
+ RowwiseParallel,
9
+ parallelize_module,
10
+ )
11
+
12
+ from cache_dit.logger import init_logger
13
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
14
+
15
+ from .tp_plan_registers import (
16
+ TensorParallelismPlanner,
17
+ TensorParallelismPlannerRegister,
18
+ )
19
+
20
+ logger = init_logger(__name__)
21
+
22
+
23
+ class DistributedRMSNorm(nn.Module):
24
+ def __init__(
25
+ self,
26
+ tp_mesh: DeviceMesh,
27
+ normalized_shape: Union[int, list[int], torch.Size],
28
+ eps: Optional[float],
29
+ elementwise_affine: bool,
30
+ weight: torch.nn.parameter.Parameter,
31
+ ):
32
+ super().__init__()
33
+ self.tp_mesh = tp_mesh
34
+ self.elementwise_affine = elementwise_affine
35
+ self.normalized_shape = normalized_shape
36
+ self.eps = eps
37
+ if self.elementwise_affine:
38
+ assert weight is not None
39
+ self.weight = weight
40
+
41
+ @classmethod
42
+ def from_rmsnorm(cls, tp_mesh: DeviceMesh, rmsnorm: nn.RMSNorm):
43
+ if not isinstance(rmsnorm, int):
44
+ assert len(rmsnorm.normalized_shape) == 1
45
+
46
+ if rmsnorm.weight is not None:
47
+ tp_size = tp_mesh.get_group().size()
48
+ tp_rank = tp_mesh.get_group().rank()
49
+ weight = rmsnorm.weight.chunk(tp_size, dim=0)[tp_rank]
50
+ else:
51
+ weight = None
52
+ norm = cls(
53
+ tp_mesh=tp_mesh,
54
+ normalized_shape=rmsnorm.normalized_shape,
55
+ eps=rmsnorm.eps,
56
+ elementwise_affine=rmsnorm.elementwise_affine,
57
+ weight=weight,
58
+ )
59
+ return norm
60
+
61
+ def forward(self, x):
62
+ if self.elementwise_affine:
63
+ assert x.shape[-1] == self.weight.shape[0]
64
+ mean_square = torch.mean(x * x, dim=-1, keepdim=True)
65
+ torch.distributed.all_reduce(
66
+ mean_square,
67
+ op=torch.distributed.ReduceOp.AVG,
68
+ group=self.tp_mesh.get_group(),
69
+ )
70
+ root_mean_square = torch.sqrt(mean_square + self.eps)
71
+ x_normed = x / root_mean_square
72
+ if self.elementwise_affine:
73
+ x_normed = x_normed * self.weight.to(device=x.device)
74
+ assert x_normed.device.type != "cpu"
75
+ return x_normed
76
+
77
+
78
+ @TensorParallelismPlannerRegister.register("Wan")
79
+ class WanTensorParallelismPlanner(TensorParallelismPlanner):
80
+ def apply(
81
+ self,
82
+ transformer: torch.nn.Module,
83
+ parallelism_config: ParallelismConfig,
84
+ **kwargs,
85
+ ) -> torch.nn.Module:
86
+ assert (
87
+ parallelism_config.tp_size is not None
88
+ and parallelism_config.tp_size > 1
89
+ ), (
90
+ "parallel_config.tp_size must be set and greater than 1 for "
91
+ "tensor parallelism"
92
+ )
93
+
94
+ device_type = torch.accelerator.current_accelerator().type
95
+ tp_mesh: DeviceMesh = init_device_mesh(
96
+ device_type=device_type,
97
+ mesh_shape=[parallelism_config.tp_size],
98
+ )
99
+
100
+ transformer = self.parallelize_transformer(
101
+ transformer=transformer,
102
+ tp_mesh=tp_mesh,
103
+ )
104
+
105
+ return transformer
106
+
107
+ def parallelize_transformer(
108
+ self,
109
+ transformer: nn.Module,
110
+ tp_mesh: DeviceMesh,
111
+ ):
112
+ for _, block in transformer.blocks.named_children():
113
+ block.attn1.heads //= tp_mesh.size()
114
+ block.attn2.heads //= tp_mesh.size()
115
+ layer_plan = {
116
+ "attn1.to_q": ColwiseParallel(),
117
+ "attn1.to_k": ColwiseParallel(),
118
+ "attn1.to_v": ColwiseParallel(),
119
+ "attn1.to_out.0": RowwiseParallel(),
120
+ "attn2.to_q": ColwiseParallel(),
121
+ "attn2.to_k": ColwiseParallel(),
122
+ "attn2.to_v": ColwiseParallel(),
123
+ "attn2.to_out.0": RowwiseParallel(),
124
+ "ffn.net.0.proj": ColwiseParallel(),
125
+ "ffn.net.2": RowwiseParallel(),
126
+ }
127
+ if getattr(block.attn2, "add_k_proj", None):
128
+ layer_plan["attn2.add_k_proj"] = ColwiseParallel()
129
+ if getattr(block.attn2, "add_v_proj", None):
130
+ layer_plan["attn2.add_v_proj"] = ColwiseParallel()
131
+ parallelize_module(
132
+ module=block,
133
+ device_mesh=tp_mesh,
134
+ parallelize_plan=layer_plan,
135
+ )
136
+
137
+ block.attn1.norm_q = DistributedRMSNorm.from_rmsnorm(
138
+ tp_mesh, block.attn1.norm_q
139
+ )
140
+ block.attn1.norm_k = DistributedRMSNorm.from_rmsnorm(
141
+ tp_mesh, block.attn1.norm_k
142
+ )
143
+ block.attn2.norm_q = DistributedRMSNorm.from_rmsnorm(
144
+ tp_mesh, block.attn2.norm_q
145
+ )
146
+ block.attn2.norm_k = DistributedRMSNorm.from_rmsnorm(
147
+ tp_mesh, block.attn2.norm_k
148
+ )
149
+ if getattr(block.attn2, "norm_added_k", None):
150
+ block.attn2.norm_added_k = DistributedRMSNorm.from_rmsnorm(
151
+ tp_mesh, block.attn2.norm_added_k
152
+ )
153
+ return transformer
@@ -0,0 +1,14 @@
1
+ # NOTE: must import all planner classes to register them
2
+ from .tp_plan_flux import FluxTensorParallelismPlanner
3
+ from .tp_plan_kandinsky5 import Kandinsky5TensorParallelismPlanner
4
+ from .tp_plan_qwen_image import QwenImageTensorParallelismPlanner
5
+ from .tp_plan_registers import TensorParallelismPlannerRegister
6
+ from .tp_plan_wan import WanTensorParallelismPlanner
7
+
8
+ __all__ = [
9
+ "FluxTensorParallelismPlanner",
10
+ "Kandinsky5TensorParallelismPlanner",
11
+ "QwenImageTensorParallelismPlanner",
12
+ "TensorParallelismPlannerRegister",
13
+ "WanTensorParallelismPlanner",
14
+ ]
@@ -0,0 +1,26 @@
1
+ from enum import Enum
2
+
3
+
4
+ class ParallelismBackend(Enum):
5
+ NATIVE_DIFFUSER = "Native_Diffuser"
6
+ NATIVE_PYTORCH = "Native_PyTorch"
7
+ NONE = "None"
8
+
9
+ @classmethod
10
+ def is_supported(cls, backend: "ParallelismBackend") -> bool:
11
+ if backend == cls.NATIVE_PYTORCH:
12
+ return True
13
+ elif backend == cls.NATIVE_DIFFUSER:
14
+ try:
15
+ from diffusers.models._modeling_parallel import ( # noqa F401
16
+ ContextParallelModelPlan,
17
+ )
18
+ except ImportError:
19
+ raise ImportError(
20
+ "NATIVE_DIFFUSER parallelism backend requires the latest "
21
+ "version of diffusers(>=0.36.dev0). Please install latest "
22
+ "version of diffusers from source: \npip3 install "
23
+ "git+https://github.com/huggingface/diffusers.git"
24
+ )
25
+ return True
26
+ return False
@@ -0,0 +1,88 @@
1
+ import dataclasses
2
+ from typing import Optional, Dict, Any
3
+ from cache_dit.parallelism.parallel_backend import ParallelismBackend
4
+ from cache_dit.logger import init_logger
5
+
6
+ logger = init_logger(__name__)
7
+
8
+
9
+ @dataclasses.dataclass
10
+ class ParallelismConfig:
11
+ # Parallelism backend, defaults to NATIVE_DIFFUSER
12
+ backend: ParallelismBackend = ParallelismBackend.NATIVE_DIFFUSER
13
+ # Context parallelism config
14
+ # ulysses_size (`int`, *optional*):
15
+ # The degree of ulysses parallelism.
16
+ ulysses_size: int = None
17
+ # ring_size (`int`, *optional*):
18
+ # The degree of ring parallelism.
19
+ ring_size: int = None
20
+ # Tensor parallelism config
21
+ # tp_size (`int`, *optional*):
22
+ # The degree of tensor parallelism.
23
+ tp_size: int = None
24
+ # parallel_kwargs (`dict`, *optional*):
25
+ # Additional kwargs for parallelism backends. For example, for
26
+ # NATIVE_DIFFUSER backend, it can include `cp_plan` and
27
+ # `attention_backend` arguments for `Context Parallelism`.
28
+ parallel_kwargs: Optional[Dict[str, Any]] = dataclasses.field(
29
+ default_factory=dict
30
+ )
31
+
32
+ def __post_init__(self):
33
+ assert ParallelismBackend.is_supported(self.backend), (
34
+ f"Parallel backend {self.backend} is not supported. "
35
+ f"Please make sure the required packages are installed."
36
+ )
37
+
38
+ # Validate the parallelism configuration and auto adjust the backend if needed
39
+ if self.tp_size is not None and self.tp_size > 1:
40
+ assert (
41
+ self.ulysses_size is None or self.ulysses_size == 1
42
+ ), "Tensor parallelism plus Ulysses parallelism is not supported right now."
43
+ assert (
44
+ self.ring_size is None or self.ring_size == 1
45
+ ), "Tensor parallelism plus Ring parallelism is not supported right now."
46
+ if self.backend != ParallelismBackend.NATIVE_PYTORCH:
47
+ logger.warning(
48
+ "Tensor parallelism is only supported for NATIVE_PYTORCH backend "
49
+ "right now. Force set backend to NATIVE_PYTORCH."
50
+ )
51
+ self.backend = ParallelismBackend.NATIVE_PYTORCH
52
+ elif (
53
+ self.ulysses_size is not None
54
+ and self.ulysses_size > 1
55
+ and self.ring_size is not None
56
+ and self.ring_size > 1
57
+ ):
58
+ raise ValueError(
59
+ "Ulysses parallelism plus Ring parallelism is not fully supported right now."
60
+ )
61
+ else:
62
+ if (self.ulysses_size is not None and self.ulysses_size > 1) or (
63
+ self.ring_size is not None and self.ring_size > 1
64
+ ):
65
+ if self.backend != ParallelismBackend.NATIVE_DIFFUSER:
66
+ logger.warning(
67
+ "Ulysses/Ring parallelism is only supported for NATIVE_DIFFUSER "
68
+ "backend right now. Force set backend to NATIVE_DIFFUSER."
69
+ )
70
+ self.backend = ParallelismBackend.NATIVE_DIFFUSER
71
+
72
+ def strify(self, details: bool = False) -> str:
73
+ if details:
74
+ return (
75
+ f"ParallelismConfig(backend={self.backend}, "
76
+ f"ulysses_size={self.ulysses_size}, "
77
+ f"ring_size={self.ring_size}, "
78
+ f"tp_size={self.tp_size})"
79
+ )
80
+ else:
81
+ parallel_str = ""
82
+ if self.ulysses_size is not None:
83
+ parallel_str += f"Ulysses{self.ulysses_size}"
84
+ if self.ring_size is not None:
85
+ parallel_str += f"Ring{self.ring_size}"
86
+ if self.tp_size is not None:
87
+ parallel_str += f"TP{self.tp_size}"
88
+ return parallel_str
@@ -0,0 +1,77 @@
1
+ import torch
2
+ from cache_dit.parallelism.parallel_backend import ParallelismBackend
3
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
4
+ from cache_dit.utils import maybe_empty_cache
5
+ from cache_dit.logger import init_logger
6
+
7
+ logger = init_logger(__name__)
8
+
9
+
10
+ def enable_parallelism(
11
+ transformer: torch.nn.Module,
12
+ parallelism_config: ParallelismConfig,
13
+ ) -> torch.nn.Module:
14
+ assert isinstance(transformer, torch.nn.Module), (
15
+ "transformer must be an instance of torch.nn.Module, "
16
+ f"but got {type(transformer)}"
17
+ )
18
+ if getattr(transformer, "_is_parallelized", False):
19
+ logger.warning(
20
+ "The transformer is already parallelized. "
21
+ "Skipping parallelism enabling."
22
+ )
23
+ return transformer
24
+
25
+ if parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER:
26
+ from cache_dit.parallelism.backends.native_diffusers import (
27
+ maybe_enable_parallelism,
28
+ )
29
+
30
+ transformer = maybe_enable_parallelism(
31
+ transformer,
32
+ parallelism_config,
33
+ )
34
+ elif parallelism_config.backend == ParallelismBackend.NATIVE_PYTORCH:
35
+ from cache_dit.parallelism.backends.native_pytorch import (
36
+ maybe_enable_parallelism,
37
+ )
38
+
39
+ transformer = maybe_enable_parallelism(
40
+ transformer,
41
+ parallelism_config,
42
+ )
43
+ else:
44
+ raise ValueError(
45
+ f"Parallel backend {parallelism_config.backend} is not supported yet."
46
+ )
47
+
48
+ transformer._is_parallelized = True # type: ignore[attr-defined]
49
+ # Use `parallelism` not `parallel` to avoid name conflict with diffusers.
50
+ transformer._parallelism_config = parallelism_config # type: ignore[attr-defined]
51
+ logger.info(
52
+ f"Enabled parallelism: {parallelism_config.strify(True)}, "
53
+ f"transformer id:{id(transformer)}"
54
+ )
55
+
56
+ # NOTE: Workaround for potential memory peak issue after parallelism
57
+ # enabling, specially for tensor parallelism in native pytorch backend.
58
+ maybe_empty_cache()
59
+
60
+ return transformer
61
+
62
+
63
+ def remove_parallelism_stats(
64
+ transformer: torch.nn.Module,
65
+ ) -> torch.nn.Module:
66
+ if not getattr(transformer, "_is_parallelized", False):
67
+ logger.warning(
68
+ "The transformer is not parallelized. "
69
+ "Skipping removing parallelism."
70
+ )
71
+ return transformer
72
+
73
+ if hasattr(transformer, "_is_parallelized"):
74
+ del transformer._is_parallelized # type: ignore[attr-defined]
75
+ if hasattr(transformer, "_parallelism_config"):
76
+ del transformer._parallelism_config # type: ignore[attr-defined]
77
+ return transformer
@@ -1 +1,8 @@
1
+ try:
2
+ import torchao
3
+ except ImportError:
4
+ raise ImportError(
5
+ "Quantization functionality requires the 'quantization' extra dependencies. "
6
+ "Install with: pip install cache-dit[quantization]"
7
+ )
1
8
  from cache_dit.quantize.quantize_interface import quantize
@@ -0,0 +1 @@
1
+ from .torchao import quantize_ao
File without changes
@@ -0,0 +1 @@
1
+ from .quantize_ao import quantize_ao
@@ -1,7 +1,6 @@
1
- import gc
2
- import time
3
1
  import torch
4
2
  from typing import Callable, Optional, List
3
+ from cache_dit.utils import maybe_empty_cache
5
4
  from cache_dit.logger import init_logger
6
5
 
7
6
  logger = init_logger(__name__)
@@ -9,7 +8,7 @@ logger = init_logger(__name__)
9
8
 
10
9
  def quantize_ao(
11
10
  module: torch.nn.Module,
12
- quant_type: str = "fp8_w8a8_dq",
11
+ quant_type: str = "float8_weight_only",
13
12
  exclude_layers: List[str] = [
14
13
  "embedder",
15
14
  "embed",
@@ -24,6 +23,18 @@ def quantize_ao(
24
23
  # set `exclude_layers` as `[]` if you don't want this behavior.
25
24
  assert isinstance(module, torch.nn.Module)
26
25
 
26
+ alias_map = {
27
+ "float8": "fp8_w8a8_dq",
28
+ "float8_weight_only": "fp8_w8a16_wo",
29
+ "int8": "int8_w8a8_dq",
30
+ "int8_weight_only": "int8_w8a16_wo",
31
+ "int4": "int4_w4a8_dq",
32
+ "int4_w4a4": "int4_w4a4_dq",
33
+ "int4_weight_only": "int4_w4a16_wo",
34
+ }
35
+ if quant_type.lower() in alias_map:
36
+ quant_type = alias_map[quant_type.lower()]
37
+
27
38
  quant_type = quant_type.lower()
28
39
  assert quant_type in (
29
40
  "fp8_w8a8_dq",
@@ -80,11 +91,11 @@ def quantize_ao(
80
91
 
81
92
  return False
82
93
 
83
- def _quantization_fn():
94
+ def _quant_config():
84
95
  try:
85
96
  if quant_type == "fp8_w8a8_dq":
86
97
  from torchao.quantization import (
87
- float8_dynamic_activation_float8_weight,
98
+ Float8DynamicActivationFloat8WeightConfig,
88
99
  PerTensor,
89
100
  PerRow,
90
101
  )
@@ -92,7 +103,7 @@ def quantize_ao(
92
103
  if per_row: # Ensure bfloat16
93
104
  module.to(torch.bfloat16)
94
105
 
95
- quantization_fn = float8_dynamic_activation_float8_weight(
106
+ quant_config = Float8DynamicActivationFloat8WeightConfig(
96
107
  weight_dtype=kwargs.get(
97
108
  "weight_dtype",
98
109
  torch.float8_e4m3fn,
@@ -109,9 +120,9 @@ def quantize_ao(
109
120
  )
110
121
 
111
122
  elif quant_type == "fp8_w8a16_wo":
112
- from torchao.quantization import float8_weight_only
123
+ from torchao.quantization import Float8WeightOnlyConfig
113
124
 
114
- quantization_fn = float8_weight_only(
125
+ quant_config = Float8WeightOnlyConfig(
115
126
  weight_dtype=kwargs.get(
116
127
  "weight_dtype",
117
128
  torch.float8_e4m3fn,
@@ -120,39 +131,43 @@ def quantize_ao(
120
131
 
121
132
  elif quant_type == "int8_w8a8_dq":
122
133
  from torchao.quantization import (
123
- int8_dynamic_activation_int8_weight,
134
+ Int8DynamicActivationInt8WeightConfig,
124
135
  )
125
136
 
126
- quantization_fn = int8_dynamic_activation_int8_weight()
137
+ quant_config = Int8DynamicActivationInt8WeightConfig()
127
138
 
128
139
  elif quant_type == "int8_w8a16_wo":
129
- from torchao.quantization import int8_weight_only
130
140
 
131
- quantization_fn = int8_weight_only(
141
+ from torchao.quantization import Int8WeightOnlyConfig
142
+
143
+ quant_config = Int8WeightOnlyConfig(
132
144
  # group_size is None -> per_channel, else per group
133
145
  group_size=kwargs.get("group_size", None),
134
146
  )
135
147
 
136
148
  elif quant_type == "int4_w4a8_dq":
149
+
137
150
  from torchao.quantization import (
138
- int8_dynamic_activation_int4_weight,
151
+ Int8DynamicActivationInt4WeightConfig,
139
152
  )
140
153
 
141
- quantization_fn = int8_dynamic_activation_int4_weight(
154
+ quant_config = Int8DynamicActivationInt4WeightConfig(
142
155
  group_size=kwargs.get("group_size", 32),
143
156
  )
144
157
 
145
158
  elif quant_type == "int4_w4a4_dq":
159
+
146
160
  from torchao.quantization import (
147
- int4_dynamic_activation_int4_weight,
161
+ Int4DynamicActivationInt4WeightConfig,
148
162
  )
149
163
 
150
- quantization_fn = int4_dynamic_activation_int4_weight()
164
+ quant_config = Int4DynamicActivationInt4WeightConfig()
151
165
 
152
166
  elif quant_type == "int4_w4a16_wo":
153
- from torchao.quantization import int4_weight_only
154
167
 
155
- quantization_fn = int4_weight_only(
168
+ from torchao.quantization import Int4WeightOnlyConfig
169
+
170
+ quant_config = Int4WeightOnlyConfig(
156
171
  group_size=kwargs.get("group_size", 32),
157
172
  )
158
173
 
@@ -168,18 +183,22 @@ def quantize_ao(
168
183
  )
169
184
  raise e
170
185
 
171
- return quantization_fn
186
+ return quant_config
172
187
 
173
188
  from torchao.quantization import quantize_
174
189
 
175
190
  quantize_(
176
191
  module,
177
- _quantization_fn(),
192
+ _quant_config(),
178
193
  filter_fn=_filter_fn if filter_fn is None else filter_fn,
179
194
  device=kwargs.get("device", None),
180
195
  )
181
196
 
182
- force_empty_cache()
197
+ maybe_empty_cache()
198
+
199
+ alias_map_rev = {v: k for k, v in alias_map.items()}
200
+ if quant_type in alias_map_rev:
201
+ quant_type = alias_map_rev[quant_type]
183
202
 
184
203
  logger.info(
185
204
  f"Quantized Module: {module.__class__.__name__:>5}\n"
@@ -193,12 +212,3 @@ def quantize_ao(
193
212
  module._quantize_type = quant_type
194
213
  module._is_quantized = True
195
214
  return module
196
-
197
-
198
- def force_empty_cache():
199
- time.sleep(1)
200
- gc.collect()
201
- torch.cuda.empty_cache()
202
- time.sleep(1)
203
- gc.collect()
204
- torch.cuda.empty_cache()
File without changes
File without changes
@@ -7,37 +7,24 @@ logger = init_logger(__name__)
7
7
 
8
8
  def quantize(
9
9
  module: torch.nn.Module,
10
- quant_type: str = "fp8_w8a8_dq",
10
+ quant_type: str = "float8_weight_only",
11
11
  backend: str = "ao",
12
12
  exclude_layers: List[str] = [
13
13
  "embedder",
14
14
  "embed",
15
15
  ],
16
16
  filter_fn: Optional[Callable] = None,
17
- # only for fp8_w8a8_dq
18
- per_row: bool = True,
19
17
  **kwargs,
20
18
  ) -> torch.nn.Module:
21
19
  assert isinstance(module, torch.nn.Module)
22
20
 
23
21
  if backend.lower() in ("ao", "torchao"):
24
- from cache_dit.quantize.quantize_ao import quantize_ao
25
-
26
- quant_type = quant_type.lower()
27
- assert quant_type in (
28
- "fp8_w8a8_dq",
29
- "fp8_w8a16_wo",
30
- "int8_w8a8_dq",
31
- "int8_w8a16_wo",
32
- "int4_w4a8_dq",
33
- "int4_w4a4_dq",
34
- "int4_w4a16_wo",
35
- ), f"{quant_type} is not supported for torchao backend now!"
22
+ from cache_dit.quantize.backends.torchao import quantize_ao
36
23
 
37
24
  return quantize_ao(
38
25
  module,
39
26
  quant_type=quant_type,
40
- per_row=per_row,
27
+ per_row=kwargs.pop("per_row", True),
41
28
  exclude_layers=exclude_layers,
42
29
  filter_fn=filter_fn,
43
30
  **kwargs,