cache-dit 0.2.27__py3-none-any.whl → 0.2.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/__init__.py CHANGED
@@ -6,10 +6,12 @@ except ImportError:
6
6
 
7
7
  from cache_dit.cache_factory import load_options
8
8
  from cache_dit.cache_factory import enable_cache
9
+ from cache_dit.cache_factory import disable_cache
9
10
  from cache_dit.cache_factory import cache_type
10
11
  from cache_dit.cache_factory import block_range
11
12
  from cache_dit.cache_factory import CacheType
12
13
  from cache_dit.cache_factory import BlockAdapter
14
+ from cache_dit.cache_factory import ParamsModifier
13
15
  from cache_dit.cache_factory import ForwardPattern
14
16
  from cache_dit.cache_factory import PatchFunctor
15
17
  from cache_dit.cache_factory import supported_pipelines
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.27'
32
- __version_tuple__ = version_tuple = (0, 2, 27)
31
+ __version__ = version = '0.2.29'
32
+ __version_tuple__ = version_tuple = (0, 2, 29)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -7,14 +7,17 @@ from cache_dit.cache_factory.forward_pattern import ForwardPattern
7
7
  from cache_dit.cache_factory.patch_functors import PatchFunctor
8
8
 
9
9
  from cache_dit.cache_factory.block_adapters import BlockAdapter
10
+ from cache_dit.cache_factory.block_adapters import ParamsModifier
10
11
  from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
11
12
 
12
13
  from cache_dit.cache_factory.cache_contexts import CachedContext
14
+ from cache_dit.cache_factory.cache_contexts import CachedContextManager
13
15
  from cache_dit.cache_factory.cache_blocks import CachedBlocks
14
16
 
15
17
  from cache_dit.cache_factory.cache_adapters import CachedAdapter
16
18
 
17
19
  from cache_dit.cache_factory.cache_interface import enable_cache
20
+ from cache_dit.cache_factory.cache_interface import disable_cache
18
21
  from cache_dit.cache_factory.cache_interface import supported_pipelines
19
22
  from cache_dit.cache_factory.cache_interface import get_adapter
20
23
 
@@ -1,5 +1,6 @@
1
1
  from cache_dit.cache_factory.forward_pattern import ForwardPattern
2
2
  from cache_dit.cache_factory.block_adapters.block_adapters import BlockAdapter
3
+ from cache_dit.cache_factory.block_adapters.block_adapters import ParamsModifier
3
4
  from cache_dit.cache_factory.block_adapters.block_registers import (
4
5
  BlockAdapterRegistry,
5
6
  )
@@ -8,23 +9,37 @@ from cache_dit.cache_factory.block_adapters.block_registers import (
8
9
  @BlockAdapterRegistry.register("Flux")
9
10
  def flux_adapter(pipe, **kwargs) -> BlockAdapter:
10
11
  from diffusers import FluxTransformer2DModel
11
- from cache_dit.cache_factory.patch_functors import FluxPatchFunctor
12
+ from cache_dit.utils import is_diffusers_at_least_0_3_5
12
13
 
13
14
  assert isinstance(pipe.transformer, FluxTransformer2DModel)
14
-
15
- return BlockAdapter(
16
- pipe=pipe,
17
- transformer=pipe.transformer,
18
- blocks=(
19
- pipe.transformer.transformer_blocks
20
- + pipe.transformer.single_transformer_blocks
21
- ),
22
- blocks_name="transformer_blocks",
23
- dummy_blocks_names=["single_transformer_blocks"],
24
- patch_functor=FluxPatchFunctor(),
25
- forward_pattern=ForwardPattern.Pattern_1,
26
- disable_patch=kwargs.pop("disable_patch", False),
27
- )
15
+ if is_diffusers_at_least_0_3_5():
16
+ return BlockAdapter(
17
+ pipe=pipe,
18
+ transformer=pipe.transformer,
19
+ blocks=[
20
+ pipe.transformer.transformer_blocks,
21
+ pipe.transformer.single_transformer_blocks,
22
+ ],
23
+ forward_pattern=[
24
+ ForwardPattern.Pattern_1,
25
+ ForwardPattern.Pattern_1,
26
+ ],
27
+ **kwargs,
28
+ )
29
+ else:
30
+ return BlockAdapter(
31
+ pipe=pipe,
32
+ transformer=pipe.transformer,
33
+ blocks=[
34
+ pipe.transformer.transformer_blocks,
35
+ pipe.transformer.single_transformer_blocks,
36
+ ],
37
+ forward_pattern=[
38
+ ForwardPattern.Pattern_1,
39
+ ForwardPattern.Pattern_3,
40
+ ],
41
+ **kwargs,
42
+ )
28
43
 
29
44
 
30
45
  @BlockAdapterRegistry.register("Mochi")
@@ -36,9 +51,8 @@ def mochi_adapter(pipe, **kwargs) -> BlockAdapter:
36
51
  pipe=pipe,
37
52
  transformer=pipe.transformer,
38
53
  blocks=pipe.transformer.transformer_blocks,
39
- blocks_name="transformer_blocks",
40
- dummy_blocks_names=[],
41
54
  forward_pattern=ForwardPattern.Pattern_0,
55
+ **kwargs,
42
56
  )
43
57
 
44
58
 
@@ -51,9 +65,8 @@ def cogvideox_adapter(pipe, **kwargs) -> BlockAdapter:
51
65
  pipe=pipe,
52
66
  transformer=pipe.transformer,
53
67
  blocks=pipe.transformer.transformer_blocks,
54
- blocks_name="transformer_blocks",
55
- dummy_blocks_names=[],
56
68
  forward_pattern=ForwardPattern.Pattern_0,
69
+ **kwargs,
57
70
  )
58
71
 
59
72
 
@@ -69,15 +82,27 @@ def wan_adapter(pipe, **kwargs) -> BlockAdapter:
69
82
  (WanTransformer3DModel, WanVACETransformer3DModel),
70
83
  )
71
84
  if getattr(pipe, "transformer_2", None):
72
- # Wan 2.2, cache for low-noise transformer
85
+ assert isinstance(
86
+ pipe.transformer_2,
87
+ (WanTransformer3DModel, WanVACETransformer3DModel),
88
+ )
89
+ # Wan 2.2 MoE
73
90
  return BlockAdapter(
74
91
  pipe=pipe,
75
- transformer=pipe.transformer_2,
76
- blocks=pipe.transformer_2.blocks,
77
- blocks_name="blocks",
78
- dummy_blocks_names=[],
79
- forward_pattern=ForwardPattern.Pattern_2,
92
+ transformer=[
93
+ pipe.transformer,
94
+ pipe.transformer_2,
95
+ ],
96
+ blocks=[
97
+ pipe.transformer.blocks,
98
+ pipe.transformer_2.blocks,
99
+ ],
100
+ forward_pattern=[
101
+ ForwardPattern.Pattern_2,
102
+ ForwardPattern.Pattern_2,
103
+ ],
80
104
  has_separate_cfg=True,
105
+ **kwargs,
81
106
  )
82
107
  else:
83
108
  # Wan 2.1
@@ -85,10 +110,9 @@ def wan_adapter(pipe, **kwargs) -> BlockAdapter:
85
110
  pipe=pipe,
86
111
  transformer=pipe.transformer,
87
112
  blocks=pipe.transformer.blocks,
88
- blocks_name="blocks",
89
- dummy_blocks_names=[],
90
113
  forward_pattern=ForwardPattern.Pattern_2,
91
114
  has_separate_cfg=True,
115
+ **kwargs,
92
116
  )
93
117
 
94
118
 
@@ -99,13 +123,15 @@ def hunyuanvideo_adapter(pipe, **kwargs) -> BlockAdapter:
99
123
  assert isinstance(pipe.transformer, HunyuanVideoTransformer3DModel)
100
124
  return BlockAdapter(
101
125
  pipe=pipe,
102
- blocks=(
103
- pipe.transformer.transformer_blocks
104
- + pipe.transformer.single_transformer_blocks
105
- ),
106
- blocks_name="transformer_blocks",
107
- dummy_blocks_names=["single_transformer_blocks"],
108
- forward_pattern=ForwardPattern.Pattern_0,
126
+ blocks=[
127
+ pipe.transformer.transformer_blocks,
128
+ pipe.transformer.single_transformer_blocks,
129
+ ],
130
+ forward_pattern=[
131
+ ForwardPattern.Pattern_0,
132
+ ForwardPattern.Pattern_0,
133
+ ],
134
+ **kwargs,
109
135
  )
110
136
 
111
137
 
@@ -118,10 +144,9 @@ def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
118
144
  pipe=pipe,
119
145
  transformer=pipe.transformer,
120
146
  blocks=pipe.transformer.transformer_blocks,
121
- blocks_name="transformer_blocks",
122
- dummy_blocks_names=[],
123
147
  forward_pattern=ForwardPattern.Pattern_1,
124
148
  has_separate_cfg=True,
149
+ **kwargs,
125
150
  )
126
151
 
127
152
 
@@ -134,9 +159,8 @@ def ltxvideo_adapter(pipe, **kwargs) -> BlockAdapter:
134
159
  pipe=pipe,
135
160
  transformer=pipe.transformer,
136
161
  blocks=pipe.transformer.transformer_blocks,
137
- blocks_name="transformer_blocks",
138
- dummy_blocks_names=[],
139
162
  forward_pattern=ForwardPattern.Pattern_2,
163
+ **kwargs,
140
164
  )
141
165
 
142
166
 
@@ -149,9 +173,8 @@ def allegro_adapter(pipe, **kwargs) -> BlockAdapter:
149
173
  pipe=pipe,
150
174
  transformer=pipe.transformer,
151
175
  blocks=pipe.transformer.transformer_blocks,
152
- blocks_name="transformer_blocks",
153
- dummy_blocks_names=[],
154
176
  forward_pattern=ForwardPattern.Pattern_2,
177
+ **kwargs,
155
178
  )
156
179
 
157
180
 
@@ -164,9 +187,8 @@ def cogview3plus_adapter(pipe, **kwargs) -> BlockAdapter:
164
187
  pipe=pipe,
165
188
  transformer=pipe.transformer,
166
189
  blocks=pipe.transformer.transformer_blocks,
167
- blocks_name="transformer_blocks",
168
- dummy_blocks_names=[],
169
190
  forward_pattern=ForwardPattern.Pattern_0,
191
+ **kwargs,
170
192
  )
171
193
 
172
194
 
@@ -179,10 +201,9 @@ def cogview4_adapter(pipe, **kwargs) -> BlockAdapter:
179
201
  pipe=pipe,
180
202
  transformer=pipe.transformer,
181
203
  blocks=pipe.transformer.transformer_blocks,
182
- blocks_name="transformer_blocks",
183
- dummy_blocks_names=[],
184
204
  forward_pattern=ForwardPattern.Pattern_0,
185
205
  has_separate_cfg=True,
206
+ **kwargs,
186
207
  )
187
208
 
188
209
 
@@ -195,10 +216,9 @@ def cosmos_adapter(pipe, **kwargs) -> BlockAdapter:
195
216
  pipe=pipe,
196
217
  transformer=pipe.transformer,
197
218
  blocks=pipe.transformer.transformer_blocks,
198
- blocks_name="transformer_blocks",
199
- dummy_blocks_names=[],
200
219
  forward_pattern=ForwardPattern.Pattern_2,
201
220
  has_separate_cfg=True,
221
+ **kwargs,
202
222
  )
203
223
 
204
224
 
@@ -211,9 +231,8 @@ def easyanimate_adapter(pipe, **kwargs) -> BlockAdapter:
211
231
  pipe=pipe,
212
232
  transformer=pipe.transformer,
213
233
  blocks=pipe.transformer.transformer_blocks,
214
- blocks_name="transformer_blocks",
215
- dummy_blocks_names=[],
216
234
  forward_pattern=ForwardPattern.Pattern_0,
235
+ **kwargs,
217
236
  )
218
237
 
219
238
 
@@ -226,10 +245,9 @@ def skyreelsv2_adapter(pipe, **kwargs) -> BlockAdapter:
226
245
  pipe=pipe,
227
246
  transformer=pipe.transformer,
228
247
  blocks=pipe.transformer.blocks,
229
- blocks_name="blocks",
230
- dummy_blocks_names=[],
231
248
  forward_pattern=ForwardPattern.Pattern_2,
232
249
  has_separate_cfg=True,
250
+ **kwargs,
233
251
  )
234
252
 
235
253
 
@@ -242,9 +260,8 @@ def sd3_adapter(pipe, **kwargs) -> BlockAdapter:
242
260
  pipe=pipe,
243
261
  transformer=pipe.transformer,
244
262
  blocks=pipe.transformer.transformer_blocks,
245
- blocks_name="transformer_blocks",
246
- dummy_blocks_names=[],
247
263
  forward_pattern=ForwardPattern.Pattern_1,
264
+ **kwargs,
248
265
  )
249
266
 
250
267
 
@@ -257,9 +274,8 @@ def consisid_adapter(pipe, **kwargs) -> BlockAdapter:
257
274
  pipe=pipe,
258
275
  transformer=pipe.transformer,
259
276
  blocks=pipe.transformer.transformer_blocks,
260
- blocks_name="transformer_blocks",
261
- dummy_blocks_names=[],
262
277
  forward_pattern=ForwardPattern.Pattern_0,
278
+ **kwargs,
263
279
  )
264
280
 
265
281
 
@@ -272,9 +288,8 @@ def dit_adapter(pipe, **kwargs) -> BlockAdapter:
272
288
  pipe=pipe,
273
289
  transformer=pipe.transformer,
274
290
  blocks=pipe.transformer.transformer_blocks,
275
- blocks_name="transformer_blocks",
276
- dummy_blocks_names=[],
277
291
  forward_pattern=ForwardPattern.Pattern_3,
292
+ **kwargs,
278
293
  )
279
294
 
280
295
 
@@ -287,9 +302,8 @@ def amused_adapter(pipe, **kwargs) -> BlockAdapter:
287
302
  pipe=pipe,
288
303
  transformer=pipe.transformer,
289
304
  blocks=pipe.transformer.transformer_layers,
290
- blocks_name="transformer_layers",
291
- dummy_blocks_names=[],
292
305
  forward_pattern=ForwardPattern.Pattern_3,
306
+ **kwargs,
293
307
  )
294
308
 
295
309
 
@@ -301,13 +315,15 @@ def bria_adapter(pipe, **kwargs) -> BlockAdapter:
301
315
  return BlockAdapter(
302
316
  pipe=pipe,
303
317
  transformer=pipe.transformer,
304
- blocks=(
305
- pipe.transformer.transformer_blocks
306
- + pipe.transformer.single_transformer_blocks
307
- ),
308
- blocks_name="transformer_blocks",
309
- dummy_blocks_names=["single_transformer_blocks"],
310
- forward_pattern=ForwardPattern.Pattern_0,
318
+ blocks=[
319
+ pipe.transformer.transformer_blocks,
320
+ pipe.transformer.single_transformer_blocks,
321
+ ],
322
+ forward_pattern=[
323
+ ForwardPattern.Pattern_0,
324
+ ForwardPattern.Pattern_0,
325
+ ],
326
+ **kwargs,
311
327
  )
312
328
 
313
329
 
@@ -323,9 +339,8 @@ def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
323
339
  pipe=pipe,
324
340
  transformer=pipe.transformer,
325
341
  blocks=pipe.transformer.blocks,
326
- blocks_name="blocks",
327
- dummy_blocks_names=[],
328
342
  forward_pattern=ForwardPattern.Pattern_3,
343
+ **kwargs,
329
344
  )
330
345
 
331
346
 
@@ -338,14 +353,13 @@ def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
338
353
  pipe=pipe,
339
354
  transformer=pipe.transformer,
340
355
  blocks=pipe.transformer.blocks,
341
- blocks_name="blocks",
342
- dummy_blocks_names=[],
343
356
  forward_pattern=ForwardPattern.Pattern_3,
357
+ **kwargs,
344
358
  )
345
359
 
346
360
 
347
361
  @BlockAdapterRegistry.register("Lumina")
348
- def lumina_adapter(pipe) -> BlockAdapter:
362
+ def lumina_adapter(pipe, **kwargs) -> BlockAdapter:
349
363
  from diffusers import LuminaNextDiT2DModel
350
364
 
351
365
  assert isinstance(pipe.transformer, LuminaNextDiT2DModel)
@@ -353,9 +367,8 @@ def lumina_adapter(pipe) -> BlockAdapter:
353
367
  pipe=pipe,
354
368
  transformer=pipe.transformer,
355
369
  blocks=pipe.transformer.layers,
356
- blocks_name="layers",
357
- dummy_blocks_names=[],
358
370
  forward_pattern=ForwardPattern.Pattern_3,
371
+ **kwargs,
359
372
  )
360
373
 
361
374
 
@@ -368,9 +381,8 @@ def lumina2_adapter(pipe, **kwargs) -> BlockAdapter:
368
381
  pipe=pipe,
369
382
  transformer=pipe.transformer,
370
383
  blocks=pipe.transformer.layers,
371
- blocks_name="layers",
372
- dummy_blocks_names=[],
373
384
  forward_pattern=ForwardPattern.Pattern_3,
385
+ **kwargs,
374
386
  )
375
387
 
376
388
 
@@ -383,9 +395,8 @@ def omnigen_adapter(pipe, **kwargs) -> BlockAdapter:
383
395
  pipe=pipe,
384
396
  transformer=pipe.transformer,
385
397
  blocks=pipe.transformer.layers,
386
- blocks_name="layers",
387
- dummy_blocks_names=[],
388
398
  forward_pattern=ForwardPattern.Pattern_3,
399
+ **kwargs,
389
400
  )
390
401
 
391
402
 
@@ -398,9 +409,8 @@ def pixart_adapter(pipe, **kwargs) -> BlockAdapter:
398
409
  pipe=pipe,
399
410
  transformer=pipe.transformer,
400
411
  blocks=pipe.transformer.transformer_blocks,
401
- blocks_name="transformer_blocks",
402
- dummy_blocks_names=[],
403
412
  forward_pattern=ForwardPattern.Pattern_3,
413
+ **kwargs,
404
414
  )
405
415
 
406
416
 
@@ -413,9 +423,8 @@ def sana_adapter(pipe, **kwargs) -> BlockAdapter:
413
423
  pipe=pipe,
414
424
  transformer=pipe.transformer,
415
425
  blocks=pipe.transformer.transformer_blocks,
416
- blocks_name="transformer_blocks",
417
- dummy_blocks_names=[],
418
426
  forward_pattern=ForwardPattern.Pattern_3,
427
+ **kwargs,
419
428
  )
420
429
 
421
430
 
@@ -428,9 +437,8 @@ def shape_adapter(pipe, **kwargs) -> BlockAdapter:
428
437
  pipe=pipe,
429
438
  transformer=pipe.prior,
430
439
  blocks=pipe.prior.transformer_blocks,
431
- blocks_name="transformer_blocks",
432
- dummy_blocks_names=[],
433
440
  forward_pattern=ForwardPattern.Pattern_3,
441
+ **kwargs,
434
442
  )
435
443
 
436
444
 
@@ -443,29 +451,28 @@ def stabledudio_adapter(pipe, **kwargs) -> BlockAdapter:
443
451
  pipe=pipe,
444
452
  transformer=pipe.transformer,
445
453
  blocks=pipe.transformer.transformer_blocks,
446
- blocks_name="transformer_blocks",
447
- dummy_blocks_names=[],
448
454
  forward_pattern=ForwardPattern.Pattern_3,
455
+ **kwargs,
449
456
  )
450
457
 
451
458
 
452
459
  @BlockAdapterRegistry.register("VisualCloze")
453
460
  def visualcloze_adapter(pipe, **kwargs) -> BlockAdapter:
454
461
  from diffusers import FluxTransformer2DModel
455
- from cache_dit.cache_factory.patch_functors import FluxPatchFunctor
456
462
 
457
463
  assert isinstance(pipe.transformer, FluxTransformer2DModel)
458
464
  return BlockAdapter(
459
465
  pipe=pipe,
460
466
  transformer=pipe.transformer,
461
- blocks=(
462
- pipe.transformer.transformer_blocks
463
- + pipe.transformer.single_transformer_blocks
464
- ),
465
- blocks_name="transformer_blocks",
466
- dummy_blocks_names=["single_transformer_blocks"],
467
- patch_functor=FluxPatchFunctor(),
468
- forward_pattern=ForwardPattern.Pattern_1,
467
+ blocks=[
468
+ pipe.transformer.transformer_blocks,
469
+ pipe.transformer.single_transformer_blocks,
470
+ ],
471
+ forward_pattern=[
472
+ ForwardPattern.Pattern_1,
473
+ ForwardPattern.Pattern_3,
474
+ ],
475
+ **kwargs,
469
476
  )
470
477
 
471
478
 
@@ -477,19 +484,9 @@ def auraflow_adapter(pipe, **kwargs) -> BlockAdapter:
477
484
  return BlockAdapter(
478
485
  pipe=pipe,
479
486
  transformer=pipe.transformer,
480
- blocks=[
481
- # Only 4 joint blocks, apply no-cache
482
- # pipe.transformer.joint_transformer_blocks,
483
- pipe.transformer.single_transformer_blocks,
484
- ],
485
- blocks_name=[
486
- # "joint_transformer_blocks",
487
- "single_transformer_blocks",
488
- ],
489
- forward_pattern=[
490
- # ForwardPattern.Pattern_1,
491
- ForwardPattern.Pattern_3,
492
- ],
487
+ blocks=pipe.transformer.single_transformer_blocks,
488
+ forward_pattern=ForwardPattern.Pattern_3,
489
+ **kwargs,
493
490
  )
494
491
 
495
492
 
@@ -505,15 +502,12 @@ def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
505
502
  pipe.transformer.transformer_blocks,
506
503
  pipe.transformer.single_transformer_blocks,
507
504
  ],
508
- blocks_name=[
509
- "transformer_blocks",
510
- "single_transformer_blocks",
511
- ],
512
505
  forward_pattern=[
513
506
  ForwardPattern.Pattern_1,
514
507
  ForwardPattern.Pattern_3,
515
508
  ],
516
509
  has_separate_cfg=True,
510
+ **kwargs,
517
511
  )
518
512
 
519
513
 
@@ -529,10 +523,10 @@ def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
529
523
  pipe.transformer.double_stream_blocks,
530
524
  pipe.transformer.single_stream_blocks,
531
525
  ],
532
- dummy_blocks_names=[],
533
526
  forward_pattern=[
534
527
  ForwardPattern.Pattern_4,
535
528
  ForwardPattern.Pattern_3,
536
529
  ],
537
530
  check_num_outputs=False,
531
+ **kwargs,
538
532
  )