cache-dit 1.0.2__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '1.0.2'
32
- __version_tuple__ = version_tuple = (1, 0, 2)
31
+ __version__ = version = '1.0.3'
32
+ __version_tuple__ = version_tuple = (1, 0, 3)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -33,14 +33,14 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
33
33
  *args,
34
34
  **kwargs,
35
35
  )
36
- hidden_states, new_encoder_hidden_states = self._process_outputs(
37
- hidden_states
36
+ hidden_states, new_encoder_hidden_states = (
37
+ self._process_block_outputs(hidden_states)
38
38
  )
39
39
 
40
40
  return hidden_states, new_encoder_hidden_states
41
41
 
42
42
  @torch.compiler.disable
43
- def _process_outputs(
43
+ def _process_block_outputs(
44
44
  self, hidden_states: torch.Tensor | tuple
45
45
  ) -> tuple[torch.Tensor, torch.Tensor | None]:
46
46
  # Process the outputs for the block.
@@ -66,7 +66,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
66
66
  return hidden_states, new_encoder_hidden_states
67
67
 
68
68
  @torch.compiler.disable
69
- def _forward_outputs(
69
+ def _process_forward_outputs(
70
70
  self,
71
71
  hidden_states: torch.Tensor,
72
72
  new_encoder_hidden_states: torch.Tensor | None,
@@ -100,7 +100,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
100
100
  *args,
101
101
  **kwargs,
102
102
  )
103
- return self._forward_outputs(
103
+ return self._process_forward_outputs(
104
104
  hidden_states, new_encoder_hidden_states
105
105
  )
106
106
 
@@ -227,7 +227,10 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
227
227
 
228
228
  torch._dynamo.graph_break()
229
229
 
230
- return self._forward_outputs(hidden_states, new_encoder_hidden_states)
230
+ return self._process_forward_outputs(
231
+ hidden_states,
232
+ new_encoder_hidden_states,
233
+ )
231
234
 
232
235
  def call_Fn_blocks(
233
236
  self,
@@ -242,8 +245,8 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
242
245
  *args,
243
246
  **kwargs,
244
247
  )
245
- hidden_states, new_encoder_hidden_states = self._process_outputs(
246
- hidden_states
248
+ hidden_states, new_encoder_hidden_states = (
249
+ self._process_block_outputs(hidden_states)
247
250
  )
248
251
 
249
252
  return hidden_states, new_encoder_hidden_states
@@ -263,8 +266,8 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
263
266
  **kwargs,
264
267
  )
265
268
 
266
- hidden_states, new_encoder_hidden_states = self._process_outputs(
267
- hidden_states
269
+ hidden_states, new_encoder_hidden_states = (
270
+ self._process_block_outputs(hidden_states)
268
271
  )
269
272
 
270
273
  # compute hidden_states residual
@@ -296,8 +299,8 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
296
299
  **kwargs,
297
300
  )
298
301
 
299
- hidden_states, new_encoder_hidden_states = self._process_outputs(
300
- hidden_states
302
+ hidden_states, new_encoder_hidden_states = (
303
+ self._process_block_outputs(hidden_states)
301
304
  )
302
305
 
303
306
  return hidden_states, new_encoder_hidden_states
@@ -135,7 +135,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
135
135
  return hidden_states, encoder_hidden_states
136
136
 
137
137
  @torch.compiler.disable
138
- def _process_outputs(
138
+ def _process_block_outputs(
139
139
  self,
140
140
  hidden_states: torch.Tensor | tuple,
141
141
  encoder_hidden_states: torch.Tensor | None,
@@ -150,7 +150,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
150
150
  return hidden_states, encoder_hidden_states
151
151
 
152
152
  @torch.compiler.disable
153
- def _forward_outputs(
153
+ def _process_forward_outputs(
154
154
  self,
155
155
  hidden_states: torch.Tensor,
156
156
  encoder_hidden_states: torch.Tensor | None,
@@ -185,7 +185,10 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
185
185
  *args,
186
186
  **kwargs,
187
187
  )
188
- return self._forward_outputs(hidden_states, encoder_hidden_states)
188
+ return self._process_forward_outputs(
189
+ hidden_states,
190
+ encoder_hidden_states,
191
+ )
189
192
 
190
193
  original_hidden_states = hidden_states
191
194
  # Call first `n` blocks to process the hidden states for
@@ -304,7 +307,10 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
304
307
  # patch cached stats for blocks or remove it.
305
308
  torch._dynamo.graph_break()
306
309
 
307
- return self._forward_outputs(hidden_states, encoder_hidden_states)
310
+ return self._process_forward_outputs(
311
+ hidden_states,
312
+ encoder_hidden_states,
313
+ )
308
314
 
309
315
  @torch.compiler.disable
310
316
  def _is_parallelized(self):
@@ -379,7 +385,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
379
385
  *args,
380
386
  **kwargs,
381
387
  )
382
- hidden_states, encoder_hidden_states = self._process_outputs(
388
+ hidden_states, encoder_hidden_states = self._process_block_outputs(
383
389
  hidden_states, encoder_hidden_states
384
390
  )
385
391
 
@@ -401,7 +407,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
401
407
  *args,
402
408
  **kwargs,
403
409
  )
404
- hidden_states, encoder_hidden_states = self._process_outputs(
410
+ hidden_states, encoder_hidden_states = self._process_block_outputs(
405
411
  hidden_states, encoder_hidden_states
406
412
  )
407
413
 
@@ -445,7 +451,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
445
451
  *args,
446
452
  **kwargs,
447
453
  )
448
- hidden_states, encoder_hidden_states = self._process_outputs(
454
+ hidden_states, encoder_hidden_states = self._process_block_outputs(
449
455
  hidden_states, encoder_hidden_states
450
456
  )
451
457
 
@@ -38,6 +38,10 @@ class BasicCacheConfig:
38
38
  # DBCache does not apply the caching strategy when the number of running steps is less than
39
39
  # or equal to this value, ensuring the model sufficiently learns basic features during warmup.
40
40
  max_warmup_steps: int = 8 # DON'T Cache in warmup steps
41
+ # warmup_interval (`int`, *required*, defaults to 1):
42
+ # Skip interval in warmup steps, e.g., when warmup_interval is 2, only 0, 2, 4, ... steps
43
+ # in warmup steps will be computed, others will use dynamic cache.
44
+ warmup_interval: int = 1 # skip interval in warmup steps
41
45
  # max_cached_steps (`int`, *required*, defaults to -1):
42
46
  # DBCache disables the caching strategy when the previous cached steps exceed this value to
43
47
  # prevent precision degradation.
@@ -71,6 +75,7 @@ class BasicCacheConfig:
71
75
  f"DBCACHE_F{self.Fn_compute_blocks}"
72
76
  f"B{self.Bn_compute_blocks}_"
73
77
  f"W{self.max_warmup_steps}"
78
+ f"I{self.warmup_interval}"
74
79
  f"M{max(0, self.max_cached_steps)}"
75
80
  f"MC{max(0, self.max_continuous_cached_steps)}_"
76
81
  f"R{self.residual_diff_threshold}"
@@ -346,5 +351,15 @@ class CachedContext:
346
351
  # CFG steps: 1, 3, 5, 7, ...
347
352
  return self.get_current_transformer_step() % 2 != 0
348
353
 
354
+ @property
355
+ def warmup_steps(self) -> List[int]:
356
+ return list(
357
+ range(
358
+ 0,
359
+ self.cache_config.max_warmup_steps,
360
+ self.cache_config.warmup_interval,
361
+ )
362
+ )
363
+
349
364
  def is_in_warmup(self):
350
- return self.get_current_step() < self.cache_config.max_warmup_steps
365
+ return self.get_current_step() in self.warmup_steps
@@ -86,6 +86,9 @@ def enable_cache(
86
86
  max_warmup_steps (`int`, *required*, defaults to 8):
87
87
  DBCache does not apply the caching strategy when the number of running steps is less than
88
88
  or equal to this value, ensuring the model sufficiently learns basic features during warmup.
89
+ warmup_interval (`int`, *required*, defaults to 1):
90
+ Skip interval in warmup steps, e.g., when warmup_interval is 2, only 0, 2, 4, ... steps
91
+ in warmup steps will be computed, others will use dynamic cache.
89
92
  max_cached_steps (`int`, *required*, defaults to -1):
90
93
  DBCache disables the caching strategy when the previous cached steps exceed this value to
91
94
  prevent precision degradation.
@@ -1,14 +1,12 @@
1
1
  import yaml
2
- from cache_dit.cache_factory import CacheType
3
2
 
4
3
 
5
4
  def load_cache_options_from_yaml(yaml_file_path):
6
5
  try:
7
6
  with open(yaml_file_path, "r") as f:
8
- config = yaml.safe_load(f)
7
+ kwargs: dict = yaml.safe_load(f)
9
8
 
10
9
  required_keys = [
11
- "cache_type",
12
10
  "max_warmup_steps",
13
11
  "max_cached_steps",
14
12
  "Fn_compute_blocks",
@@ -16,34 +14,36 @@ def load_cache_options_from_yaml(yaml_file_path):
16
14
  "residual_diff_threshold",
17
15
  ]
18
16
  for key in required_keys:
19
- if key not in config:
17
+ if key not in kwargs:
20
18
  raise ValueError(
21
19
  f"Configuration file missing required item: {key}"
22
20
  )
23
21
 
24
- # Convert cache_type to CacheType enum
25
- if isinstance(config["cache_type"], str):
26
- try:
27
- config["cache_type"] = CacheType[config["cache_type"]]
28
- except KeyError:
29
- valid_types = [ct.name for ct in CacheType]
30
- raise ValueError(
31
- f"Invalid cache_type value: {config['cache_type']}, "
32
- f"valid values are: {valid_types}"
22
+ cache_context_kwargs = {}
23
+ if kwargs.get("enable_taylorseer", False):
24
+ from cache_dit.cache_factory.cache_contexts.calibrators import (
25
+ TaylorSeerCalibratorConfig,
26
+ )
27
+
28
+ cache_context_kwargs["calibrator_config"] = (
29
+ TaylorSeerCalibratorConfig(
30
+ enable_calibrator=kwargs.pop("enable_taylorseer"),
31
+ enable_encoder_calibrator=kwargs.pop(
32
+ "enable_encoder_taylorseer", False
33
+ ),
34
+ calibrator_cache_type=kwargs.pop(
35
+ "taylorseer_cache_type", "residual"
36
+ ),
37
+ taylorseer_order=kwargs.pop("taylorseer_order", 1),
33
38
  )
34
- elif not isinstance(config["cache_type"], CacheType):
35
- raise ValueError(
36
- f"cache_type must be a string or CacheType enum, "
37
- f"got: {type(config['cache_type'])}"
38
39
  )
39
40
 
40
- # Handle default value for taylorseer_kwargs
41
- if "taylorseer_kwargs" not in config and config.get(
42
- "enable_taylorseer", False
43
- ):
44
- config["taylorseer_kwargs"] = {"n_derivatives": 2}
41
+ from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
42
+
43
+ cache_context_kwargs["cache_config"] = BasicCacheConfig()
44
+ cache_context_kwargs["cache_config"].update(**kwargs)
45
45
 
46
- return config
46
+ return cache_context_kwargs
47
47
 
48
48
  except FileNotFoundError:
49
49
  raise FileNotFoundError(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 1.0.2
3
+ Version: 1.0.3
4
4
  Summary: A Unified, Flexible and Training-free Cache Acceleration Framework for 🤗Diffusers.
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -247,6 +247,8 @@ For more advanced features such as **Unified Cache APIs**, **Forward Pattern Mat
247
247
  - [⚙️Torch Compile](./docs/User_Guide.md#️torch-compile)
248
248
  - [📚API Documents](./docs/User_Guide.md#api-documentation)
249
249
 
250
+
251
+
250
252
  ## 👋Contribute
251
253
  <div id="contribute"></div>
252
254
 
@@ -260,8 +262,13 @@ How to contribute? Star ⭐️ this repo to support us or check [CONTRIBUTE.md](
260
262
  <img alt="Star History Chart" src="https://api.star-history.com/svg?repos=vipshop/cache-dit&type=Date" width=400px />
261
263
  </picture>
262
264
  </a>
265
+
263
266
  </div>
264
267
 
268
+ ## 🎉Projects Using CacheDiT
269
+
270
+ Here is a curated list of open-source projects integrating **CacheDiT**, including popular repositories like [jetson-containers](https://github.com/dusty-nv/jetson-containers/blob/master/packages/diffusion/cache_edit/build.sh) ![](https://img.shields.io/github/stars/dusty-nv/jetson-containers.svg), [flux-fast](https://github.com/huggingface/flux-fast) ![](https://img.shields.io/github/stars/huggingface/flux-fast.svg), and [sdnext](https://github.com/vladmandic/sdnext/blob/dev/modules/cachedit.py) ![](https://img.shields.io/github/stars/vladmandic/sdnext.svg). **CacheDiT** has also been **recommended** by [Wan2.2](https://github.com/Wan-Video/Wan2.2) ![](https://img.shields.io/github/stars/Wan-Video/Wan2.2.svg), [Qwen-Image-Lightning](https://github.com/ModelTC/Qwen-Image-Lightning) ![](https://img.shields.io/github/stars/ModelTC/Qwen-Image-Lightning.svg), [Qwen-Image](https://github.com/QwenLM/Qwen-Image) ![](https://img.shields.io/github/stars/QwenLM/Qwen-Image.svg), and <a href="https://huggingface.co/docs/diffusers/main/en/optimization/cache_dit"><img src="https://img.shields.io/badge/🤗Diffusers-ecosystem-yellow.svg"></a> ![](https://img.shields.io/github/stars/huggingface/diffusers.svg), among others. We would be grateful if you could let us know if you have used CacheDiT.
271
+
265
272
  ## ©️Acknowledgements
266
273
 
267
274
  <div id="Acknowledgements"></div>
@@ -1,14 +1,14 @@
1
1
  cache_dit/__init__.py,sha256=sHRg0swXZZiw6lvSQ53fcVtN9JRayx0az2lXAz5OOGI,1510
2
- cache_dit/_version.py,sha256=ZTgKq8LPNy3l9uR2ke-VtLhvvl5l71frQ9wO76n1L5k,704
2
+ cache_dit/_version.py,sha256=l8k828IdTfzXAlmx4oT8GsiIf2eeMAlFDALjoYk-jrU,704
3
3
  cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
4
  cache_dit/utils.py,sha256=AyYRwi5XBxYBH4GaXxOxv9-X24Te_IYOYwh54t_1d3A,10674
5
5
  cache_dit/cache_factory/.gitignore,sha256=5Cb-qT9wsTUoMJ7vACDF7ZcLpAXhi5v-xdcWSRit988,23
6
6
  cache_dit/cache_factory/__init__.py,sha256=vy9I6Ofkj9jWeUoOvh-cY5a9QlDDKfj2FVPlVTf7BeA,1390
7
- cache_dit/cache_factory/cache_interface.py,sha256=KseSPyZ9D3m6pmpE7k-uYr0wfBI-hhscG1Nw54GCHxk,12316
7
+ cache_dit/cache_factory/cache_interface.py,sha256=fJgsOSR_lP0cvNDrR0zMLLoZBZC6tLAQaPQs_oo2R1o,12577
8
8
  cache_dit/cache_factory/cache_types.py,sha256=ooukxQRG55uTLmaZ0SKw6gIeY6SQHhMxkbv55uj2Sqk,991
9
9
  cache_dit/cache_factory/forward_pattern.py,sha256=FumlCuZ-TSmSYH0hGBHctSJ-oGLCftdZjLygqhsmdR4,2258
10
10
  cache_dit/cache_factory/params_modifier.py,sha256=zYJJsInTYCaYHBZ7mZJOP-PZnkSg3iN1WPewNOayXos,3628
11
- cache_dit/cache_factory/utils.py,sha256=XkVM9AXcB9zYq8-S8QKAsGz80r3tA6U3lBNGDGeHOe4,1871
11
+ cache_dit/cache_factory/utils.py,sha256=mm8JNu6XG_w6nMYvv53TmugSb-l3W7l3Y4rJ2xBgktY,1891
12
12
  cache_dit/cache_factory/block_adapters/__init__.py,sha256=vM3aDMzPY79Tw4L0hlV2PdA3MFYomnf0eo0BGBo9P78,18087
13
13
  cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=2TVK_KqiYXC7AKZ2s07fzdOzUoeUBc9P1SzQtLVzhf4,22249
14
14
  cache_dit/cache_factory/block_adapters/block_registers.py,sha256=2L7QeM4ygnaKQpC9PoJod0QRYyxidUKU2AYpysDCUwE,2572
@@ -17,11 +17,11 @@ cache_dit/cache_factory/cache_adapters/cache_adapter.py,sha256=HTyZdspd34G6QiJ2q
17
17
  cache_dit/cache_factory/cache_blocks/__init__.py,sha256=mivvm8YOfqT7YHs8y_MzGOGztPw8LxAqKGXuSRXxCv0,3032
18
18
  cache_dit/cache_factory/cache_blocks/offload_utils.py,sha256=wusgcqaCrwEjvv7Guy-6VXhNOgPPUrBV2sSVuRmGuvo,3513
19
19
  cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py,sha256=ElMps6_7uI74tSF9GDR_dEI0bZEhdzcepM29xFWnYo8,428
20
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py,sha256=rfq5-WEt-ErY28vcB4ur9E-uCb6BKP0S8v5lTw61ROk,10555
21
- cache_dit/cache_factory/cache_blocks/pattern_base.py,sha256=StNW2PyDiXEIxZd30byPUrZZ8jgSiuC_yrly2w7X2LQ,16176
20
+ cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py,sha256=mzs1S2YFwNAPMMTisTKbU6GA5m60J_20CAVy9OIWoMQ,10652
21
+ cache_dit/cache_factory/cache_blocks/pattern_base.py,sha256=UeBYEz3hamO3CyVMj1KI7GnxRVQGBjQ5EJi90obVZyI,16306
22
22
  cache_dit/cache_factory/cache_blocks/pattern_utils.py,sha256=dGOC1tMMOvcbvEgx44eTESKn_jsv-0RZ3tRHPa3wmQ4,1315
23
23
  cache_dit/cache_factory/cache_contexts/__init__.py,sha256=N3SxFnluXk5q09nhSqKIJCVzEGWzySJWm-vic6dH79E,412
24
- cache_dit/cache_factory/cache_contexts/cache_context.py,sha256=3EhaMCz3VUQ_NF81VgYwWoSEGIvhScPxPYhjL1OcgxE,15240
24
+ cache_dit/cache_factory/cache_contexts/cache_context.py,sha256=FXvrR3XZr4iIsKSTBngzaRM6_WxiHkRNQ3wAJz40kbk,15798
25
25
  cache_dit/cache_factory/cache_contexts/cache_manager.py,sha256=X99XnmiY-Us8D2pqJGPKxWcXAhQQpk3xdEWOOOYXIZ4,30465
26
26
  cache_dit/cache_factory/cache_contexts/calibrators/__init__.py,sha256=mzYXO8tbytGpJJ9rpPu20kMoj1Iu_7Ym9tjfzV8rA98,5574
27
27
  cache_dit/cache_factory/cache_contexts/calibrators/base.py,sha256=mn6ZBkChGpGwN5csrHTUGMoX6BBPvqHXSLbIExiW-EU,748
@@ -50,9 +50,9 @@ cache_dit/metrics/metrics.py,sha256=AZbQyoavE-djvyRUZ_EfCIrWSQbiWQFo7n2dhn7XptE,
50
50
  cache_dit/quantize/__init__.py,sha256=kWYoMAyZgBXu9BJlZjTQ0dRffW9GqeeY9_iTkXrb70A,59
51
51
  cache_dit/quantize/quantize_ao.py,sha256=Pr3u3Qr6qLvFkd8k-_rfcz4Mkjlg36U9BHG2t6Bl-6M,6301
52
52
  cache_dit/quantize/quantize_interface.py,sha256=2s_R7xPSKuJeFpEGeLwRxnq_CqJcBG3a3lzyW5wh-UM,1241
53
- cache_dit-1.0.2.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
54
- cache_dit-1.0.2.dist-info/METADATA,sha256=E6MkP_T9cwJEbqWE1DIRVkQLI7wLWr5zryY2poWgkyw,26766
55
- cache_dit-1.0.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
56
- cache_dit-1.0.2.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
57
- cache_dit-1.0.2.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
58
- cache_dit-1.0.2.dist-info/RECORD,,
53
+ cache_dit-1.0.3.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
54
+ cache_dit-1.0.3.dist-info/METADATA,sha256=gPY4pnvl4dvTTu7Twv6unzEesu1fXCDlGNMlSdFP3Lc,28103
55
+ cache_dit-1.0.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
56
+ cache_dit-1.0.3.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
57
+ cache_dit-1.0.3.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
58
+ cache_dit-1.0.3.dist-info/RECORD,,