diffsynth-engine 0.3.6.dev6__tar.gz → 0.3.6.dev8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/PKG-INFO +1 -1
  2. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/basic/lora.py +39 -19
  3. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/basic/transformer_helper.py +0 -11
  4. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/flux/flux_dit.py +1 -1
  5. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/flux/flux_vae.py +4 -2
  6. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/text_encoder/t5.py +2 -4
  7. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/pipelines/base.py +25 -5
  8. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/pipelines/flux_image.py +22 -12
  9. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/utils/fp8_linear.py +39 -0
  10. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/utils/offload.py +1 -1
  11. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine.egg-info/PKG-INFO +1 -1
  12. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/.gitignore +0 -0
  13. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/.pre-commit-config.yaml +0 -0
  14. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/LICENSE +0 -0
  15. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/MANIFEST.in +0 -0
  16. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/README.md +0 -0
  17. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/assets/dingtalk.png +0 -0
  18. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/assets/showcase.jpeg +0 -0
  19. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/__init__.py +0 -0
  20. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/algorithm/__init__.py +0 -0
  21. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/algorithm/noise_scheduler/__init__.py +0 -0
  22. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +0 -0
  23. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +0 -0
  24. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +0 -0
  25. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +0 -0
  26. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +0 -0
  27. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  28. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +0 -0
  29. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +0 -0
  30. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +0 -0
  31. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +0 -0
  32. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +0 -0
  33. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +0 -0
  34. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/algorithm/sampler/__init__.py +0 -0
  35. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  36. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +0 -0
  37. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  38. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +0 -0
  39. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +0 -0
  40. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +0 -0
  41. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +0 -0
  42. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +0 -0
  43. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +0 -0
  44. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +0 -0
  45. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +0 -0
  46. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +0 -0
  47. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/models/components/vae.json +0 -0
  48. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/models/flux/flux_dit.json +0 -0
  49. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/models/flux/flux_text_encoder.json +0 -0
  50. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/models/flux/flux_vae.json +0 -0
  51. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/models/sd/sd_text_encoder.json +0 -0
  52. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/models/sd/sd_unet.json +0 -0
  53. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/models/sd3/sd3_dit.json +0 -0
  54. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +0 -0
  55. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +0 -0
  56. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/models/sdxl/sdxl_unet.json +0 -0
  57. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/models/wan/dit/1.3b-t2v.json +0 -0
  58. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/models/wan/dit/14b-flf2v.json +0 -0
  59. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/models/wan/dit/14b-i2v.json +0 -0
  60. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/models/wan/dit/14b-t2v.json +0 -0
  61. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +0 -0
  62. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +0 -0
  63. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +0 -0
  64. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +0 -0
  65. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +0 -0
  66. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  67. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +0 -0
  68. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +0 -0
  69. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +0 -0
  70. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +0 -0
  71. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +0 -0
  72. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +0 -0
  73. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +0 -0
  74. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +0 -0
  75. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +0 -0
  76. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +0 -0
  77. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +0 -0
  78. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  79. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +0 -0
  80. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +0 -0
  81. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/kernels/__init__.py +0 -0
  82. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/__init__.py +0 -0
  83. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/base.py +0 -0
  84. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/basic/__init__.py +0 -0
  85. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/basic/attention.py +0 -0
  86. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/basic/relative_position_emb.py +0 -0
  87. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/basic/timestep.py +0 -0
  88. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/basic/unet_helper.py +0 -0
  89. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/flux/__init__.py +0 -0
  90. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/flux/flux_controlnet.py +0 -0
  91. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/flux/flux_dit_fbcache.py +0 -0
  92. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/flux/flux_ipadapter.py +0 -0
  93. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/flux/flux_redux.py +0 -0
  94. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/flux/flux_text_encoder.py +0 -0
  95. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/sd/__init__.py +0 -0
  96. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/sd/sd_controlnet.py +0 -0
  97. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/sd/sd_text_encoder.py +0 -0
  98. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/sd/sd_unet.py +0 -0
  99. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/sd/sd_vae.py +0 -0
  100. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/sd3/__init__.py +0 -0
  101. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/sd3/sd3_dit.py +0 -0
  102. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/sd3/sd3_text_encoder.py +0 -0
  103. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/sd3/sd3_vae.py +0 -0
  104. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/sdxl/__init__.py +0 -0
  105. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/sdxl/sdxl_controlnet.py +0 -0
  106. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/sdxl/sdxl_text_encoder.py +0 -0
  107. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/sdxl/sdxl_unet.py +0 -0
  108. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/sdxl/sdxl_vae.py +0 -0
  109. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/text_encoder/__init__.py +0 -0
  110. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/text_encoder/clip.py +0 -0
  111. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/text_encoder/siglip.py +0 -0
  112. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/utils.py +0 -0
  113. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/vae/__init__.py +0 -0
  114. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/vae/vae.py +0 -0
  115. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/wan/__init__.py +0 -0
  116. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/wan/wan_dit.py +0 -0
  117. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/wan/wan_image_encoder.py +0 -0
  118. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/wan/wan_text_encoder.py +0 -0
  119. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/models/wan/wan_vae.py +0 -0
  120. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/pipelines/__init__.py +0 -0
  121. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/pipelines/controlnet_helper.py +0 -0
  122. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/pipelines/sd_image.py +0 -0
  123. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/pipelines/sdxl_image.py +0 -0
  124. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/pipelines/wan_video.py +0 -0
  125. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/processor/__init__.py +0 -0
  126. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/processor/canny_processor.py +0 -0
  127. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/processor/depth_processor.py +0 -0
  128. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/tokenizers/__init__.py +0 -0
  129. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/tokenizers/base.py +0 -0
  130. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/tokenizers/clip.py +0 -0
  131. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/tokenizers/t5.py +0 -0
  132. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/tokenizers/wan.py +0 -0
  133. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/tools/__init__.py +0 -0
  134. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/tools/flux_inpainting_tool.py +0 -0
  135. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/tools/flux_outpainting_tool.py +0 -0
  136. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/tools/flux_reference_tool.py +0 -0
  137. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/tools/flux_replace_tool.py +0 -0
  138. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/utils/__init__.py +0 -0
  139. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/utils/constants.py +0 -0
  140. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/utils/download.py +0 -0
  141. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/utils/env.py +0 -0
  142. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/utils/flag.py +0 -0
  143. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/utils/gguf.py +0 -0
  144. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/utils/image.py +0 -0
  145. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/utils/loader.py +0 -0
  146. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/utils/lock.py +0 -0
  147. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/utils/logging.py +0 -0
  148. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/utils/onnx.py +0 -0
  149. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/utils/parallel.py +0 -0
  150. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/utils/platform.py +0 -0
  151. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/utils/prompt.py +0 -0
  152. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine/utils/video.py +0 -0
  153. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine.egg-info/SOURCES.txt +0 -0
  154. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine.egg-info/dependency_links.txt +0 -0
  155. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine.egg-info/requires.txt +0 -0
  156. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/diffsynth_engine.egg-info/top_level.txt +0 -0
  157. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/docs/tutorial.md +0 -0
  158. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/docs/tutorial_zh.md +0 -0
  159. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/pyproject.toml +0 -0
  160. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/setup.cfg +0 -0
  161. {diffsynth_engine-0.3.6.dev6 → diffsynth_engine-0.3.6.dev8}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.3.6.dev6
3
+ Version: 0.3.6.dev8
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -37,14 +37,23 @@ class LoRA(nn.Module):
37
37
  else:
38
38
  delta_w = self.scale * (self.alpha / self.rank) * (self.up.weight @ self.down.weight)
39
39
  if isinstance(w, (nn.Linear, nn.Conv2d)):
40
- delta_w = delta_w.to(device=w.weight.data.device, dtype=w.weight.data.dtype)
40
+ delta_w = delta_w.to(device=w.weight.data.device, dtype=self.dtype)
41
+ w_dtype = w.weight.data.dtype
42
+ w.weight.data = w.weight.data.to(self.dtype)
41
43
  w.weight.data.add_(delta_w)
44
+ w.weight.data = w.weight.data.to(w_dtype)
42
45
  elif isinstance(w, nn.Parameter):
43
- delta_w = delta_w.to(device=w.data.device, dtype=w.data.dtype)
46
+ delta_w = delta_w.to(device=w.data.device, dtype=self.dtype)
47
+ w_dtype = w.data.dtype
48
+ w.data = w.data.to(self.dtype)
44
49
  w.data.add_(delta_w)
50
+ w.data = w.data.to(w_dtype)
45
51
  elif isinstance(w, torch.Tensor):
46
- delta_w = delta_w.to(device=w.device, dtype=w.dtype)
52
+ delta_w = delta_w.to(device=w.device, dtype=self.dtype)
53
+ w_dtype = w.dtype
54
+ w = w.to(self.dtype)
47
55
  w.add_(delta_w)
56
+ w = w.to(w_dtype)
48
57
 
49
58
 
50
59
  class LoRALinear(nn.Linear):
@@ -60,8 +69,8 @@ class LoRALinear(nn.Linear):
60
69
  # LoRA
61
70
  self._lora_dict = OrderedDict()
62
71
  # Frozen LoRA
63
- self._frozen_lora_list = []
64
- self.register_buffer("_original_weight", None)
72
+ self.patched_frozen_lora = False
73
+ self._original_weight = None
65
74
 
66
75
  @staticmethod
67
76
  def from_linear(linear: nn.Linear):
@@ -118,20 +127,27 @@ class LoRALinear(nn.Linear):
118
127
  save_original_weight: bool = True,
119
128
  ):
120
129
  if save_original_weight and self._original_weight is None:
121
- self._original_weight = self.weight.clone()
130
+ if self.weight.dtype == torch.float8_e4m3fn:
131
+ self._original_weight = self.weight.to(dtype=torch.bfloat16, device="cpu", copy=True).pin_memory()
132
+ else:
133
+ self._original_weight = self.weight.to(device="cpu", copy=True).pin_memory()
122
134
  lora = LoRA(scale, rank, alpha, up, down, device, dtype)
123
135
  lora.apply_to(self)
124
- self._frozen_lora_list.append(lora)
136
+ self.patched_frozen_lora = True
125
137
 
126
- def clear(self):
127
- if self._original_weight is None and len(self._frozen_lora_list) > 0:
138
+ def clear(self, release_all_cpu_memory: bool = False):
139
+ if self.patched_frozen_lora and self._original_weight is None:
128
140
  raise RuntimeError(
129
141
  "Current LoRALinear has patched by frozen LoRA, but original weight is not saved, so you cannot clear LoRA."
130
142
  )
131
143
  self._lora_dict.clear()
132
- self._frozen_lora_list = []
133
144
  if self._original_weight is not None:
134
- self.weight.data.copy_(self._original_weight)
145
+ self.weight.data.copy_(
146
+ self._original_weight.to(device=self.weight.data.device, dtype=self.weight.data.dtype)
147
+ )
148
+ if release_all_cpu_memory:
149
+ del self._original_weight
150
+ self.patched_frozen_lora = False
135
151
 
136
152
  def forward(self, x):
137
153
  w_x = super().forward(x)
@@ -161,8 +177,8 @@ class LoRAConv2d(nn.Conv2d):
161
177
  # LoRA
162
178
  self._lora_dict = OrderedDict()
163
179
  # Frozen LoRA
164
- self._frozen_lora_list = []
165
180
  self._original_weight = None
181
+ self.patched_frozen_lora = False
166
182
 
167
183
  @staticmethod
168
184
  def from_conv2d(conv2d: nn.Conv2d):
@@ -257,21 +273,25 @@ class LoRAConv2d(nn.Conv2d):
257
273
  save_original_weight: bool = True,
258
274
  ):
259
275
  if save_original_weight and self._original_weight is None:
260
- self._original_weight = self.weight.clone()
276
+ if self.weight.dtype == torch.float8_e4m3fn:
277
+ self._original_weight = self.weight.to(dtype=torch.bfloat16, device="cpu", copy=True).pin_memory()
278
+ else:
279
+ self._original_weight = self.weight.to(device="cpu", copy=True).pin_memory()
261
280
  lora = self._construct_lora(name, scale, rank, alpha, up, down, device, dtype)
262
281
  lora.apply_to(self)
263
- self._frozen_lora_list.append(lora)
282
+ self.patched_frozen_lora = True
264
283
 
265
- def clear(self):
266
- if self._original_weight is None and len(self._frozen_lora_list) > 0:
284
+ def clear(self, release_all_cpu_memory: bool = False):
285
+ if self.patched_frozen_lora and self._original_weight is None:
267
286
  raise RuntimeError(
268
287
  "Current LoRALinear has patched by frozen LoRA, but original weight is not saved, so you cannot clear LoRA."
269
288
  )
270
289
  self._lora_dict.clear()
271
- self._frozen_lora_list = []
272
290
  if self._original_weight is not None:
273
- self.weight.copy_(self._original_weight)
274
- self._original_weight = None
291
+ self.weight.copy_(self._original_weight.to(device=self.weight.device, dtype=self.weight.dtype))
292
+ if release_all_cpu_memory:
293
+ del self._original_weight
294
+ self.patched_frozen_lora = False
275
295
 
276
296
  def forward(self, x):
277
297
  w_x = super().forward(x)
@@ -1,6 +1,5 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
- import math
4
3
 
5
4
 
6
5
  def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
@@ -83,13 +82,3 @@ class RMSNorm(nn.Module):
83
82
  if self.elementwise_affine:
84
83
  return norm_result * self.weight
85
84
  return norm_result
86
-
87
-
88
- class NewGELUActivation(nn.Module):
89
- """
90
- Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
91
- the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
92
- """
93
-
94
- def forward(self, input: "torch.Tensor") -> "torch.Tensor":
95
- return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
@@ -435,7 +435,7 @@ class FluxDiT(PreTrainedModel):
435
435
  # addition of floating point numbers does not meet commutative law
436
436
  conditioning = self.time_embedder(timestep, hidden_states.dtype)
437
437
  if self.guidance_embedder is not None:
438
- guidance = guidance * 1000
438
+ guidance = (guidance.to(torch.float32) * 1000).to(hidden_states.dtype)
439
439
  conditioning += self.guidance_embedder(guidance, hidden_states.dtype)
440
440
  conditioning += self.pooled_text_embedder(pooled_prompt_emb)
441
441
  rope_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
@@ -53,7 +53,8 @@ class FluxVAEEncoder(VAEEncoder):
53
53
  def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
54
54
  with no_init_weights():
55
55
  model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
56
- model.load_state_dict(state_dict)
56
+ model.load_state_dict(state_dict, assign=True)
57
+ model.to(device=device, dtype=dtype, non_blocking=True)
57
58
  return model
58
59
 
59
60
 
@@ -74,5 +75,6 @@ class FluxVAEDecoder(VAEDecoder):
74
75
  def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
75
76
  with no_init_weights():
76
77
  model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
77
- model.load_state_dict(state_dict)
78
+ model.load_state_dict(state_dict, assign=True)
79
+ model.to(device=device, dtype=dtype, non_blocking=True)
78
80
  return model
@@ -4,7 +4,7 @@ from typing import Dict, Optional
4
4
 
5
5
  from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter
6
6
  from diffsynth_engine.models.basic.relative_position_emb import RelativePositionEmbedding
7
- from diffsynth_engine.models.basic.transformer_helper import RMSNorm, NewGELUActivation
7
+ from diffsynth_engine.models.basic.transformer_helper import RMSNorm
8
8
  from diffsynth_engine.models.basic.attention import Attention
9
9
  from diffsynth_engine.models.utils import no_init_weights
10
10
  from diffsynth_engine.utils.gguf import gguf_inference
@@ -21,14 +21,12 @@ class T5FeedForward(nn.Module):
21
21
  self.wi_1 = nn.Linear(d_model, d_ff, bias=False, device=device, dtype=dtype)
22
22
  self.wo = nn.Linear(d_ff, d_model, bias=False, device=device, dtype=dtype)
23
23
  self.dropout = nn.Dropout(dropout_rate)
24
- self.act = NewGELUActivation()
24
+ self.act = nn.GELU(approximate="tanh")
25
25
 
26
26
  def forward(self, hidden_states):
27
27
  hidden_gelu = self.act(self.wi_0(hidden_states))
28
28
  hidden_linear = self.wi_1(hidden_states)
29
29
  hidden_states = self.dropout(hidden_gelu * hidden_linear)
30
-
31
- hidden_states = hidden_states.to(self.wo.weight.dtype)
32
30
  hidden_states = self.wo(hidden_states)
33
31
  return hidden_states
34
32
 
@@ -5,6 +5,7 @@ from typing import Dict, List, Tuple
5
5
  from PIL import Image
6
6
  from dataclasses import dataclass
7
7
  from diffsynth_engine.utils.offload import enable_sequential_cpu_offload
8
+ from diffsynth_engine.utils.fp8_linear import enable_fp8_autocast
8
9
  from diffsynth_engine.utils.gguf import load_gguf_checkpoint
9
10
  from diffsynth_engine.utils import logging
10
11
  from diffsynth_engine.utils.loader import load_file
@@ -42,6 +43,7 @@ class BasePipeline:
42
43
  self.dtype = dtype
43
44
  self.offload_mode = None
44
45
  self.model_names = []
46
+ self._models_offload_params = {}
45
47
 
46
48
  @classmethod
47
49
  def from_pretrained(
@@ -100,7 +102,10 @@ class BasePipeline:
100
102
  if not os.path.isfile(path):
101
103
  raise FileNotFoundError(f"{path} is not a file")
102
104
  elif path.endswith(".safetensors"):
103
- state_dict.update(**load_file(path, device=device))
105
+ state_dict_ = load_file(path, device=device)
106
+ for key, value in state_dict_.items():
107
+ state_dict[key] = value.to(dtype)
108
+
104
109
  elif path.endswith(".gguf"):
105
110
  state_dict.update(**load_gguf_checkpoint(path, device=device, dtype=dtype))
106
111
  else:
@@ -154,7 +159,7 @@ class BasePipeline:
154
159
  @staticmethod
155
160
  def generate_noise(shape, seed=None, device="cpu", dtype=torch.float16):
156
161
  generator = None if seed is None else torch.Generator(device).manual_seed(seed)
157
- noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
162
+ noise = torch.randn(shape, generator=generator, device=device).to(dtype)
158
163
  return noise
159
164
 
160
165
  def encode_image(
@@ -284,6 +289,10 @@ class BasePipeline:
284
289
  model = getattr(self, model_name)
285
290
  if model is not None:
286
291
  model.to("cpu")
292
+ self._models_offload_params[model_name] = {}
293
+ for name, param in model.named_parameters(recurse=True):
294
+ param.data = param.data.pin_memory()
295
+ self._models_offload_params[model_name][name] = param.data
287
296
  self.offload_mode = "cpu_offload"
288
297
 
289
298
  def _enable_sequential_cpu_offload(self):
@@ -294,6 +303,15 @@ class BasePipeline:
294
303
  enable_sequential_cpu_offload(model, self.device)
295
304
  self.offload_mode = "sequential_cpu_offload"
296
305
 
306
+ def enable_fp8_autocast(
307
+ self, model_names: List[str], compute_dtype: torch.dtype = torch.bfloat16, use_fp8_linear: bool = False
308
+ ):
309
+ for model_name in model_names:
310
+ model = getattr(self, model_name)
311
+ if model is not None:
312
+ enable_fp8_autocast(model, compute_dtype, use_fp8_linear)
313
+ self.fp8_autocast_enabled = True
314
+
297
315
  def load_models_to_device(self, load_model_names: List[str] | None = None):
298
316
  load_model_names = load_model_names if load_model_names else []
299
317
  # only load models to device if offload_mode is set
@@ -308,12 +326,14 @@ class BasePipeline:
308
326
  for model_name in self.model_names:
309
327
  if model_name not in load_model_names:
310
328
  model = getattr(self, model_name)
311
- if model is not None and (p := next(model.parameters(), None)) is not None and p.device != "cpu":
312
- model.to("cpu")
329
+ if model is not None and (p := next(model.parameters(), None)) is not None and p.device != torch.device("cpu"):
330
+ param_cache = self._models_offload_params[model_name]
331
+ for name, param in model.named_parameters(recurse=True):
332
+ param.data = param_cache[name]
313
333
  # load the needed models to device
314
334
  for model_name in load_model_names:
315
335
  model = getattr(self, model_name)
316
- if model is not None and (p := next(model.parameters(), None)) is not None and p.device != self.device:
336
+ if model is not None and (p := next(model.parameters(), None)) is not None and p.device != torch.device(self.device):
317
337
  model.to(self.device)
318
338
  # fresh the cuda cache
319
339
  empty_cache()
@@ -526,29 +526,29 @@ class FluxImagePipeline(BasePipeline):
526
526
  model_config = (
527
527
  model_path_or_config
528
528
  if isinstance(model_path_or_config, FluxModelConfig)
529
- else FluxModelConfig(dit_path=model_path_or_config, dit_dtype=dtype, t5_dtype=dtype, clip_dtype=dtype)
529
+ else FluxModelConfig(dit_path=model_path_or_config, dit_dtype=dtype, t5_dtype=dtype)
530
530
  )
531
531
  if model_config.vae_path is None:
532
- model_config.vae_path = fetch_model("muse/flux_vae", revision="20241015120836", path="ae.safetensors")
532
+ model_config.vae_path = fetch_model("muse/FLUX.1-dev-fp8", path="ae-bf16.safetensors")
533
533
 
534
534
  if model_config.clip_path is None and load_text_encoder:
535
- model_config.clip_path = fetch_model(
536
- "muse/flux_clip_l", revision="20241209", path="clip_l_bf16.safetensors"
537
- )
535
+ model_config.clip_path = fetch_model("muse/FLUX.1-dev-fp8", path="clip-bf16.safetensors")
538
536
  if model_config.t5_path is None and load_text_encoder:
539
537
  model_config.t5_path = fetch_model(
540
- "muse/google_t5_v1_1_xxl", revision="20241024105236", path="t5xxl_v1_1_bf16.safetensors"
538
+ "muse/FLUX.1-dev-fp8", path=["t5-fp8-00001-of-00002.safetensors", "t5-fp8-00002-of-00002.safetensors"]
541
539
  )
542
540
 
543
541
  logger.info(f"loading state dict from {model_config.dit_path} ...")
544
- dit_state_dict = cls.load_model_checkpoint(model_config.dit_path, device="cpu", dtype=dtype)
542
+ dit_state_dict = cls.load_model_checkpoint(model_config.dit_path, device="cpu", dtype=model_config.dit_dtype)
545
543
  logger.info(f"loading state dict from {model_config.vae_path} ...")
546
- vae_state_dict = cls.load_model_checkpoint(model_config.vae_path, device="cpu", dtype=dtype)
544
+ vae_state_dict = cls.load_model_checkpoint(model_config.vae_path, device="cpu", dtype=model_config.vae_dtype)
547
545
  if load_text_encoder:
548
546
  logger.info(f"loading state dict from {model_config.clip_path} ...")
549
- clip_state_dict = cls.load_model_checkpoint(model_config.clip_path, device="cpu", dtype=dtype)
547
+ clip_state_dict = cls.load_model_checkpoint(
548
+ model_config.clip_path, device="cpu", dtype=model_config.clip_dtype
549
+ )
550
550
  logger.info(f"loading state dict from {model_config.t5_path} ...")
551
- t5_state_dict = cls.load_model_checkpoint(model_config.t5_path, device="cpu", dtype=dtype)
551
+ t5_state_dict = cls.load_model_checkpoint(model_config.t5_path, device="cpu", dtype=model_config.t5_dtype)
552
552
 
553
553
  init_device = "cpu" if parallelism > 1 or offload_mode is not None else device
554
554
  if load_text_encoder:
@@ -602,10 +602,21 @@ class FluxImagePipeline(BasePipeline):
602
602
  vae_tile_stride=vae_tile_stride,
603
603
  control_type=control_type,
604
604
  device=device,
605
- dtype=dtype,
605
+ dtype=model_config.dit_dtype,
606
606
  )
607
607
  if offload_mode is not None:
608
608
  pipe.enable_cpu_offload(offload_mode)
609
+ if model_config.dit_dtype == torch.float8_e4m3fn:
610
+ pipe.dtype = torch.bfloat16 # running dtype
611
+ pipe.enable_fp8_autocast(
612
+ model_names=["dit"], compute_dtype=pipe.dtype, use_fp8_linear=model_config.use_fp8_linear
613
+ )
614
+
615
+ if model_config.t5_dtype == torch.float8_e4m3fn:
616
+ pipe.dtype = torch.bfloat16 # running dtype
617
+ pipe.enable_fp8_autocast(
618
+ model_names=["text_encoder_2"], compute_dtype=pipe.dtype, use_fp8_linear=model_config.use_fp8_linear
619
+ )
609
620
 
610
621
  if parallelism > 1:
611
622
  parallel_config = cls.init_parallel_config(parallelism, use_cfg_parallel, model_config)
@@ -803,7 +814,6 @@ class FluxImagePipeline(BasePipeline):
803
814
  current_step=current_step,
804
815
  total_step=total_step,
805
816
  )
806
-
807
817
  self.load_models_to_device(["dit"])
808
818
 
809
819
  noise_pred = self.dit(
@@ -4,6 +4,45 @@ import torch.nn.functional as F
4
4
  from contextlib import contextmanager
5
5
 
6
6
 
7
+ def enable_fp8_autocast(module: nn.Module, compute_dtype: torch.dtype = torch.bfloat16, use_fp8_linear: bool = False):
8
+ if len(list(module.children())) == 0:
9
+ if len(list(module.parameters())) > 0:
10
+ add_fp8_autocast_hook(module, compute_dtype)
11
+ return
12
+ if len(list(module.parameters(recurse=False))) > 0:
13
+ add_fp8_autocast_hook(module, compute_dtype)
14
+ for submodule in module.children():
15
+ if isinstance(submodule, nn.Linear) and use_fp8_linear:
16
+ continue
17
+
18
+ enable_fp8_autocast(submodule, compute_dtype, use_fp8_linear)
19
+
20
+
21
+ def add_fp8_autocast_hook(module: nn.Module, compute_dtype: torch.dtype = torch.bfloat16):
22
+ def _fp8_autocast_pre_hook(module: nn.Module, input_):
23
+ for name, param in module.named_parameters():
24
+ if param.dtype == torch.float8_e4m3fn:
25
+ param.data = param.data.to(compute_dtype)
26
+ new_inputs = []
27
+ for x in input_:
28
+ if isinstance(x, torch.Tensor) and x.dtype in [torch.float8_e4m3fn, torch.float16, torch.bfloat16]:
29
+ new_inputs.append(x.to(compute_dtype))
30
+ else:
31
+ new_inputs.append(x)
32
+ return tuple(new_inputs)
33
+
34
+ def _fp8_autocast_hook(module: nn.Module, input_, output_):
35
+ for name, param in module.named_parameters():
36
+ if param.dtype == compute_dtype:
37
+ param.data = param.data.to(torch.float8_e4m3fn)
38
+
39
+ if getattr(module, "_fp8_autocast_enabled", False):
40
+ return
41
+ module.register_forward_pre_hook(_fp8_autocast_pre_hook)
42
+ module.register_forward_hook(_fp8_autocast_hook)
43
+ setattr(module, "_fp8_autocast_enabled", True)
44
+
45
+
7
46
  def enable_fp8_linear(module: nn.Module):
8
47
  _enable_fp8_linear(module)
9
48
  setattr(module, "fp8_linear_enabled", True)
@@ -18,7 +18,7 @@ def add_cpu_offload_hook(module: nn.Module, device: str = "cuda", recurse: bool
18
18
  def _forward_pre_hook(module: nn.Module, input):
19
19
  offload_params = {}
20
20
  for name, param in module.named_parameters(recurse=recurse):
21
- offload_params[name] = param.data
21
+ offload_params[name] = param.data.pin_memory()
22
22
  param.data = param.data.to(device=device)
23
23
  setattr(module, "_offload_params", offload_params)
24
24
  return tuple(x.to(device=device) if isinstance(x, torch.Tensor) else x for x in input)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.3.6.dev6
3
+ Version: 0.3.6.dev8
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent