diffsynth-engine 0.6.1.dev29__py3-none-any.whl → 0.6.1.dev30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -74,9 +74,9 @@ class BasePipeline:
74
74
  component.load_state_dict(state_dict, assign=True)
75
75
  component.to(device=device, dtype=dtype, non_blocking=True)
76
76
 
77
- def load_loras(
77
+ def _load_lora_state_dicts(
78
78
  self,
79
- lora_list: List[Tuple[str, Union[float, LoraConfig]]],
79
+ lora_state_dict_list: List[Tuple[Dict[str, torch.Tensor], Union[float, LoraConfig], str]],
80
80
  fused: bool = True,
81
81
  save_original_weight: bool = False,
82
82
  lora_converter: Optional[LoRAStateDictConverter] = None,
@@ -84,29 +84,30 @@ class BasePipeline:
84
84
  if not lora_converter:
85
85
  lora_converter = self.lora_converter
86
86
 
87
- for lora_path, lora_item in lora_list:
87
+ for state_dict, lora_item, lora_name in lora_state_dict_list:
88
88
  if isinstance(lora_item, float):
89
89
  lora_scale = lora_item
90
90
  scheduler_config = None
91
- if isinstance(lora_item, LoraConfig):
91
+ elif isinstance(lora_item, LoraConfig):
92
92
  lora_scale = lora_item.scale
93
93
  scheduler_config = lora_item.scheduler_config
94
+ else:
95
+ raise ValueError(f"lora_item must be float or LoraConfig, got {type(lora_item)}")
94
96
 
95
- logger.info(f"loading lora from {lora_path} with LoraConfig (scale={lora_scale})")
96
- state_dict = load_file(lora_path, device=self.device)
97
+ logger.info(f"loading lora from state_dict '{lora_name}' with scale={lora_scale}")
97
98
 
98
99
  if scheduler_config is not None:
99
100
  self.apply_scheduler_config(scheduler_config)
100
101
  logger.info(f"Applied scheduler args from LoraConfig: {scheduler_config}")
101
102
 
102
103
  lora_state_dict = lora_converter.convert(state_dict)
103
- for model_name, state_dict in lora_state_dict.items():
104
+ for model_name, model_state_dict in lora_state_dict.items():
104
105
  model = getattr(self, model_name)
105
106
  lora_args = []
106
- for key, param in state_dict.items():
107
+ for key, param in model_state_dict.items():
107
108
  lora_args.append(
108
109
  {
109
- "name": lora_path,
110
+ "name": lora_name,
110
111
  "key": key,
111
112
  "scale": lora_scale,
112
113
  "rank": param["rank"],
@@ -120,6 +121,26 @@ class BasePipeline:
120
121
  )
121
122
  model.load_loras(lora_args, fused=fused)
122
123
 
124
+ def load_loras(
125
+ self,
126
+ lora_list: List[Tuple[str, Union[float, LoraConfig]]],
127
+ fused: bool = True,
128
+ save_original_weight: bool = False,
129
+ lora_converter: Optional[LoRAStateDictConverter] = None,
130
+ ):
131
+ lora_state_dict_list = []
132
+ for lora_path, lora_item in lora_list:
133
+ logger.info(f"loading lora from {lora_path}")
134
+ state_dict = load_file(lora_path, device=self.device)
135
+ lora_state_dict_list.append((state_dict, lora_item, lora_path))
136
+
137
+ self._load_lora_state_dicts(
138
+ lora_state_dict_list=lora_state_dict_list,
139
+ fused=fused,
140
+ save_original_weight=save_original_weight,
141
+ lora_converter=lora_converter,
142
+ )
143
+
123
144
  def load_lora(self, path: str, scale: float, fused: bool = True, save_original_weight: bool = False):
124
145
  self.load_loras([(path, scale)], fused, save_original_weight)
125
146
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.6.1.dev29
3
+ Version: 0.6.1.dev30
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -143,7 +143,7 @@ diffsynth_engine/models/wan/wan_s2v_dit.py,sha256=j63ulcWLY4XGITOKUMGX292LtSEtP-
143
143
  diffsynth_engine/models/wan/wan_text_encoder.py,sha256=OERlmwOqthAFPNnnT2sXJ4OjyyRmsRLx7VGp1zlBkLU,11021
144
144
  diffsynth_engine/models/wan/wan_vae.py,sha256=dC7MoUFeXRL7SIY0LG1OOUiZW-pp9IbXCghutMxpXr4,38889
145
145
  diffsynth_engine/pipelines/__init__.py,sha256=jh-4LSJ0vqlXiT8BgFgRIQxuAr2atEPyHrxXWj-Ud1U,604
146
- diffsynth_engine/pipelines/base.py,sha256=BNMNL-OU-9ilUv7O60trA3_rjHA21d6Oc5PKzKYBa80,16347
146
+ diffsynth_engine/pipelines/base.py,sha256=ShRiX5MY6bUkRKfuGrA1aalAqeHyeZxhzT87Mwc30b4,17231
147
147
  diffsynth_engine/pipelines/flux_image.py,sha256=L0ggxpthLD8a5-zdPHu9z668uWBei9YzPb4PFVypDNU,50707
148
148
  diffsynth_engine/pipelines/hunyuan3d_shape.py,sha256=TNV0Wr09Dj2bzzlpua9WioCClOj3YiLfE6utI9aWL8A,8164
149
149
  diffsynth_engine/pipelines/qwen_image.py,sha256=ktOirdU2ljgb6vHhXosC0tWgXI3gwvsoAtrYKYvMwzI,35719
@@ -190,8 +190,8 @@ diffsynth_engine/utils/video.py,sha256=8FCaeqIdUsWMgWI_6SO9SPynsToGcLCQAVYFTc4CD
190
190
  diffsynth_engine/utils/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
191
191
  diffsynth_engine/utils/memory/linear_regression.py,sha256=oW_EQEw13oPoyUrxiL8A7Ksa5AuJ2ynI2qhCbfAuZbg,3930
192
192
  diffsynth_engine/utils/memory/memory_predcit_model.py,sha256=EXprSl_zlVjgfMWNXP-iw83Ot3hyMcgYaRPv-dvyL84,3943
193
- diffsynth_engine-0.6.1.dev29.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
194
- diffsynth_engine-0.6.1.dev29.dist-info/METADATA,sha256=8A5q0qhRMxeJi7IOvP3dcqk58BsgIBxy16ndlnDM_6I,1164
195
- diffsynth_engine-0.6.1.dev29.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
196
- diffsynth_engine-0.6.1.dev29.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
197
- diffsynth_engine-0.6.1.dev29.dist-info/RECORD,,
193
+ diffsynth_engine-0.6.1.dev30.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
194
+ diffsynth_engine-0.6.1.dev30.dist-info/METADATA,sha256=z-j4fdSyJwgilKYRl-MrSlhicE8MJP9uvoGYYTFrYKk,1164
195
+ diffsynth_engine-0.6.1.dev30.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
196
+ diffsynth_engine-0.6.1.dev30.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
197
+ diffsynth_engine-0.6.1.dev30.dist-info/RECORD,,