diffsynth-engine 0.3.6.dev4__tar.gz → 0.3.6.dev6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/PKG-INFO +1 -1
  2. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/flux/__init__.py +2 -0
  3. diffsynth_engine-0.3.6.dev6/diffsynth_engine/models/flux/flux_dit_fbcache.py +205 -0
  4. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/sd/sd_controlnet.py +167 -85
  5. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/sdxl/__init__.py +1 -1
  6. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/sdxl/sdxl_controlnet.py +118 -73
  7. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/sdxl/sdxl_unet.py +1 -2
  8. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/wan/wan_dit.py +3 -2
  9. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/wan/wan_vae.py +14 -15
  10. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/pipelines/controlnet_helper.py +4 -2
  11. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/pipelines/flux_image.py +25 -9
  12. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/pipelines/sd_image.py +20 -15
  13. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/pipelines/sdxl_image.py +44 -19
  14. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine.egg-info/PKG-INFO +1 -1
  15. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine.egg-info/SOURCES.txt +1 -0
  16. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/.gitignore +0 -0
  17. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/.pre-commit-config.yaml +0 -0
  18. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/LICENSE +0 -0
  19. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/MANIFEST.in +0 -0
  20. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/README.md +0 -0
  21. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/assets/dingtalk.png +0 -0
  22. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/assets/showcase.jpeg +0 -0
  23. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/__init__.py +0 -0
  24. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/algorithm/__init__.py +0 -0
  25. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/algorithm/noise_scheduler/__init__.py +0 -0
  26. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +0 -0
  27. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +0 -0
  28. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +0 -0
  29. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +0 -0
  30. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +0 -0
  31. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  32. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +0 -0
  33. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +0 -0
  34. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +0 -0
  35. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +0 -0
  36. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +0 -0
  37. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +0 -0
  38. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/algorithm/sampler/__init__.py +0 -0
  39. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  40. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +0 -0
  41. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  42. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +0 -0
  43. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +0 -0
  44. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +0 -0
  45. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +0 -0
  46. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +0 -0
  47. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +0 -0
  48. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +0 -0
  49. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +0 -0
  50. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +0 -0
  51. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/models/components/vae.json +0 -0
  52. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/models/flux/flux_dit.json +0 -0
  53. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/models/flux/flux_text_encoder.json +0 -0
  54. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/models/flux/flux_vae.json +0 -0
  55. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/models/sd/sd_text_encoder.json +0 -0
  56. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/models/sd/sd_unet.json +0 -0
  57. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/models/sd3/sd3_dit.json +0 -0
  58. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +0 -0
  59. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +0 -0
  60. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/models/sdxl/sdxl_unet.json +0 -0
  61. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/models/wan/dit/1.3b-t2v.json +0 -0
  62. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/models/wan/dit/14b-flf2v.json +0 -0
  63. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/models/wan/dit/14b-i2v.json +0 -0
  64. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/models/wan/dit/14b-t2v.json +0 -0
  65. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +0 -0
  66. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +0 -0
  67. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +0 -0
  68. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +0 -0
  69. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +0 -0
  70. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  71. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +0 -0
  72. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +0 -0
  73. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +0 -0
  74. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +0 -0
  75. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +0 -0
  76. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +0 -0
  77. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +0 -0
  78. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +0 -0
  79. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +0 -0
  80. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +0 -0
  81. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +0 -0
  82. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  83. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +0 -0
  84. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +0 -0
  85. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/kernels/__init__.py +0 -0
  86. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/__init__.py +0 -0
  87. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/base.py +0 -0
  88. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/basic/__init__.py +0 -0
  89. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/basic/attention.py +0 -0
  90. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/basic/lora.py +0 -0
  91. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/basic/relative_position_emb.py +0 -0
  92. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/basic/timestep.py +0 -0
  93. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/basic/transformer_helper.py +0 -0
  94. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/basic/unet_helper.py +0 -0
  95. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/flux/flux_controlnet.py +0 -0
  96. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/flux/flux_dit.py +0 -0
  97. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/flux/flux_ipadapter.py +0 -0
  98. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/flux/flux_redux.py +0 -0
  99. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/flux/flux_text_encoder.py +0 -0
  100. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/flux/flux_vae.py +0 -0
  101. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/sd/__init__.py +0 -0
  102. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/sd/sd_text_encoder.py +0 -0
  103. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/sd/sd_unet.py +0 -0
  104. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/sd/sd_vae.py +0 -0
  105. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/sd3/__init__.py +0 -0
  106. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/sd3/sd3_dit.py +0 -0
  107. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/sd3/sd3_text_encoder.py +0 -0
  108. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/sd3/sd3_vae.py +0 -0
  109. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/sdxl/sdxl_text_encoder.py +0 -0
  110. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/sdxl/sdxl_vae.py +0 -0
  111. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/text_encoder/__init__.py +0 -0
  112. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/text_encoder/clip.py +0 -0
  113. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/text_encoder/siglip.py +0 -0
  114. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/text_encoder/t5.py +0 -0
  115. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/utils.py +0 -0
  116. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/vae/__init__.py +0 -0
  117. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/vae/vae.py +0 -0
  118. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/wan/__init__.py +0 -0
  119. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/wan/wan_image_encoder.py +0 -0
  120. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/models/wan/wan_text_encoder.py +0 -0
  121. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/pipelines/__init__.py +0 -0
  122. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/pipelines/base.py +0 -0
  123. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/pipelines/wan_video.py +0 -0
  124. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/processor/__init__.py +0 -0
  125. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/processor/canny_processor.py +0 -0
  126. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/processor/depth_processor.py +0 -0
  127. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/tokenizers/__init__.py +0 -0
  128. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/tokenizers/base.py +0 -0
  129. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/tokenizers/clip.py +0 -0
  130. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/tokenizers/t5.py +0 -0
  131. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/tokenizers/wan.py +0 -0
  132. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/tools/__init__.py +0 -0
  133. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/tools/flux_inpainting_tool.py +0 -0
  134. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/tools/flux_outpainting_tool.py +0 -0
  135. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/tools/flux_reference_tool.py +0 -0
  136. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/tools/flux_replace_tool.py +0 -0
  137. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/utils/__init__.py +0 -0
  138. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/utils/constants.py +0 -0
  139. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/utils/download.py +0 -0
  140. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/utils/env.py +0 -0
  141. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/utils/flag.py +0 -0
  142. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/utils/fp8_linear.py +0 -0
  143. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/utils/gguf.py +0 -0
  144. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/utils/image.py +0 -0
  145. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/utils/loader.py +0 -0
  146. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/utils/lock.py +0 -0
  147. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/utils/logging.py +0 -0
  148. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/utils/offload.py +0 -0
  149. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/utils/onnx.py +0 -0
  150. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/utils/parallel.py +0 -0
  151. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/utils/platform.py +0 -0
  152. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/utils/prompt.py +0 -0
  153. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine/utils/video.py +0 -0
  154. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine.egg-info/dependency_links.txt +0 -0
  155. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine.egg-info/requires.txt +0 -0
  156. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/diffsynth_engine.egg-info/top_level.txt +0 -0
  157. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/docs/tutorial.md +0 -0
  158. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/docs/tutorial_zh.md +0 -0
  159. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/pyproject.toml +0 -0
  160. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/setup.cfg +0 -0
  161. {diffsynth_engine-0.3.6.dev4 → diffsynth_engine-0.3.6.dev6}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.3.6.dev4
3
+ Version: 0.3.6.dev6
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -4,6 +4,7 @@ from .flux_vae import FluxVAEDecoder, FluxVAEEncoder, config as flux_vae_config
4
4
  from .flux_controlnet import FluxControlNet
5
5
  from .flux_ipadapter import FluxIPAdapter
6
6
  from .flux_redux import FluxRedux
7
+ from .flux_dit_fbcache import FluxDiTFBCache
7
8
 
8
9
  __all__ = [
9
10
  "FluxRedux",
@@ -14,6 +15,7 @@ __all__ = [
14
15
  "FluxTextEncoder2",
15
16
  "FluxVAEDecoder",
16
17
  "FluxVAEEncoder",
18
+ "FluxDiTFBCache",
17
19
  "flux_dit_config",
18
20
  "flux_text_encoder_config",
19
21
  "flux_vae_config",
@@ -0,0 +1,205 @@
1
+ import torch
2
+ import numpy as np
3
+ from typing import Dict, Optional
4
+
5
+ from diffsynth_engine.models.utils import no_init_weights
6
+ from diffsynth_engine.utils.gguf import gguf_inference
7
+ from diffsynth_engine.utils.fp8_linear import fp8_inference
8
+ from diffsynth_engine.utils.parallel import (
9
+ cfg_parallel,
10
+ cfg_parallel_unshard,
11
+ sequence_parallel,
12
+ sequence_parallel_unshard,
13
+ )
14
+ from diffsynth_engine.utils import logging
15
+ from diffsynth_engine.models.flux.flux_dit import FluxDiT
16
+
17
+ logger = logging.get_logger(__name__)
18
+
19
+
20
+ class FluxDiTFBCache(FluxDiT):
21
+ def __init__(
22
+ self,
23
+ in_channel: int = 64,
24
+ attn_impl: Optional[str] = None,
25
+ device: str = "cuda:0",
26
+ dtype: torch.dtype = torch.bfloat16,
27
+ relative_l1_threshold: float = 0.05,
28
+ ):
29
+ super().__init__(in_channel=in_channel, attn_impl=attn_impl, device=device, dtype=dtype)
30
+ self.relative_l1_threshold = relative_l1_threshold
31
+ self.step_count = 0
32
+ self.num_inference_steps = 0
33
+
34
+ def is_relative_l1_below_threshold(self, prev_residual, residual, threshold):
35
+ if threshold <= 0.0:
36
+ return False
37
+
38
+ if prev_residual.shape != residual.shape:
39
+ return False
40
+
41
+ mean_diff = (prev_residual - residual).abs().mean()
42
+ mean_prev_residual = prev_residual.abs().mean()
43
+ diff = mean_diff / mean_prev_residual
44
+ return diff.item() < threshold
45
+
46
+ def refresh_cache_status(self, num_inference_steps):
47
+ self.step_count = 0
48
+ self.num_inference_steps = num_inference_steps
49
+
50
+ def forward(
51
+ self,
52
+ hidden_states,
53
+ timestep,
54
+ prompt_emb,
55
+ pooled_prompt_emb,
56
+ image_emb,
57
+ guidance,
58
+ text_ids,
59
+ image_ids=None,
60
+ controlnet_double_block_output=None,
61
+ controlnet_single_block_output=None,
62
+ **kwargs,
63
+ ):
64
+ h, w = hidden_states.shape[-2:]
65
+ if image_ids is None:
66
+ image_ids = self.prepare_image_ids(hidden_states)
67
+ controlnet_double_block_output = (
68
+ controlnet_double_block_output if controlnet_double_block_output is not None else ()
69
+ )
70
+ controlnet_single_block_output = (
71
+ controlnet_single_block_output if controlnet_single_block_output is not None else ()
72
+ )
73
+
74
+ fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
75
+ use_cfg = hidden_states.shape[0] > 1
76
+ with (
77
+ fp8_inference(fp8_linear_enabled),
78
+ gguf_inference(),
79
+ cfg_parallel(
80
+ (
81
+ hidden_states,
82
+ timestep,
83
+ prompt_emb,
84
+ pooled_prompt_emb,
85
+ image_emb,
86
+ guidance,
87
+ text_ids,
88
+ image_ids,
89
+ *controlnet_double_block_output,
90
+ *controlnet_single_block_output,
91
+ ),
92
+ use_cfg=use_cfg,
93
+ ),
94
+ ):
95
+ # warning: keep the order of time_embedding + guidance_embedding + pooled_text_embedding
96
+ # addition of floating point numbers does not meet commutative law
97
+ conditioning = self.time_embedder(timestep, hidden_states.dtype)
98
+ if self.guidance_embedder is not None:
99
+ guidance = guidance * 1000
100
+ conditioning += self.guidance_embedder(guidance, hidden_states.dtype)
101
+ conditioning += self.pooled_text_embedder(pooled_prompt_emb)
102
+ rope_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
103
+ text_rope_emb = rope_emb[:, :, : text_ids.size(1)]
104
+ image_rope_emb = rope_emb[:, :, text_ids.size(1) :]
105
+ hidden_states = self.patchify(hidden_states)
106
+
107
+ with sequence_parallel(
108
+ (
109
+ hidden_states,
110
+ prompt_emb,
111
+ text_rope_emb,
112
+ image_rope_emb,
113
+ *controlnet_double_block_output,
114
+ *controlnet_single_block_output,
115
+ ),
116
+ seq_dims=(
117
+ 1,
118
+ 1,
119
+ 2,
120
+ 2,
121
+ *(1 for _ in controlnet_double_block_output),
122
+ *(1 for _ in controlnet_single_block_output),
123
+ ),
124
+ ):
125
+ hidden_states = self.x_embedder(hidden_states)
126
+ prompt_emb = self.context_embedder(prompt_emb)
127
+ rope_emb = torch.cat((text_rope_emb, image_rope_emb), dim=2)
128
+
129
+ # first block
130
+ original_hidden_states = hidden_states
131
+ hidden_states, prompt_emb = self.blocks[0](hidden_states, prompt_emb, conditioning, rope_emb, image_emb)
132
+ first_hidden_states_residual = hidden_states - original_hidden_states
133
+
134
+ (first_hidden_states_residual,) = sequence_parallel_unshard(
135
+ (first_hidden_states_residual,), seq_dims=(1,), seq_lens=(h * w // 4,)
136
+ )
137
+
138
+ if self.step_count == 0 or self.step_count == (self.num_inference_steps - 1):
139
+ should_calc = True
140
+ else:
141
+ skip = self.is_relative_l1_below_threshold(
142
+ first_hidden_states_residual,
143
+ self.prev_first_hidden_states_residual,
144
+ threshold=self.relative_l1_threshold,
145
+ )
146
+ should_calc = not skip
147
+ self.step_count += 1
148
+
149
+ if not should_calc:
150
+ hidden_states += self.previous_residual
151
+ else:
152
+ self.prev_first_hidden_states_residual = first_hidden_states_residual
153
+
154
+ first_hidden_states = hidden_states.clone()
155
+ for i, block in enumerate(self.blocks[1:]):
156
+ hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, rope_emb, image_emb)
157
+ if len(controlnet_double_block_output) > 0:
158
+ interval_control = len(self.blocks) / len(controlnet_double_block_output)
159
+ interval_control = int(np.ceil(interval_control))
160
+ hidden_states = hidden_states + controlnet_double_block_output[i // interval_control]
161
+ hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
162
+ for i, block in enumerate(self.single_blocks):
163
+ hidden_states = block(hidden_states, conditioning, rope_emb, image_emb)
164
+ if len(controlnet_single_block_output) > 0:
165
+ interval_control = len(self.single_blocks) / len(controlnet_double_block_output)
166
+ interval_control = int(np.ceil(interval_control))
167
+ hidden_states = hidden_states + controlnet_single_block_output[i // interval_control]
168
+
169
+ hidden_states = hidden_states[:, prompt_emb.shape[1] :]
170
+
171
+ previous_residual = hidden_states - first_hidden_states
172
+ self.previous_residual = previous_residual
173
+
174
+ hidden_states = self.final_norm_out(hidden_states, conditioning)
175
+ hidden_states = self.final_proj_out(hidden_states)
176
+ (hidden_states,) = sequence_parallel_unshard((hidden_states,), seq_dims=(1,), seq_lens=(h * w // 4,))
177
+
178
+ hidden_states = self.unpatchify(hidden_states, h, w)
179
+ (hidden_states,) = cfg_parallel_unshard((hidden_states,), use_cfg=use_cfg)
180
+
181
+ return hidden_states
182
+
183
+ @classmethod
184
+ def from_state_dict(
185
+ cls,
186
+ state_dict: Dict[str, torch.Tensor],
187
+ device: str,
188
+ dtype: torch.dtype,
189
+ in_channel: int = 64,
190
+ attn_impl: Optional[str] = None,
191
+ fb_cache_relative_l1_threshold: float = 0.05,
192
+ ):
193
+ with no_init_weights():
194
+ model = torch.nn.utils.skip_init(
195
+ cls,
196
+ device=device,
197
+ dtype=dtype,
198
+ in_channel=in_channel,
199
+ attn_impl=attn_impl,
200
+ fb_cache_relative_l1_threshold=fb_cache_relative_l1_threshold,
201
+ )
202
+ model = model.requires_grad_(False) # for loading gguf
203
+ model.load_state_dict(state_dict, assign=True)
204
+ model.to(device=device, dtype=dtype, non_blocking=True)
205
+ return model
@@ -12,18 +12,29 @@ from diffsynth_engine.models.basic.unet_helper import (
12
12
  DownSampler,
13
13
  )
14
14
 
15
+
15
16
  class ControlNetConditioningLayer(nn.Module):
16
- def __init__(self, channels = (3, 16, 32, 96, 256, 320), device = "cuda:0", dtype=torch.float16):
17
+ def __init__(self, channels=(3, 16, 32, 96, 256, 320), device="cuda:0", dtype=torch.float16):
17
18
  super().__init__()
18
19
  self.blocks = torch.nn.ModuleList([])
19
- self.blocks.append(torch.nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1, device=device, dtype=dtype))
20
+ self.blocks.append(
21
+ torch.nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1, device=device, dtype=dtype)
22
+ )
20
23
  self.blocks.append(torch.nn.SiLU())
21
24
  for i in range(1, len(channels) - 2):
22
- self.blocks.append(torch.nn.Conv2d(channels[i], channels[i], kernel_size=3, padding=1, device=device, dtype=dtype))
25
+ self.blocks.append(
26
+ torch.nn.Conv2d(channels[i], channels[i], kernel_size=3, padding=1, device=device, dtype=dtype)
27
+ )
23
28
  self.blocks.append(torch.nn.SiLU())
24
- self.blocks.append(torch.nn.Conv2d(channels[i], channels[i+1], kernel_size=3, padding=1, stride=2, device=device, dtype=dtype))
29
+ self.blocks.append(
30
+ torch.nn.Conv2d(
31
+ channels[i], channels[i + 1], kernel_size=3, padding=1, stride=2, device=device, dtype=dtype
32
+ )
33
+ )
25
34
  self.blocks.append(torch.nn.SiLU())
26
- self.blocks.append(torch.nn.Conv2d(channels[-2], channels[-1], kernel_size=3, padding=1, device=device, dtype=dtype))
35
+ self.blocks.append(
36
+ torch.nn.Conv2d(channels[-2], channels[-1], kernel_size=3, padding=1, device=device, dtype=dtype)
37
+ )
27
38
 
28
39
  def forward(self, conditioning):
29
40
  for block in self.blocks:
@@ -38,15 +49,73 @@ class SDControlNetStateDictConverter(StateDictConverter):
38
49
  def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
39
50
  # architecture
40
51
  block_types = [
41
- 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
42
- 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
43
- 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
44
- 'ResnetBlock', 'PushBlock', 'ResnetBlock', 'PushBlock',
45
- 'ResnetBlock', 'AttentionBlock', 'ResnetBlock',
46
- 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'UpSampler',
47
- 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
48
- 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
49
- 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock'
52
+ "ResnetBlock",
53
+ "AttentionBlock",
54
+ "PushBlock",
55
+ "ResnetBlock",
56
+ "AttentionBlock",
57
+ "PushBlock",
58
+ "DownSampler",
59
+ "PushBlock",
60
+ "ResnetBlock",
61
+ "AttentionBlock",
62
+ "PushBlock",
63
+ "ResnetBlock",
64
+ "AttentionBlock",
65
+ "PushBlock",
66
+ "DownSampler",
67
+ "PushBlock",
68
+ "ResnetBlock",
69
+ "AttentionBlock",
70
+ "PushBlock",
71
+ "ResnetBlock",
72
+ "AttentionBlock",
73
+ "PushBlock",
74
+ "DownSampler",
75
+ "PushBlock",
76
+ "ResnetBlock",
77
+ "PushBlock",
78
+ "ResnetBlock",
79
+ "PushBlock",
80
+ "ResnetBlock",
81
+ "AttentionBlock",
82
+ "ResnetBlock",
83
+ "PopBlock",
84
+ "ResnetBlock",
85
+ "PopBlock",
86
+ "ResnetBlock",
87
+ "PopBlock",
88
+ "ResnetBlock",
89
+ "UpSampler",
90
+ "PopBlock",
91
+ "ResnetBlock",
92
+ "AttentionBlock",
93
+ "PopBlock",
94
+ "ResnetBlock",
95
+ "AttentionBlock",
96
+ "PopBlock",
97
+ "ResnetBlock",
98
+ "AttentionBlock",
99
+ "UpSampler",
100
+ "PopBlock",
101
+ "ResnetBlock",
102
+ "AttentionBlock",
103
+ "PopBlock",
104
+ "ResnetBlock",
105
+ "AttentionBlock",
106
+ "PopBlock",
107
+ "ResnetBlock",
108
+ "AttentionBlock",
109
+ "UpSampler",
110
+ "PopBlock",
111
+ "ResnetBlock",
112
+ "AttentionBlock",
113
+ "PopBlock",
114
+ "ResnetBlock",
115
+ "AttentionBlock",
116
+ "PopBlock",
117
+ "ResnetBlock",
118
+ "AttentionBlock",
50
119
  ]
51
120
 
52
121
  # controlnet_rename_dict
@@ -66,7 +135,7 @@ class SDControlNetStateDictConverter(StateDictConverter):
66
135
  "controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
67
136
  "controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
68
137
  "controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
69
- "controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
138
+ "controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
70
139
  }
71
140
 
72
141
  # Rename each parameter
@@ -91,7 +160,12 @@ class SDControlNetStateDictConverter(StateDictConverter):
91
160
  elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
92
161
  if names[0] == "mid_block":
93
162
  names.insert(1, "0")
94
- block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
163
+ block_type = {
164
+ "resnets": "ResnetBlock",
165
+ "attentions": "AttentionBlock",
166
+ "downsamplers": "DownSampler",
167
+ "upsamplers": "UpSampler",
168
+ }[names[2]]
95
169
  block_type_with_id = ".".join(names[:4])
96
170
  if block_type_with_id != last_block_type_with_id[block_type]:
97
171
  block_id[block_type] += 1
@@ -102,9 +176,9 @@ class SDControlNetStateDictConverter(StateDictConverter):
102
176
  names = ["blocks", str(block_id[block_type])] + names[4:]
103
177
  if "ff" in names:
104
178
  ff_index = names.index("ff")
105
- component = ".".join(names[ff_index:ff_index+3])
179
+ component = ".".join(names[ff_index : ff_index + 3])
106
180
  component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
107
- names = names[:ff_index] + [component] + names[ff_index+3:]
181
+ names = names[:ff_index] + [component] + names[ff_index + 3 :]
108
182
  if "to_out" in names:
109
183
  names.pop(names.index("to_out") + 1)
110
184
  else:
@@ -117,13 +191,21 @@ class SDControlNetStateDictConverter(StateDictConverter):
117
191
  if ".proj_in." in name or ".proj_out." in name:
118
192
  param = param.squeeze()
119
193
  if rename_dict[name] in [
120
- "controlnet_blocks.1.bias", "controlnet_blocks.2.bias", "controlnet_blocks.3.bias", "controlnet_blocks.5.bias", "controlnet_blocks.6.bias",
121
- "controlnet_blocks.8.bias", "controlnet_blocks.9.bias", "controlnet_blocks.10.bias", "controlnet_blocks.11.bias", "controlnet_blocks.12.bias"
194
+ "controlnet_blocks.1.bias",
195
+ "controlnet_blocks.2.bias",
196
+ "controlnet_blocks.3.bias",
197
+ "controlnet_blocks.5.bias",
198
+ "controlnet_blocks.6.bias",
199
+ "controlnet_blocks.8.bias",
200
+ "controlnet_blocks.9.bias",
201
+ "controlnet_blocks.10.bias",
202
+ "controlnet_blocks.11.bias",
203
+ "controlnet_blocks.12.bias",
122
204
  ]:
123
205
  continue
124
206
  state_dict_[rename_dict[name]] = param
125
207
  return state_dict_
126
-
208
+
127
209
  def _from_civitai(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
128
210
  rename_dict = {
129
211
  "control_model.time_embed.0.weight": "time_embedding.timestep_embedder.0.weight",
@@ -496,69 +578,71 @@ class SDControlNet(PreTrainedModel):
496
578
  self.time_embedding = TimestepEmbeddings(dim_in=320, dim_out=1280, device=device, dtype=dtype)
497
579
  self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1, device=device, dtype=dtype)
498
580
 
499
- self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320), device=device, dtype=dtype)
581
+ self.controlnet_conv_in = ControlNetConditioningLayer(
582
+ channels=(3, 16, 32, 96, 256, 320), device=device, dtype=dtype
583
+ )
500
584
 
501
- self.blocks = torch.nn.ModuleList([
502
- # CrossAttnDownBlock2D
503
- ResnetBlock(320, 320, 1280, device=device, dtype=dtype),
504
- AttentionBlock(8, 40, 320, 1, 768, device=device, dtype=dtype),
505
- PushBlock(),
506
- ResnetBlock(320, 320, 1280, device=device, dtype=dtype),
507
- AttentionBlock(8, 40, 320, 1, 768, device=device, dtype=dtype),
508
- PushBlock(),
509
- DownSampler(320, device=device, dtype=dtype),
510
- PushBlock(),
511
- # CrossAttnDownBlock2D
512
- ResnetBlock(320, 640, 1280, device=device, dtype=dtype),
513
- AttentionBlock(8, 80, 640, 1, 768, device=device, dtype=dtype),
514
- PushBlock(),
515
- ResnetBlock(640, 640, 1280, device=device, dtype=dtype),
516
- AttentionBlock(8, 80, 640, 1, 768, device=device, dtype=dtype),
517
- PushBlock(),
518
- DownSampler(640, device=device, dtype=dtype),
519
- PushBlock(),
520
- # CrossAttnDownBlock2D
521
- ResnetBlock(640, 1280, 1280, device=device, dtype=dtype),
522
- AttentionBlock(8, 160, 1280, 1, 768, device=device, dtype=dtype),
523
- PushBlock(),
524
- ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
525
- AttentionBlock(8, 160, 1280, 1, 768, device=device, dtype=dtype),
526
- PushBlock(),
527
- DownSampler(1280, device=device, dtype=dtype),
528
- PushBlock(),
529
- # DownBlock2D
530
- ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
531
- PushBlock(),
532
- ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
533
- PushBlock(),
534
- # UNetMidBlock2DCrossAttn
535
- ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
536
- AttentionBlock(8, 160, 1280, 1, 768, device=device, dtype=dtype),
537
- ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
538
- PushBlock()
539
- ])
585
+ self.blocks = torch.nn.ModuleList(
586
+ [
587
+ # CrossAttnDownBlock2D
588
+ ResnetBlock(320, 320, 1280, device=device, dtype=dtype),
589
+ AttentionBlock(8, 40, 320, 1, 768, device=device, dtype=dtype),
590
+ PushBlock(),
591
+ ResnetBlock(320, 320, 1280, device=device, dtype=dtype),
592
+ AttentionBlock(8, 40, 320, 1, 768, device=device, dtype=dtype),
593
+ PushBlock(),
594
+ DownSampler(320, device=device, dtype=dtype),
595
+ PushBlock(),
596
+ # CrossAttnDownBlock2D
597
+ ResnetBlock(320, 640, 1280, device=device, dtype=dtype),
598
+ AttentionBlock(8, 80, 640, 1, 768, device=device, dtype=dtype),
599
+ PushBlock(),
600
+ ResnetBlock(640, 640, 1280, device=device, dtype=dtype),
601
+ AttentionBlock(8, 80, 640, 1, 768, device=device, dtype=dtype),
602
+ PushBlock(),
603
+ DownSampler(640, device=device, dtype=dtype),
604
+ PushBlock(),
605
+ # CrossAttnDownBlock2D
606
+ ResnetBlock(640, 1280, 1280, device=device, dtype=dtype),
607
+ AttentionBlock(8, 160, 1280, 1, 768, device=device, dtype=dtype),
608
+ PushBlock(),
609
+ ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
610
+ AttentionBlock(8, 160, 1280, 1, 768, device=device, dtype=dtype),
611
+ PushBlock(),
612
+ DownSampler(1280, device=device, dtype=dtype),
613
+ PushBlock(),
614
+ # DownBlock2D
615
+ ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
616
+ PushBlock(),
617
+ ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
618
+ PushBlock(),
619
+ # UNetMidBlock2DCrossAttn
620
+ ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
621
+ AttentionBlock(8, 160, 1280, 1, 768, device=device, dtype=dtype),
622
+ ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
623
+ PushBlock(),
624
+ ]
625
+ )
540
626
 
541
- self.controlnet_blocks = torch.nn.ModuleList([
542
- torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype),
543
- torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
544
- torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
545
- torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
546
- torch.nn.Conv2d(640, 640, kernel_size=(1, 1), device=device, dtype=dtype),
547
- torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
548
- torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
549
- torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), device=device, dtype=dtype),
550
- torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
551
- torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
552
- torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
553
- torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
554
- torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
555
- ])
627
+ self.controlnet_blocks = torch.nn.ModuleList(
628
+ [
629
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype),
630
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
631
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
632
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
633
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1), device=device, dtype=dtype),
634
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
635
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
636
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), device=device, dtype=dtype),
637
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
638
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
639
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
640
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
641
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
642
+ ]
643
+ )
556
644
 
557
- def forward(
558
- self,
559
- sample, timestep, encoder_hidden_states, conditioning,
560
- **kwargs
561
- ):
645
+ def forward(self, sample, timestep, encoder_hidden_states, conditioning, **kwargs):
562
646
  # 1. time
563
647
  time_emb = self.time_embedding(timestep, dtype=sample.dtype)
564
648
 
@@ -585,9 +669,7 @@ class SDControlNet(PreTrainedModel):
585
669
  attn_impl: Optional[str] = None,
586
670
  ):
587
671
  with no_init_weights():
588
- model = torch.nn.utils.skip_init(
589
- cls, attn_impl=attn_impl, device=device, dtype=dtype
590
- )
672
+ model = torch.nn.utils.skip_init(cls, attn_impl=attn_impl, device=device, dtype=dtype)
591
673
  model.load_state_dict(state_dict)
592
674
  model.to(device=device, dtype=dtype, non_blocking=True)
593
- return model
675
+ return model
@@ -9,7 +9,7 @@ __all__ = [
9
9
  "SDXLUNet",
10
10
  "SDXLVAEDecoder",
11
11
  "SDXLVAEEncoder",
12
- "SDXLControlNetUnion",
12
+ "SDXLControlNetUnion",
13
13
  "sdxl_text_encoder_config",
14
14
  "sdxl_unet_config",
15
15
  ]