diffsynth-engine 0.7.0__py3-none-any.whl → 0.7.1.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +6 -0
- diffsynth_engine/conf/models/flux2/qwen3_8B_config.json +68 -0
- diffsynth_engine/configs/__init__.py +4 -0
- diffsynth_engine/configs/pipeline.py +50 -1
- diffsynth_engine/models/flux2/__init__.py +7 -0
- diffsynth_engine/models/flux2/flux2_dit.py +1065 -0
- diffsynth_engine/models/flux2/flux2_vae.py +1992 -0
- diffsynth_engine/pipelines/__init__.py +2 -0
- diffsynth_engine/pipelines/flux2_klein_image.py +634 -0
- diffsynth_engine/utils/constants.py +1 -0
- {diffsynth_engine-0.7.0.dist-info → diffsynth_engine-0.7.1.dev1.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.7.0.dist-info → diffsynth_engine-0.7.1.dev1.dist-info}/RECORD +15 -10
- {diffsynth_engine-0.7.0.dist-info → diffsynth_engine-0.7.1.dev1.dist-info}/WHEEL +1 -1
- {diffsynth_engine-0.7.0.dist-info → diffsynth_engine-0.7.1.dev1.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.7.0.dist-info → diffsynth_engine-0.7.1.dev1.dist-info}/top_level.txt +0 -0
diffsynth_engine/__init__.py
CHANGED
|
@@ -7,12 +7,14 @@ from .configs import (
|
|
|
7
7
|
QwenImagePipelineConfig,
|
|
8
8
|
HunyuanPipelineConfig,
|
|
9
9
|
ZImagePipelineConfig,
|
|
10
|
+
Flux2KleinPipelineConfig,
|
|
10
11
|
SDStateDicts,
|
|
11
12
|
SDXLStateDicts,
|
|
12
13
|
FluxStateDicts,
|
|
13
14
|
WanStateDicts,
|
|
14
15
|
QwenImageStateDicts,
|
|
15
16
|
ZImageStateDicts,
|
|
17
|
+
Flux2StateDicts,
|
|
16
18
|
AttnImpl,
|
|
17
19
|
SpargeAttentionParams,
|
|
18
20
|
VideoSparseAttentionParams,
|
|
@@ -26,6 +28,7 @@ from .pipelines import (
|
|
|
26
28
|
SDImagePipeline,
|
|
27
29
|
SDXLImagePipeline,
|
|
28
30
|
FluxImagePipeline,
|
|
31
|
+
Flux2KleinPipeline,
|
|
29
32
|
WanVideoPipeline,
|
|
30
33
|
WanDMDPipeline,
|
|
31
34
|
QwenImagePipeline,
|
|
@@ -59,12 +62,14 @@ __all__ = [
|
|
|
59
62
|
"QwenImagePipelineConfig",
|
|
60
63
|
"HunyuanPipelineConfig",
|
|
61
64
|
"ZImagePipelineConfig",
|
|
65
|
+
"Flux2KleinPipelineConfig",
|
|
62
66
|
"SDStateDicts",
|
|
63
67
|
"SDXLStateDicts",
|
|
64
68
|
"FluxStateDicts",
|
|
65
69
|
"WanStateDicts",
|
|
66
70
|
"QwenImageStateDicts",
|
|
67
71
|
"ZImageStateDicts",
|
|
72
|
+
"Flux2StateDicts",
|
|
68
73
|
"AttnImpl",
|
|
69
74
|
"SpargeAttentionParams",
|
|
70
75
|
"VideoSparseAttentionParams",
|
|
@@ -78,6 +83,7 @@ __all__ = [
|
|
|
78
83
|
"SDXLImagePipeline",
|
|
79
84
|
"SDXLControlNetUnion",
|
|
80
85
|
"FluxImagePipeline",
|
|
86
|
+
"Flux2KleinPipeline",
|
|
81
87
|
"FluxControlNet",
|
|
82
88
|
"FluxIPAdapter",
|
|
83
89
|
"FluxRedux",
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
{
|
|
2
|
+
"architectures": [
|
|
3
|
+
"Qwen3ForCausalLM"
|
|
4
|
+
],
|
|
5
|
+
"attention_bias": false,
|
|
6
|
+
"attention_dropout": 0.0,
|
|
7
|
+
"bos_token_id": 151643,
|
|
8
|
+
"dtype": "bfloat16",
|
|
9
|
+
"eos_token_id": 151645,
|
|
10
|
+
"head_dim": 128,
|
|
11
|
+
"hidden_act": "silu",
|
|
12
|
+
"hidden_size": 4096,
|
|
13
|
+
"initializer_range": 0.02,
|
|
14
|
+
"intermediate_size": 12288,
|
|
15
|
+
"layer_types": [
|
|
16
|
+
"full_attention",
|
|
17
|
+
"full_attention",
|
|
18
|
+
"full_attention",
|
|
19
|
+
"full_attention",
|
|
20
|
+
"full_attention",
|
|
21
|
+
"full_attention",
|
|
22
|
+
"full_attention",
|
|
23
|
+
"full_attention",
|
|
24
|
+
"full_attention",
|
|
25
|
+
"full_attention",
|
|
26
|
+
"full_attention",
|
|
27
|
+
"full_attention",
|
|
28
|
+
"full_attention",
|
|
29
|
+
"full_attention",
|
|
30
|
+
"full_attention",
|
|
31
|
+
"full_attention",
|
|
32
|
+
"full_attention",
|
|
33
|
+
"full_attention",
|
|
34
|
+
"full_attention",
|
|
35
|
+
"full_attention",
|
|
36
|
+
"full_attention",
|
|
37
|
+
"full_attention",
|
|
38
|
+
"full_attention",
|
|
39
|
+
"full_attention",
|
|
40
|
+
"full_attention",
|
|
41
|
+
"full_attention",
|
|
42
|
+
"full_attention",
|
|
43
|
+
"full_attention",
|
|
44
|
+
"full_attention",
|
|
45
|
+
"full_attention",
|
|
46
|
+
"full_attention",
|
|
47
|
+
"full_attention",
|
|
48
|
+
"full_attention",
|
|
49
|
+
"full_attention",
|
|
50
|
+
"full_attention",
|
|
51
|
+
"full_attention"
|
|
52
|
+
],
|
|
53
|
+
"max_position_embeddings": 40960,
|
|
54
|
+
"max_window_layers": 36,
|
|
55
|
+
"model_type": "qwen3",
|
|
56
|
+
"num_attention_heads": 32,
|
|
57
|
+
"num_hidden_layers": 36,
|
|
58
|
+
"num_key_value_heads": 8,
|
|
59
|
+
"rms_norm_eps": 1e-06,
|
|
60
|
+
"rope_scaling": null,
|
|
61
|
+
"rope_theta": 1000000,
|
|
62
|
+
"sliding_window": null,
|
|
63
|
+
"tie_word_embeddings": false,
|
|
64
|
+
"transformers_version": "4.56.1",
|
|
65
|
+
"use_cache": true,
|
|
66
|
+
"use_sliding_window": false,
|
|
67
|
+
"vocab_size": 151936
|
|
68
|
+
}
|
|
@@ -11,6 +11,7 @@ from .pipeline import (
|
|
|
11
11
|
QwenImagePipelineConfig,
|
|
12
12
|
HunyuanPipelineConfig,
|
|
13
13
|
ZImagePipelineConfig,
|
|
14
|
+
Flux2KleinPipelineConfig,
|
|
14
15
|
BaseStateDicts,
|
|
15
16
|
SDStateDicts,
|
|
16
17
|
SDXLStateDicts,
|
|
@@ -19,6 +20,7 @@ from .pipeline import (
|
|
|
19
20
|
WanS2VStateDicts,
|
|
20
21
|
QwenImageStateDicts,
|
|
21
22
|
ZImageStateDicts,
|
|
23
|
+
Flux2StateDicts,
|
|
22
24
|
AttnImpl,
|
|
23
25
|
SpargeAttentionParams,
|
|
24
26
|
VideoSparseAttentionParams,
|
|
@@ -44,6 +46,7 @@ __all__ = [
|
|
|
44
46
|
"QwenImagePipelineConfig",
|
|
45
47
|
"HunyuanPipelineConfig",
|
|
46
48
|
"ZImagePipelineConfig",
|
|
49
|
+
"Flux2KleinPipelineConfig",
|
|
47
50
|
"BaseStateDicts",
|
|
48
51
|
"SDStateDicts",
|
|
49
52
|
"SDXLStateDicts",
|
|
@@ -52,6 +55,7 @@ __all__ = [
|
|
|
52
55
|
"WanS2VStateDicts",
|
|
53
56
|
"QwenImageStateDicts",
|
|
54
57
|
"ZImageStateDicts",
|
|
58
|
+
"Flux2StateDicts",
|
|
55
59
|
"AttnImpl",
|
|
56
60
|
"SpargeAttentionParams",
|
|
57
61
|
"VideoSparseAttentionParams",
|
|
@@ -3,6 +3,7 @@ import torch
|
|
|
3
3
|
from enum import Enum
|
|
4
4
|
from dataclasses import dataclass, field
|
|
5
5
|
from typing import List, Dict, Tuple, Optional
|
|
6
|
+
from typing_extensions import Literal
|
|
6
7
|
|
|
7
8
|
from diffsynth_engine.configs.controlnet import ControlType
|
|
8
9
|
|
|
@@ -339,6 +340,47 @@ class ZImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig,
|
|
|
339
340
|
init_parallel_config(self)
|
|
340
341
|
|
|
341
342
|
|
|
343
|
+
@dataclass
|
|
344
|
+
class Flux2KleinPipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig, BaseConfig):
|
|
345
|
+
model_path: str | os.PathLike | List[str | os.PathLike]
|
|
346
|
+
model_dtype: torch.dtype = torch.bfloat16
|
|
347
|
+
vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
|
|
348
|
+
vae_dtype: torch.dtype = torch.bfloat16
|
|
349
|
+
encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
|
|
350
|
+
encoder_dtype: torch.dtype = torch.bfloat16
|
|
351
|
+
image_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
|
|
352
|
+
image_encoder_dtype: torch.dtype = torch.bfloat16
|
|
353
|
+
model_size: Literal["4B", "9B"] = "4B"
|
|
354
|
+
|
|
355
|
+
@classmethod
|
|
356
|
+
def basic_config(
|
|
357
|
+
cls,
|
|
358
|
+
model_path: str | os.PathLike | List[str | os.PathLike],
|
|
359
|
+
encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
|
|
360
|
+
vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
|
|
361
|
+
image_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
|
|
362
|
+
device: str = "cuda",
|
|
363
|
+
parallelism: int = 1,
|
|
364
|
+
offload_mode: Optional[str] = None,
|
|
365
|
+
offload_to_disk: bool = False,
|
|
366
|
+
) -> "Flux2KleinPipelineConfig":
|
|
367
|
+
return cls(
|
|
368
|
+
model_path=model_path,
|
|
369
|
+
device=device,
|
|
370
|
+
encoder_path=encoder_path,
|
|
371
|
+
vae_path=vae_path,
|
|
372
|
+
image_encoder_path=image_encoder_path,
|
|
373
|
+
parallelism=parallelism,
|
|
374
|
+
use_cfg_parallel=True if parallelism > 1 else False,
|
|
375
|
+
use_fsdp=True if parallelism > 1 else False,
|
|
376
|
+
offload_mode=offload_mode,
|
|
377
|
+
offload_to_disk=offload_to_disk,
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
def __post_init__(self):
|
|
381
|
+
init_parallel_config(self)
|
|
382
|
+
|
|
383
|
+
|
|
342
384
|
@dataclass
|
|
343
385
|
class BaseStateDicts:
|
|
344
386
|
pass
|
|
@@ -398,7 +440,14 @@ class ZImageStateDicts:
|
|
|
398
440
|
image_encoder: Optional[Dict[str, torch.Tensor]] = None
|
|
399
441
|
|
|
400
442
|
|
|
401
|
-
|
|
443
|
+
@dataclass
|
|
444
|
+
class Flux2StateDicts:
|
|
445
|
+
model: Dict[str, torch.Tensor]
|
|
446
|
+
vae: Dict[str, torch.Tensor]
|
|
447
|
+
encoder: Dict[str, torch.Tensor]
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
def init_parallel_config(config: FluxPipelineConfig | QwenImagePipelineConfig | WanPipelineConfig | ZImagePipelineConfig | Flux2KleinPipelineConfig):
|
|
402
451
|
assert config.parallelism in (1, 2, 4, 8), "parallelism must be 1, 2, 4 or 8"
|
|
403
452
|
config.batch_cfg = True if config.parallelism > 1 and config.use_cfg_parallel else config.batch_cfg
|
|
404
453
|
|