mmgp 3.1.4.post1__py3-none-any.whl → 3.1.4.post151__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mmgp might be problematic. Click here for more details.

mmgp/offload.py CHANGED
@@ -1,4 +1,4 @@
1
- # ------------------ Memory Management 3.1 for the GPU Poor by DeepBeepMeep (mmgp)------------------
1
+ # ------------------ Memory Management 3.1.4-1591 for the GPU Poor by DeepBeepMeep (mmgp)------------------
2
2
  #
3
3
  # This module contains multiples optimisations so that models such as Flux (and derived), Mochi, CogView, HunyuanVideo, ... can run smoothly on a 24 GB GPU limited card.
4
4
  # This a replacement for the accelerate library that should in theory manage offloading, but doesn't work properly with models that are loaded / unloaded several
@@ -76,7 +76,18 @@ except:
76
76
  from mmgp import safetensors2
77
77
  from mmgp import profile_type
78
78
 
79
- from optimum.quanto import freeze, qfloat8, qint4 , qint8, quantize, QModuleMixin, QTensor, quantize_module
79
+ from optimum.quanto import freeze, qfloat8, qint4 , qint8, quantize, QModuleMixin, QTensor, quantize_module, register_qmodule
80
+
81
+ # support for Embedding module quantization that is not supported by default by quanto
82
+ @register_qmodule(torch.nn.Embedding)
83
+ class QEmbedding(QModuleMixin, torch.nn.Embedding):
84
+ @classmethod
85
+ def qcreate(cls, module, weights, activations = None, optimizer = None, device = None):
86
+ module.bias = None
87
+ return cls( module.num_embeddings, module.embedding_dim, module.padding_idx , module.max_norm, module.norm_type, module.scale_grad_by_freq, module.sparse, dtype=module.weight.dtype, device=device, weights=weights,
88
+ activations=activations, optimizer=optimizer, quantize_input=True)
89
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
90
+ return torch.nn.functional.embedding( input, self.qweight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse )
80
91
 
81
92
 
82
93
  shared_state = {}
@@ -96,11 +107,6 @@ ENDC = '\033[0m'
96
107
  BOLD ='\033[1m'
97
108
  UNBOLD ='\033[0m'
98
109
 
99
- cotenants_map = {
100
- "text_encoder": ["vae", "text_encoder_2"],
101
- "text_encoder_2": ["vae", "text_encoder"],
102
- }
103
-
104
110
  class clock:
105
111
  def __init__(self):
106
112
  self.start_time = 0
@@ -191,10 +197,10 @@ def _detect_main_towers(model, min_floors = 5):
191
197
  pre , num = _extract_num_from_str(submodule_name)
192
198
  if isinstance(submodule, (torch.nn.ModuleList)):
193
199
  cur_blocks_prefix, cur_blocks_seq = pre + ".", -1
194
- tower_name = submodule_name #+ ".*"
200
+ tower_name = submodule_name + "."
195
201
  elif num >=0:
196
202
  cur_blocks_prefix, cur_blocks_seq = pre, num
197
- tower_name = submodule_name[ :-1] #+ "*"
203
+ tower_name = submodule_name[ :-1]
198
204
  floors_modules.append(submodule)
199
205
 
200
206
  if len(floors_modules) >= min_floors:
@@ -216,15 +222,17 @@ def _get_model(model_path):
216
222
  if len(_path)<=1:
217
223
  raise("file not found")
218
224
  else:
219
- from huggingface_hub import hf_hub_download #snapshot_download,
220
- repoId= os.path.join(*_path[0:2] ).replace("\\", "/")
221
-
222
- if len(_path) > 2:
223
- _subfolder = os.path.join(*_path[2:] )
224
- model_path = hf_hub_download(repo_id=repoId, filename=_filename, subfolder=_subfolder)
225
- else:
226
- model_path = hf_hub_download(repo_id=repoId, filename=_filename)
225
+ try:
226
+ from huggingface_hub import hf_hub_download #snapshot_download,
227
+ repoId= os.path.join(*_path[0:2] ).replace("\\", "/")
227
228
 
229
+ if len(_path) > 2:
230
+ _subfolder = os.path.join(*_path[2:] )
231
+ model_path = hf_hub_download(repo_id=repoId, filename=_filename, subfolder=_subfolder)
232
+ else:
233
+ model_path = hf_hub_download(repo_id=repoId, filename=_filename)
234
+ except:
235
+ model_path = None
228
236
  return model_path
229
237
 
230
238
 
@@ -278,9 +286,17 @@ def _force_load_parameter(p):
278
286
  torch.utils.swap_tensors(p, q)
279
287
  del q
280
288
 
281
- def _pin_to_memory(model, model_id, partialPinning = False, verboseLevel = 1):
289
+ def _get_tensor_ref(p):
290
+ if isinstance(p, QTensor):
291
+ if p._qtype == qint4:
292
+ return p._data._data.data_ptr()
293
+ else:
294
+ return p._data.data_ptr()
295
+ else:
296
+ return p.data_ptr()
282
297
 
283
298
 
299
+ def _pin_to_memory(model, model_id, partialPinning = False, verboseLevel = 1):
284
300
  if partialPinning:
285
301
  towers_names, _ = _detect_main_towers(model)
286
302
 
@@ -292,56 +308,63 @@ def _pin_to_memory(model, model_id, partialPinning = False, verboseLevel = 1):
292
308
  tensor_map_indexes = []
293
309
  total_tensor_bytes = 0
294
310
 
295
- params_list = []
311
+ params_dict = {} # OrderedDict
296
312
  for k, sub_module in model.named_modules():
297
313
  include = True
298
314
  if partialPinning:
299
315
  include = any(k.startswith(pre) for pre in towers_names) if partialPinning else True
300
316
  if include:
301
- params_list = params_list + [ (k + '.' + n, p, False) for n, p in sub_module.named_parameters(recurse=False)] + [ (k + '.' + n, p, True) for n, p in sub_module.named_buffers(recurse=False)]
302
-
317
+ params_dict.update( { k + '.' + n : (p, False) for n, p in sub_module.named_parameters(recurse=False) } )
318
+ params_dict.update( { k + '.' + n : (b, True) for n, b in sub_module.named_buffers(recurse=False) } )
303
319
 
304
320
  if verboseLevel>=1 :
305
321
  if partialPinning:
306
- if len(params_list) == 0:
322
+ if len(params_dict) == 0:
307
323
  print(f"Unable to apply Partial of '{model_id}' as no isolated main structures were found")
308
324
  else:
309
325
  print(f"Partial pinning of data of '{model_id}' to reserved RAM")
310
326
  else:
311
327
  print(f"Pinning data of '{model_id}' to reserved RAM")
312
328
 
313
- if partialPinning and len(params_list) == 0:
329
+ if partialPinning and len(params_dict) == 0:
314
330
  return
315
331
 
316
-
317
-
318
- for n, p, _ in params_list:
319
- if isinstance(p, QTensor):
320
- if p._qtype == qint4:
321
- if hasattr(p,"_scale_shift"):
322
- length = torch.numel(p._data._data) * p._data._data.element_size() + torch.numel(p._scale_shift) * p._scale_shift.element_size()
332
+ ref_cache = {}
333
+ tied_weights = {}
334
+ for n, (p, _) in params_dict.items():
335
+ ref = _get_tensor_ref(p)
336
+ match = ref_cache.get(ref, None)
337
+ if match != None:
338
+ match_name, match_size = match
339
+ if verboseLevel >=1:
340
+ print(f"Tied weights of {match_size/ONE_MB:0.2f} MB detected: {match_name} <-> {n}")
341
+ tied_weights[n] = match_name
342
+ else:
343
+ if isinstance(p, QTensor):
344
+ if p._qtype == qint4:
345
+ if hasattr(p,"_scale_shift"):
346
+ length = torch.numel(p._data._data) * p._data._data.element_size() + torch.numel(p._scale_shift) * p._scale_shift.element_size()
347
+ else:
348
+ length = torch.numel(p._data._data) * p._data._data.element_size() + torch.numel(p._scale) * p._scale.element_size() + torch.numel(p._shift) * p._shift.element_size()
323
349
  else:
324
- length = torch.numel(p._data._data) * p._data._data.element_size() + torch.numel(p._scale) * p._scale.element_size() + torch.numel(p._shift) * p._shift.element_size()
350
+ length = torch.numel(p._data) * p._data.element_size() + torch.numel(p._scale) * p._scale.element_size()
325
351
  else:
326
- length = torch.numel(p._data) * p._data.element_size() + torch.numel(p._scale) * p._scale.element_size()
327
- else:
328
- length = torch.numel(p.data) * p.data.element_size()
329
-
330
-
331
- if current_big_tensor_size + length > BIG_TENSOR_MAX_SIZE:
332
- big_tensors_sizes.append(current_big_tensor_size)
333
- current_big_tensor_size = 0
334
- big_tensor_no += 1
352
+ length = torch.numel(p.data) * p.data.element_size()
335
353
 
354
+ ref_cache[ref] = (n, length)
355
+ if current_big_tensor_size + length > BIG_TENSOR_MAX_SIZE:
356
+ big_tensors_sizes.append(current_big_tensor_size)
357
+ current_big_tensor_size = 0
358
+ big_tensor_no += 1
336
359
 
337
- itemsize = p.data.dtype.itemsize
338
- if current_big_tensor_size % itemsize:
339
- current_big_tensor_size += itemsize - current_big_tensor_size % itemsize
340
- tensor_map_indexes.append((big_tensor_no, current_big_tensor_size, length ))
341
- current_big_tensor_size += length
342
360
 
343
- total_tensor_bytes += length
361
+ itemsize = p.data.dtype.itemsize
362
+ if current_big_tensor_size % itemsize:
363
+ current_big_tensor_size += itemsize - current_big_tensor_size % itemsize
364
+ tensor_map_indexes.append((big_tensor_no, current_big_tensor_size, length ))
365
+ current_big_tensor_size += length
344
366
 
367
+ total_tensor_bytes += length
345
368
 
346
369
  big_tensors_sizes.append(current_big_tensor_size)
347
370
 
@@ -368,39 +391,53 @@ def _pin_to_memory(model, model_id, partialPinning = False, verboseLevel = 1):
368
391
 
369
392
  tensor_no = 0
370
393
  # prev_big_tensor = 0
371
- for n, p, is_buffer in params_list:
372
- big_tensor_no, offset, length = tensor_map_indexes[tensor_no]
373
- # if big_tensor_no != prev_big_tensor:
374
- # gc.collect()
375
- # prev_big_tensor = big_tensor_no
376
- if big_tensor_no>=0 and big_tensor_no < last_big_tensor:
377
- current_big_tensor = big_tensors[big_tensor_no]
378
- if is_buffer :
379
- _force_load_buffer(p) # otherwise potential memory leak
394
+ for n, (p, is_buffer) in params_dict.items():
395
+ if n in tied_weights:
380
396
  if isinstance(p, QTensor):
381
- if p._qtype == qint4:
382
- length1 = torch.numel(p._data._data) * p._data._data.element_size()
383
- p._data._data = _move_to_pinned_tensor(p._data._data, current_big_tensor, offset, length1)
384
- if hasattr(p,"_scale_shift"):
385
- length2 = torch.numel(p._scale_shift) * p._scale_shift.element_size()
386
- p._scale_shift = _move_to_pinned_tensor(p._scale_shift, current_big_tensor, offset + length1, length2)
397
+ if p._qtype == qint4:
398
+ assert p._data._data.data.is_pinned()
399
+ else:
400
+ assert p._data.is_pinned()
401
+ else:
402
+ assert p.data.is_pinned()
403
+ else:
404
+ big_tensor_no, offset, length = tensor_map_indexes[tensor_no]
405
+ # if big_tensor_no != prev_big_tensor:
406
+ # gc.collect()
407
+ # prev_big_tensor = big_tensor_no
408
+ # match_param, match_isbuffer = tied_weights.get(n, (None, False))
409
+ # if match_param != None:
410
+
411
+ if big_tensor_no>=0 and big_tensor_no < last_big_tensor:
412
+ current_big_tensor = big_tensors[big_tensor_no]
413
+ if is_buffer :
414
+ _force_load_buffer(p) # otherwise potential memory leak
415
+ if isinstance(p, QTensor):
416
+ if p._qtype == qint4:
417
+ length1 = torch.numel(p._data._data) * p._data._data.element_size()
418
+ p._data._data = _move_to_pinned_tensor(p._data._data, current_big_tensor, offset, length1)
419
+ if hasattr(p,"_scale_shift"):
420
+ length2 = torch.numel(p._scale_shift) * p._scale_shift.element_size()
421
+ p._scale_shift = _move_to_pinned_tensor(p._scale_shift, current_big_tensor, offset + length1, length2)
422
+ else:
423
+ length2 = torch.numel(p._scale) * p._scale.element_size()
424
+ p._scale = _move_to_pinned_tensor(p._scale, current_big_tensor, offset + length1, length2)
425
+ length3 = torch.numel(p._shift) * p._shift.element_size()
426
+ p._shift = _move_to_pinned_tensor(p._shift, current_big_tensor, offset + length1 + length2, length3)
387
427
  else:
428
+ length1 = torch.numel(p._data) * p._data.element_size()
429
+ p._data = _move_to_pinned_tensor(p._data, current_big_tensor, offset, length1)
388
430
  length2 = torch.numel(p._scale) * p._scale.element_size()
389
431
  p._scale = _move_to_pinned_tensor(p._scale, current_big_tensor, offset + length1, length2)
390
- length3 = torch.numel(p._shift) * p._shift.element_size()
391
- p._shift = _move_to_pinned_tensor(p._shift, current_big_tensor, offset + length1 + length2, length3)
392
432
  else:
393
- length1 = torch.numel(p._data) * p._data.element_size()
394
- p._data = _move_to_pinned_tensor(p._data, current_big_tensor, offset, length1)
395
- length2 = torch.numel(p._scale) * p._scale.element_size()
396
- p._scale = _move_to_pinned_tensor(p._scale, current_big_tensor, offset + length1, length2)
397
- else:
398
- length = torch.numel(p.data) * p.data.element_size()
399
- p.data = _move_to_pinned_tensor(p.data, current_big_tensor, offset, length)
433
+ length = torch.numel(p.data) * p.data.element_size()
434
+ p.data = _move_to_pinned_tensor(p.data, current_big_tensor, offset, length)
400
435
 
401
- tensor_no += 1
436
+ tensor_no += 1
437
+ del p
402
438
  global total_pinned_bytes
403
439
  total_pinned_bytes += total
440
+ del params_dict
404
441
  gc.collect()
405
442
 
406
443
  if verboseLevel >=1:
@@ -420,7 +457,7 @@ def _welcome():
420
457
  if welcome_displayed:
421
458
  return
422
459
  welcome_displayed = True
423
- print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.1) by DeepBeepMeep ************{ENDC}{UNBOLD}")
460
+ print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.1.4-151) by DeepBeepMeep ************{ENDC}{UNBOLD}")
424
461
 
425
462
  def _extract_num_from_str(num_in_str):
426
463
  size = len(num_in_str)
@@ -518,16 +555,6 @@ def _requantize(model: torch.nn.Module, state_dict: dict, quantization_map: dict
518
555
 
519
556
  def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 1000000000, model_id = 'Unknown'):
520
557
 
521
- def compute_submodule_size(submodule):
522
- size = 0
523
- for p in submodule.parameters(recurse=False):
524
- size += torch.numel(p.data) * sizeofbfloat16
525
-
526
- for p in submodule.buffers(recurse=False):
527
- size += torch.numel(p.data) * sizeofbfloat16
528
-
529
- return size
530
-
531
558
  total_size =0
532
559
  total_excluded = 0
533
560
  exclude_list = []
@@ -549,16 +576,31 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 10
549
576
  tower_names ,_ = _detect_main_towers(model_to_quantize)
550
577
  tower_names = [ n[:-1] for n in tower_names]
551
578
 
579
+
580
+ cache_ref = {}
581
+ tied_weights= {}
582
+
552
583
  for submodule_name, submodule in model_to_quantize.named_modules():
553
584
  if isinstance(submodule, QModuleMixin):
554
585
  if verboseLevel>=1:
555
586
  print("No quantization to do as model is already quantized")
556
587
  return False
557
588
 
558
- if submodule_name=='':
559
- continue
589
+ size = 0
590
+ for n, p in submodule.named_parameters(recurse = False):
591
+ ref = _get_tensor_ref(p)
592
+ match = cache_ref.get(ref, None)
593
+ if match != None:
594
+ tied_weights[submodule_name]= (n, ) + match
595
+ else:
596
+ cache_ref[ref] = (submodule_name, n)
597
+ size += torch.numel(p.data) * sizeofbfloat16
598
+
599
+ for p in submodule.buffers(recurse=False):
600
+ size += torch.numel(p.data) * sizeofbfloat16
601
+
602
+
560
603
 
561
- size = compute_submodule_size(submodule)
562
604
  if not any(submodule_name.startswith(pre) for pre in tower_names):
563
605
  flush = False
564
606
  if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
@@ -590,21 +632,29 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 10
590
632
  submodule_names.append(submodule_name)
591
633
  total_size += size
592
634
 
593
- if submodule_size > 0 and submodule_size <= threshold:
635
+ if submodule_size >0 and submodule_size <= threshold :
594
636
  exclude_list += submodule_names
595
637
  if verboseLevel >=2:
596
638
  print(f"Excluded size {submodule_size/ONE_MB:.1f} MB: {prev_blocks_prefix} : {submodule_names}")
597
639
  total_excluded += submodule_size
598
640
 
641
+
599
642
  perc_excluded =total_excluded/ total_size if total_size >0 else 1
600
643
  if verboseLevel >=2:
601
- print(f"Total Excluded {total_excluded/ONE_MB:.1f} MB oF {total_size/ONE_MB:.1f} that is {perc_excluded*100:.2f}%")
644
+ if total_excluded == 0:
645
+ print(f"Can't find any module to exclude from quantization, full model ({total_size/ONE_MB:.1f} MB) will be quantized")
646
+ else:
647
+ print(f"Total Excluded {total_excluded/ONE_MB:.1f} MB of {total_size/ONE_MB:.1f} that is {perc_excluded*100:.2f}%")
602
648
  if perc_excluded >= 0.10:
603
- print(f"Too many modules are excluded, there is something wrong with the selection, switch back to full quantization.")
649
+ if verboseLevel >=2:
650
+ print(f"Too many modules are excluded, there is something wrong with the selection, switch back to full quantization.")
604
651
  exclude_list = None
605
652
 
606
653
 
607
- quantize(model_to_quantize,weights, exclude= exclude_list)
654
+ exclude_list += list(tied_weights)
655
+ quantize(model_to_quantize, weights= weights, exclude= exclude_list)
656
+
657
+
608
658
  # quantize(model_to_quantize,weights, include= [ "*1.block.attn.to_out*"]) #"
609
659
 
610
660
  # for name, m in model_to_quantize.named_modules():
@@ -614,24 +664,40 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 10
614
664
 
615
665
  # force to read non quantized parameters so that their lazy tensors and corresponding mmap are released
616
666
  # otherwise we may end up keeping in memory both the quantized and the non quantize model
617
- for n,m in model_to_quantize.named_modules():
667
+ named_modules = {n:m for n,m in model_to_quantize.named_modules()}
668
+ for module_name, module in named_modules.items():
618
669
  # do not read quantized weights (detected them directly or behind an adapter)
619
- if isinstance(m, QModuleMixin) or hasattr(m, "base_layer") and isinstance(m.base_layer, QModuleMixin):
620
- if hasattr(m, "bias") and m.bias is not None:
621
- _force_load_parameter(m.bias)
670
+ if isinstance(module, QModuleMixin) or hasattr(module, "base_layer") and isinstance(module.base_layer, QModuleMixin):
671
+ if hasattr(module, "bias") and module.bias is not None:
672
+ _force_load_parameter(module.bias)
622
673
  else:
623
- for p in m.parameters(recurse = False):
624
- _force_load_parameter(p)
625
-
626
- for b in m.buffers(recurse = False):
674
+ tied_w = tied_weights.get(module_name, None)
675
+ for n, p in module.named_parameters(recurse = False):
676
+ if tied_w != None and n == tied_w[0]:
677
+ if isinstance( named_modules[tied_w[1]], QModuleMixin) :
678
+ setattr(module, n, None) # release refs of tied weights if source is going to be quantized
679
+ # otherwise don't force load as it will be loaded in the source anyway
680
+ else:
681
+ _force_load_parameter(p)
682
+ del p # del p if not it will still contain a ref to a tensor when leaving the loop
683
+ for b in module.buffers(recurse = False):
627
684
  _force_load_buffer(b)
628
-
685
+ del b
629
686
 
630
687
 
631
688
  freeze(model_to_quantize)
632
689
  torch.cuda.empty_cache()
633
- gc.collect()
690
+ gc.collect()
691
+
692
+ for tied_module, (tied_weight, src_module, src_weight) in tied_weights.items():
693
+ p = getattr(named_modules[src_module], src_weight)
694
+ if isinstance(p, QTensor):
695
+ setattr(named_modules[tied_module], tied_weight, p ) # copy refs to quantized sources
696
+
697
+ del named_modules
698
+
634
699
  quantization_map = _quantization_map(model_to_quantize)
700
+
635
701
  model_to_quantize._quanto_map = quantization_map
636
702
 
637
703
  if hasattr(model_to_quantize, "_already_pinned"):
@@ -643,12 +709,81 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 10
643
709
 
644
710
  return True
645
711
 
712
+ def _lora_linear_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
713
+ self._check_forward_args(x, *args, **kwargs)
714
+ adapter_names = kwargs.pop("adapter_names", None)
715
+ if self.disable_adapters:
716
+ if self.merged:
717
+ self.unmerge()
718
+ result = self.base_layer(x, *args, **kwargs)
719
+ elif adapter_names is not None:
720
+ result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
721
+ elif self.merged:
722
+ result = self.base_layer(x, *args, **kwargs)
723
+ else:
724
+ base_weight = self.base_layer.weight
725
+ if base_weight.shape[-1] < x.shape[-2]: # sum base weight and lora matrices instead of applying input on each sub lora matrice if input is too large. This will save a lot VRAM and compute
726
+ for active_adapter in self.active_adapters:
727
+ lora_A = self.lora_A[active_adapter]
728
+ lora_B = self.lora_B[active_adapter]
729
+ scaling = self.scaling[active_adapter]
730
+ lora_A_weight = lora_A.weight
731
+ lora_B_weight = lora_B.weight
732
+ lora_BA = lora_B_weight @ lora_A_weight
733
+ base_weight += scaling * lora_BA
734
+
735
+ result = torch.nn.functional.linear(x, base_weight, bias=self.base_layer.bias)
736
+ torch_result_dtype = result.dtype
737
+
738
+ else:
739
+ result = self.base_layer(x, *args, **kwargs)
740
+ torch_result_dtype = result.dtype
741
+ x = x.to(torch.bfloat16)
742
+
743
+ for active_adapter in self.active_adapters:
744
+ if active_adapter not in self.lora_A.keys():
745
+ continue
746
+ lora_A = self.lora_A[active_adapter]
747
+ lora_B = self.lora_B[active_adapter]
748
+ dropout = self.lora_dropout[active_adapter]
749
+ scaling = self.scaling[active_adapter]
750
+ x = x.to(lora_A.weight.dtype)
751
+
752
+ if not self.use_dora[active_adapter]:
753
+ y = lora_A(x)
754
+ y = lora_B(y)
755
+ y*= scaling
756
+ result+= y
757
+ del lora_A, lora_B, y
758
+ # result = result + lora_B(lora_A(dropout(x))) * scaling
759
+ else:
760
+ if isinstance(dropout, nn.Identity) or not self.training:
761
+ base_result = result
762
+ else:
763
+ x = dropout(x)
764
+ base_result = None
765
+
766
+ result = result + self.lora_magnitude_vector[active_adapter](
767
+ x,
768
+ lora_A=lora_A,
769
+ lora_B=lora_B,
770
+ scaling=scaling,
771
+ base_layer=self.get_base_layer(),
772
+ base_result=base_result,
773
+ )
774
+
775
+ result = result.to(torch_result_dtype)
776
+ return result
777
+
646
778
  def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_loras = True, verboseLevel = -1,):
647
779
  verboseLevel = _compute_verbose_level(verboseLevel)
648
780
 
649
781
  if inject_adapter_in_model == None or set_weights_and_activate_adapters == None or get_peft_kwargs == None:
650
782
  raise Exception("Unable to load Lora, missing 'peft' and / or 'diffusers' modules")
651
-
783
+
784
+ from peft.tuners.lora import Linear
785
+ Linear.forward = _lora_linear_forward
786
+
652
787
  if not isinstance(lora_path, list):
653
788
  lora_path = [lora_path]
654
789
 
@@ -658,6 +793,9 @@ def load_loras_into_model(model, lora_path, lora_multi = None, activate_all_lora
658
793
  for i, path in enumerate(lora_path):
659
794
  adapter_name = str(i)
660
795
 
796
+
797
+
798
+
661
799
  state_dict = safetensors2.torch_load_file(path)
662
800
 
663
801
  keys = list(state_dict.keys())
@@ -839,7 +977,6 @@ def load_model_data(model, file_path: str, do_quantize = False, quantizationType
839
977
  verboseLevel = _compute_verbose_level(verboseLevel)
840
978
 
841
979
  model = _remove_model_wrapper(model)
842
-
843
980
  if not (".safetensors" in file_path or ".sft" in file_path):
844
981
  if pinToMemory:
845
982
  raise Exception("Pinning to memory while loading only supported for safe tensors files")
@@ -851,12 +988,20 @@ def load_model_data(model, file_path: str, do_quantize = False, quantizationType
851
988
 
852
989
  if metadata is None:
853
990
  quantization_map = None
991
+ tied_weights_map = None
854
992
  else:
855
993
  quantization_map = metadata.get("quantization_map", None)
856
994
  config = metadata.get("config", None)
857
995
  if config is not None:
858
996
  model._config = config
859
997
 
998
+ tied_weights_map = metadata.get("tied_weights_map", None)
999
+ if tied_weights_map != None:
1000
+ for name, tied_weights_list in tied_weights_map.items():
1001
+ mapped_weight = state_dict[name]
1002
+ for tied_weights in tied_weights_list:
1003
+ state_dict[tied_weights] = mapped_weight
1004
+
860
1005
 
861
1006
 
862
1007
  if quantization_map is None:
@@ -905,6 +1050,107 @@ def load_model_data(model, file_path: str, do_quantize = False, quantizationType
905
1050
 
906
1051
  return
907
1052
 
1053
+ def save_model(model, file_path, do_quantize = False, quantizationType = qint8, verboseLevel = -1, config_file_path = None ):
1054
+ """save the weights of a model and quantize them if requested
1055
+ These weights can be loaded again using 'load_model_data'
1056
+ """
1057
+
1058
+ config = None
1059
+ extra_meta = None
1060
+ verboseLevel = _compute_verbose_level(verboseLevel)
1061
+ if config_file_path !=None:
1062
+ with open(config_file_path, "r", encoding="utf-8") as reader:
1063
+ text = reader.read()
1064
+ config= json.loads(text)
1065
+ elif hasattr(model, "_config"):
1066
+ config = model._config
1067
+ elif hasattr(model, "config"):
1068
+ config_fullpath = None
1069
+ config_obj = getattr(model,"config")
1070
+ config_path = getattr(config_obj,"_name_or_path", None)
1071
+ if config_path != None:
1072
+ config_fullpath = os.path.join(config_path, "config.json")
1073
+ config_fullpath = _get_model(config_fullpath)
1074
+
1075
+ # if not os.path.isfile(config_fullpath):
1076
+ # config_fullpath = None
1077
+ if config_fullpath is None:
1078
+ config_fullpath = os.path.join(os.path.dirname(file_path), "config.json")
1079
+ if os.path.isfile(config_fullpath):
1080
+ with open(config_fullpath, "r", encoding="utf-8") as reader:
1081
+ text = reader.read()
1082
+ config= json.loads(text)
1083
+
1084
+ if do_quantize:
1085
+ _quantize(model, weights=quantizationType, model_id=file_path)
1086
+
1087
+ quantization_map = getattr(model, "_quanto_map", None)
1088
+
1089
+ from collections import OrderedDict
1090
+
1091
+ cache_ref = {}
1092
+ tied_weights_map = {}
1093
+ sd = model.state_dict()
1094
+ out_sd = OrderedDict()
1095
+
1096
+
1097
+ for name, weight in sd.items():
1098
+ ref = _get_tensor_ref(weight)
1099
+ match = cache_ref.get(ref, None)
1100
+ if match != None:
1101
+ tied_list = tied_weights_map.get(match, [])
1102
+ tied_list.append(name)
1103
+ tied_weights_map[match] = tied_list
1104
+ else:
1105
+ out_sd[name] = weight
1106
+ cache_ref[ref] = name
1107
+
1108
+ if len(tied_weights_map) > 0:
1109
+ extra_meta = { "tied_weights_map" : tied_weights_map }
1110
+
1111
+ if verboseLevel >=1:
1112
+ print(f"Saving file '{file_path}")
1113
+
1114
+ safetensors2.torch_write_file(out_sd, file_path , quantization_map = quantization_map, config = config, extra_meta= extra_meta)
1115
+ if verboseLevel >=1:
1116
+ print(f"File '{file_path}' saved")
1117
+
1118
+
1119
+ def extract_models(obj = None, prefix = None):
1120
+ if isinstance(obj, str): # for compatibility as the two args were switched
1121
+ bkp = prefix
1122
+ prefix = obj
1123
+ obj = bkp
1124
+
1125
+ pipe = {}
1126
+ if obj == None:
1127
+ raise Exception("an object to analyze must be provided")
1128
+ if prefix==None or len(prefix)==0:
1129
+ prefix = ""
1130
+ elif prefix[ -1:] != "/":
1131
+ prefix + "/"
1132
+
1133
+ for name in dir(obj):
1134
+ element = getattr(obj,name)
1135
+ if name in ("pipeline", "pipe"):
1136
+ pipeline = element
1137
+ if hasattr(pipeline , "components") and isinstance(pipeline.components, dict):
1138
+ for k, model in pipeline.components.items():
1139
+ if model != None:
1140
+ pipe[prefix + k ] = model
1141
+ elif isinstance(element, torch.nn.Module) and name!="base_model":
1142
+ if prefix + name in pipe:
1143
+ pipe[prefix + "_" + name ] = element
1144
+ else:
1145
+ pipe[prefix + name ] = element
1146
+ elif isinstance(element, dict):
1147
+ for k, element in element.items():
1148
+ if hasattr(element , "pipeline"):
1149
+ pipe.update( extract_models(prefix + k,element ))
1150
+
1151
+
1152
+ return pipe
1153
+
908
1154
  def get_model_name(model):
909
1155
  return model.name
910
1156
 
@@ -922,6 +1168,10 @@ class offload:
922
1168
  self.active_models_ids = []
923
1169
  self.active_subcaches = {}
924
1170
  self.models = {}
1171
+ self.cotenants_map = {
1172
+ "text_encoder": ["vae", "text_encoder_2"],
1173
+ "text_encoder_2": ["vae", "text_encoder"],
1174
+ }
925
1175
  self.verboseLevel = 0
926
1176
  self.blocks_of_modules = {}
927
1177
  self.blocks_of_modules_sizes = {}
@@ -931,14 +1181,16 @@ class offload:
931
1181
  self.loaded_blocks = {}
932
1182
  self.prev_blocks_names = {}
933
1183
  self.next_blocks_names = {}
1184
+ self.preloaded_blocks_per_model = {}
934
1185
  self.default_stream = torch.cuda.default_stream(torch.device("cuda")) # torch.cuda.current_stream()
935
1186
  self.transfer_stream = torch.cuda.Stream()
936
1187
  self.async_transfers = False
1188
+ self.parameters_ref = {}
937
1189
  global last_offload_obj
938
1190
  last_offload_obj = self
939
1191
 
940
1192
 
941
- def add_module_to_blocks(self, model_id, blocks_name, submodule, prev_block_name):
1193
+ def add_module_to_blocks(self, model_id, blocks_name, submodule, prev_block_name, submodule_name):
942
1194
 
943
1195
  entry_name = model_id if blocks_name is None else model_id + "/" + blocks_name
944
1196
  if entry_name in self.blocks_of_modules:
@@ -953,39 +1205,54 @@ class offload:
953
1205
  self.prev_blocks_names[entry_name] = prev_entry_name
954
1206
  if not prev_block_name == None:
955
1207
  self.next_blocks_names[prev_entry_name] = entry_name
956
-
1208
+ bef = blocks_params_size
957
1209
  for k,p in submodule.named_parameters(recurse=False):
1210
+ param_size = 0
1211
+ ref = _get_tensor_ref(p)
1212
+ tied_param = self.parameters_ref.get(ref, None)
958
1213
 
959
1214
  if isinstance(p, QTensor):
960
- blocks_params.append( (submodule, k, p, False ) )
1215
+ blocks_params.append( (submodule, k, p, False, tied_param ) )
961
1216
 
962
1217
  if p._qtype == qint4:
963
1218
  if hasattr(p,"_scale_shift"):
964
- blocks_params_size += torch.numel(p._scale_shift) * p._scale_shift.element_size()
965
- blocks_params_size += torch.numel(p._data._data) * p._data._data.element_size()
1219
+ param_size += torch.numel(p._scale_shift) * p._scale_shift.element_size()
1220
+ param_size += torch.numel(p._data._data) * p._data._data.element_size()
966
1221
  else:
967
- blocks_params_size += torch.numel(p._scale) * p._scale.element_size()
968
- blocks_params_size += torch.numel(p._shift) * p._shift.element_size()
969
- blocks_params_size += torch.numel(p._data._data) * p._data._data.element_size()
1222
+ param_size += torch.numel(p._scale) * p._scale.element_size()
1223
+ param_size += torch.numel(p._shift) * p._shift.element_size()
1224
+ param_size += torch.numel(p._data._data) * p._data._data.element_size()
970
1225
  else:
971
- blocks_params_size += torch.numel(p._scale) * p._scale.element_size()
972
- blocks_params_size += torch.numel(p._data) * p._data.element_size()
1226
+ param_size += torch.numel(p._scale) * p._scale.element_size()
1227
+ param_size += torch.numel(p._data) * p._data.element_size()
973
1228
  else:
974
- blocks_params.append( (submodule, k, p, False) )
975
- blocks_params_size += torch.numel(p.data) * p.data.element_size()
1229
+ blocks_params.append( (submodule, k, p, False, tied_param) )
1230
+ param_size += torch.numel(p.data) * p.data.element_size()
1231
+
1232
+
1233
+ if tied_param == None:
1234
+ blocks_params_size += param_size
1235
+ self.parameters_ref[ref] = (submodule, k)
976
1236
 
977
1237
  for k, p in submodule.named_buffers(recurse=False):
978
- blocks_params.append( (submodule, k, p, True) )
1238
+ blocks_params.append( (submodule, k, p, True, None) )
979
1239
  blocks_params_size += p.data.nbytes
980
1240
 
1241
+ aft = blocks_params_size
1242
+
1243
+ # if blocks_name is None:
1244
+ # print(f"Default: {model_id}/{submodule_name} : {(aft-bef)/ONE_MB:0.2f} MB")
1245
+ # pass
1246
+
981
1247
 
982
1248
  self.blocks_of_modules_sizes[entry_name] = blocks_params_size
983
1249
 
1250
+
984
1251
  return blocks_params_size
985
1252
 
986
1253
 
987
1254
  def can_model_be_cotenant(self, model_id):
988
- potential_cotenants= cotenants_map.get(model_id, None)
1255
+ potential_cotenants= self.cotenants_map.get(model_id, None)
989
1256
  if potential_cotenants is None:
990
1257
  return False
991
1258
  for existing_cotenant in self.active_models_ids:
@@ -994,51 +1261,76 @@ class offload:
994
1261
  return True
995
1262
 
996
1263
  @torch.compiler.disable()
997
- def gpu_load_blocks(self, model_id, blocks_name):
1264
+ def gpu_load_blocks(self, model_id, blocks_name, preload = False):
998
1265
  # cl = clock.start()
999
1266
 
1000
- if blocks_name != None:
1001
- self.loaded_blocks[model_id] = blocks_name
1002
1267
 
1003
1268
  entry_name = model_id if blocks_name is None else model_id + "/" + blocks_name
1004
1269
 
1005
1270
  def cpu_to_gpu(stream_to_use, blocks_params): #, record_for_stream = None
1006
1271
  with torch.cuda.stream(stream_to_use):
1007
1272
  for param in blocks_params:
1008
- parent_module, n, p, is_buffer = param
1273
+ parent_module, n, p, is_buffer, tied_param = param
1274
+ if tied_param != None:
1275
+ tied_p = getattr( tied_param[0], tied_param[1])
1276
+ if tied_p.is_cuda:
1277
+ setattr(parent_module, n , tied_p)
1278
+ continue
1279
+
1009
1280
  q = p.to("cuda", non_blocking=True)
1010
1281
  if is_buffer:
1011
1282
  q = torch.nn.Buffer(q)
1012
1283
  else:
1013
1284
  q = torch.nn.Parameter(q , requires_grad=False)
1014
1285
  setattr(parent_module, n , q)
1015
- # if record_for_stream != None:
1016
- # if isinstance(p, QTensor):
1017
- # q._data.record_stream(record_for_stream)
1018
- # q._scale.record_stream(record_for_stream)
1019
- # else:
1020
- # p.data.record_stream(record_for_stream)
1286
+
1287
+ if tied_param != None:
1288
+ setattr( tied_param[0], tied_param[1], q)
1289
+ del p, q
1290
+ any_past_block = False
1291
+
1292
+ loaded_block = self.loaded_blocks[model_id]
1293
+ if not preload and loaded_block != None:
1294
+ any_past_block = True
1295
+ self.gpu_unload_blocks(model_id, loaded_block)
1296
+ if self.ready_to_check_mem():
1297
+ self.empty_cache_if_needed()
1021
1298
 
1022
1299
 
1023
1300
  if self.verboseLevel >=2:
1024
1301
  model = self.models[model_id]
1025
1302
  model_name = model._get_name()
1026
- print(f"Loading model {entry_name} ({model_name}) in GPU")
1027
-
1303
+ # if not preload:
1304
+ # print(f"Request to load model {entry_name} ({model_name}) in GPU")
1305
+
1028
1306
 
1029
1307
  if self.async_transfers and blocks_name != None:
1030
- first = self.prev_blocks_names[entry_name] == None
1308
+ first = self.prev_blocks_names[entry_name] == None or not any_past_block
1031
1309
  next_blocks_entry = self.next_blocks_names[entry_name] if entry_name in self.next_blocks_names else None
1032
1310
  if first:
1311
+ if self.verboseLevel >=2:
1312
+ if preload:
1313
+ print(f"Preloading model {entry_name} ({model_name}) in GPU")
1314
+ else:
1315
+ print(f"Loading model {entry_name} ({model_name}) in GPU")
1033
1316
  cpu_to_gpu(torch.cuda.current_stream(), self.blocks_of_modules[entry_name])
1317
+
1034
1318
  torch.cuda.synchronize()
1035
1319
 
1036
1320
  if next_blocks_entry != None:
1321
+ if self.verboseLevel >=2:
1322
+ print(f"Prefetching model {next_blocks_entry} ({model_name}) in GPU")
1037
1323
  cpu_to_gpu(self.transfer_stream, self.blocks_of_modules[next_blocks_entry]) #, self.default_stream
1038
1324
 
1039
1325
  else:
1326
+ if self.verboseLevel >=2:
1327
+ print(f"Loading model {entry_name} ({model_name}) in GPU")
1040
1328
  cpu_to_gpu(self.default_stream, self.blocks_of_modules[entry_name])
1041
1329
  torch.cuda.synchronize()
1330
+
1331
+ if not preload:
1332
+ self.loaded_blocks[model_id] = blocks_name
1333
+
1042
1334
  # cl.stop()
1043
1335
  # print(f"load time: {cl.format_time_gap()}")
1044
1336
 
@@ -1057,12 +1349,13 @@ class offload:
1057
1349
 
1058
1350
  blocks_params = self.blocks_of_modules[blocks_name]
1059
1351
  for param in blocks_params:
1060
- parent_module, n, p, is_buffer = param
1352
+ parent_module, n, p, is_buffer, _ = param
1061
1353
  if is_buffer:
1062
1354
  q = torch.nn.Buffer(p)
1063
1355
  else:
1064
1356
  q = torch.nn.Parameter(p , requires_grad=False)
1065
1357
  setattr(parent_module, n , q)
1358
+ del p, q
1066
1359
  # cl.stop()
1067
1360
  # print(f"unload time: {cl.format_time_gap()}")
1068
1361
 
@@ -1072,13 +1365,16 @@ class offload:
1072
1365
  self.active_models.append(model)
1073
1366
  self.active_models_ids.append(model_id)
1074
1367
 
1075
- self.gpu_load_blocks(model_id, None)
1076
-
1077
- # torch.cuda.current_stream().synchronize()
1368
+ self.gpu_load_blocks(model_id, None, True)
1369
+ for block_name in self.preloaded_blocks_per_model[model_id]:
1370
+ self.gpu_load_blocks(model_id, block_name, True)
1078
1371
 
1079
1372
  def unload_all(self):
1080
1373
  for model_id in self.active_models_ids:
1081
1374
  self.gpu_unload_blocks(model_id, None)
1375
+ for block_name in self.preloaded_blocks_per_model[model_id]:
1376
+ self.gpu_unload_blocks(model_id, block_name)
1377
+
1082
1378
  loaded_block = self.loaded_blocks[model_id]
1083
1379
  if loaded_block != None:
1084
1380
  self.gpu_unload_blocks(model_id, loaded_block)
@@ -1148,82 +1444,72 @@ class offload:
1148
1444
 
1149
1445
  return False
1150
1446
 
1447
+ def ensure_model_loaded(self, model_id):
1448
+ if model_id in self.active_models_ids:
1449
+ return
1450
+ # new_model_id = getattr(module, "_mm_id")
1451
+ # do not always unload existing models if it is more efficient to keep in them in the GPU
1452
+ # (e.g: small modules whose calls are text encoders)
1453
+ if not self.can_model_be_cotenant(model_id) :
1454
+ self.unload_all()
1455
+ self.gpu_load(model_id)
1456
+
1151
1457
  def hook_preload_blocks_for_compilation(self, target_module, model_id,blocks_name, context):
1152
1458
 
1153
1459
  # @torch.compiler.disable()
1154
1460
  def preload_blocks_for_compile(module, *args, **kwargs):
1155
- some_context = context #for debugging
1156
- if blocks_name == None:
1157
- if self.ready_to_check_mem():
1158
- self.empty_cache_if_needed()
1159
- else:
1160
- loaded_block = self.loaded_blocks[model_id]
1161
- if (loaded_block == None or loaded_block != blocks_name) :
1162
- if loaded_block != None:
1163
- self.gpu_unload_blocks(model_id, loaded_block)
1164
- if self.ready_to_check_mem():
1165
- self.empty_cache_if_needed()
1166
- self.loaded_blocks[model_id] = blocks_name
1167
- self.gpu_load_blocks(model_id, blocks_name)
1461
+ # some_context = context #for debugging
1462
+ if blocks_name != None and blocks_name != self.loaded_blocks[model_id] and blocks_name not in self.preloaded_blocks_per_model[model_id]:
1463
+ self.gpu_load_blocks(model_id, blocks_name)
1464
+
1168
1465
  # need to be registered before the forward not to be break the efficiency of the compilation chain
1169
1466
  # it should be at the top of the compilation as this type of hook in the middle of a chain seems to break memory performance
1170
1467
  target_module.register_forward_pre_hook(preload_blocks_for_compile)
1171
1468
 
1172
1469
 
1173
- def hook_check_empty_cache_needed(self, target_module, model_id,blocks_name, previous_method, context):
1470
+ def hook_check_empty_cache_needed(self, target_module, model_id, blocks_name, previous_method, context):
1174
1471
 
1175
1472
  qint4quantization = isinstance(target_module, QModuleMixin) and target_module.weight!= None and target_module.weight.qtype == qint4
1176
1473
  if qint4quantization:
1177
1474
  pass
1178
1475
 
1476
+ if hasattr(target_module, "_mm_id"):
1477
+ # no hook for a shared module with no weights (otherwise this will cause models loading / unloading for nothing)
1478
+ orig_model_id = getattr(target_module, "_mm_id")
1479
+ if self.verboseLevel >=2:
1480
+ print(f"Model '{model_id}' shares module '{target_module._get_name()}' with module(s) '{orig_model_id}' ")
1481
+ assert not self.any_param_or_buffer(target_module)
1482
+ if not isinstance(orig_model_id, list):
1483
+ orig_model_id = [orig_model_id]
1484
+ orig_model_id.append(model_id)
1485
+ setattr(target_module, "_mm_id", orig_model_id)
1486
+ target_module.forward = target_module._mm_forward
1487
+ return
1488
+
1179
1489
  def check_empty_cuda_cache(module, *args, **kwargs):
1180
- # if self.ready_to_check_mem():
1181
- # self.empty_cache_if_needed()
1490
+ self.ensure_model_loaded(model_id)
1182
1491
  if blocks_name == None:
1183
1492
  if self.ready_to_check_mem():
1184
1493
  self.empty_cache_if_needed()
1185
- else:
1186
- loaded_block = self.loaded_blocks[model_id]
1187
- if (loaded_block == None or loaded_block != blocks_name) :
1188
- if loaded_block != None:
1189
- self.gpu_unload_blocks(model_id, loaded_block)
1190
- if self.ready_to_check_mem():
1191
- self.empty_cache_if_needed()
1192
- self.loaded_blocks[model_id] = blocks_name
1193
- self.gpu_load_blocks(model_id, blocks_name)
1494
+ elif blocks_name != self.loaded_blocks[model_id] and blocks_name not in self.preloaded_blocks_per_model[model_id]:
1495
+ self.gpu_load_blocks(model_id, blocks_name)
1194
1496
  if qint4quantization:
1195
1497
  args, kwargs = self.move_args_to_gpu(*args, **kwargs)
1196
1498
 
1197
1499
  return previous_method(*args, **kwargs)
1198
1500
 
1199
-
1200
- if hasattr(target_module, "_mm_id"):
1201
- orig_model_id = getattr(target_module, "_mm_id")
1202
- if self.verboseLevel >=2:
1203
- print(f"Model '{model_id}' shares module '{target_module._get_name()}' with module '{orig_model_id}' ")
1204
- assert not self.any_param_or_buffer(target_module)
1205
-
1206
- return
1207
1501
  setattr(target_module, "_mm_id", model_id)
1502
+ setattr(target_module, "_mm_forward", previous_method)
1503
+
1208
1504
  setattr(target_module, "forward", functools.update_wrapper(functools.partial(check_empty_cuda_cache, target_module), previous_method) )
1209
1505
 
1210
1506
 
1211
1507
  def hook_change_module(self, target_module, model, model_id, module_id, previous_method):
1212
- def check_change_module(module, *args, **kwargs):
1213
- performEmptyCacheTest = False
1214
- if not model_id in self.active_models_ids:
1215
- new_model_id = getattr(module, "_mm_id")
1216
- # do not always unload existing models if it is more efficient to keep in them in the GPU
1217
- # (e.g: small modules whose calls are text encoders)
1218
- if not self.can_model_be_cotenant(new_model_id) :
1219
- self.unload_all()
1220
- performEmptyCacheTest = False
1221
- self.gpu_load(new_model_id)
1508
+
1509
+ def check_change_module(module, *args, **kwargs):
1510
+ self.ensure_model_loaded(model_id)
1222
1511
  # transfer leftovers inputs that were incorrectly created in the RAM (mostly due to some .device tests that returned incorrectly "cpu")
1223
1512
  args, kwargs = self.move_args_to_gpu(*args, **kwargs)
1224
- if performEmptyCacheTest:
1225
- self.empty_cache_if_needed()
1226
-
1227
1513
  return previous_method(*args, **kwargs)
1228
1514
 
1229
1515
  if hasattr(target_module, "_mm_id"):
@@ -1240,71 +1526,90 @@ class offload:
1240
1526
  print(f"Hooked to model '{model_id}' ({model_name})")
1241
1527
 
1242
1528
 
1243
- def save_model(model, file_path, do_quantize = False, quantizationType = qint8, verboseLevel = -1, config_file_path = None ):
1244
- """save the weights of a model and quantize them if requested
1245
- These weights can be loaded again using 'load_model_data'
1246
- """
1247
-
1248
- config = None
1249
- verboseLevel = _compute_verbose_level(verboseLevel)
1250
- if config_file_path !=None:
1251
- with open(config_file_path, "r", encoding="utf-8") as reader:
1252
- text = reader.read()
1253
- config= json.loads(text)
1254
- elif hasattr(model, "_config"):
1255
- config = model._config
1256
- elif hasattr(model, "config"):
1257
- config_fullpath = None
1258
- config_obj = getattr(model,"config")
1259
- config_path = getattr(config_obj,"_name_or_path", None)
1260
- if config_path != None:
1261
- config_fullpath = os.path.join(config_path, "config.json")
1262
- if not os.path.isfile(config_fullpath):
1263
- config_fullpath = None
1264
- if config_fullpath is None:
1265
- config_fullpath = os.path.join(os.path.dirname(file_path), "config.json")
1266
- if os.path.isfile(config_fullpath):
1267
- with open(config_fullpath, "r", encoding="utf-8") as reader:
1268
- text = reader.read()
1269
- config= json.loads(text)
1270
1529
 
1271
- if do_quantize:
1272
- _quantize(model, weights=quantizationType, model_id=file_path)
1273
-
1274
- quantization_map = getattr(model, "_quanto_map", None)
1530
+ def tune_preloading(self, model_id, current_budget, towers_names):
1531
+ preloaded_blocks = {}
1532
+ preload_total = 0
1533
+ max_blocks_fetch = 0
1275
1534
 
1276
- if verboseLevel >=1:
1277
- print(f"Saving file '{file_path}")
1278
- safetensors2.torch_write_file(model.state_dict(), file_path , quantization_map = quantization_map, config = config)
1279
- if verboseLevel >=1:
1280
- print(f"File '{file_path}' saved")
1535
+ self.preloaded_blocks_per_model[model_id] = preloaded_blocks
1281
1536
 
1537
+ if current_budget == 0 or towers_names is None or len(towers_names) == 0 or not self.async_transfers:
1538
+ return
1539
+ # current_budget = 5000 * ONE_MB
1540
+ base_size = self.blocks_of_modules_sizes[model_id]
1541
+ current_budget -= base_size
1542
+ if current_budget <= 0:
1543
+ if self.verboseLevel >=1:
1544
+ print(f"Async loading plan for model '{model_id}' : due to limited budget, beside the async shuttle only only base model ({(base_size)/ONE_MB:0.2f} MB) will be preloaded")
1545
+ return
1546
+
1547
+ towers = []
1548
+ total_size = 0
1549
+ for tower_name in towers_names:
1550
+ max_floor_size = 0
1551
+ tower_size = 0
1552
+ floors = []
1553
+ prefix = model_id + "/" + tower_name
1554
+ for name, size in self.blocks_of_modules_sizes.items():
1555
+ if name.startswith(prefix):
1556
+ tower_size += size
1557
+ floor_no = int( name[len(prefix): ] )
1558
+ floors.append( (name, floor_no, size))
1559
+ max_floor_size = max(max_floor_size, size)
1560
+
1561
+ towers.append( (floors, max_floor_size, tower_size) )
1562
+ total_size += tower_size
1563
+ current_budget -= 2 * max_floor_size
1564
+ if current_budget <= 0:
1565
+ if self.verboseLevel >=1:
1566
+ print(f"Async loading plan for model '{model_id}' : due to limited budget, beside the async shuttle only the base model ({(base_size)/ONE_MB:0.2f} MB) will be preloaded")
1567
+ return
1282
1568
 
1283
- def extract_models(prefix, obj):
1284
- pipe = {}
1285
- for name in dir(obj):
1286
- element = getattr(obj,name)
1287
- if name in ("pipeline", "pipe"):
1288
- pipeline = element
1289
- if hasattr(pipeline , "components") and isinstance(pipeline.components, dict):
1290
- for k, model in pipeline.components.items():
1291
- if model != None:
1292
- pipe[prefix + "/" + k ] = model
1293
- elif isinstance(element, torch.nn.Module):
1294
- if prefix + "/" + name in pipe:
1295
- pipe[prefix + "/_" + name ] = element
1569
+
1570
+ for floors, max_floor_size, tower_size in towers:
1571
+ tower_budget = tower_size / total_size * current_budget
1572
+ preload_blocks_count = int( tower_budget / max_floor_size)
1573
+ preload_total += preload_blocks_count * max_floor_size
1574
+ max_blocks_fetch = max(max_floor_size, max_blocks_fetch)
1575
+ if preload_blocks_count <= 0:
1576
+ if self.verboseLevel >=1:
1577
+ print(f"Async loading plan for model '{model_id}' : due to limited budget, beside the async shuttle only the base model ({(base_size)/ONE_MB:0.2f} MB) will be preloaded")
1578
+ return
1579
+
1580
+ nb_blocks= len(floors)
1581
+ space_between = (nb_blocks - preload_blocks_count) / preload_blocks_count
1582
+ cursor = space_between
1583
+ first_non_preloaded = None
1584
+ prev_non_preloaded = None
1585
+ for block in floors:
1586
+ name, i, size = block
1587
+ if i < cursor:
1588
+ if prev_non_preloaded == None:
1589
+ first_non_preloaded = name
1590
+ else:
1591
+ self.next_blocks_names[prev_non_preloaded] = name
1592
+ self.prev_blocks_names[name] = prev_non_preloaded
1593
+ prev_non_preloaded = name
1594
+ else:
1595
+ self.next_blocks_names[name] = None
1596
+ self.prev_blocks_names[name] = None
1597
+ preloaded_blocks[name[ len(model_id) + 1 : ] ] = size
1598
+ cursor += 1 + space_between
1599
+
1600
+ if prev_non_preloaded != None and len(towers) == 1 :
1601
+ self.next_blocks_names[prev_non_preloaded] = first_non_preloaded
1602
+ self.prev_blocks_names[first_non_preloaded] = prev_non_preloaded
1296
1603
  else:
1297
- pipe[prefix + "/" + name ] = element
1298
- elif isinstance(element, dict):
1299
- for k, element in element.items():
1300
- if hasattr(element , "pipeline"):
1301
- pipe.update( extract_models(prefix + "/" + k,element ))
1604
+ self.next_blocks_names[prev_non_preloaded] = None
1302
1605
 
1606
+ self.preloaded_blocks_per_model[model_id] = preloaded_blocks
1607
+
1608
+ if self.verboseLevel >=1:
1609
+ print(f"Async loading plan for model '{model_id}' : {(preload_total+base_size)/ONE_MB:0.2f} MB will be preloaded (base size of {base_size/ONE_MB:0.2f} MB + {preload_total/total_size*100:0.1f}% of recurrent layers data) with a {max_blocks_fetch/ONE_MB:0.2f} MB async" + (" circular" if len(towers) == 1 else "") + " shuttle")
1303
1610
 
1304
- return pipe
1305
-
1306
1611
 
1307
- def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = True, extraModelsToQuantize = None, quantizationType = qint8, budgets= 0, asyncTransfers = True, compile = False, perc_reserved_mem_max = 0, verboseLevel = -1):
1612
+ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = True, extraModelsToQuantize = None, quantizationType = qint8, budgets= 0, workingVRAM = None, asyncTransfers = True, compile = False, perc_reserved_mem_max = 0, coTenantsMap = None, verboseLevel = -1):
1308
1613
  """Hook to a pipeline or a group of modules in order to reduce their VRAM requirements:
1309
1614
  pipe_or_dict_of_modules : the pipeline object or a dictionary of modules of the model
1310
1615
  quantizeTransformer: set True by default will quantize on the fly the video / image model
@@ -1321,9 +1626,7 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
1321
1626
  model_budgets = {}
1322
1627
 
1323
1628
  windows_os = os.name == 'nt'
1324
- global total_pinned_bytes
1325
1629
 
1326
-
1327
1630
  budget = 0
1328
1631
  if not budgets is None:
1329
1632
  if isinstance(budgets , dict):
@@ -1352,6 +1655,8 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
1352
1655
  verboseLevel = _compute_verbose_level(verboseLevel)
1353
1656
 
1354
1657
  _welcome()
1658
+ if coTenantsMap != None:
1659
+ self.cotenants_map = coTenantsMap
1355
1660
 
1356
1661
  self.models = models
1357
1662
 
@@ -1382,8 +1687,9 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
1382
1687
 
1383
1688
  self.anyCompiledModule = compileAllModels or len(modelsToCompile)>0
1384
1689
  if self.anyCompiledModule:
1385
- torch._dynamo.config.cache_size_limit = 10000
1386
1690
  torch.compiler.reset()
1691
+ torch._dynamo.config.cache_size_limit = 10000
1692
+ #dynamic=True
1387
1693
 
1388
1694
  # torch._logging.set_logs(recompiles=True)
1389
1695
  # torch._inductor.config.realize_opcount_threshold = 100 # workaround bug "AssertionError: increase TRITON_MAX_BLOCK['X'] to 4096."
@@ -1431,9 +1737,7 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
1431
1737
  current_model_size += torch.numel(p.data) * p.data.element_size()
1432
1738
 
1433
1739
  for b in current_model.buffers():
1434
- if b.data.dtype == torch.float32:
1435
- # convert any left overs float32 weight to bloat16 to divide by 2 the model memory footprint
1436
- b.data = b.data.to(torch.bfloat16)
1740
+ # do not convert 32 bits float to 16 bits since buffers are few (and potential gain low) and usually they are needed for precision calculation (for instance Rope)
1437
1741
  current_model_size += torch.numel(b.data) * b.data.element_size()
1438
1742
 
1439
1743
  if modelPinned:
@@ -1441,17 +1745,39 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
1441
1745
 
1442
1746
 
1443
1747
  model_budget = model_budgets[model_id] * ONE_MB if model_id in model_budgets else budget
1444
-
1748
+ if workingVRAM != None:
1749
+ model_minimumVRAM = -1
1750
+ if isinstance(workingVRAM, dict):
1751
+ if model_id in workingVRAM:
1752
+ model_minimumVRAM = workingVRAM[model_id]
1753
+ elif "*" in model_id in workingVRAM:
1754
+ model_minimumVRAM = workingVRAM["*"]
1755
+ else:
1756
+ model_minimumVRAM = workingVRAM
1757
+ if model_minimumVRAM > 0:
1758
+ new_budget = self.device_mem_capacity - model_minimumVRAM * ONE_MB
1759
+ new_budget = 1 if new_budget < 0 else new_budget
1760
+ model_budget = new_budget if model_budget == 0 or new_budget < model_budget else model_budget
1445
1761
  if model_budget > 0 and model_budget > current_model_size:
1446
1762
  model_budget = 0
1763
+ coef =0.8
1764
+ if current_model_size > coef * self.device_mem_capacity and model_budget == 0 or model_budget > coef * self.device_mem_capacity:
1765
+ if verboseLevel >= 1:
1766
+ if model_budget == 0:
1767
+ print(f"Model '{model_id}' is too large ({current_model_size/ONE_MB:0.1f} MB) to fit entirely in {coef * 100}% of the VRAM (max capacity is {coef * self.device_mem_capacity/ONE_MB}) MB)")
1768
+ else:
1769
+ print(f"Budget ({budget/ONE_MB:0.1f} MB) for Model '{model_id}' is too important so that this model can fit in the VRAM (max capacity is {self.device_mem_capacity/ONE_MB}) MB)")
1770
+ print(f"Budget allocation for this model has been consequently reduced to the 80% of max GPU Memory ({coef * self.device_mem_capacity/ONE_MB:0.1f} MB). This may not leave enough working VRAM and you will probably need to define manually a lower budget for this model.")
1771
+ model_budget = coef * self.device_mem_capacity
1772
+
1447
1773
 
1448
- model_budgets[model_id] = model_budget #/ 2 if asyncTransfers else model_budget
1774
+ model_budgets[model_id] = model_budget
1449
1775
 
1450
1776
  partialPinning = False
1451
1777
 
1452
1778
  if estimatesBytesToPin > 0 and estimatesBytesToPin >= (max_reservable_memory - total_pinned_bytes):
1453
1779
  if self.verboseLevel >=1:
1454
- print(f"Switching to partial pinning since full requirements for pinned models is {estimatesBytesToPin/ONE_MB:0.1f} MB while estimated reservable RAM is {max_reservable_memory/ONE_MB:0.1f} MB. You may increase the value of parameter 'perc_reserved_mem_max' to a value higher than {perc_reserved_mem_max:0.2f} to force full pinnning." )
1780
+ print(f"Switching to partial pinning since full requirements for pinned models is {estimatesBytesToPin/ONE_MB:0.1f} MB while estimated available reservable RAM is {(max_reservable_memory-total_pinned_bytes)/ONE_MB:0.1f} MB. You may increase the value of parameter 'perc_reserved_mem_max' to a value higher than {perc_reserved_mem_max:0.2f} to force full pinnning." )
1455
1781
  partialPinning = True
1456
1782
 
1457
1783
  # Hook forward methods of modules
@@ -1463,7 +1789,8 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
1463
1789
  if compilationInThisOne:
1464
1790
  if self.verboseLevel>=1:
1465
1791
  if len(towers_modules)>0:
1466
- print(f"Pytorch compilation of '{model_id}' is scheduled for these modules : {towers_names}*.")
1792
+ formated_tower_names = [name + '*' for name in towers_names]
1793
+ print(f"Pytorch compilation of '{model_id}' is scheduled for these modules : {formated_tower_names}.")
1467
1794
  else:
1468
1795
  print(f"Pytorch compilation of model '{model_id}' is not yet supported.")
1469
1796
 
@@ -1479,20 +1806,14 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
1479
1806
  _pin_to_memory(current_model, model_id, partialPinning= partialPinning, verboseLevel=verboseLevel)
1480
1807
 
1481
1808
  current_budget = model_budgets[model_id]
1482
- current_size = 0
1483
- cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq = None, None, None, -1
1809
+ cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq, is_mod_seq = None, None, None, -1, False
1484
1810
  self.loaded_blocks[model_id] = None
1485
1811
 
1486
1812
  for submodule_name, submodule in current_model.named_modules():
1487
1813
  # create a fake 'accelerate' parameter so that the _execution_device property returns always "cuda"
1488
1814
  # (it is queried in many pipelines even if offloading is not properly implemented)
1489
- if not hasattr(submodule, "_hf_hook"):
1815
+ if not hasattr(submodule, "_hf_hook"):
1490
1816
  setattr(submodule, "_hf_hook", HfHook())
1491
-
1492
- # if submodule_name=='':
1493
- # continue
1494
-
1495
-
1496
1817
  if current_budget > 0 and len(submodule_name) > 0:
1497
1818
  if cur_blocks_prefix != None:
1498
1819
  if submodule_name.startswith(cur_blocks_prefix):
@@ -1500,20 +1821,20 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
1500
1821
  depth_name = submodule_name.split(".")
1501
1822
  level = depth_name[len(depth_prefix)-1]
1502
1823
  pre , num = _extract_num_from_str(level)
1503
- if num != cur_blocks_seq and (cur_blocks_seq == -1 or current_size > current_budget):
1824
+ if num != cur_blocks_seq and not (is_mod_seq and cur_blocks_seq>=0):
1504
1825
  prev_blocks_name = cur_blocks_name
1505
1826
  cur_blocks_name = cur_blocks_prefix + str(num)
1506
1827
  # print(f"new block: {model_id}/{cur_blocks_name} - {submodule_name}")
1507
1828
  cur_blocks_seq = num
1508
1829
  else:
1509
- cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq = None, None, None, -1
1830
+ cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq, is_mod_seq = None, None, None, -1, False
1510
1831
 
1511
1832
  if cur_blocks_prefix == None:
1512
1833
  pre , num = _extract_num_from_str(submodule_name)
1513
1834
  if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
1514
- cur_blocks_prefix, prev_blocks_name, cur_blocks_seq = pre + ".", None, -1
1835
+ cur_blocks_prefix, prev_blocks_name, cur_blocks_seq, is_mod_seq = pre + ".", None, -1, isinstance(submodule, torch.nn.Sequential)
1515
1836
  elif num >=0:
1516
- cur_blocks_prefix, prev_blocks_name, cur_blocks_seq = pre, None, num
1837
+ cur_blocks_prefix, prev_blocks_name, cur_blocks_seq, is_mod_seq = pre, None, num, False
1517
1838
  cur_blocks_name = submodule_name
1518
1839
  # print(f"new block: {model_id}/{cur_blocks_name} - {submodule_name}")
1519
1840
 
@@ -1528,13 +1849,35 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
1528
1849
  else:
1529
1850
  self.hook_check_empty_cache_needed(submodule, model_id, cur_blocks_name, submodule_method, context = submodule_name )
1530
1851
 
1531
- current_size = self.add_module_to_blocks(model_id, cur_blocks_name, submodule, prev_blocks_name)
1852
+ self.add_module_to_blocks(model_id, cur_blocks_name, submodule, prev_blocks_name, submodule_name)
1853
+
1854
+ self.tune_preloading(model_id, current_budget, towers_names)
1532
1855
 
1533
1856
 
1534
1857
  if self.verboseLevel >=2:
1535
- for n,b in self.blocks_of_modules_sizes.items():
1536
- print(f"Size of submodel '{n}': {b/ONE_MB:.1f} MB")
1858
+ start_num, prev_num, prev_pre, prev_size = -1, -1, None, -1
1859
+
1860
+ def print_size_range(n,start_num,prev_num, prev_size ):
1861
+ if prev_num < 0:
1862
+ print(f"Size of submodel '{n}': {prev_size/ONE_MB:.1f} MB")
1863
+ elif prev_num - start_num <=1:
1864
+ print(f"Size of submodel '{n+ str(start_num)}': {prev_size/ONE_MB:.1f} MB")
1865
+ else:
1866
+ print(f"Size of submodel '{n+ str(start_num) +'-'+ str(prev_num)}': {(prev_num-start_num+1)*prev_size/ONE_MB:.1f} MB ({prev_size/ONE_MB:.1f} MB x {prev_num-start_num+1})")
1867
+
1868
+ for n, size in self.blocks_of_modules_sizes.items():
1869
+ size = int(size / 10000)* 10000
1870
+ pre, num = _extract_num_from_str(n) if "/" in n else (n, -1)
1871
+ if prev_pre == None :
1872
+ start_num = num
1873
+ elif prev_pre != pre or prev_pre == pre and size != prev_size:
1874
+ print_size_range(prev_pre,start_num,prev_num, prev_size )
1875
+ start_num = num
1876
+ prev_num, prev_pre, prev_size = num, pre, size
1877
+ if prev_pre != None:
1878
+ print_size_range(prev_pre,start_num,prev_num, prev_size )
1537
1879
 
1880
+
1538
1881
  torch.set_default_device('cuda')
1539
1882
  torch.cuda.empty_cache()
1540
1883
  gc.collect()
@@ -1595,21 +1938,21 @@ def profile(pipe_or_dict_of_modules, profile_no: profile_type = profile_type.Ve
1595
1938
  if profile_no == profile_type.HighRAM_HighVRAM:
1596
1939
  pinnedMemory= True
1597
1940
  budgets = None
1598
- info = "You have chosen a profile that requires at least 48 GB of RAM and 24 GB of VRAM. Some VRAM is consumed just to make the model runs faster."
1941
+ info = "You have chosen a profile that may require 48 GB of RAM and up to 24 GB of VRAM on some applications."
1599
1942
  elif profile_no == profile_type.HighRAM_LowVRAM:
1600
1943
  pinnedMemory= True
1601
1944
  budgets["*"] = 3000
1602
- info = "You have chosen a profile that requires at least 48 GB of RAM and 12 GB of VRAM. Some RAM is consumed to reduce VRAM consumption."
1945
+ info = "You have chosen a profile that may require 48 GB of RAM and up to 12 GB of VRAM on some applications."
1603
1946
  elif profile_no == profile_type.LowRAM_HighVRAM:
1604
1947
  pinnedMemory= "transformer"
1605
1948
  extraModelsToQuantize = default_extraModelsToQuantize
1606
1949
  budgets = None
1607
- info = "You have chosen a Medium speed profile that requires at least 32 GB of RAM and 24 GB of VRAM. Some VRAM is consuming just to make the model runs faster"
1950
+ info = "You have chosen a Medium speed profile that may require 32 GB of RAM and up to 24 GB of VRAM on some applications."
1608
1951
  elif profile_no == profile_type.LowRAM_LowVRAM:
1609
1952
  pinnedMemory= "transformer"
1610
1953
  extraModelsToQuantize = default_extraModelsToQuantize
1611
1954
  budgets["*"] = 3000
1612
- info = "You have chosen a profile that requires at least 32 GB of RAM and 12 GB of VRAM. Some RAM is consumed to reduce VRAM consumption. "
1955
+ info = "You have chosen a profile that usually may require 32 GB of RAM and up to 12 GB of VRAM on some applications."
1613
1956
  elif profile_no == profile_type.VerylowRAM_LowVRAM:
1614
1957
  pinnedMemory= False
1615
1958
  extraModelsToQuantize = default_extraModelsToQuantize
@@ -1617,9 +1960,10 @@ def profile(pipe_or_dict_of_modules, profile_no: profile_type = profile_type.Ve
1617
1960
  if "transformer" in modules:
1618
1961
  budgets["transformer"] = 400
1619
1962
  #asyncTransfers = False
1620
- info = "You have chosen the slowest profile that requires at least 24 GB of RAM and 10 GB of VRAM."
1963
+ info = "You have chosen the slowest profile that may require 24 GB of RAM and up to 10 GB of VRAM on some applications."
1621
1964
  else:
1622
1965
  raise Exception("Unknown profile")
1966
+ info += " Actual requirements may varry depending on the application or on the tuning done to the profile."
1623
1967
 
1624
1968
  if budgets != None and len(budgets) == 0:
1625
1969
  budgets = None