diffsynth-engine 0.3.6.dev11__py3-none-any.whl → 0.3.6.dev13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,13 +3,14 @@ from .configs import (
3
3
  SDXLPipelineConfig,
4
4
  FluxPipelineConfig,
5
5
  WanPipelineConfig,
6
+ ControlNetParams,
7
+ ControlType,
6
8
  )
7
9
  from .pipelines import (
8
10
  FluxImagePipeline,
9
11
  SDXLImagePipeline,
10
12
  SDImagePipeline,
11
13
  WanVideoPipeline,
12
- ControlNetParams,
13
14
  )
14
15
  from .models.flux import FluxControlNet, FluxIPAdapter, FluxRedux
15
16
  from .models.sd import SDControlNet
@@ -44,6 +45,7 @@ __all__ = [
44
45
  "FluxReplaceByControlTool",
45
46
  "FluxReduxRefTool",
46
47
  "ControlNetParams",
48
+ "ControlType",
47
49
  "fetch_model",
48
50
  "fetch_modelscope_model",
49
51
  "fetch_civitai_model",
@@ -8,7 +8,7 @@ from .pipeline import (
8
8
  FluxPipelineConfig,
9
9
  WanPipelineConfig,
10
10
  )
11
- from .controlnet import ControlType
11
+ from .controlnet import ControlType, ControlNetParams
12
12
 
13
13
  __all__ = [
14
14
  "BaseConfig",
@@ -20,4 +20,5 @@ __all__ = [
20
20
  "FluxPipelineConfig",
21
21
  "WanPipelineConfig",
22
22
  "ControlType",
23
+ "ControlNetParams",
23
24
  ]
@@ -1,5 +1,13 @@
1
+ from dataclasses import dataclass
1
2
  from enum import Enum
2
3
 
4
+ import torch
5
+ import torch.nn as nn
6
+ from typing import List, Union, Optional
7
+ from PIL import Image
8
+
9
+ ImageType = Union[Image.Image, torch.Tensor, List[Image.Image], List[torch.Tensor]]
10
+
3
11
 
4
12
  # FLUX ControlType
5
13
  class ControlType(Enum):
@@ -15,3 +23,14 @@ class ControlType(Enum):
15
23
  return 128
16
24
  elif self == ControlType.bfl_fill:
17
25
  return 384
26
+
27
+
28
+ @dataclass
29
+ class ControlNetParams:
30
+ image: ImageType
31
+ scale: float = 1.0
32
+ model: Optional[nn.Module] = None
33
+ mask: Optional[ImageType] = None
34
+ control_start: float = 0
35
+ control_end: float = 1
36
+ processor_name: Optional[str] = None # only used for sdxl controlnet union now
@@ -14,6 +14,8 @@ from diffsynth_engine.utils.flag import (
14
14
  SPARGE_ATTN_AVAILABLE,
15
15
  )
16
16
 
17
+ FA3_MAX_HEADDIM = 256
18
+
17
19
  logger = logging.get_logger(__name__)
18
20
 
19
21
 
@@ -130,31 +132,40 @@ def attention(
130
132
  "sage_attn",
131
133
  "sparge_attn",
132
134
  ]
135
+ flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
133
136
  if attn_impl is None or attn_impl == "auto":
134
137
  if FLASH_ATTN_3_AVAILABLE:
135
- return flash_attn3(q, k, v, softmax_scale=scale)
136
- elif XFORMERS_AVAILABLE:
138
+ if flash_attn3_compatible:
139
+ return flash_attn3(q, k, v, softmax_scale=scale)
140
+ else:
141
+ logger.warning(
142
+ f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
143
+ )
144
+ if XFORMERS_AVAILABLE:
137
145
  return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
138
- elif SDPA_AVAILABLE:
146
+ if SDPA_AVAILABLE:
139
147
  return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
140
- elif FLASH_ATTN_2_AVAILABLE:
148
+ if FLASH_ATTN_2_AVAILABLE:
141
149
  return flash_attn2(q, k, v, softmax_scale=scale)
142
- else:
143
- return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
150
+ return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
144
151
  else:
145
152
  if attn_impl == "eager":
146
153
  return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
147
- elif attn_impl == "flash_attn_3":
154
+ if attn_impl == "flash_attn_3":
155
+ if not flash_attn3_compatible:
156
+ raise RuntimeError(
157
+ f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}"
158
+ )
148
159
  return flash_attn3(q, k, v, softmax_scale=scale)
149
- elif attn_impl == "flash_attn_2":
160
+ if attn_impl == "flash_attn_2":
150
161
  return flash_attn2(q, k, v, softmax_scale=scale)
151
- elif attn_impl == "xformers":
162
+ if attn_impl == "xformers":
152
163
  return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
153
- elif attn_impl == "sdpa":
164
+ if attn_impl == "sdpa":
154
165
  return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
155
- elif attn_impl == "sage_attn":
166
+ if attn_impl == "sage_attn":
156
167
  return sage_attn(q, k, v, attn_mask=attn_mask, scale=scale)
157
- elif attn_impl == "sparge_attn":
168
+ if attn_impl == "sparge_attn":
158
169
  return sparge_attn(
159
170
  q,
160
171
  k,
@@ -166,8 +177,7 @@ def attention(
166
177
  cdfthreshd=kwargs.get("sparge_cdfthreshd", 0.98),
167
178
  pvthreshd=kwargs.get("sparge_pvthreshd", 50),
168
179
  )
169
- else:
170
- raise ValueError(f"Invalid attention implementation: {attn_impl}")
180
+ raise ValueError(f"Invalid attention implementation: {attn_impl}")
171
181
 
172
182
 
173
183
  class Attention(nn.Module):
@@ -240,32 +250,42 @@ def long_context_attention(
240
250
  "sage_attn",
241
251
  "sparge_attn",
242
252
  ]
253
+ flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
243
254
  if attn_impl is None or attn_impl == "auto":
244
255
  if FLASH_ATTN_3_AVAILABLE:
245
- attn_func = LongContextAttention(attn_type=AttnType.FA3)
246
- elif SDPA_AVAILABLE:
247
- attn_func = LongContextAttention(attn_type=AttnType.TORCH)
248
- elif FLASH_ATTN_2_AVAILABLE:
249
- attn_func = LongContextAttention(attn_type=AttnType.FA)
250
- else:
251
- raise ValueError("No available long context attention implementation")
256
+ if flash_attn3_compatible:
257
+ return LongContextAttention(attn_type=AttnType.FA3)(q, k, v, softmax_scale=scale)
258
+ else:
259
+ logger.warning(
260
+ f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
261
+ )
262
+ if SDPA_AVAILABLE:
263
+ return LongContextAttention(attn_type=AttnType.TORCH)(q, k, v, softmax_scale=scale)
264
+ if FLASH_ATTN_2_AVAILABLE:
265
+ return LongContextAttention(attn_type=AttnType.FA)(q, k, v, softmax_scale=scale)
266
+ raise ValueError("No available long context attention implementation")
252
267
  else:
253
268
  if attn_impl == "flash_attn_3":
254
- attn_func = LongContextAttention(attn_type=AttnType.FA3)
255
- elif attn_impl == "flash_attn_2":
256
- attn_func = LongContextAttention(attn_type=AttnType.FA)
257
- elif attn_impl == "sdpa":
258
- attn_func = LongContextAttention(attn_type=AttnType.TORCH)
259
- elif attn_impl == "sage_attn":
260
- attn_func = LongContextAttention(attn_type=AttnType.SAGE_FP8)
261
- elif attn_impl == "sparge_attn":
269
+ if flash_attn3_compatible:
270
+ return LongContextAttention(attn_type=AttnType.FA3)(q, k, v, softmax_scale=scale)
271
+ else:
272
+ raise RuntimeError(
273
+ f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}"
274
+ )
275
+ if attn_impl == "flash_attn_2":
276
+ return LongContextAttention(attn_type=AttnType.FA)(q, k, v, softmax_scale=scale)
277
+ if attn_impl == "sdpa":
278
+ return LongContextAttention(attn_type=AttnType.TORCH)(q, k, v, softmax_scale=scale)
279
+ if attn_impl == "sage_attn":
280
+ return LongContextAttention(attn_type=AttnType.SAGE_FP8)(q, k, v, softmax_scale=scale)
281
+ if attn_impl == "sparge_attn":
262
282
  attn_processor = SparseAttentionMeansim()
263
283
  # default args from spas_sage2_attn_meansim_cuda
264
284
  attn_processor.smooth_k = torch.tensor(kwargs.get("sparge_smooth_k", True))
265
285
  attn_processor.simthreshd1 = torch.tensor(kwargs.get("sparge_simthreshd1", 0.6))
266
286
  attn_processor.cdfthreshd = torch.tensor(kwargs.get("sparge_cdfthreshd", 0.98))
267
287
  attn_processor.pvthreshd = torch.tensor(kwargs.get("sparge_pvthreshd", 50))
268
- attn_func = LongContextAttention(attn_type=AttnType.SPARSE_SAGE, attn_processor=attn_processor)
269
- else:
270
- raise ValueError(f"Invalid long context attention implementation: {attn_impl}")
271
- return attn_func(q, k, v, softmax_scale=scale)
288
+ return LongContextAttention(attn_type=AttnType.SPARSE_SAGE, attn_processor=attn_processor)(
289
+ q, k, v, softmax_scale=scale
290
+ )
291
+ raise ValueError(f"Invalid long context attention implementation: {attn_impl}")
@@ -1,5 +1,4 @@
1
1
  from .base import BasePipeline, LoRAStateDictConverter
2
- from .controlnet_helper import ControlNetParams
3
2
  from .flux_image import FluxImagePipeline
4
3
  from .sdxl_image import SDXLImagePipeline
5
4
  from .sd_image import SDImagePipeline
@@ -13,5 +12,4 @@ __all__ = [
13
12
  "SDXLImagePipeline",
14
13
  "SDImagePipeline",
15
14
  "WanVideoPipeline",
16
- "ControlNetParams",
17
15
  ]
@@ -17,10 +17,10 @@ from diffsynth_engine.models.flux import (
17
17
  flux_dit_config,
18
18
  flux_text_encoder_config,
19
19
  )
20
- from diffsynth_engine.configs import FluxPipelineConfig, ControlType
20
+ from diffsynth_engine.configs import FluxPipelineConfig, ControlType, ControlNetParams
21
21
  from diffsynth_engine.models.basic.lora import LoRAContext
22
22
  from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
23
- from diffsynth_engine.pipelines.controlnet_helper import ControlNetParams, accumulate
23
+ from diffsynth_engine.pipelines.utils import accumulate
24
24
  from diffsynth_engine.tokenizers import CLIPTokenizer, T5TokenizerFast
25
25
  from diffsynth_engine.algorithm.noise_scheduler import RecifitedFlowScheduler
26
26
  from diffsynth_engine.algorithm.sampler import FlowMatchEulerSampler
@@ -6,12 +6,12 @@ from typing import Callable, Dict, Optional, List
6
6
  from tqdm import tqdm
7
7
  from PIL import Image, ImageOps
8
8
 
9
- from diffsynth_engine.configs import SDPipelineConfig
9
+ from diffsynth_engine.configs import SDPipelineConfig, ControlNetParams
10
10
  from diffsynth_engine.models.base import split_suffix
11
11
  from diffsynth_engine.models.basic.lora import LoRAContext
12
12
  from diffsynth_engine.models.sd import SDTextEncoder, SDVAEDecoder, SDVAEEncoder, SDUNet, sd_unet_config
13
13
  from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
14
- from diffsynth_engine.pipelines.controlnet_helper import ControlNetParams, accumulate
14
+ from diffsynth_engine.pipelines.utils import accumulate
15
15
  from diffsynth_engine.tokenizers import CLIPTokenizer
16
16
  from diffsynth_engine.algorithm.noise_scheduler import ScaledLinearScheduler
17
17
  from diffsynth_engine.algorithm.sampler import EulerSampler
@@ -6,7 +6,7 @@ from typing import Callable, Dict, Optional, List
6
6
  from tqdm import tqdm
7
7
  from PIL import Image, ImageOps
8
8
 
9
- from diffsynth_engine.configs import SDXLPipelineConfig
9
+ from diffsynth_engine.configs import SDXLPipelineConfig, ControlNetParams
10
10
  from diffsynth_engine.models.base import split_suffix
11
11
  from diffsynth_engine.models.basic.lora import LoRAContext
12
12
  from diffsynth_engine.models.basic.timestep import TemporalTimesteps
@@ -19,7 +19,7 @@ from diffsynth_engine.models.sdxl import (
19
19
  sdxl_unet_config,
20
20
  )
21
21
  from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
22
- from diffsynth_engine.pipelines.controlnet_helper import ControlNetParams, accumulate
22
+ from diffsynth_engine.pipelines.utils import accumulate
23
23
  from diffsynth_engine.tokenizers import CLIPTokenizer
24
24
  from diffsynth_engine.algorithm.noise_scheduler import ScaledLinearScheduler
25
25
  from diffsynth_engine.algorithm.sampler import EulerSampler
@@ -0,0 +1,6 @@
1
+ def accumulate(result, new_item):
2
+ if result is None:
3
+ return new_item
4
+ for i, item in enumerate(new_item):
5
+ result[i] += item
6
+ return result
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.3.6.dev11
3
+ Version: 0.3.6.dev13
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -1,4 +1,4 @@
1
- diffsynth_engine/__init__.py,sha256=PnsxBE7qAW_5yDsrl1S-I3UraXMOQKTHWxAfKHbwIYQ,1279
1
+ diffsynth_engine/__init__.py,sha256=ysgNUqKZwce7rt_JdytIOPAJH5KYiH_LQqh-JQ51ZY8,1315
2
2
  diffsynth_engine/algorithm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  diffsynth_engine/algorithm/noise_scheduler/__init__.py,sha256=YvcwE2tCNua-OAX9GEPm0EXsINNWH4XvJMNZb-uaZMM,745
4
4
  diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py,sha256=WICrLEh7b2TdZMMEN14NqiYydj7dxXT6RolXymKiMk8,188
@@ -60,15 +60,15 @@ diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json,sha256=e4q
60
60
  diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model,sha256=45CaZ7eAZQs1z1Kax4KtK2sm5tH4SdP7tqhykF9FJFg,4548313
61
61
  diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json,sha256=bhl7TT29cdoUtOslX0-pHJwfIGiyCi3iRylnyj0iYCs,16837417
62
62
  diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json,sha256=7Zo6iw-qcacKMoR-BDX-A25uES1N9O23u0ipIeNE3AU,61728
63
- diffsynth_engine/configs/__init__.py,sha256=dZ80g2GB3B2YdmoGMp9yvwK3FRJI5j8vShhB9L95j1U,460
64
- diffsynth_engine/configs/controlnet.py,sha256=OF_cznEw-NpGTM9vP_mIApr4MAJCywWoDWkcUWCz-bs,434
63
+ diffsynth_engine/configs/__init__.py,sha256=qvfbnHf3wK9THPU_mFr1Qx_lU80BaUp5HpxUmjoNy60,502
64
+ diffsynth_engine/configs/controlnet.py,sha256=EpUkCdRNk2G5uo56syaOzPFdR9g0sDHRXckagmMsgaQ,948
65
65
  diffsynth_engine/configs/pipeline.py,sha256=NPQlNz-AOpi8qFzRob0RNnOqSc8C-vCdHbstLyUugeo,7703
66
66
  diffsynth_engine/kernels/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
67
67
  diffsynth_engine/models/__init__.py,sha256=8Ze7cSE8InetgXWTNb0neVA2Q44K7WlE-h7O-02m2sY,119
68
68
  diffsynth_engine/models/base.py,sha256=sbyyGP-ENnqicr6cxjEmXRf6dWrmKjCu6k5yamuJ518,2665
69
69
  diffsynth_engine/models/utils.py,sha256=r5xLSEog1_ODaFrpqzJvAj3r23PQiEpgivzErClTZTg,1561
70
70
  diffsynth_engine/models/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
71
- diffsynth_engine/models/basic/attention.py,sha256=Hquc9H57N37hz5JbCJHpuPGX1smG403DNiOZTnlCYYA,9103
71
+ diffsynth_engine/models/basic/attention.py,sha256=vwQZi4MPTE5AV3Tv3KDAQ4TIGOy4UlD5n93zsbtBMWg,10321
72
72
  diffsynth_engine/models/basic/lora.py,sha256=qEh44zfh7ZBblLpjmKzwzAxmTlVyY0wu9IkGsnr7Ih8,10614
73
73
  diffsynth_engine/models/basic/relative_position_emb.py,sha256=rCXOweZMcayVnNUVvBcYXMdhHS257B_PC8PZSWxvhNQ,2540
74
74
  diffsynth_engine/models/basic/timestep.py,sha256=WJODYqkSXEM0wcS42YkkfrGwxWt0e60zMTkDdUBQqBw,2810
@@ -107,12 +107,12 @@ diffsynth_engine/models/wan/wan_dit.py,sha256=gUd9KeMl7y_VPLntGoGtT2Io94opPiKlrr
107
107
  diffsynth_engine/models/wan/wan_image_encoder.py,sha256=LYwcfCcQmXf9FP08DGaU2bfaPgFfdpJ23OpJP8UCggo,14397
108
108
  diffsynth_engine/models/wan/wan_text_encoder.py,sha256=bkphxtqNNwXcEA_OaUrwV9CvICV-s16awu5Z9gjjzsM,10912
109
109
  diffsynth_engine/models/wan/wan_vae.py,sha256=RxyuHExQmRjGBAqhZdIbtwZFdCibTzh__U4-Sa00zdI,29004
110
- diffsynth_engine/pipelines/__init__.py,sha256=Ewarnhf4K-sYFfSG4mghDoJh5FZKG9Xiz2DFZizNZ-I,452
110
+ diffsynth_engine/pipelines/__init__.py,sha256=kTvANqHcMPrHqiJVg-XohfqRdW6Cj4aElfItTb1B7Vs,380
111
111
  diffsynth_engine/pipelines/base.py,sha256=yVp4hSPCqk98azzy3ykKBfPAufvq_ncTFOURN95z7d0,12178
112
- diffsynth_engine/pipelines/controlnet_helper.py,sha256=b6HnJFJfMKZq9s5DQ-9Se8OTSDeHVk4AskONSwcRShg,680
113
- diffsynth_engine/pipelines/flux_image.py,sha256=CPI7WwXJz60rpQN_ZfVV5kcWxg8WlrYNNe6QI1E5EPk,48851
114
- diffsynth_engine/pipelines/sd_image.py,sha256=hiL2gvQcPgniRm8TlzUKhoo5bGUmmVlYDS__E_WFDiE,17834
115
- diffsynth_engine/pipelines/sdxl_image.py,sha256=qWpE5q0CeDrsKZVIxYHpYLyZAkKMqACDzf4RPPvdn7A,21587
112
+ diffsynth_engine/pipelines/flux_image.py,sha256=MtQqTnCqQjIFovhA3lzBXpnkS4DkZH2PtFUwNZdl42M,48839
113
+ diffsynth_engine/pipelines/sd_image.py,sha256=5dGIa6crtklO7xPd1eeBVkqj54Pe89Uo3bMyXVEaXxM,17822
114
+ diffsynth_engine/pipelines/sdxl_image.py,sha256=Ns4bCSO3BtCXdjGJEQ0s5oY0S3jrp5yE5lhfon-iNiw,21575
115
+ diffsynth_engine/pipelines/utils.py,sha256=VfSTwRejSVSKXIa7w0VhObmvaBFRvDP-uiYsHHkPAgs,165
116
116
  diffsynth_engine/pipelines/wan_video.py,sha256=vi_xW-jU4PeMtZzjkfQbnj8eOymJrTZMrOQau6tx6ks,20187
117
117
  diffsynth_engine/processor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
118
118
  diffsynth_engine/processor/canny_processor.py,sha256=hV30NlblTkEFUAmF_O-LJrNlGVM2SFrqq6okfF8VpOo,602
@@ -144,8 +144,8 @@ diffsynth_engine/utils/parallel.py,sha256=gbIeilfOYsqeDcgkaP68TfLjIXxvD0KfLiAsR_
144
144
  diffsynth_engine/utils/platform.py,sha256=2lXdw6YkqcRONCeT98n4cyg1Ii8Ybbyj2Ns72Se9tlk,496
145
145
  diffsynth_engine/utils/prompt.py,sha256=YItMchoVzsG6y-LB4vzzDUWrkhKRVlt1HfVhxZjSxMQ,280
146
146
  diffsynth_engine/utils/video.py,sha256=Ne0rd2lb59UT1q5EotpjlY7OT8F9oTCFDyo1ST77uoQ,1004
147
- diffsynth_engine-0.3.6.dev11.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
148
- diffsynth_engine-0.3.6.dev11.dist-info/METADATA,sha256=pFhxaPrL9JrwA9ZnbtL0VLrS6gVC52wFwRu2Yn_vqQc,1069
149
- diffsynth_engine-0.3.6.dev11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
150
- diffsynth_engine-0.3.6.dev11.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
151
- diffsynth_engine-0.3.6.dev11.dist-info/RECORD,,
147
+ diffsynth_engine-0.3.6.dev13.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
148
+ diffsynth_engine-0.3.6.dev13.dist-info/METADATA,sha256=2jH1jlJdbUga4JOoDHfRyKEn6E4xQ1w9wRhLKVYaqRk,1069
149
+ diffsynth_engine-0.3.6.dev13.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
150
+ diffsynth_engine-0.3.6.dev13.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
151
+ diffsynth_engine-0.3.6.dev13.dist-info/RECORD,,
@@ -1,26 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from typing import List, Union, Optional
4
- from PIL import Image
5
- from dataclasses import dataclass
6
-
7
- ImageType = Union[Image.Image, torch.Tensor, List[Image.Image], List[torch.Tensor]]
8
-
9
-
10
- @dataclass
11
- class ControlNetParams:
12
- image: ImageType
13
- scale: float = 1.0
14
- model: Optional[nn.Module] = None
15
- mask: Optional[ImageType] = None
16
- control_start: float = 0
17
- control_end: float = 1
18
- processor_name: Optional[str] = None # only used for sdxl controlnet union now
19
-
20
-
21
- def accumulate(result, new_item):
22
- if result is None:
23
- return new_item
24
- for i, item in enumerate(new_item):
25
- result[i] += item
26
- return result