diffsynth-engine 0.7.0__py3-none-any.whl → 0.7.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,12 +7,14 @@ from .configs import (
7
7
  QwenImagePipelineConfig,
8
8
  HunyuanPipelineConfig,
9
9
  ZImagePipelineConfig,
10
+ Flux2KleinPipelineConfig,
10
11
  SDStateDicts,
11
12
  SDXLStateDicts,
12
13
  FluxStateDicts,
13
14
  WanStateDicts,
14
15
  QwenImageStateDicts,
15
16
  ZImageStateDicts,
17
+ Flux2StateDicts,
16
18
  AttnImpl,
17
19
  SpargeAttentionParams,
18
20
  VideoSparseAttentionParams,
@@ -26,6 +28,7 @@ from .pipelines import (
26
28
  SDImagePipeline,
27
29
  SDXLImagePipeline,
28
30
  FluxImagePipeline,
31
+ Flux2KleinPipeline,
29
32
  WanVideoPipeline,
30
33
  WanDMDPipeline,
31
34
  QwenImagePipeline,
@@ -59,12 +62,14 @@ __all__ = [
59
62
  "QwenImagePipelineConfig",
60
63
  "HunyuanPipelineConfig",
61
64
  "ZImagePipelineConfig",
65
+ "Flux2KleinPipelineConfig",
62
66
  "SDStateDicts",
63
67
  "SDXLStateDicts",
64
68
  "FluxStateDicts",
65
69
  "WanStateDicts",
66
70
  "QwenImageStateDicts",
67
71
  "ZImageStateDicts",
72
+ "Flux2StateDicts",
68
73
  "AttnImpl",
69
74
  "SpargeAttentionParams",
70
75
  "VideoSparseAttentionParams",
@@ -78,6 +83,7 @@ __all__ = [
78
83
  "SDXLImagePipeline",
79
84
  "SDXLControlNetUnion",
80
85
  "FluxImagePipeline",
86
+ "Flux2KleinPipeline",
81
87
  "FluxControlNet",
82
88
  "FluxIPAdapter",
83
89
  "FluxRedux",
@@ -0,0 +1,68 @@
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "dtype": "bfloat16",
9
+ "eos_token_id": 151645,
10
+ "head_dim": 128,
11
+ "hidden_act": "silu",
12
+ "hidden_size": 4096,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 12288,
15
+ "layer_types": [
16
+ "full_attention",
17
+ "full_attention",
18
+ "full_attention",
19
+ "full_attention",
20
+ "full_attention",
21
+ "full_attention",
22
+ "full_attention",
23
+ "full_attention",
24
+ "full_attention",
25
+ "full_attention",
26
+ "full_attention",
27
+ "full_attention",
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention",
42
+ "full_attention",
43
+ "full_attention",
44
+ "full_attention",
45
+ "full_attention",
46
+ "full_attention",
47
+ "full_attention",
48
+ "full_attention",
49
+ "full_attention",
50
+ "full_attention",
51
+ "full_attention"
52
+ ],
53
+ "max_position_embeddings": 40960,
54
+ "max_window_layers": 36,
55
+ "model_type": "qwen3",
56
+ "num_attention_heads": 32,
57
+ "num_hidden_layers": 36,
58
+ "num_key_value_heads": 8,
59
+ "rms_norm_eps": 1e-06,
60
+ "rope_scaling": null,
61
+ "rope_theta": 1000000,
62
+ "sliding_window": null,
63
+ "tie_word_embeddings": false,
64
+ "transformers_version": "4.56.1",
65
+ "use_cache": true,
66
+ "use_sliding_window": false,
67
+ "vocab_size": 151936
68
+ }
@@ -11,6 +11,7 @@ from .pipeline import (
11
11
  QwenImagePipelineConfig,
12
12
  HunyuanPipelineConfig,
13
13
  ZImagePipelineConfig,
14
+ Flux2KleinPipelineConfig,
14
15
  BaseStateDicts,
15
16
  SDStateDicts,
16
17
  SDXLStateDicts,
@@ -19,6 +20,7 @@ from .pipeline import (
19
20
  WanS2VStateDicts,
20
21
  QwenImageStateDicts,
21
22
  ZImageStateDicts,
23
+ Flux2StateDicts,
22
24
  AttnImpl,
23
25
  SpargeAttentionParams,
24
26
  VideoSparseAttentionParams,
@@ -44,6 +46,7 @@ __all__ = [
44
46
  "QwenImagePipelineConfig",
45
47
  "HunyuanPipelineConfig",
46
48
  "ZImagePipelineConfig",
49
+ "Flux2KleinPipelineConfig",
47
50
  "BaseStateDicts",
48
51
  "SDStateDicts",
49
52
  "SDXLStateDicts",
@@ -52,6 +55,7 @@ __all__ = [
52
55
  "WanS2VStateDicts",
53
56
  "QwenImageStateDicts",
54
57
  "ZImageStateDicts",
58
+ "Flux2StateDicts",
55
59
  "AttnImpl",
56
60
  "SpargeAttentionParams",
57
61
  "VideoSparseAttentionParams",
@@ -3,6 +3,7 @@ import torch
3
3
  from enum import Enum
4
4
  from dataclasses import dataclass, field
5
5
  from typing import List, Dict, Tuple, Optional
6
+ from typing_extensions import Literal
6
7
 
7
8
  from diffsynth_engine.configs.controlnet import ControlType
8
9
 
@@ -339,6 +340,47 @@ class ZImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig,
339
340
  init_parallel_config(self)
340
341
 
341
342
 
343
+ @dataclass
344
+ class Flux2KleinPipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig, BaseConfig):
345
+ model_path: str | os.PathLike | List[str | os.PathLike]
346
+ model_dtype: torch.dtype = torch.bfloat16
347
+ vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
348
+ vae_dtype: torch.dtype = torch.bfloat16
349
+ encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
350
+ encoder_dtype: torch.dtype = torch.bfloat16
351
+ image_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
352
+ image_encoder_dtype: torch.dtype = torch.bfloat16
353
+ model_size: Literal["4B", "9B"] = "4B"
354
+
355
+ @classmethod
356
+ def basic_config(
357
+ cls,
358
+ model_path: str | os.PathLike | List[str | os.PathLike],
359
+ encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
360
+ vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
361
+ image_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
362
+ device: str = "cuda",
363
+ parallelism: int = 1,
364
+ offload_mode: Optional[str] = None,
365
+ offload_to_disk: bool = False,
366
+ ) -> "Flux2KleinPipelineConfig":
367
+ return cls(
368
+ model_path=model_path,
369
+ device=device,
370
+ encoder_path=encoder_path,
371
+ vae_path=vae_path,
372
+ image_encoder_path=image_encoder_path,
373
+ parallelism=parallelism,
374
+ use_cfg_parallel=True if parallelism > 1 else False,
375
+ use_fsdp=True if parallelism > 1 else False,
376
+ offload_mode=offload_mode,
377
+ offload_to_disk=offload_to_disk,
378
+ )
379
+
380
+ def __post_init__(self):
381
+ init_parallel_config(self)
382
+
383
+
342
384
  @dataclass
343
385
  class BaseStateDicts:
344
386
  pass
@@ -398,7 +440,14 @@ class ZImageStateDicts:
398
440
  image_encoder: Optional[Dict[str, torch.Tensor]] = None
399
441
 
400
442
 
401
- def init_parallel_config(config: FluxPipelineConfig | QwenImagePipelineConfig | WanPipelineConfig | ZImagePipelineConfig):
443
+ @dataclass
444
+ class Flux2StateDicts:
445
+ model: Dict[str, torch.Tensor]
446
+ vae: Dict[str, torch.Tensor]
447
+ encoder: Dict[str, torch.Tensor]
448
+
449
+
450
+ def init_parallel_config(config: FluxPipelineConfig | QwenImagePipelineConfig | WanPipelineConfig | ZImagePipelineConfig | Flux2KleinPipelineConfig):
402
451
  assert config.parallelism in (1, 2, 4, 8), "parallelism must be 1, 2, 4 or 8"
403
452
  config.batch_cfg = True if config.parallelism > 1 and config.use_cfg_parallel else config.batch_cfg
404
453
 
@@ -0,0 +1,7 @@
1
+ from .flux2_dit import Flux2DiT
2
+ from .flux2_vae import Flux2VAE
3
+
4
+ __all__ = [
5
+ "Flux2DiT",
6
+ "Flux2VAE",
7
+ ]