cache-dit 0.3.2__py3-none-any.whl → 1.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +37 -19
- cache_dit/_version.py +2 -2
- cache_dit/caching/__init__.py +36 -0
- cache_dit/{cache_factory → caching}/block_adapters/__init__.py +149 -18
- cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +91 -7
- cache_dit/caching/block_adapters/block_registers.py +118 -0
- cache_dit/caching/cache_adapters/__init__.py +1 -0
- cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +262 -123
- cache_dit/caching/cache_blocks/__init__.py +226 -0
- cache_dit/caching/cache_blocks/offload_utils.py +115 -0
- cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
- cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
- cache_dit/caching/cache_blocks/pattern_base.py +748 -0
- cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
- cache_dit/caching/cache_contexts/__init__.py +28 -0
- cache_dit/caching/cache_contexts/cache_config.py +120 -0
- cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +29 -90
- cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +138 -10
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
- cache_dit/caching/cache_contexts/context_manager.py +36 -0
- cache_dit/caching/cache_contexts/prune_config.py +63 -0
- cache_dit/caching/cache_contexts/prune_context.py +155 -0
- cache_dit/caching/cache_contexts/prune_manager.py +167 -0
- cache_dit/caching/cache_interface.py +358 -0
- cache_dit/{cache_factory → caching}/cache_types.py +19 -2
- cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
- cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
- cache_dit/caching/patch_functors/__init__.py +15 -0
- cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +2 -4
- cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
- cache_dit/caching/patch_functors/functor_qwen_image_controlnet.py +263 -0
- cache_dit/caching/utils.py +68 -0
- cache_dit/metrics/__init__.py +11 -0
- cache_dit/metrics/metrics.py +3 -0
- cache_dit/parallelism/__init__.py +3 -0
- cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
- cache_dit/parallelism/parallel_backend.py +26 -0
- cache_dit/parallelism/parallel_config.py +88 -0
- cache_dit/parallelism/parallel_interface.py +77 -0
- cache_dit/quantize/__init__.py +7 -0
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +44 -30
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/summary.py +593 -0
- cache_dit/utils.py +46 -290
- cache_dit-1.0.14.dist-info/METADATA +301 -0
- cache_dit-1.0.14.dist-info/RECORD +102 -0
- cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
- cache_dit/cache_factory/__init__.py +0 -28
- cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_blocks/__init__.py +0 -72
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -238
- cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -404
- cache_dit/cache_factory/cache_blocks/utils.py +0 -41
- cache_dit/cache_factory/cache_contexts/__init__.py +0 -14
- cache_dit/cache_factory/cache_interface.py +0 -217
- cache_dit/cache_factory/patch_functors/__init__.py +0 -12
- cache_dit/cache_factory/utils.py +0 -57
- cache_dit-0.3.2.dist-info/METADATA +0 -753
- cache_dit-0.3.2.dist-info/RECORD +0 -56
- cache_dit-0.3.2.dist-info/licenses/LICENSE +0 -53
- /cache_dit/{cache_factory → caching}/.gitignore +0 -0
- /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
- /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
from typing import Optional, Union
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
from torch.distributed import DeviceMesh, init_device_mesh
|
|
6
|
+
from torch.distributed.tensor.parallel import (
|
|
7
|
+
ColwiseParallel,
|
|
8
|
+
RowwiseParallel,
|
|
9
|
+
parallelize_module,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
from cache_dit.logger import init_logger
|
|
13
|
+
from cache_dit.parallelism.parallel_config import ParallelismConfig
|
|
14
|
+
|
|
15
|
+
from .tp_plan_registers import (
|
|
16
|
+
TensorParallelismPlanner,
|
|
17
|
+
TensorParallelismPlannerRegister,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
logger = init_logger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class DistributedRMSNorm(nn.Module):
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
tp_mesh: DeviceMesh,
|
|
27
|
+
normalized_shape: Union[int, list[int], torch.Size],
|
|
28
|
+
eps: Optional[float],
|
|
29
|
+
elementwise_affine: bool,
|
|
30
|
+
weight: torch.nn.parameter.Parameter,
|
|
31
|
+
):
|
|
32
|
+
super().__init__()
|
|
33
|
+
self.tp_mesh = tp_mesh
|
|
34
|
+
self.elementwise_affine = elementwise_affine
|
|
35
|
+
self.normalized_shape = normalized_shape
|
|
36
|
+
self.eps = eps
|
|
37
|
+
if self.elementwise_affine:
|
|
38
|
+
assert weight is not None
|
|
39
|
+
self.weight = weight
|
|
40
|
+
|
|
41
|
+
@classmethod
|
|
42
|
+
def from_rmsnorm(cls, tp_mesh: DeviceMesh, rmsnorm: nn.RMSNorm):
|
|
43
|
+
if not isinstance(rmsnorm, int):
|
|
44
|
+
assert len(rmsnorm.normalized_shape) == 1
|
|
45
|
+
|
|
46
|
+
if rmsnorm.weight is not None:
|
|
47
|
+
tp_size = tp_mesh.get_group().size()
|
|
48
|
+
tp_rank = tp_mesh.get_group().rank()
|
|
49
|
+
weight = rmsnorm.weight.chunk(tp_size, dim=0)[tp_rank]
|
|
50
|
+
else:
|
|
51
|
+
weight = None
|
|
52
|
+
norm = cls(
|
|
53
|
+
tp_mesh=tp_mesh,
|
|
54
|
+
normalized_shape=rmsnorm.normalized_shape,
|
|
55
|
+
eps=rmsnorm.eps,
|
|
56
|
+
elementwise_affine=rmsnorm.elementwise_affine,
|
|
57
|
+
weight=weight,
|
|
58
|
+
)
|
|
59
|
+
return norm
|
|
60
|
+
|
|
61
|
+
def forward(self, x):
|
|
62
|
+
if self.elementwise_affine:
|
|
63
|
+
assert x.shape[-1] == self.weight.shape[0]
|
|
64
|
+
mean_square = torch.mean(x * x, dim=-1, keepdim=True)
|
|
65
|
+
torch.distributed.all_reduce(
|
|
66
|
+
mean_square,
|
|
67
|
+
op=torch.distributed.ReduceOp.AVG,
|
|
68
|
+
group=self.tp_mesh.get_group(),
|
|
69
|
+
)
|
|
70
|
+
root_mean_square = torch.sqrt(mean_square + self.eps)
|
|
71
|
+
x_normed = x / root_mean_square
|
|
72
|
+
if self.elementwise_affine:
|
|
73
|
+
x_normed = x_normed * self.weight.to(device=x.device)
|
|
74
|
+
assert x_normed.device.type != "cpu"
|
|
75
|
+
return x_normed
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@TensorParallelismPlannerRegister.register("Wan")
|
|
79
|
+
class WanTensorParallelismPlanner(TensorParallelismPlanner):
|
|
80
|
+
def apply(
|
|
81
|
+
self,
|
|
82
|
+
transformer: torch.nn.Module,
|
|
83
|
+
parallelism_config: ParallelismConfig,
|
|
84
|
+
**kwargs,
|
|
85
|
+
) -> torch.nn.Module:
|
|
86
|
+
assert (
|
|
87
|
+
parallelism_config.tp_size is not None
|
|
88
|
+
and parallelism_config.tp_size > 1
|
|
89
|
+
), (
|
|
90
|
+
"parallel_config.tp_size must be set and greater than 1 for "
|
|
91
|
+
"tensor parallelism"
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
device_type = torch.accelerator.current_accelerator().type
|
|
95
|
+
tp_mesh: DeviceMesh = init_device_mesh(
|
|
96
|
+
device_type=device_type,
|
|
97
|
+
mesh_shape=[parallelism_config.tp_size],
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
transformer = self.parallelize_transformer(
|
|
101
|
+
transformer=transformer,
|
|
102
|
+
tp_mesh=tp_mesh,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
return transformer
|
|
106
|
+
|
|
107
|
+
def parallelize_transformer(
|
|
108
|
+
self,
|
|
109
|
+
transformer: nn.Module,
|
|
110
|
+
tp_mesh: DeviceMesh,
|
|
111
|
+
):
|
|
112
|
+
for _, block in transformer.blocks.named_children():
|
|
113
|
+
block.attn1.heads //= tp_mesh.size()
|
|
114
|
+
block.attn2.heads //= tp_mesh.size()
|
|
115
|
+
layer_plan = {
|
|
116
|
+
"attn1.to_q": ColwiseParallel(),
|
|
117
|
+
"attn1.to_k": ColwiseParallel(),
|
|
118
|
+
"attn1.to_v": ColwiseParallel(),
|
|
119
|
+
"attn1.to_out.0": RowwiseParallel(),
|
|
120
|
+
"attn2.to_q": ColwiseParallel(),
|
|
121
|
+
"attn2.to_k": ColwiseParallel(),
|
|
122
|
+
"attn2.to_v": ColwiseParallel(),
|
|
123
|
+
"attn2.to_out.0": RowwiseParallel(),
|
|
124
|
+
"ffn.net.0.proj": ColwiseParallel(),
|
|
125
|
+
"ffn.net.2": RowwiseParallel(),
|
|
126
|
+
}
|
|
127
|
+
if getattr(block.attn2, "add_k_proj", None):
|
|
128
|
+
layer_plan["attn2.add_k_proj"] = ColwiseParallel()
|
|
129
|
+
if getattr(block.attn2, "add_v_proj", None):
|
|
130
|
+
layer_plan["attn2.add_v_proj"] = ColwiseParallel()
|
|
131
|
+
parallelize_module(
|
|
132
|
+
module=block,
|
|
133
|
+
device_mesh=tp_mesh,
|
|
134
|
+
parallelize_plan=layer_plan,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
block.attn1.norm_q = DistributedRMSNorm.from_rmsnorm(
|
|
138
|
+
tp_mesh, block.attn1.norm_q
|
|
139
|
+
)
|
|
140
|
+
block.attn1.norm_k = DistributedRMSNorm.from_rmsnorm(
|
|
141
|
+
tp_mesh, block.attn1.norm_k
|
|
142
|
+
)
|
|
143
|
+
block.attn2.norm_q = DistributedRMSNorm.from_rmsnorm(
|
|
144
|
+
tp_mesh, block.attn2.norm_q
|
|
145
|
+
)
|
|
146
|
+
block.attn2.norm_k = DistributedRMSNorm.from_rmsnorm(
|
|
147
|
+
tp_mesh, block.attn2.norm_k
|
|
148
|
+
)
|
|
149
|
+
if getattr(block.attn2, "norm_added_k", None):
|
|
150
|
+
block.attn2.norm_added_k = DistributedRMSNorm.from_rmsnorm(
|
|
151
|
+
tp_mesh, block.attn2.norm_added_k
|
|
152
|
+
)
|
|
153
|
+
return transformer
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# NOTE: must import all planner classes to register them
|
|
2
|
+
from .tp_plan_flux import FluxTensorParallelismPlanner
|
|
3
|
+
from .tp_plan_kandinsky5 import Kandinsky5TensorParallelismPlanner
|
|
4
|
+
from .tp_plan_qwen_image import QwenImageTensorParallelismPlanner
|
|
5
|
+
from .tp_plan_registers import TensorParallelismPlannerRegister
|
|
6
|
+
from .tp_plan_wan import WanTensorParallelismPlanner
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"FluxTensorParallelismPlanner",
|
|
10
|
+
"Kandinsky5TensorParallelismPlanner",
|
|
11
|
+
"QwenImageTensorParallelismPlanner",
|
|
12
|
+
"TensorParallelismPlannerRegister",
|
|
13
|
+
"WanTensorParallelismPlanner",
|
|
14
|
+
]
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class ParallelismBackend(Enum):
|
|
5
|
+
NATIVE_DIFFUSER = "Native_Diffuser"
|
|
6
|
+
NATIVE_PYTORCH = "Native_PyTorch"
|
|
7
|
+
NONE = "None"
|
|
8
|
+
|
|
9
|
+
@classmethod
|
|
10
|
+
def is_supported(cls, backend: "ParallelismBackend") -> bool:
|
|
11
|
+
if backend == cls.NATIVE_PYTORCH:
|
|
12
|
+
return True
|
|
13
|
+
elif backend == cls.NATIVE_DIFFUSER:
|
|
14
|
+
try:
|
|
15
|
+
from diffusers.models._modeling_parallel import ( # noqa F401
|
|
16
|
+
ContextParallelModelPlan,
|
|
17
|
+
)
|
|
18
|
+
except ImportError:
|
|
19
|
+
raise ImportError(
|
|
20
|
+
"NATIVE_DIFFUSER parallelism backend requires the latest "
|
|
21
|
+
"version of diffusers(>=0.36.dev0). Please install latest "
|
|
22
|
+
"version of diffusers from source: \npip3 install "
|
|
23
|
+
"git+https://github.com/huggingface/diffusers.git"
|
|
24
|
+
)
|
|
25
|
+
return True
|
|
26
|
+
return False
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
from typing import Optional, Dict, Any
|
|
3
|
+
from cache_dit.parallelism.parallel_backend import ParallelismBackend
|
|
4
|
+
from cache_dit.logger import init_logger
|
|
5
|
+
|
|
6
|
+
logger = init_logger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclasses.dataclass
|
|
10
|
+
class ParallelismConfig:
|
|
11
|
+
# Parallelism backend, defaults to NATIVE_DIFFUSER
|
|
12
|
+
backend: ParallelismBackend = ParallelismBackend.NATIVE_DIFFUSER
|
|
13
|
+
# Context parallelism config
|
|
14
|
+
# ulysses_size (`int`, *optional*):
|
|
15
|
+
# The degree of ulysses parallelism.
|
|
16
|
+
ulysses_size: int = None
|
|
17
|
+
# ring_size (`int`, *optional*):
|
|
18
|
+
# The degree of ring parallelism.
|
|
19
|
+
ring_size: int = None
|
|
20
|
+
# Tensor parallelism config
|
|
21
|
+
# tp_size (`int`, *optional*):
|
|
22
|
+
# The degree of tensor parallelism.
|
|
23
|
+
tp_size: int = None
|
|
24
|
+
# parallel_kwargs (`dict`, *optional*):
|
|
25
|
+
# Additional kwargs for parallelism backends. For example, for
|
|
26
|
+
# NATIVE_DIFFUSER backend, it can include `cp_plan` and
|
|
27
|
+
# `attention_backend` arguments for `Context Parallelism`.
|
|
28
|
+
parallel_kwargs: Optional[Dict[str, Any]] = dataclasses.field(
|
|
29
|
+
default_factory=dict
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
def __post_init__(self):
|
|
33
|
+
assert ParallelismBackend.is_supported(self.backend), (
|
|
34
|
+
f"Parallel backend {self.backend} is not supported. "
|
|
35
|
+
f"Please make sure the required packages are installed."
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# Validate the parallelism configuration and auto adjust the backend if needed
|
|
39
|
+
if self.tp_size is not None and self.tp_size > 1:
|
|
40
|
+
assert (
|
|
41
|
+
self.ulysses_size is None or self.ulysses_size == 1
|
|
42
|
+
), "Tensor parallelism plus Ulysses parallelism is not supported right now."
|
|
43
|
+
assert (
|
|
44
|
+
self.ring_size is None or self.ring_size == 1
|
|
45
|
+
), "Tensor parallelism plus Ring parallelism is not supported right now."
|
|
46
|
+
if self.backend != ParallelismBackend.NATIVE_PYTORCH:
|
|
47
|
+
logger.warning(
|
|
48
|
+
"Tensor parallelism is only supported for NATIVE_PYTORCH backend "
|
|
49
|
+
"right now. Force set backend to NATIVE_PYTORCH."
|
|
50
|
+
)
|
|
51
|
+
self.backend = ParallelismBackend.NATIVE_PYTORCH
|
|
52
|
+
elif (
|
|
53
|
+
self.ulysses_size is not None
|
|
54
|
+
and self.ulysses_size > 1
|
|
55
|
+
and self.ring_size is not None
|
|
56
|
+
and self.ring_size > 1
|
|
57
|
+
):
|
|
58
|
+
raise ValueError(
|
|
59
|
+
"Ulysses parallelism plus Ring parallelism is not fully supported right now."
|
|
60
|
+
)
|
|
61
|
+
else:
|
|
62
|
+
if (self.ulysses_size is not None and self.ulysses_size > 1) or (
|
|
63
|
+
self.ring_size is not None and self.ring_size > 1
|
|
64
|
+
):
|
|
65
|
+
if self.backend != ParallelismBackend.NATIVE_DIFFUSER:
|
|
66
|
+
logger.warning(
|
|
67
|
+
"Ulysses/Ring parallelism is only supported for NATIVE_DIFFUSER "
|
|
68
|
+
"backend right now. Force set backend to NATIVE_DIFFUSER."
|
|
69
|
+
)
|
|
70
|
+
self.backend = ParallelismBackend.NATIVE_DIFFUSER
|
|
71
|
+
|
|
72
|
+
def strify(self, details: bool = False) -> str:
|
|
73
|
+
if details:
|
|
74
|
+
return (
|
|
75
|
+
f"ParallelismConfig(backend={self.backend}, "
|
|
76
|
+
f"ulysses_size={self.ulysses_size}, "
|
|
77
|
+
f"ring_size={self.ring_size}, "
|
|
78
|
+
f"tp_size={self.tp_size})"
|
|
79
|
+
)
|
|
80
|
+
else:
|
|
81
|
+
parallel_str = ""
|
|
82
|
+
if self.ulysses_size is not None:
|
|
83
|
+
parallel_str += f"Ulysses{self.ulysses_size}"
|
|
84
|
+
if self.ring_size is not None:
|
|
85
|
+
parallel_str += f"Ring{self.ring_size}"
|
|
86
|
+
if self.tp_size is not None:
|
|
87
|
+
parallel_str += f"TP{self.tp_size}"
|
|
88
|
+
return parallel_str
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from cache_dit.parallelism.parallel_backend import ParallelismBackend
|
|
3
|
+
from cache_dit.parallelism.parallel_config import ParallelismConfig
|
|
4
|
+
from cache_dit.utils import maybe_empty_cache
|
|
5
|
+
from cache_dit.logger import init_logger
|
|
6
|
+
|
|
7
|
+
logger = init_logger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def enable_parallelism(
|
|
11
|
+
transformer: torch.nn.Module,
|
|
12
|
+
parallelism_config: ParallelismConfig,
|
|
13
|
+
) -> torch.nn.Module:
|
|
14
|
+
assert isinstance(transformer, torch.nn.Module), (
|
|
15
|
+
"transformer must be an instance of torch.nn.Module, "
|
|
16
|
+
f"but got {type(transformer)}"
|
|
17
|
+
)
|
|
18
|
+
if getattr(transformer, "_is_parallelized", False):
|
|
19
|
+
logger.warning(
|
|
20
|
+
"The transformer is already parallelized. "
|
|
21
|
+
"Skipping parallelism enabling."
|
|
22
|
+
)
|
|
23
|
+
return transformer
|
|
24
|
+
|
|
25
|
+
if parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER:
|
|
26
|
+
from cache_dit.parallelism.backends.native_diffusers import (
|
|
27
|
+
maybe_enable_parallelism,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
transformer = maybe_enable_parallelism(
|
|
31
|
+
transformer,
|
|
32
|
+
parallelism_config,
|
|
33
|
+
)
|
|
34
|
+
elif parallelism_config.backend == ParallelismBackend.NATIVE_PYTORCH:
|
|
35
|
+
from cache_dit.parallelism.backends.native_pytorch import (
|
|
36
|
+
maybe_enable_parallelism,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
transformer = maybe_enable_parallelism(
|
|
40
|
+
transformer,
|
|
41
|
+
parallelism_config,
|
|
42
|
+
)
|
|
43
|
+
else:
|
|
44
|
+
raise ValueError(
|
|
45
|
+
f"Parallel backend {parallelism_config.backend} is not supported yet."
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
transformer._is_parallelized = True # type: ignore[attr-defined]
|
|
49
|
+
# Use `parallelism` not `parallel` to avoid name conflict with diffusers.
|
|
50
|
+
transformer._parallelism_config = parallelism_config # type: ignore[attr-defined]
|
|
51
|
+
logger.info(
|
|
52
|
+
f"Enabled parallelism: {parallelism_config.strify(True)}, "
|
|
53
|
+
f"transformer id:{id(transformer)}"
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# NOTE: Workaround for potential memory peak issue after parallelism
|
|
57
|
+
# enabling, specially for tensor parallelism in native pytorch backend.
|
|
58
|
+
maybe_empty_cache()
|
|
59
|
+
|
|
60
|
+
return transformer
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def remove_parallelism_stats(
|
|
64
|
+
transformer: torch.nn.Module,
|
|
65
|
+
) -> torch.nn.Module:
|
|
66
|
+
if not getattr(transformer, "_is_parallelized", False):
|
|
67
|
+
logger.warning(
|
|
68
|
+
"The transformer is not parallelized. "
|
|
69
|
+
"Skipping removing parallelism."
|
|
70
|
+
)
|
|
71
|
+
return transformer
|
|
72
|
+
|
|
73
|
+
if hasattr(transformer, "_is_parallelized"):
|
|
74
|
+
del transformer._is_parallelized # type: ignore[attr-defined]
|
|
75
|
+
if hasattr(transformer, "_parallelism_config"):
|
|
76
|
+
del transformer._parallelism_config # type: ignore[attr-defined]
|
|
77
|
+
return transformer
|
cache_dit/quantize/__init__.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .torchao import quantize_ao
|
|
File without changes
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .quantize_ao import quantize_ao
|
|
@@ -1,7 +1,6 @@
|
|
|
1
|
-
import gc
|
|
2
|
-
import time
|
|
3
1
|
import torch
|
|
4
2
|
from typing import Callable, Optional, List
|
|
3
|
+
from cache_dit.utils import maybe_empty_cache
|
|
5
4
|
from cache_dit.logger import init_logger
|
|
6
5
|
|
|
7
6
|
logger = init_logger(__name__)
|
|
@@ -9,7 +8,7 @@ logger = init_logger(__name__)
|
|
|
9
8
|
|
|
10
9
|
def quantize_ao(
|
|
11
10
|
module: torch.nn.Module,
|
|
12
|
-
quant_type: str = "
|
|
11
|
+
quant_type: str = "float8_weight_only",
|
|
13
12
|
exclude_layers: List[str] = [
|
|
14
13
|
"embedder",
|
|
15
14
|
"embed",
|
|
@@ -24,6 +23,18 @@ def quantize_ao(
|
|
|
24
23
|
# set `exclude_layers` as `[]` if you don't want this behavior.
|
|
25
24
|
assert isinstance(module, torch.nn.Module)
|
|
26
25
|
|
|
26
|
+
alias_map = {
|
|
27
|
+
"float8": "fp8_w8a8_dq",
|
|
28
|
+
"float8_weight_only": "fp8_w8a16_wo",
|
|
29
|
+
"int8": "int8_w8a8_dq",
|
|
30
|
+
"int8_weight_only": "int8_w8a16_wo",
|
|
31
|
+
"int4": "int4_w4a8_dq",
|
|
32
|
+
"int4_w4a4": "int4_w4a4_dq",
|
|
33
|
+
"int4_weight_only": "int4_w4a16_wo",
|
|
34
|
+
}
|
|
35
|
+
if quant_type.lower() in alias_map:
|
|
36
|
+
quant_type = alias_map[quant_type.lower()]
|
|
37
|
+
|
|
27
38
|
quant_type = quant_type.lower()
|
|
28
39
|
assert quant_type in (
|
|
29
40
|
"fp8_w8a8_dq",
|
|
@@ -80,11 +91,11 @@ def quantize_ao(
|
|
|
80
91
|
|
|
81
92
|
return False
|
|
82
93
|
|
|
83
|
-
def
|
|
94
|
+
def _quant_config():
|
|
84
95
|
try:
|
|
85
96
|
if quant_type == "fp8_w8a8_dq":
|
|
86
97
|
from torchao.quantization import (
|
|
87
|
-
|
|
98
|
+
Float8DynamicActivationFloat8WeightConfig,
|
|
88
99
|
PerTensor,
|
|
89
100
|
PerRow,
|
|
90
101
|
)
|
|
@@ -92,7 +103,7 @@ def quantize_ao(
|
|
|
92
103
|
if per_row: # Ensure bfloat16
|
|
93
104
|
module.to(torch.bfloat16)
|
|
94
105
|
|
|
95
|
-
|
|
106
|
+
quant_config = Float8DynamicActivationFloat8WeightConfig(
|
|
96
107
|
weight_dtype=kwargs.get(
|
|
97
108
|
"weight_dtype",
|
|
98
109
|
torch.float8_e4m3fn,
|
|
@@ -109,9 +120,9 @@ def quantize_ao(
|
|
|
109
120
|
)
|
|
110
121
|
|
|
111
122
|
elif quant_type == "fp8_w8a16_wo":
|
|
112
|
-
from torchao.quantization import
|
|
123
|
+
from torchao.quantization import Float8WeightOnlyConfig
|
|
113
124
|
|
|
114
|
-
|
|
125
|
+
quant_config = Float8WeightOnlyConfig(
|
|
115
126
|
weight_dtype=kwargs.get(
|
|
116
127
|
"weight_dtype",
|
|
117
128
|
torch.float8_e4m3fn,
|
|
@@ -120,39 +131,43 @@ def quantize_ao(
|
|
|
120
131
|
|
|
121
132
|
elif quant_type == "int8_w8a8_dq":
|
|
122
133
|
from torchao.quantization import (
|
|
123
|
-
|
|
134
|
+
Int8DynamicActivationInt8WeightConfig,
|
|
124
135
|
)
|
|
125
136
|
|
|
126
|
-
|
|
137
|
+
quant_config = Int8DynamicActivationInt8WeightConfig()
|
|
127
138
|
|
|
128
139
|
elif quant_type == "int8_w8a16_wo":
|
|
129
|
-
from torchao.quantization import int8_weight_only
|
|
130
140
|
|
|
131
|
-
|
|
141
|
+
from torchao.quantization import Int8WeightOnlyConfig
|
|
142
|
+
|
|
143
|
+
quant_config = Int8WeightOnlyConfig(
|
|
132
144
|
# group_size is None -> per_channel, else per group
|
|
133
145
|
group_size=kwargs.get("group_size", None),
|
|
134
146
|
)
|
|
135
147
|
|
|
136
148
|
elif quant_type == "int4_w4a8_dq":
|
|
149
|
+
|
|
137
150
|
from torchao.quantization import (
|
|
138
|
-
|
|
151
|
+
Int8DynamicActivationInt4WeightConfig,
|
|
139
152
|
)
|
|
140
153
|
|
|
141
|
-
|
|
154
|
+
quant_config = Int8DynamicActivationInt4WeightConfig(
|
|
142
155
|
group_size=kwargs.get("group_size", 32),
|
|
143
156
|
)
|
|
144
157
|
|
|
145
158
|
elif quant_type == "int4_w4a4_dq":
|
|
159
|
+
|
|
146
160
|
from torchao.quantization import (
|
|
147
|
-
|
|
161
|
+
Int4DynamicActivationInt4WeightConfig,
|
|
148
162
|
)
|
|
149
163
|
|
|
150
|
-
|
|
164
|
+
quant_config = Int4DynamicActivationInt4WeightConfig()
|
|
151
165
|
|
|
152
166
|
elif quant_type == "int4_w4a16_wo":
|
|
153
|
-
from torchao.quantization import int4_weight_only
|
|
154
167
|
|
|
155
|
-
|
|
168
|
+
from torchao.quantization import Int4WeightOnlyConfig
|
|
169
|
+
|
|
170
|
+
quant_config = Int4WeightOnlyConfig(
|
|
156
171
|
group_size=kwargs.get("group_size", 32),
|
|
157
172
|
)
|
|
158
173
|
|
|
@@ -168,33 +183,32 @@ def quantize_ao(
|
|
|
168
183
|
)
|
|
169
184
|
raise e
|
|
170
185
|
|
|
171
|
-
return
|
|
186
|
+
return quant_config
|
|
172
187
|
|
|
173
188
|
from torchao.quantization import quantize_
|
|
174
189
|
|
|
175
190
|
quantize_(
|
|
176
191
|
module,
|
|
177
|
-
|
|
192
|
+
_quant_config(),
|
|
178
193
|
filter_fn=_filter_fn if filter_fn is None else filter_fn,
|
|
179
194
|
device=kwargs.get("device", None),
|
|
180
195
|
)
|
|
181
196
|
|
|
182
|
-
|
|
197
|
+
maybe_empty_cache()
|
|
198
|
+
|
|
199
|
+
alias_map_rev = {v: k for k, v in alias_map.items()}
|
|
200
|
+
if quant_type in alias_map_rev:
|
|
201
|
+
quant_type = alias_map_rev[quant_type]
|
|
183
202
|
|
|
184
203
|
logger.info(
|
|
204
|
+
f"Quantized Module: {module.__class__.__name__:>5}\n"
|
|
185
205
|
f"Quantized Method: {quant_type:>5}\n"
|
|
186
206
|
f"Quantized Linear Layers: {num_quant_linear:>5}\n"
|
|
187
207
|
f"Skipped Linear Layers: {num_skip_linear:>5}\n"
|
|
188
208
|
f"Total Linear Layers: {num_linear_layers:>5}\n"
|
|
189
209
|
f"Total (all) Layers: {num_layers:>5}"
|
|
190
210
|
)
|
|
191
|
-
return module
|
|
192
211
|
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
gc.collect()
|
|
197
|
-
torch.cuda.empty_cache()
|
|
198
|
-
time.sleep(1)
|
|
199
|
-
gc.collect()
|
|
200
|
-
torch.cuda.empty_cache()
|
|
212
|
+
module._quantize_type = quant_type
|
|
213
|
+
module._is_quantized = True
|
|
214
|
+
return module
|
|
File without changes
|
|
File without changes
|
|
@@ -7,37 +7,24 @@ logger = init_logger(__name__)
|
|
|
7
7
|
|
|
8
8
|
def quantize(
|
|
9
9
|
module: torch.nn.Module,
|
|
10
|
-
quant_type: str = "
|
|
10
|
+
quant_type: str = "float8_weight_only",
|
|
11
11
|
backend: str = "ao",
|
|
12
12
|
exclude_layers: List[str] = [
|
|
13
13
|
"embedder",
|
|
14
14
|
"embed",
|
|
15
15
|
],
|
|
16
16
|
filter_fn: Optional[Callable] = None,
|
|
17
|
-
# only for fp8_w8a8_dq
|
|
18
|
-
per_row: bool = True,
|
|
19
17
|
**kwargs,
|
|
20
18
|
) -> torch.nn.Module:
|
|
21
19
|
assert isinstance(module, torch.nn.Module)
|
|
22
20
|
|
|
23
21
|
if backend.lower() in ("ao", "torchao"):
|
|
24
|
-
from cache_dit.quantize.
|
|
25
|
-
|
|
26
|
-
quant_type = quant_type.lower()
|
|
27
|
-
assert quant_type in (
|
|
28
|
-
"fp8_w8a8_dq",
|
|
29
|
-
"fp8_w8a16_wo",
|
|
30
|
-
"int8_w8a8_dq",
|
|
31
|
-
"int8_w8a16_wo",
|
|
32
|
-
"int4_w4a8_dq",
|
|
33
|
-
"int4_w4a4_dq",
|
|
34
|
-
"int4_w4a16_wo",
|
|
35
|
-
), f"{quant_type} is not supported for torchao backend now!"
|
|
22
|
+
from cache_dit.quantize.backends.torchao import quantize_ao
|
|
36
23
|
|
|
37
24
|
return quantize_ao(
|
|
38
25
|
module,
|
|
39
26
|
quant_type=quant_type,
|
|
40
|
-
per_row=per_row,
|
|
27
|
+
per_row=kwargs.pop("per_row", True),
|
|
41
28
|
exclude_layers=exclude_layers,
|
|
42
29
|
filter_fn=filter_fn,
|
|
43
30
|
**kwargs,
|