mmgp 3.1.4.post1592653__py3-none-any.whl → 3.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mmgp might be problematic. Click here for more details.

mmgp/offload.py CHANGED
@@ -1,4 +1,4 @@
1
- # ------------------ Memory Management 3.1.4-159265 for the GPU Poor by DeepBeepMeep (mmgp)------------------
1
+ # ------------------ Memory Management 3.2.1 for the GPU Poor by DeepBeepMeep (mmgp)------------------
2
2
  #
3
3
  # This module contains multiples optimisations so that models such as Flux (and derived), Mochi, CogView, HunyuanVideo, ... can run smoothly on a 24 GB GPU limited card.
4
4
  # This a replacement for the accelerate library that should in theory manage offloading, but doesn't work properly with models that are loaded / unloaded several
@@ -81,6 +81,7 @@ from optimum.quanto import freeze, qfloat8, qint4 , qint8, quantize, QModuleMix
81
81
  # support for Embedding module quantization that is not supported by default by quanto
82
82
  @register_qmodule(torch.nn.Embedding)
83
83
  class QEmbedding(QModuleMixin, torch.nn.Embedding):
84
+ bias = None
84
85
  @classmethod
85
86
  def qcreate(cls, module, weights, activations = None, optimizer = None, device = None):
86
87
  module.bias = None
@@ -331,13 +332,19 @@ def _pin_to_memory(model, model_id, partialPinning = False, verboseLevel = 1):
331
332
 
332
333
  ref_cache = {}
333
334
  tied_weights = {}
335
+ tied_weights_count = 0
336
+ tied_weights_total = 0
337
+ tied_weights_last = None
338
+
334
339
  for n, (p, _) in params_dict.items():
335
340
  ref = _get_tensor_ref(p)
336
341
  match = ref_cache.get(ref, None)
337
342
  if match != None:
338
343
  match_name, match_size = match
344
+ tied_weights_count += 1
345
+ tied_weights_total += match_size
339
346
  if verboseLevel >=1:
340
- print(f"Tied weights of {match_size/ONE_MB:0.2f} MB detected: {match_name} <-> {n}")
347
+ tied_weights_last = f"{match_name} <-> {n}"
341
348
  tied_weights[n] = match_name
342
349
  else:
343
350
  if isinstance(p, QTensor):
@@ -366,6 +373,13 @@ def _pin_to_memory(model, model_id, partialPinning = False, verboseLevel = 1):
366
373
 
367
374
  total_tensor_bytes += length
368
375
 
376
+ if verboseLevel >=1 and tied_weights_count > 0:
377
+ if tied_weights_count == 1:
378
+ print(f"Tied weights of {tied_weights_total/ONE_MB:0.2f} MB detected: {tied_weights_last}")
379
+ else:
380
+ print(f"Found {tied_weights_count} tied weights for a total of {tied_weights_total/ONE_MB:0.2f} MB, last : {tied_weights_last}")
381
+
382
+
369
383
  big_tensors_sizes.append(current_big_tensor_size)
370
384
 
371
385
  big_tensors = []
@@ -392,14 +406,22 @@ def _pin_to_memory(model, model_id, partialPinning = False, verboseLevel = 1):
392
406
  tensor_no = 0
393
407
  # prev_big_tensor = 0
394
408
  for n, (p, is_buffer) in params_dict.items():
395
- if n in tied_weights:
409
+ q_name = tied_weights.get(n,None)
410
+ if q_name != None:
411
+ q , _ = params_dict[q_name]
396
412
  if isinstance(p, QTensor):
397
413
  if p._qtype == qint4:
414
+ p._data._data = q._data._data
415
+ p._scale_shift = q._scale_shift
398
416
  assert p._data._data.data.is_pinned()
399
417
  else:
418
+ p._data = q._data
419
+ p._scale = q._scale
400
420
  assert p._data.is_pinned()
401
421
  else:
422
+ p.data = q.data
402
423
  assert p.data.is_pinned()
424
+ q = None
403
425
  else:
404
426
  big_tensor_no, offset, length = tensor_map_indexes[tensor_no]
405
427
  # if big_tensor_no != prev_big_tensor:
@@ -432,7 +454,7 @@ def _pin_to_memory(model, model_id, partialPinning = False, verboseLevel = 1):
432
454
  else:
433
455
  length = torch.numel(p.data) * p.data.element_size()
434
456
  p.data = _move_to_pinned_tensor(p.data, current_big_tensor, offset, length)
435
-
457
+ p.aaaaa = n
436
458
  tensor_no += 1
437
459
  del p
438
460
  global total_pinned_bytes
@@ -457,7 +479,7 @@ def _welcome():
457
479
  if welcome_displayed:
458
480
  return
459
481
  welcome_displayed = True
460
- print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.1.4-1592653) by DeepBeepMeep ************{ENDC}{UNBOLD}")
482
+ print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.2.1) by DeepBeepMeep ************{ENDC}{UNBOLD}")
461
483
 
462
484
  def _extract_num_from_str(num_in_str):
463
485
  size = len(num_in_str)
@@ -553,7 +575,7 @@ def _requantize(model: torch.nn.Module, state_dict: dict, quantization_map: dict
553
575
 
554
576
 
555
577
 
556
- def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 1000000000, model_id = 'Unknown'):
578
+ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 2**31, model_id = 'Unknown'):
557
579
 
558
580
  total_size =0
559
581
  total_excluded = 0
@@ -603,25 +625,14 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 10
603
625
 
604
626
  if not any(submodule_name.startswith(pre) for pre in tower_names):
605
627
  flush = False
606
- if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
607
- if cur_blocks_prefix == None:
608
- cur_blocks_prefix = submodule_name + "."
609
- flush = True
610
- else:
611
- if not submodule_name.startswith(cur_blocks_prefix):
612
- cur_blocks_prefix = submodule_name + "."
613
- flush = True
614
- else:
615
- if cur_blocks_prefix is not None:
616
- #if not cur_blocks_prefix == submodule_name[0:len(cur_blocks_prefix)]:
617
- if not submodule_name.startswith(cur_blocks_prefix):
618
- cur_blocks_prefix = None
619
- flush = True
628
+ if cur_blocks_prefix == None or not submodule_name.startswith(cur_blocks_prefix):
629
+ cur_blocks_prefix = submodule_name + "."
630
+ flush = True
620
631
 
621
632
  if flush :
622
633
  if submodule_size <= threshold :
623
634
  exclude_list += submodule_names
624
- if verboseLevel >=2:
635
+ if verboseLevel >=2 and submodule_size >0:
625
636
  print(f"Excluded size {submodule_size/ONE_MB:.1f} MB: {prev_blocks_prefix} : {submodule_names}")
626
637
  total_excluded += submodule_size
627
638
 
@@ -632,7 +643,7 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 10
632
643
  submodule_names.append(submodule_name)
633
644
  total_size += size
634
645
 
635
- if submodule_size >0 and submodule_size <= threshold :
646
+ if submodule_size >0 :
636
647
  exclude_list += submodule_names
637
648
  if verboseLevel >=2:
638
649
  print(f"Excluded size {submodule_size/ONE_MB:.1f} MB: {prev_blocks_prefix} : {submodule_names}")
@@ -645,7 +656,7 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 10
645
656
  print(f"Can't find any module to exclude from quantization, full model ({total_size/ONE_MB:.1f} MB) will be quantized")
646
657
  else:
647
658
  print(f"Total Excluded {total_excluded/ONE_MB:.1f} MB of {total_size/ONE_MB:.1f} that is {perc_excluded*100:.2f}%")
648
- if perc_excluded >= 0.10:
659
+ if perc_excluded >= 0.20:
649
660
  if verboseLevel >=2:
650
661
  print(f"Too many modules are excluded, there is something wrong with the selection, switch back to full quantization.")
651
662
  exclude_list = None
@@ -709,6 +720,48 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 10
709
720
 
710
721
  return True
711
722
 
723
+ def split_linear_modules(model, map ):
724
+ from optimum.quanto import QModuleMixin, WeightQBytesTensor, QLinear
725
+ from accelerate import init_empty_weights
726
+
727
+ modules_dict = { k: m for k, m in model.named_modules()}
728
+ for module_suffix, split_info in map.items():
729
+ mapped_modules = split_info["mapped_modules"]
730
+ split_sizes = split_info["split_sizes"]
731
+ for k, module in modules_dict.items():
732
+ if k.endswith("." + module_suffix):
733
+ parent_module = modules_dict[k[:len(k)-len(module_suffix)-1]]
734
+ weight = module.weight
735
+ bias = getattr(module, "bias", None)
736
+ if isinstance(module, QModuleMixin):
737
+ _data = weight._data
738
+ _scale = weight._scale
739
+ sub_data = torch.split(_data, split_sizes, dim=0)
740
+ sub_scale = torch.split(_scale, split_sizes, dim=0)
741
+ sub_bias = torch.split(bias, split_sizes, dim=0)
742
+ for sub_name, _subdata, _subbias, _subscale in zip(mapped_modules, sub_data, sub_bias, sub_scale):
743
+ with init_empty_weights():
744
+ sub_module = QLinear(_subdata.shape[1], _subdata.shape[0], bias=bias != None, device ="cpu", dtype=torch.bfloat16)
745
+ sub_module.weight = torch.nn.Parameter(WeightQBytesTensor.create(weight.qtype, weight.axis, _subdata.size(), weight.stride(), _subdata, _subscale, activation_qtype=weight.activation_qtype, requires_grad=weight.requires_grad ))
746
+ if bias != None:
747
+ sub_module.bias = torch.nn.Parameter(_subbias)
748
+ sub_module.optimizer = module.optimizer
749
+ sub_module.weight_qtype = module.weight_qtype
750
+ setattr(parent_module, sub_name, sub_module)
751
+ # del _data, _scale, _subdata, sub_d
752
+ else:
753
+ sub_data = torch.split(weight, split_sizes, dim=0)
754
+ sub_bias = torch.split(bias, split_sizes, dim=0)
755
+ for sub_name, subdata, subbias in zip(mapped_modules, sub_data, sub_bias):
756
+ with init_empty_weights():
757
+ sub_module = torch.nn.Linear( subdata.shape[1], subdata.shape[0], bias=bias != None, device ="cpu", dtype=torch.bfloat16)
758
+ sub_module.weight = torch.nn.Parameter(subdata , requires_grad=False)
759
+ if bias != None:
760
+ sub_module.bias = torch.nn.Parameter(subbias)
761
+ setattr(parent_module, sub_name, sub_module)
762
+
763
+ delattr(parent_module, module_suffix)
764
+
712
765
  def _lora_linear_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
713
766
  self._check_forward_args(x, *args, **kwargs)
714
767
  adapter_names = kwargs.pop("adapter_names", None)
@@ -721,8 +774,20 @@ def _lora_linear_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor
721
774
  elif self.merged:
722
775
  result = self.base_layer(x, *args, **kwargs)
723
776
  else:
777
+ def get_scaling(active_adapter):
778
+ scaling_dict = shared_state.get("_lora_scaling", None)
779
+ if scaling_dict == None:
780
+ return self.scaling[active_adapter]
781
+ scaling_list = scaling_dict[active_adapter]
782
+ if isinstance(scaling_list, list):
783
+ step_no =shared_state.get("_lora_step_no", 0)
784
+ return scaling_list[step_no]
785
+ else:
786
+ return float(scaling_list)
787
+
724
788
  base_weight = self.base_layer.weight
725
- if base_weight.shape[-1] < x.shape[-2]: # sum base weight and lora matrices instead of applying input on each sub lora matrice if input is too large. This will save a lot VRAM and compute
789
+ new_weights = not isinstance(self.base_layer, QModuleMixin)
790
+ if base_weight.shape[-1] < x.shape[-2] : # sum base weight and lora matrices instead of applying input on each sub lora matrice if input is too large. This will save a lot VRAM and compute
726
791
  for active_adapter in self.active_adapters:
727
792
  if active_adapter not in self.lora_A.keys():
728
793
  continue
@@ -732,11 +797,16 @@ def _lora_linear_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor
732
797
  lora_A = self.lora_A[active_adapter]
733
798
  lora_B = self.lora_B[active_adapter]
734
799
  dropout = self.lora_dropout[active_adapter]
735
- scaling = self.scaling[active_adapter]
800
+ scaling = get_scaling(active_adapter)
736
801
  lora_A_weight = lora_A.weight
737
802
  lora_B_weight = lora_B.weight
738
- lora_BA = lora_B_weight @ lora_A_weight
739
- base_weight += scaling * lora_BA
803
+ if new_weights:
804
+ base_weight = torch.addmm(base_weight, lora_B_weight, lora_A_weight, alpha= scaling )
805
+ # base_weight = base_weight + scaling * lora_B_weight @ lora_A_weight
806
+ else:
807
+ base_weight.addmm_(lora_B_weight, lora_A_weight, alpha= scaling )
808
+ # base_weight += scaling * lora_B_weight @ lora_A_weight
809
+ new_weights = False
740
810
 
741
811
  if self.training:
742
812
  result = torch.nn.functional.linear(dropout(x), base_weight, bias=self.base_layer.bias)
@@ -755,7 +825,7 @@ def _lora_linear_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor
755
825
  lora_A = self.lora_A[active_adapter]
756
826
  lora_B = self.lora_B[active_adapter]
757
827
  dropout = self.lora_dropout[active_adapter]
758
- scaling = self.scaling[active_adapter]
828
+ scaling = get_scaling(active_adapter)
759
829
  x = x.to(lora_A.weight.dtype)
760
830
 
761
831
  if not self.use_dora[active_adapter]:
@@ -788,7 +858,7 @@ def _lora_linear_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor
788
858
  result = result.to(torch_result_dtype)
789
859
  return result
790
860
 
791
- def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_loras = True, verboseLevel = -1,):
861
+ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_loras = True, split_linear_modules_map = None,verboseLevel = -1,):
792
862
  verboseLevel = _compute_verbose_level(verboseLevel)
793
863
 
794
864
  if inject_adapter_in_model == None or set_weights_and_activate_adapters == None or get_peft_kwargs == None:
@@ -806,16 +876,45 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
806
876
  for i, path in enumerate(lora_path):
807
877
  adapter_name = str(i)
808
878
 
809
-
810
-
811
-
812
879
  state_dict = safetensors2.torch_load_file(path)
813
880
 
881
+
882
+ if split_linear_modules_map != None:
883
+ new_state_dict = {}
884
+ targets_A = { "."+k+".lora_A.weight" : k for k in split_linear_modules_map }
885
+ targets_B = { "."+k+".lora_B.weight" : k for k in split_linear_modules_map }
886
+ for module_name, module_data in state_dict.items():
887
+ if any(module_name.endswith(suffix) for suffix in targets_B):
888
+ for suffix, target_module in targets_B.items():
889
+ if module_name.endswith(suffix):
890
+ break
891
+ parent_module_name = module_name[:-len(suffix)]
892
+ map = split_linear_modules_map[target_module]
893
+ mapped_modules = map["mapped_modules"]
894
+ split_sizes = map["split_sizes"]
895
+ sub_data = torch.split(module_data, split_sizes, dim=0)
896
+ for sub_name, subdata, in zip(mapped_modules, sub_data):
897
+ new_module_name = parent_module_name + "." + sub_name + ".lora_B.weight"
898
+ new_state_dict[new_module_name] = subdata
899
+ elif any(module_name.endswith(suffix) for suffix in targets_A):
900
+ for suffix, target_module in targets_A.items():
901
+ if module_name.endswith(suffix):
902
+ break
903
+ parent_module_name = module_name[:-len(suffix)]
904
+ map = split_linear_modules_map[target_module]
905
+ mapped_modules = map["mapped_modules"]
906
+ for sub_name in mapped_modules :
907
+ new_module_name = parent_module_name + "." + sub_name + ".lora_A.weight"
908
+ new_state_dict[new_module_name] = module_data
909
+ else:
910
+ new_state_dict[module_name] = module_data
911
+ state_dict = new_state_dict
912
+ del new_state_dict
913
+
814
914
  keys = list(state_dict.keys())
815
915
  if len(keys) == 0:
816
916
  raise Exception(f"Empty Lora '{path}'")
817
917
 
818
-
819
918
  network_alphas = {}
820
919
  for k in keys:
821
920
  if "alpha" in k:
@@ -884,13 +983,26 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
884
983
  if activate_all_loras:
885
984
  set_weights_and_activate_adapters(model,[ str(i) for i in range(len(lora_multi))], lora_multi)
886
985
 
986
+ def set_step_no_for_lora(step_no):
987
+ shared_state["_lora_step_no"] = step_no
988
+
887
989
  def activate_loras(model, lora_nos, lora_multi = None ):
888
990
  if not isinstance(lora_nos, list):
889
991
  lora_nos = [lora_nos]
890
992
  lora_nos = [str(l) for l in lora_nos]
993
+
891
994
  if lora_multi is None:
892
995
  lora_multi = [1. for _ in lora_nos]
893
- set_weights_and_activate_adapters(model, lora_nos, lora_multi)
996
+
997
+ lora_fake_scaling = [1. if isinstance(mult, list) else mult for mult in lora_multi ]
998
+ lora_scaling_dict = {}
999
+ for no, multi in zip(lora_nos, lora_multi):
1000
+ lora_scaling_dict[no] = multi
1001
+
1002
+ shared_state["_lora_scaling"] = lora_scaling_dict
1003
+ shared_state["_lora_step_no"] = 0
1004
+
1005
+ set_weights_and_activate_adapters(model, lora_nos, lora_fake_scaling)
894
1006
 
895
1007
 
896
1008
  def move_loras_to_device(model, device="cpu" ):
@@ -903,7 +1015,7 @@ def move_loras_to_device(model, device="cpu" ):
903
1015
  if ".lora_" in k:
904
1016
  m.to(device)
905
1017
 
906
- def fast_load_transformers_model(model_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, verboseLevel = -1):
1018
+ def fast_load_transformers_model(model_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, forcedConfigPath = None, modelClass=None, verboseLevel = -1):
907
1019
  """
908
1020
  quick version of .LoadfromPretrained of the transformers library
909
1021
  used to build a model and load the corresponding weights (quantized or not)
@@ -917,6 +1029,7 @@ def fast_load_transformers_model(model_path: str, do_quantize = False, quantizat
917
1029
  raise Exception("full model path to file expected")
918
1030
 
919
1031
  model_path = _get_model(model_path)
1032
+
920
1033
  verboseLevel = _compute_verbose_level(verboseLevel)
921
1034
 
922
1035
  with safetensors2.safe_open(model_path) as f:
@@ -927,8 +1040,11 @@ def fast_load_transformers_model(model_path: str, do_quantize = False, quantizat
927
1040
  else:
928
1041
  transformer_config = metadata.get("config", None)
929
1042
 
930
- if transformer_config == None:
931
- config_fullpath = os.path.join(os.path.dirname(model_path), "config.json")
1043
+ if transformer_config == None or forcedConfigPath != None:
1044
+ if forcedConfigPath != None:
1045
+ config_fullpath = forcedConfigPath
1046
+ else:
1047
+ config_fullpath = os.path.join(os.path.dirname(model_path), "config.json")
932
1048
 
933
1049
  if not os.path.isfile(config_fullpath):
934
1050
  raise Exception("a 'config.json' that describes the model is required in the directory of the model or inside the safetensor file")
@@ -941,11 +1057,13 @@ def fast_load_transformers_model(model_path: str, do_quantize = False, quantizat
941
1057
  if "architectures" in transformer_config:
942
1058
  architectures = transformer_config["architectures"]
943
1059
  class_name = architectures[0]
944
-
945
- module = __import__("transformers")
946
- map = { "T5WithLMHeadModel" : "T5EncoderModel"}
947
- class_name = map.get(class_name, class_name)
948
- transfomer_class = getattr(module, class_name)
1060
+ if modelClass !=None:
1061
+ transfomer_class = modelClass
1062
+ else:
1063
+ module = __import__("transformers")
1064
+ map = { "T5WithLMHeadModel" : "T5EncoderModel"}
1065
+ class_name = map.get(class_name, class_name)
1066
+ transfomer_class = getattr(module, class_name)
949
1067
  from transformers import AutoConfig
950
1068
 
951
1069
  import tempfile
@@ -964,8 +1082,11 @@ def fast_load_transformers_model(model_path: str, do_quantize = False, quantizat
964
1082
  elif "_class_name" in transformer_config:
965
1083
  class_name = transformer_config["_class_name"]
966
1084
 
967
- module = __import__("diffusers")
968
- transfomer_class = getattr(module, class_name)
1085
+ if modelClass !=None:
1086
+ transfomer_class = modelClass
1087
+ else:
1088
+ module = __import__("diffusers")
1089
+ transfomer_class = getattr(module, class_name)
969
1090
 
970
1091
  with init_empty_weights():
971
1092
  model = transfomer_class.from_config(transformer_config)
@@ -987,6 +1108,8 @@ def load_model_data(model, file_path: str, do_quantize = False, quantizationType
987
1108
  """
988
1109
 
989
1110
  file_path = _get_model(file_path)
1111
+ if file_path == None:
1112
+ raise Exception("Unable to find file")
990
1113
  verboseLevel = _compute_verbose_level(verboseLevel)
991
1114
 
992
1115
  model = _remove_model_wrapper(model)
@@ -1036,9 +1159,34 @@ def load_model_data(model, file_path: str, do_quantize = False, quantizationType
1036
1159
  _requantize(model, state_dict, quantization_map)
1037
1160
 
1038
1161
  missing_keys , unexpected_keys = model.load_state_dict(state_dict, False, assign = True )
1039
- # if len(missing_keys) > 0:
1040
- # sd_crap = { k : None for k in missing_keys}
1041
- # missing_keys , unexpected_keys = model.load_state_dict(sd_crap, strict =False, assign = True )
1162
+ if len(missing_keys) > 0 :
1163
+ # if there is a key mismatch maybe we forgot to remove some prefix or we are trying to load just a sub part of a larger model
1164
+ if hasattr(model, "base_model_prefix"):
1165
+ base_model_prefix = model.base_model_prefix + "."
1166
+ else:
1167
+ for k,v in state_dict.items():
1168
+ if k.endswith(missing_keys[0]):
1169
+ base_model_prefix = k[:-len(missing_keys[0])]
1170
+ break
1171
+
1172
+ new_state_dict= {}
1173
+ start = -1
1174
+ for k,v in state_dict.items():
1175
+ if k.startswith(base_model_prefix):
1176
+ new_start = len(base_model_prefix)
1177
+ else:
1178
+ pos = k.find("." + base_model_prefix)
1179
+ if pos < 0:
1180
+ continue
1181
+ new_start = pos + len(base_model_prefix) +1
1182
+ if start != -1 and start != new_start:
1183
+ new_state_dict = state_dict
1184
+ break
1185
+ start = new_start
1186
+ new_state_dict[k[ start:]] = v
1187
+ state_dict = new_state_dict
1188
+ del new_state_dict
1189
+ missing_keys , unexpected_keys = model.load_state_dict(state_dict, False, assign = True )
1042
1190
  del state_dict
1043
1191
 
1044
1192
  for k,p in model.named_parameters():
@@ -1095,7 +1243,7 @@ def save_model(model, file_path, do_quantize = False, quantizationType = qint8,
1095
1243
  config= json.loads(text)
1096
1244
 
1097
1245
  if do_quantize:
1098
- _quantize(model, weights=quantizationType, model_id=file_path)
1246
+ _quantize(model, weights=quantizationType, model_id=file_path, verboseLevel=verboseLevel)
1099
1247
 
1100
1248
  quantization_map = getattr(model, "_quanto_map", None)
1101
1249
 
@@ -1194,6 +1342,7 @@ class offload:
1194
1342
  self.loaded_blocks = {}
1195
1343
  self.prev_blocks_names = {}
1196
1344
  self.next_blocks_names = {}
1345
+ self.lora_parents = {}
1197
1346
  self.preloaded_blocks_per_model = {}
1198
1347
  self.default_stream = torch.cuda.default_stream(torch.device("cuda")) # torch.cuda.current_stream()
1199
1348
  self.transfer_stream = torch.cuda.Stream()
@@ -1219,13 +1368,17 @@ class offload:
1219
1368
  if not prev_block_name == None:
1220
1369
  self.next_blocks_names[prev_entry_name] = entry_name
1221
1370
  bef = blocks_params_size
1371
+
1372
+ lora_name = None
1373
+ if self.lora_parents.get(submodule, None) != None:
1374
+ lora_name = str(submodule_name[ submodule_name.rfind(".") + 1: ] )
1375
+
1222
1376
  for k,p in submodule.named_parameters(recurse=False):
1223
1377
  param_size = 0
1224
1378
  ref = _get_tensor_ref(p)
1225
1379
  tied_param = self.parameters_ref.get(ref, None)
1226
-
1227
1380
  if isinstance(p, QTensor):
1228
- blocks_params.append( (submodule, k, p, False, tied_param ) )
1381
+ blocks_params.append( (submodule, k, p, False, tied_param, lora_name ) )
1229
1382
 
1230
1383
  if p._qtype == qint4:
1231
1384
  if hasattr(p,"_scale_shift"):
@@ -1239,7 +1392,7 @@ class offload:
1239
1392
  param_size += torch.numel(p._scale) * p._scale.element_size()
1240
1393
  param_size += torch.numel(p._data) * p._data.element_size()
1241
1394
  else:
1242
- blocks_params.append( (submodule, k, p, False, tied_param) )
1395
+ blocks_params.append( (submodule, k, p, False, tied_param, lora_name) )
1243
1396
  param_size += torch.numel(p.data) * p.data.element_size()
1244
1397
 
1245
1398
 
@@ -1248,7 +1401,7 @@ class offload:
1248
1401
  self.parameters_ref[ref] = (submodule, k)
1249
1402
 
1250
1403
  for k, p in submodule.named_buffers(recurse=False):
1251
- blocks_params.append( (submodule, k, p, True, None) )
1404
+ blocks_params.append( (submodule, k, p, True, None, lora_name) )
1252
1405
  blocks_params_size += p.data.nbytes
1253
1406
 
1254
1407
  aft = blocks_params_size
@@ -1283,7 +1436,11 @@ class offload:
1283
1436
  def cpu_to_gpu(stream_to_use, blocks_params): #, record_for_stream = None
1284
1437
  with torch.cuda.stream(stream_to_use):
1285
1438
  for param in blocks_params:
1286
- parent_module, n, p, is_buffer, tied_param = param
1439
+ parent_module, n, p, is_buffer, tied_param, lora_name = param
1440
+ if lora_name != None:
1441
+ if not lora_name in self.lora_parents[parent_module].active_adapters:
1442
+ continue
1443
+
1287
1444
  if tied_param != None:
1288
1445
  tied_p = getattr( tied_param[0], tied_param[1])
1289
1446
  if tied_p.is_cuda:
@@ -1353,8 +1510,8 @@ class offload:
1353
1510
  if blocks_name != None:
1354
1511
  self.loaded_blocks[model_id] = None
1355
1512
 
1356
- blocks_name = model_id if blocks_name is None else model_id + "/" + blocks_name
1357
1513
 
1514
+ blocks_name = model_id if blocks_name is None else model_id + "/" + blocks_name
1358
1515
  if self.verboseLevel >=2:
1359
1516
  model = self.models[model_id]
1360
1517
  model_name = model._get_name()
@@ -1362,7 +1519,7 @@ class offload:
1362
1519
 
1363
1520
  blocks_params = self.blocks_of_modules[blocks_name]
1364
1521
  for param in blocks_params:
1365
- parent_module, n, p, is_buffer, _ = param
1522
+ parent_module, n, p, is_buffer, _, _ = param
1366
1523
  if is_buffer:
1367
1524
  q = torch.nn.Buffer(p)
1368
1525
  else:
@@ -1377,7 +1534,6 @@ class offload:
1377
1534
  model = self.models[model_id]
1378
1535
  self.active_models.append(model)
1379
1536
  self.active_models_ids.append(model_id)
1380
-
1381
1537
  self.gpu_load_blocks(model_id, None, True)
1382
1538
  for block_name in self.preloaded_blocks_per_model[model_id]:
1383
1539
  self.gpu_load_blocks(model_id, block_name, True)
@@ -1889,6 +2045,14 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
1889
2045
 
1890
2046
  self.add_module_to_blocks(model_id, cur_blocks_name, submodule, prev_blocks_name, submodule_name)
1891
2047
 
2048
+ if hasattr(submodule, "active_adapters"):
2049
+ for dictmodule in ["lora_A","lora_B"]:
2050
+ ssubmod = getattr(submodule, dictmodule, None)
2051
+ if ssubmod !=None:
2052
+ for k, loramod in ssubmod._modules.items():
2053
+ self.lora_parents[loramod] = submodule
2054
+
2055
+
1892
2056
  self.tune_preloading(model_id, current_budget, towers_names)
1893
2057
 
1894
2058
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: mmgp
3
- Version: 3.1.4.post1592653
3
+ Version: 3.2.1
4
4
  Summary: Memory Management for the GPU Poor
5
5
  Author-email: deepbeepmeep <deepbeepmeep@yahoo.com>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -17,7 +17,7 @@ Requires-Dist: peft
17
17
 
18
18
 
19
19
  <p align="center">
20
- <H2>Memory Management 3.1.4-1592653 for the GPU Poor by DeepBeepMeep</H2>
20
+ <H2>Memory Management 3.2.0 for the GPU Poor by DeepBeepMeep</H2>
21
21
  </p>
22
22
 
23
23
 
@@ -0,0 +1,9 @@
1
+ __init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ mmgp/__init__.py,sha256=A9qBwyQMd1M7vshSTOBnFGP1MQvS2hXmTcTCMUcmyzE,509
3
+ mmgp/offload.py,sha256=hzirru31j78E88OIT38GJ46iMvddEFM2c3_CCn4N4K4,95676
4
+ mmgp/safetensors2.py,sha256=DCdlRH3769CTyraAmWAB3b0XrVua7z6ygQ-OyKgJN6A,16453
5
+ mmgp-3.2.1.dist-info/LICENSE.md,sha256=HjzvY2grdtdduZclbZ46B2M-XpT4MDCxFub5ZwTWq2g,93
6
+ mmgp-3.2.1.dist-info/METADATA,sha256=1gHy9pcQrpOKhpKwn3dbayJGlYzjVJ54lJpIGW9GXxE,15934
7
+ mmgp-3.2.1.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
8
+ mmgp-3.2.1.dist-info/top_level.txt,sha256=waGaepj2qVfnS2yAOkaMu4r9mJaVjGbEi6AwOUogU_U,14
9
+ mmgp-3.2.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.0)
2
+ Generator: setuptools (75.8.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,9 +0,0 @@
1
- __init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- mmgp/__init__.py,sha256=A9qBwyQMd1M7vshSTOBnFGP1MQvS2hXmTcTCMUcmyzE,509
3
- mmgp/offload.py,sha256=M8TpqTbDT8T0uFVo2bJlcww_hh9unCqFcUlioC6B-3E,87183
4
- mmgp/safetensors2.py,sha256=DCdlRH3769CTyraAmWAB3b0XrVua7z6ygQ-OyKgJN6A,16453
5
- mmgp-3.1.4.post1592653.dist-info/LICENSE.md,sha256=HjzvY2grdtdduZclbZ46B2M-XpT4MDCxFub5ZwTWq2g,93
6
- mmgp-3.1.4.post1592653.dist-info/METADATA,sha256=zjGqw7sj5EzuLUqLFDOnt6k6pyy2ynX47fK3fbXZiPE,15954
7
- mmgp-3.1.4.post1592653.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
8
- mmgp-3.1.4.post1592653.dist-info/top_level.txt,sha256=waGaepj2qVfnS2yAOkaMu4r9mJaVjGbEi6AwOUogU_U,14
9
- mmgp-3.1.4.post1592653.dist-info/RECORD,,