mmgp 3.4.6__py3-none-any.whl → 3.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mmgp might be problematic. Click here for more details.

mmgp/offload.py CHANGED
@@ -1,4 +1,4 @@
1
- # ------------------ Memory Management 3.4.5 for the GPU Poor by DeepBeepMeep (mmgp)------------------
1
+ # ------------------ Memory Management 3.4.8 for the GPU Poor by DeepBeepMeep (mmgp)------------------
2
2
  #
3
3
  # This module contains multiples optimisations so that models such as Flux (and derived), Mochi, CogView, HunyuanVideo, ... can run smoothly on a 24 GB GPU limited card.
4
4
  # This a replacement for the accelerate library that should in theory manage offloading, but doesn't work properly with models that are loaded / unloaded several
@@ -331,12 +331,35 @@ def _extract_tie_weights_from_sd(sd , sd_name, verboseLevel =1):
331
331
 
332
332
  def _pin_sd_to_memory(sd, sd_name, tied_weights = None, gig_tensor_size = BIG_TENSOR_MAX_SIZE, verboseLevel = 1):
333
333
  global max_pinnable_bytes, total_pinned_bytes
334
+
335
+
336
+ names_list = sd_name if isinstance(sd, list) else [sd_name]
337
+
334
338
  if max_pinnable_bytes > 0 and total_pinned_bytes >= max_pinnable_bytes:
335
339
 
336
340
  if verboseLevel>=1 :
337
- print(f"Unable pin data of '{sd_name}' to reserved RAM as there is no reserved RAM left")
341
+ print(f"Unable pin data of '{','.join(names_list)}' to reserved RAM as there is no reserved RAM left")
338
342
  return
339
343
 
344
+
345
+ if isinstance(sd, list):
346
+ new_sd = {}
347
+ for i, sub_sd, in enumerate(sd):
348
+ for k, v in sub_sd.items():
349
+ new_sd[str(i) + "#" + k] =v
350
+ sd = new_sd
351
+ del new_sd
352
+ sub_sd = None
353
+
354
+ if isinstance(tied_weights, list):
355
+ new_tied_weights = {}
356
+ for i, sub_tied_weights, in enumerate(tied_weights):
357
+ for k, v in sub_tied_weights.items():
358
+ new_tied_weights[str(i) + "#" + k] =v
359
+ sd = new_tied_weights
360
+ del new_tied_weights
361
+ sub_tied_weights = None
362
+
340
363
  current_big_tensor_size = 0
341
364
  big_tensor_no = 0
342
365
  big_tensors_sizes = []
@@ -365,11 +388,14 @@ def _pin_sd_to_memory(sd, sd_name, tied_weights = None, gig_tensor_size = BIG_TE
365
388
  big_tensors = []
366
389
  last_big_tensor = 0
367
390
  total = 0
391
+ incomplete_pinning = False
368
392
 
369
393
  try:
370
394
  dummy_pinned_tensor = torch.empty( RESERVED_RAM_MIN_AVAILABLE, dtype= torch.uint8, pin_memory=True, device="cpu")
371
395
  except:
372
396
  print("There isn't any Reserved RAM left, you may need to choose a profile with a higher number that requires less Reserved RAM or set OS env 'perc_reserved_mem_max' to a value less 0.3")
397
+ gc.collect()
398
+ torch.cuda.empty_cache()
373
399
  return
374
400
 
375
401
  for size in big_tensors_sizes:
@@ -377,6 +403,7 @@ def _pin_sd_to_memory(sd, sd_name, tied_weights = None, gig_tensor_size = BIG_TE
377
403
  current_big_tensor = torch.empty( size, dtype= torch.uint8, pin_memory=True, device="cpu")
378
404
  big_tensors.append(current_big_tensor)
379
405
  except:
406
+ incomplete_pinning = True
380
407
  print(f"Unable to pin more tensors for '{sd_name}' as the maximum reservable memory has been reached ({total/ONE_MB:.2f})")
381
408
  break
382
409
 
@@ -410,9 +437,21 @@ def _pin_sd_to_memory(sd, sd_name, tied_weights = None, gig_tensor_size = BIG_TE
410
437
  # global total_pinned_bytes
411
438
  # total_pinned_bytes += total
412
439
  gc.collect()
440
+ torch.cuda.empty_cache()
441
+
413
442
 
414
443
  if verboseLevel >=1:
415
- print(f"'{sd_name}' was pinned entirely to reserved RAM: {last_big_tensor} large blocks spread across {total/ONE_MB:.2f} MB")
444
+ if incomplete_pinning :
445
+ if len(names_list) > 0:
446
+ print(f"'{','.join(names_list)}' were partially pinned to reserved RAM: {last_big_tensor} large blocks spread across {total/ONE_MB:.2f} MB")
447
+ else:
448
+ print(f"'{','.join(names_list)}' was partially pinned to reserved RAM: {last_big_tensor} large blocks spread across {total/ONE_MB:.2f} MB")
449
+ else:
450
+ if len(names_list) > 0:
451
+ print(f"'{','.join(names_list)}' were pinned entirely to reserved RAM: {last_big_tensor} large blocks spread across {total/ONE_MB:.2f} MB")
452
+ else:
453
+ print(f"'{','.join(names_list)}' was pinned entirely to reserved RAM: {last_big_tensor} large blocks spread across {total/ONE_MB:.2f} MB")
454
+
416
455
 
417
456
  return
418
457
 
@@ -619,7 +658,7 @@ def _welcome():
619
658
  if welcome_displayed:
620
659
  return
621
660
  welcome_displayed = True
622
- print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.4.5) by DeepBeepMeep ************{ENDC}{UNBOLD}")
661
+ print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.4.8) by DeepBeepMeep ************{ENDC}{UNBOLD}")
623
662
 
624
663
  def change_dtype(model, new_dtype, exclude_buffers = False):
625
664
  for submodule_name, submodule in model.named_modules():
@@ -961,6 +1000,8 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
961
1000
  errors = []
962
1001
  adapters = {}
963
1002
  adapter_no = 0
1003
+ pinned_sd_list = []
1004
+ pinned_names_list = []
964
1005
  for i, path in enumerate(lora_path):
965
1006
  adapter_name = str(adapter_no)
966
1007
  error_msg = ""
@@ -1042,28 +1083,37 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
1042
1083
  invalid_keys = []
1043
1084
  unexpected_keys = []
1044
1085
  for k, v in state_dict.items():
1045
- pos = k.rfind(".lora_")
1046
- if pos <=0:
1047
- invalid_keys.append(k)
1048
- continue
1049
- module_name = k[ : pos]
1050
- lora_key = k[ pos+1:]
1051
1086
  lora_A = None
1052
1087
  lora_B = None
1053
- if lora_key == "lora_A.weight":
1054
- lora_A = v
1055
- elif lora_key == "lora_B.weight":
1056
- lora_B = v
1088
+ diff_b = None
1089
+ diff = None
1090
+ if k.endswith(".diff"):
1091
+ diff = v
1092
+ module_name = k[ : -5]
1093
+ elif k.endswith(".diff_b"):
1094
+ diff_b = v
1095
+ module_name = k[ : -7]
1057
1096
  else:
1058
- invalid_keys.append(k)
1059
- continue
1097
+ pos = k.rfind(".lora_")
1098
+ if pos <=0:
1099
+ invalid_keys.append(k)
1100
+ continue
1101
+ module_name = k[ : pos]
1102
+ lora_key = k[ pos+1:]
1103
+ if lora_key in ("lora_A.weight", "lora_down.weight"):
1104
+ lora_A = v
1105
+ elif lora_key in ("lora_B.weight", "lora_up.weight"):
1106
+ lora_B = v
1107
+ else:
1108
+ invalid_keys.append(k)
1109
+ continue
1060
1110
 
1061
1111
  module = modules_dict.get(module_name, None)
1062
1112
  if module == None:
1063
1113
  unexpected_keys.append(k)
1064
1114
  continue
1065
- if not isinstance(module, (QLinear, torch.nn.Linear)):
1066
- msg = f"Lora '{path}' contains a non linear layer '{k}'"
1115
+ if False: #not isinstance(module, (QLinear, torch.nn.Linear, torch.nn.Conv3d, torch.nn.LayerNorm)):
1116
+ msg = f"Lora '{path}' contains a non supported type of layer '{k}'"
1067
1117
  error_msg = append(error_msg, msg)
1068
1118
  fail = True
1069
1119
  break
@@ -1077,7 +1127,7 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
1077
1127
  error_msg = append(error_msg, msg)
1078
1128
  fail = True
1079
1129
  break
1080
- if lora_B != None:
1130
+ elif lora_B != None:
1081
1131
  if module_shape[0] != v.shape[0]:
1082
1132
  if ignore_model_variations:
1083
1133
  skip = True
@@ -1086,28 +1136,56 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
1086
1136
  error_msg = append(error_msg, msg)
1087
1137
  fail = True
1088
1138
  break
1139
+ elif diff != None:
1140
+ lora_B = diff
1141
+ if module_shape != v.shape:
1142
+ if ignore_model_variations:
1143
+ skip = True
1144
+ else:
1145
+ msg = f"Lora '{path}': Lora shape is not compatible with model '{_get_module_name(model)}' (model = {module_shape[0]}, lora = {v.shape[0]}). It is likely this Lora has been made for another version of this model."
1146
+ error_msg = append(error_msg, msg)
1147
+ fail = True
1148
+ break
1149
+ elif diff_b != None:
1150
+ if module.bias == None:
1151
+ msg = f"Lora '{path}': Lora Basis is defined while it doesnt exist in model '{_get_module_name(model)}'. It is likely this Lora has been made for another version of this model."
1152
+ fail = True
1153
+ break
1154
+ else:
1155
+ module_shape = module.bias.shape
1156
+ if module_shape != v.shape:
1157
+ if ignore_model_variations:
1158
+ skip = True
1159
+ else:
1160
+ msg = f"Lora '{path}': Lora Basis dimension is not compatible with model '{_get_module_name(model)}' (model = {module_shape[0]}, lora Basis = {v.shape[0]}). It is likely this Lora has been made for another version of this model."
1161
+ error_msg = append(error_msg, msg)
1162
+ fail = True
1163
+ break
1164
+
1089
1165
  if not check_only:
1090
1166
  loras_module_data = loras_model_data.get(module, None)
1167
+ if loras_module_data == None:
1168
+ pass
1091
1169
  assert loras_module_data != None
1092
- # if loras_module_data == None:
1093
- # loras_module_data = dict()
1094
- # loras_model_data[module] = loras_module_data
1095
1170
  loras_adapter_data = loras_module_data.get(adapter_name, None)
1096
1171
  lora_A = None if lora_A == None else lora_A.to(module.weight.dtype)
1097
1172
  lora_B = None if lora_B == None else lora_B.to(module.weight.dtype)
1173
+ diff_b = None if diff_b == None else diff_b.to(module.weight.dtype)
1098
1174
  if loras_adapter_data == None:
1099
1175
  alpha = lora_alphas.get(k[:-len("lora_X.weight")] + "alpha", 1.)
1100
- loras_adapter_data = [lora_A, lora_B, alpha]
1176
+ loras_adapter_data = [lora_A, lora_B, diff_b, alpha]
1101
1177
  loras_module_data[adapter_name] = loras_adapter_data
1102
1178
  elif lora_A != None:
1103
1179
  loras_adapter_data[0] = lora_A
1104
- else:
1180
+ elif lora_B != None:
1105
1181
  loras_adapter_data[1] = lora_B
1106
- lora_A, lora_B, v, loras_module_data, loras_adapter_data = None, None, None, None, None
1182
+ else:
1183
+ loras_adapter_data[2] = diff_b
1184
+ lora_A, lora_B, diff, diff_b, v, loras_module_data, loras_adapter_data = None, None, None, None, None, None, None
1107
1185
  lora_alphas = None
1108
1186
 
1109
1187
  if len(invalid_keys) > 0:
1110
- msg = "Lora '{path}' contains non Lora keys '{trunc(invalid_keys,200)}'"
1188
+ msg = f"Lora '{path}' contains non Lora keys '{trunc(invalid_keys,200)}'"
1111
1189
  error_msg = append(error_msg, msg)
1112
1190
  fail = True
1113
1191
  if len(unexpected_keys) > 0:
@@ -1127,7 +1205,9 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
1127
1205
  if not check_only:
1128
1206
  # model._loras_tied_weights[adapter_name] = tied_weights
1129
1207
  if pinnedLora:
1130
- _pin_sd_to_memory(state_dict, path)
1208
+ pinned_sd_list.append(state_dict)
1209
+ pinned_names_list.append(path)
1210
+ # _pin_sd_to_memory(state_dict, path)
1131
1211
 
1132
1212
  del state_dict
1133
1213
 
@@ -1146,6 +1226,8 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
1146
1226
 
1147
1227
  model._loras_errors = errors
1148
1228
  if not check_only:
1229
+ if pinnedLora and len(pinned_sd_list) > 0:
1230
+ _pin_sd_to_memory(pinned_sd_list, pinned_names_list)
1149
1231
  model._loras_adapters = adapters
1150
1232
  if activate_all_loras:
1151
1233
  activate_loras(model, loras_nos, loras_multi)
@@ -1193,7 +1275,7 @@ def move_loras_to_device(model, device="cpu" ):
1193
1275
  if ".lora_" in k:
1194
1276
  m.to(device)
1195
1277
 
1196
- def fast_load_transformers_model(model_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, forcedConfigPath = None, modelClass=None, modelPrefix = None, writable_tensors = True, verboseLevel = -1):
1278
+ def fast_load_transformers_model(model_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, forcedConfigPath = None, modelClass=None, modelPrefix = None, writable_tensors = True, verboseLevel = -1, configKwargs ={}):
1197
1279
  """
1198
1280
  quick version of .LoadfromPretrained of the transformers library
1199
1281
  used to build a model and load the corresponding weights (quantized or not)
@@ -1235,6 +1317,7 @@ def fast_load_transformers_model(model_path: str, do_quantize = False, quantizat
1235
1317
  text = reader.read()
1236
1318
  transformer_config= json.loads(text)
1237
1319
 
1320
+ transformer_config.update( configKwargs )
1238
1321
 
1239
1322
  if "architectures" in transformer_config:
1240
1323
  architectures = transformer_config["architectures"]
@@ -1254,7 +1337,6 @@ def fast_load_transformers_model(model_path: str, do_quantize = False, quantizat
1254
1337
  fp.close()
1255
1338
  config_obj = AutoConfig.from_pretrained(fp.name)
1256
1339
  os.remove(fp.name)
1257
-
1258
1340
  #needed to keep inits of non persistent buffers
1259
1341
  with init_empty_weights():
1260
1342
  model = transfomer_class(config_obj)
@@ -1270,7 +1352,7 @@ def fast_load_transformers_model(model_path: str, do_quantize = False, quantizat
1270
1352
  transfomer_class = getattr(module, class_name)
1271
1353
 
1272
1354
  with init_empty_weights():
1273
- model = transfomer_class.from_config(transformer_config)
1355
+ model = transfomer_class.from_config(transformer_config )
1274
1356
 
1275
1357
 
1276
1358
  torch.set_default_device('cpu')
@@ -1325,14 +1407,14 @@ def load_model_data(model, file_path: str, do_quantize = False, quantizationType
1325
1407
  if not (".safetensors" in file or ".sft" in file):
1326
1408
  if pinToMemory:
1327
1409
  raise Exception("Pinning to memory while loading only supported for safe tensors files")
1328
- state_dict = torch.load(file, weights_only=True)
1410
+ state_dict = torch.load(file, weights_only=True, map_location="cpu")
1329
1411
  if "module" in state_dict:
1330
1412
  state_dict = state_dict["module"]
1331
1413
 
1332
1414
  else:
1333
1415
  basename = os.path.basename(file)
1334
1416
 
1335
- if "model-0" in basename:
1417
+ if "-of-" in basename:
1336
1418
  metadata = None
1337
1419
  file_parts= basename.split("-")
1338
1420
  parts_max = int(file_parts[-1][:5])
@@ -1539,9 +1621,12 @@ class HfHook:
1539
1621
  def __init__(self):
1540
1622
  self.execution_device = "cuda"
1541
1623
 
1542
- def detach_hook(self, module):
1543
- pass
1624
+ def init_hook(self, module):
1625
+ return module
1544
1626
 
1627
+ def detach_hook(self, module):
1628
+ return module
1629
+
1545
1630
  last_offload_obj = None
1546
1631
  class offload:
1547
1632
  def __init__(self):
@@ -1650,10 +1735,9 @@ class offload:
1650
1735
  lora_data = lora_module.get(adapter, None)
1651
1736
  if lora_data == None:
1652
1737
  continue
1653
- lora_A, lora_B, alpha = lora_data
1654
1738
  key = adapter + '_GPU'
1655
1739
  if to_GPU:
1656
- lora_module[key] = [lora_A.cuda(non_blocking=True), lora_B.cuda(non_blocking=True), alpha]
1740
+ lora_module[key] = [None if item == None else item.cuda(non_blocking=True) for item in lora_data[ :-1] ] + lora_data[ -1:]
1657
1741
  elif key in lora_module:
1658
1742
  del lora_module[key]
1659
1743
 
@@ -1876,27 +1960,64 @@ class offload:
1876
1960
 
1877
1961
  return False
1878
1962
 
1879
- def _lora_linear_forward(self, model, submodule, loras_data, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1963
+ def _get_lora_scaling(self, loras_scaling, model, active_adapter):
1964
+ scaling_list = loras_scaling[active_adapter]
1965
+ if isinstance(scaling_list, list):
1966
+ step_no =getattr(model, "_lora_step_no", 0)
1967
+ return scaling_list[step_no]
1968
+ else:
1969
+ return float(scaling_list)
1880
1970
 
1881
- def get_scaling(active_adapter):
1882
- scaling_list = loras_scaling[active_adapter]
1883
- if isinstance(scaling_list, list):
1884
- step_no =getattr(model, "_lora_step_no", 0)
1885
- return scaling_list[step_no]
1886
- else:
1887
- return float(scaling_list)
1888
1971
 
1889
- weight = submodule.weight
1890
1972
 
1891
- if loras_data == None:
1892
- return torch.nn.functional.linear(x, weight, bias=submodule.bias)
1973
+ def _lora_generic_forward(self, model, submodule, loras_data, func, *args, **kwargs) -> torch.Tensor:
1974
+
1975
+ weight = submodule.weight
1976
+ bias = getattr(submodule, "bias", None)
1977
+ original_weight = None
1978
+ original_bias = None
1979
+ active_adapters = model._loras_active_adapters
1980
+ loras_scaling = model._loras_scaling
1981
+ first_weight = True
1982
+ first_bias = True
1983
+ for active_adapter in active_adapters:
1984
+ data = loras_data.get(active_adapter + '_GPU', None)
1985
+ if data == None:
1986
+ continue
1987
+ diff_w , _ , diff_b, alpha = data
1988
+ if first_weight:
1989
+ original_weight= weight.clone() if weight != None else None
1990
+ first_weight = False
1991
+ if first_bias:
1992
+ original_bias= bias.clone() if bias != None else None
1993
+ first_bias = False
1994
+ scaling = self._get_lora_scaling( loras_scaling, model, active_adapter) * alpha
1995
+ if diff_w != None:
1996
+ weight.add_(diff_w, alpha= scaling)
1997
+ diff_w = None
1998
+ if diff_b != None:
1999
+ bias.add_(diff_b, alpha= scaling)
2000
+ diff_b = None
2001
+
2002
+ ret = func(*args, **kwargs )
2003
+
2004
+ weight.data = original_weight if original_weight != None else None
2005
+ if original_bias != None:
2006
+ bias.data = original_bias
2007
+
2008
+ return ret
2009
+
1893
2010
 
2011
+ def _lora_linear_forward(self, model, submodule, loras_data, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
2012
+ weight = submodule.weight
1894
2013
  active_adapters = model._loras_active_adapters
1895
2014
  loras_scaling = model._loras_scaling
1896
2015
  training = False
1897
2016
 
1898
2017
  dtype = weight.dtype
1899
2018
  if weight.shape[-1] < x.shape[-2] : # sum base weight and lora matrices instead of applying input on each sub lora matrice if input is too large. This will save a lot VRAM and compute
2019
+ bias = submodule.bias
2020
+ original_bias = True
1900
2021
  if len(active_adapters) > 0:
1901
2022
  if isinstance(submodule, QModuleMixin):
1902
2023
  weight = weight.view(weight.shape) # get a persistent copy of the on the fly dequantized weights
@@ -1908,16 +2029,27 @@ class offload:
1908
2029
  data = loras_data.get(active_adapter + '_GPU', None)
1909
2030
  if data == None:
1910
2031
  continue
1911
- lora_A_weight, lora_B_weight, alpha = data
1912
- scaling = get_scaling(active_adapter) * alpha
1913
- weight.addmm_(lora_B_weight, lora_A_weight, alpha= scaling )
2032
+ lora_A_weight, lora_B_weight, diff_b, alpha = data
2033
+ scaling = self._get_lora_scaling(loras_scaling, model, active_adapter) * alpha
2034
+ if lora_A_weight != None:
2035
+ weight.addmm_(lora_B_weight, lora_A_weight, alpha= scaling )
2036
+
2037
+ if diff_b != None:
2038
+ if bias == None:
2039
+ bias = diff_b.clone()
2040
+ original_bias = False
2041
+ elif original_bias:
2042
+ bias = bias.clone()
2043
+ original_bias = False
2044
+ bias.add_(diff_b, alpha=scaling)
2045
+
1914
2046
  # base_weight += scaling * lora_B_weight @ lora_A_weight
1915
2047
 
1916
2048
  if training:
1917
2049
  pass
1918
2050
  # result = torch.nn.functional.linear(dropout(x), base_weight, bias=submodule.bias)
1919
2051
  else:
1920
- result = torch.nn.functional.linear(x, weight, bias=submodule.bias)
2052
+ result = torch.nn.functional.linear(x, weight, bias=bias)
1921
2053
 
1922
2054
  else:
1923
2055
  result = torch.nn.functional.linear(x, weight, bias=submodule.bias)
@@ -1929,38 +2061,48 @@ class offload:
1929
2061
  data = loras_data.get(active_adapter + '_GPU', None)
1930
2062
  if data == None:
1931
2063
  continue
1932
- lora_A, lora_B, alpha = data
2064
+ lora_A, lora_B, diff_b, alpha = data
1933
2065
  # dropout = self.lora_dropout[active_adapter]
1934
- scaling = get_scaling(active_adapter) * alpha
1935
- x = x.to(lora_A.dtype)
1936
-
1937
- if training:
1938
- pass
1939
- # y = lora_A(dropout(x))
2066
+ scaling = self._get_lora_scaling(loras_scaling, model, active_adapter) * alpha
2067
+ if lora_A == None:
2068
+ result.add_(diff_b, alpha=scaling)
1940
2069
  else:
1941
- y = torch.nn.functional.linear(x, lora_A, bias=None)
2070
+ x = x.to(lora_A.dtype)
1942
2071
 
1943
- y = torch.nn.functional.linear(y, lora_B, bias=None)
1944
- y*= scaling
1945
- result+= y
1946
- del y
2072
+ if training:
2073
+ pass
2074
+ # y = lora_A(dropout(x))
2075
+ else:
2076
+ y = torch.nn.functional.linear(x, lora_A, bias=None)
2077
+ y = torch.nn.functional.linear(y, lora_B, bias=diff_b)
2078
+ y*= scaling
2079
+ result+= y
2080
+ del y
1947
2081
 
1948
2082
  return result
1949
2083
 
1950
2084
 
1951
- def hook_lora_linear(self, submodule, current_model, model_id, loras_model_data, submodule_name):
2085
+ def hook_lora(self, submodule, current_model, model_id, loras_model_data, submodule_name):
1952
2086
  old_forward = submodule.forward
1953
2087
 
1954
2088
  loras_data = {}
1955
2089
  loras_model_data[submodule] = loras_data
1956
2090
 
1957
- def lora_linear_forward(module, *args, **kwargs):
1958
- if len(loras_data) == 0:
1959
- return old_forward(*args, **kwargs)
1960
- else:
1961
- return self._lora_linear_forward(current_model, submodule, loras_data, *args, **kwargs)
1962
-
1963
- return functools.update_wrapper(functools.partial(lora_linear_forward, submodule), old_forward)
2091
+ if isinstance(submodule, torch.nn.Linear):
2092
+ def lora_linear_forward(module, *args, **kwargs):
2093
+ if len(loras_data) == 0:
2094
+ return old_forward(*args, **kwargs)
2095
+ else:
2096
+ return self._lora_linear_forward(current_model, submodule, loras_data, *args, **kwargs)
2097
+ target_fn = lora_linear_forward
2098
+ else:
2099
+ def lora_generic_forward(module, *args, **kwargs):
2100
+ if len(loras_data) == 0:
2101
+ return old_forward(*args, **kwargs)
2102
+ else:
2103
+ return self._lora_generic_forward(current_model, submodule, loras_data, old_forward, *args, **kwargs)
2104
+ target_fn = lora_generic_forward
2105
+ return functools.update_wrapper(functools.partial(target_fn, submodule), old_forward)
1964
2106
 
1965
2107
  def ensure_model_loaded(self, model_id):
1966
2108
  if model_id in self.active_models_ids:
@@ -2271,7 +2413,6 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, p
2271
2413
  model_dtype = getattr(current_model, "_model_dtype", None)
2272
2414
  # if model_dtype == None:
2273
2415
  # model_dtype = getattr(current_model, "dtype", None)
2274
-
2275
2416
  for _ , m in current_model.named_modules():
2276
2417
  ignore_dtype = hasattr(m, "_lock_dtype")
2277
2418
  for n, p in m.named_parameters(recurse = False):
@@ -2413,8 +2554,9 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, p
2413
2554
 
2414
2555
 
2415
2556
  if hasattr(submodule, "forward"):
2416
- if any_lora and isinstance(submodule, torch.nn.Linear):
2417
- submodule_method = self.hook_lora_linear(submodule, current_model, model_id, loras_model_data, submodule_name)
2557
+ # if any_lora and isinstance(submodule, ( torch.nn.Linear, torch.nn.Conv3d, torch.nn.LayerNorm)):
2558
+ if any_lora and hasattr(submodule,"weight"):
2559
+ submodule_method = self.hook_lora(submodule, current_model, model_id, loras_model_data, submodule_name)
2418
2560
  else:
2419
2561
  submodule_method = getattr(submodule, "forward")
2420
2562
  if callable(submodule_method):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mmgp
3
- Version: 3.4.6
3
+ Version: 3.4.8
4
4
  Summary: Memory Management for the GPU Poor
5
5
  Author-email: deepbeepmeep <deepbeepmeep@yahoo.com>
6
6
  Requires-Python: >=3.10
@@ -15,7 +15,7 @@ Dynamic: license-file
15
15
 
16
16
 
17
17
  <p align="center">
18
- <H2>Memory Management 3.4.6 for the GPU Poor by DeepBeepMeep</H2>
18
+ <H2>Memory Management 3.4.8 for the GPU Poor by DeepBeepMeep</H2>
19
19
  </p>
20
20
 
21
21
 
@@ -0,0 +1,9 @@
1
+ __init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ mmgp/__init__.py,sha256=A9qBwyQMd1M7vshSTOBnFGP1MQvS2hXmTcTCMUcmyzE,509
3
+ mmgp/offload.py,sha256=UhILpsjJdWDv0IzOeis9KMgmPzcwZFsfPU04BLk_3To,121471
4
+ mmgp/safetensors2.py,sha256=4nKV13qCMabnNEB1TA_ueFbfGYYmiQ9racR_C6SsGug,18693
5
+ mmgp-3.4.8.dist-info/licenses/LICENSE.md,sha256=HjzvY2grdtdduZclbZ46B2M-XpT4MDCxFub5ZwTWq2g,93
6
+ mmgp-3.4.8.dist-info/METADATA,sha256=Ux77MBs2BZl3fDw5BeJyOPZgyra7eyk4c4PFpmQGhtk,16309
7
+ mmgp-3.4.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
8
+ mmgp-3.4.8.dist-info/top_level.txt,sha256=waGaepj2qVfnS2yAOkaMu4r9mJaVjGbEi6AwOUogU_U,14
9
+ mmgp-3.4.8.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.8.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,9 +0,0 @@
1
- __init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- mmgp/__init__.py,sha256=A9qBwyQMd1M7vshSTOBnFGP1MQvS2hXmTcTCMUcmyzE,509
3
- mmgp/offload.py,sha256=2oWFiDcwIx3lGOb_6_aac1zzIIF-nhP8bwOA-G9HxsU,114594
4
- mmgp/safetensors2.py,sha256=4nKV13qCMabnNEB1TA_ueFbfGYYmiQ9racR_C6SsGug,18693
5
- mmgp-3.4.6.dist-info/licenses/LICENSE.md,sha256=HjzvY2grdtdduZclbZ46B2M-XpT4MDCxFub5ZwTWq2g,93
6
- mmgp-3.4.6.dist-info/METADATA,sha256=kv9OfYHAAHKyiv9p9vrf4guU3tNd0I7vUgQ6xm7dkk8,16309
7
- mmgp-3.4.6.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
8
- mmgp-3.4.6.dist-info/top_level.txt,sha256=waGaepj2qVfnS2yAOkaMu4r9mJaVjGbEi6AwOUogU_U,14
9
- mmgp-3.4.6.dist-info/RECORD,,