mmgp 3.1.4.post1592653__py3-none-any.whl → 3.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mmgp might be problematic. Click here for more details.

mmgp/offload.py CHANGED
@@ -1,4 +1,4 @@
1
- # ------------------ Memory Management 3.1.4-159265 for the GPU Poor by DeepBeepMeep (mmgp)------------------
1
+ # ------------------ Memory Management 3.2.0 for the GPU Poor by DeepBeepMeep (mmgp)------------------
2
2
  #
3
3
  # This module contains multiples optimisations so that models such as Flux (and derived), Mochi, CogView, HunyuanVideo, ... can run smoothly on a 24 GB GPU limited card.
4
4
  # This a replacement for the accelerate library that should in theory manage offloading, but doesn't work properly with models that are loaded / unloaded several
@@ -81,6 +81,7 @@ from optimum.quanto import freeze, qfloat8, qint4 , qint8, quantize, QModuleMix
81
81
  # support for Embedding module quantization that is not supported by default by quanto
82
82
  @register_qmodule(torch.nn.Embedding)
83
83
  class QEmbedding(QModuleMixin, torch.nn.Embedding):
84
+ bias = None
84
85
  @classmethod
85
86
  def qcreate(cls, module, weights, activations = None, optimizer = None, device = None):
86
87
  module.bias = None
@@ -331,13 +332,19 @@ def _pin_to_memory(model, model_id, partialPinning = False, verboseLevel = 1):
331
332
 
332
333
  ref_cache = {}
333
334
  tied_weights = {}
335
+ tied_weights_count = 0
336
+ tied_weights_total = 0
337
+ tied_weights_last = None
338
+
334
339
  for n, (p, _) in params_dict.items():
335
340
  ref = _get_tensor_ref(p)
336
341
  match = ref_cache.get(ref, None)
337
342
  if match != None:
338
343
  match_name, match_size = match
344
+ tied_weights_count += 1
345
+ tied_weights_total += match_size
339
346
  if verboseLevel >=1:
340
- print(f"Tied weights of {match_size/ONE_MB:0.2f} MB detected: {match_name} <-> {n}")
347
+ tied_weights_last = f"{match_name} <-> {n}"
341
348
  tied_weights[n] = match_name
342
349
  else:
343
350
  if isinstance(p, QTensor):
@@ -366,6 +373,13 @@ def _pin_to_memory(model, model_id, partialPinning = False, verboseLevel = 1):
366
373
 
367
374
  total_tensor_bytes += length
368
375
 
376
+ if verboseLevel >=1 and tied_weights_count > 0:
377
+ if tied_weights_count == 1:
378
+ print(f"Tied weights of {tied_weights_total/ONE_MB:0.2f} MB detected: {tied_weights_last}")
379
+ else:
380
+ print(f"Found {tied_weights_count} tied weights for a total of {tied_weights_total/ONE_MB:0.2f} MB, last : {tied_weights_last}")
381
+
382
+
369
383
  big_tensors_sizes.append(current_big_tensor_size)
370
384
 
371
385
  big_tensors = []
@@ -392,14 +406,22 @@ def _pin_to_memory(model, model_id, partialPinning = False, verboseLevel = 1):
392
406
  tensor_no = 0
393
407
  # prev_big_tensor = 0
394
408
  for n, (p, is_buffer) in params_dict.items():
395
- if n in tied_weights:
409
+ q_name = tied_weights.get(n,None)
410
+ if q_name != None:
411
+ q , _ = params_dict[q_name]
396
412
  if isinstance(p, QTensor):
397
413
  if p._qtype == qint4:
414
+ p._data._data = q._data._data
415
+ p._scale_shift = q._scale_shift
398
416
  assert p._data._data.data.is_pinned()
399
417
  else:
418
+ p._data = q._data
419
+ p._scale = q._scale
400
420
  assert p._data.is_pinned()
401
421
  else:
422
+ p.data = q.data
402
423
  assert p.data.is_pinned()
424
+ q = None
403
425
  else:
404
426
  big_tensor_no, offset, length = tensor_map_indexes[tensor_no]
405
427
  # if big_tensor_no != prev_big_tensor:
@@ -432,7 +454,7 @@ def _pin_to_memory(model, model_id, partialPinning = False, verboseLevel = 1):
432
454
  else:
433
455
  length = torch.numel(p.data) * p.data.element_size()
434
456
  p.data = _move_to_pinned_tensor(p.data, current_big_tensor, offset, length)
435
-
457
+ p.aaaaa = n
436
458
  tensor_no += 1
437
459
  del p
438
460
  global total_pinned_bytes
@@ -457,7 +479,7 @@ def _welcome():
457
479
  if welcome_displayed:
458
480
  return
459
481
  welcome_displayed = True
460
- print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.1.4-1592653) by DeepBeepMeep ************{ENDC}{UNBOLD}")
482
+ print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.2.0) by DeepBeepMeep ************{ENDC}{UNBOLD}")
461
483
 
462
484
  def _extract_num_from_str(num_in_str):
463
485
  size = len(num_in_str)
@@ -553,7 +575,7 @@ def _requantize(model: torch.nn.Module, state_dict: dict, quantization_map: dict
553
575
 
554
576
 
555
577
 
556
- def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 1000000000, model_id = 'Unknown'):
578
+ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 2**31, model_id = 'Unknown'):
557
579
 
558
580
  total_size =0
559
581
  total_excluded = 0
@@ -581,6 +603,8 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 10
581
603
  tied_weights= {}
582
604
 
583
605
  for submodule_name, submodule in model_to_quantize.named_modules():
606
+ if "embed_token" in submodule_name:
607
+ pass
584
608
  if isinstance(submodule, QModuleMixin):
585
609
  if verboseLevel>=1:
586
610
  print("No quantization to do as model is already quantized")
@@ -603,25 +627,14 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 10
603
627
 
604
628
  if not any(submodule_name.startswith(pre) for pre in tower_names):
605
629
  flush = False
606
- if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
607
- if cur_blocks_prefix == None:
608
- cur_blocks_prefix = submodule_name + "."
609
- flush = True
610
- else:
611
- if not submodule_name.startswith(cur_blocks_prefix):
612
- cur_blocks_prefix = submodule_name + "."
613
- flush = True
614
- else:
615
- if cur_blocks_prefix is not None:
616
- #if not cur_blocks_prefix == submodule_name[0:len(cur_blocks_prefix)]:
617
- if not submodule_name.startswith(cur_blocks_prefix):
618
- cur_blocks_prefix = None
619
- flush = True
630
+ if cur_blocks_prefix == None or not submodule_name.startswith(cur_blocks_prefix):
631
+ cur_blocks_prefix = submodule_name + "."
632
+ flush = True
620
633
 
621
634
  if flush :
622
635
  if submodule_size <= threshold :
623
636
  exclude_list += submodule_names
624
- if verboseLevel >=2:
637
+ if verboseLevel >=2 and submodule_size >0:
625
638
  print(f"Excluded size {submodule_size/ONE_MB:.1f} MB: {prev_blocks_prefix} : {submodule_names}")
626
639
  total_excluded += submodule_size
627
640
 
@@ -632,7 +645,7 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 10
632
645
  submodule_names.append(submodule_name)
633
646
  total_size += size
634
647
 
635
- if submodule_size >0 and submodule_size <= threshold :
648
+ if submodule_size >0 :
636
649
  exclude_list += submodule_names
637
650
  if verboseLevel >=2:
638
651
  print(f"Excluded size {submodule_size/ONE_MB:.1f} MB: {prev_blocks_prefix} : {submodule_names}")
@@ -645,7 +658,7 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 10
645
658
  print(f"Can't find any module to exclude from quantization, full model ({total_size/ONE_MB:.1f} MB) will be quantized")
646
659
  else:
647
660
  print(f"Total Excluded {total_excluded/ONE_MB:.1f} MB of {total_size/ONE_MB:.1f} that is {perc_excluded*100:.2f}%")
648
- if perc_excluded >= 0.10:
661
+ if perc_excluded >= 0.20:
649
662
  if verboseLevel >=2:
650
663
  print(f"Too many modules are excluded, there is something wrong with the selection, switch back to full quantization.")
651
664
  exclude_list = None
@@ -709,6 +722,48 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 10
709
722
 
710
723
  return True
711
724
 
725
+ def split_linear_modules(model, map ):
726
+ from optimum.quanto import QModuleMixin, WeightQBytesTensor, QLinear
727
+ from accelerate import init_empty_weights
728
+
729
+ modules_dict = { k: m for k, m in model.named_modules()}
730
+ for module_suffix, split_info in map.items():
731
+ mapped_modules = split_info["mapped_modules"]
732
+ split_sizes = split_info["split_sizes"]
733
+ for k, module in modules_dict.items():
734
+ if k.endswith("." + module_suffix):
735
+ parent_module = modules_dict[k[:len(k)-len(module_suffix)-1]]
736
+ weight = module.weight
737
+ bias = getattr(module, "bias", None)
738
+ if isinstance(module, QModuleMixin):
739
+ _data = weight._data
740
+ _scale = weight._scale
741
+ sub_data = torch.split(_data, split_sizes, dim=0)
742
+ sub_scale = torch.split(_scale, split_sizes, dim=0)
743
+ sub_bias = torch.split(bias, split_sizes, dim=0)
744
+ for sub_name, _subdata, _subbias, _subscale in zip(mapped_modules, sub_data, sub_bias, sub_scale):
745
+ with init_empty_weights():
746
+ sub_module = QLinear(_subdata.shape[1], _subdata.shape[0], bias=bias != None, device ="cpu", dtype=torch.bfloat16)
747
+ sub_module.weight = torch.nn.Parameter(WeightQBytesTensor.create(weight.qtype, weight.axis, _subdata.size(), weight.stride(), _subdata, _subscale, activation_qtype=weight.activation_qtype, requires_grad=weight.requires_grad ))
748
+ if bias != None:
749
+ sub_module.bias = torch.nn.Parameter(_subbias)
750
+ sub_module.optimizer = module.optimizer
751
+ sub_module.weight_qtype = module.weight_qtype
752
+ setattr(parent_module, sub_name, sub_module)
753
+ # del _data, _scale, _subdata, sub_d
754
+ else:
755
+ sub_data = torch.split(weight, split_sizes, dim=0)
756
+ sub_bias = torch.split(bias, split_sizes, dim=0)
757
+ for sub_name, subdata, subbias in zip(mapped_modules, sub_data, sub_bias):
758
+ with init_empty_weights():
759
+ sub_module = torch.nn.Linear( subdata.shape[1], subdata.shape[0], bias=bias != None, device ="cpu", dtype=torch.bfloat16)
760
+ sub_module.weight = torch.nn.Parameter(subdata , requires_grad=False)
761
+ if bias != None:
762
+ sub_module.bias = torch.nn.Parameter(subbias)
763
+ setattr(parent_module, sub_name, sub_module)
764
+
765
+ delattr(parent_module, module_suffix)
766
+
712
767
  def _lora_linear_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
713
768
  self._check_forward_args(x, *args, **kwargs)
714
769
  adapter_names = kwargs.pop("adapter_names", None)
@@ -721,8 +776,20 @@ def _lora_linear_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor
721
776
  elif self.merged:
722
777
  result = self.base_layer(x, *args, **kwargs)
723
778
  else:
779
+ def get_scaling(active_adapter):
780
+ scaling_dict = shared_state.get("_lora_scaling", None)
781
+ if scaling_dict == None:
782
+ return self.scaling[active_adapter]
783
+ scaling_list = scaling_dict[active_adapter]
784
+ if isinstance(scaling_list, list):
785
+ step_no =shared_state.get("_lora_step_no", 0)
786
+ return scaling_list[step_no]
787
+ else:
788
+ return float(scaling_list)
789
+
724
790
  base_weight = self.base_layer.weight
725
- if base_weight.shape[-1] < x.shape[-2]: # sum base weight and lora matrices instead of applying input on each sub lora matrice if input is too large. This will save a lot VRAM and compute
791
+ new_weights = not isinstance(self.base_layer, QModuleMixin)
792
+ if base_weight.shape[-1] < x.shape[-2] : # sum base weight and lora matrices instead of applying input on each sub lora matrice if input is too large. This will save a lot VRAM and compute
726
793
  for active_adapter in self.active_adapters:
727
794
  if active_adapter not in self.lora_A.keys():
728
795
  continue
@@ -732,11 +799,16 @@ def _lora_linear_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor
732
799
  lora_A = self.lora_A[active_adapter]
733
800
  lora_B = self.lora_B[active_adapter]
734
801
  dropout = self.lora_dropout[active_adapter]
735
- scaling = self.scaling[active_adapter]
802
+ scaling = get_scaling(active_adapter)
736
803
  lora_A_weight = lora_A.weight
737
804
  lora_B_weight = lora_B.weight
738
- lora_BA = lora_B_weight @ lora_A_weight
739
- base_weight += scaling * lora_BA
805
+ if new_weights:
806
+ base_weight = torch.addmm(base_weight, lora_B_weight, lora_A_weight, alpha= scaling )
807
+ # base_weight = base_weight + scaling * lora_B_weight @ lora_A_weight
808
+ else:
809
+ base_weight.addmm_(lora_B_weight, lora_A_weight, alpha= scaling )
810
+ # base_weight += scaling * lora_B_weight @ lora_A_weight
811
+ new_weights = False
740
812
 
741
813
  if self.training:
742
814
  result = torch.nn.functional.linear(dropout(x), base_weight, bias=self.base_layer.bias)
@@ -755,7 +827,7 @@ def _lora_linear_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor
755
827
  lora_A = self.lora_A[active_adapter]
756
828
  lora_B = self.lora_B[active_adapter]
757
829
  dropout = self.lora_dropout[active_adapter]
758
- scaling = self.scaling[active_adapter]
830
+ scaling = get_scaling(active_adapter)
759
831
  x = x.to(lora_A.weight.dtype)
760
832
 
761
833
  if not self.use_dora[active_adapter]:
@@ -788,7 +860,7 @@ def _lora_linear_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor
788
860
  result = result.to(torch_result_dtype)
789
861
  return result
790
862
 
791
- def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_loras = True, verboseLevel = -1,):
863
+ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_loras = True, split_linear_modules_map = None,verboseLevel = -1,):
792
864
  verboseLevel = _compute_verbose_level(verboseLevel)
793
865
 
794
866
  if inject_adapter_in_model == None or set_weights_and_activate_adapters == None or get_peft_kwargs == None:
@@ -806,16 +878,45 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
806
878
  for i, path in enumerate(lora_path):
807
879
  adapter_name = str(i)
808
880
 
809
-
810
-
811
-
812
881
  state_dict = safetensors2.torch_load_file(path)
813
882
 
883
+
884
+ if split_linear_modules_map != None:
885
+ new_state_dict = {}
886
+ targets_A = { "."+k+".lora_A.weight" : k for k in split_linear_modules_map }
887
+ targets_B = { "."+k+".lora_B.weight" : k for k in split_linear_modules_map }
888
+ for module_name, module_data in state_dict.items():
889
+ if any(module_name.endswith(suffix) for suffix in targets_B):
890
+ for suffix, target_module in targets_B.items():
891
+ if module_name.endswith(suffix):
892
+ break
893
+ parent_module_name = module_name[:-len(suffix)]
894
+ map = split_linear_modules_map[target_module]
895
+ mapped_modules = map["mapped_modules"]
896
+ split_sizes = map["split_sizes"]
897
+ sub_data = torch.split(module_data, split_sizes, dim=0)
898
+ for sub_name, subdata, in zip(mapped_modules, sub_data):
899
+ new_module_name = parent_module_name + "." + sub_name + ".lora_B.weight"
900
+ new_state_dict[new_module_name] = subdata
901
+ elif any(module_name.endswith(suffix) for suffix in targets_A):
902
+ for suffix, target_module in targets_A.items():
903
+ if module_name.endswith(suffix):
904
+ break
905
+ parent_module_name = module_name[:-len(suffix)]
906
+ map = split_linear_modules_map[target_module]
907
+ mapped_modules = map["mapped_modules"]
908
+ for sub_name in mapped_modules :
909
+ new_module_name = parent_module_name + "." + sub_name + ".lora_A.weight"
910
+ new_state_dict[new_module_name] = module_data
911
+ else:
912
+ new_state_dict[module_name] = module_data
913
+ state_dict = new_state_dict
914
+ del new_state_dict
915
+
814
916
  keys = list(state_dict.keys())
815
917
  if len(keys) == 0:
816
918
  raise Exception(f"Empty Lora '{path}'")
817
919
 
818
-
819
920
  network_alphas = {}
820
921
  for k in keys:
821
922
  if "alpha" in k:
@@ -884,13 +985,26 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
884
985
  if activate_all_loras:
885
986
  set_weights_and_activate_adapters(model,[ str(i) for i in range(len(lora_multi))], lora_multi)
886
987
 
988
+ def set_step_no_for_lora(step_no):
989
+ shared_state["_lora_step_no"] = step_no
990
+
887
991
  def activate_loras(model, lora_nos, lora_multi = None ):
888
992
  if not isinstance(lora_nos, list):
889
993
  lora_nos = [lora_nos]
890
994
  lora_nos = [str(l) for l in lora_nos]
995
+
891
996
  if lora_multi is None:
892
997
  lora_multi = [1. for _ in lora_nos]
893
- set_weights_and_activate_adapters(model, lora_nos, lora_multi)
998
+
999
+ lora_fake_scaling = [1. if isinstance(mult, list) else mult for mult in lora_multi ]
1000
+ lora_scaling_dict = {}
1001
+ for no, multi in zip(lora_nos, lora_multi):
1002
+ lora_scaling_dict[no] = multi
1003
+
1004
+ shared_state["_lora_scaling"] = lora_scaling_dict
1005
+ shared_state["_lora_step_no"] = 0
1006
+
1007
+ set_weights_and_activate_adapters(model, lora_nos, lora_fake_scaling)
894
1008
 
895
1009
 
896
1010
  def move_loras_to_device(model, device="cpu" ):
@@ -903,7 +1017,7 @@ def move_loras_to_device(model, device="cpu" ):
903
1017
  if ".lora_" in k:
904
1018
  m.to(device)
905
1019
 
906
- def fast_load_transformers_model(model_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, verboseLevel = -1):
1020
+ def fast_load_transformers_model(model_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, forcedConfigPath = None, verboseLevel = -1):
907
1021
  """
908
1022
  quick version of .LoadfromPretrained of the transformers library
909
1023
  used to build a model and load the corresponding weights (quantized or not)
@@ -927,8 +1041,11 @@ def fast_load_transformers_model(model_path: str, do_quantize = False, quantizat
927
1041
  else:
928
1042
  transformer_config = metadata.get("config", None)
929
1043
 
930
- if transformer_config == None:
931
- config_fullpath = os.path.join(os.path.dirname(model_path), "config.json")
1044
+ if transformer_config == None or forcedConfigPath != None:
1045
+ if forcedConfigPath != None:
1046
+ config_fullpath = forcedConfigPath
1047
+ else:
1048
+ config_fullpath = os.path.join(os.path.dirname(model_path), "config.json")
932
1049
 
933
1050
  if not os.path.isfile(config_fullpath):
934
1051
  raise Exception("a 'config.json' that describes the model is required in the directory of the model or inside the safetensor file")
@@ -1036,9 +1153,27 @@ def load_model_data(model, file_path: str, do_quantize = False, quantizationType
1036
1153
  _requantize(model, state_dict, quantization_map)
1037
1154
 
1038
1155
  missing_keys , unexpected_keys = model.load_state_dict(state_dict, False, assign = True )
1039
- # if len(missing_keys) > 0:
1040
- # sd_crap = { k : None for k in missing_keys}
1041
- # missing_keys , unexpected_keys = model.load_state_dict(sd_crap, strict =False, assign = True )
1156
+ if len(missing_keys) > 0 and hasattr(model, "base_model_prefix"):
1157
+ # if there is a key mismatch maybe we forgot to remove some prefix or we are trying to load just a sub part of a larger model
1158
+ base_model_prefix = model.base_model_prefix + "."
1159
+ new_state_dict= {}
1160
+ start = -1
1161
+ for k,v in state_dict.items():
1162
+ if k.startswith(base_model_prefix):
1163
+ new_start = len(base_model_prefix)
1164
+ else:
1165
+ pos = k.find("." + base_model_prefix)
1166
+ if pos < 0:
1167
+ continue
1168
+ new_start = pos + len(base_model_prefix) +1
1169
+ if start != -1 and start != new_start:
1170
+ new_state_dict = state_dict
1171
+ break
1172
+ start = new_start
1173
+ new_state_dict[k[ start:]] = v
1174
+ state_dict = new_state_dict
1175
+ del new_state_dict
1176
+ missing_keys , unexpected_keys = model.load_state_dict(state_dict, False, assign = True )
1042
1177
  del state_dict
1043
1178
 
1044
1179
  for k,p in model.named_parameters():
@@ -1095,7 +1230,7 @@ def save_model(model, file_path, do_quantize = False, quantizationType = qint8,
1095
1230
  config= json.loads(text)
1096
1231
 
1097
1232
  if do_quantize:
1098
- _quantize(model, weights=quantizationType, model_id=file_path)
1233
+ _quantize(model, weights=quantizationType, model_id=file_path, verboseLevel=verboseLevel)
1099
1234
 
1100
1235
  quantization_map = getattr(model, "_quanto_map", None)
1101
1236
 
@@ -1194,6 +1329,7 @@ class offload:
1194
1329
  self.loaded_blocks = {}
1195
1330
  self.prev_blocks_names = {}
1196
1331
  self.next_blocks_names = {}
1332
+ self.lora_parents = {}
1197
1333
  self.preloaded_blocks_per_model = {}
1198
1334
  self.default_stream = torch.cuda.default_stream(torch.device("cuda")) # torch.cuda.current_stream()
1199
1335
  self.transfer_stream = torch.cuda.Stream()
@@ -1219,13 +1355,17 @@ class offload:
1219
1355
  if not prev_block_name == None:
1220
1356
  self.next_blocks_names[prev_entry_name] = entry_name
1221
1357
  bef = blocks_params_size
1358
+
1359
+ lora_name = None
1360
+ if self.lora_parents.get(submodule, None) != None:
1361
+ lora_name = str(submodule_name[ submodule_name.rfind(".") + 1: ] )
1362
+
1222
1363
  for k,p in submodule.named_parameters(recurse=False):
1223
1364
  param_size = 0
1224
1365
  ref = _get_tensor_ref(p)
1225
1366
  tied_param = self.parameters_ref.get(ref, None)
1226
-
1227
1367
  if isinstance(p, QTensor):
1228
- blocks_params.append( (submodule, k, p, False, tied_param ) )
1368
+ blocks_params.append( (submodule, k, p, False, tied_param, lora_name ) )
1229
1369
 
1230
1370
  if p._qtype == qint4:
1231
1371
  if hasattr(p,"_scale_shift"):
@@ -1239,7 +1379,7 @@ class offload:
1239
1379
  param_size += torch.numel(p._scale) * p._scale.element_size()
1240
1380
  param_size += torch.numel(p._data) * p._data.element_size()
1241
1381
  else:
1242
- blocks_params.append( (submodule, k, p, False, tied_param) )
1382
+ blocks_params.append( (submodule, k, p, False, tied_param, lora_name) )
1243
1383
  param_size += torch.numel(p.data) * p.data.element_size()
1244
1384
 
1245
1385
 
@@ -1248,7 +1388,7 @@ class offload:
1248
1388
  self.parameters_ref[ref] = (submodule, k)
1249
1389
 
1250
1390
  for k, p in submodule.named_buffers(recurse=False):
1251
- blocks_params.append( (submodule, k, p, True, None) )
1391
+ blocks_params.append( (submodule, k, p, True, None, lora_name) )
1252
1392
  blocks_params_size += p.data.nbytes
1253
1393
 
1254
1394
  aft = blocks_params_size
@@ -1283,7 +1423,11 @@ class offload:
1283
1423
  def cpu_to_gpu(stream_to_use, blocks_params): #, record_for_stream = None
1284
1424
  with torch.cuda.stream(stream_to_use):
1285
1425
  for param in blocks_params:
1286
- parent_module, n, p, is_buffer, tied_param = param
1426
+ parent_module, n, p, is_buffer, tied_param, lora_name = param
1427
+ if lora_name != None:
1428
+ if not lora_name in self.lora_parents[parent_module].active_adapters:
1429
+ continue
1430
+
1287
1431
  if tied_param != None:
1288
1432
  tied_p = getattr( tied_param[0], tied_param[1])
1289
1433
  if tied_p.is_cuda:
@@ -1353,8 +1497,8 @@ class offload:
1353
1497
  if blocks_name != None:
1354
1498
  self.loaded_blocks[model_id] = None
1355
1499
 
1356
- blocks_name = model_id if blocks_name is None else model_id + "/" + blocks_name
1357
1500
 
1501
+ blocks_name = model_id if blocks_name is None else model_id + "/" + blocks_name
1358
1502
  if self.verboseLevel >=2:
1359
1503
  model = self.models[model_id]
1360
1504
  model_name = model._get_name()
@@ -1362,7 +1506,7 @@ class offload:
1362
1506
 
1363
1507
  blocks_params = self.blocks_of_modules[blocks_name]
1364
1508
  for param in blocks_params:
1365
- parent_module, n, p, is_buffer, _ = param
1509
+ parent_module, n, p, is_buffer, _, _ = param
1366
1510
  if is_buffer:
1367
1511
  q = torch.nn.Buffer(p)
1368
1512
  else:
@@ -1889,6 +2033,14 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
1889
2033
 
1890
2034
  self.add_module_to_blocks(model_id, cur_blocks_name, submodule, prev_blocks_name, submodule_name)
1891
2035
 
2036
+ if hasattr(submodule, "active_adapters"):
2037
+ for dictmodule in ["lora_A","lora_B"]:
2038
+ ssubmod = getattr(submodule, dictmodule, None)
2039
+ if ssubmod !=None:
2040
+ for k, loramod in ssubmod._modules.items():
2041
+ self.lora_parents[loramod] = submodule
2042
+
2043
+
1892
2044
  self.tune_preloading(model_id, current_budget, towers_names)
1893
2045
 
1894
2046
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: mmgp
3
- Version: 3.1.4.post1592653
3
+ Version: 3.2.0
4
4
  Summary: Memory Management for the GPU Poor
5
5
  Author-email: deepbeepmeep <deepbeepmeep@yahoo.com>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -17,7 +17,7 @@ Requires-Dist: peft
17
17
 
18
18
 
19
19
  <p align="center">
20
- <H2>Memory Management 3.1.4-1592653 for the GPU Poor by DeepBeepMeep</H2>
20
+ <H2>Memory Management 3.2.0 for the GPU Poor by DeepBeepMeep</H2>
21
21
  </p>
22
22
 
23
23
 
@@ -0,0 +1,9 @@
1
+ __init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ mmgp/__init__.py,sha256=A9qBwyQMd1M7vshSTOBnFGP1MQvS2hXmTcTCMUcmyzE,509
3
+ mmgp/offload.py,sha256=tvByWwhjzGpGAlbefQQ2D5sS-fyVVnLXrNAD3u7qLiU,95236
4
+ mmgp/safetensors2.py,sha256=DCdlRH3769CTyraAmWAB3b0XrVua7z6ygQ-OyKgJN6A,16453
5
+ mmgp-3.2.0.dist-info/LICENSE.md,sha256=HjzvY2grdtdduZclbZ46B2M-XpT4MDCxFub5ZwTWq2g,93
6
+ mmgp-3.2.0.dist-info/METADATA,sha256=Q4Rfjsz_M4fDVGuPGUdodQ4N1HUuKqCizCKO3CHzubg,15934
7
+ mmgp-3.2.0.dist-info/WHEEL,sha256=nn6H5-ilmfVryoAQl3ZQ2l8SH5imPWFpm1A5FgEuFV4,91
8
+ mmgp-3.2.0.dist-info/top_level.txt,sha256=waGaepj2qVfnS2yAOkaMu4r9mJaVjGbEi6AwOUogU_U,14
9
+ mmgp-3.2.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.0)
2
+ Generator: setuptools (75.8.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,9 +0,0 @@
1
- __init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- mmgp/__init__.py,sha256=A9qBwyQMd1M7vshSTOBnFGP1MQvS2hXmTcTCMUcmyzE,509
3
- mmgp/offload.py,sha256=M8TpqTbDT8T0uFVo2bJlcww_hh9unCqFcUlioC6B-3E,87183
4
- mmgp/safetensors2.py,sha256=DCdlRH3769CTyraAmWAB3b0XrVua7z6ygQ-OyKgJN6A,16453
5
- mmgp-3.1.4.post1592653.dist-info/LICENSE.md,sha256=HjzvY2grdtdduZclbZ46B2M-XpT4MDCxFub5ZwTWq2g,93
6
- mmgp-3.1.4.post1592653.dist-info/METADATA,sha256=zjGqw7sj5EzuLUqLFDOnt6k6pyy2ynX47fK3fbXZiPE,15954
7
- mmgp-3.1.4.post1592653.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
8
- mmgp-3.1.4.post1592653.dist-info/top_level.txt,sha256=waGaepj2qVfnS2yAOkaMu4r9mJaVjGbEi6AwOUogU_U,14
9
- mmgp-3.1.4.post1592653.dist-info/RECORD,,