mmgp 3.5.1__py3-none-any.whl → 3.5.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mmgp might be problematic. Click here for more details.

mmgp/offload.py CHANGED
@@ -1,4 +1,4 @@
1
- # ------------------ Memory Management 3.5.1 for the GPU Poor by DeepBeepMeep (mmgp)------------------
1
+ # ------------------ Memory Management 3.5.5 for the GPU Poor by DeepBeepMeep (mmgp)------------------
2
2
  #
3
3
  # This module contains multiples optimisations so that models such as Flux (and derived), Mochi, CogView, HunyuanVideo, ... can run smoothly on a 24 GB GPU limited card.
4
4
  # This a replacement for the accelerate library that should in theory manage offloading, but doesn't work properly with models that are loaded / unloaded several
@@ -253,17 +253,17 @@ def _remove_model_wrapper(model):
253
253
  def _move_to_pinned_tensor(source_tensor, big_tensor, offset, length):
254
254
  dtype= source_tensor.dtype
255
255
  shape = source_tensor.shape
256
- if len(shape) == 0:
257
- return source_tensor
258
- else:
256
+ if len(shape) > 0 :
259
257
  t = source_tensor.view(torch.uint8)
260
258
  t = torch.reshape(t, (length,))
261
- # magic swap !
262
- big_tensor[offset: offset + length] = t
263
- t = big_tensor[offset: offset + length]
264
- t = t.view(dtype)
265
- t = torch.reshape(t, shape)
266
- assert t.is_pinned()
259
+ else:
260
+ t = source_tensor
261
+ # magic swap !
262
+ big_tensor[offset: offset + length] = t
263
+ t = big_tensor[offset: offset + length]
264
+ t = t.view(dtype)
265
+ t = torch.reshape(t, shape)
266
+ assert t.is_pinned()
267
267
  return t
268
268
 
269
269
  def _safetensors_load_file(file_path, writable_tensors = True):
@@ -336,9 +336,8 @@ def _pin_sd_to_memory(sd, sd_name, tied_weights = None, gig_tensor_size = BIG_TE
336
336
  names_list = sd_name if isinstance(sd, list) else [sd_name]
337
337
 
338
338
  if max_pinnable_bytes > 0 and total_pinned_bytes >= max_pinnable_bytes:
339
-
340
339
  if verboseLevel>=1 :
341
- print(f"Unable pin data of '{','.join(names_list)}' to reserved RAM as there is no reserved RAM left")
340
+ print(f"Unable to pin data of '{','.join(names_list)}' to reserved RAM as there is no reserved RAM left. Transfer speed from RAM to VRAM will may be slower.")
342
341
  return
343
342
 
344
343
 
@@ -404,7 +403,7 @@ def _pin_sd_to_memory(sd, sd_name, tied_weights = None, gig_tensor_size = BIG_TE
404
403
  big_tensors.append(current_big_tensor)
405
404
  except:
406
405
  incomplete_pinning = True
407
- print(f"Unable to pin more tensors for '{sd_name}' as the maximum reservable memory has been reached ({total/ONE_MB:.2f})")
406
+ print(f"Unable to pin more tensors for '{sd_name}' as the maximum reservable memory has been reached ({total/ONE_MB:.2f}). Transfer speed from RAM to VRAM may be slower.")
408
407
  break
409
408
 
410
409
  last_big_tensor += 1
@@ -442,12 +441,12 @@ def _pin_sd_to_memory(sd, sd_name, tied_weights = None, gig_tensor_size = BIG_TE
442
441
 
443
442
  if verboseLevel >=1:
444
443
  if incomplete_pinning :
445
- if len(names_list) > 0:
444
+ if len(names_list) > 1:
446
445
  print(f"'{','.join(names_list)}' were partially pinned to reserved RAM: {last_big_tensor} large blocks spread across {total/ONE_MB:.2f} MB")
447
446
  else:
448
447
  print(f"'{','.join(names_list)}' was partially pinned to reserved RAM: {last_big_tensor} large blocks spread across {total/ONE_MB:.2f} MB")
449
448
  else:
450
- if len(names_list) > 0:
449
+ if len(names_list) > 1:
451
450
  print(f"'{','.join(names_list)}' were pinned entirely to reserved RAM: {last_big_tensor} large blocks spread across {total/ONE_MB:.2f} MB")
452
451
  else:
453
452
  print(f"'{','.join(names_list)}' was pinned entirely to reserved RAM: {last_big_tensor} large blocks spread across {total/ONE_MB:.2f} MB")
@@ -462,7 +461,7 @@ def _pin_to_memory(model, model_id, partialPinning = False, pinnedPEFTLora = Tru
462
461
  if max_pinnable_bytes > 0 and total_pinned_bytes >= max_pinnable_bytes:
463
462
 
464
463
  if verboseLevel>=1 :
465
- print(f"Unable pin data of '{model_id}' to reserved RAM as there is no reserved RAM left")
464
+ print(f"Unable to pin data of '{model_id}' to reserved RAM as there is no reserved RAM left. Transfer speed from RAM to VRAM may be slower.")
466
465
  return
467
466
 
468
467
  if partialPinning:
@@ -499,7 +498,7 @@ def _pin_to_memory(model, model_id, partialPinning = False, pinnedPEFTLora = Tru
499
498
  else:
500
499
  print(f"Pinning data of '{model_id}' to reserved RAM")
501
500
 
502
- if partialPinning and len(params_dict) == 0:
501
+ if len(params_dict) == 0:
503
502
  return
504
503
 
505
504
  ref_cache = {}
@@ -521,13 +520,22 @@ def _pin_to_memory(model, model_id, partialPinning = False, pinnedPEFTLora = Tru
521
520
  else:
522
521
  if isinstance(p, QTensor):
523
522
  if p._qtype == qint4:
523
+ if p._data._data.is_pinned():
524
+ params_dict[n] = (None, False)
525
+ continue
524
526
  if hasattr(p,"_scale_shift"):
525
527
  length = torch.numel(p._data._data) * p._data._data.element_size() + torch.numel(p._scale_shift) * p._scale_shift.element_size()
526
528
  else:
527
529
  length = torch.numel(p._data._data) * p._data._data.element_size() + torch.numel(p._scale) * p._scale.element_size() + torch.numel(p._shift) * p._shift.element_size()
528
530
  else:
529
531
  length = torch.numel(p._data) * p._data.element_size() + torch.numel(p._scale) * p._scale.element_size()
532
+ if p._data.is_pinned():
533
+ params_dict[n] = (None, False)
534
+ continue
530
535
  else:
536
+ if p.data.is_pinned():
537
+ params_dict[n] = (None, False)
538
+ continue
531
539
  length = torch.numel(p.data) * p.data.element_size()
532
540
 
533
541
  ref_cache[ref] = (n, length)
@@ -544,7 +552,7 @@ def _pin_to_memory(model, model_id, partialPinning = False, pinnedPEFTLora = Tru
544
552
  current_big_tensor_size += length
545
553
 
546
554
  total_tensor_bytes += length
547
-
555
+ p = None
548
556
  if verboseLevel >=1 and tied_weights_count > 0:
549
557
  if tied_weights_count == 1:
550
558
  print(f"Tied weights of {tied_weights_total/ONE_MB:0.2f} MB detected: {tied_weights_last}")
@@ -570,6 +578,7 @@ def _pin_to_memory(model, model_id, partialPinning = False, pinnedPEFTLora = Tru
570
578
  tensor_no = 0
571
579
  # prev_big_tensor = 0
572
580
  for n, (p, is_buffer) in params_dict.items():
581
+ if p is None: continue
573
582
  q_name = tied_weights.get(n,None)
574
583
  if q_name != None:
575
584
  q , _ = params_dict[q_name]
@@ -633,6 +642,7 @@ def _pin_to_memory(model, model_id, partialPinning = False, pinnedPEFTLora = Tru
633
642
  else:
634
643
  length = torch.numel(p.data) * p.data.element_size()
635
644
  p.data = _move_to_pinned_tensor(p.data, current_big_tensor, offset, length)
645
+
636
646
  tensor_no += 1
637
647
  del p
638
648
  del dummy_pinned_tensor
@@ -658,7 +668,7 @@ def _welcome():
658
668
  if welcome_displayed:
659
669
  return
660
670
  welcome_displayed = True
661
- print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.5.1) by DeepBeepMeep ************{ENDC}{UNBOLD}")
671
+ print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.5.5) by DeepBeepMeep ************{ENDC}{UNBOLD}")
662
672
 
663
673
  def change_dtype(model, new_dtype, exclude_buffers = False):
664
674
  for submodule_name, submodule in model.named_modules():
@@ -1136,6 +1146,8 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
1136
1146
  break
1137
1147
  elif diff_b != None:
1138
1148
  rank = diff_b.shape[0]
1149
+ if not hasattr(module, "bias"):
1150
+ pass
1139
1151
  if module.bias == None:
1140
1152
  msg = f"Lora '{path}': Lora Basis is defined while it doesnt exist in model '{_get_module_name(model)}'. It is likely this Lora has been made for another version of this model."
1141
1153
  fail = True
@@ -1220,9 +1232,30 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
1220
1232
  activate_loras(model, loras_nos, loras_multi)
1221
1233
  return new_lora_path
1222
1234
 
1235
+
1236
+ def merge_dicts(A, B):
1237
+ for key, value in A.items():
1238
+ if isinstance(value, dict):
1239
+ if key not in B or not isinstance(B[key], dict):
1240
+ B[key] = value # Copy entire dict reference from A
1241
+ else:
1242
+ merge_dicts(value, B[key]) # Recurse into both dicts
1243
+ else:
1244
+ B[key] = value # Copy non-dict value from A to B
1245
+
1246
+
1247
+ def sync_models_loras(model, model2):
1248
+ merge_dicts(model._loras_model_shortcuts , model2._loras_model_shortcuts)
1249
+ model2._loras_active_adapters = model._loras_active_adapters
1250
+ model2._loras_adapters = model._loras_adapters
1251
+ model2._loras_scaling = model._loras_scaling
1252
+
1223
1253
  def unload_loras_from_model(model):
1254
+ if model is None: return
1224
1255
  for _, v in model._loras_model_data.items():
1225
1256
  v.clear()
1257
+ for _, v in model._loras_model_shortcuts.items():
1258
+ v.clear()
1226
1259
 
1227
1260
  model._loras_active_adapters = []
1228
1261
  model._loras_scaling = dict()
@@ -1262,7 +1295,7 @@ def move_loras_to_device(model, device="cpu" ):
1262
1295
  if ".lora_" in k:
1263
1296
  m.to(device)
1264
1297
 
1265
- def fast_load_transformers_model(model_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, forcedConfigPath = None, defaultConfigPath = None, modelClass=None, modelPrefix = None, writable_tensors = True, verboseLevel = -1, configKwargs ={}):
1298
+ def fast_load_transformers_model(model_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, forcedConfigPath = None, defaultConfigPath = None, modelClass=None, modelPrefix = None, writable_tensors = True, verboseLevel = -1, modules = None, return_shared_modules = None, configKwargs ={}):
1266
1299
  """
1267
1300
  quick version of .LoadfromPretrained of the transformers library
1268
1301
  used to build a model and load the corresponding weights (quantized or not)
@@ -1331,42 +1364,36 @@ def fast_load_transformers_model(model_path: str, do_quantize = False, quantizat
1331
1364
  model = transfomer_class(config_obj)
1332
1365
 
1333
1366
 
1334
- elif "_class_name" in transformer_config:
1335
- class_name = transformer_config["_class_name"]
1336
-
1367
+ else:
1337
1368
  if modelClass !=None:
1338
1369
  transfomer_class = modelClass
1339
- else:
1370
+ elif "_class_name" in transformer_config:
1371
+ class_name = 'Transformer3DModel'
1340
1372
  module = __import__("diffusers")
1341
1373
  transfomer_class = getattr(module, class_name)
1374
+ else:
1375
+ raise Exception("class not defined")
1342
1376
 
1343
1377
  with init_empty_weights():
1344
1378
  model = transfomer_class.from_config(transformer_config )
1345
1379
 
1346
1380
 
1347
1381
  torch.set_default_device('cpu')
1382
+ model.eval().requires_grad_(False)
1348
1383
 
1349
1384
  model._config = transformer_config
1350
1385
 
1351
- load_model_data(model,model_path, do_quantize = do_quantize, quantizationType = quantizationType, pinToMemory= pinToMemory, partialPinning= partialPinning, modelPrefix = modelPrefix, writable_tensors =writable_tensors ,verboseLevel=verboseLevel )
1386
+ load_model_data(model,model_path, do_quantize = do_quantize, quantizationType = quantizationType, pinToMemory= pinToMemory, partialPinning= partialPinning, modelPrefix = modelPrefix, writable_tensors =writable_tensors, modules = modules, return_shared_modules = return_shared_modules, verboseLevel=verboseLevel )
1352
1387
 
1353
1388
  return model
1354
1389
 
1355
1390
 
1356
1391
 
1357
- def load_model_data(model, file_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, modelPrefix = None, writable_tensors = True, verboseLevel = -1):
1392
+ def load_model_data(model, file_path, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, modelPrefix = None, writable_tensors = True, modules = None, return_shared_modules = None, verboseLevel = -1):
1358
1393
  """
1359
1394
  Load a model, detect if it has been previously quantized using quanto and do the extra setup if necessary
1360
1395
  """
1361
- if not isinstance(file_path, list):
1362
- file_path = [file_path]
1363
1396
 
1364
- file_path = [ _get_model(file) for file in file_path]
1365
- if any( file == None for file in file_path):
1366
- raise Exception("Unable to find file")
1367
- verboseLevel = _compute_verbose_level(verboseLevel)
1368
-
1369
- model = _remove_model_wrapper(model)
1370
1397
 
1371
1398
  def filter_state_dict(state_dict, base_model_prefix):
1372
1399
  new_state_dict= {}
@@ -1387,10 +1414,34 @@ def load_model_data(model, file_path: str, do_quantize = False, quantizationType
1387
1414
  new_state_dict[k[ start:]] = v
1388
1415
  return new_state_dict
1389
1416
 
1417
+
1418
+
1419
+ if not isinstance(file_path, list):
1420
+ file_path = [file_path]
1421
+
1422
+ file_count = len(file_path)
1423
+ if isinstance(modules, (list,str)):
1424
+ if isinstance(modules, str): modules = [modules]
1425
+ file_path += modules
1426
+ modules = None
1427
+
1428
+ file_path = [ _get_model(file) for file in file_path]
1429
+ if any( file == None for file in file_path):
1430
+ raise Exception("Unable to find file")
1431
+ verboseLevel = _compute_verbose_level(verboseLevel)
1432
+
1433
+ model = _remove_model_wrapper(model)
1434
+
1435
+ if return_shared_modules is not None:
1436
+ return_state_dict ={}
1437
+ return_quantization_map ={}
1438
+ return_shared_modules["state_dict"] = return_state_dict
1439
+ return_shared_modules["quantization_map"] = return_quantization_map
1440
+
1390
1441
  full_quantization_map = {}
1391
1442
  full_tied_weights_map = {}
1392
1443
  full_state_dict = {}
1393
- for file in file_path:
1444
+ for no, file in enumerate(file_path):
1394
1445
  quantization_map = None
1395
1446
  tied_weights_map = None
1396
1447
  if not (".safetensors" in file or ".sft" in file):
@@ -1443,6 +1494,13 @@ def load_model_data(model, file_path: str, do_quantize = False, quantizationType
1443
1494
  full_quantization_map.update(quantization_map)
1444
1495
  if tied_weights_map != None:
1445
1496
  full_tied_weights_map.update(tied_weights_map)
1497
+ if return_shared_modules is not None and no >= file_count:
1498
+ return_state_dict.update(state_dict)
1499
+ if quantization_map is not None: return_quantization_map.update(quantization_map)
1500
+
1501
+ if isinstance(modules, dict) :
1502
+ full_state_dict.update(modules["state_dict"])
1503
+ full_quantization_map.update(modules["quantization_map"])
1446
1504
 
1447
1505
  state_dict, quantization_map, tied_weights_map = full_state_dict, full_quantization_map, full_tied_weights_map
1448
1506
  full_state_dict, full_quantization_map, full_tied_weights_map = None, None, None
@@ -1463,7 +1521,7 @@ def load_model_data(model, file_path: str, do_quantize = False, quantizationType
1463
1521
 
1464
1522
 
1465
1523
  missing_keys , unexpected_keys = model.load_state_dict(state_dict, False, assign = True )
1466
- if len(missing_keys) > 0 :
1524
+ if len(missing_keys) > 0 :
1467
1525
  # if there is a key mismatch maybe we forgot to remove some prefix
1468
1526
  base_model_prefix = None
1469
1527
  for k,v in state_dict.items():
@@ -1474,18 +1532,53 @@ def load_model_data(model, file_path: str, do_quantize = False, quantizationType
1474
1532
  raise Exception(f"Missing keys: {missing_keys}")
1475
1533
  state_dict = filter_state_dict(state_dict, base_model_prefix)
1476
1534
  missing_keys , unexpected_keys = model.load_state_dict(state_dict, False, assign = True )
1535
+
1477
1536
  del state_dict
1537
+
1478
1538
  if len(unexpected_keys) > 0 and verboseLevel >=2:
1479
1539
  print(f"Unexpected keys while loading '{file_path}': {unexpected_keys}")
1480
1540
 
1481
1541
  for k,p in model.named_parameters():
1482
- if p.is_meta:
1542
+ if p.is_meta :
1483
1543
  txt = f"Incompatible State Dictionary or 'Init_Empty_Weights' not set since parameter '{k}' has no data"
1484
1544
  raise Exception(txt)
1485
1545
  for k,b in model.named_buffers():
1486
- if b.is_meta:
1546
+ if b.is_meta :
1487
1547
  txt = f"Incompatible State Dictionary or 'Init_Empty_Weights' not set since buffer '{k}' has no data"
1488
1548
  raise Exception(txt)
1549
+
1550
+ if return_shared_modules is not None:
1551
+ mods = { k : v for k,v in model.named_modules()}
1552
+ return_parameters = {}
1553
+ return_shared_modules["parameters"] = return_parameters
1554
+ for k in return_state_dict:
1555
+ if k.endswith("._data"):
1556
+ k = k[:-6]
1557
+ pos = k.rfind(".")
1558
+ mod_name = k[:pos]
1559
+ param_name = k[pos +1:]
1560
+ mod = mods.get(mod_name, None)
1561
+ if mod is not None:
1562
+ p = mod._parameters.get(param_name, None)
1563
+ if p is None: p = mod._buffers.get(param_name, None)
1564
+ if p is not None:
1565
+ return_parameters[k] = p
1566
+ del mods
1567
+
1568
+ if isinstance(modules, dict) :
1569
+ mods = { k : v for k,v in model.named_modules()}
1570
+ # replace Parameter outer shell so that both models parameters are tied
1571
+ for k, rep_p in modules["parameters"].items():
1572
+ pos = k.rfind(".")
1573
+ mod_name = k[:pos]
1574
+ param_name = k[pos +1:]
1575
+ mod = mods.get(mod_name, None)
1576
+ if mod is not None:
1577
+ setattr(mod, param_name, rep_p)
1578
+ del mods
1579
+ modules["parameters"].clear()
1580
+ modules["state_dict"].clear()
1581
+ rep_p = p = None
1489
1582
 
1490
1583
  if do_quantize:
1491
1584
  if quantization_map != None and len(quantization_map) > 0 :
@@ -1500,7 +1593,7 @@ def load_model_data(model, file_path: str, do_quantize = False, quantizationType
1500
1593
 
1501
1594
  return
1502
1595
 
1503
- def save_model(model, file_path, do_quantize = False, quantizationType = qint8, verboseLevel = -1, config_file_path = None ):
1596
+ def save_model(model, file_path, do_quantize = False, quantizationType = qint8, verboseLevel = -1, config_file_path = None, filter_sd =None ):
1504
1597
  """save the weights of a model and quantize them if requested
1505
1598
  These weights can be loaded again using 'load_model_data'
1506
1599
  """
@@ -1541,6 +1634,24 @@ def save_model(model, file_path, do_quantize = False, quantizationType = qint8,
1541
1634
  cache_ref = {}
1542
1635
  tied_weights_map = {}
1543
1636
  sd = model.state_dict()
1637
+ if filter_sd != None:
1638
+ new_sd = {}
1639
+ new_quantization_map = {}
1640
+ for k_k in filter_sd:
1641
+ for s in [".weight", ".bias", ".weight._data", ".weight._scale"]:
1642
+ if k_k.endswith(s):
1643
+ k_k= k_k[:-len(s)]
1644
+ break
1645
+ for k,v in sd.items():
1646
+ if k.startswith(k_k):
1647
+ new_sd[k] = v
1648
+ if quantization_map != None:
1649
+ for k,v in quantization_map.items():
1650
+ if k.startswith(k_k):
1651
+ new_quantization_map[k] = v
1652
+ sd = new_sd
1653
+ if quantization_map != None: quantization_map = new_quantization_map
1654
+
1544
1655
  out_sd = OrderedDict()
1545
1656
 
1546
1657
 
@@ -1755,6 +1866,12 @@ class offload:
1755
1866
  if tied_p.is_cuda:
1756
1867
  setattr(parent_module, n , tied_p)
1757
1868
  continue
1869
+ # if hasattr(p,'_data'):
1870
+ # if not p._data.is_pinned() or not p._scale.is_pinned():
1871
+ # pass
1872
+ # else:
1873
+ # if not p.data.is_pinned():
1874
+ # pass
1758
1875
 
1759
1876
  q = p.to("cuda", non_blocking=True)
1760
1877
  if is_buffer:
@@ -1974,13 +2091,16 @@ class offload:
1974
2091
  if data == None:
1975
2092
  continue
1976
2093
  diff_w , _ , diff_b, alpha = data
2094
+ scaling = self._get_lora_scaling( loras_scaling, model, active_adapter) * alpha
2095
+ if scaling == 0:
2096
+ continue
1977
2097
  if first_weight:
1978
2098
  original_weight= weight.clone() if weight != None else None
1979
2099
  first_weight = False
1980
2100
  if first_bias:
1981
2101
  original_bias= bias.clone() if bias != None else None
1982
2102
  first_bias = False
1983
- scaling = self._get_lora_scaling( loras_scaling, model, active_adapter) * alpha
2103
+
1984
2104
  if diff_w != None:
1985
2105
  weight.add_(diff_w, alpha= scaling)
1986
2106
  diff_w = None
@@ -2018,6 +2138,8 @@ class offload:
2018
2138
  continue
2019
2139
  lora_A_weight, lora_B_weight, diff_b, alpha = data
2020
2140
  scaling = self._get_lora_scaling(loras_scaling, model, active_adapter) * alpha
2141
+ if scaling == 0:
2142
+ continue
2021
2143
  if lora_A_weight != None:
2022
2144
  weight.addmm_(lora_B_weight, lora_A_weight, alpha= scaling )
2023
2145
 
@@ -2049,6 +2171,8 @@ class offload:
2049
2171
  lora_A, lora_B, diff_b, alpha = data
2050
2172
  # dropout = self.lora_dropout[active_adapter]
2051
2173
  scaling = self._get_lora_scaling(loras_scaling, model, active_adapter) * alpha
2174
+ if scaling == 0:
2175
+ continue
2052
2176
  if lora_A == None:
2053
2177
  result.add_(diff_b, alpha=scaling)
2054
2178
  else:
@@ -2067,10 +2191,12 @@ class offload:
2067
2191
  return result
2068
2192
 
2069
2193
 
2070
- def hook_lora(self, submodule, current_model, model_id, loras_model_data, submodule_name):
2194
+ def hook_lora(self, submodule, current_model, model_id, loras_model_data, loras_model_shortcuts, submodule_name):
2071
2195
  old_forward = submodule.forward
2072
2196
 
2073
2197
  loras_data = {}
2198
+ assert submodule_name not in loras_model_shortcuts
2199
+ loras_model_shortcuts[submodule_name] = loras_data
2074
2200
  loras_model_data[submodule] = loras_data
2075
2201
 
2076
2202
  if isinstance(submodule, torch.nn.Linear):
@@ -2078,7 +2204,7 @@ class offload:
2078
2204
  if len(loras_data) == 0:
2079
2205
  return old_forward(*args, **kwargs)
2080
2206
  else:
2081
- # submodule.aaa = submodule_name
2207
+ submodule.aaa = submodule_name
2082
2208
  return self._lora_linear_forward(current_model, submodule, loras_data, *args, **kwargs)
2083
2209
  target_fn = lora_linear_forward
2084
2210
  else:
@@ -2188,7 +2314,6 @@ class offload:
2188
2314
 
2189
2315
  if current_budget == 0 or towers_names is None or len(towers_names) == 0 or not self.async_transfers:
2190
2316
  return
2191
- # current_budget = 5000 * ONE_MB
2192
2317
  base_size = self.blocks_of_modules_sizes[model_id]
2193
2318
  current_budget -= base_size
2194
2319
  current_budget = max(0, current_budget)
@@ -2501,14 +2626,14 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, p
2501
2626
  print(f"Model '{model_id}' already pinned to reserved memory")
2502
2627
  else:
2503
2628
  _pin_to_memory(current_model, model_id, partialPinning= partialPinning, pinnedPEFTLora = pinnedPEFTLora, perc_reserved_mem_max = perc_reserved_mem_max, verboseLevel=verboseLevel)
2504
-
2505
2629
  current_budget = model_budgets[model_id]
2506
2630
  cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq, is_mod_seq = None, None, None, -1, False
2507
2631
  self.loaded_blocks[model_id] = None
2508
2632
  any_lora = loras !=None and model_id in loras
2509
2633
  if any_lora:
2510
- loras_model_data = {}
2634
+ loras_model_data, loras_model_shortcuts = {}, {}
2511
2635
  current_model._loras_model_data = loras_model_data
2636
+ current_model._loras_model_shortcuts = loras_model_shortcuts
2512
2637
  for submodule_name, submodule in current_model.named_modules():
2513
2638
  # create a fake 'accelerate' parameter so that the _execution_device property returns always "cuda"
2514
2639
  # (it is queried in many pipelines even if offloading is not properly implemented)
@@ -2542,7 +2667,7 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, p
2542
2667
  if hasattr(submodule, "forward"):
2543
2668
  # if any_lora and isinstance(submodule, ( torch.nn.Linear, torch.nn.Conv3d, torch.nn.LayerNorm)):
2544
2669
  if any_lora and hasattr(submodule,"weight"):
2545
- submodule_method = self.hook_lora(submodule, current_model, model_id, loras_model_data, submodule_name)
2670
+ submodule_method = self.hook_lora(submodule, current_model, model_id, loras_model_data, loras_model_shortcuts, submodule_name)
2546
2671
  else:
2547
2672
  submodule_method = getattr(submodule, "forward")
2548
2673
  if callable(submodule_method):
@@ -2552,11 +2677,12 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, p
2552
2677
  self.hook_preload_blocks_for_compilation(submodule, model_id, cur_blocks_name, context = submodule_name )
2553
2678
  else:
2554
2679
  self.hook_check_empty_cache_needed(submodule, current_model, model_id, cur_blocks_name, submodule_method, context = submodule_name )
2555
-
2680
+
2556
2681
  self.add_module_to_blocks(model_id, cur_blocks_name, submodule, prev_blocks_name, submodule_name)
2557
2682
 
2558
2683
 
2559
2684
  self.tune_preloading(model_id, current_budget, towers_names)
2685
+ self.parameters_ref = {}
2560
2686
 
2561
2687
 
2562
2688
  if self.verboseLevel >=2:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mmgp
3
- Version: 3.5.1
3
+ Version: 3.5.5
4
4
  Summary: Memory Management for the GPU Poor
5
5
  Author-email: deepbeepmeep <deepbeepmeep@yahoo.com>
6
6
  Requires-Python: >=3.10
@@ -15,7 +15,7 @@ Dynamic: license-file
15
15
 
16
16
 
17
17
  <p align="center">
18
- <H2>Memory Management 3.5.1 for the GPU Poor by DeepBeepMeep</H2>
18
+ <H2>Memory Management 3.5.5 for the GPU Poor by DeepBeepMeep</H2>
19
19
  </p>
20
20
 
21
21
 
@@ -0,0 +1,9 @@
1
+ __init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ mmgp/__init__.py,sha256=A9qBwyQMd1M7vshSTOBnFGP1MQvS2hXmTcTCMUcmyzE,509
3
+ mmgp/offload.py,sha256=EnkDZp__eKmqWhAS9vM1hKGdXizeLYPTN5bdFbmVmlc,126301
4
+ mmgp/safetensors2.py,sha256=4nKV13qCMabnNEB1TA_ueFbfGYYmiQ9racR_C6SsGug,18693
5
+ mmgp-3.5.5.dist-info/licenses/LICENSE.md,sha256=DD-WIS0BkPoWJ_8hQO3J8hMP9K_1-dyrYv1YCbkxcDU,94
6
+ mmgp-3.5.5.dist-info/METADATA,sha256=1IiXOasc93ZBWOCEPZTQ-0ajzR_pP74cTuqS_ROcZsU,16309
7
+ mmgp-3.5.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
8
+ mmgp-3.5.5.dist-info/top_level.txt,sha256=waGaepj2qVfnS2yAOkaMu4r9mJaVjGbEi6AwOUogU_U,14
9
+ mmgp-3.5.5.dist-info/RECORD,,
@@ -1,2 +1,2 @@
1
- GNU GENERAL PUBLIC LICENSE
1
+ GNU GENERAL PUBLIC LICENSE
2
2
  Version 3, 29 June 2007
@@ -1,9 +0,0 @@
1
- __init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- mmgp/__init__.py,sha256=A9qBwyQMd1M7vshSTOBnFGP1MQvS2hXmTcTCMUcmyzE,509
3
- mmgp/offload.py,sha256=3TpjzT7DJ2yGHIm-u2O-U2wAR_V2ZH1NqmC1bMYhfso,120962
4
- mmgp/safetensors2.py,sha256=4nKV13qCMabnNEB1TA_ueFbfGYYmiQ9racR_C6SsGug,18693
5
- mmgp-3.5.1.dist-info/licenses/LICENSE.md,sha256=HjzvY2grdtdduZclbZ46B2M-XpT4MDCxFub5ZwTWq2g,93
6
- mmgp-3.5.1.dist-info/METADATA,sha256=IOEJfRTedEF4lv9A9DEHJF9-kr3UK1yhlo6Nsdhca10,16309
7
- mmgp-3.5.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
8
- mmgp-3.5.1.dist-info/top_level.txt,sha256=waGaepj2qVfnS2yAOkaMu4r9mJaVjGbEi6AwOUogU_U,14
9
- mmgp-3.5.1.dist-info/RECORD,,
File without changes