mmgp 3.4.0__tar.gz → 3.4.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mmgp
3
- Version: 3.4.0
3
+ Version: 3.4.1
4
4
  Summary: Memory Management for the GPU Poor
5
5
  Author-email: deepbeepmeep <deepbeepmeep@yahoo.com>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -17,7 +17,7 @@ Dynamic: license-file
17
17
 
18
18
 
19
19
  <p align="center">
20
- <H2>Memory Management 3.3.1 for the GPU Poor by DeepBeepMeep</H2>
20
+ <H2>Memory Management 3.4.1 for the GPU Poor by DeepBeepMeep</H2>
21
21
  </p>
22
22
 
23
23
 
@@ -1,6 +1,6 @@
1
1
 
2
2
  <p align="center">
3
- <H2>Memory Management 3.3.1 for the GPU Poor by DeepBeepMeep</H2>
3
+ <H2>Memory Management 3.4.1 for the GPU Poor by DeepBeepMeep</H2>
4
4
  </p>
5
5
 
6
6
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "mmgp"
3
- version = "3.4.0"
3
+ version = "3.4.1"
4
4
  authors = [
5
5
  { name = "deepbeepmeep", email = "deepbeepmeep@yahoo.com" },
6
6
  ]
@@ -1,4 +1,4 @@
1
- # ------------------ Memory Management 3.4.0 for the GPU Poor by DeepBeepMeep (mmgp)------------------
1
+ # ------------------ Memory Management 3.4.1 for the GPU Poor by DeepBeepMeep (mmgp)------------------
2
2
  #
3
3
  # This module contains multiples optimisations so that models such as Flux (and derived), Mochi, CogView, HunyuanVideo, ... can run smoothly on a 24 GB GPU limited card.
4
4
  # This a replacement for the accelerate library that should in theory manage offloading, but doesn't work properly with models that are loaded / unloaded several
@@ -618,8 +618,23 @@ def _welcome():
618
618
  if welcome_displayed:
619
619
  return
620
620
  welcome_displayed = True
621
- print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.4.0) by DeepBeepMeep ************{ENDC}{UNBOLD}")
621
+ print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.4.1) by DeepBeepMeep ************{ENDC}{UNBOLD}")
622
622
 
623
+ def change_dtype(model, new_dtype, exclude_buffers = False):
624
+ for submodule_name, submodule in model.named_modules():
625
+ if hasattr(submodule, "_lock_dtype"):
626
+ continue
627
+ for n, p in submodule.named_parameters(recurse = False):
628
+ if p.data.dtype != new_dtype:
629
+ p.data = p.data.to(new_dtype)
630
+
631
+ if not exclude_buffers:
632
+ for p in submodule.buffers(recurse=False):
633
+ if p.data.dtype != new_dtype:
634
+ p.data = p.data.to(new_dtype)
635
+
636
+ return model
637
+
623
638
  def _extract_num_from_str(num_in_str):
624
639
  size = len(num_in_str)
625
640
  for i in range(size):
@@ -760,7 +775,11 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 2*
760
775
  for p in submodule.buffers(recurse=False):
761
776
  size += torch.numel(p.data) * sizeofhalffloat
762
777
 
763
-
778
+ already_added = False
779
+ if hasattr(submodule, "_lock_dtype"):
780
+ submodule_size += size
781
+ submodule_names.append(submodule_name)
782
+ already_added = True
764
783
 
765
784
  if not any(submodule_name.startswith(pre) for pre in tower_names):
766
785
  flush = False
@@ -778,8 +797,9 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 2*
778
797
  submodule_size = 0
779
798
  submodule_names = []
780
799
  prev_blocks_prefix = cur_blocks_prefix
781
- submodule_size += size
782
- submodule_names.append(submodule_name)
800
+ if not already_added:
801
+ submodule_size += size
802
+ submodule_names.append(submodule_name)
783
803
  total_size += size
784
804
 
785
805
  if submodule_size >0 :
@@ -1347,9 +1367,13 @@ def load_model_data(model, file_path: str, do_quantize = False, quantizationType
1347
1367
  if k.endswith(missing_keys[0]):
1348
1368
  base_model_prefix = k[:-len(missing_keys[0])]
1349
1369
  break
1350
- state_dict = filter_state_dict(state_dict,base_model_prefix)
1370
+ if base_model_prefix == None:
1371
+ raise Exception("Missing keys: {missing_keys}")
1372
+ state_dict = filter_state_dict(state_dict, base_model_prefix)
1351
1373
  missing_keys , unexpected_keys = model.load_state_dict(state_dict, False, assign = True )
1352
1374
  del state_dict
1375
+ if len(unexpected_keys) > 0 and verboseLevel >=2:
1376
+ print(f"Unexpected keys while loading '{file_path}': {unexpected_keys}")
1353
1377
 
1354
1378
  for k,p in model.named_parameters():
1355
1379
  if p.is_meta:
@@ -1962,7 +1986,10 @@ class offload:
1962
1986
 
1963
1987
 
1964
1988
  def hook_change_module(self, target_module, model, model_id, module_id, previous_method):
1965
- dtype = model._dtype
1989
+ if hasattr(target_module, "_lock_dtype"):
1990
+ dtype = target_module._lock_dtype
1991
+ else:
1992
+ dtype = model._dtype
1966
1993
 
1967
1994
  def check_change_module(module, *args, **kwargs):
1968
1995
  self.ensure_model_loaded(model_id)
@@ -2207,35 +2234,39 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, pinnedPEFTLora = False, p
2207
2234
  modelPinned = (pinAllModels or model_id in modelsToPin) and not hasattr(current_model,"_already_pinned")
2208
2235
 
2209
2236
  current_model_size = 0
2210
- model_dtype = None
2211
- for n, p in current_model.named_parameters():
2212
- p.requires_grad = False
2213
- if isinstance(p, QTensor):
2214
- if p._qtype == qint4:
2215
- if hasattr(p,"_scale_shift"):
2216
- current_model_size += torch.numel(p._scale_shift) * p._scale_shift.element_size()
2217
- else:
2218
- current_model_size += torch.numel(p._scale) * p._shift.element_size() + torch.numel(p._scale) * p._shift.element_size()
2237
+ model_dtype = None
2238
+
2239
+ for _ , m in current_model.named_modules():
2240
+ ignore_dtype = hasattr(m, "_lock_dtype")
2241
+ for n, p in m.named_parameters(recurse = False):
2242
+ p.requires_grad = False
2243
+ if isinstance(p, QTensor):
2244
+ if p._qtype == qint4:
2245
+ if hasattr(p,"_scale_shift"):
2246
+ current_model_size += torch.numel(p._scale_shift) * p._scale_shift.element_size()
2247
+ else:
2248
+ current_model_size += torch.numel(p._scale) * p._shift.element_size() + torch.numel(p._scale) * p._shift.element_size()
2219
2249
 
2220
- current_model_size += torch.numel(p._data._data) * p._data._data.element_size()
2250
+ current_model_size += torch.numel(p._data._data) * p._data._data.element_size()
2221
2251
 
2222
- else:
2223
- current_model_size += torch.numel(p._scale) * p._scale.element_size()
2224
- current_model_size += torch.numel(p._data) * p._data.element_size()
2225
- dtype = p._scale.dtype
2252
+ else:
2253
+ current_model_size += torch.numel(p._scale) * p._scale.element_size()
2254
+ current_model_size += torch.numel(p._data) * p._data.element_size()
2255
+ dtype = p._scale.dtype
2226
2256
 
2227
- else:
2228
- dtype = p.data.dtype
2229
- if convertWeightsFloatTo != None and dtype == torch.float32:
2230
- # convert any left overs float32 weight to bfloat16 / float16 to divide by 2 the model memory footprint
2231
- dtype = convertWeightsFloatTo if model_dtype == None else model_dtype
2232
- p.data = p.data.to(dtype)
2233
- if model_dtype== None:
2234
- model_dtype = dtype
2235
2257
  else:
2236
- assert model_dtype == dtype
2237
- current_model_size += torch.numel(p.data) * p.data.element_size()
2238
- current_model._dtype = model_dtype
2258
+ if not ignore_dtype:
2259
+ dtype = p.data.dtype
2260
+ if convertWeightsFloatTo != None and dtype == torch.float32 :
2261
+ # convert any left overs float32 weight to bfloat16 / float16 to divide by 2 the model memory footprint
2262
+ dtype = convertWeightsFloatTo if model_dtype == None else model_dtype
2263
+ p.data = p.data.to(dtype)
2264
+ if model_dtype== None:
2265
+ model_dtype = dtype
2266
+ else:
2267
+ assert model_dtype == dtype
2268
+ current_model_size += torch.numel(p.data) * p.data.element_size()
2269
+ current_model._dtype = model_dtype
2239
2270
  for b in current_model.buffers():
2240
2271
  # do not convert 32 bits float to 16 bits since buffers are few (and potential gain low) and usually they are needed for precision calculation (for instance Rope)
2241
2272
  current_model_size += torch.numel(b.data) * b.data.element_size()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mmgp
3
- Version: 3.4.0
3
+ Version: 3.4.1
4
4
  Summary: Memory Management for the GPU Poor
5
5
  Author-email: deepbeepmeep <deepbeepmeep@yahoo.com>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -17,7 +17,7 @@ Dynamic: license-file
17
17
 
18
18
 
19
19
  <p align="center">
20
- <H2>Memory Management 3.3.1 for the GPU Poor by DeepBeepMeep</H2>
20
+ <H2>Memory Management 3.4.1 for the GPU Poor by DeepBeepMeep</H2>
21
21
  </p>
22
22
 
23
23
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes