diffsynth-engine 0.3.6.dev5__py3-none-any.whl → 0.3.6.dev6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/models/flux/__init__.py +2 -0
- diffsynth_engine/models/flux/flux_dit_fbcache.py +205 -0
- diffsynth_engine/models/sd/sd_controlnet.py +167 -85
- diffsynth_engine/models/sdxl/__init__.py +1 -1
- diffsynth_engine/models/sdxl/sdxl_controlnet.py +118 -73
- diffsynth_engine/models/sdxl/sdxl_unet.py +1 -2
- diffsynth_engine/pipelines/controlnet_helper.py +4 -2
- diffsynth_engine/pipelines/flux_image.py +25 -9
- diffsynth_engine/pipelines/sd_image.py +20 -15
- diffsynth_engine/pipelines/sdxl_image.py +44 -19
- {diffsynth_engine-0.3.6.dev5.dist-info → diffsynth_engine-0.3.6.dev6.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.3.6.dev5.dist-info → diffsynth_engine-0.3.6.dev6.dist-info}/RECORD +15 -14
- {diffsynth_engine-0.3.6.dev5.dist-info → diffsynth_engine-0.3.6.dev6.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.3.6.dev5.dist-info → diffsynth_engine-0.3.6.dev6.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.3.6.dev5.dist-info → diffsynth_engine-0.3.6.dev6.dist-info}/top_level.txt +0 -0
|
@@ -4,6 +4,7 @@ from .flux_vae import FluxVAEDecoder, FluxVAEEncoder, config as flux_vae_config
|
|
|
4
4
|
from .flux_controlnet import FluxControlNet
|
|
5
5
|
from .flux_ipadapter import FluxIPAdapter
|
|
6
6
|
from .flux_redux import FluxRedux
|
|
7
|
+
from .flux_dit_fbcache import FluxDiTFBCache
|
|
7
8
|
|
|
8
9
|
__all__ = [
|
|
9
10
|
"FluxRedux",
|
|
@@ -14,6 +15,7 @@ __all__ = [
|
|
|
14
15
|
"FluxTextEncoder2",
|
|
15
16
|
"FluxVAEDecoder",
|
|
16
17
|
"FluxVAEEncoder",
|
|
18
|
+
"FluxDiTFBCache",
|
|
17
19
|
"flux_dit_config",
|
|
18
20
|
"flux_text_encoder_config",
|
|
19
21
|
"flux_vae_config",
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
from typing import Dict, Optional
|
|
4
|
+
|
|
5
|
+
from diffsynth_engine.models.utils import no_init_weights
|
|
6
|
+
from diffsynth_engine.utils.gguf import gguf_inference
|
|
7
|
+
from diffsynth_engine.utils.fp8_linear import fp8_inference
|
|
8
|
+
from diffsynth_engine.utils.parallel import (
|
|
9
|
+
cfg_parallel,
|
|
10
|
+
cfg_parallel_unshard,
|
|
11
|
+
sequence_parallel,
|
|
12
|
+
sequence_parallel_unshard,
|
|
13
|
+
)
|
|
14
|
+
from diffsynth_engine.utils import logging
|
|
15
|
+
from diffsynth_engine.models.flux.flux_dit import FluxDiT
|
|
16
|
+
|
|
17
|
+
logger = logging.get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class FluxDiTFBCache(FluxDiT):
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
in_channel: int = 64,
|
|
24
|
+
attn_impl: Optional[str] = None,
|
|
25
|
+
device: str = "cuda:0",
|
|
26
|
+
dtype: torch.dtype = torch.bfloat16,
|
|
27
|
+
relative_l1_threshold: float = 0.05,
|
|
28
|
+
):
|
|
29
|
+
super().__init__(in_channel=in_channel, attn_impl=attn_impl, device=device, dtype=dtype)
|
|
30
|
+
self.relative_l1_threshold = relative_l1_threshold
|
|
31
|
+
self.step_count = 0
|
|
32
|
+
self.num_inference_steps = 0
|
|
33
|
+
|
|
34
|
+
def is_relative_l1_below_threshold(self, prev_residual, residual, threshold):
|
|
35
|
+
if threshold <= 0.0:
|
|
36
|
+
return False
|
|
37
|
+
|
|
38
|
+
if prev_residual.shape != residual.shape:
|
|
39
|
+
return False
|
|
40
|
+
|
|
41
|
+
mean_diff = (prev_residual - residual).abs().mean()
|
|
42
|
+
mean_prev_residual = prev_residual.abs().mean()
|
|
43
|
+
diff = mean_diff / mean_prev_residual
|
|
44
|
+
return diff.item() < threshold
|
|
45
|
+
|
|
46
|
+
def refresh_cache_status(self, num_inference_steps):
|
|
47
|
+
self.step_count = 0
|
|
48
|
+
self.num_inference_steps = num_inference_steps
|
|
49
|
+
|
|
50
|
+
def forward(
|
|
51
|
+
self,
|
|
52
|
+
hidden_states,
|
|
53
|
+
timestep,
|
|
54
|
+
prompt_emb,
|
|
55
|
+
pooled_prompt_emb,
|
|
56
|
+
image_emb,
|
|
57
|
+
guidance,
|
|
58
|
+
text_ids,
|
|
59
|
+
image_ids=None,
|
|
60
|
+
controlnet_double_block_output=None,
|
|
61
|
+
controlnet_single_block_output=None,
|
|
62
|
+
**kwargs,
|
|
63
|
+
):
|
|
64
|
+
h, w = hidden_states.shape[-2:]
|
|
65
|
+
if image_ids is None:
|
|
66
|
+
image_ids = self.prepare_image_ids(hidden_states)
|
|
67
|
+
controlnet_double_block_output = (
|
|
68
|
+
controlnet_double_block_output if controlnet_double_block_output is not None else ()
|
|
69
|
+
)
|
|
70
|
+
controlnet_single_block_output = (
|
|
71
|
+
controlnet_single_block_output if controlnet_single_block_output is not None else ()
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
|
|
75
|
+
use_cfg = hidden_states.shape[0] > 1
|
|
76
|
+
with (
|
|
77
|
+
fp8_inference(fp8_linear_enabled),
|
|
78
|
+
gguf_inference(),
|
|
79
|
+
cfg_parallel(
|
|
80
|
+
(
|
|
81
|
+
hidden_states,
|
|
82
|
+
timestep,
|
|
83
|
+
prompt_emb,
|
|
84
|
+
pooled_prompt_emb,
|
|
85
|
+
image_emb,
|
|
86
|
+
guidance,
|
|
87
|
+
text_ids,
|
|
88
|
+
image_ids,
|
|
89
|
+
*controlnet_double_block_output,
|
|
90
|
+
*controlnet_single_block_output,
|
|
91
|
+
),
|
|
92
|
+
use_cfg=use_cfg,
|
|
93
|
+
),
|
|
94
|
+
):
|
|
95
|
+
# warning: keep the order of time_embedding + guidance_embedding + pooled_text_embedding
|
|
96
|
+
# addition of floating point numbers does not meet commutative law
|
|
97
|
+
conditioning = self.time_embedder(timestep, hidden_states.dtype)
|
|
98
|
+
if self.guidance_embedder is not None:
|
|
99
|
+
guidance = guidance * 1000
|
|
100
|
+
conditioning += self.guidance_embedder(guidance, hidden_states.dtype)
|
|
101
|
+
conditioning += self.pooled_text_embedder(pooled_prompt_emb)
|
|
102
|
+
rope_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
|
103
|
+
text_rope_emb = rope_emb[:, :, : text_ids.size(1)]
|
|
104
|
+
image_rope_emb = rope_emb[:, :, text_ids.size(1) :]
|
|
105
|
+
hidden_states = self.patchify(hidden_states)
|
|
106
|
+
|
|
107
|
+
with sequence_parallel(
|
|
108
|
+
(
|
|
109
|
+
hidden_states,
|
|
110
|
+
prompt_emb,
|
|
111
|
+
text_rope_emb,
|
|
112
|
+
image_rope_emb,
|
|
113
|
+
*controlnet_double_block_output,
|
|
114
|
+
*controlnet_single_block_output,
|
|
115
|
+
),
|
|
116
|
+
seq_dims=(
|
|
117
|
+
1,
|
|
118
|
+
1,
|
|
119
|
+
2,
|
|
120
|
+
2,
|
|
121
|
+
*(1 for _ in controlnet_double_block_output),
|
|
122
|
+
*(1 for _ in controlnet_single_block_output),
|
|
123
|
+
),
|
|
124
|
+
):
|
|
125
|
+
hidden_states = self.x_embedder(hidden_states)
|
|
126
|
+
prompt_emb = self.context_embedder(prompt_emb)
|
|
127
|
+
rope_emb = torch.cat((text_rope_emb, image_rope_emb), dim=2)
|
|
128
|
+
|
|
129
|
+
# first block
|
|
130
|
+
original_hidden_states = hidden_states
|
|
131
|
+
hidden_states, prompt_emb = self.blocks[0](hidden_states, prompt_emb, conditioning, rope_emb, image_emb)
|
|
132
|
+
first_hidden_states_residual = hidden_states - original_hidden_states
|
|
133
|
+
|
|
134
|
+
(first_hidden_states_residual,) = sequence_parallel_unshard(
|
|
135
|
+
(first_hidden_states_residual,), seq_dims=(1,), seq_lens=(h * w // 4,)
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
if self.step_count == 0 or self.step_count == (self.num_inference_steps - 1):
|
|
139
|
+
should_calc = True
|
|
140
|
+
else:
|
|
141
|
+
skip = self.is_relative_l1_below_threshold(
|
|
142
|
+
first_hidden_states_residual,
|
|
143
|
+
self.prev_first_hidden_states_residual,
|
|
144
|
+
threshold=self.relative_l1_threshold,
|
|
145
|
+
)
|
|
146
|
+
should_calc = not skip
|
|
147
|
+
self.step_count += 1
|
|
148
|
+
|
|
149
|
+
if not should_calc:
|
|
150
|
+
hidden_states += self.previous_residual
|
|
151
|
+
else:
|
|
152
|
+
self.prev_first_hidden_states_residual = first_hidden_states_residual
|
|
153
|
+
|
|
154
|
+
first_hidden_states = hidden_states.clone()
|
|
155
|
+
for i, block in enumerate(self.blocks[1:]):
|
|
156
|
+
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, rope_emb, image_emb)
|
|
157
|
+
if len(controlnet_double_block_output) > 0:
|
|
158
|
+
interval_control = len(self.blocks) / len(controlnet_double_block_output)
|
|
159
|
+
interval_control = int(np.ceil(interval_control))
|
|
160
|
+
hidden_states = hidden_states + controlnet_double_block_output[i // interval_control]
|
|
161
|
+
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
|
162
|
+
for i, block in enumerate(self.single_blocks):
|
|
163
|
+
hidden_states = block(hidden_states, conditioning, rope_emb, image_emb)
|
|
164
|
+
if len(controlnet_single_block_output) > 0:
|
|
165
|
+
interval_control = len(self.single_blocks) / len(controlnet_double_block_output)
|
|
166
|
+
interval_control = int(np.ceil(interval_control))
|
|
167
|
+
hidden_states = hidden_states + controlnet_single_block_output[i // interval_control]
|
|
168
|
+
|
|
169
|
+
hidden_states = hidden_states[:, prompt_emb.shape[1] :]
|
|
170
|
+
|
|
171
|
+
previous_residual = hidden_states - first_hidden_states
|
|
172
|
+
self.previous_residual = previous_residual
|
|
173
|
+
|
|
174
|
+
hidden_states = self.final_norm_out(hidden_states, conditioning)
|
|
175
|
+
hidden_states = self.final_proj_out(hidden_states)
|
|
176
|
+
(hidden_states,) = sequence_parallel_unshard((hidden_states,), seq_dims=(1,), seq_lens=(h * w // 4,))
|
|
177
|
+
|
|
178
|
+
hidden_states = self.unpatchify(hidden_states, h, w)
|
|
179
|
+
(hidden_states,) = cfg_parallel_unshard((hidden_states,), use_cfg=use_cfg)
|
|
180
|
+
|
|
181
|
+
return hidden_states
|
|
182
|
+
|
|
183
|
+
@classmethod
|
|
184
|
+
def from_state_dict(
|
|
185
|
+
cls,
|
|
186
|
+
state_dict: Dict[str, torch.Tensor],
|
|
187
|
+
device: str,
|
|
188
|
+
dtype: torch.dtype,
|
|
189
|
+
in_channel: int = 64,
|
|
190
|
+
attn_impl: Optional[str] = None,
|
|
191
|
+
fb_cache_relative_l1_threshold: float = 0.05,
|
|
192
|
+
):
|
|
193
|
+
with no_init_weights():
|
|
194
|
+
model = torch.nn.utils.skip_init(
|
|
195
|
+
cls,
|
|
196
|
+
device=device,
|
|
197
|
+
dtype=dtype,
|
|
198
|
+
in_channel=in_channel,
|
|
199
|
+
attn_impl=attn_impl,
|
|
200
|
+
fb_cache_relative_l1_threshold=fb_cache_relative_l1_threshold,
|
|
201
|
+
)
|
|
202
|
+
model = model.requires_grad_(False) # for loading gguf
|
|
203
|
+
model.load_state_dict(state_dict, assign=True)
|
|
204
|
+
model.to(device=device, dtype=dtype, non_blocking=True)
|
|
205
|
+
return model
|
|
@@ -12,18 +12,29 @@ from diffsynth_engine.models.basic.unet_helper import (
|
|
|
12
12
|
DownSampler,
|
|
13
13
|
)
|
|
14
14
|
|
|
15
|
+
|
|
15
16
|
class ControlNetConditioningLayer(nn.Module):
|
|
16
|
-
def __init__(self, channels
|
|
17
|
+
def __init__(self, channels=(3, 16, 32, 96, 256, 320), device="cuda:0", dtype=torch.float16):
|
|
17
18
|
super().__init__()
|
|
18
19
|
self.blocks = torch.nn.ModuleList([])
|
|
19
|
-
self.blocks.append(
|
|
20
|
+
self.blocks.append(
|
|
21
|
+
torch.nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1, device=device, dtype=dtype)
|
|
22
|
+
)
|
|
20
23
|
self.blocks.append(torch.nn.SiLU())
|
|
21
24
|
for i in range(1, len(channels) - 2):
|
|
22
|
-
self.blocks.append(
|
|
25
|
+
self.blocks.append(
|
|
26
|
+
torch.nn.Conv2d(channels[i], channels[i], kernel_size=3, padding=1, device=device, dtype=dtype)
|
|
27
|
+
)
|
|
23
28
|
self.blocks.append(torch.nn.SiLU())
|
|
24
|
-
self.blocks.append(
|
|
29
|
+
self.blocks.append(
|
|
30
|
+
torch.nn.Conv2d(
|
|
31
|
+
channels[i], channels[i + 1], kernel_size=3, padding=1, stride=2, device=device, dtype=dtype
|
|
32
|
+
)
|
|
33
|
+
)
|
|
25
34
|
self.blocks.append(torch.nn.SiLU())
|
|
26
|
-
self.blocks.append(
|
|
35
|
+
self.blocks.append(
|
|
36
|
+
torch.nn.Conv2d(channels[-2], channels[-1], kernel_size=3, padding=1, device=device, dtype=dtype)
|
|
37
|
+
)
|
|
27
38
|
|
|
28
39
|
def forward(self, conditioning):
|
|
29
40
|
for block in self.blocks:
|
|
@@ -38,15 +49,73 @@ class SDControlNetStateDictConverter(StateDictConverter):
|
|
|
38
49
|
def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
39
50
|
# architecture
|
|
40
51
|
block_types = [
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
52
|
+
"ResnetBlock",
|
|
53
|
+
"AttentionBlock",
|
|
54
|
+
"PushBlock",
|
|
55
|
+
"ResnetBlock",
|
|
56
|
+
"AttentionBlock",
|
|
57
|
+
"PushBlock",
|
|
58
|
+
"DownSampler",
|
|
59
|
+
"PushBlock",
|
|
60
|
+
"ResnetBlock",
|
|
61
|
+
"AttentionBlock",
|
|
62
|
+
"PushBlock",
|
|
63
|
+
"ResnetBlock",
|
|
64
|
+
"AttentionBlock",
|
|
65
|
+
"PushBlock",
|
|
66
|
+
"DownSampler",
|
|
67
|
+
"PushBlock",
|
|
68
|
+
"ResnetBlock",
|
|
69
|
+
"AttentionBlock",
|
|
70
|
+
"PushBlock",
|
|
71
|
+
"ResnetBlock",
|
|
72
|
+
"AttentionBlock",
|
|
73
|
+
"PushBlock",
|
|
74
|
+
"DownSampler",
|
|
75
|
+
"PushBlock",
|
|
76
|
+
"ResnetBlock",
|
|
77
|
+
"PushBlock",
|
|
78
|
+
"ResnetBlock",
|
|
79
|
+
"PushBlock",
|
|
80
|
+
"ResnetBlock",
|
|
81
|
+
"AttentionBlock",
|
|
82
|
+
"ResnetBlock",
|
|
83
|
+
"PopBlock",
|
|
84
|
+
"ResnetBlock",
|
|
85
|
+
"PopBlock",
|
|
86
|
+
"ResnetBlock",
|
|
87
|
+
"PopBlock",
|
|
88
|
+
"ResnetBlock",
|
|
89
|
+
"UpSampler",
|
|
90
|
+
"PopBlock",
|
|
91
|
+
"ResnetBlock",
|
|
92
|
+
"AttentionBlock",
|
|
93
|
+
"PopBlock",
|
|
94
|
+
"ResnetBlock",
|
|
95
|
+
"AttentionBlock",
|
|
96
|
+
"PopBlock",
|
|
97
|
+
"ResnetBlock",
|
|
98
|
+
"AttentionBlock",
|
|
99
|
+
"UpSampler",
|
|
100
|
+
"PopBlock",
|
|
101
|
+
"ResnetBlock",
|
|
102
|
+
"AttentionBlock",
|
|
103
|
+
"PopBlock",
|
|
104
|
+
"ResnetBlock",
|
|
105
|
+
"AttentionBlock",
|
|
106
|
+
"PopBlock",
|
|
107
|
+
"ResnetBlock",
|
|
108
|
+
"AttentionBlock",
|
|
109
|
+
"UpSampler",
|
|
110
|
+
"PopBlock",
|
|
111
|
+
"ResnetBlock",
|
|
112
|
+
"AttentionBlock",
|
|
113
|
+
"PopBlock",
|
|
114
|
+
"ResnetBlock",
|
|
115
|
+
"AttentionBlock",
|
|
116
|
+
"PopBlock",
|
|
117
|
+
"ResnetBlock",
|
|
118
|
+
"AttentionBlock",
|
|
50
119
|
]
|
|
51
120
|
|
|
52
121
|
# controlnet_rename_dict
|
|
@@ -66,7 +135,7 @@ class SDControlNetStateDictConverter(StateDictConverter):
|
|
|
66
135
|
"controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
|
|
67
136
|
"controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
|
|
68
137
|
"controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
|
|
69
|
-
"controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
|
|
138
|
+
"controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
|
|
70
139
|
}
|
|
71
140
|
|
|
72
141
|
# Rename each parameter
|
|
@@ -91,7 +160,12 @@ class SDControlNetStateDictConverter(StateDictConverter):
|
|
|
91
160
|
elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
|
|
92
161
|
if names[0] == "mid_block":
|
|
93
162
|
names.insert(1, "0")
|
|
94
|
-
block_type = {
|
|
163
|
+
block_type = {
|
|
164
|
+
"resnets": "ResnetBlock",
|
|
165
|
+
"attentions": "AttentionBlock",
|
|
166
|
+
"downsamplers": "DownSampler",
|
|
167
|
+
"upsamplers": "UpSampler",
|
|
168
|
+
}[names[2]]
|
|
95
169
|
block_type_with_id = ".".join(names[:4])
|
|
96
170
|
if block_type_with_id != last_block_type_with_id[block_type]:
|
|
97
171
|
block_id[block_type] += 1
|
|
@@ -102,9 +176,9 @@ class SDControlNetStateDictConverter(StateDictConverter):
|
|
|
102
176
|
names = ["blocks", str(block_id[block_type])] + names[4:]
|
|
103
177
|
if "ff" in names:
|
|
104
178
|
ff_index = names.index("ff")
|
|
105
|
-
component = ".".join(names[ff_index:ff_index+3])
|
|
179
|
+
component = ".".join(names[ff_index : ff_index + 3])
|
|
106
180
|
component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
|
|
107
|
-
names = names[:ff_index] + [component] + names[ff_index+3:]
|
|
181
|
+
names = names[:ff_index] + [component] + names[ff_index + 3 :]
|
|
108
182
|
if "to_out" in names:
|
|
109
183
|
names.pop(names.index("to_out") + 1)
|
|
110
184
|
else:
|
|
@@ -117,13 +191,21 @@ class SDControlNetStateDictConverter(StateDictConverter):
|
|
|
117
191
|
if ".proj_in." in name or ".proj_out." in name:
|
|
118
192
|
param = param.squeeze()
|
|
119
193
|
if rename_dict[name] in [
|
|
120
|
-
"controlnet_blocks.1.bias",
|
|
121
|
-
"controlnet_blocks.
|
|
194
|
+
"controlnet_blocks.1.bias",
|
|
195
|
+
"controlnet_blocks.2.bias",
|
|
196
|
+
"controlnet_blocks.3.bias",
|
|
197
|
+
"controlnet_blocks.5.bias",
|
|
198
|
+
"controlnet_blocks.6.bias",
|
|
199
|
+
"controlnet_blocks.8.bias",
|
|
200
|
+
"controlnet_blocks.9.bias",
|
|
201
|
+
"controlnet_blocks.10.bias",
|
|
202
|
+
"controlnet_blocks.11.bias",
|
|
203
|
+
"controlnet_blocks.12.bias",
|
|
122
204
|
]:
|
|
123
205
|
continue
|
|
124
206
|
state_dict_[rename_dict[name]] = param
|
|
125
207
|
return state_dict_
|
|
126
|
-
|
|
208
|
+
|
|
127
209
|
def _from_civitai(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
128
210
|
rename_dict = {
|
|
129
211
|
"control_model.time_embed.0.weight": "time_embedding.timestep_embedder.0.weight",
|
|
@@ -496,69 +578,71 @@ class SDControlNet(PreTrainedModel):
|
|
|
496
578
|
self.time_embedding = TimestepEmbeddings(dim_in=320, dim_out=1280, device=device, dtype=dtype)
|
|
497
579
|
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1, device=device, dtype=dtype)
|
|
498
580
|
|
|
499
|
-
self.controlnet_conv_in = ControlNetConditioningLayer(
|
|
581
|
+
self.controlnet_conv_in = ControlNetConditioningLayer(
|
|
582
|
+
channels=(3, 16, 32, 96, 256, 320), device=device, dtype=dtype
|
|
583
|
+
)
|
|
500
584
|
|
|
501
|
-
self.blocks = torch.nn.ModuleList(
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
585
|
+
self.blocks = torch.nn.ModuleList(
|
|
586
|
+
[
|
|
587
|
+
# CrossAttnDownBlock2D
|
|
588
|
+
ResnetBlock(320, 320, 1280, device=device, dtype=dtype),
|
|
589
|
+
AttentionBlock(8, 40, 320, 1, 768, device=device, dtype=dtype),
|
|
590
|
+
PushBlock(),
|
|
591
|
+
ResnetBlock(320, 320, 1280, device=device, dtype=dtype),
|
|
592
|
+
AttentionBlock(8, 40, 320, 1, 768, device=device, dtype=dtype),
|
|
593
|
+
PushBlock(),
|
|
594
|
+
DownSampler(320, device=device, dtype=dtype),
|
|
595
|
+
PushBlock(),
|
|
596
|
+
# CrossAttnDownBlock2D
|
|
597
|
+
ResnetBlock(320, 640, 1280, device=device, dtype=dtype),
|
|
598
|
+
AttentionBlock(8, 80, 640, 1, 768, device=device, dtype=dtype),
|
|
599
|
+
PushBlock(),
|
|
600
|
+
ResnetBlock(640, 640, 1280, device=device, dtype=dtype),
|
|
601
|
+
AttentionBlock(8, 80, 640, 1, 768, device=device, dtype=dtype),
|
|
602
|
+
PushBlock(),
|
|
603
|
+
DownSampler(640, device=device, dtype=dtype),
|
|
604
|
+
PushBlock(),
|
|
605
|
+
# CrossAttnDownBlock2D
|
|
606
|
+
ResnetBlock(640, 1280, 1280, device=device, dtype=dtype),
|
|
607
|
+
AttentionBlock(8, 160, 1280, 1, 768, device=device, dtype=dtype),
|
|
608
|
+
PushBlock(),
|
|
609
|
+
ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
|
|
610
|
+
AttentionBlock(8, 160, 1280, 1, 768, device=device, dtype=dtype),
|
|
611
|
+
PushBlock(),
|
|
612
|
+
DownSampler(1280, device=device, dtype=dtype),
|
|
613
|
+
PushBlock(),
|
|
614
|
+
# DownBlock2D
|
|
615
|
+
ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
|
|
616
|
+
PushBlock(),
|
|
617
|
+
ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
|
|
618
|
+
PushBlock(),
|
|
619
|
+
# UNetMidBlock2DCrossAttn
|
|
620
|
+
ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
|
|
621
|
+
AttentionBlock(8, 160, 1280, 1, 768, device=device, dtype=dtype),
|
|
622
|
+
ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
|
|
623
|
+
PushBlock(),
|
|
624
|
+
]
|
|
625
|
+
)
|
|
540
626
|
|
|
541
|
-
self.controlnet_blocks = torch.nn.ModuleList(
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
627
|
+
self.controlnet_blocks = torch.nn.ModuleList(
|
|
628
|
+
[
|
|
629
|
+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype),
|
|
630
|
+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
|
|
631
|
+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
|
|
632
|
+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
|
|
633
|
+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), device=device, dtype=dtype),
|
|
634
|
+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
|
|
635
|
+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
|
|
636
|
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), device=device, dtype=dtype),
|
|
637
|
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
|
|
638
|
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
|
|
639
|
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
|
|
640
|
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
|
|
641
|
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
|
|
642
|
+
]
|
|
643
|
+
)
|
|
556
644
|
|
|
557
|
-
def forward(
|
|
558
|
-
self,
|
|
559
|
-
sample, timestep, encoder_hidden_states, conditioning,
|
|
560
|
-
**kwargs
|
|
561
|
-
):
|
|
645
|
+
def forward(self, sample, timestep, encoder_hidden_states, conditioning, **kwargs):
|
|
562
646
|
# 1. time
|
|
563
647
|
time_emb = self.time_embedding(timestep, dtype=sample.dtype)
|
|
564
648
|
|
|
@@ -585,9 +669,7 @@ class SDControlNet(PreTrainedModel):
|
|
|
585
669
|
attn_impl: Optional[str] = None,
|
|
586
670
|
):
|
|
587
671
|
with no_init_weights():
|
|
588
|
-
model = torch.nn.utils.skip_init(
|
|
589
|
-
cls, attn_impl=attn_impl, device=device, dtype=dtype
|
|
590
|
-
)
|
|
672
|
+
model = torch.nn.utils.skip_init(cls, attn_impl=attn_impl, device=device, dtype=dtype)
|
|
591
673
|
model.load_state_dict(state_dict)
|
|
592
674
|
model.to(device=device, dtype=dtype, non_blocking=True)
|
|
593
|
-
return model
|
|
675
|
+
return model
|
|
@@ -12,23 +12,27 @@ from diffsynth_engine.models.basic.timestep import TimestepEmbeddings, TemporalT
|
|
|
12
12
|
|
|
13
13
|
from collections import OrderedDict
|
|
14
14
|
|
|
15
|
-
class QuickGELU(torch.nn.Module):
|
|
16
15
|
|
|
16
|
+
class QuickGELU(torch.nn.Module):
|
|
17
17
|
def forward(self, x: torch.Tensor):
|
|
18
18
|
return x * torch.sigmoid(1.702 * x)
|
|
19
19
|
|
|
20
|
-
class ResidualAttentionBlock(torch.nn.Module):
|
|
21
20
|
|
|
21
|
+
class ResidualAttentionBlock(torch.nn.Module):
|
|
22
22
|
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, device="cuda:0", dtype=torch.float16):
|
|
23
23
|
super().__init__()
|
|
24
24
|
|
|
25
25
|
self.attn = torch.nn.MultiheadAttention(d_model, n_head, device=device, dtype=dtype)
|
|
26
26
|
self.ln_1 = torch.nn.LayerNorm(d_model, device=device, dtype=dtype)
|
|
27
|
-
self.mlp = torch.nn.Sequential(
|
|
28
|
-
(
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
27
|
+
self.mlp = torch.nn.Sequential(
|
|
28
|
+
OrderedDict(
|
|
29
|
+
[
|
|
30
|
+
("c_fc", torch.nn.Linear(d_model, d_model * 4, device=device, dtype=dtype)),
|
|
31
|
+
("gelu", QuickGELU()),
|
|
32
|
+
("c_proj", torch.nn.Linear(d_model * 4, d_model, device=device, dtype=dtype)),
|
|
33
|
+
]
|
|
34
|
+
)
|
|
35
|
+
)
|
|
32
36
|
self.ln_2 = torch.nn.LayerNorm(d_model, device=device, dtype=dtype)
|
|
33
37
|
self.attn_mask = attn_mask
|
|
34
38
|
|
|
@@ -49,10 +53,30 @@ class SDXLControlNetUnionStateDictConverter(StateDictConverter):
|
|
|
49
53
|
def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
50
54
|
# architecture
|
|
51
55
|
block_types = [
|
|
52
|
-
"ResnetBlock",
|
|
53
|
-
"
|
|
54
|
-
"ResnetBlock",
|
|
55
|
-
"
|
|
56
|
+
"ResnetBlock",
|
|
57
|
+
"PushBlock",
|
|
58
|
+
"ResnetBlock",
|
|
59
|
+
"PushBlock",
|
|
60
|
+
"DownSampler",
|
|
61
|
+
"PushBlock",
|
|
62
|
+
"ResnetBlock",
|
|
63
|
+
"AttentionBlock",
|
|
64
|
+
"PushBlock",
|
|
65
|
+
"ResnetBlock",
|
|
66
|
+
"AttentionBlock",
|
|
67
|
+
"PushBlock",
|
|
68
|
+
"DownSampler",
|
|
69
|
+
"PushBlock",
|
|
70
|
+
"ResnetBlock",
|
|
71
|
+
"AttentionBlock",
|
|
72
|
+
"PushBlock",
|
|
73
|
+
"ResnetBlock",
|
|
74
|
+
"AttentionBlock",
|
|
75
|
+
"PushBlock",
|
|
76
|
+
"ResnetBlock",
|
|
77
|
+
"AttentionBlock",
|
|
78
|
+
"ResnetBlock",
|
|
79
|
+
"PushBlock",
|
|
56
80
|
]
|
|
57
81
|
|
|
58
82
|
# controlnet_rename_dict
|
|
@@ -107,7 +131,12 @@ class SDXLControlNetUnionStateDictConverter(StateDictConverter):
|
|
|
107
131
|
elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
|
|
108
132
|
if names[0] == "mid_block":
|
|
109
133
|
names.insert(1, "0")
|
|
110
|
-
block_type = {
|
|
134
|
+
block_type = {
|
|
135
|
+
"resnets": "ResnetBlock",
|
|
136
|
+
"attentions": "AttentionBlock",
|
|
137
|
+
"downsamplers": "DownSampler",
|
|
138
|
+
"upsamplers": "UpSampler",
|
|
139
|
+
}[names[2]]
|
|
111
140
|
block_type_with_id = ".".join(names[:4])
|
|
112
141
|
if block_type_with_id != last_block_type_with_id[block_type]:
|
|
113
142
|
block_id[block_type] += 1
|
|
@@ -118,9 +147,9 @@ class SDXLControlNetUnionStateDictConverter(StateDictConverter):
|
|
|
118
147
|
names = ["blocks", str(block_id[block_type])] + names[4:]
|
|
119
148
|
if "ff" in names:
|
|
120
149
|
ff_index = names.index("ff")
|
|
121
|
-
component = ".".join(names[ff_index:ff_index+3])
|
|
150
|
+
component = ".".join(names[ff_index : ff_index + 3])
|
|
122
151
|
component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
|
|
123
|
-
names = names[:ff_index] + [component] + names[ff_index+3:]
|
|
152
|
+
names = names[:ff_index] + [component] + names[ff_index + 3 :]
|
|
124
153
|
if "to_out" in names:
|
|
125
154
|
names.pop(names.index("to_out") + 1)
|
|
126
155
|
else:
|
|
@@ -137,19 +166,20 @@ class SDXLControlNetUnionStateDictConverter(StateDictConverter):
|
|
|
137
166
|
param = param.squeeze()
|
|
138
167
|
state_dict_[rename_dict[name]] = param
|
|
139
168
|
return state_dict_
|
|
140
|
-
|
|
169
|
+
|
|
141
170
|
# TODO: check civitai
|
|
142
171
|
def _from_civitai(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
143
172
|
return self._from_diffusers(state_dict)
|
|
144
173
|
|
|
145
|
-
|
|
146
174
|
def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
147
175
|
return self._from_diffusers(state_dict)
|
|
148
176
|
|
|
177
|
+
|
|
149
178
|
class SDXLControlNetUnion(PreTrainedModel):
|
|
150
179
|
converter = SDXLControlNetUnionStateDictConverter()
|
|
151
180
|
|
|
152
|
-
def __init__(
|
|
181
|
+
def __init__(
|
|
182
|
+
self,
|
|
153
183
|
attn_impl: Optional[str] = None,
|
|
154
184
|
device: str = "cuda:0",
|
|
155
185
|
dtype: torch.dtype = torch.bfloat16,
|
|
@@ -157,68 +187,78 @@ class SDXLControlNetUnion(PreTrainedModel):
|
|
|
157
187
|
super().__init__()
|
|
158
188
|
self.time_embedding = TimestepEmbeddings(dim_in=320, dim_out=1280, device=device, dtype=dtype)
|
|
159
189
|
|
|
160
|
-
self.add_time_proj = TemporalTimesteps(
|
|
190
|
+
self.add_time_proj = TemporalTimesteps(
|
|
191
|
+
256, flip_sin_to_cos=True, downscale_freq_shift=0, device=device, dtype=dtype
|
|
192
|
+
)
|
|
161
193
|
self.add_time_embedding = torch.nn.Sequential(
|
|
162
194
|
torch.nn.Linear(2816, 1280, device=device, dtype=dtype),
|
|
163
195
|
torch.nn.SiLU(),
|
|
164
|
-
torch.nn.Linear(1280, 1280, device=device, dtype=dtype)
|
|
196
|
+
torch.nn.Linear(1280, 1280, device=device, dtype=dtype),
|
|
197
|
+
)
|
|
198
|
+
self.control_type_proj = TemporalTimesteps(
|
|
199
|
+
256, flip_sin_to_cos=True, downscale_freq_shift=0, device=device, dtype=dtype
|
|
165
200
|
)
|
|
166
|
-
self.control_type_proj = TemporalTimesteps(256, flip_sin_to_cos=True, downscale_freq_shift=0, device=device, dtype=dtype)
|
|
167
201
|
self.control_type_embedding = torch.nn.Sequential(
|
|
168
202
|
torch.nn.Linear(256 * 8, 1280, device=device, dtype=dtype),
|
|
169
203
|
torch.nn.SiLU(),
|
|
170
|
-
torch.nn.Linear(1280, 1280, device=device, dtype=dtype)
|
|
204
|
+
torch.nn.Linear(1280, 1280, device=device, dtype=dtype),
|
|
171
205
|
)
|
|
172
206
|
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1, device=device, dtype=dtype)
|
|
173
207
|
|
|
174
|
-
self.controlnet_conv_in = ControlNetConditioningLayer(
|
|
208
|
+
self.controlnet_conv_in = ControlNetConditioningLayer(
|
|
209
|
+
channels=(3, 16, 32, 96, 256, 320), device=device, dtype=dtype
|
|
210
|
+
)
|
|
175
211
|
self.controlnet_transformer = ResidualAttentionBlock(320, 8, device=device, dtype=dtype)
|
|
176
212
|
self.task_embedding = torch.nn.Parameter(torch.randn(8, 320))
|
|
177
213
|
self.spatial_ch_projs = torch.nn.Linear(320, 320, device=device, dtype=dtype)
|
|
178
214
|
|
|
179
|
-
self.blocks = torch.nn.ModuleList(
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
215
|
+
self.blocks = torch.nn.ModuleList(
|
|
216
|
+
[
|
|
217
|
+
# DownBlock2D
|
|
218
|
+
ResnetBlock(320, 320, 1280, device=device, dtype=dtype),
|
|
219
|
+
PushBlock(),
|
|
220
|
+
ResnetBlock(320, 320, 1280, device=device, dtype=dtype),
|
|
221
|
+
PushBlock(),
|
|
222
|
+
DownSampler(320, device=device, dtype=dtype),
|
|
223
|
+
PushBlock(),
|
|
224
|
+
# CrossAttnDownBlock2D
|
|
225
|
+
ResnetBlock(320, 640, 1280, device=device, dtype=dtype),
|
|
226
|
+
AttentionBlock(10, 64, 640, 2, 2048, device=device, dtype=dtype),
|
|
227
|
+
PushBlock(),
|
|
228
|
+
ResnetBlock(640, 640, 1280, device=device, dtype=dtype),
|
|
229
|
+
AttentionBlock(10, 64, 640, 2, 2048, device=device, dtype=dtype),
|
|
230
|
+
PushBlock(),
|
|
231
|
+
DownSampler(640, device=device, dtype=dtype),
|
|
232
|
+
PushBlock(),
|
|
233
|
+
# CrossAttnDownBlock2D
|
|
234
|
+
ResnetBlock(640, 1280, 1280, device=device, dtype=dtype),
|
|
235
|
+
AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype),
|
|
236
|
+
PushBlock(),
|
|
237
|
+
ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
|
|
238
|
+
AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype),
|
|
239
|
+
PushBlock(),
|
|
240
|
+
# UNetMidBlock2DCrossAttn
|
|
241
|
+
ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
|
|
242
|
+
AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype),
|
|
243
|
+
ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
|
|
244
|
+
PushBlock(),
|
|
245
|
+
]
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
self.controlnet_blocks = torch.nn.ModuleList(
|
|
249
|
+
[
|
|
250
|
+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype),
|
|
251
|
+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype),
|
|
252
|
+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype),
|
|
253
|
+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype),
|
|
254
|
+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), device=device, dtype=dtype),
|
|
255
|
+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), device=device, dtype=dtype),
|
|
256
|
+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), device=device, dtype=dtype),
|
|
257
|
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), device=device, dtype=dtype),
|
|
258
|
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), device=device, dtype=dtype),
|
|
259
|
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), device=device, dtype=dtype),
|
|
260
|
+
]
|
|
261
|
+
)
|
|
222
262
|
|
|
223
263
|
# 0 -- openpose
|
|
224
264
|
# 1 -- depth
|
|
@@ -236,10 +276,9 @@ class SDXLControlNetUnion(PreTrainedModel):
|
|
|
236
276
|
"lineart": 3,
|
|
237
277
|
"lineart_anime": 3,
|
|
238
278
|
"tile": 6,
|
|
239
|
-
"inpaint": 7
|
|
279
|
+
"inpaint": 7,
|
|
240
280
|
}
|
|
241
281
|
|
|
242
|
-
|
|
243
282
|
def fuse_condition_to_input(self, hidden_states, task_id, conditioning):
|
|
244
283
|
controlnet_cond = self.controlnet_conv_in(conditioning)
|
|
245
284
|
feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
|
|
@@ -247,19 +286,25 @@ class SDXLControlNetUnion(PreTrainedModel):
|
|
|
247
286
|
x = torch.stack([feat_seq, torch.mean(hidden_states, dim=(2, 3))], dim=1)
|
|
248
287
|
x = self.controlnet_transformer(x)
|
|
249
288
|
|
|
250
|
-
alpha = self.spatial_ch_projs(x[:,0]).unsqueeze(-1).unsqueeze(-1)
|
|
289
|
+
alpha = self.spatial_ch_projs(x[:, 0]).unsqueeze(-1).unsqueeze(-1)
|
|
251
290
|
controlnet_cond_fuser = controlnet_cond + alpha
|
|
252
291
|
|
|
253
292
|
hidden_states = hidden_states + controlnet_cond_fuser
|
|
254
293
|
return hidden_states
|
|
255
|
-
|
|
256
294
|
|
|
257
295
|
def forward(
|
|
258
296
|
self,
|
|
259
|
-
sample,
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
297
|
+
sample,
|
|
298
|
+
timestep,
|
|
299
|
+
encoder_hidden_states,
|
|
300
|
+
conditioning,
|
|
301
|
+
processor_name,
|
|
302
|
+
add_time_id,
|
|
303
|
+
add_text_embeds,
|
|
304
|
+
tiled=False,
|
|
305
|
+
tile_size=64,
|
|
306
|
+
tile_stride=32,
|
|
307
|
+
**kwargs,
|
|
263
308
|
):
|
|
264
309
|
task_id = self.task_id[processor_name]
|
|
265
310
|
|
|
@@ -268,13 +268,12 @@ class SDXLUNet(PreTrainedModel):
|
|
|
268
268
|
text_emb,
|
|
269
269
|
res_stack,
|
|
270
270
|
)
|
|
271
|
-
|
|
271
|
+
|
|
272
272
|
# 3.2 Controlnet
|
|
273
273
|
if i == controlnet_insert_block_id and controlnet_res_stack is not None:
|
|
274
274
|
hidden_states += controlnet_res_stack.pop()
|
|
275
275
|
res_stack = [res + controlnet_res for res, controlnet_res in zip(res_stack, controlnet_res_stack)]
|
|
276
276
|
|
|
277
|
-
|
|
278
277
|
# 4. output
|
|
279
278
|
hidden_states = self.conv_norm_out(hidden_states)
|
|
280
279
|
hidden_states = self.conv_act(hidden_states)
|
|
@@ -6,6 +6,7 @@ from dataclasses import dataclass
|
|
|
6
6
|
|
|
7
7
|
ImageType = Union[Image.Image, torch.Tensor, List[Image.Image], List[torch.Tensor]]
|
|
8
8
|
|
|
9
|
+
|
|
9
10
|
@dataclass
|
|
10
11
|
class ControlNetParams:
|
|
11
12
|
image: ImageType
|
|
@@ -14,11 +15,12 @@ class ControlNetParams:
|
|
|
14
15
|
mask: Optional[ImageType] = None
|
|
15
16
|
control_start: float = 0
|
|
16
17
|
control_end: float = 1
|
|
17
|
-
processor_name: Optional[str] = None
|
|
18
|
+
processor_name: Optional[str] = None # only used for sdxl controlnet union now
|
|
19
|
+
|
|
18
20
|
|
|
19
21
|
def accumulate(result, new_item):
|
|
20
22
|
if result is None:
|
|
21
23
|
return new_item
|
|
22
24
|
for i, item in enumerate(new_item):
|
|
23
25
|
result[i] += item
|
|
24
|
-
return result
|
|
26
|
+
return result
|
|
@@ -6,7 +6,7 @@ import torch.distributed as dist
|
|
|
6
6
|
import math
|
|
7
7
|
from einops import rearrange
|
|
8
8
|
from enum import Enum
|
|
9
|
-
from typing import Callable, Dict, List, Tuple, Optional
|
|
9
|
+
from typing import Callable, Dict, List, Tuple, Optional, Union
|
|
10
10
|
from tqdm import tqdm
|
|
11
11
|
from PIL import Image
|
|
12
12
|
from dataclasses import dataclass
|
|
@@ -16,6 +16,7 @@ from diffsynth_engine.models.flux import (
|
|
|
16
16
|
FluxVAEDecoder,
|
|
17
17
|
FluxVAEEncoder,
|
|
18
18
|
FluxDiT,
|
|
19
|
+
FluxDiTFBCache,
|
|
19
20
|
flux_dit_config,
|
|
20
21
|
flux_text_encoder_config,
|
|
21
22
|
)
|
|
@@ -429,6 +430,7 @@ class ControlType(Enum):
|
|
|
429
430
|
elif self == ControlType.bfl_fill:
|
|
430
431
|
return 384
|
|
431
432
|
|
|
433
|
+
|
|
432
434
|
@dataclass
|
|
433
435
|
class FluxModelConfig:
|
|
434
436
|
dit_path: str | os.PathLike
|
|
@@ -460,7 +462,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
460
462
|
tokenizer_2: T5TokenizerFast,
|
|
461
463
|
text_encoder_1: FluxTextEncoder1,
|
|
462
464
|
text_encoder_2: FluxTextEncoder2,
|
|
463
|
-
dit: FluxDiT,
|
|
465
|
+
dit: Union[FluxDiT, FluxDiTFBCache],
|
|
464
466
|
vae_decoder: FluxVAEDecoder,
|
|
465
467
|
vae_encoder: FluxVAEEncoder,
|
|
466
468
|
load_text_encoder: bool = True,
|
|
@@ -518,6 +520,8 @@ class FluxImagePipeline(BasePipeline):
|
|
|
518
520
|
offload_mode: str | None = None,
|
|
519
521
|
parallelism: int = 1,
|
|
520
522
|
use_cfg_parallel: bool = False,
|
|
523
|
+
use_fb_cache: bool = False,
|
|
524
|
+
fb_cache_relative_l1_threshold: float = 0.05,
|
|
521
525
|
) -> "FluxImagePipeline":
|
|
522
526
|
model_config = (
|
|
523
527
|
model_path_or_config
|
|
@@ -562,13 +566,23 @@ class FluxImagePipeline(BasePipeline):
|
|
|
562
566
|
vae_encoder = FluxVAEEncoder.from_state_dict(vae_state_dict, device=init_device, dtype=model_config.vae_dtype)
|
|
563
567
|
|
|
564
568
|
with LoRAContext():
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
569
|
+
if use_fb_cache:
|
|
570
|
+
dit = FluxDiTFBCache.from_state_dict(
|
|
571
|
+
dit_state_dict,
|
|
572
|
+
device=init_device,
|
|
573
|
+
dtype=model_config.dit_dtype,
|
|
574
|
+
in_channel=control_type.get_in_channel(),
|
|
575
|
+
attn_impl=model_config.dit_attn_impl,
|
|
576
|
+
relative_l1_threshold=fb_cache_relative_l1_threshold,
|
|
577
|
+
)
|
|
578
|
+
else:
|
|
579
|
+
dit = FluxDiT.from_state_dict(
|
|
580
|
+
dit_state_dict,
|
|
581
|
+
device=init_device,
|
|
582
|
+
dtype=model_config.dit_dtype,
|
|
583
|
+
in_channel=control_type.get_in_channel(),
|
|
584
|
+
attn_impl=model_config.dit_attn_impl,
|
|
585
|
+
)
|
|
572
586
|
if model_config.use_fp8_linear:
|
|
573
587
|
enable_fp8_linear(dit)
|
|
574
588
|
|
|
@@ -968,6 +982,8 @@ class FluxImagePipeline(BasePipeline):
|
|
|
968
982
|
controlnet_params: List[ControlNetParams] | ControlNetParams = [],
|
|
969
983
|
progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
|
|
970
984
|
):
|
|
985
|
+
if isinstance(self.dit, FluxDiTFBCache):
|
|
986
|
+
self.dit.refresh_cache_status(num_inference_steps)
|
|
971
987
|
if not isinstance(controlnet_params, list):
|
|
972
988
|
controlnet_params = [controlnet_params]
|
|
973
989
|
if self.control_type != ControlType.normal:
|
|
@@ -291,7 +291,7 @@ class SDImagePipeline(BasePipeline):
|
|
|
291
291
|
current_step: int,
|
|
292
292
|
total_step: int,
|
|
293
293
|
):
|
|
294
|
-
controlnet_res_stack = None
|
|
294
|
+
controlnet_res_stack = None
|
|
295
295
|
for param in controlnet_params:
|
|
296
296
|
current_scale = param.scale
|
|
297
297
|
if not (
|
|
@@ -303,15 +303,10 @@ class SDImagePipeline(BasePipeline):
|
|
|
303
303
|
if self.offload_mode is not None:
|
|
304
304
|
empty_cache()
|
|
305
305
|
param.model.to(self.device)
|
|
306
|
-
controlnet_res = param.model(
|
|
307
|
-
latents,
|
|
308
|
-
timestep,
|
|
309
|
-
prompt_emb,
|
|
310
|
-
param.image
|
|
311
|
-
)
|
|
306
|
+
controlnet_res = param.model(latents, timestep, prompt_emb, param.image)
|
|
312
307
|
controlnet_res = [res * current_scale for res in controlnet_res]
|
|
313
308
|
if self.offload_mode is not None:
|
|
314
|
-
param.model.to("cpu")
|
|
309
|
+
param.model.to("cpu")
|
|
315
310
|
empty_cache()
|
|
316
311
|
controlnet_res_stack = accumulate(controlnet_res_stack, controlnet_res)
|
|
317
312
|
return controlnet_res_stack
|
|
@@ -324,16 +319,22 @@ class SDImagePipeline(BasePipeline):
|
|
|
324
319
|
negative_prompt_emb: torch.Tensor,
|
|
325
320
|
controlnet_params: List[ControlNetParams],
|
|
326
321
|
current_step: int,
|
|
327
|
-
total_step: int,
|
|
322
|
+
total_step: int,
|
|
328
323
|
cfg_scale: float,
|
|
329
324
|
batch_cfg: bool = True,
|
|
330
325
|
):
|
|
331
326
|
if cfg_scale <= 1.0:
|
|
332
|
-
return self.predict_noise(
|
|
327
|
+
return self.predict_noise(
|
|
328
|
+
latents, timestep, positive_prompt_emb, controlnet_params, current_step, total_step
|
|
329
|
+
)
|
|
333
330
|
if not batch_cfg:
|
|
334
331
|
# cfg by predict noise one by one
|
|
335
|
-
positive_noise_pred = self.predict_noise(
|
|
336
|
-
|
|
332
|
+
positive_noise_pred = self.predict_noise(
|
|
333
|
+
latents, timestep, positive_prompt_emb, controlnet_params, current_step, total_step
|
|
334
|
+
)
|
|
335
|
+
negative_noise_pred = self.predict_noise(
|
|
336
|
+
latents, timestep, negative_prompt_emb, controlnet_params, current_step, total_step
|
|
337
|
+
)
|
|
337
338
|
noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
|
|
338
339
|
return noise_pred
|
|
339
340
|
else:
|
|
@@ -341,12 +342,16 @@ class SDImagePipeline(BasePipeline):
|
|
|
341
342
|
prompt_emb = torch.cat([positive_prompt_emb, negative_prompt_emb], dim=0)
|
|
342
343
|
latents = torch.cat([latents, latents], dim=0)
|
|
343
344
|
timestep = torch.cat([timestep, timestep], dim=0)
|
|
344
|
-
positive_noise_pred, negative_noise_pred = self.predict_noise(
|
|
345
|
+
positive_noise_pred, negative_noise_pred = self.predict_noise(
|
|
346
|
+
latents, timestep, prompt_emb, controlnet_params, current_step, total_step
|
|
347
|
+
).chunk(2)
|
|
345
348
|
noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
|
|
346
349
|
return noise_pred
|
|
347
350
|
|
|
348
351
|
def predict_noise(self, latents, timestep, prompt_emb, controlnet_params, current_step, total_step):
|
|
349
|
-
controlnet_res_stack = self.predict_multicontrolnet(
|
|
352
|
+
controlnet_res_stack = self.predict_multicontrolnet(
|
|
353
|
+
latents, timestep, prompt_emb, controlnet_params, current_step, total_step
|
|
354
|
+
)
|
|
350
355
|
|
|
351
356
|
noise_pred = self.unet(
|
|
352
357
|
x=latents,
|
|
@@ -433,7 +438,7 @@ class SDImagePipeline(BasePipeline):
|
|
|
433
438
|
cfg_scale=cfg_scale,
|
|
434
439
|
controlnet_params=controlnet_params,
|
|
435
440
|
current_step=i,
|
|
436
|
-
total_step=len(timesteps),
|
|
441
|
+
total_step=len(timesteps),
|
|
437
442
|
batch_cfg=self.batch_cfg,
|
|
438
443
|
)
|
|
439
444
|
# Denoise
|
|
@@ -31,6 +31,7 @@ from diffsynth_engine.utils import logging
|
|
|
31
31
|
|
|
32
32
|
logger = logging.get_logger(__name__)
|
|
33
33
|
|
|
34
|
+
|
|
34
35
|
class SDXLLoRAConverter(LoRAStateDictConverter):
|
|
35
36
|
def _replace_kohya_te1_key(self, key):
|
|
36
37
|
key = key.replace("lora_te1_text_model_encoder_layers_", "encoders.")
|
|
@@ -91,7 +92,7 @@ class SDXLLoRAConverter(LoRAStateDictConverter):
|
|
|
91
92
|
else:
|
|
92
93
|
raise ValueError(f"Unsupported key: {key}")
|
|
93
94
|
# clip skip
|
|
94
|
-
te1_dict = {k: v for k, v in te1_dict.items() if not k.startswith(
|
|
95
|
+
te1_dict = {k: v for k, v in te1_dict.items() if not k.startswith("encoders.11")}
|
|
95
96
|
return {"unet": unet_dict, "text_encoder": te1_dict, "text_encoder_2": te2_dict}
|
|
96
97
|
|
|
97
98
|
def convert(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
|
@@ -279,10 +280,7 @@ class SDXLImagePipeline(BasePipeline):
|
|
|
279
280
|
condition = self.preprocess_control_image(param.image).to(device=self.device, dtype=self.dtype)
|
|
280
281
|
results.append(
|
|
281
282
|
ControlNetParams(
|
|
282
|
-
model=param.model,
|
|
283
|
-
scale=param.scale,
|
|
284
|
-
image=condition,
|
|
285
|
-
processor_name=param.processor_name
|
|
283
|
+
model=param.model, scale=param.scale, image=condition, processor_name=param.processor_name
|
|
286
284
|
)
|
|
287
285
|
)
|
|
288
286
|
return results
|
|
@@ -307,13 +305,13 @@ class SDXLImagePipeline(BasePipeline):
|
|
|
307
305
|
latents: torch.Tensor,
|
|
308
306
|
timestep: torch.Tensor,
|
|
309
307
|
prompt_emb: torch.Tensor,
|
|
310
|
-
add_text_embeds: torch.Tensor,
|
|
311
|
-
add_time_id: torch.Tensor,
|
|
308
|
+
add_text_embeds: torch.Tensor,
|
|
309
|
+
add_time_id: torch.Tensor,
|
|
312
310
|
controlnet_params: List[ControlNetParams],
|
|
313
311
|
current_step: int,
|
|
314
312
|
total_step: int,
|
|
315
313
|
):
|
|
316
|
-
controlnet_res_stack = None
|
|
314
|
+
controlnet_res_stack = None
|
|
317
315
|
for param in controlnet_params:
|
|
318
316
|
current_scale = param.scale
|
|
319
317
|
if not (
|
|
@@ -338,8 +336,8 @@ class SDXLImagePipeline(BasePipeline):
|
|
|
338
336
|
)
|
|
339
337
|
controlnet_res = [res * current_scale for res in controlnet_res]
|
|
340
338
|
if self.offload_mode is not None:
|
|
341
|
-
param.model.to("cpu")
|
|
342
|
-
empty_cache()
|
|
339
|
+
param.model.to("cpu")
|
|
340
|
+
empty_cache()
|
|
343
341
|
controlnet_res_stack = accumulate(controlnet_res_stack, controlnet_res)
|
|
344
342
|
return controlnet_res_stack
|
|
345
343
|
|
|
@@ -353,20 +351,36 @@ class SDXLImagePipeline(BasePipeline):
|
|
|
353
351
|
negative_add_text_embeds: torch.Tensor,
|
|
354
352
|
controlnet_params: List[ControlNetParams],
|
|
355
353
|
current_step: int,
|
|
356
|
-
total_step: int,
|
|
354
|
+
total_step: int,
|
|
357
355
|
add_time_id: torch.Tensor,
|
|
358
356
|
cfg_scale: float,
|
|
359
357
|
batch_cfg: bool = True,
|
|
360
358
|
):
|
|
361
359
|
if cfg_scale <= 1.0:
|
|
362
|
-
return self.predict_noise(
|
|
360
|
+
return self.predict_noise(
|
|
361
|
+
latents, timestep, positive_prompt_emb, add_time_id, controlnet_params, current_step, total_step
|
|
362
|
+
)
|
|
363
363
|
if not batch_cfg:
|
|
364
364
|
# cfg by predict noise one by one
|
|
365
365
|
positive_noise_pred = self.predict_noise(
|
|
366
|
-
latents,
|
|
366
|
+
latents,
|
|
367
|
+
timestep,
|
|
368
|
+
positive_prompt_emb,
|
|
369
|
+
positive_add_text_embeds,
|
|
370
|
+
add_time_id,
|
|
371
|
+
controlnet_params,
|
|
372
|
+
current_step,
|
|
373
|
+
total_step,
|
|
367
374
|
)
|
|
368
375
|
negative_noise_pred = self.predict_noise(
|
|
369
|
-
latents,
|
|
376
|
+
latents,
|
|
377
|
+
timestep,
|
|
378
|
+
negative_prompt_emb,
|
|
379
|
+
negative_add_text_embeds,
|
|
380
|
+
add_time_id,
|
|
381
|
+
controlnet_params,
|
|
382
|
+
current_step,
|
|
383
|
+
total_step,
|
|
370
384
|
)
|
|
371
385
|
noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
|
|
372
386
|
return noise_pred
|
|
@@ -378,14 +392,25 @@ class SDXLImagePipeline(BasePipeline):
|
|
|
378
392
|
latents = torch.cat([latents, latents], dim=0)
|
|
379
393
|
timestep = torch.cat([timestep, timestep], dim=0)
|
|
380
394
|
positive_noise_pred, negative_noise_pred = self.predict_noise(
|
|
381
|
-
latents,
|
|
395
|
+
latents,
|
|
396
|
+
timestep,
|
|
397
|
+
prompt_emb,
|
|
398
|
+
add_text_embeds,
|
|
399
|
+
add_time_ids,
|
|
400
|
+
controlnet_params,
|
|
401
|
+
current_step,
|
|
402
|
+
total_step,
|
|
382
403
|
)
|
|
383
404
|
noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
|
|
384
405
|
return noise_pred
|
|
385
406
|
|
|
386
|
-
def predict_noise(
|
|
407
|
+
def predict_noise(
|
|
408
|
+
self, latents, timestep, prompt_emb, add_text_embeds, add_time_id, controlnet_params, current_step, total_step
|
|
409
|
+
):
|
|
387
410
|
y = self.prepare_add_embeds(add_text_embeds, add_time_id, self.dtype)
|
|
388
|
-
controlnet_res_stack = self.predict_multicontrolnet(
|
|
411
|
+
controlnet_res_stack = self.predict_multicontrolnet(
|
|
412
|
+
latents, timestep, prompt_emb, add_text_embeds, add_time_id, controlnet_params, current_step, total_step
|
|
413
|
+
)
|
|
389
414
|
|
|
390
415
|
noise_pred = self.unet(
|
|
391
416
|
x=latents,
|
|
@@ -433,7 +458,7 @@ class SDXLImagePipeline(BasePipeline):
|
|
|
433
458
|
width: int = 1024,
|
|
434
459
|
num_inference_steps: int = 20,
|
|
435
460
|
seed: int | None = None,
|
|
436
|
-
controlnet_params: List[ControlNetParams] | ControlNetParams = [],
|
|
461
|
+
controlnet_params: List[ControlNetParams] | ControlNetParams = [],
|
|
437
462
|
progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
|
|
438
463
|
):
|
|
439
464
|
if not isinstance(controlnet_params, list):
|
|
@@ -491,7 +516,7 @@ class SDXLImagePipeline(BasePipeline):
|
|
|
491
516
|
cfg_scale=cfg_scale,
|
|
492
517
|
controlnet_params=controlnet_params,
|
|
493
518
|
current_step=i,
|
|
494
|
-
total_step=len(timesteps),
|
|
519
|
+
total_step=len(timesteps),
|
|
495
520
|
batch_cfg=self.batch_cfg,
|
|
496
521
|
)
|
|
497
522
|
# Denoise
|
|
@@ -71,15 +71,16 @@ diffsynth_engine/models/basic/relative_position_emb.py,sha256=rCXOweZMcayVnNUVvB
|
|
|
71
71
|
diffsynth_engine/models/basic/timestep.py,sha256=WJODYqkSXEM0wcS42YkkfrGwxWt0e60zMTkDdUBQqBw,2810
|
|
72
72
|
diffsynth_engine/models/basic/transformer_helper.py,sha256=LAEXInYSVEF61T3bzvu33jvEyR59B2feKhLmgZvUJH8,3384
|
|
73
73
|
diffsynth_engine/models/basic/unet_helper.py,sha256=4lN6F80Ubm6ip4dkLVmB-Og5-Y25Wduhs9Q8qjyzK6E,9044
|
|
74
|
-
diffsynth_engine/models/flux/__init__.py,sha256=
|
|
74
|
+
diffsynth_engine/models/flux/__init__.py,sha256=x0JoxL0CdiiVrY0BjkIrGinud7mcXecLleGO0km91XQ,686
|
|
75
75
|
diffsynth_engine/models/flux/flux_controlnet.py,sha256=YvQ5gfzttP092gwLzTzf7IcEnG4smOvILghfoCVuHOo,8163
|
|
76
76
|
diffsynth_engine/models/flux/flux_dit.py,sha256=4mpMYVbK8OoQSQeGNJcnkR-rkc0jvw5cT9SOz0gxig0,23335
|
|
77
|
+
diffsynth_engine/models/flux/flux_dit_fbcache.py,sha256=1uG24Cv7IOfrHzwqBv1H_BGAbBsJkylvADaA_gnNjss,8497
|
|
77
78
|
diffsynth_engine/models/flux/flux_ipadapter.py,sha256=SHN2hlibLy0OIScGTygV2yNNJZ7FLzpFdBXzykH6wTo,7129
|
|
78
79
|
diffsynth_engine/models/flux/flux_redux.py,sha256=tVK4hZxHtC_jx-6vpc0os6dtohBUWK_V8ucdZBkeb-o,2560
|
|
79
80
|
diffsynth_engine/models/flux/flux_text_encoder.py,sha256=Qcs277RIPP-O5AkcAb5Fb0jToV5o6Qn8hh8nw8zx9g8,3663
|
|
80
81
|
diffsynth_engine/models/flux/flux_vae.py,sha256=qJzcpfQ9-ATQmE5n8nOy1c5BEA0GZjIddm8zOUodXnE,2840
|
|
81
82
|
diffsynth_engine/models/sd/__init__.py,sha256=hjoKRnwoXOLD0wude-w7I6wK5ak7ACMbnbkPuBB2oU0,380
|
|
82
|
-
diffsynth_engine/models/sd/sd_controlnet.py,sha256=
|
|
83
|
+
diffsynth_engine/models/sd/sd_controlnet.py,sha256=K4iHSCf_UgIlogfLcwvcnnu4RwtgkI-IL-PsWq3e-gM,50761
|
|
83
84
|
diffsynth_engine/models/sd/sd_text_encoder.py,sha256=gTj8GOGnrUjLDvIFgbHCMJAkdh8nn0E6FtumQ-6L2Xs,5406
|
|
84
85
|
diffsynth_engine/models/sd/sd_unet.py,sha256=EbF0Vdt4tKp0VjqozBK71-utFAPp6OHZOdWn0czqYNo,12897
|
|
85
86
|
diffsynth_engine/models/sd/sd_vae.py,sha256=neHAC2HJoiSCQZsyH-ufvIbg1VWmXk3OHz5N_-I5BTQ,1650
|
|
@@ -87,10 +88,10 @@ diffsynth_engine/models/sd3/__init__.py,sha256=Kd5JCDlfP-FfW0Z_BiugQxuRhs2E_3np_
|
|
|
87
88
|
diffsynth_engine/models/sd3/sd3_dit.py,sha256=3G7XVQRw2iD3lFDvzfYXykB3M0QpX42adPnyjtLihos,11447
|
|
88
89
|
diffsynth_engine/models/sd3/sd3_text_encoder.py,sha256=9I57W5AScVoV7L2If4CsOxD_k903S5jlUxEOiUF-xf4,7081
|
|
89
90
|
diffsynth_engine/models/sd3/sd3_vae.py,sha256=nxZpg616aAagSWALVC4r2B-NLYeA2-u9kPYN0lgBm1I,1428
|
|
90
|
-
diffsynth_engine/models/sdxl/__init__.py,sha256=
|
|
91
|
-
diffsynth_engine/models/sdxl/sdxl_controlnet.py,sha256=
|
|
91
|
+
diffsynth_engine/models/sdxl/__init__.py,sha256=3DyhN7p2N2n0MF7RqnTl54JNL0qNQHjIqBMb7W8V2zA,468
|
|
92
|
+
diffsynth_engine/models/sdxl/sdxl_controlnet.py,sha256=rfLDGbe9txKS7DCEtCCn1RjD08SnC5BJsXR3LEunCXM,15380
|
|
92
93
|
diffsynth_engine/models/sdxl/sdxl_text_encoder.py,sha256=VNSx1hfqKAOK-T3m34SaONscAJK8wFkyuVtfcTWmeRA,12369
|
|
93
|
-
diffsynth_engine/models/sdxl/sdxl_unet.py,sha256=
|
|
94
|
+
diffsynth_engine/models/sdxl/sdxl_unet.py,sha256=rnD-ZmGf2g_iRuS8I1zBjFJkcoMAE4TxuTjmISWlQhk,12156
|
|
94
95
|
diffsynth_engine/models/sdxl/sdxl_vae.py,sha256=kNGnn5wKipD31qgMYYLeFyyi8iCc8bsAMD1pEtDhue8,1654
|
|
95
96
|
diffsynth_engine/models/text_encoder/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
96
97
|
diffsynth_engine/models/text_encoder/clip.py,sha256=RwJYtSTWjh7yw8dQKbQbiL335naqJhiM96RDCMTnFh4,1805
|
|
@@ -105,10 +106,10 @@ diffsynth_engine/models/wan/wan_text_encoder.py,sha256=bkphxtqNNwXcEA_OaUrwV9CvI
|
|
|
105
106
|
diffsynth_engine/models/wan/wan_vae.py,sha256=RxyuHExQmRjGBAqhZdIbtwZFdCibTzh__U4-Sa00zdI,29004
|
|
106
107
|
diffsynth_engine/pipelines/__init__.py,sha256=TQnwuOGtnlIBdi-afGvS0W4PXPDVkOyjCAAys0DMowA,605
|
|
107
108
|
diffsynth_engine/pipelines/base.py,sha256=mmJGBC0bJQR1qQp9cknz7t-n5PrqmasTKqSpwWaxVVA,13461
|
|
108
|
-
diffsynth_engine/pipelines/controlnet_helper.py,sha256=
|
|
109
|
-
diffsynth_engine/pipelines/flux_image.py,sha256=
|
|
110
|
-
diffsynth_engine/pipelines/sd_image.py,sha256=
|
|
111
|
-
diffsynth_engine/pipelines/sdxl_image.py,sha256=
|
|
109
|
+
diffsynth_engine/pipelines/controlnet_helper.py,sha256=b6HnJFJfMKZq9s5DQ-9Se8OTSDeHVk4AskONSwcRShg,680
|
|
110
|
+
diffsynth_engine/pipelines/flux_image.py,sha256=4UBdlTlzTQnpnh6t0vJroouk423yoxY6GHvBZm0lexY,50590
|
|
111
|
+
diffsynth_engine/pipelines/sd_image.py,sha256=A92oBYZt37x4uOYhl3dAiCzkcgsPmRmH1z4_IWEULeA,18807
|
|
112
|
+
diffsynth_engine/pipelines/sdxl_image.py,sha256=P1bFvFg3Dd49BAEihig-e5QkoNtztJqVxdfAuD693FU,22671
|
|
112
113
|
diffsynth_engine/pipelines/wan_video.py,sha256=LqDBZILcIOK6SR0zPrDJgBjpBgkIah8N6Q7I-8XHchY,21584
|
|
113
114
|
diffsynth_engine/processor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
114
115
|
diffsynth_engine/processor/canny_processor.py,sha256=hV30NlblTkEFUAmF_O-LJrNlGVM2SFrqq6okfF8VpOo,602
|
|
@@ -140,8 +141,8 @@ diffsynth_engine/utils/parallel.py,sha256=2WISMBTTmW0v2qPvpms421-B59v3bYlS6YrLq9
|
|
|
140
141
|
diffsynth_engine/utils/platform.py,sha256=q9ifmdzoa66Cj9YKfwps21DsDdwA0JGpwroKQbG6shU,224
|
|
141
142
|
diffsynth_engine/utils/prompt.py,sha256=YItMchoVzsG6y-LB4vzzDUWrkhKRVlt1HfVhxZjSxMQ,280
|
|
142
143
|
diffsynth_engine/utils/video.py,sha256=Ne0rd2lb59UT1q5EotpjlY7OT8F9oTCFDyo1ST77uoQ,1004
|
|
143
|
-
diffsynth_engine-0.3.6.
|
|
144
|
-
diffsynth_engine-0.3.6.
|
|
145
|
-
diffsynth_engine-0.3.6.
|
|
146
|
-
diffsynth_engine-0.3.6.
|
|
147
|
-
diffsynth_engine-0.3.6.
|
|
144
|
+
diffsynth_engine-0.3.6.dev6.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
|
|
145
|
+
diffsynth_engine-0.3.6.dev6.dist-info/METADATA,sha256=RlAdpLv_Z2WfeqivtN0OVMGSTi4_rMPfOTiBJgyB6BI,1068
|
|
146
|
+
diffsynth_engine-0.3.6.dev6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
147
|
+
diffsynth_engine-0.3.6.dev6.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
|
|
148
|
+
diffsynth_engine-0.3.6.dev6.dist-info/RECORD,,
|
|
File without changes
|
{diffsynth_engine-0.3.6.dev5.dist-info → diffsynth_engine-0.3.6.dev6.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|
{diffsynth_engine-0.3.6.dev5.dist-info → diffsynth_engine-0.3.6.dev6.dist-info}/top_level.txt
RENAMED
|
File without changes
|