diffusers 0.30.3__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. diffusers/__init__.py +34 -2
  2. diffusers/configuration_utils.py +12 -0
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +257 -54
  5. diffusers/loaders/__init__.py +2 -0
  6. diffusers/loaders/ip_adapter.py +5 -1
  7. diffusers/loaders/lora_base.py +14 -7
  8. diffusers/loaders/lora_conversion_utils.py +332 -0
  9. diffusers/loaders/lora_pipeline.py +707 -41
  10. diffusers/loaders/peft.py +1 -0
  11. diffusers/loaders/single_file_utils.py +81 -4
  12. diffusers/loaders/textual_inversion.py +2 -0
  13. diffusers/loaders/unet.py +39 -8
  14. diffusers/models/__init__.py +4 -0
  15. diffusers/models/adapter.py +53 -53
  16. diffusers/models/attention.py +86 -10
  17. diffusers/models/attention_processor.py +169 -133
  18. diffusers/models/autoencoders/autoencoder_kl.py +71 -11
  19. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +187 -88
  20. diffusers/models/controlnet_flux.py +536 -0
  21. diffusers/models/controlnet_sd3.py +7 -3
  22. diffusers/models/controlnet_sparsectrl.py +0 -1
  23. diffusers/models/embeddings.py +170 -61
  24. diffusers/models/embeddings_flax.py +23 -9
  25. diffusers/models/model_loading_utils.py +182 -14
  26. diffusers/models/modeling_utils.py +283 -46
  27. diffusers/models/normalization.py +79 -0
  28. diffusers/models/transformers/__init__.py +1 -0
  29. diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
  30. diffusers/models/transformers/cogvideox_transformer_3d.py +23 -2
  31. diffusers/models/transformers/pixart_transformer_2d.py +9 -1
  32. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  33. diffusers/models/transformers/transformer_flux.py +161 -44
  34. diffusers/models/transformers/transformer_sd3.py +7 -1
  35. diffusers/models/unets/unet_2d_condition.py +8 -8
  36. diffusers/models/unets/unet_motion_model.py +41 -63
  37. diffusers/models/upsampling.py +6 -6
  38. diffusers/pipelines/__init__.py +35 -6
  39. diffusers/pipelines/animatediff/__init__.py +2 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  41. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
  42. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  43. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
  45. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  46. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
  47. diffusers/pipelines/auto_pipeline.py +39 -8
  48. diffusers/pipelines/cogvideo/__init__.py +2 -0
  49. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +30 -17
  50. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
  51. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +41 -31
  52. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +42 -29
  53. diffusers/pipelines/cogview3/__init__.py +47 -0
  54. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  55. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  56. diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
  57. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
  58. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
  59. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
  60. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
  61. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
  62. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
  63. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  64. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
  65. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  66. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  67. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  68. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  69. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  70. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  71. diffusers/pipelines/flux/__init__.py +10 -0
  72. diffusers/pipelines/flux/pipeline_flux.py +53 -20
  73. diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
  74. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
  75. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
  76. diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
  77. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
  78. diffusers/pipelines/free_noise_utils.py +365 -5
  79. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
  80. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  81. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  82. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  83. diffusers/pipelines/kolors/tokenizer.py +4 -0
  84. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  85. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  86. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  87. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  88. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  89. diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
  90. diffusers/pipelines/pag/__init__.py +6 -0
  91. diffusers/pipelines/pag/pag_utils.py +8 -2
  92. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
  93. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
  94. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
  95. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
  96. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
  97. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  98. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
  99. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  100. diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
  101. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  102. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
  103. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  104. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  105. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  106. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  107. diffusers/pipelines/pipeline_loading_utils.py +225 -27
  108. diffusers/pipelines/pipeline_utils.py +123 -180
  109. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  110. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  111. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  112. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  113. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
  114. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  115. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  116. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  117. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
  118. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
  119. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
  120. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  121. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  122. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  123. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
  126. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  127. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  129. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  130. diffusers/quantizers/__init__.py +16 -0
  131. diffusers/quantizers/auto.py +126 -0
  132. diffusers/quantizers/base.py +233 -0
  133. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  134. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
  135. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  136. diffusers/quantizers/quantization_config.py +391 -0
  137. diffusers/schedulers/scheduling_ddim.py +4 -1
  138. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  139. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  140. diffusers/schedulers/scheduling_ddpm.py +4 -1
  141. diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
  142. diffusers/schedulers/scheduling_deis_multistep.py +78 -1
  143. diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
  145. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  146. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
  147. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  148. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  149. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  150. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  151. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  152. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  153. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  154. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  155. diffusers/schedulers/scheduling_sasolver.py +78 -1
  156. diffusers/schedulers/scheduling_unclip.py +4 -1
  157. diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
  158. diffusers/training_utils.py +48 -18
  159. diffusers/utils/__init__.py +2 -1
  160. diffusers/utils/dummy_pt_objects.py +60 -0
  161. diffusers/utils/dummy_torch_and_transformers_objects.py +165 -0
  162. diffusers/utils/hub_utils.py +16 -4
  163. diffusers/utils/import_utils.py +31 -8
  164. diffusers/utils/loading_utils.py +28 -4
  165. diffusers/utils/peft_utils.py +3 -3
  166. diffusers/utils/testing_utils.py +59 -0
  167. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
  168. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/RECORD +172 -149
  169. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
  170. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/WHEEL +0 -0
  171. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
  172. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
diffusers/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.30.3"
1
+ __version__ = "0.31.0"
2
2
 
3
3
  from typing import TYPE_CHECKING
4
4
 
@@ -31,6 +31,7 @@ _import_structure = {
31
31
  "loaders": ["FromOriginalModelMixin"],
32
32
  "models": [],
33
33
  "pipelines": [],
34
+ "quantizers.quantization_config": ["BitsAndBytesConfig"],
34
35
  "schedulers": [],
35
36
  "utils": [
36
37
  "OptionalDependencyNotAvailable",
@@ -84,10 +85,13 @@ else:
84
85
  "AutoencoderOobleck",
85
86
  "AutoencoderTiny",
86
87
  "CogVideoXTransformer3DModel",
88
+ "CogView3PlusTransformer2DModel",
87
89
  "ConsistencyDecoderVAE",
88
90
  "ControlNetModel",
89
91
  "ControlNetXSAdapter",
90
92
  "DiTTransformer2DModel",
93
+ "FluxControlNetModel",
94
+ "FluxMultiControlNetModel",
91
95
  "FluxTransformer2DModel",
92
96
  "HunyuanDiT2DControlNetModel",
93
97
  "HunyuanDiT2DModel",
@@ -121,7 +125,6 @@ else:
121
125
  "VQModel",
122
126
  ]
123
127
  )
124
-
125
128
  _import_structure["optimization"] = [
126
129
  "get_constant_schedule",
127
130
  "get_constant_schedule_with_warmup",
@@ -153,6 +156,7 @@ else:
153
156
  "StableDiffusionMixin",
154
157
  ]
155
158
  )
159
+ _import_structure["quantizers"] = ["DiffusersQuantizer"]
156
160
  _import_structure["schedulers"].extend(
157
161
  [
158
162
  "AmusedScheduler",
@@ -243,6 +247,7 @@ else:
243
247
  "AnimateDiffPipeline",
244
248
  "AnimateDiffSDXLPipeline",
245
249
  "AnimateDiffSparseControlNetPipeline",
250
+ "AnimateDiffVideoToVideoControlNetPipeline",
246
251
  "AnimateDiffVideoToVideoPipeline",
247
252
  "AudioLDM2Pipeline",
248
253
  "AudioLDM2ProjectionModel",
@@ -252,10 +257,17 @@ else:
252
257
  "BlipDiffusionControlNetPipeline",
253
258
  "BlipDiffusionPipeline",
254
259
  "CLIPImageProjection",
260
+ "CogVideoXFunControlPipeline",
255
261
  "CogVideoXImageToVideoPipeline",
256
262
  "CogVideoXPipeline",
257
263
  "CogVideoXVideoToVideoPipeline",
264
+ "CogView3PlusPipeline",
258
265
  "CycleDiffusionPipeline",
266
+ "FluxControlNetImg2ImgPipeline",
267
+ "FluxControlNetInpaintPipeline",
268
+ "FluxControlNetPipeline",
269
+ "FluxImg2ImgPipeline",
270
+ "FluxInpaintPipeline",
259
271
  "FluxPipeline",
260
272
  "HunyuanDiTControlNetPipeline",
261
273
  "HunyuanDiTPAGPipeline",
@@ -310,6 +322,7 @@ else:
310
322
  "StableCascadeCombinedPipeline",
311
323
  "StableCascadeDecoderPipeline",
312
324
  "StableCascadePriorPipeline",
325
+ "StableDiffusion3ControlNetInpaintingPipeline",
313
326
  "StableDiffusion3ControlNetPipeline",
314
327
  "StableDiffusion3Img2ImgPipeline",
315
328
  "StableDiffusion3InpaintPipeline",
@@ -319,6 +332,7 @@ else:
319
332
  "StableDiffusionAttendAndExcitePipeline",
320
333
  "StableDiffusionControlNetImg2ImgPipeline",
321
334
  "StableDiffusionControlNetInpaintPipeline",
335
+ "StableDiffusionControlNetPAGInpaintPipeline",
322
336
  "StableDiffusionControlNetPAGPipeline",
323
337
  "StableDiffusionControlNetPipeline",
324
338
  "StableDiffusionControlNetXSPipeline",
@@ -334,6 +348,7 @@ else:
334
348
  "StableDiffusionLatentUpscalePipeline",
335
349
  "StableDiffusionLDM3DPipeline",
336
350
  "StableDiffusionModelEditingPipeline",
351
+ "StableDiffusionPAGImg2ImgPipeline",
337
352
  "StableDiffusionPAGPipeline",
338
353
  "StableDiffusionPanoramaPipeline",
339
354
  "StableDiffusionParadigmsPipeline",
@@ -345,6 +360,7 @@ else:
345
360
  "StableDiffusionXLAdapterPipeline",
346
361
  "StableDiffusionXLControlNetImg2ImgPipeline",
347
362
  "StableDiffusionXLControlNetInpaintPipeline",
363
+ "StableDiffusionXLControlNetPAGImg2ImgPipeline",
348
364
  "StableDiffusionXLControlNetPAGPipeline",
349
365
  "StableDiffusionXLControlNetPipeline",
350
366
  "StableDiffusionXLControlNetXSPipeline",
@@ -523,6 +539,7 @@ else:
523
539
 
524
540
  if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
525
541
  from .configuration_utils import ConfigMixin
542
+ from .quantizers.quantization_config import BitsAndBytesConfig
526
543
 
527
544
  try:
528
545
  if not is_onnx_available():
@@ -547,10 +564,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
547
564
  AutoencoderOobleck,
548
565
  AutoencoderTiny,
549
566
  CogVideoXTransformer3DModel,
567
+ CogView3PlusTransformer2DModel,
550
568
  ConsistencyDecoderVAE,
551
569
  ControlNetModel,
552
570
  ControlNetXSAdapter,
553
571
  DiTTransformer2DModel,
572
+ FluxControlNetModel,
573
+ FluxMultiControlNetModel,
554
574
  FluxTransformer2DModel,
555
575
  HunyuanDiT2DControlNetModel,
556
576
  HunyuanDiT2DModel,
@@ -614,6 +634,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
614
634
  ScoreSdeVePipeline,
615
635
  StableDiffusionMixin,
616
636
  )
637
+ from .quantizers import DiffusersQuantizer
617
638
  from .schedulers import (
618
639
  AmusedScheduler,
619
640
  CMStochasticIterativeScheduler,
@@ -686,6 +707,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
686
707
  AnimateDiffPipeline,
687
708
  AnimateDiffSDXLPipeline,
688
709
  AnimateDiffSparseControlNetPipeline,
710
+ AnimateDiffVideoToVideoControlNetPipeline,
689
711
  AnimateDiffVideoToVideoPipeline,
690
712
  AudioLDM2Pipeline,
691
713
  AudioLDM2ProjectionModel,
@@ -693,10 +715,17 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
693
715
  AudioLDMPipeline,
694
716
  AuraFlowPipeline,
695
717
  CLIPImageProjection,
718
+ CogVideoXFunControlPipeline,
696
719
  CogVideoXImageToVideoPipeline,
697
720
  CogVideoXPipeline,
698
721
  CogVideoXVideoToVideoPipeline,
722
+ CogView3PlusPipeline,
699
723
  CycleDiffusionPipeline,
724
+ FluxControlNetImg2ImgPipeline,
725
+ FluxControlNetInpaintPipeline,
726
+ FluxControlNetPipeline,
727
+ FluxImg2ImgPipeline,
728
+ FluxInpaintPipeline,
700
729
  FluxPipeline,
701
730
  HunyuanDiTControlNetPipeline,
702
731
  HunyuanDiTPAGPipeline,
@@ -760,6 +789,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
760
789
  StableDiffusionAttendAndExcitePipeline,
761
790
  StableDiffusionControlNetImg2ImgPipeline,
762
791
  StableDiffusionControlNetInpaintPipeline,
792
+ StableDiffusionControlNetPAGInpaintPipeline,
763
793
  StableDiffusionControlNetPAGPipeline,
764
794
  StableDiffusionControlNetPipeline,
765
795
  StableDiffusionControlNetXSPipeline,
@@ -775,6 +805,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
775
805
  StableDiffusionLatentUpscalePipeline,
776
806
  StableDiffusionLDM3DPipeline,
777
807
  StableDiffusionModelEditingPipeline,
808
+ StableDiffusionPAGImg2ImgPipeline,
778
809
  StableDiffusionPAGPipeline,
779
810
  StableDiffusionPanoramaPipeline,
780
811
  StableDiffusionParadigmsPipeline,
@@ -786,6 +817,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
786
817
  StableDiffusionXLAdapterPipeline,
787
818
  StableDiffusionXLControlNetImg2ImgPipeline,
788
819
  StableDiffusionXLControlNetInpaintPipeline,
820
+ StableDiffusionXLControlNetPAGImg2ImgPipeline,
789
821
  StableDiffusionXLControlNetPAGPipeline,
790
822
  StableDiffusionXLControlNetPipeline,
791
823
  StableDiffusionXLControlNetXSPipeline,
@@ -510,6 +510,9 @@ class ConfigMixin:
510
510
  # remove private attributes
511
511
  config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
512
512
 
513
+ # remove quantization_config
514
+ config_dict = {k: v for k, v in config_dict.items() if k != "quantization_config"}
515
+
513
516
  # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
514
517
  init_dict = {}
515
518
  for key in expected_keys:
@@ -586,10 +589,19 @@ class ConfigMixin:
586
589
  value = value.as_posix()
587
590
  return value
588
591
 
592
+ if "quantization_config" in config_dict:
593
+ config_dict["quantization_config"] = (
594
+ config_dict.quantization_config.to_dict()
595
+ if not isinstance(config_dict.quantization_config, dict)
596
+ else config_dict.quantization_config
597
+ )
598
+
589
599
  config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
590
600
  # Don't save "_ignore_files" or "_use_default_values"
591
601
  config_dict.pop("_ignore_files", None)
592
602
  config_dict.pop("_use_default_values", None)
603
+ # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
604
+ _ = config_dict.pop("_pre_quantization_dtype", None)
593
605
 
594
606
  return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
595
607
 
@@ -38,7 +38,7 @@ deps = {
38
38
  "regex": "regex!=2019.12.17",
39
39
  "requests": "requests",
40
40
  "tensorboard": "tensorboard",
41
- "torch": "torch>=1.4",
41
+ "torch": "torch>=1.4,<2.5.0",
42
42
  "torchvision": "torchvision",
43
43
  "transformers": "transformers>=4.41.2",
44
44
  "urllib3": "urllib3<=2.0.0",