diffusers 0.30.2__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. diffusers/__init__.py +38 -2
  2. diffusers/configuration_utils.py +12 -0
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +257 -54
  5. diffusers/loaders/__init__.py +2 -0
  6. diffusers/loaders/ip_adapter.py +5 -1
  7. diffusers/loaders/lora_base.py +14 -7
  8. diffusers/loaders/lora_conversion_utils.py +332 -0
  9. diffusers/loaders/lora_pipeline.py +707 -41
  10. diffusers/loaders/peft.py +1 -0
  11. diffusers/loaders/single_file_utils.py +81 -4
  12. diffusers/loaders/textual_inversion.py +2 -0
  13. diffusers/loaders/unet.py +39 -8
  14. diffusers/models/__init__.py +4 -0
  15. diffusers/models/adapter.py +53 -53
  16. diffusers/models/attention.py +86 -10
  17. diffusers/models/attention_processor.py +169 -133
  18. diffusers/models/autoencoders/autoencoder_kl.py +71 -11
  19. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +287 -85
  20. diffusers/models/controlnet_flux.py +536 -0
  21. diffusers/models/controlnet_sd3.py +7 -3
  22. diffusers/models/controlnet_sparsectrl.py +0 -1
  23. diffusers/models/embeddings.py +238 -61
  24. diffusers/models/embeddings_flax.py +23 -9
  25. diffusers/models/model_loading_utils.py +182 -14
  26. diffusers/models/modeling_utils.py +283 -46
  27. diffusers/models/normalization.py +79 -0
  28. diffusers/models/transformers/__init__.py +1 -0
  29. diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
  30. diffusers/models/transformers/cogvideox_transformer_3d.py +58 -36
  31. diffusers/models/transformers/pixart_transformer_2d.py +9 -1
  32. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  33. diffusers/models/transformers/transformer_flux.py +161 -44
  34. diffusers/models/transformers/transformer_sd3.py +7 -1
  35. diffusers/models/unets/unet_2d_condition.py +8 -8
  36. diffusers/models/unets/unet_motion_model.py +41 -63
  37. diffusers/models/upsampling.py +6 -6
  38. diffusers/pipelines/__init__.py +40 -7
  39. diffusers/pipelines/animatediff/__init__.py +2 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  41. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
  42. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  43. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
  45. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  46. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
  47. diffusers/pipelines/auto_pipeline.py +39 -8
  48. diffusers/pipelines/cogvideo/__init__.py +6 -0
  49. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +32 -34
  50. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
  51. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +837 -0
  52. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +825 -0
  53. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  54. diffusers/pipelines/cogview3/__init__.py +47 -0
  55. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  56. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  57. diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
  58. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
  59. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
  60. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
  61. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
  62. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
  63. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
  64. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  65. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
  66. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  67. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  68. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  69. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  70. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  71. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  72. diffusers/pipelines/flux/__init__.py +10 -0
  73. diffusers/pipelines/flux/pipeline_flux.py +53 -20
  74. diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
  75. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
  76. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
  77. diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
  78. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
  79. diffusers/pipelines/free_noise_utils.py +365 -5
  80. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
  81. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  82. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  83. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  84. diffusers/pipelines/kolors/tokenizer.py +4 -0
  85. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  86. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  87. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  88. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  89. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  90. diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
  91. diffusers/pipelines/pag/__init__.py +6 -0
  92. diffusers/pipelines/pag/pag_utils.py +8 -2
  93. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
  94. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
  95. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
  96. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
  97. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
  98. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  99. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
  100. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  101. diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
  102. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  103. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
  104. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  105. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  106. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  107. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  108. diffusers/pipelines/pipeline_loading_utils.py +225 -27
  109. diffusers/pipelines/pipeline_utils.py +123 -180
  110. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  111. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  112. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  113. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  114. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
  115. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  116. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  117. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  118. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
  119. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
  120. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
  121. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  122. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  123. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
  126. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
  127. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  129. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  130. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  131. diffusers/quantizers/__init__.py +16 -0
  132. diffusers/quantizers/auto.py +126 -0
  133. diffusers/quantizers/base.py +233 -0
  134. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  135. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
  136. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  137. diffusers/quantizers/quantization_config.py +391 -0
  138. diffusers/schedulers/scheduling_ddim.py +4 -1
  139. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  140. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  141. diffusers/schedulers/scheduling_ddpm.py +4 -1
  142. diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
  143. diffusers/schedulers/scheduling_deis_multistep.py +78 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
  145. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
  146. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  147. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
  148. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  149. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  150. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  151. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  152. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  153. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  154. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  155. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  156. diffusers/schedulers/scheduling_sasolver.py +78 -1
  157. diffusers/schedulers/scheduling_unclip.py +4 -1
  158. diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
  159. diffusers/training_utils.py +48 -18
  160. diffusers/utils/__init__.py +2 -1
  161. diffusers/utils/dummy_pt_objects.py +60 -0
  162. diffusers/utils/dummy_torch_and_transformers_objects.py +195 -0
  163. diffusers/utils/hub_utils.py +16 -4
  164. diffusers/utils/import_utils.py +31 -8
  165. diffusers/utils/loading_utils.py +28 -4
  166. diffusers/utils/peft_utils.py +3 -3
  167. diffusers/utils/testing_utils.py +59 -0
  168. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
  169. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/RECORD +173 -147
  170. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/WHEEL +1 -1
  171. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
  172. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
  173. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
diffusers/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.30.2"
1
+ __version__ = "0.31.0"
2
2
 
3
3
  from typing import TYPE_CHECKING
4
4
 
@@ -31,6 +31,7 @@ _import_structure = {
31
31
  "loaders": ["FromOriginalModelMixin"],
32
32
  "models": [],
33
33
  "pipelines": [],
34
+ "quantizers.quantization_config": ["BitsAndBytesConfig"],
34
35
  "schedulers": [],
35
36
  "utils": [
36
37
  "OptionalDependencyNotAvailable",
@@ -84,10 +85,13 @@ else:
84
85
  "AutoencoderOobleck",
85
86
  "AutoencoderTiny",
86
87
  "CogVideoXTransformer3DModel",
88
+ "CogView3PlusTransformer2DModel",
87
89
  "ConsistencyDecoderVAE",
88
90
  "ControlNetModel",
89
91
  "ControlNetXSAdapter",
90
92
  "DiTTransformer2DModel",
93
+ "FluxControlNetModel",
94
+ "FluxMultiControlNetModel",
91
95
  "FluxTransformer2DModel",
92
96
  "HunyuanDiT2DControlNetModel",
93
97
  "HunyuanDiT2DModel",
@@ -121,7 +125,6 @@ else:
121
125
  "VQModel",
122
126
  ]
123
127
  )
124
-
125
128
  _import_structure["optimization"] = [
126
129
  "get_constant_schedule",
127
130
  "get_constant_schedule_with_warmup",
@@ -153,6 +156,7 @@ else:
153
156
  "StableDiffusionMixin",
154
157
  ]
155
158
  )
159
+ _import_structure["quantizers"] = ["DiffusersQuantizer"]
156
160
  _import_structure["schedulers"].extend(
157
161
  [
158
162
  "AmusedScheduler",
@@ -243,6 +247,7 @@ else:
243
247
  "AnimateDiffPipeline",
244
248
  "AnimateDiffSDXLPipeline",
245
249
  "AnimateDiffSparseControlNetPipeline",
250
+ "AnimateDiffVideoToVideoControlNetPipeline",
246
251
  "AnimateDiffVideoToVideoPipeline",
247
252
  "AudioLDM2Pipeline",
248
253
  "AudioLDM2ProjectionModel",
@@ -252,8 +257,17 @@ else:
252
257
  "BlipDiffusionControlNetPipeline",
253
258
  "BlipDiffusionPipeline",
254
259
  "CLIPImageProjection",
260
+ "CogVideoXFunControlPipeline",
261
+ "CogVideoXImageToVideoPipeline",
255
262
  "CogVideoXPipeline",
263
+ "CogVideoXVideoToVideoPipeline",
264
+ "CogView3PlusPipeline",
256
265
  "CycleDiffusionPipeline",
266
+ "FluxControlNetImg2ImgPipeline",
267
+ "FluxControlNetInpaintPipeline",
268
+ "FluxControlNetPipeline",
269
+ "FluxImg2ImgPipeline",
270
+ "FluxInpaintPipeline",
257
271
  "FluxPipeline",
258
272
  "HunyuanDiTControlNetPipeline",
259
273
  "HunyuanDiTPAGPipeline",
@@ -308,6 +322,7 @@ else:
308
322
  "StableCascadeCombinedPipeline",
309
323
  "StableCascadeDecoderPipeline",
310
324
  "StableCascadePriorPipeline",
325
+ "StableDiffusion3ControlNetInpaintingPipeline",
311
326
  "StableDiffusion3ControlNetPipeline",
312
327
  "StableDiffusion3Img2ImgPipeline",
313
328
  "StableDiffusion3InpaintPipeline",
@@ -317,6 +332,7 @@ else:
317
332
  "StableDiffusionAttendAndExcitePipeline",
318
333
  "StableDiffusionControlNetImg2ImgPipeline",
319
334
  "StableDiffusionControlNetInpaintPipeline",
335
+ "StableDiffusionControlNetPAGInpaintPipeline",
320
336
  "StableDiffusionControlNetPAGPipeline",
321
337
  "StableDiffusionControlNetPipeline",
322
338
  "StableDiffusionControlNetXSPipeline",
@@ -332,6 +348,7 @@ else:
332
348
  "StableDiffusionLatentUpscalePipeline",
333
349
  "StableDiffusionLDM3DPipeline",
334
350
  "StableDiffusionModelEditingPipeline",
351
+ "StableDiffusionPAGImg2ImgPipeline",
335
352
  "StableDiffusionPAGPipeline",
336
353
  "StableDiffusionPanoramaPipeline",
337
354
  "StableDiffusionParadigmsPipeline",
@@ -343,6 +360,7 @@ else:
343
360
  "StableDiffusionXLAdapterPipeline",
344
361
  "StableDiffusionXLControlNetImg2ImgPipeline",
345
362
  "StableDiffusionXLControlNetInpaintPipeline",
363
+ "StableDiffusionXLControlNetPAGImg2ImgPipeline",
346
364
  "StableDiffusionXLControlNetPAGPipeline",
347
365
  "StableDiffusionXLControlNetPipeline",
348
366
  "StableDiffusionXLControlNetXSPipeline",
@@ -521,6 +539,7 @@ else:
521
539
 
522
540
  if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
523
541
  from .configuration_utils import ConfigMixin
542
+ from .quantizers.quantization_config import BitsAndBytesConfig
524
543
 
525
544
  try:
526
545
  if not is_onnx_available():
@@ -545,10 +564,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
545
564
  AutoencoderOobleck,
546
565
  AutoencoderTiny,
547
566
  CogVideoXTransformer3DModel,
567
+ CogView3PlusTransformer2DModel,
548
568
  ConsistencyDecoderVAE,
549
569
  ControlNetModel,
550
570
  ControlNetXSAdapter,
551
571
  DiTTransformer2DModel,
572
+ FluxControlNetModel,
573
+ FluxMultiControlNetModel,
552
574
  FluxTransformer2DModel,
553
575
  HunyuanDiT2DControlNetModel,
554
576
  HunyuanDiT2DModel,
@@ -612,6 +634,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
612
634
  ScoreSdeVePipeline,
613
635
  StableDiffusionMixin,
614
636
  )
637
+ from .quantizers import DiffusersQuantizer
615
638
  from .schedulers import (
616
639
  AmusedScheduler,
617
640
  CMStochasticIterativeScheduler,
@@ -684,6 +707,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
684
707
  AnimateDiffPipeline,
685
708
  AnimateDiffSDXLPipeline,
686
709
  AnimateDiffSparseControlNetPipeline,
710
+ AnimateDiffVideoToVideoControlNetPipeline,
687
711
  AnimateDiffVideoToVideoPipeline,
688
712
  AudioLDM2Pipeline,
689
713
  AudioLDM2ProjectionModel,
@@ -691,8 +715,17 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
691
715
  AudioLDMPipeline,
692
716
  AuraFlowPipeline,
693
717
  CLIPImageProjection,
718
+ CogVideoXFunControlPipeline,
719
+ CogVideoXImageToVideoPipeline,
694
720
  CogVideoXPipeline,
721
+ CogVideoXVideoToVideoPipeline,
722
+ CogView3PlusPipeline,
695
723
  CycleDiffusionPipeline,
724
+ FluxControlNetImg2ImgPipeline,
725
+ FluxControlNetInpaintPipeline,
726
+ FluxControlNetPipeline,
727
+ FluxImg2ImgPipeline,
728
+ FluxInpaintPipeline,
696
729
  FluxPipeline,
697
730
  HunyuanDiTControlNetPipeline,
698
731
  HunyuanDiTPAGPipeline,
@@ -756,6 +789,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
756
789
  StableDiffusionAttendAndExcitePipeline,
757
790
  StableDiffusionControlNetImg2ImgPipeline,
758
791
  StableDiffusionControlNetInpaintPipeline,
792
+ StableDiffusionControlNetPAGInpaintPipeline,
759
793
  StableDiffusionControlNetPAGPipeline,
760
794
  StableDiffusionControlNetPipeline,
761
795
  StableDiffusionControlNetXSPipeline,
@@ -771,6 +805,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
771
805
  StableDiffusionLatentUpscalePipeline,
772
806
  StableDiffusionLDM3DPipeline,
773
807
  StableDiffusionModelEditingPipeline,
808
+ StableDiffusionPAGImg2ImgPipeline,
774
809
  StableDiffusionPAGPipeline,
775
810
  StableDiffusionPanoramaPipeline,
776
811
  StableDiffusionParadigmsPipeline,
@@ -782,6 +817,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
782
817
  StableDiffusionXLAdapterPipeline,
783
818
  StableDiffusionXLControlNetImg2ImgPipeline,
784
819
  StableDiffusionXLControlNetInpaintPipeline,
820
+ StableDiffusionXLControlNetPAGImg2ImgPipeline,
785
821
  StableDiffusionXLControlNetPAGPipeline,
786
822
  StableDiffusionXLControlNetPipeline,
787
823
  StableDiffusionXLControlNetXSPipeline,
@@ -510,6 +510,9 @@ class ConfigMixin:
510
510
  # remove private attributes
511
511
  config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
512
512
 
513
+ # remove quantization_config
514
+ config_dict = {k: v for k, v in config_dict.items() if k != "quantization_config"}
515
+
513
516
  # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
514
517
  init_dict = {}
515
518
  for key in expected_keys:
@@ -586,10 +589,19 @@ class ConfigMixin:
586
589
  value = value.as_posix()
587
590
  return value
588
591
 
592
+ if "quantization_config" in config_dict:
593
+ config_dict["quantization_config"] = (
594
+ config_dict.quantization_config.to_dict()
595
+ if not isinstance(config_dict.quantization_config, dict)
596
+ else config_dict.quantization_config
597
+ )
598
+
589
599
  config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
590
600
  # Don't save "_ignore_files" or "_use_default_values"
591
601
  config_dict.pop("_ignore_files", None)
592
602
  config_dict.pop("_use_default_values", None)
603
+ # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
604
+ _ = config_dict.pop("_pre_quantization_dtype", None)
593
605
 
594
606
  return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
595
607
 
@@ -38,7 +38,7 @@ deps = {
38
38
  "regex": "regex!=2019.12.17",
39
39
  "requests": "requests",
40
40
  "tensorboard": "tensorboard",
41
- "torch": "torch>=1.4",
41
+ "torch": "torch>=1.4,<2.5.0",
42
42
  "torchvision": "torchvision",
43
43
  "transformers": "transformers>=4.41.2",
44
44
  "urllib3": "urllib3<=2.0.0",