diffsynth-engine 0.3.6.dev5__py3-none-any.whl → 0.3.6.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,6 +4,7 @@ from .flux_vae import FluxVAEDecoder, FluxVAEEncoder, config as flux_vae_config
4
4
  from .flux_controlnet import FluxControlNet
5
5
  from .flux_ipadapter import FluxIPAdapter
6
6
  from .flux_redux import FluxRedux
7
+ from .flux_dit_fbcache import FluxDiTFBCache
7
8
 
8
9
  __all__ = [
9
10
  "FluxRedux",
@@ -14,6 +15,7 @@ __all__ = [
14
15
  "FluxTextEncoder2",
15
16
  "FluxVAEDecoder",
16
17
  "FluxVAEEncoder",
18
+ "FluxDiTFBCache",
17
19
  "flux_dit_config",
18
20
  "flux_text_encoder_config",
19
21
  "flux_vae_config",
@@ -0,0 +1,205 @@
1
+ import torch
2
+ import numpy as np
3
+ from typing import Dict, Optional
4
+
5
+ from diffsynth_engine.models.utils import no_init_weights
6
+ from diffsynth_engine.utils.gguf import gguf_inference
7
+ from diffsynth_engine.utils.fp8_linear import fp8_inference
8
+ from diffsynth_engine.utils.parallel import (
9
+ cfg_parallel,
10
+ cfg_parallel_unshard,
11
+ sequence_parallel,
12
+ sequence_parallel_unshard,
13
+ )
14
+ from diffsynth_engine.utils import logging
15
+ from diffsynth_engine.models.flux.flux_dit import FluxDiT
16
+
17
+ logger = logging.get_logger(__name__)
18
+
19
+
20
+ class FluxDiTFBCache(FluxDiT):
21
+ def __init__(
22
+ self,
23
+ in_channel: int = 64,
24
+ attn_impl: Optional[str] = None,
25
+ device: str = "cuda:0",
26
+ dtype: torch.dtype = torch.bfloat16,
27
+ relative_l1_threshold: float = 0.05,
28
+ ):
29
+ super().__init__(in_channel=in_channel, attn_impl=attn_impl, device=device, dtype=dtype)
30
+ self.relative_l1_threshold = relative_l1_threshold
31
+ self.step_count = 0
32
+ self.num_inference_steps = 0
33
+
34
+ def is_relative_l1_below_threshold(self, prev_residual, residual, threshold):
35
+ if threshold <= 0.0:
36
+ return False
37
+
38
+ if prev_residual.shape != residual.shape:
39
+ return False
40
+
41
+ mean_diff = (prev_residual - residual).abs().mean()
42
+ mean_prev_residual = prev_residual.abs().mean()
43
+ diff = mean_diff / mean_prev_residual
44
+ return diff.item() < threshold
45
+
46
+ def refresh_cache_status(self, num_inference_steps):
47
+ self.step_count = 0
48
+ self.num_inference_steps = num_inference_steps
49
+
50
+ def forward(
51
+ self,
52
+ hidden_states,
53
+ timestep,
54
+ prompt_emb,
55
+ pooled_prompt_emb,
56
+ image_emb,
57
+ guidance,
58
+ text_ids,
59
+ image_ids=None,
60
+ controlnet_double_block_output=None,
61
+ controlnet_single_block_output=None,
62
+ **kwargs,
63
+ ):
64
+ h, w = hidden_states.shape[-2:]
65
+ if image_ids is None:
66
+ image_ids = self.prepare_image_ids(hidden_states)
67
+ controlnet_double_block_output = (
68
+ controlnet_double_block_output if controlnet_double_block_output is not None else ()
69
+ )
70
+ controlnet_single_block_output = (
71
+ controlnet_single_block_output if controlnet_single_block_output is not None else ()
72
+ )
73
+
74
+ fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
75
+ use_cfg = hidden_states.shape[0] > 1
76
+ with (
77
+ fp8_inference(fp8_linear_enabled),
78
+ gguf_inference(),
79
+ cfg_parallel(
80
+ (
81
+ hidden_states,
82
+ timestep,
83
+ prompt_emb,
84
+ pooled_prompt_emb,
85
+ image_emb,
86
+ guidance,
87
+ text_ids,
88
+ image_ids,
89
+ *controlnet_double_block_output,
90
+ *controlnet_single_block_output,
91
+ ),
92
+ use_cfg=use_cfg,
93
+ ),
94
+ ):
95
+ # warning: keep the order of time_embedding + guidance_embedding + pooled_text_embedding
96
+ # addition of floating point numbers does not meet commutative law
97
+ conditioning = self.time_embedder(timestep, hidden_states.dtype)
98
+ if self.guidance_embedder is not None:
99
+ guidance = guidance * 1000
100
+ conditioning += self.guidance_embedder(guidance, hidden_states.dtype)
101
+ conditioning += self.pooled_text_embedder(pooled_prompt_emb)
102
+ rope_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
103
+ text_rope_emb = rope_emb[:, :, : text_ids.size(1)]
104
+ image_rope_emb = rope_emb[:, :, text_ids.size(1) :]
105
+ hidden_states = self.patchify(hidden_states)
106
+
107
+ with sequence_parallel(
108
+ (
109
+ hidden_states,
110
+ prompt_emb,
111
+ text_rope_emb,
112
+ image_rope_emb,
113
+ *controlnet_double_block_output,
114
+ *controlnet_single_block_output,
115
+ ),
116
+ seq_dims=(
117
+ 1,
118
+ 1,
119
+ 2,
120
+ 2,
121
+ *(1 for _ in controlnet_double_block_output),
122
+ *(1 for _ in controlnet_single_block_output),
123
+ ),
124
+ ):
125
+ hidden_states = self.x_embedder(hidden_states)
126
+ prompt_emb = self.context_embedder(prompt_emb)
127
+ rope_emb = torch.cat((text_rope_emb, image_rope_emb), dim=2)
128
+
129
+ # first block
130
+ original_hidden_states = hidden_states
131
+ hidden_states, prompt_emb = self.blocks[0](hidden_states, prompt_emb, conditioning, rope_emb, image_emb)
132
+ first_hidden_states_residual = hidden_states - original_hidden_states
133
+
134
+ (first_hidden_states_residual,) = sequence_parallel_unshard(
135
+ (first_hidden_states_residual,), seq_dims=(1,), seq_lens=(h * w // 4,)
136
+ )
137
+
138
+ if self.step_count == 0 or self.step_count == (self.num_inference_steps - 1):
139
+ should_calc = True
140
+ else:
141
+ skip = self.is_relative_l1_below_threshold(
142
+ first_hidden_states_residual,
143
+ self.prev_first_hidden_states_residual,
144
+ threshold=self.relative_l1_threshold,
145
+ )
146
+ should_calc = not skip
147
+ self.step_count += 1
148
+
149
+ if not should_calc:
150
+ hidden_states += self.previous_residual
151
+ else:
152
+ self.prev_first_hidden_states_residual = first_hidden_states_residual
153
+
154
+ first_hidden_states = hidden_states.clone()
155
+ for i, block in enumerate(self.blocks[1:]):
156
+ hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, rope_emb, image_emb)
157
+ if len(controlnet_double_block_output) > 0:
158
+ interval_control = len(self.blocks) / len(controlnet_double_block_output)
159
+ interval_control = int(np.ceil(interval_control))
160
+ hidden_states = hidden_states + controlnet_double_block_output[i // interval_control]
161
+ hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
162
+ for i, block in enumerate(self.single_blocks):
163
+ hidden_states = block(hidden_states, conditioning, rope_emb, image_emb)
164
+ if len(controlnet_single_block_output) > 0:
165
+ interval_control = len(self.single_blocks) / len(controlnet_double_block_output)
166
+ interval_control = int(np.ceil(interval_control))
167
+ hidden_states = hidden_states + controlnet_single_block_output[i // interval_control]
168
+
169
+ hidden_states = hidden_states[:, prompt_emb.shape[1] :]
170
+
171
+ previous_residual = hidden_states - first_hidden_states
172
+ self.previous_residual = previous_residual
173
+
174
+ hidden_states = self.final_norm_out(hidden_states, conditioning)
175
+ hidden_states = self.final_proj_out(hidden_states)
176
+ (hidden_states,) = sequence_parallel_unshard((hidden_states,), seq_dims=(1,), seq_lens=(h * w // 4,))
177
+
178
+ hidden_states = self.unpatchify(hidden_states, h, w)
179
+ (hidden_states,) = cfg_parallel_unshard((hidden_states,), use_cfg=use_cfg)
180
+
181
+ return hidden_states
182
+
183
+ @classmethod
184
+ def from_state_dict(
185
+ cls,
186
+ state_dict: Dict[str, torch.Tensor],
187
+ device: str,
188
+ dtype: torch.dtype,
189
+ in_channel: int = 64,
190
+ attn_impl: Optional[str] = None,
191
+ fb_cache_relative_l1_threshold: float = 0.05,
192
+ ):
193
+ with no_init_weights():
194
+ model = torch.nn.utils.skip_init(
195
+ cls,
196
+ device=device,
197
+ dtype=dtype,
198
+ in_channel=in_channel,
199
+ attn_impl=attn_impl,
200
+ fb_cache_relative_l1_threshold=fb_cache_relative_l1_threshold,
201
+ )
202
+ model = model.requires_grad_(False) # for loading gguf
203
+ model.load_state_dict(state_dict, assign=True)
204
+ model.to(device=device, dtype=dtype, non_blocking=True)
205
+ return model
@@ -12,18 +12,29 @@ from diffsynth_engine.models.basic.unet_helper import (
12
12
  DownSampler,
13
13
  )
14
14
 
15
+
15
16
  class ControlNetConditioningLayer(nn.Module):
16
- def __init__(self, channels = (3, 16, 32, 96, 256, 320), device = "cuda:0", dtype=torch.float16):
17
+ def __init__(self, channels=(3, 16, 32, 96, 256, 320), device="cuda:0", dtype=torch.float16):
17
18
  super().__init__()
18
19
  self.blocks = torch.nn.ModuleList([])
19
- self.blocks.append(torch.nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1, device=device, dtype=dtype))
20
+ self.blocks.append(
21
+ torch.nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1, device=device, dtype=dtype)
22
+ )
20
23
  self.blocks.append(torch.nn.SiLU())
21
24
  for i in range(1, len(channels) - 2):
22
- self.blocks.append(torch.nn.Conv2d(channels[i], channels[i], kernel_size=3, padding=1, device=device, dtype=dtype))
25
+ self.blocks.append(
26
+ torch.nn.Conv2d(channels[i], channels[i], kernel_size=3, padding=1, device=device, dtype=dtype)
27
+ )
23
28
  self.blocks.append(torch.nn.SiLU())
24
- self.blocks.append(torch.nn.Conv2d(channels[i], channels[i+1], kernel_size=3, padding=1, stride=2, device=device, dtype=dtype))
29
+ self.blocks.append(
30
+ torch.nn.Conv2d(
31
+ channels[i], channels[i + 1], kernel_size=3, padding=1, stride=2, device=device, dtype=dtype
32
+ )
33
+ )
25
34
  self.blocks.append(torch.nn.SiLU())
26
- self.blocks.append(torch.nn.Conv2d(channels[-2], channels[-1], kernel_size=3, padding=1, device=device, dtype=dtype))
35
+ self.blocks.append(
36
+ torch.nn.Conv2d(channels[-2], channels[-1], kernel_size=3, padding=1, device=device, dtype=dtype)
37
+ )
27
38
 
28
39
  def forward(self, conditioning):
29
40
  for block in self.blocks:
@@ -38,15 +49,73 @@ class SDControlNetStateDictConverter(StateDictConverter):
38
49
  def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
39
50
  # architecture
40
51
  block_types = [
41
- 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
42
- 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
43
- 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
44
- 'ResnetBlock', 'PushBlock', 'ResnetBlock', 'PushBlock',
45
- 'ResnetBlock', 'AttentionBlock', 'ResnetBlock',
46
- 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'UpSampler',
47
- 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
48
- 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
49
- 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock'
52
+ "ResnetBlock",
53
+ "AttentionBlock",
54
+ "PushBlock",
55
+ "ResnetBlock",
56
+ "AttentionBlock",
57
+ "PushBlock",
58
+ "DownSampler",
59
+ "PushBlock",
60
+ "ResnetBlock",
61
+ "AttentionBlock",
62
+ "PushBlock",
63
+ "ResnetBlock",
64
+ "AttentionBlock",
65
+ "PushBlock",
66
+ "DownSampler",
67
+ "PushBlock",
68
+ "ResnetBlock",
69
+ "AttentionBlock",
70
+ "PushBlock",
71
+ "ResnetBlock",
72
+ "AttentionBlock",
73
+ "PushBlock",
74
+ "DownSampler",
75
+ "PushBlock",
76
+ "ResnetBlock",
77
+ "PushBlock",
78
+ "ResnetBlock",
79
+ "PushBlock",
80
+ "ResnetBlock",
81
+ "AttentionBlock",
82
+ "ResnetBlock",
83
+ "PopBlock",
84
+ "ResnetBlock",
85
+ "PopBlock",
86
+ "ResnetBlock",
87
+ "PopBlock",
88
+ "ResnetBlock",
89
+ "UpSampler",
90
+ "PopBlock",
91
+ "ResnetBlock",
92
+ "AttentionBlock",
93
+ "PopBlock",
94
+ "ResnetBlock",
95
+ "AttentionBlock",
96
+ "PopBlock",
97
+ "ResnetBlock",
98
+ "AttentionBlock",
99
+ "UpSampler",
100
+ "PopBlock",
101
+ "ResnetBlock",
102
+ "AttentionBlock",
103
+ "PopBlock",
104
+ "ResnetBlock",
105
+ "AttentionBlock",
106
+ "PopBlock",
107
+ "ResnetBlock",
108
+ "AttentionBlock",
109
+ "UpSampler",
110
+ "PopBlock",
111
+ "ResnetBlock",
112
+ "AttentionBlock",
113
+ "PopBlock",
114
+ "ResnetBlock",
115
+ "AttentionBlock",
116
+ "PopBlock",
117
+ "ResnetBlock",
118
+ "AttentionBlock",
50
119
  ]
51
120
 
52
121
  # controlnet_rename_dict
@@ -66,7 +135,7 @@ class SDControlNetStateDictConverter(StateDictConverter):
66
135
  "controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
67
136
  "controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
68
137
  "controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
69
- "controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
138
+ "controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
70
139
  }
71
140
 
72
141
  # Rename each parameter
@@ -91,7 +160,12 @@ class SDControlNetStateDictConverter(StateDictConverter):
91
160
  elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
92
161
  if names[0] == "mid_block":
93
162
  names.insert(1, "0")
94
- block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
163
+ block_type = {
164
+ "resnets": "ResnetBlock",
165
+ "attentions": "AttentionBlock",
166
+ "downsamplers": "DownSampler",
167
+ "upsamplers": "UpSampler",
168
+ }[names[2]]
95
169
  block_type_with_id = ".".join(names[:4])
96
170
  if block_type_with_id != last_block_type_with_id[block_type]:
97
171
  block_id[block_type] += 1
@@ -102,9 +176,9 @@ class SDControlNetStateDictConverter(StateDictConverter):
102
176
  names = ["blocks", str(block_id[block_type])] + names[4:]
103
177
  if "ff" in names:
104
178
  ff_index = names.index("ff")
105
- component = ".".join(names[ff_index:ff_index+3])
179
+ component = ".".join(names[ff_index : ff_index + 3])
106
180
  component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
107
- names = names[:ff_index] + [component] + names[ff_index+3:]
181
+ names = names[:ff_index] + [component] + names[ff_index + 3 :]
108
182
  if "to_out" in names:
109
183
  names.pop(names.index("to_out") + 1)
110
184
  else:
@@ -117,13 +191,21 @@ class SDControlNetStateDictConverter(StateDictConverter):
117
191
  if ".proj_in." in name or ".proj_out." in name:
118
192
  param = param.squeeze()
119
193
  if rename_dict[name] in [
120
- "controlnet_blocks.1.bias", "controlnet_blocks.2.bias", "controlnet_blocks.3.bias", "controlnet_blocks.5.bias", "controlnet_blocks.6.bias",
121
- "controlnet_blocks.8.bias", "controlnet_blocks.9.bias", "controlnet_blocks.10.bias", "controlnet_blocks.11.bias", "controlnet_blocks.12.bias"
194
+ "controlnet_blocks.1.bias",
195
+ "controlnet_blocks.2.bias",
196
+ "controlnet_blocks.3.bias",
197
+ "controlnet_blocks.5.bias",
198
+ "controlnet_blocks.6.bias",
199
+ "controlnet_blocks.8.bias",
200
+ "controlnet_blocks.9.bias",
201
+ "controlnet_blocks.10.bias",
202
+ "controlnet_blocks.11.bias",
203
+ "controlnet_blocks.12.bias",
122
204
  ]:
123
205
  continue
124
206
  state_dict_[rename_dict[name]] = param
125
207
  return state_dict_
126
-
208
+
127
209
  def _from_civitai(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
128
210
  rename_dict = {
129
211
  "control_model.time_embed.0.weight": "time_embedding.timestep_embedder.0.weight",
@@ -496,69 +578,71 @@ class SDControlNet(PreTrainedModel):
496
578
  self.time_embedding = TimestepEmbeddings(dim_in=320, dim_out=1280, device=device, dtype=dtype)
497
579
  self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1, device=device, dtype=dtype)
498
580
 
499
- self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320), device=device, dtype=dtype)
581
+ self.controlnet_conv_in = ControlNetConditioningLayer(
582
+ channels=(3, 16, 32, 96, 256, 320), device=device, dtype=dtype
583
+ )
500
584
 
501
- self.blocks = torch.nn.ModuleList([
502
- # CrossAttnDownBlock2D
503
- ResnetBlock(320, 320, 1280, device=device, dtype=dtype),
504
- AttentionBlock(8, 40, 320, 1, 768, device=device, dtype=dtype),
505
- PushBlock(),
506
- ResnetBlock(320, 320, 1280, device=device, dtype=dtype),
507
- AttentionBlock(8, 40, 320, 1, 768, device=device, dtype=dtype),
508
- PushBlock(),
509
- DownSampler(320, device=device, dtype=dtype),
510
- PushBlock(),
511
- # CrossAttnDownBlock2D
512
- ResnetBlock(320, 640, 1280, device=device, dtype=dtype),
513
- AttentionBlock(8, 80, 640, 1, 768, device=device, dtype=dtype),
514
- PushBlock(),
515
- ResnetBlock(640, 640, 1280, device=device, dtype=dtype),
516
- AttentionBlock(8, 80, 640, 1, 768, device=device, dtype=dtype),
517
- PushBlock(),
518
- DownSampler(640, device=device, dtype=dtype),
519
- PushBlock(),
520
- # CrossAttnDownBlock2D
521
- ResnetBlock(640, 1280, 1280, device=device, dtype=dtype),
522
- AttentionBlock(8, 160, 1280, 1, 768, device=device, dtype=dtype),
523
- PushBlock(),
524
- ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
525
- AttentionBlock(8, 160, 1280, 1, 768, device=device, dtype=dtype),
526
- PushBlock(),
527
- DownSampler(1280, device=device, dtype=dtype),
528
- PushBlock(),
529
- # DownBlock2D
530
- ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
531
- PushBlock(),
532
- ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
533
- PushBlock(),
534
- # UNetMidBlock2DCrossAttn
535
- ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
536
- AttentionBlock(8, 160, 1280, 1, 768, device=device, dtype=dtype),
537
- ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
538
- PushBlock()
539
- ])
585
+ self.blocks = torch.nn.ModuleList(
586
+ [
587
+ # CrossAttnDownBlock2D
588
+ ResnetBlock(320, 320, 1280, device=device, dtype=dtype),
589
+ AttentionBlock(8, 40, 320, 1, 768, device=device, dtype=dtype),
590
+ PushBlock(),
591
+ ResnetBlock(320, 320, 1280, device=device, dtype=dtype),
592
+ AttentionBlock(8, 40, 320, 1, 768, device=device, dtype=dtype),
593
+ PushBlock(),
594
+ DownSampler(320, device=device, dtype=dtype),
595
+ PushBlock(),
596
+ # CrossAttnDownBlock2D
597
+ ResnetBlock(320, 640, 1280, device=device, dtype=dtype),
598
+ AttentionBlock(8, 80, 640, 1, 768, device=device, dtype=dtype),
599
+ PushBlock(),
600
+ ResnetBlock(640, 640, 1280, device=device, dtype=dtype),
601
+ AttentionBlock(8, 80, 640, 1, 768, device=device, dtype=dtype),
602
+ PushBlock(),
603
+ DownSampler(640, device=device, dtype=dtype),
604
+ PushBlock(),
605
+ # CrossAttnDownBlock2D
606
+ ResnetBlock(640, 1280, 1280, device=device, dtype=dtype),
607
+ AttentionBlock(8, 160, 1280, 1, 768, device=device, dtype=dtype),
608
+ PushBlock(),
609
+ ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
610
+ AttentionBlock(8, 160, 1280, 1, 768, device=device, dtype=dtype),
611
+ PushBlock(),
612
+ DownSampler(1280, device=device, dtype=dtype),
613
+ PushBlock(),
614
+ # DownBlock2D
615
+ ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
616
+ PushBlock(),
617
+ ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
618
+ PushBlock(),
619
+ # UNetMidBlock2DCrossAttn
620
+ ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
621
+ AttentionBlock(8, 160, 1280, 1, 768, device=device, dtype=dtype),
622
+ ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
623
+ PushBlock(),
624
+ ]
625
+ )
540
626
 
541
- self.controlnet_blocks = torch.nn.ModuleList([
542
- torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype),
543
- torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
544
- torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
545
- torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
546
- torch.nn.Conv2d(640, 640, kernel_size=(1, 1), device=device, dtype=dtype),
547
- torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
548
- torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
549
- torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), device=device, dtype=dtype),
550
- torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
551
- torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
552
- torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
553
- torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
554
- torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
555
- ])
627
+ self.controlnet_blocks = torch.nn.ModuleList(
628
+ [
629
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype),
630
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
631
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
632
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
633
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1), device=device, dtype=dtype),
634
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
635
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
636
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), device=device, dtype=dtype),
637
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
638
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
639
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
640
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
641
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
642
+ ]
643
+ )
556
644
 
557
- def forward(
558
- self,
559
- sample, timestep, encoder_hidden_states, conditioning,
560
- **kwargs
561
- ):
645
+ def forward(self, sample, timestep, encoder_hidden_states, conditioning, **kwargs):
562
646
  # 1. time
563
647
  time_emb = self.time_embedding(timestep, dtype=sample.dtype)
564
648
 
@@ -585,9 +669,7 @@ class SDControlNet(PreTrainedModel):
585
669
  attn_impl: Optional[str] = None,
586
670
  ):
587
671
  with no_init_weights():
588
- model = torch.nn.utils.skip_init(
589
- cls, attn_impl=attn_impl, device=device, dtype=dtype
590
- )
672
+ model = torch.nn.utils.skip_init(cls, attn_impl=attn_impl, device=device, dtype=dtype)
591
673
  model.load_state_dict(state_dict)
592
674
  model.to(device=device, dtype=dtype, non_blocking=True)
593
- return model
675
+ return model
@@ -9,7 +9,7 @@ __all__ = [
9
9
  "SDXLUNet",
10
10
  "SDXLVAEDecoder",
11
11
  "SDXLVAEEncoder",
12
- "SDXLControlNetUnion",
12
+ "SDXLControlNetUnion",
13
13
  "sdxl_text_encoder_config",
14
14
  "sdxl_unet_config",
15
15
  ]
@@ -12,23 +12,27 @@ from diffsynth_engine.models.basic.timestep import TimestepEmbeddings, TemporalT
12
12
 
13
13
  from collections import OrderedDict
14
14
 
15
- class QuickGELU(torch.nn.Module):
16
15
 
16
+ class QuickGELU(torch.nn.Module):
17
17
  def forward(self, x: torch.Tensor):
18
18
  return x * torch.sigmoid(1.702 * x)
19
19
 
20
- class ResidualAttentionBlock(torch.nn.Module):
21
20
 
21
+ class ResidualAttentionBlock(torch.nn.Module):
22
22
  def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, device="cuda:0", dtype=torch.float16):
23
23
  super().__init__()
24
24
 
25
25
  self.attn = torch.nn.MultiheadAttention(d_model, n_head, device=device, dtype=dtype)
26
26
  self.ln_1 = torch.nn.LayerNorm(d_model, device=device, dtype=dtype)
27
- self.mlp = torch.nn.Sequential(OrderedDict([
28
- ("c_fc", torch.nn.Linear(d_model, d_model * 4, device=device, dtype=dtype)),
29
- ("gelu", QuickGELU()),
30
- ("c_proj", torch.nn.Linear(d_model * 4, d_model, device=device, dtype=dtype))
31
- ]))
27
+ self.mlp = torch.nn.Sequential(
28
+ OrderedDict(
29
+ [
30
+ ("c_fc", torch.nn.Linear(d_model, d_model * 4, device=device, dtype=dtype)),
31
+ ("gelu", QuickGELU()),
32
+ ("c_proj", torch.nn.Linear(d_model * 4, d_model, device=device, dtype=dtype)),
33
+ ]
34
+ )
35
+ )
32
36
  self.ln_2 = torch.nn.LayerNorm(d_model, device=device, dtype=dtype)
33
37
  self.attn_mask = attn_mask
34
38
 
@@ -49,10 +53,30 @@ class SDXLControlNetUnionStateDictConverter(StateDictConverter):
49
53
  def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
50
54
  # architecture
51
55
  block_types = [
52
- "ResnetBlock", "PushBlock", "ResnetBlock", "PushBlock", "DownSampler", "PushBlock",
53
- "ResnetBlock", "AttentionBlock", "PushBlock", "ResnetBlock", "AttentionBlock", "PushBlock", "DownSampler", "PushBlock",
54
- "ResnetBlock", "AttentionBlock", "PushBlock", "ResnetBlock", "AttentionBlock", "PushBlock",
55
- "ResnetBlock", "AttentionBlock", "ResnetBlock", "PushBlock"
56
+ "ResnetBlock",
57
+ "PushBlock",
58
+ "ResnetBlock",
59
+ "PushBlock",
60
+ "DownSampler",
61
+ "PushBlock",
62
+ "ResnetBlock",
63
+ "AttentionBlock",
64
+ "PushBlock",
65
+ "ResnetBlock",
66
+ "AttentionBlock",
67
+ "PushBlock",
68
+ "DownSampler",
69
+ "PushBlock",
70
+ "ResnetBlock",
71
+ "AttentionBlock",
72
+ "PushBlock",
73
+ "ResnetBlock",
74
+ "AttentionBlock",
75
+ "PushBlock",
76
+ "ResnetBlock",
77
+ "AttentionBlock",
78
+ "ResnetBlock",
79
+ "PushBlock",
56
80
  ]
57
81
 
58
82
  # controlnet_rename_dict
@@ -107,7 +131,12 @@ class SDXLControlNetUnionStateDictConverter(StateDictConverter):
107
131
  elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
108
132
  if names[0] == "mid_block":
109
133
  names.insert(1, "0")
110
- block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
134
+ block_type = {
135
+ "resnets": "ResnetBlock",
136
+ "attentions": "AttentionBlock",
137
+ "downsamplers": "DownSampler",
138
+ "upsamplers": "UpSampler",
139
+ }[names[2]]
111
140
  block_type_with_id = ".".join(names[:4])
112
141
  if block_type_with_id != last_block_type_with_id[block_type]:
113
142
  block_id[block_type] += 1
@@ -118,9 +147,9 @@ class SDXLControlNetUnionStateDictConverter(StateDictConverter):
118
147
  names = ["blocks", str(block_id[block_type])] + names[4:]
119
148
  if "ff" in names:
120
149
  ff_index = names.index("ff")
121
- component = ".".join(names[ff_index:ff_index+3])
150
+ component = ".".join(names[ff_index : ff_index + 3])
122
151
  component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
123
- names = names[:ff_index] + [component] + names[ff_index+3:]
152
+ names = names[:ff_index] + [component] + names[ff_index + 3 :]
124
153
  if "to_out" in names:
125
154
  names.pop(names.index("to_out") + 1)
126
155
  else:
@@ -137,19 +166,20 @@ class SDXLControlNetUnionStateDictConverter(StateDictConverter):
137
166
  param = param.squeeze()
138
167
  state_dict_[rename_dict[name]] = param
139
168
  return state_dict_
140
-
169
+
141
170
  # TODO: check civitai
142
171
  def _from_civitai(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
143
172
  return self._from_diffusers(state_dict)
144
173
 
145
-
146
174
  def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
147
175
  return self._from_diffusers(state_dict)
148
176
 
177
+
149
178
  class SDXLControlNetUnion(PreTrainedModel):
150
179
  converter = SDXLControlNetUnionStateDictConverter()
151
180
 
152
- def __init__(self,
181
+ def __init__(
182
+ self,
153
183
  attn_impl: Optional[str] = None,
154
184
  device: str = "cuda:0",
155
185
  dtype: torch.dtype = torch.bfloat16,
@@ -157,68 +187,78 @@ class SDXLControlNetUnion(PreTrainedModel):
157
187
  super().__init__()
158
188
  self.time_embedding = TimestepEmbeddings(dim_in=320, dim_out=1280, device=device, dtype=dtype)
159
189
 
160
- self.add_time_proj = TemporalTimesteps(256, flip_sin_to_cos=True, downscale_freq_shift=0, device=device, dtype=dtype)
190
+ self.add_time_proj = TemporalTimesteps(
191
+ 256, flip_sin_to_cos=True, downscale_freq_shift=0, device=device, dtype=dtype
192
+ )
161
193
  self.add_time_embedding = torch.nn.Sequential(
162
194
  torch.nn.Linear(2816, 1280, device=device, dtype=dtype),
163
195
  torch.nn.SiLU(),
164
- torch.nn.Linear(1280, 1280, device=device, dtype=dtype)
196
+ torch.nn.Linear(1280, 1280, device=device, dtype=dtype),
197
+ )
198
+ self.control_type_proj = TemporalTimesteps(
199
+ 256, flip_sin_to_cos=True, downscale_freq_shift=0, device=device, dtype=dtype
165
200
  )
166
- self.control_type_proj = TemporalTimesteps(256, flip_sin_to_cos=True, downscale_freq_shift=0, device=device, dtype=dtype)
167
201
  self.control_type_embedding = torch.nn.Sequential(
168
202
  torch.nn.Linear(256 * 8, 1280, device=device, dtype=dtype),
169
203
  torch.nn.SiLU(),
170
- torch.nn.Linear(1280, 1280, device=device, dtype=dtype)
204
+ torch.nn.Linear(1280, 1280, device=device, dtype=dtype),
171
205
  )
172
206
  self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1, device=device, dtype=dtype)
173
207
 
174
- self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320), device=device, dtype=dtype)
208
+ self.controlnet_conv_in = ControlNetConditioningLayer(
209
+ channels=(3, 16, 32, 96, 256, 320), device=device, dtype=dtype
210
+ )
175
211
  self.controlnet_transformer = ResidualAttentionBlock(320, 8, device=device, dtype=dtype)
176
212
  self.task_embedding = torch.nn.Parameter(torch.randn(8, 320))
177
213
  self.spatial_ch_projs = torch.nn.Linear(320, 320, device=device, dtype=dtype)
178
214
 
179
- self.blocks = torch.nn.ModuleList([
180
- # DownBlock2D
181
- ResnetBlock(320, 320, 1280, device=device, dtype=dtype),
182
- PushBlock(),
183
- ResnetBlock(320, 320, 1280, device=device, dtype=dtype),
184
- PushBlock(),
185
- DownSampler(320, device=device, dtype=dtype),
186
- PushBlock(),
187
- # CrossAttnDownBlock2D
188
- ResnetBlock(320, 640, 1280, device=device, dtype=dtype),
189
- AttentionBlock(10, 64, 640, 2, 2048, device=device, dtype=dtype),
190
- PushBlock(),
191
- ResnetBlock(640, 640, 1280, device=device, dtype=dtype),
192
- AttentionBlock(10, 64, 640, 2, 2048, device=device, dtype=dtype),
193
- PushBlock(),
194
- DownSampler(640, device=device, dtype=dtype),
195
- PushBlock(),
196
- # CrossAttnDownBlock2D
197
- ResnetBlock(640, 1280, 1280, device=device, dtype=dtype),
198
- AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype),
199
- PushBlock(),
200
- ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
201
- AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype),
202
- PushBlock(),
203
- # UNetMidBlock2DCrossAttn
204
- ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
205
- AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype),
206
- ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
207
- PushBlock()
208
- ])
209
-
210
- self.controlnet_blocks = torch.nn.ModuleList([
211
- torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype),
212
- torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype),
213
- torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype),
214
- torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype),
215
- torch.nn.Conv2d(640, 640, kernel_size=(1, 1), device=device, dtype=dtype),
216
- torch.nn.Conv2d(640, 640, kernel_size=(1, 1), device=device, dtype=dtype),
217
- torch.nn.Conv2d(640, 640, kernel_size=(1, 1), device=device, dtype=dtype),
218
- torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), device=device, dtype=dtype),
219
- torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), device=device, dtype=dtype),
220
- torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), device=device, dtype=dtype),
221
- ])
215
+ self.blocks = torch.nn.ModuleList(
216
+ [
217
+ # DownBlock2D
218
+ ResnetBlock(320, 320, 1280, device=device, dtype=dtype),
219
+ PushBlock(),
220
+ ResnetBlock(320, 320, 1280, device=device, dtype=dtype),
221
+ PushBlock(),
222
+ DownSampler(320, device=device, dtype=dtype),
223
+ PushBlock(),
224
+ # CrossAttnDownBlock2D
225
+ ResnetBlock(320, 640, 1280, device=device, dtype=dtype),
226
+ AttentionBlock(10, 64, 640, 2, 2048, device=device, dtype=dtype),
227
+ PushBlock(),
228
+ ResnetBlock(640, 640, 1280, device=device, dtype=dtype),
229
+ AttentionBlock(10, 64, 640, 2, 2048, device=device, dtype=dtype),
230
+ PushBlock(),
231
+ DownSampler(640, device=device, dtype=dtype),
232
+ PushBlock(),
233
+ # CrossAttnDownBlock2D
234
+ ResnetBlock(640, 1280, 1280, device=device, dtype=dtype),
235
+ AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype),
236
+ PushBlock(),
237
+ ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
238
+ AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype),
239
+ PushBlock(),
240
+ # UNetMidBlock2DCrossAttn
241
+ ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
242
+ AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype),
243
+ ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
244
+ PushBlock(),
245
+ ]
246
+ )
247
+
248
+ self.controlnet_blocks = torch.nn.ModuleList(
249
+ [
250
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype),
251
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype),
252
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype),
253
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype),
254
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1), device=device, dtype=dtype),
255
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1), device=device, dtype=dtype),
256
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1), device=device, dtype=dtype),
257
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), device=device, dtype=dtype),
258
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), device=device, dtype=dtype),
259
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), device=device, dtype=dtype),
260
+ ]
261
+ )
222
262
 
223
263
  # 0 -- openpose
224
264
  # 1 -- depth
@@ -236,10 +276,9 @@ class SDXLControlNetUnion(PreTrainedModel):
236
276
  "lineart": 3,
237
277
  "lineart_anime": 3,
238
278
  "tile": 6,
239
- "inpaint": 7
279
+ "inpaint": 7,
240
280
  }
241
281
 
242
-
243
282
  def fuse_condition_to_input(self, hidden_states, task_id, conditioning):
244
283
  controlnet_cond = self.controlnet_conv_in(conditioning)
245
284
  feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
@@ -247,19 +286,25 @@ class SDXLControlNetUnion(PreTrainedModel):
247
286
  x = torch.stack([feat_seq, torch.mean(hidden_states, dim=(2, 3))], dim=1)
248
287
  x = self.controlnet_transformer(x)
249
288
 
250
- alpha = self.spatial_ch_projs(x[:,0]).unsqueeze(-1).unsqueeze(-1)
289
+ alpha = self.spatial_ch_projs(x[:, 0]).unsqueeze(-1).unsqueeze(-1)
251
290
  controlnet_cond_fuser = controlnet_cond + alpha
252
291
 
253
292
  hidden_states = hidden_states + controlnet_cond_fuser
254
293
  return hidden_states
255
-
256
294
 
257
295
  def forward(
258
296
  self,
259
- sample, timestep, encoder_hidden_states,
260
- conditioning, processor_name, add_time_id, add_text_embeds,
261
- tiled=False, tile_size=64, tile_stride=32,
262
- **kwargs
297
+ sample,
298
+ timestep,
299
+ encoder_hidden_states,
300
+ conditioning,
301
+ processor_name,
302
+ add_time_id,
303
+ add_text_embeds,
304
+ tiled=False,
305
+ tile_size=64,
306
+ tile_stride=32,
307
+ **kwargs,
263
308
  ):
264
309
  task_id = self.task_id[processor_name]
265
310
 
@@ -268,13 +268,12 @@ class SDXLUNet(PreTrainedModel):
268
268
  text_emb,
269
269
  res_stack,
270
270
  )
271
-
271
+
272
272
  # 3.2 Controlnet
273
273
  if i == controlnet_insert_block_id and controlnet_res_stack is not None:
274
274
  hidden_states += controlnet_res_stack.pop()
275
275
  res_stack = [res + controlnet_res for res, controlnet_res in zip(res_stack, controlnet_res_stack)]
276
276
 
277
-
278
277
  # 4. output
279
278
  hidden_states = self.conv_norm_out(hidden_states)
280
279
  hidden_states = self.conv_act(hidden_states)
@@ -6,6 +6,7 @@ from dataclasses import dataclass
6
6
 
7
7
  ImageType = Union[Image.Image, torch.Tensor, List[Image.Image], List[torch.Tensor]]
8
8
 
9
+
9
10
  @dataclass
10
11
  class ControlNetParams:
11
12
  image: ImageType
@@ -14,11 +15,12 @@ class ControlNetParams:
14
15
  mask: Optional[ImageType] = None
15
16
  control_start: float = 0
16
17
  control_end: float = 1
17
- processor_name: Optional[str] = None # only used for sdxl controlnet union now
18
+ processor_name: Optional[str] = None # only used for sdxl controlnet union now
19
+
18
20
 
19
21
  def accumulate(result, new_item):
20
22
  if result is None:
21
23
  return new_item
22
24
  for i, item in enumerate(new_item):
23
25
  result[i] += item
24
- return result
26
+ return result
@@ -6,7 +6,7 @@ import torch.distributed as dist
6
6
  import math
7
7
  from einops import rearrange
8
8
  from enum import Enum
9
- from typing import Callable, Dict, List, Tuple, Optional
9
+ from typing import Callable, Dict, List, Tuple, Optional, Union
10
10
  from tqdm import tqdm
11
11
  from PIL import Image
12
12
  from dataclasses import dataclass
@@ -16,6 +16,7 @@ from diffsynth_engine.models.flux import (
16
16
  FluxVAEDecoder,
17
17
  FluxVAEEncoder,
18
18
  FluxDiT,
19
+ FluxDiTFBCache,
19
20
  flux_dit_config,
20
21
  flux_text_encoder_config,
21
22
  )
@@ -429,6 +430,7 @@ class ControlType(Enum):
429
430
  elif self == ControlType.bfl_fill:
430
431
  return 384
431
432
 
433
+
432
434
  @dataclass
433
435
  class FluxModelConfig:
434
436
  dit_path: str | os.PathLike
@@ -460,7 +462,7 @@ class FluxImagePipeline(BasePipeline):
460
462
  tokenizer_2: T5TokenizerFast,
461
463
  text_encoder_1: FluxTextEncoder1,
462
464
  text_encoder_2: FluxTextEncoder2,
463
- dit: FluxDiT,
465
+ dit: Union[FluxDiT, FluxDiTFBCache],
464
466
  vae_decoder: FluxVAEDecoder,
465
467
  vae_encoder: FluxVAEEncoder,
466
468
  load_text_encoder: bool = True,
@@ -518,6 +520,8 @@ class FluxImagePipeline(BasePipeline):
518
520
  offload_mode: str | None = None,
519
521
  parallelism: int = 1,
520
522
  use_cfg_parallel: bool = False,
523
+ use_fb_cache: bool = False,
524
+ fb_cache_relative_l1_threshold: float = 0.05,
521
525
  ) -> "FluxImagePipeline":
522
526
  model_config = (
523
527
  model_path_or_config
@@ -562,13 +566,23 @@ class FluxImagePipeline(BasePipeline):
562
566
  vae_encoder = FluxVAEEncoder.from_state_dict(vae_state_dict, device=init_device, dtype=model_config.vae_dtype)
563
567
 
564
568
  with LoRAContext():
565
- dit = FluxDiT.from_state_dict(
566
- dit_state_dict,
567
- device=init_device,
568
- dtype=model_config.dit_dtype,
569
- in_channel=control_type.get_in_channel(),
570
- attn_impl=model_config.dit_attn_impl,
571
- )
569
+ if use_fb_cache:
570
+ dit = FluxDiTFBCache.from_state_dict(
571
+ dit_state_dict,
572
+ device=init_device,
573
+ dtype=model_config.dit_dtype,
574
+ in_channel=control_type.get_in_channel(),
575
+ attn_impl=model_config.dit_attn_impl,
576
+ relative_l1_threshold=fb_cache_relative_l1_threshold,
577
+ )
578
+ else:
579
+ dit = FluxDiT.from_state_dict(
580
+ dit_state_dict,
581
+ device=init_device,
582
+ dtype=model_config.dit_dtype,
583
+ in_channel=control_type.get_in_channel(),
584
+ attn_impl=model_config.dit_attn_impl,
585
+ )
572
586
  if model_config.use_fp8_linear:
573
587
  enable_fp8_linear(dit)
574
588
 
@@ -968,6 +982,8 @@ class FluxImagePipeline(BasePipeline):
968
982
  controlnet_params: List[ControlNetParams] | ControlNetParams = [],
969
983
  progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
970
984
  ):
985
+ if isinstance(self.dit, FluxDiTFBCache):
986
+ self.dit.refresh_cache_status(num_inference_steps)
971
987
  if not isinstance(controlnet_params, list):
972
988
  controlnet_params = [controlnet_params]
973
989
  if self.control_type != ControlType.normal:
@@ -291,7 +291,7 @@ class SDImagePipeline(BasePipeline):
291
291
  current_step: int,
292
292
  total_step: int,
293
293
  ):
294
- controlnet_res_stack = None
294
+ controlnet_res_stack = None
295
295
  for param in controlnet_params:
296
296
  current_scale = param.scale
297
297
  if not (
@@ -303,15 +303,10 @@ class SDImagePipeline(BasePipeline):
303
303
  if self.offload_mode is not None:
304
304
  empty_cache()
305
305
  param.model.to(self.device)
306
- controlnet_res = param.model(
307
- latents,
308
- timestep,
309
- prompt_emb,
310
- param.image
311
- )
306
+ controlnet_res = param.model(latents, timestep, prompt_emb, param.image)
312
307
  controlnet_res = [res * current_scale for res in controlnet_res]
313
308
  if self.offload_mode is not None:
314
- param.model.to("cpu")
309
+ param.model.to("cpu")
315
310
  empty_cache()
316
311
  controlnet_res_stack = accumulate(controlnet_res_stack, controlnet_res)
317
312
  return controlnet_res_stack
@@ -324,16 +319,22 @@ class SDImagePipeline(BasePipeline):
324
319
  negative_prompt_emb: torch.Tensor,
325
320
  controlnet_params: List[ControlNetParams],
326
321
  current_step: int,
327
- total_step: int,
322
+ total_step: int,
328
323
  cfg_scale: float,
329
324
  batch_cfg: bool = True,
330
325
  ):
331
326
  if cfg_scale <= 1.0:
332
- return self.predict_noise(latents, timestep, positive_prompt_emb, controlnet_params, current_step, total_step)
327
+ return self.predict_noise(
328
+ latents, timestep, positive_prompt_emb, controlnet_params, current_step, total_step
329
+ )
333
330
  if not batch_cfg:
334
331
  # cfg by predict noise one by one
335
- positive_noise_pred = self.predict_noise(latents, timestep, positive_prompt_emb, controlnet_params, current_step, total_step)
336
- negative_noise_pred = self.predict_noise(latents, timestep, negative_prompt_emb, controlnet_params, current_step, total_step)
332
+ positive_noise_pred = self.predict_noise(
333
+ latents, timestep, positive_prompt_emb, controlnet_params, current_step, total_step
334
+ )
335
+ negative_noise_pred = self.predict_noise(
336
+ latents, timestep, negative_prompt_emb, controlnet_params, current_step, total_step
337
+ )
337
338
  noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
338
339
  return noise_pred
339
340
  else:
@@ -341,12 +342,16 @@ class SDImagePipeline(BasePipeline):
341
342
  prompt_emb = torch.cat([positive_prompt_emb, negative_prompt_emb], dim=0)
342
343
  latents = torch.cat([latents, latents], dim=0)
343
344
  timestep = torch.cat([timestep, timestep], dim=0)
344
- positive_noise_pred, negative_noise_pred = self.predict_noise(latents, timestep, prompt_emb, controlnet_params, current_step, total_step).chunk(2)
345
+ positive_noise_pred, negative_noise_pred = self.predict_noise(
346
+ latents, timestep, prompt_emb, controlnet_params, current_step, total_step
347
+ ).chunk(2)
345
348
  noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
346
349
  return noise_pred
347
350
 
348
351
  def predict_noise(self, latents, timestep, prompt_emb, controlnet_params, current_step, total_step):
349
- controlnet_res_stack = self.predict_multicontrolnet(latents, timestep, prompt_emb, controlnet_params, current_step, total_step)
352
+ controlnet_res_stack = self.predict_multicontrolnet(
353
+ latents, timestep, prompt_emb, controlnet_params, current_step, total_step
354
+ )
350
355
 
351
356
  noise_pred = self.unet(
352
357
  x=latents,
@@ -433,7 +438,7 @@ class SDImagePipeline(BasePipeline):
433
438
  cfg_scale=cfg_scale,
434
439
  controlnet_params=controlnet_params,
435
440
  current_step=i,
436
- total_step=len(timesteps),
441
+ total_step=len(timesteps),
437
442
  batch_cfg=self.batch_cfg,
438
443
  )
439
444
  # Denoise
@@ -31,6 +31,7 @@ from diffsynth_engine.utils import logging
31
31
 
32
32
  logger = logging.get_logger(__name__)
33
33
 
34
+
34
35
  class SDXLLoRAConverter(LoRAStateDictConverter):
35
36
  def _replace_kohya_te1_key(self, key):
36
37
  key = key.replace("lora_te1_text_model_encoder_layers_", "encoders.")
@@ -91,7 +92,7 @@ class SDXLLoRAConverter(LoRAStateDictConverter):
91
92
  else:
92
93
  raise ValueError(f"Unsupported key: {key}")
93
94
  # clip skip
94
- te1_dict = {k: v for k, v in te1_dict.items() if not k.startswith('encoders.11')}
95
+ te1_dict = {k: v for k, v in te1_dict.items() if not k.startswith("encoders.11")}
95
96
  return {"unet": unet_dict, "text_encoder": te1_dict, "text_encoder_2": te2_dict}
96
97
 
97
98
  def convert(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
@@ -279,10 +280,7 @@ class SDXLImagePipeline(BasePipeline):
279
280
  condition = self.preprocess_control_image(param.image).to(device=self.device, dtype=self.dtype)
280
281
  results.append(
281
282
  ControlNetParams(
282
- model=param.model,
283
- scale=param.scale,
284
- image=condition,
285
- processor_name=param.processor_name
283
+ model=param.model, scale=param.scale, image=condition, processor_name=param.processor_name
286
284
  )
287
285
  )
288
286
  return results
@@ -307,13 +305,13 @@ class SDXLImagePipeline(BasePipeline):
307
305
  latents: torch.Tensor,
308
306
  timestep: torch.Tensor,
309
307
  prompt_emb: torch.Tensor,
310
- add_text_embeds: torch.Tensor,
311
- add_time_id: torch.Tensor,
308
+ add_text_embeds: torch.Tensor,
309
+ add_time_id: torch.Tensor,
312
310
  controlnet_params: List[ControlNetParams],
313
311
  current_step: int,
314
312
  total_step: int,
315
313
  ):
316
- controlnet_res_stack = None
314
+ controlnet_res_stack = None
317
315
  for param in controlnet_params:
318
316
  current_scale = param.scale
319
317
  if not (
@@ -338,8 +336,8 @@ class SDXLImagePipeline(BasePipeline):
338
336
  )
339
337
  controlnet_res = [res * current_scale for res in controlnet_res]
340
338
  if self.offload_mode is not None:
341
- param.model.to("cpu")
342
- empty_cache()
339
+ param.model.to("cpu")
340
+ empty_cache()
343
341
  controlnet_res_stack = accumulate(controlnet_res_stack, controlnet_res)
344
342
  return controlnet_res_stack
345
343
 
@@ -353,20 +351,36 @@ class SDXLImagePipeline(BasePipeline):
353
351
  negative_add_text_embeds: torch.Tensor,
354
352
  controlnet_params: List[ControlNetParams],
355
353
  current_step: int,
356
- total_step: int,
354
+ total_step: int,
357
355
  add_time_id: torch.Tensor,
358
356
  cfg_scale: float,
359
357
  batch_cfg: bool = True,
360
358
  ):
361
359
  if cfg_scale <= 1.0:
362
- return self.predict_noise(latents, timestep, positive_prompt_emb, add_time_id, controlnet_params, current_step, total_step)
360
+ return self.predict_noise(
361
+ latents, timestep, positive_prompt_emb, add_time_id, controlnet_params, current_step, total_step
362
+ )
363
363
  if not batch_cfg:
364
364
  # cfg by predict noise one by one
365
365
  positive_noise_pred = self.predict_noise(
366
- latents, timestep, positive_prompt_emb, positive_add_text_embeds, add_time_id, controlnet_params, current_step, total_step
366
+ latents,
367
+ timestep,
368
+ positive_prompt_emb,
369
+ positive_add_text_embeds,
370
+ add_time_id,
371
+ controlnet_params,
372
+ current_step,
373
+ total_step,
367
374
  )
368
375
  negative_noise_pred = self.predict_noise(
369
- latents, timestep, negative_prompt_emb, negative_add_text_embeds, add_time_id, controlnet_params, current_step, total_step
376
+ latents,
377
+ timestep,
378
+ negative_prompt_emb,
379
+ negative_add_text_embeds,
380
+ add_time_id,
381
+ controlnet_params,
382
+ current_step,
383
+ total_step,
370
384
  )
371
385
  noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
372
386
  return noise_pred
@@ -378,14 +392,25 @@ class SDXLImagePipeline(BasePipeline):
378
392
  latents = torch.cat([latents, latents], dim=0)
379
393
  timestep = torch.cat([timestep, timestep], dim=0)
380
394
  positive_noise_pred, negative_noise_pred = self.predict_noise(
381
- latents, timestep, prompt_emb, add_text_embeds, add_time_ids, controlnet_params, current_step, total_step
395
+ latents,
396
+ timestep,
397
+ prompt_emb,
398
+ add_text_embeds,
399
+ add_time_ids,
400
+ controlnet_params,
401
+ current_step,
402
+ total_step,
382
403
  )
383
404
  noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
384
405
  return noise_pred
385
406
 
386
- def predict_noise(self, latents, timestep, prompt_emb, add_text_embeds, add_time_id, controlnet_params, current_step, total_step):
407
+ def predict_noise(
408
+ self, latents, timestep, prompt_emb, add_text_embeds, add_time_id, controlnet_params, current_step, total_step
409
+ ):
387
410
  y = self.prepare_add_embeds(add_text_embeds, add_time_id, self.dtype)
388
- controlnet_res_stack = self.predict_multicontrolnet(latents, timestep, prompt_emb, add_text_embeds, add_time_id, controlnet_params, current_step, total_step)
411
+ controlnet_res_stack = self.predict_multicontrolnet(
412
+ latents, timestep, prompt_emb, add_text_embeds, add_time_id, controlnet_params, current_step, total_step
413
+ )
389
414
 
390
415
  noise_pred = self.unet(
391
416
  x=latents,
@@ -433,7 +458,7 @@ class SDXLImagePipeline(BasePipeline):
433
458
  width: int = 1024,
434
459
  num_inference_steps: int = 20,
435
460
  seed: int | None = None,
436
- controlnet_params: List[ControlNetParams] | ControlNetParams = [],
461
+ controlnet_params: List[ControlNetParams] | ControlNetParams = [],
437
462
  progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
438
463
  ):
439
464
  if not isinstance(controlnet_params, list):
@@ -491,7 +516,7 @@ class SDXLImagePipeline(BasePipeline):
491
516
  cfg_scale=cfg_scale,
492
517
  controlnet_params=controlnet_params,
493
518
  current_step=i,
494
- total_step=len(timesteps),
519
+ total_step=len(timesteps),
495
520
  batch_cfg=self.batch_cfg,
496
521
  )
497
522
  # Denoise
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.3.6.dev5
3
+ Version: 0.3.6.dev6
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -71,15 +71,16 @@ diffsynth_engine/models/basic/relative_position_emb.py,sha256=rCXOweZMcayVnNUVvB
71
71
  diffsynth_engine/models/basic/timestep.py,sha256=WJODYqkSXEM0wcS42YkkfrGwxWt0e60zMTkDdUBQqBw,2810
72
72
  diffsynth_engine/models/basic/transformer_helper.py,sha256=LAEXInYSVEF61T3bzvu33jvEyR59B2feKhLmgZvUJH8,3384
73
73
  diffsynth_engine/models/basic/unet_helper.py,sha256=4lN6F80Ubm6ip4dkLVmB-Og5-Y25Wduhs9Q8qjyzK6E,9044
74
- diffsynth_engine/models/flux/__init__.py,sha256=RHjSG6CPYY3UXrG3GdF82PgHRg0bpWZBRcpr_9TTZyk,619
74
+ diffsynth_engine/models/flux/__init__.py,sha256=x0JoxL0CdiiVrY0BjkIrGinud7mcXecLleGO0km91XQ,686
75
75
  diffsynth_engine/models/flux/flux_controlnet.py,sha256=YvQ5gfzttP092gwLzTzf7IcEnG4smOvILghfoCVuHOo,8163
76
76
  diffsynth_engine/models/flux/flux_dit.py,sha256=4mpMYVbK8OoQSQeGNJcnkR-rkc0jvw5cT9SOz0gxig0,23335
77
+ diffsynth_engine/models/flux/flux_dit_fbcache.py,sha256=1uG24Cv7IOfrHzwqBv1H_BGAbBsJkylvADaA_gnNjss,8497
77
78
  diffsynth_engine/models/flux/flux_ipadapter.py,sha256=SHN2hlibLy0OIScGTygV2yNNJZ7FLzpFdBXzykH6wTo,7129
78
79
  diffsynth_engine/models/flux/flux_redux.py,sha256=tVK4hZxHtC_jx-6vpc0os6dtohBUWK_V8ucdZBkeb-o,2560
79
80
  diffsynth_engine/models/flux/flux_text_encoder.py,sha256=Qcs277RIPP-O5AkcAb5Fb0jToV5o6Qn8hh8nw8zx9g8,3663
80
81
  diffsynth_engine/models/flux/flux_vae.py,sha256=qJzcpfQ9-ATQmE5n8nOy1c5BEA0GZjIddm8zOUodXnE,2840
81
82
  diffsynth_engine/models/sd/__init__.py,sha256=hjoKRnwoXOLD0wude-w7I6wK5ak7ACMbnbkPuBB2oU0,380
82
- diffsynth_engine/models/sd/sd_controlnet.py,sha256=JFNIg_ff28yMOC-foE8z4YCeBGT4uzDnl5pgALwvN1w,49490
83
+ diffsynth_engine/models/sd/sd_controlnet.py,sha256=K4iHSCf_UgIlogfLcwvcnnu4RwtgkI-IL-PsWq3e-gM,50761
83
84
  diffsynth_engine/models/sd/sd_text_encoder.py,sha256=gTj8GOGnrUjLDvIFgbHCMJAkdh8nn0E6FtumQ-6L2Xs,5406
84
85
  diffsynth_engine/models/sd/sd_unet.py,sha256=EbF0Vdt4tKp0VjqozBK71-utFAPp6OHZOdWn0czqYNo,12897
85
86
  diffsynth_engine/models/sd/sd_vae.py,sha256=neHAC2HJoiSCQZsyH-ufvIbg1VWmXk3OHz5N_-I5BTQ,1650
@@ -87,10 +88,10 @@ diffsynth_engine/models/sd3/__init__.py,sha256=Kd5JCDlfP-FfW0Z_BiugQxuRhs2E_3np_
87
88
  diffsynth_engine/models/sd3/sd3_dit.py,sha256=3G7XVQRw2iD3lFDvzfYXykB3M0QpX42adPnyjtLihos,11447
88
89
  diffsynth_engine/models/sd3/sd3_text_encoder.py,sha256=9I57W5AScVoV7L2If4CsOxD_k903S5jlUxEOiUF-xf4,7081
89
90
  diffsynth_engine/models/sd3/sd3_vae.py,sha256=nxZpg616aAagSWALVC4r2B-NLYeA2-u9kPYN0lgBm1I,1428
90
- diffsynth_engine/models/sdxl/__init__.py,sha256=mcciSseHyA8TFUiuSkaBUmusQh3OofBQA6vqpgzRr1M,472
91
- diffsynth_engine/models/sdxl/sdxl_controlnet.py,sha256=o-F5ZA8AT9dDVV-qr_JoFdHHI1h1pDqekMRzE0-mZ_U,14630
91
+ diffsynth_engine/models/sdxl/__init__.py,sha256=3DyhN7p2N2n0MF7RqnTl54JNL0qNQHjIqBMb7W8V2zA,468
92
+ diffsynth_engine/models/sdxl/sdxl_controlnet.py,sha256=rfLDGbe9txKS7DCEtCCn1RjD08SnC5BJsXR3LEunCXM,15380
92
93
  diffsynth_engine/models/sdxl/sdxl_text_encoder.py,sha256=VNSx1hfqKAOK-T3m34SaONscAJK8wFkyuVtfcTWmeRA,12369
93
- diffsynth_engine/models/sdxl/sdxl_unet.py,sha256=bKRo3l6_AodpPQQZAGZTpFbhqKrz0je3_s2G81xxRn4,12169
94
+ diffsynth_engine/models/sdxl/sdxl_unet.py,sha256=rnD-ZmGf2g_iRuS8I1zBjFJkcoMAE4TxuTjmISWlQhk,12156
94
95
  diffsynth_engine/models/sdxl/sdxl_vae.py,sha256=kNGnn5wKipD31qgMYYLeFyyi8iCc8bsAMD1pEtDhue8,1654
95
96
  diffsynth_engine/models/text_encoder/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
96
97
  diffsynth_engine/models/text_encoder/clip.py,sha256=RwJYtSTWjh7yw8dQKbQbiL335naqJhiM96RDCMTnFh4,1805
@@ -105,10 +106,10 @@ diffsynth_engine/models/wan/wan_text_encoder.py,sha256=bkphxtqNNwXcEA_OaUrwV9CvI
105
106
  diffsynth_engine/models/wan/wan_vae.py,sha256=RxyuHExQmRjGBAqhZdIbtwZFdCibTzh__U4-Sa00zdI,29004
106
107
  diffsynth_engine/pipelines/__init__.py,sha256=TQnwuOGtnlIBdi-afGvS0W4PXPDVkOyjCAAys0DMowA,605
107
108
  diffsynth_engine/pipelines/base.py,sha256=mmJGBC0bJQR1qQp9cknz7t-n5PrqmasTKqSpwWaxVVA,13461
108
- diffsynth_engine/pipelines/controlnet_helper.py,sha256=X1SMjgsMB4cs0Im2Rfy6b0UlMehKMoyFb834c0eGa-0,676
109
- diffsynth_engine/pipelines/flux_image.py,sha256=B5B7tgtdB9xoJfQyc6H7ttGdjCSFJP3fp6ffL_kUDUw,49870
110
- diffsynth_engine/pipelines/sd_image.py,sha256=plPD9eVFsharaJYwwM1EbYSqQPrRJ3GajEDBcXgcxbQ,18778
111
- diffsynth_engine/pipelines/sdxl_image.py,sha256=zO_YYST5Eim1pZFBGaadt9AklwQcGlYvLxvvbhI0vJ8,22386
109
+ diffsynth_engine/pipelines/controlnet_helper.py,sha256=b6HnJFJfMKZq9s5DQ-9Se8OTSDeHVk4AskONSwcRShg,680
110
+ diffsynth_engine/pipelines/flux_image.py,sha256=4UBdlTlzTQnpnh6t0vJroouk423yoxY6GHvBZm0lexY,50590
111
+ diffsynth_engine/pipelines/sd_image.py,sha256=A92oBYZt37x4uOYhl3dAiCzkcgsPmRmH1z4_IWEULeA,18807
112
+ diffsynth_engine/pipelines/sdxl_image.py,sha256=P1bFvFg3Dd49BAEihig-e5QkoNtztJqVxdfAuD693FU,22671
112
113
  diffsynth_engine/pipelines/wan_video.py,sha256=LqDBZILcIOK6SR0zPrDJgBjpBgkIah8N6Q7I-8XHchY,21584
113
114
  diffsynth_engine/processor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
114
115
  diffsynth_engine/processor/canny_processor.py,sha256=hV30NlblTkEFUAmF_O-LJrNlGVM2SFrqq6okfF8VpOo,602
@@ -140,8 +141,8 @@ diffsynth_engine/utils/parallel.py,sha256=2WISMBTTmW0v2qPvpms421-B59v3bYlS6YrLq9
140
141
  diffsynth_engine/utils/platform.py,sha256=q9ifmdzoa66Cj9YKfwps21DsDdwA0JGpwroKQbG6shU,224
141
142
  diffsynth_engine/utils/prompt.py,sha256=YItMchoVzsG6y-LB4vzzDUWrkhKRVlt1HfVhxZjSxMQ,280
142
143
  diffsynth_engine/utils/video.py,sha256=Ne0rd2lb59UT1q5EotpjlY7OT8F9oTCFDyo1ST77uoQ,1004
143
- diffsynth_engine-0.3.6.dev5.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
144
- diffsynth_engine-0.3.6.dev5.dist-info/METADATA,sha256=H5eWg2pjUB0FgRWcz06ObCJJPD5Wpkox6mWlqyGiyG4,1068
145
- diffsynth_engine-0.3.6.dev5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
146
- diffsynth_engine-0.3.6.dev5.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
147
- diffsynth_engine-0.3.6.dev5.dist-info/RECORD,,
144
+ diffsynth_engine-0.3.6.dev6.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
145
+ diffsynth_engine-0.3.6.dev6.dist-info/METADATA,sha256=RlAdpLv_Z2WfeqivtN0OVMGSTi4_rMPfOTiBJgyB6BI,1068
146
+ diffsynth_engine-0.3.6.dev6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
147
+ diffsynth_engine-0.3.6.dev6.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
148
+ diffsynth_engine-0.3.6.dev6.dist-info/RECORD,,