mmgp 3.0.9__py3-none-any.whl → 3.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mmgp might be problematic. Click here for more details.
- mmgp/offload.py +386 -330
- mmgp/safetensors2.py +33 -24
- {mmgp-3.0.9.dist-info → mmgp-3.1.0.dist-info}/METADATA +3 -3
- mmgp-3.1.0.dist-info/RECORD +9 -0
- mmgp-3.0.9.dist-info/RECORD +0 -9
- {mmgp-3.0.9.dist-info → mmgp-3.1.0.dist-info}/LICENSE.md +0 -0
- {mmgp-3.0.9.dist-info → mmgp-3.1.0.dist-info}/WHEEL +0 -0
- {mmgp-3.0.9.dist-info → mmgp-3.1.0.dist-info}/top_level.txt +0 -0
mmgp/offload.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# ------------------ Memory Management 3.
|
|
1
|
+
# ------------------ Memory Management 3.1 for the GPU Poor by DeepBeepMeep (mmgp)------------------
|
|
2
2
|
#
|
|
3
3
|
# This module contains multiples optimisations so that models such as Flux (and derived), Mochi, CogView, HunyuanVideo, ... can run smoothly on a 24 GB GPU limited card.
|
|
4
4
|
# This a replacement for the accelerate library that should in theory manage offloading, but doesn't work properly with models that are loaded / unloaded several
|
|
@@ -79,7 +79,7 @@ from mmgp import profile_type
|
|
|
79
79
|
from optimum.quanto import freeze, qfloat8, qint4 , qint8, quantize, QModuleMixin, QTensor, quantize_module
|
|
80
80
|
|
|
81
81
|
|
|
82
|
-
|
|
82
|
+
shared_state = {}
|
|
83
83
|
|
|
84
84
|
mmm = safetensors2.mmm
|
|
85
85
|
|
|
@@ -154,33 +154,75 @@ def _get_max_reservable_memory(perc_reserved_mem_max):
|
|
|
154
154
|
perc_reserved_mem_max = 0.40 if os.name == 'nt' else 0.5
|
|
155
155
|
return perc_reserved_mem_max * physical_memory
|
|
156
156
|
|
|
157
|
-
def _detect_main_towers(model, verboseLevel=1):
|
|
157
|
+
def _detect_main_towers(model, min_floors = 5, verboseLevel=1):
|
|
158
158
|
cur_blocks_prefix = None
|
|
159
159
|
towers_modules= []
|
|
160
160
|
towers_names= []
|
|
161
161
|
|
|
162
|
+
floors_modules= []
|
|
163
|
+
tower_name = None
|
|
164
|
+
|
|
165
|
+
|
|
162
166
|
for submodule_name, submodule in model.named_modules():
|
|
167
|
+
|
|
163
168
|
if submodule_name=='':
|
|
164
169
|
continue
|
|
165
170
|
|
|
166
|
-
if
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
171
|
+
if cur_blocks_prefix != None:
|
|
172
|
+
if submodule_name.startswith(cur_blocks_prefix):
|
|
173
|
+
depth_prefix = cur_blocks_prefix.split(".")
|
|
174
|
+
depth_name = submodule_name.split(".")
|
|
175
|
+
level = depth_name[len(depth_prefix)-1]
|
|
176
|
+
pre , num = _extract_num_from_str(level)
|
|
177
|
+
|
|
178
|
+
if num != cur_blocks_seq:
|
|
179
|
+
floors_modules.append(submodule)
|
|
175
180
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
181
|
+
cur_blocks_seq = num
|
|
182
|
+
else:
|
|
183
|
+
if len(floors_modules) >= min_floors:
|
|
184
|
+
towers_modules += floors_modules
|
|
185
|
+
towers_names.append(tower_name)
|
|
186
|
+
tower_name = None
|
|
187
|
+
floors_modules= []
|
|
188
|
+
cur_blocks_prefix, cur_blocks_seq = None, -1
|
|
189
|
+
|
|
190
|
+
if cur_blocks_prefix == None:
|
|
191
|
+
pre , num = _extract_num_from_str(submodule_name)
|
|
192
|
+
if isinstance(submodule, (torch.nn.ModuleList)):
|
|
193
|
+
cur_blocks_prefix, cur_blocks_seq = pre + ".", -1
|
|
194
|
+
tower_name = submodule_name + ".*"
|
|
195
|
+
elif num >=0:
|
|
196
|
+
cur_blocks_prefix, cur_blocks_seq = pre, num
|
|
197
|
+
tower_name = submodule_name[ :-1] + "*"
|
|
198
|
+
floors_modules.append(submodule)
|
|
199
|
+
|
|
200
|
+
if len(floors_modules) >= min_floors:
|
|
201
|
+
towers_modules += floors_modules
|
|
202
|
+
towers_names.append(tower_name)
|
|
203
|
+
|
|
204
|
+
# for submodule_name, submodule in model.named_modules():
|
|
205
|
+
# if submodule_name=='':
|
|
206
|
+
# continue
|
|
207
|
+
|
|
208
|
+
# if isinstance(submodule, torch.nn.ModuleList):
|
|
209
|
+
# newList =False
|
|
210
|
+
# if cur_blocks_prefix == None:
|
|
211
|
+
# cur_blocks_prefix = submodule_name + "."
|
|
212
|
+
# newList = True
|
|
213
|
+
# else:
|
|
214
|
+
# if not submodule_name.startswith(cur_blocks_prefix):
|
|
215
|
+
# cur_blocks_prefix = submodule_name + "."
|
|
216
|
+
# newList = True
|
|
217
|
+
|
|
218
|
+
# if newList and len(submodule)>=5:
|
|
219
|
+
# towers_names.append(submodule_name)
|
|
220
|
+
# towers_modules.append(submodule)
|
|
179
221
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
222
|
+
# else:
|
|
223
|
+
# if cur_blocks_prefix is not None:
|
|
224
|
+
# if not submodule_name.startswith(cur_blocks_prefix):
|
|
225
|
+
# cur_blocks_prefix = None
|
|
184
226
|
|
|
185
227
|
return towers_names, towers_modules
|
|
186
228
|
|
|
@@ -194,7 +236,7 @@ def _get_model(model_path):
|
|
|
194
236
|
_path = Path(model_path).parts
|
|
195
237
|
_filename = _path[-1]
|
|
196
238
|
_path = _path[:-1]
|
|
197
|
-
if len(_path)
|
|
239
|
+
if len(_path)<=1:
|
|
198
240
|
raise("file not found")
|
|
199
241
|
else:
|
|
200
242
|
from huggingface_hub import hf_hub_download #snapshot_download,
|
|
@@ -369,8 +411,16 @@ def _welcome():
|
|
|
369
411
|
if welcome_displayed:
|
|
370
412
|
return
|
|
371
413
|
welcome_displayed = True
|
|
372
|
-
print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.
|
|
414
|
+
print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.1) by DeepBeepMeep ************{ENDC}{UNBOLD}")
|
|
373
415
|
|
|
416
|
+
def _extract_num_from_str(num_in_str):
|
|
417
|
+
for i in range(len(num_in_str)):
|
|
418
|
+
if not num_in_str[-i-1:].isnumeric():
|
|
419
|
+
if i == 0:
|
|
420
|
+
return num_in_str, -1
|
|
421
|
+
else:
|
|
422
|
+
return num_in_str[: -i], int(num_in_str[-i:])
|
|
423
|
+
return "", int(num_in_str)
|
|
374
424
|
|
|
375
425
|
def _quantize_dirty_hack(model):
|
|
376
426
|
# dirty hack: add a hook on state_dict() to return a fake non quantized state_dict if called by Lora Diffusers initialization functions
|
|
@@ -581,6 +631,255 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 10
|
|
|
581
631
|
|
|
582
632
|
return True
|
|
583
633
|
|
|
634
|
+
def load_loras_into_model(model, lora_path, lora_multi = None, verboseLevel = -1):
|
|
635
|
+
verboseLevel = _compute_verbose_level(verboseLevel)
|
|
636
|
+
|
|
637
|
+
if inject_adapter_in_model == None or set_weights_and_activate_adapters == None or get_peft_kwargs == None:
|
|
638
|
+
raise Exception("Unable to load Lora, missing 'peft' and / or 'diffusers' modules")
|
|
639
|
+
|
|
640
|
+
if not isinstance(lora_path, list):
|
|
641
|
+
lora_path = [lora_path]
|
|
642
|
+
|
|
643
|
+
if lora_multi is None:
|
|
644
|
+
lora_multi = [1. for _ in lora_path]
|
|
645
|
+
|
|
646
|
+
for i, path in enumerate(lora_path):
|
|
647
|
+
adapter_name = str(i)
|
|
648
|
+
|
|
649
|
+
state_dict = safetensors2.torch_load_file(path)
|
|
650
|
+
|
|
651
|
+
keys = list(state_dict.keys())
|
|
652
|
+
if len(keys) == 0:
|
|
653
|
+
raise Exception(f"Empty Lora '{path}'")
|
|
654
|
+
|
|
655
|
+
|
|
656
|
+
network_alphas = {}
|
|
657
|
+
for k in keys:
|
|
658
|
+
if "alpha" in k:
|
|
659
|
+
alpha_value = state_dict.pop(k)
|
|
660
|
+
if not ( (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
|
|
661
|
+
alpha_value, float
|
|
662
|
+
)):
|
|
663
|
+
network_alphas[k] = torch.tensor( float(alpha_value.item() ) )
|
|
664
|
+
|
|
665
|
+
pos = keys[0].find(".")
|
|
666
|
+
prefix = keys[0][0:pos]
|
|
667
|
+
if not any( prefix.startswith(some_prefix) for some_prefix in ["diffusion_model", "transformer"]):
|
|
668
|
+
msg = f"No compatible weight was found in Lora file '{path}'. Please check that it is compatible with the Diffusers format."
|
|
669
|
+
raise Exception(msg)
|
|
670
|
+
|
|
671
|
+
transformer = model
|
|
672
|
+
|
|
673
|
+
transformer_keys = [k for k in keys if k.startswith(prefix)]
|
|
674
|
+
state_dict = {
|
|
675
|
+
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in transformer_keys
|
|
676
|
+
}
|
|
677
|
+
|
|
678
|
+
sd_keys = state_dict.keys()
|
|
679
|
+
if len(sd_keys) == 0:
|
|
680
|
+
print(f"No compatible weight was found in Lora file '{path}'. Please check that it is compatible with the Diffusers format.")
|
|
681
|
+
return
|
|
682
|
+
|
|
683
|
+
# is_correct_format = all("lora" in key for key in state_dict.keys())
|
|
684
|
+
|
|
685
|
+
|
|
686
|
+
|
|
687
|
+
|
|
688
|
+
# check with first key if is not in peft format
|
|
689
|
+
# first_key = next(iter(state_dict.keys()))
|
|
690
|
+
# if "lora_A" not in first_key:
|
|
691
|
+
# state_dict = convert_unet_state_dict_to_peft(state_dict)
|
|
692
|
+
|
|
693
|
+
if adapter_name in getattr(transformer, "peft_config", {}):
|
|
694
|
+
raise ValueError(
|
|
695
|
+
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
rank = {}
|
|
699
|
+
for key, val in state_dict.items():
|
|
700
|
+
if "lora_B" in key:
|
|
701
|
+
rank[key] = val.shape[1]
|
|
702
|
+
|
|
703
|
+
if network_alphas is not None and len(network_alphas) >= 1:
|
|
704
|
+
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
|
705
|
+
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
|
706
|
+
|
|
707
|
+
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
|
|
708
|
+
|
|
709
|
+
lora_config = LoraConfig(**lora_config_kwargs)
|
|
710
|
+
peft_kwargs = {}
|
|
711
|
+
peft_kwargs["low_cpu_mem_usage"] = True
|
|
712
|
+
inject_adapter_in_model(lora_config, model, adapter_name=adapter_name, **peft_kwargs)
|
|
713
|
+
|
|
714
|
+
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
|
|
715
|
+
|
|
716
|
+
warn_msg = ""
|
|
717
|
+
if incompatible_keys is not None:
|
|
718
|
+
# Check only for unexpected keys.
|
|
719
|
+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
|
720
|
+
if unexpected_keys:
|
|
721
|
+
pass
|
|
722
|
+
if verboseLevel >=1:
|
|
723
|
+
print(f"Lora '{path}' was loaded in model '{_get_module_name(model)}'")
|
|
724
|
+
set_weights_and_activate_adapters(model,[ str(i) for i in range(len(lora_multi))], lora_multi)
|
|
725
|
+
|
|
726
|
+
|
|
727
|
+
def fast_load_transformers_model(model_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, verboseLevel = -1):
|
|
728
|
+
"""
|
|
729
|
+
quick version of .LoadfromPretrained of the transformers library
|
|
730
|
+
used to build a model and load the corresponding weights (quantized or not)
|
|
731
|
+
"""
|
|
732
|
+
|
|
733
|
+
|
|
734
|
+
import os.path
|
|
735
|
+
from accelerate import init_empty_weights
|
|
736
|
+
|
|
737
|
+
if not (model_path.endswith(".sft") or model_path.endswith(".safetensors")):
|
|
738
|
+
raise Exception("full model path to file expected")
|
|
739
|
+
|
|
740
|
+
model_path = _get_model(model_path)
|
|
741
|
+
verboseLevel = _compute_verbose_level(verboseLevel)
|
|
742
|
+
|
|
743
|
+
with safetensors2.safe_open(model_path) as f:
|
|
744
|
+
metadata = f.metadata()
|
|
745
|
+
|
|
746
|
+
if metadata is None:
|
|
747
|
+
transformer_config = None
|
|
748
|
+
else:
|
|
749
|
+
transformer_config = metadata.get("config", None)
|
|
750
|
+
|
|
751
|
+
if transformer_config == None:
|
|
752
|
+
config_fullpath = os.path.join(os.path.dirname(model_path), "config.json")
|
|
753
|
+
|
|
754
|
+
if not os.path.isfile(config_fullpath):
|
|
755
|
+
raise Exception("a 'config.json' that describes the model is required in the directory of the model or inside the safetensor file")
|
|
756
|
+
|
|
757
|
+
with open(config_fullpath, "r", encoding="utf-8") as reader:
|
|
758
|
+
text = reader.read()
|
|
759
|
+
transformer_config= json.loads(text)
|
|
760
|
+
|
|
761
|
+
|
|
762
|
+
if "architectures" in transformer_config:
|
|
763
|
+
architectures = transformer_config["architectures"]
|
|
764
|
+
class_name = architectures[0]
|
|
765
|
+
|
|
766
|
+
module = __import__("transformers")
|
|
767
|
+
map = { "T5WithLMHeadModel" : "T5EncoderModel"}
|
|
768
|
+
class_name = map.get(class_name, class_name)
|
|
769
|
+
transfomer_class = getattr(module, class_name)
|
|
770
|
+
from transformers import AutoConfig
|
|
771
|
+
|
|
772
|
+
import tempfile
|
|
773
|
+
with tempfile.NamedTemporaryFile("w", delete = False, encoding ="utf-8") as fp:
|
|
774
|
+
fp.write(json.dumps(transformer_config))
|
|
775
|
+
fp.close()
|
|
776
|
+
config_obj = AutoConfig.from_pretrained(fp.name)
|
|
777
|
+
os.remove(fp.name)
|
|
778
|
+
|
|
779
|
+
#needed to keep inits of non persistent buffers
|
|
780
|
+
with init_empty_weights():
|
|
781
|
+
model = transfomer_class(config_obj)
|
|
782
|
+
|
|
783
|
+
model = model.base_model
|
|
784
|
+
|
|
785
|
+
elif "_class_name" in transformer_config:
|
|
786
|
+
class_name = transformer_config["_class_name"]
|
|
787
|
+
|
|
788
|
+
module = __import__("diffusers")
|
|
789
|
+
transfomer_class = getattr(module, class_name)
|
|
790
|
+
|
|
791
|
+
with init_empty_weights():
|
|
792
|
+
model = transfomer_class.from_config(transformer_config)
|
|
793
|
+
|
|
794
|
+
|
|
795
|
+
torch.set_default_device('cpu')
|
|
796
|
+
|
|
797
|
+
model._config = transformer_config
|
|
798
|
+
|
|
799
|
+
load_model_data(model,model_path, do_quantize = do_quantize, quantizationType = quantizationType, pinToMemory= pinToMemory, partialPinning= partialPinning, verboseLevel=verboseLevel )
|
|
800
|
+
|
|
801
|
+
return model
|
|
802
|
+
|
|
803
|
+
|
|
804
|
+
|
|
805
|
+
def load_model_data(model, file_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, verboseLevel = -1):
|
|
806
|
+
"""
|
|
807
|
+
Load a model, detect if it has been previously quantized using quanto and do the extra setup if necessary
|
|
808
|
+
"""
|
|
809
|
+
|
|
810
|
+
file_path = _get_model(file_path)
|
|
811
|
+
verboseLevel = _compute_verbose_level(verboseLevel)
|
|
812
|
+
|
|
813
|
+
model = _remove_model_wrapper(model)
|
|
814
|
+
|
|
815
|
+
# if pinToMemory and do_quantize:
|
|
816
|
+
# raise Exception("Pinning and Quantization can not be used at the same time")
|
|
817
|
+
|
|
818
|
+
if not (".safetensors" in file_path or ".sft" in file_path):
|
|
819
|
+
if pinToMemory:
|
|
820
|
+
raise Exception("Pinning to memory while loading only supported for safe tensors files")
|
|
821
|
+
state_dict = torch.load(file_path, weights_only=True)
|
|
822
|
+
if "module" in state_dict:
|
|
823
|
+
state_dict = state_dict["module"]
|
|
824
|
+
else:
|
|
825
|
+
state_dict, metadata = _safetensors_load_file(file_path)
|
|
826
|
+
|
|
827
|
+
if metadata is None:
|
|
828
|
+
quantization_map = None
|
|
829
|
+
else:
|
|
830
|
+
quantization_map = metadata.get("quantization_map", None)
|
|
831
|
+
config = metadata.get("config", None)
|
|
832
|
+
if config is not None:
|
|
833
|
+
model._config = config
|
|
834
|
+
|
|
835
|
+
|
|
836
|
+
|
|
837
|
+
if quantization_map is None:
|
|
838
|
+
pos = str.rfind(file_path, ".")
|
|
839
|
+
if pos > 0:
|
|
840
|
+
quantization_map_path = file_path[:pos]
|
|
841
|
+
quantization_map_path += "_map.json"
|
|
842
|
+
|
|
843
|
+
if os.path.isfile(quantization_map_path):
|
|
844
|
+
with open(quantization_map_path, 'r') as f:
|
|
845
|
+
quantization_map = json.load(f)
|
|
846
|
+
|
|
847
|
+
|
|
848
|
+
|
|
849
|
+
if quantization_map is None :
|
|
850
|
+
if "quanto" in file_path and not do_quantize:
|
|
851
|
+
print("Model seems to be quantized by quanto but no quantization map was found whether inside the model or in a separate '{file_path[:json]}_map.json' file")
|
|
852
|
+
else:
|
|
853
|
+
_requantize(model, state_dict, quantization_map)
|
|
854
|
+
|
|
855
|
+
missing_keys , unexpected_keys = model.load_state_dict(state_dict, False, assign = True )
|
|
856
|
+
# if len(missing_keys) > 0:
|
|
857
|
+
# sd_crap = { k : None for k in missing_keys}
|
|
858
|
+
# missing_keys , unexpected_keys = model.load_state_dict(sd_crap, strict =False, assign = True )
|
|
859
|
+
del state_dict
|
|
860
|
+
|
|
861
|
+
for k,p in model.named_parameters():
|
|
862
|
+
if p.is_meta:
|
|
863
|
+
txt = f"Incompatible State Dictionary or 'Init_Empty_Weights' not set since parameter '{k}' has no data"
|
|
864
|
+
raise Exception(txt)
|
|
865
|
+
for k,b in model.named_buffers():
|
|
866
|
+
if b.is_meta:
|
|
867
|
+
txt = f"Incompatible State Dictionary or 'Init_Empty_Weights' not set since buffer '{k}' has no data"
|
|
868
|
+
raise Exception(txt)
|
|
869
|
+
|
|
870
|
+
if do_quantize:
|
|
871
|
+
if quantization_map is None:
|
|
872
|
+
if _quantize(model, quantizationType, verboseLevel=verboseLevel, model_id=file_path):
|
|
873
|
+
quantization_map = model._quanto_map
|
|
874
|
+
else:
|
|
875
|
+
if verboseLevel >=1:
|
|
876
|
+
print("Model already quantized")
|
|
877
|
+
|
|
878
|
+
if pinToMemory:
|
|
879
|
+
_pin_to_memory(model, file_path, partialPinning = partialPinning, verboseLevel = verboseLevel)
|
|
880
|
+
|
|
881
|
+
return
|
|
882
|
+
|
|
584
883
|
def get_model_name(model):
|
|
585
884
|
return model.name
|
|
586
885
|
|
|
@@ -612,6 +911,7 @@ class offload:
|
|
|
612
911
|
self.async_transfers = False
|
|
613
912
|
global last_offload_obj
|
|
614
913
|
last_offload_obj = self
|
|
914
|
+
|
|
615
915
|
|
|
616
916
|
def add_module_to_blocks(self, model_id, blocks_name, submodule, prev_block_name):
|
|
617
917
|
|
|
@@ -669,7 +969,7 @@ class offload:
|
|
|
669
969
|
return False
|
|
670
970
|
return True
|
|
671
971
|
|
|
672
|
-
|
|
972
|
+
@torch.compiler.disable()
|
|
673
973
|
def gpu_load_blocks(self, model_id, blocks_name):
|
|
674
974
|
# cl = clock.start()
|
|
675
975
|
|
|
@@ -715,7 +1015,7 @@ class offload:
|
|
|
715
1015
|
# cl.stop()
|
|
716
1016
|
# print(f"load time: {cl.format_time_gap()}")
|
|
717
1017
|
|
|
718
|
-
|
|
1018
|
+
@torch.compiler.disable()
|
|
719
1019
|
def gpu_unload_blocks(self, model_id, blocks_name):
|
|
720
1020
|
# cl = clock.start()
|
|
721
1021
|
if blocks_name != None:
|
|
@@ -736,7 +1036,7 @@ class offload:
|
|
|
736
1036
|
# cl.stop()
|
|
737
1037
|
# print(f"unload time: {cl.format_time_gap()}")
|
|
738
1038
|
|
|
739
|
-
|
|
1039
|
+
# @torch.compiler.disable()
|
|
740
1040
|
def gpu_load(self, model_id):
|
|
741
1041
|
model = self.models[model_id]
|
|
742
1042
|
self.active_models.append(model)
|
|
@@ -818,10 +1118,10 @@ class offload:
|
|
|
818
1118
|
|
|
819
1119
|
return False
|
|
820
1120
|
|
|
821
|
-
def
|
|
1121
|
+
def hook_preload_blocks_for_compilation(self, target_module, model_id,blocks_name, context):
|
|
822
1122
|
|
|
823
|
-
@torch.compiler.disable()
|
|
824
|
-
def
|
|
1123
|
+
# @torch.compiler.disable()
|
|
1124
|
+
def preload_blocks_for_compile(module, *args, **kwargs):
|
|
825
1125
|
some_context = context #for debugging
|
|
826
1126
|
if blocks_name == None:
|
|
827
1127
|
if self.ready_to_check_mem():
|
|
@@ -835,8 +1135,9 @@ class offload:
|
|
|
835
1135
|
self.empty_cache_if_needed()
|
|
836
1136
|
self.loaded_blocks[model_id] = blocks_name
|
|
837
1137
|
self.gpu_load_blocks(model_id, blocks_name)
|
|
838
|
-
|
|
839
|
-
|
|
1138
|
+
# need to be registered before the forward not to be break the efficiency of the compilation chain
|
|
1139
|
+
# it should be at the top of the compilation as this type of hook in the middle of a chain seems to break memory performance
|
|
1140
|
+
target_module.register_forward_pre_hook(preload_blocks_for_compile)
|
|
840
1141
|
|
|
841
1142
|
|
|
842
1143
|
def hook_check_empty_cache_needed(self, target_module, model_id,blocks_name, previous_method, context):
|
|
@@ -909,267 +1210,18 @@ class offload:
|
|
|
909
1210
|
print(f"Hooked in model '{model_id}' ({model_name})")
|
|
910
1211
|
|
|
911
1212
|
|
|
912
|
-
|
|
913
|
-
# def unhook_module(module: torch.nn.Module):
|
|
914
|
-
# if not hasattr(module,"_mm_id"):
|
|
915
|
-
# return
|
|
916
|
-
|
|
917
|
-
# delattr(module, "_mm_id")
|
|
918
|
-
|
|
919
|
-
# def unhook_all(parent_module: torch.nn.Module):
|
|
920
|
-
# for module in parent_module.components.items():
|
|
921
|
-
# self.unhook_module(module)
|
|
922
|
-
|
|
923
|
-
import torch
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
def load_loras_into_model(model, lora_path, lora_multi = None, verboseLevel = -1):
|
|
929
|
-
verboseLevel = _compute_verbose_level(verboseLevel)
|
|
930
|
-
|
|
931
|
-
if inject_adapter_in_model == None or set_weights_and_activate_adapters == None or get_peft_kwargs == None:
|
|
932
|
-
raise Exception("Unable to load Lora, missing 'peft' and / or 'diffusers' modules")
|
|
933
|
-
|
|
934
|
-
if not isinstance(lora_path, list):
|
|
935
|
-
lora_path = [lora_path]
|
|
936
|
-
|
|
937
|
-
if lora_multi is None:
|
|
938
|
-
lora_multi = [1. for _ in lora_path]
|
|
939
|
-
|
|
940
|
-
for i, path in enumerate(lora_path):
|
|
941
|
-
adapter_name = str(i)
|
|
942
|
-
|
|
943
|
-
state_dict = safetensors2.torch_load_file(path)
|
|
944
|
-
|
|
945
|
-
keys = list(state_dict.keys())
|
|
946
|
-
if len(keys) == 0:
|
|
947
|
-
raise Exception(f"Empty Lora '{path}'")
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
network_alphas = {}
|
|
951
|
-
for k in keys:
|
|
952
|
-
if "alpha" in k:
|
|
953
|
-
alpha_value = state_dict.pop(k)
|
|
954
|
-
if not ( (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
|
|
955
|
-
alpha_value, float
|
|
956
|
-
)):
|
|
957
|
-
network_alphas[k] = torch.tensor( float(alpha_value.item() ) )
|
|
958
|
-
|
|
959
|
-
pos = keys[0].find(".")
|
|
960
|
-
prefix = keys[0][0:pos]
|
|
961
|
-
if not any( prefix.startswith(some_prefix) for some_prefix in ["diffusion_model", "transformer"]):
|
|
962
|
-
msg = f"No compatible weight was found in Lora file '{path}'. Please check that it is compatible with the Diffusers format."
|
|
963
|
-
raise Exception(msg)
|
|
964
|
-
|
|
965
|
-
transformer = model
|
|
966
|
-
|
|
967
|
-
transformer_keys = [k for k in keys if k.startswith(prefix)]
|
|
968
|
-
state_dict = {
|
|
969
|
-
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in transformer_keys
|
|
970
|
-
}
|
|
971
|
-
|
|
972
|
-
sd_keys = state_dict.keys()
|
|
973
|
-
if len(sd_keys) == 0:
|
|
974
|
-
print(f"No compatible weight was found in Lora file '{path}'. Please check that it is compatible with the Diffusers format.")
|
|
975
|
-
return
|
|
976
|
-
|
|
977
|
-
# is_correct_format = all("lora" in key for key in state_dict.keys())
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
# check with first key if is not in peft format
|
|
983
|
-
# first_key = next(iter(state_dict.keys()))
|
|
984
|
-
# if "lora_A" not in first_key:
|
|
985
|
-
# state_dict = convert_unet_state_dict_to_peft(state_dict)
|
|
986
|
-
|
|
987
|
-
if adapter_name in getattr(transformer, "peft_config", {}):
|
|
988
|
-
raise ValueError(
|
|
989
|
-
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
|
990
|
-
)
|
|
991
|
-
|
|
992
|
-
rank = {}
|
|
993
|
-
for key, val in state_dict.items():
|
|
994
|
-
if "lora_B" in key:
|
|
995
|
-
rank[key] = val.shape[1]
|
|
996
|
-
|
|
997
|
-
if network_alphas is not None and len(network_alphas) >= 1:
|
|
998
|
-
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
|
999
|
-
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
|
1000
|
-
|
|
1001
|
-
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
|
|
1002
|
-
|
|
1003
|
-
lora_config = LoraConfig(**lora_config_kwargs)
|
|
1004
|
-
peft_kwargs = {}
|
|
1005
|
-
peft_kwargs["low_cpu_mem_usage"] = True
|
|
1006
|
-
inject_adapter_in_model(lora_config, model, adapter_name=adapter_name, **peft_kwargs)
|
|
1007
|
-
|
|
1008
|
-
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
|
|
1009
|
-
|
|
1010
|
-
warn_msg = ""
|
|
1011
|
-
if incompatible_keys is not None:
|
|
1012
|
-
# Check only for unexpected keys.
|
|
1013
|
-
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
|
1014
|
-
if unexpected_keys:
|
|
1015
|
-
pass
|
|
1016
|
-
if verboseLevel >=1:
|
|
1017
|
-
print(f"Lora '{path}' was loaded in model '{_get_module_name(model)}'")
|
|
1018
|
-
set_weights_and_activate_adapters(model,[ str(i) for i in range(len(lora_multi))], lora_multi)
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
def fast_load_transformers_model(model_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, verboseLevel = -1):
|
|
1022
|
-
"""
|
|
1023
|
-
quick version of .LoadfromPretrained of the transformers library
|
|
1024
|
-
used to build a model and load the corresponding weights (quantized or not)
|
|
1025
|
-
"""
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
import os.path
|
|
1029
|
-
from accelerate import init_empty_weights
|
|
1030
|
-
|
|
1031
|
-
if not (model_path.endswith(".sft") or model_path.endswith(".safetensors")):
|
|
1032
|
-
raise Exception("full model path to file expected")
|
|
1033
|
-
|
|
1034
|
-
model_path = _get_model(model_path)
|
|
1035
|
-
verboseLevel = _compute_verbose_level(verboseLevel)
|
|
1036
|
-
|
|
1037
|
-
with safetensors2.safe_open(model_path) as f:
|
|
1038
|
-
metadata = f.metadata()
|
|
1039
|
-
|
|
1040
|
-
if metadata is None:
|
|
1041
|
-
transformer_config = None
|
|
1042
|
-
else:
|
|
1043
|
-
transformer_config = metadata.get("config", None)
|
|
1044
|
-
|
|
1045
|
-
if transformer_config == None:
|
|
1046
|
-
config_fullpath = os.path.join(os.path.dirname(model_path), "config.json")
|
|
1047
|
-
|
|
1048
|
-
if not os.path.isfile(config_fullpath):
|
|
1049
|
-
raise Exception("a 'config.json' that describes the model is required in the directory of the model or inside the safetensor file")
|
|
1050
|
-
|
|
1051
|
-
with open(config_fullpath, "r", encoding="utf-8") as reader:
|
|
1052
|
-
text = reader.read()
|
|
1053
|
-
transformer_config= json.loads(text)
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
if "architectures" in transformer_config:
|
|
1057
|
-
architectures = transformer_config["architectures"]
|
|
1058
|
-
class_name = architectures[0]
|
|
1059
|
-
|
|
1060
|
-
module = __import__("transformers")
|
|
1061
|
-
transfomer_class = getattr(module, class_name)
|
|
1062
|
-
from transformers import AutoConfig
|
|
1063
|
-
|
|
1064
|
-
import tempfile
|
|
1065
|
-
with tempfile.NamedTemporaryFile("w", delete = False, encoding ="utf-8") as fp:
|
|
1066
|
-
fp.write(json.dumps(transformer_config))
|
|
1067
|
-
fp.close()
|
|
1068
|
-
config_obj = AutoConfig.from_pretrained(fp.name)
|
|
1069
|
-
os.remove(fp.name)
|
|
1070
|
-
|
|
1071
|
-
#needed to keep inits of non persistent buffers
|
|
1072
|
-
with init_empty_weights():
|
|
1073
|
-
model = transfomer_class(config_obj)
|
|
1074
|
-
|
|
1075
|
-
model = model.base_model
|
|
1076
|
-
|
|
1077
|
-
elif "_class_name" in transformer_config:
|
|
1078
|
-
class_name = transformer_config["_class_name"]
|
|
1079
|
-
|
|
1080
|
-
module = __import__("diffusers")
|
|
1081
|
-
transfomer_class = getattr(module, class_name)
|
|
1082
|
-
|
|
1083
|
-
with init_empty_weights():
|
|
1084
|
-
model = transfomer_class.from_config(transformer_config)
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
torch.set_default_device('cpu')
|
|
1088
|
-
|
|
1089
|
-
model._config = transformer_config
|
|
1090
|
-
|
|
1091
|
-
load_model_data(model,model_path, do_quantize = do_quantize, quantizationType = quantizationType, pinToMemory= pinToMemory, partialPinning= partialPinning, verboseLevel=verboseLevel )
|
|
1092
|
-
|
|
1093
|
-
return model
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
def load_model_data(model, file_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, verboseLevel = -1):
|
|
1098
|
-
"""
|
|
1099
|
-
Load a model, detect if it has been previously quantized using quanto and do the extra setup if necessary
|
|
1100
|
-
"""
|
|
1101
|
-
|
|
1102
|
-
file_path = _get_model(file_path)
|
|
1103
|
-
verboseLevel = _compute_verbose_level(verboseLevel)
|
|
1104
|
-
|
|
1105
|
-
model = _remove_model_wrapper(model)
|
|
1106
|
-
|
|
1107
|
-
# if pinToMemory and do_quantize:
|
|
1108
|
-
# raise Exception("Pinning and Quantization can not be used at the same time")
|
|
1109
|
-
|
|
1110
|
-
if not (".safetensors" in file_path or ".sft" in file_path):
|
|
1111
|
-
if pinToMemory:
|
|
1112
|
-
raise Exception("Pinning to memory while loading only supported for safe tensors files")
|
|
1113
|
-
state_dict = torch.load(file_path, weights_only=True)
|
|
1114
|
-
if "module" in state_dict:
|
|
1115
|
-
state_dict = state_dict["module"]
|
|
1116
|
-
else:
|
|
1117
|
-
state_dict, metadata = _safetensors_load_file(file_path)
|
|
1118
|
-
|
|
1119
|
-
if metadata is None:
|
|
1120
|
-
quantization_map = None
|
|
1121
|
-
else:
|
|
1122
|
-
quantization_map = metadata.get("quantization_map", None)
|
|
1123
|
-
config = metadata.get("config", None)
|
|
1124
|
-
if config is not None:
|
|
1125
|
-
model._config = config
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
if quantization_map is None:
|
|
1130
|
-
pos = str.rfind(file_path, ".")
|
|
1131
|
-
if pos > 0:
|
|
1132
|
-
quantization_map_path = file_path[:pos]
|
|
1133
|
-
quantization_map_path += "_map.json"
|
|
1134
|
-
|
|
1135
|
-
if os.path.isfile(quantization_map_path):
|
|
1136
|
-
with open(quantization_map_path, 'r') as f:
|
|
1137
|
-
quantization_map = json.load(f)
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
if quantization_map is None :
|
|
1142
|
-
if "quanto" in file_path and not do_quantize:
|
|
1143
|
-
print("Model seems to be quantized by quanto but no quantization map was found whether inside the model or in a separate '{file_path[:json]}_map.json' file")
|
|
1144
|
-
else:
|
|
1145
|
-
_requantize(model, state_dict, quantization_map)
|
|
1146
|
-
|
|
1147
|
-
missing_keys , unexpected_keys = model.load_state_dict(state_dict, strict = quantization_map is None, assign = True )
|
|
1148
|
-
del state_dict
|
|
1149
|
-
|
|
1150
|
-
if do_quantize:
|
|
1151
|
-
if quantization_map is None:
|
|
1152
|
-
if _quantize(model, quantizationType, verboseLevel=verboseLevel, model_id=file_path):
|
|
1153
|
-
quantization_map = model._quanto_map
|
|
1154
|
-
else:
|
|
1155
|
-
if verboseLevel >=1:
|
|
1156
|
-
print("Model already quantized")
|
|
1157
|
-
|
|
1158
|
-
if pinToMemory:
|
|
1159
|
-
_pin_to_memory(model, file_path, partialPinning = partialPinning, verboseLevel = verboseLevel)
|
|
1160
|
-
|
|
1161
|
-
return
|
|
1162
|
-
|
|
1163
|
-
def save_model(model, file_path, do_quantize = False, quantizationType = qint8, verboseLevel = -1 ):
|
|
1213
|
+
def save_model(model, file_path, do_quantize = False, quantizationType = qint8, verboseLevel = -1, config_file_path = None ):
|
|
1164
1214
|
"""save the weights of a model and quantize them if requested
|
|
1165
1215
|
These weights can be loaded again using 'load_model_data'
|
|
1166
1216
|
"""
|
|
1167
1217
|
|
|
1168
1218
|
config = None
|
|
1169
|
-
|
|
1170
1219
|
verboseLevel = _compute_verbose_level(verboseLevel)
|
|
1171
|
-
|
|
1172
|
-
|
|
1220
|
+
if config_file_path !=None:
|
|
1221
|
+
with open(config_file_path, "r", encoding="utf-8") as reader:
|
|
1222
|
+
text = reader.read()
|
|
1223
|
+
config= json.loads(text)
|
|
1224
|
+
elif hasattr(model, "_config"):
|
|
1173
1225
|
config = model._config
|
|
1174
1226
|
elif hasattr(model, "config"):
|
|
1175
1227
|
config_fullpath = None
|
|
@@ -1195,7 +1247,7 @@ def save_model(model, file_path, do_quantize = False, quantizationType = qint8,
|
|
|
1195
1247
|
print(f"Saving file '{file_path}")
|
|
1196
1248
|
safetensors2.torch_write_file(model.state_dict(), file_path , quantization_map = quantization_map, config = config)
|
|
1197
1249
|
if verboseLevel >=1:
|
|
1198
|
-
print(f"File '{file_path} saved")
|
|
1250
|
+
print(f"File '{file_path}' saved")
|
|
1199
1251
|
|
|
1200
1252
|
|
|
1201
1253
|
|
|
@@ -1286,7 +1338,6 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
|
|
|
1286
1338
|
max_reservable_memory = _get_max_reservable_memory(perc_reserved_mem_max)
|
|
1287
1339
|
|
|
1288
1340
|
estimatesBytesToPin = 0
|
|
1289
|
-
|
|
1290
1341
|
for model_id in models:
|
|
1291
1342
|
current_model: torch.nn.Module = models[model_id]
|
|
1292
1343
|
# make sure that no RAM or GPU memory is not allocated for gradiant / training
|
|
@@ -1302,7 +1353,6 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
|
|
|
1302
1353
|
|
|
1303
1354
|
for n, p in current_model.named_parameters():
|
|
1304
1355
|
p.requires_grad = False
|
|
1305
|
-
p = p.detach()
|
|
1306
1356
|
if isinstance(p, QTensor):
|
|
1307
1357
|
# # fix quanto bug (seems to have been fixed)
|
|
1308
1358
|
# if not modelPinned and p._scale.dtype == torch.float32:
|
|
@@ -1352,21 +1402,21 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
|
|
|
1352
1402
|
# Hook forward methods of modules
|
|
1353
1403
|
for model_id in models:
|
|
1354
1404
|
current_model: torch.nn.Module = models[model_id]
|
|
1355
|
-
current_budget = model_budgets[model_id]
|
|
1356
|
-
current_size = 0
|
|
1357
|
-
cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq = None, None, None, -1
|
|
1358
|
-
self.loaded_blocks[model_id] = None
|
|
1359
1405
|
towers_names, towers_modules = _detect_main_towers(current_model)
|
|
1360
|
-
towers_names = [n +"." for n in towers_names]
|
|
1361
1406
|
if self.verboseLevel>=2 and len(towers_names)>0:
|
|
1362
1407
|
print(f"Potential iterative blocks found in model '{model_id}':{towers_names}")
|
|
1363
1408
|
# compile main iterative modules stacks ("towers")
|
|
1364
|
-
|
|
1409
|
+
compilationInThisOne = compileAllModels or model_id in modelsToCompile
|
|
1410
|
+
if compilationInThisOne:
|
|
1365
1411
|
if self.verboseLevel>=1:
|
|
1366
|
-
|
|
1367
|
-
|
|
1368
|
-
|
|
1369
|
-
|
|
1412
|
+
if len(towers_modules)>0:
|
|
1413
|
+
print(f"Pytorch compilation of model '{model_id}' is scheduled.")
|
|
1414
|
+
else:
|
|
1415
|
+
print(f"Pytorch compilation of model '{model_id}' is not yet supported.")
|
|
1416
|
+
|
|
1417
|
+
for submodel in towers_modules:
|
|
1418
|
+
# for submodel in tower:
|
|
1419
|
+
submodel.forward= torch.compile(submodel.forward, backend= "inductor", mode="default" ) # , fullgraph= True, mode= "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs",
|
|
1370
1420
|
#dynamic=True,
|
|
1371
1421
|
|
|
1372
1422
|
if pinAllModels or model_id in modelsToPin:
|
|
@@ -1376,6 +1426,11 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
|
|
|
1376
1426
|
else:
|
|
1377
1427
|
_pin_to_memory(current_model, model_id, partialPinning= partialPinning, perc_reserved_mem_max=perc_reserved_mem_max, verboseLevel=verboseLevel)
|
|
1378
1428
|
|
|
1429
|
+
current_budget = model_budgets[model_id]
|
|
1430
|
+
current_size = 0
|
|
1431
|
+
cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq = None, None, None, -1
|
|
1432
|
+
self.loaded_blocks[model_id] = None
|
|
1433
|
+
|
|
1379
1434
|
for submodule_name, submodule in current_model.named_modules():
|
|
1380
1435
|
# create a fake 'accelerate' parameter so that the _execution_device property returns always "cuda"
|
|
1381
1436
|
# (it is queried in many pipelines even if offloading is not properly implemented)
|
|
@@ -1384,44 +1439,43 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
|
|
|
1384
1439
|
|
|
1385
1440
|
if submodule_name=='':
|
|
1386
1441
|
continue
|
|
1387
|
-
|
|
1442
|
+
|
|
1388
1443
|
if current_budget > 0:
|
|
1389
|
-
if
|
|
1390
|
-
if cur_blocks_prefix
|
|
1391
|
-
|
|
1444
|
+
if cur_blocks_prefix != None:
|
|
1445
|
+
if submodule_name.startswith(cur_blocks_prefix):
|
|
1446
|
+
depth_prefix = cur_blocks_prefix.split(".")
|
|
1447
|
+
depth_name = submodule_name.split(".")
|
|
1448
|
+
level = depth_name[len(depth_prefix)-1]
|
|
1449
|
+
pre , num = _extract_num_from_str(level)
|
|
1450
|
+
if num != cur_blocks_seq and (cur_blocks_seq == -1 or current_size > current_budget):
|
|
1451
|
+
prev_blocks_name = cur_blocks_name
|
|
1452
|
+
cur_blocks_name = cur_blocks_prefix + str(num)
|
|
1453
|
+
# print(f"new block: {model_id}/{cur_blocks_name} - {submodule_name}")
|
|
1454
|
+
cur_blocks_seq = num
|
|
1392
1455
|
else:
|
|
1393
|
-
|
|
1394
|
-
|
|
1395
|
-
|
|
1396
|
-
|
|
1397
|
-
|
|
1398
|
-
|
|
1399
|
-
|
|
1400
|
-
|
|
1401
|
-
|
|
1402
|
-
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
cur_blocks_name = cur_blocks_prefix + str(num)
|
|
1406
|
-
# print(f"new block: {model_id}/{cur_blocks_name} - {submodule_name}")
|
|
1407
|
-
cur_blocks_seq = num
|
|
1408
|
-
else:
|
|
1409
|
-
cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq = None, None, None, -1
|
|
1410
|
-
|
|
1456
|
+
cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq = None, None, None, -1
|
|
1457
|
+
|
|
1458
|
+
if cur_blocks_prefix == None:
|
|
1459
|
+
pre , num = _extract_num_from_str(submodule_name)
|
|
1460
|
+
if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
|
|
1461
|
+
cur_blocks_prefix, prev_blocks_name, cur_blocks_seq = pre + ".", None, -1
|
|
1462
|
+
elif num >=0:
|
|
1463
|
+
cur_blocks_prefix, prev_blocks_name, cur_blocks_seq = pre, None, num
|
|
1464
|
+
cur_blocks_name = submodule_name
|
|
1465
|
+
# print(f"new block: {model_id}/{cur_blocks_name} - {submodule_name}")
|
|
1466
|
+
|
|
1467
|
+
|
|
1411
1468
|
if hasattr(submodule, "forward"):
|
|
1412
1469
|
submodule_method = getattr(submodule, "forward")
|
|
1413
1470
|
if callable(submodule_method):
|
|
1414
1471
|
if len(submodule_name.split("."))==1:
|
|
1415
1472
|
self.hook_change_module(submodule, current_model, model_id, submodule_name, submodule_method)
|
|
1416
|
-
elif
|
|
1417
|
-
self.
|
|
1473
|
+
elif compilationInThisOne and submodule in towers_modules:
|
|
1474
|
+
self.hook_preload_blocks_for_compilation(submodule, model_id, cur_blocks_name, context = submodule_name )
|
|
1418
1475
|
else:
|
|
1419
1476
|
self.hook_check_empty_cache_needed(submodule, model_id, cur_blocks_name, submodule_method, context = submodule_name )
|
|
1420
1477
|
|
|
1421
|
-
|
|
1422
|
-
current_size = self.add_module_to_blocks(model_id, cur_blocks_name, submodule, prev_blocks_name)
|
|
1423
|
-
|
|
1424
|
-
|
|
1478
|
+
current_size = self.add_module_to_blocks(model_id, cur_blocks_name, submodule, prev_blocks_name)
|
|
1425
1479
|
|
|
1426
1480
|
|
|
1427
1481
|
if self.verboseLevel >=2:
|
|
@@ -1467,11 +1521,12 @@ def profile(pipe_or_dict_of_modules, profile_no: profile_type = profile_type.Ve
|
|
|
1467
1521
|
models_to_scan = ("text_encoder", "text_encoder_2")
|
|
1468
1522
|
candidates_to_quantize = ("t5", "llama", "llm")
|
|
1469
1523
|
for model_id in models_to_scan:
|
|
1470
|
-
|
|
1471
|
-
|
|
1472
|
-
|
|
1473
|
-
|
|
1474
|
-
|
|
1524
|
+
if model_id in module_names:
|
|
1525
|
+
name = module_names[model_id]
|
|
1526
|
+
for candidate in candidates_to_quantize:
|
|
1527
|
+
if candidate in name:
|
|
1528
|
+
default_extraModelsToQuantize.append(model_id)
|
|
1529
|
+
break
|
|
1475
1530
|
|
|
1476
1531
|
|
|
1477
1532
|
# transformer (video or image generator) should be as small as possible not to occupy space that could be used by actual image data
|
|
@@ -1480,6 +1535,7 @@ def profile(pipe_or_dict_of_modules, profile_no: profile_type = profile_type.Ve
|
|
|
1480
1535
|
default_budgets = { "transformer" : 600 , "text_encoder": 3000, "text_encoder_2": 3000 }
|
|
1481
1536
|
extraModelsToQuantize = None
|
|
1482
1537
|
asyncTransfers = True
|
|
1538
|
+
budgets = None
|
|
1483
1539
|
|
|
1484
1540
|
if profile_no == profile_type.HighRAM_HighVRAM:
|
|
1485
1541
|
pinnedMemory= True
|
mmgp/safetensors2.py
CHANGED
|
@@ -156,19 +156,32 @@ def torch_write_file(sd, file_path, quantization_map = None, config = None):
|
|
|
156
156
|
pos = 0
|
|
157
157
|
i = 0
|
|
158
158
|
mx = 100000
|
|
159
|
+
metadata = dict()
|
|
159
160
|
for k , t in sd.items():
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
161
|
+
if torch.is_tensor(t):
|
|
162
|
+
entry = {}
|
|
163
|
+
dtypestr= map[t.dtype]
|
|
164
|
+
entry["dtype"] = dtypestr
|
|
165
|
+
entry["shape"] = list(t.shape)
|
|
166
|
+
size = torch.numel(t) * t.element_size()
|
|
167
|
+
if size == 0:
|
|
168
|
+
pass
|
|
169
|
+
entry["data_offsets"] = [pos, pos + size]
|
|
170
|
+
pos += size
|
|
171
|
+
sf_sd[k] = entry
|
|
172
|
+
else:
|
|
173
|
+
if isinstance(t, str):
|
|
174
|
+
metadata[k] = t
|
|
175
|
+
else:
|
|
176
|
+
try:
|
|
177
|
+
b64 = base64.b64encode(json.dumps(t, ensure_ascii=False).encode('utf8')).decode('utf8')
|
|
178
|
+
metadata[k + "_base64"] = b64
|
|
179
|
+
except:
|
|
180
|
+
pass
|
|
181
|
+
|
|
168
182
|
i+=1
|
|
169
183
|
if i==mx:
|
|
170
184
|
break
|
|
171
|
-
metadata = dict()
|
|
172
185
|
if not quantization_map is None:
|
|
173
186
|
metadata["quantization_format"] = "quanto"
|
|
174
187
|
metadata["quantization_map_base64"] = base64.b64encode(json.dumps(quantization_map, ensure_ascii=False).encode('utf8')).decode('utf8')
|
|
@@ -192,9 +205,9 @@ def torch_write_file(sd, file_path, quantization_map = None, config = None):
|
|
|
192
205
|
|
|
193
206
|
i = 0
|
|
194
207
|
for k , t in sd.items():
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
if
|
|
208
|
+
if torch.is_tensor(t):
|
|
209
|
+
size = torch.numel(t) * t.element_size()
|
|
210
|
+
if size != 0:
|
|
198
211
|
dtype = t.dtype
|
|
199
212
|
# convert in a friendly format, scalars types not supported by numpy
|
|
200
213
|
if dtype == torch.bfloat16:
|
|
@@ -202,11 +215,8 @@ def torch_write_file(sd, file_path, quantization_map = None, config = None):
|
|
|
202
215
|
elif dtype == torch.float8_e5m2 or dtype == torch.float8_e4m3fn:
|
|
203
216
|
t = t.view(torch.uint8)
|
|
204
217
|
buffer = t.numpy().tobytes()
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
bytes_written = writer.write(buffer)
|
|
208
|
-
assert bytes_written == size
|
|
209
|
-
|
|
218
|
+
bytes_written = writer.write(buffer)
|
|
219
|
+
assert bytes_written == size
|
|
210
220
|
i+=1
|
|
211
221
|
if i==mx:
|
|
212
222
|
break
|
|
@@ -297,13 +307,12 @@ class SafeTensorFile:
|
|
|
297
307
|
length = data_offsets[1]-data_offsets[0]
|
|
298
308
|
map_idx = next(iter_tensor_no)
|
|
299
309
|
offset = current_pos - maps[map_idx][1]
|
|
300
|
-
if
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
t = t.view(dtype)
|
|
310
|
+
if length == 0:
|
|
311
|
+
t = torch.empty(shape, dtype=dtype)
|
|
312
|
+
elif len(shape) == 0:
|
|
313
|
+
# don't waste a memory view for a scalar
|
|
314
|
+
t = torch.frombuffer(bytearray(maps[map_idx][0][offset:offset + length]), dtype=torch.uint8)
|
|
315
|
+
t = t.view(dtype)
|
|
307
316
|
else:
|
|
308
317
|
mv = memoryview(maps[map_idx][0])[offset:offset + length]
|
|
309
318
|
t = torch.frombuffer(mv, dtype=dtype)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: mmgp
|
|
3
|
-
Version: 3.0
|
|
3
|
+
Version: 3.1.0
|
|
4
4
|
Summary: Memory Management for the GPU Poor
|
|
5
5
|
Author-email: deepbeepmeep <deepbeepmeep@yahoo.com>
|
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
|
@@ -17,7 +17,7 @@ Requires-Dist: peft
|
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
<p align="center">
|
|
20
|
-
<H2>Memory Management 3.0
|
|
20
|
+
<H2>Memory Management 3.1.0 for the GPU Poor by DeepBeepMeep</H2>
|
|
21
21
|
</p>
|
|
22
22
|
|
|
23
23
|
|
|
@@ -100,7 +100,7 @@ For example:
|
|
|
100
100
|
The smaller this number, the more VRAM left for image data / longer video but also the slower because there will be lots of loading / unloading between the RAM and the VRAM. If model is too big to fit in a budget, it will be broken down in multiples parts that will be unloaded / loaded consequently. The speed of low budget can be increased (up to 2 times) by turning on the options pinnedMemory and asyncTransfers.
|
|
101
101
|
- asyncTransfers: boolean, load to the GPU the next model part while the current part is being processed. This requires twice the budget if any is defined. This may increase speed by 20% (mostly visible on fast modern GPUs).
|
|
102
102
|
- verboseLevel: number between 0 and 2 (1 by default), provides various level of feedback of the different processes
|
|
103
|
-
- compile: list of model ids to compile, may accelerate up x2 depending on the type of GPU. As of 01/01/2025 it will work only on Linux or WSL since compilation relies on Triton which is not yet supported on Windows
|
|
103
|
+
- compile: list of model ids to compile, may accelerate up x2 depending on the type of GPU. It makes sens to compile only the model that is frequently used such as the "transformer" model in the case of video or image generation. As of 01/01/2025 it will work only on Linux or WSL since compilation relies on Triton which is not yet supported on Windows
|
|
104
104
|
|
|
105
105
|
If you are short on RAM and plan to work with quantized models, it is recommended to load pre-quantized models direclty rather than using on the fly quantization, it will be faster and consume slightly less RAM.
|
|
106
106
|
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
mmgp/__init__.py,sha256=A9qBwyQMd1M7vshSTOBnFGP1MQvS2hXmTcTCMUcmyzE,509
|
|
3
|
+
mmgp/offload.py,sha256=VDau0VCAWHnS40swGuqxn7LIyZJdI0qYI58iGCRyw3Y,67352
|
|
4
|
+
mmgp/safetensors2.py,sha256=mTXL-rZ2lZwYKRujNAc8lUJoqQjq6lpD2XrkuZjA_2Y,16138
|
|
5
|
+
mmgp-3.1.0.dist-info/LICENSE.md,sha256=HjzvY2grdtdduZclbZ46B2M-XpT4MDCxFub5ZwTWq2g,93
|
|
6
|
+
mmgp-3.1.0.dist-info/METADATA,sha256=A5Tvc-FGxjk3FuzNHlQ6g6ztJg7hqIwPKvL5EK1pXTc,12708
|
|
7
|
+
mmgp-3.1.0.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
8
|
+
mmgp-3.1.0.dist-info/top_level.txt,sha256=waGaepj2qVfnS2yAOkaMu4r9mJaVjGbEi6AwOUogU_U,14
|
|
9
|
+
mmgp-3.1.0.dist-info/RECORD,,
|
mmgp-3.0.9.dist-info/RECORD
DELETED
|
@@ -1,9 +0,0 @@
|
|
|
1
|
-
__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
mmgp/__init__.py,sha256=A9qBwyQMd1M7vshSTOBnFGP1MQvS2hXmTcTCMUcmyzE,509
|
|
3
|
-
mmgp/offload.py,sha256=bYjpbAHbVX2Vf3nBJXYEc1u9B5JIYvJxv4eMS8L5Tco,64209
|
|
4
|
-
mmgp/safetensors2.py,sha256=G6uzvpGauJLPEvN74MX1ib4YK0E4wzNMyrZO5wOX2k0,15812
|
|
5
|
-
mmgp-3.0.9.dist-info/LICENSE.md,sha256=HjzvY2grdtdduZclbZ46B2M-XpT4MDCxFub5ZwTWq2g,93
|
|
6
|
-
mmgp-3.0.9.dist-info/METADATA,sha256=0vNt8lNKfMkyBrFUN8pOfkDRf8i_jmndgH2ePIekmdg,12570
|
|
7
|
-
mmgp-3.0.9.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
8
|
-
mmgp-3.0.9.dist-info/top_level.txt,sha256=waGaepj2qVfnS2yAOkaMu4r9mJaVjGbEi6AwOUogU_U,14
|
|
9
|
-
mmgp-3.0.9.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|