cache-dit 1.0.9__py3-none-any.whl → 1.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (45) hide show
  1. cache_dit/_version.py +2 -2
  2. cache_dit/cache_factory/__init__.py +1 -0
  3. cache_dit/cache_factory/block_adapters/__init__.py +37 -0
  4. cache_dit/cache_factory/block_adapters/block_adapters.py +51 -3
  5. cache_dit/cache_factory/block_adapters/block_registers.py +41 -14
  6. cache_dit/cache_factory/cache_adapters/cache_adapter.py +68 -30
  7. cache_dit/cache_factory/cache_contexts/cache_config.py +5 -3
  8. cache_dit/cache_factory/cache_contexts/cache_manager.py +125 -4
  9. cache_dit/cache_factory/cache_contexts/context_manager.py +9 -2
  10. cache_dit/cache_factory/cache_contexts/prune_manager.py +15 -2
  11. cache_dit/cache_factory/cache_interface.py +29 -3
  12. cache_dit/cache_factory/forward_pattern.py +14 -14
  13. cache_dit/parallelism/backends/native_diffusers/__init__.py +0 -3
  14. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +95 -0
  15. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +74 -0
  16. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +254 -0
  17. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +17 -61
  18. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  19. cache_dit/parallelism/backends/native_pytorch/__init__.py +3 -0
  20. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  21. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  22. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +159 -0
  23. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  24. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +58 -0
  25. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  26. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +12 -0
  27. cache_dit/parallelism/parallel_backend.py +2 -0
  28. cache_dit/parallelism/parallel_config.py +8 -1
  29. cache_dit/parallelism/parallel_interface.py +9 -4
  30. cache_dit/quantize/backends/__init__.py +1 -0
  31. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  32. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  33. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +28 -9
  34. cache_dit/quantize/quantize_backend.py +0 -0
  35. cache_dit/quantize/quantize_config.py +0 -0
  36. cache_dit/quantize/quantize_interface.py +3 -16
  37. cache_dit/utils.py +22 -2
  38. {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/METADATA +22 -13
  39. {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/RECORD +45 -29
  40. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  41. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  42. {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/WHEEL +0 -0
  43. {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/entry_points.txt +0 -0
  44. {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/licenses/LICENSE +0 -0
  45. {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,78 @@
1
+ import torch
2
+ from torch import nn
3
+ from torch.distributed import DeviceMesh, init_device_mesh
4
+ from torch.distributed._tensor import Replicate
5
+ from torch.distributed.tensor.parallel import (
6
+ ColwiseParallel,
7
+ RowwiseParallel,
8
+ parallelize_module,
9
+ )
10
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
11
+ from .tp_plan_registers import (
12
+ TensorParallelismPlanner,
13
+ TensorParallelismPlannerRegister,
14
+ )
15
+
16
+ from cache_dit.logger import init_logger
17
+
18
+ logger = init_logger(__name__)
19
+
20
+
21
+ @TensorParallelismPlannerRegister.register("QwenImage")
22
+ class QwenImageTensorParallelismPlanner(TensorParallelismPlanner):
23
+ def apply(
24
+ self,
25
+ transformer: torch.nn.Module,
26
+ parallelism_config: ParallelismConfig,
27
+ **kwargs,
28
+ ) -> torch.nn.Module:
29
+ assert (
30
+ parallelism_config.tp_size is not None
31
+ and parallelism_config.tp_size > 1
32
+ ), (
33
+ "parallel_config.tp_size must be set and greater than 1 for "
34
+ "tensor parallelism"
35
+ )
36
+
37
+ device_type = torch.accelerator.current_accelerator().type
38
+ tp_mesh: DeviceMesh = init_device_mesh(
39
+ device_type=device_type,
40
+ mesh_shape=[parallelism_config.tp_size],
41
+ )
42
+
43
+ transformer = self.parallelize_transformer(
44
+ transformer=transformer,
45
+ tp_mesh=tp_mesh,
46
+ )
47
+
48
+ return transformer
49
+
50
+ def parallelize_transformer(
51
+ self,
52
+ transformer: nn.Module,
53
+ tp_mesh: DeviceMesh,
54
+ ):
55
+ for _, block in transformer.transformer_blocks.named_children():
56
+ block.attn.heads //= tp_mesh.size()
57
+ layer_plan = {
58
+ "attn.to_q": ColwiseParallel(),
59
+ "attn.to_k": ColwiseParallel(),
60
+ "attn.to_v": ColwiseParallel(),
61
+ "attn.to_out.0": RowwiseParallel(),
62
+ "img_mod.1": ColwiseParallel(output_layouts=Replicate()),
63
+ "img_mlp.net.0.proj": ColwiseParallel(),
64
+ "img_mlp.net.2": RowwiseParallel(),
65
+ "attn.add_q_proj": ColwiseParallel(),
66
+ "attn.add_k_proj": ColwiseParallel(),
67
+ "attn.add_v_proj": ColwiseParallel(),
68
+ "attn.to_add_out": RowwiseParallel(),
69
+ "txt_mod.1": ColwiseParallel(output_layouts=Replicate()),
70
+ "txt_mlp.net.0.proj": ColwiseParallel(),
71
+ "txt_mlp.net.2": RowwiseParallel(),
72
+ }
73
+ parallelize_module(
74
+ module=block,
75
+ device_mesh=tp_mesh,
76
+ parallelize_plan=layer_plan,
77
+ )
78
+ return transformer
@@ -0,0 +1,58 @@
1
+ import torch
2
+ import logging
3
+ from abc import abstractmethod
4
+ from typing import Dict
5
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
6
+ from cache_dit.logger import init_logger
7
+
8
+ logger = init_logger(__name__)
9
+
10
+
11
+ class TensorParallelismPlanner:
12
+ # TODO: add `apply_extra` abstract method for extra
13
+ # parallelism kwargs handling
14
+
15
+ @abstractmethod
16
+ def apply(
17
+ self,
18
+ transformer: torch.nn.Module,
19
+ parallelism_config: ParallelismConfig,
20
+ **kwargs,
21
+ ) -> torch.nn.Module:
22
+ raise NotImplementedError(
23
+ "apply method must be implemented by subclasses"
24
+ )
25
+
26
+
27
+ class TensorParallelismPlannerRegister:
28
+ _tp_planner_registry: Dict[str, TensorParallelismPlanner] = {}
29
+
30
+ @classmethod
31
+ def register(cls, name: str):
32
+ def decorator(planner_cls: type[TensorParallelismPlanner]):
33
+ assert (
34
+ name not in cls._tp_planner_registry
35
+ ), f"TensorParallelismPlanner with name {name} is already registered."
36
+ if logger.isEnabledFor(logging.DEBUG):
37
+ logger.debug(f"Registering TensorParallelismPlanner: {name}")
38
+ cls._tp_planner_registry[name] = planner_cls
39
+ return planner_cls
40
+
41
+ return decorator
42
+
43
+ @classmethod
44
+ def get_planner(
45
+ cls, transformer: str | torch.nn.Module
46
+ ) -> type[TensorParallelismPlanner]:
47
+ if isinstance(transformer, torch.nn.Module):
48
+ name = transformer.__class__.__name__
49
+ else:
50
+ name = transformer
51
+ planner_cls = None
52
+ for planner_name in cls._tp_planner_registry:
53
+ if name.startswith(planner_name):
54
+ planner_cls = cls._tp_planner_registry.get(planner_name)
55
+ break
56
+ if planner_cls is None:
57
+ raise ValueError(f"No planner registered under name: {name}")
58
+ return planner_cls
@@ -0,0 +1,153 @@
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.distributed import DeviceMesh, init_device_mesh
6
+ from torch.distributed.tensor.parallel import (
7
+ ColwiseParallel,
8
+ RowwiseParallel,
9
+ parallelize_module,
10
+ )
11
+
12
+ from cache_dit.logger import init_logger
13
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
14
+
15
+ from .tp_plan_registers import (
16
+ TensorParallelismPlanner,
17
+ TensorParallelismPlannerRegister,
18
+ )
19
+
20
+ logger = init_logger(__name__)
21
+
22
+
23
+ class DistributedRMSNorm(nn.Module):
24
+ def __init__(
25
+ self,
26
+ tp_mesh: DeviceMesh,
27
+ normalized_shape: Union[int, list[int], torch.Size],
28
+ eps: Optional[float],
29
+ elementwise_affine: bool,
30
+ weight: torch.nn.parameter.Parameter,
31
+ ):
32
+ super().__init__()
33
+ self.tp_mesh = tp_mesh
34
+ self.elementwise_affine = elementwise_affine
35
+ self.normalized_shape = normalized_shape
36
+ self.eps = eps
37
+ if self.elementwise_affine:
38
+ assert weight is not None
39
+ self.weight = weight
40
+
41
+ @classmethod
42
+ def from_rmsnorm(cls, tp_mesh: DeviceMesh, rmsnorm: nn.RMSNorm):
43
+ if not isinstance(rmsnorm, int):
44
+ assert len(rmsnorm.normalized_shape) == 1
45
+
46
+ if rmsnorm.weight is not None:
47
+ tp_size = tp_mesh.get_group().size()
48
+ tp_rank = tp_mesh.get_group().rank()
49
+ weight = rmsnorm.weight.chunk(tp_size, dim=0)[tp_rank]
50
+ else:
51
+ weight = None
52
+ norm = cls(
53
+ tp_mesh=tp_mesh,
54
+ normalized_shape=rmsnorm.normalized_shape,
55
+ eps=rmsnorm.eps,
56
+ elementwise_affine=rmsnorm.elementwise_affine,
57
+ weight=weight,
58
+ )
59
+ return norm
60
+
61
+ def forward(self, x):
62
+ if self.elementwise_affine:
63
+ assert x.shape[-1] == self.weight.shape[0]
64
+ mean_square = torch.mean(x * x, dim=-1, keepdim=True)
65
+ torch.distributed.all_reduce(
66
+ mean_square,
67
+ op=torch.distributed.ReduceOp.AVG,
68
+ group=self.tp_mesh.get_group(),
69
+ )
70
+ root_mean_square = torch.sqrt(mean_square + self.eps)
71
+ x_normed = x / root_mean_square
72
+ if self.elementwise_affine:
73
+ x_normed = x_normed * self.weight.to(device=x.device)
74
+ assert x_normed.device.type != "cpu"
75
+ return x_normed
76
+
77
+
78
+ @TensorParallelismPlannerRegister.register("Wan")
79
+ class WanTensorParallelismPlanner(TensorParallelismPlanner):
80
+ def apply(
81
+ self,
82
+ transformer: torch.nn.Module,
83
+ parallelism_config: ParallelismConfig,
84
+ **kwargs,
85
+ ) -> torch.nn.Module:
86
+ assert (
87
+ parallelism_config.tp_size is not None
88
+ and parallelism_config.tp_size > 1
89
+ ), (
90
+ "parallel_config.tp_size must be set and greater than 1 for "
91
+ "tensor parallelism"
92
+ )
93
+
94
+ device_type = torch.accelerator.current_accelerator().type
95
+ tp_mesh: DeviceMesh = init_device_mesh(
96
+ device_type=device_type,
97
+ mesh_shape=[parallelism_config.tp_size],
98
+ )
99
+
100
+ transformer = self.parallelize_transformer(
101
+ transformer=transformer,
102
+ tp_mesh=tp_mesh,
103
+ )
104
+
105
+ return transformer
106
+
107
+ def parallelize_transformer(
108
+ self,
109
+ transformer: nn.Module,
110
+ tp_mesh: DeviceMesh,
111
+ ):
112
+ for _, block in transformer.blocks.named_children():
113
+ block.attn1.heads //= tp_mesh.size()
114
+ block.attn2.heads //= tp_mesh.size()
115
+ layer_plan = {
116
+ "attn1.to_q": ColwiseParallel(),
117
+ "attn1.to_k": ColwiseParallel(),
118
+ "attn1.to_v": ColwiseParallel(),
119
+ "attn1.to_out.0": RowwiseParallel(),
120
+ "attn2.to_q": ColwiseParallel(),
121
+ "attn2.to_k": ColwiseParallel(),
122
+ "attn2.to_v": ColwiseParallel(),
123
+ "attn2.to_out.0": RowwiseParallel(),
124
+ "ffn.net.0.proj": ColwiseParallel(),
125
+ "ffn.net.2": RowwiseParallel(),
126
+ }
127
+ if getattr(block.attn2, "add_k_proj", None):
128
+ layer_plan["attn2.add_k_proj"] = ColwiseParallel()
129
+ if getattr(block.attn2, "add_v_proj", None):
130
+ layer_plan["attn2.add_v_proj"] = ColwiseParallel()
131
+ parallelize_module(
132
+ module=block,
133
+ device_mesh=tp_mesh,
134
+ parallelize_plan=layer_plan,
135
+ )
136
+
137
+ block.attn1.norm_q = DistributedRMSNorm.from_rmsnorm(
138
+ tp_mesh, block.attn1.norm_q
139
+ )
140
+ block.attn1.norm_k = DistributedRMSNorm.from_rmsnorm(
141
+ tp_mesh, block.attn1.norm_k
142
+ )
143
+ block.attn2.norm_q = DistributedRMSNorm.from_rmsnorm(
144
+ tp_mesh, block.attn2.norm_q
145
+ )
146
+ block.attn2.norm_k = DistributedRMSNorm.from_rmsnorm(
147
+ tp_mesh, block.attn2.norm_k
148
+ )
149
+ if getattr(block.attn2, "norm_added_k", None):
150
+ block.attn2.norm_added_k = DistributedRMSNorm.from_rmsnorm(
151
+ tp_mesh, block.attn2.norm_added_k
152
+ )
153
+ return transformer
@@ -0,0 +1,12 @@
1
+ # NOTE: must import all planner classes to register them
2
+ from .tp_plan_registers import TensorParallelismPlannerRegister
3
+ from .tp_plan_flux import FluxTensorParallelismPlanner
4
+ from .tp_plan_qwen_image import QwenImageTensorParallelismPlanner
5
+ from .tp_plan_wan import WanTensorParallelismPlanner
6
+
7
+ __all__ = [
8
+ "TensorParallelismPlannerRegister",
9
+ "FluxTensorParallelismPlanner",
10
+ "QwenImageTensorParallelismPlanner",
11
+ "WanTensorParallelismPlanner",
12
+ ]
@@ -8,6 +8,8 @@ class ParallelismBackend(Enum):
8
8
 
9
9
  @classmethod
10
10
  def is_supported(cls, backend: "ParallelismBackend") -> bool:
11
+ if backend in [cls.NATIVE_PYTORCH]:
12
+ return True
11
13
  # Now, only Native_Diffuser backend is supported
12
14
  if backend in [cls.NATIVE_DIFFUSER]:
13
15
  try:
@@ -34,7 +34,14 @@ class ParallelismConfig:
34
34
  f"Parallel backend {self.backend} is not supported. "
35
35
  f"Please make sure the required packages are installed."
36
36
  )
37
- assert self.tp_size is None, "Tensor parallelism is not supported yet."
37
+
38
+ if self.tp_size is not None and self.tp_size > 1:
39
+ assert (
40
+ self.ulysses_size is None or self.ulysses_size == 1
41
+ ), "Tensor parallelism plus Ulysses parallelism is not supported right now."
42
+ assert (
43
+ self.ring_size is None or self.ring_size == 1
44
+ ), "Tensor parallelism plus Ring parallelism is not supported right now."
38
45
 
39
46
  def strify(self, details: bool = False) -> str:
40
47
  if details:
@@ -24,12 +24,17 @@ def enable_parallelism(
24
24
  if parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER:
25
25
  from cache_dit.parallelism.backends.native_diffusers import (
26
26
  maybe_enable_parallelism,
27
- native_diffusers_parallelism_available,
28
27
  )
29
28
 
30
- assert (
31
- native_diffusers_parallelism_available()
32
- ), "Please install diffusers>=0.36.dev0 to use Native_Diffuser backend."
29
+ transformer = maybe_enable_parallelism(
30
+ transformer,
31
+ parallelism_config,
32
+ )
33
+ elif parallelism_config.backend == ParallelismBackend.NATIVE_PYTORCH:
34
+ from cache_dit.parallelism.backends.native_pytorch import (
35
+ maybe_enable_parallelism,
36
+ )
37
+
33
38
  transformer = maybe_enable_parallelism(
34
39
  transformer,
35
40
  parallelism_config,
@@ -0,0 +1 @@
1
+ from .torchao import quantize_ao
File without changes
@@ -0,0 +1 @@
1
+ from .quantize_ao import quantize_ao
@@ -9,7 +9,7 @@ logger = init_logger(__name__)
9
9
 
10
10
  def quantize_ao(
11
11
  module: torch.nn.Module,
12
- quant_type: str = "fp8_w8a8_dq",
12
+ quant_type: str = "float8_weight_only",
13
13
  exclude_layers: List[str] = [
14
14
  "embedder",
15
15
  "embed",
@@ -24,6 +24,18 @@ def quantize_ao(
24
24
  # set `exclude_layers` as `[]` if you don't want this behavior.
25
25
  assert isinstance(module, torch.nn.Module)
26
26
 
27
+ alias_map = {
28
+ "float8": "fp8_w8a8_dq",
29
+ "float8_weight_only": "fp8_w8a16_wo",
30
+ "int8": "int8_w8a8_dq",
31
+ "int8_weight_only": "int8_w8a16_wo",
32
+ "int4": "int4_w4a8_dq",
33
+ "int4_w4a4": "int4_w4a4_dq",
34
+ "int4_weight_only": "int4_w4a16_wo",
35
+ }
36
+ if quant_type.lower() in alias_map:
37
+ quant_type = alias_map[quant_type.lower()]
38
+
27
39
  quant_type = quant_type.lower()
28
40
  assert quant_type in (
29
41
  "fp8_w8a8_dq",
@@ -183,7 +195,11 @@ def quantize_ao(
183
195
  device=kwargs.get("device", None),
184
196
  )
185
197
 
186
- force_empty_cache()
198
+ maybe_empty_cache()
199
+
200
+ alias_map_rev = {v: k for k, v in alias_map.items()}
201
+ if quant_type in alias_map_rev:
202
+ quant_type = alias_map_rev[quant_type]
187
203
 
188
204
  logger.info(
189
205
  f"Quantized Module: {module.__class__.__name__:>5}\n"
@@ -199,10 +215,13 @@ def quantize_ao(
199
215
  return module
200
216
 
201
217
 
202
- def force_empty_cache():
203
- time.sleep(1)
204
- gc.collect()
205
- torch.cuda.empty_cache()
206
- time.sleep(1)
207
- gc.collect()
208
- torch.cuda.empty_cache()
218
+ def maybe_empty_cache():
219
+ try:
220
+ time.sleep(1)
221
+ gc.collect()
222
+ torch.cuda.empty_cache()
223
+ time.sleep(1)
224
+ gc.collect()
225
+ torch.cuda.empty_cache()
226
+ except Exception:
227
+ pass
File without changes
File without changes
@@ -7,37 +7,24 @@ logger = init_logger(__name__)
7
7
 
8
8
  def quantize(
9
9
  module: torch.nn.Module,
10
- quant_type: str = "fp8_w8a8_dq",
10
+ quant_type: str = "float8_weight_only",
11
11
  backend: str = "ao",
12
12
  exclude_layers: List[str] = [
13
13
  "embedder",
14
14
  "embed",
15
15
  ],
16
16
  filter_fn: Optional[Callable] = None,
17
- # only for fp8_w8a8_dq
18
- per_row: bool = True,
19
17
  **kwargs,
20
18
  ) -> torch.nn.Module:
21
19
  assert isinstance(module, torch.nn.Module)
22
20
 
23
21
  if backend.lower() in ("ao", "torchao"):
24
- from cache_dit.quantize.quantize_ao import quantize_ao
25
-
26
- quant_type = quant_type.lower()
27
- assert quant_type in (
28
- "fp8_w8a8_dq",
29
- "fp8_w8a16_wo",
30
- "int8_w8a8_dq",
31
- "int8_w8a16_wo",
32
- "int4_w4a8_dq",
33
- "int4_w4a4_dq",
34
- "int4_w4a16_wo",
35
- ), f"{quant_type} is not supported for torchao backend now!"
22
+ from cache_dit.quantize.backends.torchao import quantize_ao
36
23
 
37
24
  return quantize_ao(
38
25
  module,
39
26
  quant_type=quant_type,
40
- per_row=per_row,
27
+ per_row=kwargs.pop("per_row", True),
41
28
  exclude_layers=exclude_layers,
42
29
  filter_fn=filter_fn,
43
30
  **kwargs,
cache_dit/utils.py CHANGED
@@ -13,6 +13,7 @@ from cache_dit.cache_factory import CacheType
13
13
  from cache_dit.cache_factory import BlockAdapter
14
14
  from cache_dit.cache_factory import BasicCacheConfig
15
15
  from cache_dit.cache_factory import CalibratorConfig
16
+ from cache_dit.cache_factory import FakeDiffusionPipeline
16
17
  from cache_dit.parallelism import ParallelismConfig
17
18
  from cache_dit.logger import init_logger
18
19
 
@@ -64,6 +65,7 @@ def summary(
64
65
  adapter_or_others: Union[
65
66
  BlockAdapter,
66
67
  DiffusionPipeline,
68
+ FakeDiffusionPipeline,
67
69
  torch.nn.Module,
68
70
  ],
69
71
  details: bool = False,
@@ -73,9 +75,15 @@ def summary(
73
75
  if adapter_or_others is None:
74
76
  return [CacheStats()]
75
77
 
78
+ if isinstance(adapter_or_others, FakeDiffusionPipeline):
79
+ raise ValueError(
80
+ "Please pass DiffusionPipeline, BlockAdapter or transfomer, "
81
+ "not FakeDiffusionPipeline."
82
+ )
83
+
76
84
  if not isinstance(adapter_or_others, BlockAdapter):
77
85
  if not isinstance(adapter_or_others, DiffusionPipeline):
78
- transformer = adapter_or_others
86
+ transformer = adapter_or_others # transformer-only
79
87
  transformer_2 = None
80
88
  else:
81
89
  transformer = adapter_or_others.transformer
@@ -165,11 +173,18 @@ def strify(
165
173
  adapter_or_others: Union[
166
174
  BlockAdapter,
167
175
  DiffusionPipeline,
176
+ FakeDiffusionPipeline,
177
+ torch.nn.Module,
168
178
  CacheStats,
169
179
  List[CacheStats],
170
180
  Dict[str, Any],
171
181
  ],
172
182
  ) -> str:
183
+ if isinstance(adapter_or_others, FakeDiffusionPipeline):
184
+ raise ValueError(
185
+ "Please pass DiffusionPipeline, BlockAdapter or transfomer, "
186
+ "not FakeDiffusionPipeline."
187
+ )
173
188
 
174
189
  parallelism_config: ParallelismConfig = None
175
190
  if isinstance(adapter_or_others, BlockAdapter):
@@ -180,6 +195,10 @@ def strify(
180
195
  stats = summary(adapter_or_others, logging=False)[-1]
181
196
  cache_options = stats.cache_options
182
197
  cached_steps = len(stats.cached_steps)
198
+ elif isinstance(adapter_or_others, torch.nn.Module):
199
+ stats = summary(adapter_or_others, logging=False)[-1]
200
+ cache_options = stats.cache_options
201
+ cached_steps = len(stats.cached_steps)
183
202
  elif isinstance(adapter_or_others, CacheStats):
184
203
  stats = adapter_or_others
185
204
  cache_options = stats.cache_options
@@ -202,7 +221,8 @@ def strify(
202
221
  else:
203
222
  raise ValueError(
204
223
  "Please set pipe_or_stats param as one of: "
205
- "DiffusionPipeline | CacheStats | Dict[str, Any]"
224
+ "DiffusionPipeline | CacheStats | Dict[str, Any] | List[CacheStats]"
225
+ " | BlockAdapter | Transformer"
206
226
  )
207
227
 
208
228
  if stats is not None: