cache-dit 1.0.8__py3-none-any.whl → 1.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (45) hide show
  1. cache_dit/_version.py +2 -2
  2. cache_dit/cache_factory/__init__.py +1 -0
  3. cache_dit/cache_factory/block_adapters/__init__.py +37 -0
  4. cache_dit/cache_factory/block_adapters/block_adapters.py +75 -4
  5. cache_dit/cache_factory/block_adapters/block_registers.py +44 -17
  6. cache_dit/cache_factory/cache_adapters/cache_adapter.py +72 -30
  7. cache_dit/cache_factory/cache_contexts/cache_config.py +5 -3
  8. cache_dit/cache_factory/cache_contexts/cache_manager.py +125 -4
  9. cache_dit/cache_factory/cache_contexts/context_manager.py +9 -2
  10. cache_dit/cache_factory/cache_contexts/prune_manager.py +15 -2
  11. cache_dit/cache_factory/cache_interface.py +102 -28
  12. cache_dit/cache_factory/forward_pattern.py +14 -14
  13. cache_dit/parallelism/backends/native_diffusers/__init__.py +0 -3
  14. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +95 -0
  15. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +74 -0
  16. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +254 -0
  17. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +17 -49
  18. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  19. cache_dit/parallelism/backends/native_pytorch/__init__.py +3 -0
  20. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  21. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  22. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +159 -0
  23. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  24. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +58 -0
  25. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  26. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +12 -0
  27. cache_dit/parallelism/parallel_backend.py +2 -0
  28. cache_dit/parallelism/parallel_config.py +10 -3
  29. cache_dit/parallelism/parallel_interface.py +14 -5
  30. cache_dit/quantize/backends/__init__.py +1 -0
  31. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  32. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  33. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +28 -9
  34. cache_dit/quantize/quantize_backend.py +0 -0
  35. cache_dit/quantize/quantize_config.py +0 -0
  36. cache_dit/quantize/quantize_interface.py +3 -16
  37. cache_dit/utils.py +56 -20
  38. {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/METADATA +24 -13
  39. {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/RECORD +45 -29
  40. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  41. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  42. {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/WHEEL +0 -0
  43. {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/entry_points.txt +0 -0
  44. {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/licenses/LICENSE +0 -0
  45. {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,48 @@
1
+ try:
2
+ import einops
3
+ except ImportError:
4
+ raise ImportError(
5
+ "Metrics functionality requires the 'parallelism' extra dependencies. "
6
+ "Install with:\npip install cache-dit[parallelism]"
7
+ )
8
+
9
+ import torch
10
+ from typing import Optional
11
+ from diffusers.models.modeling_utils import ModelMixin
12
+ from cache_dit.parallelism.parallel_backend import ParallelismBackend
13
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
14
+ from cache_dit.logger import init_logger
15
+ from .tp_planners import *
16
+
17
+ logger = init_logger(__name__)
18
+
19
+
20
+ def maybe_enable_tensor_parallelism(
21
+ transformer: torch.nn.Module | ModelMixin,
22
+ parallelism_config: Optional[ParallelismConfig],
23
+ ) -> torch.nn.Module:
24
+ assert isinstance(transformer, torch.nn.Module), (
25
+ "transformer must be an instance of torch.nn.Module, "
26
+ f"but got {type(transformer)}"
27
+ )
28
+ assert isinstance(transformer, ModelMixin), (
29
+ "transformer must be an instance of diffusers' ModelMixin, "
30
+ f"but got {type(transformer)}"
31
+ )
32
+ if parallelism_config is None:
33
+ return transformer
34
+
35
+ assert parallelism_config.backend == ParallelismBackend.NATIVE_PYTORCH, (
36
+ "parallelism_config.backend must be ParallelismBackend.NATIVE_PYTORCH "
37
+ f"but got {parallelism_config.backend}"
38
+ )
39
+
40
+ extra_parallel_kwargs = {}
41
+ if parallelism_config.parallel_kwargs is not None:
42
+ extra_parallel_kwargs = parallelism_config.parallel_kwargs
43
+
44
+ return TensorParallelismPlannerRegister.get_planner(transformer)().apply(
45
+ transformer=transformer,
46
+ parallelism_config=parallelism_config,
47
+ **extra_parallel_kwargs,
48
+ )
@@ -0,0 +1,159 @@
1
+ import torch
2
+ from diffusers.models.transformers.transformer_flux import (
3
+ FluxSingleTransformerBlock,
4
+ )
5
+ from einops import rearrange
6
+ from torch import nn
7
+ from torch.distributed import DeviceMesh, init_device_mesh
8
+ from torch.distributed._tensor import Replicate
9
+ from torch.distributed.tensor.parallel import (
10
+ ColwiseParallel,
11
+ RowwiseParallel,
12
+ parallelize_module,
13
+ )
14
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
15
+ from .tp_plan_registers import (
16
+ TensorParallelismPlanner,
17
+ TensorParallelismPlannerRegister,
18
+ )
19
+
20
+ from cache_dit.logger import init_logger
21
+
22
+ logger = init_logger(__name__)
23
+
24
+
25
+ @TensorParallelismPlannerRegister.register("Flux")
26
+ class FluxTensorParallelismPlanner(TensorParallelismPlanner):
27
+ def apply(
28
+ self,
29
+ transformer: torch.nn.Module,
30
+ parallelism_config: ParallelismConfig,
31
+ **kwargs,
32
+ ) -> torch.nn.Module:
33
+ assert (
34
+ parallelism_config.tp_size is not None
35
+ and parallelism_config.tp_size > 1
36
+ ), (
37
+ "parallel_config.tp_size must be set and greater than 1 for "
38
+ "tensor parallelism"
39
+ )
40
+
41
+ device_type = torch.accelerator.current_accelerator().type
42
+ tp_mesh: DeviceMesh = init_device_mesh(
43
+ device_type=device_type,
44
+ mesh_shape=[parallelism_config.tp_size],
45
+ )
46
+
47
+ transformer = self.parallelize_transformer(
48
+ transformer=transformer,
49
+ tp_mesh=tp_mesh,
50
+ )
51
+ # TODO: Parallelize t5 text encoder via `apply_extra`
52
+ # abstract method and `extra_parallel_kwargs` ?
53
+
54
+ return transformer
55
+
56
+ def parallelize_t5(
57
+ self,
58
+ text_encoder: nn.Module,
59
+ tp_mesh: DeviceMesh,
60
+ ):
61
+ for i, block in enumerate(text_encoder.encoder.block):
62
+ block.layer[0].SelfAttention.n_heads //= tp_mesh.size()
63
+ block.layer[0].SelfAttention.inner_dim //= tp_mesh.size()
64
+ layer_plan = {
65
+ "layer.0.SelfAttention.q": ColwiseParallel(),
66
+ "layer.0.SelfAttention.k": ColwiseParallel(),
67
+ "layer.0.SelfAttention.v": ColwiseParallel(),
68
+ "layer.0.SelfAttention.o": RowwiseParallel(),
69
+ "layer.1.DenseReluDense.wi_0": ColwiseParallel(),
70
+ "layer.1.DenseReluDense.wi_1": ColwiseParallel(),
71
+ "layer.1.DenseReluDense.wo": RowwiseParallel(),
72
+ }
73
+ if i == 0:
74
+ layer_plan["layer.0.SelfAttention.relative_attention_bias"] = (
75
+ ColwiseParallel()
76
+ )
77
+ parallelize_module(
78
+ module=block,
79
+ device_mesh=tp_mesh,
80
+ parallelize_plan=layer_plan,
81
+ )
82
+
83
+ return text_encoder
84
+
85
+ def parallelize_transformer(
86
+ self,
87
+ transformer: nn.Module,
88
+ tp_mesh: DeviceMesh,
89
+ ):
90
+ for _, block in transformer.transformer_blocks.named_children():
91
+ block.attn.heads //= tp_mesh.size()
92
+ layer_plan = {
93
+ "attn.to_q": ColwiseParallel(),
94
+ "attn.to_k": ColwiseParallel(),
95
+ "attn.to_v": ColwiseParallel(),
96
+ "attn.to_out.0": RowwiseParallel(),
97
+ "norm1.linear": ColwiseParallel(output_layouts=Replicate()),
98
+ "ff.net.0.proj": ColwiseParallel(),
99
+ "ff.net.2": RowwiseParallel(),
100
+ "attn.add_q_proj": ColwiseParallel(),
101
+ "attn.add_k_proj": ColwiseParallel(),
102
+ "attn.add_v_proj": ColwiseParallel(),
103
+ "attn.to_add_out": RowwiseParallel(),
104
+ "norm1_context.linear": ColwiseParallel(
105
+ output_layouts=Replicate()
106
+ ),
107
+ "ff_context.net.0.proj": ColwiseParallel(),
108
+ "ff_context.net.2": RowwiseParallel(),
109
+ }
110
+ parallelize_module(
111
+ module=block,
112
+ device_mesh=tp_mesh,
113
+ parallelize_plan=layer_plan,
114
+ )
115
+
116
+ # NOTE: special handling for FluxSingleTransformerBlock, we have to
117
+ # rearrange the proj_out weight because it contains both out and down
118
+ # projection weights in a single matrix.
119
+ def rearrange_proj_out_weight(
120
+ single_block: FluxSingleTransformerBlock, tp_group_size
121
+ ):
122
+ # rowwise
123
+ hidden_dim = 3072
124
+ requires_grad = single_block.proj_out.weight.requires_grad
125
+ linear2_weight_data = (
126
+ single_block.proj_out.weight.data.T.detach().clone()
127
+ )
128
+ out_weight = linear2_weight_data[:hidden_dim, ...]
129
+ out_weight = rearrange(
130
+ out_weight, "(G D) C -> G D C", G=tp_group_size
131
+ )
132
+ down_weight = linear2_weight_data.data[hidden_dim:, ...]
133
+ down_weight = rearrange(
134
+ down_weight, "(G D) C -> G D C", G=tp_group_size
135
+ )
136
+ new_linear2_weight = torch.cat([out_weight, down_weight], dim=1)
137
+ new_linear2_weight = rearrange(
138
+ new_linear2_weight, "G D C -> (G D) C"
139
+ )
140
+ single_block.proj_out.weight.data.copy_(new_linear2_weight.T)
141
+ single_block.proj_out.weight.requires_grad_(requires_grad)
142
+
143
+ for _, block in transformer.single_transformer_blocks.named_children():
144
+ rearrange_proj_out_weight(block, tp_mesh.size())
145
+ block.attn.heads //= tp_mesh.size()
146
+ layer_plan = {
147
+ "attn.to_q": ColwiseParallel(),
148
+ "attn.to_k": ColwiseParallel(),
149
+ "attn.to_v": ColwiseParallel(),
150
+ "proj_mlp": ColwiseParallel(),
151
+ "proj_out": RowwiseParallel(),
152
+ "norm.linear": ColwiseParallel(output_layouts=Replicate()),
153
+ }
154
+ parallelize_module(
155
+ module=block,
156
+ device_mesh=tp_mesh,
157
+ parallelize_plan=layer_plan,
158
+ )
159
+ return transformer
@@ -0,0 +1,78 @@
1
+ import torch
2
+ from torch import nn
3
+ from torch.distributed import DeviceMesh, init_device_mesh
4
+ from torch.distributed._tensor import Replicate
5
+ from torch.distributed.tensor.parallel import (
6
+ ColwiseParallel,
7
+ RowwiseParallel,
8
+ parallelize_module,
9
+ )
10
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
11
+ from .tp_plan_registers import (
12
+ TensorParallelismPlanner,
13
+ TensorParallelismPlannerRegister,
14
+ )
15
+
16
+ from cache_dit.logger import init_logger
17
+
18
+ logger = init_logger(__name__)
19
+
20
+
21
+ @TensorParallelismPlannerRegister.register("QwenImage")
22
+ class QwenImageTensorParallelismPlanner(TensorParallelismPlanner):
23
+ def apply(
24
+ self,
25
+ transformer: torch.nn.Module,
26
+ parallelism_config: ParallelismConfig,
27
+ **kwargs,
28
+ ) -> torch.nn.Module:
29
+ assert (
30
+ parallelism_config.tp_size is not None
31
+ and parallelism_config.tp_size > 1
32
+ ), (
33
+ "parallel_config.tp_size must be set and greater than 1 for "
34
+ "tensor parallelism"
35
+ )
36
+
37
+ device_type = torch.accelerator.current_accelerator().type
38
+ tp_mesh: DeviceMesh = init_device_mesh(
39
+ device_type=device_type,
40
+ mesh_shape=[parallelism_config.tp_size],
41
+ )
42
+
43
+ transformer = self.parallelize_transformer(
44
+ transformer=transformer,
45
+ tp_mesh=tp_mesh,
46
+ )
47
+
48
+ return transformer
49
+
50
+ def parallelize_transformer(
51
+ self,
52
+ transformer: nn.Module,
53
+ tp_mesh: DeviceMesh,
54
+ ):
55
+ for _, block in transformer.transformer_blocks.named_children():
56
+ block.attn.heads //= tp_mesh.size()
57
+ layer_plan = {
58
+ "attn.to_q": ColwiseParallel(),
59
+ "attn.to_k": ColwiseParallel(),
60
+ "attn.to_v": ColwiseParallel(),
61
+ "attn.to_out.0": RowwiseParallel(),
62
+ "img_mod.1": ColwiseParallel(output_layouts=Replicate()),
63
+ "img_mlp.net.0.proj": ColwiseParallel(),
64
+ "img_mlp.net.2": RowwiseParallel(),
65
+ "attn.add_q_proj": ColwiseParallel(),
66
+ "attn.add_k_proj": ColwiseParallel(),
67
+ "attn.add_v_proj": ColwiseParallel(),
68
+ "attn.to_add_out": RowwiseParallel(),
69
+ "txt_mod.1": ColwiseParallel(output_layouts=Replicate()),
70
+ "txt_mlp.net.0.proj": ColwiseParallel(),
71
+ "txt_mlp.net.2": RowwiseParallel(),
72
+ }
73
+ parallelize_module(
74
+ module=block,
75
+ device_mesh=tp_mesh,
76
+ parallelize_plan=layer_plan,
77
+ )
78
+ return transformer
@@ -0,0 +1,58 @@
1
+ import torch
2
+ import logging
3
+ from abc import abstractmethod
4
+ from typing import Dict
5
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
6
+ from cache_dit.logger import init_logger
7
+
8
+ logger = init_logger(__name__)
9
+
10
+
11
+ class TensorParallelismPlanner:
12
+ # TODO: add `apply_extra` abstract method for extra
13
+ # parallelism kwargs handling
14
+
15
+ @abstractmethod
16
+ def apply(
17
+ self,
18
+ transformer: torch.nn.Module,
19
+ parallelism_config: ParallelismConfig,
20
+ **kwargs,
21
+ ) -> torch.nn.Module:
22
+ raise NotImplementedError(
23
+ "apply method must be implemented by subclasses"
24
+ )
25
+
26
+
27
+ class TensorParallelismPlannerRegister:
28
+ _tp_planner_registry: Dict[str, TensorParallelismPlanner] = {}
29
+
30
+ @classmethod
31
+ def register(cls, name: str):
32
+ def decorator(planner_cls: type[TensorParallelismPlanner]):
33
+ assert (
34
+ name not in cls._tp_planner_registry
35
+ ), f"TensorParallelismPlanner with name {name} is already registered."
36
+ if logger.isEnabledFor(logging.DEBUG):
37
+ logger.debug(f"Registering TensorParallelismPlanner: {name}")
38
+ cls._tp_planner_registry[name] = planner_cls
39
+ return planner_cls
40
+
41
+ return decorator
42
+
43
+ @classmethod
44
+ def get_planner(
45
+ cls, transformer: str | torch.nn.Module
46
+ ) -> type[TensorParallelismPlanner]:
47
+ if isinstance(transformer, torch.nn.Module):
48
+ name = transformer.__class__.__name__
49
+ else:
50
+ name = transformer
51
+ planner_cls = None
52
+ for planner_name in cls._tp_planner_registry:
53
+ if name.startswith(planner_name):
54
+ planner_cls = cls._tp_planner_registry.get(planner_name)
55
+ break
56
+ if planner_cls is None:
57
+ raise ValueError(f"No planner registered under name: {name}")
58
+ return planner_cls
@@ -0,0 +1,153 @@
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.distributed import DeviceMesh, init_device_mesh
6
+ from torch.distributed.tensor.parallel import (
7
+ ColwiseParallel,
8
+ RowwiseParallel,
9
+ parallelize_module,
10
+ )
11
+
12
+ from cache_dit.logger import init_logger
13
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
14
+
15
+ from .tp_plan_registers import (
16
+ TensorParallelismPlanner,
17
+ TensorParallelismPlannerRegister,
18
+ )
19
+
20
+ logger = init_logger(__name__)
21
+
22
+
23
+ class DistributedRMSNorm(nn.Module):
24
+ def __init__(
25
+ self,
26
+ tp_mesh: DeviceMesh,
27
+ normalized_shape: Union[int, list[int], torch.Size],
28
+ eps: Optional[float],
29
+ elementwise_affine: bool,
30
+ weight: torch.nn.parameter.Parameter,
31
+ ):
32
+ super().__init__()
33
+ self.tp_mesh = tp_mesh
34
+ self.elementwise_affine = elementwise_affine
35
+ self.normalized_shape = normalized_shape
36
+ self.eps = eps
37
+ if self.elementwise_affine:
38
+ assert weight is not None
39
+ self.weight = weight
40
+
41
+ @classmethod
42
+ def from_rmsnorm(cls, tp_mesh: DeviceMesh, rmsnorm: nn.RMSNorm):
43
+ if not isinstance(rmsnorm, int):
44
+ assert len(rmsnorm.normalized_shape) == 1
45
+
46
+ if rmsnorm.weight is not None:
47
+ tp_size = tp_mesh.get_group().size()
48
+ tp_rank = tp_mesh.get_group().rank()
49
+ weight = rmsnorm.weight.chunk(tp_size, dim=0)[tp_rank]
50
+ else:
51
+ weight = None
52
+ norm = cls(
53
+ tp_mesh=tp_mesh,
54
+ normalized_shape=rmsnorm.normalized_shape,
55
+ eps=rmsnorm.eps,
56
+ elementwise_affine=rmsnorm.elementwise_affine,
57
+ weight=weight,
58
+ )
59
+ return norm
60
+
61
+ def forward(self, x):
62
+ if self.elementwise_affine:
63
+ assert x.shape[-1] == self.weight.shape[0]
64
+ mean_square = torch.mean(x * x, dim=-1, keepdim=True)
65
+ torch.distributed.all_reduce(
66
+ mean_square,
67
+ op=torch.distributed.ReduceOp.AVG,
68
+ group=self.tp_mesh.get_group(),
69
+ )
70
+ root_mean_square = torch.sqrt(mean_square + self.eps)
71
+ x_normed = x / root_mean_square
72
+ if self.elementwise_affine:
73
+ x_normed = x_normed * self.weight.to(device=x.device)
74
+ assert x_normed.device.type != "cpu"
75
+ return x_normed
76
+
77
+
78
+ @TensorParallelismPlannerRegister.register("Wan")
79
+ class WanTensorParallelismPlanner(TensorParallelismPlanner):
80
+ def apply(
81
+ self,
82
+ transformer: torch.nn.Module,
83
+ parallelism_config: ParallelismConfig,
84
+ **kwargs,
85
+ ) -> torch.nn.Module:
86
+ assert (
87
+ parallelism_config.tp_size is not None
88
+ and parallelism_config.tp_size > 1
89
+ ), (
90
+ "parallel_config.tp_size must be set and greater than 1 for "
91
+ "tensor parallelism"
92
+ )
93
+
94
+ device_type = torch.accelerator.current_accelerator().type
95
+ tp_mesh: DeviceMesh = init_device_mesh(
96
+ device_type=device_type,
97
+ mesh_shape=[parallelism_config.tp_size],
98
+ )
99
+
100
+ transformer = self.parallelize_transformer(
101
+ transformer=transformer,
102
+ tp_mesh=tp_mesh,
103
+ )
104
+
105
+ return transformer
106
+
107
+ def parallelize_transformer(
108
+ self,
109
+ transformer: nn.Module,
110
+ tp_mesh: DeviceMesh,
111
+ ):
112
+ for _, block in transformer.blocks.named_children():
113
+ block.attn1.heads //= tp_mesh.size()
114
+ block.attn2.heads //= tp_mesh.size()
115
+ layer_plan = {
116
+ "attn1.to_q": ColwiseParallel(),
117
+ "attn1.to_k": ColwiseParallel(),
118
+ "attn1.to_v": ColwiseParallel(),
119
+ "attn1.to_out.0": RowwiseParallel(),
120
+ "attn2.to_q": ColwiseParallel(),
121
+ "attn2.to_k": ColwiseParallel(),
122
+ "attn2.to_v": ColwiseParallel(),
123
+ "attn2.to_out.0": RowwiseParallel(),
124
+ "ffn.net.0.proj": ColwiseParallel(),
125
+ "ffn.net.2": RowwiseParallel(),
126
+ }
127
+ if getattr(block.attn2, "add_k_proj", None):
128
+ layer_plan["attn2.add_k_proj"] = ColwiseParallel()
129
+ if getattr(block.attn2, "add_v_proj", None):
130
+ layer_plan["attn2.add_v_proj"] = ColwiseParallel()
131
+ parallelize_module(
132
+ module=block,
133
+ device_mesh=tp_mesh,
134
+ parallelize_plan=layer_plan,
135
+ )
136
+
137
+ block.attn1.norm_q = DistributedRMSNorm.from_rmsnorm(
138
+ tp_mesh, block.attn1.norm_q
139
+ )
140
+ block.attn1.norm_k = DistributedRMSNorm.from_rmsnorm(
141
+ tp_mesh, block.attn1.norm_k
142
+ )
143
+ block.attn2.norm_q = DistributedRMSNorm.from_rmsnorm(
144
+ tp_mesh, block.attn2.norm_q
145
+ )
146
+ block.attn2.norm_k = DistributedRMSNorm.from_rmsnorm(
147
+ tp_mesh, block.attn2.norm_k
148
+ )
149
+ if getattr(block.attn2, "norm_added_k", None):
150
+ block.attn2.norm_added_k = DistributedRMSNorm.from_rmsnorm(
151
+ tp_mesh, block.attn2.norm_added_k
152
+ )
153
+ return transformer
@@ -0,0 +1,12 @@
1
+ # NOTE: must import all planner classes to register them
2
+ from .tp_plan_registers import TensorParallelismPlannerRegister
3
+ from .tp_plan_flux import FluxTensorParallelismPlanner
4
+ from .tp_plan_qwen_image import QwenImageTensorParallelismPlanner
5
+ from .tp_plan_wan import WanTensorParallelismPlanner
6
+
7
+ __all__ = [
8
+ "TensorParallelismPlannerRegister",
9
+ "FluxTensorParallelismPlanner",
10
+ "QwenImageTensorParallelismPlanner",
11
+ "WanTensorParallelismPlanner",
12
+ ]
@@ -8,6 +8,8 @@ class ParallelismBackend(Enum):
8
8
 
9
9
  @classmethod
10
10
  def is_supported(cls, backend: "ParallelismBackend") -> bool:
11
+ if backend in [cls.NATIVE_PYTORCH]:
12
+ return True
11
13
  # Now, only Native_Diffuser backend is supported
12
14
  if backend in [cls.NATIVE_DIFFUSER]:
13
15
  try:
@@ -23,8 +23,8 @@ class ParallelismConfig:
23
23
  tp_size: int = None
24
24
  # parallel_kwargs (`dict`, *optional*):
25
25
  # Additional kwargs for parallelism backends. For example, for
26
- # NATIVE_DIFFUSER backend, it can include `cp_plan` and other
27
- # arguments for `Context Parallelism`.
26
+ # NATIVE_DIFFUSER backend, it can include `cp_plan` and
27
+ # `attention_backend` arguments for `Context Parallelism`.
28
28
  parallel_kwargs: Optional[Dict[str, Any]] = dataclasses.field(
29
29
  default_factory=dict
30
30
  )
@@ -34,7 +34,14 @@ class ParallelismConfig:
34
34
  f"Parallel backend {self.backend} is not supported. "
35
35
  f"Please make sure the required packages are installed."
36
36
  )
37
- assert self.tp_size is None, "Tensor parallelism is not supported yet."
37
+
38
+ if self.tp_size is not None and self.tp_size > 1:
39
+ assert (
40
+ self.ulysses_size is None or self.ulysses_size == 1
41
+ ), "Tensor parallelism plus Ulysses parallelism is not supported right now."
42
+ assert (
43
+ self.ring_size is None or self.ring_size == 1
44
+ ), "Tensor parallelism plus Ring parallelism is not supported right now."
38
45
 
39
46
  def strify(self, details: bool = False) -> str:
40
47
  if details:
@@ -24,12 +24,17 @@ def enable_parallelism(
24
24
  if parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER:
25
25
  from cache_dit.parallelism.backends.native_diffusers import (
26
26
  maybe_enable_parallelism,
27
- native_diffusers_parallelism_available,
28
27
  )
29
28
 
30
- assert (
31
- native_diffusers_parallelism_available()
32
- ), "Please install diffusers>=0.36.dev0 to use Native_Diffuser backend."
29
+ transformer = maybe_enable_parallelism(
30
+ transformer,
31
+ parallelism_config,
32
+ )
33
+ elif parallelism_config.backend == ParallelismBackend.NATIVE_PYTORCH:
34
+ from cache_dit.parallelism.backends.native_pytorch import (
35
+ maybe_enable_parallelism,
36
+ )
37
+
33
38
  transformer = maybe_enable_parallelism(
34
39
  transformer,
35
40
  parallelism_config,
@@ -40,8 +45,12 @@ def enable_parallelism(
40
45
  )
41
46
 
42
47
  transformer._is_parallelized = True # type: ignore[attr-defined]
48
+ # Use `parallelism` not `parallel` to avoid name conflict with diffusers.
43
49
  transformer._parallelism_config = parallelism_config # type: ignore[attr-defined]
44
- logger.info(f"Enabled parallelism: {parallelism_config.strify(True)}")
50
+ logger.info(
51
+ f"Enabled parallelism: {parallelism_config.strify(True)}, "
52
+ f"transformer id:{id(transformer)}"
53
+ )
45
54
  return transformer
46
55
 
47
56
 
@@ -0,0 +1 @@
1
+ from .torchao import quantize_ao
File without changes
@@ -0,0 +1 @@
1
+ from .quantize_ao import quantize_ao
@@ -9,7 +9,7 @@ logger = init_logger(__name__)
9
9
 
10
10
  def quantize_ao(
11
11
  module: torch.nn.Module,
12
- quant_type: str = "fp8_w8a8_dq",
12
+ quant_type: str = "float8_weight_only",
13
13
  exclude_layers: List[str] = [
14
14
  "embedder",
15
15
  "embed",
@@ -24,6 +24,18 @@ def quantize_ao(
24
24
  # set `exclude_layers` as `[]` if you don't want this behavior.
25
25
  assert isinstance(module, torch.nn.Module)
26
26
 
27
+ alias_map = {
28
+ "float8": "fp8_w8a8_dq",
29
+ "float8_weight_only": "fp8_w8a16_wo",
30
+ "int8": "int8_w8a8_dq",
31
+ "int8_weight_only": "int8_w8a16_wo",
32
+ "int4": "int4_w4a8_dq",
33
+ "int4_w4a4": "int4_w4a4_dq",
34
+ "int4_weight_only": "int4_w4a16_wo",
35
+ }
36
+ if quant_type.lower() in alias_map:
37
+ quant_type = alias_map[quant_type.lower()]
38
+
27
39
  quant_type = quant_type.lower()
28
40
  assert quant_type in (
29
41
  "fp8_w8a8_dq",
@@ -183,7 +195,11 @@ def quantize_ao(
183
195
  device=kwargs.get("device", None),
184
196
  )
185
197
 
186
- force_empty_cache()
198
+ maybe_empty_cache()
199
+
200
+ alias_map_rev = {v: k for k, v in alias_map.items()}
201
+ if quant_type in alias_map_rev:
202
+ quant_type = alias_map_rev[quant_type]
187
203
 
188
204
  logger.info(
189
205
  f"Quantized Module: {module.__class__.__name__:>5}\n"
@@ -199,10 +215,13 @@ def quantize_ao(
199
215
  return module
200
216
 
201
217
 
202
- def force_empty_cache():
203
- time.sleep(1)
204
- gc.collect()
205
- torch.cuda.empty_cache()
206
- time.sleep(1)
207
- gc.collect()
208
- torch.cuda.empty_cache()
218
+ def maybe_empty_cache():
219
+ try:
220
+ time.sleep(1)
221
+ gc.collect()
222
+ torch.cuda.empty_cache()
223
+ time.sleep(1)
224
+ gc.collect()
225
+ torch.cuda.empty_cache()
226
+ except Exception:
227
+ pass
File without changes
File without changes