cache-dit 1.0.9__py3-none-any.whl → 1.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +1 -0
- cache_dit/cache_factory/block_adapters/__init__.py +37 -0
- cache_dit/cache_factory/block_adapters/block_adapters.py +51 -3
- cache_dit/cache_factory/block_adapters/block_registers.py +41 -14
- cache_dit/cache_factory/cache_adapters/cache_adapter.py +68 -30
- cache_dit/cache_factory/cache_contexts/cache_config.py +5 -3
- cache_dit/cache_factory/cache_contexts/cache_manager.py +125 -4
- cache_dit/cache_factory/cache_contexts/context_manager.py +9 -2
- cache_dit/cache_factory/cache_contexts/prune_manager.py +15 -2
- cache_dit/cache_factory/cache_interface.py +29 -3
- cache_dit/cache_factory/forward_pattern.py +14 -14
- cache_dit/parallelism/backends/native_diffusers/__init__.py +0 -3
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +74 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +254 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +17 -61
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +3 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +159 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +58 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +12 -0
- cache_dit/parallelism/parallel_backend.py +2 -0
- cache_dit/parallelism/parallel_config.py +8 -1
- cache_dit/parallelism/parallel_interface.py +9 -4
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +28 -9
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/utils.py +22 -2
- {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/METADATA +22 -13
- {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/RECORD +45 -29
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
from torch.distributed import DeviceMesh, init_device_mesh
|
|
4
|
+
from torch.distributed._tensor import Replicate
|
|
5
|
+
from torch.distributed.tensor.parallel import (
|
|
6
|
+
ColwiseParallel,
|
|
7
|
+
RowwiseParallel,
|
|
8
|
+
parallelize_module,
|
|
9
|
+
)
|
|
10
|
+
from cache_dit.parallelism.parallel_config import ParallelismConfig
|
|
11
|
+
from .tp_plan_registers import (
|
|
12
|
+
TensorParallelismPlanner,
|
|
13
|
+
TensorParallelismPlannerRegister,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
from cache_dit.logger import init_logger
|
|
17
|
+
|
|
18
|
+
logger = init_logger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@TensorParallelismPlannerRegister.register("QwenImage")
|
|
22
|
+
class QwenImageTensorParallelismPlanner(TensorParallelismPlanner):
|
|
23
|
+
def apply(
|
|
24
|
+
self,
|
|
25
|
+
transformer: torch.nn.Module,
|
|
26
|
+
parallelism_config: ParallelismConfig,
|
|
27
|
+
**kwargs,
|
|
28
|
+
) -> torch.nn.Module:
|
|
29
|
+
assert (
|
|
30
|
+
parallelism_config.tp_size is not None
|
|
31
|
+
and parallelism_config.tp_size > 1
|
|
32
|
+
), (
|
|
33
|
+
"parallel_config.tp_size must be set and greater than 1 for "
|
|
34
|
+
"tensor parallelism"
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
device_type = torch.accelerator.current_accelerator().type
|
|
38
|
+
tp_mesh: DeviceMesh = init_device_mesh(
|
|
39
|
+
device_type=device_type,
|
|
40
|
+
mesh_shape=[parallelism_config.tp_size],
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
transformer = self.parallelize_transformer(
|
|
44
|
+
transformer=transformer,
|
|
45
|
+
tp_mesh=tp_mesh,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
return transformer
|
|
49
|
+
|
|
50
|
+
def parallelize_transformer(
|
|
51
|
+
self,
|
|
52
|
+
transformer: nn.Module,
|
|
53
|
+
tp_mesh: DeviceMesh,
|
|
54
|
+
):
|
|
55
|
+
for _, block in transformer.transformer_blocks.named_children():
|
|
56
|
+
block.attn.heads //= tp_mesh.size()
|
|
57
|
+
layer_plan = {
|
|
58
|
+
"attn.to_q": ColwiseParallel(),
|
|
59
|
+
"attn.to_k": ColwiseParallel(),
|
|
60
|
+
"attn.to_v": ColwiseParallel(),
|
|
61
|
+
"attn.to_out.0": RowwiseParallel(),
|
|
62
|
+
"img_mod.1": ColwiseParallel(output_layouts=Replicate()),
|
|
63
|
+
"img_mlp.net.0.proj": ColwiseParallel(),
|
|
64
|
+
"img_mlp.net.2": RowwiseParallel(),
|
|
65
|
+
"attn.add_q_proj": ColwiseParallel(),
|
|
66
|
+
"attn.add_k_proj": ColwiseParallel(),
|
|
67
|
+
"attn.add_v_proj": ColwiseParallel(),
|
|
68
|
+
"attn.to_add_out": RowwiseParallel(),
|
|
69
|
+
"txt_mod.1": ColwiseParallel(output_layouts=Replicate()),
|
|
70
|
+
"txt_mlp.net.0.proj": ColwiseParallel(),
|
|
71
|
+
"txt_mlp.net.2": RowwiseParallel(),
|
|
72
|
+
}
|
|
73
|
+
parallelize_module(
|
|
74
|
+
module=block,
|
|
75
|
+
device_mesh=tp_mesh,
|
|
76
|
+
parallelize_plan=layer_plan,
|
|
77
|
+
)
|
|
78
|
+
return transformer
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import logging
|
|
3
|
+
from abc import abstractmethod
|
|
4
|
+
from typing import Dict
|
|
5
|
+
from cache_dit.parallelism.parallel_config import ParallelismConfig
|
|
6
|
+
from cache_dit.logger import init_logger
|
|
7
|
+
|
|
8
|
+
logger = init_logger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TensorParallelismPlanner:
|
|
12
|
+
# TODO: add `apply_extra` abstract method for extra
|
|
13
|
+
# parallelism kwargs handling
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def apply(
|
|
17
|
+
self,
|
|
18
|
+
transformer: torch.nn.Module,
|
|
19
|
+
parallelism_config: ParallelismConfig,
|
|
20
|
+
**kwargs,
|
|
21
|
+
) -> torch.nn.Module:
|
|
22
|
+
raise NotImplementedError(
|
|
23
|
+
"apply method must be implemented by subclasses"
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class TensorParallelismPlannerRegister:
|
|
28
|
+
_tp_planner_registry: Dict[str, TensorParallelismPlanner] = {}
|
|
29
|
+
|
|
30
|
+
@classmethod
|
|
31
|
+
def register(cls, name: str):
|
|
32
|
+
def decorator(planner_cls: type[TensorParallelismPlanner]):
|
|
33
|
+
assert (
|
|
34
|
+
name not in cls._tp_planner_registry
|
|
35
|
+
), f"TensorParallelismPlanner with name {name} is already registered."
|
|
36
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
37
|
+
logger.debug(f"Registering TensorParallelismPlanner: {name}")
|
|
38
|
+
cls._tp_planner_registry[name] = planner_cls
|
|
39
|
+
return planner_cls
|
|
40
|
+
|
|
41
|
+
return decorator
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
def get_planner(
|
|
45
|
+
cls, transformer: str | torch.nn.Module
|
|
46
|
+
) -> type[TensorParallelismPlanner]:
|
|
47
|
+
if isinstance(transformer, torch.nn.Module):
|
|
48
|
+
name = transformer.__class__.__name__
|
|
49
|
+
else:
|
|
50
|
+
name = transformer
|
|
51
|
+
planner_cls = None
|
|
52
|
+
for planner_name in cls._tp_planner_registry:
|
|
53
|
+
if name.startswith(planner_name):
|
|
54
|
+
planner_cls = cls._tp_planner_registry.get(planner_name)
|
|
55
|
+
break
|
|
56
|
+
if planner_cls is None:
|
|
57
|
+
raise ValueError(f"No planner registered under name: {name}")
|
|
58
|
+
return planner_cls
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
from typing import Optional, Union
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
from torch.distributed import DeviceMesh, init_device_mesh
|
|
6
|
+
from torch.distributed.tensor.parallel import (
|
|
7
|
+
ColwiseParallel,
|
|
8
|
+
RowwiseParallel,
|
|
9
|
+
parallelize_module,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
from cache_dit.logger import init_logger
|
|
13
|
+
from cache_dit.parallelism.parallel_config import ParallelismConfig
|
|
14
|
+
|
|
15
|
+
from .tp_plan_registers import (
|
|
16
|
+
TensorParallelismPlanner,
|
|
17
|
+
TensorParallelismPlannerRegister,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
logger = init_logger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class DistributedRMSNorm(nn.Module):
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
tp_mesh: DeviceMesh,
|
|
27
|
+
normalized_shape: Union[int, list[int], torch.Size],
|
|
28
|
+
eps: Optional[float],
|
|
29
|
+
elementwise_affine: bool,
|
|
30
|
+
weight: torch.nn.parameter.Parameter,
|
|
31
|
+
):
|
|
32
|
+
super().__init__()
|
|
33
|
+
self.tp_mesh = tp_mesh
|
|
34
|
+
self.elementwise_affine = elementwise_affine
|
|
35
|
+
self.normalized_shape = normalized_shape
|
|
36
|
+
self.eps = eps
|
|
37
|
+
if self.elementwise_affine:
|
|
38
|
+
assert weight is not None
|
|
39
|
+
self.weight = weight
|
|
40
|
+
|
|
41
|
+
@classmethod
|
|
42
|
+
def from_rmsnorm(cls, tp_mesh: DeviceMesh, rmsnorm: nn.RMSNorm):
|
|
43
|
+
if not isinstance(rmsnorm, int):
|
|
44
|
+
assert len(rmsnorm.normalized_shape) == 1
|
|
45
|
+
|
|
46
|
+
if rmsnorm.weight is not None:
|
|
47
|
+
tp_size = tp_mesh.get_group().size()
|
|
48
|
+
tp_rank = tp_mesh.get_group().rank()
|
|
49
|
+
weight = rmsnorm.weight.chunk(tp_size, dim=0)[tp_rank]
|
|
50
|
+
else:
|
|
51
|
+
weight = None
|
|
52
|
+
norm = cls(
|
|
53
|
+
tp_mesh=tp_mesh,
|
|
54
|
+
normalized_shape=rmsnorm.normalized_shape,
|
|
55
|
+
eps=rmsnorm.eps,
|
|
56
|
+
elementwise_affine=rmsnorm.elementwise_affine,
|
|
57
|
+
weight=weight,
|
|
58
|
+
)
|
|
59
|
+
return norm
|
|
60
|
+
|
|
61
|
+
def forward(self, x):
|
|
62
|
+
if self.elementwise_affine:
|
|
63
|
+
assert x.shape[-1] == self.weight.shape[0]
|
|
64
|
+
mean_square = torch.mean(x * x, dim=-1, keepdim=True)
|
|
65
|
+
torch.distributed.all_reduce(
|
|
66
|
+
mean_square,
|
|
67
|
+
op=torch.distributed.ReduceOp.AVG,
|
|
68
|
+
group=self.tp_mesh.get_group(),
|
|
69
|
+
)
|
|
70
|
+
root_mean_square = torch.sqrt(mean_square + self.eps)
|
|
71
|
+
x_normed = x / root_mean_square
|
|
72
|
+
if self.elementwise_affine:
|
|
73
|
+
x_normed = x_normed * self.weight.to(device=x.device)
|
|
74
|
+
assert x_normed.device.type != "cpu"
|
|
75
|
+
return x_normed
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@TensorParallelismPlannerRegister.register("Wan")
|
|
79
|
+
class WanTensorParallelismPlanner(TensorParallelismPlanner):
|
|
80
|
+
def apply(
|
|
81
|
+
self,
|
|
82
|
+
transformer: torch.nn.Module,
|
|
83
|
+
parallelism_config: ParallelismConfig,
|
|
84
|
+
**kwargs,
|
|
85
|
+
) -> torch.nn.Module:
|
|
86
|
+
assert (
|
|
87
|
+
parallelism_config.tp_size is not None
|
|
88
|
+
and parallelism_config.tp_size > 1
|
|
89
|
+
), (
|
|
90
|
+
"parallel_config.tp_size must be set and greater than 1 for "
|
|
91
|
+
"tensor parallelism"
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
device_type = torch.accelerator.current_accelerator().type
|
|
95
|
+
tp_mesh: DeviceMesh = init_device_mesh(
|
|
96
|
+
device_type=device_type,
|
|
97
|
+
mesh_shape=[parallelism_config.tp_size],
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
transformer = self.parallelize_transformer(
|
|
101
|
+
transformer=transformer,
|
|
102
|
+
tp_mesh=tp_mesh,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
return transformer
|
|
106
|
+
|
|
107
|
+
def parallelize_transformer(
|
|
108
|
+
self,
|
|
109
|
+
transformer: nn.Module,
|
|
110
|
+
tp_mesh: DeviceMesh,
|
|
111
|
+
):
|
|
112
|
+
for _, block in transformer.blocks.named_children():
|
|
113
|
+
block.attn1.heads //= tp_mesh.size()
|
|
114
|
+
block.attn2.heads //= tp_mesh.size()
|
|
115
|
+
layer_plan = {
|
|
116
|
+
"attn1.to_q": ColwiseParallel(),
|
|
117
|
+
"attn1.to_k": ColwiseParallel(),
|
|
118
|
+
"attn1.to_v": ColwiseParallel(),
|
|
119
|
+
"attn1.to_out.0": RowwiseParallel(),
|
|
120
|
+
"attn2.to_q": ColwiseParallel(),
|
|
121
|
+
"attn2.to_k": ColwiseParallel(),
|
|
122
|
+
"attn2.to_v": ColwiseParallel(),
|
|
123
|
+
"attn2.to_out.0": RowwiseParallel(),
|
|
124
|
+
"ffn.net.0.proj": ColwiseParallel(),
|
|
125
|
+
"ffn.net.2": RowwiseParallel(),
|
|
126
|
+
}
|
|
127
|
+
if getattr(block.attn2, "add_k_proj", None):
|
|
128
|
+
layer_plan["attn2.add_k_proj"] = ColwiseParallel()
|
|
129
|
+
if getattr(block.attn2, "add_v_proj", None):
|
|
130
|
+
layer_plan["attn2.add_v_proj"] = ColwiseParallel()
|
|
131
|
+
parallelize_module(
|
|
132
|
+
module=block,
|
|
133
|
+
device_mesh=tp_mesh,
|
|
134
|
+
parallelize_plan=layer_plan,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
block.attn1.norm_q = DistributedRMSNorm.from_rmsnorm(
|
|
138
|
+
tp_mesh, block.attn1.norm_q
|
|
139
|
+
)
|
|
140
|
+
block.attn1.norm_k = DistributedRMSNorm.from_rmsnorm(
|
|
141
|
+
tp_mesh, block.attn1.norm_k
|
|
142
|
+
)
|
|
143
|
+
block.attn2.norm_q = DistributedRMSNorm.from_rmsnorm(
|
|
144
|
+
tp_mesh, block.attn2.norm_q
|
|
145
|
+
)
|
|
146
|
+
block.attn2.norm_k = DistributedRMSNorm.from_rmsnorm(
|
|
147
|
+
tp_mesh, block.attn2.norm_k
|
|
148
|
+
)
|
|
149
|
+
if getattr(block.attn2, "norm_added_k", None):
|
|
150
|
+
block.attn2.norm_added_k = DistributedRMSNorm.from_rmsnorm(
|
|
151
|
+
tp_mesh, block.attn2.norm_added_k
|
|
152
|
+
)
|
|
153
|
+
return transformer
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
# NOTE: must import all planner classes to register them
|
|
2
|
+
from .tp_plan_registers import TensorParallelismPlannerRegister
|
|
3
|
+
from .tp_plan_flux import FluxTensorParallelismPlanner
|
|
4
|
+
from .tp_plan_qwen_image import QwenImageTensorParallelismPlanner
|
|
5
|
+
from .tp_plan_wan import WanTensorParallelismPlanner
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"TensorParallelismPlannerRegister",
|
|
9
|
+
"FluxTensorParallelismPlanner",
|
|
10
|
+
"QwenImageTensorParallelismPlanner",
|
|
11
|
+
"WanTensorParallelismPlanner",
|
|
12
|
+
]
|
|
@@ -8,6 +8,8 @@ class ParallelismBackend(Enum):
|
|
|
8
8
|
|
|
9
9
|
@classmethod
|
|
10
10
|
def is_supported(cls, backend: "ParallelismBackend") -> bool:
|
|
11
|
+
if backend in [cls.NATIVE_PYTORCH]:
|
|
12
|
+
return True
|
|
11
13
|
# Now, only Native_Diffuser backend is supported
|
|
12
14
|
if backend in [cls.NATIVE_DIFFUSER]:
|
|
13
15
|
try:
|
|
@@ -34,7 +34,14 @@ class ParallelismConfig:
|
|
|
34
34
|
f"Parallel backend {self.backend} is not supported. "
|
|
35
35
|
f"Please make sure the required packages are installed."
|
|
36
36
|
)
|
|
37
|
-
|
|
37
|
+
|
|
38
|
+
if self.tp_size is not None and self.tp_size > 1:
|
|
39
|
+
assert (
|
|
40
|
+
self.ulysses_size is None or self.ulysses_size == 1
|
|
41
|
+
), "Tensor parallelism plus Ulysses parallelism is not supported right now."
|
|
42
|
+
assert (
|
|
43
|
+
self.ring_size is None or self.ring_size == 1
|
|
44
|
+
), "Tensor parallelism plus Ring parallelism is not supported right now."
|
|
38
45
|
|
|
39
46
|
def strify(self, details: bool = False) -> str:
|
|
40
47
|
if details:
|
|
@@ -24,12 +24,17 @@ def enable_parallelism(
|
|
|
24
24
|
if parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER:
|
|
25
25
|
from cache_dit.parallelism.backends.native_diffusers import (
|
|
26
26
|
maybe_enable_parallelism,
|
|
27
|
-
native_diffusers_parallelism_available,
|
|
28
27
|
)
|
|
29
28
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
29
|
+
transformer = maybe_enable_parallelism(
|
|
30
|
+
transformer,
|
|
31
|
+
parallelism_config,
|
|
32
|
+
)
|
|
33
|
+
elif parallelism_config.backend == ParallelismBackend.NATIVE_PYTORCH:
|
|
34
|
+
from cache_dit.parallelism.backends.native_pytorch import (
|
|
35
|
+
maybe_enable_parallelism,
|
|
36
|
+
)
|
|
37
|
+
|
|
33
38
|
transformer = maybe_enable_parallelism(
|
|
34
39
|
transformer,
|
|
35
40
|
parallelism_config,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .torchao import quantize_ao
|
|
File without changes
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .quantize_ao import quantize_ao
|
|
@@ -9,7 +9,7 @@ logger = init_logger(__name__)
|
|
|
9
9
|
|
|
10
10
|
def quantize_ao(
|
|
11
11
|
module: torch.nn.Module,
|
|
12
|
-
quant_type: str = "
|
|
12
|
+
quant_type: str = "float8_weight_only",
|
|
13
13
|
exclude_layers: List[str] = [
|
|
14
14
|
"embedder",
|
|
15
15
|
"embed",
|
|
@@ -24,6 +24,18 @@ def quantize_ao(
|
|
|
24
24
|
# set `exclude_layers` as `[]` if you don't want this behavior.
|
|
25
25
|
assert isinstance(module, torch.nn.Module)
|
|
26
26
|
|
|
27
|
+
alias_map = {
|
|
28
|
+
"float8": "fp8_w8a8_dq",
|
|
29
|
+
"float8_weight_only": "fp8_w8a16_wo",
|
|
30
|
+
"int8": "int8_w8a8_dq",
|
|
31
|
+
"int8_weight_only": "int8_w8a16_wo",
|
|
32
|
+
"int4": "int4_w4a8_dq",
|
|
33
|
+
"int4_w4a4": "int4_w4a4_dq",
|
|
34
|
+
"int4_weight_only": "int4_w4a16_wo",
|
|
35
|
+
}
|
|
36
|
+
if quant_type.lower() in alias_map:
|
|
37
|
+
quant_type = alias_map[quant_type.lower()]
|
|
38
|
+
|
|
27
39
|
quant_type = quant_type.lower()
|
|
28
40
|
assert quant_type in (
|
|
29
41
|
"fp8_w8a8_dq",
|
|
@@ -183,7 +195,11 @@ def quantize_ao(
|
|
|
183
195
|
device=kwargs.get("device", None),
|
|
184
196
|
)
|
|
185
197
|
|
|
186
|
-
|
|
198
|
+
maybe_empty_cache()
|
|
199
|
+
|
|
200
|
+
alias_map_rev = {v: k for k, v in alias_map.items()}
|
|
201
|
+
if quant_type in alias_map_rev:
|
|
202
|
+
quant_type = alias_map_rev[quant_type]
|
|
187
203
|
|
|
188
204
|
logger.info(
|
|
189
205
|
f"Quantized Module: {module.__class__.__name__:>5}\n"
|
|
@@ -199,10 +215,13 @@ def quantize_ao(
|
|
|
199
215
|
return module
|
|
200
216
|
|
|
201
217
|
|
|
202
|
-
def
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
218
|
+
def maybe_empty_cache():
|
|
219
|
+
try:
|
|
220
|
+
time.sleep(1)
|
|
221
|
+
gc.collect()
|
|
222
|
+
torch.cuda.empty_cache()
|
|
223
|
+
time.sleep(1)
|
|
224
|
+
gc.collect()
|
|
225
|
+
torch.cuda.empty_cache()
|
|
226
|
+
except Exception:
|
|
227
|
+
pass
|
|
File without changes
|
|
File without changes
|
|
@@ -7,37 +7,24 @@ logger = init_logger(__name__)
|
|
|
7
7
|
|
|
8
8
|
def quantize(
|
|
9
9
|
module: torch.nn.Module,
|
|
10
|
-
quant_type: str = "
|
|
10
|
+
quant_type: str = "float8_weight_only",
|
|
11
11
|
backend: str = "ao",
|
|
12
12
|
exclude_layers: List[str] = [
|
|
13
13
|
"embedder",
|
|
14
14
|
"embed",
|
|
15
15
|
],
|
|
16
16
|
filter_fn: Optional[Callable] = None,
|
|
17
|
-
# only for fp8_w8a8_dq
|
|
18
|
-
per_row: bool = True,
|
|
19
17
|
**kwargs,
|
|
20
18
|
) -> torch.nn.Module:
|
|
21
19
|
assert isinstance(module, torch.nn.Module)
|
|
22
20
|
|
|
23
21
|
if backend.lower() in ("ao", "torchao"):
|
|
24
|
-
from cache_dit.quantize.
|
|
25
|
-
|
|
26
|
-
quant_type = quant_type.lower()
|
|
27
|
-
assert quant_type in (
|
|
28
|
-
"fp8_w8a8_dq",
|
|
29
|
-
"fp8_w8a16_wo",
|
|
30
|
-
"int8_w8a8_dq",
|
|
31
|
-
"int8_w8a16_wo",
|
|
32
|
-
"int4_w4a8_dq",
|
|
33
|
-
"int4_w4a4_dq",
|
|
34
|
-
"int4_w4a16_wo",
|
|
35
|
-
), f"{quant_type} is not supported for torchao backend now!"
|
|
22
|
+
from cache_dit.quantize.backends.torchao import quantize_ao
|
|
36
23
|
|
|
37
24
|
return quantize_ao(
|
|
38
25
|
module,
|
|
39
26
|
quant_type=quant_type,
|
|
40
|
-
per_row=per_row,
|
|
27
|
+
per_row=kwargs.pop("per_row", True),
|
|
41
28
|
exclude_layers=exclude_layers,
|
|
42
29
|
filter_fn=filter_fn,
|
|
43
30
|
**kwargs,
|
cache_dit/utils.py
CHANGED
|
@@ -13,6 +13,7 @@ from cache_dit.cache_factory import CacheType
|
|
|
13
13
|
from cache_dit.cache_factory import BlockAdapter
|
|
14
14
|
from cache_dit.cache_factory import BasicCacheConfig
|
|
15
15
|
from cache_dit.cache_factory import CalibratorConfig
|
|
16
|
+
from cache_dit.cache_factory import FakeDiffusionPipeline
|
|
16
17
|
from cache_dit.parallelism import ParallelismConfig
|
|
17
18
|
from cache_dit.logger import init_logger
|
|
18
19
|
|
|
@@ -64,6 +65,7 @@ def summary(
|
|
|
64
65
|
adapter_or_others: Union[
|
|
65
66
|
BlockAdapter,
|
|
66
67
|
DiffusionPipeline,
|
|
68
|
+
FakeDiffusionPipeline,
|
|
67
69
|
torch.nn.Module,
|
|
68
70
|
],
|
|
69
71
|
details: bool = False,
|
|
@@ -73,9 +75,15 @@ def summary(
|
|
|
73
75
|
if adapter_or_others is None:
|
|
74
76
|
return [CacheStats()]
|
|
75
77
|
|
|
78
|
+
if isinstance(adapter_or_others, FakeDiffusionPipeline):
|
|
79
|
+
raise ValueError(
|
|
80
|
+
"Please pass DiffusionPipeline, BlockAdapter or transfomer, "
|
|
81
|
+
"not FakeDiffusionPipeline."
|
|
82
|
+
)
|
|
83
|
+
|
|
76
84
|
if not isinstance(adapter_or_others, BlockAdapter):
|
|
77
85
|
if not isinstance(adapter_or_others, DiffusionPipeline):
|
|
78
|
-
transformer = adapter_or_others
|
|
86
|
+
transformer = adapter_or_others # transformer-only
|
|
79
87
|
transformer_2 = None
|
|
80
88
|
else:
|
|
81
89
|
transformer = adapter_or_others.transformer
|
|
@@ -165,11 +173,18 @@ def strify(
|
|
|
165
173
|
adapter_or_others: Union[
|
|
166
174
|
BlockAdapter,
|
|
167
175
|
DiffusionPipeline,
|
|
176
|
+
FakeDiffusionPipeline,
|
|
177
|
+
torch.nn.Module,
|
|
168
178
|
CacheStats,
|
|
169
179
|
List[CacheStats],
|
|
170
180
|
Dict[str, Any],
|
|
171
181
|
],
|
|
172
182
|
) -> str:
|
|
183
|
+
if isinstance(adapter_or_others, FakeDiffusionPipeline):
|
|
184
|
+
raise ValueError(
|
|
185
|
+
"Please pass DiffusionPipeline, BlockAdapter or transfomer, "
|
|
186
|
+
"not FakeDiffusionPipeline."
|
|
187
|
+
)
|
|
173
188
|
|
|
174
189
|
parallelism_config: ParallelismConfig = None
|
|
175
190
|
if isinstance(adapter_or_others, BlockAdapter):
|
|
@@ -180,6 +195,10 @@ def strify(
|
|
|
180
195
|
stats = summary(adapter_or_others, logging=False)[-1]
|
|
181
196
|
cache_options = stats.cache_options
|
|
182
197
|
cached_steps = len(stats.cached_steps)
|
|
198
|
+
elif isinstance(adapter_or_others, torch.nn.Module):
|
|
199
|
+
stats = summary(adapter_or_others, logging=False)[-1]
|
|
200
|
+
cache_options = stats.cache_options
|
|
201
|
+
cached_steps = len(stats.cached_steps)
|
|
183
202
|
elif isinstance(adapter_or_others, CacheStats):
|
|
184
203
|
stats = adapter_or_others
|
|
185
204
|
cache_options = stats.cache_options
|
|
@@ -202,7 +221,8 @@ def strify(
|
|
|
202
221
|
else:
|
|
203
222
|
raise ValueError(
|
|
204
223
|
"Please set pipe_or_stats param as one of: "
|
|
205
|
-
"DiffusionPipeline | CacheStats | Dict[str, Any]"
|
|
224
|
+
"DiffusionPipeline | CacheStats | Dict[str, Any] | List[CacheStats]"
|
|
225
|
+
" | BlockAdapter | Transformer"
|
|
206
226
|
)
|
|
207
227
|
|
|
208
228
|
if stats is not None:
|