cache-dit 0.2.28__py3-none-any.whl → 0.2.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/__init__.py CHANGED
@@ -6,6 +6,7 @@ except ImportError:
6
6
 
7
7
  from cache_dit.cache_factory import load_options
8
8
  from cache_dit.cache_factory import enable_cache
9
+ from cache_dit.cache_factory import disable_cache
9
10
  from cache_dit.cache_factory import cache_type
10
11
  from cache_dit.cache_factory import block_range
11
12
  from cache_dit.cache_factory import CacheType
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.28'
32
- __version_tuple__ = version_tuple = (0, 2, 28)
31
+ __version__ = version = '0.2.30'
32
+ __version_tuple__ = version_tuple = (0, 2, 30)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -17,6 +17,7 @@ from cache_dit.cache_factory.cache_blocks import CachedBlocks
17
17
  from cache_dit.cache_factory.cache_adapters import CachedAdapter
18
18
 
19
19
  from cache_dit.cache_factory.cache_interface import enable_cache
20
+ from cache_dit.cache_factory.cache_interface import disable_cache
20
21
  from cache_dit.cache_factory.cache_interface import supported_pipelines
21
22
  from cache_dit.cache_factory.cache_interface import get_adapter
22
23
 
@@ -9,23 +9,37 @@ from cache_dit.cache_factory.block_adapters.block_registers import (
9
9
  @BlockAdapterRegistry.register("Flux")
10
10
  def flux_adapter(pipe, **kwargs) -> BlockAdapter:
11
11
  from diffusers import FluxTransformer2DModel
12
- from cache_dit.cache_factory.patch_functors import FluxPatchFunctor
12
+ from cache_dit.utils import is_diffusers_at_least_0_3_5
13
13
 
14
14
  assert isinstance(pipe.transformer, FluxTransformer2DModel)
15
-
16
- return BlockAdapter(
17
- pipe=pipe,
18
- transformer=pipe.transformer,
19
- blocks=(
20
- pipe.transformer.transformer_blocks
21
- + pipe.transformer.single_transformer_blocks
22
- ),
23
- blocks_name="transformer_blocks",
24
- dummy_blocks_names=["single_transformer_blocks"],
25
- patch_functor=FluxPatchFunctor(),
26
- forward_pattern=ForwardPattern.Pattern_1,
27
- disable_patch=kwargs.pop("disable_patch", False),
28
- )
15
+ if is_diffusers_at_least_0_3_5():
16
+ return BlockAdapter(
17
+ pipe=pipe,
18
+ transformer=pipe.transformer,
19
+ blocks=[
20
+ pipe.transformer.transformer_blocks,
21
+ pipe.transformer.single_transformer_blocks,
22
+ ],
23
+ forward_pattern=[
24
+ ForwardPattern.Pattern_1,
25
+ ForwardPattern.Pattern_1,
26
+ ],
27
+ **kwargs,
28
+ )
29
+ else:
30
+ return BlockAdapter(
31
+ pipe=pipe,
32
+ transformer=pipe.transformer,
33
+ blocks=[
34
+ pipe.transformer.transformer_blocks,
35
+ pipe.transformer.single_transformer_blocks,
36
+ ],
37
+ forward_pattern=[
38
+ ForwardPattern.Pattern_1,
39
+ ForwardPattern.Pattern_3,
40
+ ],
41
+ **kwargs,
42
+ )
29
43
 
30
44
 
31
45
  @BlockAdapterRegistry.register("Mochi")
@@ -37,9 +51,8 @@ def mochi_adapter(pipe, **kwargs) -> BlockAdapter:
37
51
  pipe=pipe,
38
52
  transformer=pipe.transformer,
39
53
  blocks=pipe.transformer.transformer_blocks,
40
- blocks_name="transformer_blocks",
41
- dummy_blocks_names=[],
42
54
  forward_pattern=ForwardPattern.Pattern_0,
55
+ **kwargs,
43
56
  )
44
57
 
45
58
 
@@ -52,9 +65,8 @@ def cogvideox_adapter(pipe, **kwargs) -> BlockAdapter:
52
65
  pipe=pipe,
53
66
  transformer=pipe.transformer,
54
67
  blocks=pipe.transformer.transformer_blocks,
55
- blocks_name="transformer_blocks",
56
- dummy_blocks_names=[],
57
68
  forward_pattern=ForwardPattern.Pattern_0,
69
+ **kwargs,
58
70
  )
59
71
 
60
72
 
@@ -85,16 +97,12 @@ def wan_adapter(pipe, **kwargs) -> BlockAdapter:
85
97
  pipe.transformer.blocks,
86
98
  pipe.transformer_2.blocks,
87
99
  ],
88
- blocks_name=[
89
- "blocks",
90
- "blocks",
91
- ],
92
100
  forward_pattern=[
93
101
  ForwardPattern.Pattern_2,
94
102
  ForwardPattern.Pattern_2,
95
103
  ],
96
- dummy_blocks_names=[],
97
104
  has_separate_cfg=True,
105
+ **kwargs,
98
106
  )
99
107
  else:
100
108
  # Wan 2.1
@@ -102,10 +110,9 @@ def wan_adapter(pipe, **kwargs) -> BlockAdapter:
102
110
  pipe=pipe,
103
111
  transformer=pipe.transformer,
104
112
  blocks=pipe.transformer.blocks,
105
- blocks_name="blocks",
106
- dummy_blocks_names=[],
107
113
  forward_pattern=ForwardPattern.Pattern_2,
108
114
  has_separate_cfg=True,
115
+ **kwargs,
109
116
  )
110
117
 
111
118
 
@@ -116,13 +123,18 @@ def hunyuanvideo_adapter(pipe, **kwargs) -> BlockAdapter:
116
123
  assert isinstance(pipe.transformer, HunyuanVideoTransformer3DModel)
117
124
  return BlockAdapter(
118
125
  pipe=pipe,
119
- blocks=(
120
- pipe.transformer.transformer_blocks
121
- + pipe.transformer.single_transformer_blocks
122
- ),
123
- blocks_name="transformer_blocks",
124
- dummy_blocks_names=["single_transformer_blocks"],
125
- forward_pattern=ForwardPattern.Pattern_0,
126
+ transformer=pipe.transformer,
127
+ blocks=[
128
+ pipe.transformer.transformer_blocks,
129
+ pipe.transformer.single_transformer_blocks,
130
+ ],
131
+ forward_pattern=[
132
+ ForwardPattern.Pattern_0,
133
+ ForwardPattern.Pattern_0,
134
+ ],
135
+ # The type hint in diffusers is wrong
136
+ check_num_outputs=False,
137
+ **kwargs,
126
138
  )
127
139
 
128
140
 
@@ -135,10 +147,9 @@ def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
135
147
  pipe=pipe,
136
148
  transformer=pipe.transformer,
137
149
  blocks=pipe.transformer.transformer_blocks,
138
- blocks_name="transformer_blocks",
139
- dummy_blocks_names=[],
140
150
  forward_pattern=ForwardPattern.Pattern_1,
141
151
  has_separate_cfg=True,
152
+ **kwargs,
142
153
  )
143
154
 
144
155
 
@@ -151,9 +162,8 @@ def ltxvideo_adapter(pipe, **kwargs) -> BlockAdapter:
151
162
  pipe=pipe,
152
163
  transformer=pipe.transformer,
153
164
  blocks=pipe.transformer.transformer_blocks,
154
- blocks_name="transformer_blocks",
155
- dummy_blocks_names=[],
156
165
  forward_pattern=ForwardPattern.Pattern_2,
166
+ **kwargs,
157
167
  )
158
168
 
159
169
 
@@ -166,9 +176,8 @@ def allegro_adapter(pipe, **kwargs) -> BlockAdapter:
166
176
  pipe=pipe,
167
177
  transformer=pipe.transformer,
168
178
  blocks=pipe.transformer.transformer_blocks,
169
- blocks_name="transformer_blocks",
170
- dummy_blocks_names=[],
171
179
  forward_pattern=ForwardPattern.Pattern_2,
180
+ **kwargs,
172
181
  )
173
182
 
174
183
 
@@ -181,9 +190,8 @@ def cogview3plus_adapter(pipe, **kwargs) -> BlockAdapter:
181
190
  pipe=pipe,
182
191
  transformer=pipe.transformer,
183
192
  blocks=pipe.transformer.transformer_blocks,
184
- blocks_name="transformer_blocks",
185
- dummy_blocks_names=[],
186
193
  forward_pattern=ForwardPattern.Pattern_0,
194
+ **kwargs,
187
195
  )
188
196
 
189
197
 
@@ -196,10 +204,9 @@ def cogview4_adapter(pipe, **kwargs) -> BlockAdapter:
196
204
  pipe=pipe,
197
205
  transformer=pipe.transformer,
198
206
  blocks=pipe.transformer.transformer_blocks,
199
- blocks_name="transformer_blocks",
200
- dummy_blocks_names=[],
201
207
  forward_pattern=ForwardPattern.Pattern_0,
202
208
  has_separate_cfg=True,
209
+ **kwargs,
203
210
  )
204
211
 
205
212
 
@@ -212,10 +219,9 @@ def cosmos_adapter(pipe, **kwargs) -> BlockAdapter:
212
219
  pipe=pipe,
213
220
  transformer=pipe.transformer,
214
221
  blocks=pipe.transformer.transformer_blocks,
215
- blocks_name="transformer_blocks",
216
- dummy_blocks_names=[],
217
222
  forward_pattern=ForwardPattern.Pattern_2,
218
223
  has_separate_cfg=True,
224
+ **kwargs,
219
225
  )
220
226
 
221
227
 
@@ -228,9 +234,8 @@ def easyanimate_adapter(pipe, **kwargs) -> BlockAdapter:
228
234
  pipe=pipe,
229
235
  transformer=pipe.transformer,
230
236
  blocks=pipe.transformer.transformer_blocks,
231
- blocks_name="transformer_blocks",
232
- dummy_blocks_names=[],
233
237
  forward_pattern=ForwardPattern.Pattern_0,
238
+ **kwargs,
234
239
  )
235
240
 
236
241
 
@@ -243,10 +248,9 @@ def skyreelsv2_adapter(pipe, **kwargs) -> BlockAdapter:
243
248
  pipe=pipe,
244
249
  transformer=pipe.transformer,
245
250
  blocks=pipe.transformer.blocks,
246
- blocks_name="blocks",
247
- dummy_blocks_names=[],
248
251
  forward_pattern=ForwardPattern.Pattern_2,
249
252
  has_separate_cfg=True,
253
+ **kwargs,
250
254
  )
251
255
 
252
256
 
@@ -259,9 +263,8 @@ def sd3_adapter(pipe, **kwargs) -> BlockAdapter:
259
263
  pipe=pipe,
260
264
  transformer=pipe.transformer,
261
265
  blocks=pipe.transformer.transformer_blocks,
262
- blocks_name="transformer_blocks",
263
- dummy_blocks_names=[],
264
266
  forward_pattern=ForwardPattern.Pattern_1,
267
+ **kwargs,
265
268
  )
266
269
 
267
270
 
@@ -274,9 +277,8 @@ def consisid_adapter(pipe, **kwargs) -> BlockAdapter:
274
277
  pipe=pipe,
275
278
  transformer=pipe.transformer,
276
279
  blocks=pipe.transformer.transformer_blocks,
277
- blocks_name="transformer_blocks",
278
- dummy_blocks_names=[],
279
280
  forward_pattern=ForwardPattern.Pattern_0,
281
+ **kwargs,
280
282
  )
281
283
 
282
284
 
@@ -289,9 +291,8 @@ def dit_adapter(pipe, **kwargs) -> BlockAdapter:
289
291
  pipe=pipe,
290
292
  transformer=pipe.transformer,
291
293
  blocks=pipe.transformer.transformer_blocks,
292
- blocks_name="transformer_blocks",
293
- dummy_blocks_names=[],
294
294
  forward_pattern=ForwardPattern.Pattern_3,
295
+ **kwargs,
295
296
  )
296
297
 
297
298
 
@@ -304,9 +305,8 @@ def amused_adapter(pipe, **kwargs) -> BlockAdapter:
304
305
  pipe=pipe,
305
306
  transformer=pipe.transformer,
306
307
  blocks=pipe.transformer.transformer_layers,
307
- blocks_name="transformer_layers",
308
- dummy_blocks_names=[],
309
308
  forward_pattern=ForwardPattern.Pattern_3,
309
+ **kwargs,
310
310
  )
311
311
 
312
312
 
@@ -318,51 +318,20 @@ def bria_adapter(pipe, **kwargs) -> BlockAdapter:
318
318
  return BlockAdapter(
319
319
  pipe=pipe,
320
320
  transformer=pipe.transformer,
321
- blocks=(
322
- pipe.transformer.transformer_blocks
323
- + pipe.transformer.single_transformer_blocks
324
- ),
325
- blocks_name="transformer_blocks",
326
- dummy_blocks_names=["single_transformer_blocks"],
327
- forward_pattern=ForwardPattern.Pattern_0,
328
- )
329
-
330
-
331
- @BlockAdapterRegistry.register("HunyuanDiT")
332
- def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
333
- from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
334
-
335
- assert isinstance(
336
- pipe.transformer,
337
- (HunyuanDiT2DModel, HunyuanDiT2DControlNetModel),
338
- )
339
- return BlockAdapter(
340
- pipe=pipe,
341
- transformer=pipe.transformer,
342
- blocks=pipe.transformer.blocks,
343
- blocks_name="blocks",
344
- dummy_blocks_names=[],
345
- forward_pattern=ForwardPattern.Pattern_3,
346
- )
347
-
348
-
349
- @BlockAdapterRegistry.register("HunyuanDiTPAG")
350
- def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
351
- from diffusers import HunyuanDiT2DModel
352
-
353
- assert isinstance(pipe.transformer, HunyuanDiT2DModel)
354
- return BlockAdapter(
355
- pipe=pipe,
356
- transformer=pipe.transformer,
357
- blocks=pipe.transformer.blocks,
358
- blocks_name="blocks",
359
- dummy_blocks_names=[],
360
- forward_pattern=ForwardPattern.Pattern_3,
321
+ blocks=[
322
+ pipe.transformer.transformer_blocks,
323
+ pipe.transformer.single_transformer_blocks,
324
+ ],
325
+ forward_pattern=[
326
+ ForwardPattern.Pattern_0,
327
+ ForwardPattern.Pattern_0,
328
+ ],
329
+ **kwargs,
361
330
  )
362
331
 
363
332
 
364
333
  @BlockAdapterRegistry.register("Lumina")
365
- def lumina_adapter(pipe) -> BlockAdapter:
334
+ def lumina_adapter(pipe, **kwargs) -> BlockAdapter:
366
335
  from diffusers import LuminaNextDiT2DModel
367
336
 
368
337
  assert isinstance(pipe.transformer, LuminaNextDiT2DModel)
@@ -370,9 +339,8 @@ def lumina_adapter(pipe) -> BlockAdapter:
370
339
  pipe=pipe,
371
340
  transformer=pipe.transformer,
372
341
  blocks=pipe.transformer.layers,
373
- blocks_name="layers",
374
- dummy_blocks_names=[],
375
342
  forward_pattern=ForwardPattern.Pattern_3,
343
+ **kwargs,
376
344
  )
377
345
 
378
346
 
@@ -385,9 +353,8 @@ def lumina2_adapter(pipe, **kwargs) -> BlockAdapter:
385
353
  pipe=pipe,
386
354
  transformer=pipe.transformer,
387
355
  blocks=pipe.transformer.layers,
388
- blocks_name="layers",
389
- dummy_blocks_names=[],
390
356
  forward_pattern=ForwardPattern.Pattern_3,
357
+ **kwargs,
391
358
  )
392
359
 
393
360
 
@@ -400,9 +367,8 @@ def omnigen_adapter(pipe, **kwargs) -> BlockAdapter:
400
367
  pipe=pipe,
401
368
  transformer=pipe.transformer,
402
369
  blocks=pipe.transformer.layers,
403
- blocks_name="layers",
404
- dummy_blocks_names=[],
405
370
  forward_pattern=ForwardPattern.Pattern_3,
371
+ **kwargs,
406
372
  )
407
373
 
408
374
 
@@ -415,39 +381,24 @@ def pixart_adapter(pipe, **kwargs) -> BlockAdapter:
415
381
  pipe=pipe,
416
382
  transformer=pipe.transformer,
417
383
  blocks=pipe.transformer.transformer_blocks,
418
- blocks_name="transformer_blocks",
419
- dummy_blocks_names=[],
420
384
  forward_pattern=ForwardPattern.Pattern_3,
385
+ **kwargs,
421
386
  )
422
387
 
423
388
 
424
- @BlockAdapterRegistry.register("Sana")
389
+ @BlockAdapterRegistry.register("Sana", supported=False)
425
390
  def sana_adapter(pipe, **kwargs) -> BlockAdapter:
426
391
  from diffusers import SanaTransformer2DModel
427
392
 
393
+ # TODO: fix -> got multiple values for argument 'encoder_hidden_states'
394
+
428
395
  assert isinstance(pipe.transformer, SanaTransformer2DModel)
429
396
  return BlockAdapter(
430
397
  pipe=pipe,
431
398
  transformer=pipe.transformer,
432
399
  blocks=pipe.transformer.transformer_blocks,
433
- blocks_name="transformer_blocks",
434
- dummy_blocks_names=[],
435
- forward_pattern=ForwardPattern.Pattern_3,
436
- )
437
-
438
-
439
- @BlockAdapterRegistry.register("ShapE")
440
- def shape_adapter(pipe, **kwargs) -> BlockAdapter:
441
- from diffusers import PriorTransformer
442
-
443
- assert isinstance(pipe.prior, PriorTransformer)
444
- return BlockAdapter(
445
- pipe=pipe,
446
- transformer=pipe.prior,
447
- blocks=pipe.prior.transformer_blocks,
448
- blocks_name="transformer_blocks",
449
- dummy_blocks_names=[],
450
400
  forward_pattern=ForwardPattern.Pattern_3,
401
+ **kwargs,
451
402
  )
452
403
 
453
404
 
@@ -460,30 +411,45 @@ def stabledudio_adapter(pipe, **kwargs) -> BlockAdapter:
460
411
  pipe=pipe,
461
412
  transformer=pipe.transformer,
462
413
  blocks=pipe.transformer.transformer_blocks,
463
- blocks_name="transformer_blocks",
464
- dummy_blocks_names=[],
465
414
  forward_pattern=ForwardPattern.Pattern_3,
415
+ **kwargs,
466
416
  )
467
417
 
468
418
 
469
419
  @BlockAdapterRegistry.register("VisualCloze")
470
420
  def visualcloze_adapter(pipe, **kwargs) -> BlockAdapter:
471
421
  from diffusers import FluxTransformer2DModel
472
- from cache_dit.cache_factory.patch_functors import FluxPatchFunctor
422
+ from cache_dit.utils import is_diffusers_at_least_0_3_5
473
423
 
474
424
  assert isinstance(pipe.transformer, FluxTransformer2DModel)
475
- return BlockAdapter(
476
- pipe=pipe,
477
- transformer=pipe.transformer,
478
- blocks=(
479
- pipe.transformer.transformer_blocks
480
- + pipe.transformer.single_transformer_blocks
481
- ),
482
- blocks_name="transformer_blocks",
483
- dummy_blocks_names=["single_transformer_blocks"],
484
- patch_functor=FluxPatchFunctor(),
485
- forward_pattern=ForwardPattern.Pattern_1,
486
- )
425
+ if is_diffusers_at_least_0_3_5():
426
+ return BlockAdapter(
427
+ pipe=pipe,
428
+ transformer=pipe.transformer,
429
+ blocks=[
430
+ pipe.transformer.transformer_blocks,
431
+ pipe.transformer.single_transformer_blocks,
432
+ ],
433
+ forward_pattern=[
434
+ ForwardPattern.Pattern_1,
435
+ ForwardPattern.Pattern_1,
436
+ ],
437
+ **kwargs,
438
+ )
439
+ else:
440
+ return BlockAdapter(
441
+ pipe=pipe,
442
+ transformer=pipe.transformer,
443
+ blocks=[
444
+ pipe.transformer.transformer_blocks,
445
+ pipe.transformer.single_transformer_blocks,
446
+ ],
447
+ forward_pattern=[
448
+ ForwardPattern.Pattern_1,
449
+ ForwardPattern.Pattern_3,
450
+ ],
451
+ **kwargs,
452
+ )
487
453
 
488
454
 
489
455
  @BlockAdapterRegistry.register("AuraFlow")
@@ -494,19 +460,9 @@ def auraflow_adapter(pipe, **kwargs) -> BlockAdapter:
494
460
  return BlockAdapter(
495
461
  pipe=pipe,
496
462
  transformer=pipe.transformer,
497
- blocks=[
498
- # Only 4 joint blocks, apply no-cache
499
- # pipe.transformer.joint_transformer_blocks,
500
- pipe.transformer.single_transformer_blocks,
501
- ],
502
- blocks_name=[
503
- # "joint_transformer_blocks",
504
- "single_transformer_blocks",
505
- ],
506
- forward_pattern=[
507
- # ForwardPattern.Pattern_1,
508
- ForwardPattern.Pattern_3,
509
- ],
463
+ blocks=pipe.transformer.single_transformer_blocks,
464
+ forward_pattern=ForwardPattern.Pattern_3,
465
+ **kwargs,
510
466
  )
511
467
 
512
468
 
@@ -522,20 +478,36 @@ def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
522
478
  pipe.transformer.transformer_blocks,
523
479
  pipe.transformer.single_transformer_blocks,
524
480
  ],
525
- blocks_name=[
526
- "transformer_blocks",
527
- "single_transformer_blocks",
528
- ],
529
481
  forward_pattern=[
530
482
  ForwardPattern.Pattern_1,
531
483
  ForwardPattern.Pattern_3,
532
484
  ],
533
485
  has_separate_cfg=True,
486
+ **kwargs,
487
+ )
488
+
489
+
490
+ @BlockAdapterRegistry.register("ShapE")
491
+ def shape_adapter(pipe, **kwargs) -> BlockAdapter:
492
+ from diffusers import PriorTransformer
493
+
494
+ assert isinstance(pipe.prior, PriorTransformer)
495
+ return BlockAdapter(
496
+ pipe=pipe,
497
+ transformer=pipe.prior,
498
+ blocks=pipe.prior.transformer_blocks,
499
+ forward_pattern=ForwardPattern.Pattern_3,
500
+ **kwargs,
534
501
  )
535
502
 
536
503
 
537
- @BlockAdapterRegistry.register("HiDream")
504
+ @BlockAdapterRegistry.register("HiDream", supported=True)
538
505
  def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
506
+ # NOTE: Need to patch Transformer forward to fully support
507
+ # double_stream_blocks and single_stream_blocks, namely, need
508
+ # to remove the logics inside the blocks forward loop:
509
+ # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L893
510
+ # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L927
539
511
  from diffusers import HiDreamImageTransformer2DModel
540
512
 
541
513
  assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
@@ -543,13 +515,47 @@ def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
543
515
  pipe=pipe,
544
516
  transformer=pipe.transformer,
545
517
  blocks=[
546
- pipe.transformer.double_stream_blocks,
518
+ # pipe.transformer.double_stream_blocks,
547
519
  pipe.transformer.single_stream_blocks,
548
520
  ],
549
- dummy_blocks_names=[],
550
521
  forward_pattern=[
551
- ForwardPattern.Pattern_4,
522
+ # ForwardPattern.Pattern_4,
552
523
  ForwardPattern.Pattern_3,
553
524
  ],
525
+ # The type hint in diffusers is wrong
554
526
  check_num_outputs=False,
527
+ **kwargs,
528
+ )
529
+
530
+
531
+ @BlockAdapterRegistry.register("HunyuanDiT", supported=False)
532
+ def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
533
+ # TODO: Patch Transformer forward
534
+ from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
535
+
536
+ assert isinstance(
537
+ pipe.transformer,
538
+ (HunyuanDiT2DModel, HunyuanDiT2DControlNetModel),
539
+ )
540
+ return BlockAdapter(
541
+ pipe=pipe,
542
+ transformer=pipe.transformer,
543
+ blocks=pipe.transformer.blocks,
544
+ forward_pattern=ForwardPattern.Pattern_3,
545
+ **kwargs,
546
+ )
547
+
548
+
549
+ @BlockAdapterRegistry.register("HunyuanDiTPAG", supported=False)
550
+ def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
551
+ # TODO: Patch Transformer forward
552
+ from diffusers import HunyuanDiT2DModel
553
+
554
+ assert isinstance(pipe.transformer, HunyuanDiT2DModel)
555
+ return BlockAdapter(
556
+ pipe=pipe,
557
+ transformer=pipe.transformer,
558
+ blocks=pipe.transformer.blocks,
559
+ forward_pattern=ForwardPattern.Pattern_3,
560
+ **kwargs,
555
561
  )