mmgp 1.1.0__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mmgp might be problematic. Click here for more details.

mmgp.py CHANGED
@@ -1,22 +1,28 @@
1
- # ------------------ Memory Management for the GPU Poor by DeepBeepMeep (mmgp)------------------
1
+ # ------------------ Memory Management 2.0 for the GPU Poor by DeepBeepMeep (mmgp)------------------
2
2
  #
3
3
  # This module contains multiples optimisations so that models such as Flux (and derived), Mochi, CogView, HunyuanVideo, ... can run smoothly on a 24 GB GPU limited card.
4
4
  # This a replacement for the accelerate library that should in theory manage offloading, but doesn't work properly with models that are loaded / unloaded several
5
5
  # times in a pipe (eg VAE).
6
6
  #
7
7
  # Requirements:
8
- # - GPU: RTX 3090/ RTX 4090 (24 GB of VRAM)
9
- # - RAM: minimum 48 GB, recommended 64 GB
8
+ # - VRAM: minimum 12 GB, recommended 24 GB (RTX 3090/ RTX 4090)
9
+ # - RAM: minimum 24 GB, recommended 48 - 64 GB
10
10
  #
11
11
  # It is almost plug and play and just needs to be invoked from the main app just after the model pipeline has been created.
12
12
  # 1) First make sure that the pipeline explictly loads the models in the CPU device
13
13
  # for instance: pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to("cpu")
14
14
  # 2) Once every potential Lora has been loaded and merged, add the following lines:
15
+ # For a quick setup, you may want to choose between 4 profiles depending on your hardware, for instance:
16
+ # from mmgp import offload, profile_type
17
+ # offload.profile(pipe, profile_type.HighRAM_LowVRAM_Fast)
18
+ # Alternatively you may want to your own parameters, for instance:
15
19
  # from mmgp import offload
16
- # offload.me(pipe)
17
- # The 'transformer' model that contains usually the video or image generator is quantized on the fly by default to 8 bits. If you want to save time on disk and reduce the loading time, you may want to load directly a prequantized model. In that case you need to set the option quantizeTransformer to False to turn off on the fly quantization.
18
- #
19
- # If you have more than 64GB RAM you may want to enable RAM pinning with the option pinInRAM = True. You will get in return super fast loading / unloading of models
20
+ # offload.all(pipe, pinInRAM=true, modelsToQuantize = ["text_encoder_2"] )
21
+ # The 'transformer' model that contains usually the video or image generator is quantized on the fly by default to 8 bits so that it can fit into 24 GB of VRAM.
22
+ # If you want to save time on disk and reduce the loading time, you may want to load directly a prequantized model. In that case you need to set the option quantizeTransformer to False to turn off on the fly quantization.
23
+ # You can specify a list of additional models string ids to quantize (for instance the text_encoder) using the optional argument modelsToQuantize. This may be useful if you have less than 48 GB of RAM.
24
+ # Note that there is little advantage on the GPU / VRAM side to quantize text encoders as their inputs are usually quite light.
25
+ # Conversely if you have more than 48GB RAM you may want to enable RAM pinning with the option pinInRAM = True. You will get in return super fast loading / unloading of models
20
26
  # (this can save significant time if the same pipeline is run multiple times in a row)
21
27
  #
22
28
  # Sometime there isn't an explicit pipe object as each submodel is loaded separately in the main app. If this is the case, you need to create a dictionary that manually maps all the models.
@@ -51,10 +57,15 @@ import torch
51
57
  import gc
52
58
  import time
53
59
  import functools
60
+ import sys
61
+ import json
62
+
54
63
  from optimum.quanto import freeze, qfloat8, qint8, quantize, QModuleMixin, QTensor
55
64
 
56
65
 
57
66
 
67
+ ONE_MB = 1048576
68
+
58
69
  cotenants_map = {
59
70
  "text_encoder": ["vae", "text_encoder_2"],
60
71
  "text_encoder_2": ["vae", "text_encoder"],
@@ -77,10 +88,107 @@ def move_tensors(obj, device):
77
88
  else:
78
89
  raise TypeError("Tensor or list / dict of tensors expected")
79
90
 
91
+ def _quantize(model_to_quantize, weights=qint8, verboseLevel = 1, threshold = 1000000000, model_id = None):
92
+
93
+ sizeofbfloat16 = torch.bfloat16.itemsize
94
+
95
+ def compute_submodule_size(submodule):
96
+ size = 0
97
+ for p in submodule.parameters(recurse=False):
98
+ size += torch.numel(p.data) * sizeofbfloat16
99
+
100
+ for p in submodule.buffers(recurse=False):
101
+ size += torch.numel(p.data) * sizeofbfloat16
102
+
103
+ return size
104
+
105
+ total_size =0
106
+ total_excluded = 0
107
+ exclude_list = []
108
+ submodule_size = 0
109
+ submodule_names = []
110
+ cur_blocks_prefix = None
111
+ prev_blocks_prefix = None
112
+
113
+ print(f"Quantization of model '{model_id}' started")
114
+
115
+ for submodule_name, submodule in model_to_quantize.named_modules():
116
+ if isinstance(submodule, QModuleMixin):
117
+ if verboseLevel>=1:
118
+ print("No quantization to do as model is already quantized")
119
+ return False
120
+
121
+
122
+ if submodule_name=='':
123
+ continue
124
+
125
+
126
+ flush = False
127
+ if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
128
+ if cur_blocks_prefix == None:
129
+ cur_blocks_prefix = submodule_name + "."
130
+ flush = True
131
+ else:
132
+ #if cur_blocks_prefix != submodule_name[:len(cur_blocks_prefix)]:
133
+ if not submodule_name.startswith(cur_blocks_prefix):
134
+ cur_blocks_prefix = submodule_name + "."
135
+ flush = True
136
+ else:
137
+ if cur_blocks_prefix is not None:
138
+ #if not cur_blocks_prefix == submodule_name[0:len(cur_blocks_prefix)]:
139
+ if not submodule_name.startswith(cur_blocks_prefix):
140
+ cur_blocks_prefix = None
141
+ flush = True
142
+
143
+ if flush:
144
+ if submodule_size <= threshold:
145
+ exclude_list += submodule_names
146
+ if verboseLevel >=2:
147
+ print(f"Excluded size {submodule_size/ONE_MB:.1f} MB: {prev_blocks_prefix} : {submodule_names}")
148
+ total_excluded += submodule_size
149
+
150
+ submodule_size = 0
151
+ submodule_names = []
152
+ prev_blocks_prefix = cur_blocks_prefix
153
+ size = compute_submodule_size(submodule)
154
+ submodule_size += size
155
+ total_size += size
156
+ submodule_names.append(submodule_name)
157
+
158
+ if submodule_size > 0 and submodule_size <= threshold:
159
+ exclude_list += submodule_names
160
+ if verboseLevel >=2:
161
+ print(f"Excluded size {submodule_size/ONE_MB:.1f} MB: {prev_blocks_prefix} : {submodule_names}")
162
+ total_excluded += submodule_size
163
+
164
+ perc_excluded =total_excluded/ total_size if total_size >0 else 1
165
+ if verboseLevel >=2:
166
+ print(f"Total Excluded {total_excluded/ONE_MB:.1f} MB oF {total_size/ONE_MB:.1f} that is {perc_excluded*100:.2f}%")
167
+ if perc_excluded >= 0.10:
168
+ print(f"Too many many modules are excluded, there is something wrong with the selection, switch back to full quantization.")
169
+ exclude_list = None
170
+
171
+ # we are obviously loading a model that has been already quantized
172
+
173
+ quantize(model_to_quantize,weights, exclude= exclude_list)
174
+ freeze(model_to_quantize)
175
+ torch.cuda.empty_cache()
176
+ gc.collect()
177
+ print(f"Quantization of model '{model_id}' done")
178
+
179
+ return True
80
180
 
81
181
  def get_model_name(model):
82
182
  return model.name
83
183
 
184
+ import enum
185
+ class profile_type(int, enum.Enum):
186
+ HighRAM_HighVRAM_Fastest = 1
187
+ HighRAM_LowVRAM_Fast = 2
188
+ LowRAM_HighVRAM_Medium = 3
189
+ LowRAM_LowVRAM_Slow = 4
190
+ VerylowRAM_LowVRAM_Slowest = 5
191
+
84
192
  class HfHook:
85
193
  def __init__(self):
86
194
  self.execution_device = "cuda"
@@ -92,28 +200,57 @@ class offload:
92
200
  def __init__(self):
93
201
  self.active_models = []
94
202
  self.active_models_ids = []
203
+ self.active_subcaches = {}
95
204
  self.models = {}
96
- self.verbose = False
205
+ self.verboseLevel = 0
97
206
  self.models_to_quantize = []
98
207
  self.pinned_modules_data = {}
99
- self.params_of_modules = {}
100
- self.pinTensors = False
208
+ self.blocks_of_modules = {}
209
+ self.blocks_of_modules_sizes = {}
210
+ self.compile = False
101
211
  self.device_mem_capacity = torch.cuda.get_device_properties(0).total_memory
102
212
  self.last_reserved_mem_check =0
213
+ self.loaded_blocks = {}
214
+ self.prev_blocks_names = {}
215
+ self.next_blocks_names = {}
216
+ self.default_stream = torch.cuda.default_stream(torch.device("cuda")) # torch.cuda.current_stream()
217
+ self.transfer_stream = torch.cuda.Stream()
218
+ self.async_transfers = False
103
219
 
104
- def collect_module_parameters(self, module: torch.nn.Module, module_params):
105
- if isinstance(module, (torch.nn.ModuleList, torch.nn.Sequential)):
106
- for i in range(len(module)):
107
- current_layer = module[i]
108
- module_params.extend(current_layer.parameters())
109
- module_params.extend(current_layer.buffers())
220
+
221
+ def add_module_to_blocks(self, model_id, blocks_name, submodule, prev_block_name):
222
+
223
+ entry_name = model_id if blocks_name is None else model_id + "/" + blocks_name
224
+ if entry_name in self.blocks_of_modules:
225
+ blocks_params = self.blocks_of_modules[entry_name]
226
+ blocks_params_size = self.blocks_of_modules_sizes[entry_name]
110
227
  else:
111
- for p in module.parameters(recurse=False):
112
- module_params.append(p)
113
- for p in module.buffers(recurse=False):
114
- module_params.append(p)
115
- for sub_module in module.children():
116
- self.collect_module_parameters(sub_module, module_params)
228
+ blocks_params = []
229
+ self.blocks_of_modules[entry_name] = blocks_params
230
+ blocks_params_size = 0
231
+ if blocks_name !=None:
232
+ prev_entry_name = None if prev_block_name == None else model_id + "/" + prev_block_name
233
+ self.prev_blocks_names[entry_name] = prev_entry_name
234
+ if not prev_block_name == None:
235
+ self.next_blocks_names[prev_entry_name] = entry_name
236
+
237
+ for p in submodule.parameters(recurse=False):
238
+ blocks_params.append(p)
239
+ if isinstance(p, QTensor):
240
+ blocks_params_size += p._data.nbytes
241
+ blocks_params_size += p._scale.nbytes
242
+ else:
243
+ blocks_params_size += p.data.nbytes
244
+
245
+ for p in submodule.buffers(recurse=False):
246
+ blocks_params.append(p)
247
+ blocks_params_size += p.data.nbytes
248
+
249
+
250
+ self.blocks_of_modules_sizes[entry_name] = blocks_params_size
251
+
252
+ return blocks_params_size
253
+
117
254
 
118
255
  def can_model_be_cotenant(self, model_id):
119
256
  potential_cotenants= cotenants_map.get(model_id, None)
@@ -124,45 +261,113 @@ class offload:
124
261
  return False
125
262
  return True
126
263
 
127
- def gpu_load(self, model_id):
128
- model = self.models[model_id]
129
- self.active_models.append(model)
130
- self.active_models_ids.append(model_id)
131
- if self.verbose:
264
+ @torch.compiler.disable()
265
+ def gpu_load_blocks(self, model_id, blocks_name, async_load = False):
266
+ if blocks_name != None:
267
+ self.loaded_blocks[model_id] = blocks_name
268
+
269
+ def cpu_to_gpu(stream_to_use, blocks_params, record_for_stream = None):
270
+ with torch.cuda.stream(stream_to_use):
271
+ for p in blocks_params:
272
+ if isinstance(p, QTensor):
273
+ p._data = p._data.cuda(non_blocking=True)
274
+ p._scale = p._scale.cuda(non_blocking=True)
275
+ else:
276
+ p.data = p.data.cuda(non_blocking=True)
277
+
278
+ if record_for_stream != None:
279
+ if isinstance(p, QTensor):
280
+ p._data.record_stream(record_for_stream)
281
+ p._scale.record_stream(record_for_stream)
282
+ else:
283
+ p.data.record_stream(record_for_stream)
284
+
285
+
286
+ entry_name = model_id if blocks_name is None else model_id + "/" + blocks_name
287
+ if self.verboseLevel >=2:
288
+ model = self.models[model_id]
132
289
  model_name = model._get_name()
133
- print(f"Loading model {model_name} ({model_id}) in GPU")
134
- if not self.pinInRAM:
135
- model.to("cuda")
290
+ print(f"Loading model {entry_name} ({model_name}) in GPU")
291
+
292
+
293
+ if self.async_transfers and blocks_name != None:
294
+ first = self.prev_blocks_names[entry_name] == None
295
+ next_blocks_entry = self.next_blocks_names[entry_name] if entry_name in self.next_blocks_names else None
296
+ if first:
297
+ cpu_to_gpu(torch.cuda.current_stream(), self.blocks_of_modules[entry_name])
298
+ # if next_blocks_entry != None:
299
+ # self.transfer_stream.wait_stream(self.default_stream)
300
+ # else:
301
+ # self.transfer_stream.wait_stream(self.default_stream)
302
+ torch.cuda.synchronize()
303
+
304
+ if next_blocks_entry != None:
305
+ cpu_to_gpu(self.transfer_stream, self.blocks_of_modules[next_blocks_entry]) #, self.default_stream
306
+
136
307
  else:
137
- module_params = self.params_of_modules[model_id]
138
- for p in module_params:
308
+ # if self.async_transfers:
309
+ # self.transfer_stream.wait_stream(self.default_stream)
310
+ cpu_to_gpu(self.default_stream, self.blocks_of_modules[entry_name])
311
+ torch.cuda.synchronize()
312
+
313
+
314
+ @torch.compiler.disable()
315
+ def gpu_unload_blocks(self, model_id, blocks_name):
316
+ if blocks_name != None:
317
+ self.loaded_blocks[model_id] = None
318
+
319
+ blocks_name = model_id if blocks_name is None else model_id + "/" + blocks_name
320
+
321
+ if self.verboseLevel >=2:
322
+ model = self.models[model_id]
323
+ model_name = model._get_name()
324
+ print(f"Unloading model {blocks_name} ({model_name}) from GPU")
325
+
326
+ blocks_params = self.blocks_of_modules[blocks_name]
327
+
328
+ if model_id in self.pinned_modules_data:
329
+ pinned_parameters_data = self.pinned_modules_data[model_id]
330
+ for p in blocks_params:
139
331
  if isinstance(p, QTensor):
140
- p._data = p._data.cuda(non_blocking=True)
141
- p._scale = p._scale.cuda(non_blocking=True)
332
+ data = pinned_parameters_data[p]
333
+ p._data = data[0]
334
+ p._scale = data[1]
142
335
  else:
143
- p.data = p.data.cuda(non_blocking=True) #
144
- # torch.cuda.current_stream().synchronize()
336
+ p.data = pinned_parameters_data[p]
337
+ else:
338
+ for p in blocks_params:
339
+ if isinstance(p, QTensor):
340
+ p._data = p._data.cpu()
341
+ p._scale = p._scale.cpu()
342
+ else:
343
+ p.data = p.data.cpu()
344
+
345
+
346
+
145
347
  @torch.compiler.disable()
348
+ def gpu_load(self, model_id):
349
+ model = self.models[model_id]
350
+ self.active_models.append(model)
351
+ self.active_models_ids.append(model_id)
352
+
353
+ self.gpu_load_blocks(model_id, None)
354
+
355
+ # torch.cuda.current_stream().synchronize()
356
+
146
357
  def unload_all(self):
147
- for model, model_id in zip(self.active_models, self.active_models_ids):
148
- if not self.pinInRAM:
149
- model.to("cpu")
150
- else:
151
- module_params = self.params_of_modules[model_id]
152
- pinned_parameters_data = self.pinned_modules_data[model_id]
153
- for p in module_params:
154
- if isinstance(p, QTensor):
155
- data = pinned_parameters_data[p]
156
- p._data = data[0]
157
- p._scale = data[1]
158
- else:
159
- p.data = pinned_parameters_data[p]
160
-
358
+ for model_id in self.active_models_ids:
359
+ self.gpu_unload_blocks(model_id, None)
360
+ loaded_block = self.loaded_blocks[model_id]
361
+ if loaded_block != None:
362
+ self.gpu_unload_blocks(model_id, loaded_block)
363
+ self.loaded_blocks[model_id] = None
161
364
 
162
365
  self.active_models = []
163
366
  self.active_models_ids = []
367
+ self.active_subcaches = []
164
368
  torch.cuda.empty_cache()
165
369
  gc.collect()
370
+ self.last_reserved_mem_check = time.time()
166
371
 
167
372
  def move_args_to_gpu(self, *args, **kwargs):
168
373
  new_args= []
@@ -186,10 +391,12 @@ class offload:
186
391
 
187
392
  return new_args, new_kwargs
188
393
 
189
- def ready_to_check_mem(self, forceMemoryCheck):
394
+ def ready_to_check_mem(self):
395
+ if self.compile:
396
+ return
190
397
  cur_clock = time.time()
191
398
  # can't check at each call if we can empty the cuda cache as quering the reserved memory value is a time consuming operation
192
- if not forceMemoryCheck and (cur_clock - self.last_reserved_mem_check)<0.200:
399
+ if (cur_clock - self.last_reserved_mem_check)<0.200:
193
400
  return False
194
401
  self.last_reserved_mem_check = cur_clock
195
402
  return True
@@ -197,20 +404,70 @@ class offload:
197
404
 
198
405
  def empty_cache_if_needed(self):
199
406
  mem_reserved = torch.cuda.memory_reserved()
200
- if mem_reserved >= 0.9*self.device_mem_capacity:
407
+ mem_threshold = 0.9*self.device_mem_capacity
408
+ if mem_reserved >= mem_threshold:
201
409
  mem_allocated = torch.cuda.memory_allocated()
202
410
  if mem_allocated <= 0.70 * mem_reserved:
203
411
  # print(f"Cuda empty cache triggered as Allocated Memory ({mem_allocated/1024000:0f} MB) is lot less than Cached Memory ({mem_reserved/1024000:0f} MB) ")
204
412
  torch.cuda.empty_cache()
413
+ tm= time.time()
414
+ if self.verboseLevel >=2:
415
+ print(f"Empty Cuda cache at {tm}")
205
416
  # print(f"New cached memory after purge is {torch.cuda.memory_reserved()/1024000:0f} MB) ")
206
417
 
207
- def hook_me_light(self, target_module, forceMemoryCheck, previous_method):
208
- def check_empty_cache(module, *args, **kwargs):
209
- if self.ready_to_check_mem(forceMemoryCheck):
418
+
419
+ def any_param_or_buffer(self, target_module: torch.nn.Module):
420
+
421
+ for _ in target_module.parameters(recurse= False):
422
+ return True
423
+
424
+ for _ in target_module.buffers(recurse= False):
425
+ return True
426
+
427
+ return False
428
+
429
+
430
+
431
+ def hook_me_light(self, target_module, model_id,blocks_name, previous_method, context):
432
+
433
+ anyParam = self.any_param_or_buffer(target_module)
434
+
435
+ def check_empty_cuda_cache(module, *args, **kwargs):
436
+ if self.ready_to_check_mem():
210
437
  self.empty_cache_if_needed()
211
438
  return previous_method(*args, **kwargs)
212
-
213
- setattr(target_module, "forward", functools.update_wrapper(functools.partial(check_empty_cache, target_module), previous_method) )
439
+
440
+
441
+ def load_module_blocks(module, *args, **kwargs):
442
+ #some_context = context #for debugging
443
+ if blocks_name == None:
444
+ if self.ready_to_check_mem():
445
+ self.empty_cache_if_needed()
446
+ else:
447
+ loaded_block = self.loaded_blocks[model_id]
448
+ if (loaded_block == None or loaded_block != blocks_name) :
449
+ if loaded_block != None:
450
+ self.gpu_unload_blocks(model_id, loaded_block)
451
+ if self.ready_to_check_mem():
452
+ self.empty_cache_if_needed()
453
+ self.loaded_blocks[model_id] = blocks_name
454
+ self.gpu_load_blocks(model_id, blocks_name)
455
+ return previous_method(*args, **kwargs)
456
+
457
+ if hasattr(target_module, "_mm_id"):
458
+ orig_model_id = getattr(target_module, "_mm_id")
459
+ if self.verboseLevel >=2:
460
+ print(f"Model '{model_id}' shares module '{target_module._get_name()}' with module '{orig_model_id}' ")
461
+ assert not anyParam
462
+ return
463
+ setattr(target_module, "_mm_id", model_id)
464
+
465
+
466
+ if blocks_name != None and anyParam:
467
+ setattr(target_module, "forward", functools.update_wrapper(functools.partial(load_module_blocks, target_module), previous_method) )
468
+ #print(f"new cache:{blocks_name}")
469
+ else:
470
+ setattr(target_module, "forward", functools.update_wrapper(functools.partial(check_empty_cuda_cache, target_module), previous_method) )
214
471
 
215
472
 
216
473
  def hook_me(self, target_module, model, model_id, module_id, previous_method):
@@ -234,13 +491,9 @@ class offload:
234
491
  return
235
492
  setattr(target_module, "_mm_id", model_id)
236
493
 
237
- # create a fake accelerate parameter so that the _execution_device property returns always "cuda"
238
- # (it is queried in many pipelines even if offloading is not properly implemented)
239
- if not hasattr(target_module, "_hf_hook"):
240
- setattr(target_module, "_hf_hook", HfHook())
241
494
  setattr(target_module, "forward", functools.update_wrapper(functools.partial(check_change_module, target_module), previous_method) )
242
495
 
243
- if not self.verbose:
496
+ if not self.verboseLevel >=1:
244
497
  return
245
498
 
246
499
  if module_id == None or module_id =='':
@@ -260,22 +513,185 @@ class offload:
260
513
  # self.unhook_module(module)
261
514
 
262
515
 
516
+ @staticmethod
517
+ def fast_load_transformers_model(model_path: str):
518
+ """
519
+ quick version of .LoadfromPretrained of the transformers library
520
+ used to build a model and load the corresponding weights (quantized or not)
521
+ """
522
+
523
+ from transformers import AutoConfig
524
+
525
+ if model_path.endswith(".sft") or model_path.endswith(".safetensors"):
526
+ config_path = model_path[ : model_path.rfind("/")]
527
+ else:
528
+ raise("full model path expected")
529
+ config_fullpath = config_path +"/config.json"
530
+
531
+ import os.path
532
+ if not os.path.isfile(config_fullpath):
533
+ raise("a 'config.json' that describes the model is required in the directory of the model")
534
+
535
+ with open(config_fullpath, "r", encoding="utf-8") as reader:
536
+ text = reader.read()
537
+ transformer_config= json.loads(text)
538
+ architectures = transformer_config["architectures"]
539
+ class_name = architectures[0]
540
+
541
+ module = __import__("transformers")
542
+ transfomer_class = getattr(module, class_name)
543
+
544
+ config = AutoConfig.from_pretrained(config_path)
545
+
546
+ from accelerate import init_empty_weights
547
+ #needed to keep inits of non persistent buffers
548
+ with init_empty_weights():
549
+ model = transfomer_class(config)
550
+
551
+ model = model.base_model
552
+ torch.set_default_device('cpu')
553
+ model.apply(model._initialize_weights)
554
+
555
+ #missing_keys, unexpected_keys =
556
+ offload.load_model_data(model,model_path, strict = True )
557
+
558
+ return model
559
+ # # text_encoder.final_layer_norm = text_encoder.norm
560
+ # model = model.base_model
561
+ # model.final_layer_norm = model.norm
562
+ # self.model = model
563
+
564
+
565
+
566
+ @staticmethod
567
+ def load_model_data(model, file_path: str, device=torch.device('cpu'), strict = True):
568
+ """
569
+ Load a model, detect if it has been previously quantized using quanto and do the extra setup if necessary
570
+ """
571
+ from optimum.quanto import requantize
572
+ import safetensors.torch
573
+
574
+ if "quanto" in file_path.lower():
575
+ pos = str.rfind(file_path, ".")
576
+ if pos > 0:
577
+ quantization_map_path = file_path[:pos]
578
+ quantization_map_path += "_map.json"
579
+
580
+
581
+ with open(quantization_map_path, 'r') as f:
582
+ quantization_map = json.load(f)
583
+
584
+ state_dict = safetensors.torch.load_file(file_path)
585
+
586
+ # change dtype of current meta model parameters because 'requantize' won't update the dtype on non quantized parameters
587
+ for k, p in model.named_parameters():
588
+ if not k in quantization_map and k in state_dict:
589
+ p_in_sd = state_dict[k]
590
+ if p.data.dtype != p_in_sd.data.dtype:
591
+ p.data = p.data.to(p_in_sd.data.dtype)
592
+
593
+ requantize(model, state_dict, quantization_map, device)
594
+
595
+ # for k, p in model.named_parameters():
596
+ # if p.data.dtype == torch.float32:
597
+ # pass
598
+
599
+
600
+ # del state_dict
601
+ return
602
+
603
+ else:
604
+ if ".safetensors" in file_path or ".sft" in file_path:
605
+ state_dict = safetensors.torch.load_file(file_path)
606
+
607
+ else:
608
+
609
+ state_dict = torch.load(file_path, weights_only=True)
610
+ if "module" in state_dict:
611
+ state_dict = state_dict["module"]
612
+
613
+
614
+ model.load_state_dict(state_dict, strict = strict, assign = True ) #strict=True,
615
+
616
+
617
+ return
618
+
619
+ @staticmethod
620
+ def save_model(model, file_path, do_quantize = False, quantization_type = qint8 ):
621
+ """save the weights of a model and quantize them if requested
622
+ These weights can be loaded again using 'load_model_data'
623
+ """
624
+ import safetensors.torch
625
+ pos = str.rfind(file_path, ".")
626
+ if pos > 0:
627
+ file_path = file_path[:pos]
628
+
629
+ if do_quantize:
630
+ _quantize(model, weights=quantization_type)
631
+
632
+ # # state_dict = {k: v.clone().contiguous() for k, v in model.state_dict().items()}
633
+ # state_dict = {k: v for k, v in model.state_dict().items()}
634
+
635
+
636
+
637
+ safetensors.torch.save_file(model.state_dict(), file_path + '.safetensors')
638
+
639
+ if do_quantize:
640
+ from optimum.quanto import quantization_map
641
+
642
+ with open(file_path + '_map.json', 'w') as f:
643
+ json.dump(quantization_map(model), f)
644
+
263
645
 
264
646
 
265
647
  @classmethod
266
- def all(cls, pipe_or_dict_of_modules, quantizeTransformer = True, pinInRAM = False, verbose = True):
648
+ def all(cls, pipe_or_dict_of_modules, quantizeTransformer = True, pinInRAM = False, verboseLevel = 1, modelsToQuantize = None, budgets= 0, info = None):
649
+ """Hook to a pipeline or a group of modules in order to reduce their VRAM requirements:
650
+ pipe_or_dict_of_modules : the pipeline object or a dictionary of modules of the model
651
+ quantizeTransformer: set True by default will quantize on the fly the video / image model
652
+ pinInRAM: move models in reserved memor. This allows very fast performance but requires 50% extra RAM (usually >=64 GB)
653
+ modelsToQuantize: a list of models to be also quantized on the fly (e.g the text_encoder), useful to reduce bith RAM and VRAM consumption
654
+ budgets: 0 by default (unlimited). If non 0, it corresponds to the maximum size in MB that every model will occupy at any moment
655
+ (in fact the real usage is twice this number). It is very efficient to reduce VRAM consumption but this feature may be very slow
656
+ if pinInRAM is not enabled
657
+ """
658
+
267
659
  self = cls()
268
- self.verbose = verbose
660
+ self.verboseLevel = verboseLevel
269
661
  self.pinned_modules_data = {}
662
+ model_budgets = {}
270
663
 
664
+ # model_budgets = {"text_encoder_2": 3400 }
665
+ HEADER = '\033[95m'
666
+ ENDC = '\033[0m'
667
+ BOLD ='\033[1m'
668
+ UNBOLD ='\033[0m'
669
+
670
+ print(f"{BOLD}{HEADER}************ Memory Management for the GPU Poor (mmgp 2.0) by DeepBeepMeep ************{ENDC}{UNBOLD}")
671
+ if info != None:
672
+ print(info)
673
+ budget = 0
674
+ if not budgets is None:
675
+ if isinstance(budgets , dict):
676
+ model_budgets = budgets
677
+ else:
678
+ budget = int(budgets) * ONE_MB
679
+
680
+ if (budgets!= None or budget >0) :
681
+ self.async_transfers = True
682
+
683
+ #pinInRAM = True
271
684
  # compile not working yet or slower
272
685
  compile = False
273
- self.pinInRAM = pinInRAM
686
+ #quantizeTransformer = False
687
+ #self.async_transfers = False
688
+ self.compile = compile
689
+
274
690
  pipe = None
275
- preloadInRAM = True
276
691
  torch.set_default_device('cuda')
277
692
  if hasattr(pipe_or_dict_of_modules, "components"):
278
- pipe_or_dict_of_modules.to("cpu") #XXXX
693
+ # commented as it not very useful and generates warnings
694
+ #pipe_or_dict_of_modules.to("cpu") #XXXX
279
695
  # create a fake Accelerate parameter so that lora loading doesn't change the device
280
696
  pipe_or_dict_of_modules.hf_device_map = torch.device("cuda")
281
697
  pipe = pipe_or_dict_of_modules
@@ -284,115 +700,186 @@ class offload:
284
700
 
285
701
  models = {k: v for k, v in pipe_or_dict_of_modules.items() if isinstance(v, torch.nn.Module)}
286
702
 
703
+ modelsToQuantize = modelsToQuantize if modelsToQuantize is not None else []
704
+ if not isinstance(modelsToQuantize, list):
705
+ modelsToQuantize = [modelsToQuantize]
287
706
  if quantizeTransformer:
288
- self.models_to_quantize = ["transformer"]
707
+ modelsToQuantize.append("transformer")
708
+
709
+ self.models_to_quantize = modelsToQuantize
710
+ models_already_loaded = []
711
+
712
+ modelsToPin = None
713
+ pinAllModels = False
714
+ if isinstance(pinInRAM, bool):
715
+ pinAllModels = pinInRAM
716
+ elif isinstance(pinInRAM, list):
717
+ modelsToPin = pinInRAM
718
+ else:
719
+ modelsToPin = [pinInRAM]
720
+
289
721
  # del models["transformer"] # to test everything but the transformer that has a much longer loading
290
- # models = { 'transformer': pipe_or_dict_of_modules["transformer"]} # to test only the transformer
722
+ sizeofbfloat16 = torch.bfloat16.itemsize
723
+ #
724
+ # models = { 'transformer': pipe_or_dict_of_modules["transformer"]} # to test only the transformer
725
+
726
+
291
727
  for model_id in models:
292
728
  current_model: torch.nn.Module = models[model_id]
729
+ modelPinned = pinAllModels or (modelsToPin != None and model_id in modelsToPin)
293
730
  # make sure that no RAM or GPU memory is not allocated for gradiant / training
294
- current_model.to("cpu").eval() #XXXXX
295
-
731
+ current_model.to("cpu").eval()
732
+ already_loaded = False
296
733
  # Quantize model just before transferring it to the RAM to keep OS cache file
297
734
  # open as short as possible. Indeed it seems that as long as the lazy safetensors
298
735
  # are not fully fully loaded, the OS won't be able to release the corresponding cache file in RAM.
299
736
  if model_id in self.models_to_quantize:
300
- print(f"Quantization of model '{model_id}' started")
301
- quantize(current_model, weights=qint8)
302
- freeze(current_model)
303
- print(f"Quantization of model '{model_id}' done")
304
- torch.cuda.empty_cache()
305
- gc.collect()
306
737
 
738
+ already_quantized = _quantize(current_model, weights=qint8, verboseLevel = self.verboseLevel, model_id=model_id)
739
+ if not already_quantized:
740
+ already_loaded = True
741
+ models_already_loaded.append(model_id)
307
742
 
308
-
309
- if preloadInRAM: #
310
- # load all the remaining unread lazy safetensors in RAM to free open cache files
311
- for p in current_model.parameters():
312
- # Preread every tensor in RAM except tensors that have just been quantified
313
- # and are no longer needed
314
- if isinstance(p, QTensor):
315
- # fix quanto bug (see below) now as he won't have any opportunity to do it during RAM pinning
316
- if not pinInRAM and p._scale.dtype == torch.float32:
317
- p._scale = p._scale.to(torch.bfloat16)
318
743
 
744
+ current_model_size = 0
745
+ # load all the remaining unread lazy safetensors in RAM to free open cache files
746
+ for p in current_model.parameters():
747
+ # Preread every tensor in RAM except tensors that have just been quantified
748
+ # and are no longer needed
749
+ if isinstance(p, QTensor):
750
+ # fix quanto bug (see below) now as he won't have any opportunity to do it during RAM pinning
751
+ if not modelPinned and p._scale.dtype == torch.float32:
752
+ p._scale = p._scale.to(torch.bfloat16)
753
+ current_model_size += torch.numel(p._scale) * sizeofbfloat16
754
+ current_model_size += torch.numel(p._data) * sizeofbfloat16 / 2
755
+ if pinInRAM and not already_loaded:
756
+ # Force flushing the lazy load so that reserved memory can be freed when we are ready to pin
757
+ p._scale = p._scale + 0
758
+ p._data = p._data + 0
759
+ else:
760
+ if p.data.dtype == torch.float32:
761
+ # convert any left overs float32 weight to bloat16 to divide by 2 the model memory footprint
762
+ p.data = p.data.to(torch.bfloat16)
319
763
  else:
320
- if p.data.dtype == torch.float32:
321
- # convert any left overs float32 weight to bloat16 to divide by 2 the model memory footprint
322
- p.data = p.data.to(torch.bfloat16)
323
- else:
324
- # force reading the tensors from the disk by pretending to modify them
325
- p.data = p.data + 0
326
-
764
+ # force reading the tensors from the disk by pretending to modify them
765
+ p.data = p.data + 0
766
+
767
+ current_model_size += torch.numel(p.data) * p.data.element_size()
768
+
769
+ for b in current_model.buffers():
770
+ if b.data.dtype == torch.float32:
771
+ # convert any left overs float32 weight to bloat16 to divide by 2 the model memory footprint
772
+ b.data = b.data.to(torch.bfloat16)
773
+ else:
774
+ # force reading the tensors from the disk by pretending to modify them
775
+ b.data = b.data + 0
776
+
777
+ current_model_size += torch.numel(p.data) * p.data.element_size()
778
+
779
+ if model_id not in self.models:
780
+ self.models[model_id] = current_model
781
+
782
+
783
+ model_budget = model_budgets[model_id] * ONE_MB if model_id in model_budgets else budget
784
+
785
+ if model_budget > 0 and model_budget > current_model_size:
786
+ model_budget = 0
787
+
788
+ model_budgets[model_id] = model_budget
789
+
790
+ # Pin in RAM models only once they have been fully loaded otherwise there will be some contention (at least on Linux OS) in the non pageable memory
791
+ # between partially loaded lazy safetensors and pinned tensors
792
+ for model_id in models:
793
+ current_model: torch.nn.Module = models[model_id]
794
+ if not (pinAllModels or modelsToPin != None and model_id in modelsToPin):
795
+ continue
796
+ if verboseLevel>=1:
797
+ print(f"Pinning tensors of '{model_id}' in RAM")
798
+ gc.collect()
799
+ pinned_parameters_data = {}
800
+ for p in current_model.parameters():
801
+ if isinstance(p, QTensor):
802
+ # pin in memory both quantized data and scales of quantized parameters
803
+ # but don't pin .data as it corresponds to the original tensor that we don't want to reload
804
+ p._data = p._data.pin_memory()
805
+ # fix quanto bug (that seems to have been fixed since&) that allows _scale to be float32 if the original weight was float32
806
+ # (this may cause type mismatch between dequantified bfloat16 weights and float32 scales)
807
+ if p._scale.dtype == torch.float32:
808
+ pass
809
+
810
+ p._scale = p._scale.to(torch.bfloat16).pin_memory() if p._scale.dtype == torch.float32 else p._scale.pin_memory()
811
+ pinned_parameters_data[p]=[p._data, p._scale]
812
+ else:
813
+ p.data = p.data.pin_memory()
814
+ pinned_parameters_data[p]=p.data
815
+ for b in current_model.buffers():
816
+ b.data = b.data.pin_memory()
817
+
818
+ pinned_buffers_data = {b: b.data for b in current_model.buffers()}
819
+ pinned_parameters_data.update(pinned_buffers_data)
820
+ self.pinned_modules_data[model_id]=pinned_parameters_data
327
821
 
328
- addModelFlag = False
329
822
 
330
- current_block_sequence = None
823
+ # Hook forward methods of modules
824
+ for model_id in models:
825
+ current_model: torch.nn.Module = models[model_id]
826
+ current_budget = model_budgets[model_id]
827
+ current_size = 0
828
+ cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq = None, None, None, -1
829
+ self.loaded_blocks[model_id] = None
830
+
331
831
  for submodule_name, submodule in current_model.named_modules():
832
+ # create a fake accelerate parameter so that the _execution_device property returns always "cuda"
833
+ # (it is queried in many pipelines even if offloading is not properly implemented)
834
+ if not hasattr(submodule, "_hf_hook"):
835
+ setattr(submodule, "_hf_hook", HfHook())
836
+
837
+ if submodule_name=='':
838
+ continue
839
+
840
+ if current_budget > 0:
841
+ if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
842
+ if cur_blocks_prefix == None:
843
+ cur_blocks_prefix = submodule_name + "."
844
+ else:
845
+ #if cur_blocks_prefix != submodule_name[:len(cur_blocks_prefix)]:
846
+ if not submodule_name.startswith(cur_blocks_prefix):
847
+ cur_blocks_prefix = submodule_name + "."
848
+ cur_blocks_name,cur_blocks_seq = None, -1
849
+ else:
850
+
851
+ if cur_blocks_prefix is not None:
852
+ #if cur_blocks_prefix == submodule_name[0:len(cur_blocks_prefix)]:
853
+ if submodule_name.startswith(cur_blocks_prefix):
854
+ num = int(submodule_name[len(cur_blocks_prefix):].split(".")[0])
855
+ if num != cur_blocks_seq and (cur_blocks_name == None or current_size > current_budget):
856
+ prev_blocks_name = cur_blocks_name
857
+ cur_blocks_name = cur_blocks_prefix + str(num)
858
+ # print(f"new block: {model_id}/{cur_blocks_name} - {submodule_name}")
859
+ cur_blocks_seq = num
860
+ else:
861
+ cur_blocks_prefix, prev_blocks_name, cur_blocks_name,cur_blocks_seq = None, None, None, -1
862
+
332
863
  if hasattr(submodule, "forward"):
333
864
  submodule_method = getattr(submodule, "forward")
334
865
  if callable(submodule_method):
335
- addModelFlag = True
336
- if submodule_name=='' or len(submodule_name.split("."))==1:
337
- # hook only the first two levels of modules with the full suite of processing
866
+ if len(submodule_name.split("."))==1:
867
+ # hook only the first level of modules with the full suite of processing
338
868
  self.hook_me(submodule, current_model, model_id, submodule_name, submodule_method)
339
- else:
340
- forceMemoryCheck = False
341
- pos = submodule_name.find(".0.")
342
- if pos > 0:
343
- if current_block_sequence == None:
344
- new_candidate = submodule_name[0:pos+3]
345
- if len(new_candidate.split("."))<=4:
346
- current_block_sequence = new_candidate
347
- # force a memory check when initiating a new sequence of blocks as the shapes of tensor will certainly change
348
- # and memory reusability is less likely
349
- # we limit this check to the first level of blocks as quering the cuda cache is time consuming
350
- forceMemoryCheck = True
351
- else:
352
- if current_block_sequence != submodule_name[0:len(current_block_sequence)]:
353
- current_block_sequence = None
354
- self.hook_me_light(submodule, forceMemoryCheck, submodule_method)
355
-
356
-
357
- if addModelFlag:
358
- if model_id not in self.models:
359
- self.models[model_id] = current_model
360
-
361
- # Pin in RAM models only once they have been fully loaded otherwise there may be some contention in the non pageable memory
362
- # between partially loaded lazy safetensors and pinned tensors
363
- if pinInRAM:
364
- if verbose:
365
- print("Pinning model tensors in RAM")
366
- torch.cuda.empty_cache()
367
- gc.collect()
368
- for model_id in models:
369
- pinned_parameters_data = {}
370
- current_model: torch.nn.Module = models[model_id]
371
- for p in current_model.parameters():
372
- if isinstance(p, QTensor):
373
- # pin in memory both quantized data and scales of quantized parameters
374
- # but don't pin .data as it corresponds to the original tensor that we don't want to reload
375
- p._data = p._data.pin_memory()
376
- # fix quanto bug that allows _scale to be float32 if the original weight was float32
377
- # (this may cause type mismatch between dequantified bfloat16 weights and float32 scales)
378
- p._scale = p._scale.to(torch.bfloat16).pin_memory() if p._scale.dtype == torch.float32 else p._scale.pin_memory()
379
- pinned_parameters_data[p]=[p._data, p._scale]
380
- else:
381
- p.data = p.data.pin_memory()
382
- pinned_parameters_data[p]=p.data
383
- for b in current_model.buffers():
384
- b.data = b.data.pin_memory()
869
+ else:
870
+ # force a memory check when initiating a new sequence of blocks as the shapes of tensor will certainly change
871
+ # and memory reusability is less likely
872
+ # we limit this check to the first level of blocks as quering the cuda cache is time consuming
873
+ self.hook_me_light(submodule, model_id, cur_blocks_name, submodule_method, context = submodule_name)
385
874
 
386
- pinned_buffers_data = {b: b.data for b in current_model.buffers()}
387
- pinned_parameters_data.update(pinned_buffers_data)
388
- self.pinned_modules_data[model_id]=pinned_parameters_data
875
+ # if compile and cur_blocks_name != None and model_id == "transformer" and "_blocks" in submodule_name:
876
+ # submodule.compile(mode="reduce-overhead" ) #mode= "max-autotune"
877
+
878
+ current_size = self.add_module_to_blocks(model_id, cur_blocks_name, submodule, prev_blocks_name)
389
879
 
390
- module_params = []
391
- self.params_of_modules[model_id] = module_params
392
- self.collect_module_parameters(current_model,module_params)
393
880
 
394
881
  if compile:
395
- if verbose:
882
+ if verboseLevel>=1:
396
883
  print("Torch compilation started")
397
884
  torch._dynamo.config.cache_size_limit = 10000
398
885
  # if pipe != None and hasattr(pipe, "__call__"):
@@ -403,13 +890,65 @@ class offload:
403
890
  current_model.compile(mode= "max-autotune")
404
891
  #models["transformer"].compile()
405
892
 
406
- if verbose:
893
+ if verboseLevel>=1:
407
894
  print("Torch compilation done")
408
895
 
896
+ if verboseLevel >=2:
897
+ for n,b in self.blocks_of_modules_sizes.items():
898
+ print(f"Size of submodel '{n}': {b/ONE_MB:.1f} MB")
899
+
409
900
  torch.cuda.empty_cache()
410
901
  gc.collect()
411
902
 
412
-
413
903
  return self
414
904
 
415
-
905
+
906
+
907
+ @staticmethod
908
+ def profile(pipe_or_dict_of_modules,profile_no: profile_type, quantizeTransformer = True):
909
+ """Apply a configuration profile that depends on your hardware:
910
+ pipe_or_dict_of_modules : the pipeline object or a dictionary of modules of the model
911
+ profile_name : num of the profile:
912
+ HighRAM_HighVRAM_Fastest (=1): at least 48 GB of RAM and 24 GB of VRAM : the fastest well suited for a RTX 3090 / RTX 4090
913
+ HighRAM_LowVRAM_Fast (=2): at least 48 GB of RAM and 12 GB of VRAM : a bit slower, better suited for RTX 3070/3080/4070/4080
914
+ or for RTX 3090 / RTX 4090 with large pictures batches or long videos
915
+ LowRAM_HighVRAM_Medium (=3): at least 32 GB of RAM and 24 GB of VRAM : so so speed but adapted for RTX 3090 / RTX 4090 with limited RAM
916
+ LowRAM_LowVRAM_Slow (=4): at least 32 GB of RAM and 12 GB of VRAM : if have little VRAM or generate longer videos
917
+ VerylowRAM_LowVRAM_Slowest (=5): at least 24 GB of RAM and 10 GB of VRAM : if you don't have much it won't be fast but maybe it will work
918
+ quantizeTransformer: bool = True, the main model is quantized by default for all the profiles, you may want to disable that to get the best image quality
919
+ """
920
+
921
+
922
+ modules = pipe_or_dict_of_modules
923
+ if hasattr(modules, "components"):
924
+ modules= modules.components
925
+ any_T5 = False
926
+ if "text_encoder_2" in modules:
927
+ text_encoder_2 = modules["text_encoder_2"]
928
+ any_T5 = "t5" in text_encoder_2.__module__.lower()
929
+ extra_mod_to_quantize = ("text_encoder_2" if any_T5 else "text_encoder")
930
+
931
+ # transformer (video or image generator) should be as small as possible to not occupy space that could be used by actual image data
932
+ # on the other hand the text encoder should be quite large (as long as it fits in 10 GB of VRAM) to reduce sequence offloading
933
+
934
+ budgets = { "transformer" : 600 , "text_encoder": 3000, "text_encoder_2": 3000 }
935
+
936
+ if profile_no == profile_type.HighRAM_HighVRAM_Fastest:
937
+ info = "You have chosen a Very Fast profile that requires at least 48 GB of RAM and 24 GB of VRAM."
938
+ return offload.all(pipe_or_dict_of_modules, pinInRAM= True, info = info, quantizeTransformer= quantizeTransformer)
939
+ elif profile_no == profile_type.HighRAM_LowVRAM_Fast:
940
+ info = "You have chosen a Fast profile that requires at least 48 GB of RAM and 12 GB of VRAM."
941
+ return offload.all(pipe_or_dict_of_modules, pinInRAM= True, budgets=budgets, info = info, quantizeTransformer= quantizeTransformer )
942
+ elif profile_no == profile_type.LowRAM_HighVRAM_Medium:
943
+ info = "You have chosen a Medium speed profile that requires at least 32 GB of RAM and 24 GB of VRAM."
944
+ return offload.all(pipe_or_dict_of_modules, pinInRAM= "transformer", modelsToQuantize= extra_mod_to_quantize , info = info, quantizeTransformer= quantizeTransformer)
945
+ elif profile_no == profile_type.LowRAM_LowVRAM_Slow:
946
+ info = "You have chosen the Slowest profile that requires at least 32 GB of RAM and 12 GB of VRAM."
947
+ return offload.all(pipe_or_dict_of_modules, pinInRAM= "transformer", modelsToQuantize= extra_mod_to_quantize , budgets=budgets, info = info, quantizeTransformer= quantizeTransformer)
948
+ elif profile_no == profile_type.VerylowRAM_LowVRAM_Slowest:
949
+ budgets["transformer"] = 400
950
+ info = "You have chosen the Slowest profile that requires at least 24 GB of RAM and 10 GB of VRAM."
951
+ return offload.all(pipe_or_dict_of_modules, pinInRAM= False, modelsToQuantize= extra_mod_to_quantize , budgets=budgets, info = info, quantizeTransformer= quantizeTransformer)
952
+ else:
953
+ raise("Unknown profile")
954
+