diffsynth-engine 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. diffsynth_engine/__init__.py +28 -0
  2. diffsynth_engine/algorithm/__init__.py +0 -0
  3. diffsynth_engine/algorithm/noise_scheduler/__init__.py +21 -0
  4. diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +10 -0
  5. diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +5 -0
  6. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +28 -0
  7. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +25 -0
  8. diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +50 -0
  9. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  10. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +26 -0
  11. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +25 -0
  12. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +19 -0
  13. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +21 -0
  14. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +77 -0
  15. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +17 -0
  16. diffsynth_engine/algorithm/sampler/__init__.py +19 -0
  17. diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  18. diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +22 -0
  19. diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  20. diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +54 -0
  21. diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +32 -0
  22. diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +125 -0
  23. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +29 -0
  24. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +53 -0
  25. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +59 -0
  26. diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +29 -0
  27. diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +12 -0
  28. diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +30 -0
  29. diffsynth_engine/conf/models/components/vae.json +254 -0
  30. diffsynth_engine/conf/models/flux/flux_dit.json +105 -0
  31. diffsynth_engine/conf/models/flux/flux_text_encoder.json +20 -0
  32. diffsynth_engine/conf/models/flux/flux_vae.json +250 -0
  33. diffsynth_engine/conf/models/sd/sd_text_encoder.json +220 -0
  34. diffsynth_engine/conf/models/sd/sd_unet.json +397 -0
  35. diffsynth_engine/conf/models/sd3/sd3_dit.json +908 -0
  36. diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +756 -0
  37. diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +455 -0
  38. diffsynth_engine/conf/models/sdxl/sdxl_unet.json +1056 -0
  39. diffsynth_engine/conf/models/wan/dit/1.3b-t2v.json +13 -0
  40. diffsynth_engine/conf/models/wan/dit/14b-i2v.json +13 -0
  41. diffsynth_engine/conf/models/wan/dit/14b-t2v.json +13 -0
  42. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +48895 -0
  43. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +30 -0
  44. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +30 -0
  45. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +49410 -0
  46. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +125 -0
  47. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  48. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +129428 -0
  49. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +940 -0
  50. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +48895 -0
  51. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +24 -0
  52. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +30 -0
  53. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +49410 -0
  54. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +40213 -0
  55. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +24 -0
  56. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +38 -0
  57. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +49411 -0
  58. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +308 -0
  59. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  60. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +1028026 -0
  61. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +2748 -0
  62. diffsynth_engine/kernels/__init__.py +0 -0
  63. diffsynth_engine/models/__init__.py +7 -0
  64. diffsynth_engine/models/base.py +64 -0
  65. diffsynth_engine/models/basic/__init__.py +0 -0
  66. diffsynth_engine/models/basic/attention.py +217 -0
  67. diffsynth_engine/models/basic/lora.py +293 -0
  68. diffsynth_engine/models/basic/relative_position_emb.py +56 -0
  69. diffsynth_engine/models/basic/timestep.py +81 -0
  70. diffsynth_engine/models/basic/transformer_helper.py +88 -0
  71. diffsynth_engine/models/basic/unet_helper.py +244 -0
  72. diffsynth_engine/models/components/__init__.py +0 -0
  73. diffsynth_engine/models/components/clip.py +56 -0
  74. diffsynth_engine/models/components/t5.py +222 -0
  75. diffsynth_engine/models/components/vae.py +392 -0
  76. diffsynth_engine/models/flux/__init__.py +14 -0
  77. diffsynth_engine/models/flux/flux_dit.py +476 -0
  78. diffsynth_engine/models/flux/flux_text_encoder.py +88 -0
  79. diffsynth_engine/models/flux/flux_vae.py +78 -0
  80. diffsynth_engine/models/sd/__init__.py +12 -0
  81. diffsynth_engine/models/sd/sd_text_encoder.py +142 -0
  82. diffsynth_engine/models/sd/sd_unet.py +293 -0
  83. diffsynth_engine/models/sd/sd_vae.py +38 -0
  84. diffsynth_engine/models/sd3/__init__.py +14 -0
  85. diffsynth_engine/models/sd3/sd3_dit.py +302 -0
  86. diffsynth_engine/models/sd3/sd3_text_encoder.py +163 -0
  87. diffsynth_engine/models/sd3/sd3_vae.py +43 -0
  88. diffsynth_engine/models/sdxl/__init__.py +13 -0
  89. diffsynth_engine/models/sdxl/sdxl_text_encoder.py +307 -0
  90. diffsynth_engine/models/sdxl/sdxl_unet.py +306 -0
  91. diffsynth_engine/models/sdxl/sdxl_vae.py +38 -0
  92. diffsynth_engine/models/utils.py +54 -0
  93. diffsynth_engine/models/wan/__init__.py +0 -0
  94. diffsynth_engine/models/wan/wan_dit.py +497 -0
  95. diffsynth_engine/models/wan/wan_image_encoder.py +494 -0
  96. diffsynth_engine/models/wan/wan_text_encoder.py +297 -0
  97. diffsynth_engine/models/wan/wan_vae.py +771 -0
  98. diffsynth_engine/pipelines/__init__.py +18 -0
  99. diffsynth_engine/pipelines/base.py +253 -0
  100. diffsynth_engine/pipelines/flux_image.py +512 -0
  101. diffsynth_engine/pipelines/sd_image.py +352 -0
  102. diffsynth_engine/pipelines/sdxl_image.py +395 -0
  103. diffsynth_engine/pipelines/wan_video.py +524 -0
  104. diffsynth_engine/tokenizers/__init__.py +6 -0
  105. diffsynth_engine/tokenizers/base.py +157 -0
  106. diffsynth_engine/tokenizers/clip.py +288 -0
  107. diffsynth_engine/tokenizers/t5.py +194 -0
  108. diffsynth_engine/tokenizers/wan.py +74 -0
  109. diffsynth_engine/utils/__init__.py +0 -0
  110. diffsynth_engine/utils/constants.py +34 -0
  111. diffsynth_engine/utils/download.py +135 -0
  112. diffsynth_engine/utils/env.py +7 -0
  113. diffsynth_engine/utils/flag.py +46 -0
  114. diffsynth_engine/utils/fp8_linear.py +64 -0
  115. diffsynth_engine/utils/gguf.py +415 -0
  116. diffsynth_engine/utils/loader.py +17 -0
  117. diffsynth_engine/utils/lock.py +56 -0
  118. diffsynth_engine/utils/logging.py +12 -0
  119. diffsynth_engine/utils/offload.py +44 -0
  120. diffsynth_engine/utils/parallel.py +390 -0
  121. diffsynth_engine/utils/prompt.py +9 -0
  122. diffsynth_engine/utils/video.py +40 -0
  123. diffsynth_engine-0.0.0.dist-info/LICENSE +201 -0
  124. diffsynth_engine-0.0.0.dist-info/METADATA +236 -0
  125. diffsynth_engine-0.0.0.dist-info/RECORD +127 -0
  126. diffsynth_engine-0.0.0.dist-info/WHEEL +5 -0
  127. diffsynth_engine-0.0.0.dist-info/top_level.txt +1 -0
File without changes
@@ -0,0 +1,7 @@
1
+ from .base import PreTrainedModel, StateDictConverter
2
+
3
+
4
+ __all__ = [
5
+ "PreTrainedModel",
6
+ "StateDictConverter",
7
+ ]
@@ -0,0 +1,64 @@
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import Dict, List, Union
5
+ from safetensors.torch import load_file
6
+
7
+ from diffsynth_engine.models.basic.lora import LoRALinear, LoRAConv2d
8
+ from diffsynth_engine.models.utils import no_init_weights
9
+
10
+
11
+ class StateDictConverter:
12
+ def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
13
+ return state_dict
14
+
15
+
16
+ class PreTrainedModel(nn.Module):
17
+ converter = StateDictConverter()
18
+
19
+ def load_state_dict(self, state_dict: Dict[str, torch.Tensor], strict: bool = True, assign: bool = False):
20
+ state_dict = self.converter.convert(state_dict)
21
+ super().load_state_dict(state_dict, strict=strict, assign=assign)
22
+
23
+ @classmethod
24
+ def from_pretrained(cls, pretrained_model_path: Union[str, os.PathLike], device: str, dtype: torch.dtype, **kwargs):
25
+ state_dict = load_file(pretrained_model_path, device=device)
26
+ return cls.from_state_dict(state_dict, device=device, dtype=dtype, **kwargs)
27
+
28
+ @classmethod
29
+ def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype, **kwargs):
30
+ with no_init_weights():
31
+ model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype, **kwargs)
32
+ model.load_state_dict(state_dict)
33
+ model.to(device=device, dtype=dtype, non_blocking=True)
34
+ return model
35
+
36
+ def load_loras(self, lora_args: List[Dict[str, any]], fused: bool = True):
37
+ for args in lora_args:
38
+ key = args["name"]
39
+ module = self.get_submodule(key)
40
+ if not isinstance(module, (LoRALinear, LoRAConv2d)):
41
+ raise ValueError(f"Unsupported lora key: {key}")
42
+ if fused:
43
+ module.add_frozen_lora(**args)
44
+ else:
45
+ module.add_lora(**args)
46
+
47
+ def unload_loras(self):
48
+ for module in self.modules():
49
+ if isinstance(module, (LoRALinear, LoRAConv2d)):
50
+ module.clear()
51
+
52
+
53
+ def split_suffix(name: str):
54
+ suffix_list = [
55
+ ".lora_up.weight",
56
+ ".lora_down.weight",
57
+ ".weight",
58
+ ".bias",
59
+ ".alpha",
60
+ ]
61
+ for suffix in suffix_list:
62
+ if name.endswith(suffix):
63
+ return name.replace(suffix, ""), suffix
64
+ return name, ""
File without changes
@@ -0,0 +1,217 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from einops import rearrange
4
+ from typing import Optional
5
+ from yunchang import LongContextAttention
6
+ from yunchang.kernels import AttnType
7
+
8
+ from diffsynth_engine.utils import logging
9
+ from diffsynth_engine.utils.flag import (
10
+ FLASH_ATTN_3_AVAILABLE,
11
+ FLASH_ATTN_2_AVAILABLE,
12
+ XFORMERS_AVAILABLE,
13
+ SDPA_AVAILABLE,
14
+ SAGE_ATTN_AVAILABLE,
15
+ SPARGE_ATTN_AVAILABLE,
16
+ )
17
+
18
+ logger = logging.get_logger(__name__)
19
+
20
+
21
+ if FLASH_ATTN_3_AVAILABLE:
22
+ from flash_attn_interface import flash_attn_func as flash_attn3
23
+ if FLASH_ATTN_2_AVAILABLE:
24
+ from flash_attn import flash_attn_func as flash_attn2
25
+ if XFORMERS_AVAILABLE:
26
+ from xformers.ops import memory_efficient_attention as xformers_attn
27
+ if SDPA_AVAILABLE:
28
+
29
+ def sdpa_attn(q, k, v, attn_mask=None, scale=None):
30
+ q = q.transpose(1, 2)
31
+ k = k.transpose(1, 2)
32
+ v = v.transpose(1, 2)
33
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=scale)
34
+ return out.transpose(1, 2)
35
+
36
+
37
+ if SAGE_ATTN_AVAILABLE:
38
+ from sageattention import sageattn
39
+
40
+ def sage_attn(q, k, v, attn_mask=None, scale=None):
41
+ q = q.transpose(1, 2)
42
+ k = k.transpose(1, 2)
43
+ v = v.transpose(1, 2)
44
+ out = sageattn(q, k, v, attn_mask=attn_mask, sm_scale=scale)
45
+ return out.transpose(1, 2)
46
+
47
+
48
+ if SPARGE_ATTN_AVAILABLE:
49
+ from spas_sage_attn import spas_sage2_attn_meansim_cuda
50
+
51
+ def sparge_attn(self, q, k, v, attn_mask=None, scale=None):
52
+ q = q.transpose(1, 2)
53
+ k = k.transpose(1, 2)
54
+ v = v.transpose(1, 2)
55
+ out = spas_sage2_attn_meansim_cuda(q, k, v, attn_mask=attn_mask, scale=scale)
56
+ return out.transpose(1, 2)
57
+
58
+
59
+ def eager_attn(q, k, v, attn_mask=None, scale=None):
60
+ q = q.transpose(1, 2)
61
+ k = k.transpose(1, 2)
62
+ v = v.transpose(1, 2)
63
+ scale = 1 / q.shape[-1] ** 0.5 if scale is None else scale
64
+ q = q * scale
65
+ attn = torch.matmul(q, k.transpose(-2, -1))
66
+ if attn_mask is not None:
67
+ attn = attn + attn_mask
68
+ attn = attn.softmax(-1)
69
+ out = attn @ v
70
+ return out.transpose(1, 2)
71
+
72
+
73
+ def attention(
74
+ q,
75
+ k,
76
+ v,
77
+ attn_impl: Optional[str] = None,
78
+ attn_mask: Optional[torch.Tensor] = None,
79
+ scale: Optional[float] = None,
80
+ ):
81
+ """
82
+ q: [B, Lq, Nq, C1]
83
+ k: [B, Lk, Nk, C1]
84
+ v: [B, Lk, Nk, C2]
85
+ """
86
+ assert attn_impl in [
87
+ None,
88
+ "auto",
89
+ "eager",
90
+ "flash_attn_2",
91
+ "flash_attn_3",
92
+ "xformers",
93
+ "sdpa",
94
+ "sage_attn",
95
+ "sparge_attn",
96
+ ]
97
+ if attn_impl is None or attn_impl == "auto":
98
+ if FLASH_ATTN_3_AVAILABLE:
99
+ return flash_attn3(q, k, v, softmax_scale=scale)
100
+ elif FLASH_ATTN_2_AVAILABLE:
101
+ return flash_attn2(q, k, v, softmax_scale=scale)
102
+ elif XFORMERS_AVAILABLE:
103
+ return xformers_attn(q, k, v, attn_bias=attn_mask, scale=scale)
104
+ elif SDPA_AVAILABLE:
105
+ return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
106
+ else:
107
+ return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
108
+ else:
109
+ if attn_impl == "eager":
110
+ return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
111
+ elif attn_impl == "flash_attn_3":
112
+ return flash_attn3(q, k, v, softmax_scale=scale)
113
+ elif attn_impl == "flash_attn_2":
114
+ return flash_attn2(q, k, v, softmax_scale=scale)
115
+ elif attn_impl == "xformers":
116
+ return xformers_attn(q, k, v, attn_bias=attn_mask, scale=scale)
117
+ elif attn_impl == "sdpa":
118
+ return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
119
+ elif attn_impl == "sage_attn":
120
+ return sage_attn(q, k, v, attn_mask=attn_mask, scale=scale)
121
+ elif attn_impl == "sparge_attn":
122
+ return sparge_attn(q, k, v, attn_mask=attn_mask, scale=scale)
123
+ else:
124
+ raise ValueError(f"Invalid attention implementation: {attn_impl}")
125
+
126
+
127
+ class Attention(nn.Module):
128
+ def __init__(
129
+ self,
130
+ q_dim,
131
+ num_heads,
132
+ head_dim,
133
+ kv_dim=None,
134
+ bias_q=False,
135
+ bias_kv=False,
136
+ bias_out=False,
137
+ scale=None,
138
+ attn_impl: Optional[str] = None,
139
+ device: str = "cuda:0",
140
+ dtype: torch.dtype = torch.float16,
141
+ ):
142
+ super().__init__()
143
+ dim_inner = head_dim * num_heads
144
+ kv_dim = kv_dim if kv_dim is not None else q_dim
145
+ self.num_heads = num_heads
146
+ self.head_dim = head_dim
147
+
148
+ self.to_q = nn.Linear(q_dim, dim_inner, bias=bias_q, device=device, dtype=dtype)
149
+ self.to_k = nn.Linear(kv_dim, dim_inner, bias=bias_kv, device=device, dtype=dtype)
150
+ self.to_v = nn.Linear(kv_dim, dim_inner, bias=bias_kv, device=device, dtype=dtype)
151
+ self.to_out = nn.Linear(dim_inner, q_dim, bias=bias_out, device=device, dtype=dtype)
152
+ self.attn_impl = attn_impl
153
+ self.scale = scale
154
+
155
+ def forward(
156
+ self,
157
+ x: torch.Tensor,
158
+ y: Optional[torch.Tensor] = None,
159
+ attn_mask: Optional[torch.Tensor] = None,
160
+ ):
161
+ if y is None:
162
+ y = x
163
+ q = rearrange(self.to_q(x), "b s (n d) -> b s n d", n=self.num_heads)
164
+ k = rearrange(self.to_k(y), "b s (n d) -> b s n d", n=self.num_heads)
165
+ v = rearrange(self.to_v(y), "b s (n d) -> b s n d", n=self.num_heads)
166
+ out = attention(q, k, v, attn_mask=attn_mask, attn_impl=self.attn_impl, scale=self.scale)
167
+ out = rearrange(out, "b s n d -> b s (n d)", n=self.num_heads)
168
+ return self.to_out(out)
169
+
170
+
171
+ def long_context_attention(
172
+ q,
173
+ k,
174
+ v,
175
+ attn_impl: Optional[str] = None,
176
+ attn_mask: Optional[torch.Tensor] = None,
177
+ scale: Optional[float] = None,
178
+ ):
179
+ """
180
+ q: [B, Lq, Nq, C1]
181
+ k: [B, Lk, Nk, C1]
182
+ v: [B, Lk, Nk, C2]
183
+ """
184
+ assert attn_impl in [
185
+ None,
186
+ "auto",
187
+ "eager",
188
+ "flash_attn_2",
189
+ "flash_attn_3",
190
+ "xformers",
191
+ "sdpa",
192
+ "sage_attn",
193
+ "sparge_attn",
194
+ ]
195
+ if attn_impl is None or attn_impl == "auto":
196
+ if FLASH_ATTN_3_AVAILABLE:
197
+ attn_func = LongContextAttention(attn_type=AttnType.FA3)
198
+ elif FLASH_ATTN_2_AVAILABLE:
199
+ attn_func = LongContextAttention(attn_type=AttnType.FA)
200
+ elif SDPA_AVAILABLE:
201
+ attn_func = LongContextAttention(attn_type=AttnType.TORCH)
202
+ else:
203
+ raise ValueError("No available long context attention implementation")
204
+ else:
205
+ if attn_impl == "flash_attn_3":
206
+ attn_func = LongContextAttention(attn_type=AttnType.FA3)
207
+ elif attn_impl == "flash_attn_2":
208
+ attn_func = LongContextAttention(attn_type=AttnType.FA)
209
+ elif attn_impl == "sdpa":
210
+ attn_func = LongContextAttention(attn_type=AttnType.TORCH)
211
+ elif attn_impl == "sage_attn":
212
+ attn_func = LongContextAttention(attn_type=AttnType.SAGE_FP8)
213
+ elif attn_impl == "sparge_attn":
214
+ attn_func = LongContextAttention(attn_type=AttnType.SPARSE_SAGE)
215
+ else:
216
+ raise ValueError(f"Invalid long context attention implementation: {attn_impl}")
217
+ return attn_func(q, k, v, softmax_scale=scale)
@@ -0,0 +1,293 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.common_types import _size_2_t
4
+ from typing import Union
5
+ from collections import OrderedDict
6
+ from contextlib import contextmanager
7
+
8
+
9
+ class LoRA(nn.Module):
10
+ def __init__(
11
+ self,
12
+ scale: float,
13
+ rank: int,
14
+ alpha: int,
15
+ up: Union[nn.Linear, nn.Conv2d, torch.Tensor],
16
+ down: Union[nn.Linear, nn.Conv2d, torch.Tensor],
17
+ device: str,
18
+ dtype: torch.dtype,
19
+ ):
20
+ super().__init__()
21
+ self.device = device
22
+ self.dtype = dtype
23
+ self.scale = scale
24
+ self.rank = rank
25
+ self.alpha = alpha.item() if isinstance(alpha, torch.Tensor) else alpha
26
+ self.up = up.to(device=device, dtype=dtype)
27
+ self.down = down.to(device=device, dtype=dtype)
28
+
29
+ def forward(self, x):
30
+ if isinstance(self.up, torch.Tensor) and isinstance(self.down, torch.Tensor):
31
+ return self.scale * (self.alpha / self.rank) * (x @ self.down.T @ self.up.T)
32
+ return self.scale * (self.alpha / self.rank) * (self.up(self.down(x)))
33
+
34
+ def apply_to(self, w: Union[nn.Linear, nn.Conv2d, nn.Parameter, torch.Tensor]):
35
+ if isinstance(self.up, torch.Tensor) and isinstance(self.down, torch.Tensor):
36
+ delta_w = self.scale * (self.alpha / self.rank) * (self.up @ self.down)
37
+ else:
38
+ delta_w = self.scale * (self.alpha / self.rank) * (self.up.weight @ self.down.weight)
39
+ if isinstance(w, (nn.Linear, nn.Conv2d)):
40
+ delta_w = delta_w.to(device=w.weight.data.device, dtype=w.weight.data.dtype)
41
+ w.weight.data.add_(delta_w)
42
+ elif isinstance(w, nn.Parameter):
43
+ delta_w = delta_w.to(device=w.data.device, dtype=w.data.dtype)
44
+ w.data.add_(delta_w)
45
+ elif isinstance(w, torch.Tensor):
46
+ delta_w = delta_w.to(device=w.device, dtype=w.dtype)
47
+ w.add_(delta_w)
48
+
49
+
50
+ class LoRALinear(nn.Linear):
51
+ def __init__(
52
+ self,
53
+ in_features: int,
54
+ out_features: int,
55
+ bias: bool = True,
56
+ device=None,
57
+ dtype=None,
58
+ ) -> None:
59
+ super().__init__(in_features, out_features, bias, device, dtype)
60
+ # LoRA
61
+ self._lora_dict = OrderedDict()
62
+ # Frozen LoRA
63
+ self._frozen_lora_list = []
64
+ self.register_buffer("_original_weight", None)
65
+
66
+ @staticmethod
67
+ def from_linear(linear: nn.Linear):
68
+ lora_linear = torch.nn.utils.skip_init(
69
+ LoRALinear,
70
+ linear.in_features,
71
+ linear.out_features,
72
+ linear.bias is not None,
73
+ device=linear.weight.device,
74
+ dtype=linear.weight.dtype,
75
+ )
76
+ lora_linear.weight = linear.weight
77
+ lora_linear.bias = linear.bias
78
+ return lora_linear
79
+
80
+ def add_lora(
81
+ self,
82
+ name: str,
83
+ scale: float,
84
+ rank: int,
85
+ alpha: int,
86
+ up: torch.Tensor,
87
+ down: torch.Tensor,
88
+ device: str,
89
+ dtype: torch.dtype,
90
+ **kwargs,
91
+ ):
92
+ up_linear = torch.nn.utils.skip_init(
93
+ nn.Linear, up.shape[1], up.shape[0], bias=False, device=device, dtype=dtype
94
+ )
95
+ down_linear = torch.nn.utils.skip_init(
96
+ nn.Linear, down.shape[0], down.shape[1], bias=False, device=device, dtype=dtype
97
+ )
98
+ up_linear.weight.data = up
99
+ down_linear.weight.data = down
100
+ lora = LoRA(scale, rank, alpha, up_linear, down_linear, device, dtype)
101
+ self._lora_dict[name] = lora
102
+
103
+ def modify_scale(self, name: str, scale: float):
104
+ if name not in self._lora_dict:
105
+ raise ValueError(f"LoRA name {name} not found in LoRALinear {self.__class__.__name__}")
106
+ self._lora_dict[name].scale = scale
107
+
108
+ def add_frozen_lora(
109
+ self,
110
+ name: str,
111
+ scale: float,
112
+ rank: int,
113
+ alpha: int,
114
+ up: torch.Tensor,
115
+ down: torch.Tensor,
116
+ device: str,
117
+ dtype: torch.dtype,
118
+ save_original_weight: bool = True,
119
+ ):
120
+ if save_original_weight and self._original_weight is None:
121
+ self._original_weight = self.weight.clone()
122
+ lora = LoRA(scale, rank, alpha, up, down, device, dtype)
123
+ lora.apply_to(self)
124
+ self._frozen_lora_list.append(lora)
125
+
126
+ def clear(self):
127
+ if self._original_weight is None and len(self._frozen_lora_list) > 0:
128
+ raise RuntimeError(
129
+ "Current LoRALinear has patched by frozen LoRA, but original weight is not saved, so you cannot clear LoRA."
130
+ )
131
+ self._lora_dict.clear()
132
+ self._frozen_lora_list = []
133
+ if self._original_weight is not None:
134
+ self.weight.data = self._original_weight
135
+ self._original_weight = None
136
+
137
+ def forward(self, x):
138
+ w_x = super().forward(x)
139
+ for name, lora in self._lora_dict.items():
140
+ w_x += lora(x)
141
+ return w_x
142
+
143
+
144
+ class LoRAConv2d(nn.Conv2d):
145
+ def __init__(
146
+ self,
147
+ in_channels: int,
148
+ out_channels: int,
149
+ kernel_size: _size_2_t,
150
+ stride: _size_2_t = 1,
151
+ padding: Union[str, _size_2_t] = 0,
152
+ dilation: _size_2_t = 1,
153
+ groups: int = 1,
154
+ bias: bool = True,
155
+ padding_mode: str = "zeros", # TODO: refine this type
156
+ device=None,
157
+ dtype=None,
158
+ ) -> None:
159
+ super().__init__(
160
+ in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype
161
+ )
162
+ # LoRA
163
+ self._lora_dict = OrderedDict()
164
+ # Frozen LoRA
165
+ self._frozen_lora_list = []
166
+ self._original_weight = None
167
+
168
+ @staticmethod
169
+ def from_conv2d(conv2d: nn.Conv2d):
170
+ lora_conv2d = torch.nn.utils.skip_init(
171
+ LoRAConv2d,
172
+ conv2d.in_channels,
173
+ conv2d.out_channels,
174
+ conv2d.kernel_size,
175
+ conv2d.stride,
176
+ conv2d.padding,
177
+ conv2d.dilation,
178
+ conv2d.groups,
179
+ conv2d.bias is not None,
180
+ conv2d.padding_mode,
181
+ device=conv2d.weight.device,
182
+ dtype=conv2d.weight.dtype,
183
+ )
184
+ lora_conv2d.weight = conv2d.weight
185
+ lora_conv2d.bias = conv2d.bias
186
+ return lora_conv2d
187
+
188
+ def _construct_lora(
189
+ self,
190
+ name: str,
191
+ scale: float,
192
+ rank: int,
193
+ alpha: int,
194
+ up: torch.Tensor,
195
+ down: torch.Tensor,
196
+ device: str,
197
+ dtype: torch.dtype,
198
+ ):
199
+ down_conv = torch.nn.utils.skip_init(
200
+ nn.Conv2d,
201
+ self.in_channels,
202
+ rank,
203
+ kernel_size=self.kernel_size,
204
+ stride=self.stride,
205
+ padding=self.padding,
206
+ bias=False,
207
+ device=device,
208
+ dtype=dtype,
209
+ )
210
+ down_conv.weight.data = down
211
+ # according to the official kohya_ss trainer kernel_size are always fixed for the up layer
212
+ # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
213
+ # refer from diffusers
214
+ up_conv = torch.nn.utils.skip_init(
215
+ nn.Conv2d,
216
+ rank,
217
+ self.out_channels,
218
+ kernel_size=(1, 1),
219
+ stride=(1, 1),
220
+ bias=False,
221
+ device=device,
222
+ dtype=dtype,
223
+ )
224
+ up_conv.weight.data = up
225
+
226
+ lora = LoRA(scale, rank, alpha, up_conv, down_conv, device, dtype)
227
+ return lora
228
+
229
+ def add_lora(
230
+ self,
231
+ name: str,
232
+ scale: float,
233
+ rank: int,
234
+ alpha: int,
235
+ up: torch.Tensor,
236
+ down: torch.Tensor,
237
+ device: str,
238
+ dtype: torch.dtype,
239
+ **kwargs,
240
+ ):
241
+ self._lora_dict[name] = self._construct_lora(name, scale, rank, alpha, up, down, device, dtype)
242
+
243
+ def modify_scale(self, name: str, scale: float):
244
+ if name not in self._lora_dict:
245
+ raise ValueError(f"LoRA name {name} not found in LoRAConv2d {self.__class__.__name__}")
246
+ self._lora_dict[name].scale = scale
247
+
248
+ def add_frozen_lora(
249
+ self,
250
+ name: str,
251
+ scale: float,
252
+ rank: int,
253
+ alpha: int,
254
+ up: torch.Tensor,
255
+ down: torch.Tensor,
256
+ device: str,
257
+ dtype: torch.dtype,
258
+ save_original_weight: bool = True,
259
+ ):
260
+ if save_original_weight and self._original_weight is None:
261
+ self._original_weight = self.weight.clone()
262
+ lora = self._construct_lora(name, scale, rank, alpha, up, down, device, dtype)
263
+ lora.apply_to(self)
264
+ self._frozen_lora_list.append(lora)
265
+
266
+ def clear(self):
267
+ if self._original_weight is None and len(self._frozen_lora_list) > 0:
268
+ raise RuntimeError(
269
+ "Current LoRALinear has patched by frozen LoRA, but original weight is not saved, so you cannot clear LoRA."
270
+ )
271
+ self._lora_dict.clear()
272
+ self._frozen_lora_list = []
273
+ if self._original_weight is not None:
274
+ self.weight.copy_(self._original_weight)
275
+ self._original_weight = None
276
+
277
+ def forward(self, x):
278
+ w_x = super().forward(x)
279
+ for name, lora in self._lora_dict.items():
280
+ w_x += lora(x)
281
+ return w_x
282
+
283
+
284
+ @contextmanager
285
+ def LoRAContext():
286
+ origin_linear = torch.nn.Linear
287
+ origin_conv2d = torch.nn.Conv2d
288
+
289
+ torch.nn.Linear = LoRALinear
290
+ torch.nn.Conv2d = LoRAConv2d
291
+ yield
292
+ torch.nn.Linear = origin_linear
293
+ torch.nn.Conv2d = origin_conv2d
@@ -0,0 +1,56 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+
6
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
7
+ relative_buckets = 0
8
+ if bidirectional:
9
+ num_buckets //= 2
10
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
11
+ relative_position = torch.abs(relative_position)
12
+ else:
13
+ relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
14
+ # now relative_position is in the range [0, inf)
15
+
16
+ # half of the buckets are for exact increments in positions
17
+ max_exact = num_buckets // 2
18
+ is_small = relative_position < max_exact
19
+
20
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
21
+ relative_position_if_large = max_exact + (
22
+ torch.log(relative_position.float() / max_exact)
23
+ / math.log(max_distance / max_exact)
24
+ * (num_buckets - max_exact)
25
+ ).to(torch.long)
26
+ relative_position_if_large = torch.min(
27
+ relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
28
+ )
29
+
30
+ relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
31
+ return relative_buckets
32
+
33
+
34
+ class RelativePositionEmbedding(nn.Module):
35
+ def __init__(
36
+ self, num_buckets, max_distance, num_heads, device: str = "cuda:0", dtype: torch.dtype = torch.float16
37
+ ):
38
+ super().__init__()
39
+ self.num_buckets = num_buckets
40
+ self.max_distance = max_distance
41
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads, device=device, dtype=dtype)
42
+
43
+ def forward(self, query_length, key_length):
44
+ device = self.relative_attention_bias.weight.device
45
+ context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
46
+ memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
47
+ relative_position = memory_position - context_position # shape (query_length, key_length)
48
+ relative_position_bucket = _relative_position_bucket(
49
+ relative_position, # shape (query_length, key_length)
50
+ bidirectional=True,
51
+ num_buckets=self.num_buckets,
52
+ max_distance=self.max_distance,
53
+ )
54
+ values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
55
+ values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
56
+ return values