diffsynth-engine 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +28 -0
- diffsynth_engine/algorithm/__init__.py +0 -0
- diffsynth_engine/algorithm/noise_scheduler/__init__.py +21 -0
- diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +10 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +5 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +28 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +25 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +50 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +26 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +25 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +19 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +21 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +77 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +17 -0
- diffsynth_engine/algorithm/sampler/__init__.py +19 -0
- diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
- diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +22 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +54 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +32 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +125 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +29 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +53 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +59 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +29 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +12 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +30 -0
- diffsynth_engine/conf/models/components/vae.json +254 -0
- diffsynth_engine/conf/models/flux/flux_dit.json +105 -0
- diffsynth_engine/conf/models/flux/flux_text_encoder.json +20 -0
- diffsynth_engine/conf/models/flux/flux_vae.json +250 -0
- diffsynth_engine/conf/models/sd/sd_text_encoder.json +220 -0
- diffsynth_engine/conf/models/sd/sd_unet.json +397 -0
- diffsynth_engine/conf/models/sd3/sd3_dit.json +908 -0
- diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +756 -0
- diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +455 -0
- diffsynth_engine/conf/models/sdxl/sdxl_unet.json +1056 -0
- diffsynth_engine/conf/models/wan/dit/1.3b-t2v.json +13 -0
- diffsynth_engine/conf/models/wan/dit/14b-i2v.json +13 -0
- diffsynth_engine/conf/models/wan/dit/14b-t2v.json +13 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +48895 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +30 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +30 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +49410 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +125 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +129428 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +940 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +48895 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +24 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +30 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +49410 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +40213 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +24 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +49411 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +308 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +1028026 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +2748 -0
- diffsynth_engine/kernels/__init__.py +0 -0
- diffsynth_engine/models/__init__.py +7 -0
- diffsynth_engine/models/base.py +64 -0
- diffsynth_engine/models/basic/__init__.py +0 -0
- diffsynth_engine/models/basic/attention.py +217 -0
- diffsynth_engine/models/basic/lora.py +293 -0
- diffsynth_engine/models/basic/relative_position_emb.py +56 -0
- diffsynth_engine/models/basic/timestep.py +81 -0
- diffsynth_engine/models/basic/transformer_helper.py +88 -0
- diffsynth_engine/models/basic/unet_helper.py +244 -0
- diffsynth_engine/models/components/__init__.py +0 -0
- diffsynth_engine/models/components/clip.py +56 -0
- diffsynth_engine/models/components/t5.py +222 -0
- diffsynth_engine/models/components/vae.py +392 -0
- diffsynth_engine/models/flux/__init__.py +14 -0
- diffsynth_engine/models/flux/flux_dit.py +476 -0
- diffsynth_engine/models/flux/flux_text_encoder.py +88 -0
- diffsynth_engine/models/flux/flux_vae.py +78 -0
- diffsynth_engine/models/sd/__init__.py +12 -0
- diffsynth_engine/models/sd/sd_text_encoder.py +142 -0
- diffsynth_engine/models/sd/sd_unet.py +293 -0
- diffsynth_engine/models/sd/sd_vae.py +38 -0
- diffsynth_engine/models/sd3/__init__.py +14 -0
- diffsynth_engine/models/sd3/sd3_dit.py +302 -0
- diffsynth_engine/models/sd3/sd3_text_encoder.py +163 -0
- diffsynth_engine/models/sd3/sd3_vae.py +43 -0
- diffsynth_engine/models/sdxl/__init__.py +13 -0
- diffsynth_engine/models/sdxl/sdxl_text_encoder.py +307 -0
- diffsynth_engine/models/sdxl/sdxl_unet.py +306 -0
- diffsynth_engine/models/sdxl/sdxl_vae.py +38 -0
- diffsynth_engine/models/utils.py +54 -0
- diffsynth_engine/models/wan/__init__.py +0 -0
- diffsynth_engine/models/wan/wan_dit.py +497 -0
- diffsynth_engine/models/wan/wan_image_encoder.py +494 -0
- diffsynth_engine/models/wan/wan_text_encoder.py +297 -0
- diffsynth_engine/models/wan/wan_vae.py +771 -0
- diffsynth_engine/pipelines/__init__.py +18 -0
- diffsynth_engine/pipelines/base.py +253 -0
- diffsynth_engine/pipelines/flux_image.py +512 -0
- diffsynth_engine/pipelines/sd_image.py +352 -0
- diffsynth_engine/pipelines/sdxl_image.py +395 -0
- diffsynth_engine/pipelines/wan_video.py +524 -0
- diffsynth_engine/tokenizers/__init__.py +6 -0
- diffsynth_engine/tokenizers/base.py +157 -0
- diffsynth_engine/tokenizers/clip.py +288 -0
- diffsynth_engine/tokenizers/t5.py +194 -0
- diffsynth_engine/tokenizers/wan.py +74 -0
- diffsynth_engine/utils/__init__.py +0 -0
- diffsynth_engine/utils/constants.py +34 -0
- diffsynth_engine/utils/download.py +135 -0
- diffsynth_engine/utils/env.py +7 -0
- diffsynth_engine/utils/flag.py +46 -0
- diffsynth_engine/utils/fp8_linear.py +64 -0
- diffsynth_engine/utils/gguf.py +415 -0
- diffsynth_engine/utils/loader.py +17 -0
- diffsynth_engine/utils/lock.py +56 -0
- diffsynth_engine/utils/logging.py +12 -0
- diffsynth_engine/utils/offload.py +44 -0
- diffsynth_engine/utils/parallel.py +390 -0
- diffsynth_engine/utils/prompt.py +9 -0
- diffsynth_engine/utils/video.py +40 -0
- diffsynth_engine-0.0.0.dist-info/LICENSE +201 -0
- diffsynth_engine-0.0.0.dist-info/METADATA +236 -0
- diffsynth_engine-0.0.0.dist-info/RECORD +127 -0
- diffsynth_engine-0.0.0.dist-info/WHEEL +5 -0
- diffsynth_engine-0.0.0.dist-info/top_level.txt +1 -0
|
File without changes
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
from typing import Dict, List, Union
|
|
5
|
+
from safetensors.torch import load_file
|
|
6
|
+
|
|
7
|
+
from diffsynth_engine.models.basic.lora import LoRALinear, LoRAConv2d
|
|
8
|
+
from diffsynth_engine.models.utils import no_init_weights
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class StateDictConverter:
|
|
12
|
+
def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
13
|
+
return state_dict
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PreTrainedModel(nn.Module):
|
|
17
|
+
converter = StateDictConverter()
|
|
18
|
+
|
|
19
|
+
def load_state_dict(self, state_dict: Dict[str, torch.Tensor], strict: bool = True, assign: bool = False):
|
|
20
|
+
state_dict = self.converter.convert(state_dict)
|
|
21
|
+
super().load_state_dict(state_dict, strict=strict, assign=assign)
|
|
22
|
+
|
|
23
|
+
@classmethod
|
|
24
|
+
def from_pretrained(cls, pretrained_model_path: Union[str, os.PathLike], device: str, dtype: torch.dtype, **kwargs):
|
|
25
|
+
state_dict = load_file(pretrained_model_path, device=device)
|
|
26
|
+
return cls.from_state_dict(state_dict, device=device, dtype=dtype, **kwargs)
|
|
27
|
+
|
|
28
|
+
@classmethod
|
|
29
|
+
def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype, **kwargs):
|
|
30
|
+
with no_init_weights():
|
|
31
|
+
model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype, **kwargs)
|
|
32
|
+
model.load_state_dict(state_dict)
|
|
33
|
+
model.to(device=device, dtype=dtype, non_blocking=True)
|
|
34
|
+
return model
|
|
35
|
+
|
|
36
|
+
def load_loras(self, lora_args: List[Dict[str, any]], fused: bool = True):
|
|
37
|
+
for args in lora_args:
|
|
38
|
+
key = args["name"]
|
|
39
|
+
module = self.get_submodule(key)
|
|
40
|
+
if not isinstance(module, (LoRALinear, LoRAConv2d)):
|
|
41
|
+
raise ValueError(f"Unsupported lora key: {key}")
|
|
42
|
+
if fused:
|
|
43
|
+
module.add_frozen_lora(**args)
|
|
44
|
+
else:
|
|
45
|
+
module.add_lora(**args)
|
|
46
|
+
|
|
47
|
+
def unload_loras(self):
|
|
48
|
+
for module in self.modules():
|
|
49
|
+
if isinstance(module, (LoRALinear, LoRAConv2d)):
|
|
50
|
+
module.clear()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def split_suffix(name: str):
|
|
54
|
+
suffix_list = [
|
|
55
|
+
".lora_up.weight",
|
|
56
|
+
".lora_down.weight",
|
|
57
|
+
".weight",
|
|
58
|
+
".bias",
|
|
59
|
+
".alpha",
|
|
60
|
+
]
|
|
61
|
+
for suffix in suffix_list:
|
|
62
|
+
if name.endswith(suffix):
|
|
63
|
+
return name.replace(suffix, ""), suffix
|
|
64
|
+
return name, ""
|
|
File without changes
|
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from einops import rearrange
|
|
4
|
+
from typing import Optional
|
|
5
|
+
from yunchang import LongContextAttention
|
|
6
|
+
from yunchang.kernels import AttnType
|
|
7
|
+
|
|
8
|
+
from diffsynth_engine.utils import logging
|
|
9
|
+
from diffsynth_engine.utils.flag import (
|
|
10
|
+
FLASH_ATTN_3_AVAILABLE,
|
|
11
|
+
FLASH_ATTN_2_AVAILABLE,
|
|
12
|
+
XFORMERS_AVAILABLE,
|
|
13
|
+
SDPA_AVAILABLE,
|
|
14
|
+
SAGE_ATTN_AVAILABLE,
|
|
15
|
+
SPARGE_ATTN_AVAILABLE,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
logger = logging.get_logger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
if FLASH_ATTN_3_AVAILABLE:
|
|
22
|
+
from flash_attn_interface import flash_attn_func as flash_attn3
|
|
23
|
+
if FLASH_ATTN_2_AVAILABLE:
|
|
24
|
+
from flash_attn import flash_attn_func as flash_attn2
|
|
25
|
+
if XFORMERS_AVAILABLE:
|
|
26
|
+
from xformers.ops import memory_efficient_attention as xformers_attn
|
|
27
|
+
if SDPA_AVAILABLE:
|
|
28
|
+
|
|
29
|
+
def sdpa_attn(q, k, v, attn_mask=None, scale=None):
|
|
30
|
+
q = q.transpose(1, 2)
|
|
31
|
+
k = k.transpose(1, 2)
|
|
32
|
+
v = v.transpose(1, 2)
|
|
33
|
+
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
34
|
+
return out.transpose(1, 2)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
if SAGE_ATTN_AVAILABLE:
|
|
38
|
+
from sageattention import sageattn
|
|
39
|
+
|
|
40
|
+
def sage_attn(q, k, v, attn_mask=None, scale=None):
|
|
41
|
+
q = q.transpose(1, 2)
|
|
42
|
+
k = k.transpose(1, 2)
|
|
43
|
+
v = v.transpose(1, 2)
|
|
44
|
+
out = sageattn(q, k, v, attn_mask=attn_mask, sm_scale=scale)
|
|
45
|
+
return out.transpose(1, 2)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
if SPARGE_ATTN_AVAILABLE:
|
|
49
|
+
from spas_sage_attn import spas_sage2_attn_meansim_cuda
|
|
50
|
+
|
|
51
|
+
def sparge_attn(self, q, k, v, attn_mask=None, scale=None):
|
|
52
|
+
q = q.transpose(1, 2)
|
|
53
|
+
k = k.transpose(1, 2)
|
|
54
|
+
v = v.transpose(1, 2)
|
|
55
|
+
out = spas_sage2_attn_meansim_cuda(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
56
|
+
return out.transpose(1, 2)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def eager_attn(q, k, v, attn_mask=None, scale=None):
|
|
60
|
+
q = q.transpose(1, 2)
|
|
61
|
+
k = k.transpose(1, 2)
|
|
62
|
+
v = v.transpose(1, 2)
|
|
63
|
+
scale = 1 / q.shape[-1] ** 0.5 if scale is None else scale
|
|
64
|
+
q = q * scale
|
|
65
|
+
attn = torch.matmul(q, k.transpose(-2, -1))
|
|
66
|
+
if attn_mask is not None:
|
|
67
|
+
attn = attn + attn_mask
|
|
68
|
+
attn = attn.softmax(-1)
|
|
69
|
+
out = attn @ v
|
|
70
|
+
return out.transpose(1, 2)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def attention(
|
|
74
|
+
q,
|
|
75
|
+
k,
|
|
76
|
+
v,
|
|
77
|
+
attn_impl: Optional[str] = None,
|
|
78
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
79
|
+
scale: Optional[float] = None,
|
|
80
|
+
):
|
|
81
|
+
"""
|
|
82
|
+
q: [B, Lq, Nq, C1]
|
|
83
|
+
k: [B, Lk, Nk, C1]
|
|
84
|
+
v: [B, Lk, Nk, C2]
|
|
85
|
+
"""
|
|
86
|
+
assert attn_impl in [
|
|
87
|
+
None,
|
|
88
|
+
"auto",
|
|
89
|
+
"eager",
|
|
90
|
+
"flash_attn_2",
|
|
91
|
+
"flash_attn_3",
|
|
92
|
+
"xformers",
|
|
93
|
+
"sdpa",
|
|
94
|
+
"sage_attn",
|
|
95
|
+
"sparge_attn",
|
|
96
|
+
]
|
|
97
|
+
if attn_impl is None or attn_impl == "auto":
|
|
98
|
+
if FLASH_ATTN_3_AVAILABLE:
|
|
99
|
+
return flash_attn3(q, k, v, softmax_scale=scale)
|
|
100
|
+
elif FLASH_ATTN_2_AVAILABLE:
|
|
101
|
+
return flash_attn2(q, k, v, softmax_scale=scale)
|
|
102
|
+
elif XFORMERS_AVAILABLE:
|
|
103
|
+
return xformers_attn(q, k, v, attn_bias=attn_mask, scale=scale)
|
|
104
|
+
elif SDPA_AVAILABLE:
|
|
105
|
+
return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
106
|
+
else:
|
|
107
|
+
return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
108
|
+
else:
|
|
109
|
+
if attn_impl == "eager":
|
|
110
|
+
return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
111
|
+
elif attn_impl == "flash_attn_3":
|
|
112
|
+
return flash_attn3(q, k, v, softmax_scale=scale)
|
|
113
|
+
elif attn_impl == "flash_attn_2":
|
|
114
|
+
return flash_attn2(q, k, v, softmax_scale=scale)
|
|
115
|
+
elif attn_impl == "xformers":
|
|
116
|
+
return xformers_attn(q, k, v, attn_bias=attn_mask, scale=scale)
|
|
117
|
+
elif attn_impl == "sdpa":
|
|
118
|
+
return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
119
|
+
elif attn_impl == "sage_attn":
|
|
120
|
+
return sage_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
121
|
+
elif attn_impl == "sparge_attn":
|
|
122
|
+
return sparge_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
123
|
+
else:
|
|
124
|
+
raise ValueError(f"Invalid attention implementation: {attn_impl}")
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class Attention(nn.Module):
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
q_dim,
|
|
131
|
+
num_heads,
|
|
132
|
+
head_dim,
|
|
133
|
+
kv_dim=None,
|
|
134
|
+
bias_q=False,
|
|
135
|
+
bias_kv=False,
|
|
136
|
+
bias_out=False,
|
|
137
|
+
scale=None,
|
|
138
|
+
attn_impl: Optional[str] = None,
|
|
139
|
+
device: str = "cuda:0",
|
|
140
|
+
dtype: torch.dtype = torch.float16,
|
|
141
|
+
):
|
|
142
|
+
super().__init__()
|
|
143
|
+
dim_inner = head_dim * num_heads
|
|
144
|
+
kv_dim = kv_dim if kv_dim is not None else q_dim
|
|
145
|
+
self.num_heads = num_heads
|
|
146
|
+
self.head_dim = head_dim
|
|
147
|
+
|
|
148
|
+
self.to_q = nn.Linear(q_dim, dim_inner, bias=bias_q, device=device, dtype=dtype)
|
|
149
|
+
self.to_k = nn.Linear(kv_dim, dim_inner, bias=bias_kv, device=device, dtype=dtype)
|
|
150
|
+
self.to_v = nn.Linear(kv_dim, dim_inner, bias=bias_kv, device=device, dtype=dtype)
|
|
151
|
+
self.to_out = nn.Linear(dim_inner, q_dim, bias=bias_out, device=device, dtype=dtype)
|
|
152
|
+
self.attn_impl = attn_impl
|
|
153
|
+
self.scale = scale
|
|
154
|
+
|
|
155
|
+
def forward(
|
|
156
|
+
self,
|
|
157
|
+
x: torch.Tensor,
|
|
158
|
+
y: Optional[torch.Tensor] = None,
|
|
159
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
160
|
+
):
|
|
161
|
+
if y is None:
|
|
162
|
+
y = x
|
|
163
|
+
q = rearrange(self.to_q(x), "b s (n d) -> b s n d", n=self.num_heads)
|
|
164
|
+
k = rearrange(self.to_k(y), "b s (n d) -> b s n d", n=self.num_heads)
|
|
165
|
+
v = rearrange(self.to_v(y), "b s (n d) -> b s n d", n=self.num_heads)
|
|
166
|
+
out = attention(q, k, v, attn_mask=attn_mask, attn_impl=self.attn_impl, scale=self.scale)
|
|
167
|
+
out = rearrange(out, "b s n d -> b s (n d)", n=self.num_heads)
|
|
168
|
+
return self.to_out(out)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def long_context_attention(
|
|
172
|
+
q,
|
|
173
|
+
k,
|
|
174
|
+
v,
|
|
175
|
+
attn_impl: Optional[str] = None,
|
|
176
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
177
|
+
scale: Optional[float] = None,
|
|
178
|
+
):
|
|
179
|
+
"""
|
|
180
|
+
q: [B, Lq, Nq, C1]
|
|
181
|
+
k: [B, Lk, Nk, C1]
|
|
182
|
+
v: [B, Lk, Nk, C2]
|
|
183
|
+
"""
|
|
184
|
+
assert attn_impl in [
|
|
185
|
+
None,
|
|
186
|
+
"auto",
|
|
187
|
+
"eager",
|
|
188
|
+
"flash_attn_2",
|
|
189
|
+
"flash_attn_3",
|
|
190
|
+
"xformers",
|
|
191
|
+
"sdpa",
|
|
192
|
+
"sage_attn",
|
|
193
|
+
"sparge_attn",
|
|
194
|
+
]
|
|
195
|
+
if attn_impl is None or attn_impl == "auto":
|
|
196
|
+
if FLASH_ATTN_3_AVAILABLE:
|
|
197
|
+
attn_func = LongContextAttention(attn_type=AttnType.FA3)
|
|
198
|
+
elif FLASH_ATTN_2_AVAILABLE:
|
|
199
|
+
attn_func = LongContextAttention(attn_type=AttnType.FA)
|
|
200
|
+
elif SDPA_AVAILABLE:
|
|
201
|
+
attn_func = LongContextAttention(attn_type=AttnType.TORCH)
|
|
202
|
+
else:
|
|
203
|
+
raise ValueError("No available long context attention implementation")
|
|
204
|
+
else:
|
|
205
|
+
if attn_impl == "flash_attn_3":
|
|
206
|
+
attn_func = LongContextAttention(attn_type=AttnType.FA3)
|
|
207
|
+
elif attn_impl == "flash_attn_2":
|
|
208
|
+
attn_func = LongContextAttention(attn_type=AttnType.FA)
|
|
209
|
+
elif attn_impl == "sdpa":
|
|
210
|
+
attn_func = LongContextAttention(attn_type=AttnType.TORCH)
|
|
211
|
+
elif attn_impl == "sage_attn":
|
|
212
|
+
attn_func = LongContextAttention(attn_type=AttnType.SAGE_FP8)
|
|
213
|
+
elif attn_impl == "sparge_attn":
|
|
214
|
+
attn_func = LongContextAttention(attn_type=AttnType.SPARSE_SAGE)
|
|
215
|
+
else:
|
|
216
|
+
raise ValueError(f"Invalid long context attention implementation: {attn_impl}")
|
|
217
|
+
return attn_func(q, k, v, softmax_scale=scale)
|
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from torch.nn.common_types import _size_2_t
|
|
4
|
+
from typing import Union
|
|
5
|
+
from collections import OrderedDict
|
|
6
|
+
from contextlib import contextmanager
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LoRA(nn.Module):
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
scale: float,
|
|
13
|
+
rank: int,
|
|
14
|
+
alpha: int,
|
|
15
|
+
up: Union[nn.Linear, nn.Conv2d, torch.Tensor],
|
|
16
|
+
down: Union[nn.Linear, nn.Conv2d, torch.Tensor],
|
|
17
|
+
device: str,
|
|
18
|
+
dtype: torch.dtype,
|
|
19
|
+
):
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.device = device
|
|
22
|
+
self.dtype = dtype
|
|
23
|
+
self.scale = scale
|
|
24
|
+
self.rank = rank
|
|
25
|
+
self.alpha = alpha.item() if isinstance(alpha, torch.Tensor) else alpha
|
|
26
|
+
self.up = up.to(device=device, dtype=dtype)
|
|
27
|
+
self.down = down.to(device=device, dtype=dtype)
|
|
28
|
+
|
|
29
|
+
def forward(self, x):
|
|
30
|
+
if isinstance(self.up, torch.Tensor) and isinstance(self.down, torch.Tensor):
|
|
31
|
+
return self.scale * (self.alpha / self.rank) * (x @ self.down.T @ self.up.T)
|
|
32
|
+
return self.scale * (self.alpha / self.rank) * (self.up(self.down(x)))
|
|
33
|
+
|
|
34
|
+
def apply_to(self, w: Union[nn.Linear, nn.Conv2d, nn.Parameter, torch.Tensor]):
|
|
35
|
+
if isinstance(self.up, torch.Tensor) and isinstance(self.down, torch.Tensor):
|
|
36
|
+
delta_w = self.scale * (self.alpha / self.rank) * (self.up @ self.down)
|
|
37
|
+
else:
|
|
38
|
+
delta_w = self.scale * (self.alpha / self.rank) * (self.up.weight @ self.down.weight)
|
|
39
|
+
if isinstance(w, (nn.Linear, nn.Conv2d)):
|
|
40
|
+
delta_w = delta_w.to(device=w.weight.data.device, dtype=w.weight.data.dtype)
|
|
41
|
+
w.weight.data.add_(delta_w)
|
|
42
|
+
elif isinstance(w, nn.Parameter):
|
|
43
|
+
delta_w = delta_w.to(device=w.data.device, dtype=w.data.dtype)
|
|
44
|
+
w.data.add_(delta_w)
|
|
45
|
+
elif isinstance(w, torch.Tensor):
|
|
46
|
+
delta_w = delta_w.to(device=w.device, dtype=w.dtype)
|
|
47
|
+
w.add_(delta_w)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class LoRALinear(nn.Linear):
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
in_features: int,
|
|
54
|
+
out_features: int,
|
|
55
|
+
bias: bool = True,
|
|
56
|
+
device=None,
|
|
57
|
+
dtype=None,
|
|
58
|
+
) -> None:
|
|
59
|
+
super().__init__(in_features, out_features, bias, device, dtype)
|
|
60
|
+
# LoRA
|
|
61
|
+
self._lora_dict = OrderedDict()
|
|
62
|
+
# Frozen LoRA
|
|
63
|
+
self._frozen_lora_list = []
|
|
64
|
+
self.register_buffer("_original_weight", None)
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def from_linear(linear: nn.Linear):
|
|
68
|
+
lora_linear = torch.nn.utils.skip_init(
|
|
69
|
+
LoRALinear,
|
|
70
|
+
linear.in_features,
|
|
71
|
+
linear.out_features,
|
|
72
|
+
linear.bias is not None,
|
|
73
|
+
device=linear.weight.device,
|
|
74
|
+
dtype=linear.weight.dtype,
|
|
75
|
+
)
|
|
76
|
+
lora_linear.weight = linear.weight
|
|
77
|
+
lora_linear.bias = linear.bias
|
|
78
|
+
return lora_linear
|
|
79
|
+
|
|
80
|
+
def add_lora(
|
|
81
|
+
self,
|
|
82
|
+
name: str,
|
|
83
|
+
scale: float,
|
|
84
|
+
rank: int,
|
|
85
|
+
alpha: int,
|
|
86
|
+
up: torch.Tensor,
|
|
87
|
+
down: torch.Tensor,
|
|
88
|
+
device: str,
|
|
89
|
+
dtype: torch.dtype,
|
|
90
|
+
**kwargs,
|
|
91
|
+
):
|
|
92
|
+
up_linear = torch.nn.utils.skip_init(
|
|
93
|
+
nn.Linear, up.shape[1], up.shape[0], bias=False, device=device, dtype=dtype
|
|
94
|
+
)
|
|
95
|
+
down_linear = torch.nn.utils.skip_init(
|
|
96
|
+
nn.Linear, down.shape[0], down.shape[1], bias=False, device=device, dtype=dtype
|
|
97
|
+
)
|
|
98
|
+
up_linear.weight.data = up
|
|
99
|
+
down_linear.weight.data = down
|
|
100
|
+
lora = LoRA(scale, rank, alpha, up_linear, down_linear, device, dtype)
|
|
101
|
+
self._lora_dict[name] = lora
|
|
102
|
+
|
|
103
|
+
def modify_scale(self, name: str, scale: float):
|
|
104
|
+
if name not in self._lora_dict:
|
|
105
|
+
raise ValueError(f"LoRA name {name} not found in LoRALinear {self.__class__.__name__}")
|
|
106
|
+
self._lora_dict[name].scale = scale
|
|
107
|
+
|
|
108
|
+
def add_frozen_lora(
|
|
109
|
+
self,
|
|
110
|
+
name: str,
|
|
111
|
+
scale: float,
|
|
112
|
+
rank: int,
|
|
113
|
+
alpha: int,
|
|
114
|
+
up: torch.Tensor,
|
|
115
|
+
down: torch.Tensor,
|
|
116
|
+
device: str,
|
|
117
|
+
dtype: torch.dtype,
|
|
118
|
+
save_original_weight: bool = True,
|
|
119
|
+
):
|
|
120
|
+
if save_original_weight and self._original_weight is None:
|
|
121
|
+
self._original_weight = self.weight.clone()
|
|
122
|
+
lora = LoRA(scale, rank, alpha, up, down, device, dtype)
|
|
123
|
+
lora.apply_to(self)
|
|
124
|
+
self._frozen_lora_list.append(lora)
|
|
125
|
+
|
|
126
|
+
def clear(self):
|
|
127
|
+
if self._original_weight is None and len(self._frozen_lora_list) > 0:
|
|
128
|
+
raise RuntimeError(
|
|
129
|
+
"Current LoRALinear has patched by frozen LoRA, but original weight is not saved, so you cannot clear LoRA."
|
|
130
|
+
)
|
|
131
|
+
self._lora_dict.clear()
|
|
132
|
+
self._frozen_lora_list = []
|
|
133
|
+
if self._original_weight is not None:
|
|
134
|
+
self.weight.data = self._original_weight
|
|
135
|
+
self._original_weight = None
|
|
136
|
+
|
|
137
|
+
def forward(self, x):
|
|
138
|
+
w_x = super().forward(x)
|
|
139
|
+
for name, lora in self._lora_dict.items():
|
|
140
|
+
w_x += lora(x)
|
|
141
|
+
return w_x
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class LoRAConv2d(nn.Conv2d):
|
|
145
|
+
def __init__(
|
|
146
|
+
self,
|
|
147
|
+
in_channels: int,
|
|
148
|
+
out_channels: int,
|
|
149
|
+
kernel_size: _size_2_t,
|
|
150
|
+
stride: _size_2_t = 1,
|
|
151
|
+
padding: Union[str, _size_2_t] = 0,
|
|
152
|
+
dilation: _size_2_t = 1,
|
|
153
|
+
groups: int = 1,
|
|
154
|
+
bias: bool = True,
|
|
155
|
+
padding_mode: str = "zeros", # TODO: refine this type
|
|
156
|
+
device=None,
|
|
157
|
+
dtype=None,
|
|
158
|
+
) -> None:
|
|
159
|
+
super().__init__(
|
|
160
|
+
in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype
|
|
161
|
+
)
|
|
162
|
+
# LoRA
|
|
163
|
+
self._lora_dict = OrderedDict()
|
|
164
|
+
# Frozen LoRA
|
|
165
|
+
self._frozen_lora_list = []
|
|
166
|
+
self._original_weight = None
|
|
167
|
+
|
|
168
|
+
@staticmethod
|
|
169
|
+
def from_conv2d(conv2d: nn.Conv2d):
|
|
170
|
+
lora_conv2d = torch.nn.utils.skip_init(
|
|
171
|
+
LoRAConv2d,
|
|
172
|
+
conv2d.in_channels,
|
|
173
|
+
conv2d.out_channels,
|
|
174
|
+
conv2d.kernel_size,
|
|
175
|
+
conv2d.stride,
|
|
176
|
+
conv2d.padding,
|
|
177
|
+
conv2d.dilation,
|
|
178
|
+
conv2d.groups,
|
|
179
|
+
conv2d.bias is not None,
|
|
180
|
+
conv2d.padding_mode,
|
|
181
|
+
device=conv2d.weight.device,
|
|
182
|
+
dtype=conv2d.weight.dtype,
|
|
183
|
+
)
|
|
184
|
+
lora_conv2d.weight = conv2d.weight
|
|
185
|
+
lora_conv2d.bias = conv2d.bias
|
|
186
|
+
return lora_conv2d
|
|
187
|
+
|
|
188
|
+
def _construct_lora(
|
|
189
|
+
self,
|
|
190
|
+
name: str,
|
|
191
|
+
scale: float,
|
|
192
|
+
rank: int,
|
|
193
|
+
alpha: int,
|
|
194
|
+
up: torch.Tensor,
|
|
195
|
+
down: torch.Tensor,
|
|
196
|
+
device: str,
|
|
197
|
+
dtype: torch.dtype,
|
|
198
|
+
):
|
|
199
|
+
down_conv = torch.nn.utils.skip_init(
|
|
200
|
+
nn.Conv2d,
|
|
201
|
+
self.in_channels,
|
|
202
|
+
rank,
|
|
203
|
+
kernel_size=self.kernel_size,
|
|
204
|
+
stride=self.stride,
|
|
205
|
+
padding=self.padding,
|
|
206
|
+
bias=False,
|
|
207
|
+
device=device,
|
|
208
|
+
dtype=dtype,
|
|
209
|
+
)
|
|
210
|
+
down_conv.weight.data = down
|
|
211
|
+
# according to the official kohya_ss trainer kernel_size are always fixed for the up layer
|
|
212
|
+
# see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
|
|
213
|
+
# refer from diffusers
|
|
214
|
+
up_conv = torch.nn.utils.skip_init(
|
|
215
|
+
nn.Conv2d,
|
|
216
|
+
rank,
|
|
217
|
+
self.out_channels,
|
|
218
|
+
kernel_size=(1, 1),
|
|
219
|
+
stride=(1, 1),
|
|
220
|
+
bias=False,
|
|
221
|
+
device=device,
|
|
222
|
+
dtype=dtype,
|
|
223
|
+
)
|
|
224
|
+
up_conv.weight.data = up
|
|
225
|
+
|
|
226
|
+
lora = LoRA(scale, rank, alpha, up_conv, down_conv, device, dtype)
|
|
227
|
+
return lora
|
|
228
|
+
|
|
229
|
+
def add_lora(
|
|
230
|
+
self,
|
|
231
|
+
name: str,
|
|
232
|
+
scale: float,
|
|
233
|
+
rank: int,
|
|
234
|
+
alpha: int,
|
|
235
|
+
up: torch.Tensor,
|
|
236
|
+
down: torch.Tensor,
|
|
237
|
+
device: str,
|
|
238
|
+
dtype: torch.dtype,
|
|
239
|
+
**kwargs,
|
|
240
|
+
):
|
|
241
|
+
self._lora_dict[name] = self._construct_lora(name, scale, rank, alpha, up, down, device, dtype)
|
|
242
|
+
|
|
243
|
+
def modify_scale(self, name: str, scale: float):
|
|
244
|
+
if name not in self._lora_dict:
|
|
245
|
+
raise ValueError(f"LoRA name {name} not found in LoRAConv2d {self.__class__.__name__}")
|
|
246
|
+
self._lora_dict[name].scale = scale
|
|
247
|
+
|
|
248
|
+
def add_frozen_lora(
|
|
249
|
+
self,
|
|
250
|
+
name: str,
|
|
251
|
+
scale: float,
|
|
252
|
+
rank: int,
|
|
253
|
+
alpha: int,
|
|
254
|
+
up: torch.Tensor,
|
|
255
|
+
down: torch.Tensor,
|
|
256
|
+
device: str,
|
|
257
|
+
dtype: torch.dtype,
|
|
258
|
+
save_original_weight: bool = True,
|
|
259
|
+
):
|
|
260
|
+
if save_original_weight and self._original_weight is None:
|
|
261
|
+
self._original_weight = self.weight.clone()
|
|
262
|
+
lora = self._construct_lora(name, scale, rank, alpha, up, down, device, dtype)
|
|
263
|
+
lora.apply_to(self)
|
|
264
|
+
self._frozen_lora_list.append(lora)
|
|
265
|
+
|
|
266
|
+
def clear(self):
|
|
267
|
+
if self._original_weight is None and len(self._frozen_lora_list) > 0:
|
|
268
|
+
raise RuntimeError(
|
|
269
|
+
"Current LoRALinear has patched by frozen LoRA, but original weight is not saved, so you cannot clear LoRA."
|
|
270
|
+
)
|
|
271
|
+
self._lora_dict.clear()
|
|
272
|
+
self._frozen_lora_list = []
|
|
273
|
+
if self._original_weight is not None:
|
|
274
|
+
self.weight.copy_(self._original_weight)
|
|
275
|
+
self._original_weight = None
|
|
276
|
+
|
|
277
|
+
def forward(self, x):
|
|
278
|
+
w_x = super().forward(x)
|
|
279
|
+
for name, lora in self._lora_dict.items():
|
|
280
|
+
w_x += lora(x)
|
|
281
|
+
return w_x
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
@contextmanager
|
|
285
|
+
def LoRAContext():
|
|
286
|
+
origin_linear = torch.nn.Linear
|
|
287
|
+
origin_conv2d = torch.nn.Conv2d
|
|
288
|
+
|
|
289
|
+
torch.nn.Linear = LoRALinear
|
|
290
|
+
torch.nn.Conv2d = LoRAConv2d
|
|
291
|
+
yield
|
|
292
|
+
torch.nn.Linear = origin_linear
|
|
293
|
+
torch.nn.Conv2d = origin_conv2d
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
|
7
|
+
relative_buckets = 0
|
|
8
|
+
if bidirectional:
|
|
9
|
+
num_buckets //= 2
|
|
10
|
+
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
|
|
11
|
+
relative_position = torch.abs(relative_position)
|
|
12
|
+
else:
|
|
13
|
+
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
|
|
14
|
+
# now relative_position is in the range [0, inf)
|
|
15
|
+
|
|
16
|
+
# half of the buckets are for exact increments in positions
|
|
17
|
+
max_exact = num_buckets // 2
|
|
18
|
+
is_small = relative_position < max_exact
|
|
19
|
+
|
|
20
|
+
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
|
21
|
+
relative_position_if_large = max_exact + (
|
|
22
|
+
torch.log(relative_position.float() / max_exact)
|
|
23
|
+
/ math.log(max_distance / max_exact)
|
|
24
|
+
* (num_buckets - max_exact)
|
|
25
|
+
).to(torch.long)
|
|
26
|
+
relative_position_if_large = torch.min(
|
|
27
|
+
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
|
|
31
|
+
return relative_buckets
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class RelativePositionEmbedding(nn.Module):
|
|
35
|
+
def __init__(
|
|
36
|
+
self, num_buckets, max_distance, num_heads, device: str = "cuda:0", dtype: torch.dtype = torch.float16
|
|
37
|
+
):
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.num_buckets = num_buckets
|
|
40
|
+
self.max_distance = max_distance
|
|
41
|
+
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads, device=device, dtype=dtype)
|
|
42
|
+
|
|
43
|
+
def forward(self, query_length, key_length):
|
|
44
|
+
device = self.relative_attention_bias.weight.device
|
|
45
|
+
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
|
|
46
|
+
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
|
|
47
|
+
relative_position = memory_position - context_position # shape (query_length, key_length)
|
|
48
|
+
relative_position_bucket = _relative_position_bucket(
|
|
49
|
+
relative_position, # shape (query_length, key_length)
|
|
50
|
+
bidirectional=True,
|
|
51
|
+
num_buckets=self.num_buckets,
|
|
52
|
+
max_distance=self.max_distance,
|
|
53
|
+
)
|
|
54
|
+
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
|
|
55
|
+
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
|
56
|
+
return values
|