mmgp 3.2.5__py3-none-any.whl → 3.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mmgp might be problematic. Click here for more details.

mmgp/offload.py CHANGED
@@ -1,4 +1,4 @@
1
- # ------------------ Memory Management 3.2.5 for the GPU Poor by DeepBeepMeep (mmgp)------------------
1
+ # ------------------ Memory Management 3.2.7 for the GPU Poor by DeepBeepMeep (mmgp)------------------
2
2
  #
3
3
  # This module contains multiples optimisations so that models such as Flux (and derived), Mochi, CogView, HunyuanVideo, ... can run smoothly on a 24 GB GPU limited card.
4
4
  # This a replacement for the accelerate library that should in theory manage offloading, but doesn't work properly with models that are loaded / unloaded several
@@ -61,22 +61,13 @@ import sys
61
61
  import os
62
62
  import json
63
63
  import psutil
64
- try:
65
- from diffusers.utils.peft_utils import set_weights_and_activate_adapters, get_peft_kwargs
66
- except:
67
- set_weights_and_activate_adapters = None
68
- get_peft_kwargs = None
69
- pass
70
- try:
71
- from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
72
- except:
73
- inject_adapter_in_model = None
74
- pass
64
+ from accelerate import init_empty_weights
65
+
75
66
 
76
67
  from mmgp import safetensors2
77
68
  from mmgp import profile_type
78
69
 
79
- from optimum.quanto import freeze, qfloat8, qint4 , qint8, quantize, QModuleMixin, QTensor, quantize_module, register_qmodule
70
+ from optimum.quanto import freeze, qfloat8, qint4 , qint8, quantize, QModuleMixin, QLinear, QTensor, quantize_module, register_qmodule
80
71
 
81
72
  # support for Embedding module quantization that is not supported by default by quanto
82
73
  @register_qmodule(torch.nn.Embedding)
@@ -297,12 +288,115 @@ def _get_tensor_ref(p):
297
288
  return p.data_ptr()
298
289
 
299
290
 
300
- def _pin_to_memory(model, model_id, partialPinning = False, verboseLevel = 1):
291
+ # BIG_TENSOR_MAX_SIZE = 2**28 # 256 MB
292
+ BIG_TENSOR_MAX_SIZE = 2**27 # 128 MB
293
+
294
+ def _extract_tie_weights_from_sd(sd , sd_name, verboseLevel =1):
295
+ tied_weights = {}
296
+ tied_weights_count = 0
297
+ tied_weights_total = 0
298
+ tied_weights_last = None
299
+ ref_cache = {}
300
+
301
+ for n, p in sd.items():
302
+ ref = _get_tensor_ref(p)
303
+ match = ref_cache.get(ref, None)
304
+ if match != None:
305
+ match_name, match_size = match
306
+ tied_weights_count += 1
307
+ tied_weights_total += match_size
308
+ if verboseLevel >=1:
309
+ tied_weights_last = f"{match_name} <-> {n}"
310
+ tied_weights[n] = match_name
311
+ else:
312
+ length = torch.numel(p.data) * p.data.element_size()
313
+ ref_cache[ref] = (n, length)
314
+
315
+ if verboseLevel >=1 and tied_weights_count > 0:
316
+ if tied_weights_count == 1:
317
+ print(f"Tied weights of {tied_weights_total/ONE_MB:0.2f} MB detected: {tied_weights_last}")
318
+ else:
319
+ print(f"Found {tied_weights_count} tied weights for a total of {tied_weights_total/ONE_MB:0.2f} MB, last : {tied_weights_last}")
320
+
321
+ def _pin_sd_to_memory(sd, sd_name, tied_weights = None, gig_tensor_size = BIG_TENSOR_MAX_SIZE, verboseLevel = 1):
322
+ current_big_tensor_size = 0
323
+ big_tensor_no = 0
324
+ big_tensors_sizes = []
325
+ tensor_map_indexes = []
326
+ total_tensor_bytes = 0
327
+
328
+ for n, p in sd.items():
329
+ if tied_weights == None or not n in tied_weights :
330
+ length = torch.numel(p.data) * p.data.element_size()
331
+
332
+ if current_big_tensor_size + length > gig_tensor_size :
333
+ big_tensors_sizes.append(current_big_tensor_size)
334
+ current_big_tensor_size = 0
335
+ big_tensor_no += 1
336
+
337
+ itemsize = p.data.dtype.itemsize
338
+ if current_big_tensor_size % itemsize:
339
+ current_big_tensor_size += itemsize - current_big_tensor_size % itemsize
340
+ tensor_map_indexes.append((big_tensor_no, current_big_tensor_size, length ))
341
+ current_big_tensor_size += length
342
+
343
+ total_tensor_bytes += length
344
+
345
+ big_tensors_sizes.append(current_big_tensor_size)
346
+
347
+ big_tensors = []
348
+ last_big_tensor = 0
349
+ total = 0
350
+
351
+ for size in big_tensors_sizes:
352
+ try:
353
+ current_big_tensor = torch.empty( size, dtype= torch.uint8, pin_memory=True, device="cpu")
354
+ big_tensors.append(current_big_tensor)
355
+ except:
356
+ print(f"Unable to pin more tensors for '{sd_name}' as the maximum reservable memory has been reached ({total/ONE_MB:.2f})")
357
+ break
358
+
359
+ last_big_tensor += 1
360
+ total += size
361
+
362
+
363
+ tensor_no = 0
364
+ # prev_big_tensor = 0
365
+ q_name = None
366
+ for n, p in sd.items():
367
+ if tied_weights != None:
368
+ q_name = tied_weights.get(n,None)
369
+ if q_name != None:
370
+ q = sd[q_name]
371
+ p.data = q.data
372
+ assert p.data.is_pinned()
373
+ q = None
374
+ else:
375
+ big_tensor_no, offset, length = tensor_map_indexes[tensor_no]
376
+
377
+ if big_tensor_no>=0 and big_tensor_no < last_big_tensor:
378
+ current_big_tensor = big_tensors[big_tensor_no]
379
+ length = torch.numel(p.data) * p.data.element_size()
380
+ q = _move_to_pinned_tensor(p.data, current_big_tensor, offset, length)
381
+ torch.utils.swap_tensors(p, q)
382
+ del q
383
+ tensor_no += 1
384
+ del p
385
+ # global total_pinned_bytes
386
+ # total_pinned_bytes += total
387
+ gc.collect()
388
+
389
+ if verboseLevel >=1:
390
+ print(f"'{sd_name}' was pinned entirely to reserved RAM: {last_big_tensor} large blocks spread across {total/ONE_MB:.2f} MB")
391
+
392
+ return
393
+
394
+
395
+ def _pin_to_memory(model, model_id, partialPinning = False, pinnedPEFTLora = True, gig_tensor_size = BIG_TENSOR_MAX_SIZE, verboseLevel = 1):
301
396
  if partialPinning:
302
397
  towers_names, _ = _detect_main_towers(model)
303
398
 
304
399
 
305
- BIG_TENSOR_MAX_SIZE = 2**28 # 256 MB
306
400
  current_big_tensor_size = 0
307
401
  big_tensor_no = 0
308
402
  big_tensors_sizes = []
@@ -314,6 +408,9 @@ def _pin_to_memory(model, model_id, partialPinning = False, verboseLevel = 1):
314
408
  include = True
315
409
  if partialPinning:
316
410
  include = any(k.startswith(pre) for pre in towers_names) if partialPinning else True
411
+ if include and not pinnedPEFTLora and ".lora_" in k:
412
+ include = False
413
+
317
414
  if include:
318
415
  params_dict.update( { k + '.' + n : (p, False) for n, p in sub_module.named_parameters(recurse=False) } )
319
416
  params_dict.update( { k + '.' + n : (b, True) for n, b in sub_module.named_buffers(recurse=False) } )
@@ -359,7 +456,7 @@ def _pin_to_memory(model, model_id, partialPinning = False, verboseLevel = 1):
359
456
  length = torch.numel(p.data) * p.data.element_size()
360
457
 
361
458
  ref_cache[ref] = (n, length)
362
- if current_big_tensor_size + length > BIG_TENSOR_MAX_SIZE:
459
+ if current_big_tensor_size + length > gig_tensor_size :
363
460
  big_tensors_sizes.append(current_big_tensor_size)
364
461
  current_big_tensor_size = 0
365
462
  big_tensor_no += 1
@@ -454,7 +551,6 @@ def _pin_to_memory(model, model_id, partialPinning = False, verboseLevel = 1):
454
551
  else:
455
552
  length = torch.numel(p.data) * p.data.element_size()
456
553
  p.data = _move_to_pinned_tensor(p.data, current_big_tensor, offset, length)
457
- p.aaaaa = n
458
554
  tensor_no += 1
459
555
  del p
460
556
  global total_pinned_bytes
@@ -479,7 +575,7 @@ def _welcome():
479
575
  if welcome_displayed:
480
576
  return
481
577
  welcome_displayed = True
482
- print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.2.5) by DeepBeepMeep ************{ENDC}{UNBOLD}")
578
+ print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.2.7) by DeepBeepMeep ************{ENDC}{UNBOLD}")
483
579
 
484
580
  def _extract_num_from_str(num_in_str):
485
581
  size = len(num_in_str)
@@ -762,126 +858,63 @@ def split_linear_modules(model, map ):
762
858
 
763
859
  delattr(parent_module, module_suffix)
764
860
 
765
- def _lora_linear_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
766
- self._check_forward_args(x, *args, **kwargs)
767
- adapter_names = kwargs.pop("adapter_names", None)
768
- if self.disable_adapters:
769
- if self.merged:
770
- self.unmerge()
771
- result = self.base_layer(x, *args, **kwargs)
772
- elif adapter_names is not None:
773
- result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
774
- elif self.merged:
775
- result = self.base_layer(x, *args, **kwargs)
776
- else:
777
- def get_scaling(active_adapter):
778
- scaling_dict = shared_state.get("_lora_scaling", None)
779
- if scaling_dict == None:
780
- return self.scaling[active_adapter]
781
- scaling_list = scaling_dict[active_adapter]
782
- if isinstance(scaling_list, list):
783
- step_no =shared_state.get("_lora_step_no", 0)
784
- return scaling_list[step_no]
785
- else:
786
- return float(scaling_list)
787
861
 
788
- base_weight = self.base_layer.weight
789
- new_weights = not isinstance(self.base_layer, QModuleMixin)
790
- if base_weight.shape[-1] < x.shape[-2] : # sum base weight and lora matrices instead of applying input on each sub lora matrice if input is too large. This will save a lot VRAM and compute
791
- for active_adapter in self.active_adapters:
792
- if active_adapter not in self.lora_A.keys():
793
- continue
794
- if self.use_dora[active_adapter]:
795
- raise Exception("Dora not yet supported by mmgp")
796
-
797
- lora_A = self.lora_A[active_adapter]
798
- lora_B = self.lora_B[active_adapter]
799
- dropout = self.lora_dropout[active_adapter]
800
- scaling = get_scaling(active_adapter)
801
- lora_A_weight = lora_A.weight
802
- lora_B_weight = lora_B.weight
803
- if new_weights:
804
- base_weight = torch.addmm(base_weight, lora_B_weight, lora_A_weight, alpha= scaling )
805
- # base_weight = base_weight + scaling * lora_B_weight @ lora_A_weight
806
- else:
807
- base_weight.addmm_(lora_B_weight, lora_A_weight, alpha= scaling )
808
- # base_weight += scaling * lora_B_weight @ lora_A_weight
809
- new_weights = False
810
-
811
- if self.training:
812
- result = torch.nn.functional.linear(dropout(x), base_weight, bias=self.base_layer.bias)
813
- else:
814
- result = torch.nn.functional.linear(x, base_weight, bias=self.base_layer.bias)
815
- torch_result_dtype = result.dtype
862
+ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_loras = True, check_only = False, ignore_model_variations = False, pinnedLora = False, split_linear_modules_map = None, preprocess_sd = None, verboseLevel = -1,):
863
+ verboseLevel = _compute_verbose_level(verboseLevel)
864
+ modules_dict = {k: v for k,v in model.named_modules()}
865
+
866
+ if not check_only:
867
+ loras_model_data = dict()
868
+ model._loras_model_data = loras_model_data
869
+ loras_active_adapters = set()
870
+ model._loras_active_adapters = loras_active_adapters
871
+ loras_scaling = dict()
872
+ model._loras_scaling = loras_scaling
873
+ loras_tied_weights = dict()
874
+ model._loras_tied_weights = loras_tied_weights
816
875
 
876
+ CrLf = '\r\n'
877
+ error_msg = ""
878
+ def append(source, text ):
879
+ if len(source) == 0:
880
+ return text
817
881
  else:
818
- result = self.base_layer(x, *args, **kwargs)
819
- torch_result_dtype = result.dtype
820
- x = x.to(torch.bfloat16)
821
-
822
- for active_adapter in self.active_adapters:
823
- if active_adapter not in self.lora_A.keys():
824
- continue
825
- lora_A = self.lora_A[active_adapter]
826
- lora_B = self.lora_B[active_adapter]
827
- dropout = self.lora_dropout[active_adapter]
828
- scaling = get_scaling(active_adapter)
829
- x = x.to(lora_A.weight.dtype)
830
-
831
- if not self.use_dora[active_adapter]:
832
- if self.training:
833
- y = lora_A(dropout(x))
834
- else:
835
- y = lora_A(x)
836
-
837
- y = lora_B(y)
838
- y*= scaling
839
- result+= y
840
- del lora_A, lora_B, y
841
- # result = result + lora_B(lora_A(dropout(x))) * scaling
842
- else:
843
- if isinstance(dropout, torch.nn.Identity) or not self.training:
844
- base_result = result
845
- else:
846
- x = dropout(x)
847
- base_result = None
848
-
849
- result = result + self.lora_magnitude_vector[active_adapter](
850
- x,
851
- lora_A=lora_A,
852
- lora_B=lora_B,
853
- scaling=scaling,
854
- base_layer=self.get_base_layer(),
855
- base_result=base_result,
856
- )
857
-
858
- result = result.to(torch_result_dtype)
859
- return result
882
+ return source + CrLf + text
860
883
 
861
- def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_loras = True, split_linear_modules_map = None, preprocess_sd = None, verboseLevel = -1,):
862
- verboseLevel = _compute_verbose_level(verboseLevel)
863
-
864
- if inject_adapter_in_model == None or set_weights_and_activate_adapters == None or get_peft_kwargs == None:
865
- raise Exception("Unable to load Lora, missing 'peft' and / or 'diffusers' modules")
866
-
867
- from peft.tuners.lora import Linear
868
- Linear.forward = _lora_linear_forward
884
+ def trunc(text, sz):
885
+ if len(text) < sz:
886
+ return str(text)
887
+ else:
888
+ return str(text)[0:sz] + '...'
869
889
 
870
890
  if not isinstance(lora_path, list):
871
891
  lora_path = [lora_path]
872
892
 
873
893
  if lora_multi is None:
874
894
  lora_multi = [1. for _ in lora_path]
875
-
895
+ loras_nos = []
896
+ loras_multi = []
897
+ new_lora_path = []
898
+ errors = []
899
+ adapters = {}
900
+ adapter_no = 0
876
901
  for i, path in enumerate(lora_path):
877
- adapter_name = str(i)
878
-
902
+ adapter_name = str(adapter_no)
903
+ error_msg = ""
904
+ if not os.path.isfile(path):
905
+ error_msg = f"Lora '{path}' was not found"
906
+ errors.append((path, error_msg))
907
+ print(error_msg)
908
+ continue
909
+ fail = False
910
+ skip = False
879
911
  state_dict = safetensors2.torch_load_file(path)
912
+
880
913
  if preprocess_sd != None:
881
914
  state_dict = preprocess_sd(state_dict)
882
915
 
883
916
  if split_linear_modules_map != None:
884
- new_state_dict = {}
917
+ new_state_dict = dict()
885
918
  targets_A = { "."+k+".lora_A.weight" : k for k in split_linear_modules_map }
886
919
  targets_B = { "."+k+".lora_B.weight" : k for k in split_linear_modules_map }
887
920
  for module_name, module_data in state_dict.items():
@@ -911,82 +944,162 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
911
944
  new_state_dict[module_name] = module_data
912
945
  state_dict = new_state_dict
913
946
  del new_state_dict
947
+ # tied_weights = _extract_tie_weights_from_sd(state_dict, path) # to do
914
948
 
949
+ clean_up = False
915
950
  keys = list(state_dict.keys())
916
951
  if len(keys) == 0:
917
- raise Exception(f"Empty Lora '{path}'")
918
-
919
- network_alphas = {}
920
- for k in keys:
921
- if "alpha" in k:
922
- alpha_value = state_dict.pop(k)
923
- if not ( (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
924
- alpha_value, float
925
- )):
926
- network_alphas[k] = torch.tensor( float(alpha_value.item() ) )
927
-
928
- pos = keys[0].find(".")
929
- prefix = keys[0][0:pos]
930
- if not any( prefix.startswith(some_prefix) for some_prefix in ["diffusion_model", "transformer"]):
931
- msg = f"No compatible weight was found in Lora file '{path}'. Please check that it is compatible with the Diffusers format."
932
- raise Exception(msg)
933
-
934
- transformer = model
935
-
936
- transformer_keys = [k for k in keys if k.startswith(prefix)]
937
- state_dict = {
938
- k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in transformer_keys
939
- }
940
-
941
- sd_keys = state_dict.keys()
942
- if len(sd_keys) == 0:
943
- print(f"No compatible weight was found in Lora file '{path}'. Please check that it is compatible with the Diffusers format.")
944
- return
945
-
946
- # is_correct_format = all("lora" in key for key in state_dict.keys())
952
+ msg = f"Empty Lora '{path}'"
953
+ error_msg = append(error_msg, msg)
954
+ fail = True
955
+
956
+ if not fail:
957
+ network_alphas = {}
958
+ for k in keys:
959
+ if "alpha" in k:
960
+ alpha_value = state_dict.pop(k)
961
+ if not ( (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
962
+ alpha_value, float
963
+ )):
964
+ network_alphas[k] = torch.tensor( float(alpha_value.item() ) )
965
+
966
+ pos = keys[0].find(".")
967
+ prefix = keys[0][0:pos]
968
+ if prefix not in ["diffusion_model", "transformer"]:
969
+ msg = f"No compatible weight was found in Lora file '{path}'. Please check that it is compatible with the Diffusers format."
970
+ error_msg = append(error_msg, msg)
971
+ fail = True
972
+
973
+ if not fail:
974
+ state_dict = { k[ len(prefix) + 1:]: v for k, v in state_dict.items() if k.startswith(prefix) }
975
+ rank = {}
976
+ clean_up = True
977
+
978
+ # for key, val in state_dict.items():
979
+ # if "lora_B" in key:
980
+ # rank[key] = val.shape[1]
981
+
982
+ # if network_alphas is not None and len(network_alphas) >= 1:
983
+ # alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
984
+ # network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
985
+ network_alphas = None
986
+
987
+ invalid_keys = []
988
+ unexpected_keys = []
989
+ for k, v in state_dict.items():
990
+ pos = k.rfind(".lora_")
991
+ if pos <=0:
992
+ invalid_keys.append(k)
993
+ continue
994
+ module_name = k[ : pos]
995
+ lora_key = k[ pos+1:]
996
+ lora_A = None
997
+ lora_B = None
998
+ if lora_key == "lora_A.weight":
999
+ lora_A = v
1000
+ elif lora_key == "lora_B.weight":
1001
+ lora_B = v
1002
+ else:
1003
+ invalid_keys.append(k)
1004
+ continue
947
1005
 
948
- # check with first key if is not in peft format
949
- # first_key = next(iter(state_dict.keys()))
950
- # if "lora_A" not in first_key:
951
- # state_dict = convert_unet_state_dict_to_peft(state_dict)
1006
+ module = modules_dict.get(module_name, None)
1007
+ if module == None:
1008
+ unexpected_keys.append(k)
1009
+ continue
1010
+ if not isinstance(module, (QLinear, torch.nn.Linear)):
1011
+ msg = f"Lora '{path}' contains a non linear layer '{k}'"
1012
+ error_msg = append(error_msg, msg)
1013
+ fail = True
1014
+ break
1015
+ module_shape = module.weight.shape
1016
+ if lora_A != None:
1017
+ if module_shape[1] != v.shape[1]:
1018
+ if ignore_model_variations:
1019
+ skip = True
1020
+ else:
1021
+ msg = f"Lora '{path}': Lora A dimension is not compatible with model '{_get_module_name(model)}' (model = {module_shape[1]}, lora A = {v.shape[1]}). It is likely this Lora has been made for another version of this model."
1022
+ error_msg = append(error_msg, msg)
1023
+ fail = True
1024
+ break
1025
+ if lora_B != None:
1026
+ if module_shape[0] != v.shape[0]:
1027
+ if ignore_model_variations:
1028
+ skip = True
1029
+ else:
1030
+ msg = f"Lora '{path}': Lora B dimension is not compatible with model '{_get_module_name(model)}' (model = {module_shape[0]}, lora B = {v.shape[0]}). It is likely this Lora has been made for another version of this model."
1031
+ error_msg = append(error_msg, msg)
1032
+ fail = True
1033
+ break
1034
+ if not check_only:
1035
+ loras_module_data = loras_model_data.get(module, None)
1036
+ if loras_module_data == None:
1037
+ loras_module_data = dict()
1038
+ loras_model_data[module] = loras_module_data
1039
+ loras_adapter_data = loras_module_data.get(adapter_name, None)
1040
+ if loras_adapter_data == None:
1041
+ loras_adapter_data = [lora_A, lora_B]
1042
+ loras_module_data[adapter_name] = loras_adapter_data
1043
+ elif lora_A != None:
1044
+ loras_adapter_data[0] = lora_A
1045
+ else:
1046
+ loras_adapter_data[1] = lora_B
1047
+ lora_A, lora_B, v, loras_module_data, loras_adapter_data = None, None, None, None, None
1048
+
1049
+ if len(invalid_keys) > 0:
1050
+ msg = "Lora '{path}' contains non Lora keys '{trunc(invalid_keys,200)}'"
1051
+ error_msg = append(error_msg, msg)
1052
+ fail = True
1053
+ if len(unexpected_keys) > 0:
1054
+ msg = f"Lora '{path}' contains unexpected module keys, it is likely that this Lora is for a different model : '{trunc(unexpected_keys,200)}'"
1055
+ error_msg = append(error_msg, msg)
1056
+ fail = True
1057
+ if fail or skip:
1058
+ if fail:
1059
+ errors.append((path, error_msg))
1060
+ print(error_msg)
1061
+ if clean_up and not check_only:
1062
+ for m,loras_module_data in loras_model_data.items():
1063
+ if adapter_name in loras_module_data:
1064
+ del loras_module_data[adapter_name]
952
1065
 
953
- if adapter_name in getattr(transformer, "peft_config", {}):
954
- raise ValueError(
955
- f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
956
- )
1066
+ else:
1067
+ if not check_only:
1068
+ # model._loras_tied_weights[adapter_name] = tied_weights
1069
+ if pinnedLora:
1070
+ _pin_sd_to_memory(state_dict, path)
957
1071
 
958
- rank = {}
959
- for key, val in state_dict.items():
960
- if "lora_B" in key:
961
- rank[key] = val.shape[1]
1072
+ del state_dict
962
1073
 
963
- if network_alphas is not None and len(network_alphas) >= 1:
964
- alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
965
- network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
966
1074
 
967
- lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
968
-
969
- lora_config = LoraConfig(**lora_config_kwargs)
970
- peft_kwargs = {}
971
- peft_kwargs["low_cpu_mem_usage"] = True
972
- inject_adapter_in_model(lora_config, model, adapter_name=adapter_name, **peft_kwargs)
973
-
974
- incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
975
-
976
- warn_msg = ""
977
- if incompatible_keys is not None:
978
- # Check only for unexpected keys.
979
- unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
980
- if unexpected_keys:
981
- raise Exception(f"Lora '{path}' contains invalid keys '{unexpected_keys}'")
982
-
983
- if verboseLevel >=1:
984
- print(f"Lora '{path}' was loaded in model '{_get_module_name(model)}'")
1075
+ adapters[adapter_name] = path
1076
+ loras_nos.append(adapter_name)
1077
+ new_lora_path.append(path)
1078
+ loras_multi.append(1.0 if i > (len(lora_multi) -1) else lora_multi[i])
1079
+ pass
1080
+ adapter_no += 1
1081
+ if verboseLevel >=1:
1082
+ if check_only:
1083
+ print(f"Lora '{path}' was found for model '{_get_module_name(model)}'")
1084
+ else:
1085
+ print(f"Lora '{path}' was loaded in model '{_get_module_name(model)}'")
1086
+
1087
+ model._loras_errors = errors
1088
+ if not check_only:
1089
+ model._loras_adapters = adapters
985
1090
  if activate_all_loras:
986
- set_weights_and_activate_adapters(model,[ str(i) for i in range(len(lora_multi))], lora_multi)
1091
+ activate_loras(model, loras_nos, loras_multi)
1092
+ return new_lora_path
987
1093
 
988
- def set_step_no_for_lora(step_no):
989
- shared_state["_lora_step_no"] = step_no
1094
+ def unload_loras_from_model(model):
1095
+ model._loras_model_data = None
1096
+ model._loras_errors = None
1097
+ model._loras_adapters = None
1098
+ model._loras_active_adapters = None
1099
+ model._loras_scaling = None
1100
+
1101
+ def set_step_no_for_lora(model, step_no):
1102
+ model._lora_step_no = step_no
990
1103
 
991
1104
  def activate_loras(model, lora_nos, lora_multi = None ):
992
1105
  if not isinstance(lora_nos, list):
@@ -996,15 +1109,13 @@ def activate_loras(model, lora_nos, lora_multi = None ):
996
1109
  if lora_multi is None:
997
1110
  lora_multi = [1. for _ in lora_nos]
998
1111
 
999
- lora_fake_scaling = [1. if isinstance(mult, list) else mult for mult in lora_multi ]
1000
1112
  lora_scaling_dict = {}
1001
1113
  for no, multi in zip(lora_nos, lora_multi):
1002
1114
  lora_scaling_dict[no] = multi
1003
1115
 
1004
- shared_state["_lora_scaling"] = lora_scaling_dict
1005
- shared_state["_lora_step_no"] = 0
1006
-
1007
- set_weights_and_activate_adapters(model, lora_nos, lora_fake_scaling)
1116
+ model._lora_step_no = 0
1117
+ model._loras_active_adapters = set(lora_nos)
1118
+ model._loras_scaling = lora_scaling_dict
1008
1119
 
1009
1120
 
1010
1121
  def move_loras_to_device(model, device="cpu" ):
@@ -1025,7 +1136,6 @@ def fast_load_transformers_model(model_path: str, do_quantize = False, quantizat
1025
1136
 
1026
1137
 
1027
1138
  import os.path
1028
- from accelerate import init_empty_weights
1029
1139
 
1030
1140
  if not (model_path.endswith(".sft") or model_path.endswith(".safetensors")):
1031
1141
  raise Exception("full model path to file expected")
@@ -1350,12 +1460,12 @@ class offload:
1350
1460
  self.loaded_blocks = {}
1351
1461
  self.prev_blocks_names = {}
1352
1462
  self.next_blocks_names = {}
1353
- self.lora_parents = {}
1354
1463
  self.preloaded_blocks_per_model = {}
1355
1464
  self.default_stream = torch.cuda.default_stream(torch.device("cuda")) # torch.cuda.current_stream()
1356
1465
  self.transfer_stream = torch.cuda.Stream()
1357
1466
  self.async_transfers = False
1358
1467
  self.parameters_ref = {}
1468
+
1359
1469
  global last_offload_obj
1360
1470
  last_offload_obj = self
1361
1471
 
@@ -1379,15 +1489,12 @@ class offload:
1379
1489
  self.next_blocks_names[prev_entry_name] = entry_name
1380
1490
  bef = blocks_params_size
1381
1491
 
1382
- lora_name = None
1383
- if self.lora_parents.get(submodule, None) != None:
1384
- lora_name = str(submodule_name[ submodule_name.rfind(".") + 1: ] )
1385
1492
  for k,p in submodule.named_parameters(recurse=False):
1386
1493
  param_size = 0
1387
1494
  ref = _get_tensor_ref(p)
1388
1495
  tied_param = self.parameters_ref.get(ref, None)
1389
1496
  if isinstance(p, QTensor):
1390
- blocks_params.append( (submodule, k, p, False, tied_param, lora_name ) )
1497
+ blocks_params.append( (submodule, k, p, False, tied_param ) )
1391
1498
 
1392
1499
  if p._qtype == qint4:
1393
1500
  if hasattr(p,"_scale_shift"):
@@ -1401,7 +1508,7 @@ class offload:
1401
1508
  param_size += torch.numel(p._scale) * p._scale.element_size()
1402
1509
  param_size += torch.numel(p._data) * p._data.element_size()
1403
1510
  else:
1404
- blocks_params.append( (submodule, k, p, False, tied_param, lora_name) )
1511
+ blocks_params.append( (submodule, k, p, False, tied_param) )
1405
1512
  param_size += torch.numel(p.data) * p.data.element_size()
1406
1513
 
1407
1514
 
@@ -1410,7 +1517,7 @@ class offload:
1410
1517
  self.parameters_ref[ref] = (submodule, k)
1411
1518
 
1412
1519
  for k, p in submodule.named_buffers(recurse=False):
1413
- blocks_params.append( (submodule, k, p, True, None, lora_name) )
1520
+ blocks_params.append( (submodule, k, p, True, None) )
1414
1521
  blocks_params_size += p.data.nbytes
1415
1522
 
1416
1523
  aft = blocks_params_size
@@ -1435,6 +1542,19 @@ class offload:
1435
1542
  return False
1436
1543
  return True
1437
1544
 
1545
+ def _move_loras(self, loras_active_adapters, loras_modules, to_GPU):
1546
+ for name, lora_module in loras_modules.items():
1547
+ for adapter in loras_active_adapters:
1548
+ lora_data = lora_module.get(adapter, None)
1549
+ if lora_data == None:
1550
+ continue
1551
+ lora_A, lora_B = lora_data
1552
+ key = adapter + '_GPU'
1553
+ if to_GPU:
1554
+ lora_module[key] = [lora_A.cuda(), lora_B.cuda()]
1555
+ elif key in lora_module:
1556
+ del lora_module[key]
1557
+
1438
1558
  @torch.compiler.disable()
1439
1559
  def gpu_load_blocks(self, model_id, blocks_name, preload = False):
1440
1560
  # cl = clock.start()
@@ -1443,12 +1563,17 @@ class offload:
1443
1563
  entry_name = model_id if blocks_name is None else model_id + "/" + blocks_name
1444
1564
 
1445
1565
  def cpu_to_gpu(stream_to_use, blocks_params): #, record_for_stream = None
1566
+ model = self.models[model_id]
1567
+ loras_modules = {}
1568
+ loras_active_adapters = getattr(model ,"_loras_active_adapters", None)
1569
+ if loras_active_adapters == None or len(loras_active_adapters) == 0:
1570
+ loras_model_data = None
1571
+ else:
1572
+ loras_model_data = getattr(model, "_loras_model_data", None)
1573
+
1446
1574
  with torch.cuda.stream(stream_to_use):
1447
1575
  for param in blocks_params:
1448
- parent_module, n, p, is_buffer, tied_param, lora_name = param
1449
- if lora_name != None:
1450
- if not lora_name in self.lora_parents[parent_module].active_adapters:
1451
- continue
1576
+ parent_module, n, p, is_buffer, tied_param = param
1452
1577
 
1453
1578
  if tied_param != None:
1454
1579
  tied_p = getattr( tied_param[0], tied_param[1])
@@ -1466,6 +1591,12 @@ class offload:
1466
1591
  if tied_param != None:
1467
1592
  setattr( tied_param[0], tied_param[1], q)
1468
1593
  del p, q
1594
+ if loras_model_data != None:
1595
+ lora_data = loras_model_data.get(parent_module, None)
1596
+ if lora_data != None:
1597
+ loras_modules[parent_module]= lora_data
1598
+ if len(loras_modules) > 0:
1599
+ self._move_loras(loras_active_adapters, loras_modules, True)
1469
1600
 
1470
1601
  loaded_block = self.loaded_blocks[model_id]
1471
1602
 
@@ -1526,14 +1657,31 @@ class offload:
1526
1657
  print(f"Unloading model {blocks_name} ({model_name}) from GPU")
1527
1658
 
1528
1659
  blocks_params = self.blocks_of_modules[blocks_name]
1660
+ model = self.models[model_id]
1661
+ loras_modules = {}
1662
+ loras_active_adapters = getattr(model ,"_loras_active_adapters", None)
1663
+ if loras_active_adapters == None or len(loras_active_adapters) == 0 :
1664
+ loras_model_data = None
1665
+ else:
1666
+ loras_model_data = getattr(model, "_loras_model_data", None)
1667
+
1529
1668
  for param in blocks_params:
1530
- parent_module, n, p, is_buffer, _, _ = param
1669
+ parent_module, n, p, is_buffer, _ = param
1531
1670
  if is_buffer:
1532
1671
  q = torch.nn.Buffer(p)
1533
1672
  else:
1534
1673
  q = torch.nn.Parameter(p , requires_grad=False)
1535
1674
  setattr(parent_module, n , q)
1536
1675
  del p, q
1676
+
1677
+ if loras_model_data != None:
1678
+ lora_data = loras_model_data.get(parent_module, None)
1679
+ if lora_data != None:
1680
+ loras_modules[parent_module]= lora_data
1681
+
1682
+ if len(loras_modules) > 0:
1683
+ self._move_loras(loras_active_adapters, loras_modules, False)
1684
+
1537
1685
  # cl.stop()
1538
1686
  # print(f"unload time: {cl.format_time_gap()}")
1539
1687
 
@@ -1621,6 +1769,92 @@ class offload:
1621
1769
 
1622
1770
  return False
1623
1771
 
1772
+ def _lora_linear_forward(self, model, submodule, loras_data, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1773
+
1774
+ def get_scaling(active_adapter):
1775
+ scaling_list = loras_scaling[active_adapter]
1776
+ if isinstance(scaling_list, list):
1777
+ step_no =getattr(model, "_lora_step_no", 0)
1778
+ return scaling_list[step_no]
1779
+ else:
1780
+ return float(scaling_list)
1781
+
1782
+ weight = submodule.weight
1783
+
1784
+ if loras_data == None:
1785
+ return torch.nn.functional.linear(x, weight, bias=submodule.bias)
1786
+
1787
+ active_adapters = model._loras_active_adapters
1788
+ loras_scaling = model._loras_scaling
1789
+ training = False
1790
+
1791
+
1792
+ if weight.shape[-1] < x.shape[-2] : # sum base weight and lora matrices instead of applying input on each sub lora matrice if input is too large. This will save a lot VRAM and compute
1793
+ if len(active_adapters) > 0:
1794
+ if isinstance(submodule, QModuleMixin):
1795
+ weight = weight.view(weight.shape) # get a persistent copy of the on the fly dequantized weights
1796
+ else:
1797
+ weight = weight.clone()
1798
+
1799
+
1800
+ for active_adapter in active_adapters:
1801
+ data = loras_data.get(active_adapter + '_GPU', None)
1802
+ if data == None:
1803
+ continue
1804
+ lora_A_weight, lora_B_weight = data
1805
+ scaling = get_scaling(active_adapter)
1806
+ weight.addmm_(lora_B_weight, lora_A_weight, alpha= scaling )
1807
+ # base_weight += scaling * lora_B_weight @ lora_A_weight
1808
+
1809
+ if training:
1810
+ pass
1811
+ # result = torch.nn.functional.linear(dropout(x), base_weight, bias=submodule.bias)
1812
+ else:
1813
+ result = torch.nn.functional.linear(x, weight, bias=submodule.bias)
1814
+
1815
+ else:
1816
+ result = torch.nn.functional.linear(x, weight, bias=submodule.bias)
1817
+
1818
+ if len(active_adapters) > 0:
1819
+ x = x.to(torch.bfloat16)
1820
+
1821
+ for active_adapter in active_adapters:
1822
+ data = loras_data.get(active_adapter + '_GPU', None)
1823
+ if data == None:
1824
+ continue
1825
+ lora_A, lora_B = data
1826
+ # dropout = self.lora_dropout[active_adapter]
1827
+ scaling = get_scaling(active_adapter)
1828
+ x = x.to(lora_A.dtype)
1829
+
1830
+ if training:
1831
+ pass
1832
+ # y = lora_A(dropout(x))
1833
+ else:
1834
+ y = torch.nn.functional.linear(x, lora_A, bias=None)
1835
+
1836
+ y = torch.nn.functional.linear(y, lora_B, bias=None)
1837
+ y*= scaling
1838
+ result+= y
1839
+ del y
1840
+
1841
+ return result
1842
+
1843
+
1844
+ def hook_lora_linear(self, submodule, current_model, model_id, submodule_name):
1845
+ old_forward = submodule.forward
1846
+ def lora_linear_forward(module, *args, **kwargs):
1847
+ loras_model_data = getattr(current_model, "_loras_model_data", None)
1848
+ loras_data = None
1849
+ if loras_model_data != None:
1850
+ loras_data = loras_model_data.get(submodule, None)
1851
+ if loras_data == None:
1852
+ return old_forward(*args, **kwargs)
1853
+ else:
1854
+ return self._lora_linear_forward(current_model, submodule, loras_data, *args, **kwargs)
1855
+
1856
+ return functools.update_wrapper(functools.partial(lora_linear_forward, submodule), old_forward)
1857
+
1624
1858
  def ensure_model_loaded(self, model_id):
1625
1859
  if model_id in self.active_models_ids:
1626
1860
  return
@@ -1802,6 +2036,8 @@ class offload:
1802
2036
 
1803
2037
  for model_id, model in self.models.items():
1804
2038
  move_loras_to_device(model, "cpu")
2039
+ if hasattr(model, "_loras_model_data"):
2040
+ unload_loras_from_model(model)
1805
2041
 
1806
2042
  self.models = None
1807
2043
 
@@ -1811,7 +2047,7 @@ class offload:
1811
2047
 
1812
2048
 
1813
2049
 
1814
- def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = True, extraModelsToQuantize = None, quantizationType = qint8, budgets= 0, workingVRAM = None, asyncTransfers = True, compile = False, perc_reserved_mem_max = 0, coTenantsMap = None, verboseLevel = -1):
2050
+ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, loras = None, quantizeTransformer = True, extraModelsToQuantize = None, quantizationType = qint8, budgets= 0, workingVRAM = None, asyncTransfers = True, compile = False, perc_reserved_mem_max = 0, coTenantsMap = None, verboseLevel = -1):
1815
2051
  """Hook to a pipeline or a group of modules in order to reduce their VRAM requirements:
1816
2052
  pipe_or_dict_of_modules : the pipeline object or a dictionary of modules of the model
1817
2053
  quantizeTransformer: set True by default will quantize on the fly the video / image model
@@ -1863,7 +2099,8 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
1863
2099
  _welcome()
1864
2100
  if coTenantsMap != None:
1865
2101
  self.cotenants_map = coTenantsMap
1866
-
2102
+ if loras != None and isinstance(loras, str):
2103
+ loras = [loras]
1867
2104
  self.models = models
1868
2105
 
1869
2106
  extraModelsToQuantize = extraModelsToQuantize if extraModelsToQuantize is not None else []
@@ -2010,12 +2247,12 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
2010
2247
  if self.verboseLevel >=1:
2011
2248
  print(f"Model '{model_id}' already pinned to reserved memory")
2012
2249
  else:
2013
- _pin_to_memory(current_model, model_id, partialPinning= partialPinning, verboseLevel=verboseLevel)
2250
+ _pin_to_memory(current_model, model_id, partialPinning= partialPinning, pinnedPEFTLora = pinnedPEFTLora, verboseLevel=verboseLevel)
2014
2251
 
2015
2252
  current_budget = model_budgets[model_id]
2016
2253
  cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq, is_mod_seq = None, None, None, -1, False
2017
2254
  self.loaded_blocks[model_id] = None
2018
-
2255
+ any_lora = loras !=None and model_id in loras or getattr(current_model, "_loras_model_data", False)
2019
2256
  for submodule_name, submodule in current_model.named_modules():
2020
2257
  # create a fake 'accelerate' parameter so that the _execution_device property returns always "cuda"
2021
2258
  # (it is queried in many pipelines even if offloading is not properly implemented)
@@ -2047,7 +2284,10 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
2047
2284
 
2048
2285
 
2049
2286
  if hasattr(submodule, "forward"):
2050
- submodule_method = getattr(submodule, "forward")
2287
+ if any_lora and isinstance(submodule, torch.nn.Linear):
2288
+ submodule_method = self.hook_lora_linear(submodule, current_model, model_id, submodule_name)
2289
+ else:
2290
+ submodule_method = getattr(submodule, "forward")
2051
2291
  if callable(submodule_method):
2052
2292
  if len(submodule_name.split("."))==1:
2053
2293
  self.hook_change_module(submodule, current_model, model_id, submodule_name, submodule_method)
@@ -2058,13 +2298,6 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
2058
2298
 
2059
2299
  self.add_module_to_blocks(model_id, cur_blocks_name, submodule, prev_blocks_name, submodule_name)
2060
2300
 
2061
- if hasattr(submodule, "active_adapters"):
2062
- for dictmodule in ["lora_A","lora_B"]:
2063
- ssubmod = getattr(submodule, dictmodule, None)
2064
- if ssubmod !=None:
2065
- for k, loramod in ssubmod._modules.items():
2066
- self.lora_parents[loramod] = submodule
2067
-
2068
2301
 
2069
2302
  self.tune_preloading(model_id, current_budget, towers_names)
2070
2303
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: mmgp
3
- Version: 3.2.5
3
+ Version: 3.2.7
4
4
  Summary: Memory Management for the GPU Poor
5
5
  Author-email: deepbeepmeep <deepbeepmeep@yahoo.com>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -13,11 +13,10 @@ Requires-Dist: optimum-quanto
13
13
  Requires-Dist: accelerate
14
14
  Requires-Dist: safetensors
15
15
  Requires-Dist: psutil
16
- Requires-Dist: peft
17
16
 
18
17
 
19
18
  <p align="center">
20
- <H2>Memory Management 3.2.4 for the GPU Poor by DeepBeepMeep</H2>
19
+ <H2>Memory Management 3.2.7 for the GPU Poor by DeepBeepMeep</H2>
21
20
  </p>
22
21
 
23
22
 
@@ -0,0 +1,9 @@
1
+ __init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ mmgp/__init__.py,sha256=A9qBwyQMd1M7vshSTOBnFGP1MQvS2hXmTcTCMUcmyzE,509
3
+ mmgp/offload.py,sha256=6qJrxM3EPqUHC04njZetVY2sr2x9DQwh13CZIM5oLIA,105417
4
+ mmgp/safetensors2.py,sha256=DCdlRH3769CTyraAmWAB3b0XrVua7z6ygQ-OyKgJN6A,16453
5
+ mmgp-3.2.7.dist-info/LICENSE.md,sha256=HjzvY2grdtdduZclbZ46B2M-XpT4MDCxFub5ZwTWq2g,93
6
+ mmgp-3.2.7.dist-info/METADATA,sha256=zu_MxYB3j6sYNqQShyKnNwJkv0_j-fO6qOHoO8PUUfY,16131
7
+ mmgp-3.2.7.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
8
+ mmgp-3.2.7.dist-info/top_level.txt,sha256=waGaepj2qVfnS2yAOkaMu4r9mJaVjGbEi6AwOUogU_U,14
9
+ mmgp-3.2.7.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.2)
2
+ Generator: setuptools (76.0.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,9 +0,0 @@
1
- __init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- mmgp/__init__.py,sha256=A9qBwyQMd1M7vshSTOBnFGP1MQvS2hXmTcTCMUcmyzE,509
3
- mmgp/offload.py,sha256=XQOTMMp5UQku3byZwDr_dYgD3tK4DNTZkwotVyPg-Lk,96434
4
- mmgp/safetensors2.py,sha256=DCdlRH3769CTyraAmWAB3b0XrVua7z6ygQ-OyKgJN6A,16453
5
- mmgp-3.2.5.dist-info/LICENSE.md,sha256=HjzvY2grdtdduZclbZ46B2M-XpT4MDCxFub5ZwTWq2g,93
6
- mmgp-3.2.5.dist-info/METADATA,sha256=s6c1X2ar9DQH1CiLAHdO5X60fuNfKqfmqu-xL_W6j5s,16151
7
- mmgp-3.2.5.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
8
- mmgp-3.2.5.dist-info/top_level.txt,sha256=waGaepj2qVfnS2yAOkaMu4r9mJaVjGbEi6AwOUogU_U,14
9
- mmgp-3.2.5.dist-info/RECORD,,