cache-dit 0.2.28__py3-none-any.whl → 0.2.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/__init__.py CHANGED
@@ -6,6 +6,7 @@ except ImportError:
6
6
 
7
7
  from cache_dit.cache_factory import load_options
8
8
  from cache_dit.cache_factory import enable_cache
9
+ from cache_dit.cache_factory import disable_cache
9
10
  from cache_dit.cache_factory import cache_type
10
11
  from cache_dit.cache_factory import block_range
11
12
  from cache_dit.cache_factory import CacheType
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.28'
32
- __version_tuple__ = version_tuple = (0, 2, 28)
31
+ __version__ = version = '0.2.29'
32
+ __version_tuple__ = version_tuple = (0, 2, 29)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -17,6 +17,7 @@ from cache_dit.cache_factory.cache_blocks import CachedBlocks
17
17
  from cache_dit.cache_factory.cache_adapters import CachedAdapter
18
18
 
19
19
  from cache_dit.cache_factory.cache_interface import enable_cache
20
+ from cache_dit.cache_factory.cache_interface import disable_cache
20
21
  from cache_dit.cache_factory.cache_interface import supported_pipelines
21
22
  from cache_dit.cache_factory.cache_interface import get_adapter
22
23
 
@@ -9,23 +9,37 @@ from cache_dit.cache_factory.block_adapters.block_registers import (
9
9
  @BlockAdapterRegistry.register("Flux")
10
10
  def flux_adapter(pipe, **kwargs) -> BlockAdapter:
11
11
  from diffusers import FluxTransformer2DModel
12
- from cache_dit.cache_factory.patch_functors import FluxPatchFunctor
12
+ from cache_dit.utils import is_diffusers_at_least_0_3_5
13
13
 
14
14
  assert isinstance(pipe.transformer, FluxTransformer2DModel)
15
-
16
- return BlockAdapter(
17
- pipe=pipe,
18
- transformer=pipe.transformer,
19
- blocks=(
20
- pipe.transformer.transformer_blocks
21
- + pipe.transformer.single_transformer_blocks
22
- ),
23
- blocks_name="transformer_blocks",
24
- dummy_blocks_names=["single_transformer_blocks"],
25
- patch_functor=FluxPatchFunctor(),
26
- forward_pattern=ForwardPattern.Pattern_1,
27
- disable_patch=kwargs.pop("disable_patch", False),
28
- )
15
+ if is_diffusers_at_least_0_3_5():
16
+ return BlockAdapter(
17
+ pipe=pipe,
18
+ transformer=pipe.transformer,
19
+ blocks=[
20
+ pipe.transformer.transformer_blocks,
21
+ pipe.transformer.single_transformer_blocks,
22
+ ],
23
+ forward_pattern=[
24
+ ForwardPattern.Pattern_1,
25
+ ForwardPattern.Pattern_1,
26
+ ],
27
+ **kwargs,
28
+ )
29
+ else:
30
+ return BlockAdapter(
31
+ pipe=pipe,
32
+ transformer=pipe.transformer,
33
+ blocks=[
34
+ pipe.transformer.transformer_blocks,
35
+ pipe.transformer.single_transformer_blocks,
36
+ ],
37
+ forward_pattern=[
38
+ ForwardPattern.Pattern_1,
39
+ ForwardPattern.Pattern_3,
40
+ ],
41
+ **kwargs,
42
+ )
29
43
 
30
44
 
31
45
  @BlockAdapterRegistry.register("Mochi")
@@ -37,9 +51,8 @@ def mochi_adapter(pipe, **kwargs) -> BlockAdapter:
37
51
  pipe=pipe,
38
52
  transformer=pipe.transformer,
39
53
  blocks=pipe.transformer.transformer_blocks,
40
- blocks_name="transformer_blocks",
41
- dummy_blocks_names=[],
42
54
  forward_pattern=ForwardPattern.Pattern_0,
55
+ **kwargs,
43
56
  )
44
57
 
45
58
 
@@ -52,9 +65,8 @@ def cogvideox_adapter(pipe, **kwargs) -> BlockAdapter:
52
65
  pipe=pipe,
53
66
  transformer=pipe.transformer,
54
67
  blocks=pipe.transformer.transformer_blocks,
55
- blocks_name="transformer_blocks",
56
- dummy_blocks_names=[],
57
68
  forward_pattern=ForwardPattern.Pattern_0,
69
+ **kwargs,
58
70
  )
59
71
 
60
72
 
@@ -85,16 +97,12 @@ def wan_adapter(pipe, **kwargs) -> BlockAdapter:
85
97
  pipe.transformer.blocks,
86
98
  pipe.transformer_2.blocks,
87
99
  ],
88
- blocks_name=[
89
- "blocks",
90
- "blocks",
91
- ],
92
100
  forward_pattern=[
93
101
  ForwardPattern.Pattern_2,
94
102
  ForwardPattern.Pattern_2,
95
103
  ],
96
- dummy_blocks_names=[],
97
104
  has_separate_cfg=True,
105
+ **kwargs,
98
106
  )
99
107
  else:
100
108
  # Wan 2.1
@@ -102,10 +110,9 @@ def wan_adapter(pipe, **kwargs) -> BlockAdapter:
102
110
  pipe=pipe,
103
111
  transformer=pipe.transformer,
104
112
  blocks=pipe.transformer.blocks,
105
- blocks_name="blocks",
106
- dummy_blocks_names=[],
107
113
  forward_pattern=ForwardPattern.Pattern_2,
108
114
  has_separate_cfg=True,
115
+ **kwargs,
109
116
  )
110
117
 
111
118
 
@@ -116,13 +123,15 @@ def hunyuanvideo_adapter(pipe, **kwargs) -> BlockAdapter:
116
123
  assert isinstance(pipe.transformer, HunyuanVideoTransformer3DModel)
117
124
  return BlockAdapter(
118
125
  pipe=pipe,
119
- blocks=(
120
- pipe.transformer.transformer_blocks
121
- + pipe.transformer.single_transformer_blocks
122
- ),
123
- blocks_name="transformer_blocks",
124
- dummy_blocks_names=["single_transformer_blocks"],
125
- forward_pattern=ForwardPattern.Pattern_0,
126
+ blocks=[
127
+ pipe.transformer.transformer_blocks,
128
+ pipe.transformer.single_transformer_blocks,
129
+ ],
130
+ forward_pattern=[
131
+ ForwardPattern.Pattern_0,
132
+ ForwardPattern.Pattern_0,
133
+ ],
134
+ **kwargs,
126
135
  )
127
136
 
128
137
 
@@ -135,10 +144,9 @@ def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
135
144
  pipe=pipe,
136
145
  transformer=pipe.transformer,
137
146
  blocks=pipe.transformer.transformer_blocks,
138
- blocks_name="transformer_blocks",
139
- dummy_blocks_names=[],
140
147
  forward_pattern=ForwardPattern.Pattern_1,
141
148
  has_separate_cfg=True,
149
+ **kwargs,
142
150
  )
143
151
 
144
152
 
@@ -151,9 +159,8 @@ def ltxvideo_adapter(pipe, **kwargs) -> BlockAdapter:
151
159
  pipe=pipe,
152
160
  transformer=pipe.transformer,
153
161
  blocks=pipe.transformer.transformer_blocks,
154
- blocks_name="transformer_blocks",
155
- dummy_blocks_names=[],
156
162
  forward_pattern=ForwardPattern.Pattern_2,
163
+ **kwargs,
157
164
  )
158
165
 
159
166
 
@@ -166,9 +173,8 @@ def allegro_adapter(pipe, **kwargs) -> BlockAdapter:
166
173
  pipe=pipe,
167
174
  transformer=pipe.transformer,
168
175
  blocks=pipe.transformer.transformer_blocks,
169
- blocks_name="transformer_blocks",
170
- dummy_blocks_names=[],
171
176
  forward_pattern=ForwardPattern.Pattern_2,
177
+ **kwargs,
172
178
  )
173
179
 
174
180
 
@@ -181,9 +187,8 @@ def cogview3plus_adapter(pipe, **kwargs) -> BlockAdapter:
181
187
  pipe=pipe,
182
188
  transformer=pipe.transformer,
183
189
  blocks=pipe.transformer.transformer_blocks,
184
- blocks_name="transformer_blocks",
185
- dummy_blocks_names=[],
186
190
  forward_pattern=ForwardPattern.Pattern_0,
191
+ **kwargs,
187
192
  )
188
193
 
189
194
 
@@ -196,10 +201,9 @@ def cogview4_adapter(pipe, **kwargs) -> BlockAdapter:
196
201
  pipe=pipe,
197
202
  transformer=pipe.transformer,
198
203
  blocks=pipe.transformer.transformer_blocks,
199
- blocks_name="transformer_blocks",
200
- dummy_blocks_names=[],
201
204
  forward_pattern=ForwardPattern.Pattern_0,
202
205
  has_separate_cfg=True,
206
+ **kwargs,
203
207
  )
204
208
 
205
209
 
@@ -212,10 +216,9 @@ def cosmos_adapter(pipe, **kwargs) -> BlockAdapter:
212
216
  pipe=pipe,
213
217
  transformer=pipe.transformer,
214
218
  blocks=pipe.transformer.transformer_blocks,
215
- blocks_name="transformer_blocks",
216
- dummy_blocks_names=[],
217
219
  forward_pattern=ForwardPattern.Pattern_2,
218
220
  has_separate_cfg=True,
221
+ **kwargs,
219
222
  )
220
223
 
221
224
 
@@ -228,9 +231,8 @@ def easyanimate_adapter(pipe, **kwargs) -> BlockAdapter:
228
231
  pipe=pipe,
229
232
  transformer=pipe.transformer,
230
233
  blocks=pipe.transformer.transformer_blocks,
231
- blocks_name="transformer_blocks",
232
- dummy_blocks_names=[],
233
234
  forward_pattern=ForwardPattern.Pattern_0,
235
+ **kwargs,
234
236
  )
235
237
 
236
238
 
@@ -243,10 +245,9 @@ def skyreelsv2_adapter(pipe, **kwargs) -> BlockAdapter:
243
245
  pipe=pipe,
244
246
  transformer=pipe.transformer,
245
247
  blocks=pipe.transformer.blocks,
246
- blocks_name="blocks",
247
- dummy_blocks_names=[],
248
248
  forward_pattern=ForwardPattern.Pattern_2,
249
249
  has_separate_cfg=True,
250
+ **kwargs,
250
251
  )
251
252
 
252
253
 
@@ -259,9 +260,8 @@ def sd3_adapter(pipe, **kwargs) -> BlockAdapter:
259
260
  pipe=pipe,
260
261
  transformer=pipe.transformer,
261
262
  blocks=pipe.transformer.transformer_blocks,
262
- blocks_name="transformer_blocks",
263
- dummy_blocks_names=[],
264
263
  forward_pattern=ForwardPattern.Pattern_1,
264
+ **kwargs,
265
265
  )
266
266
 
267
267
 
@@ -274,9 +274,8 @@ def consisid_adapter(pipe, **kwargs) -> BlockAdapter:
274
274
  pipe=pipe,
275
275
  transformer=pipe.transformer,
276
276
  blocks=pipe.transformer.transformer_blocks,
277
- blocks_name="transformer_blocks",
278
- dummy_blocks_names=[],
279
277
  forward_pattern=ForwardPattern.Pattern_0,
278
+ **kwargs,
280
279
  )
281
280
 
282
281
 
@@ -289,9 +288,8 @@ def dit_adapter(pipe, **kwargs) -> BlockAdapter:
289
288
  pipe=pipe,
290
289
  transformer=pipe.transformer,
291
290
  blocks=pipe.transformer.transformer_blocks,
292
- blocks_name="transformer_blocks",
293
- dummy_blocks_names=[],
294
291
  forward_pattern=ForwardPattern.Pattern_3,
292
+ **kwargs,
295
293
  )
296
294
 
297
295
 
@@ -304,9 +302,8 @@ def amused_adapter(pipe, **kwargs) -> BlockAdapter:
304
302
  pipe=pipe,
305
303
  transformer=pipe.transformer,
306
304
  blocks=pipe.transformer.transformer_layers,
307
- blocks_name="transformer_layers",
308
- dummy_blocks_names=[],
309
305
  forward_pattern=ForwardPattern.Pattern_3,
306
+ **kwargs,
310
307
  )
311
308
 
312
309
 
@@ -318,13 +315,15 @@ def bria_adapter(pipe, **kwargs) -> BlockAdapter:
318
315
  return BlockAdapter(
319
316
  pipe=pipe,
320
317
  transformer=pipe.transformer,
321
- blocks=(
322
- pipe.transformer.transformer_blocks
323
- + pipe.transformer.single_transformer_blocks
324
- ),
325
- blocks_name="transformer_blocks",
326
- dummy_blocks_names=["single_transformer_blocks"],
327
- forward_pattern=ForwardPattern.Pattern_0,
318
+ blocks=[
319
+ pipe.transformer.transformer_blocks,
320
+ pipe.transformer.single_transformer_blocks,
321
+ ],
322
+ forward_pattern=[
323
+ ForwardPattern.Pattern_0,
324
+ ForwardPattern.Pattern_0,
325
+ ],
326
+ **kwargs,
328
327
  )
329
328
 
330
329
 
@@ -340,9 +339,8 @@ def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
340
339
  pipe=pipe,
341
340
  transformer=pipe.transformer,
342
341
  blocks=pipe.transformer.blocks,
343
- blocks_name="blocks",
344
- dummy_blocks_names=[],
345
342
  forward_pattern=ForwardPattern.Pattern_3,
343
+ **kwargs,
346
344
  )
347
345
 
348
346
 
@@ -355,14 +353,13 @@ def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
355
353
  pipe=pipe,
356
354
  transformer=pipe.transformer,
357
355
  blocks=pipe.transformer.blocks,
358
- blocks_name="blocks",
359
- dummy_blocks_names=[],
360
356
  forward_pattern=ForwardPattern.Pattern_3,
357
+ **kwargs,
361
358
  )
362
359
 
363
360
 
364
361
  @BlockAdapterRegistry.register("Lumina")
365
- def lumina_adapter(pipe) -> BlockAdapter:
362
+ def lumina_adapter(pipe, **kwargs) -> BlockAdapter:
366
363
  from diffusers import LuminaNextDiT2DModel
367
364
 
368
365
  assert isinstance(pipe.transformer, LuminaNextDiT2DModel)
@@ -370,9 +367,8 @@ def lumina_adapter(pipe) -> BlockAdapter:
370
367
  pipe=pipe,
371
368
  transformer=pipe.transformer,
372
369
  blocks=pipe.transformer.layers,
373
- blocks_name="layers",
374
- dummy_blocks_names=[],
375
370
  forward_pattern=ForwardPattern.Pattern_3,
371
+ **kwargs,
376
372
  )
377
373
 
378
374
 
@@ -385,9 +381,8 @@ def lumina2_adapter(pipe, **kwargs) -> BlockAdapter:
385
381
  pipe=pipe,
386
382
  transformer=pipe.transformer,
387
383
  blocks=pipe.transformer.layers,
388
- blocks_name="layers",
389
- dummy_blocks_names=[],
390
384
  forward_pattern=ForwardPattern.Pattern_3,
385
+ **kwargs,
391
386
  )
392
387
 
393
388
 
@@ -400,9 +395,8 @@ def omnigen_adapter(pipe, **kwargs) -> BlockAdapter:
400
395
  pipe=pipe,
401
396
  transformer=pipe.transformer,
402
397
  blocks=pipe.transformer.layers,
403
- blocks_name="layers",
404
- dummy_blocks_names=[],
405
398
  forward_pattern=ForwardPattern.Pattern_3,
399
+ **kwargs,
406
400
  )
407
401
 
408
402
 
@@ -415,9 +409,8 @@ def pixart_adapter(pipe, **kwargs) -> BlockAdapter:
415
409
  pipe=pipe,
416
410
  transformer=pipe.transformer,
417
411
  blocks=pipe.transformer.transformer_blocks,
418
- blocks_name="transformer_blocks",
419
- dummy_blocks_names=[],
420
412
  forward_pattern=ForwardPattern.Pattern_3,
413
+ **kwargs,
421
414
  )
422
415
 
423
416
 
@@ -430,9 +423,8 @@ def sana_adapter(pipe, **kwargs) -> BlockAdapter:
430
423
  pipe=pipe,
431
424
  transformer=pipe.transformer,
432
425
  blocks=pipe.transformer.transformer_blocks,
433
- blocks_name="transformer_blocks",
434
- dummy_blocks_names=[],
435
426
  forward_pattern=ForwardPattern.Pattern_3,
427
+ **kwargs,
436
428
  )
437
429
 
438
430
 
@@ -445,9 +437,8 @@ def shape_adapter(pipe, **kwargs) -> BlockAdapter:
445
437
  pipe=pipe,
446
438
  transformer=pipe.prior,
447
439
  blocks=pipe.prior.transformer_blocks,
448
- blocks_name="transformer_blocks",
449
- dummy_blocks_names=[],
450
440
  forward_pattern=ForwardPattern.Pattern_3,
441
+ **kwargs,
451
442
  )
452
443
 
453
444
 
@@ -460,29 +451,28 @@ def stabledudio_adapter(pipe, **kwargs) -> BlockAdapter:
460
451
  pipe=pipe,
461
452
  transformer=pipe.transformer,
462
453
  blocks=pipe.transformer.transformer_blocks,
463
- blocks_name="transformer_blocks",
464
- dummy_blocks_names=[],
465
454
  forward_pattern=ForwardPattern.Pattern_3,
455
+ **kwargs,
466
456
  )
467
457
 
468
458
 
469
459
  @BlockAdapterRegistry.register("VisualCloze")
470
460
  def visualcloze_adapter(pipe, **kwargs) -> BlockAdapter:
471
461
  from diffusers import FluxTransformer2DModel
472
- from cache_dit.cache_factory.patch_functors import FluxPatchFunctor
473
462
 
474
463
  assert isinstance(pipe.transformer, FluxTransformer2DModel)
475
464
  return BlockAdapter(
476
465
  pipe=pipe,
477
466
  transformer=pipe.transformer,
478
- blocks=(
479
- pipe.transformer.transformer_blocks
480
- + pipe.transformer.single_transformer_blocks
481
- ),
482
- blocks_name="transformer_blocks",
483
- dummy_blocks_names=["single_transformer_blocks"],
484
- patch_functor=FluxPatchFunctor(),
485
- forward_pattern=ForwardPattern.Pattern_1,
467
+ blocks=[
468
+ pipe.transformer.transformer_blocks,
469
+ pipe.transformer.single_transformer_blocks,
470
+ ],
471
+ forward_pattern=[
472
+ ForwardPattern.Pattern_1,
473
+ ForwardPattern.Pattern_3,
474
+ ],
475
+ **kwargs,
486
476
  )
487
477
 
488
478
 
@@ -494,19 +484,9 @@ def auraflow_adapter(pipe, **kwargs) -> BlockAdapter:
494
484
  return BlockAdapter(
495
485
  pipe=pipe,
496
486
  transformer=pipe.transformer,
497
- blocks=[
498
- # Only 4 joint blocks, apply no-cache
499
- # pipe.transformer.joint_transformer_blocks,
500
- pipe.transformer.single_transformer_blocks,
501
- ],
502
- blocks_name=[
503
- # "joint_transformer_blocks",
504
- "single_transformer_blocks",
505
- ],
506
- forward_pattern=[
507
- # ForwardPattern.Pattern_1,
508
- ForwardPattern.Pattern_3,
509
- ],
487
+ blocks=pipe.transformer.single_transformer_blocks,
488
+ forward_pattern=ForwardPattern.Pattern_3,
489
+ **kwargs,
510
490
  )
511
491
 
512
492
 
@@ -522,15 +502,12 @@ def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
522
502
  pipe.transformer.transformer_blocks,
523
503
  pipe.transformer.single_transformer_blocks,
524
504
  ],
525
- blocks_name=[
526
- "transformer_blocks",
527
- "single_transformer_blocks",
528
- ],
529
505
  forward_pattern=[
530
506
  ForwardPattern.Pattern_1,
531
507
  ForwardPattern.Pattern_3,
532
508
  ],
533
509
  has_separate_cfg=True,
510
+ **kwargs,
534
511
  )
535
512
 
536
513
 
@@ -546,10 +523,10 @@ def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
546
523
  pipe.transformer.double_stream_blocks,
547
524
  pipe.transformer.single_stream_blocks,
548
525
  ],
549
- dummy_blocks_names=[],
550
526
  forward_pattern=[
551
527
  ForwardPattern.Pattern_4,
552
528
  ForwardPattern.Pattern_3,
553
529
  ],
554
530
  check_num_outputs=False,
531
+ **kwargs,
555
532
  )