mmgp 3.0.3__py3-none-any.whl → 3.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mmgp might be problematic. Click here for more details.

mmgp/offload.py CHANGED
@@ -1,4 +1,4 @@
1
- # ------------------ Memory Management 3.0 for the GPU Poor by DeepBeepMeep (mmgp)------------------
1
+ # ------------------ Memory Management 3.1 for the GPU Poor by DeepBeepMeep (mmgp)------------------
2
2
  #
3
3
  # This module contains multiples optimisations so that models such as Flux (and derived), Mochi, CogView, HunyuanVideo, ... can run smoothly on a 24 GB GPU limited card.
4
4
  # This a replacement for the accelerate library that should in theory manage offloading, but doesn't work properly with models that are loaded / unloaded several
@@ -61,13 +61,25 @@ import sys
61
61
  import os
62
62
  import json
63
63
  import psutil
64
+ try:
65
+ from diffusers.utils.peft_utils import set_weights_and_activate_adapters, get_peft_kwargs
66
+ except:
67
+ set_weights_and_activate_adapters = None
68
+ get_peft_kwargs = None
69
+ pass
70
+ try:
71
+ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
72
+ except:
73
+ inject_adapter_in_model = None
74
+ pass
75
+
64
76
  from mmgp import safetensors2
65
77
  from mmgp import profile_type
66
78
 
67
- from optimum.quanto import freeze, qfloat8, qint8, quantize, QModuleMixin, QTensor, WeightQBytesTensor, quantize_module
68
-
79
+ from optimum.quanto import freeze, qfloat8, qint4 , qint8, quantize, QModuleMixin, QTensor, quantize_module
69
80
 
70
81
 
82
+ shared_state = {}
71
83
 
72
84
  mmm = safetensors2.mmm
73
85
 
@@ -127,6 +139,9 @@ def move_tensors(obj, device):
127
139
  return _list
128
140
  else:
129
141
  raise TypeError("Tensor or list / dict of tensors expected")
142
+ def _get_module_name(v):
143
+ return v.__module__.lower()
144
+
130
145
 
131
146
  def _compute_verbose_level(level):
132
147
  if level <0:
@@ -139,33 +154,75 @@ def _get_max_reservable_memory(perc_reserved_mem_max):
139
154
  perc_reserved_mem_max = 0.40 if os.name == 'nt' else 0.5
140
155
  return perc_reserved_mem_max * physical_memory
141
156
 
142
- def _detect_main_towers(model, verboseLevel=1):
157
+ def _detect_main_towers(model, min_floors = 5, verboseLevel=1):
143
158
  cur_blocks_prefix = None
144
159
  towers_modules= []
145
160
  towers_names= []
146
161
 
162
+ floors_modules= []
163
+ tower_name = None
164
+
165
+
147
166
  for submodule_name, submodule in model.named_modules():
167
+
148
168
  if submodule_name=='':
149
169
  continue
150
170
 
151
- if isinstance(submodule, torch.nn.ModuleList):
152
- newList =False
153
- if cur_blocks_prefix == None:
154
- cur_blocks_prefix = submodule_name + "."
155
- newList = True
156
- else:
157
- if not submodule_name.startswith(cur_blocks_prefix):
158
- cur_blocks_prefix = submodule_name + "."
159
- newList = True
171
+ if cur_blocks_prefix != None:
172
+ if submodule_name.startswith(cur_blocks_prefix):
173
+ depth_prefix = cur_blocks_prefix.split(".")
174
+ depth_name = submodule_name.split(".")
175
+ level = depth_name[len(depth_prefix)-1]
176
+ pre , num = _extract_num_from_str(level)
177
+
178
+ if num != cur_blocks_seq:
179
+ floors_modules.append(submodule)
160
180
 
161
- if newList and len(submodule)>=5:
162
- towers_names.append(submodule_name)
163
- towers_modules.append(submodule)
181
+ cur_blocks_seq = num
182
+ else:
183
+ if len(floors_modules) >= min_floors:
184
+ towers_modules += floors_modules
185
+ towers_names.append(tower_name)
186
+ tower_name = None
187
+ floors_modules= []
188
+ cur_blocks_prefix, cur_blocks_seq = None, -1
189
+
190
+ if cur_blocks_prefix == None:
191
+ pre , num = _extract_num_from_str(submodule_name)
192
+ if isinstance(submodule, (torch.nn.ModuleList)):
193
+ cur_blocks_prefix, cur_blocks_seq = pre + ".", -1
194
+ tower_name = submodule_name + ".*"
195
+ elif num >=0:
196
+ cur_blocks_prefix, cur_blocks_seq = pre, num
197
+ tower_name = submodule_name[ :-1] + "*"
198
+ floors_modules.append(submodule)
199
+
200
+ if len(floors_modules) >= min_floors:
201
+ towers_modules += floors_modules
202
+ towers_names.append(tower_name)
203
+
204
+ # for submodule_name, submodule in model.named_modules():
205
+ # if submodule_name=='':
206
+ # continue
207
+
208
+ # if isinstance(submodule, torch.nn.ModuleList):
209
+ # newList =False
210
+ # if cur_blocks_prefix == None:
211
+ # cur_blocks_prefix = submodule_name + "."
212
+ # newList = True
213
+ # else:
214
+ # if not submodule_name.startswith(cur_blocks_prefix):
215
+ # cur_blocks_prefix = submodule_name + "."
216
+ # newList = True
217
+
218
+ # if newList and len(submodule)>=5:
219
+ # towers_names.append(submodule_name)
220
+ # towers_modules.append(submodule)
164
221
 
165
- else:
166
- if cur_blocks_prefix is not None:
167
- if not submodule_name.startswith(cur_blocks_prefix):
168
- cur_blocks_prefix = None
222
+ # else:
223
+ # if cur_blocks_prefix is not None:
224
+ # if not submodule_name.startswith(cur_blocks_prefix):
225
+ # cur_blocks_prefix = None
169
226
 
170
227
  return towers_names, towers_modules
171
228
 
@@ -179,7 +236,7 @@ def _get_model(model_path):
179
236
  _path = Path(model_path).parts
180
237
  _filename = _path[-1]
181
238
  _path = _path[:-1]
182
- if len(_path)==1:
239
+ if len(_path)<=1:
183
240
  raise("file not found")
184
241
  else:
185
242
  from huggingface_hub import hf_hub_download #snapshot_download,
@@ -263,7 +320,13 @@ def _pin_to_memory(model, model_id, partialPinning = False, perc_reserved_mem_ma
263
320
  # print(f"num params to pin {model_id}: {len(params_list)}")
264
321
  for p in params_list:
265
322
  if isinstance(p, QTensor):
266
- length = torch.numel(p._data) * p._data.element_size() + torch.numel(p._scale) * p._scale.element_size()
323
+ if p._qtype == qint4:
324
+ if hasattr(p,"_scale_shift"):
325
+ length = torch.numel(p._data._data) * p._data._data.element_size() + torch.numel(p._scale_shift) * p._scale_shift.element_size()
326
+ else:
327
+ length = torch.numel(p._data._data) * p._data._data.element_size() + torch.numel(p._scale) * p._scale.element_size() + torch.numel(p._shift) * p._shift.element_size()
328
+ else:
329
+ length = torch.numel(p._data) * p._data.element_size() + torch.numel(p._scale) * p._scale.element_size()
267
330
  else:
268
331
  length = torch.numel(p.data) * p.data.element_size()
269
332
 
@@ -306,10 +369,22 @@ def _pin_to_memory(model, model_id, partialPinning = False, perc_reserved_mem_ma
306
369
  if big_tensor_no>=0 and big_tensor_no < last_big_tensor:
307
370
  current_big_tensor = big_tensors[big_tensor_no]
308
371
  if isinstance(p, QTensor):
309
- length1 = torch.numel(p._data) * p._data.element_size()
310
- p._data = _move_to_pinned_tensor(p._data, current_big_tensor, offset, length1)
311
- length2 = torch.numel(p._scale) * p._scale.element_size()
312
- p._scale = _move_to_pinned_tensor(p._scale, current_big_tensor, offset + length1, length2)
372
+ if p._qtype == qint4:
373
+ length1 = torch.numel(p._data._data) * p._data._data.element_size()
374
+ p._data._data = _move_to_pinned_tensor(p._data._data, current_big_tensor, offset, length1)
375
+ if hasattr(p,"_scale_shift"):
376
+ length2 = torch.numel(p._scale_shift) * p._scale_shift.element_size()
377
+ p._scale_shift = _move_to_pinned_tensor(p._scale_shift, current_big_tensor, offset + length1, length2)
378
+ else:
379
+ length2 = torch.numel(p._scale) * p._scale.element_size()
380
+ p._scale = _move_to_pinned_tensor(p._scale, current_big_tensor, offset + length1, length2)
381
+ length3 = torch.numel(p._shift) * p._shift.element_size()
382
+ p._shift = _move_to_pinned_tensor(p._shift, current_big_tensor, offset + length1 + length2, length3)
383
+ else:
384
+ length1 = torch.numel(p._data) * p._data.element_size()
385
+ p._data = _move_to_pinned_tensor(p._data, current_big_tensor, offset, length1)
386
+ length2 = torch.numel(p._scale) * p._scale.element_size()
387
+ p._scale = _move_to_pinned_tensor(p._scale, current_big_tensor, offset + length1, length2)
313
388
  else:
314
389
  length = torch.numel(p.data) * p.data.element_size()
315
390
  p.data = _move_to_pinned_tensor(p.data, current_big_tensor, offset, length)
@@ -336,100 +411,16 @@ def _welcome():
336
411
  if welcome_displayed:
337
412
  return
338
413
  welcome_displayed = True
339
- print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.0) by DeepBeepMeep ************{ENDC}{UNBOLD}")
340
-
341
-
342
- # def _pin_to_memory_sd(model, sd, model_id, partialPinning = False, perc_reserved_mem_max = 0, verboseLevel = 1):
343
- # if verboseLevel>=1 :
344
- # if partialPinning:
345
- # print(f"Partial pinning to reserved RAM of data of file '{model_id}' while loading it")
346
- # else:
347
- # print(f"Pinning data to reserved RAM of file '{model_id}' while loading it")
348
-
349
- # max_reservable_memory = _get_max_reservable_memory(perc_reserved_mem_max)
350
- # if partialPinning:
351
- # towers_names, _ = _detect_main_towers(model)
352
- # towers_names = [n +"." for n in towers_names]
353
-
354
- # BIG_TENSOR_MAX_SIZE = 2**28 # 256 MB
355
- # current_big_tensor_size = 0
356
- # big_tensor_no = 0
357
- # big_tensors_sizes = []
358
- # tensor_map_indexes = []
359
- # total_tensor_bytes = 0
360
-
361
- # for k,t in sd.items():
362
- # include = True
363
- # # if isinstance(p, QTensor):
364
- # # length = torch.numel(p._data) * p._data.element_size() + torch.numel(p._scale) * p._scale.element_size()
365
- # # else:
366
- # # length = torch.numel(p.data) * p.data.element_size()
367
- # length = torch.numel(t) * t.data.element_size()
368
-
369
- # if partialPinning:
370
- # include = any(k.startswith(pre) for pre in towers_names) if partialPinning else True
371
-
372
- # if include:
373
- # if current_big_tensor_size + length > BIG_TENSOR_MAX_SIZE:
374
- # big_tensors_sizes.append(current_big_tensor_size)
375
- # current_big_tensor_size = 0
376
- # big_tensor_no += 1
377
- # tensor_map_indexes.append((big_tensor_no, current_big_tensor_size, length ))
378
- # current_big_tensor_size += length
379
- # else:
380
- # tensor_map_indexes.append((-1, 0, 0 ))
381
- # total_tensor_bytes += length
382
-
383
- # big_tensors_sizes.append(current_big_tensor_size)
384
-
385
- # big_tensors = []
386
- # last_big_tensor = 0
387
- # total = 0
388
-
414
+ print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.1) by DeepBeepMeep ************{ENDC}{UNBOLD}")
389
415
 
390
- # for size in big_tensors_sizes:
391
- # try:
392
- # currrent_big_tensor = torch.empty( size, dtype= torch.uint8, pin_memory=True)
393
- # big_tensors.append(currrent_big_tensor)
394
- # except:
395
- # print(f"Unable to pin more tensors for this model as the maximum reservable memory has been reached ({total/ONE_MB:.2f})")
396
- # break
397
-
398
- # last_big_tensor += 1
399
- # total += size
400
-
401
-
402
- # tensor_no = 0
403
- # for k,t in sd.items():
404
- # big_tensor_no, offset, length = tensor_map_indexes[tensor_no]
405
- # if big_tensor_no>=0 and big_tensor_no < last_big_tensor:
406
- # current_big_tensor = big_tensors[big_tensor_no]
407
- # # if isinstance(p, QTensor):
408
- # # length1 = torch.numel(p._data) * p._data.element_size()
409
- # # p._data = _move_to_pinned_tensor(p._data, current_big_tensor, offset, length1)
410
- # # length2 = torch.numel(p._scale) * p._scale.element_size()
411
- # # p._scale = _move_to_pinned_tensor(p._scale, current_big_tensor, offset + length1, length2)
412
- # # else:
413
- # # length = torch.numel(p.data) * p.data.element_size()
414
- # # p.data = _move_to_pinned_tensor(p.data, current_big_tensor, offset, length)
415
- # length = torch.numel(t) * t.data.element_size()
416
- # t = _move_to_pinned_tensor(t, current_big_tensor, offset, length)
417
- # sd[k] = t
418
- # tensor_no += 1
419
-
420
- # global total_pinned_bytes
421
- # total_pinned_bytes += total
422
-
423
- # if verboseLevel >=1:
424
- # if total_tensor_bytes == total:
425
- # print(f"The whole model was pinned to reserved RAM: {last_big_tensor} large blocks spread across {total/ONE_MB:.2f} MB")
426
- # else:
427
- # print(f"{total/ONE_MB:.2f} MB were pinned to reserved RAM out of {total_tensor_bytes/ONE_MB:.2f} MB")
428
-
429
- # model._already_pinned = True
430
-
431
-
432
- # return
416
+ def _extract_num_from_str(num_in_str):
417
+ for i in range(len(num_in_str)):
418
+ if not num_in_str[-i-1:].isnumeric():
419
+ if i == 0:
420
+ return num_in_str, -1
421
+ else:
422
+ return num_in_str[: -i], int(num_in_str[-i:])
423
+ return "", int(num_in_str)
433
424
 
434
425
  def _quantize_dirty_hack(model):
435
426
  # dirty hack: add a hook on state_dict() to return a fake non quantized state_dict if called by Lora Diffusers initialization functions
@@ -536,10 +527,14 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 10
536
527
  prev_blocks_prefix = None
537
528
 
538
529
  if hasattr(model_to_quantize, "_quanto_map"):
530
+ for k, entry in model_to_quantize._quanto_map.items():
531
+ weights = entry["weights"]
532
+ print(f"Model '{model_id}' is already quantized to format '{weights}'")
533
+ return False
539
534
  print(f"Model '{model_id}' is already quantized")
540
535
  return False
541
-
542
- print(f"Quantization of model '{model_id}' started")
536
+
537
+ print(f"Quantization of model '{model_id}' started to format '{weights}'")
543
538
 
544
539
  for submodule_name, submodule in model_to_quantize.named_modules():
545
540
  if isinstance(submodule, QModuleMixin):
@@ -594,18 +589,18 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 10
594
589
  if verboseLevel >=2:
595
590
  print(f"Total Excluded {total_excluded/ONE_MB:.1f} MB oF {total_size/ONE_MB:.1f} that is {perc_excluded*100:.2f}%")
596
591
  if perc_excluded >= 0.10:
597
- print(f"Too many many modules are excluded, there is something wrong with the selection, switch back to full quantization.")
592
+ print(f"Too many modules are excluded, there is something wrong with the selection, switch back to full quantization.")
598
593
  exclude_list = None
599
594
 
600
595
 
601
596
  #quantize(model_to_quantize,weights, exclude= exclude_list)
602
- pass
597
+
603
598
  for name, m in model_to_quantize.named_modules():
604
599
  if exclude_list is None or not any( name == module_name for module_name in exclude_list):
605
600
  _quantize_submodule(model_to_quantize, name, m, weights=weights, activations=None, optimizer=None)
606
601
 
607
- # force read non quantized parameters so that their lazy tensors and corresponding mmap are released
608
- # otherwise we may end up to keep in memory both the quantized and the non quantize model
602
+ # force to read non quantized parameters so that their lazy tensors and corresponding mmap are released
603
+ # otherwise we may end up keeping in memory both the quantized and the non quantize model
609
604
  for m in model_to_quantize.modules():
610
605
  # do not read quantized weights (detected them directly or behind an adapter)
611
606
  if isinstance(m, QModuleMixin) or hasattr(m, "base_layer") and isinstance(m.base_layer, QModuleMixin):
@@ -620,18 +615,271 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 10
620
615
  b.data = b.data + 0
621
616
 
622
617
 
618
+
623
619
  freeze(model_to_quantize)
624
620
  torch.cuda.empty_cache()
625
621
  gc.collect()
626
622
  quantization_map = _quantization_map(model_to_quantize)
627
623
  model_to_quantize._quanto_map = quantization_map
628
624
 
625
+ if hasattr(model_to_quantize, "_already_pinned"):
626
+ delattr(model_to_quantize, "_already_pinned")
627
+
629
628
  _quantize_dirty_hack(model_to_quantize)
630
629
 
631
630
  print(f"Quantization of model '{model_id}' done")
632
631
 
633
632
  return True
634
633
 
634
+ def load_loras_into_model(model, lora_path, lora_multi = None, verboseLevel = -1):
635
+ verboseLevel = _compute_verbose_level(verboseLevel)
636
+
637
+ if inject_adapter_in_model == None or set_weights_and_activate_adapters == None or get_peft_kwargs == None:
638
+ raise Exception("Unable to load Lora, missing 'peft' and / or 'diffusers' modules")
639
+
640
+ if not isinstance(lora_path, list):
641
+ lora_path = [lora_path]
642
+
643
+ if lora_multi is None:
644
+ lora_multi = [1. for _ in lora_path]
645
+
646
+ for i, path in enumerate(lora_path):
647
+ adapter_name = str(i)
648
+
649
+ state_dict = safetensors2.torch_load_file(path)
650
+
651
+ keys = list(state_dict.keys())
652
+ if len(keys) == 0:
653
+ raise Exception(f"Empty Lora '{path}'")
654
+
655
+
656
+ network_alphas = {}
657
+ for k in keys:
658
+ if "alpha" in k:
659
+ alpha_value = state_dict.pop(k)
660
+ if not ( (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
661
+ alpha_value, float
662
+ )):
663
+ network_alphas[k] = torch.tensor( float(alpha_value.item() ) )
664
+
665
+ pos = keys[0].find(".")
666
+ prefix = keys[0][0:pos]
667
+ if not any( prefix.startswith(some_prefix) for some_prefix in ["diffusion_model", "transformer"]):
668
+ msg = f"No compatible weight was found in Lora file '{path}'. Please check that it is compatible with the Diffusers format."
669
+ raise Exception(msg)
670
+
671
+ transformer = model
672
+
673
+ transformer_keys = [k for k in keys if k.startswith(prefix)]
674
+ state_dict = {
675
+ k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in transformer_keys
676
+ }
677
+
678
+ sd_keys = state_dict.keys()
679
+ if len(sd_keys) == 0:
680
+ print(f"No compatible weight was found in Lora file '{path}'. Please check that it is compatible with the Diffusers format.")
681
+ return
682
+
683
+ # is_correct_format = all("lora" in key for key in state_dict.keys())
684
+
685
+
686
+
687
+
688
+ # check with first key if is not in peft format
689
+ # first_key = next(iter(state_dict.keys()))
690
+ # if "lora_A" not in first_key:
691
+ # state_dict = convert_unet_state_dict_to_peft(state_dict)
692
+
693
+ if adapter_name in getattr(transformer, "peft_config", {}):
694
+ raise ValueError(
695
+ f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
696
+ )
697
+
698
+ rank = {}
699
+ for key, val in state_dict.items():
700
+ if "lora_B" in key:
701
+ rank[key] = val.shape[1]
702
+
703
+ if network_alphas is not None and len(network_alphas) >= 1:
704
+ alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
705
+ network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
706
+
707
+ lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
708
+
709
+ lora_config = LoraConfig(**lora_config_kwargs)
710
+ peft_kwargs = {}
711
+ peft_kwargs["low_cpu_mem_usage"] = True
712
+ inject_adapter_in_model(lora_config, model, adapter_name=adapter_name, **peft_kwargs)
713
+
714
+ incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
715
+
716
+ warn_msg = ""
717
+ if incompatible_keys is not None:
718
+ # Check only for unexpected keys.
719
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
720
+ if unexpected_keys:
721
+ pass
722
+ if verboseLevel >=1:
723
+ print(f"Lora '{path}' was loaded in model '{_get_module_name(model)}'")
724
+ set_weights_and_activate_adapters(model,[ str(i) for i in range(len(lora_multi))], lora_multi)
725
+
726
+
727
+ def fast_load_transformers_model(model_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, verboseLevel = -1):
728
+ """
729
+ quick version of .LoadfromPretrained of the transformers library
730
+ used to build a model and load the corresponding weights (quantized or not)
731
+ """
732
+
733
+
734
+ import os.path
735
+ from accelerate import init_empty_weights
736
+
737
+ if not (model_path.endswith(".sft") or model_path.endswith(".safetensors")):
738
+ raise Exception("full model path to file expected")
739
+
740
+ model_path = _get_model(model_path)
741
+ verboseLevel = _compute_verbose_level(verboseLevel)
742
+
743
+ with safetensors2.safe_open(model_path) as f:
744
+ metadata = f.metadata()
745
+
746
+ if metadata is None:
747
+ transformer_config = None
748
+ else:
749
+ transformer_config = metadata.get("config", None)
750
+
751
+ if transformer_config == None:
752
+ config_fullpath = os.path.join(os.path.dirname(model_path), "config.json")
753
+
754
+ if not os.path.isfile(config_fullpath):
755
+ raise Exception("a 'config.json' that describes the model is required in the directory of the model or inside the safetensor file")
756
+
757
+ with open(config_fullpath, "r", encoding="utf-8") as reader:
758
+ text = reader.read()
759
+ transformer_config= json.loads(text)
760
+
761
+
762
+ if "architectures" in transformer_config:
763
+ architectures = transformer_config["architectures"]
764
+ class_name = architectures[0]
765
+
766
+ module = __import__("transformers")
767
+ map = { "T5WithLMHeadModel" : "T5EncoderModel"}
768
+ class_name = map.get(class_name, class_name)
769
+ transfomer_class = getattr(module, class_name)
770
+ from transformers import AutoConfig
771
+
772
+ import tempfile
773
+ with tempfile.NamedTemporaryFile("w", delete = False, encoding ="utf-8") as fp:
774
+ fp.write(json.dumps(transformer_config))
775
+ fp.close()
776
+ config_obj = AutoConfig.from_pretrained(fp.name)
777
+ os.remove(fp.name)
778
+
779
+ #needed to keep inits of non persistent buffers
780
+ with init_empty_weights():
781
+ model = transfomer_class(config_obj)
782
+
783
+ model = model.base_model
784
+
785
+ elif "_class_name" in transformer_config:
786
+ class_name = transformer_config["_class_name"]
787
+
788
+ module = __import__("diffusers")
789
+ transfomer_class = getattr(module, class_name)
790
+
791
+ with init_empty_weights():
792
+ model = transfomer_class.from_config(transformer_config)
793
+
794
+
795
+ torch.set_default_device('cpu')
796
+
797
+ model._config = transformer_config
798
+
799
+ load_model_data(model,model_path, do_quantize = do_quantize, quantizationType = quantizationType, pinToMemory= pinToMemory, partialPinning= partialPinning, verboseLevel=verboseLevel )
800
+
801
+ return model
802
+
803
+
804
+
805
+ def load_model_data(model, file_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, verboseLevel = -1):
806
+ """
807
+ Load a model, detect if it has been previously quantized using quanto and do the extra setup if necessary
808
+ """
809
+
810
+ file_path = _get_model(file_path)
811
+ verboseLevel = _compute_verbose_level(verboseLevel)
812
+
813
+ model = _remove_model_wrapper(model)
814
+
815
+ # if pinToMemory and do_quantize:
816
+ # raise Exception("Pinning and Quantization can not be used at the same time")
817
+
818
+ if not (".safetensors" in file_path or ".sft" in file_path):
819
+ if pinToMemory:
820
+ raise Exception("Pinning to memory while loading only supported for safe tensors files")
821
+ state_dict = torch.load(file_path, weights_only=True)
822
+ if "module" in state_dict:
823
+ state_dict = state_dict["module"]
824
+ else:
825
+ state_dict, metadata = _safetensors_load_file(file_path)
826
+
827
+ if metadata is None:
828
+ quantization_map = None
829
+ else:
830
+ quantization_map = metadata.get("quantization_map", None)
831
+ config = metadata.get("config", None)
832
+ if config is not None:
833
+ model._config = config
834
+
835
+
836
+
837
+ if quantization_map is None:
838
+ pos = str.rfind(file_path, ".")
839
+ if pos > 0:
840
+ quantization_map_path = file_path[:pos]
841
+ quantization_map_path += "_map.json"
842
+
843
+ if os.path.isfile(quantization_map_path):
844
+ with open(quantization_map_path, 'r') as f:
845
+ quantization_map = json.load(f)
846
+
847
+
848
+
849
+ if quantization_map is None :
850
+ if "quanto" in file_path and not do_quantize:
851
+ print("Model seems to be quantized by quanto but no quantization map was found whether inside the model or in a separate '{file_path[:json]}_map.json' file")
852
+ else:
853
+ _requantize(model, state_dict, quantization_map)
854
+
855
+ missing_keys , unexpected_keys = model.load_state_dict(state_dict, False, assign = True )
856
+ # if len(missing_keys) > 0:
857
+ # sd_crap = { k : None for k in missing_keys}
858
+ # missing_keys , unexpected_keys = model.load_state_dict(sd_crap, strict =False, assign = True )
859
+ del state_dict
860
+
861
+ for k,p in model.named_parameters():
862
+ if p.is_meta:
863
+ txt = f"Incompatible State Dictionary or 'Init_Empty_Weights' not set since parameter '{k}' has no data"
864
+ raise Exception(txt)
865
+ for k,b in model.named_buffers():
866
+ if b.is_meta:
867
+ txt = f"Incompatible State Dictionary or 'Init_Empty_Weights' not set since buffer '{k}' has no data"
868
+ raise Exception(txt)
869
+
870
+ if do_quantize:
871
+ if quantization_map is None:
872
+ if _quantize(model, quantizationType, verboseLevel=verboseLevel, model_id=file_path):
873
+ quantization_map = model._quanto_map
874
+ else:
875
+ if verboseLevel >=1:
876
+ print("Model already quantized")
877
+
878
+ if pinToMemory:
879
+ _pin_to_memory(model, file_path, partialPinning = partialPinning, verboseLevel = verboseLevel)
880
+
881
+ return
882
+
635
883
  def get_model_name(model):
636
884
  return model.name
637
885
 
@@ -663,6 +911,7 @@ class offload:
663
911
  self.async_transfers = False
664
912
  global last_offload_obj
665
913
  last_offload_obj = self
914
+
666
915
 
667
916
  def add_module_to_blocks(self, model_id, blocks_name, submodule, prev_block_name):
668
917
 
@@ -684,15 +933,25 @@ class offload:
684
933
 
685
934
  for k,p in submodule.named_parameters(recurse=False):
686
935
  if isinstance(p, QTensor):
687
- blocks_params.append( (submodule, k, p._data, p._scale) )
688
- blocks_params_size += p._data.nbytes
689
- blocks_params_size += p._scale.nbytes
936
+ blocks_params.append( (submodule, k, p ) )
937
+
938
+ if p._qtype == qint4:
939
+ if hasattr(p,"_scale_shift"):
940
+ blocks_params_size += torch.numel(p._scale_shift) * p._scale_shift.element_size()
941
+ blocks_params_size += torch.numel(p._data._data) * p._data._data.element_size()
942
+ else:
943
+ blocks_params_size += torch.numel(p._scale) * p._scale.element_size()
944
+ blocks_params_size += torch.numel(p._shift) * p._shift.element_size()
945
+ blocks_params_size += torch.numel(p._data._data) * p._data._data.element_size()
946
+ else:
947
+ blocks_params_size += torch.numel(p._scale) * p._scale.element_size()
948
+ blocks_params_size += torch.numel(p._data) * p._data.element_size()
690
949
  else:
691
- blocks_params.append( (submodule, k, p.data, None) )
692
- blocks_params_size += p.data.nbytes
950
+ blocks_params.append( (submodule, k, p ) )
951
+ blocks_params_size += torch.numel(p.data) * p.data.element_size()
693
952
 
694
953
  for k, p in submodule.named_buffers(recurse=False):
695
- blocks_params.append( (submodule, k, p.data, None) )
954
+ blocks_params.append( (submodule, k, p) )
696
955
  blocks_params_size += p.data.nbytes
697
956
 
698
957
 
@@ -710,34 +969,28 @@ class offload:
710
969
  return False
711
970
  return True
712
971
 
713
- def gpu_load_blocks(self, model_id, blocks_name, async_load = False):
972
+ @torch.compiler.disable()
973
+ def gpu_load_blocks(self, model_id, blocks_name):
714
974
  # cl = clock.start()
715
975
 
716
976
  if blocks_name != None:
717
977
  self.loaded_blocks[model_id] = blocks_name
718
978
 
719
979
  entry_name = model_id if blocks_name is None else model_id + "/" + blocks_name
720
-
721
- def cpu_to_gpu(stream_to_use, blocks_params, record_for_stream = None):
980
+
981
+ def cpu_to_gpu(stream_to_use, blocks_params): #, record_for_stream = None
722
982
  with torch.cuda.stream(stream_to_use):
723
983
  for param in blocks_params:
724
- parent_module, n, data, scale = param
725
- p = getattr(parent_module, n)
726
- if isinstance(p, QTensor):
727
- q = WeightQBytesTensor.create(p.qtype, p.axis, p.size(), p.stride(), data.cuda(non_blocking=True), scale.cuda(non_blocking=True), activation_qtype=p.activation_qtype, requires_grad=p.requires_grad )
728
- #q = p.to("cuda", non_blocking=True)
729
- q = torch.nn.Parameter(q , requires_grad=False)
730
- setattr(parent_module, n , q)
731
- del p
732
- else:
733
- p.data = p.data.cuda(non_blocking=True)
734
-
735
- if record_for_stream != None:
736
- if isinstance(p, QTensor):
737
- q._data.record_stream(record_for_stream)
738
- q._scale.record_stream(record_for_stream)
739
- else:
740
- p.data.record_stream(record_for_stream)
984
+ parent_module, n, p = param
985
+ q = p.to("cuda", non_blocking=True)
986
+ q = torch.nn.Parameter(q , requires_grad=False)
987
+ setattr(parent_module, n , q)
988
+ # if record_for_stream != None:
989
+ # if isinstance(p, QTensor):
990
+ # q._data.record_stream(record_for_stream)
991
+ # q._scale.record_stream(record_for_stream)
992
+ # else:
993
+ # p.data.record_stream(record_for_stream)
741
994
 
742
995
 
743
996
  if self.verboseLevel >=2:
@@ -762,7 +1015,7 @@ class offload:
762
1015
  # cl.stop()
763
1016
  # print(f"load time: {cl.format_time_gap()}")
764
1017
 
765
-
1018
+ @torch.compiler.disable()
766
1019
  def gpu_unload_blocks(self, model_id, blocks_name):
767
1020
  # cl = clock.start()
768
1021
  if blocks_name != None:
@@ -776,23 +1029,14 @@ class offload:
776
1029
  print(f"Unloading model {blocks_name} ({model_name}) from GPU")
777
1030
 
778
1031
  blocks_params = self.blocks_of_modules[blocks_name]
779
-
780
1032
  for param in blocks_params:
781
- parent_module, n, data, scale = param
782
- p = getattr(parent_module, n)
783
- if isinstance(p, QTensor):
784
- # need to change the parameter directly from the module as it can't be swapped in place due to a memory leak in the pytorch compiler
785
- q = WeightQBytesTensor.create(p.qtype, p.axis, p.size(), p.stride(), data, scale, activation_qtype=p.activation_qtype, requires_grad=p.requires_grad )
786
- q = torch.nn.Parameter(q , requires_grad=False)
787
- setattr(parent_module, n , q)
788
- del p
789
- else:
790
- p.data = data
791
-
1033
+ parent_module, n, p = param
1034
+ q = torch.nn.Parameter(p , requires_grad=False)
1035
+ setattr(parent_module, n , q)
792
1036
  # cl.stop()
793
1037
  # print(f"unload time: {cl.format_time_gap()}")
794
1038
 
795
-
1039
+ # @torch.compiler.disable()
796
1040
  def gpu_load(self, model_id):
797
1041
  model = self.models[model_id]
798
1042
  self.active_models.append(model)
@@ -824,8 +1068,8 @@ class offload:
824
1068
  if torch.is_tensor(arg):
825
1069
  if arg.dtype == torch.float32:
826
1070
  arg = arg.to(torch.bfloat16).cuda(non_blocking=True)
827
- else:
828
- arg = arg.cuda(non_blocking=True)
1071
+ elif not arg.is_cuda:
1072
+ arg = arg.cuda(non_blocking=True)
829
1073
  new_args.append(arg)
830
1074
 
831
1075
  for k in kwargs:
@@ -833,7 +1077,7 @@ class offload:
833
1077
  if torch.is_tensor(arg):
834
1078
  if arg.dtype == torch.float32:
835
1079
  arg = arg.to(torch.bfloat16).cuda(non_blocking=True)
836
- else:
1080
+ elif not arg.is_cuda:
837
1081
  arg = arg.cuda(non_blocking=True)
838
1082
  new_kwargs[k]= arg
839
1083
 
@@ -874,10 +1118,10 @@ class offload:
874
1118
 
875
1119
  return False
876
1120
 
877
- def hook_load_data_if_needed(self, target_module, model_id,blocks_name, context):
1121
+ def hook_preload_blocks_for_compilation(self, target_module, model_id,blocks_name, context):
878
1122
 
879
- @torch.compiler.disable()
880
- def load_data_if_needed(module, *args, **kwargs):
1123
+ # @torch.compiler.disable()
1124
+ def preload_blocks_for_compile(module, *args, **kwargs):
881
1125
  some_context = context #for debugging
882
1126
  if blocks_name == None:
883
1127
  if self.ready_to_check_mem():
@@ -891,12 +1135,17 @@ class offload:
891
1135
  self.empty_cache_if_needed()
892
1136
  self.loaded_blocks[model_id] = blocks_name
893
1137
  self.gpu_load_blocks(model_id, blocks_name)
894
-
895
- target_module.register_forward_pre_hook(load_data_if_needed)
1138
+ # need to be registered before the forward not to be break the efficiency of the compilation chain
1139
+ # it should be at the top of the compilation as this type of hook in the middle of a chain seems to break memory performance
1140
+ target_module.register_forward_pre_hook(preload_blocks_for_compile)
896
1141
 
897
1142
 
898
1143
  def hook_check_empty_cache_needed(self, target_module, model_id,blocks_name, previous_method, context):
899
1144
 
1145
+ qint4quantization = isinstance(target_module, QModuleMixin) and target_module.weight!= None and target_module.weight.qtype == qint4
1146
+ if qint4quantization:
1147
+ pass
1148
+
900
1149
  def check_empty_cuda_cache(module, *args, **kwargs):
901
1150
  # if self.ready_to_check_mem():
902
1151
  # self.empty_cache_if_needed()
@@ -912,6 +1161,8 @@ class offload:
912
1161
  self.empty_cache_if_needed()
913
1162
  self.loaded_blocks[model_id] = blocks_name
914
1163
  self.gpu_load_blocks(model_id, blocks_name)
1164
+ if qint4quantization:
1165
+ args, kwargs = self.move_args_to_gpu(*args, **kwargs)
915
1166
 
916
1167
  return previous_method(*args, **kwargs)
917
1168
 
@@ -959,177 +1210,18 @@ class offload:
959
1210
  print(f"Hooked in model '{model_id}' ({model_name})")
960
1211
 
961
1212
 
962
- # Not implemented yet, but why would one want to get rid of these features ?
963
- # def unhook_module(module: torch.nn.Module):
964
- # if not hasattr(module,"_mm_id"):
965
- # return
966
-
967
- # delattr(module, "_mm_id")
968
-
969
- # def unhook_all(parent_module: torch.nn.Module):
970
- # for module in parent_module.components.items():
971
- # self.unhook_module(module)
972
-
973
- def fast_load_transformers_model(model_path: str, do_quantize = False, quantization_type = qint8, pinToMemory = False, partialPinning = False, verboseLevel = -1):
974
- """
975
- quick version of .LoadfromPretrained of the transformers library
976
- used to build a model and load the corresponding weights (quantized or not)
977
- """
978
-
979
-
980
- import os.path
981
- from accelerate import init_empty_weights
982
-
983
- if not (model_path.endswith(".sft") or model_path.endswith(".safetensors")):
984
- raise Exception("full model path to file expected")
985
-
986
- model_path = _get_model(model_path)
987
- verboseLevel = _compute_verbose_level(verboseLevel)
988
-
989
- with safetensors2.safe_open(model_path) as f:
990
- metadata = f.metadata()
991
-
992
- if metadata is None:
993
- transformer_config = None
994
- else:
995
- transformer_config = metadata.get("config", None)
996
-
997
- if transformer_config == None:
998
- config_fullpath = os.path.join(os.path.dirname(model_path), "config.json")
999
-
1000
- if not os.path.isfile(config_fullpath):
1001
- raise Exception("a 'config.json' that describes the model is required in the directory of the model or inside the safetensor file")
1002
-
1003
- with open(config_fullpath, "r", encoding="utf-8") as reader:
1004
- text = reader.read()
1005
- transformer_config= json.loads(text)
1006
-
1007
-
1008
- if "architectures" in transformer_config:
1009
- architectures = transformer_config["architectures"]
1010
- class_name = architectures[0]
1011
-
1012
- module = __import__("transformers")
1013
- transfomer_class = getattr(module, class_name)
1014
- from transformers import AutoConfig
1015
-
1016
- import tempfile
1017
- with tempfile.NamedTemporaryFile("w", delete = False, encoding ="utf-8") as fp:
1018
- fp.write(json.dumps(transformer_config))
1019
- fp.close()
1020
- config_obj = AutoConfig.from_pretrained(fp.name)
1021
- os.remove(fp.name)
1022
-
1023
- #needed to keep inits of non persistent buffers
1024
- with init_empty_weights():
1025
- model = transfomer_class(config_obj)
1026
-
1027
- model = model.base_model
1028
-
1029
- elif "_class_name" in transformer_config:
1030
- class_name = transformer_config["_class_name"]
1031
-
1032
- module = __import__("diffusers")
1033
- transfomer_class = getattr(module, class_name)
1034
-
1035
- with init_empty_weights():
1036
- model = transfomer_class.from_config(transformer_config)
1037
-
1038
-
1039
- torch.set_default_device('cpu')
1040
-
1041
- model._config = transformer_config
1042
-
1043
- load_model_data(model,model_path, do_quantize = do_quantize, quantization_type = quantization_type, pinToMemory= pinToMemory, partialPinning= partialPinning, verboseLevel=verboseLevel )
1044
-
1045
- return model
1046
-
1047
-
1048
-
1049
- def load_model_data(model, file_path: str, do_quantize = False, quantization_type = qint8, pinToMemory = False, partialPinning = False, verboseLevel = -1):
1050
- """
1051
- Load a model, detect if it has been previously quantized using quanto and do the extra setup if necessary
1052
- """
1053
-
1054
- file_path = _get_model(file_path)
1055
- verboseLevel = _compute_verbose_level(verboseLevel)
1056
-
1057
- model = _remove_model_wrapper(model)
1058
-
1059
- # if pinToMemory and do_quantize:
1060
- # raise Exception("Pinning and Quantization can not be used at the same time")
1061
-
1062
- if not (".safetensors" in file_path or ".sft" in file_path):
1063
- if pinToMemory:
1064
- raise Exception("Pinning to memory while loading only supported for safe tensors files")
1065
- state_dict = torch.load(file_path, weights_only=True)
1066
- if "module" in state_dict:
1067
- state_dict = state_dict["module"]
1068
- else:
1069
- state_dict, metadata = _safetensors_load_file(file_path)
1070
-
1071
-
1072
- # if pinToMemory:
1073
- # _pin_to_memory_sd(model,state_dict, file_path, partialPinning = partialPinning, perc_reserved_mem_max = perc_reserved_mem_max, verboseLevel = verboseLevel)
1074
-
1075
- # with safetensors2.safe_open(file_path) as f:
1076
- # metadata = f.metadata()
1077
-
1078
-
1079
- if metadata is None:
1080
- quantization_map = None
1081
- else:
1082
- quantization_map = metadata.get("quantization_map", None)
1083
- config = metadata.get("config", None)
1084
- if config is not None:
1085
- model._config = config
1086
-
1087
-
1088
-
1089
- if quantization_map is None:
1090
- pos = str.rfind(file_path, ".")
1091
- if pos > 0:
1092
- quantization_map_path = file_path[:pos]
1093
- quantization_map_path += "_map.json"
1094
-
1095
- if os.path.isfile(quantization_map_path):
1096
- with open(quantization_map_path, 'r') as f:
1097
- quantization_map = json.load(f)
1098
-
1099
-
1100
-
1101
- if quantization_map is None :
1102
- if "quanto" in file_path and not do_quantize:
1103
- print("Model seems to be quantized by quanto but no quantization map was found whether inside the model or in a separate '{file_path[:json]}_map.json' file")
1104
- else:
1105
- _requantize(model, state_dict, quantization_map)
1106
-
1107
- missing_keys , unexpected_keys = model.load_state_dict(state_dict, strict = quantization_map is None, assign = True )
1108
- del state_dict
1109
-
1110
- if do_quantize:
1111
- if quantization_map is None:
1112
- if _quantize(model, quantization_type, verboseLevel=verboseLevel, model_id=file_path):
1113
- quantization_map = model._quanto_map
1114
- else:
1115
- if verboseLevel >=1:
1116
- print("Model already quantized")
1117
-
1118
- if pinToMemory:
1119
- _pin_to_memory(model, file_path, partialPinning = partialPinning, verboseLevel = verboseLevel)
1120
-
1121
- return
1122
-
1123
- def save_model(model, file_path, do_quantize = False, quantization_type = qint8, verboseLevel = -1 ):
1213
+ def save_model(model, file_path, do_quantize = False, quantizationType = qint8, verboseLevel = -1, config_file_path = None ):
1124
1214
  """save the weights of a model and quantize them if requested
1125
1215
  These weights can be loaded again using 'load_model_data'
1126
1216
  """
1127
1217
 
1128
1218
  config = None
1129
-
1130
1219
  verboseLevel = _compute_verbose_level(verboseLevel)
1131
-
1132
- if hasattr(model, "_config"):
1220
+ if config_file_path !=None:
1221
+ with open(config_file_path, "r", encoding="utf-8") as reader:
1222
+ text = reader.read()
1223
+ config= json.loads(text)
1224
+ elif hasattr(model, "_config"):
1133
1225
  config = model._config
1134
1226
  elif hasattr(model, "config"):
1135
1227
  config_fullpath = None
@@ -1147,7 +1239,7 @@ def save_model(model, file_path, do_quantize = False, quantization_type = qint8,
1147
1239
  config= json.loads(text)
1148
1240
 
1149
1241
  if do_quantize:
1150
- _quantize(model, weights=quantization_type, model_id=file_path)
1242
+ _quantize(model, weights=quantizationType, model_id=file_path)
1151
1243
 
1152
1244
  quantization_map = getattr(model, "_quanto_map", None)
1153
1245
 
@@ -1155,12 +1247,12 @@ def save_model(model, file_path, do_quantize = False, quantization_type = qint8,
1155
1247
  print(f"Saving file '{file_path}")
1156
1248
  safetensors2.torch_write_file(model.state_dict(), file_path , quantization_map = quantization_map, config = config)
1157
1249
  if verboseLevel >=1:
1158
- print(f"File '{file_path} saved")
1250
+ print(f"File '{file_path}' saved")
1159
1251
 
1160
1252
 
1161
1253
 
1162
1254
 
1163
- def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = True, extraModelsToQuantize = None, budgets= 0, asyncTransfers = True, compile = False, perc_reserved_mem_max = 0, verboseLevel = -1):
1255
+ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = True, extraModelsToQuantize = None, quantizationType = qint8, budgets= 0, asyncTransfers = True, compile = False, perc_reserved_mem_max = 0, verboseLevel = -1):
1164
1256
  """Hook to a pipeline or a group of modules in order to reduce their VRAM requirements:
1165
1257
  pipe_or_dict_of_modules : the pipeline object or a dictionary of modules of the model
1166
1258
  quantizeTransformer: set True by default will quantize on the fly the video / image model
@@ -1238,13 +1330,14 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
1238
1330
  self.anyCompiledModule = compileAllModels or len(modelsToCompile)>0
1239
1331
  if self.anyCompiledModule:
1240
1332
  torch._dynamo.config.cache_size_limit = 10000
1333
+ torch.compiler.reset()
1334
+
1241
1335
  # torch._logging.set_logs(recompiles=True)
1242
1336
  # torch._inductor.config.realize_opcount_threshold = 100 # workaround bug "AssertionError: increase TRITON_MAX_BLOCK['X'] to 4096."
1243
1337
 
1244
1338
  max_reservable_memory = _get_max_reservable_memory(perc_reserved_mem_max)
1245
1339
 
1246
1340
  estimatesBytesToPin = 0
1247
-
1248
1341
  for model_id in models:
1249
1342
  current_model: torch.nn.Module = models[model_id]
1250
1343
  # make sure that no RAM or GPU memory is not allocated for gradiant / training
@@ -1252,19 +1345,30 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
1252
1345
 
1253
1346
  # if the model has just been quantized so there is no need to quantize it again
1254
1347
  if model_id in models_to_quantize:
1255
- _quantize(current_model, weights=qint8, verboseLevel = self.verboseLevel, model_id=model_id)
1348
+ _quantize(current_model, weights=quantizationType, verboseLevel = self.verboseLevel, model_id=model_id)
1256
1349
 
1257
1350
  modelPinned = (pinAllModels or model_id in modelsToPin) and not hasattr(current_model,"_already_pinned")
1258
1351
 
1259
- current_model_size = 0
1260
- # load all the remaining unread lazy safetensors in RAM to free open cache files
1261
- for p in current_model.parameters():
1352
+ current_model_size = 0
1353
+
1354
+ for n, p in current_model.named_parameters():
1355
+ p.requires_grad = False
1262
1356
  if isinstance(p, QTensor):
1263
1357
  # # fix quanto bug (seems to have been fixed)
1264
1358
  # if not modelPinned and p._scale.dtype == torch.float32:
1265
1359
  # p._scale = p._scale.to(torch.bfloat16)
1266
- current_model_size += torch.numel(p._scale) * p._scale.element_size()
1267
- current_model_size += torch.numel(p._data) * p._data.element_size()
1360
+ if p._qtype == qint4:
1361
+ if hasattr(p,"_scale_shift"):
1362
+ current_model_size += torch.numel(p._scale_shift) * p._scale_shift.element_size()
1363
+ else:
1364
+ current_model_size += torch.numel(p._scale) * p._shift.element_size() + torch.numel(p._scale) * p._shift.element_size()
1365
+
1366
+ current_model_size += torch.numel(p._data._data) * p._data._data.element_size()
1367
+
1368
+ else:
1369
+ current_model_size += torch.numel(p._scale) * p._scale.element_size()
1370
+ current_model_size += torch.numel(p._data) * p._data.element_size()
1371
+
1268
1372
  else:
1269
1373
  if p.data.dtype == torch.float32:
1270
1374
  # convert any left overs float32 weight to bloat16 to divide by 2 the model memory footprint
@@ -1272,7 +1376,7 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
1272
1376
  current_model_size += torch.numel(p.data) * p.data.element_size()
1273
1377
 
1274
1378
  for b in current_model.buffers():
1275
- if b.data.dtype == torch.float32:
1379
+ if b.data.dtype == torch.float32:
1276
1380
  # convert any left overs float32 weight to bloat16 to divide by 2 the model memory footprint
1277
1381
  b.data = b.data.to(torch.bfloat16)
1278
1382
  current_model_size += torch.numel(b.data) * b.data.element_size()
@@ -1298,22 +1402,21 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
1298
1402
  # Hook forward methods of modules
1299
1403
  for model_id in models:
1300
1404
  current_model: torch.nn.Module = models[model_id]
1301
- current_budget = model_budgets[model_id]
1302
- current_size = 0
1303
- cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq = None, None, None, -1
1304
- self.loaded_blocks[model_id] = None
1305
1405
  towers_names, towers_modules = _detect_main_towers(current_model)
1306
- towers_names = [n +"." for n in towers_names]
1307
1406
  if self.verboseLevel>=2 and len(towers_names)>0:
1308
1407
  print(f"Potential iterative blocks found in model '{model_id}':{towers_names}")
1309
1408
  # compile main iterative modules stacks ("towers")
1310
- if compileAllModels or model_id in modelsToCompile :
1311
- #torch.compiler.reset()
1409
+ compilationInThisOne = compileAllModels or model_id in modelsToCompile
1410
+ if compilationInThisOne:
1312
1411
  if self.verboseLevel>=1:
1313
- print(f"Pytorch compilation of model '{model_id}' is scheduled.")
1314
- for tower in towers_modules:
1315
- for submodel in tower:
1316
- submodel.forward= torch.compile(submodel.forward, backend= "inductor", mode="default" ) # , fullgraph= True, mode= "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs",
1412
+ if len(towers_modules)>0:
1413
+ print(f"Pytorch compilation of model '{model_id}' is scheduled.")
1414
+ else:
1415
+ print(f"Pytorch compilation of model '{model_id}' is not yet supported.")
1416
+
1417
+ for submodel in towers_modules:
1418
+ # for submodel in tower:
1419
+ submodel.forward= torch.compile(submodel.forward, backend= "inductor", mode="default" ) # , fullgraph= True, mode= "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs",
1317
1420
  #dynamic=True,
1318
1421
 
1319
1422
  if pinAllModels or model_id in modelsToPin:
@@ -1323,6 +1426,11 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
1323
1426
  else:
1324
1427
  _pin_to_memory(current_model, model_id, partialPinning= partialPinning, perc_reserved_mem_max=perc_reserved_mem_max, verboseLevel=verboseLevel)
1325
1428
 
1429
+ current_budget = model_budgets[model_id]
1430
+ current_size = 0
1431
+ cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq = None, None, None, -1
1432
+ self.loaded_blocks[model_id] = None
1433
+
1326
1434
  for submodule_name, submodule in current_model.named_modules():
1327
1435
  # create a fake 'accelerate' parameter so that the _execution_device property returns always "cuda"
1328
1436
  # (it is queried in many pipelines even if offloading is not properly implemented)
@@ -1331,44 +1439,43 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
1331
1439
 
1332
1440
  if submodule_name=='':
1333
1441
  continue
1334
- newListItem = False
1442
+
1335
1443
  if current_budget > 0:
1336
- if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): #
1337
- if cur_blocks_prefix == None:
1338
- cur_blocks_prefix = submodule_name + "."
1444
+ if cur_blocks_prefix != None:
1445
+ if submodule_name.startswith(cur_blocks_prefix):
1446
+ depth_prefix = cur_blocks_prefix.split(".")
1447
+ depth_name = submodule_name.split(".")
1448
+ level = depth_name[len(depth_prefix)-1]
1449
+ pre , num = _extract_num_from_str(level)
1450
+ if num != cur_blocks_seq and (cur_blocks_seq == -1 or current_size > current_budget):
1451
+ prev_blocks_name = cur_blocks_name
1452
+ cur_blocks_name = cur_blocks_prefix + str(num)
1453
+ # print(f"new block: {model_id}/{cur_blocks_name} - {submodule_name}")
1454
+ cur_blocks_seq = num
1339
1455
  else:
1340
- #if cur_blocks_prefix != submodule_name[:len(cur_blocks_prefix)]:
1341
- if not submodule_name.startswith(cur_blocks_prefix):
1342
- cur_blocks_prefix = submodule_name + "."
1343
- cur_blocks_name,cur_blocks_seq = None, -1
1344
- else:
1345
-
1346
- if cur_blocks_prefix is not None:
1347
- if submodule_name.startswith(cur_blocks_prefix):
1348
- num = int(submodule_name[len(cur_blocks_prefix):].split(".")[0])
1349
- newListItem= num != cur_blocks_seq
1350
- if num != cur_blocks_seq and (cur_blocks_name == None or current_size > current_budget):
1351
- prev_blocks_name = cur_blocks_name
1352
- cur_blocks_name = cur_blocks_prefix + str(num)
1353
- # print(f"new block: {model_id}/{cur_blocks_name} - {submodule_name}")
1354
- cur_blocks_seq = num
1355
- else:
1356
- cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq = None, None, None, -1
1357
-
1456
+ cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq = None, None, None, -1
1457
+
1458
+ if cur_blocks_prefix == None:
1459
+ pre , num = _extract_num_from_str(submodule_name)
1460
+ if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
1461
+ cur_blocks_prefix, prev_blocks_name, cur_blocks_seq = pre + ".", None, -1
1462
+ elif num >=0:
1463
+ cur_blocks_prefix, prev_blocks_name, cur_blocks_seq = pre, None, num
1464
+ cur_blocks_name = submodule_name
1465
+ # print(f"new block: {model_id}/{cur_blocks_name} - {submodule_name}")
1466
+
1467
+
1358
1468
  if hasattr(submodule, "forward"):
1359
1469
  submodule_method = getattr(submodule, "forward")
1360
1470
  if callable(submodule_method):
1361
1471
  if len(submodule_name.split("."))==1:
1362
1472
  self.hook_change_module(submodule, current_model, model_id, submodule_name, submodule_method)
1363
- elif newListItem:
1364
- self.hook_load_data_if_needed(submodule, model_id, cur_blocks_name, context = submodule_name )
1473
+ elif compilationInThisOne and submodule in towers_modules:
1474
+ self.hook_preload_blocks_for_compilation(submodule, model_id, cur_blocks_name, context = submodule_name )
1365
1475
  else:
1366
1476
  self.hook_check_empty_cache_needed(submodule, model_id, cur_blocks_name, submodule_method, context = submodule_name )
1367
1477
 
1368
-
1369
- current_size = self.add_module_to_blocks(model_id, cur_blocks_name, submodule, prev_blocks_name)
1370
-
1371
-
1478
+ current_size = self.add_module_to_blocks(model_id, cur_blocks_name, submodule, prev_blocks_name)
1372
1479
 
1373
1480
 
1374
1481
  if self.verboseLevel >=2:
@@ -1406,7 +1513,7 @@ def profile(pipe_or_dict_of_modules, profile_no: profile_type = profile_type.Ve
1406
1513
  modules= modules.components
1407
1514
 
1408
1515
  modules = {k: _remove_model_wrapper(v) for k, v in modules.items() if isinstance(v, torch.nn.Module)}
1409
- module_names = {k: v.__module__.lower() for k, v in modules.items() }
1516
+ module_names = {k: _get_module_name(v) for k, v in modules.items() }
1410
1517
 
1411
1518
  default_extraModelsToQuantize = []
1412
1519
  quantizeTransformer = True
@@ -1414,11 +1521,12 @@ def profile(pipe_or_dict_of_modules, profile_no: profile_type = profile_type.Ve
1414
1521
  models_to_scan = ("text_encoder", "text_encoder_2")
1415
1522
  candidates_to_quantize = ("t5", "llama", "llm")
1416
1523
  for model_id in models_to_scan:
1417
- name = module_names[model_id]
1418
- for candidate in candidates_to_quantize:
1419
- if candidate in name:
1420
- default_extraModelsToQuantize.append(model_id)
1421
- break
1524
+ if model_id in module_names:
1525
+ name = module_names[model_id]
1526
+ for candidate in candidates_to_quantize:
1527
+ if candidate in name:
1528
+ default_extraModelsToQuantize.append(model_id)
1529
+ break
1422
1530
 
1423
1531
 
1424
1532
  # transformer (video or image generator) should be as small as possible not to occupy space that could be used by actual image data
@@ -1427,6 +1535,7 @@ def profile(pipe_or_dict_of_modules, profile_no: profile_type = profile_type.Ve
1427
1535
  default_budgets = { "transformer" : 600 , "text_encoder": 3000, "text_encoder_2": 3000 }
1428
1536
  extraModelsToQuantize = None
1429
1537
  asyncTransfers = True
1538
+ budgets = None
1430
1539
 
1431
1540
  if profile_no == profile_type.HighRAM_HighVRAM:
1432
1541
  pinnedMemory= True