mmgp 3.0.9__py3-none-any.whl → 3.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mmgp might be problematic. Click here for more details.

mmgp/offload.py CHANGED
@@ -1,4 +1,4 @@
1
- # ------------------ Memory Management 3.0 for the GPU Poor by DeepBeepMeep (mmgp)------------------
1
+ # ------------------ Memory Management 3.1 for the GPU Poor by DeepBeepMeep (mmgp)------------------
2
2
  #
3
3
  # This module contains multiples optimisations so that models such as Flux (and derived), Mochi, CogView, HunyuanVideo, ... can run smoothly on a 24 GB GPU limited card.
4
4
  # This a replacement for the accelerate library that should in theory manage offloading, but doesn't work properly with models that are loaded / unloaded several
@@ -79,7 +79,7 @@ from mmgp import profile_type
79
79
  from optimum.quanto import freeze, qfloat8, qint4 , qint8, quantize, QModuleMixin, QTensor, quantize_module
80
80
 
81
81
 
82
-
82
+ shared_state = {}
83
83
 
84
84
  mmm = safetensors2.mmm
85
85
 
@@ -154,33 +154,75 @@ def _get_max_reservable_memory(perc_reserved_mem_max):
154
154
  perc_reserved_mem_max = 0.40 if os.name == 'nt' else 0.5
155
155
  return perc_reserved_mem_max * physical_memory
156
156
 
157
- def _detect_main_towers(model, verboseLevel=1):
157
+ def _detect_main_towers(model, min_floors = 5, verboseLevel=1):
158
158
  cur_blocks_prefix = None
159
159
  towers_modules= []
160
160
  towers_names= []
161
161
 
162
+ floors_modules= []
163
+ tower_name = None
164
+
165
+
162
166
  for submodule_name, submodule in model.named_modules():
167
+
163
168
  if submodule_name=='':
164
169
  continue
165
170
 
166
- if isinstance(submodule, torch.nn.ModuleList):
167
- newList =False
168
- if cur_blocks_prefix == None:
169
- cur_blocks_prefix = submodule_name + "."
170
- newList = True
171
- else:
172
- if not submodule_name.startswith(cur_blocks_prefix):
173
- cur_blocks_prefix = submodule_name + "."
174
- newList = True
171
+ if cur_blocks_prefix != None:
172
+ if submodule_name.startswith(cur_blocks_prefix):
173
+ depth_prefix = cur_blocks_prefix.split(".")
174
+ depth_name = submodule_name.split(".")
175
+ level = depth_name[len(depth_prefix)-1]
176
+ pre , num = _extract_num_from_str(level)
175
177
 
176
- if newList and len(submodule)>=5:
177
- towers_names.append(submodule_name)
178
- towers_modules.append(submodule)
178
+ if num != cur_blocks_seq:
179
+ floors_modules.append(submodule)
180
+
181
+ cur_blocks_seq = num
182
+ else:
183
+ if len(floors_modules) >= min_floors:
184
+ towers_modules += floors_modules
185
+ towers_names.append(tower_name)
186
+ tower_name = None
187
+ floors_modules= []
188
+ cur_blocks_prefix, cur_blocks_seq = None, -1
189
+
190
+ if cur_blocks_prefix == None:
191
+ pre , num = _extract_num_from_str(submodule_name)
192
+ if isinstance(submodule, (torch.nn.ModuleList)):
193
+ cur_blocks_prefix, cur_blocks_seq = pre + ".", -1
194
+ tower_name = submodule_name + ".*"
195
+ elif num >=0:
196
+ cur_blocks_prefix, cur_blocks_seq = pre, num
197
+ tower_name = submodule_name[ :-1] + "*"
198
+ floors_modules.append(submodule)
199
+
200
+ if len(floors_modules) >= min_floors:
201
+ towers_modules += floors_modules
202
+ towers_names.append(tower_name)
203
+
204
+ # for submodule_name, submodule in model.named_modules():
205
+ # if submodule_name=='':
206
+ # continue
207
+
208
+ # if isinstance(submodule, torch.nn.ModuleList):
209
+ # newList =False
210
+ # if cur_blocks_prefix == None:
211
+ # cur_blocks_prefix = submodule_name + "."
212
+ # newList = True
213
+ # else:
214
+ # if not submodule_name.startswith(cur_blocks_prefix):
215
+ # cur_blocks_prefix = submodule_name + "."
216
+ # newList = True
217
+
218
+ # if newList and len(submodule)>=5:
219
+ # towers_names.append(submodule_name)
220
+ # towers_modules.append(submodule)
179
221
 
180
- else:
181
- if cur_blocks_prefix is not None:
182
- if not submodule_name.startswith(cur_blocks_prefix):
183
- cur_blocks_prefix = None
222
+ # else:
223
+ # if cur_blocks_prefix is not None:
224
+ # if not submodule_name.startswith(cur_blocks_prefix):
225
+ # cur_blocks_prefix = None
184
226
 
185
227
  return towers_names, towers_modules
186
228
 
@@ -194,7 +236,7 @@ def _get_model(model_path):
194
236
  _path = Path(model_path).parts
195
237
  _filename = _path[-1]
196
238
  _path = _path[:-1]
197
- if len(_path)==1:
239
+ if len(_path)<=1:
198
240
  raise("file not found")
199
241
  else:
200
242
  from huggingface_hub import hf_hub_download #snapshot_download,
@@ -219,6 +261,29 @@ def _remove_model_wrapper(model):
219
261
  return sub_module
220
262
  return model
221
263
 
264
+ # def force_load_tensor(t):
265
+ # c = torch.nn.Parameter(t + 0)
266
+ # torch.utils.swap_tensors(t, c)
267
+ # del c
268
+
269
+
270
+ # for n,m in model_to_quantize.named_modules():
271
+ # # do not read quantized weights (detected them directly or behind an adapter)
272
+ # if isinstance(m, QModuleMixin) or hasattr(m, "base_layer") and isinstance(m.base_layer, QModuleMixin):
273
+ # if hasattr(m, "bias") and m.bias is not None:
274
+ # force_load_tensor(m.bias.data)
275
+ # # m.bias.data = m.bias.data + 0
276
+ # else:
277
+ # for n, p in m.named_parameters(recurse = False):
278
+ # data = getattr(m, n)
279
+ # force_load_tensor(data)
280
+ # # setattr(m,n, torch.nn.Parameter(data + 0 ) )
281
+
282
+ # for b in m.buffers(recurse = False):
283
+ # # b.data = b.data + 0
284
+ # b.data = torch.nn.Buffer(b.data + 0)
285
+ # force_load_tensor(b.data)
286
+
222
287
 
223
288
 
224
289
  def _move_to_pinned_tensor(source_tensor, big_tensor, offset, length):
@@ -248,6 +313,17 @@ def _safetensors_load_file(file_path):
248
313
 
249
314
  return sd, metadata
250
315
 
316
+ def _force_load_buffer(p):
317
+ # To do : check if buffer was persistent and transfer state, or maybe swap keep already this property ?
318
+ q = torch.nn.Buffer(p + 0)
319
+ torch.utils.swap_tensors(p, q)
320
+ del q
321
+
322
+ def _force_load_parameter(p):
323
+ q = torch.nn.Parameter(p + 0)
324
+ torch.utils.swap_tensors(p, q)
325
+ del q
326
+
251
327
  def _pin_to_memory(model, model_id, partialPinning = False, perc_reserved_mem_max = 0, verboseLevel = 1):
252
328
  if verboseLevel>=1 :
253
329
  if partialPinning:
@@ -260,6 +336,7 @@ def _pin_to_memory(model, model_id, partialPinning = False, perc_reserved_mem_ma
260
336
  towers_names, _ = _detect_main_towers(model)
261
337
  towers_names = [n +"." for n in towers_names]
262
338
 
339
+
263
340
  BIG_TENSOR_MAX_SIZE = 2**28 # 256 MB
264
341
  current_big_tensor_size = 0
265
342
  big_tensor_no = 0
@@ -273,10 +350,10 @@ def _pin_to_memory(model, model_id, partialPinning = False, perc_reserved_mem_ma
273
350
  if partialPinning:
274
351
  include = any(k.startswith(pre) for pre in towers_names) if partialPinning else True
275
352
  if include:
276
- params_list = params_list + list(sub_module.buffers(recurse=False)) + list(sub_module.parameters(recurse=False))
353
+ params_list = params_list + [ (k + '.' + n, p, False) for n, p in sub_module.named_parameters(recurse=False)] + [ (k + '.' + n, p, True) for n, p in sub_module.named_buffers(recurse=False)]
277
354
 
278
- # print(f"num params to pin {model_id}: {len(params_list)}")
279
- for p in params_list:
355
+
356
+ for n, p, _ in params_list:
280
357
  if isinstance(p, QTensor):
281
358
  if p._qtype == qint4:
282
359
  if hasattr(p,"_scale_shift"):
@@ -288,10 +365,16 @@ def _pin_to_memory(model, model_id, partialPinning = False, perc_reserved_mem_ma
288
365
  else:
289
366
  length = torch.numel(p.data) * p.data.element_size()
290
367
 
368
+
291
369
  if current_big_tensor_size + length > BIG_TENSOR_MAX_SIZE:
292
370
  big_tensors_sizes.append(current_big_tensor_size)
293
371
  current_big_tensor_size = 0
294
372
  big_tensor_no += 1
373
+
374
+
375
+ itemsize = p.data.dtype.itemsize
376
+ if current_big_tensor_size % itemsize:
377
+ current_big_tensor_size += itemsize - current_big_tensor_size % itemsize
295
378
  tensor_map_indexes.append((big_tensor_no, current_big_tensor_size, length ))
296
379
  current_big_tensor_size += length
297
380
 
@@ -320,12 +403,18 @@ def _pin_to_memory(model, model_id, partialPinning = False, perc_reserved_mem_ma
320
403
 
321
404
  gc.collect()
322
405
 
406
+
323
407
  tensor_no = 0
324
- for p in params_list:
408
+ # prev_big_tensor = 0
409
+ for n, p, is_buffer in params_list:
325
410
  big_tensor_no, offset, length = tensor_map_indexes[tensor_no]
326
-
411
+ # if big_tensor_no != prev_big_tensor:
412
+ # gc.collect()
413
+ # prev_big_tensor = big_tensor_no
327
414
  if big_tensor_no>=0 and big_tensor_no < last_big_tensor:
328
415
  current_big_tensor = big_tensors[big_tensor_no]
416
+ if is_buffer :
417
+ _force_load_buffer(p) # otherwise potential memory leak
329
418
  if isinstance(p, QTensor):
330
419
  if p._qtype == qint4:
331
420
  length1 = torch.numel(p._data._data) * p._data._data.element_size()
@@ -353,7 +442,7 @@ def _pin_to_memory(model, model_id, partialPinning = False, perc_reserved_mem_ma
353
442
  gc.collect()
354
443
 
355
444
  if verboseLevel >=1:
356
- if total_tensor_bytes == total:
445
+ if total_tensor_bytes <= total:
357
446
  print(f"The whole model was pinned to reserved RAM: {last_big_tensor} large blocks spread across {total/ONE_MB:.2f} MB")
358
447
  else:
359
448
  print(f"{total/ONE_MB:.2f} MB were pinned to reserved RAM out of {total_tensor_bytes/ONE_MB:.2f} MB")
@@ -369,8 +458,16 @@ def _welcome():
369
458
  if welcome_displayed:
370
459
  return
371
460
  welcome_displayed = True
372
- print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.0) by DeepBeepMeep ************{ENDC}{UNBOLD}")
461
+ print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 3.1) by DeepBeepMeep ************{ENDC}{UNBOLD}")
373
462
 
463
+ def _extract_num_from_str(num_in_str):
464
+ for i in range(len(num_in_str)):
465
+ if not num_in_str[-i-1:].isnumeric():
466
+ if i == 0:
467
+ return num_in_str, -1
468
+ else:
469
+ return num_in_str[: -i], int(num_in_str[-i:])
470
+ return "", int(num_in_str)
374
471
 
375
472
  def _quantize_dirty_hack(model):
376
473
  # dirty hack: add a hook on state_dict() to return a fake non quantized state_dict if called by Lora Diffusers initialization functions
@@ -479,55 +576,56 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 10
479
576
  if hasattr(model_to_quantize, "_quanto_map"):
480
577
  for k, entry in model_to_quantize._quanto_map.items():
481
578
  weights = entry["weights"]
482
- print(f"Model '{model_id}' is already quantized to format '{weights}'")
579
+ print(f"Model '{model_id}' is already quantized in format '{weights}'")
483
580
  return False
484
581
  print(f"Model '{model_id}' is already quantized")
485
582
  return False
486
583
 
487
584
  print(f"Quantization of model '{model_id}' started to format '{weights}'")
488
585
 
586
+ tower_names ,_ = _detect_main_towers(model_to_quantize)
587
+ tower_names = [ n[:-1] for n in tower_names]
588
+
489
589
  for submodule_name, submodule in model_to_quantize.named_modules():
490
590
  if isinstance(submodule, QModuleMixin):
491
591
  if verboseLevel>=1:
492
592
  print("No quantization to do as model is already quantized")
493
593
  return False
494
594
 
495
-
496
595
  if submodule_name=='':
497
596
  continue
498
597
 
499
-
500
- flush = False
501
- if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
502
- if cur_blocks_prefix == None:
503
- cur_blocks_prefix = submodule_name + "."
504
- flush = True
505
- else:
506
- #if cur_blocks_prefix != submodule_name[:len(cur_blocks_prefix)]:
507
- if not submodule_name.startswith(cur_blocks_prefix):
598
+ size = compute_submodule_size(submodule)
599
+ if not any(submodule_name.startswith(pre) for pre in tower_names):
600
+ flush = False
601
+ if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
602
+ if cur_blocks_prefix == None:
508
603
  cur_blocks_prefix = submodule_name + "."
509
604
  flush = True
510
- else:
511
- if cur_blocks_prefix is not None:
512
- #if not cur_blocks_prefix == submodule_name[0:len(cur_blocks_prefix)]:
513
- if not submodule_name.startswith(cur_blocks_prefix):
514
- cur_blocks_prefix = None
515
- flush = True
516
-
517
- if flush:
518
- if submodule_size <= threshold:
519
- exclude_list += submodule_names
520
- if verboseLevel >=2:
521
- print(f"Excluded size {submodule_size/ONE_MB:.1f} MB: {prev_blocks_prefix} : {submodule_names}")
522
- total_excluded += submodule_size
523
-
524
- submodule_size = 0
525
- submodule_names = []
526
- prev_blocks_prefix = cur_blocks_prefix
527
- size = compute_submodule_size(submodule)
528
- submodule_size += size
605
+ else:
606
+ if not submodule_name.startswith(cur_blocks_prefix):
607
+ cur_blocks_prefix = submodule_name + "."
608
+ flush = True
609
+ else:
610
+ if cur_blocks_prefix is not None:
611
+ #if not cur_blocks_prefix == submodule_name[0:len(cur_blocks_prefix)]:
612
+ if not submodule_name.startswith(cur_blocks_prefix):
613
+ cur_blocks_prefix = None
614
+ flush = True
615
+
616
+ if flush :
617
+ if submodule_size <= threshold :
618
+ exclude_list += submodule_names
619
+ if verboseLevel >=2:
620
+ print(f"Excluded size {submodule_size/ONE_MB:.1f} MB: {prev_blocks_prefix} : {submodule_names}")
621
+ total_excluded += submodule_size
622
+
623
+ submodule_size = 0
624
+ submodule_names = []
625
+ prev_blocks_prefix = cur_blocks_prefix
626
+ submodule_size += size
627
+ submodule_names.append(submodule_name)
529
628
  total_size += size
530
- submodule_names.append(submodule_name)
531
629
 
532
630
  if submodule_size > 0 and submodule_size <= threshold:
533
631
  exclude_list += submodule_names
@@ -543,28 +641,29 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 10
543
641
  exclude_list = None
544
642
 
545
643
 
546
- #quantize(model_to_quantize,weights, exclude= exclude_list)
644
+ quantize(model_to_quantize,weights, exclude= exclude_list)
645
+ # quantize(model_to_quantize,weights, include= [ "*1.block.attn.to_out*"]) #"
646
+
647
+ # for name, m in model_to_quantize.named_modules():
648
+ # if exclude_list is None or not any( name == module_name for module_name in exclude_list):
649
+ # _quantize_submodule(model_to_quantize, name, m, weights=weights, activations=None, optimizer=None)
547
650
 
548
- for name, m in model_to_quantize.named_modules():
549
- if exclude_list is None or not any( name == module_name for module_name in exclude_list):
550
- _quantize_submodule(model_to_quantize, name, m, weights=weights, activations=None, optimizer=None)
551
651
 
552
652
  # force to read non quantized parameters so that their lazy tensors and corresponding mmap are released
553
653
  # otherwise we may end up keeping in memory both the quantized and the non quantize model
554
- for m in model_to_quantize.modules():
654
+ for n,m in model_to_quantize.named_modules():
555
655
  # do not read quantized weights (detected them directly or behind an adapter)
556
656
  if isinstance(m, QModuleMixin) or hasattr(m, "base_layer") and isinstance(m.base_layer, QModuleMixin):
557
657
  if hasattr(m, "bias") and m.bias is not None:
558
- m.bias.data = m.bias.data + 0
658
+ _force_load_parameter(m.bias)
559
659
  else:
560
- for n, p in m.named_parameters(recurse = False):
561
- data = getattr(m, n)
562
- setattr(m,n, torch.nn.Parameter(data + 0 ) )
660
+ for p in m.parameters(recurse = False):
661
+ _force_load_parameter(p)
563
662
 
564
663
  for b in m.buffers(recurse = False):
565
- b.data = b.data + 0
664
+ _force_load_buffer(b)
665
+
566
666
 
567
-
568
667
 
569
668
  freeze(model_to_quantize)
570
669
  torch.cuda.empty_cache()
@@ -581,595 +680,609 @@ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 10
581
680
 
582
681
  return True
583
682
 
584
- def get_model_name(model):
585
- return model.name
683
+ def load_loras_into_model(model, lora_path, lora_multi = None, verboseLevel = -1):
684
+ verboseLevel = _compute_verbose_level(verboseLevel)
586
685
 
587
- class HfHook:
588
- def __init__(self):
589
- self.execution_device = "cuda"
686
+ if inject_adapter_in_model == None or set_weights_and_activate_adapters == None or get_peft_kwargs == None:
687
+ raise Exception("Unable to load Lora, missing 'peft' and / or 'diffusers' modules")
688
+
689
+ if not isinstance(lora_path, list):
690
+ lora_path = [lora_path]
691
+
692
+ if lora_multi is None:
693
+ lora_multi = [1. for _ in lora_path]
590
694
 
591
- def detach_hook(self, module):
592
- pass
695
+ for i, path in enumerate(lora_path):
696
+ adapter_name = str(i)
593
697
 
594
- last_offload_obj = None
595
- class offload:
596
- def __init__(self):
597
- self.active_models = []
598
- self.active_models_ids = []
599
- self.active_subcaches = {}
600
- self.models = {}
601
- self.verboseLevel = 0
602
- self.blocks_of_modules = {}
603
- self.blocks_of_modules_sizes = {}
604
- self.anyCompiledModule = False
605
- self.device_mem_capacity = torch.cuda.get_device_properties(0).total_memory
606
- self.last_reserved_mem_check =0
607
- self.loaded_blocks = {}
608
- self.prev_blocks_names = {}
609
- self.next_blocks_names = {}
610
- self.default_stream = torch.cuda.default_stream(torch.device("cuda")) # torch.cuda.current_stream()
611
- self.transfer_stream = torch.cuda.Stream()
612
- self.async_transfers = False
613
- global last_offload_obj
614
- last_offload_obj = self
698
+ state_dict = safetensors2.torch_load_file(path)
615
699
 
616
- def add_module_to_blocks(self, model_id, blocks_name, submodule, prev_block_name):
700
+ keys = list(state_dict.keys())
701
+ if len(keys) == 0:
702
+ raise Exception(f"Empty Lora '{path}'")
617
703
 
618
- entry_name = model_id if blocks_name is None else model_id + "/" + blocks_name
619
- if entry_name in self.blocks_of_modules:
620
- blocks_params = self.blocks_of_modules[entry_name]
621
- blocks_params_size = self.blocks_of_modules_sizes[entry_name]
622
- else:
623
- blocks_params = []
624
- self.blocks_of_modules[entry_name] = blocks_params
625
- blocks_params_size = 0
626
- if blocks_name !=None:
627
704
 
628
- prev_entry_name = None if prev_block_name == None else model_id + "/" + prev_block_name
629
- self.prev_blocks_names[entry_name] = prev_entry_name
630
- if not prev_block_name == None:
631
- self.next_blocks_names[prev_entry_name] = entry_name
705
+ network_alphas = {}
706
+ for k in keys:
707
+ if "alpha" in k:
708
+ alpha_value = state_dict.pop(k)
709
+ if not ( (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
710
+ alpha_value, float
711
+ )):
712
+ network_alphas[k] = torch.tensor( float(alpha_value.item() ) )
632
713
 
714
+ pos = keys[0].find(".")
715
+ prefix = keys[0][0:pos]
716
+ if not any( prefix.startswith(some_prefix) for some_prefix in ["diffusion_model", "transformer"]):
717
+ msg = f"No compatible weight was found in Lora file '{path}'. Please check that it is compatible with the Diffusers format."
718
+ raise Exception(msg)
633
719
 
634
- for k,p in submodule.named_parameters(recurse=False):
635
- if isinstance(p, QTensor):
636
- blocks_params.append( (submodule, k, p ) )
720
+ transformer = model
637
721
 
638
- if p._qtype == qint4:
639
- if hasattr(p,"_scale_shift"):
640
- blocks_params_size += torch.numel(p._scale_shift) * p._scale_shift.element_size()
641
- blocks_params_size += torch.numel(p._data._data) * p._data._data.element_size()
642
- else:
643
- blocks_params_size += torch.numel(p._scale) * p._scale.element_size()
644
- blocks_params_size += torch.numel(p._shift) * p._shift.element_size()
645
- blocks_params_size += torch.numel(p._data._data) * p._data._data.element_size()
646
- else:
647
- blocks_params_size += torch.numel(p._scale) * p._scale.element_size()
648
- blocks_params_size += torch.numel(p._data) * p._data.element_size()
649
- else:
650
- blocks_params.append( (submodule, k, p ) )
651
- blocks_params_size += torch.numel(p.data) * p.data.element_size()
722
+ transformer_keys = [k for k in keys if k.startswith(prefix)]
723
+ state_dict = {
724
+ k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in transformer_keys
725
+ }
652
726
 
653
- for k, p in submodule.named_buffers(recurse=False):
654
- blocks_params.append( (submodule, k, p) )
655
- blocks_params_size += p.data.nbytes
727
+ sd_keys = state_dict.keys()
728
+ if len(sd_keys) == 0:
729
+ print(f"No compatible weight was found in Lora file '{path}'. Please check that it is compatible with the Diffusers format.")
730
+ return
656
731
 
732
+ # is_correct_format = all("lora" in key for key in state_dict.keys())
657
733
 
658
- self.blocks_of_modules_sizes[entry_name] = blocks_params_size
659
734
 
660
- return blocks_params_size
661
735
 
662
736
 
663
- def can_model_be_cotenant(self, model_id):
664
- potential_cotenants= cotenants_map.get(model_id, None)
665
- if potential_cotenants is None:
666
- return False
667
- for existing_cotenant in self.active_models_ids:
668
- if existing_cotenant not in potential_cotenants:
669
- return False
670
- return True
737
+ # check with first key if is not in peft format
738
+ # first_key = next(iter(state_dict.keys()))
739
+ # if "lora_A" not in first_key:
740
+ # state_dict = convert_unet_state_dict_to_peft(state_dict)
671
741
 
742
+ if adapter_name in getattr(transformer, "peft_config", {}):
743
+ raise ValueError(
744
+ f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
745
+ )
672
746
 
673
- def gpu_load_blocks(self, model_id, blocks_name):
674
- # cl = clock.start()
747
+ rank = {}
748
+ for key, val in state_dict.items():
749
+ if "lora_B" in key:
750
+ rank[key] = val.shape[1]
675
751
 
676
- if blocks_name != None:
677
- self.loaded_blocks[model_id] = blocks_name
752
+ if network_alphas is not None and len(network_alphas) >= 1:
753
+ alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
754
+ network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
678
755
 
679
- entry_name = model_id if blocks_name is None else model_id + "/" + blocks_name
756
+ lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
680
757
 
681
- def cpu_to_gpu(stream_to_use, blocks_params): #, record_for_stream = None
682
- with torch.cuda.stream(stream_to_use):
683
- for param in blocks_params:
684
- parent_module, n, p = param
685
- q = p.to("cuda", non_blocking=True)
686
- q = torch.nn.Parameter(q , requires_grad=False)
687
- setattr(parent_module, n , q)
688
- # if record_for_stream != None:
689
- # if isinstance(p, QTensor):
690
- # q._data.record_stream(record_for_stream)
691
- # q._scale.record_stream(record_for_stream)
692
- # else:
693
- # p.data.record_stream(record_for_stream)
758
+ lora_config = LoraConfig(**lora_config_kwargs)
759
+ peft_kwargs = {}
760
+ peft_kwargs["low_cpu_mem_usage"] = True
761
+ inject_adapter_in_model(lora_config, model, adapter_name=adapter_name, **peft_kwargs)
694
762
 
763
+ incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
695
764
 
696
- if self.verboseLevel >=2:
697
- model = self.models[model_id]
698
- model_name = model._get_name()
699
- print(f"Loading model {entry_name} ({model_name}) in GPU")
700
-
765
+ warn_msg = ""
766
+ if incompatible_keys is not None:
767
+ # Check only for unexpected keys.
768
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
769
+ if unexpected_keys:
770
+ pass
771
+ if verboseLevel >=1:
772
+ print(f"Lora '{path}' was loaded in model '{_get_module_name(model)}'")
773
+ set_weights_and_activate_adapters(model,[ str(i) for i in range(len(lora_multi))], lora_multi)
701
774
 
702
- if self.async_transfers and blocks_name != None:
703
- first = self.prev_blocks_names[entry_name] == None
704
- next_blocks_entry = self.next_blocks_names[entry_name] if entry_name in self.next_blocks_names else None
705
- if first:
706
- cpu_to_gpu(torch.cuda.current_stream(), self.blocks_of_modules[entry_name])
707
- torch.cuda.synchronize()
775
+ def move_loras_to_device(model, device="cpu" ):
776
+ if hasattr( model, "_lora_loadable_modules"):
777
+ for k in model._lora_loadable_modules:
778
+ move_loras_to_device(getattr(model,k), device)
779
+ return
780
+
781
+ for k, m in model.named_modules():
782
+ if ".lora_" in k:
783
+ m.to(device)
708
784
 
709
- if next_blocks_entry != None:
710
- cpu_to_gpu(self.transfer_stream, self.blocks_of_modules[next_blocks_entry]) #, self.default_stream
785
+ def fast_load_transformers_model(model_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, verboseLevel = -1):
786
+ """
787
+ quick version of .LoadfromPretrained of the transformers library
788
+ used to build a model and load the corresponding weights (quantized or not)
789
+ """
711
790
 
712
- else:
713
- cpu_to_gpu(self.default_stream, self.blocks_of_modules[entry_name])
714
- torch.cuda.synchronize()
715
- # cl.stop()
716
- # print(f"load time: {cl.format_time_gap()}")
791
+
792
+ import os.path
793
+ from accelerate import init_empty_weights
794
+
795
+ if not (model_path.endswith(".sft") or model_path.endswith(".safetensors")):
796
+ raise Exception("full model path to file expected")
717
797
 
798
+ model_path = _get_model(model_path)
799
+ verboseLevel = _compute_verbose_level(verboseLevel)
718
800
 
719
- def gpu_unload_blocks(self, model_id, blocks_name):
720
- # cl = clock.start()
721
- if blocks_name != None:
722
- self.loaded_blocks[model_id] = None
801
+ with safetensors2.safe_open(model_path) as f:
802
+ metadata = f.metadata()
723
803
 
724
- blocks_name = model_id if blocks_name is None else model_id + "/" + blocks_name
804
+ if metadata is None:
805
+ transformer_config = None
806
+ else:
807
+ transformer_config = metadata.get("config", None)
725
808
 
726
- if self.verboseLevel >=2:
727
- model = self.models[model_id]
728
- model_name = model._get_name()
729
- print(f"Unloading model {blocks_name} ({model_name}) from GPU")
730
-
731
- blocks_params = self.blocks_of_modules[blocks_name]
732
- for param in blocks_params:
733
- parent_module, n, p = param
734
- q = torch.nn.Parameter(p , requires_grad=False)
735
- setattr(parent_module, n , q)
736
- # cl.stop()
737
- # print(f"unload time: {cl.format_time_gap()}")
809
+ if transformer_config == None:
810
+ config_fullpath = os.path.join(os.path.dirname(model_path), "config.json")
738
811
 
812
+ if not os.path.isfile(config_fullpath):
813
+ raise Exception("a 'config.json' that describes the model is required in the directory of the model or inside the safetensor file")
739
814
 
740
- def gpu_load(self, model_id):
741
- model = self.models[model_id]
742
- self.active_models.append(model)
743
- self.active_models_ids.append(model_id)
815
+ with open(config_fullpath, "r", encoding="utf-8") as reader:
816
+ text = reader.read()
817
+ transformer_config= json.loads(text)
744
818
 
745
- self.gpu_load_blocks(model_id, None)
746
819
 
747
- # torch.cuda.current_stream().synchronize()
820
+ if "architectures" in transformer_config:
821
+ architectures = transformer_config["architectures"]
822
+ class_name = architectures[0]
748
823
 
749
- def unload_all(self):
750
- for model_id in self.active_models_ids:
751
- self.gpu_unload_blocks(model_id, None)
752
- loaded_block = self.loaded_blocks[model_id]
753
- if loaded_block != None:
754
- self.gpu_unload_blocks(model_id, loaded_block)
755
- self.loaded_blocks[model_id] = None
756
-
757
- self.active_models = []
758
- self.active_models_ids = []
759
- self.active_subcaches = []
760
- torch.cuda.empty_cache()
761
- gc.collect()
762
- self.last_reserved_mem_check = time.time()
824
+ module = __import__("transformers")
825
+ map = { "T5WithLMHeadModel" : "T5EncoderModel"}
826
+ class_name = map.get(class_name, class_name)
827
+ transfomer_class = getattr(module, class_name)
828
+ from transformers import AutoConfig
763
829
 
764
- def move_args_to_gpu(self, *args, **kwargs):
765
- new_args= []
766
- new_kwargs={}
767
- for arg in args:
768
- if torch.is_tensor(arg):
769
- if arg.dtype == torch.float32:
770
- arg = arg.to(torch.bfloat16).cuda(non_blocking=True)
771
- elif not arg.is_cuda:
772
- arg = arg.cuda(non_blocking=True)
773
- new_args.append(arg)
830
+ import tempfile
831
+ with tempfile.NamedTemporaryFile("w", delete = False, encoding ="utf-8") as fp:
832
+ fp.write(json.dumps(transformer_config))
833
+ fp.close()
834
+ config_obj = AutoConfig.from_pretrained(fp.name)
835
+ os.remove(fp.name)
774
836
 
775
- for k in kwargs:
776
- arg = kwargs[k]
777
- if torch.is_tensor(arg):
778
- if arg.dtype == torch.float32:
779
- arg = arg.to(torch.bfloat16).cuda(non_blocking=True)
780
- elif not arg.is_cuda:
781
- arg = arg.cuda(non_blocking=True)
782
- new_kwargs[k]= arg
783
-
784
- return new_args, new_kwargs
837
+ #needed to keep inits of non persistent buffers
838
+ with init_empty_weights():
839
+ model = transfomer_class(config_obj)
840
+
841
+ model = model.base_model
785
842
 
786
- def ready_to_check_mem(self):
787
- if self.anyCompiledModule:
788
- return
789
- cur_clock = time.time()
790
- # can't check at each call if we can empty the cuda cache as quering the reserved memory value is a time consuming operation
791
- if (cur_clock - self.last_reserved_mem_check)<0.200:
792
- return False
793
- self.last_reserved_mem_check = cur_clock
794
- return True
843
+ elif "_class_name" in transformer_config:
844
+ class_name = transformer_config["_class_name"]
795
845
 
846
+ module = __import__("diffusers")
847
+ transfomer_class = getattr(module, class_name)
796
848
 
797
- def empty_cache_if_needed(self):
798
- mem_reserved = torch.cuda.memory_reserved()
799
- mem_threshold = 0.9*self.device_mem_capacity
800
- if mem_reserved >= mem_threshold:
801
- mem_allocated = torch.cuda.memory_allocated()
802
- if mem_allocated <= 0.70 * mem_reserved:
803
- # print(f"Cuda empty cache triggered as Allocated Memory ({mem_allocated/1024000:0f} MB) is lot less than Cached Memory ({mem_reserved/1024000:0f} MB) ")
804
- torch.cuda.empty_cache()
805
- tm= time.time()
806
- if self.verboseLevel >=2:
807
- print(f"Empty Cuda cache at {tm}")
808
- # print(f"New cached memory after purge is {torch.cuda.memory_reserved()/1024000:0f} MB) ")
849
+ with init_empty_weights():
850
+ model = transfomer_class.from_config(transformer_config)
809
851
 
810
852
 
811
- def any_param_or_buffer(self, target_module: torch.nn.Module):
812
-
813
- for _ in target_module.parameters(recurse= False):
814
- return True
815
-
816
- for _ in target_module.buffers(recurse= False):
817
- return True
818
-
819
- return False
853
+ torch.set_default_device('cpu')
820
854
 
821
- def hook_load_data_if_needed(self, target_module, model_id,blocks_name, context):
855
+ model._config = transformer_config
856
+
857
+ load_model_data(model,model_path, do_quantize = do_quantize, quantizationType = quantizationType, pinToMemory= pinToMemory, partialPinning= partialPinning, verboseLevel=verboseLevel )
822
858
 
823
- @torch.compiler.disable()
824
- def load_data_if_needed(module, *args, **kwargs):
825
- some_context = context #for debugging
826
- if blocks_name == None:
827
- if self.ready_to_check_mem():
828
- self.empty_cache_if_needed()
829
- else:
830
- loaded_block = self.loaded_blocks[model_id]
831
- if (loaded_block == None or loaded_block != blocks_name) :
832
- if loaded_block != None:
833
- self.gpu_unload_blocks(model_id, loaded_block)
834
- if self.ready_to_check_mem():
835
- self.empty_cache_if_needed()
836
- self.loaded_blocks[model_id] = blocks_name
837
- self.gpu_load_blocks(model_id, blocks_name)
859
+ return model
838
860
 
839
- target_module.register_forward_pre_hook(load_data_if_needed)
840
861
 
841
862
 
842
- def hook_check_empty_cache_needed(self, target_module, model_id,blocks_name, previous_method, context):
863
+ def load_model_data(model, file_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, verboseLevel = -1):
864
+ """
865
+ Load a model, detect if it has been previously quantized using quanto and do the extra setup if necessary
866
+ """
843
867
 
844
- qint4quantization = isinstance(target_module, QModuleMixin) and target_module.weight!= None and target_module.weight.qtype == qint4
845
- if qint4quantization:
846
- pass
868
+ file_path = _get_model(file_path)
869
+ verboseLevel = _compute_verbose_level(verboseLevel)
847
870
 
848
- def check_empty_cuda_cache(module, *args, **kwargs):
849
- # if self.ready_to_check_mem():
850
- # self.empty_cache_if_needed()
851
- if blocks_name == None:
852
- if self.ready_to_check_mem():
853
- self.empty_cache_if_needed()
854
- else:
855
- loaded_block = self.loaded_blocks[model_id]
856
- if (loaded_block == None or loaded_block != blocks_name) :
857
- if loaded_block != None:
858
- self.gpu_unload_blocks(model_id, loaded_block)
859
- if self.ready_to_check_mem():
860
- self.empty_cache_if_needed()
861
- self.loaded_blocks[model_id] = blocks_name
862
- self.gpu_load_blocks(model_id, blocks_name)
863
- if qint4quantization:
864
- args, kwargs = self.move_args_to_gpu(*args, **kwargs)
871
+ model = _remove_model_wrapper(model)
865
872
 
866
- return previous_method(*args, **kwargs)
873
+ if not (".safetensors" in file_path or ".sft" in file_path):
874
+ if pinToMemory:
875
+ raise Exception("Pinning to memory while loading only supported for safe tensors files")
876
+ state_dict = torch.load(file_path, weights_only=True)
877
+ if "module" in state_dict:
878
+ state_dict = state_dict["module"]
879
+ else:
880
+ state_dict, metadata = _safetensors_load_file(file_path)
881
+
882
+ if metadata is None:
883
+ quantization_map = None
884
+ else:
885
+ quantization_map = metadata.get("quantization_map", None)
886
+ config = metadata.get("config", None)
887
+ if config is not None:
888
+ model._config = config
867
889
 
868
890
 
869
- if hasattr(target_module, "_mm_id"):
870
- orig_model_id = getattr(target_module, "_mm_id")
871
- if self.verboseLevel >=2:
872
- print(f"Model '{model_id}' shares module '{target_module._get_name()}' with module '{orig_model_id}' ")
873
- assert not self.any_param_or_buffer(target_module)
874
891
 
875
- return
876
- setattr(target_module, "_mm_id", model_id)
877
- setattr(target_module, "forward", functools.update_wrapper(functools.partial(check_empty_cuda_cache, target_module), previous_method) )
892
+ if quantization_map is None:
893
+ pos = str.rfind(file_path, ".")
894
+ if pos > 0:
895
+ quantization_map_path = file_path[:pos]
896
+ quantization_map_path += "_map.json"
878
897
 
879
-
880
- def hook_change_module(self, target_module, model, model_id, module_id, previous_method):
881
- def check_change_module(module, *args, **kwargs):
882
- performEmptyCacheTest = False
883
- if not model_id in self.active_models_ids:
884
- new_model_id = getattr(module, "_mm_id")
885
- # do not always unload existing models if it is more efficient to keep in them in the GPU
886
- # (e.g: small modules whose calls are text encoders)
887
- if not self.can_model_be_cotenant(new_model_id) :
888
- self.unload_all()
889
- performEmptyCacheTest = False
890
- self.gpu_load(new_model_id)
891
- # transfer leftovers inputs that were incorrectly created in the RAM (mostly due to some .device tests that returned incorrectly "cpu")
892
- args, kwargs = self.move_args_to_gpu(*args, **kwargs)
893
- if performEmptyCacheTest:
894
- self.empty_cache_if_needed()
895
-
896
- return previous_method(*args, **kwargs)
897
-
898
- if hasattr(target_module, "_mm_id"):
899
- return
900
- setattr(target_module, "_mm_id", model_id)
898
+ if os.path.isfile(quantization_map_path):
899
+ with open(quantization_map_path, 'r') as f:
900
+ quantization_map = json.load(f)
901
901
 
902
- setattr(target_module, "forward", functools.update_wrapper(functools.partial(check_change_module, target_module), previous_method) )
903
902
 
904
- if not self.verboseLevel >=1:
905
- return
906
903
 
907
- if module_id == None or module_id =='':
908
- model_name = model._get_name()
909
- print(f"Hooked in model '{model_id}' ({model_name})")
904
+ if quantization_map is None :
905
+ if "quanto" in file_path and not do_quantize:
906
+ print("Model seems to be quantized by quanto but no quantization map was found whether inside the model or in a separate '{file_path[:json]}_map.json' file")
907
+ else:
908
+ _requantize(model, state_dict, quantization_map)
910
909
 
910
+ missing_keys , unexpected_keys = model.load_state_dict(state_dict, False, assign = True )
911
+ # if len(missing_keys) > 0:
912
+ # sd_crap = { k : None for k in missing_keys}
913
+ # missing_keys , unexpected_keys = model.load_state_dict(sd_crap, strict =False, assign = True )
914
+ del state_dict
911
915
 
912
- # Not implemented yet, but why would one want to get rid of these features ?
913
- # def unhook_module(module: torch.nn.Module):
914
- # if not hasattr(module,"_mm_id"):
915
- # return
916
-
917
- # delattr(module, "_mm_id")
918
-
919
- # def unhook_all(parent_module: torch.nn.Module):
920
- # for module in parent_module.components.items():
921
- # self.unhook_module(module)
916
+ for k,p in model.named_parameters():
917
+ if p.is_meta:
918
+ txt = f"Incompatible State Dictionary or 'Init_Empty_Weights' not set since parameter '{k}' has no data"
919
+ raise Exception(txt)
920
+ for k,b in model.named_buffers():
921
+ if b.is_meta:
922
+ txt = f"Incompatible State Dictionary or 'Init_Empty_Weights' not set since buffer '{k}' has no data"
923
+ raise Exception(txt)
922
924
 
923
- import torch
925
+ if do_quantize:
926
+ if quantization_map is None:
927
+ if _quantize(model, quantizationType, verboseLevel=verboseLevel, model_id=file_path):
928
+ quantization_map = model._quanto_map
929
+ else:
930
+ if verboseLevel >=1:
931
+ print("Model already quantized")
924
932
 
933
+ if pinToMemory:
934
+ _pin_to_memory(model, file_path, partialPinning = partialPinning, verboseLevel = verboseLevel)
925
935
 
936
+ return
926
937
 
938
+ def get_model_name(model):
939
+ return model.name
927
940
 
928
- def load_loras_into_model(model, lora_path, lora_multi = None, verboseLevel = -1):
929
- verboseLevel = _compute_verbose_level(verboseLevel)
941
+ class HfHook:
942
+ def __init__(self):
943
+ self.execution_device = "cuda"
930
944
 
931
- if inject_adapter_in_model == None or set_weights_and_activate_adapters == None or get_peft_kwargs == None:
932
- raise Exception("Unable to load Lora, missing 'peft' and / or 'diffusers' modules")
933
-
934
- if not isinstance(lora_path, list):
935
- lora_path = [lora_path]
936
-
937
- if lora_multi is None:
938
- lora_multi = [1. for _ in lora_path]
945
+ def detach_hook(self, module):
946
+ pass
939
947
 
940
- for i, path in enumerate(lora_path):
941
- adapter_name = str(i)
948
+ last_offload_obj = None
949
+ class offload:
950
+ def __init__(self):
951
+ self.active_models = []
952
+ self.active_models_ids = []
953
+ self.active_subcaches = {}
954
+ self.models = {}
955
+ self.verboseLevel = 0
956
+ self.blocks_of_modules = {}
957
+ self.blocks_of_modules_sizes = {}
958
+ self.anyCompiledModule = False
959
+ self.device_mem_capacity = torch.cuda.get_device_properties(0).total_memory
960
+ self.last_reserved_mem_check =0
961
+ self.loaded_blocks = {}
962
+ self.prev_blocks_names = {}
963
+ self.next_blocks_names = {}
964
+ self.default_stream = torch.cuda.default_stream(torch.device("cuda")) # torch.cuda.current_stream()
965
+ self.transfer_stream = torch.cuda.Stream()
966
+ self.async_transfers = False
967
+ global last_offload_obj
968
+ last_offload_obj = self
942
969
 
943
- state_dict = safetensors2.torch_load_file(path)
944
970
 
945
- keys = list(state_dict.keys())
946
- if len(keys) == 0:
947
- raise Exception(f"Empty Lora '{path}'")
948
-
971
+ def add_module_to_blocks(self, model_id, blocks_name, submodule, prev_block_name):
949
972
 
950
- network_alphas = {}
951
- for k in keys:
952
- if "alpha" in k:
953
- alpha_value = state_dict.pop(k)
954
- if not ( (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
955
- alpha_value, float
956
- )):
957
- network_alphas[k] = torch.tensor( float(alpha_value.item() ) )
973
+ entry_name = model_id if blocks_name is None else model_id + "/" + blocks_name
974
+ if entry_name in self.blocks_of_modules:
975
+ blocks_params = self.blocks_of_modules[entry_name]
976
+ blocks_params_size = self.blocks_of_modules_sizes[entry_name]
977
+ else:
978
+ blocks_params = []
979
+ self.blocks_of_modules[entry_name] = blocks_params
980
+ blocks_params_size = 0
981
+ if blocks_name !=None:
958
982
 
959
- pos = keys[0].find(".")
960
- prefix = keys[0][0:pos]
961
- if not any( prefix.startswith(some_prefix) for some_prefix in ["diffusion_model", "transformer"]):
962
- msg = f"No compatible weight was found in Lora file '{path}'. Please check that it is compatible with the Diffusers format."
963
- raise Exception(msg)
983
+ prev_entry_name = None if prev_block_name == None else model_id + "/" + prev_block_name
984
+ self.prev_blocks_names[entry_name] = prev_entry_name
985
+ if not prev_block_name == None:
986
+ self.next_blocks_names[prev_entry_name] = entry_name
964
987
 
965
- transformer = model
966
988
 
967
- transformer_keys = [k for k in keys if k.startswith(prefix)]
968
- state_dict = {
969
- k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in transformer_keys
970
- }
989
+ for k,p in submodule.named_parameters(recurse=False):
990
+ if isinstance(p, QTensor):
991
+ blocks_params.append( (submodule, k, p, False ) )
971
992
 
972
- sd_keys = state_dict.keys()
973
- if len(sd_keys) == 0:
974
- print(f"No compatible weight was found in Lora file '{path}'. Please check that it is compatible with the Diffusers format.")
975
- return
993
+ if p._qtype == qint4:
994
+ if hasattr(p,"_scale_shift"):
995
+ blocks_params_size += torch.numel(p._scale_shift) * p._scale_shift.element_size()
996
+ blocks_params_size += torch.numel(p._data._data) * p._data._data.element_size()
997
+ else:
998
+ blocks_params_size += torch.numel(p._scale) * p._scale.element_size()
999
+ blocks_params_size += torch.numel(p._shift) * p._shift.element_size()
1000
+ blocks_params_size += torch.numel(p._data._data) * p._data._data.element_size()
1001
+ else:
1002
+ blocks_params_size += torch.numel(p._scale) * p._scale.element_size()
1003
+ blocks_params_size += torch.numel(p._data) * p._data.element_size()
1004
+ else:
1005
+ blocks_params.append( (submodule, k, p, False) )
1006
+ blocks_params_size += torch.numel(p.data) * p.data.element_size()
976
1007
 
977
- # is_correct_format = all("lora" in key for key in state_dict.keys())
1008
+ for k, p in submodule.named_buffers(recurse=False):
1009
+ blocks_params.append( (submodule, k, p, True) )
1010
+ blocks_params_size += p.data.nbytes
978
1011
 
979
1012
 
1013
+ self.blocks_of_modules_sizes[entry_name] = blocks_params_size
980
1014
 
1015
+ return blocks_params_size
981
1016
 
982
- # check with first key if is not in peft format
983
- # first_key = next(iter(state_dict.keys()))
984
- # if "lora_A" not in first_key:
985
- # state_dict = convert_unet_state_dict_to_peft(state_dict)
986
1017
 
987
- if adapter_name in getattr(transformer, "peft_config", {}):
988
- raise ValueError(
989
- f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
990
- )
1018
+ def can_model_be_cotenant(self, model_id):
1019
+ potential_cotenants= cotenants_map.get(model_id, None)
1020
+ if potential_cotenants is None:
1021
+ return False
1022
+ for existing_cotenant in self.active_models_ids:
1023
+ if existing_cotenant not in potential_cotenants:
1024
+ return False
1025
+ return True
991
1026
 
992
- rank = {}
993
- for key, val in state_dict.items():
994
- if "lora_B" in key:
995
- rank[key] = val.shape[1]
1027
+ @torch.compiler.disable()
1028
+ def gpu_load_blocks(self, model_id, blocks_name):
1029
+ # cl = clock.start()
996
1030
 
997
- if network_alphas is not None and len(network_alphas) >= 1:
998
- alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
999
- network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
1031
+ if blocks_name != None:
1032
+ self.loaded_blocks[model_id] = blocks_name
1000
1033
 
1001
- lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
1034
+ entry_name = model_id if blocks_name is None else model_id + "/" + blocks_name
1002
1035
 
1003
- lora_config = LoraConfig(**lora_config_kwargs)
1004
- peft_kwargs = {}
1005
- peft_kwargs["low_cpu_mem_usage"] = True
1006
- inject_adapter_in_model(lora_config, model, adapter_name=adapter_name, **peft_kwargs)
1007
-
1008
- incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
1009
-
1010
- warn_msg = ""
1011
- if incompatible_keys is not None:
1012
- # Check only for unexpected keys.
1013
- unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
1014
- if unexpected_keys:
1015
- pass
1016
- if verboseLevel >=1:
1017
- print(f"Lora '{path}' was loaded in model '{_get_module_name(model)}'")
1018
- set_weights_and_activate_adapters(model,[ str(i) for i in range(len(lora_multi))], lora_multi)
1019
-
1036
+ def cpu_to_gpu(stream_to_use, blocks_params): #, record_for_stream = None
1037
+ with torch.cuda.stream(stream_to_use):
1038
+ for param in blocks_params:
1039
+ parent_module, n, p, is_buffer = param
1040
+ q = p.to("cuda", non_blocking=True)
1041
+ if is_buffer:
1042
+ q = torch.nn.Buffer(q)
1043
+ else:
1044
+ q = torch.nn.Parameter(q , requires_grad=False)
1045
+ setattr(parent_module, n , q)
1046
+ # if record_for_stream != None:
1047
+ # if isinstance(p, QTensor):
1048
+ # q._data.record_stream(record_for_stream)
1049
+ # q._scale.record_stream(record_for_stream)
1050
+ # else:
1051
+ # p.data.record_stream(record_for_stream)
1020
1052
 
1021
- def fast_load_transformers_model(model_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, verboseLevel = -1):
1022
- """
1023
- quick version of .LoadfromPretrained of the transformers library
1024
- used to build a model and load the corresponding weights (quantized or not)
1025
- """
1026
1053
 
1027
-
1028
- import os.path
1029
- from accelerate import init_empty_weights
1054
+ if self.verboseLevel >=2:
1055
+ model = self.models[model_id]
1056
+ model_name = model._get_name()
1057
+ print(f"Loading model {entry_name} ({model_name}) in GPU")
1030
1058
 
1031
- if not (model_path.endswith(".sft") or model_path.endswith(".safetensors")):
1032
- raise Exception("full model path to file expected")
1033
-
1034
- model_path = _get_model(model_path)
1035
- verboseLevel = _compute_verbose_level(verboseLevel)
1036
-
1037
- with safetensors2.safe_open(model_path) as f:
1038
- metadata = f.metadata()
1039
-
1040
- if metadata is None:
1041
- transformer_config = None
1042
- else:
1043
- transformer_config = metadata.get("config", None)
1044
1059
 
1045
- if transformer_config == None:
1046
- config_fullpath = os.path.join(os.path.dirname(model_path), "config.json")
1047
-
1048
- if not os.path.isfile(config_fullpath):
1049
- raise Exception("a 'config.json' that describes the model is required in the directory of the model or inside the safetensor file")
1050
-
1051
- with open(config_fullpath, "r", encoding="utf-8") as reader:
1052
- text = reader.read()
1053
- transformer_config= json.loads(text)
1060
+ if self.async_transfers and blocks_name != None:
1061
+ first = self.prev_blocks_names[entry_name] == None
1062
+ next_blocks_entry = self.next_blocks_names[entry_name] if entry_name in self.next_blocks_names else None
1063
+ if first:
1064
+ cpu_to_gpu(torch.cuda.current_stream(), self.blocks_of_modules[entry_name])
1065
+ torch.cuda.synchronize()
1054
1066
 
1067
+ if next_blocks_entry != None:
1068
+ cpu_to_gpu(self.transfer_stream, self.blocks_of_modules[next_blocks_entry]) #, self.default_stream
1055
1069
 
1056
- if "architectures" in transformer_config:
1057
- architectures = transformer_config["architectures"]
1058
- class_name = architectures[0]
1070
+ else:
1071
+ cpu_to_gpu(self.default_stream, self.blocks_of_modules[entry_name])
1072
+ torch.cuda.synchronize()
1073
+ # cl.stop()
1074
+ # print(f"load time: {cl.format_time_gap()}")
1059
1075
 
1060
- module = __import__("transformers")
1061
- transfomer_class = getattr(module, class_name)
1062
- from transformers import AutoConfig
1076
+ @torch.compiler.disable()
1077
+ def gpu_unload_blocks(self, model_id, blocks_name):
1078
+ # cl = clock.start()
1079
+ if blocks_name != None:
1080
+ self.loaded_blocks[model_id] = None
1063
1081
 
1064
- import tempfile
1065
- with tempfile.NamedTemporaryFile("w", delete = False, encoding ="utf-8") as fp:
1066
- fp.write(json.dumps(transformer_config))
1067
- fp.close()
1068
- config_obj = AutoConfig.from_pretrained(fp.name)
1069
- os.remove(fp.name)
1082
+ blocks_name = model_id if blocks_name is None else model_id + "/" + blocks_name
1070
1083
 
1071
- #needed to keep inits of non persistent buffers
1072
- with init_empty_weights():
1073
- model = transfomer_class(config_obj)
1074
-
1075
- model = model.base_model
1084
+ if self.verboseLevel >=2:
1085
+ model = self.models[model_id]
1086
+ model_name = model._get_name()
1087
+ print(f"Unloading model {blocks_name} ({model_name}) from GPU")
1088
+
1089
+ blocks_params = self.blocks_of_modules[blocks_name]
1090
+ for param in blocks_params:
1091
+ parent_module, n, p, is_buffer = param
1092
+ if is_buffer:
1093
+ q = torch.nn.Buffer(p)
1094
+ else:
1095
+ q = torch.nn.Parameter(p , requires_grad=False)
1096
+ setattr(parent_module, n , q)
1097
+ # cl.stop()
1098
+ # print(f"unload time: {cl.format_time_gap()}")
1076
1099
 
1077
- elif "_class_name" in transformer_config:
1078
- class_name = transformer_config["_class_name"]
1100
+ # @torch.compiler.disable()
1101
+ def gpu_load(self, model_id):
1102
+ model = self.models[model_id]
1103
+ self.active_models.append(model)
1104
+ self.active_models_ids.append(model_id)
1079
1105
 
1080
- module = __import__("diffusers")
1081
- transfomer_class = getattr(module, class_name)
1106
+ self.gpu_load_blocks(model_id, None)
1082
1107
 
1083
- with init_empty_weights():
1084
- model = transfomer_class.from_config(transformer_config)
1108
+ # torch.cuda.current_stream().synchronize()
1085
1109
 
1110
+ def unload_all(self):
1111
+ for model_id in self.active_models_ids:
1112
+ self.gpu_unload_blocks(model_id, None)
1113
+ loaded_block = self.loaded_blocks[model_id]
1114
+ if loaded_block != None:
1115
+ self.gpu_unload_blocks(model_id, loaded_block)
1116
+ self.loaded_blocks[model_id] = None
1117
+
1118
+ self.active_models = []
1119
+ self.active_models_ids = []
1120
+ self.active_subcaches = []
1121
+ torch.cuda.empty_cache()
1122
+ gc.collect()
1123
+ self.last_reserved_mem_check = time.time()
1086
1124
 
1087
- torch.set_default_device('cpu')
1125
+ def move_args_to_gpu(self, *args, **kwargs):
1126
+ new_args= []
1127
+ new_kwargs={}
1128
+ for arg in args:
1129
+ if torch.is_tensor(arg):
1130
+ if arg.dtype == torch.float32:
1131
+ arg = arg.to(torch.bfloat16).cuda(non_blocking=True)
1132
+ elif not arg.is_cuda:
1133
+ arg = arg.cuda(non_blocking=True)
1134
+ new_args.append(arg)
1088
1135
 
1089
- model._config = transformer_config
1090
-
1091
- load_model_data(model,model_path, do_quantize = do_quantize, quantizationType = quantizationType, pinToMemory= pinToMemory, partialPinning= partialPinning, verboseLevel=verboseLevel )
1136
+ for k in kwargs:
1137
+ arg = kwargs[k]
1138
+ if torch.is_tensor(arg):
1139
+ if arg.dtype == torch.float32:
1140
+ arg = arg.to(torch.bfloat16).cuda(non_blocking=True)
1141
+ elif not arg.is_cuda:
1142
+ arg = arg.cuda(non_blocking=True)
1143
+ new_kwargs[k]= arg
1144
+
1145
+ return new_args, new_kwargs
1146
+
1147
+ def ready_to_check_mem(self):
1148
+ if self.anyCompiledModule:
1149
+ return
1150
+ cur_clock = time.time()
1151
+ # can't check at each call if we can empty the cuda cache as quering the reserved memory value is a time consuming operation
1152
+ if (cur_clock - self.last_reserved_mem_check)<0.200:
1153
+ return False
1154
+ self.last_reserved_mem_check = cur_clock
1155
+ return True
1092
1156
 
1093
- return model
1094
1157
 
1158
+ def empty_cache_if_needed(self):
1159
+ mem_reserved = torch.cuda.memory_reserved()
1160
+ mem_threshold = 0.9*self.device_mem_capacity
1161
+ if mem_reserved >= mem_threshold:
1162
+ mem_allocated = torch.cuda.memory_allocated()
1163
+ if mem_allocated <= 0.70 * mem_reserved:
1164
+ # print(f"Cuda empty cache triggered as Allocated Memory ({mem_allocated/1024000:0f} MB) is lot less than Cached Memory ({mem_reserved/1024000:0f} MB) ")
1165
+ torch.cuda.empty_cache()
1166
+ tm= time.time()
1167
+ if self.verboseLevel >=2:
1168
+ print(f"Empty Cuda cache at {tm}")
1169
+ # print(f"New cached memory after purge is {torch.cuda.memory_reserved()/1024000:0f} MB) ")
1095
1170
 
1096
1171
 
1097
- def load_model_data(model, file_path: str, do_quantize = False, quantizationType = qint8, pinToMemory = False, partialPinning = False, verboseLevel = -1):
1098
- """
1099
- Load a model, detect if it has been previously quantized using quanto and do the extra setup if necessary
1100
- """
1172
+ def any_param_or_buffer(self, target_module: torch.nn.Module):
1173
+
1174
+ for _ in target_module.parameters(recurse= False):
1175
+ return True
1176
+
1177
+ for _ in target_module.buffers(recurse= False):
1178
+ return True
1179
+
1180
+ return False
1101
1181
 
1102
- file_path = _get_model(file_path)
1103
- verboseLevel = _compute_verbose_level(verboseLevel)
1182
+ def hook_preload_blocks_for_compilation(self, target_module, model_id,blocks_name, context):
1104
1183
 
1105
- model = _remove_model_wrapper(model)
1184
+ # @torch.compiler.disable()
1185
+ def preload_blocks_for_compile(module, *args, **kwargs):
1186
+ some_context = context #for debugging
1187
+ if blocks_name == None:
1188
+ if self.ready_to_check_mem():
1189
+ self.empty_cache_if_needed()
1190
+ else:
1191
+ loaded_block = self.loaded_blocks[model_id]
1192
+ if (loaded_block == None or loaded_block != blocks_name) :
1193
+ if loaded_block != None:
1194
+ self.gpu_unload_blocks(model_id, loaded_block)
1195
+ if self.ready_to_check_mem():
1196
+ self.empty_cache_if_needed()
1197
+ self.loaded_blocks[model_id] = blocks_name
1198
+ self.gpu_load_blocks(model_id, blocks_name)
1199
+ # need to be registered before the forward not to be break the efficiency of the compilation chain
1200
+ # it should be at the top of the compilation as this type of hook in the middle of a chain seems to break memory performance
1201
+ target_module.register_forward_pre_hook(preload_blocks_for_compile)
1106
1202
 
1107
- # if pinToMemory and do_quantize:
1108
- # raise Exception("Pinning and Quantization can not be used at the same time")
1109
1203
 
1110
- if not (".safetensors" in file_path or ".sft" in file_path):
1111
- if pinToMemory:
1112
- raise Exception("Pinning to memory while loading only supported for safe tensors files")
1113
- state_dict = torch.load(file_path, weights_only=True)
1114
- if "module" in state_dict:
1115
- state_dict = state_dict["module"]
1116
- else:
1117
- state_dict, metadata = _safetensors_load_file(file_path)
1118
-
1119
- if metadata is None:
1120
- quantization_map = None
1121
- else:
1122
- quantization_map = metadata.get("quantization_map", None)
1123
- config = metadata.get("config", None)
1124
- if config is not None:
1125
- model._config = config
1204
+ def hook_check_empty_cache_needed(self, target_module, model_id,blocks_name, previous_method, context):
1126
1205
 
1206
+ qint4quantization = isinstance(target_module, QModuleMixin) and target_module.weight!= None and target_module.weight.qtype == qint4
1207
+ if qint4quantization:
1208
+ pass
1127
1209
 
1210
+ def check_empty_cuda_cache(module, *args, **kwargs):
1211
+ # if self.ready_to_check_mem():
1212
+ # self.empty_cache_if_needed()
1213
+ if blocks_name == None:
1214
+ if self.ready_to_check_mem():
1215
+ self.empty_cache_if_needed()
1216
+ else:
1217
+ loaded_block = self.loaded_blocks[model_id]
1218
+ if (loaded_block == None or loaded_block != blocks_name) :
1219
+ if loaded_block != None:
1220
+ self.gpu_unload_blocks(model_id, loaded_block)
1221
+ if self.ready_to_check_mem():
1222
+ self.empty_cache_if_needed()
1223
+ self.loaded_blocks[model_id] = blocks_name
1224
+ self.gpu_load_blocks(model_id, blocks_name)
1225
+ if qint4quantization:
1226
+ args, kwargs = self.move_args_to_gpu(*args, **kwargs)
1128
1227
 
1129
- if quantization_map is None:
1130
- pos = str.rfind(file_path, ".")
1131
- if pos > 0:
1132
- quantization_map_path = file_path[:pos]
1133
- quantization_map_path += "_map.json"
1228
+ return previous_method(*args, **kwargs)
1134
1229
 
1135
- if os.path.isfile(quantization_map_path):
1136
- with open(quantization_map_path, 'r') as f:
1137
- quantization_map = json.load(f)
1138
1230
 
1231
+ if hasattr(target_module, "_mm_id"):
1232
+ orig_model_id = getattr(target_module, "_mm_id")
1233
+ if self.verboseLevel >=2:
1234
+ print(f"Model '{model_id}' shares module '{target_module._get_name()}' with module '{orig_model_id}' ")
1235
+ assert not self.any_param_or_buffer(target_module)
1139
1236
 
1237
+ return
1238
+ setattr(target_module, "_mm_id", model_id)
1239
+ setattr(target_module, "forward", functools.update_wrapper(functools.partial(check_empty_cuda_cache, target_module), previous_method) )
1140
1240
 
1141
- if quantization_map is None :
1142
- if "quanto" in file_path and not do_quantize:
1143
- print("Model seems to be quantized by quanto but no quantization map was found whether inside the model or in a separate '{file_path[:json]}_map.json' file")
1144
- else:
1145
- _requantize(model, state_dict, quantization_map)
1241
+
1242
+ def hook_change_module(self, target_module, model, model_id, module_id, previous_method):
1243
+ def check_change_module(module, *args, **kwargs):
1244
+ performEmptyCacheTest = False
1245
+ if not model_id in self.active_models_ids:
1246
+ new_model_id = getattr(module, "_mm_id")
1247
+ # do not always unload existing models if it is more efficient to keep in them in the GPU
1248
+ # (e.g: small modules whose calls are text encoders)
1249
+ if not self.can_model_be_cotenant(new_model_id) :
1250
+ self.unload_all()
1251
+ performEmptyCacheTest = False
1252
+ self.gpu_load(new_model_id)
1253
+ # transfer leftovers inputs that were incorrectly created in the RAM (mostly due to some .device tests that returned incorrectly "cpu")
1254
+ args, kwargs = self.move_args_to_gpu(*args, **kwargs)
1255
+ if performEmptyCacheTest:
1256
+ self.empty_cache_if_needed()
1257
+
1258
+ return previous_method(*args, **kwargs)
1259
+
1260
+ if hasattr(target_module, "_mm_id"):
1261
+ return
1262
+ setattr(target_module, "_mm_id", model_id)
1146
1263
 
1147
- missing_keys , unexpected_keys = model.load_state_dict(state_dict, strict = quantization_map is None, assign = True )
1148
- del state_dict
1264
+ setattr(target_module, "forward", functools.update_wrapper(functools.partial(check_change_module, target_module), previous_method) )
1149
1265
 
1150
- if do_quantize:
1151
- if quantization_map is None:
1152
- if _quantize(model, quantizationType, verboseLevel=verboseLevel, model_id=file_path):
1153
- quantization_map = model._quanto_map
1154
- else:
1155
- if verboseLevel >=1:
1156
- print("Model already quantized")
1266
+ if not self.verboseLevel >=1:
1267
+ return
1157
1268
 
1158
- if pinToMemory:
1159
- _pin_to_memory(model, file_path, partialPinning = partialPinning, verboseLevel = verboseLevel)
1269
+ if module_id == None or module_id =='':
1270
+ model_name = model._get_name()
1271
+ print(f"Hooked in model '{model_id}' ({model_name})")
1160
1272
 
1161
- return
1162
1273
 
1163
- def save_model(model, file_path, do_quantize = False, quantizationType = qint8, verboseLevel = -1 ):
1274
+ def save_model(model, file_path, do_quantize = False, quantizationType = qint8, verboseLevel = -1, config_file_path = None ):
1164
1275
  """save the weights of a model and quantize them if requested
1165
1276
  These weights can be loaded again using 'load_model_data'
1166
1277
  """
1167
1278
 
1168
1279
  config = None
1169
-
1170
1280
  verboseLevel = _compute_verbose_level(verboseLevel)
1171
-
1172
- if hasattr(model, "_config"):
1281
+ if config_file_path !=None:
1282
+ with open(config_file_path, "r", encoding="utf-8") as reader:
1283
+ text = reader.read()
1284
+ config= json.loads(text)
1285
+ elif hasattr(model, "_config"):
1173
1286
  config = model._config
1174
1287
  elif hasattr(model, "config"):
1175
1288
  config_fullpath = None
@@ -1195,7 +1308,7 @@ def save_model(model, file_path, do_quantize = False, quantizationType = qint8,
1195
1308
  print(f"Saving file '{file_path}")
1196
1309
  safetensors2.torch_write_file(model.state_dict(), file_path , quantization_map = quantization_map, config = config)
1197
1310
  if verboseLevel >=1:
1198
- print(f"File '{file_path} saved")
1311
+ print(f"File '{file_path}' saved")
1199
1312
 
1200
1313
 
1201
1314
 
@@ -1286,7 +1399,6 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
1286
1399
  max_reservable_memory = _get_max_reservable_memory(perc_reserved_mem_max)
1287
1400
 
1288
1401
  estimatesBytesToPin = 0
1289
-
1290
1402
  for model_id in models:
1291
1403
  current_model: torch.nn.Module = models[model_id]
1292
1404
  # make sure that no RAM or GPU memory is not allocated for gradiant / training
@@ -1302,7 +1414,6 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
1302
1414
 
1303
1415
  for n, p in current_model.named_parameters():
1304
1416
  p.requires_grad = False
1305
- p = p.detach()
1306
1417
  if isinstance(p, QTensor):
1307
1418
  # # fix quanto bug (seems to have been fixed)
1308
1419
  # if not modelPinned and p._scale.dtype == torch.float32:
@@ -1352,21 +1463,18 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
1352
1463
  # Hook forward methods of modules
1353
1464
  for model_id in models:
1354
1465
  current_model: torch.nn.Module = models[model_id]
1355
- current_budget = model_budgets[model_id]
1356
- current_size = 0
1357
- cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq = None, None, None, -1
1358
- self.loaded_blocks[model_id] = None
1359
1466
  towers_names, towers_modules = _detect_main_towers(current_model)
1360
- towers_names = [n +"." for n in towers_names]
1361
- if self.verboseLevel>=2 and len(towers_names)>0:
1362
- print(f"Potential iterative blocks found in model '{model_id}':{towers_names}")
1363
1467
  # compile main iterative modules stacks ("towers")
1364
- if compileAllModels or model_id in modelsToCompile :
1468
+ compilationInThisOne = compileAllModels or model_id in modelsToCompile
1469
+ if compilationInThisOne:
1365
1470
  if self.verboseLevel>=1:
1366
- print(f"Pytorch compilation of model '{model_id}' is scheduled.")
1367
- for tower in towers_modules:
1368
- for submodel in tower:
1369
- submodel.forward= torch.compile(submodel.forward, backend= "inductor", mode="default" ) # , fullgraph= True, mode= "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs",
1471
+ if len(towers_modules)>0:
1472
+ print(f"Pytorch compilation of '{model_id}' is scheduled for these modules : {towers_names}.")
1473
+ else:
1474
+ print(f"Pytorch compilation of model '{model_id}' is not yet supported.")
1475
+
1476
+ for submodel in towers_modules:
1477
+ submodel.forward= torch.compile(submodel.forward, backend= "inductor", mode="default" ) # , fullgraph= True, mode= "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs",
1370
1478
  #dynamic=True,
1371
1479
 
1372
1480
  if pinAllModels or model_id in modelsToPin:
@@ -1376,6 +1484,11 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
1376
1484
  else:
1377
1485
  _pin_to_memory(current_model, model_id, partialPinning= partialPinning, perc_reserved_mem_max=perc_reserved_mem_max, verboseLevel=verboseLevel)
1378
1486
 
1487
+ current_budget = model_budgets[model_id]
1488
+ current_size = 0
1489
+ cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq = None, None, None, -1
1490
+ self.loaded_blocks[model_id] = None
1491
+
1379
1492
  for submodule_name, submodule in current_model.named_modules():
1380
1493
  # create a fake 'accelerate' parameter so that the _execution_device property returns always "cuda"
1381
1494
  # (it is queried in many pipelines even if offloading is not properly implemented)
@@ -1384,44 +1497,43 @@ def all(pipe_or_dict_of_modules, pinnedMemory = False, quantizeTransformer = Tru
1384
1497
 
1385
1498
  if submodule_name=='':
1386
1499
  continue
1387
- newListItem = False
1500
+
1388
1501
  if current_budget > 0:
1389
- if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): #
1390
- if cur_blocks_prefix == None:
1391
- cur_blocks_prefix = submodule_name + "."
1502
+ if cur_blocks_prefix != None:
1503
+ if submodule_name.startswith(cur_blocks_prefix):
1504
+ depth_prefix = cur_blocks_prefix.split(".")
1505
+ depth_name = submodule_name.split(".")
1506
+ level = depth_name[len(depth_prefix)-1]
1507
+ pre , num = _extract_num_from_str(level)
1508
+ if num != cur_blocks_seq and (cur_blocks_seq == -1 or current_size > current_budget):
1509
+ prev_blocks_name = cur_blocks_name
1510
+ cur_blocks_name = cur_blocks_prefix + str(num)
1511
+ # print(f"new block: {model_id}/{cur_blocks_name} - {submodule_name}")
1512
+ cur_blocks_seq = num
1392
1513
  else:
1393
- #if cur_blocks_prefix != submodule_name[:len(cur_blocks_prefix)]:
1394
- if not submodule_name.startswith(cur_blocks_prefix):
1395
- cur_blocks_prefix = submodule_name + "."
1396
- cur_blocks_name,cur_blocks_seq = None, -1
1397
- else:
1398
-
1399
- if cur_blocks_prefix is not None:
1400
- if submodule_name.startswith(cur_blocks_prefix):
1401
- num = int(submodule_name[len(cur_blocks_prefix):].split(".")[0])
1402
- newListItem= num != cur_blocks_seq
1403
- if num != cur_blocks_seq and (cur_blocks_name == None or current_size > current_budget):
1404
- prev_blocks_name = cur_blocks_name
1405
- cur_blocks_name = cur_blocks_prefix + str(num)
1406
- # print(f"new block: {model_id}/{cur_blocks_name} - {submodule_name}")
1407
- cur_blocks_seq = num
1408
- else:
1409
- cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq = None, None, None, -1
1410
-
1514
+ cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq = None, None, None, -1
1515
+
1516
+ if cur_blocks_prefix == None:
1517
+ pre , num = _extract_num_from_str(submodule_name)
1518
+ if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
1519
+ cur_blocks_prefix, prev_blocks_name, cur_blocks_seq = pre + ".", None, -1
1520
+ elif num >=0:
1521
+ cur_blocks_prefix, prev_blocks_name, cur_blocks_seq = pre, None, num
1522
+ cur_blocks_name = submodule_name
1523
+ # print(f"new block: {model_id}/{cur_blocks_name} - {submodule_name}")
1524
+
1525
+
1411
1526
  if hasattr(submodule, "forward"):
1412
1527
  submodule_method = getattr(submodule, "forward")
1413
1528
  if callable(submodule_method):
1414
1529
  if len(submodule_name.split("."))==1:
1415
1530
  self.hook_change_module(submodule, current_model, model_id, submodule_name, submodule_method)
1416
- elif newListItem:
1417
- self.hook_load_data_if_needed(submodule, model_id, cur_blocks_name, context = submodule_name )
1531
+ elif compilationInThisOne and submodule in towers_modules:
1532
+ self.hook_preload_blocks_for_compilation(submodule, model_id, cur_blocks_name, context = submodule_name )
1418
1533
  else:
1419
1534
  self.hook_check_empty_cache_needed(submodule, model_id, cur_blocks_name, submodule_method, context = submodule_name )
1420
1535
 
1421
-
1422
- current_size = self.add_module_to_blocks(model_id, cur_blocks_name, submodule, prev_blocks_name)
1423
-
1424
-
1536
+ current_size = self.add_module_to_blocks(model_id, cur_blocks_name, submodule, prev_blocks_name)
1425
1537
 
1426
1538
 
1427
1539
  if self.verboseLevel >=2:
@@ -1467,11 +1579,12 @@ def profile(pipe_or_dict_of_modules, profile_no: profile_type = profile_type.Ve
1467
1579
  models_to_scan = ("text_encoder", "text_encoder_2")
1468
1580
  candidates_to_quantize = ("t5", "llama", "llm")
1469
1581
  for model_id in models_to_scan:
1470
- name = module_names[model_id]
1471
- for candidate in candidates_to_quantize:
1472
- if candidate in name:
1473
- default_extraModelsToQuantize.append(model_id)
1474
- break
1582
+ if model_id in module_names:
1583
+ name = module_names[model_id]
1584
+ for candidate in candidates_to_quantize:
1585
+ if candidate in name:
1586
+ default_extraModelsToQuantize.append(model_id)
1587
+ break
1475
1588
 
1476
1589
 
1477
1590
  # transformer (video or image generator) should be as small as possible not to occupy space that could be used by actual image data
@@ -1480,6 +1593,7 @@ def profile(pipe_or_dict_of_modules, profile_no: profile_type = profile_type.Ve
1480
1593
  default_budgets = { "transformer" : 600 , "text_encoder": 3000, "text_encoder_2": 3000 }
1481
1594
  extraModelsToQuantize = None
1482
1595
  asyncTransfers = True
1596
+ budgets = None
1483
1597
 
1484
1598
  if profile_no == profile_type.HighRAM_HighVRAM:
1485
1599
  pinnedMemory= True