diffsynth-engine 0.5.1.dev4__py3-none-any.whl → 0.6.1.dev25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +12 -0
- diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +19 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +22 -6
- diffsynth_engine/conf/models/flux/flux_dit.json +20 -1
- diffsynth_engine/conf/models/flux/flux_vae.json +253 -5
- diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
- diffsynth_engine/configs/__init__.py +16 -1
- diffsynth_engine/configs/controlnet.py +13 -0
- diffsynth_engine/configs/pipeline.py +37 -11
- diffsynth_engine/models/base.py +1 -1
- diffsynth_engine/models/basic/attention.py +105 -43
- diffsynth_engine/models/basic/transformer_helper.py +36 -2
- diffsynth_engine/models/basic/video_sparse_attention.py +238 -0
- diffsynth_engine/models/flux/flux_controlnet.py +16 -30
- diffsynth_engine/models/flux/flux_dit.py +49 -62
- diffsynth_engine/models/flux/flux_dit_fbcache.py +26 -28
- diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
- diffsynth_engine/models/flux/flux_text_encoder.py +1 -1
- diffsynth_engine/models/flux/flux_vae.py +20 -2
- diffsynth_engine/models/hunyuan3d/dino_image_encoder.py +4 -2
- diffsynth_engine/models/qwen_image/qwen2_5_vl.py +5 -0
- diffsynth_engine/models/qwen_image/qwen_image_dit.py +151 -58
- diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
- diffsynth_engine/models/qwen_image/qwen_image_vae.py +1 -1
- diffsynth_engine/models/sd/sd_text_encoder.py +1 -1
- diffsynth_engine/models/sd/sd_unet.py +1 -1
- diffsynth_engine/models/sd3/sd3_dit.py +1 -1
- diffsynth_engine/models/sd3/sd3_text_encoder.py +1 -1
- diffsynth_engine/models/sdxl/sdxl_text_encoder.py +1 -1
- diffsynth_engine/models/sdxl/sdxl_unet.py +1 -1
- diffsynth_engine/models/vae/vae.py +1 -1
- diffsynth_engine/models/wan/wan_audio_encoder.py +6 -3
- diffsynth_engine/models/wan/wan_dit.py +65 -28
- diffsynth_engine/models/wan/wan_s2v_dit.py +1 -1
- diffsynth_engine/models/wan/wan_text_encoder.py +13 -13
- diffsynth_engine/models/wan/wan_vae.py +2 -2
- diffsynth_engine/pipelines/base.py +73 -7
- diffsynth_engine/pipelines/flux_image.py +139 -120
- diffsynth_engine/pipelines/hunyuan3d_shape.py +4 -0
- diffsynth_engine/pipelines/qwen_image.py +272 -87
- diffsynth_engine/pipelines/sdxl_image.py +1 -1
- diffsynth_engine/pipelines/utils.py +52 -0
- diffsynth_engine/pipelines/wan_s2v.py +25 -14
- diffsynth_engine/pipelines/wan_video.py +43 -19
- diffsynth_engine/tokenizers/base.py +6 -0
- diffsynth_engine/tokenizers/qwen2.py +12 -4
- diffsynth_engine/utils/constants.py +13 -12
- diffsynth_engine/utils/download.py +4 -2
- diffsynth_engine/utils/env.py +2 -0
- diffsynth_engine/utils/flag.py +6 -0
- diffsynth_engine/utils/loader.py +25 -6
- diffsynth_engine/utils/parallel.py +62 -29
- diffsynth_engine/utils/video.py +3 -1
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/RECORD +69 -67
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-flf2v-14b.json → wan2.1_flf2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-i2v-14b.json → wan2.1_i2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-1.3b.json → wan2.1_t2v_1.3b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-14b.json → wan2.1_t2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-i2v-a14b.json → wan2.2_i2v_a14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-s2v-14b.json → wan2.2_s2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-t2v-a14b.json → wan2.2_t2v_a14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-ti2v-5b.json → wan2.2_ti2v_5b.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan2.1-vae.json → wan2.1_vae.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan2.2-vae.json → wan2.2_vae.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan-vae-keymap.json → wan_vae_keymap.json} +0 -0
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/top_level.txt +0 -0
diffsynth_engine/__init__.py
CHANGED
|
@@ -11,8 +11,14 @@ from .configs import (
|
|
|
11
11
|
FluxStateDicts,
|
|
12
12
|
WanStateDicts,
|
|
13
13
|
QwenImageStateDicts,
|
|
14
|
+
AttnImpl,
|
|
15
|
+
SpargeAttentionParams,
|
|
16
|
+
VideoSparseAttentionParams,
|
|
17
|
+
LoraConfig,
|
|
14
18
|
ControlNetParams,
|
|
15
19
|
ControlType,
|
|
20
|
+
QwenImageControlNetParams,
|
|
21
|
+
QwenImageControlType,
|
|
16
22
|
)
|
|
17
23
|
from .pipelines import (
|
|
18
24
|
SDImagePipeline,
|
|
@@ -54,8 +60,14 @@ __all__ = [
|
|
|
54
60
|
"FluxStateDicts",
|
|
55
61
|
"WanStateDicts",
|
|
56
62
|
"QwenImageStateDicts",
|
|
63
|
+
"AttnImpl",
|
|
64
|
+
"SpargeAttentionParams",
|
|
65
|
+
"VideoSparseAttentionParams",
|
|
66
|
+
"LoraConfig",
|
|
57
67
|
"ControlNetParams",
|
|
58
68
|
"ControlType",
|
|
69
|
+
"QwenImageControlNetParams",
|
|
70
|
+
"QwenImageControlType",
|
|
59
71
|
"SDImagePipeline",
|
|
60
72
|
"SDControlNet",
|
|
61
73
|
"SDXLImagePipeline",
|
|
@@ -6,5 +6,24 @@ def append_zero(x):
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class BaseScheduler:
|
|
9
|
+
def __init__(self):
|
|
10
|
+
self._stored_config = {}
|
|
11
|
+
|
|
12
|
+
def store_config(self):
|
|
13
|
+
self._stored_config = {
|
|
14
|
+
config_name: config_value
|
|
15
|
+
for config_name, config_value in vars(self).items()
|
|
16
|
+
if not config_name.startswith("_")
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
def update_config(self, config_dict):
|
|
20
|
+
for config_name, new_value in config_dict.items():
|
|
21
|
+
if hasattr(self, config_name):
|
|
22
|
+
setattr(self, config_name, new_value)
|
|
23
|
+
|
|
24
|
+
def restore_config(self):
|
|
25
|
+
for config_name, config_value in self._stored_config.items():
|
|
26
|
+
setattr(self, config_name, config_value)
|
|
27
|
+
|
|
9
28
|
def schedule(self, num_inference_steps: int):
|
|
10
29
|
raise NotImplementedError()
|
|
@@ -12,16 +12,23 @@ class RecifitedFlowScheduler(BaseScheduler):
|
|
|
12
12
|
def __init__(
|
|
13
13
|
self,
|
|
14
14
|
shift=1.0,
|
|
15
|
-
sigma_min=
|
|
16
|
-
sigma_max=
|
|
15
|
+
sigma_min=None,
|
|
16
|
+
sigma_max=None,
|
|
17
17
|
num_train_timesteps=1000,
|
|
18
18
|
use_dynamic_shifting=False,
|
|
19
|
+
shift_terminal=None,
|
|
20
|
+
exponential_shift_mu=None,
|
|
19
21
|
):
|
|
22
|
+
super().__init__()
|
|
20
23
|
self.shift = shift
|
|
21
24
|
self.sigma_min = sigma_min
|
|
22
25
|
self.sigma_max = sigma_max
|
|
23
26
|
self.num_train_timesteps = num_train_timesteps
|
|
24
27
|
self.use_dynamic_shifting = use_dynamic_shifting
|
|
28
|
+
self.shift_terminal = shift_terminal
|
|
29
|
+
# static mu for distill model
|
|
30
|
+
self.exponential_shift_mu = exponential_shift_mu
|
|
31
|
+
self.store_config()
|
|
25
32
|
|
|
26
33
|
def _sigma_to_t(self, sigma):
|
|
27
34
|
return sigma * self.num_train_timesteps
|
|
@@ -35,21 +42,30 @@ class RecifitedFlowScheduler(BaseScheduler):
|
|
|
35
42
|
def _shift_sigma(self, sigma: torch.Tensor, shift: float):
|
|
36
43
|
return shift * sigma / (1 + (shift - 1) * sigma)
|
|
37
44
|
|
|
45
|
+
def _stretch_shift_to_terminal(self, sigma: torch.Tensor):
|
|
46
|
+
one_minus_z = 1 - sigma
|
|
47
|
+
scale_factor = one_minus_z[-1] / (1 - self.shift_terminal)
|
|
48
|
+
return 1 - (one_minus_z / scale_factor)
|
|
49
|
+
|
|
38
50
|
def schedule(
|
|
39
51
|
self,
|
|
40
52
|
num_inference_steps: int,
|
|
41
53
|
mu: float | None = None,
|
|
42
|
-
sigma_min: float
|
|
43
|
-
sigma_max: float
|
|
54
|
+
sigma_min: float = 0.001,
|
|
55
|
+
sigma_max: float = 1.0,
|
|
44
56
|
append_value: float = 0,
|
|
45
57
|
):
|
|
46
|
-
sigma_min =
|
|
47
|
-
sigma_max =
|
|
58
|
+
sigma_min = sigma_min if self.sigma_min is None else self.sigma_min
|
|
59
|
+
sigma_max = sigma_max if self.sigma_max is None else self.sigma_max
|
|
48
60
|
sigmas = torch.linspace(sigma_max, sigma_min, num_inference_steps)
|
|
61
|
+
if self.exponential_shift_mu is not None:
|
|
62
|
+
mu = self.exponential_shift_mu
|
|
49
63
|
if self.use_dynamic_shifting:
|
|
50
64
|
sigmas = self._time_shift(mu, 1.0, sigmas) # FLUX
|
|
51
65
|
else:
|
|
52
66
|
sigmas = self._shift_sigma(sigmas, self.shift)
|
|
67
|
+
if self.shift_terminal is not None:
|
|
68
|
+
sigmas = self._stretch_shift_to_terminal(sigmas)
|
|
53
69
|
timesteps = sigmas * self.num_train_timesteps
|
|
54
70
|
sigmas = append(sigmas, append_value)
|
|
55
71
|
return sigmas, timesteps
|
|
@@ -101,5 +101,24 @@
|
|
|
101
101
|
"proj_mlp": "proj_in_besides_attn",
|
|
102
102
|
"proj_out": "proj_out"
|
|
103
103
|
}
|
|
104
|
-
}
|
|
104
|
+
},
|
|
105
|
+
"preferred_kontext_resolutions": [
|
|
106
|
+
[672, 1568],
|
|
107
|
+
[688, 1504],
|
|
108
|
+
[720, 1456],
|
|
109
|
+
[752, 1392],
|
|
110
|
+
[800, 1328],
|
|
111
|
+
[832, 1248],
|
|
112
|
+
[880, 1184],
|
|
113
|
+
[944, 1104],
|
|
114
|
+
[1024, 1024],
|
|
115
|
+
[1104, 944],
|
|
116
|
+
[1184, 880],
|
|
117
|
+
[1248, 832],
|
|
118
|
+
[1328, 800],
|
|
119
|
+
[1392, 752],
|
|
120
|
+
[1456, 720],
|
|
121
|
+
[1504, 688],
|
|
122
|
+
[1568, 672]
|
|
123
|
+
]
|
|
105
124
|
}
|
|
@@ -5,6 +5,8 @@
|
|
|
5
5
|
"decoder.conv_in.weight": "decoder.conv_in.weight",
|
|
6
6
|
"decoder.conv_out.bias": "decoder.conv_out.bias",
|
|
7
7
|
"decoder.conv_out.weight": "decoder.conv_out.weight",
|
|
8
|
+
"decoder.norm_out.bias": "decoder.conv_norm_out.bias",
|
|
9
|
+
"decoder.norm_out.weight": "decoder.conv_norm_out.weight",
|
|
8
10
|
"decoder.mid.attn_1.k.bias": "decoder.blocks.1.transformer_blocks.0.to_k.bias",
|
|
9
11
|
"decoder.mid.attn_1.k.weight": "decoder.blocks.1.transformer_blocks.0.to_k.weight",
|
|
10
12
|
"decoder.mid.attn_1.norm.bias": "decoder.blocks.1.norm.bias",
|
|
@@ -31,8 +33,6 @@
|
|
|
31
33
|
"decoder.mid.block_2.norm1.weight": "decoder.blocks.2.norm1.weight",
|
|
32
34
|
"decoder.mid.block_2.norm2.bias": "decoder.blocks.2.norm2.bias",
|
|
33
35
|
"decoder.mid.block_2.norm2.weight": "decoder.blocks.2.norm2.weight",
|
|
34
|
-
"decoder.norm_out.bias": "decoder.conv_norm_out.bias",
|
|
35
|
-
"decoder.norm_out.weight": "decoder.conv_norm_out.weight",
|
|
36
36
|
"decoder.up.0.block.0.conv1.bias": "decoder.blocks.15.conv1.bias",
|
|
37
37
|
"decoder.up.0.block.0.conv1.weight": "decoder.blocks.15.conv1.weight",
|
|
38
38
|
"decoder.up.0.block.0.conv2.bias": "decoder.blocks.15.conv2.bias",
|
|
@@ -143,6 +143,8 @@
|
|
|
143
143
|
"encoder.conv_in.weight": "encoder.conv_in.weight",
|
|
144
144
|
"encoder.conv_out.bias": "encoder.conv_out.bias",
|
|
145
145
|
"encoder.conv_out.weight": "encoder.conv_out.weight",
|
|
146
|
+
"encoder.norm_out.bias": "encoder.conv_norm_out.bias",
|
|
147
|
+
"encoder.norm_out.weight": "encoder.conv_norm_out.weight",
|
|
146
148
|
"encoder.down.0.block.0.conv1.bias": "encoder.blocks.0.conv1.bias",
|
|
147
149
|
"encoder.down.0.block.0.conv1.weight": "encoder.blocks.0.conv1.weight",
|
|
148
150
|
"encoder.down.0.block.0.conv2.bias": "encoder.blocks.0.conv2.bias",
|
|
@@ -242,9 +244,255 @@
|
|
|
242
244
|
"encoder.mid.block_2.norm1.bias": "encoder.blocks.13.norm1.bias",
|
|
243
245
|
"encoder.mid.block_2.norm1.weight": "encoder.blocks.13.norm1.weight",
|
|
244
246
|
"encoder.mid.block_2.norm2.bias": "encoder.blocks.13.norm2.bias",
|
|
245
|
-
"encoder.mid.block_2.norm2.weight": "encoder.blocks.13.norm2.weight"
|
|
246
|
-
|
|
247
|
-
|
|
247
|
+
"encoder.mid.block_2.norm2.weight": "encoder.blocks.13.norm2.weight"
|
|
248
|
+
}
|
|
249
|
+
},
|
|
250
|
+
"diffusers": {
|
|
251
|
+
"rename_dict": {
|
|
252
|
+
"decoder.conv_in.bias": "decoder.conv_in.bias",
|
|
253
|
+
"decoder.conv_in.weight": "decoder.conv_in.weight",
|
|
254
|
+
"decoder.conv_out.bias": "decoder.conv_out.bias",
|
|
255
|
+
"decoder.conv_out.weight": "decoder.conv_out.weight",
|
|
256
|
+
"decoder.conv_norm_out.bias": "decoder.conv_norm_out.bias",
|
|
257
|
+
"decoder.conv_norm_out.weight": "decoder.conv_norm_out.weight",
|
|
258
|
+
"decoder.mid_block.attentions.0.to_k.bias": "decoder.blocks.1.transformer_blocks.0.to_k.bias",
|
|
259
|
+
"decoder.mid_block.attentions.0.to_k.weight": "decoder.blocks.1.transformer_blocks.0.to_k.weight",
|
|
260
|
+
"decoder.mid_block.attentions.0.group_norm.bias": "decoder.blocks.1.norm.bias",
|
|
261
|
+
"decoder.mid_block.attentions.0.group_norm.weight": "decoder.blocks.1.norm.weight",
|
|
262
|
+
"decoder.mid_block.attentions.0.to_out.0.bias": "decoder.blocks.1.transformer_blocks.0.to_out.bias",
|
|
263
|
+
"decoder.mid_block.attentions.0.to_out.0.weight": "decoder.blocks.1.transformer_blocks.0.to_out.weight",
|
|
264
|
+
"decoder.mid_block.attentions.0.to_q.bias": "decoder.blocks.1.transformer_blocks.0.to_q.bias",
|
|
265
|
+
"decoder.mid_block.attentions.0.to_q.weight": "decoder.blocks.1.transformer_blocks.0.to_q.weight",
|
|
266
|
+
"decoder.mid_block.attentions.0.to_v.bias": "decoder.blocks.1.transformer_blocks.0.to_v.bias",
|
|
267
|
+
"decoder.mid_block.attentions.0.to_v.weight": "decoder.blocks.1.transformer_blocks.0.to_v.weight",
|
|
268
|
+
"decoder.mid_block.resnets.0.conv1.bias": "decoder.blocks.0.conv1.bias",
|
|
269
|
+
"decoder.mid_block.resnets.0.conv1.weight": "decoder.blocks.0.conv1.weight",
|
|
270
|
+
"decoder.mid_block.resnets.0.conv2.bias": "decoder.blocks.0.conv2.bias",
|
|
271
|
+
"decoder.mid_block.resnets.0.conv2.weight": "decoder.blocks.0.conv2.weight",
|
|
272
|
+
"decoder.mid_block.resnets.0.norm1.bias": "decoder.blocks.0.norm1.bias",
|
|
273
|
+
"decoder.mid_block.resnets.0.norm1.weight": "decoder.blocks.0.norm1.weight",
|
|
274
|
+
"decoder.mid_block.resnets.0.norm2.bias": "decoder.blocks.0.norm2.bias",
|
|
275
|
+
"decoder.mid_block.resnets.0.norm2.weight": "decoder.blocks.0.norm2.weight",
|
|
276
|
+
"decoder.mid_block.resnets.1.conv1.bias": "decoder.blocks.2.conv1.bias",
|
|
277
|
+
"decoder.mid_block.resnets.1.conv1.weight": "decoder.blocks.2.conv1.weight",
|
|
278
|
+
"decoder.mid_block.resnets.1.conv2.bias": "decoder.blocks.2.conv2.bias",
|
|
279
|
+
"decoder.mid_block.resnets.1.conv2.weight": "decoder.blocks.2.conv2.weight",
|
|
280
|
+
"decoder.mid_block.resnets.1.norm1.bias": "decoder.blocks.2.norm1.bias",
|
|
281
|
+
"decoder.mid_block.resnets.1.norm1.weight": "decoder.blocks.2.norm1.weight",
|
|
282
|
+
"decoder.mid_block.resnets.1.norm2.bias": "decoder.blocks.2.norm2.bias",
|
|
283
|
+
"decoder.mid_block.resnets.1.norm2.weight": "decoder.blocks.2.norm2.weight",
|
|
284
|
+
"decoder.up_blocks.0.resnets.0.conv1.bias": "decoder.blocks.3.conv1.bias",
|
|
285
|
+
"decoder.up_blocks.0.resnets.0.conv1.weight": "decoder.blocks.3.conv1.weight",
|
|
286
|
+
"decoder.up_blocks.0.resnets.0.conv2.bias": "decoder.blocks.3.conv2.bias",
|
|
287
|
+
"decoder.up_blocks.0.resnets.0.conv2.weight": "decoder.blocks.3.conv2.weight",
|
|
288
|
+
"decoder.up_blocks.0.resnets.0.norm1.bias": "decoder.blocks.3.norm1.bias",
|
|
289
|
+
"decoder.up_blocks.0.resnets.0.norm1.weight": "decoder.blocks.3.norm1.weight",
|
|
290
|
+
"decoder.up_blocks.0.resnets.0.norm2.bias": "decoder.blocks.3.norm2.bias",
|
|
291
|
+
"decoder.up_blocks.0.resnets.0.norm2.weight": "decoder.blocks.3.norm2.weight",
|
|
292
|
+
"decoder.up_blocks.0.resnets.1.conv1.bias": "decoder.blocks.4.conv1.bias",
|
|
293
|
+
"decoder.up_blocks.0.resnets.1.conv1.weight": "decoder.blocks.4.conv1.weight",
|
|
294
|
+
"decoder.up_blocks.0.resnets.1.conv2.bias": "decoder.blocks.4.conv2.bias",
|
|
295
|
+
"decoder.up_blocks.0.resnets.1.conv2.weight": "decoder.blocks.4.conv2.weight",
|
|
296
|
+
"decoder.up_blocks.0.resnets.1.norm1.bias": "decoder.blocks.4.norm1.bias",
|
|
297
|
+
"decoder.up_blocks.0.resnets.1.norm1.weight": "decoder.blocks.4.norm1.weight",
|
|
298
|
+
"decoder.up_blocks.0.resnets.1.norm2.bias": "decoder.blocks.4.norm2.bias",
|
|
299
|
+
"decoder.up_blocks.0.resnets.1.norm2.weight": "decoder.blocks.4.norm2.weight",
|
|
300
|
+
"decoder.up_blocks.0.resnets.2.conv1.bias": "decoder.blocks.5.conv1.bias",
|
|
301
|
+
"decoder.up_blocks.0.resnets.2.conv1.weight": "decoder.blocks.5.conv1.weight",
|
|
302
|
+
"decoder.up_blocks.0.resnets.2.conv2.bias": "decoder.blocks.5.conv2.bias",
|
|
303
|
+
"decoder.up_blocks.0.resnets.2.conv2.weight": "decoder.blocks.5.conv2.weight",
|
|
304
|
+
"decoder.up_blocks.0.resnets.2.norm1.bias": "decoder.blocks.5.norm1.bias",
|
|
305
|
+
"decoder.up_blocks.0.resnets.2.norm1.weight": "decoder.blocks.5.norm1.weight",
|
|
306
|
+
"decoder.up_blocks.0.resnets.2.norm2.bias": "decoder.blocks.5.norm2.bias",
|
|
307
|
+
"decoder.up_blocks.0.resnets.2.norm2.weight": "decoder.blocks.5.norm2.weight",
|
|
308
|
+
"decoder.up_blocks.0.upsamplers.0.conv.bias": "decoder.blocks.6.conv.bias",
|
|
309
|
+
"decoder.up_blocks.0.upsamplers.0.conv.weight": "decoder.blocks.6.conv.weight",
|
|
310
|
+
"decoder.up_blocks.1.resnets.0.conv1.bias": "decoder.blocks.7.conv1.bias",
|
|
311
|
+
"decoder.up_blocks.1.resnets.0.conv1.weight": "decoder.blocks.7.conv1.weight",
|
|
312
|
+
"decoder.up_blocks.1.resnets.0.conv2.bias": "decoder.blocks.7.conv2.bias",
|
|
313
|
+
"decoder.up_blocks.1.resnets.0.conv2.weight": "decoder.blocks.7.conv2.weight",
|
|
314
|
+
"decoder.up_blocks.1.resnets.0.norm1.bias": "decoder.blocks.7.norm1.bias",
|
|
315
|
+
"decoder.up_blocks.1.resnets.0.norm1.weight": "decoder.blocks.7.norm1.weight",
|
|
316
|
+
"decoder.up_blocks.1.resnets.0.norm2.bias": "decoder.blocks.7.norm2.bias",
|
|
317
|
+
"decoder.up_blocks.1.resnets.0.norm2.weight": "decoder.blocks.7.norm2.weight",
|
|
318
|
+
"decoder.up_blocks.1.resnets.1.conv1.bias": "decoder.blocks.8.conv1.bias",
|
|
319
|
+
"decoder.up_blocks.1.resnets.1.conv1.weight": "decoder.blocks.8.conv1.weight",
|
|
320
|
+
"decoder.up_blocks.1.resnets.1.conv2.bias": "decoder.blocks.8.conv2.bias",
|
|
321
|
+
"decoder.up_blocks.1.resnets.1.conv2.weight": "decoder.blocks.8.conv2.weight",
|
|
322
|
+
"decoder.up_blocks.1.resnets.1.norm1.bias": "decoder.blocks.8.norm1.bias",
|
|
323
|
+
"decoder.up_blocks.1.resnets.1.norm1.weight": "decoder.blocks.8.norm1.weight",
|
|
324
|
+
"decoder.up_blocks.1.resnets.1.norm2.bias": "decoder.blocks.8.norm2.bias",
|
|
325
|
+
"decoder.up_blocks.1.resnets.1.norm2.weight": "decoder.blocks.8.norm2.weight",
|
|
326
|
+
"decoder.up_blocks.1.resnets.2.conv1.bias": "decoder.blocks.9.conv1.bias",
|
|
327
|
+
"decoder.up_blocks.1.resnets.2.conv1.weight": "decoder.blocks.9.conv1.weight",
|
|
328
|
+
"decoder.up_blocks.1.resnets.2.conv2.bias": "decoder.blocks.9.conv2.bias",
|
|
329
|
+
"decoder.up_blocks.1.resnets.2.conv2.weight": "decoder.blocks.9.conv2.weight",
|
|
330
|
+
"decoder.up_blocks.1.resnets.2.norm1.bias": "decoder.blocks.9.norm1.bias",
|
|
331
|
+
"decoder.up_blocks.1.resnets.2.norm1.weight": "decoder.blocks.9.norm1.weight",
|
|
332
|
+
"decoder.up_blocks.1.resnets.2.norm2.bias": "decoder.blocks.9.norm2.bias",
|
|
333
|
+
"decoder.up_blocks.1.resnets.2.norm2.weight": "decoder.blocks.9.norm2.weight",
|
|
334
|
+
"decoder.up_blocks.1.upsamplers.0.conv.bias": "decoder.blocks.10.conv.bias",
|
|
335
|
+
"decoder.up_blocks.1.upsamplers.0.conv.weight": "decoder.blocks.10.conv.weight",
|
|
336
|
+
"decoder.up_blocks.2.resnets.0.conv1.bias": "decoder.blocks.11.conv1.bias",
|
|
337
|
+
"decoder.up_blocks.2.resnets.0.conv1.weight": "decoder.blocks.11.conv1.weight",
|
|
338
|
+
"decoder.up_blocks.2.resnets.0.conv2.bias": "decoder.blocks.11.conv2.bias",
|
|
339
|
+
"decoder.up_blocks.2.resnets.0.conv2.weight": "decoder.blocks.11.conv2.weight",
|
|
340
|
+
"decoder.up_blocks.2.resnets.0.conv_shortcut.bias": "decoder.blocks.11.conv_shortcut.bias",
|
|
341
|
+
"decoder.up_blocks.2.resnets.0.conv_shortcut.weight": "decoder.blocks.11.conv_shortcut.weight",
|
|
342
|
+
"decoder.up_blocks.2.resnets.0.norm1.bias": "decoder.blocks.11.norm1.bias",
|
|
343
|
+
"decoder.up_blocks.2.resnets.0.norm1.weight": "decoder.blocks.11.norm1.weight",
|
|
344
|
+
"decoder.up_blocks.2.resnets.0.norm2.bias": "decoder.blocks.11.norm2.bias",
|
|
345
|
+
"decoder.up_blocks.2.resnets.0.norm2.weight": "decoder.blocks.11.norm2.weight",
|
|
346
|
+
"decoder.up_blocks.2.resnets.1.conv1.bias": "decoder.blocks.12.conv1.bias",
|
|
347
|
+
"decoder.up_blocks.2.resnets.1.conv1.weight": "decoder.blocks.12.conv1.weight",
|
|
348
|
+
"decoder.up_blocks.2.resnets.1.conv2.bias": "decoder.blocks.12.conv2.bias",
|
|
349
|
+
"decoder.up_blocks.2.resnets.1.conv2.weight": "decoder.blocks.12.conv2.weight",
|
|
350
|
+
"decoder.up_blocks.2.resnets.1.norm1.bias": "decoder.blocks.12.norm1.bias",
|
|
351
|
+
"decoder.up_blocks.2.resnets.1.norm1.weight": "decoder.blocks.12.norm1.weight",
|
|
352
|
+
"decoder.up_blocks.2.resnets.1.norm2.bias": "decoder.blocks.12.norm2.bias",
|
|
353
|
+
"decoder.up_blocks.2.resnets.1.norm2.weight": "decoder.blocks.12.norm2.weight",
|
|
354
|
+
"decoder.up_blocks.2.resnets.2.conv1.bias": "decoder.blocks.13.conv1.bias",
|
|
355
|
+
"decoder.up_blocks.2.resnets.2.conv1.weight": "decoder.blocks.13.conv1.weight",
|
|
356
|
+
"decoder.up_blocks.2.resnets.2.conv2.bias": "decoder.blocks.13.conv2.bias",
|
|
357
|
+
"decoder.up_blocks.2.resnets.2.conv2.weight": "decoder.blocks.13.conv2.weight",
|
|
358
|
+
"decoder.up_blocks.2.resnets.2.norm1.bias": "decoder.blocks.13.norm1.bias",
|
|
359
|
+
"decoder.up_blocks.2.resnets.2.norm1.weight": "decoder.blocks.13.norm1.weight",
|
|
360
|
+
"decoder.up_blocks.2.resnets.2.norm2.bias": "decoder.blocks.13.norm2.bias",
|
|
361
|
+
"decoder.up_blocks.2.resnets.2.norm2.weight": "decoder.blocks.13.norm2.weight",
|
|
362
|
+
"decoder.up_blocks.2.upsamplers.0.conv.bias": "decoder.blocks.14.conv.bias",
|
|
363
|
+
"decoder.up_blocks.2.upsamplers.0.conv.weight": "decoder.blocks.14.conv.weight",
|
|
364
|
+
"decoder.up_blocks.3.resnets.0.conv1.bias": "decoder.blocks.15.conv1.bias",
|
|
365
|
+
"decoder.up_blocks.3.resnets.0.conv1.weight": "decoder.blocks.15.conv1.weight",
|
|
366
|
+
"decoder.up_blocks.3.resnets.0.conv2.bias": "decoder.blocks.15.conv2.bias",
|
|
367
|
+
"decoder.up_blocks.3.resnets.0.conv2.weight": "decoder.blocks.15.conv2.weight",
|
|
368
|
+
"decoder.up_blocks.3.resnets.0.conv_shortcut.bias": "decoder.blocks.15.conv_shortcut.bias",
|
|
369
|
+
"decoder.up_blocks.3.resnets.0.conv_shortcut.weight": "decoder.blocks.15.conv_shortcut.weight",
|
|
370
|
+
"decoder.up_blocks.3.resnets.0.norm1.bias": "decoder.blocks.15.norm1.bias",
|
|
371
|
+
"decoder.up_blocks.3.resnets.0.norm1.weight": "decoder.blocks.15.norm1.weight",
|
|
372
|
+
"decoder.up_blocks.3.resnets.0.norm2.bias": "decoder.blocks.15.norm2.bias",
|
|
373
|
+
"decoder.up_blocks.3.resnets.0.norm2.weight": "decoder.blocks.15.norm2.weight",
|
|
374
|
+
"decoder.up_blocks.3.resnets.1.conv1.bias": "decoder.blocks.16.conv1.bias",
|
|
375
|
+
"decoder.up_blocks.3.resnets.1.conv1.weight": "decoder.blocks.16.conv1.weight",
|
|
376
|
+
"decoder.up_blocks.3.resnets.1.conv2.bias": "decoder.blocks.16.conv2.bias",
|
|
377
|
+
"decoder.up_blocks.3.resnets.1.conv2.weight": "decoder.blocks.16.conv2.weight",
|
|
378
|
+
"decoder.up_blocks.3.resnets.1.norm1.bias": "decoder.blocks.16.norm1.bias",
|
|
379
|
+
"decoder.up_blocks.3.resnets.1.norm1.weight": "decoder.blocks.16.norm1.weight",
|
|
380
|
+
"decoder.up_blocks.3.resnets.1.norm2.bias": "decoder.blocks.16.norm2.bias",
|
|
381
|
+
"decoder.up_blocks.3.resnets.1.norm2.weight": "decoder.blocks.16.norm2.weight",
|
|
382
|
+
"decoder.up_blocks.3.resnets.2.conv1.bias": "decoder.blocks.17.conv1.bias",
|
|
383
|
+
"decoder.up_blocks.3.resnets.2.conv1.weight": "decoder.blocks.17.conv1.weight",
|
|
384
|
+
"decoder.up_blocks.3.resnets.2.conv2.bias": "decoder.blocks.17.conv2.bias",
|
|
385
|
+
"decoder.up_blocks.3.resnets.2.conv2.weight": "decoder.blocks.17.conv2.weight",
|
|
386
|
+
"decoder.up_blocks.3.resnets.2.norm1.bias": "decoder.blocks.17.norm1.bias",
|
|
387
|
+
"decoder.up_blocks.3.resnets.2.norm1.weight": "decoder.blocks.17.norm1.weight",
|
|
388
|
+
"decoder.up_blocks.3.resnets.2.norm2.bias": "decoder.blocks.17.norm2.bias",
|
|
389
|
+
"decoder.up_blocks.3.resnets.2.norm2.weight": "decoder.blocks.17.norm2.weight",
|
|
390
|
+
"encoder.conv_in.bias": "encoder.conv_in.bias",
|
|
391
|
+
"encoder.conv_in.weight": "encoder.conv_in.weight",
|
|
392
|
+
"encoder.conv_out.bias": "encoder.conv_out.bias",
|
|
393
|
+
"encoder.conv_out.weight": "encoder.conv_out.weight",
|
|
394
|
+
"encoder.conv_norm_out.bias": "encoder.conv_norm_out.bias",
|
|
395
|
+
"encoder.conv_norm_out.weight": "encoder.conv_norm_out.weight",
|
|
396
|
+
"encoder.down_blocks.0.resnets.0.conv1.bias": "encoder.blocks.0.conv1.bias",
|
|
397
|
+
"encoder.down_blocks.0.resnets.0.conv1.weight": "encoder.blocks.0.conv1.weight",
|
|
398
|
+
"encoder.down_blocks.0.resnets.0.conv2.bias": "encoder.blocks.0.conv2.bias",
|
|
399
|
+
"encoder.down_blocks.0.resnets.0.conv2.weight": "encoder.blocks.0.conv2.weight",
|
|
400
|
+
"encoder.down_blocks.0.resnets.0.norm1.bias": "encoder.blocks.0.norm1.bias",
|
|
401
|
+
"encoder.down_blocks.0.resnets.0.norm1.weight": "encoder.blocks.0.norm1.weight",
|
|
402
|
+
"encoder.down_blocks.0.resnets.0.norm2.bias": "encoder.blocks.0.norm2.bias",
|
|
403
|
+
"encoder.down_blocks.0.resnets.0.norm2.weight": "encoder.blocks.0.norm2.weight",
|
|
404
|
+
"encoder.down_blocks.0.resnets.1.conv1.bias": "encoder.blocks.1.conv1.bias",
|
|
405
|
+
"encoder.down_blocks.0.resnets.1.conv1.weight": "encoder.blocks.1.conv1.weight",
|
|
406
|
+
"encoder.down_blocks.0.resnets.1.conv2.bias": "encoder.blocks.1.conv2.bias",
|
|
407
|
+
"encoder.down_blocks.0.resnets.1.conv2.weight": "encoder.blocks.1.conv2.weight",
|
|
408
|
+
"encoder.down_blocks.0.resnets.1.norm1.bias": "encoder.blocks.1.norm1.bias",
|
|
409
|
+
"encoder.down_blocks.0.resnets.1.norm1.weight": "encoder.blocks.1.norm1.weight",
|
|
410
|
+
"encoder.down_blocks.0.resnets.1.norm2.bias": "encoder.blocks.1.norm2.bias",
|
|
411
|
+
"encoder.down_blocks.0.resnets.1.norm2.weight": "encoder.blocks.1.norm2.weight",
|
|
412
|
+
"encoder.down_blocks.0.downsamplers.0.conv.bias": "encoder.blocks.2.conv.bias",
|
|
413
|
+
"encoder.down_blocks.0.downsamplers.0.conv.weight": "encoder.blocks.2.conv.weight",
|
|
414
|
+
"encoder.down_blocks.1.resnets.0.conv1.bias": "encoder.blocks.3.conv1.bias",
|
|
415
|
+
"encoder.down_blocks.1.resnets.0.conv1.weight": "encoder.blocks.3.conv1.weight",
|
|
416
|
+
"encoder.down_blocks.1.resnets.0.conv2.bias": "encoder.blocks.3.conv2.bias",
|
|
417
|
+
"encoder.down_blocks.1.resnets.0.conv2.weight": "encoder.blocks.3.conv2.weight",
|
|
418
|
+
"encoder.down_blocks.1.resnets.0.conv_shortcut.bias": "encoder.blocks.3.conv_shortcut.bias",
|
|
419
|
+
"encoder.down_blocks.1.resnets.0.conv_shortcut.weight": "encoder.blocks.3.conv_shortcut.weight",
|
|
420
|
+
"encoder.down_blocks.1.resnets.0.norm1.bias": "encoder.blocks.3.norm1.bias",
|
|
421
|
+
"encoder.down_blocks.1.resnets.0.norm1.weight": "encoder.blocks.3.norm1.weight",
|
|
422
|
+
"encoder.down_blocks.1.resnets.0.norm2.bias": "encoder.blocks.3.norm2.bias",
|
|
423
|
+
"encoder.down_blocks.1.resnets.0.norm2.weight": "encoder.blocks.3.norm2.weight",
|
|
424
|
+
"encoder.down_blocks.1.resnets.1.conv1.bias": "encoder.blocks.4.conv1.bias",
|
|
425
|
+
"encoder.down_blocks.1.resnets.1.conv1.weight": "encoder.blocks.4.conv1.weight",
|
|
426
|
+
"encoder.down_blocks.1.resnets.1.conv2.bias": "encoder.blocks.4.conv2.bias",
|
|
427
|
+
"encoder.down_blocks.1.resnets.1.conv2.weight": "encoder.blocks.4.conv2.weight",
|
|
428
|
+
"encoder.down_blocks.1.resnets.1.norm1.bias": "encoder.blocks.4.norm1.bias",
|
|
429
|
+
"encoder.down_blocks.1.resnets.1.norm1.weight": "encoder.blocks.4.norm1.weight",
|
|
430
|
+
"encoder.down_blocks.1.resnets.1.norm2.bias": "encoder.blocks.4.norm2.bias",
|
|
431
|
+
"encoder.down_blocks.1.resnets.1.norm2.weight": "encoder.blocks.4.norm2.weight",
|
|
432
|
+
"encoder.down_blocks.1.downsamplers.0.conv.bias": "encoder.blocks.5.conv.bias",
|
|
433
|
+
"encoder.down_blocks.1.downsamplers.0.conv.weight": "encoder.blocks.5.conv.weight",
|
|
434
|
+
"encoder.down_blocks.2.resnets.0.conv1.bias": "encoder.blocks.6.conv1.bias",
|
|
435
|
+
"encoder.down_blocks.2.resnets.0.conv1.weight": "encoder.blocks.6.conv1.weight",
|
|
436
|
+
"encoder.down_blocks.2.resnets.0.conv2.bias": "encoder.blocks.6.conv2.bias",
|
|
437
|
+
"encoder.down_blocks.2.resnets.0.conv2.weight": "encoder.blocks.6.conv2.weight",
|
|
438
|
+
"encoder.down_blocks.2.resnets.0.conv_shortcut.bias": "encoder.blocks.6.conv_shortcut.bias",
|
|
439
|
+
"encoder.down_blocks.2.resnets.0.conv_shortcut.weight": "encoder.blocks.6.conv_shortcut.weight",
|
|
440
|
+
"encoder.down_blocks.2.resnets.0.norm1.bias": "encoder.blocks.6.norm1.bias",
|
|
441
|
+
"encoder.down_blocks.2.resnets.0.norm1.weight": "encoder.blocks.6.norm1.weight",
|
|
442
|
+
"encoder.down_blocks.2.resnets.0.norm2.bias": "encoder.blocks.6.norm2.bias",
|
|
443
|
+
"encoder.down_blocks.2.resnets.0.norm2.weight": "encoder.blocks.6.norm2.weight",
|
|
444
|
+
"encoder.down_blocks.2.resnets.1.conv1.bias": "encoder.blocks.7.conv1.bias",
|
|
445
|
+
"encoder.down_blocks.2.resnets.1.conv1.weight": "encoder.blocks.7.conv1.weight",
|
|
446
|
+
"encoder.down_blocks.2.resnets.1.conv2.bias": "encoder.blocks.7.conv2.bias",
|
|
447
|
+
"encoder.down_blocks.2.resnets.1.conv2.weight": "encoder.blocks.7.conv2.weight",
|
|
448
|
+
"encoder.down_blocks.2.resnets.1.norm1.bias": "encoder.blocks.7.norm1.bias",
|
|
449
|
+
"encoder.down_blocks.2.resnets.1.norm1.weight": "encoder.blocks.7.norm1.weight",
|
|
450
|
+
"encoder.down_blocks.2.resnets.1.norm2.bias": "encoder.blocks.7.norm2.bias",
|
|
451
|
+
"encoder.down_blocks.2.resnets.1.norm2.weight": "encoder.blocks.7.norm2.weight",
|
|
452
|
+
"encoder.down_blocks.2.downsamplers.0.conv.bias": "encoder.blocks.8.conv.bias",
|
|
453
|
+
"encoder.down_blocks.2.downsamplers.0.conv.weight": "encoder.blocks.8.conv.weight",
|
|
454
|
+
"encoder.down_blocks.3.resnets.0.conv1.bias": "encoder.blocks.9.conv1.bias",
|
|
455
|
+
"encoder.down_blocks.3.resnets.0.conv1.weight": "encoder.blocks.9.conv1.weight",
|
|
456
|
+
"encoder.down_blocks.3.resnets.0.conv2.bias": "encoder.blocks.9.conv2.bias",
|
|
457
|
+
"encoder.down_blocks.3.resnets.0.conv2.weight": "encoder.blocks.9.conv2.weight",
|
|
458
|
+
"encoder.down_blocks.3.resnets.0.norm1.bias": "encoder.blocks.9.norm1.bias",
|
|
459
|
+
"encoder.down_blocks.3.resnets.0.norm1.weight": "encoder.blocks.9.norm1.weight",
|
|
460
|
+
"encoder.down_blocks.3.resnets.0.norm2.bias": "encoder.blocks.9.norm2.bias",
|
|
461
|
+
"encoder.down_blocks.3.resnets.0.norm2.weight": "encoder.blocks.9.norm2.weight",
|
|
462
|
+
"encoder.down_blocks.3.resnets.1.conv1.bias": "encoder.blocks.10.conv1.bias",
|
|
463
|
+
"encoder.down_blocks.3.resnets.1.conv1.weight": "encoder.blocks.10.conv1.weight",
|
|
464
|
+
"encoder.down_blocks.3.resnets.1.conv2.bias": "encoder.blocks.10.conv2.bias",
|
|
465
|
+
"encoder.down_blocks.3.resnets.1.conv2.weight": "encoder.blocks.10.conv2.weight",
|
|
466
|
+
"encoder.down_blocks.3.resnets.1.norm1.bias": "encoder.blocks.10.norm1.bias",
|
|
467
|
+
"encoder.down_blocks.3.resnets.1.norm1.weight": "encoder.blocks.10.norm1.weight",
|
|
468
|
+
"encoder.down_blocks.3.resnets.1.norm2.bias": "encoder.blocks.10.norm2.bias",
|
|
469
|
+
"encoder.down_blocks.3.resnets.1.norm2.weight": "encoder.blocks.10.norm2.weight",
|
|
470
|
+
"encoder.mid_block.attentions.0.to_k.bias": "encoder.blocks.12.transformer_blocks.0.to_k.bias",
|
|
471
|
+
"encoder.mid_block.attentions.0.to_k.weight": "encoder.blocks.12.transformer_blocks.0.to_k.weight",
|
|
472
|
+
"encoder.mid_block.attentions.0.group_norm.bias": "encoder.blocks.12.norm.bias",
|
|
473
|
+
"encoder.mid_block.attentions.0.group_norm.weight": "encoder.blocks.12.norm.weight",
|
|
474
|
+
"encoder.mid_block.attentions.0.to_out.0.bias": "encoder.blocks.12.transformer_blocks.0.to_out.bias",
|
|
475
|
+
"encoder.mid_block.attentions.0.to_out.0.weight": "encoder.blocks.12.transformer_blocks.0.to_out.weight",
|
|
476
|
+
"encoder.mid_block.attentions.0.to_q.bias": "encoder.blocks.12.transformer_blocks.0.to_q.bias",
|
|
477
|
+
"encoder.mid_block.attentions.0.to_q.weight": "encoder.blocks.12.transformer_blocks.0.to_q.weight",
|
|
478
|
+
"encoder.mid_block.attentions.0.to_v.bias": "encoder.blocks.12.transformer_blocks.0.to_v.bias",
|
|
479
|
+
"encoder.mid_block.attentions.0.to_v.weight": "encoder.blocks.12.transformer_blocks.0.to_v.weight",
|
|
480
|
+
"encoder.mid_block.resnets.0.conv1.bias": "encoder.blocks.11.conv1.bias",
|
|
481
|
+
"encoder.mid_block.resnets.0.conv1.weight": "encoder.blocks.11.conv1.weight",
|
|
482
|
+
"encoder.mid_block.resnets.0.conv2.bias": "encoder.blocks.11.conv2.bias",
|
|
483
|
+
"encoder.mid_block.resnets.0.conv2.weight": "encoder.blocks.11.conv2.weight",
|
|
484
|
+
"encoder.mid_block.resnets.0.norm1.bias": "encoder.blocks.11.norm1.bias",
|
|
485
|
+
"encoder.mid_block.resnets.0.norm1.weight": "encoder.blocks.11.norm1.weight",
|
|
486
|
+
"encoder.mid_block.resnets.0.norm2.bias": "encoder.blocks.11.norm2.bias",
|
|
487
|
+
"encoder.mid_block.resnets.0.norm2.weight": "encoder.blocks.11.norm2.weight",
|
|
488
|
+
"encoder.mid_block.resnets.1.conv1.bias": "encoder.blocks.13.conv1.bias",
|
|
489
|
+
"encoder.mid_block.resnets.1.conv1.weight": "encoder.blocks.13.conv1.weight",
|
|
490
|
+
"encoder.mid_block.resnets.1.conv2.bias": "encoder.blocks.13.conv2.bias",
|
|
491
|
+
"encoder.mid_block.resnets.1.conv2.weight": "encoder.blocks.13.conv2.weight",
|
|
492
|
+
"encoder.mid_block.resnets.1.norm1.bias": "encoder.blocks.13.norm1.bias",
|
|
493
|
+
"encoder.mid_block.resnets.1.norm1.weight": "encoder.blocks.13.norm1.weight",
|
|
494
|
+
"encoder.mid_block.resnets.1.norm2.bias": "encoder.blocks.13.norm2.bias",
|
|
495
|
+
"encoder.mid_block.resnets.1.norm2.weight": "encoder.blocks.13.norm2.weight"
|
|
248
496
|
}
|
|
249
497
|
}
|
|
250
498
|
}
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
{
|
|
2
|
+
"diffusers": {
|
|
3
|
+
"global_rename_dict": {
|
|
4
|
+
"patch_embedding": "patch_embedding",
|
|
5
|
+
"condition_embedder.text_embedder.linear_1": "text_embedding.0",
|
|
6
|
+
"condition_embedder.text_embedder.linear_2": "text_embedding.2",
|
|
7
|
+
"condition_embedder.time_embedder.linear_1": "time_embedding.0",
|
|
8
|
+
"condition_embedder.time_embedder.linear_2": "time_embedding.2",
|
|
9
|
+
"condition_embedder.time_proj": "time_projection.1",
|
|
10
|
+
"condition_embedder.image_embedder.norm1": "img_emb.proj.0",
|
|
11
|
+
"condition_embedder.image_embedder.ff.net.0.proj": "img_emb.proj.1",
|
|
12
|
+
"condition_embedder.image_embedder.ff.net.2": "img_emb.proj.3",
|
|
13
|
+
"condition_embedder.image_embedder.norm2": "img_emb.proj.4",
|
|
14
|
+
"condition_embedder.image_embedder.pos_embed": "img_emb.emb_pos",
|
|
15
|
+
"proj_out": "head.head",
|
|
16
|
+
"scale_shift_table": "head.modulation"
|
|
17
|
+
},
|
|
18
|
+
"rename_dict": {
|
|
19
|
+
"attn1.to_q": "self_attn.q",
|
|
20
|
+
"attn1.to_k": "self_attn.k",
|
|
21
|
+
"attn1.to_v": "self_attn.v",
|
|
22
|
+
"attn1.to_out.0": "self_attn.o",
|
|
23
|
+
"attn1.norm_q": "self_attn.norm_q",
|
|
24
|
+
"attn1.norm_k": "self_attn.norm_k",
|
|
25
|
+
"to_gate_compress": "self_attn.gate_compress",
|
|
26
|
+
"attn2.to_q": "cross_attn.q",
|
|
27
|
+
"attn2.to_k": "cross_attn.k",
|
|
28
|
+
"attn2.to_v": "cross_attn.v",
|
|
29
|
+
"attn2.to_out.0": "cross_attn.o",
|
|
30
|
+
"attn2.norm_q": "cross_attn.norm_q",
|
|
31
|
+
"attn2.norm_k": "cross_attn.norm_k",
|
|
32
|
+
"attn2.add_k_proj": "cross_attn.k_img",
|
|
33
|
+
"attn2.add_v_proj": "cross_attn.v_img",
|
|
34
|
+
"attn2.norm_added_k": "cross_attn.norm_k_img",
|
|
35
|
+
"norm2": "norm3",
|
|
36
|
+
"ffn.net.0.proj": "ffn.0",
|
|
37
|
+
"ffn.net.2": "ffn.2",
|
|
38
|
+
"scale_shift_table": "modulation"
|
|
39
|
+
}
|
|
40
|
+
}
|
|
41
|
+
}
|
|
@@ -17,8 +17,17 @@ from .pipeline import (
|
|
|
17
17
|
WanStateDicts,
|
|
18
18
|
WanS2VStateDicts,
|
|
19
19
|
QwenImageStateDicts,
|
|
20
|
+
AttnImpl,
|
|
21
|
+
SpargeAttentionParams,
|
|
22
|
+
VideoSparseAttentionParams,
|
|
23
|
+
LoraConfig,
|
|
24
|
+
)
|
|
25
|
+
from .controlnet import (
|
|
26
|
+
ControlType,
|
|
27
|
+
ControlNetParams,
|
|
28
|
+
QwenImageControlType,
|
|
29
|
+
QwenImageControlNetParams,
|
|
20
30
|
)
|
|
21
|
-
from .controlnet import ControlType, ControlNetParams
|
|
22
31
|
|
|
23
32
|
__all__ = [
|
|
24
33
|
"BaseConfig",
|
|
@@ -39,6 +48,12 @@ __all__ = [
|
|
|
39
48
|
"WanStateDicts",
|
|
40
49
|
"WanS2VStateDicts",
|
|
41
50
|
"QwenImageStateDicts",
|
|
51
|
+
"AttnImpl",
|
|
52
|
+
"SpargeAttentionParams",
|
|
53
|
+
"VideoSparseAttentionParams",
|
|
54
|
+
"LoraConfig",
|
|
42
55
|
"ControlType",
|
|
43
56
|
"ControlNetParams",
|
|
57
|
+
"QwenImageControlType",
|
|
58
|
+
"QwenImageControlNetParams",
|
|
44
59
|
]
|
|
@@ -34,3 +34,16 @@ class ControlNetParams:
|
|
|
34
34
|
control_start: float = 0
|
|
35
35
|
control_end: float = 1
|
|
36
36
|
processor_name: Optional[str] = None # only used for sdxl controlnet union now
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class QwenImageControlType(Enum):
|
|
40
|
+
eligen = "eligen"
|
|
41
|
+
in_context = "in_context"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class QwenImageControlNetParams:
|
|
46
|
+
image: ImageType
|
|
47
|
+
model: str
|
|
48
|
+
control_type: QwenImageControlType
|
|
49
|
+
scale: float = 1.0
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import torch
|
|
3
|
+
from enum import Enum
|
|
3
4
|
from dataclasses import dataclass, field
|
|
4
5
|
from typing import List, Dict, Tuple, Optional
|
|
5
6
|
|
|
@@ -19,14 +20,36 @@ class BaseConfig:
|
|
|
19
20
|
offload_to_disk: bool = False
|
|
20
21
|
|
|
21
22
|
|
|
23
|
+
class AttnImpl(Enum):
|
|
24
|
+
AUTO = "auto"
|
|
25
|
+
EAGER = "eager" # Native Attention
|
|
26
|
+
FA2 = "fa2" # Flash Attention 2
|
|
27
|
+
FA3 = "fa3" # Flash Attention 3
|
|
28
|
+
FA3_FP8 = "fa3_fp8" # Flash Attention 3 with FP8
|
|
29
|
+
XFORMERS = "xformers" # XFormers
|
|
30
|
+
SDPA = "sdpa" # Scaled Dot Product Attention
|
|
31
|
+
SAGE = "sage" # Sage Attention
|
|
32
|
+
SPARGE = "sparge" # Sparge Attention
|
|
33
|
+
VSA = "vsa" # Video Sparse Attention
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class SpargeAttentionParams:
|
|
38
|
+
smooth_k: bool = True
|
|
39
|
+
cdfthreshd: float = 0.6
|
|
40
|
+
simthreshd1: float = 0.98
|
|
41
|
+
pvthreshd: float = 50.0
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class VideoSparseAttentionParams:
|
|
46
|
+
sparsity: float = 0.9
|
|
47
|
+
|
|
48
|
+
|
|
22
49
|
@dataclass
|
|
23
50
|
class AttentionConfig:
|
|
24
|
-
dit_attn_impl:
|
|
25
|
-
|
|
26
|
-
sparge_smooth_k: bool = True
|
|
27
|
-
sparge_cdfthreshd: float = 0.6
|
|
28
|
-
sparge_simthreshd1: float = 0.98
|
|
29
|
-
sparge_pvthreshd: float = 50.0
|
|
51
|
+
dit_attn_impl: AttnImpl = AttnImpl.AUTO
|
|
52
|
+
attn_params: Optional[SpargeAttentionParams | VideoSparseAttentionParams] = None
|
|
30
53
|
|
|
31
54
|
|
|
32
55
|
@dataclass
|
|
@@ -221,14 +244,11 @@ class QwenImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfi
|
|
|
221
244
|
encoder_dtype: torch.dtype = torch.bfloat16
|
|
222
245
|
vae_dtype: torch.dtype = torch.float32
|
|
223
246
|
|
|
247
|
+
load_encoder: bool = True
|
|
248
|
+
|
|
224
249
|
# override OptimizationConfig
|
|
225
250
|
fbcache_relative_l1_threshold = 0.009
|
|
226
251
|
|
|
227
|
-
# override BaseConfig
|
|
228
|
-
vae_tiled: bool = True
|
|
229
|
-
vae_tile_size: Tuple[int, int] = (34, 34)
|
|
230
|
-
vae_tile_stride: Tuple[int, int] = (18, 16)
|
|
231
|
-
|
|
232
252
|
@classmethod
|
|
233
253
|
def basic_config(
|
|
234
254
|
cls,
|
|
@@ -352,3 +372,9 @@ def init_parallel_config(config: FluxPipelineConfig | QwenImagePipelineConfig |
|
|
|
352
372
|
config.tp_degree = 1
|
|
353
373
|
else:
|
|
354
374
|
raise ValueError("sp_ulysses_degree and sp_ring_degree must be specified together")
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
@dataclass
|
|
378
|
+
class LoraConfig:
|
|
379
|
+
scale: float
|
|
380
|
+
scheduler_config: Optional[Dict] = None
|
diffsynth_engine/models/base.py
CHANGED
|
@@ -57,7 +57,7 @@ class PreTrainedModel(nn.Module):
|
|
|
57
57
|
def get_tp_plan(self):
|
|
58
58
|
raise NotImplementedError(f"{self.__class__.__name__} does not support TP")
|
|
59
59
|
|
|
60
|
-
def
|
|
60
|
+
def get_fsdp_module_cls(self):
|
|
61
61
|
raise NotImplementedError(f"{self.__class__.__name__} does not support FSDP")
|
|
62
62
|
|
|
63
63
|
|